geoai-py 0.14.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -383,6 +383,394 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
383
383
  return accum_mean / len(files), accum_std / len(files)
384
384
 
385
385
 
386
+ def calc_iou(
387
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
388
+ prediction: Union[str, np.ndarray, torch.Tensor],
389
+ num_classes: Optional[int] = None,
390
+ ignore_index: Optional[int] = None,
391
+ smooth: float = 1e-6,
392
+ band: int = 1,
393
+ ) -> Union[float, np.ndarray]:
394
+ """
395
+ Calculate Intersection over Union (IoU) between ground truth and prediction masks.
396
+
397
+ This function computes the IoU metric for segmentation tasks. It supports both
398
+ binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
399
+ or file paths to raster files.
400
+
401
+ Args:
402
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
403
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
404
+ For binary segmentation: shape (H, W) with values {0, 1}.
405
+ For multi-class segmentation: shape (H, W) with class indices.
406
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
407
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
408
+ Should have the same shape and format as ground_truth.
409
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
410
+ If None, assumes binary segmentation. Defaults to None.
411
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
412
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
413
+ smooth (float, optional): Smoothing factor to avoid division by zero.
414
+ Defaults to 1e-6.
415
+ band (int, optional): Band index to read from raster file (1-based indexing).
416
+ Only used when input is a file path. Defaults to 1.
417
+
418
+ Returns:
419
+ Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
420
+ For multi-class segmentation, returns an array of IoU scores for each class.
421
+
422
+ Examples:
423
+ >>> # Binary segmentation with arrays
424
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
425
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
426
+ >>> iou = calc_iou(gt, pred)
427
+ >>> print(f"IoU: {iou:.4f}")
428
+ IoU: 0.8333
429
+
430
+ >>> # Multi-class segmentation
431
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
432
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
433
+ >>> iou = calc_iou(gt, pred, num_classes=3)
434
+ >>> print(f"IoU per class: {iou}")
435
+ IoU per class: [0.8333 0.5000 0.5000]
436
+
437
+ >>> # Using PyTorch tensors
438
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
439
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
440
+ >>> iou = calc_iou(gt_tensor, pred_tensor)
441
+ >>> print(f"IoU: {iou:.4f}")
442
+ IoU: 0.8333
443
+
444
+ >>> # Using raster file paths
445
+ >>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
446
+ >>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
447
+ Mean IoU: 0.7500
448
+ """
449
+ # Load from file if string path is provided
450
+ if isinstance(ground_truth, str):
451
+ with rasterio.open(ground_truth) as src:
452
+ ground_truth = src.read(band)
453
+ if isinstance(prediction, str):
454
+ with rasterio.open(prediction) as src:
455
+ prediction = src.read(band)
456
+
457
+ # Convert to numpy if torch tensor
458
+ if isinstance(ground_truth, torch.Tensor):
459
+ ground_truth = ground_truth.cpu().numpy()
460
+ if isinstance(prediction, torch.Tensor):
461
+ prediction = prediction.cpu().numpy()
462
+
463
+ # Ensure inputs have the same shape
464
+ if ground_truth.shape != prediction.shape:
465
+ raise ValueError(
466
+ f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
467
+ )
468
+
469
+ # Binary segmentation
470
+ if num_classes is None:
471
+ ground_truth = ground_truth.astype(bool)
472
+ prediction = prediction.astype(bool)
473
+
474
+ intersection = np.logical_and(ground_truth, prediction).sum()
475
+ union = np.logical_or(ground_truth, prediction).sum()
476
+
477
+ if union == 0:
478
+ return 1.0 if intersection == 0 else 0.0
479
+
480
+ iou = (intersection + smooth) / (union + smooth)
481
+ return float(iou)
482
+
483
+ # Multi-class segmentation
484
+ else:
485
+ iou_per_class = []
486
+
487
+ for class_idx in range(num_classes):
488
+ # Handle ignored class by appending np.nan
489
+ if ignore_index is not None and class_idx == ignore_index:
490
+ iou_per_class.append(np.nan)
491
+ continue
492
+
493
+ # Create binary masks for current class
494
+ gt_class = (ground_truth == class_idx).astype(bool)
495
+ pred_class = (prediction == class_idx).astype(bool)
496
+
497
+ intersection = np.logical_and(gt_class, pred_class).sum()
498
+ union = np.logical_or(gt_class, pred_class).sum()
499
+
500
+ if union == 0:
501
+ # If class is not present in both gt and pred
502
+ iou_per_class.append(np.nan)
503
+ else:
504
+ iou_per_class.append((intersection + smooth) / (union + smooth))
505
+
506
+ return np.array(iou_per_class)
507
+
508
+
509
+ def calc_f1_score(
510
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
511
+ prediction: Union[str, np.ndarray, torch.Tensor],
512
+ num_classes: Optional[int] = None,
513
+ ignore_index: Optional[int] = None,
514
+ smooth: float = 1e-6,
515
+ band: int = 1,
516
+ ) -> Union[float, np.ndarray]:
517
+ """
518
+ Calculate F1 score between ground truth and prediction masks.
519
+
520
+ The F1 score is the harmonic mean of precision and recall, computed as:
521
+ F1 = 2 * (precision * recall) / (precision + recall)
522
+ where precision = TP / (TP + FP) and recall = TP / (TP + FN).
523
+
524
+ This function supports both binary and multi-class segmentation, and can handle
525
+ numpy arrays, PyTorch tensors, or file paths to raster files.
526
+
527
+ Args:
528
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
529
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
530
+ For binary segmentation: shape (H, W) with values {0, 1}.
531
+ For multi-class segmentation: shape (H, W) with class indices.
532
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
533
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
534
+ Should have the same shape and format as ground_truth.
535
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
536
+ If None, assumes binary segmentation. Defaults to None.
537
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
538
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
539
+ smooth (float, optional): Smoothing factor to avoid division by zero.
540
+ Defaults to 1e-6.
541
+ band (int, optional): Band index to read from raster file (1-based indexing).
542
+ Only used when input is a file path. Defaults to 1.
543
+
544
+ Returns:
545
+ Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
546
+ For multi-class segmentation, returns an array of F1 scores for each class.
547
+
548
+ Examples:
549
+ >>> # Binary segmentation with arrays
550
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
551
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
552
+ >>> f1 = calc_f1_score(gt, pred)
553
+ >>> print(f"F1 Score: {f1:.4f}")
554
+ F1 Score: 0.8571
555
+
556
+ >>> # Multi-class segmentation
557
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
558
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
559
+ >>> f1 = calc_f1_score(gt, pred, num_classes=3)
560
+ >>> print(f"F1 Score per class: {f1}")
561
+ F1 Score per class: [0.8571 0.6667 0.6667]
562
+
563
+ >>> # Using PyTorch tensors
564
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
565
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
566
+ >>> f1 = calc_f1_score(gt_tensor, pred_tensor)
567
+ >>> print(f"F1 Score: {f1:.4f}")
568
+ F1 Score: 0.8571
569
+
570
+ >>> # Using raster file paths
571
+ >>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
572
+ >>> print(f"Mean F1: {np.nanmean(f1):.4f}")
573
+ Mean F1: 0.7302
574
+ """
575
+ # Load from file if string path is provided
576
+ if isinstance(ground_truth, str):
577
+ with rasterio.open(ground_truth) as src:
578
+ ground_truth = src.read(band)
579
+ if isinstance(prediction, str):
580
+ with rasterio.open(prediction) as src:
581
+ prediction = src.read(band)
582
+
583
+ # Convert to numpy if torch tensor
584
+ if isinstance(ground_truth, torch.Tensor):
585
+ ground_truth = ground_truth.cpu().numpy()
586
+ if isinstance(prediction, torch.Tensor):
587
+ prediction = prediction.cpu().numpy()
588
+
589
+ # Ensure inputs have the same shape
590
+ if ground_truth.shape != prediction.shape:
591
+ raise ValueError(
592
+ f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
593
+ )
594
+
595
+ # Binary segmentation
596
+ if num_classes is None:
597
+ ground_truth = ground_truth.astype(bool)
598
+ prediction = prediction.astype(bool)
599
+
600
+ # Calculate True Positives, False Positives, False Negatives
601
+ tp = np.logical_and(ground_truth, prediction).sum()
602
+ fp = np.logical_and(~ground_truth, prediction).sum()
603
+ fn = np.logical_and(ground_truth, ~prediction).sum()
604
+
605
+ # Calculate precision and recall
606
+ precision = (tp + smooth) / (tp + fp + smooth)
607
+ recall = (tp + smooth) / (tp + fn + smooth)
608
+
609
+ # Calculate F1 score
610
+ f1 = 2 * (precision * recall) / (precision + recall + smooth)
611
+ return float(f1)
612
+
613
+ # Multi-class segmentation
614
+ else:
615
+ f1_per_class = []
616
+
617
+ for class_idx in range(num_classes):
618
+ # Mark ignored class with np.nan
619
+ if ignore_index is not None and class_idx == ignore_index:
620
+ f1_per_class.append(np.nan)
621
+ continue
622
+
623
+ # Create binary masks for current class
624
+ gt_class = (ground_truth == class_idx).astype(bool)
625
+ pred_class = (prediction == class_idx).astype(bool)
626
+
627
+ # Calculate True Positives, False Positives, False Negatives
628
+ tp = np.logical_and(gt_class, pred_class).sum()
629
+ fp = np.logical_and(~gt_class, pred_class).sum()
630
+ fn = np.logical_and(gt_class, ~pred_class).sum()
631
+
632
+ # Calculate precision and recall
633
+ precision = (tp + smooth) / (tp + fp + smooth)
634
+ recall = (tp + smooth) / (tp + fn + smooth)
635
+
636
+ # Calculate F1 score
637
+ if tp + fp + fn == 0:
638
+ # If class is not present in both gt and pred
639
+ f1_per_class.append(np.nan)
640
+ else:
641
+ f1 = 2 * (precision * recall) / (precision + recall + smooth)
642
+ f1_per_class.append(f1)
643
+
644
+ return np.array(f1_per_class)
645
+
646
+
647
+ def calc_segmentation_metrics(
648
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
649
+ prediction: Union[str, np.ndarray, torch.Tensor],
650
+ num_classes: Optional[int] = None,
651
+ ignore_index: Optional[int] = None,
652
+ smooth: float = 1e-6,
653
+ metrics: List[str] = ["iou", "f1"],
654
+ band: int = 1,
655
+ ) -> Dict[str, Union[float, np.ndarray]]:
656
+ """
657
+ Calculate multiple segmentation metrics between ground truth and prediction masks.
658
+
659
+ This is a convenient wrapper function that computes multiple metrics at once,
660
+ including IoU (Intersection over Union) and F1 score. It supports both binary
661
+ and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
662
+ or file paths to raster files.
663
+
664
+ Args:
665
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
666
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
667
+ For binary segmentation: shape (H, W) with values {0, 1}.
668
+ For multi-class segmentation: shape (H, W) with class indices.
669
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
670
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
671
+ Should have the same shape and format as ground_truth.
672
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
673
+ If None, assumes binary segmentation. Defaults to None.
674
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
675
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
676
+ smooth (float, optional): Smoothing factor to avoid division by zero.
677
+ Defaults to 1e-6.
678
+ metrics (List[str], optional): List of metrics to calculate.
679
+ Options: "iou", "f1". Defaults to ["iou", "f1"].
680
+ band (int, optional): Band index to read from raster file (1-based indexing).
681
+ Only used when input is a file path. Defaults to 1.
682
+
683
+ Returns:
684
+ Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
685
+ Keys are metric names ("iou", "f1"), values are the metric scores.
686
+ For binary segmentation, values are floats.
687
+ For multi-class segmentation, values are numpy arrays with per-class scores.
688
+ Also includes "mean_iou" and "mean_f1" for multi-class segmentation
689
+ (mean computed over valid classes, ignoring NaN values).
690
+
691
+ Examples:
692
+ >>> # Binary segmentation with arrays
693
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
694
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
695
+ >>> metrics = calc_segmentation_metrics(gt, pred)
696
+ >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
697
+ IoU: 0.8333, F1: 0.8571
698
+
699
+ >>> # Multi-class segmentation
700
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
701
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
702
+ >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
703
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
704
+ >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
705
+ >>> print(f"Per-class IoU: {metrics['iou']}")
706
+ Mean IoU: 0.6111
707
+ Mean F1: 0.7302
708
+ Per-class IoU: [0.8333 0.5000 0.5000]
709
+
710
+ >>> # Calculate only IoU
711
+ >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
712
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
713
+ Mean IoU: 0.6111
714
+
715
+ >>> # Using PyTorch tensors
716
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
717
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
718
+ >>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
719
+ >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
720
+ IoU: 0.8333, F1: 0.8571
721
+
722
+ >>> # Using raster file paths
723
+ >>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
724
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
725
+ >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
726
+ Mean IoU: 0.6111
727
+ Mean F1: 0.7302
728
+ """
729
+ results = {}
730
+
731
+ # Calculate IoU if requested
732
+ if "iou" in metrics:
733
+ iou = calc_iou(
734
+ ground_truth,
735
+ prediction,
736
+ num_classes=num_classes,
737
+ ignore_index=ignore_index,
738
+ smooth=smooth,
739
+ band=band,
740
+ )
741
+ results["iou"] = iou
742
+
743
+ # Add mean IoU for multi-class
744
+ if num_classes is not None and isinstance(iou, np.ndarray):
745
+ # Calculate mean ignoring NaN values
746
+ valid_ious = iou[~np.isnan(iou)]
747
+ results["mean_iou"] = (
748
+ float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
749
+ )
750
+
751
+ # Calculate F1 score if requested
752
+ if "f1" in metrics:
753
+ f1 = calc_f1_score(
754
+ ground_truth,
755
+ prediction,
756
+ num_classes=num_classes,
757
+ ignore_index=ignore_index,
758
+ smooth=smooth,
759
+ band=band,
760
+ )
761
+ results["f1"] = f1
762
+
763
+ # Add mean F1 for multi-class
764
+ if num_classes is not None and isinstance(f1, np.ndarray):
765
+ # Calculate mean ignoring NaN values
766
+ valid_f1s = f1[~np.isnan(f1)]
767
+ results["mean_f1"] = (
768
+ float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
769
+ )
770
+
771
+ return results
772
+
773
+
386
774
  def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
387
775
  """Convert a dictionary to a xarray DataArray. The dictionary should contain the
388
776
  following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
@@ -751,6 +1139,12 @@ def view_vector_interactive(
751
1139
  },
752
1140
  }
753
1141
 
1142
+ # Make it compatible with binder and JupyterHub
1143
+ if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
1144
+ os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
1145
+ f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
1146
+ )
1147
+
754
1148
  basemap_layer_name = None
755
1149
  raster_layer = None
756
1150
 
@@ -2599,7 +2993,7 @@ def batch_vector_to_raster(
2599
2993
  def export_geotiff_tiles(
2600
2994
  in_raster,
2601
2995
  out_folder,
2602
- in_class_data,
2996
+ in_class_data=None,
2603
2997
  tile_size=256,
2604
2998
  stride=128,
2605
2999
  class_value_field="class",
@@ -2609,6 +3003,7 @@ def export_geotiff_tiles(
2609
3003
  all_touched=True,
2610
3004
  create_overview=False,
2611
3005
  skip_empty_tiles=False,
3006
+ metadata_format="PASCAL_VOC",
2612
3007
  ):
2613
3008
  """
2614
3009
  Export georeferenced GeoTIFF tiles and labels from raster and classification data.
@@ -2616,7 +3011,8 @@ def export_geotiff_tiles(
2616
3011
  Args:
2617
3012
  in_raster (str): Path to input raster image
2618
3013
  out_folder (str): Path to output folder
2619
- in_class_data (str): Path to classification data - can be vector file or raster
3014
+ in_class_data (str, optional): Path to classification data - can be vector file or raster.
3015
+ If None, only image tiles will be exported without labels. Defaults to None.
2620
3016
  tile_size (int): Size of tiles in pixels (square)
2621
3017
  stride (int): Step size between tiles
2622
3018
  class_value_field (str): Field containing class values (for vector data)
@@ -2626,6 +3022,7 @@ def export_geotiff_tiles(
2626
3022
  all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
2627
3023
  create_overview (bool): Whether to create an overview image of all tiles
2628
3024
  skip_empty_tiles (bool): If True, skip tiles with no features
3025
+ metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
2629
3026
  """
2630
3027
 
2631
3028
  import logging
@@ -2636,28 +3033,42 @@ def export_geotiff_tiles(
2636
3033
  os.makedirs(out_folder, exist_ok=True)
2637
3034
  image_dir = os.path.join(out_folder, "images")
2638
3035
  os.makedirs(image_dir, exist_ok=True)
2639
- label_dir = os.path.join(out_folder, "labels")
2640
- os.makedirs(label_dir, exist_ok=True)
2641
- ann_dir = os.path.join(out_folder, "annotations")
2642
- os.makedirs(ann_dir, exist_ok=True)
2643
3036
 
2644
- # Determine if class data is raster or vector
3037
+ # Only create label and annotation directories if class data is provided
3038
+ if in_class_data is not None:
3039
+ label_dir = os.path.join(out_folder, "labels")
3040
+ os.makedirs(label_dir, exist_ok=True)
3041
+
3042
+ # Create annotation directory based on metadata format
3043
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
3044
+ ann_dir = os.path.join(out_folder, "annotations")
3045
+ os.makedirs(ann_dir, exist_ok=True)
3046
+
3047
+ # Initialize COCO annotations dictionary
3048
+ if metadata_format == "COCO":
3049
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
3050
+ ann_id = 0
3051
+
3052
+ # Determine if class data is raster or vector (only if class data provided)
2645
3053
  is_class_data_raster = False
2646
- if isinstance(in_class_data, str):
2647
- file_ext = Path(in_class_data).suffix.lower()
2648
- # Common raster extensions
2649
- if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
2650
- try:
2651
- with rasterio.open(in_class_data) as src:
2652
- is_class_data_raster = True
3054
+ if in_class_data is not None:
3055
+ if isinstance(in_class_data, str):
3056
+ file_ext = Path(in_class_data).suffix.lower()
3057
+ # Common raster extensions
3058
+ if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
3059
+ try:
3060
+ with rasterio.open(in_class_data) as src:
3061
+ is_class_data_raster = True
3062
+ if not quiet:
3063
+ print(f"Detected in_class_data as raster: {in_class_data}")
3064
+ print(f"Raster CRS: {src.crs}")
3065
+ print(f"Raster dimensions: {src.width} x {src.height}")
3066
+ except Exception:
3067
+ is_class_data_raster = False
2653
3068
  if not quiet:
2654
- print(f"Detected in_class_data as raster: {in_class_data}")
2655
- print(f"Raster CRS: {src.crs}")
2656
- print(f"Raster dimensions: {src.width} x {src.height}")
2657
- except Exception:
2658
- is_class_data_raster = False
2659
- if not quiet:
2660
- print(f"Unable to open {in_class_data} as raster, trying as vector")
3069
+ print(
3070
+ f"Unable to open {in_class_data} as raster, trying as vector"
3071
+ )
2661
3072
 
2662
3073
  # Open the input raster
2663
3074
  with rasterio.open(in_raster) as src:
@@ -2677,10 +3088,10 @@ def export_geotiff_tiles(
2677
3088
  if max_tiles is None:
2678
3089
  max_tiles = total_tiles
2679
3090
 
2680
- # Process classification data
3091
+ # Process classification data (only if class data provided)
2681
3092
  class_to_id = {}
2682
3093
 
2683
- if is_class_data_raster:
3094
+ if in_class_data is not None and is_class_data_raster:
2684
3095
  # Load raster class data
2685
3096
  with rasterio.open(in_class_data) as class_src:
2686
3097
  # Check if raster CRS matches
@@ -2713,7 +3124,18 @@ def export_geotiff_tiles(
2713
3124
 
2714
3125
  # Create class mapping
2715
3126
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
2716
- else:
3127
+
3128
+ # Populate COCO categories
3129
+ if metadata_format == "COCO":
3130
+ for cls_val in unique_classes:
3131
+ coco_annotations["categories"].append(
3132
+ {
3133
+ "id": class_to_id[int(cls_val)],
3134
+ "name": str(int(cls_val)),
3135
+ "supercategory": "object",
3136
+ }
3137
+ )
3138
+ elif in_class_data is not None:
2717
3139
  # Load vector class data
2718
3140
  try:
2719
3141
  gdf = gpd.read_file(in_class_data)
@@ -2742,12 +3164,33 @@ def export_geotiff_tiles(
2742
3164
  )
2743
3165
  # Create class mapping
2744
3166
  class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
3167
+
3168
+ # Populate COCO categories
3169
+ if metadata_format == "COCO":
3170
+ for cls_val in unique_classes:
3171
+ coco_annotations["categories"].append(
3172
+ {
3173
+ "id": class_to_id[cls_val],
3174
+ "name": str(cls_val),
3175
+ "supercategory": "object",
3176
+ }
3177
+ )
2745
3178
  else:
2746
3179
  if not quiet:
2747
3180
  print(
2748
3181
  f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
2749
3182
  )
2750
3183
  class_to_id = {1: 1} # Default mapping
3184
+
3185
+ # Populate COCO categories with default
3186
+ if metadata_format == "COCO":
3187
+ coco_annotations["categories"].append(
3188
+ {
3189
+ "id": 1,
3190
+ "name": "object",
3191
+ "supercategory": "object",
3192
+ }
3193
+ )
2751
3194
  except Exception as e:
2752
3195
  raise ValueError(f"Error processing vector data: {e}")
2753
3196
 
@@ -2814,8 +3257,8 @@ def export_geotiff_tiles(
2814
3257
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
2815
3258
  has_features = False
2816
3259
 
2817
- # Process classification data to create labels
2818
- if is_class_data_raster:
3260
+ # Process classification data to create labels (only if class data provided)
3261
+ if in_class_data is not None and is_class_data_raster:
2819
3262
  # For raster class data
2820
3263
  with rasterio.open(in_class_data) as class_src:
2821
3264
  # Calculate window in class raster
@@ -2865,7 +3308,7 @@ def export_geotiff_tiles(
2865
3308
  except Exception as e:
2866
3309
  pbar.write(f"Error reading class raster window: {e}")
2867
3310
  stats["errors"] += 1
2868
- else:
3311
+ elif in_class_data is not None:
2869
3312
  # For vector class data
2870
3313
  # Find features that intersect with window
2871
3314
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -2908,8 +3351,8 @@ def export_geotiff_tiles(
2908
3351
  pbar.write(f"Error rasterizing feature {idx}: {e}")
2909
3352
  stats["errors"] += 1
2910
3353
 
2911
- # Skip tile if no features and skip_empty_tiles is True
2912
- if skip_empty_tiles and not has_features:
3354
+ # Skip tile if no features and skip_empty_tiles is True (only when class data provided)
3355
+ if in_class_data is not None and skip_empty_tiles and not has_features:
2913
3356
  pbar.update(1)
2914
3357
  tile_index += 1
2915
3358
  continue
@@ -2940,96 +3383,212 @@ def export_geotiff_tiles(
2940
3383
  pbar.write(f"ERROR saving image GeoTIFF: {e}")
2941
3384
  stats["errors"] += 1
2942
3385
 
2943
- # Create profile for label GeoTIFF
2944
- label_profile = {
2945
- "driver": "GTiff",
2946
- "height": tile_size,
2947
- "width": tile_size,
2948
- "count": 1,
2949
- "dtype": "uint8",
2950
- "crs": src.crs,
2951
- "transform": window_transform,
2952
- }
3386
+ # Export label as GeoTIFF (only if class data provided)
3387
+ if in_class_data is not None:
3388
+ # Create profile for label GeoTIFF
3389
+ label_profile = {
3390
+ "driver": "GTiff",
3391
+ "height": tile_size,
3392
+ "width": tile_size,
3393
+ "count": 1,
3394
+ "dtype": "uint8",
3395
+ "crs": src.crs,
3396
+ "transform": window_transform,
3397
+ }
2953
3398
 
2954
- # Export label as GeoTIFF
2955
- label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
2956
- try:
2957
- with rasterio.open(label_path, "w", **label_profile) as dst:
2958
- dst.write(label_mask.astype(np.uint8), 1)
3399
+ label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
3400
+ try:
3401
+ with rasterio.open(label_path, "w", **label_profile) as dst:
3402
+ dst.write(label_mask.astype(np.uint8), 1)
2959
3403
 
2960
- if has_features:
2961
- stats["tiles_with_features"] += 1
2962
- stats["feature_pixels"] += np.count_nonzero(label_mask)
2963
- except Exception as e:
2964
- pbar.write(f"ERROR saving label GeoTIFF: {e}")
2965
- stats["errors"] += 1
3404
+ if has_features:
3405
+ stats["tiles_with_features"] += 1
3406
+ stats["feature_pixels"] += np.count_nonzero(label_mask)
3407
+ except Exception as e:
3408
+ pbar.write(f"ERROR saving label GeoTIFF: {e}")
3409
+ stats["errors"] += 1
2966
3410
 
2967
- # Create XML annotation for object detection if using vector class data
3411
+ # Create annotations for object detection if using vector class data
2968
3412
  if (
2969
- not is_class_data_raster
3413
+ in_class_data is not None
3414
+ and not is_class_data_raster
2970
3415
  and "gdf" in locals()
2971
3416
  and len(window_features) > 0
2972
3417
  ):
2973
- # Create XML annotation
2974
- root = ET.Element("annotation")
2975
- ET.SubElement(root, "folder").text = "images"
2976
- ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
3418
+ if metadata_format == "PASCAL_VOC":
3419
+ # Create XML annotation
3420
+ root = ET.Element("annotation")
3421
+ ET.SubElement(root, "folder").text = "images"
3422
+ ET.SubElement(root, "filename").text = (
3423
+ f"tile_{tile_index:06d}.tif"
3424
+ )
2977
3425
 
2978
- size = ET.SubElement(root, "size")
2979
- ET.SubElement(size, "width").text = str(tile_size)
2980
- ET.SubElement(size, "height").text = str(tile_size)
2981
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
3426
+ size = ET.SubElement(root, "size")
3427
+ ET.SubElement(size, "width").text = str(tile_size)
3428
+ ET.SubElement(size, "height").text = str(tile_size)
3429
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
3430
+
3431
+ # Add georeference information
3432
+ geo = ET.SubElement(root, "georeference")
3433
+ ET.SubElement(geo, "crs").text = str(src.crs)
3434
+ ET.SubElement(geo, "transform").text = str(
3435
+ window_transform
3436
+ ).replace("\n", "")
3437
+ ET.SubElement(geo, "bounds").text = (
3438
+ f"{minx}, {miny}, {maxx}, {maxy}"
3439
+ )
2982
3440
 
2983
- # Add georeference information
2984
- geo = ET.SubElement(root, "georeference")
2985
- ET.SubElement(geo, "crs").text = str(src.crs)
2986
- ET.SubElement(geo, "transform").text = str(
2987
- window_transform
2988
- ).replace("\n", "")
2989
- ET.SubElement(geo, "bounds").text = (
2990
- f"{minx}, {miny}, {maxx}, {maxy}"
2991
- )
3441
+ # Add objects
3442
+ for idx, feature in window_features.iterrows():
3443
+ # Get feature class
3444
+ if class_value_field in feature:
3445
+ class_val = feature[class_value_field]
3446
+ else:
3447
+ class_val = "object"
2992
3448
 
2993
- # Add objects
2994
- for idx, feature in window_features.iterrows():
2995
- # Get feature class
2996
- if class_value_field in feature:
2997
- class_val = feature[class_value_field]
2998
- else:
2999
- class_val = "object"
3449
+ # Get geometry bounds in pixel coordinates
3450
+ geom = feature.geometry.intersection(window_bounds)
3451
+ if not geom.is_empty:
3452
+ # Get bounds in world coordinates
3453
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3454
+
3455
+ # Convert to pixel coordinates
3456
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3457
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3458
+
3459
+ # Ensure coordinates are within tile bounds
3460
+ xmin = max(0, min(tile_size, int(col_min)))
3461
+ ymin = max(0, min(tile_size, int(row_min)))
3462
+ xmax = max(0, min(tile_size, int(col_max)))
3463
+ ymax = max(0, min(tile_size, int(row_max)))
3464
+
3465
+ # Only add if the box has non-zero area
3466
+ if xmax > xmin and ymax > ymin:
3467
+ obj = ET.SubElement(root, "object")
3468
+ ET.SubElement(obj, "name").text = str(class_val)
3469
+ ET.SubElement(obj, "difficult").text = "0"
3470
+
3471
+ bbox = ET.SubElement(obj, "bndbox")
3472
+ ET.SubElement(bbox, "xmin").text = str(xmin)
3473
+ ET.SubElement(bbox, "ymin").text = str(ymin)
3474
+ ET.SubElement(bbox, "xmax").text = str(xmax)
3475
+ ET.SubElement(bbox, "ymax").text = str(ymax)
3476
+
3477
+ # Save XML
3478
+ tree = ET.ElementTree(root)
3479
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3480
+ tree.write(xml_path)
3000
3481
 
3001
- # Get geometry bounds in pixel coordinates
3002
- geom = feature.geometry.intersection(window_bounds)
3003
- if not geom.is_empty:
3004
- # Get bounds in world coordinates
3005
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3006
-
3007
- # Convert to pixel coordinates
3008
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3009
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3010
-
3011
- # Ensure coordinates are within tile bounds
3012
- xmin = max(0, min(tile_size, int(col_min)))
3013
- ymin = max(0, min(tile_size, int(row_min)))
3014
- xmax = max(0, min(tile_size, int(col_max)))
3015
- ymax = max(0, min(tile_size, int(row_max)))
3016
-
3017
- # Only add if the box has non-zero area
3018
- if xmax > xmin and ymax > ymin:
3019
- obj = ET.SubElement(root, "object")
3020
- ET.SubElement(obj, "name").text = str(class_val)
3021
- ET.SubElement(obj, "difficult").text = "0"
3022
-
3023
- bbox = ET.SubElement(obj, "bndbox")
3024
- ET.SubElement(bbox, "xmin").text = str(xmin)
3025
- ET.SubElement(bbox, "ymin").text = str(ymin)
3026
- ET.SubElement(bbox, "xmax").text = str(xmax)
3027
- ET.SubElement(bbox, "ymax").text = str(ymax)
3482
+ elif metadata_format == "COCO":
3483
+ # Add image info
3484
+ image_id = tile_index
3485
+ coco_annotations["images"].append(
3486
+ {
3487
+ "id": image_id,
3488
+ "file_name": f"tile_{tile_index:06d}.tif",
3489
+ "width": tile_size,
3490
+ "height": tile_size,
3491
+ "crs": str(src.crs),
3492
+ "transform": str(window_transform),
3493
+ }
3494
+ )
3028
3495
 
3029
- # Save XML
3030
- tree = ET.ElementTree(root)
3031
- xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3032
- tree.write(xml_path)
3496
+ # Add annotations for each feature
3497
+ for _, feature in window_features.iterrows():
3498
+ # Get feature class
3499
+ if class_value_field in feature:
3500
+ class_val = feature[class_value_field]
3501
+ category_id = class_to_id.get(class_val, 1)
3502
+ else:
3503
+ category_id = 1
3504
+
3505
+ # Get geometry bounds
3506
+ geom = feature.geometry.intersection(window_bounds)
3507
+ if not geom.is_empty:
3508
+ # Get bounds in world coordinates
3509
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3510
+
3511
+ # Convert to pixel coordinates
3512
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3513
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3514
+
3515
+ # Ensure coordinates are within tile bounds
3516
+ xmin = max(0, min(tile_size, int(col_min)))
3517
+ ymin = max(0, min(tile_size, int(row_min)))
3518
+ xmax = max(0, min(tile_size, int(col_max)))
3519
+ ymax = max(0, min(tile_size, int(row_max)))
3520
+
3521
+ # Skip if box is too small
3522
+ if xmax - xmin < 1 or ymax - ymin < 1:
3523
+ continue
3524
+
3525
+ width = xmax - xmin
3526
+ height = ymax - ymin
3527
+
3528
+ # Add annotation
3529
+ ann_id += 1
3530
+ coco_annotations["annotations"].append(
3531
+ {
3532
+ "id": ann_id,
3533
+ "image_id": image_id,
3534
+ "category_id": category_id,
3535
+ "bbox": [xmin, ymin, width, height],
3536
+ "area": width * height,
3537
+ "iscrowd": 0,
3538
+ }
3539
+ )
3540
+
3541
+ elif metadata_format == "YOLO":
3542
+ # Create YOLO format annotations
3543
+ yolo_annotations = []
3544
+
3545
+ for _, feature in window_features.iterrows():
3546
+ # Get feature class
3547
+ if class_value_field in feature:
3548
+ class_val = feature[class_value_field]
3549
+ # YOLO uses 0-indexed class IDs
3550
+ class_id = class_to_id.get(class_val, 1) - 1
3551
+ else:
3552
+ class_id = 0
3553
+
3554
+ # Get geometry bounds
3555
+ geom = feature.geometry.intersection(window_bounds)
3556
+ if not geom.is_empty:
3557
+ # Get bounds in world coordinates
3558
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3559
+
3560
+ # Convert to pixel coordinates
3561
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3562
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3563
+
3564
+ # Ensure coordinates are within tile bounds
3565
+ xmin = max(0, min(tile_size, col_min))
3566
+ ymin = max(0, min(tile_size, row_min))
3567
+ xmax = max(0, min(tile_size, col_max))
3568
+ ymax = max(0, min(tile_size, row_max))
3569
+
3570
+ # Skip if box is too small
3571
+ if xmax - xmin < 1 or ymax - ymin < 1:
3572
+ continue
3573
+
3574
+ # Calculate normalized coordinates (YOLO format)
3575
+ x_center = ((xmin + xmax) / 2) / tile_size
3576
+ y_center = ((ymin + ymax) / 2) / tile_size
3577
+ width = (xmax - xmin) / tile_size
3578
+ height = (ymax - ymin) / tile_size
3579
+
3580
+ # Add YOLO annotation line
3581
+ yolo_annotations.append(
3582
+ f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
3583
+ )
3584
+
3585
+ # Save YOLO annotations to text file
3586
+ if yolo_annotations:
3587
+ yolo_path = os.path.join(
3588
+ label_dir, f"tile_{tile_index:06d}.txt"
3589
+ )
3590
+ with open(yolo_path, "w") as f:
3591
+ f.write("\n".join(yolo_annotations))
3033
3592
 
3034
3593
  # Update progress bar
3035
3594
  pbar.update(1)
@@ -3047,6 +3606,39 @@ def export_geotiff_tiles(
3047
3606
  # Close progress bar
3048
3607
  pbar.close()
3049
3608
 
3609
+ # Save COCO annotations if applicable (only if class data provided)
3610
+ if in_class_data is not None and metadata_format == "COCO":
3611
+ try:
3612
+ with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3613
+ json.dump(coco_annotations, f, indent=2)
3614
+ if not quiet:
3615
+ print(
3616
+ f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
3617
+ f"{len(coco_annotations['annotations'])} annotations, "
3618
+ f"{len(coco_annotations['categories'])} categories"
3619
+ )
3620
+ except Exception as e:
3621
+ if not quiet:
3622
+ print(f"ERROR saving COCO annotations: {e}")
3623
+ stats["errors"] += 1
3624
+
3625
+ # Save YOLO classes file if applicable (only if class data provided)
3626
+ if in_class_data is not None and metadata_format == "YOLO":
3627
+ try:
3628
+ # Create classes.txt with class names
3629
+ classes_path = os.path.join(out_folder, "classes.txt")
3630
+ # Sort by class ID to ensure correct order
3631
+ sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
3632
+ with open(classes_path, "w") as f:
3633
+ for class_val, _ in sorted_classes:
3634
+ f.write(f"{class_val}\n")
3635
+ if not quiet:
3636
+ print(f"Saved YOLO classes file with {len(class_to_id)} classes")
3637
+ except Exception as e:
3638
+ if not quiet:
3639
+ print(f"ERROR saving YOLO classes file: {e}")
3640
+ stats["errors"] += 1
3641
+
3050
3642
  # Create overview image if requested
3051
3643
  if create_overview and stats["tile_coordinates"]:
3052
3644
  try:
@@ -3064,13 +3656,14 @@ def export_geotiff_tiles(
3064
3656
  if not quiet:
3065
3657
  print("\n------- Export Summary -------")
3066
3658
  print(f"Total tiles exported: {stats['total_tiles']}")
3067
- print(
3068
- f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3069
- )
3070
- if stats["tiles_with_features"] > 0:
3659
+ if in_class_data is not None:
3071
3660
  print(
3072
- f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3661
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3073
3662
  )
3663
+ if stats["tiles_with_features"] > 0:
3664
+ print(
3665
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3666
+ )
3074
3667
  if stats["errors"] > 0:
3075
3668
  print(f"Errors encountered: {stats['errors']}")
3076
3669
  print(f"Output saved to: {out_folder}")
@@ -3079,7 +3672,6 @@ def export_geotiff_tiles(
3079
3672
  if stats["total_tiles"] > 0:
3080
3673
  print("\n------- Georeference Verification -------")
3081
3674
  sample_image = os.path.join(image_dir, f"tile_0.tif")
3082
- sample_label = os.path.join(label_dir, f"tile_0.tif")
3083
3675
 
3084
3676
  if os.path.exists(sample_image):
3085
3677
  try:
@@ -3095,19 +3687,22 @@ def export_geotiff_tiles(
3095
3687
  except Exception as e:
3096
3688
  print(f"Error verifying image georeference: {e}")
3097
3689
 
3098
- if os.path.exists(sample_label):
3099
- try:
3100
- with rasterio.open(sample_label) as lbl:
3101
- print(f"Label CRS: {lbl.crs}")
3102
- print(f"Label transform: {lbl.transform}")
3103
- print(
3104
- f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3105
- )
3106
- print(
3107
- f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3108
- )
3109
- except Exception as e:
3110
- print(f"Error verifying label georeference: {e}")
3690
+ # Only verify label if class data was provided
3691
+ if in_class_data is not None:
3692
+ sample_label = os.path.join(label_dir, f"tile_0.tif")
3693
+ if os.path.exists(sample_label):
3694
+ try:
3695
+ with rasterio.open(sample_label) as lbl:
3696
+ print(f"Label CRS: {lbl.crs}")
3697
+ print(f"Label transform: {lbl.transform}")
3698
+ print(
3699
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3700
+ )
3701
+ print(
3702
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3703
+ )
3704
+ except Exception as e:
3705
+ print(f"Error verifying label georeference: {e}")
3111
3706
 
3112
3707
  # Return statistics dictionary for further processing if needed
3113
3708
  return stats
@@ -3125,36 +3720,41 @@ def export_geotiff_tiles_batch(
3125
3720
  max_tiles=None,
3126
3721
  quiet=False,
3127
3722
  all_touched=True,
3128
- create_overview=False,
3129
3723
  skip_empty_tiles=False,
3130
3724
  image_extensions=None,
3131
3725
  mask_extensions=None,
3132
- match_by_name=True,
3726
+ match_by_name=False,
3727
+ metadata_format="PASCAL_VOC",
3133
3728
  ) -> Dict[str, Any]:
3134
3729
  """
3135
- Export georeferenced GeoTIFF tiles from images and masks.
3730
+ Export georeferenced GeoTIFF tiles from images and optionally masks.
3731
+
3732
+ This function supports four modes:
3733
+ 1. Images only (no masks) - when neither masks_file nor masks_folder is provided
3734
+ 2. Single vector file covering all images (masks_file parameter)
3735
+ 3. Multiple vector files, one per image (masks_folder parameter)
3736
+ 4. Multiple raster mask files (masks_folder parameter)
3136
3737
 
3137
- This function supports three mask input modes:
3138
- 1. Single vector file covering all images (masks_file parameter)
3139
- 2. Multiple vector files, one per image (masks_folder parameter)
3140
- 3. Multiple raster mask files (masks_folder parameter)
3738
+ For mode 1 (images only), only image tiles will be exported without labels.
3141
3739
 
3142
- For mode 1 (single vector file), specify masks_file path. The function will
3740
+ For mode 2 (single vector file), specify masks_file path. The function will
3143
3741
  use spatial intersection to determine which features apply to each image.
3144
3742
 
3145
- For mode 2/3 (multiple mask files), specify masks_folder path. Images and masks
3743
+ For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
3146
3744
  are paired either by matching filenames (match_by_name=True) or by sorted order
3147
3745
  (match_by_name=False).
3148
3746
 
3149
- All image tiles are saved to a single 'images' folder and all mask tiles to a
3150
- single 'masks' folder within the output directory.
3747
+ All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
3748
+ to a single 'masks' folder within the output directory.
3151
3749
 
3152
3750
  Args:
3153
3751
  images_folder (str): Path to folder containing raster images
3154
3752
  masks_folder (str, optional): Path to folder containing classification masks/vectors.
3155
- Use this for multiple mask files (one per image or raster masks).
3753
+ Use this for multiple mask files (one per image or raster masks). If not provided
3754
+ and masks_file is also not provided, only image tiles will be exported.
3156
3755
  masks_file (str, optional): Path to a single vector file covering all images.
3157
- Use this for a single GeoJSON/Shapefile that covers multiple images.
3756
+ Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
3757
+ and masks_folder is also not provided, only image tiles will be exported.
3158
3758
  output_folder (str, optional): Path to output folder. If None, creates 'tiles'
3159
3759
  subfolder in images_folder.
3160
3760
  tile_size (int): Size of tiles in pixels (square)
@@ -3170,16 +3770,23 @@ def export_geotiff_tiles_batch(
3170
3770
  mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
3171
3771
  match_by_name (bool): If True, match image and mask files by base filename.
3172
3772
  If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
3773
+ metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
3774
+ Default is "PASCAL_VOC".
3173
3775
 
3174
3776
  Returns:
3175
3777
  Dict[str, Any]: Dictionary containing batch processing statistics
3176
3778
 
3177
3779
  Raises:
3178
3780
  ValueError: If no images found, or if masks_folder and masks_file are both specified,
3179
- or if neither is specified, or if counts don't match when using masks_folder with
3180
- match_by_name=False.
3781
+ or if counts don't match when using masks_folder with match_by_name=False.
3181
3782
 
3182
3783
  Examples:
3784
+ # Images only (no masks)
3785
+ >>> stats = export_geotiff_tiles_batch(
3786
+ ... images_folder='data/images',
3787
+ ... output_folder='output/tiles'
3788
+ ... )
3789
+
3183
3790
  # Single vector file covering all images
3184
3791
  >>> stats = export_geotiff_tiles_batch(
3185
3792
  ... images_folder='data/images',
@@ -3214,11 +3821,6 @@ def export_geotiff_tiles_batch(
3214
3821
  "Cannot specify both masks_folder and masks_file. Please use only one."
3215
3822
  )
3216
3823
 
3217
- if masks_folder is None and masks_file is None:
3218
- raise ValueError(
3219
- "Must specify either masks_folder or masks_file for mask data source."
3220
- )
3221
-
3222
3824
  # Default output folder if not specified
3223
3825
  if output_folder is None:
3224
3826
  output_folder = os.path.join(images_folder, "tiles")
@@ -3249,9 +3851,37 @@ def export_geotiff_tiles_batch(
3249
3851
  # Create output folder structure
3250
3852
  os.makedirs(output_folder, exist_ok=True)
3251
3853
  output_images_dir = os.path.join(output_folder, "images")
3252
- output_masks_dir = os.path.join(output_folder, "masks")
3253
3854
  os.makedirs(output_images_dir, exist_ok=True)
3254
- os.makedirs(output_masks_dir, exist_ok=True)
3855
+
3856
+ # Only create masks directory if masks are provided
3857
+ output_masks_dir = None
3858
+ if masks_folder is not None or masks_file is not None:
3859
+ output_masks_dir = os.path.join(output_folder, "masks")
3860
+ os.makedirs(output_masks_dir, exist_ok=True)
3861
+
3862
+ # Create annotation directory based on metadata format (only if masks are provided)
3863
+ ann_dir = None
3864
+ if (masks_folder is not None or masks_file is not None) and metadata_format in [
3865
+ "PASCAL_VOC",
3866
+ "COCO",
3867
+ ]:
3868
+ ann_dir = os.path.join(output_folder, "annotations")
3869
+ os.makedirs(ann_dir, exist_ok=True)
3870
+
3871
+ # Initialize COCO annotations dictionary (only if masks are provided)
3872
+ coco_annotations = None
3873
+ if (
3874
+ masks_folder is not None or masks_file is not None
3875
+ ) and metadata_format == "COCO":
3876
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
3877
+
3878
+ # Initialize YOLO class set (only if masks are provided)
3879
+ yolo_classes = (
3880
+ set()
3881
+ if (masks_folder is not None or masks_file is not None)
3882
+ and metadata_format == "YOLO"
3883
+ else None
3884
+ )
3255
3885
 
3256
3886
  # Get list of image files
3257
3887
  image_files = []
@@ -3269,10 +3899,16 @@ def export_geotiff_tiles_batch(
3269
3899
 
3270
3900
  # Handle different mask input modes
3271
3901
  use_single_mask_file = masks_file is not None
3902
+ has_masks = masks_file is not None or masks_folder is not None
3272
3903
  mask_files = []
3273
3904
  image_mask_pairs = []
3274
3905
 
3275
- if use_single_mask_file:
3906
+ if not has_masks:
3907
+ # Mode 0: No masks - create pairs with None for mask
3908
+ for image_file in image_files:
3909
+ image_mask_pairs.append((image_file, None, None))
3910
+
3911
+ elif use_single_mask_file:
3276
3912
  # Mode 1: Single vector file covering all images
3277
3913
  if not os.path.exists(masks_file):
3278
3914
  raise ValueError(f"Mask file not found: {masks_file}")
@@ -3324,10 +3960,21 @@ def export_geotiff_tiles_batch(
3324
3960
  print(f"Warning: No mask found for image {img_base}")
3325
3961
 
3326
3962
  if not image_mask_pairs:
3327
- raise ValueError(
3963
+ # Provide detailed error message with found files
3964
+ image_bases = list(image_dict.keys())
3965
+ mask_bases = list(mask_dict.keys())
3966
+ error_msg = (
3328
3967
  "No matching image-mask pairs found when matching by filename. "
3329
- "Check that image and mask files have matching base names."
3968
+ "Check that image and mask files have matching base names.\n"
3969
+ f"Found {len(image_bases)} image(s): "
3970
+ f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
3971
+ f"{'...' if len(image_bases) > 5 else ''}\n"
3972
+ f"Found {len(mask_bases)} mask(s): "
3973
+ f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
3974
+ f"{'...' if len(mask_bases) > 5 else ''}\n"
3975
+ "Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
3330
3976
  )
3977
+ raise ValueError(error_msg)
3331
3978
 
3332
3979
  else:
3333
3980
  # Match by sorted order
@@ -3354,7 +4001,11 @@ def export_geotiff_tiles_batch(
3354
4001
  }
3355
4002
 
3356
4003
  if not quiet:
3357
- if use_single_mask_file:
4004
+ if not has_masks:
4005
+ print(
4006
+ f"Found {len(image_files)} image files to process (images only, no masks)"
4007
+ )
4008
+ elif use_single_mask_file:
3358
4009
  print(f"Found {len(image_files)} image files to process")
3359
4010
  print(f"Using single mask file: {masks_file}")
3360
4011
  else:
@@ -3383,10 +4034,15 @@ def export_geotiff_tiles_batch(
3383
4034
  if not quiet:
3384
4035
  print(f"\nProcessing: {base_name}")
3385
4036
  print(f" Image: {os.path.basename(image_file)}")
3386
- if use_single_mask_file:
3387
- print(f" Mask: {os.path.basename(mask_file)} (spatially filtered)")
4037
+ if mask_file is not None:
4038
+ if use_single_mask_file:
4039
+ print(
4040
+ f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
4041
+ )
4042
+ else:
4043
+ print(f" Mask: {os.path.basename(mask_file)}")
3388
4044
  else:
3389
- print(f" Mask: {os.path.basename(mask_file)}")
4045
+ print(f" Mask: None (images only)")
3390
4046
 
3391
4047
  # Process the image-mask pair
3392
4048
  tiles_generated = _process_image_mask_pair(
@@ -3406,6 +4062,13 @@ def export_geotiff_tiles_batch(
3406
4062
  quiet=quiet,
3407
4063
  mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
3408
4064
  use_single_mask_file=use_single_mask_file,
4065
+ metadata_format=metadata_format,
4066
+ ann_dir=(
4067
+ ann_dir
4068
+ if "ann_dir" in locals()
4069
+ and metadata_format in ["PASCAL_VOC", "COCO"]
4070
+ else None
4071
+ ),
3409
4072
  )
3410
4073
 
3411
4074
  # Update counters
@@ -3427,6 +4090,23 @@ def export_geotiff_tiles_batch(
3427
4090
  }
3428
4091
  )
3429
4092
 
4093
+ # Aggregate COCO annotations
4094
+ if metadata_format == "COCO" and "coco_data" in tiles_generated:
4095
+ coco_data = tiles_generated["coco_data"]
4096
+ # Add images and annotations
4097
+ coco_annotations["images"].extend(coco_data.get("images", []))
4098
+ coco_annotations["annotations"].extend(coco_data.get("annotations", []))
4099
+ # Merge categories (avoid duplicates)
4100
+ for cat in coco_data.get("categories", []):
4101
+ if not any(
4102
+ c["id"] == cat["id"] for c in coco_annotations["categories"]
4103
+ ):
4104
+ coco_annotations["categories"].append(cat)
4105
+
4106
+ # Aggregate YOLO classes
4107
+ if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
4108
+ yolo_classes.update(tiles_generated["yolo_classes"])
4109
+
3430
4110
  except Exception as e:
3431
4111
  if not quiet:
3432
4112
  print(f"ERROR processing {base_name}: {e}")
@@ -3435,6 +4115,33 @@ def export_geotiff_tiles_batch(
3435
4115
  )
3436
4116
  batch_stats["errors"] += 1
3437
4117
 
4118
+ # Save aggregated COCO annotations
4119
+ if metadata_format == "COCO" and coco_annotations:
4120
+ import json
4121
+
4122
+ coco_path = os.path.join(ann_dir, "instances.json")
4123
+ with open(coco_path, "w") as f:
4124
+ json.dump(coco_annotations, f, indent=2)
4125
+ if not quiet:
4126
+ print(f"\nSaved COCO annotations: {coco_path}")
4127
+ print(
4128
+ f" Images: {len(coco_annotations['images'])}, "
4129
+ f"Annotations: {len(coco_annotations['annotations'])}, "
4130
+ f"Categories: {len(coco_annotations['categories'])}"
4131
+ )
4132
+
4133
+ # Save aggregated YOLO classes
4134
+ if metadata_format == "YOLO" and yolo_classes:
4135
+ classes_path = os.path.join(output_folder, "labels", "classes.txt")
4136
+ os.makedirs(os.path.dirname(classes_path), exist_ok=True)
4137
+ sorted_classes = sorted(yolo_classes)
4138
+ with open(classes_path, "w") as f:
4139
+ for cls in sorted_classes:
4140
+ f.write(f"{cls}\n")
4141
+ if not quiet:
4142
+ print(f"\nSaved YOLO classes: {classes_path}")
4143
+ print(f" Total classes: {len(sorted_classes)}")
4144
+
3438
4145
  # Print batch summary
3439
4146
  if not quiet:
3440
4147
  print("\n" + "=" * 60)
@@ -3457,7 +4164,12 @@ def export_geotiff_tiles_batch(
3457
4164
 
3458
4165
  print(f"Output saved to: {output_folder}")
3459
4166
  print(f" Images: {output_images_dir}")
3460
- print(f" Masks: {output_masks_dir}")
4167
+ if output_masks_dir is not None:
4168
+ print(f" Masks: {output_masks_dir}")
4169
+ if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
4170
+ print(f" Annotations: {ann_dir}")
4171
+ elif metadata_format == "YOLO":
4172
+ print(f" Labels: {os.path.join(output_folder, 'labels')}")
3461
4173
 
3462
4174
  # List failed files if any
3463
4175
  if batch_stats["failed_files"]:
@@ -3485,6 +4197,8 @@ def _process_image_mask_pair(
3485
4197
  quiet=False,
3486
4198
  mask_gdf=None,
3487
4199
  use_single_mask_file=False,
4200
+ metadata_format="PASCAL_VOC",
4201
+ ann_dir=None,
3488
4202
  ):
3489
4203
  """
3490
4204
  Process a single image-mask pair and save tiles directly to output directories.
@@ -3498,9 +4212,9 @@ def _process_image_mask_pair(
3498
4212
  """
3499
4213
  import warnings
3500
4214
 
3501
- # Determine if mask data is raster or vector
4215
+ # Determine if mask data is raster or vector (only if mask_file is provided)
3502
4216
  is_class_data_raster = False
3503
- if isinstance(mask_file, str):
4217
+ if mask_file is not None and isinstance(mask_file, str):
3504
4218
  file_ext = Path(mask_file).suffix.lower()
3505
4219
  # Common raster extensions
3506
4220
  if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
@@ -3517,6 +4231,13 @@ def _process_image_mask_pair(
3517
4231
  "errors": 0,
3518
4232
  }
3519
4233
 
4234
+ # Initialize COCO/YOLO tracking for this image
4235
+ if metadata_format == "COCO":
4236
+ stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
4237
+ coco_ann_id = 0
4238
+ if metadata_format == "YOLO":
4239
+ stats["yolo_classes"] = set()
4240
+
3520
4241
  # Open the input raster
3521
4242
  with rasterio.open(image_file) as src:
3522
4243
  # Calculate number of tiles
@@ -3527,10 +4248,10 @@ def _process_image_mask_pair(
3527
4248
  if max_tiles is None:
3528
4249
  max_tiles = total_tiles
3529
4250
 
3530
- # Process classification data
4251
+ # Process classification data (only if mask_file is provided)
3531
4252
  class_to_id = {}
3532
4253
 
3533
- if is_class_data_raster:
4254
+ if mask_file is not None and is_class_data_raster:
3534
4255
  # Load raster class data
3535
4256
  with rasterio.open(mask_file) as class_src:
3536
4257
  # Check if raster CRS matches
@@ -3557,7 +4278,7 @@ def _process_image_mask_pair(
3557
4278
 
3558
4279
  # Create class mapping
3559
4280
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
3560
- else:
4281
+ elif mask_file is not None:
3561
4282
  # Load vector class data
3562
4283
  try:
3563
4284
  if use_single_mask_file and mask_gdf is not None:
@@ -3609,9 +4330,6 @@ def _process_image_mask_pair(
3609
4330
  tile_index = 0
3610
4331
  for y in range(num_tiles_y):
3611
4332
  for x in range(num_tiles_x):
3612
- if tile_index >= max_tiles:
3613
- break
3614
-
3615
4333
  # Calculate window coordinates
3616
4334
  window_x = x * stride
3617
4335
  window_y = y * stride
@@ -3636,12 +4354,12 @@ def _process_image_mask_pair(
3636
4354
 
3637
4355
  window_bounds = box(minx, miny, maxx, maxy)
3638
4356
 
3639
- # Create label mask
4357
+ # Create label mask (only if mask_file is provided)
3640
4358
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
3641
4359
  has_features = False
3642
4360
 
3643
- # Process classification data to create labels
3644
- if is_class_data_raster:
4361
+ # Process classification data to create labels (only if mask_file is provided)
4362
+ if mask_file is not None and is_class_data_raster:
3645
4363
  # For raster class data
3646
4364
  with rasterio.open(mask_file) as class_src:
3647
4365
  # Get corresponding window in class raster
@@ -3674,7 +4392,7 @@ def _process_image_mask_pair(
3674
4392
  if not quiet:
3675
4393
  print(f"Error reading class raster window: {e}")
3676
4394
  stats["errors"] += 1
3677
- else:
4395
+ elif mask_file is not None:
3678
4396
  # For vector class data
3679
4397
  # Find features that intersect with window
3680
4398
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -3712,11 +4430,14 @@ def _process_image_mask_pair(
3712
4430
  print(f"Error rasterizing feature {idx}: {e}")
3713
4431
  stats["errors"] += 1
3714
4432
 
3715
- # Skip tile if no features and skip_empty_tiles is True
3716
- if skip_empty_tiles and not has_features:
3717
- tile_index += 1
4433
+ # Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
4434
+ if mask_file is not None and skip_empty_tiles and not has_features:
3718
4435
  continue
3719
4436
 
4437
+ # Check if we've reached max_tiles before saving
4438
+ if tile_index >= max_tiles:
4439
+ break
4440
+
3720
4441
  # Generate unique tile name
3721
4442
  tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
3722
4443
 
@@ -3747,29 +4468,225 @@ def _process_image_mask_pair(
3747
4468
  print(f"ERROR saving image GeoTIFF: {e}")
3748
4469
  stats["errors"] += 1
3749
4470
 
3750
- # Create profile for label GeoTIFF
3751
- label_profile = {
3752
- "driver": "GTiff",
3753
- "height": tile_size,
3754
- "width": tile_size,
3755
- "count": 1,
3756
- "dtype": "uint8",
3757
- "crs": src.crs,
3758
- "transform": window_transform,
3759
- }
4471
+ # Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
4472
+ if mask_file is not None and output_masks_dir is not None:
4473
+ # Create profile for label GeoTIFF
4474
+ label_profile = {
4475
+ "driver": "GTiff",
4476
+ "height": tile_size,
4477
+ "width": tile_size,
4478
+ "count": 1,
4479
+ "dtype": "uint8",
4480
+ "crs": src.crs,
4481
+ "transform": window_transform,
4482
+ }
3760
4483
 
3761
- # Export label as GeoTIFF
3762
- label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
3763
- try:
3764
- with rasterio.open(label_path, "w", **label_profile) as dst:
3765
- dst.write(label_mask.astype(np.uint8), 1)
4484
+ label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
4485
+ try:
4486
+ with rasterio.open(label_path, "w", **label_profile) as dst:
4487
+ dst.write(label_mask.astype(np.uint8), 1)
3766
4488
 
3767
- if has_features:
3768
- stats["tiles_with_features"] += 1
3769
- except Exception as e:
3770
- if not quiet:
3771
- print(f"ERROR saving label GeoTIFF: {e}")
3772
- stats["errors"] += 1
4489
+ if has_features:
4490
+ stats["tiles_with_features"] += 1
4491
+ except Exception as e:
4492
+ if not quiet:
4493
+ print(f"ERROR saving label GeoTIFF: {e}")
4494
+ stats["errors"] += 1
4495
+
4496
+ # Generate annotation metadata based on format (only if mask_file is provided)
4497
+ if (
4498
+ mask_file is not None
4499
+ and metadata_format == "PASCAL_VOC"
4500
+ and ann_dir
4501
+ ):
4502
+ # Create PASCAL VOC XML annotation
4503
+ from lxml import etree as ET
4504
+
4505
+ annotation = ET.Element("annotation")
4506
+ ET.SubElement(annotation, "folder").text = os.path.basename(
4507
+ output_images_dir
4508
+ )
4509
+ ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
4510
+ ET.SubElement(annotation, "path").text = image_path
4511
+
4512
+ source = ET.SubElement(annotation, "source")
4513
+ ET.SubElement(source, "database").text = "GeoAI"
4514
+
4515
+ size = ET.SubElement(annotation, "size")
4516
+ ET.SubElement(size, "width").text = str(tile_size)
4517
+ ET.SubElement(size, "height").text = str(tile_size)
4518
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
4519
+
4520
+ ET.SubElement(annotation, "segmented").text = "1"
4521
+
4522
+ # Find connected components for instance segmentation
4523
+ from scipy import ndimage
4524
+
4525
+ for class_id in np.unique(label_mask):
4526
+ if class_id == 0:
4527
+ continue
4528
+
4529
+ class_mask = (label_mask == class_id).astype(np.uint8)
4530
+ labeled_array, num_features = ndimage.label(class_mask)
4531
+
4532
+ for instance_id in range(1, num_features + 1):
4533
+ instance_mask = labeled_array == instance_id
4534
+ coords = np.argwhere(instance_mask)
4535
+
4536
+ if len(coords) == 0:
4537
+ continue
4538
+
4539
+ ymin, xmin = coords.min(axis=0)
4540
+ ymax, xmax = coords.max(axis=0)
4541
+
4542
+ obj = ET.SubElement(annotation, "object")
4543
+ class_name = next(
4544
+ (k for k, v in class_to_id.items() if v == class_id),
4545
+ str(class_id),
4546
+ )
4547
+ ET.SubElement(obj, "name").text = str(class_name)
4548
+ ET.SubElement(obj, "pose").text = "Unspecified"
4549
+ ET.SubElement(obj, "truncated").text = "0"
4550
+ ET.SubElement(obj, "difficult").text = "0"
4551
+
4552
+ bndbox = ET.SubElement(obj, "bndbox")
4553
+ ET.SubElement(bndbox, "xmin").text = str(int(xmin))
4554
+ ET.SubElement(bndbox, "ymin").text = str(int(ymin))
4555
+ ET.SubElement(bndbox, "xmax").text = str(int(xmax))
4556
+ ET.SubElement(bndbox, "ymax").text = str(int(ymax))
4557
+
4558
+ # Save XML file
4559
+ xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
4560
+ tree = ET.ElementTree(annotation)
4561
+ tree.write(xml_path, pretty_print=True, encoding="utf-8")
4562
+
4563
+ elif mask_file is not None and metadata_format == "COCO":
4564
+ # Add COCO image entry
4565
+ image_id = int(global_tile_counter + tile_index)
4566
+ stats["coco_data"]["images"].append(
4567
+ {
4568
+ "id": image_id,
4569
+ "file_name": f"{tile_name}.tif",
4570
+ "width": int(tile_size),
4571
+ "height": int(tile_size),
4572
+ }
4573
+ )
4574
+
4575
+ # Add COCO categories (only once per unique class)
4576
+ for class_val, class_id in class_to_id.items():
4577
+ if not any(
4578
+ c["id"] == class_id
4579
+ for c in stats["coco_data"]["categories"]
4580
+ ):
4581
+ stats["coco_data"]["categories"].append(
4582
+ {
4583
+ "id": int(class_id),
4584
+ "name": str(class_val),
4585
+ "supercategory": "object",
4586
+ }
4587
+ )
4588
+
4589
+ # Add COCO annotations (instance segmentation)
4590
+ from scipy import ndimage
4591
+ from skimage import measure
4592
+
4593
+ for class_id in np.unique(label_mask):
4594
+ if class_id == 0:
4595
+ continue
4596
+
4597
+ class_mask = (label_mask == class_id).astype(np.uint8)
4598
+ labeled_array, num_features = ndimage.label(class_mask)
4599
+
4600
+ for instance_id in range(1, num_features + 1):
4601
+ instance_mask = (labeled_array == instance_id).astype(
4602
+ np.uint8
4603
+ )
4604
+ coords = np.argwhere(instance_mask)
4605
+
4606
+ if len(coords) == 0:
4607
+ continue
4608
+
4609
+ ymin, xmin = coords.min(axis=0)
4610
+ ymax, xmax = coords.max(axis=0)
4611
+
4612
+ bbox = [
4613
+ int(xmin),
4614
+ int(ymin),
4615
+ int(xmax - xmin),
4616
+ int(ymax - ymin),
4617
+ ]
4618
+ area = int(np.sum(instance_mask))
4619
+
4620
+ # Find contours for segmentation
4621
+ contours = measure.find_contours(instance_mask, 0.5)
4622
+ segmentation = []
4623
+ for contour in contours:
4624
+ contour = np.flip(contour, axis=1)
4625
+ segmentation_points = contour.ravel().tolist()
4626
+ if len(segmentation_points) >= 6:
4627
+ segmentation.append(segmentation_points)
4628
+
4629
+ if segmentation:
4630
+ stats["coco_data"]["annotations"].append(
4631
+ {
4632
+ "id": int(coco_ann_id),
4633
+ "image_id": int(image_id),
4634
+ "category_id": int(class_id),
4635
+ "bbox": bbox,
4636
+ "area": area,
4637
+ "segmentation": segmentation,
4638
+ "iscrowd": 0,
4639
+ }
4640
+ )
4641
+ coco_ann_id += 1
4642
+
4643
+ elif mask_file is not None and metadata_format == "YOLO":
4644
+ # Create YOLO labels directory if needed
4645
+ labels_dir = os.path.join(
4646
+ os.path.dirname(output_images_dir), "labels"
4647
+ )
4648
+ os.makedirs(labels_dir, exist_ok=True)
4649
+
4650
+ # Generate YOLO annotation file
4651
+ yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
4652
+ from scipy import ndimage
4653
+
4654
+ with open(yolo_path, "w") as yolo_file:
4655
+ for class_id in np.unique(label_mask):
4656
+ if class_id == 0:
4657
+ continue
4658
+
4659
+ # Track class for classes.txt
4660
+ class_name = next(
4661
+ (k for k, v in class_to_id.items() if v == class_id),
4662
+ str(class_id),
4663
+ )
4664
+ stats["yolo_classes"].add(class_name)
4665
+
4666
+ class_mask = (label_mask == class_id).astype(np.uint8)
4667
+ labeled_array, num_features = ndimage.label(class_mask)
4668
+
4669
+ for instance_id in range(1, num_features + 1):
4670
+ instance_mask = labeled_array == instance_id
4671
+ coords = np.argwhere(instance_mask)
4672
+
4673
+ if len(coords) == 0:
4674
+ continue
4675
+
4676
+ ymin, xmin = coords.min(axis=0)
4677
+ ymax, xmax = coords.max(axis=0)
4678
+
4679
+ # Convert to YOLO format (normalized center coordinates)
4680
+ x_center = ((xmin + xmax) / 2) / tile_size
4681
+ y_center = ((ymin + ymax) / 2) / tile_size
4682
+ width = (xmax - xmin) / tile_size
4683
+ height = (ymax - ymin) / tile_size
4684
+
4685
+ # YOLO uses 0-based class indices
4686
+ yolo_class_id = class_id - 1
4687
+ yolo_file.write(
4688
+ f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
4689
+ )
3773
4690
 
3774
4691
  tile_index += 1
3775
4692
  if tile_index >= max_tiles:
@@ -3781,6 +4698,179 @@ def _process_image_mask_pair(
3781
4698
  return stats
3782
4699
 
3783
4700
 
4701
+ def display_training_tiles(
4702
+ output_dir,
4703
+ num_tiles=6,
4704
+ figsize=(18, 6),
4705
+ cmap="gray",
4706
+ save_path=None,
4707
+ ):
4708
+ """
4709
+ Display image and mask tile pairs from training data output.
4710
+
4711
+ Args:
4712
+ output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
4713
+ num_tiles (int): Number of tile pairs to display (default: 6)
4714
+ figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
4715
+ cmap (str): Colormap for mask display (default: 'gray')
4716
+ save_path (str, optional): If provided, save figure to this path instead of displaying
4717
+
4718
+ Returns:
4719
+ tuple: (fig, axes) matplotlib figure and axes objects
4720
+
4721
+ Example:
4722
+ >>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
4723
+ >>> # Or save to file
4724
+ >>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
4725
+ """
4726
+ import matplotlib.pyplot as plt
4727
+
4728
+ # Get list of image tiles
4729
+ images_dir = os.path.join(output_dir, "images")
4730
+ if not os.path.exists(images_dir):
4731
+ raise ValueError(f"Images directory not found: {images_dir}")
4732
+
4733
+ image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
4734
+
4735
+ if not image_tiles:
4736
+ raise ValueError(f"No image tiles found in {images_dir}")
4737
+
4738
+ # Limit to available tiles
4739
+ num_tiles = min(num_tiles, len(image_tiles))
4740
+
4741
+ # Create figure with subplots
4742
+ fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
4743
+
4744
+ # Handle case where num_tiles is 1
4745
+ if num_tiles == 1:
4746
+ axes = axes.reshape(2, 1)
4747
+
4748
+ for idx, tile_name in enumerate(image_tiles):
4749
+ # Load and display image tile
4750
+ image_path = os.path.join(output_dir, "images", tile_name)
4751
+ with rasterio.open(image_path) as src:
4752
+ show(src, ax=axes[0, idx], title=f"Image {idx+1}")
4753
+
4754
+ # Load and display mask tile
4755
+ mask_path = os.path.join(output_dir, "masks", tile_name)
4756
+ if os.path.exists(mask_path):
4757
+ with rasterio.open(mask_path) as src:
4758
+ show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
4759
+ else:
4760
+ axes[1, idx].text(
4761
+ 0.5,
4762
+ 0.5,
4763
+ "Mask not found",
4764
+ ha="center",
4765
+ va="center",
4766
+ transform=axes[1, idx].transAxes,
4767
+ )
4768
+ axes[1, idx].set_title(f"Mask {idx+1}")
4769
+
4770
+ plt.tight_layout()
4771
+
4772
+ # Save or show
4773
+ if save_path:
4774
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
4775
+ plt.close(fig)
4776
+ print(f"Figure saved to: {save_path}")
4777
+ else:
4778
+ plt.show()
4779
+
4780
+ return fig, axes
4781
+
4782
+
4783
+ def display_image_with_vector(
4784
+ image_path,
4785
+ vector_path,
4786
+ figsize=(16, 8),
4787
+ vector_color="red",
4788
+ vector_linewidth=1,
4789
+ vector_facecolor="none",
4790
+ save_path=None,
4791
+ ):
4792
+ """
4793
+ Display a raster image alongside the same image with vector overlay.
4794
+
4795
+ Args:
4796
+ image_path (str): Path to raster image file
4797
+ vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
4798
+ figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
4799
+ vector_color (str): Edge color for vector features (default: 'red')
4800
+ vector_linewidth (float): Line width for vector features (default: 1)
4801
+ vector_facecolor (str): Fill color for vector features (default: 'none')
4802
+ save_path (str, optional): If provided, save figure to this path instead of displaying
4803
+
4804
+ Returns:
4805
+ tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
4806
+
4807
+ Example:
4808
+ >>> fig, axes, info = display_image_with_vector(
4809
+ ... 'image.tif',
4810
+ ... 'buildings.geojson',
4811
+ ... vector_color='blue'
4812
+ ... )
4813
+ >>> print(f"Number of features: {info['num_features']}")
4814
+ """
4815
+ import matplotlib.pyplot as plt
4816
+
4817
+ # Validate inputs
4818
+ if not os.path.exists(image_path):
4819
+ raise ValueError(f"Image file not found: {image_path}")
4820
+ if not os.path.exists(vector_path):
4821
+ raise ValueError(f"Vector file not found: {vector_path}")
4822
+
4823
+ # Create figure
4824
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
4825
+
4826
+ # Load and display image
4827
+ with rasterio.open(image_path) as src:
4828
+ # Plot image only
4829
+ show(src, ax=ax1, title="Image")
4830
+
4831
+ # Load vector data
4832
+ vector_data = gpd.read_file(vector_path)
4833
+
4834
+ # Reproject to image CRS if needed
4835
+ if vector_data.crs != src.crs:
4836
+ vector_data = vector_data.to_crs(src.crs)
4837
+
4838
+ # Plot image with vector overlay
4839
+ show(
4840
+ src,
4841
+ ax=ax2,
4842
+ title=f"Image with {len(vector_data)} Vector Features",
4843
+ )
4844
+ vector_data.plot(
4845
+ ax=ax2,
4846
+ facecolor=vector_facecolor,
4847
+ edgecolor=vector_color,
4848
+ linewidth=vector_linewidth,
4849
+ )
4850
+
4851
+ # Collect metadata
4852
+ info = {
4853
+ "image_shape": src.shape,
4854
+ "image_crs": src.crs,
4855
+ "image_bounds": src.bounds,
4856
+ "num_features": len(vector_data),
4857
+ "vector_crs": vector_data.crs,
4858
+ "vector_bounds": vector_data.total_bounds,
4859
+ }
4860
+
4861
+ plt.tight_layout()
4862
+
4863
+ # Save or show
4864
+ if save_path:
4865
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
4866
+ plt.close(fig)
4867
+ print(f"Figure saved to: {save_path}")
4868
+ else:
4869
+ plt.show()
4870
+
4871
+ return fig, (ax1, ax2), info
4872
+
4873
+
3784
4874
  def create_overview_image(
3785
4875
  src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
3786
4876
  ) -> str:
@@ -7675,17 +8765,39 @@ def write_colormap(
7675
8765
 
7676
8766
  def plot_performance_metrics(
7677
8767
  history_path: str,
7678
- figsize: Tuple[int, int] = (15, 5),
8768
+ figsize: Optional[Tuple[int, int]] = None,
7679
8769
  verbose: bool = True,
7680
8770
  save_path: Optional[str] = None,
8771
+ csv_path: Optional[str] = None,
7681
8772
  kwargs: Optional[Dict] = None,
7682
- ) -> None:
7683
- """Plot performance metrics from a history object.
8773
+ ) -> pd.DataFrame:
8774
+ """Plot performance metrics from a training history object and return as DataFrame.
8775
+
8776
+ This function loads training history, plots available metrics (loss, IoU, F1,
8777
+ precision, recall), optionally exports to CSV, and returns all metrics as a
8778
+ pandas DataFrame for further analysis.
7684
8779
 
7685
8780
  Args:
7686
- history_path: The history object to plot.
7687
- figsize: The figure size.
7688
- verbose: Whether to print the best and final metrics.
8781
+ history_path (str): Path to the saved training history (.pth file).
8782
+ figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
8783
+ automatically determined based on number of metrics.
8784
+ verbose (bool): Whether to print best and final metric values. Defaults to True.
8785
+ save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
8786
+ csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
8787
+ kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
8788
+
8789
+ Returns:
8790
+ pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
8791
+ Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
8792
+ 'val_precision', 'val_recall' (depending on availability in history).
8793
+
8794
+ Example:
8795
+ >>> df = plot_performance_metrics(
8796
+ ... 'training_history.pth',
8797
+ ... save_path='metrics_plot.png',
8798
+ ... csv_path='metrics.csv'
8799
+ ... )
8800
+ >>> print(df.head())
7689
8801
  """
7690
8802
  if kwargs is None:
7691
8803
  kwargs = {}
@@ -7695,65 +8807,135 @@ def plot_performance_metrics(
7695
8807
  train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
7696
8808
  val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
7697
8809
  val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
7698
- val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
8810
+ # Support both new (f1) and old (dice) key formats for backward compatibility
8811
+ val_f1_key = (
8812
+ "val_f1s"
8813
+ if "val_f1s" in history
8814
+ else ("val_dices" if "val_dices" in history else "val_dice")
8815
+ )
8816
+ # Add support for precision and recall
8817
+ val_precision_key = (
8818
+ "val_precisions" if "val_precisions" in history else "val_precision"
8819
+ )
8820
+ val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
8821
+
8822
+ # Collect available metrics for plotting
8823
+ available_metrics = []
8824
+ metric_info = {
8825
+ "Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
8826
+ "IoU": (val_iou_key, None, ["Val IoU"]),
8827
+ "F1": (val_f1_key, None, ["Val F1"]),
8828
+ "Precision": (val_precision_key, None, ["Val Precision"]),
8829
+ "Recall": (val_recall_key, None, ["Val Recall"]),
8830
+ }
8831
+
8832
+ for metric_name, (key1, key2, labels) in metric_info.items():
8833
+ if key1 in history or (key2 and key2 in history):
8834
+ available_metrics.append((metric_name, key1, key2, labels))
7699
8835
 
7700
- # Determine number of subplots based on available metrics
7701
- has_dice = val_dice_key in history
7702
- n_plots = 3 if has_dice else 2
7703
- figsize = (15, 5) if has_dice else (10, 5)
8836
+ # Determine number of subplots and figure size
8837
+ n_plots = len(available_metrics)
8838
+ if figsize is None:
8839
+ figsize = (5 * n_plots, 5)
7704
8840
 
7705
- plt.figure(figsize=figsize)
8841
+ # Create DataFrame for all metrics
8842
+ n_epochs = 0
8843
+ df_data = {}
7706
8844
 
7707
- # Plot loss
7708
- plt.subplot(1, n_plots, 1)
8845
+ # Add epochs
8846
+ if "epochs" in history:
8847
+ df_data["epoch"] = history["epochs"]
8848
+ n_epochs = len(history["epochs"])
8849
+ elif train_loss_key in history:
8850
+ n_epochs = len(history[train_loss_key])
8851
+ df_data["epoch"] = list(range(1, n_epochs + 1))
8852
+
8853
+ # Add all available metrics to DataFrame
7709
8854
  if train_loss_key in history:
7710
- plt.plot(history[train_loss_key], label="Train Loss")
8855
+ df_data["train_loss"] = history[train_loss_key]
7711
8856
  if val_loss_key in history:
7712
- plt.plot(history[val_loss_key], label="Val Loss")
7713
- plt.title("Loss")
7714
- plt.xlabel("Epoch")
7715
- plt.ylabel("Loss")
7716
- plt.legend()
7717
- plt.grid(True)
7718
-
7719
- # Plot IoU
7720
- plt.subplot(1, n_plots, 2)
8857
+ df_data["val_loss"] = history[val_loss_key]
7721
8858
  if val_iou_key in history:
7722
- plt.plot(history[val_iou_key], label="Val IoU")
7723
- plt.title("IoU Score")
7724
- plt.xlabel("Epoch")
7725
- plt.ylabel("IoU")
7726
- plt.legend()
7727
- plt.grid(True)
7728
-
7729
- # Plot Dice if available
7730
- if has_dice:
7731
- plt.subplot(1, n_plots, 3)
7732
- plt.plot(history[val_dice_key], label="Val Dice")
7733
- plt.title("Dice Score")
7734
- plt.xlabel("Epoch")
7735
- plt.ylabel("Dice")
7736
- plt.legend()
7737
- plt.grid(True)
8859
+ df_data["val_iou"] = history[val_iou_key]
8860
+ if val_f1_key in history:
8861
+ df_data["val_f1"] = history[val_f1_key]
8862
+ if val_precision_key in history:
8863
+ df_data["val_precision"] = history[val_precision_key]
8864
+ if val_recall_key in history:
8865
+ df_data["val_recall"] = history[val_recall_key]
8866
+
8867
+ # Create DataFrame
8868
+ df = pd.DataFrame(df_data)
8869
+
8870
+ # Export to CSV if requested
8871
+ if csv_path:
8872
+ df.to_csv(csv_path, index=False)
8873
+ if verbose:
8874
+ print(f"Metrics exported to: {csv_path}")
8875
+
8876
+ # Create plots
8877
+ if n_plots > 0:
8878
+ fig, axes = plt.subplots(1, n_plots, figsize=figsize)
8879
+ if n_plots == 1:
8880
+ axes = [axes]
8881
+
8882
+ for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
8883
+ ax = axes[idx]
8884
+
8885
+ if metric_name == "Loss":
8886
+ # Special handling for loss (has both train and val)
8887
+ if key1 in history:
8888
+ ax.plot(history[key1], label=labels[0])
8889
+ if key2 and key2 in history:
8890
+ ax.plot(history[key2], label=labels[1])
8891
+ else:
8892
+ # Single metric plots
8893
+ if key1 in history:
8894
+ ax.plot(history[key1], label=labels[0])
7738
8895
 
7739
- plt.tight_layout()
8896
+ ax.set_title(metric_name)
8897
+ ax.set_xlabel("Epoch")
8898
+ ax.set_ylabel(metric_name)
8899
+ ax.legend()
8900
+ ax.grid(True)
7740
8901
 
7741
- if save_path:
7742
- if "dpi" not in kwargs:
7743
- kwargs["dpi"] = 150
7744
- if "bbox_inches" not in kwargs:
7745
- kwargs["bbox_inches"] = "tight"
7746
- plt.savefig(save_path, **kwargs)
8902
+ plt.tight_layout()
7747
8903
 
7748
- plt.show()
8904
+ if save_path:
8905
+ if "dpi" not in kwargs:
8906
+ kwargs["dpi"] = 150
8907
+ if "bbox_inches" not in kwargs:
8908
+ kwargs["bbox_inches"] = "tight"
8909
+ plt.savefig(save_path, **kwargs)
7749
8910
 
8911
+ plt.show()
8912
+
8913
+ # Print summary statistics
7750
8914
  if verbose:
8915
+ print("\n=== Performance Metrics Summary ===")
7751
8916
  if val_iou_key in history:
7752
- print(f"Best IoU: {max(history[val_iou_key]):.4f}")
7753
- print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
7754
- if val_dice_key in history:
7755
- print(f"Best Dice: {max(history[val_dice_key]):.4f}")
7756
- print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
8917
+ print(
8918
+ f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
8919
+ )
8920
+ if val_f1_key in history:
8921
+ print(
8922
+ f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
8923
+ )
8924
+ if val_precision_key in history:
8925
+ print(
8926
+ f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
8927
+ )
8928
+ if val_recall_key in history:
8929
+ print(
8930
+ f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
8931
+ )
8932
+ if val_loss_key in history:
8933
+ print(
8934
+ f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
8935
+ )
8936
+ print("===================================\n")
8937
+
8938
+ return df
7757
8939
 
7758
8940
 
7759
8941
  def get_device() -> torch.device: