geoai-py 0.15.0__py2.py3-none-any.whl → 0.17.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -383,6 +383,394 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
383
383
  return accum_mean / len(files), accum_std / len(files)
384
384
 
385
385
 
386
+ def calc_iou(
387
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
388
+ prediction: Union[str, np.ndarray, torch.Tensor],
389
+ num_classes: Optional[int] = None,
390
+ ignore_index: Optional[int] = None,
391
+ smooth: float = 1e-6,
392
+ band: int = 1,
393
+ ) -> Union[float, np.ndarray]:
394
+ """
395
+ Calculate Intersection over Union (IoU) between ground truth and prediction masks.
396
+
397
+ This function computes the IoU metric for segmentation tasks. It supports both
398
+ binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
399
+ or file paths to raster files.
400
+
401
+ Args:
402
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
403
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
404
+ For binary segmentation: shape (H, W) with values {0, 1}.
405
+ For multi-class segmentation: shape (H, W) with class indices.
406
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
407
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
408
+ Should have the same shape and format as ground_truth.
409
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
410
+ If None, assumes binary segmentation. Defaults to None.
411
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
412
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
413
+ smooth (float, optional): Smoothing factor to avoid division by zero.
414
+ Defaults to 1e-6.
415
+ band (int, optional): Band index to read from raster file (1-based indexing).
416
+ Only used when input is a file path. Defaults to 1.
417
+
418
+ Returns:
419
+ Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
420
+ For multi-class segmentation, returns an array of IoU scores for each class.
421
+
422
+ Examples:
423
+ >>> # Binary segmentation with arrays
424
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
425
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
426
+ >>> iou = calc_iou(gt, pred)
427
+ >>> print(f"IoU: {iou:.4f}")
428
+ IoU: 0.8333
429
+
430
+ >>> # Multi-class segmentation
431
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
432
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
433
+ >>> iou = calc_iou(gt, pred, num_classes=3)
434
+ >>> print(f"IoU per class: {iou}")
435
+ IoU per class: [0.8333 0.5000 0.5000]
436
+
437
+ >>> # Using PyTorch tensors
438
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
439
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
440
+ >>> iou = calc_iou(gt_tensor, pred_tensor)
441
+ >>> print(f"IoU: {iou:.4f}")
442
+ IoU: 0.8333
443
+
444
+ >>> # Using raster file paths
445
+ >>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
446
+ >>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
447
+ Mean IoU: 0.7500
448
+ """
449
+ # Load from file if string path is provided
450
+ if isinstance(ground_truth, str):
451
+ with rasterio.open(ground_truth) as src:
452
+ ground_truth = src.read(band)
453
+ if isinstance(prediction, str):
454
+ with rasterio.open(prediction) as src:
455
+ prediction = src.read(band)
456
+
457
+ # Convert to numpy if torch tensor
458
+ if isinstance(ground_truth, torch.Tensor):
459
+ ground_truth = ground_truth.cpu().numpy()
460
+ if isinstance(prediction, torch.Tensor):
461
+ prediction = prediction.cpu().numpy()
462
+
463
+ # Ensure inputs have the same shape
464
+ if ground_truth.shape != prediction.shape:
465
+ raise ValueError(
466
+ f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
467
+ )
468
+
469
+ # Binary segmentation
470
+ if num_classes is None:
471
+ ground_truth = ground_truth.astype(bool)
472
+ prediction = prediction.astype(bool)
473
+
474
+ intersection = np.logical_and(ground_truth, prediction).sum()
475
+ union = np.logical_or(ground_truth, prediction).sum()
476
+
477
+ if union == 0:
478
+ return 1.0 if intersection == 0 else 0.0
479
+
480
+ iou = (intersection + smooth) / (union + smooth)
481
+ return float(iou)
482
+
483
+ # Multi-class segmentation
484
+ else:
485
+ iou_per_class = []
486
+
487
+ for class_idx in range(num_classes):
488
+ # Handle ignored class by appending np.nan
489
+ if ignore_index is not None and class_idx == ignore_index:
490
+ iou_per_class.append(np.nan)
491
+ continue
492
+
493
+ # Create binary masks for current class
494
+ gt_class = (ground_truth == class_idx).astype(bool)
495
+ pred_class = (prediction == class_idx).astype(bool)
496
+
497
+ intersection = np.logical_and(gt_class, pred_class).sum()
498
+ union = np.logical_or(gt_class, pred_class).sum()
499
+
500
+ if union == 0:
501
+ # If class is not present in both gt and pred
502
+ iou_per_class.append(np.nan)
503
+ else:
504
+ iou_per_class.append((intersection + smooth) / (union + smooth))
505
+
506
+ return np.array(iou_per_class)
507
+
508
+
509
+ def calc_f1_score(
510
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
511
+ prediction: Union[str, np.ndarray, torch.Tensor],
512
+ num_classes: Optional[int] = None,
513
+ ignore_index: Optional[int] = None,
514
+ smooth: float = 1e-6,
515
+ band: int = 1,
516
+ ) -> Union[float, np.ndarray]:
517
+ """
518
+ Calculate F1 score between ground truth and prediction masks.
519
+
520
+ The F1 score is the harmonic mean of precision and recall, computed as:
521
+ F1 = 2 * (precision * recall) / (precision + recall)
522
+ where precision = TP / (TP + FP) and recall = TP / (TP + FN).
523
+
524
+ This function supports both binary and multi-class segmentation, and can handle
525
+ numpy arrays, PyTorch tensors, or file paths to raster files.
526
+
527
+ Args:
528
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
529
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
530
+ For binary segmentation: shape (H, W) with values {0, 1}.
531
+ For multi-class segmentation: shape (H, W) with class indices.
532
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
533
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
534
+ Should have the same shape and format as ground_truth.
535
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
536
+ If None, assumes binary segmentation. Defaults to None.
537
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
538
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
539
+ smooth (float, optional): Smoothing factor to avoid division by zero.
540
+ Defaults to 1e-6.
541
+ band (int, optional): Band index to read from raster file (1-based indexing).
542
+ Only used when input is a file path. Defaults to 1.
543
+
544
+ Returns:
545
+ Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
546
+ For multi-class segmentation, returns an array of F1 scores for each class.
547
+
548
+ Examples:
549
+ >>> # Binary segmentation with arrays
550
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
551
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
552
+ >>> f1 = calc_f1_score(gt, pred)
553
+ >>> print(f"F1 Score: {f1:.4f}")
554
+ F1 Score: 0.8571
555
+
556
+ >>> # Multi-class segmentation
557
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
558
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
559
+ >>> f1 = calc_f1_score(gt, pred, num_classes=3)
560
+ >>> print(f"F1 Score per class: {f1}")
561
+ F1 Score per class: [0.8571 0.6667 0.6667]
562
+
563
+ >>> # Using PyTorch tensors
564
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
565
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
566
+ >>> f1 = calc_f1_score(gt_tensor, pred_tensor)
567
+ >>> print(f"F1 Score: {f1:.4f}")
568
+ F1 Score: 0.8571
569
+
570
+ >>> # Using raster file paths
571
+ >>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
572
+ >>> print(f"Mean F1: {np.nanmean(f1):.4f}")
573
+ Mean F1: 0.7302
574
+ """
575
+ # Load from file if string path is provided
576
+ if isinstance(ground_truth, str):
577
+ with rasterio.open(ground_truth) as src:
578
+ ground_truth = src.read(band)
579
+ if isinstance(prediction, str):
580
+ with rasterio.open(prediction) as src:
581
+ prediction = src.read(band)
582
+
583
+ # Convert to numpy if torch tensor
584
+ if isinstance(ground_truth, torch.Tensor):
585
+ ground_truth = ground_truth.cpu().numpy()
586
+ if isinstance(prediction, torch.Tensor):
587
+ prediction = prediction.cpu().numpy()
588
+
589
+ # Ensure inputs have the same shape
590
+ if ground_truth.shape != prediction.shape:
591
+ raise ValueError(
592
+ f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
593
+ )
594
+
595
+ # Binary segmentation
596
+ if num_classes is None:
597
+ ground_truth = ground_truth.astype(bool)
598
+ prediction = prediction.astype(bool)
599
+
600
+ # Calculate True Positives, False Positives, False Negatives
601
+ tp = np.logical_and(ground_truth, prediction).sum()
602
+ fp = np.logical_and(~ground_truth, prediction).sum()
603
+ fn = np.logical_and(ground_truth, ~prediction).sum()
604
+
605
+ # Calculate precision and recall
606
+ precision = (tp + smooth) / (tp + fp + smooth)
607
+ recall = (tp + smooth) / (tp + fn + smooth)
608
+
609
+ # Calculate F1 score
610
+ f1 = 2 * (precision * recall) / (precision + recall + smooth)
611
+ return float(f1)
612
+
613
+ # Multi-class segmentation
614
+ else:
615
+ f1_per_class = []
616
+
617
+ for class_idx in range(num_classes):
618
+ # Mark ignored class with np.nan
619
+ if ignore_index is not None and class_idx == ignore_index:
620
+ f1_per_class.append(np.nan)
621
+ continue
622
+
623
+ # Create binary masks for current class
624
+ gt_class = (ground_truth == class_idx).astype(bool)
625
+ pred_class = (prediction == class_idx).astype(bool)
626
+
627
+ # Calculate True Positives, False Positives, False Negatives
628
+ tp = np.logical_and(gt_class, pred_class).sum()
629
+ fp = np.logical_and(~gt_class, pred_class).sum()
630
+ fn = np.logical_and(gt_class, ~pred_class).sum()
631
+
632
+ # Calculate precision and recall
633
+ precision = (tp + smooth) / (tp + fp + smooth)
634
+ recall = (tp + smooth) / (tp + fn + smooth)
635
+
636
+ # Calculate F1 score
637
+ if tp + fp + fn == 0:
638
+ # If class is not present in both gt and pred
639
+ f1_per_class.append(np.nan)
640
+ else:
641
+ f1 = 2 * (precision * recall) / (precision + recall + smooth)
642
+ f1_per_class.append(f1)
643
+
644
+ return np.array(f1_per_class)
645
+
646
+
647
+ def calc_segmentation_metrics(
648
+ ground_truth: Union[str, np.ndarray, torch.Tensor],
649
+ prediction: Union[str, np.ndarray, torch.Tensor],
650
+ num_classes: Optional[int] = None,
651
+ ignore_index: Optional[int] = None,
652
+ smooth: float = 1e-6,
653
+ metrics: List[str] = ["iou", "f1"],
654
+ band: int = 1,
655
+ ) -> Dict[str, Union[float, np.ndarray]]:
656
+ """
657
+ Calculate multiple segmentation metrics between ground truth and prediction masks.
658
+
659
+ This is a convenient wrapper function that computes multiple metrics at once,
660
+ including IoU (Intersection over Union) and F1 score. It supports both binary
661
+ and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
662
+ or file paths to raster files.
663
+
664
+ Args:
665
+ ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
666
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
667
+ For binary segmentation: shape (H, W) with values {0, 1}.
668
+ For multi-class segmentation: shape (H, W) with class indices.
669
+ prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
670
+ Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
671
+ Should have the same shape and format as ground_truth.
672
+ num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
673
+ If None, assumes binary segmentation. Defaults to None.
674
+ ignore_index (Optional[int], optional): Class index to ignore in computation.
675
+ Useful for ignoring background or unlabeled pixels. Defaults to None.
676
+ smooth (float, optional): Smoothing factor to avoid division by zero.
677
+ Defaults to 1e-6.
678
+ metrics (List[str], optional): List of metrics to calculate.
679
+ Options: "iou", "f1". Defaults to ["iou", "f1"].
680
+ band (int, optional): Band index to read from raster file (1-based indexing).
681
+ Only used when input is a file path. Defaults to 1.
682
+
683
+ Returns:
684
+ Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
685
+ Keys are metric names ("iou", "f1"), values are the metric scores.
686
+ For binary segmentation, values are floats.
687
+ For multi-class segmentation, values are numpy arrays with per-class scores.
688
+ Also includes "mean_iou" and "mean_f1" for multi-class segmentation
689
+ (mean computed over valid classes, ignoring NaN values).
690
+
691
+ Examples:
692
+ >>> # Binary segmentation with arrays
693
+ >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
694
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
695
+ >>> metrics = calc_segmentation_metrics(gt, pred)
696
+ >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
697
+ IoU: 0.8333, F1: 0.8571
698
+
699
+ >>> # Multi-class segmentation
700
+ >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
701
+ >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
702
+ >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
703
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
704
+ >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
705
+ >>> print(f"Per-class IoU: {metrics['iou']}")
706
+ Mean IoU: 0.6111
707
+ Mean F1: 0.7302
708
+ Per-class IoU: [0.8333 0.5000 0.5000]
709
+
710
+ >>> # Calculate only IoU
711
+ >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
712
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
713
+ Mean IoU: 0.6111
714
+
715
+ >>> # Using PyTorch tensors
716
+ >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
717
+ >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
718
+ >>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
719
+ >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
720
+ IoU: 0.8333, F1: 0.8571
721
+
722
+ >>> # Using raster file paths
723
+ >>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
724
+ >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
725
+ >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
726
+ Mean IoU: 0.6111
727
+ Mean F1: 0.7302
728
+ """
729
+ results = {}
730
+
731
+ # Calculate IoU if requested
732
+ if "iou" in metrics:
733
+ iou = calc_iou(
734
+ ground_truth,
735
+ prediction,
736
+ num_classes=num_classes,
737
+ ignore_index=ignore_index,
738
+ smooth=smooth,
739
+ band=band,
740
+ )
741
+ results["iou"] = iou
742
+
743
+ # Add mean IoU for multi-class
744
+ if num_classes is not None and isinstance(iou, np.ndarray):
745
+ # Calculate mean ignoring NaN values
746
+ valid_ious = iou[~np.isnan(iou)]
747
+ results["mean_iou"] = (
748
+ float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
749
+ )
750
+
751
+ # Calculate F1 score if requested
752
+ if "f1" in metrics:
753
+ f1 = calc_f1_score(
754
+ ground_truth,
755
+ prediction,
756
+ num_classes=num_classes,
757
+ ignore_index=ignore_index,
758
+ smooth=smooth,
759
+ band=band,
760
+ )
761
+ results["f1"] = f1
762
+
763
+ # Add mean F1 for multi-class
764
+ if num_classes is not None and isinstance(f1, np.ndarray):
765
+ # Calculate mean ignoring NaN values
766
+ valid_f1s = f1[~np.isnan(f1)]
767
+ results["mean_f1"] = (
768
+ float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
769
+ )
770
+
771
+ return results
772
+
773
+
386
774
  def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
387
775
  """Convert a dictionary to a xarray DataArray. The dictionary should contain the
388
776
  following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
@@ -2605,7 +2993,7 @@ def batch_vector_to_raster(
2605
2993
  def export_geotiff_tiles(
2606
2994
  in_raster,
2607
2995
  out_folder,
2608
- in_class_data,
2996
+ in_class_data=None,
2609
2997
  tile_size=256,
2610
2998
  stride=128,
2611
2999
  class_value_field="class",
@@ -2623,7 +3011,8 @@ def export_geotiff_tiles(
2623
3011
  Args:
2624
3012
  in_raster (str): Path to input raster image
2625
3013
  out_folder (str): Path to output folder
2626
- in_class_data (str): Path to classification data - can be vector file or raster
3014
+ in_class_data (str, optional): Path to classification data - can be vector file or raster.
3015
+ If None, only image tiles will be exported without labels. Defaults to None.
2627
3016
  tile_size (int): Size of tiles in pixels (square)
2628
3017
  stride (int): Step size between tiles
2629
3018
  class_value_field (str): Field containing class values (for vector data)
@@ -2644,36 +3033,42 @@ def export_geotiff_tiles(
2644
3033
  os.makedirs(out_folder, exist_ok=True)
2645
3034
  image_dir = os.path.join(out_folder, "images")
2646
3035
  os.makedirs(image_dir, exist_ok=True)
2647
- label_dir = os.path.join(out_folder, "labels")
2648
- os.makedirs(label_dir, exist_ok=True)
2649
3036
 
2650
- # Create annotation directory based on metadata format
2651
- if metadata_format in ["PASCAL_VOC", "COCO"]:
2652
- ann_dir = os.path.join(out_folder, "annotations")
2653
- os.makedirs(ann_dir, exist_ok=True)
3037
+ # Only create label and annotation directories if class data is provided
3038
+ if in_class_data is not None:
3039
+ label_dir = os.path.join(out_folder, "labels")
3040
+ os.makedirs(label_dir, exist_ok=True)
2654
3041
 
2655
- # Initialize COCO annotations dictionary
2656
- if metadata_format == "COCO":
2657
- coco_annotations = {"images": [], "annotations": [], "categories": []}
2658
- ann_id = 0
3042
+ # Create annotation directory based on metadata format
3043
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
3044
+ ann_dir = os.path.join(out_folder, "annotations")
3045
+ os.makedirs(ann_dir, exist_ok=True)
2659
3046
 
2660
- # Determine if class data is raster or vector
3047
+ # Initialize COCO annotations dictionary
3048
+ if metadata_format == "COCO":
3049
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
3050
+ ann_id = 0
3051
+
3052
+ # Determine if class data is raster or vector (only if class data provided)
2661
3053
  is_class_data_raster = False
2662
- if isinstance(in_class_data, str):
2663
- file_ext = Path(in_class_data).suffix.lower()
2664
- # Common raster extensions
2665
- if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
2666
- try:
2667
- with rasterio.open(in_class_data) as src:
2668
- is_class_data_raster = True
3054
+ if in_class_data is not None:
3055
+ if isinstance(in_class_data, str):
3056
+ file_ext = Path(in_class_data).suffix.lower()
3057
+ # Common raster extensions
3058
+ if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
3059
+ try:
3060
+ with rasterio.open(in_class_data) as src:
3061
+ is_class_data_raster = True
3062
+ if not quiet:
3063
+ print(f"Detected in_class_data as raster: {in_class_data}")
3064
+ print(f"Raster CRS: {src.crs}")
3065
+ print(f"Raster dimensions: {src.width} x {src.height}")
3066
+ except Exception:
3067
+ is_class_data_raster = False
2669
3068
  if not quiet:
2670
- print(f"Detected in_class_data as raster: {in_class_data}")
2671
- print(f"Raster CRS: {src.crs}")
2672
- print(f"Raster dimensions: {src.width} x {src.height}")
2673
- except Exception:
2674
- is_class_data_raster = False
2675
- if not quiet:
2676
- print(f"Unable to open {in_class_data} as raster, trying as vector")
3069
+ print(
3070
+ f"Unable to open {in_class_data} as raster, trying as vector"
3071
+ )
2677
3072
 
2678
3073
  # Open the input raster
2679
3074
  with rasterio.open(in_raster) as src:
@@ -2693,10 +3088,10 @@ def export_geotiff_tiles(
2693
3088
  if max_tiles is None:
2694
3089
  max_tiles = total_tiles
2695
3090
 
2696
- # Process classification data
3091
+ # Process classification data (only if class data provided)
2697
3092
  class_to_id = {}
2698
3093
 
2699
- if is_class_data_raster:
3094
+ if in_class_data is not None and is_class_data_raster:
2700
3095
  # Load raster class data
2701
3096
  with rasterio.open(in_class_data) as class_src:
2702
3097
  # Check if raster CRS matches
@@ -2740,7 +3135,7 @@ def export_geotiff_tiles(
2740
3135
  "supercategory": "object",
2741
3136
  }
2742
3137
  )
2743
- else:
3138
+ elif in_class_data is not None:
2744
3139
  # Load vector class data
2745
3140
  try:
2746
3141
  gdf = gpd.read_file(in_class_data)
@@ -2862,8 +3257,8 @@ def export_geotiff_tiles(
2862
3257
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
2863
3258
  has_features = False
2864
3259
 
2865
- # Process classification data to create labels
2866
- if is_class_data_raster:
3260
+ # Process classification data to create labels (only if class data provided)
3261
+ if in_class_data is not None and is_class_data_raster:
2867
3262
  # For raster class data
2868
3263
  with rasterio.open(in_class_data) as class_src:
2869
3264
  # Calculate window in class raster
@@ -2913,7 +3308,7 @@ def export_geotiff_tiles(
2913
3308
  except Exception as e:
2914
3309
  pbar.write(f"Error reading class raster window: {e}")
2915
3310
  stats["errors"] += 1
2916
- else:
3311
+ elif in_class_data is not None:
2917
3312
  # For vector class data
2918
3313
  # Find features that intersect with window
2919
3314
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -2956,8 +3351,8 @@ def export_geotiff_tiles(
2956
3351
  pbar.write(f"Error rasterizing feature {idx}: {e}")
2957
3352
  stats["errors"] += 1
2958
3353
 
2959
- # Skip tile if no features and skip_empty_tiles is True
2960
- if skip_empty_tiles and not has_features:
3354
+ # Skip tile if no features and skip_empty_tiles is True (only when class data provided)
3355
+ if in_class_data is not None and skip_empty_tiles and not has_features:
2961
3356
  pbar.update(1)
2962
3357
  tile_index += 1
2963
3358
  continue
@@ -2988,33 +3383,35 @@ def export_geotiff_tiles(
2988
3383
  pbar.write(f"ERROR saving image GeoTIFF: {e}")
2989
3384
  stats["errors"] += 1
2990
3385
 
2991
- # Create profile for label GeoTIFF
2992
- label_profile = {
2993
- "driver": "GTiff",
2994
- "height": tile_size,
2995
- "width": tile_size,
2996
- "count": 1,
2997
- "dtype": "uint8",
2998
- "crs": src.crs,
2999
- "transform": window_transform,
3000
- }
3386
+ # Export label as GeoTIFF (only if class data provided)
3387
+ if in_class_data is not None:
3388
+ # Create profile for label GeoTIFF
3389
+ label_profile = {
3390
+ "driver": "GTiff",
3391
+ "height": tile_size,
3392
+ "width": tile_size,
3393
+ "count": 1,
3394
+ "dtype": "uint8",
3395
+ "crs": src.crs,
3396
+ "transform": window_transform,
3397
+ }
3001
3398
 
3002
- # Export label as GeoTIFF
3003
- label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
3004
- try:
3005
- with rasterio.open(label_path, "w", **label_profile) as dst:
3006
- dst.write(label_mask.astype(np.uint8), 1)
3399
+ label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
3400
+ try:
3401
+ with rasterio.open(label_path, "w", **label_profile) as dst:
3402
+ dst.write(label_mask.astype(np.uint8), 1)
3007
3403
 
3008
- if has_features:
3009
- stats["tiles_with_features"] += 1
3010
- stats["feature_pixels"] += np.count_nonzero(label_mask)
3011
- except Exception as e:
3012
- pbar.write(f"ERROR saving label GeoTIFF: {e}")
3013
- stats["errors"] += 1
3404
+ if has_features:
3405
+ stats["tiles_with_features"] += 1
3406
+ stats["feature_pixels"] += np.count_nonzero(label_mask)
3407
+ except Exception as e:
3408
+ pbar.write(f"ERROR saving label GeoTIFF: {e}")
3409
+ stats["errors"] += 1
3014
3410
 
3015
3411
  # Create annotations for object detection if using vector class data
3016
3412
  if (
3017
- not is_class_data_raster
3413
+ in_class_data is not None
3414
+ and not is_class_data_raster
3018
3415
  and "gdf" in locals()
3019
3416
  and len(window_features) > 0
3020
3417
  ):
@@ -3209,8 +3606,8 @@ def export_geotiff_tiles(
3209
3606
  # Close progress bar
3210
3607
  pbar.close()
3211
3608
 
3212
- # Save COCO annotations if applicable
3213
- if metadata_format == "COCO":
3609
+ # Save COCO annotations if applicable (only if class data provided)
3610
+ if in_class_data is not None and metadata_format == "COCO":
3214
3611
  try:
3215
3612
  with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3216
3613
  json.dump(coco_annotations, f, indent=2)
@@ -3225,8 +3622,8 @@ def export_geotiff_tiles(
3225
3622
  print(f"ERROR saving COCO annotations: {e}")
3226
3623
  stats["errors"] += 1
3227
3624
 
3228
- # Save YOLO classes file if applicable
3229
- if metadata_format == "YOLO":
3625
+ # Save YOLO classes file if applicable (only if class data provided)
3626
+ if in_class_data is not None and metadata_format == "YOLO":
3230
3627
  try:
3231
3628
  # Create classes.txt with class names
3232
3629
  classes_path = os.path.join(out_folder, "classes.txt")
@@ -3259,13 +3656,14 @@ def export_geotiff_tiles(
3259
3656
  if not quiet:
3260
3657
  print("\n------- Export Summary -------")
3261
3658
  print(f"Total tiles exported: {stats['total_tiles']}")
3262
- print(
3263
- f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3264
- )
3265
- if stats["tiles_with_features"] > 0:
3659
+ if in_class_data is not None:
3266
3660
  print(
3267
- f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3661
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3268
3662
  )
3663
+ if stats["tiles_with_features"] > 0:
3664
+ print(
3665
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3666
+ )
3269
3667
  if stats["errors"] > 0:
3270
3668
  print(f"Errors encountered: {stats['errors']}")
3271
3669
  print(f"Output saved to: {out_folder}")
@@ -3274,7 +3672,6 @@ def export_geotiff_tiles(
3274
3672
  if stats["total_tiles"] > 0:
3275
3673
  print("\n------- Georeference Verification -------")
3276
3674
  sample_image = os.path.join(image_dir, f"tile_0.tif")
3277
- sample_label = os.path.join(label_dir, f"tile_0.tif")
3278
3675
 
3279
3676
  if os.path.exists(sample_image):
3280
3677
  try:
@@ -3290,19 +3687,22 @@ def export_geotiff_tiles(
3290
3687
  except Exception as e:
3291
3688
  print(f"Error verifying image georeference: {e}")
3292
3689
 
3293
- if os.path.exists(sample_label):
3294
- try:
3295
- with rasterio.open(sample_label) as lbl:
3296
- print(f"Label CRS: {lbl.crs}")
3297
- print(f"Label transform: {lbl.transform}")
3298
- print(
3299
- f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3300
- )
3301
- print(
3302
- f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3303
- )
3304
- except Exception as e:
3305
- print(f"Error verifying label georeference: {e}")
3690
+ # Only verify label if class data was provided
3691
+ if in_class_data is not None:
3692
+ sample_label = os.path.join(label_dir, f"tile_0.tif")
3693
+ if os.path.exists(sample_label):
3694
+ try:
3695
+ with rasterio.open(sample_label) as lbl:
3696
+ print(f"Label CRS: {lbl.crs}")
3697
+ print(f"Label transform: {lbl.transform}")
3698
+ print(
3699
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3700
+ )
3701
+ print(
3702
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3703
+ )
3704
+ except Exception as e:
3705
+ print(f"Error verifying label georeference: {e}")
3306
3706
 
3307
3707
  # Return statistics dictionary for further processing if needed
3308
3708
  return stats
@@ -3323,33 +3723,38 @@ def export_geotiff_tiles_batch(
3323
3723
  skip_empty_tiles=False,
3324
3724
  image_extensions=None,
3325
3725
  mask_extensions=None,
3326
- match_by_name=True,
3726
+ match_by_name=False,
3327
3727
  metadata_format="PASCAL_VOC",
3328
3728
  ) -> Dict[str, Any]:
3329
3729
  """
3330
- Export georeferenced GeoTIFF tiles from images and masks.
3730
+ Export georeferenced GeoTIFF tiles from images and optionally masks.
3731
+
3732
+ This function supports four modes:
3733
+ 1. Images only (no masks) - when neither masks_file nor masks_folder is provided
3734
+ 2. Single vector file covering all images (masks_file parameter)
3735
+ 3. Multiple vector files, one per image (masks_folder parameter)
3736
+ 4. Multiple raster mask files (masks_folder parameter)
3331
3737
 
3332
- This function supports three mask input modes:
3333
- 1. Single vector file covering all images (masks_file parameter)
3334
- 2. Multiple vector files, one per image (masks_folder parameter)
3335
- 3. Multiple raster mask files (masks_folder parameter)
3738
+ For mode 1 (images only), only image tiles will be exported without labels.
3336
3739
 
3337
- For mode 1 (single vector file), specify masks_file path. The function will
3740
+ For mode 2 (single vector file), specify masks_file path. The function will
3338
3741
  use spatial intersection to determine which features apply to each image.
3339
3742
 
3340
- For mode 2/3 (multiple mask files), specify masks_folder path. Images and masks
3743
+ For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
3341
3744
  are paired either by matching filenames (match_by_name=True) or by sorted order
3342
3745
  (match_by_name=False).
3343
3746
 
3344
- All image tiles are saved to a single 'images' folder and all mask tiles to a
3345
- single 'masks' folder within the output directory.
3747
+ All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
3748
+ to a single 'masks' folder within the output directory.
3346
3749
 
3347
3750
  Args:
3348
3751
  images_folder (str): Path to folder containing raster images
3349
3752
  masks_folder (str, optional): Path to folder containing classification masks/vectors.
3350
- Use this for multiple mask files (one per image or raster masks).
3753
+ Use this for multiple mask files (one per image or raster masks). If not provided
3754
+ and masks_file is also not provided, only image tiles will be exported.
3351
3755
  masks_file (str, optional): Path to a single vector file covering all images.
3352
- Use this for a single GeoJSON/Shapefile that covers multiple images.
3756
+ Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
3757
+ and masks_folder is also not provided, only image tiles will be exported.
3353
3758
  output_folder (str, optional): Path to output folder. If None, creates 'tiles'
3354
3759
  subfolder in images_folder.
3355
3760
  tile_size (int): Size of tiles in pixels (square)
@@ -3373,10 +3778,15 @@ def export_geotiff_tiles_batch(
3373
3778
 
3374
3779
  Raises:
3375
3780
  ValueError: If no images found, or if masks_folder and masks_file are both specified,
3376
- or if neither is specified, or if counts don't match when using masks_folder with
3377
- match_by_name=False.
3781
+ or if counts don't match when using masks_folder with match_by_name=False.
3378
3782
 
3379
3783
  Examples:
3784
+ # Images only (no masks)
3785
+ >>> stats = export_geotiff_tiles_batch(
3786
+ ... images_folder='data/images',
3787
+ ... output_folder='output/tiles'
3788
+ ... )
3789
+
3380
3790
  # Single vector file covering all images
3381
3791
  >>> stats = export_geotiff_tiles_batch(
3382
3792
  ... images_folder='data/images',
@@ -3411,11 +3821,6 @@ def export_geotiff_tiles_batch(
3411
3821
  "Cannot specify both masks_folder and masks_file. Please use only one."
3412
3822
  )
3413
3823
 
3414
- if masks_folder is None and masks_file is None:
3415
- raise ValueError(
3416
- "Must specify either masks_folder or masks_file for mask data source."
3417
- )
3418
-
3419
3824
  # Default output folder if not specified
3420
3825
  if output_folder is None:
3421
3826
  output_folder = os.path.join(images_folder, "tiles")
@@ -3446,22 +3851,37 @@ def export_geotiff_tiles_batch(
3446
3851
  # Create output folder structure
3447
3852
  os.makedirs(output_folder, exist_ok=True)
3448
3853
  output_images_dir = os.path.join(output_folder, "images")
3449
- output_masks_dir = os.path.join(output_folder, "masks")
3450
3854
  os.makedirs(output_images_dir, exist_ok=True)
3451
- os.makedirs(output_masks_dir, exist_ok=True)
3452
3855
 
3453
- # Create annotation directory based on metadata format
3454
- if metadata_format in ["PASCAL_VOC", "COCO"]:
3856
+ # Only create masks directory if masks are provided
3857
+ output_masks_dir = None
3858
+ if masks_folder is not None or masks_file is not None:
3859
+ output_masks_dir = os.path.join(output_folder, "masks")
3860
+ os.makedirs(output_masks_dir, exist_ok=True)
3861
+
3862
+ # Create annotation directory based on metadata format (only if masks are provided)
3863
+ ann_dir = None
3864
+ if (masks_folder is not None or masks_file is not None) and metadata_format in [
3865
+ "PASCAL_VOC",
3866
+ "COCO",
3867
+ ]:
3455
3868
  ann_dir = os.path.join(output_folder, "annotations")
3456
3869
  os.makedirs(ann_dir, exist_ok=True)
3457
3870
 
3458
- # Initialize COCO annotations dictionary
3871
+ # Initialize COCO annotations dictionary (only if masks are provided)
3459
3872
  coco_annotations = None
3460
- if metadata_format == "COCO":
3873
+ if (
3874
+ masks_folder is not None or masks_file is not None
3875
+ ) and metadata_format == "COCO":
3461
3876
  coco_annotations = {"images": [], "annotations": [], "categories": []}
3462
3877
 
3463
- # Initialize YOLO class set
3464
- yolo_classes = set() if metadata_format == "YOLO" else None
3878
+ # Initialize YOLO class set (only if masks are provided)
3879
+ yolo_classes = (
3880
+ set()
3881
+ if (masks_folder is not None or masks_file is not None)
3882
+ and metadata_format == "YOLO"
3883
+ else None
3884
+ )
3465
3885
 
3466
3886
  # Get list of image files
3467
3887
  image_files = []
@@ -3479,10 +3899,16 @@ def export_geotiff_tiles_batch(
3479
3899
 
3480
3900
  # Handle different mask input modes
3481
3901
  use_single_mask_file = masks_file is not None
3902
+ has_masks = masks_file is not None or masks_folder is not None
3482
3903
  mask_files = []
3483
3904
  image_mask_pairs = []
3484
3905
 
3485
- if use_single_mask_file:
3906
+ if not has_masks:
3907
+ # Mode 0: No masks - create pairs with None for mask
3908
+ for image_file in image_files:
3909
+ image_mask_pairs.append((image_file, None, None))
3910
+
3911
+ elif use_single_mask_file:
3486
3912
  # Mode 1: Single vector file covering all images
3487
3913
  if not os.path.exists(masks_file):
3488
3914
  raise ValueError(f"Mask file not found: {masks_file}")
@@ -3534,10 +3960,21 @@ def export_geotiff_tiles_batch(
3534
3960
  print(f"Warning: No mask found for image {img_base}")
3535
3961
 
3536
3962
  if not image_mask_pairs:
3537
- raise ValueError(
3963
+ # Provide detailed error message with found files
3964
+ image_bases = list(image_dict.keys())
3965
+ mask_bases = list(mask_dict.keys())
3966
+ error_msg = (
3538
3967
  "No matching image-mask pairs found when matching by filename. "
3539
- "Check that image and mask files have matching base names."
3968
+ "Check that image and mask files have matching base names.\n"
3969
+ f"Found {len(image_bases)} image(s): "
3970
+ f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
3971
+ f"{'...' if len(image_bases) > 5 else ''}\n"
3972
+ f"Found {len(mask_bases)} mask(s): "
3973
+ f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
3974
+ f"{'...' if len(mask_bases) > 5 else ''}\n"
3975
+ "Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
3540
3976
  )
3977
+ raise ValueError(error_msg)
3541
3978
 
3542
3979
  else:
3543
3980
  # Match by sorted order
@@ -3564,7 +4001,11 @@ def export_geotiff_tiles_batch(
3564
4001
  }
3565
4002
 
3566
4003
  if not quiet:
3567
- if use_single_mask_file:
4004
+ if not has_masks:
4005
+ print(
4006
+ f"Found {len(image_files)} image files to process (images only, no masks)"
4007
+ )
4008
+ elif use_single_mask_file:
3568
4009
  print(f"Found {len(image_files)} image files to process")
3569
4010
  print(f"Using single mask file: {masks_file}")
3570
4011
  else:
@@ -3593,10 +4034,15 @@ def export_geotiff_tiles_batch(
3593
4034
  if not quiet:
3594
4035
  print(f"\nProcessing: {base_name}")
3595
4036
  print(f" Image: {os.path.basename(image_file)}")
3596
- if use_single_mask_file:
3597
- print(f" Mask: {os.path.basename(mask_file)} (spatially filtered)")
4037
+ if mask_file is not None:
4038
+ if use_single_mask_file:
4039
+ print(
4040
+ f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
4041
+ )
4042
+ else:
4043
+ print(f" Mask: {os.path.basename(mask_file)}")
3598
4044
  else:
3599
- print(f" Mask: {os.path.basename(mask_file)}")
4045
+ print(f" Mask: None (images only)")
3600
4046
 
3601
4047
  # Process the image-mask pair
3602
4048
  tiles_generated = _process_image_mask_pair(
@@ -3718,11 +4164,12 @@ def export_geotiff_tiles_batch(
3718
4164
 
3719
4165
  print(f"Output saved to: {output_folder}")
3720
4166
  print(f" Images: {output_images_dir}")
3721
- print(f" Masks: {output_masks_dir}")
3722
- if metadata_format in ["PASCAL_VOC", "COCO"]:
3723
- print(f" Annotations: {ann_dir}")
3724
- elif metadata_format == "YOLO":
3725
- print(f" Labels: {os.path.join(output_folder, 'labels')}")
4167
+ if output_masks_dir is not None:
4168
+ print(f" Masks: {output_masks_dir}")
4169
+ if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
4170
+ print(f" Annotations: {ann_dir}")
4171
+ elif metadata_format == "YOLO":
4172
+ print(f" Labels: {os.path.join(output_folder, 'labels')}")
3726
4173
 
3727
4174
  # List failed files if any
3728
4175
  if batch_stats["failed_files"]:
@@ -3765,9 +4212,9 @@ def _process_image_mask_pair(
3765
4212
  """
3766
4213
  import warnings
3767
4214
 
3768
- # Determine if mask data is raster or vector
4215
+ # Determine if mask data is raster or vector (only if mask_file is provided)
3769
4216
  is_class_data_raster = False
3770
- if isinstance(mask_file, str):
4217
+ if mask_file is not None and isinstance(mask_file, str):
3771
4218
  file_ext = Path(mask_file).suffix.lower()
3772
4219
  # Common raster extensions
3773
4220
  if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
@@ -3801,10 +4248,10 @@ def _process_image_mask_pair(
3801
4248
  if max_tiles is None:
3802
4249
  max_tiles = total_tiles
3803
4250
 
3804
- # Process classification data
4251
+ # Process classification data (only if mask_file is provided)
3805
4252
  class_to_id = {}
3806
4253
 
3807
- if is_class_data_raster:
4254
+ if mask_file is not None and is_class_data_raster:
3808
4255
  # Load raster class data
3809
4256
  with rasterio.open(mask_file) as class_src:
3810
4257
  # Check if raster CRS matches
@@ -3831,7 +4278,7 @@ def _process_image_mask_pair(
3831
4278
 
3832
4279
  # Create class mapping
3833
4280
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
3834
- else:
4281
+ elif mask_file is not None:
3835
4282
  # Load vector class data
3836
4283
  try:
3837
4284
  if use_single_mask_file and mask_gdf is not None:
@@ -3907,12 +4354,12 @@ def _process_image_mask_pair(
3907
4354
 
3908
4355
  window_bounds = box(minx, miny, maxx, maxy)
3909
4356
 
3910
- # Create label mask
4357
+ # Create label mask (only if mask_file is provided)
3911
4358
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
3912
4359
  has_features = False
3913
4360
 
3914
- # Process classification data to create labels
3915
- if is_class_data_raster:
4361
+ # Process classification data to create labels (only if mask_file is provided)
4362
+ if mask_file is not None and is_class_data_raster:
3916
4363
  # For raster class data
3917
4364
  with rasterio.open(mask_file) as class_src:
3918
4365
  # Get corresponding window in class raster
@@ -3945,7 +4392,7 @@ def _process_image_mask_pair(
3945
4392
  if not quiet:
3946
4393
  print(f"Error reading class raster window: {e}")
3947
4394
  stats["errors"] += 1
3948
- else:
4395
+ elif mask_file is not None:
3949
4396
  # For vector class data
3950
4397
  # Find features that intersect with window
3951
4398
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -3983,8 +4430,8 @@ def _process_image_mask_pair(
3983
4430
  print(f"Error rasterizing feature {idx}: {e}")
3984
4431
  stats["errors"] += 1
3985
4432
 
3986
- # Skip tile if no features and skip_empty_tiles is True
3987
- if skip_empty_tiles and not has_features:
4433
+ # Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
4434
+ if mask_file is not None and skip_empty_tiles and not has_features:
3988
4435
  continue
3989
4436
 
3990
4437
  # Check if we've reached max_tiles before saving
@@ -4021,32 +4468,37 @@ def _process_image_mask_pair(
4021
4468
  print(f"ERROR saving image GeoTIFF: {e}")
4022
4469
  stats["errors"] += 1
4023
4470
 
4024
- # Create profile for label GeoTIFF
4025
- label_profile = {
4026
- "driver": "GTiff",
4027
- "height": tile_size,
4028
- "width": tile_size,
4029
- "count": 1,
4030
- "dtype": "uint8",
4031
- "crs": src.crs,
4032
- "transform": window_transform,
4033
- }
4471
+ # Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
4472
+ if mask_file is not None and output_masks_dir is not None:
4473
+ # Create profile for label GeoTIFF
4474
+ label_profile = {
4475
+ "driver": "GTiff",
4476
+ "height": tile_size,
4477
+ "width": tile_size,
4478
+ "count": 1,
4479
+ "dtype": "uint8",
4480
+ "crs": src.crs,
4481
+ "transform": window_transform,
4482
+ }
4034
4483
 
4035
- # Export label as GeoTIFF
4036
- label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
4037
- try:
4038
- with rasterio.open(label_path, "w", **label_profile) as dst:
4039
- dst.write(label_mask.astype(np.uint8), 1)
4484
+ label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
4485
+ try:
4486
+ with rasterio.open(label_path, "w", **label_profile) as dst:
4487
+ dst.write(label_mask.astype(np.uint8), 1)
4040
4488
 
4041
- if has_features:
4042
- stats["tiles_with_features"] += 1
4043
- except Exception as e:
4044
- if not quiet:
4045
- print(f"ERROR saving label GeoTIFF: {e}")
4046
- stats["errors"] += 1
4489
+ if has_features:
4490
+ stats["tiles_with_features"] += 1
4491
+ except Exception as e:
4492
+ if not quiet:
4493
+ print(f"ERROR saving label GeoTIFF: {e}")
4494
+ stats["errors"] += 1
4047
4495
 
4048
- # Generate annotation metadata based on format
4049
- if metadata_format == "PASCAL_VOC" and ann_dir:
4496
+ # Generate annotation metadata based on format (only if mask_file is provided)
4497
+ if (
4498
+ mask_file is not None
4499
+ and metadata_format == "PASCAL_VOC"
4500
+ and ann_dir
4501
+ ):
4050
4502
  # Create PASCAL VOC XML annotation
4051
4503
  from lxml import etree as ET
4052
4504
 
@@ -4108,7 +4560,7 @@ def _process_image_mask_pair(
4108
4560
  tree = ET.ElementTree(annotation)
4109
4561
  tree.write(xml_path, pretty_print=True, encoding="utf-8")
4110
4562
 
4111
- elif metadata_format == "COCO":
4563
+ elif mask_file is not None and metadata_format == "COCO":
4112
4564
  # Add COCO image entry
4113
4565
  image_id = int(global_tile_counter + tile_index)
4114
4566
  stats["coco_data"]["images"].append(
@@ -4188,7 +4640,7 @@ def _process_image_mask_pair(
4188
4640
  )
4189
4641
  coco_ann_id += 1
4190
4642
 
4191
- elif metadata_format == "YOLO":
4643
+ elif mask_file is not None and metadata_format == "YOLO":
4192
4644
  # Create YOLO labels directory if needed
4193
4645
  labels_dir = os.path.join(
4194
4646
  os.path.dirname(output_images_dir), "labels"
@@ -8313,17 +8765,39 @@ def write_colormap(
8313
8765
 
8314
8766
  def plot_performance_metrics(
8315
8767
  history_path: str,
8316
- figsize: Tuple[int, int] = (15, 5),
8768
+ figsize: Optional[Tuple[int, int]] = None,
8317
8769
  verbose: bool = True,
8318
8770
  save_path: Optional[str] = None,
8771
+ csv_path: Optional[str] = None,
8319
8772
  kwargs: Optional[Dict] = None,
8320
- ) -> None:
8321
- """Plot performance metrics from a history object.
8773
+ ) -> pd.DataFrame:
8774
+ """Plot performance metrics from a training history object and return as DataFrame.
8775
+
8776
+ This function loads training history, plots available metrics (loss, IoU, F1,
8777
+ precision, recall), optionally exports to CSV, and returns all metrics as a
8778
+ pandas DataFrame for further analysis.
8322
8779
 
8323
8780
  Args:
8324
- history_path: The history object to plot.
8325
- figsize: The figure size.
8326
- verbose: Whether to print the best and final metrics.
8781
+ history_path (str): Path to the saved training history (.pth file).
8782
+ figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
8783
+ automatically determined based on number of metrics.
8784
+ verbose (bool): Whether to print best and final metric values. Defaults to True.
8785
+ save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
8786
+ csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
8787
+ kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
8788
+
8789
+ Returns:
8790
+ pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
8791
+ Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
8792
+ 'val_precision', 'val_recall' (depending on availability in history).
8793
+
8794
+ Example:
8795
+ >>> df = plot_performance_metrics(
8796
+ ... 'training_history.pth',
8797
+ ... save_path='metrics_plot.png',
8798
+ ... csv_path='metrics.csv'
8799
+ ... )
8800
+ >>> print(df.head())
8327
8801
  """
8328
8802
  if kwargs is None:
8329
8803
  kwargs = {}
@@ -8333,65 +8807,135 @@ def plot_performance_metrics(
8333
8807
  train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
8334
8808
  val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
8335
8809
  val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
8336
- val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
8810
+ # Support both new (f1) and old (dice) key formats for backward compatibility
8811
+ val_f1_key = (
8812
+ "val_f1s"
8813
+ if "val_f1s" in history
8814
+ else ("val_dices" if "val_dices" in history else "val_dice")
8815
+ )
8816
+ # Add support for precision and recall
8817
+ val_precision_key = (
8818
+ "val_precisions" if "val_precisions" in history else "val_precision"
8819
+ )
8820
+ val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
8821
+
8822
+ # Collect available metrics for plotting
8823
+ available_metrics = []
8824
+ metric_info = {
8825
+ "Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
8826
+ "IoU": (val_iou_key, None, ["Val IoU"]),
8827
+ "F1": (val_f1_key, None, ["Val F1"]),
8828
+ "Precision": (val_precision_key, None, ["Val Precision"]),
8829
+ "Recall": (val_recall_key, None, ["Val Recall"]),
8830
+ }
8831
+
8832
+ for metric_name, (key1, key2, labels) in metric_info.items():
8833
+ if key1 in history or (key2 and key2 in history):
8834
+ available_metrics.append((metric_name, key1, key2, labels))
8337
8835
 
8338
- # Determine number of subplots based on available metrics
8339
- has_dice = val_dice_key in history
8340
- n_plots = 3 if has_dice else 2
8341
- figsize = (15, 5) if has_dice else (10, 5)
8836
+ # Determine number of subplots and figure size
8837
+ n_plots = len(available_metrics)
8838
+ if figsize is None:
8839
+ figsize = (5 * n_plots, 5)
8342
8840
 
8343
- plt.figure(figsize=figsize)
8841
+ # Create DataFrame for all metrics
8842
+ n_epochs = 0
8843
+ df_data = {}
8344
8844
 
8345
- # Plot loss
8346
- plt.subplot(1, n_plots, 1)
8845
+ # Add epochs
8846
+ if "epochs" in history:
8847
+ df_data["epoch"] = history["epochs"]
8848
+ n_epochs = len(history["epochs"])
8849
+ elif train_loss_key in history:
8850
+ n_epochs = len(history[train_loss_key])
8851
+ df_data["epoch"] = list(range(1, n_epochs + 1))
8852
+
8853
+ # Add all available metrics to DataFrame
8347
8854
  if train_loss_key in history:
8348
- plt.plot(history[train_loss_key], label="Train Loss")
8855
+ df_data["train_loss"] = history[train_loss_key]
8349
8856
  if val_loss_key in history:
8350
- plt.plot(history[val_loss_key], label="Val Loss")
8351
- plt.title("Loss")
8352
- plt.xlabel("Epoch")
8353
- plt.ylabel("Loss")
8354
- plt.legend()
8355
- plt.grid(True)
8356
-
8357
- # Plot IoU
8358
- plt.subplot(1, n_plots, 2)
8857
+ df_data["val_loss"] = history[val_loss_key]
8359
8858
  if val_iou_key in history:
8360
- plt.plot(history[val_iou_key], label="Val IoU")
8361
- plt.title("IoU Score")
8362
- plt.xlabel("Epoch")
8363
- plt.ylabel("IoU")
8364
- plt.legend()
8365
- plt.grid(True)
8366
-
8367
- # Plot Dice if available
8368
- if has_dice:
8369
- plt.subplot(1, n_plots, 3)
8370
- plt.plot(history[val_dice_key], label="Val Dice")
8371
- plt.title("Dice Score")
8372
- plt.xlabel("Epoch")
8373
- plt.ylabel("Dice")
8374
- plt.legend()
8375
- plt.grid(True)
8859
+ df_data["val_iou"] = history[val_iou_key]
8860
+ if val_f1_key in history:
8861
+ df_data["val_f1"] = history[val_f1_key]
8862
+ if val_precision_key in history:
8863
+ df_data["val_precision"] = history[val_precision_key]
8864
+ if val_recall_key in history:
8865
+ df_data["val_recall"] = history[val_recall_key]
8866
+
8867
+ # Create DataFrame
8868
+ df = pd.DataFrame(df_data)
8869
+
8870
+ # Export to CSV if requested
8871
+ if csv_path:
8872
+ df.to_csv(csv_path, index=False)
8873
+ if verbose:
8874
+ print(f"Metrics exported to: {csv_path}")
8875
+
8876
+ # Create plots
8877
+ if n_plots > 0:
8878
+ fig, axes = plt.subplots(1, n_plots, figsize=figsize)
8879
+ if n_plots == 1:
8880
+ axes = [axes]
8881
+
8882
+ for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
8883
+ ax = axes[idx]
8884
+
8885
+ if metric_name == "Loss":
8886
+ # Special handling for loss (has both train and val)
8887
+ if key1 in history:
8888
+ ax.plot(history[key1], label=labels[0])
8889
+ if key2 and key2 in history:
8890
+ ax.plot(history[key2], label=labels[1])
8891
+ else:
8892
+ # Single metric plots
8893
+ if key1 in history:
8894
+ ax.plot(history[key1], label=labels[0])
8376
8895
 
8377
- plt.tight_layout()
8896
+ ax.set_title(metric_name)
8897
+ ax.set_xlabel("Epoch")
8898
+ ax.set_ylabel(metric_name)
8899
+ ax.legend()
8900
+ ax.grid(True)
8378
8901
 
8379
- if save_path:
8380
- if "dpi" not in kwargs:
8381
- kwargs["dpi"] = 150
8382
- if "bbox_inches" not in kwargs:
8383
- kwargs["bbox_inches"] = "tight"
8384
- plt.savefig(save_path, **kwargs)
8902
+ plt.tight_layout()
8385
8903
 
8386
- plt.show()
8904
+ if save_path:
8905
+ if "dpi" not in kwargs:
8906
+ kwargs["dpi"] = 150
8907
+ if "bbox_inches" not in kwargs:
8908
+ kwargs["bbox_inches"] = "tight"
8909
+ plt.savefig(save_path, **kwargs)
8387
8910
 
8911
+ plt.show()
8912
+
8913
+ # Print summary statistics
8388
8914
  if verbose:
8915
+ print("\n=== Performance Metrics Summary ===")
8389
8916
  if val_iou_key in history:
8390
- print(f"Best IoU: {max(history[val_iou_key]):.4f}")
8391
- print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
8392
- if val_dice_key in history:
8393
- print(f"Best Dice: {max(history[val_dice_key]):.4f}")
8394
- print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
8917
+ print(
8918
+ f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
8919
+ )
8920
+ if val_f1_key in history:
8921
+ print(
8922
+ f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
8923
+ )
8924
+ if val_precision_key in history:
8925
+ print(
8926
+ f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
8927
+ )
8928
+ if val_recall_key in history:
8929
+ print(
8930
+ f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
8931
+ )
8932
+ if val_loss_key in history:
8933
+ print(
8934
+ f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
8935
+ )
8936
+ print("===================================\n")
8937
+
8938
+ return df
8395
8939
 
8396
8940
 
8397
8941
  def get_device() -> torch.device: