geoai-py 0.15.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/change_detection.py +16 -6
- geoai/geoai.py +3 -0
- geoai/train.py +573 -31
- geoai/utils.py +752 -208
- {geoai_py-0.15.0.dist-info → geoai_py-0.16.0.dist-info}/METADATA +2 -1
- {geoai_py-0.15.0.dist-info → geoai_py-0.16.0.dist-info}/RECORD +11 -11
- {geoai_py-0.15.0.dist-info → geoai_py-0.16.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.16.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.16.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.16.0.dist-info}/top_level.txt +0 -0
geoai/utils.py
CHANGED
@@ -383,6 +383,394 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
|
|
383
383
|
return accum_mean / len(files), accum_std / len(files)
|
384
384
|
|
385
385
|
|
386
|
+
def calc_iou(
|
387
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
388
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
389
|
+
num_classes: Optional[int] = None,
|
390
|
+
ignore_index: Optional[int] = None,
|
391
|
+
smooth: float = 1e-6,
|
392
|
+
band: int = 1,
|
393
|
+
) -> Union[float, np.ndarray]:
|
394
|
+
"""
|
395
|
+
Calculate Intersection over Union (IoU) between ground truth and prediction masks.
|
396
|
+
|
397
|
+
This function computes the IoU metric for segmentation tasks. It supports both
|
398
|
+
binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
399
|
+
or file paths to raster files.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
403
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
404
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
405
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
406
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
407
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
408
|
+
Should have the same shape and format as ground_truth.
|
409
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
410
|
+
If None, assumes binary segmentation. Defaults to None.
|
411
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
412
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
413
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
414
|
+
Defaults to 1e-6.
|
415
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
416
|
+
Only used when input is a file path. Defaults to 1.
|
417
|
+
|
418
|
+
Returns:
|
419
|
+
Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
|
420
|
+
For multi-class segmentation, returns an array of IoU scores for each class.
|
421
|
+
|
422
|
+
Examples:
|
423
|
+
>>> # Binary segmentation with arrays
|
424
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
425
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
426
|
+
>>> iou = calc_iou(gt, pred)
|
427
|
+
>>> print(f"IoU: {iou:.4f}")
|
428
|
+
IoU: 0.8333
|
429
|
+
|
430
|
+
>>> # Multi-class segmentation
|
431
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
432
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
433
|
+
>>> iou = calc_iou(gt, pred, num_classes=3)
|
434
|
+
>>> print(f"IoU per class: {iou}")
|
435
|
+
IoU per class: [0.8333 0.5000 0.5000]
|
436
|
+
|
437
|
+
>>> # Using PyTorch tensors
|
438
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
439
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
440
|
+
>>> iou = calc_iou(gt_tensor, pred_tensor)
|
441
|
+
>>> print(f"IoU: {iou:.4f}")
|
442
|
+
IoU: 0.8333
|
443
|
+
|
444
|
+
>>> # Using raster file paths
|
445
|
+
>>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
|
446
|
+
>>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
|
447
|
+
Mean IoU: 0.7500
|
448
|
+
"""
|
449
|
+
# Load from file if string path is provided
|
450
|
+
if isinstance(ground_truth, str):
|
451
|
+
with rasterio.open(ground_truth) as src:
|
452
|
+
ground_truth = src.read(band)
|
453
|
+
if isinstance(prediction, str):
|
454
|
+
with rasterio.open(prediction) as src:
|
455
|
+
prediction = src.read(band)
|
456
|
+
|
457
|
+
# Convert to numpy if torch tensor
|
458
|
+
if isinstance(ground_truth, torch.Tensor):
|
459
|
+
ground_truth = ground_truth.cpu().numpy()
|
460
|
+
if isinstance(prediction, torch.Tensor):
|
461
|
+
prediction = prediction.cpu().numpy()
|
462
|
+
|
463
|
+
# Ensure inputs have the same shape
|
464
|
+
if ground_truth.shape != prediction.shape:
|
465
|
+
raise ValueError(
|
466
|
+
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
467
|
+
)
|
468
|
+
|
469
|
+
# Binary segmentation
|
470
|
+
if num_classes is None:
|
471
|
+
ground_truth = ground_truth.astype(bool)
|
472
|
+
prediction = prediction.astype(bool)
|
473
|
+
|
474
|
+
intersection = np.logical_and(ground_truth, prediction).sum()
|
475
|
+
union = np.logical_or(ground_truth, prediction).sum()
|
476
|
+
|
477
|
+
if union == 0:
|
478
|
+
return 1.0 if intersection == 0 else 0.0
|
479
|
+
|
480
|
+
iou = (intersection + smooth) / (union + smooth)
|
481
|
+
return float(iou)
|
482
|
+
|
483
|
+
# Multi-class segmentation
|
484
|
+
else:
|
485
|
+
iou_per_class = []
|
486
|
+
|
487
|
+
for class_idx in range(num_classes):
|
488
|
+
# Handle ignored class by appending np.nan
|
489
|
+
if ignore_index is not None and class_idx == ignore_index:
|
490
|
+
iou_per_class.append(np.nan)
|
491
|
+
continue
|
492
|
+
|
493
|
+
# Create binary masks for current class
|
494
|
+
gt_class = (ground_truth == class_idx).astype(bool)
|
495
|
+
pred_class = (prediction == class_idx).astype(bool)
|
496
|
+
|
497
|
+
intersection = np.logical_and(gt_class, pred_class).sum()
|
498
|
+
union = np.logical_or(gt_class, pred_class).sum()
|
499
|
+
|
500
|
+
if union == 0:
|
501
|
+
# If class is not present in both gt and pred
|
502
|
+
iou_per_class.append(np.nan)
|
503
|
+
else:
|
504
|
+
iou_per_class.append((intersection + smooth) / (union + smooth))
|
505
|
+
|
506
|
+
return np.array(iou_per_class)
|
507
|
+
|
508
|
+
|
509
|
+
def calc_f1_score(
|
510
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
511
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
512
|
+
num_classes: Optional[int] = None,
|
513
|
+
ignore_index: Optional[int] = None,
|
514
|
+
smooth: float = 1e-6,
|
515
|
+
band: int = 1,
|
516
|
+
) -> Union[float, np.ndarray]:
|
517
|
+
"""
|
518
|
+
Calculate F1 score between ground truth and prediction masks.
|
519
|
+
|
520
|
+
The F1 score is the harmonic mean of precision and recall, computed as:
|
521
|
+
F1 = 2 * (precision * recall) / (precision + recall)
|
522
|
+
where precision = TP / (TP + FP) and recall = TP / (TP + FN).
|
523
|
+
|
524
|
+
This function supports both binary and multi-class segmentation, and can handle
|
525
|
+
numpy arrays, PyTorch tensors, or file paths to raster files.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
529
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
530
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
531
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
532
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
533
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
534
|
+
Should have the same shape and format as ground_truth.
|
535
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
536
|
+
If None, assumes binary segmentation. Defaults to None.
|
537
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
538
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
539
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
540
|
+
Defaults to 1e-6.
|
541
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
542
|
+
Only used when input is a file path. Defaults to 1.
|
543
|
+
|
544
|
+
Returns:
|
545
|
+
Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
|
546
|
+
For multi-class segmentation, returns an array of F1 scores for each class.
|
547
|
+
|
548
|
+
Examples:
|
549
|
+
>>> # Binary segmentation with arrays
|
550
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
551
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
552
|
+
>>> f1 = calc_f1_score(gt, pred)
|
553
|
+
>>> print(f"F1 Score: {f1:.4f}")
|
554
|
+
F1 Score: 0.8571
|
555
|
+
|
556
|
+
>>> # Multi-class segmentation
|
557
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
558
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
559
|
+
>>> f1 = calc_f1_score(gt, pred, num_classes=3)
|
560
|
+
>>> print(f"F1 Score per class: {f1}")
|
561
|
+
F1 Score per class: [0.8571 0.6667 0.6667]
|
562
|
+
|
563
|
+
>>> # Using PyTorch tensors
|
564
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
565
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
566
|
+
>>> f1 = calc_f1_score(gt_tensor, pred_tensor)
|
567
|
+
>>> print(f"F1 Score: {f1:.4f}")
|
568
|
+
F1 Score: 0.8571
|
569
|
+
|
570
|
+
>>> # Using raster file paths
|
571
|
+
>>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
|
572
|
+
>>> print(f"Mean F1: {np.nanmean(f1):.4f}")
|
573
|
+
Mean F1: 0.7302
|
574
|
+
"""
|
575
|
+
# Load from file if string path is provided
|
576
|
+
if isinstance(ground_truth, str):
|
577
|
+
with rasterio.open(ground_truth) as src:
|
578
|
+
ground_truth = src.read(band)
|
579
|
+
if isinstance(prediction, str):
|
580
|
+
with rasterio.open(prediction) as src:
|
581
|
+
prediction = src.read(band)
|
582
|
+
|
583
|
+
# Convert to numpy if torch tensor
|
584
|
+
if isinstance(ground_truth, torch.Tensor):
|
585
|
+
ground_truth = ground_truth.cpu().numpy()
|
586
|
+
if isinstance(prediction, torch.Tensor):
|
587
|
+
prediction = prediction.cpu().numpy()
|
588
|
+
|
589
|
+
# Ensure inputs have the same shape
|
590
|
+
if ground_truth.shape != prediction.shape:
|
591
|
+
raise ValueError(
|
592
|
+
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
593
|
+
)
|
594
|
+
|
595
|
+
# Binary segmentation
|
596
|
+
if num_classes is None:
|
597
|
+
ground_truth = ground_truth.astype(bool)
|
598
|
+
prediction = prediction.astype(bool)
|
599
|
+
|
600
|
+
# Calculate True Positives, False Positives, False Negatives
|
601
|
+
tp = np.logical_and(ground_truth, prediction).sum()
|
602
|
+
fp = np.logical_and(~ground_truth, prediction).sum()
|
603
|
+
fn = np.logical_and(ground_truth, ~prediction).sum()
|
604
|
+
|
605
|
+
# Calculate precision and recall
|
606
|
+
precision = (tp + smooth) / (tp + fp + smooth)
|
607
|
+
recall = (tp + smooth) / (tp + fn + smooth)
|
608
|
+
|
609
|
+
# Calculate F1 score
|
610
|
+
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
611
|
+
return float(f1)
|
612
|
+
|
613
|
+
# Multi-class segmentation
|
614
|
+
else:
|
615
|
+
f1_per_class = []
|
616
|
+
|
617
|
+
for class_idx in range(num_classes):
|
618
|
+
# Mark ignored class with np.nan
|
619
|
+
if ignore_index is not None and class_idx == ignore_index:
|
620
|
+
f1_per_class.append(np.nan)
|
621
|
+
continue
|
622
|
+
|
623
|
+
# Create binary masks for current class
|
624
|
+
gt_class = (ground_truth == class_idx).astype(bool)
|
625
|
+
pred_class = (prediction == class_idx).astype(bool)
|
626
|
+
|
627
|
+
# Calculate True Positives, False Positives, False Negatives
|
628
|
+
tp = np.logical_and(gt_class, pred_class).sum()
|
629
|
+
fp = np.logical_and(~gt_class, pred_class).sum()
|
630
|
+
fn = np.logical_and(gt_class, ~pred_class).sum()
|
631
|
+
|
632
|
+
# Calculate precision and recall
|
633
|
+
precision = (tp + smooth) / (tp + fp + smooth)
|
634
|
+
recall = (tp + smooth) / (tp + fn + smooth)
|
635
|
+
|
636
|
+
# Calculate F1 score
|
637
|
+
if tp + fp + fn == 0:
|
638
|
+
# If class is not present in both gt and pred
|
639
|
+
f1_per_class.append(np.nan)
|
640
|
+
else:
|
641
|
+
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
642
|
+
f1_per_class.append(f1)
|
643
|
+
|
644
|
+
return np.array(f1_per_class)
|
645
|
+
|
646
|
+
|
647
|
+
def calc_segmentation_metrics(
|
648
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
649
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
650
|
+
num_classes: Optional[int] = None,
|
651
|
+
ignore_index: Optional[int] = None,
|
652
|
+
smooth: float = 1e-6,
|
653
|
+
metrics: List[str] = ["iou", "f1"],
|
654
|
+
band: int = 1,
|
655
|
+
) -> Dict[str, Union[float, np.ndarray]]:
|
656
|
+
"""
|
657
|
+
Calculate multiple segmentation metrics between ground truth and prediction masks.
|
658
|
+
|
659
|
+
This is a convenient wrapper function that computes multiple metrics at once,
|
660
|
+
including IoU (Intersection over Union) and F1 score. It supports both binary
|
661
|
+
and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
662
|
+
or file paths to raster files.
|
663
|
+
|
664
|
+
Args:
|
665
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
666
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
667
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
668
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
669
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
670
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
671
|
+
Should have the same shape and format as ground_truth.
|
672
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
673
|
+
If None, assumes binary segmentation. Defaults to None.
|
674
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
675
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
676
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
677
|
+
Defaults to 1e-6.
|
678
|
+
metrics (List[str], optional): List of metrics to calculate.
|
679
|
+
Options: "iou", "f1". Defaults to ["iou", "f1"].
|
680
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
681
|
+
Only used when input is a file path. Defaults to 1.
|
682
|
+
|
683
|
+
Returns:
|
684
|
+
Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
|
685
|
+
Keys are metric names ("iou", "f1"), values are the metric scores.
|
686
|
+
For binary segmentation, values are floats.
|
687
|
+
For multi-class segmentation, values are numpy arrays with per-class scores.
|
688
|
+
Also includes "mean_iou" and "mean_f1" for multi-class segmentation
|
689
|
+
(mean computed over valid classes, ignoring NaN values).
|
690
|
+
|
691
|
+
Examples:
|
692
|
+
>>> # Binary segmentation with arrays
|
693
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
694
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
695
|
+
>>> metrics = calc_segmentation_metrics(gt, pred)
|
696
|
+
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
697
|
+
IoU: 0.8333, F1: 0.8571
|
698
|
+
|
699
|
+
>>> # Multi-class segmentation
|
700
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
701
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
702
|
+
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
|
703
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
704
|
+
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
705
|
+
>>> print(f"Per-class IoU: {metrics['iou']}")
|
706
|
+
Mean IoU: 0.6111
|
707
|
+
Mean F1: 0.7302
|
708
|
+
Per-class IoU: [0.8333 0.5000 0.5000]
|
709
|
+
|
710
|
+
>>> # Calculate only IoU
|
711
|
+
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
|
712
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
713
|
+
Mean IoU: 0.6111
|
714
|
+
|
715
|
+
>>> # Using PyTorch tensors
|
716
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
717
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
718
|
+
>>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
|
719
|
+
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
720
|
+
IoU: 0.8333, F1: 0.8571
|
721
|
+
|
722
|
+
>>> # Using raster file paths
|
723
|
+
>>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
|
724
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
725
|
+
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
726
|
+
Mean IoU: 0.6111
|
727
|
+
Mean F1: 0.7302
|
728
|
+
"""
|
729
|
+
results = {}
|
730
|
+
|
731
|
+
# Calculate IoU if requested
|
732
|
+
if "iou" in metrics:
|
733
|
+
iou = calc_iou(
|
734
|
+
ground_truth,
|
735
|
+
prediction,
|
736
|
+
num_classes=num_classes,
|
737
|
+
ignore_index=ignore_index,
|
738
|
+
smooth=smooth,
|
739
|
+
band=band,
|
740
|
+
)
|
741
|
+
results["iou"] = iou
|
742
|
+
|
743
|
+
# Add mean IoU for multi-class
|
744
|
+
if num_classes is not None and isinstance(iou, np.ndarray):
|
745
|
+
# Calculate mean ignoring NaN values
|
746
|
+
valid_ious = iou[~np.isnan(iou)]
|
747
|
+
results["mean_iou"] = (
|
748
|
+
float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
|
749
|
+
)
|
750
|
+
|
751
|
+
# Calculate F1 score if requested
|
752
|
+
if "f1" in metrics:
|
753
|
+
f1 = calc_f1_score(
|
754
|
+
ground_truth,
|
755
|
+
prediction,
|
756
|
+
num_classes=num_classes,
|
757
|
+
ignore_index=ignore_index,
|
758
|
+
smooth=smooth,
|
759
|
+
band=band,
|
760
|
+
)
|
761
|
+
results["f1"] = f1
|
762
|
+
|
763
|
+
# Add mean F1 for multi-class
|
764
|
+
if num_classes is not None and isinstance(f1, np.ndarray):
|
765
|
+
# Calculate mean ignoring NaN values
|
766
|
+
valid_f1s = f1[~np.isnan(f1)]
|
767
|
+
results["mean_f1"] = (
|
768
|
+
float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
|
769
|
+
)
|
770
|
+
|
771
|
+
return results
|
772
|
+
|
773
|
+
|
386
774
|
def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
|
387
775
|
"""Convert a dictionary to a xarray DataArray. The dictionary should contain the
|
388
776
|
following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
|
@@ -2605,7 +2993,7 @@ def batch_vector_to_raster(
|
|
2605
2993
|
def export_geotiff_tiles(
|
2606
2994
|
in_raster,
|
2607
2995
|
out_folder,
|
2608
|
-
in_class_data,
|
2996
|
+
in_class_data=None,
|
2609
2997
|
tile_size=256,
|
2610
2998
|
stride=128,
|
2611
2999
|
class_value_field="class",
|
@@ -2623,7 +3011,8 @@ def export_geotiff_tiles(
|
|
2623
3011
|
Args:
|
2624
3012
|
in_raster (str): Path to input raster image
|
2625
3013
|
out_folder (str): Path to output folder
|
2626
|
-
in_class_data (str): Path to classification data - can be vector file or raster
|
3014
|
+
in_class_data (str, optional): Path to classification data - can be vector file or raster.
|
3015
|
+
If None, only image tiles will be exported without labels. Defaults to None.
|
2627
3016
|
tile_size (int): Size of tiles in pixels (square)
|
2628
3017
|
stride (int): Step size between tiles
|
2629
3018
|
class_value_field (str): Field containing class values (for vector data)
|
@@ -2644,36 +3033,42 @@ def export_geotiff_tiles(
|
|
2644
3033
|
os.makedirs(out_folder, exist_ok=True)
|
2645
3034
|
image_dir = os.path.join(out_folder, "images")
|
2646
3035
|
os.makedirs(image_dir, exist_ok=True)
|
2647
|
-
label_dir = os.path.join(out_folder, "labels")
|
2648
|
-
os.makedirs(label_dir, exist_ok=True)
|
2649
3036
|
|
2650
|
-
#
|
2651
|
-
if
|
2652
|
-
|
2653
|
-
os.makedirs(
|
3037
|
+
# Only create label and annotation directories if class data is provided
|
3038
|
+
if in_class_data is not None:
|
3039
|
+
label_dir = os.path.join(out_folder, "labels")
|
3040
|
+
os.makedirs(label_dir, exist_ok=True)
|
2654
3041
|
|
2655
|
-
|
2656
|
-
|
2657
|
-
|
2658
|
-
|
3042
|
+
# Create annotation directory based on metadata format
|
3043
|
+
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
3044
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
3045
|
+
os.makedirs(ann_dir, exist_ok=True)
|
2659
3046
|
|
2660
|
-
|
3047
|
+
# Initialize COCO annotations dictionary
|
3048
|
+
if metadata_format == "COCO":
|
3049
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
3050
|
+
ann_id = 0
|
3051
|
+
|
3052
|
+
# Determine if class data is raster or vector (only if class data provided)
|
2661
3053
|
is_class_data_raster = False
|
2662
|
-
if
|
2663
|
-
|
2664
|
-
|
2665
|
-
|
2666
|
-
|
2667
|
-
|
2668
|
-
|
3054
|
+
if in_class_data is not None:
|
3055
|
+
if isinstance(in_class_data, str):
|
3056
|
+
file_ext = Path(in_class_data).suffix.lower()
|
3057
|
+
# Common raster extensions
|
3058
|
+
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
3059
|
+
try:
|
3060
|
+
with rasterio.open(in_class_data) as src:
|
3061
|
+
is_class_data_raster = True
|
3062
|
+
if not quiet:
|
3063
|
+
print(f"Detected in_class_data as raster: {in_class_data}")
|
3064
|
+
print(f"Raster CRS: {src.crs}")
|
3065
|
+
print(f"Raster dimensions: {src.width} x {src.height}")
|
3066
|
+
except Exception:
|
3067
|
+
is_class_data_raster = False
|
2669
3068
|
if not quiet:
|
2670
|
-
print(
|
2671
|
-
|
2672
|
-
|
2673
|
-
except Exception:
|
2674
|
-
is_class_data_raster = False
|
2675
|
-
if not quiet:
|
2676
|
-
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
3069
|
+
print(
|
3070
|
+
f"Unable to open {in_class_data} as raster, trying as vector"
|
3071
|
+
)
|
2677
3072
|
|
2678
3073
|
# Open the input raster
|
2679
3074
|
with rasterio.open(in_raster) as src:
|
@@ -2693,10 +3088,10 @@ def export_geotiff_tiles(
|
|
2693
3088
|
if max_tiles is None:
|
2694
3089
|
max_tiles = total_tiles
|
2695
3090
|
|
2696
|
-
# Process classification data
|
3091
|
+
# Process classification data (only if class data provided)
|
2697
3092
|
class_to_id = {}
|
2698
3093
|
|
2699
|
-
if is_class_data_raster:
|
3094
|
+
if in_class_data is not None and is_class_data_raster:
|
2700
3095
|
# Load raster class data
|
2701
3096
|
with rasterio.open(in_class_data) as class_src:
|
2702
3097
|
# Check if raster CRS matches
|
@@ -2740,7 +3135,7 @@ def export_geotiff_tiles(
|
|
2740
3135
|
"supercategory": "object",
|
2741
3136
|
}
|
2742
3137
|
)
|
2743
|
-
|
3138
|
+
elif in_class_data is not None:
|
2744
3139
|
# Load vector class data
|
2745
3140
|
try:
|
2746
3141
|
gdf = gpd.read_file(in_class_data)
|
@@ -2862,8 +3257,8 @@ def export_geotiff_tiles(
|
|
2862
3257
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
2863
3258
|
has_features = False
|
2864
3259
|
|
2865
|
-
# Process classification data to create labels
|
2866
|
-
if is_class_data_raster:
|
3260
|
+
# Process classification data to create labels (only if class data provided)
|
3261
|
+
if in_class_data is not None and is_class_data_raster:
|
2867
3262
|
# For raster class data
|
2868
3263
|
with rasterio.open(in_class_data) as class_src:
|
2869
3264
|
# Calculate window in class raster
|
@@ -2913,7 +3308,7 @@ def export_geotiff_tiles(
|
|
2913
3308
|
except Exception as e:
|
2914
3309
|
pbar.write(f"Error reading class raster window: {e}")
|
2915
3310
|
stats["errors"] += 1
|
2916
|
-
|
3311
|
+
elif in_class_data is not None:
|
2917
3312
|
# For vector class data
|
2918
3313
|
# Find features that intersect with window
|
2919
3314
|
window_features = gdf[gdf.intersects(window_bounds)]
|
@@ -2956,8 +3351,8 @@ def export_geotiff_tiles(
|
|
2956
3351
|
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
2957
3352
|
stats["errors"] += 1
|
2958
3353
|
|
2959
|
-
# Skip tile if no features and skip_empty_tiles is True
|
2960
|
-
if skip_empty_tiles and not has_features:
|
3354
|
+
# Skip tile if no features and skip_empty_tiles is True (only when class data provided)
|
3355
|
+
if in_class_data is not None and skip_empty_tiles and not has_features:
|
2961
3356
|
pbar.update(1)
|
2962
3357
|
tile_index += 1
|
2963
3358
|
continue
|
@@ -2988,33 +3383,35 @@ def export_geotiff_tiles(
|
|
2988
3383
|
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
2989
3384
|
stats["errors"] += 1
|
2990
3385
|
|
2991
|
-
#
|
2992
|
-
|
2993
|
-
|
2994
|
-
|
2995
|
-
|
2996
|
-
|
2997
|
-
|
2998
|
-
|
2999
|
-
|
3000
|
-
|
3386
|
+
# Export label as GeoTIFF (only if class data provided)
|
3387
|
+
if in_class_data is not None:
|
3388
|
+
# Create profile for label GeoTIFF
|
3389
|
+
label_profile = {
|
3390
|
+
"driver": "GTiff",
|
3391
|
+
"height": tile_size,
|
3392
|
+
"width": tile_size,
|
3393
|
+
"count": 1,
|
3394
|
+
"dtype": "uint8",
|
3395
|
+
"crs": src.crs,
|
3396
|
+
"transform": window_transform,
|
3397
|
+
}
|
3001
3398
|
|
3002
|
-
|
3003
|
-
|
3004
|
-
|
3005
|
-
|
3006
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
3399
|
+
label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
|
3400
|
+
try:
|
3401
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
3402
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
3007
3403
|
|
3008
|
-
|
3009
|
-
|
3010
|
-
|
3011
|
-
|
3012
|
-
|
3013
|
-
|
3404
|
+
if has_features:
|
3405
|
+
stats["tiles_with_features"] += 1
|
3406
|
+
stats["feature_pixels"] += np.count_nonzero(label_mask)
|
3407
|
+
except Exception as e:
|
3408
|
+
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
3409
|
+
stats["errors"] += 1
|
3014
3410
|
|
3015
3411
|
# Create annotations for object detection if using vector class data
|
3016
3412
|
if (
|
3017
|
-
not
|
3413
|
+
in_class_data is not None
|
3414
|
+
and not is_class_data_raster
|
3018
3415
|
and "gdf" in locals()
|
3019
3416
|
and len(window_features) > 0
|
3020
3417
|
):
|
@@ -3209,8 +3606,8 @@ def export_geotiff_tiles(
|
|
3209
3606
|
# Close progress bar
|
3210
3607
|
pbar.close()
|
3211
3608
|
|
3212
|
-
# Save COCO annotations if applicable
|
3213
|
-
if metadata_format == "COCO":
|
3609
|
+
# Save COCO annotations if applicable (only if class data provided)
|
3610
|
+
if in_class_data is not None and metadata_format == "COCO":
|
3214
3611
|
try:
|
3215
3612
|
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
3216
3613
|
json.dump(coco_annotations, f, indent=2)
|
@@ -3225,8 +3622,8 @@ def export_geotiff_tiles(
|
|
3225
3622
|
print(f"ERROR saving COCO annotations: {e}")
|
3226
3623
|
stats["errors"] += 1
|
3227
3624
|
|
3228
|
-
# Save YOLO classes file if applicable
|
3229
|
-
if metadata_format == "YOLO":
|
3625
|
+
# Save YOLO classes file if applicable (only if class data provided)
|
3626
|
+
if in_class_data is not None and metadata_format == "YOLO":
|
3230
3627
|
try:
|
3231
3628
|
# Create classes.txt with class names
|
3232
3629
|
classes_path = os.path.join(out_folder, "classes.txt")
|
@@ -3259,13 +3656,14 @@ def export_geotiff_tiles(
|
|
3259
3656
|
if not quiet:
|
3260
3657
|
print("\n------- Export Summary -------")
|
3261
3658
|
print(f"Total tiles exported: {stats['total_tiles']}")
|
3262
|
-
|
3263
|
-
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
3264
|
-
)
|
3265
|
-
if stats["tiles_with_features"] > 0:
|
3659
|
+
if in_class_data is not None:
|
3266
3660
|
print(
|
3267
|
-
f"
|
3661
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
3268
3662
|
)
|
3663
|
+
if stats["tiles_with_features"] > 0:
|
3664
|
+
print(
|
3665
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
3666
|
+
)
|
3269
3667
|
if stats["errors"] > 0:
|
3270
3668
|
print(f"Errors encountered: {stats['errors']}")
|
3271
3669
|
print(f"Output saved to: {out_folder}")
|
@@ -3274,7 +3672,6 @@ def export_geotiff_tiles(
|
|
3274
3672
|
if stats["total_tiles"] > 0:
|
3275
3673
|
print("\n------- Georeference Verification -------")
|
3276
3674
|
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
3277
|
-
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
3278
3675
|
|
3279
3676
|
if os.path.exists(sample_image):
|
3280
3677
|
try:
|
@@ -3290,19 +3687,22 @@ def export_geotiff_tiles(
|
|
3290
3687
|
except Exception as e:
|
3291
3688
|
print(f"Error verifying image georeference: {e}")
|
3292
3689
|
|
3293
|
-
if
|
3294
|
-
|
3295
|
-
|
3296
|
-
|
3297
|
-
|
3298
|
-
|
3299
|
-
f"Label
|
3300
|
-
|
3301
|
-
|
3302
|
-
|
3303
|
-
|
3304
|
-
|
3305
|
-
|
3690
|
+
# Only verify label if class data was provided
|
3691
|
+
if in_class_data is not None:
|
3692
|
+
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
3693
|
+
if os.path.exists(sample_label):
|
3694
|
+
try:
|
3695
|
+
with rasterio.open(sample_label) as lbl:
|
3696
|
+
print(f"Label CRS: {lbl.crs}")
|
3697
|
+
print(f"Label transform: {lbl.transform}")
|
3698
|
+
print(
|
3699
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
3700
|
+
)
|
3701
|
+
print(
|
3702
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
3703
|
+
)
|
3704
|
+
except Exception as e:
|
3705
|
+
print(f"Error verifying label georeference: {e}")
|
3306
3706
|
|
3307
3707
|
# Return statistics dictionary for further processing if needed
|
3308
3708
|
return stats
|
@@ -3323,33 +3723,38 @@ def export_geotiff_tiles_batch(
|
|
3323
3723
|
skip_empty_tiles=False,
|
3324
3724
|
image_extensions=None,
|
3325
3725
|
mask_extensions=None,
|
3326
|
-
match_by_name=
|
3726
|
+
match_by_name=False,
|
3327
3727
|
metadata_format="PASCAL_VOC",
|
3328
3728
|
) -> Dict[str, Any]:
|
3329
3729
|
"""
|
3330
|
-
Export georeferenced GeoTIFF tiles from images and masks.
|
3730
|
+
Export georeferenced GeoTIFF tiles from images and optionally masks.
|
3731
|
+
|
3732
|
+
This function supports four modes:
|
3733
|
+
1. Images only (no masks) - when neither masks_file nor masks_folder is provided
|
3734
|
+
2. Single vector file covering all images (masks_file parameter)
|
3735
|
+
3. Multiple vector files, one per image (masks_folder parameter)
|
3736
|
+
4. Multiple raster mask files (masks_folder parameter)
|
3331
3737
|
|
3332
|
-
|
3333
|
-
1. Single vector file covering all images (masks_file parameter)
|
3334
|
-
2. Multiple vector files, one per image (masks_folder parameter)
|
3335
|
-
3. Multiple raster mask files (masks_folder parameter)
|
3738
|
+
For mode 1 (images only), only image tiles will be exported without labels.
|
3336
3739
|
|
3337
|
-
For mode
|
3740
|
+
For mode 2 (single vector file), specify masks_file path. The function will
|
3338
3741
|
use spatial intersection to determine which features apply to each image.
|
3339
3742
|
|
3340
|
-
For mode
|
3743
|
+
For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
|
3341
3744
|
are paired either by matching filenames (match_by_name=True) or by sorted order
|
3342
3745
|
(match_by_name=False).
|
3343
3746
|
|
3344
|
-
All image tiles are saved to a single 'images' folder and all mask tiles
|
3345
|
-
single 'masks' folder within the output directory.
|
3747
|
+
All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
|
3748
|
+
to a single 'masks' folder within the output directory.
|
3346
3749
|
|
3347
3750
|
Args:
|
3348
3751
|
images_folder (str): Path to folder containing raster images
|
3349
3752
|
masks_folder (str, optional): Path to folder containing classification masks/vectors.
|
3350
|
-
Use this for multiple mask files (one per image or raster masks).
|
3753
|
+
Use this for multiple mask files (one per image or raster masks). If not provided
|
3754
|
+
and masks_file is also not provided, only image tiles will be exported.
|
3351
3755
|
masks_file (str, optional): Path to a single vector file covering all images.
|
3352
|
-
Use this for a single GeoJSON/Shapefile that covers multiple images.
|
3756
|
+
Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
|
3757
|
+
and masks_folder is also not provided, only image tiles will be exported.
|
3353
3758
|
output_folder (str, optional): Path to output folder. If None, creates 'tiles'
|
3354
3759
|
subfolder in images_folder.
|
3355
3760
|
tile_size (int): Size of tiles in pixels (square)
|
@@ -3373,10 +3778,15 @@ def export_geotiff_tiles_batch(
|
|
3373
3778
|
|
3374
3779
|
Raises:
|
3375
3780
|
ValueError: If no images found, or if masks_folder and masks_file are both specified,
|
3376
|
-
or if
|
3377
|
-
match_by_name=False.
|
3781
|
+
or if counts don't match when using masks_folder with match_by_name=False.
|
3378
3782
|
|
3379
3783
|
Examples:
|
3784
|
+
# Images only (no masks)
|
3785
|
+
>>> stats = export_geotiff_tiles_batch(
|
3786
|
+
... images_folder='data/images',
|
3787
|
+
... output_folder='output/tiles'
|
3788
|
+
... )
|
3789
|
+
|
3380
3790
|
# Single vector file covering all images
|
3381
3791
|
>>> stats = export_geotiff_tiles_batch(
|
3382
3792
|
... images_folder='data/images',
|
@@ -3411,11 +3821,6 @@ def export_geotiff_tiles_batch(
|
|
3411
3821
|
"Cannot specify both masks_folder and masks_file. Please use only one."
|
3412
3822
|
)
|
3413
3823
|
|
3414
|
-
if masks_folder is None and masks_file is None:
|
3415
|
-
raise ValueError(
|
3416
|
-
"Must specify either masks_folder or masks_file for mask data source."
|
3417
|
-
)
|
3418
|
-
|
3419
3824
|
# Default output folder if not specified
|
3420
3825
|
if output_folder is None:
|
3421
3826
|
output_folder = os.path.join(images_folder, "tiles")
|
@@ -3446,22 +3851,37 @@ def export_geotiff_tiles_batch(
|
|
3446
3851
|
# Create output folder structure
|
3447
3852
|
os.makedirs(output_folder, exist_ok=True)
|
3448
3853
|
output_images_dir = os.path.join(output_folder, "images")
|
3449
|
-
output_masks_dir = os.path.join(output_folder, "masks")
|
3450
3854
|
os.makedirs(output_images_dir, exist_ok=True)
|
3451
|
-
os.makedirs(output_masks_dir, exist_ok=True)
|
3452
3855
|
|
3453
|
-
#
|
3454
|
-
|
3856
|
+
# Only create masks directory if masks are provided
|
3857
|
+
output_masks_dir = None
|
3858
|
+
if masks_folder is not None or masks_file is not None:
|
3859
|
+
output_masks_dir = os.path.join(output_folder, "masks")
|
3860
|
+
os.makedirs(output_masks_dir, exist_ok=True)
|
3861
|
+
|
3862
|
+
# Create annotation directory based on metadata format (only if masks are provided)
|
3863
|
+
ann_dir = None
|
3864
|
+
if (masks_folder is not None or masks_file is not None) and metadata_format in [
|
3865
|
+
"PASCAL_VOC",
|
3866
|
+
"COCO",
|
3867
|
+
]:
|
3455
3868
|
ann_dir = os.path.join(output_folder, "annotations")
|
3456
3869
|
os.makedirs(ann_dir, exist_ok=True)
|
3457
3870
|
|
3458
|
-
# Initialize COCO annotations dictionary
|
3871
|
+
# Initialize COCO annotations dictionary (only if masks are provided)
|
3459
3872
|
coco_annotations = None
|
3460
|
-
if
|
3873
|
+
if (
|
3874
|
+
masks_folder is not None or masks_file is not None
|
3875
|
+
) and metadata_format == "COCO":
|
3461
3876
|
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
3462
3877
|
|
3463
|
-
# Initialize YOLO class set
|
3464
|
-
yolo_classes =
|
3878
|
+
# Initialize YOLO class set (only if masks are provided)
|
3879
|
+
yolo_classes = (
|
3880
|
+
set()
|
3881
|
+
if (masks_folder is not None or masks_file is not None)
|
3882
|
+
and metadata_format == "YOLO"
|
3883
|
+
else None
|
3884
|
+
)
|
3465
3885
|
|
3466
3886
|
# Get list of image files
|
3467
3887
|
image_files = []
|
@@ -3479,10 +3899,16 @@ def export_geotiff_tiles_batch(
|
|
3479
3899
|
|
3480
3900
|
# Handle different mask input modes
|
3481
3901
|
use_single_mask_file = masks_file is not None
|
3902
|
+
has_masks = masks_file is not None or masks_folder is not None
|
3482
3903
|
mask_files = []
|
3483
3904
|
image_mask_pairs = []
|
3484
3905
|
|
3485
|
-
if
|
3906
|
+
if not has_masks:
|
3907
|
+
# Mode 0: No masks - create pairs with None for mask
|
3908
|
+
for image_file in image_files:
|
3909
|
+
image_mask_pairs.append((image_file, None, None))
|
3910
|
+
|
3911
|
+
elif use_single_mask_file:
|
3486
3912
|
# Mode 1: Single vector file covering all images
|
3487
3913
|
if not os.path.exists(masks_file):
|
3488
3914
|
raise ValueError(f"Mask file not found: {masks_file}")
|
@@ -3534,10 +3960,21 @@ def export_geotiff_tiles_batch(
|
|
3534
3960
|
print(f"Warning: No mask found for image {img_base}")
|
3535
3961
|
|
3536
3962
|
if not image_mask_pairs:
|
3537
|
-
|
3963
|
+
# Provide detailed error message with found files
|
3964
|
+
image_bases = list(image_dict.keys())
|
3965
|
+
mask_bases = list(mask_dict.keys())
|
3966
|
+
error_msg = (
|
3538
3967
|
"No matching image-mask pairs found when matching by filename. "
|
3539
|
-
"Check that image and mask files have matching base names
|
3968
|
+
"Check that image and mask files have matching base names.\n"
|
3969
|
+
f"Found {len(image_bases)} image(s): "
|
3970
|
+
f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
|
3971
|
+
f"{'...' if len(image_bases) > 5 else ''}\n"
|
3972
|
+
f"Found {len(mask_bases)} mask(s): "
|
3973
|
+
f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
|
3974
|
+
f"{'...' if len(mask_bases) > 5 else ''}\n"
|
3975
|
+
"Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
|
3540
3976
|
)
|
3977
|
+
raise ValueError(error_msg)
|
3541
3978
|
|
3542
3979
|
else:
|
3543
3980
|
# Match by sorted order
|
@@ -3564,7 +4001,11 @@ def export_geotiff_tiles_batch(
|
|
3564
4001
|
}
|
3565
4002
|
|
3566
4003
|
if not quiet:
|
3567
|
-
if
|
4004
|
+
if not has_masks:
|
4005
|
+
print(
|
4006
|
+
f"Found {len(image_files)} image files to process (images only, no masks)"
|
4007
|
+
)
|
4008
|
+
elif use_single_mask_file:
|
3568
4009
|
print(f"Found {len(image_files)} image files to process")
|
3569
4010
|
print(f"Using single mask file: {masks_file}")
|
3570
4011
|
else:
|
@@ -3593,10 +4034,15 @@ def export_geotiff_tiles_batch(
|
|
3593
4034
|
if not quiet:
|
3594
4035
|
print(f"\nProcessing: {base_name}")
|
3595
4036
|
print(f" Image: {os.path.basename(image_file)}")
|
3596
|
-
if
|
3597
|
-
|
4037
|
+
if mask_file is not None:
|
4038
|
+
if use_single_mask_file:
|
4039
|
+
print(
|
4040
|
+
f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
|
4041
|
+
)
|
4042
|
+
else:
|
4043
|
+
print(f" Mask: {os.path.basename(mask_file)}")
|
3598
4044
|
else:
|
3599
|
-
print(f" Mask:
|
4045
|
+
print(f" Mask: None (images only)")
|
3600
4046
|
|
3601
4047
|
# Process the image-mask pair
|
3602
4048
|
tiles_generated = _process_image_mask_pair(
|
@@ -3718,11 +4164,12 @@ def export_geotiff_tiles_batch(
|
|
3718
4164
|
|
3719
4165
|
print(f"Output saved to: {output_folder}")
|
3720
4166
|
print(f" Images: {output_images_dir}")
|
3721
|
-
|
3722
|
-
|
3723
|
-
|
3724
|
-
|
3725
|
-
|
4167
|
+
if output_masks_dir is not None:
|
4168
|
+
print(f" Masks: {output_masks_dir}")
|
4169
|
+
if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
|
4170
|
+
print(f" Annotations: {ann_dir}")
|
4171
|
+
elif metadata_format == "YOLO":
|
4172
|
+
print(f" Labels: {os.path.join(output_folder, 'labels')}")
|
3726
4173
|
|
3727
4174
|
# List failed files if any
|
3728
4175
|
if batch_stats["failed_files"]:
|
@@ -3765,9 +4212,9 @@ def _process_image_mask_pair(
|
|
3765
4212
|
"""
|
3766
4213
|
import warnings
|
3767
4214
|
|
3768
|
-
# Determine if mask data is raster or vector
|
4215
|
+
# Determine if mask data is raster or vector (only if mask_file is provided)
|
3769
4216
|
is_class_data_raster = False
|
3770
|
-
if isinstance(mask_file, str):
|
4217
|
+
if mask_file is not None and isinstance(mask_file, str):
|
3771
4218
|
file_ext = Path(mask_file).suffix.lower()
|
3772
4219
|
# Common raster extensions
|
3773
4220
|
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
@@ -3801,10 +4248,10 @@ def _process_image_mask_pair(
|
|
3801
4248
|
if max_tiles is None:
|
3802
4249
|
max_tiles = total_tiles
|
3803
4250
|
|
3804
|
-
# Process classification data
|
4251
|
+
# Process classification data (only if mask_file is provided)
|
3805
4252
|
class_to_id = {}
|
3806
4253
|
|
3807
|
-
if is_class_data_raster:
|
4254
|
+
if mask_file is not None and is_class_data_raster:
|
3808
4255
|
# Load raster class data
|
3809
4256
|
with rasterio.open(mask_file) as class_src:
|
3810
4257
|
# Check if raster CRS matches
|
@@ -3831,7 +4278,7 @@ def _process_image_mask_pair(
|
|
3831
4278
|
|
3832
4279
|
# Create class mapping
|
3833
4280
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
3834
|
-
|
4281
|
+
elif mask_file is not None:
|
3835
4282
|
# Load vector class data
|
3836
4283
|
try:
|
3837
4284
|
if use_single_mask_file and mask_gdf is not None:
|
@@ -3907,12 +4354,12 @@ def _process_image_mask_pair(
|
|
3907
4354
|
|
3908
4355
|
window_bounds = box(minx, miny, maxx, maxy)
|
3909
4356
|
|
3910
|
-
# Create label mask
|
4357
|
+
# Create label mask (only if mask_file is provided)
|
3911
4358
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
3912
4359
|
has_features = False
|
3913
4360
|
|
3914
|
-
# Process classification data to create labels
|
3915
|
-
if is_class_data_raster:
|
4361
|
+
# Process classification data to create labels (only if mask_file is provided)
|
4362
|
+
if mask_file is not None and is_class_data_raster:
|
3916
4363
|
# For raster class data
|
3917
4364
|
with rasterio.open(mask_file) as class_src:
|
3918
4365
|
# Get corresponding window in class raster
|
@@ -3945,7 +4392,7 @@ def _process_image_mask_pair(
|
|
3945
4392
|
if not quiet:
|
3946
4393
|
print(f"Error reading class raster window: {e}")
|
3947
4394
|
stats["errors"] += 1
|
3948
|
-
|
4395
|
+
elif mask_file is not None:
|
3949
4396
|
# For vector class data
|
3950
4397
|
# Find features that intersect with window
|
3951
4398
|
window_features = gdf[gdf.intersects(window_bounds)]
|
@@ -3983,8 +4430,8 @@ def _process_image_mask_pair(
|
|
3983
4430
|
print(f"Error rasterizing feature {idx}: {e}")
|
3984
4431
|
stats["errors"] += 1
|
3985
4432
|
|
3986
|
-
# Skip tile if no features and skip_empty_tiles is True
|
3987
|
-
if skip_empty_tiles and not has_features:
|
4433
|
+
# Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
|
4434
|
+
if mask_file is not None and skip_empty_tiles and not has_features:
|
3988
4435
|
continue
|
3989
4436
|
|
3990
4437
|
# Check if we've reached max_tiles before saving
|
@@ -4021,32 +4468,37 @@ def _process_image_mask_pair(
|
|
4021
4468
|
print(f"ERROR saving image GeoTIFF: {e}")
|
4022
4469
|
stats["errors"] += 1
|
4023
4470
|
|
4024
|
-
#
|
4025
|
-
|
4026
|
-
|
4027
|
-
|
4028
|
-
|
4029
|
-
|
4030
|
-
|
4031
|
-
|
4032
|
-
|
4033
|
-
|
4471
|
+
# Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
|
4472
|
+
if mask_file is not None and output_masks_dir is not None:
|
4473
|
+
# Create profile for label GeoTIFF
|
4474
|
+
label_profile = {
|
4475
|
+
"driver": "GTiff",
|
4476
|
+
"height": tile_size,
|
4477
|
+
"width": tile_size,
|
4478
|
+
"count": 1,
|
4479
|
+
"dtype": "uint8",
|
4480
|
+
"crs": src.crs,
|
4481
|
+
"transform": window_transform,
|
4482
|
+
}
|
4034
4483
|
|
4035
|
-
|
4036
|
-
|
4037
|
-
|
4038
|
-
|
4039
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
4484
|
+
label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
|
4485
|
+
try:
|
4486
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
4487
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
4040
4488
|
|
4041
|
-
|
4042
|
-
|
4043
|
-
|
4044
|
-
|
4045
|
-
|
4046
|
-
|
4489
|
+
if has_features:
|
4490
|
+
stats["tiles_with_features"] += 1
|
4491
|
+
except Exception as e:
|
4492
|
+
if not quiet:
|
4493
|
+
print(f"ERROR saving label GeoTIFF: {e}")
|
4494
|
+
stats["errors"] += 1
|
4047
4495
|
|
4048
|
-
# Generate annotation metadata based on format
|
4049
|
-
if
|
4496
|
+
# Generate annotation metadata based on format (only if mask_file is provided)
|
4497
|
+
if (
|
4498
|
+
mask_file is not None
|
4499
|
+
and metadata_format == "PASCAL_VOC"
|
4500
|
+
and ann_dir
|
4501
|
+
):
|
4050
4502
|
# Create PASCAL VOC XML annotation
|
4051
4503
|
from lxml import etree as ET
|
4052
4504
|
|
@@ -4108,7 +4560,7 @@ def _process_image_mask_pair(
|
|
4108
4560
|
tree = ET.ElementTree(annotation)
|
4109
4561
|
tree.write(xml_path, pretty_print=True, encoding="utf-8")
|
4110
4562
|
|
4111
|
-
elif metadata_format == "COCO":
|
4563
|
+
elif mask_file is not None and metadata_format == "COCO":
|
4112
4564
|
# Add COCO image entry
|
4113
4565
|
image_id = int(global_tile_counter + tile_index)
|
4114
4566
|
stats["coco_data"]["images"].append(
|
@@ -4188,7 +4640,7 @@ def _process_image_mask_pair(
|
|
4188
4640
|
)
|
4189
4641
|
coco_ann_id += 1
|
4190
4642
|
|
4191
|
-
elif metadata_format == "YOLO":
|
4643
|
+
elif mask_file is not None and metadata_format == "YOLO":
|
4192
4644
|
# Create YOLO labels directory if needed
|
4193
4645
|
labels_dir = os.path.join(
|
4194
4646
|
os.path.dirname(output_images_dir), "labels"
|
@@ -8313,17 +8765,39 @@ def write_colormap(
|
|
8313
8765
|
|
8314
8766
|
def plot_performance_metrics(
|
8315
8767
|
history_path: str,
|
8316
|
-
figsize: Tuple[int, int] =
|
8768
|
+
figsize: Optional[Tuple[int, int]] = None,
|
8317
8769
|
verbose: bool = True,
|
8318
8770
|
save_path: Optional[str] = None,
|
8771
|
+
csv_path: Optional[str] = None,
|
8319
8772
|
kwargs: Optional[Dict] = None,
|
8320
|
-
) ->
|
8321
|
-
"""Plot performance metrics from a history object.
|
8773
|
+
) -> pd.DataFrame:
|
8774
|
+
"""Plot performance metrics from a training history object and return as DataFrame.
|
8775
|
+
|
8776
|
+
This function loads training history, plots available metrics (loss, IoU, F1,
|
8777
|
+
precision, recall), optionally exports to CSV, and returns all metrics as a
|
8778
|
+
pandas DataFrame for further analysis.
|
8322
8779
|
|
8323
8780
|
Args:
|
8324
|
-
history_path:
|
8325
|
-
figsize:
|
8326
|
-
|
8781
|
+
history_path (str): Path to the saved training history (.pth file).
|
8782
|
+
figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
|
8783
|
+
automatically determined based on number of metrics.
|
8784
|
+
verbose (bool): Whether to print best and final metric values. Defaults to True.
|
8785
|
+
save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
|
8786
|
+
csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
|
8787
|
+
kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
|
8788
|
+
|
8789
|
+
Returns:
|
8790
|
+
pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
|
8791
|
+
Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
|
8792
|
+
'val_precision', 'val_recall' (depending on availability in history).
|
8793
|
+
|
8794
|
+
Example:
|
8795
|
+
>>> df = plot_performance_metrics(
|
8796
|
+
... 'training_history.pth',
|
8797
|
+
... save_path='metrics_plot.png',
|
8798
|
+
... csv_path='metrics.csv'
|
8799
|
+
... )
|
8800
|
+
>>> print(df.head())
|
8327
8801
|
"""
|
8328
8802
|
if kwargs is None:
|
8329
8803
|
kwargs = {}
|
@@ -8333,65 +8807,135 @@ def plot_performance_metrics(
|
|
8333
8807
|
train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
|
8334
8808
|
val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
|
8335
8809
|
val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
|
8336
|
-
|
8810
|
+
# Support both new (f1) and old (dice) key formats for backward compatibility
|
8811
|
+
val_f1_key = (
|
8812
|
+
"val_f1s"
|
8813
|
+
if "val_f1s" in history
|
8814
|
+
else ("val_dices" if "val_dices" in history else "val_dice")
|
8815
|
+
)
|
8816
|
+
# Add support for precision and recall
|
8817
|
+
val_precision_key = (
|
8818
|
+
"val_precisions" if "val_precisions" in history else "val_precision"
|
8819
|
+
)
|
8820
|
+
val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
|
8821
|
+
|
8822
|
+
# Collect available metrics for plotting
|
8823
|
+
available_metrics = []
|
8824
|
+
metric_info = {
|
8825
|
+
"Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
|
8826
|
+
"IoU": (val_iou_key, None, ["Val IoU"]),
|
8827
|
+
"F1": (val_f1_key, None, ["Val F1"]),
|
8828
|
+
"Precision": (val_precision_key, None, ["Val Precision"]),
|
8829
|
+
"Recall": (val_recall_key, None, ["Val Recall"]),
|
8830
|
+
}
|
8831
|
+
|
8832
|
+
for metric_name, (key1, key2, labels) in metric_info.items():
|
8833
|
+
if key1 in history or (key2 and key2 in history):
|
8834
|
+
available_metrics.append((metric_name, key1, key2, labels))
|
8337
8835
|
|
8338
|
-
# Determine number of subplots
|
8339
|
-
|
8340
|
-
|
8341
|
-
|
8836
|
+
# Determine number of subplots and figure size
|
8837
|
+
n_plots = len(available_metrics)
|
8838
|
+
if figsize is None:
|
8839
|
+
figsize = (5 * n_plots, 5)
|
8342
8840
|
|
8343
|
-
|
8841
|
+
# Create DataFrame for all metrics
|
8842
|
+
n_epochs = 0
|
8843
|
+
df_data = {}
|
8344
8844
|
|
8345
|
-
#
|
8346
|
-
|
8845
|
+
# Add epochs
|
8846
|
+
if "epochs" in history:
|
8847
|
+
df_data["epoch"] = history["epochs"]
|
8848
|
+
n_epochs = len(history["epochs"])
|
8849
|
+
elif train_loss_key in history:
|
8850
|
+
n_epochs = len(history[train_loss_key])
|
8851
|
+
df_data["epoch"] = list(range(1, n_epochs + 1))
|
8852
|
+
|
8853
|
+
# Add all available metrics to DataFrame
|
8347
8854
|
if train_loss_key in history:
|
8348
|
-
|
8855
|
+
df_data["train_loss"] = history[train_loss_key]
|
8349
8856
|
if val_loss_key in history:
|
8350
|
-
|
8351
|
-
plt.title("Loss")
|
8352
|
-
plt.xlabel("Epoch")
|
8353
|
-
plt.ylabel("Loss")
|
8354
|
-
plt.legend()
|
8355
|
-
plt.grid(True)
|
8356
|
-
|
8357
|
-
# Plot IoU
|
8358
|
-
plt.subplot(1, n_plots, 2)
|
8857
|
+
df_data["val_loss"] = history[val_loss_key]
|
8359
8858
|
if val_iou_key in history:
|
8360
|
-
|
8361
|
-
|
8362
|
-
|
8363
|
-
|
8364
|
-
|
8365
|
-
|
8366
|
-
|
8367
|
-
|
8368
|
-
|
8369
|
-
|
8370
|
-
|
8371
|
-
|
8372
|
-
|
8373
|
-
|
8374
|
-
|
8375
|
-
|
8859
|
+
df_data["val_iou"] = history[val_iou_key]
|
8860
|
+
if val_f1_key in history:
|
8861
|
+
df_data["val_f1"] = history[val_f1_key]
|
8862
|
+
if val_precision_key in history:
|
8863
|
+
df_data["val_precision"] = history[val_precision_key]
|
8864
|
+
if val_recall_key in history:
|
8865
|
+
df_data["val_recall"] = history[val_recall_key]
|
8866
|
+
|
8867
|
+
# Create DataFrame
|
8868
|
+
df = pd.DataFrame(df_data)
|
8869
|
+
|
8870
|
+
# Export to CSV if requested
|
8871
|
+
if csv_path:
|
8872
|
+
df.to_csv(csv_path, index=False)
|
8873
|
+
if verbose:
|
8874
|
+
print(f"Metrics exported to: {csv_path}")
|
8875
|
+
|
8876
|
+
# Create plots
|
8877
|
+
if n_plots > 0:
|
8878
|
+
fig, axes = plt.subplots(1, n_plots, figsize=figsize)
|
8879
|
+
if n_plots == 1:
|
8880
|
+
axes = [axes]
|
8881
|
+
|
8882
|
+
for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
|
8883
|
+
ax = axes[idx]
|
8884
|
+
|
8885
|
+
if metric_name == "Loss":
|
8886
|
+
# Special handling for loss (has both train and val)
|
8887
|
+
if key1 in history:
|
8888
|
+
ax.plot(history[key1], label=labels[0])
|
8889
|
+
if key2 and key2 in history:
|
8890
|
+
ax.plot(history[key2], label=labels[1])
|
8891
|
+
else:
|
8892
|
+
# Single metric plots
|
8893
|
+
if key1 in history:
|
8894
|
+
ax.plot(history[key1], label=labels[0])
|
8376
8895
|
|
8377
|
-
|
8896
|
+
ax.set_title(metric_name)
|
8897
|
+
ax.set_xlabel("Epoch")
|
8898
|
+
ax.set_ylabel(metric_name)
|
8899
|
+
ax.legend()
|
8900
|
+
ax.grid(True)
|
8378
8901
|
|
8379
|
-
|
8380
|
-
if "dpi" not in kwargs:
|
8381
|
-
kwargs["dpi"] = 150
|
8382
|
-
if "bbox_inches" not in kwargs:
|
8383
|
-
kwargs["bbox_inches"] = "tight"
|
8384
|
-
plt.savefig(save_path, **kwargs)
|
8902
|
+
plt.tight_layout()
|
8385
8903
|
|
8386
|
-
|
8904
|
+
if save_path:
|
8905
|
+
if "dpi" not in kwargs:
|
8906
|
+
kwargs["dpi"] = 150
|
8907
|
+
if "bbox_inches" not in kwargs:
|
8908
|
+
kwargs["bbox_inches"] = "tight"
|
8909
|
+
plt.savefig(save_path, **kwargs)
|
8387
8910
|
|
8911
|
+
plt.show()
|
8912
|
+
|
8913
|
+
# Print summary statistics
|
8388
8914
|
if verbose:
|
8915
|
+
print("\n=== Performance Metrics Summary ===")
|
8389
8916
|
if val_iou_key in history:
|
8390
|
-
print(
|
8391
|
-
|
8392
|
-
|
8393
|
-
|
8394
|
-
print(
|
8917
|
+
print(
|
8918
|
+
f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
|
8919
|
+
)
|
8920
|
+
if val_f1_key in history:
|
8921
|
+
print(
|
8922
|
+
f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
|
8923
|
+
)
|
8924
|
+
if val_precision_key in history:
|
8925
|
+
print(
|
8926
|
+
f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
|
8927
|
+
)
|
8928
|
+
if val_recall_key in history:
|
8929
|
+
print(
|
8930
|
+
f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
|
8931
|
+
)
|
8932
|
+
if val_loss_key in history:
|
8933
|
+
print(
|
8934
|
+
f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
|
8935
|
+
)
|
8936
|
+
print("===================================\n")
|
8937
|
+
|
8938
|
+
return df
|
8395
8939
|
|
8396
8940
|
|
8397
8941
|
def get_device() -> torch.device:
|