geoai-py 0.14.0__py2.py3-none-any.whl → 0.16.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +21 -1
- geoai/change_detection.py +16 -6
- geoai/geoai.py +3 -0
- geoai/timm_segment.py +1097 -0
- geoai/timm_train.py +658 -0
- geoai/train.py +796 -107
- geoai/utils.py +1427 -245
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/METADATA +9 -1
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/RECORD +13 -11
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/licenses/LICENSE +1 -2
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.14.0.dist-info → geoai_py-0.16.0.dist-info}/top_level.txt +0 -0
geoai/utils.py
CHANGED
@@ -383,6 +383,394 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
|
|
383
383
|
return accum_mean / len(files), accum_std / len(files)
|
384
384
|
|
385
385
|
|
386
|
+
def calc_iou(
|
387
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
388
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
389
|
+
num_classes: Optional[int] = None,
|
390
|
+
ignore_index: Optional[int] = None,
|
391
|
+
smooth: float = 1e-6,
|
392
|
+
band: int = 1,
|
393
|
+
) -> Union[float, np.ndarray]:
|
394
|
+
"""
|
395
|
+
Calculate Intersection over Union (IoU) between ground truth and prediction masks.
|
396
|
+
|
397
|
+
This function computes the IoU metric for segmentation tasks. It supports both
|
398
|
+
binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
399
|
+
or file paths to raster files.
|
400
|
+
|
401
|
+
Args:
|
402
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
403
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
404
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
405
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
406
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
407
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
408
|
+
Should have the same shape and format as ground_truth.
|
409
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
410
|
+
If None, assumes binary segmentation. Defaults to None.
|
411
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
412
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
413
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
414
|
+
Defaults to 1e-6.
|
415
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
416
|
+
Only used when input is a file path. Defaults to 1.
|
417
|
+
|
418
|
+
Returns:
|
419
|
+
Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
|
420
|
+
For multi-class segmentation, returns an array of IoU scores for each class.
|
421
|
+
|
422
|
+
Examples:
|
423
|
+
>>> # Binary segmentation with arrays
|
424
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
425
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
426
|
+
>>> iou = calc_iou(gt, pred)
|
427
|
+
>>> print(f"IoU: {iou:.4f}")
|
428
|
+
IoU: 0.8333
|
429
|
+
|
430
|
+
>>> # Multi-class segmentation
|
431
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
432
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
433
|
+
>>> iou = calc_iou(gt, pred, num_classes=3)
|
434
|
+
>>> print(f"IoU per class: {iou}")
|
435
|
+
IoU per class: [0.8333 0.5000 0.5000]
|
436
|
+
|
437
|
+
>>> # Using PyTorch tensors
|
438
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
439
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
440
|
+
>>> iou = calc_iou(gt_tensor, pred_tensor)
|
441
|
+
>>> print(f"IoU: {iou:.4f}")
|
442
|
+
IoU: 0.8333
|
443
|
+
|
444
|
+
>>> # Using raster file paths
|
445
|
+
>>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
|
446
|
+
>>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
|
447
|
+
Mean IoU: 0.7500
|
448
|
+
"""
|
449
|
+
# Load from file if string path is provided
|
450
|
+
if isinstance(ground_truth, str):
|
451
|
+
with rasterio.open(ground_truth) as src:
|
452
|
+
ground_truth = src.read(band)
|
453
|
+
if isinstance(prediction, str):
|
454
|
+
with rasterio.open(prediction) as src:
|
455
|
+
prediction = src.read(band)
|
456
|
+
|
457
|
+
# Convert to numpy if torch tensor
|
458
|
+
if isinstance(ground_truth, torch.Tensor):
|
459
|
+
ground_truth = ground_truth.cpu().numpy()
|
460
|
+
if isinstance(prediction, torch.Tensor):
|
461
|
+
prediction = prediction.cpu().numpy()
|
462
|
+
|
463
|
+
# Ensure inputs have the same shape
|
464
|
+
if ground_truth.shape != prediction.shape:
|
465
|
+
raise ValueError(
|
466
|
+
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
467
|
+
)
|
468
|
+
|
469
|
+
# Binary segmentation
|
470
|
+
if num_classes is None:
|
471
|
+
ground_truth = ground_truth.astype(bool)
|
472
|
+
prediction = prediction.astype(bool)
|
473
|
+
|
474
|
+
intersection = np.logical_and(ground_truth, prediction).sum()
|
475
|
+
union = np.logical_or(ground_truth, prediction).sum()
|
476
|
+
|
477
|
+
if union == 0:
|
478
|
+
return 1.0 if intersection == 0 else 0.0
|
479
|
+
|
480
|
+
iou = (intersection + smooth) / (union + smooth)
|
481
|
+
return float(iou)
|
482
|
+
|
483
|
+
# Multi-class segmentation
|
484
|
+
else:
|
485
|
+
iou_per_class = []
|
486
|
+
|
487
|
+
for class_idx in range(num_classes):
|
488
|
+
# Handle ignored class by appending np.nan
|
489
|
+
if ignore_index is not None and class_idx == ignore_index:
|
490
|
+
iou_per_class.append(np.nan)
|
491
|
+
continue
|
492
|
+
|
493
|
+
# Create binary masks for current class
|
494
|
+
gt_class = (ground_truth == class_idx).astype(bool)
|
495
|
+
pred_class = (prediction == class_idx).astype(bool)
|
496
|
+
|
497
|
+
intersection = np.logical_and(gt_class, pred_class).sum()
|
498
|
+
union = np.logical_or(gt_class, pred_class).sum()
|
499
|
+
|
500
|
+
if union == 0:
|
501
|
+
# If class is not present in both gt and pred
|
502
|
+
iou_per_class.append(np.nan)
|
503
|
+
else:
|
504
|
+
iou_per_class.append((intersection + smooth) / (union + smooth))
|
505
|
+
|
506
|
+
return np.array(iou_per_class)
|
507
|
+
|
508
|
+
|
509
|
+
def calc_f1_score(
|
510
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
511
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
512
|
+
num_classes: Optional[int] = None,
|
513
|
+
ignore_index: Optional[int] = None,
|
514
|
+
smooth: float = 1e-6,
|
515
|
+
band: int = 1,
|
516
|
+
) -> Union[float, np.ndarray]:
|
517
|
+
"""
|
518
|
+
Calculate F1 score between ground truth and prediction masks.
|
519
|
+
|
520
|
+
The F1 score is the harmonic mean of precision and recall, computed as:
|
521
|
+
F1 = 2 * (precision * recall) / (precision + recall)
|
522
|
+
where precision = TP / (TP + FP) and recall = TP / (TP + FN).
|
523
|
+
|
524
|
+
This function supports both binary and multi-class segmentation, and can handle
|
525
|
+
numpy arrays, PyTorch tensors, or file paths to raster files.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
529
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
530
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
531
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
532
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
533
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
534
|
+
Should have the same shape and format as ground_truth.
|
535
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
536
|
+
If None, assumes binary segmentation. Defaults to None.
|
537
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
538
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
539
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
540
|
+
Defaults to 1e-6.
|
541
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
542
|
+
Only used when input is a file path. Defaults to 1.
|
543
|
+
|
544
|
+
Returns:
|
545
|
+
Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
|
546
|
+
For multi-class segmentation, returns an array of F1 scores for each class.
|
547
|
+
|
548
|
+
Examples:
|
549
|
+
>>> # Binary segmentation with arrays
|
550
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
551
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
552
|
+
>>> f1 = calc_f1_score(gt, pred)
|
553
|
+
>>> print(f"F1 Score: {f1:.4f}")
|
554
|
+
F1 Score: 0.8571
|
555
|
+
|
556
|
+
>>> # Multi-class segmentation
|
557
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
558
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
559
|
+
>>> f1 = calc_f1_score(gt, pred, num_classes=3)
|
560
|
+
>>> print(f"F1 Score per class: {f1}")
|
561
|
+
F1 Score per class: [0.8571 0.6667 0.6667]
|
562
|
+
|
563
|
+
>>> # Using PyTorch tensors
|
564
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
565
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
566
|
+
>>> f1 = calc_f1_score(gt_tensor, pred_tensor)
|
567
|
+
>>> print(f"F1 Score: {f1:.4f}")
|
568
|
+
F1 Score: 0.8571
|
569
|
+
|
570
|
+
>>> # Using raster file paths
|
571
|
+
>>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
|
572
|
+
>>> print(f"Mean F1: {np.nanmean(f1):.4f}")
|
573
|
+
Mean F1: 0.7302
|
574
|
+
"""
|
575
|
+
# Load from file if string path is provided
|
576
|
+
if isinstance(ground_truth, str):
|
577
|
+
with rasterio.open(ground_truth) as src:
|
578
|
+
ground_truth = src.read(band)
|
579
|
+
if isinstance(prediction, str):
|
580
|
+
with rasterio.open(prediction) as src:
|
581
|
+
prediction = src.read(band)
|
582
|
+
|
583
|
+
# Convert to numpy if torch tensor
|
584
|
+
if isinstance(ground_truth, torch.Tensor):
|
585
|
+
ground_truth = ground_truth.cpu().numpy()
|
586
|
+
if isinstance(prediction, torch.Tensor):
|
587
|
+
prediction = prediction.cpu().numpy()
|
588
|
+
|
589
|
+
# Ensure inputs have the same shape
|
590
|
+
if ground_truth.shape != prediction.shape:
|
591
|
+
raise ValueError(
|
592
|
+
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
593
|
+
)
|
594
|
+
|
595
|
+
# Binary segmentation
|
596
|
+
if num_classes is None:
|
597
|
+
ground_truth = ground_truth.astype(bool)
|
598
|
+
prediction = prediction.astype(bool)
|
599
|
+
|
600
|
+
# Calculate True Positives, False Positives, False Negatives
|
601
|
+
tp = np.logical_and(ground_truth, prediction).sum()
|
602
|
+
fp = np.logical_and(~ground_truth, prediction).sum()
|
603
|
+
fn = np.logical_and(ground_truth, ~prediction).sum()
|
604
|
+
|
605
|
+
# Calculate precision and recall
|
606
|
+
precision = (tp + smooth) / (tp + fp + smooth)
|
607
|
+
recall = (tp + smooth) / (tp + fn + smooth)
|
608
|
+
|
609
|
+
# Calculate F1 score
|
610
|
+
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
611
|
+
return float(f1)
|
612
|
+
|
613
|
+
# Multi-class segmentation
|
614
|
+
else:
|
615
|
+
f1_per_class = []
|
616
|
+
|
617
|
+
for class_idx in range(num_classes):
|
618
|
+
# Mark ignored class with np.nan
|
619
|
+
if ignore_index is not None and class_idx == ignore_index:
|
620
|
+
f1_per_class.append(np.nan)
|
621
|
+
continue
|
622
|
+
|
623
|
+
# Create binary masks for current class
|
624
|
+
gt_class = (ground_truth == class_idx).astype(bool)
|
625
|
+
pred_class = (prediction == class_idx).astype(bool)
|
626
|
+
|
627
|
+
# Calculate True Positives, False Positives, False Negatives
|
628
|
+
tp = np.logical_and(gt_class, pred_class).sum()
|
629
|
+
fp = np.logical_and(~gt_class, pred_class).sum()
|
630
|
+
fn = np.logical_and(gt_class, ~pred_class).sum()
|
631
|
+
|
632
|
+
# Calculate precision and recall
|
633
|
+
precision = (tp + smooth) / (tp + fp + smooth)
|
634
|
+
recall = (tp + smooth) / (tp + fn + smooth)
|
635
|
+
|
636
|
+
# Calculate F1 score
|
637
|
+
if tp + fp + fn == 0:
|
638
|
+
# If class is not present in both gt and pred
|
639
|
+
f1_per_class.append(np.nan)
|
640
|
+
else:
|
641
|
+
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
642
|
+
f1_per_class.append(f1)
|
643
|
+
|
644
|
+
return np.array(f1_per_class)
|
645
|
+
|
646
|
+
|
647
|
+
def calc_segmentation_metrics(
|
648
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
649
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
650
|
+
num_classes: Optional[int] = None,
|
651
|
+
ignore_index: Optional[int] = None,
|
652
|
+
smooth: float = 1e-6,
|
653
|
+
metrics: List[str] = ["iou", "f1"],
|
654
|
+
band: int = 1,
|
655
|
+
) -> Dict[str, Union[float, np.ndarray]]:
|
656
|
+
"""
|
657
|
+
Calculate multiple segmentation metrics between ground truth and prediction masks.
|
658
|
+
|
659
|
+
This is a convenient wrapper function that computes multiple metrics at once,
|
660
|
+
including IoU (Intersection over Union) and F1 score. It supports both binary
|
661
|
+
and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
662
|
+
or file paths to raster files.
|
663
|
+
|
664
|
+
Args:
|
665
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
666
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
667
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
668
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
669
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
670
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
671
|
+
Should have the same shape and format as ground_truth.
|
672
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
673
|
+
If None, assumes binary segmentation. Defaults to None.
|
674
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
675
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
676
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
677
|
+
Defaults to 1e-6.
|
678
|
+
metrics (List[str], optional): List of metrics to calculate.
|
679
|
+
Options: "iou", "f1". Defaults to ["iou", "f1"].
|
680
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
681
|
+
Only used when input is a file path. Defaults to 1.
|
682
|
+
|
683
|
+
Returns:
|
684
|
+
Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
|
685
|
+
Keys are metric names ("iou", "f1"), values are the metric scores.
|
686
|
+
For binary segmentation, values are floats.
|
687
|
+
For multi-class segmentation, values are numpy arrays with per-class scores.
|
688
|
+
Also includes "mean_iou" and "mean_f1" for multi-class segmentation
|
689
|
+
(mean computed over valid classes, ignoring NaN values).
|
690
|
+
|
691
|
+
Examples:
|
692
|
+
>>> # Binary segmentation with arrays
|
693
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
694
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
695
|
+
>>> metrics = calc_segmentation_metrics(gt, pred)
|
696
|
+
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
697
|
+
IoU: 0.8333, F1: 0.8571
|
698
|
+
|
699
|
+
>>> # Multi-class segmentation
|
700
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
701
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
702
|
+
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
|
703
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
704
|
+
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
705
|
+
>>> print(f"Per-class IoU: {metrics['iou']}")
|
706
|
+
Mean IoU: 0.6111
|
707
|
+
Mean F1: 0.7302
|
708
|
+
Per-class IoU: [0.8333 0.5000 0.5000]
|
709
|
+
|
710
|
+
>>> # Calculate only IoU
|
711
|
+
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
|
712
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
713
|
+
Mean IoU: 0.6111
|
714
|
+
|
715
|
+
>>> # Using PyTorch tensors
|
716
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
717
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
718
|
+
>>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
|
719
|
+
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
720
|
+
IoU: 0.8333, F1: 0.8571
|
721
|
+
|
722
|
+
>>> # Using raster file paths
|
723
|
+
>>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
|
724
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
725
|
+
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
726
|
+
Mean IoU: 0.6111
|
727
|
+
Mean F1: 0.7302
|
728
|
+
"""
|
729
|
+
results = {}
|
730
|
+
|
731
|
+
# Calculate IoU if requested
|
732
|
+
if "iou" in metrics:
|
733
|
+
iou = calc_iou(
|
734
|
+
ground_truth,
|
735
|
+
prediction,
|
736
|
+
num_classes=num_classes,
|
737
|
+
ignore_index=ignore_index,
|
738
|
+
smooth=smooth,
|
739
|
+
band=band,
|
740
|
+
)
|
741
|
+
results["iou"] = iou
|
742
|
+
|
743
|
+
# Add mean IoU for multi-class
|
744
|
+
if num_classes is not None and isinstance(iou, np.ndarray):
|
745
|
+
# Calculate mean ignoring NaN values
|
746
|
+
valid_ious = iou[~np.isnan(iou)]
|
747
|
+
results["mean_iou"] = (
|
748
|
+
float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
|
749
|
+
)
|
750
|
+
|
751
|
+
# Calculate F1 score if requested
|
752
|
+
if "f1" in metrics:
|
753
|
+
f1 = calc_f1_score(
|
754
|
+
ground_truth,
|
755
|
+
prediction,
|
756
|
+
num_classes=num_classes,
|
757
|
+
ignore_index=ignore_index,
|
758
|
+
smooth=smooth,
|
759
|
+
band=band,
|
760
|
+
)
|
761
|
+
results["f1"] = f1
|
762
|
+
|
763
|
+
# Add mean F1 for multi-class
|
764
|
+
if num_classes is not None and isinstance(f1, np.ndarray):
|
765
|
+
# Calculate mean ignoring NaN values
|
766
|
+
valid_f1s = f1[~np.isnan(f1)]
|
767
|
+
results["mean_f1"] = (
|
768
|
+
float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
|
769
|
+
)
|
770
|
+
|
771
|
+
return results
|
772
|
+
|
773
|
+
|
386
774
|
def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
|
387
775
|
"""Convert a dictionary to a xarray DataArray. The dictionary should contain the
|
388
776
|
following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
|
@@ -751,6 +1139,12 @@ def view_vector_interactive(
|
|
751
1139
|
},
|
752
1140
|
}
|
753
1141
|
|
1142
|
+
# Make it compatible with binder and JupyterHub
|
1143
|
+
if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
|
1144
|
+
os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
|
1145
|
+
f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
|
1146
|
+
)
|
1147
|
+
|
754
1148
|
basemap_layer_name = None
|
755
1149
|
raster_layer = None
|
756
1150
|
|
@@ -2599,7 +2993,7 @@ def batch_vector_to_raster(
|
|
2599
2993
|
def export_geotiff_tiles(
|
2600
2994
|
in_raster,
|
2601
2995
|
out_folder,
|
2602
|
-
in_class_data,
|
2996
|
+
in_class_data=None,
|
2603
2997
|
tile_size=256,
|
2604
2998
|
stride=128,
|
2605
2999
|
class_value_field="class",
|
@@ -2609,6 +3003,7 @@ def export_geotiff_tiles(
|
|
2609
3003
|
all_touched=True,
|
2610
3004
|
create_overview=False,
|
2611
3005
|
skip_empty_tiles=False,
|
3006
|
+
metadata_format="PASCAL_VOC",
|
2612
3007
|
):
|
2613
3008
|
"""
|
2614
3009
|
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
@@ -2616,7 +3011,8 @@ def export_geotiff_tiles(
|
|
2616
3011
|
Args:
|
2617
3012
|
in_raster (str): Path to input raster image
|
2618
3013
|
out_folder (str): Path to output folder
|
2619
|
-
in_class_data (str): Path to classification data - can be vector file or raster
|
3014
|
+
in_class_data (str, optional): Path to classification data - can be vector file or raster.
|
3015
|
+
If None, only image tiles will be exported without labels. Defaults to None.
|
2620
3016
|
tile_size (int): Size of tiles in pixels (square)
|
2621
3017
|
stride (int): Step size between tiles
|
2622
3018
|
class_value_field (str): Field containing class values (for vector data)
|
@@ -2626,6 +3022,7 @@ def export_geotiff_tiles(
|
|
2626
3022
|
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
2627
3023
|
create_overview (bool): Whether to create an overview image of all tiles
|
2628
3024
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
3025
|
+
metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
|
2629
3026
|
"""
|
2630
3027
|
|
2631
3028
|
import logging
|
@@ -2636,28 +3033,42 @@ def export_geotiff_tiles(
|
|
2636
3033
|
os.makedirs(out_folder, exist_ok=True)
|
2637
3034
|
image_dir = os.path.join(out_folder, "images")
|
2638
3035
|
os.makedirs(image_dir, exist_ok=True)
|
2639
|
-
label_dir = os.path.join(out_folder, "labels")
|
2640
|
-
os.makedirs(label_dir, exist_ok=True)
|
2641
|
-
ann_dir = os.path.join(out_folder, "annotations")
|
2642
|
-
os.makedirs(ann_dir, exist_ok=True)
|
2643
3036
|
|
2644
|
-
#
|
3037
|
+
# Only create label and annotation directories if class data is provided
|
3038
|
+
if in_class_data is not None:
|
3039
|
+
label_dir = os.path.join(out_folder, "labels")
|
3040
|
+
os.makedirs(label_dir, exist_ok=True)
|
3041
|
+
|
3042
|
+
# Create annotation directory based on metadata format
|
3043
|
+
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
3044
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
3045
|
+
os.makedirs(ann_dir, exist_ok=True)
|
3046
|
+
|
3047
|
+
# Initialize COCO annotations dictionary
|
3048
|
+
if metadata_format == "COCO":
|
3049
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
3050
|
+
ann_id = 0
|
3051
|
+
|
3052
|
+
# Determine if class data is raster or vector (only if class data provided)
|
2645
3053
|
is_class_data_raster = False
|
2646
|
-
if
|
2647
|
-
|
2648
|
-
|
2649
|
-
|
2650
|
-
|
2651
|
-
|
2652
|
-
|
3054
|
+
if in_class_data is not None:
|
3055
|
+
if isinstance(in_class_data, str):
|
3056
|
+
file_ext = Path(in_class_data).suffix.lower()
|
3057
|
+
# Common raster extensions
|
3058
|
+
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
3059
|
+
try:
|
3060
|
+
with rasterio.open(in_class_data) as src:
|
3061
|
+
is_class_data_raster = True
|
3062
|
+
if not quiet:
|
3063
|
+
print(f"Detected in_class_data as raster: {in_class_data}")
|
3064
|
+
print(f"Raster CRS: {src.crs}")
|
3065
|
+
print(f"Raster dimensions: {src.width} x {src.height}")
|
3066
|
+
except Exception:
|
3067
|
+
is_class_data_raster = False
|
2653
3068
|
if not quiet:
|
2654
|
-
print(
|
2655
|
-
|
2656
|
-
|
2657
|
-
except Exception:
|
2658
|
-
is_class_data_raster = False
|
2659
|
-
if not quiet:
|
2660
|
-
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
3069
|
+
print(
|
3070
|
+
f"Unable to open {in_class_data} as raster, trying as vector"
|
3071
|
+
)
|
2661
3072
|
|
2662
3073
|
# Open the input raster
|
2663
3074
|
with rasterio.open(in_raster) as src:
|
@@ -2677,10 +3088,10 @@ def export_geotiff_tiles(
|
|
2677
3088
|
if max_tiles is None:
|
2678
3089
|
max_tiles = total_tiles
|
2679
3090
|
|
2680
|
-
# Process classification data
|
3091
|
+
# Process classification data (only if class data provided)
|
2681
3092
|
class_to_id = {}
|
2682
3093
|
|
2683
|
-
if is_class_data_raster:
|
3094
|
+
if in_class_data is not None and is_class_data_raster:
|
2684
3095
|
# Load raster class data
|
2685
3096
|
with rasterio.open(in_class_data) as class_src:
|
2686
3097
|
# Check if raster CRS matches
|
@@ -2713,7 +3124,18 @@ def export_geotiff_tiles(
|
|
2713
3124
|
|
2714
3125
|
# Create class mapping
|
2715
3126
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
2716
|
-
|
3127
|
+
|
3128
|
+
# Populate COCO categories
|
3129
|
+
if metadata_format == "COCO":
|
3130
|
+
for cls_val in unique_classes:
|
3131
|
+
coco_annotations["categories"].append(
|
3132
|
+
{
|
3133
|
+
"id": class_to_id[int(cls_val)],
|
3134
|
+
"name": str(int(cls_val)),
|
3135
|
+
"supercategory": "object",
|
3136
|
+
}
|
3137
|
+
)
|
3138
|
+
elif in_class_data is not None:
|
2717
3139
|
# Load vector class data
|
2718
3140
|
try:
|
2719
3141
|
gdf = gpd.read_file(in_class_data)
|
@@ -2742,12 +3164,33 @@ def export_geotiff_tiles(
|
|
2742
3164
|
)
|
2743
3165
|
# Create class mapping
|
2744
3166
|
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
3167
|
+
|
3168
|
+
# Populate COCO categories
|
3169
|
+
if metadata_format == "COCO":
|
3170
|
+
for cls_val in unique_classes:
|
3171
|
+
coco_annotations["categories"].append(
|
3172
|
+
{
|
3173
|
+
"id": class_to_id[cls_val],
|
3174
|
+
"name": str(cls_val),
|
3175
|
+
"supercategory": "object",
|
3176
|
+
}
|
3177
|
+
)
|
2745
3178
|
else:
|
2746
3179
|
if not quiet:
|
2747
3180
|
print(
|
2748
3181
|
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
2749
3182
|
)
|
2750
3183
|
class_to_id = {1: 1} # Default mapping
|
3184
|
+
|
3185
|
+
# Populate COCO categories with default
|
3186
|
+
if metadata_format == "COCO":
|
3187
|
+
coco_annotations["categories"].append(
|
3188
|
+
{
|
3189
|
+
"id": 1,
|
3190
|
+
"name": "object",
|
3191
|
+
"supercategory": "object",
|
3192
|
+
}
|
3193
|
+
)
|
2751
3194
|
except Exception as e:
|
2752
3195
|
raise ValueError(f"Error processing vector data: {e}")
|
2753
3196
|
|
@@ -2814,8 +3257,8 @@ def export_geotiff_tiles(
|
|
2814
3257
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
2815
3258
|
has_features = False
|
2816
3259
|
|
2817
|
-
# Process classification data to create labels
|
2818
|
-
if is_class_data_raster:
|
3260
|
+
# Process classification data to create labels (only if class data provided)
|
3261
|
+
if in_class_data is not None and is_class_data_raster:
|
2819
3262
|
# For raster class data
|
2820
3263
|
with rasterio.open(in_class_data) as class_src:
|
2821
3264
|
# Calculate window in class raster
|
@@ -2865,7 +3308,7 @@ def export_geotiff_tiles(
|
|
2865
3308
|
except Exception as e:
|
2866
3309
|
pbar.write(f"Error reading class raster window: {e}")
|
2867
3310
|
stats["errors"] += 1
|
2868
|
-
|
3311
|
+
elif in_class_data is not None:
|
2869
3312
|
# For vector class data
|
2870
3313
|
# Find features that intersect with window
|
2871
3314
|
window_features = gdf[gdf.intersects(window_bounds)]
|
@@ -2908,8 +3351,8 @@ def export_geotiff_tiles(
|
|
2908
3351
|
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
2909
3352
|
stats["errors"] += 1
|
2910
3353
|
|
2911
|
-
# Skip tile if no features and skip_empty_tiles is True
|
2912
|
-
if skip_empty_tiles and not has_features:
|
3354
|
+
# Skip tile if no features and skip_empty_tiles is True (only when class data provided)
|
3355
|
+
if in_class_data is not None and skip_empty_tiles and not has_features:
|
2913
3356
|
pbar.update(1)
|
2914
3357
|
tile_index += 1
|
2915
3358
|
continue
|
@@ -2940,96 +3383,212 @@ def export_geotiff_tiles(
|
|
2940
3383
|
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
2941
3384
|
stats["errors"] += 1
|
2942
3385
|
|
2943
|
-
#
|
2944
|
-
|
2945
|
-
|
2946
|
-
|
2947
|
-
|
2948
|
-
|
2949
|
-
|
2950
|
-
|
2951
|
-
|
2952
|
-
|
3386
|
+
# Export label as GeoTIFF (only if class data provided)
|
3387
|
+
if in_class_data is not None:
|
3388
|
+
# Create profile for label GeoTIFF
|
3389
|
+
label_profile = {
|
3390
|
+
"driver": "GTiff",
|
3391
|
+
"height": tile_size,
|
3392
|
+
"width": tile_size,
|
3393
|
+
"count": 1,
|
3394
|
+
"dtype": "uint8",
|
3395
|
+
"crs": src.crs,
|
3396
|
+
"transform": window_transform,
|
3397
|
+
}
|
2953
3398
|
|
2954
|
-
|
2955
|
-
|
2956
|
-
|
2957
|
-
|
2958
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
3399
|
+
label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
|
3400
|
+
try:
|
3401
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
3402
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
2959
3403
|
|
2960
|
-
|
2961
|
-
|
2962
|
-
|
2963
|
-
|
2964
|
-
|
2965
|
-
|
3404
|
+
if has_features:
|
3405
|
+
stats["tiles_with_features"] += 1
|
3406
|
+
stats["feature_pixels"] += np.count_nonzero(label_mask)
|
3407
|
+
except Exception as e:
|
3408
|
+
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
3409
|
+
stats["errors"] += 1
|
2966
3410
|
|
2967
|
-
# Create
|
3411
|
+
# Create annotations for object detection if using vector class data
|
2968
3412
|
if (
|
2969
|
-
not
|
3413
|
+
in_class_data is not None
|
3414
|
+
and not is_class_data_raster
|
2970
3415
|
and "gdf" in locals()
|
2971
3416
|
and len(window_features) > 0
|
2972
3417
|
):
|
2973
|
-
|
2974
|
-
|
2975
|
-
|
2976
|
-
|
3418
|
+
if metadata_format == "PASCAL_VOC":
|
3419
|
+
# Create XML annotation
|
3420
|
+
root = ET.Element("annotation")
|
3421
|
+
ET.SubElement(root, "folder").text = "images"
|
3422
|
+
ET.SubElement(root, "filename").text = (
|
3423
|
+
f"tile_{tile_index:06d}.tif"
|
3424
|
+
)
|
2977
3425
|
|
2978
|
-
|
2979
|
-
|
2980
|
-
|
2981
|
-
|
3426
|
+
size = ET.SubElement(root, "size")
|
3427
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
3428
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
3429
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
3430
|
+
|
3431
|
+
# Add georeference information
|
3432
|
+
geo = ET.SubElement(root, "georeference")
|
3433
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
3434
|
+
ET.SubElement(geo, "transform").text = str(
|
3435
|
+
window_transform
|
3436
|
+
).replace("\n", "")
|
3437
|
+
ET.SubElement(geo, "bounds").text = (
|
3438
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
3439
|
+
)
|
2982
3440
|
|
2983
|
-
|
2984
|
-
|
2985
|
-
|
2986
|
-
|
2987
|
-
|
2988
|
-
|
2989
|
-
|
2990
|
-
f"{minx}, {miny}, {maxx}, {maxy}"
|
2991
|
-
)
|
3441
|
+
# Add objects
|
3442
|
+
for idx, feature in window_features.iterrows():
|
3443
|
+
# Get feature class
|
3444
|
+
if class_value_field in feature:
|
3445
|
+
class_val = feature[class_value_field]
|
3446
|
+
else:
|
3447
|
+
class_val = "object"
|
2992
3448
|
|
2993
|
-
|
2994
|
-
|
2995
|
-
|
2996
|
-
|
2997
|
-
|
2998
|
-
|
2999
|
-
|
3449
|
+
# Get geometry bounds in pixel coordinates
|
3450
|
+
geom = feature.geometry.intersection(window_bounds)
|
3451
|
+
if not geom.is_empty:
|
3452
|
+
# Get bounds in world coordinates
|
3453
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
3454
|
+
|
3455
|
+
# Convert to pixel coordinates
|
3456
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
3457
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
3458
|
+
|
3459
|
+
# Ensure coordinates are within tile bounds
|
3460
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
3461
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
3462
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
3463
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
3464
|
+
|
3465
|
+
# Only add if the box has non-zero area
|
3466
|
+
if xmax > xmin and ymax > ymin:
|
3467
|
+
obj = ET.SubElement(root, "object")
|
3468
|
+
ET.SubElement(obj, "name").text = str(class_val)
|
3469
|
+
ET.SubElement(obj, "difficult").text = "0"
|
3470
|
+
|
3471
|
+
bbox = ET.SubElement(obj, "bndbox")
|
3472
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
3473
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
3474
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
3475
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
3476
|
+
|
3477
|
+
# Save XML
|
3478
|
+
tree = ET.ElementTree(root)
|
3479
|
+
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
3480
|
+
tree.write(xml_path)
|
3000
3481
|
|
3001
|
-
|
3002
|
-
|
3003
|
-
|
3004
|
-
|
3005
|
-
|
3006
|
-
|
3007
|
-
|
3008
|
-
|
3009
|
-
|
3010
|
-
|
3011
|
-
|
3012
|
-
|
3013
|
-
|
3014
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
3015
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
3016
|
-
|
3017
|
-
# Only add if the box has non-zero area
|
3018
|
-
if xmax > xmin and ymax > ymin:
|
3019
|
-
obj = ET.SubElement(root, "object")
|
3020
|
-
ET.SubElement(obj, "name").text = str(class_val)
|
3021
|
-
ET.SubElement(obj, "difficult").text = "0"
|
3022
|
-
|
3023
|
-
bbox = ET.SubElement(obj, "bndbox")
|
3024
|
-
ET.SubElement(bbox, "xmin").text = str(xmin)
|
3025
|
-
ET.SubElement(bbox, "ymin").text = str(ymin)
|
3026
|
-
ET.SubElement(bbox, "xmax").text = str(xmax)
|
3027
|
-
ET.SubElement(bbox, "ymax").text = str(ymax)
|
3482
|
+
elif metadata_format == "COCO":
|
3483
|
+
# Add image info
|
3484
|
+
image_id = tile_index
|
3485
|
+
coco_annotations["images"].append(
|
3486
|
+
{
|
3487
|
+
"id": image_id,
|
3488
|
+
"file_name": f"tile_{tile_index:06d}.tif",
|
3489
|
+
"width": tile_size,
|
3490
|
+
"height": tile_size,
|
3491
|
+
"crs": str(src.crs),
|
3492
|
+
"transform": str(window_transform),
|
3493
|
+
}
|
3494
|
+
)
|
3028
3495
|
|
3029
|
-
|
3030
|
-
|
3031
|
-
|
3032
|
-
|
3496
|
+
# Add annotations for each feature
|
3497
|
+
for _, feature in window_features.iterrows():
|
3498
|
+
# Get feature class
|
3499
|
+
if class_value_field in feature:
|
3500
|
+
class_val = feature[class_value_field]
|
3501
|
+
category_id = class_to_id.get(class_val, 1)
|
3502
|
+
else:
|
3503
|
+
category_id = 1
|
3504
|
+
|
3505
|
+
# Get geometry bounds
|
3506
|
+
geom = feature.geometry.intersection(window_bounds)
|
3507
|
+
if not geom.is_empty:
|
3508
|
+
# Get bounds in world coordinates
|
3509
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
3510
|
+
|
3511
|
+
# Convert to pixel coordinates
|
3512
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
3513
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
3514
|
+
|
3515
|
+
# Ensure coordinates are within tile bounds
|
3516
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
3517
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
3518
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
3519
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
3520
|
+
|
3521
|
+
# Skip if box is too small
|
3522
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
3523
|
+
continue
|
3524
|
+
|
3525
|
+
width = xmax - xmin
|
3526
|
+
height = ymax - ymin
|
3527
|
+
|
3528
|
+
# Add annotation
|
3529
|
+
ann_id += 1
|
3530
|
+
coco_annotations["annotations"].append(
|
3531
|
+
{
|
3532
|
+
"id": ann_id,
|
3533
|
+
"image_id": image_id,
|
3534
|
+
"category_id": category_id,
|
3535
|
+
"bbox": [xmin, ymin, width, height],
|
3536
|
+
"area": width * height,
|
3537
|
+
"iscrowd": 0,
|
3538
|
+
}
|
3539
|
+
)
|
3540
|
+
|
3541
|
+
elif metadata_format == "YOLO":
|
3542
|
+
# Create YOLO format annotations
|
3543
|
+
yolo_annotations = []
|
3544
|
+
|
3545
|
+
for _, feature in window_features.iterrows():
|
3546
|
+
# Get feature class
|
3547
|
+
if class_value_field in feature:
|
3548
|
+
class_val = feature[class_value_field]
|
3549
|
+
# YOLO uses 0-indexed class IDs
|
3550
|
+
class_id = class_to_id.get(class_val, 1) - 1
|
3551
|
+
else:
|
3552
|
+
class_id = 0
|
3553
|
+
|
3554
|
+
# Get geometry bounds
|
3555
|
+
geom = feature.geometry.intersection(window_bounds)
|
3556
|
+
if not geom.is_empty:
|
3557
|
+
# Get bounds in world coordinates
|
3558
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
3559
|
+
|
3560
|
+
# Convert to pixel coordinates
|
3561
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
3562
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
3563
|
+
|
3564
|
+
# Ensure coordinates are within tile bounds
|
3565
|
+
xmin = max(0, min(tile_size, col_min))
|
3566
|
+
ymin = max(0, min(tile_size, row_min))
|
3567
|
+
xmax = max(0, min(tile_size, col_max))
|
3568
|
+
ymax = max(0, min(tile_size, row_max))
|
3569
|
+
|
3570
|
+
# Skip if box is too small
|
3571
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
3572
|
+
continue
|
3573
|
+
|
3574
|
+
# Calculate normalized coordinates (YOLO format)
|
3575
|
+
x_center = ((xmin + xmax) / 2) / tile_size
|
3576
|
+
y_center = ((ymin + ymax) / 2) / tile_size
|
3577
|
+
width = (xmax - xmin) / tile_size
|
3578
|
+
height = (ymax - ymin) / tile_size
|
3579
|
+
|
3580
|
+
# Add YOLO annotation line
|
3581
|
+
yolo_annotations.append(
|
3582
|
+
f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
3583
|
+
)
|
3584
|
+
|
3585
|
+
# Save YOLO annotations to text file
|
3586
|
+
if yolo_annotations:
|
3587
|
+
yolo_path = os.path.join(
|
3588
|
+
label_dir, f"tile_{tile_index:06d}.txt"
|
3589
|
+
)
|
3590
|
+
with open(yolo_path, "w") as f:
|
3591
|
+
f.write("\n".join(yolo_annotations))
|
3033
3592
|
|
3034
3593
|
# Update progress bar
|
3035
3594
|
pbar.update(1)
|
@@ -3047,6 +3606,39 @@ def export_geotiff_tiles(
|
|
3047
3606
|
# Close progress bar
|
3048
3607
|
pbar.close()
|
3049
3608
|
|
3609
|
+
# Save COCO annotations if applicable (only if class data provided)
|
3610
|
+
if in_class_data is not None and metadata_format == "COCO":
|
3611
|
+
try:
|
3612
|
+
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
3613
|
+
json.dump(coco_annotations, f, indent=2)
|
3614
|
+
if not quiet:
|
3615
|
+
print(
|
3616
|
+
f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
|
3617
|
+
f"{len(coco_annotations['annotations'])} annotations, "
|
3618
|
+
f"{len(coco_annotations['categories'])} categories"
|
3619
|
+
)
|
3620
|
+
except Exception as e:
|
3621
|
+
if not quiet:
|
3622
|
+
print(f"ERROR saving COCO annotations: {e}")
|
3623
|
+
stats["errors"] += 1
|
3624
|
+
|
3625
|
+
# Save YOLO classes file if applicable (only if class data provided)
|
3626
|
+
if in_class_data is not None and metadata_format == "YOLO":
|
3627
|
+
try:
|
3628
|
+
# Create classes.txt with class names
|
3629
|
+
classes_path = os.path.join(out_folder, "classes.txt")
|
3630
|
+
# Sort by class ID to ensure correct order
|
3631
|
+
sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
|
3632
|
+
with open(classes_path, "w") as f:
|
3633
|
+
for class_val, _ in sorted_classes:
|
3634
|
+
f.write(f"{class_val}\n")
|
3635
|
+
if not quiet:
|
3636
|
+
print(f"Saved YOLO classes file with {len(class_to_id)} classes")
|
3637
|
+
except Exception as e:
|
3638
|
+
if not quiet:
|
3639
|
+
print(f"ERROR saving YOLO classes file: {e}")
|
3640
|
+
stats["errors"] += 1
|
3641
|
+
|
3050
3642
|
# Create overview image if requested
|
3051
3643
|
if create_overview and stats["tile_coordinates"]:
|
3052
3644
|
try:
|
@@ -3064,13 +3656,14 @@ def export_geotiff_tiles(
|
|
3064
3656
|
if not quiet:
|
3065
3657
|
print("\n------- Export Summary -------")
|
3066
3658
|
print(f"Total tiles exported: {stats['total_tiles']}")
|
3067
|
-
|
3068
|
-
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
3069
|
-
)
|
3070
|
-
if stats["tiles_with_features"] > 0:
|
3659
|
+
if in_class_data is not None:
|
3071
3660
|
print(
|
3072
|
-
f"
|
3661
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
3073
3662
|
)
|
3663
|
+
if stats["tiles_with_features"] > 0:
|
3664
|
+
print(
|
3665
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
3666
|
+
)
|
3074
3667
|
if stats["errors"] > 0:
|
3075
3668
|
print(f"Errors encountered: {stats['errors']}")
|
3076
3669
|
print(f"Output saved to: {out_folder}")
|
@@ -3079,7 +3672,6 @@ def export_geotiff_tiles(
|
|
3079
3672
|
if stats["total_tiles"] > 0:
|
3080
3673
|
print("\n------- Georeference Verification -------")
|
3081
3674
|
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
3082
|
-
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
3083
3675
|
|
3084
3676
|
if os.path.exists(sample_image):
|
3085
3677
|
try:
|
@@ -3095,19 +3687,22 @@ def export_geotiff_tiles(
|
|
3095
3687
|
except Exception as e:
|
3096
3688
|
print(f"Error verifying image georeference: {e}")
|
3097
3689
|
|
3098
|
-
if
|
3099
|
-
|
3100
|
-
|
3101
|
-
|
3102
|
-
|
3103
|
-
|
3104
|
-
f"Label
|
3105
|
-
|
3106
|
-
|
3107
|
-
|
3108
|
-
|
3109
|
-
|
3110
|
-
|
3690
|
+
# Only verify label if class data was provided
|
3691
|
+
if in_class_data is not None:
|
3692
|
+
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
3693
|
+
if os.path.exists(sample_label):
|
3694
|
+
try:
|
3695
|
+
with rasterio.open(sample_label) as lbl:
|
3696
|
+
print(f"Label CRS: {lbl.crs}")
|
3697
|
+
print(f"Label transform: {lbl.transform}")
|
3698
|
+
print(
|
3699
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
3700
|
+
)
|
3701
|
+
print(
|
3702
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
3703
|
+
)
|
3704
|
+
except Exception as e:
|
3705
|
+
print(f"Error verifying label georeference: {e}")
|
3111
3706
|
|
3112
3707
|
# Return statistics dictionary for further processing if needed
|
3113
3708
|
return stats
|
@@ -3125,36 +3720,41 @@ def export_geotiff_tiles_batch(
|
|
3125
3720
|
max_tiles=None,
|
3126
3721
|
quiet=False,
|
3127
3722
|
all_touched=True,
|
3128
|
-
create_overview=False,
|
3129
3723
|
skip_empty_tiles=False,
|
3130
3724
|
image_extensions=None,
|
3131
3725
|
mask_extensions=None,
|
3132
|
-
match_by_name=
|
3726
|
+
match_by_name=False,
|
3727
|
+
metadata_format="PASCAL_VOC",
|
3133
3728
|
) -> Dict[str, Any]:
|
3134
3729
|
"""
|
3135
|
-
Export georeferenced GeoTIFF tiles from images and masks.
|
3730
|
+
Export georeferenced GeoTIFF tiles from images and optionally masks.
|
3731
|
+
|
3732
|
+
This function supports four modes:
|
3733
|
+
1. Images only (no masks) - when neither masks_file nor masks_folder is provided
|
3734
|
+
2. Single vector file covering all images (masks_file parameter)
|
3735
|
+
3. Multiple vector files, one per image (masks_folder parameter)
|
3736
|
+
4. Multiple raster mask files (masks_folder parameter)
|
3136
3737
|
|
3137
|
-
|
3138
|
-
1. Single vector file covering all images (masks_file parameter)
|
3139
|
-
2. Multiple vector files, one per image (masks_folder parameter)
|
3140
|
-
3. Multiple raster mask files (masks_folder parameter)
|
3738
|
+
For mode 1 (images only), only image tiles will be exported without labels.
|
3141
3739
|
|
3142
|
-
For mode
|
3740
|
+
For mode 2 (single vector file), specify masks_file path. The function will
|
3143
3741
|
use spatial intersection to determine which features apply to each image.
|
3144
3742
|
|
3145
|
-
For mode
|
3743
|
+
For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
|
3146
3744
|
are paired either by matching filenames (match_by_name=True) or by sorted order
|
3147
3745
|
(match_by_name=False).
|
3148
3746
|
|
3149
|
-
All image tiles are saved to a single 'images' folder and all mask tiles
|
3150
|
-
single 'masks' folder within the output directory.
|
3747
|
+
All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
|
3748
|
+
to a single 'masks' folder within the output directory.
|
3151
3749
|
|
3152
3750
|
Args:
|
3153
3751
|
images_folder (str): Path to folder containing raster images
|
3154
3752
|
masks_folder (str, optional): Path to folder containing classification masks/vectors.
|
3155
|
-
Use this for multiple mask files (one per image or raster masks).
|
3753
|
+
Use this for multiple mask files (one per image or raster masks). If not provided
|
3754
|
+
and masks_file is also not provided, only image tiles will be exported.
|
3156
3755
|
masks_file (str, optional): Path to a single vector file covering all images.
|
3157
|
-
Use this for a single GeoJSON/Shapefile that covers multiple images.
|
3756
|
+
Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
|
3757
|
+
and masks_folder is also not provided, only image tiles will be exported.
|
3158
3758
|
output_folder (str, optional): Path to output folder. If None, creates 'tiles'
|
3159
3759
|
subfolder in images_folder.
|
3160
3760
|
tile_size (int): Size of tiles in pixels (square)
|
@@ -3170,16 +3770,23 @@ def export_geotiff_tiles_batch(
|
|
3170
3770
|
mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
|
3171
3771
|
match_by_name (bool): If True, match image and mask files by base filename.
|
3172
3772
|
If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
|
3773
|
+
metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
|
3774
|
+
Default is "PASCAL_VOC".
|
3173
3775
|
|
3174
3776
|
Returns:
|
3175
3777
|
Dict[str, Any]: Dictionary containing batch processing statistics
|
3176
3778
|
|
3177
3779
|
Raises:
|
3178
3780
|
ValueError: If no images found, or if masks_folder and masks_file are both specified,
|
3179
|
-
or if
|
3180
|
-
match_by_name=False.
|
3781
|
+
or if counts don't match when using masks_folder with match_by_name=False.
|
3181
3782
|
|
3182
3783
|
Examples:
|
3784
|
+
# Images only (no masks)
|
3785
|
+
>>> stats = export_geotiff_tiles_batch(
|
3786
|
+
... images_folder='data/images',
|
3787
|
+
... output_folder='output/tiles'
|
3788
|
+
... )
|
3789
|
+
|
3183
3790
|
# Single vector file covering all images
|
3184
3791
|
>>> stats = export_geotiff_tiles_batch(
|
3185
3792
|
... images_folder='data/images',
|
@@ -3214,11 +3821,6 @@ def export_geotiff_tiles_batch(
|
|
3214
3821
|
"Cannot specify both masks_folder and masks_file. Please use only one."
|
3215
3822
|
)
|
3216
3823
|
|
3217
|
-
if masks_folder is None and masks_file is None:
|
3218
|
-
raise ValueError(
|
3219
|
-
"Must specify either masks_folder or masks_file for mask data source."
|
3220
|
-
)
|
3221
|
-
|
3222
3824
|
# Default output folder if not specified
|
3223
3825
|
if output_folder is None:
|
3224
3826
|
output_folder = os.path.join(images_folder, "tiles")
|
@@ -3249,9 +3851,37 @@ def export_geotiff_tiles_batch(
|
|
3249
3851
|
# Create output folder structure
|
3250
3852
|
os.makedirs(output_folder, exist_ok=True)
|
3251
3853
|
output_images_dir = os.path.join(output_folder, "images")
|
3252
|
-
output_masks_dir = os.path.join(output_folder, "masks")
|
3253
3854
|
os.makedirs(output_images_dir, exist_ok=True)
|
3254
|
-
|
3855
|
+
|
3856
|
+
# Only create masks directory if masks are provided
|
3857
|
+
output_masks_dir = None
|
3858
|
+
if masks_folder is not None or masks_file is not None:
|
3859
|
+
output_masks_dir = os.path.join(output_folder, "masks")
|
3860
|
+
os.makedirs(output_masks_dir, exist_ok=True)
|
3861
|
+
|
3862
|
+
# Create annotation directory based on metadata format (only if masks are provided)
|
3863
|
+
ann_dir = None
|
3864
|
+
if (masks_folder is not None or masks_file is not None) and metadata_format in [
|
3865
|
+
"PASCAL_VOC",
|
3866
|
+
"COCO",
|
3867
|
+
]:
|
3868
|
+
ann_dir = os.path.join(output_folder, "annotations")
|
3869
|
+
os.makedirs(ann_dir, exist_ok=True)
|
3870
|
+
|
3871
|
+
# Initialize COCO annotations dictionary (only if masks are provided)
|
3872
|
+
coco_annotations = None
|
3873
|
+
if (
|
3874
|
+
masks_folder is not None or masks_file is not None
|
3875
|
+
) and metadata_format == "COCO":
|
3876
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
3877
|
+
|
3878
|
+
# Initialize YOLO class set (only if masks are provided)
|
3879
|
+
yolo_classes = (
|
3880
|
+
set()
|
3881
|
+
if (masks_folder is not None or masks_file is not None)
|
3882
|
+
and metadata_format == "YOLO"
|
3883
|
+
else None
|
3884
|
+
)
|
3255
3885
|
|
3256
3886
|
# Get list of image files
|
3257
3887
|
image_files = []
|
@@ -3269,10 +3899,16 @@ def export_geotiff_tiles_batch(
|
|
3269
3899
|
|
3270
3900
|
# Handle different mask input modes
|
3271
3901
|
use_single_mask_file = masks_file is not None
|
3902
|
+
has_masks = masks_file is not None or masks_folder is not None
|
3272
3903
|
mask_files = []
|
3273
3904
|
image_mask_pairs = []
|
3274
3905
|
|
3275
|
-
if
|
3906
|
+
if not has_masks:
|
3907
|
+
# Mode 0: No masks - create pairs with None for mask
|
3908
|
+
for image_file in image_files:
|
3909
|
+
image_mask_pairs.append((image_file, None, None))
|
3910
|
+
|
3911
|
+
elif use_single_mask_file:
|
3276
3912
|
# Mode 1: Single vector file covering all images
|
3277
3913
|
if not os.path.exists(masks_file):
|
3278
3914
|
raise ValueError(f"Mask file not found: {masks_file}")
|
@@ -3324,10 +3960,21 @@ def export_geotiff_tiles_batch(
|
|
3324
3960
|
print(f"Warning: No mask found for image {img_base}")
|
3325
3961
|
|
3326
3962
|
if not image_mask_pairs:
|
3327
|
-
|
3963
|
+
# Provide detailed error message with found files
|
3964
|
+
image_bases = list(image_dict.keys())
|
3965
|
+
mask_bases = list(mask_dict.keys())
|
3966
|
+
error_msg = (
|
3328
3967
|
"No matching image-mask pairs found when matching by filename. "
|
3329
|
-
"Check that image and mask files have matching base names
|
3968
|
+
"Check that image and mask files have matching base names.\n"
|
3969
|
+
f"Found {len(image_bases)} image(s): "
|
3970
|
+
f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
|
3971
|
+
f"{'...' if len(image_bases) > 5 else ''}\n"
|
3972
|
+
f"Found {len(mask_bases)} mask(s): "
|
3973
|
+
f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
|
3974
|
+
f"{'...' if len(mask_bases) > 5 else ''}\n"
|
3975
|
+
"Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
|
3330
3976
|
)
|
3977
|
+
raise ValueError(error_msg)
|
3331
3978
|
|
3332
3979
|
else:
|
3333
3980
|
# Match by sorted order
|
@@ -3354,7 +4001,11 @@ def export_geotiff_tiles_batch(
|
|
3354
4001
|
}
|
3355
4002
|
|
3356
4003
|
if not quiet:
|
3357
|
-
if
|
4004
|
+
if not has_masks:
|
4005
|
+
print(
|
4006
|
+
f"Found {len(image_files)} image files to process (images only, no masks)"
|
4007
|
+
)
|
4008
|
+
elif use_single_mask_file:
|
3358
4009
|
print(f"Found {len(image_files)} image files to process")
|
3359
4010
|
print(f"Using single mask file: {masks_file}")
|
3360
4011
|
else:
|
@@ -3383,10 +4034,15 @@ def export_geotiff_tiles_batch(
|
|
3383
4034
|
if not quiet:
|
3384
4035
|
print(f"\nProcessing: {base_name}")
|
3385
4036
|
print(f" Image: {os.path.basename(image_file)}")
|
3386
|
-
if
|
3387
|
-
|
4037
|
+
if mask_file is not None:
|
4038
|
+
if use_single_mask_file:
|
4039
|
+
print(
|
4040
|
+
f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
|
4041
|
+
)
|
4042
|
+
else:
|
4043
|
+
print(f" Mask: {os.path.basename(mask_file)}")
|
3388
4044
|
else:
|
3389
|
-
print(f" Mask:
|
4045
|
+
print(f" Mask: None (images only)")
|
3390
4046
|
|
3391
4047
|
# Process the image-mask pair
|
3392
4048
|
tiles_generated = _process_image_mask_pair(
|
@@ -3406,6 +4062,13 @@ def export_geotiff_tiles_batch(
|
|
3406
4062
|
quiet=quiet,
|
3407
4063
|
mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
|
3408
4064
|
use_single_mask_file=use_single_mask_file,
|
4065
|
+
metadata_format=metadata_format,
|
4066
|
+
ann_dir=(
|
4067
|
+
ann_dir
|
4068
|
+
if "ann_dir" in locals()
|
4069
|
+
and metadata_format in ["PASCAL_VOC", "COCO"]
|
4070
|
+
else None
|
4071
|
+
),
|
3409
4072
|
)
|
3410
4073
|
|
3411
4074
|
# Update counters
|
@@ -3427,6 +4090,23 @@ def export_geotiff_tiles_batch(
|
|
3427
4090
|
}
|
3428
4091
|
)
|
3429
4092
|
|
4093
|
+
# Aggregate COCO annotations
|
4094
|
+
if metadata_format == "COCO" and "coco_data" in tiles_generated:
|
4095
|
+
coco_data = tiles_generated["coco_data"]
|
4096
|
+
# Add images and annotations
|
4097
|
+
coco_annotations["images"].extend(coco_data.get("images", []))
|
4098
|
+
coco_annotations["annotations"].extend(coco_data.get("annotations", []))
|
4099
|
+
# Merge categories (avoid duplicates)
|
4100
|
+
for cat in coco_data.get("categories", []):
|
4101
|
+
if not any(
|
4102
|
+
c["id"] == cat["id"] for c in coco_annotations["categories"]
|
4103
|
+
):
|
4104
|
+
coco_annotations["categories"].append(cat)
|
4105
|
+
|
4106
|
+
# Aggregate YOLO classes
|
4107
|
+
if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
|
4108
|
+
yolo_classes.update(tiles_generated["yolo_classes"])
|
4109
|
+
|
3430
4110
|
except Exception as e:
|
3431
4111
|
if not quiet:
|
3432
4112
|
print(f"ERROR processing {base_name}: {e}")
|
@@ -3435,6 +4115,33 @@ def export_geotiff_tiles_batch(
|
|
3435
4115
|
)
|
3436
4116
|
batch_stats["errors"] += 1
|
3437
4117
|
|
4118
|
+
# Save aggregated COCO annotations
|
4119
|
+
if metadata_format == "COCO" and coco_annotations:
|
4120
|
+
import json
|
4121
|
+
|
4122
|
+
coco_path = os.path.join(ann_dir, "instances.json")
|
4123
|
+
with open(coco_path, "w") as f:
|
4124
|
+
json.dump(coco_annotations, f, indent=2)
|
4125
|
+
if not quiet:
|
4126
|
+
print(f"\nSaved COCO annotations: {coco_path}")
|
4127
|
+
print(
|
4128
|
+
f" Images: {len(coco_annotations['images'])}, "
|
4129
|
+
f"Annotations: {len(coco_annotations['annotations'])}, "
|
4130
|
+
f"Categories: {len(coco_annotations['categories'])}"
|
4131
|
+
)
|
4132
|
+
|
4133
|
+
# Save aggregated YOLO classes
|
4134
|
+
if metadata_format == "YOLO" and yolo_classes:
|
4135
|
+
classes_path = os.path.join(output_folder, "labels", "classes.txt")
|
4136
|
+
os.makedirs(os.path.dirname(classes_path), exist_ok=True)
|
4137
|
+
sorted_classes = sorted(yolo_classes)
|
4138
|
+
with open(classes_path, "w") as f:
|
4139
|
+
for cls in sorted_classes:
|
4140
|
+
f.write(f"{cls}\n")
|
4141
|
+
if not quiet:
|
4142
|
+
print(f"\nSaved YOLO classes: {classes_path}")
|
4143
|
+
print(f" Total classes: {len(sorted_classes)}")
|
4144
|
+
|
3438
4145
|
# Print batch summary
|
3439
4146
|
if not quiet:
|
3440
4147
|
print("\n" + "=" * 60)
|
@@ -3457,7 +4164,12 @@ def export_geotiff_tiles_batch(
|
|
3457
4164
|
|
3458
4165
|
print(f"Output saved to: {output_folder}")
|
3459
4166
|
print(f" Images: {output_images_dir}")
|
3460
|
-
|
4167
|
+
if output_masks_dir is not None:
|
4168
|
+
print(f" Masks: {output_masks_dir}")
|
4169
|
+
if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
|
4170
|
+
print(f" Annotations: {ann_dir}")
|
4171
|
+
elif metadata_format == "YOLO":
|
4172
|
+
print(f" Labels: {os.path.join(output_folder, 'labels')}")
|
3461
4173
|
|
3462
4174
|
# List failed files if any
|
3463
4175
|
if batch_stats["failed_files"]:
|
@@ -3485,6 +4197,8 @@ def _process_image_mask_pair(
|
|
3485
4197
|
quiet=False,
|
3486
4198
|
mask_gdf=None,
|
3487
4199
|
use_single_mask_file=False,
|
4200
|
+
metadata_format="PASCAL_VOC",
|
4201
|
+
ann_dir=None,
|
3488
4202
|
):
|
3489
4203
|
"""
|
3490
4204
|
Process a single image-mask pair and save tiles directly to output directories.
|
@@ -3498,9 +4212,9 @@ def _process_image_mask_pair(
|
|
3498
4212
|
"""
|
3499
4213
|
import warnings
|
3500
4214
|
|
3501
|
-
# Determine if mask data is raster or vector
|
4215
|
+
# Determine if mask data is raster or vector (only if mask_file is provided)
|
3502
4216
|
is_class_data_raster = False
|
3503
|
-
if isinstance(mask_file, str):
|
4217
|
+
if mask_file is not None and isinstance(mask_file, str):
|
3504
4218
|
file_ext = Path(mask_file).suffix.lower()
|
3505
4219
|
# Common raster extensions
|
3506
4220
|
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
@@ -3517,6 +4231,13 @@ def _process_image_mask_pair(
|
|
3517
4231
|
"errors": 0,
|
3518
4232
|
}
|
3519
4233
|
|
4234
|
+
# Initialize COCO/YOLO tracking for this image
|
4235
|
+
if metadata_format == "COCO":
|
4236
|
+
stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
|
4237
|
+
coco_ann_id = 0
|
4238
|
+
if metadata_format == "YOLO":
|
4239
|
+
stats["yolo_classes"] = set()
|
4240
|
+
|
3520
4241
|
# Open the input raster
|
3521
4242
|
with rasterio.open(image_file) as src:
|
3522
4243
|
# Calculate number of tiles
|
@@ -3527,10 +4248,10 @@ def _process_image_mask_pair(
|
|
3527
4248
|
if max_tiles is None:
|
3528
4249
|
max_tiles = total_tiles
|
3529
4250
|
|
3530
|
-
# Process classification data
|
4251
|
+
# Process classification data (only if mask_file is provided)
|
3531
4252
|
class_to_id = {}
|
3532
4253
|
|
3533
|
-
if is_class_data_raster:
|
4254
|
+
if mask_file is not None and is_class_data_raster:
|
3534
4255
|
# Load raster class data
|
3535
4256
|
with rasterio.open(mask_file) as class_src:
|
3536
4257
|
# Check if raster CRS matches
|
@@ -3557,7 +4278,7 @@ def _process_image_mask_pair(
|
|
3557
4278
|
|
3558
4279
|
# Create class mapping
|
3559
4280
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
3560
|
-
|
4281
|
+
elif mask_file is not None:
|
3561
4282
|
# Load vector class data
|
3562
4283
|
try:
|
3563
4284
|
if use_single_mask_file and mask_gdf is not None:
|
@@ -3609,9 +4330,6 @@ def _process_image_mask_pair(
|
|
3609
4330
|
tile_index = 0
|
3610
4331
|
for y in range(num_tiles_y):
|
3611
4332
|
for x in range(num_tiles_x):
|
3612
|
-
if tile_index >= max_tiles:
|
3613
|
-
break
|
3614
|
-
|
3615
4333
|
# Calculate window coordinates
|
3616
4334
|
window_x = x * stride
|
3617
4335
|
window_y = y * stride
|
@@ -3636,12 +4354,12 @@ def _process_image_mask_pair(
|
|
3636
4354
|
|
3637
4355
|
window_bounds = box(minx, miny, maxx, maxy)
|
3638
4356
|
|
3639
|
-
# Create label mask
|
4357
|
+
# Create label mask (only if mask_file is provided)
|
3640
4358
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
3641
4359
|
has_features = False
|
3642
4360
|
|
3643
|
-
# Process classification data to create labels
|
3644
|
-
if is_class_data_raster:
|
4361
|
+
# Process classification data to create labels (only if mask_file is provided)
|
4362
|
+
if mask_file is not None and is_class_data_raster:
|
3645
4363
|
# For raster class data
|
3646
4364
|
with rasterio.open(mask_file) as class_src:
|
3647
4365
|
# Get corresponding window in class raster
|
@@ -3674,7 +4392,7 @@ def _process_image_mask_pair(
|
|
3674
4392
|
if not quiet:
|
3675
4393
|
print(f"Error reading class raster window: {e}")
|
3676
4394
|
stats["errors"] += 1
|
3677
|
-
|
4395
|
+
elif mask_file is not None:
|
3678
4396
|
# For vector class data
|
3679
4397
|
# Find features that intersect with window
|
3680
4398
|
window_features = gdf[gdf.intersects(window_bounds)]
|
@@ -3712,11 +4430,14 @@ def _process_image_mask_pair(
|
|
3712
4430
|
print(f"Error rasterizing feature {idx}: {e}")
|
3713
4431
|
stats["errors"] += 1
|
3714
4432
|
|
3715
|
-
# Skip tile if no features and skip_empty_tiles is True
|
3716
|
-
if skip_empty_tiles and not has_features:
|
3717
|
-
tile_index += 1
|
4433
|
+
# Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
|
4434
|
+
if mask_file is not None and skip_empty_tiles and not has_features:
|
3718
4435
|
continue
|
3719
4436
|
|
4437
|
+
# Check if we've reached max_tiles before saving
|
4438
|
+
if tile_index >= max_tiles:
|
4439
|
+
break
|
4440
|
+
|
3720
4441
|
# Generate unique tile name
|
3721
4442
|
tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
|
3722
4443
|
|
@@ -3747,29 +4468,225 @@ def _process_image_mask_pair(
|
|
3747
4468
|
print(f"ERROR saving image GeoTIFF: {e}")
|
3748
4469
|
stats["errors"] += 1
|
3749
4470
|
|
3750
|
-
#
|
3751
|
-
|
3752
|
-
|
3753
|
-
|
3754
|
-
|
3755
|
-
|
3756
|
-
|
3757
|
-
|
3758
|
-
|
3759
|
-
|
4471
|
+
# Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
|
4472
|
+
if mask_file is not None and output_masks_dir is not None:
|
4473
|
+
# Create profile for label GeoTIFF
|
4474
|
+
label_profile = {
|
4475
|
+
"driver": "GTiff",
|
4476
|
+
"height": tile_size,
|
4477
|
+
"width": tile_size,
|
4478
|
+
"count": 1,
|
4479
|
+
"dtype": "uint8",
|
4480
|
+
"crs": src.crs,
|
4481
|
+
"transform": window_transform,
|
4482
|
+
}
|
3760
4483
|
|
3761
|
-
|
3762
|
-
|
3763
|
-
|
3764
|
-
|
3765
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
4484
|
+
label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
|
4485
|
+
try:
|
4486
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
4487
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
3766
4488
|
|
3767
|
-
|
3768
|
-
|
3769
|
-
|
3770
|
-
|
3771
|
-
|
3772
|
-
|
4489
|
+
if has_features:
|
4490
|
+
stats["tiles_with_features"] += 1
|
4491
|
+
except Exception as e:
|
4492
|
+
if not quiet:
|
4493
|
+
print(f"ERROR saving label GeoTIFF: {e}")
|
4494
|
+
stats["errors"] += 1
|
4495
|
+
|
4496
|
+
# Generate annotation metadata based on format (only if mask_file is provided)
|
4497
|
+
if (
|
4498
|
+
mask_file is not None
|
4499
|
+
and metadata_format == "PASCAL_VOC"
|
4500
|
+
and ann_dir
|
4501
|
+
):
|
4502
|
+
# Create PASCAL VOC XML annotation
|
4503
|
+
from lxml import etree as ET
|
4504
|
+
|
4505
|
+
annotation = ET.Element("annotation")
|
4506
|
+
ET.SubElement(annotation, "folder").text = os.path.basename(
|
4507
|
+
output_images_dir
|
4508
|
+
)
|
4509
|
+
ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
|
4510
|
+
ET.SubElement(annotation, "path").text = image_path
|
4511
|
+
|
4512
|
+
source = ET.SubElement(annotation, "source")
|
4513
|
+
ET.SubElement(source, "database").text = "GeoAI"
|
4514
|
+
|
4515
|
+
size = ET.SubElement(annotation, "size")
|
4516
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
4517
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
4518
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
4519
|
+
|
4520
|
+
ET.SubElement(annotation, "segmented").text = "1"
|
4521
|
+
|
4522
|
+
# Find connected components for instance segmentation
|
4523
|
+
from scipy import ndimage
|
4524
|
+
|
4525
|
+
for class_id in np.unique(label_mask):
|
4526
|
+
if class_id == 0:
|
4527
|
+
continue
|
4528
|
+
|
4529
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
4530
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
4531
|
+
|
4532
|
+
for instance_id in range(1, num_features + 1):
|
4533
|
+
instance_mask = labeled_array == instance_id
|
4534
|
+
coords = np.argwhere(instance_mask)
|
4535
|
+
|
4536
|
+
if len(coords) == 0:
|
4537
|
+
continue
|
4538
|
+
|
4539
|
+
ymin, xmin = coords.min(axis=0)
|
4540
|
+
ymax, xmax = coords.max(axis=0)
|
4541
|
+
|
4542
|
+
obj = ET.SubElement(annotation, "object")
|
4543
|
+
class_name = next(
|
4544
|
+
(k for k, v in class_to_id.items() if v == class_id),
|
4545
|
+
str(class_id),
|
4546
|
+
)
|
4547
|
+
ET.SubElement(obj, "name").text = str(class_name)
|
4548
|
+
ET.SubElement(obj, "pose").text = "Unspecified"
|
4549
|
+
ET.SubElement(obj, "truncated").text = "0"
|
4550
|
+
ET.SubElement(obj, "difficult").text = "0"
|
4551
|
+
|
4552
|
+
bndbox = ET.SubElement(obj, "bndbox")
|
4553
|
+
ET.SubElement(bndbox, "xmin").text = str(int(xmin))
|
4554
|
+
ET.SubElement(bndbox, "ymin").text = str(int(ymin))
|
4555
|
+
ET.SubElement(bndbox, "xmax").text = str(int(xmax))
|
4556
|
+
ET.SubElement(bndbox, "ymax").text = str(int(ymax))
|
4557
|
+
|
4558
|
+
# Save XML file
|
4559
|
+
xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
|
4560
|
+
tree = ET.ElementTree(annotation)
|
4561
|
+
tree.write(xml_path, pretty_print=True, encoding="utf-8")
|
4562
|
+
|
4563
|
+
elif mask_file is not None and metadata_format == "COCO":
|
4564
|
+
# Add COCO image entry
|
4565
|
+
image_id = int(global_tile_counter + tile_index)
|
4566
|
+
stats["coco_data"]["images"].append(
|
4567
|
+
{
|
4568
|
+
"id": image_id,
|
4569
|
+
"file_name": f"{tile_name}.tif",
|
4570
|
+
"width": int(tile_size),
|
4571
|
+
"height": int(tile_size),
|
4572
|
+
}
|
4573
|
+
)
|
4574
|
+
|
4575
|
+
# Add COCO categories (only once per unique class)
|
4576
|
+
for class_val, class_id in class_to_id.items():
|
4577
|
+
if not any(
|
4578
|
+
c["id"] == class_id
|
4579
|
+
for c in stats["coco_data"]["categories"]
|
4580
|
+
):
|
4581
|
+
stats["coco_data"]["categories"].append(
|
4582
|
+
{
|
4583
|
+
"id": int(class_id),
|
4584
|
+
"name": str(class_val),
|
4585
|
+
"supercategory": "object",
|
4586
|
+
}
|
4587
|
+
)
|
4588
|
+
|
4589
|
+
# Add COCO annotations (instance segmentation)
|
4590
|
+
from scipy import ndimage
|
4591
|
+
from skimage import measure
|
4592
|
+
|
4593
|
+
for class_id in np.unique(label_mask):
|
4594
|
+
if class_id == 0:
|
4595
|
+
continue
|
4596
|
+
|
4597
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
4598
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
4599
|
+
|
4600
|
+
for instance_id in range(1, num_features + 1):
|
4601
|
+
instance_mask = (labeled_array == instance_id).astype(
|
4602
|
+
np.uint8
|
4603
|
+
)
|
4604
|
+
coords = np.argwhere(instance_mask)
|
4605
|
+
|
4606
|
+
if len(coords) == 0:
|
4607
|
+
continue
|
4608
|
+
|
4609
|
+
ymin, xmin = coords.min(axis=0)
|
4610
|
+
ymax, xmax = coords.max(axis=0)
|
4611
|
+
|
4612
|
+
bbox = [
|
4613
|
+
int(xmin),
|
4614
|
+
int(ymin),
|
4615
|
+
int(xmax - xmin),
|
4616
|
+
int(ymax - ymin),
|
4617
|
+
]
|
4618
|
+
area = int(np.sum(instance_mask))
|
4619
|
+
|
4620
|
+
# Find contours for segmentation
|
4621
|
+
contours = measure.find_contours(instance_mask, 0.5)
|
4622
|
+
segmentation = []
|
4623
|
+
for contour in contours:
|
4624
|
+
contour = np.flip(contour, axis=1)
|
4625
|
+
segmentation_points = contour.ravel().tolist()
|
4626
|
+
if len(segmentation_points) >= 6:
|
4627
|
+
segmentation.append(segmentation_points)
|
4628
|
+
|
4629
|
+
if segmentation:
|
4630
|
+
stats["coco_data"]["annotations"].append(
|
4631
|
+
{
|
4632
|
+
"id": int(coco_ann_id),
|
4633
|
+
"image_id": int(image_id),
|
4634
|
+
"category_id": int(class_id),
|
4635
|
+
"bbox": bbox,
|
4636
|
+
"area": area,
|
4637
|
+
"segmentation": segmentation,
|
4638
|
+
"iscrowd": 0,
|
4639
|
+
}
|
4640
|
+
)
|
4641
|
+
coco_ann_id += 1
|
4642
|
+
|
4643
|
+
elif mask_file is not None and metadata_format == "YOLO":
|
4644
|
+
# Create YOLO labels directory if needed
|
4645
|
+
labels_dir = os.path.join(
|
4646
|
+
os.path.dirname(output_images_dir), "labels"
|
4647
|
+
)
|
4648
|
+
os.makedirs(labels_dir, exist_ok=True)
|
4649
|
+
|
4650
|
+
# Generate YOLO annotation file
|
4651
|
+
yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
|
4652
|
+
from scipy import ndimage
|
4653
|
+
|
4654
|
+
with open(yolo_path, "w") as yolo_file:
|
4655
|
+
for class_id in np.unique(label_mask):
|
4656
|
+
if class_id == 0:
|
4657
|
+
continue
|
4658
|
+
|
4659
|
+
# Track class for classes.txt
|
4660
|
+
class_name = next(
|
4661
|
+
(k for k, v in class_to_id.items() if v == class_id),
|
4662
|
+
str(class_id),
|
4663
|
+
)
|
4664
|
+
stats["yolo_classes"].add(class_name)
|
4665
|
+
|
4666
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
4667
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
4668
|
+
|
4669
|
+
for instance_id in range(1, num_features + 1):
|
4670
|
+
instance_mask = labeled_array == instance_id
|
4671
|
+
coords = np.argwhere(instance_mask)
|
4672
|
+
|
4673
|
+
if len(coords) == 0:
|
4674
|
+
continue
|
4675
|
+
|
4676
|
+
ymin, xmin = coords.min(axis=0)
|
4677
|
+
ymax, xmax = coords.max(axis=0)
|
4678
|
+
|
4679
|
+
# Convert to YOLO format (normalized center coordinates)
|
4680
|
+
x_center = ((xmin + xmax) / 2) / tile_size
|
4681
|
+
y_center = ((ymin + ymax) / 2) / tile_size
|
4682
|
+
width = (xmax - xmin) / tile_size
|
4683
|
+
height = (ymax - ymin) / tile_size
|
4684
|
+
|
4685
|
+
# YOLO uses 0-based class indices
|
4686
|
+
yolo_class_id = class_id - 1
|
4687
|
+
yolo_file.write(
|
4688
|
+
f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
4689
|
+
)
|
3773
4690
|
|
3774
4691
|
tile_index += 1
|
3775
4692
|
if tile_index >= max_tiles:
|
@@ -3781,6 +4698,179 @@ def _process_image_mask_pair(
|
|
3781
4698
|
return stats
|
3782
4699
|
|
3783
4700
|
|
4701
|
+
def display_training_tiles(
|
4702
|
+
output_dir,
|
4703
|
+
num_tiles=6,
|
4704
|
+
figsize=(18, 6),
|
4705
|
+
cmap="gray",
|
4706
|
+
save_path=None,
|
4707
|
+
):
|
4708
|
+
"""
|
4709
|
+
Display image and mask tile pairs from training data output.
|
4710
|
+
|
4711
|
+
Args:
|
4712
|
+
output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
|
4713
|
+
num_tiles (int): Number of tile pairs to display (default: 6)
|
4714
|
+
figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
|
4715
|
+
cmap (str): Colormap for mask display (default: 'gray')
|
4716
|
+
save_path (str, optional): If provided, save figure to this path instead of displaying
|
4717
|
+
|
4718
|
+
Returns:
|
4719
|
+
tuple: (fig, axes) matplotlib figure and axes objects
|
4720
|
+
|
4721
|
+
Example:
|
4722
|
+
>>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
|
4723
|
+
>>> # Or save to file
|
4724
|
+
>>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
|
4725
|
+
"""
|
4726
|
+
import matplotlib.pyplot as plt
|
4727
|
+
|
4728
|
+
# Get list of image tiles
|
4729
|
+
images_dir = os.path.join(output_dir, "images")
|
4730
|
+
if not os.path.exists(images_dir):
|
4731
|
+
raise ValueError(f"Images directory not found: {images_dir}")
|
4732
|
+
|
4733
|
+
image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
|
4734
|
+
|
4735
|
+
if not image_tiles:
|
4736
|
+
raise ValueError(f"No image tiles found in {images_dir}")
|
4737
|
+
|
4738
|
+
# Limit to available tiles
|
4739
|
+
num_tiles = min(num_tiles, len(image_tiles))
|
4740
|
+
|
4741
|
+
# Create figure with subplots
|
4742
|
+
fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
|
4743
|
+
|
4744
|
+
# Handle case where num_tiles is 1
|
4745
|
+
if num_tiles == 1:
|
4746
|
+
axes = axes.reshape(2, 1)
|
4747
|
+
|
4748
|
+
for idx, tile_name in enumerate(image_tiles):
|
4749
|
+
# Load and display image tile
|
4750
|
+
image_path = os.path.join(output_dir, "images", tile_name)
|
4751
|
+
with rasterio.open(image_path) as src:
|
4752
|
+
show(src, ax=axes[0, idx], title=f"Image {idx+1}")
|
4753
|
+
|
4754
|
+
# Load and display mask tile
|
4755
|
+
mask_path = os.path.join(output_dir, "masks", tile_name)
|
4756
|
+
if os.path.exists(mask_path):
|
4757
|
+
with rasterio.open(mask_path) as src:
|
4758
|
+
show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
|
4759
|
+
else:
|
4760
|
+
axes[1, idx].text(
|
4761
|
+
0.5,
|
4762
|
+
0.5,
|
4763
|
+
"Mask not found",
|
4764
|
+
ha="center",
|
4765
|
+
va="center",
|
4766
|
+
transform=axes[1, idx].transAxes,
|
4767
|
+
)
|
4768
|
+
axes[1, idx].set_title(f"Mask {idx+1}")
|
4769
|
+
|
4770
|
+
plt.tight_layout()
|
4771
|
+
|
4772
|
+
# Save or show
|
4773
|
+
if save_path:
|
4774
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
4775
|
+
plt.close(fig)
|
4776
|
+
print(f"Figure saved to: {save_path}")
|
4777
|
+
else:
|
4778
|
+
plt.show()
|
4779
|
+
|
4780
|
+
return fig, axes
|
4781
|
+
|
4782
|
+
|
4783
|
+
def display_image_with_vector(
|
4784
|
+
image_path,
|
4785
|
+
vector_path,
|
4786
|
+
figsize=(16, 8),
|
4787
|
+
vector_color="red",
|
4788
|
+
vector_linewidth=1,
|
4789
|
+
vector_facecolor="none",
|
4790
|
+
save_path=None,
|
4791
|
+
):
|
4792
|
+
"""
|
4793
|
+
Display a raster image alongside the same image with vector overlay.
|
4794
|
+
|
4795
|
+
Args:
|
4796
|
+
image_path (str): Path to raster image file
|
4797
|
+
vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
|
4798
|
+
figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
|
4799
|
+
vector_color (str): Edge color for vector features (default: 'red')
|
4800
|
+
vector_linewidth (float): Line width for vector features (default: 1)
|
4801
|
+
vector_facecolor (str): Fill color for vector features (default: 'none')
|
4802
|
+
save_path (str, optional): If provided, save figure to this path instead of displaying
|
4803
|
+
|
4804
|
+
Returns:
|
4805
|
+
tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
|
4806
|
+
|
4807
|
+
Example:
|
4808
|
+
>>> fig, axes, info = display_image_with_vector(
|
4809
|
+
... 'image.tif',
|
4810
|
+
... 'buildings.geojson',
|
4811
|
+
... vector_color='blue'
|
4812
|
+
... )
|
4813
|
+
>>> print(f"Number of features: {info['num_features']}")
|
4814
|
+
"""
|
4815
|
+
import matplotlib.pyplot as plt
|
4816
|
+
|
4817
|
+
# Validate inputs
|
4818
|
+
if not os.path.exists(image_path):
|
4819
|
+
raise ValueError(f"Image file not found: {image_path}")
|
4820
|
+
if not os.path.exists(vector_path):
|
4821
|
+
raise ValueError(f"Vector file not found: {vector_path}")
|
4822
|
+
|
4823
|
+
# Create figure
|
4824
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
4825
|
+
|
4826
|
+
# Load and display image
|
4827
|
+
with rasterio.open(image_path) as src:
|
4828
|
+
# Plot image only
|
4829
|
+
show(src, ax=ax1, title="Image")
|
4830
|
+
|
4831
|
+
# Load vector data
|
4832
|
+
vector_data = gpd.read_file(vector_path)
|
4833
|
+
|
4834
|
+
# Reproject to image CRS if needed
|
4835
|
+
if vector_data.crs != src.crs:
|
4836
|
+
vector_data = vector_data.to_crs(src.crs)
|
4837
|
+
|
4838
|
+
# Plot image with vector overlay
|
4839
|
+
show(
|
4840
|
+
src,
|
4841
|
+
ax=ax2,
|
4842
|
+
title=f"Image with {len(vector_data)} Vector Features",
|
4843
|
+
)
|
4844
|
+
vector_data.plot(
|
4845
|
+
ax=ax2,
|
4846
|
+
facecolor=vector_facecolor,
|
4847
|
+
edgecolor=vector_color,
|
4848
|
+
linewidth=vector_linewidth,
|
4849
|
+
)
|
4850
|
+
|
4851
|
+
# Collect metadata
|
4852
|
+
info = {
|
4853
|
+
"image_shape": src.shape,
|
4854
|
+
"image_crs": src.crs,
|
4855
|
+
"image_bounds": src.bounds,
|
4856
|
+
"num_features": len(vector_data),
|
4857
|
+
"vector_crs": vector_data.crs,
|
4858
|
+
"vector_bounds": vector_data.total_bounds,
|
4859
|
+
}
|
4860
|
+
|
4861
|
+
plt.tight_layout()
|
4862
|
+
|
4863
|
+
# Save or show
|
4864
|
+
if save_path:
|
4865
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
4866
|
+
plt.close(fig)
|
4867
|
+
print(f"Figure saved to: {save_path}")
|
4868
|
+
else:
|
4869
|
+
plt.show()
|
4870
|
+
|
4871
|
+
return fig, (ax1, ax2), info
|
4872
|
+
|
4873
|
+
|
3784
4874
|
def create_overview_image(
|
3785
4875
|
src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
|
3786
4876
|
) -> str:
|
@@ -7675,17 +8765,39 @@ def write_colormap(
|
|
7675
8765
|
|
7676
8766
|
def plot_performance_metrics(
|
7677
8767
|
history_path: str,
|
7678
|
-
figsize: Tuple[int, int] =
|
8768
|
+
figsize: Optional[Tuple[int, int]] = None,
|
7679
8769
|
verbose: bool = True,
|
7680
8770
|
save_path: Optional[str] = None,
|
8771
|
+
csv_path: Optional[str] = None,
|
7681
8772
|
kwargs: Optional[Dict] = None,
|
7682
|
-
) ->
|
7683
|
-
"""Plot performance metrics from a history object.
|
8773
|
+
) -> pd.DataFrame:
|
8774
|
+
"""Plot performance metrics from a training history object and return as DataFrame.
|
8775
|
+
|
8776
|
+
This function loads training history, plots available metrics (loss, IoU, F1,
|
8777
|
+
precision, recall), optionally exports to CSV, and returns all metrics as a
|
8778
|
+
pandas DataFrame for further analysis.
|
7684
8779
|
|
7685
8780
|
Args:
|
7686
|
-
history_path:
|
7687
|
-
figsize:
|
7688
|
-
|
8781
|
+
history_path (str): Path to the saved training history (.pth file).
|
8782
|
+
figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
|
8783
|
+
automatically determined based on number of metrics.
|
8784
|
+
verbose (bool): Whether to print best and final metric values. Defaults to True.
|
8785
|
+
save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
|
8786
|
+
csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
|
8787
|
+
kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
|
8788
|
+
|
8789
|
+
Returns:
|
8790
|
+
pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
|
8791
|
+
Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
|
8792
|
+
'val_precision', 'val_recall' (depending on availability in history).
|
8793
|
+
|
8794
|
+
Example:
|
8795
|
+
>>> df = plot_performance_metrics(
|
8796
|
+
... 'training_history.pth',
|
8797
|
+
... save_path='metrics_plot.png',
|
8798
|
+
... csv_path='metrics.csv'
|
8799
|
+
... )
|
8800
|
+
>>> print(df.head())
|
7689
8801
|
"""
|
7690
8802
|
if kwargs is None:
|
7691
8803
|
kwargs = {}
|
@@ -7695,65 +8807,135 @@ def plot_performance_metrics(
|
|
7695
8807
|
train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
|
7696
8808
|
val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
|
7697
8809
|
val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
|
7698
|
-
|
8810
|
+
# Support both new (f1) and old (dice) key formats for backward compatibility
|
8811
|
+
val_f1_key = (
|
8812
|
+
"val_f1s"
|
8813
|
+
if "val_f1s" in history
|
8814
|
+
else ("val_dices" if "val_dices" in history else "val_dice")
|
8815
|
+
)
|
8816
|
+
# Add support for precision and recall
|
8817
|
+
val_precision_key = (
|
8818
|
+
"val_precisions" if "val_precisions" in history else "val_precision"
|
8819
|
+
)
|
8820
|
+
val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
|
8821
|
+
|
8822
|
+
# Collect available metrics for plotting
|
8823
|
+
available_metrics = []
|
8824
|
+
metric_info = {
|
8825
|
+
"Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
|
8826
|
+
"IoU": (val_iou_key, None, ["Val IoU"]),
|
8827
|
+
"F1": (val_f1_key, None, ["Val F1"]),
|
8828
|
+
"Precision": (val_precision_key, None, ["Val Precision"]),
|
8829
|
+
"Recall": (val_recall_key, None, ["Val Recall"]),
|
8830
|
+
}
|
8831
|
+
|
8832
|
+
for metric_name, (key1, key2, labels) in metric_info.items():
|
8833
|
+
if key1 in history or (key2 and key2 in history):
|
8834
|
+
available_metrics.append((metric_name, key1, key2, labels))
|
7699
8835
|
|
7700
|
-
# Determine number of subplots
|
7701
|
-
|
7702
|
-
|
7703
|
-
|
8836
|
+
# Determine number of subplots and figure size
|
8837
|
+
n_plots = len(available_metrics)
|
8838
|
+
if figsize is None:
|
8839
|
+
figsize = (5 * n_plots, 5)
|
7704
8840
|
|
7705
|
-
|
8841
|
+
# Create DataFrame for all metrics
|
8842
|
+
n_epochs = 0
|
8843
|
+
df_data = {}
|
7706
8844
|
|
7707
|
-
#
|
7708
|
-
|
8845
|
+
# Add epochs
|
8846
|
+
if "epochs" in history:
|
8847
|
+
df_data["epoch"] = history["epochs"]
|
8848
|
+
n_epochs = len(history["epochs"])
|
8849
|
+
elif train_loss_key in history:
|
8850
|
+
n_epochs = len(history[train_loss_key])
|
8851
|
+
df_data["epoch"] = list(range(1, n_epochs + 1))
|
8852
|
+
|
8853
|
+
# Add all available metrics to DataFrame
|
7709
8854
|
if train_loss_key in history:
|
7710
|
-
|
8855
|
+
df_data["train_loss"] = history[train_loss_key]
|
7711
8856
|
if val_loss_key in history:
|
7712
|
-
|
7713
|
-
plt.title("Loss")
|
7714
|
-
plt.xlabel("Epoch")
|
7715
|
-
plt.ylabel("Loss")
|
7716
|
-
plt.legend()
|
7717
|
-
plt.grid(True)
|
7718
|
-
|
7719
|
-
# Plot IoU
|
7720
|
-
plt.subplot(1, n_plots, 2)
|
8857
|
+
df_data["val_loss"] = history[val_loss_key]
|
7721
8858
|
if val_iou_key in history:
|
7722
|
-
|
7723
|
-
|
7724
|
-
|
7725
|
-
|
7726
|
-
|
7727
|
-
|
7728
|
-
|
7729
|
-
|
7730
|
-
|
7731
|
-
|
7732
|
-
|
7733
|
-
|
7734
|
-
|
7735
|
-
|
7736
|
-
|
7737
|
-
|
8859
|
+
df_data["val_iou"] = history[val_iou_key]
|
8860
|
+
if val_f1_key in history:
|
8861
|
+
df_data["val_f1"] = history[val_f1_key]
|
8862
|
+
if val_precision_key in history:
|
8863
|
+
df_data["val_precision"] = history[val_precision_key]
|
8864
|
+
if val_recall_key in history:
|
8865
|
+
df_data["val_recall"] = history[val_recall_key]
|
8866
|
+
|
8867
|
+
# Create DataFrame
|
8868
|
+
df = pd.DataFrame(df_data)
|
8869
|
+
|
8870
|
+
# Export to CSV if requested
|
8871
|
+
if csv_path:
|
8872
|
+
df.to_csv(csv_path, index=False)
|
8873
|
+
if verbose:
|
8874
|
+
print(f"Metrics exported to: {csv_path}")
|
8875
|
+
|
8876
|
+
# Create plots
|
8877
|
+
if n_plots > 0:
|
8878
|
+
fig, axes = plt.subplots(1, n_plots, figsize=figsize)
|
8879
|
+
if n_plots == 1:
|
8880
|
+
axes = [axes]
|
8881
|
+
|
8882
|
+
for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
|
8883
|
+
ax = axes[idx]
|
8884
|
+
|
8885
|
+
if metric_name == "Loss":
|
8886
|
+
# Special handling for loss (has both train and val)
|
8887
|
+
if key1 in history:
|
8888
|
+
ax.plot(history[key1], label=labels[0])
|
8889
|
+
if key2 and key2 in history:
|
8890
|
+
ax.plot(history[key2], label=labels[1])
|
8891
|
+
else:
|
8892
|
+
# Single metric plots
|
8893
|
+
if key1 in history:
|
8894
|
+
ax.plot(history[key1], label=labels[0])
|
7738
8895
|
|
7739
|
-
|
8896
|
+
ax.set_title(metric_name)
|
8897
|
+
ax.set_xlabel("Epoch")
|
8898
|
+
ax.set_ylabel(metric_name)
|
8899
|
+
ax.legend()
|
8900
|
+
ax.grid(True)
|
7740
8901
|
|
7741
|
-
|
7742
|
-
if "dpi" not in kwargs:
|
7743
|
-
kwargs["dpi"] = 150
|
7744
|
-
if "bbox_inches" not in kwargs:
|
7745
|
-
kwargs["bbox_inches"] = "tight"
|
7746
|
-
plt.savefig(save_path, **kwargs)
|
8902
|
+
plt.tight_layout()
|
7747
8903
|
|
7748
|
-
|
8904
|
+
if save_path:
|
8905
|
+
if "dpi" not in kwargs:
|
8906
|
+
kwargs["dpi"] = 150
|
8907
|
+
if "bbox_inches" not in kwargs:
|
8908
|
+
kwargs["bbox_inches"] = "tight"
|
8909
|
+
plt.savefig(save_path, **kwargs)
|
7749
8910
|
|
8911
|
+
plt.show()
|
8912
|
+
|
8913
|
+
# Print summary statistics
|
7750
8914
|
if verbose:
|
8915
|
+
print("\n=== Performance Metrics Summary ===")
|
7751
8916
|
if val_iou_key in history:
|
7752
|
-
print(
|
7753
|
-
|
7754
|
-
|
7755
|
-
|
7756
|
-
print(
|
8917
|
+
print(
|
8918
|
+
f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
|
8919
|
+
)
|
8920
|
+
if val_f1_key in history:
|
8921
|
+
print(
|
8922
|
+
f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
|
8923
|
+
)
|
8924
|
+
if val_precision_key in history:
|
8925
|
+
print(
|
8926
|
+
f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
|
8927
|
+
)
|
8928
|
+
if val_recall_key in history:
|
8929
|
+
print(
|
8930
|
+
f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
|
8931
|
+
)
|
8932
|
+
if val_loss_key in history:
|
8933
|
+
print(
|
8934
|
+
f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
|
8935
|
+
)
|
8936
|
+
print("===================================\n")
|
8937
|
+
|
8938
|
+
return df
|
7757
8939
|
|
7758
8940
|
|
7759
8941
|
def get_device() -> torch.device:
|