geoai-py 0.24.0__py2.py3-none-any.whl → 0.26.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +16 -1
- geoai/auto.py +0 -1
- geoai/change_detection.py +1 -1
- geoai/moondream.py +748 -1
- geoai/prithvi.py +1253 -0
- geoai/utils.py +1877 -327
- {geoai_py-0.24.0.dist-info → geoai_py-0.26.0.dist-info}/METADATA +3 -2
- {geoai_py-0.24.0.dist-info → geoai_py-0.26.0.dist-info}/RECORD +12 -11
- {geoai_py-0.24.0.dist-info → geoai_py-0.26.0.dist-info}/WHEEL +1 -1
- {geoai_py-0.24.0.dist-info → geoai_py-0.26.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.24.0.dist-info → geoai_py-0.26.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.24.0.dist-info → geoai_py-0.26.0.dist-info}/top_level.txt +0 -0
geoai/utils.py
CHANGED
|
@@ -64,7 +64,7 @@ def view_raster(
|
|
|
64
64
|
client_args: Optional[Dict] = {"cors_all": False},
|
|
65
65
|
basemap: Optional[str] = "OpenStreetMap",
|
|
66
66
|
basemap_args: Optional[Dict] = None,
|
|
67
|
-
backend: Optional[str] = "
|
|
67
|
+
backend: Optional[str] = "folium",
|
|
68
68
|
**kwargs: Any,
|
|
69
69
|
) -> Any:
|
|
70
70
|
"""
|
|
@@ -87,7 +87,7 @@ def view_raster(
|
|
|
87
87
|
client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
|
|
88
88
|
basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
|
|
89
89
|
basemap_args (Optional[Dict], optional): Additional arguments for the basemap. Defaults to None.
|
|
90
|
-
backend (Optional[str], optional): The backend to use. Defaults to "
|
|
90
|
+
backend (Optional[str], optional): The backend to use. Defaults to "folium".
|
|
91
91
|
**kwargs (Any): Additional keyword arguments.
|
|
92
92
|
|
|
93
93
|
Returns:
|
|
@@ -123,26 +123,39 @@ def view_raster(
|
|
|
123
123
|
if isinstance(source, dict):
|
|
124
124
|
source = dict_to_image(source)
|
|
125
125
|
|
|
126
|
-
if (
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
126
|
+
if isinstance(source, str) and source.startswith("http"):
|
|
127
|
+
if backend == "folium":
|
|
128
|
+
|
|
129
|
+
m.add_geotiff(
|
|
130
|
+
url=source,
|
|
131
|
+
name=layer_name,
|
|
132
|
+
opacity=opacity,
|
|
133
|
+
attribution=attribution,
|
|
134
|
+
fit_bounds=zoom_to_layer,
|
|
135
|
+
palette=colormap,
|
|
136
|
+
vmin=vmin,
|
|
137
|
+
vmax=vmax,
|
|
138
|
+
**kwargs,
|
|
139
|
+
)
|
|
140
|
+
m.add_layer_control()
|
|
141
|
+
m.add_opacity_control()
|
|
142
|
+
|
|
143
|
+
else:
|
|
144
|
+
if indexes is not None:
|
|
145
|
+
kwargs["bidx"] = indexes
|
|
146
|
+
if colormap is not None:
|
|
147
|
+
kwargs["colormap_name"] = colormap
|
|
148
|
+
if attribution is None:
|
|
149
|
+
attribution = "TiTiler"
|
|
150
|
+
|
|
151
|
+
m.add_cog_layer(
|
|
152
|
+
source,
|
|
153
|
+
name=layer_name,
|
|
154
|
+
opacity=opacity,
|
|
155
|
+
attribution=attribution,
|
|
156
|
+
zoom_to_layer=zoom_to_layer,
|
|
157
|
+
**kwargs,
|
|
158
|
+
)
|
|
146
159
|
else:
|
|
147
160
|
m.add_raster(
|
|
148
161
|
source=source,
|
|
@@ -381,6 +394,394 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
|
|
|
381
394
|
return accum_mean / len(files), accum_std / len(files)
|
|
382
395
|
|
|
383
396
|
|
|
397
|
+
def calc_iou(
|
|
398
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
399
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
400
|
+
num_classes: Optional[int] = None,
|
|
401
|
+
ignore_index: Optional[int] = None,
|
|
402
|
+
smooth: float = 1e-6,
|
|
403
|
+
band: int = 1,
|
|
404
|
+
) -> Union[float, np.ndarray]:
|
|
405
|
+
"""
|
|
406
|
+
Calculate Intersection over Union (IoU) between ground truth and prediction masks.
|
|
407
|
+
|
|
408
|
+
This function computes the IoU metric for segmentation tasks. It supports both
|
|
409
|
+
binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
|
410
|
+
or file paths to raster files.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
414
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
415
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
416
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
|
417
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
418
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
419
|
+
Should have the same shape and format as ground_truth.
|
|
420
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
421
|
+
If None, assumes binary segmentation. Defaults to None.
|
|
422
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
423
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
424
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
425
|
+
Defaults to 1e-6.
|
|
426
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
427
|
+
Only used when input is a file path. Defaults to 1.
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
|
|
431
|
+
For multi-class segmentation, returns an array of IoU scores for each class.
|
|
432
|
+
|
|
433
|
+
Examples:
|
|
434
|
+
>>> # Binary segmentation with arrays
|
|
435
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
436
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
437
|
+
>>> iou = calc_iou(gt, pred)
|
|
438
|
+
>>> print(f"IoU: {iou:.4f}")
|
|
439
|
+
IoU: 0.8333
|
|
440
|
+
|
|
441
|
+
>>> # Multi-class segmentation
|
|
442
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
443
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
444
|
+
>>> iou = calc_iou(gt, pred, num_classes=3)
|
|
445
|
+
>>> print(f"IoU per class: {iou}")
|
|
446
|
+
IoU per class: [0.8333 0.5000 0.5000]
|
|
447
|
+
|
|
448
|
+
>>> # Using PyTorch tensors
|
|
449
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
450
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
451
|
+
>>> iou = calc_iou(gt_tensor, pred_tensor)
|
|
452
|
+
>>> print(f"IoU: {iou:.4f}")
|
|
453
|
+
IoU: 0.8333
|
|
454
|
+
|
|
455
|
+
>>> # Using raster file paths
|
|
456
|
+
>>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
457
|
+
>>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
|
|
458
|
+
Mean IoU: 0.7500
|
|
459
|
+
"""
|
|
460
|
+
# Load from file if string path is provided
|
|
461
|
+
if isinstance(ground_truth, str):
|
|
462
|
+
with rasterio.open(ground_truth) as src:
|
|
463
|
+
ground_truth = src.read(band)
|
|
464
|
+
if isinstance(prediction, str):
|
|
465
|
+
with rasterio.open(prediction) as src:
|
|
466
|
+
prediction = src.read(band)
|
|
467
|
+
|
|
468
|
+
# Convert to numpy if torch tensor
|
|
469
|
+
if isinstance(ground_truth, torch.Tensor):
|
|
470
|
+
ground_truth = ground_truth.cpu().numpy()
|
|
471
|
+
if isinstance(prediction, torch.Tensor):
|
|
472
|
+
prediction = prediction.cpu().numpy()
|
|
473
|
+
|
|
474
|
+
# Ensure inputs have the same shape
|
|
475
|
+
if ground_truth.shape != prediction.shape:
|
|
476
|
+
raise ValueError(
|
|
477
|
+
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# Binary segmentation
|
|
481
|
+
if num_classes is None:
|
|
482
|
+
ground_truth = ground_truth.astype(bool)
|
|
483
|
+
prediction = prediction.astype(bool)
|
|
484
|
+
|
|
485
|
+
intersection = np.logical_and(ground_truth, prediction).sum()
|
|
486
|
+
union = np.logical_or(ground_truth, prediction).sum()
|
|
487
|
+
|
|
488
|
+
if union == 0:
|
|
489
|
+
return 1.0 if intersection == 0 else 0.0
|
|
490
|
+
|
|
491
|
+
iou = (intersection + smooth) / (union + smooth)
|
|
492
|
+
return float(iou)
|
|
493
|
+
|
|
494
|
+
# Multi-class segmentation
|
|
495
|
+
else:
|
|
496
|
+
iou_per_class = []
|
|
497
|
+
|
|
498
|
+
for class_idx in range(num_classes):
|
|
499
|
+
# Handle ignored class by appending np.nan
|
|
500
|
+
if ignore_index is not None and class_idx == ignore_index:
|
|
501
|
+
iou_per_class.append(np.nan)
|
|
502
|
+
continue
|
|
503
|
+
|
|
504
|
+
# Create binary masks for current class
|
|
505
|
+
gt_class = (ground_truth == class_idx).astype(bool)
|
|
506
|
+
pred_class = (prediction == class_idx).astype(bool)
|
|
507
|
+
|
|
508
|
+
intersection = np.logical_and(gt_class, pred_class).sum()
|
|
509
|
+
union = np.logical_or(gt_class, pred_class).sum()
|
|
510
|
+
|
|
511
|
+
if union == 0:
|
|
512
|
+
# If class is not present in both gt and pred
|
|
513
|
+
iou_per_class.append(np.nan)
|
|
514
|
+
else:
|
|
515
|
+
iou_per_class.append((intersection + smooth) / (union + smooth))
|
|
516
|
+
|
|
517
|
+
return np.array(iou_per_class)
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def calc_f1_score(
|
|
521
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
522
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
523
|
+
num_classes: Optional[int] = None,
|
|
524
|
+
ignore_index: Optional[int] = None,
|
|
525
|
+
smooth: float = 1e-6,
|
|
526
|
+
band: int = 1,
|
|
527
|
+
) -> Union[float, np.ndarray]:
|
|
528
|
+
"""
|
|
529
|
+
Calculate F1 score between ground truth and prediction masks.
|
|
530
|
+
|
|
531
|
+
The F1 score is the harmonic mean of precision and recall, computed as:
|
|
532
|
+
F1 = 2 * (precision * recall) / (precision + recall)
|
|
533
|
+
where precision = TP / (TP + FP) and recall = TP / (TP + FN).
|
|
534
|
+
|
|
535
|
+
This function supports both binary and multi-class segmentation, and can handle
|
|
536
|
+
numpy arrays, PyTorch tensors, or file paths to raster files.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
540
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
541
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
542
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
|
543
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
544
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
545
|
+
Should have the same shape and format as ground_truth.
|
|
546
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
547
|
+
If None, assumes binary segmentation. Defaults to None.
|
|
548
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
549
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
550
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
551
|
+
Defaults to 1e-6.
|
|
552
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
553
|
+
Only used when input is a file path. Defaults to 1.
|
|
554
|
+
|
|
555
|
+
Returns:
|
|
556
|
+
Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
|
|
557
|
+
For multi-class segmentation, returns an array of F1 scores for each class.
|
|
558
|
+
|
|
559
|
+
Examples:
|
|
560
|
+
>>> # Binary segmentation with arrays
|
|
561
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
562
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
563
|
+
>>> f1 = calc_f1_score(gt, pred)
|
|
564
|
+
>>> print(f"F1 Score: {f1:.4f}")
|
|
565
|
+
F1 Score: 0.8571
|
|
566
|
+
|
|
567
|
+
>>> # Multi-class segmentation
|
|
568
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
569
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
570
|
+
>>> f1 = calc_f1_score(gt, pred, num_classes=3)
|
|
571
|
+
>>> print(f"F1 Score per class: {f1}")
|
|
572
|
+
F1 Score per class: [0.8571 0.6667 0.6667]
|
|
573
|
+
|
|
574
|
+
>>> # Using PyTorch tensors
|
|
575
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
576
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
577
|
+
>>> f1 = calc_f1_score(gt_tensor, pred_tensor)
|
|
578
|
+
>>> print(f"F1 Score: {f1:.4f}")
|
|
579
|
+
F1 Score: 0.8571
|
|
580
|
+
|
|
581
|
+
>>> # Using raster file paths
|
|
582
|
+
>>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
583
|
+
>>> print(f"Mean F1: {np.nanmean(f1):.4f}")
|
|
584
|
+
Mean F1: 0.7302
|
|
585
|
+
"""
|
|
586
|
+
# Load from file if string path is provided
|
|
587
|
+
if isinstance(ground_truth, str):
|
|
588
|
+
with rasterio.open(ground_truth) as src:
|
|
589
|
+
ground_truth = src.read(band)
|
|
590
|
+
if isinstance(prediction, str):
|
|
591
|
+
with rasterio.open(prediction) as src:
|
|
592
|
+
prediction = src.read(band)
|
|
593
|
+
|
|
594
|
+
# Convert to numpy if torch tensor
|
|
595
|
+
if isinstance(ground_truth, torch.Tensor):
|
|
596
|
+
ground_truth = ground_truth.cpu().numpy()
|
|
597
|
+
if isinstance(prediction, torch.Tensor):
|
|
598
|
+
prediction = prediction.cpu().numpy()
|
|
599
|
+
|
|
600
|
+
# Ensure inputs have the same shape
|
|
601
|
+
if ground_truth.shape != prediction.shape:
|
|
602
|
+
raise ValueError(
|
|
603
|
+
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Binary segmentation
|
|
607
|
+
if num_classes is None:
|
|
608
|
+
ground_truth = ground_truth.astype(bool)
|
|
609
|
+
prediction = prediction.astype(bool)
|
|
610
|
+
|
|
611
|
+
# Calculate True Positives, False Positives, False Negatives
|
|
612
|
+
tp = np.logical_and(ground_truth, prediction).sum()
|
|
613
|
+
fp = np.logical_and(~ground_truth, prediction).sum()
|
|
614
|
+
fn = np.logical_and(ground_truth, ~prediction).sum()
|
|
615
|
+
|
|
616
|
+
# Calculate precision and recall
|
|
617
|
+
precision = (tp + smooth) / (tp + fp + smooth)
|
|
618
|
+
recall = (tp + smooth) / (tp + fn + smooth)
|
|
619
|
+
|
|
620
|
+
# Calculate F1 score
|
|
621
|
+
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
|
622
|
+
return float(f1)
|
|
623
|
+
|
|
624
|
+
# Multi-class segmentation
|
|
625
|
+
else:
|
|
626
|
+
f1_per_class = []
|
|
627
|
+
|
|
628
|
+
for class_idx in range(num_classes):
|
|
629
|
+
# Mark ignored class with np.nan
|
|
630
|
+
if ignore_index is not None and class_idx == ignore_index:
|
|
631
|
+
f1_per_class.append(np.nan)
|
|
632
|
+
continue
|
|
633
|
+
|
|
634
|
+
# Create binary masks for current class
|
|
635
|
+
gt_class = (ground_truth == class_idx).astype(bool)
|
|
636
|
+
pred_class = (prediction == class_idx).astype(bool)
|
|
637
|
+
|
|
638
|
+
# Calculate True Positives, False Positives, False Negatives
|
|
639
|
+
tp = np.logical_and(gt_class, pred_class).sum()
|
|
640
|
+
fp = np.logical_and(~gt_class, pred_class).sum()
|
|
641
|
+
fn = np.logical_and(gt_class, ~pred_class).sum()
|
|
642
|
+
|
|
643
|
+
# Calculate precision and recall
|
|
644
|
+
precision = (tp + smooth) / (tp + fp + smooth)
|
|
645
|
+
recall = (tp + smooth) / (tp + fn + smooth)
|
|
646
|
+
|
|
647
|
+
# Calculate F1 score
|
|
648
|
+
if tp + fp + fn == 0:
|
|
649
|
+
# If class is not present in both gt and pred
|
|
650
|
+
f1_per_class.append(np.nan)
|
|
651
|
+
else:
|
|
652
|
+
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
|
653
|
+
f1_per_class.append(f1)
|
|
654
|
+
|
|
655
|
+
return np.array(f1_per_class)
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def calc_segmentation_metrics(
|
|
659
|
+
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
660
|
+
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
661
|
+
num_classes: Optional[int] = None,
|
|
662
|
+
ignore_index: Optional[int] = None,
|
|
663
|
+
smooth: float = 1e-6,
|
|
664
|
+
metrics: List[str] = ["iou", "f1"],
|
|
665
|
+
band: int = 1,
|
|
666
|
+
) -> Dict[str, Union[float, np.ndarray]]:
|
|
667
|
+
"""
|
|
668
|
+
Calculate multiple segmentation metrics between ground truth and prediction masks.
|
|
669
|
+
|
|
670
|
+
This is a convenient wrapper function that computes multiple metrics at once,
|
|
671
|
+
including IoU (Intersection over Union) and F1 score. It supports both binary
|
|
672
|
+
and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
|
673
|
+
or file paths to raster files.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
677
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
678
|
+
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
679
|
+
For multi-class segmentation: shape (H, W) with class indices.
|
|
680
|
+
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
681
|
+
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
682
|
+
Should have the same shape and format as ground_truth.
|
|
683
|
+
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
684
|
+
If None, assumes binary segmentation. Defaults to None.
|
|
685
|
+
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
686
|
+
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
687
|
+
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
688
|
+
Defaults to 1e-6.
|
|
689
|
+
metrics (List[str], optional): List of metrics to calculate.
|
|
690
|
+
Options: "iou", "f1". Defaults to ["iou", "f1"].
|
|
691
|
+
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
692
|
+
Only used when input is a file path. Defaults to 1.
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
|
|
696
|
+
Keys are metric names ("iou", "f1"), values are the metric scores.
|
|
697
|
+
For binary segmentation, values are floats.
|
|
698
|
+
For multi-class segmentation, values are numpy arrays with per-class scores.
|
|
699
|
+
Also includes "mean_iou" and "mean_f1" for multi-class segmentation
|
|
700
|
+
(mean computed over valid classes, ignoring NaN values).
|
|
701
|
+
|
|
702
|
+
Examples:
|
|
703
|
+
>>> # Binary segmentation with arrays
|
|
704
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
705
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
706
|
+
>>> metrics = calc_segmentation_metrics(gt, pred)
|
|
707
|
+
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
|
708
|
+
IoU: 0.8333, F1: 0.8571
|
|
709
|
+
|
|
710
|
+
>>> # Multi-class segmentation
|
|
711
|
+
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
712
|
+
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
713
|
+
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
|
|
714
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
715
|
+
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
|
716
|
+
>>> print(f"Per-class IoU: {metrics['iou']}")
|
|
717
|
+
Mean IoU: 0.6111
|
|
718
|
+
Mean F1: 0.7302
|
|
719
|
+
Per-class IoU: [0.8333 0.5000 0.5000]
|
|
720
|
+
|
|
721
|
+
>>> # Calculate only IoU
|
|
722
|
+
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
|
|
723
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
724
|
+
Mean IoU: 0.6111
|
|
725
|
+
|
|
726
|
+
>>> # Using PyTorch tensors
|
|
727
|
+
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
728
|
+
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
729
|
+
>>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
|
|
730
|
+
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
|
731
|
+
IoU: 0.8333, F1: 0.8571
|
|
732
|
+
|
|
733
|
+
>>> # Using raster file paths
|
|
734
|
+
>>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
735
|
+
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
736
|
+
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
|
737
|
+
Mean IoU: 0.6111
|
|
738
|
+
Mean F1: 0.7302
|
|
739
|
+
"""
|
|
740
|
+
results = {}
|
|
741
|
+
|
|
742
|
+
# Calculate IoU if requested
|
|
743
|
+
if "iou" in metrics:
|
|
744
|
+
iou = calc_iou(
|
|
745
|
+
ground_truth,
|
|
746
|
+
prediction,
|
|
747
|
+
num_classes=num_classes,
|
|
748
|
+
ignore_index=ignore_index,
|
|
749
|
+
smooth=smooth,
|
|
750
|
+
band=band,
|
|
751
|
+
)
|
|
752
|
+
results["iou"] = iou
|
|
753
|
+
|
|
754
|
+
# Add mean IoU for multi-class
|
|
755
|
+
if num_classes is not None and isinstance(iou, np.ndarray):
|
|
756
|
+
# Calculate mean ignoring NaN values
|
|
757
|
+
valid_ious = iou[~np.isnan(iou)]
|
|
758
|
+
results["mean_iou"] = (
|
|
759
|
+
float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Calculate F1 score if requested
|
|
763
|
+
if "f1" in metrics:
|
|
764
|
+
f1 = calc_f1_score(
|
|
765
|
+
ground_truth,
|
|
766
|
+
prediction,
|
|
767
|
+
num_classes=num_classes,
|
|
768
|
+
ignore_index=ignore_index,
|
|
769
|
+
smooth=smooth,
|
|
770
|
+
band=band,
|
|
771
|
+
)
|
|
772
|
+
results["f1"] = f1
|
|
773
|
+
|
|
774
|
+
# Add mean F1 for multi-class
|
|
775
|
+
if num_classes is not None and isinstance(f1, np.ndarray):
|
|
776
|
+
# Calculate mean ignoring NaN values
|
|
777
|
+
valid_f1s = f1[~np.isnan(f1)]
|
|
778
|
+
results["mean_f1"] = (
|
|
779
|
+
float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
return results
|
|
783
|
+
|
|
784
|
+
|
|
384
785
|
def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
|
|
385
786
|
"""Convert a dictionary to a xarray DataArray. The dictionary should contain the
|
|
386
787
|
following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
|
|
@@ -691,8 +1092,9 @@ def view_vector(
|
|
|
691
1092
|
|
|
692
1093
|
def view_vector_interactive(
|
|
693
1094
|
vector_data: Union[str, gpd.GeoDataFrame],
|
|
694
|
-
layer_name: str = "Vector
|
|
1095
|
+
layer_name: str = "Vector",
|
|
695
1096
|
tiles_args: Optional[Dict] = None,
|
|
1097
|
+
opacity: float = 0.7,
|
|
696
1098
|
**kwargs: Any,
|
|
697
1099
|
) -> Any:
|
|
698
1100
|
"""
|
|
@@ -704,9 +1106,10 @@ def view_vector_interactive(
|
|
|
704
1106
|
|
|
705
1107
|
Args:
|
|
706
1108
|
vector_data (geopandas.GeoDataFrame): The vector dataset to visualize.
|
|
707
|
-
layer_name (str, optional): The name of the layer. Defaults to "Vector
|
|
1109
|
+
layer_name (str, optional): The name of the layer. Defaults to "Vector".
|
|
708
1110
|
tiles_args (dict, optional): Additional arguments for the localtileserver client.
|
|
709
1111
|
get_folium_tile_layer function. Defaults to None.
|
|
1112
|
+
opacity (float, optional): The opacity of the layer. Defaults to 0.7.
|
|
710
1113
|
**kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
|
|
711
1114
|
See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
|
|
712
1115
|
|
|
@@ -721,9 +1124,8 @@ def view_vector_interactive(
|
|
|
721
1124
|
>>> roads = gpd.read_file("roads.shp")
|
|
722
1125
|
>>> view_vector_interactive(roads, figsize=(12, 8))
|
|
723
1126
|
"""
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
from leafmap import cog_tile
|
|
1127
|
+
|
|
1128
|
+
from leafmap.foliumap import Map
|
|
727
1129
|
from localtileserver import TileClient, get_folium_tile_layer
|
|
728
1130
|
|
|
729
1131
|
google_tiles = {
|
|
@@ -749,9 +1151,17 @@ def view_vector_interactive(
|
|
|
749
1151
|
},
|
|
750
1152
|
}
|
|
751
1153
|
|
|
1154
|
+
# Make it compatible with binder and JupyterHub
|
|
1155
|
+
if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
|
|
1156
|
+
os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
|
|
1157
|
+
f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
|
|
1158
|
+
)
|
|
1159
|
+
|
|
752
1160
|
basemap_layer_name = None
|
|
753
1161
|
raster_layer = None
|
|
754
1162
|
|
|
1163
|
+
m = Map()
|
|
1164
|
+
|
|
755
1165
|
if "tiles" in kwargs and isinstance(kwargs["tiles"], str):
|
|
756
1166
|
if kwargs["tiles"].title() in google_tiles:
|
|
757
1167
|
basemap_layer_name = google_tiles[kwargs["tiles"].title()]["name"]
|
|
@@ -762,14 +1172,17 @@ def view_vector_interactive(
|
|
|
762
1172
|
tiles_args = {}
|
|
763
1173
|
if kwargs["tiles"].lower().startswith("http"):
|
|
764
1174
|
basemap_layer_name = "Remote Raster"
|
|
765
|
-
kwargs["tiles"] =
|
|
766
|
-
kwargs["attr"] = "TiTiler"
|
|
1175
|
+
m.add_geotiff(kwargs["tiles"], name=basemap_layer_name, **tiles_args)
|
|
767
1176
|
else:
|
|
768
1177
|
basemap_layer_name = "Local Raster"
|
|
769
1178
|
client = TileClient(kwargs["tiles"])
|
|
770
1179
|
raster_layer = get_folium_tile_layer(client, **tiles_args)
|
|
771
|
-
|
|
772
|
-
|
|
1180
|
+
m.add_tile_layer(
|
|
1181
|
+
raster_layer.tiles,
|
|
1182
|
+
name=basemap_layer_name,
|
|
1183
|
+
attribution="localtileserver",
|
|
1184
|
+
**tiles_args,
|
|
1185
|
+
)
|
|
773
1186
|
|
|
774
1187
|
if "max_zoom" not in kwargs:
|
|
775
1188
|
kwargs["max_zoom"] = 30
|
|
@@ -784,23 +1197,18 @@ def view_vector_interactive(
|
|
|
784
1197
|
if not isinstance(vector_data, gpd.GeoDataFrame):
|
|
785
1198
|
raise TypeError("Input data must be a GeoDataFrame")
|
|
786
1199
|
|
|
787
|
-
|
|
788
|
-
|
|
1200
|
+
if "column" in kwargs:
|
|
1201
|
+
if "legend_position" not in kwargs:
|
|
1202
|
+
kwargs["legend_position"] = "bottomleft"
|
|
1203
|
+
if "cmap" not in kwargs:
|
|
1204
|
+
kwargs["cmap"] = "viridis"
|
|
1205
|
+
m.add_data(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
|
|
789
1206
|
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
# Change the layer name
|
|
793
|
-
for layer in m._children.values():
|
|
794
|
-
if isinstance(layer, folium.GeoJson):
|
|
795
|
-
layer.layer_name = layer_name
|
|
796
|
-
if isinstance(layer, folium.TileLayer) and basemap_layer_name:
|
|
797
|
-
layer.layer_name = basemap_layer_name
|
|
798
|
-
|
|
799
|
-
if layer_control:
|
|
800
|
-
m.add_child(folium.LayerControl())
|
|
1207
|
+
else:
|
|
1208
|
+
m.add_gdf(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
|
|
801
1209
|
|
|
802
|
-
|
|
803
|
-
|
|
1210
|
+
m.add_layer_control()
|
|
1211
|
+
m.add_opacity_control()
|
|
804
1212
|
|
|
805
1213
|
return m
|
|
806
1214
|
|
|
@@ -2594,10 +3002,86 @@ def batch_vector_to_raster(
|
|
|
2594
3002
|
return output_files
|
|
2595
3003
|
|
|
2596
3004
|
|
|
3005
|
+
def get_default_augmentation_transforms(
|
|
3006
|
+
tile_size: int = 256,
|
|
3007
|
+
include_normalize: bool = False,
|
|
3008
|
+
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
|
3009
|
+
std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
|
|
3010
|
+
) -> Any:
|
|
3011
|
+
"""
|
|
3012
|
+
Get default data augmentation transforms for geospatial imagery using albumentations.
|
|
3013
|
+
|
|
3014
|
+
This function returns a composition of augmentation transforms commonly used
|
|
3015
|
+
for remote sensing and geospatial data. The transforms include geometric
|
|
3016
|
+
transformations (flips, rotations) and photometric adjustments (brightness,
|
|
3017
|
+
contrast, saturation).
|
|
3018
|
+
|
|
3019
|
+
Args:
|
|
3020
|
+
tile_size (int): Target size for tiles. Defaults to 256.
|
|
3021
|
+
include_normalize (bool): Whether to include normalization transform.
|
|
3022
|
+
Defaults to False. Set to True if using for training with pretrained models.
|
|
3023
|
+
mean (tuple): Mean values for normalization (RGB). Defaults to ImageNet values.
|
|
3024
|
+
std (tuple): Standard deviation for normalization (RGB). Defaults to ImageNet values.
|
|
3025
|
+
|
|
3026
|
+
Returns:
|
|
3027
|
+
albumentations.Compose: A composition of augmentation transforms.
|
|
3028
|
+
|
|
3029
|
+
Example:
|
|
3030
|
+
>>> import albumentations as A
|
|
3031
|
+
>>> # Get default transforms
|
|
3032
|
+
>>> transform = get_default_augmentation_transforms()
|
|
3033
|
+
>>> # Apply to image and mask
|
|
3034
|
+
>>> augmented = transform(image=image, mask=mask)
|
|
3035
|
+
>>> aug_image = augmented['image']
|
|
3036
|
+
>>> aug_mask = augmented['mask']
|
|
3037
|
+
"""
|
|
3038
|
+
try:
|
|
3039
|
+
import albumentations as A
|
|
3040
|
+
except ImportError:
|
|
3041
|
+
raise ImportError(
|
|
3042
|
+
"albumentations is required for data augmentation. "
|
|
3043
|
+
"Install it with: pip install albumentations"
|
|
3044
|
+
)
|
|
3045
|
+
|
|
3046
|
+
transforms_list = [
|
|
3047
|
+
# Geometric transforms
|
|
3048
|
+
A.HorizontalFlip(p=0.5),
|
|
3049
|
+
A.VerticalFlip(p=0.5),
|
|
3050
|
+
A.RandomRotate90(p=0.5),
|
|
3051
|
+
A.ShiftScaleRotate(
|
|
3052
|
+
shift_limit=0.1,
|
|
3053
|
+
scale_limit=0.1,
|
|
3054
|
+
rotate_limit=45,
|
|
3055
|
+
border_mode=0,
|
|
3056
|
+
p=0.5,
|
|
3057
|
+
),
|
|
3058
|
+
# Photometric transforms
|
|
3059
|
+
A.RandomBrightnessContrast(
|
|
3060
|
+
brightness_limit=0.2,
|
|
3061
|
+
contrast_limit=0.2,
|
|
3062
|
+
p=0.5,
|
|
3063
|
+
),
|
|
3064
|
+
A.HueSaturationValue(
|
|
3065
|
+
hue_shift_limit=10,
|
|
3066
|
+
sat_shift_limit=20,
|
|
3067
|
+
val_shift_limit=10,
|
|
3068
|
+
p=0.3,
|
|
3069
|
+
),
|
|
3070
|
+
A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
|
|
3071
|
+
A.GaussianBlur(blur_limit=(3, 5), p=0.2),
|
|
3072
|
+
]
|
|
3073
|
+
|
|
3074
|
+
# Add normalization if requested
|
|
3075
|
+
if include_normalize:
|
|
3076
|
+
transforms_list.append(A.Normalize(mean=mean, std=std))
|
|
3077
|
+
|
|
3078
|
+
return A.Compose(transforms_list)
|
|
3079
|
+
|
|
3080
|
+
|
|
2597
3081
|
def export_geotiff_tiles(
|
|
2598
3082
|
in_raster,
|
|
2599
3083
|
out_folder,
|
|
2600
|
-
in_class_data,
|
|
3084
|
+
in_class_data=None,
|
|
2601
3085
|
tile_size=256,
|
|
2602
3086
|
stride=128,
|
|
2603
3087
|
class_value_field="class",
|
|
@@ -2607,6 +3091,10 @@ def export_geotiff_tiles(
|
|
|
2607
3091
|
all_touched=True,
|
|
2608
3092
|
create_overview=False,
|
|
2609
3093
|
skip_empty_tiles=False,
|
|
3094
|
+
metadata_format="PASCAL_VOC",
|
|
3095
|
+
apply_augmentation=False,
|
|
3096
|
+
augmentation_count=3,
|
|
3097
|
+
augmentation_transforms=None,
|
|
2610
3098
|
):
|
|
2611
3099
|
"""
|
|
2612
3100
|
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
|
@@ -2614,7 +3102,8 @@ def export_geotiff_tiles(
|
|
|
2614
3102
|
Args:
|
|
2615
3103
|
in_raster (str): Path to input raster image
|
|
2616
3104
|
out_folder (str): Path to output folder
|
|
2617
|
-
in_class_data (str): Path to classification data - can be vector file or raster
|
|
3105
|
+
in_class_data (str, optional): Path to classification data - can be vector file or raster.
|
|
3106
|
+
If None, only image tiles will be exported without labels. Defaults to None.
|
|
2618
3107
|
tile_size (int): Size of tiles in pixels (square)
|
|
2619
3108
|
stride (int): Step size between tiles
|
|
2620
3109
|
class_value_field (str): Field containing class values (for vector data)
|
|
@@ -2624,38 +3113,95 @@ def export_geotiff_tiles(
|
|
|
2624
3113
|
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
|
2625
3114
|
create_overview (bool): Whether to create an overview image of all tiles
|
|
2626
3115
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
3116
|
+
metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
|
|
3117
|
+
apply_augmentation (bool): If True, generate augmented versions of each tile.
|
|
3118
|
+
This will create multiple variants of each tile using data augmentation techniques.
|
|
3119
|
+
Defaults to False.
|
|
3120
|
+
augmentation_count (int): Number of augmented versions to generate per tile
|
|
3121
|
+
(only used if apply_augmentation=True). Defaults to 3.
|
|
3122
|
+
augmentation_transforms (albumentations.Compose, optional): Custom augmentation transforms.
|
|
3123
|
+
If None and apply_augmentation=True, uses default transforms from
|
|
3124
|
+
get_default_augmentation_transforms(). Should be an albumentations.Compose object.
|
|
3125
|
+
Defaults to None.
|
|
3126
|
+
|
|
3127
|
+
Returns:
|
|
3128
|
+
None: Tiles and labels are saved to out_folder.
|
|
3129
|
+
|
|
3130
|
+
Example:
|
|
3131
|
+
>>> # Export tiles without augmentation
|
|
3132
|
+
>>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif')
|
|
3133
|
+
>>>
|
|
3134
|
+
>>> # Export tiles with default augmentation (3 augmented versions per tile)
|
|
3135
|
+
>>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
|
|
3136
|
+
... apply_augmentation=True)
|
|
3137
|
+
>>>
|
|
3138
|
+
>>> # Export with custom augmentation
|
|
3139
|
+
>>> import albumentations as A
|
|
3140
|
+
>>> custom_transform = A.Compose([
|
|
3141
|
+
... A.HorizontalFlip(p=0.5),
|
|
3142
|
+
... A.RandomBrightnessContrast(p=0.5),
|
|
3143
|
+
... ])
|
|
3144
|
+
>>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
|
|
3145
|
+
... apply_augmentation=True,
|
|
3146
|
+
... augmentation_count=5,
|
|
3147
|
+
... augmentation_transforms=custom_transform)
|
|
2627
3148
|
"""
|
|
2628
3149
|
|
|
2629
3150
|
import logging
|
|
2630
3151
|
|
|
2631
3152
|
logging.getLogger("rasterio").setLevel(logging.ERROR)
|
|
2632
3153
|
|
|
3154
|
+
# Initialize augmentation transforms if needed
|
|
3155
|
+
if apply_augmentation:
|
|
3156
|
+
if augmentation_transforms is None:
|
|
3157
|
+
augmentation_transforms = get_default_augmentation_transforms(
|
|
3158
|
+
tile_size=tile_size
|
|
3159
|
+
)
|
|
3160
|
+
if not quiet:
|
|
3161
|
+
print(
|
|
3162
|
+
f"Data augmentation enabled: generating {augmentation_count} augmented versions per tile"
|
|
3163
|
+
)
|
|
3164
|
+
|
|
2633
3165
|
# Create output directories
|
|
2634
3166
|
os.makedirs(out_folder, exist_ok=True)
|
|
2635
3167
|
image_dir = os.path.join(out_folder, "images")
|
|
2636
3168
|
os.makedirs(image_dir, exist_ok=True)
|
|
2637
|
-
label_dir = os.path.join(out_folder, "labels")
|
|
2638
|
-
os.makedirs(label_dir, exist_ok=True)
|
|
2639
|
-
ann_dir = os.path.join(out_folder, "annotations")
|
|
2640
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
2641
3169
|
|
|
2642
|
-
#
|
|
3170
|
+
# Only create label and annotation directories if class data is provided
|
|
3171
|
+
if in_class_data is not None:
|
|
3172
|
+
label_dir = os.path.join(out_folder, "labels")
|
|
3173
|
+
os.makedirs(label_dir, exist_ok=True)
|
|
3174
|
+
|
|
3175
|
+
# Create annotation directory based on metadata format
|
|
3176
|
+
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
|
3177
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
3178
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
3179
|
+
|
|
3180
|
+
# Initialize COCO annotations dictionary
|
|
3181
|
+
if metadata_format == "COCO":
|
|
3182
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
3183
|
+
ann_id = 0
|
|
3184
|
+
|
|
3185
|
+
# Determine if class data is raster or vector (only if class data provided)
|
|
2643
3186
|
is_class_data_raster = False
|
|
2644
|
-
if
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
|
|
2648
|
-
|
|
2649
|
-
|
|
2650
|
-
|
|
3187
|
+
if in_class_data is not None:
|
|
3188
|
+
if isinstance(in_class_data, str):
|
|
3189
|
+
file_ext = Path(in_class_data).suffix.lower()
|
|
3190
|
+
# Common raster extensions
|
|
3191
|
+
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
3192
|
+
try:
|
|
3193
|
+
with rasterio.open(in_class_data) as src:
|
|
3194
|
+
is_class_data_raster = True
|
|
3195
|
+
if not quiet:
|
|
3196
|
+
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
3197
|
+
print(f"Raster CRS: {src.crs}")
|
|
3198
|
+
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
3199
|
+
except Exception:
|
|
3200
|
+
is_class_data_raster = False
|
|
2651
3201
|
if not quiet:
|
|
2652
|
-
print(
|
|
2653
|
-
|
|
2654
|
-
|
|
2655
|
-
except Exception:
|
|
2656
|
-
is_class_data_raster = False
|
|
2657
|
-
if not quiet:
|
|
2658
|
-
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
|
3202
|
+
print(
|
|
3203
|
+
f"Unable to open {in_class_data} as raster, trying as vector"
|
|
3204
|
+
)
|
|
2659
3205
|
|
|
2660
3206
|
# Open the input raster
|
|
2661
3207
|
with rasterio.open(in_raster) as src:
|
|
@@ -2675,10 +3221,10 @@ def export_geotiff_tiles(
|
|
|
2675
3221
|
if max_tiles is None:
|
|
2676
3222
|
max_tiles = total_tiles
|
|
2677
3223
|
|
|
2678
|
-
# Process classification data
|
|
3224
|
+
# Process classification data (only if class data provided)
|
|
2679
3225
|
class_to_id = {}
|
|
2680
3226
|
|
|
2681
|
-
if is_class_data_raster:
|
|
3227
|
+
if in_class_data is not None and is_class_data_raster:
|
|
2682
3228
|
# Load raster class data
|
|
2683
3229
|
with rasterio.open(in_class_data) as class_src:
|
|
2684
3230
|
# Check if raster CRS matches
|
|
@@ -2711,7 +3257,18 @@ def export_geotiff_tiles(
|
|
|
2711
3257
|
|
|
2712
3258
|
# Create class mapping
|
|
2713
3259
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
2714
|
-
|
|
3260
|
+
|
|
3261
|
+
# Populate COCO categories
|
|
3262
|
+
if metadata_format == "COCO":
|
|
3263
|
+
for cls_val in unique_classes:
|
|
3264
|
+
coco_annotations["categories"].append(
|
|
3265
|
+
{
|
|
3266
|
+
"id": class_to_id[int(cls_val)],
|
|
3267
|
+
"name": str(int(cls_val)),
|
|
3268
|
+
"supercategory": "object",
|
|
3269
|
+
}
|
|
3270
|
+
)
|
|
3271
|
+
elif in_class_data is not None:
|
|
2715
3272
|
# Load vector class data
|
|
2716
3273
|
try:
|
|
2717
3274
|
gdf = gpd.read_file(in_class_data)
|
|
@@ -2740,12 +3297,33 @@ def export_geotiff_tiles(
|
|
|
2740
3297
|
)
|
|
2741
3298
|
# Create class mapping
|
|
2742
3299
|
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
3300
|
+
|
|
3301
|
+
# Populate COCO categories
|
|
3302
|
+
if metadata_format == "COCO":
|
|
3303
|
+
for cls_val in unique_classes:
|
|
3304
|
+
coco_annotations["categories"].append(
|
|
3305
|
+
{
|
|
3306
|
+
"id": class_to_id[cls_val],
|
|
3307
|
+
"name": str(cls_val),
|
|
3308
|
+
"supercategory": "object",
|
|
3309
|
+
}
|
|
3310
|
+
)
|
|
2743
3311
|
else:
|
|
2744
3312
|
if not quiet:
|
|
2745
3313
|
print(
|
|
2746
3314
|
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
|
2747
3315
|
)
|
|
2748
3316
|
class_to_id = {1: 1} # Default mapping
|
|
3317
|
+
|
|
3318
|
+
# Populate COCO categories with default
|
|
3319
|
+
if metadata_format == "COCO":
|
|
3320
|
+
coco_annotations["categories"].append(
|
|
3321
|
+
{
|
|
3322
|
+
"id": 1,
|
|
3323
|
+
"name": "object",
|
|
3324
|
+
"supercategory": "object",
|
|
3325
|
+
}
|
|
3326
|
+
)
|
|
2749
3327
|
except Exception as e:
|
|
2750
3328
|
raise ValueError(f"Error processing vector data: {e}")
|
|
2751
3329
|
|
|
@@ -2812,8 +3390,8 @@ def export_geotiff_tiles(
|
|
|
2812
3390
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
2813
3391
|
has_features = False
|
|
2814
3392
|
|
|
2815
|
-
# Process classification data to create labels
|
|
2816
|
-
if is_class_data_raster:
|
|
3393
|
+
# Process classification data to create labels (only if class data provided)
|
|
3394
|
+
if in_class_data is not None and is_class_data_raster:
|
|
2817
3395
|
# For raster class data
|
|
2818
3396
|
with rasterio.open(in_class_data) as class_src:
|
|
2819
3397
|
# Calculate window in class raster
|
|
@@ -2863,7 +3441,7 @@ def export_geotiff_tiles(
|
|
|
2863
3441
|
except Exception as e:
|
|
2864
3442
|
pbar.write(f"Error reading class raster window: {e}")
|
|
2865
3443
|
stats["errors"] += 1
|
|
2866
|
-
|
|
3444
|
+
elif in_class_data is not None:
|
|
2867
3445
|
# For vector class data
|
|
2868
3446
|
# Find features that intersect with window
|
|
2869
3447
|
window_features = gdf[gdf.intersects(window_bounds)]
|
|
@@ -2906,8 +3484,8 @@ def export_geotiff_tiles(
|
|
|
2906
3484
|
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
2907
3485
|
stats["errors"] += 1
|
|
2908
3486
|
|
|
2909
|
-
# Skip tile if no features and skip_empty_tiles is True
|
|
2910
|
-
if skip_empty_tiles and not has_features:
|
|
3487
|
+
# Skip tile if no features and skip_empty_tiles is True (only when class data provided)
|
|
3488
|
+
if in_class_data is not None and skip_empty_tiles and not has_features:
|
|
2911
3489
|
pbar.update(1)
|
|
2912
3490
|
tile_index += 1
|
|
2913
3491
|
continue
|
|
@@ -2915,119 +3493,316 @@ def export_geotiff_tiles(
|
|
|
2915
3493
|
# Read image data
|
|
2916
3494
|
image_data = src.read(window=window)
|
|
2917
3495
|
|
|
2918
|
-
#
|
|
2919
|
-
|
|
3496
|
+
# Helper function to save a single tile (original or augmented)
|
|
3497
|
+
def save_tile(
|
|
3498
|
+
img_data,
|
|
3499
|
+
lbl_mask,
|
|
3500
|
+
tile_id,
|
|
3501
|
+
img_profile,
|
|
3502
|
+
window_trans,
|
|
3503
|
+
is_augmented=False,
|
|
3504
|
+
):
|
|
3505
|
+
"""Save a single image and label tile."""
|
|
3506
|
+
# Export image as GeoTIFF
|
|
3507
|
+
image_path = os.path.join(image_dir, f"tile_{tile_id:06d}.tif")
|
|
2920
3508
|
|
|
2921
|
-
|
|
2922
|
-
|
|
2923
|
-
|
|
2924
|
-
|
|
2925
|
-
|
|
2926
|
-
|
|
2927
|
-
|
|
2928
|
-
|
|
2929
|
-
|
|
3509
|
+
# Update profile
|
|
3510
|
+
img_profile_copy = img_profile.copy()
|
|
3511
|
+
img_profile_copy.update(
|
|
3512
|
+
{
|
|
3513
|
+
"height": tile_size,
|
|
3514
|
+
"width": tile_size,
|
|
3515
|
+
"count": img_data.shape[0],
|
|
3516
|
+
"transform": window_trans,
|
|
3517
|
+
}
|
|
3518
|
+
)
|
|
3519
|
+
|
|
3520
|
+
# Save image as GeoTIFF
|
|
3521
|
+
try:
|
|
3522
|
+
with rasterio.open(image_path, "w", **img_profile_copy) as dst:
|
|
3523
|
+
dst.write(img_data)
|
|
3524
|
+
stats["total_tiles"] += 1
|
|
3525
|
+
except Exception as e:
|
|
3526
|
+
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
|
3527
|
+
stats["errors"] += 1
|
|
3528
|
+
return
|
|
3529
|
+
|
|
3530
|
+
# Export label as GeoTIFF (only if class data provided)
|
|
3531
|
+
if in_class_data is not None:
|
|
3532
|
+
# Create profile for label GeoTIFF
|
|
3533
|
+
label_profile = {
|
|
3534
|
+
"driver": "GTiff",
|
|
3535
|
+
"height": tile_size,
|
|
3536
|
+
"width": tile_size,
|
|
3537
|
+
"count": 1,
|
|
3538
|
+
"dtype": "uint8",
|
|
3539
|
+
"crs": src.crs,
|
|
3540
|
+
"transform": window_trans,
|
|
3541
|
+
}
|
|
3542
|
+
|
|
3543
|
+
label_path = os.path.join(label_dir, f"tile_{tile_id:06d}.tif")
|
|
3544
|
+
try:
|
|
3545
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
3546
|
+
dst.write(lbl_mask.astype(np.uint8), 1)
|
|
3547
|
+
|
|
3548
|
+
if not is_augmented and np.any(lbl_mask > 0):
|
|
3549
|
+
stats["tiles_with_features"] += 1
|
|
3550
|
+
stats["feature_pixels"] += np.count_nonzero(lbl_mask)
|
|
3551
|
+
except Exception as e:
|
|
3552
|
+
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
|
3553
|
+
stats["errors"] += 1
|
|
3554
|
+
|
|
3555
|
+
# Save original tile
|
|
3556
|
+
save_tile(
|
|
3557
|
+
image_data,
|
|
3558
|
+
label_mask,
|
|
3559
|
+
tile_index,
|
|
3560
|
+
src.profile,
|
|
3561
|
+
window_transform,
|
|
3562
|
+
is_augmented=False,
|
|
2930
3563
|
)
|
|
2931
3564
|
|
|
2932
|
-
#
|
|
2933
|
-
|
|
2934
|
-
|
|
2935
|
-
|
|
2936
|
-
|
|
2937
|
-
|
|
2938
|
-
|
|
2939
|
-
|
|
3565
|
+
# Generate and save augmented tiles if enabled
|
|
3566
|
+
if apply_augmentation:
|
|
3567
|
+
for aug_idx in range(augmentation_count):
|
|
3568
|
+
# Prepare image for augmentation (convert from CHW to HWC)
|
|
3569
|
+
img_for_aug = np.transpose(image_data, (1, 2, 0))
|
|
3570
|
+
|
|
3571
|
+
# Ensure uint8 data type for albumentations
|
|
3572
|
+
# Albumentations expects uint8 for most transforms
|
|
3573
|
+
if not np.issubdtype(img_for_aug.dtype, np.uint8):
|
|
3574
|
+
# If image is float, scale to 0-255 and convert to uint8
|
|
3575
|
+
if np.issubdtype(img_for_aug.dtype, np.floating):
|
|
3576
|
+
img_for_aug = (
|
|
3577
|
+
(img_for_aug * 255).clip(0, 255).astype(np.uint8)
|
|
3578
|
+
)
|
|
3579
|
+
else:
|
|
3580
|
+
img_for_aug = img_for_aug.astype(np.uint8)
|
|
2940
3581
|
|
|
2941
|
-
|
|
2942
|
-
|
|
2943
|
-
|
|
2944
|
-
|
|
2945
|
-
|
|
2946
|
-
|
|
2947
|
-
|
|
2948
|
-
|
|
2949
|
-
|
|
2950
|
-
|
|
3582
|
+
# Apply augmentation
|
|
3583
|
+
try:
|
|
3584
|
+
if in_class_data is not None:
|
|
3585
|
+
# Augment both image and mask
|
|
3586
|
+
augmented = augmentation_transforms(
|
|
3587
|
+
image=img_for_aug, mask=label_mask
|
|
3588
|
+
)
|
|
3589
|
+
aug_image = augmented["image"]
|
|
3590
|
+
aug_mask = augmented["mask"]
|
|
3591
|
+
else:
|
|
3592
|
+
# Augment only image
|
|
3593
|
+
augmented = augmentation_transforms(image=img_for_aug)
|
|
3594
|
+
aug_image = augmented["image"]
|
|
3595
|
+
aug_mask = label_mask
|
|
2951
3596
|
|
|
2952
|
-
|
|
2953
|
-
|
|
2954
|
-
try:
|
|
2955
|
-
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
2956
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
|
3597
|
+
# Convert back from HWC to CHW
|
|
3598
|
+
aug_image = np.transpose(aug_image, (2, 0, 1))
|
|
2957
3599
|
|
|
2958
|
-
|
|
2959
|
-
|
|
2960
|
-
|
|
2961
|
-
|
|
2962
|
-
|
|
2963
|
-
|
|
3600
|
+
# Ensure correct dtype for saving
|
|
3601
|
+
aug_image = aug_image.astype(image_data.dtype)
|
|
3602
|
+
|
|
3603
|
+
# Generate unique tile ID for augmented version
|
|
3604
|
+
# Use a collision-free numbering scheme: (tile_index * (augmentation_count + 1)) + aug_idx + 1
|
|
3605
|
+
aug_tile_id = (
|
|
3606
|
+
(tile_index * (augmentation_count + 1)) + aug_idx + 1
|
|
3607
|
+
)
|
|
3608
|
+
|
|
3609
|
+
# Save augmented tile
|
|
3610
|
+
save_tile(
|
|
3611
|
+
aug_image,
|
|
3612
|
+
aug_mask,
|
|
3613
|
+
aug_tile_id,
|
|
3614
|
+
src.profile,
|
|
3615
|
+
window_transform,
|
|
3616
|
+
is_augmented=True,
|
|
3617
|
+
)
|
|
3618
|
+
|
|
3619
|
+
except Exception as e:
|
|
3620
|
+
pbar.write(
|
|
3621
|
+
f"ERROR applying augmentation {aug_idx} to tile {tile_index}: {e}"
|
|
3622
|
+
)
|
|
3623
|
+
stats["errors"] += 1
|
|
2964
3624
|
|
|
2965
|
-
# Create
|
|
3625
|
+
# Create annotations for object detection if using vector class data
|
|
2966
3626
|
if (
|
|
2967
|
-
not
|
|
3627
|
+
in_class_data is not None
|
|
3628
|
+
and not is_class_data_raster
|
|
2968
3629
|
and "gdf" in locals()
|
|
2969
3630
|
and len(window_features) > 0
|
|
2970
3631
|
):
|
|
2971
|
-
|
|
2972
|
-
|
|
2973
|
-
|
|
2974
|
-
|
|
3632
|
+
if metadata_format == "PASCAL_VOC":
|
|
3633
|
+
# Create XML annotation
|
|
3634
|
+
root = ET.Element("annotation")
|
|
3635
|
+
ET.SubElement(root, "folder").text = "images"
|
|
3636
|
+
ET.SubElement(root, "filename").text = (
|
|
3637
|
+
f"tile_{tile_index:06d}.tif"
|
|
3638
|
+
)
|
|
2975
3639
|
|
|
2976
|
-
|
|
2977
|
-
|
|
2978
|
-
|
|
2979
|
-
|
|
3640
|
+
size = ET.SubElement(root, "size")
|
|
3641
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
|
3642
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
|
3643
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
3644
|
+
|
|
3645
|
+
# Add georeference information
|
|
3646
|
+
geo = ET.SubElement(root, "georeference")
|
|
3647
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
3648
|
+
ET.SubElement(geo, "transform").text = str(
|
|
3649
|
+
window_transform
|
|
3650
|
+
).replace("\n", "")
|
|
3651
|
+
ET.SubElement(geo, "bounds").text = (
|
|
3652
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
3653
|
+
)
|
|
2980
3654
|
|
|
2981
|
-
|
|
2982
|
-
|
|
2983
|
-
|
|
2984
|
-
|
|
2985
|
-
|
|
2986
|
-
|
|
2987
|
-
|
|
2988
|
-
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
2989
|
-
)
|
|
3655
|
+
# Add objects
|
|
3656
|
+
for idx, feature in window_features.iterrows():
|
|
3657
|
+
# Get feature class
|
|
3658
|
+
if class_value_field in feature:
|
|
3659
|
+
class_val = feature[class_value_field]
|
|
3660
|
+
else:
|
|
3661
|
+
class_val = "object"
|
|
2990
3662
|
|
|
2991
|
-
|
|
2992
|
-
|
|
2993
|
-
|
|
2994
|
-
|
|
2995
|
-
|
|
2996
|
-
|
|
2997
|
-
|
|
3663
|
+
# Get geometry bounds in pixel coordinates
|
|
3664
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
3665
|
+
if not geom.is_empty:
|
|
3666
|
+
# Get bounds in world coordinates
|
|
3667
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3668
|
+
|
|
3669
|
+
# Convert to pixel coordinates
|
|
3670
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3671
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3672
|
+
|
|
3673
|
+
# Ensure coordinates are within tile bounds
|
|
3674
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
|
3675
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
|
3676
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
|
3677
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
|
3678
|
+
|
|
3679
|
+
# Only add if the box has non-zero area
|
|
3680
|
+
if xmax > xmin and ymax > ymin:
|
|
3681
|
+
obj = ET.SubElement(root, "object")
|
|
3682
|
+
ET.SubElement(obj, "name").text = str(class_val)
|
|
3683
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
3684
|
+
|
|
3685
|
+
bbox = ET.SubElement(obj, "bndbox")
|
|
3686
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
3687
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
3688
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
3689
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
3690
|
+
|
|
3691
|
+
# Save XML
|
|
3692
|
+
tree = ET.ElementTree(root)
|
|
3693
|
+
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
3694
|
+
tree.write(xml_path)
|
|
2998
3695
|
|
|
2999
|
-
|
|
3000
|
-
|
|
3001
|
-
|
|
3002
|
-
|
|
3003
|
-
|
|
3004
|
-
|
|
3005
|
-
|
|
3006
|
-
|
|
3007
|
-
|
|
3008
|
-
|
|
3009
|
-
|
|
3010
|
-
|
|
3011
|
-
|
|
3012
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
|
3013
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
|
3014
|
-
|
|
3015
|
-
# Only add if the box has non-zero area
|
|
3016
|
-
if xmax > xmin and ymax > ymin:
|
|
3017
|
-
obj = ET.SubElement(root, "object")
|
|
3018
|
-
ET.SubElement(obj, "name").text = str(class_val)
|
|
3019
|
-
ET.SubElement(obj, "difficult").text = "0"
|
|
3020
|
-
|
|
3021
|
-
bbox = ET.SubElement(obj, "bndbox")
|
|
3022
|
-
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
3023
|
-
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
3024
|
-
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
3025
|
-
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
3696
|
+
elif metadata_format == "COCO":
|
|
3697
|
+
# Add image info
|
|
3698
|
+
image_id = tile_index
|
|
3699
|
+
coco_annotations["images"].append(
|
|
3700
|
+
{
|
|
3701
|
+
"id": image_id,
|
|
3702
|
+
"file_name": f"tile_{tile_index:06d}.tif",
|
|
3703
|
+
"width": tile_size,
|
|
3704
|
+
"height": tile_size,
|
|
3705
|
+
"crs": str(src.crs),
|
|
3706
|
+
"transform": str(window_transform),
|
|
3707
|
+
}
|
|
3708
|
+
)
|
|
3026
3709
|
|
|
3027
|
-
|
|
3028
|
-
|
|
3029
|
-
|
|
3030
|
-
|
|
3710
|
+
# Add annotations for each feature
|
|
3711
|
+
for _, feature in window_features.iterrows():
|
|
3712
|
+
# Get feature class
|
|
3713
|
+
if class_value_field in feature:
|
|
3714
|
+
class_val = feature[class_value_field]
|
|
3715
|
+
category_id = class_to_id.get(class_val, 1)
|
|
3716
|
+
else:
|
|
3717
|
+
category_id = 1
|
|
3718
|
+
|
|
3719
|
+
# Get geometry bounds
|
|
3720
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
3721
|
+
if not geom.is_empty:
|
|
3722
|
+
# Get bounds in world coordinates
|
|
3723
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3724
|
+
|
|
3725
|
+
# Convert to pixel coordinates
|
|
3726
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3727
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3728
|
+
|
|
3729
|
+
# Ensure coordinates are within tile bounds
|
|
3730
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
|
3731
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
|
3732
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
|
3733
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
|
3734
|
+
|
|
3735
|
+
# Skip if box is too small
|
|
3736
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
3737
|
+
continue
|
|
3738
|
+
|
|
3739
|
+
width = xmax - xmin
|
|
3740
|
+
height = ymax - ymin
|
|
3741
|
+
|
|
3742
|
+
# Add annotation
|
|
3743
|
+
ann_id += 1
|
|
3744
|
+
coco_annotations["annotations"].append(
|
|
3745
|
+
{
|
|
3746
|
+
"id": ann_id,
|
|
3747
|
+
"image_id": image_id,
|
|
3748
|
+
"category_id": category_id,
|
|
3749
|
+
"bbox": [xmin, ymin, width, height],
|
|
3750
|
+
"area": width * height,
|
|
3751
|
+
"iscrowd": 0,
|
|
3752
|
+
}
|
|
3753
|
+
)
|
|
3754
|
+
|
|
3755
|
+
elif metadata_format == "YOLO":
|
|
3756
|
+
# Create YOLO format annotations
|
|
3757
|
+
yolo_annotations = []
|
|
3758
|
+
|
|
3759
|
+
for _, feature in window_features.iterrows():
|
|
3760
|
+
# Get feature class
|
|
3761
|
+
if class_value_field in feature:
|
|
3762
|
+
class_val = feature[class_value_field]
|
|
3763
|
+
# YOLO uses 0-indexed class IDs
|
|
3764
|
+
class_id = class_to_id.get(class_val, 1) - 1
|
|
3765
|
+
else:
|
|
3766
|
+
class_id = 0
|
|
3767
|
+
|
|
3768
|
+
# Get geometry bounds
|
|
3769
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
3770
|
+
if not geom.is_empty:
|
|
3771
|
+
# Get bounds in world coordinates
|
|
3772
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3773
|
+
|
|
3774
|
+
# Convert to pixel coordinates
|
|
3775
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3776
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3777
|
+
|
|
3778
|
+
# Ensure coordinates are within tile bounds
|
|
3779
|
+
xmin = max(0, min(tile_size, col_min))
|
|
3780
|
+
ymin = max(0, min(tile_size, row_min))
|
|
3781
|
+
xmax = max(0, min(tile_size, col_max))
|
|
3782
|
+
ymax = max(0, min(tile_size, row_max))
|
|
3783
|
+
|
|
3784
|
+
# Skip if box is too small
|
|
3785
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
3786
|
+
continue
|
|
3787
|
+
|
|
3788
|
+
# Calculate normalized coordinates (YOLO format)
|
|
3789
|
+
x_center = ((xmin + xmax) / 2) / tile_size
|
|
3790
|
+
y_center = ((ymin + ymax) / 2) / tile_size
|
|
3791
|
+
width = (xmax - xmin) / tile_size
|
|
3792
|
+
height = (ymax - ymin) / tile_size
|
|
3793
|
+
|
|
3794
|
+
# Add YOLO annotation line
|
|
3795
|
+
yolo_annotations.append(
|
|
3796
|
+
f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
|
3797
|
+
)
|
|
3798
|
+
|
|
3799
|
+
# Save YOLO annotations to text file
|
|
3800
|
+
if yolo_annotations:
|
|
3801
|
+
yolo_path = os.path.join(
|
|
3802
|
+
label_dir, f"tile_{tile_index:06d}.txt"
|
|
3803
|
+
)
|
|
3804
|
+
with open(yolo_path, "w") as f:
|
|
3805
|
+
f.write("\n".join(yolo_annotations))
|
|
3031
3806
|
|
|
3032
3807
|
# Update progress bar
|
|
3033
3808
|
pbar.update(1)
|
|
@@ -3045,6 +3820,39 @@ def export_geotiff_tiles(
|
|
|
3045
3820
|
# Close progress bar
|
|
3046
3821
|
pbar.close()
|
|
3047
3822
|
|
|
3823
|
+
# Save COCO annotations if applicable (only if class data provided)
|
|
3824
|
+
if in_class_data is not None and metadata_format == "COCO":
|
|
3825
|
+
try:
|
|
3826
|
+
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
|
3827
|
+
json.dump(coco_annotations, f, indent=2)
|
|
3828
|
+
if not quiet:
|
|
3829
|
+
print(
|
|
3830
|
+
f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
|
|
3831
|
+
f"{len(coco_annotations['annotations'])} annotations, "
|
|
3832
|
+
f"{len(coco_annotations['categories'])} categories"
|
|
3833
|
+
)
|
|
3834
|
+
except Exception as e:
|
|
3835
|
+
if not quiet:
|
|
3836
|
+
print(f"ERROR saving COCO annotations: {e}")
|
|
3837
|
+
stats["errors"] += 1
|
|
3838
|
+
|
|
3839
|
+
# Save YOLO classes file if applicable (only if class data provided)
|
|
3840
|
+
if in_class_data is not None and metadata_format == "YOLO":
|
|
3841
|
+
try:
|
|
3842
|
+
# Create classes.txt with class names
|
|
3843
|
+
classes_path = os.path.join(out_folder, "classes.txt")
|
|
3844
|
+
# Sort by class ID to ensure correct order
|
|
3845
|
+
sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
|
|
3846
|
+
with open(classes_path, "w") as f:
|
|
3847
|
+
for class_val, _ in sorted_classes:
|
|
3848
|
+
f.write(f"{class_val}\n")
|
|
3849
|
+
if not quiet:
|
|
3850
|
+
print(f"Saved YOLO classes file with {len(class_to_id)} classes")
|
|
3851
|
+
except Exception as e:
|
|
3852
|
+
if not quiet:
|
|
3853
|
+
print(f"ERROR saving YOLO classes file: {e}")
|
|
3854
|
+
stats["errors"] += 1
|
|
3855
|
+
|
|
3048
3856
|
# Create overview image if requested
|
|
3049
3857
|
if create_overview and stats["tile_coordinates"]:
|
|
3050
3858
|
try:
|
|
@@ -3062,13 +3870,14 @@ def export_geotiff_tiles(
|
|
|
3062
3870
|
if not quiet:
|
|
3063
3871
|
print("\n------- Export Summary -------")
|
|
3064
3872
|
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
3065
|
-
|
|
3066
|
-
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
3067
|
-
)
|
|
3068
|
-
if stats["tiles_with_features"] > 0:
|
|
3873
|
+
if in_class_data is not None:
|
|
3069
3874
|
print(
|
|
3070
|
-
f"
|
|
3875
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
3071
3876
|
)
|
|
3877
|
+
if stats["tiles_with_features"] > 0:
|
|
3878
|
+
print(
|
|
3879
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
3880
|
+
)
|
|
3072
3881
|
if stats["errors"] > 0:
|
|
3073
3882
|
print(f"Errors encountered: {stats['errors']}")
|
|
3074
3883
|
print(f"Output saved to: {out_folder}")
|
|
@@ -3077,7 +3886,6 @@ def export_geotiff_tiles(
|
|
|
3077
3886
|
if stats["total_tiles"] > 0:
|
|
3078
3887
|
print("\n------- Georeference Verification -------")
|
|
3079
3888
|
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
|
3080
|
-
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
|
3081
3889
|
|
|
3082
3890
|
if os.path.exists(sample_image):
|
|
3083
3891
|
try:
|
|
@@ -3093,19 +3901,22 @@ def export_geotiff_tiles(
|
|
|
3093
3901
|
except Exception as e:
|
|
3094
3902
|
print(f"Error verifying image georeference: {e}")
|
|
3095
3903
|
|
|
3096
|
-
if
|
|
3097
|
-
|
|
3098
|
-
|
|
3099
|
-
|
|
3100
|
-
|
|
3101
|
-
|
|
3102
|
-
f"Label
|
|
3103
|
-
|
|
3104
|
-
|
|
3105
|
-
|
|
3106
|
-
|
|
3107
|
-
|
|
3108
|
-
|
|
3904
|
+
# Only verify label if class data was provided
|
|
3905
|
+
if in_class_data is not None:
|
|
3906
|
+
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
|
3907
|
+
if os.path.exists(sample_label):
|
|
3908
|
+
try:
|
|
3909
|
+
with rasterio.open(sample_label) as lbl:
|
|
3910
|
+
print(f"Label CRS: {lbl.crs}")
|
|
3911
|
+
print(f"Label transform: {lbl.transform}")
|
|
3912
|
+
print(
|
|
3913
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
3914
|
+
)
|
|
3915
|
+
print(
|
|
3916
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
3917
|
+
)
|
|
3918
|
+
except Exception as e:
|
|
3919
|
+
print(f"Error verifying label georeference: {e}")
|
|
3109
3920
|
|
|
3110
3921
|
# Return statistics dictionary for further processing if needed
|
|
3111
3922
|
return stats
|
|
@@ -3113,8 +3924,9 @@ def export_geotiff_tiles(
|
|
|
3113
3924
|
|
|
3114
3925
|
def export_geotiff_tiles_batch(
|
|
3115
3926
|
images_folder,
|
|
3116
|
-
masks_folder,
|
|
3117
|
-
|
|
3927
|
+
masks_folder=None,
|
|
3928
|
+
masks_file=None,
|
|
3929
|
+
output_folder=None,
|
|
3118
3930
|
tile_size=256,
|
|
3119
3931
|
stride=128,
|
|
3120
3932
|
class_value_field="class",
|
|
@@ -3122,25 +3934,43 @@ def export_geotiff_tiles_batch(
|
|
|
3122
3934
|
max_tiles=None,
|
|
3123
3935
|
quiet=False,
|
|
3124
3936
|
all_touched=True,
|
|
3125
|
-
create_overview=False,
|
|
3126
3937
|
skip_empty_tiles=False,
|
|
3127
3938
|
image_extensions=None,
|
|
3128
3939
|
mask_extensions=None,
|
|
3940
|
+
match_by_name=False,
|
|
3941
|
+
metadata_format="PASCAL_VOC",
|
|
3129
3942
|
) -> Dict[str, Any]:
|
|
3130
3943
|
"""
|
|
3131
|
-
Export georeferenced GeoTIFF tiles from
|
|
3944
|
+
Export georeferenced GeoTIFF tiles from images and optionally masks.
|
|
3945
|
+
|
|
3946
|
+
This function supports four modes:
|
|
3947
|
+
1. Images only (no masks) - when neither masks_file nor masks_folder is provided
|
|
3948
|
+
2. Single vector file covering all images (masks_file parameter)
|
|
3949
|
+
3. Multiple vector files, one per image (masks_folder parameter)
|
|
3950
|
+
4. Multiple raster mask files (masks_folder parameter)
|
|
3132
3951
|
|
|
3133
|
-
|
|
3134
|
-
generating tiles for each pair. All image tiles are saved to a single
|
|
3135
|
-
'images' folder and all mask tiles to a single 'masks' folder.
|
|
3952
|
+
For mode 1 (images only), only image tiles will be exported without labels.
|
|
3136
3953
|
|
|
3137
|
-
|
|
3138
|
-
|
|
3954
|
+
For mode 2 (single vector file), specify masks_file path. The function will
|
|
3955
|
+
use spatial intersection to determine which features apply to each image.
|
|
3956
|
+
|
|
3957
|
+
For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
|
|
3958
|
+
are paired either by matching filenames (match_by_name=True) or by sorted order
|
|
3959
|
+
(match_by_name=False).
|
|
3960
|
+
|
|
3961
|
+
All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
|
|
3962
|
+
to a single 'masks' folder within the output directory.
|
|
3139
3963
|
|
|
3140
3964
|
Args:
|
|
3141
3965
|
images_folder (str): Path to folder containing raster images
|
|
3142
|
-
masks_folder (str): Path to folder containing classification masks/vectors
|
|
3143
|
-
|
|
3966
|
+
masks_folder (str, optional): Path to folder containing classification masks/vectors.
|
|
3967
|
+
Use this for multiple mask files (one per image or raster masks). If not provided
|
|
3968
|
+
and masks_file is also not provided, only image tiles will be exported.
|
|
3969
|
+
masks_file (str, optional): Path to a single vector file covering all images.
|
|
3970
|
+
Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
|
|
3971
|
+
and masks_folder is also not provided, only image tiles will be exported.
|
|
3972
|
+
output_folder (str, optional): Path to output folder. If None, creates 'tiles'
|
|
3973
|
+
subfolder in images_folder.
|
|
3144
3974
|
tile_size (int): Size of tiles in pixels (square)
|
|
3145
3975
|
stride (int): Step size between tiles
|
|
3146
3976
|
class_value_field (str): Field containing class values (for vector data)
|
|
@@ -3152,18 +3982,63 @@ def export_geotiff_tiles_batch(
|
|
|
3152
3982
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
3153
3983
|
image_extensions (list): List of image file extensions to process (default: common raster formats)
|
|
3154
3984
|
mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
|
|
3985
|
+
match_by_name (bool): If True, match image and mask files by base filename.
|
|
3986
|
+
If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
|
|
3987
|
+
metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
|
|
3988
|
+
Default is "PASCAL_VOC".
|
|
3989
|
+
|
|
3990
|
+
Returns:
|
|
3991
|
+
Dict[str, Any]: Dictionary containing batch processing statistics
|
|
3992
|
+
|
|
3993
|
+
Raises:
|
|
3994
|
+
ValueError: If no images found, or if masks_folder and masks_file are both specified,
|
|
3995
|
+
or if counts don't match when using masks_folder with match_by_name=False.
|
|
3996
|
+
|
|
3997
|
+
Examples:
|
|
3998
|
+
# Images only (no masks)
|
|
3999
|
+
>>> stats = export_geotiff_tiles_batch(
|
|
4000
|
+
... images_folder='data/images',
|
|
4001
|
+
... output_folder='output/tiles'
|
|
4002
|
+
... )
|
|
3155
4003
|
|
|
3156
|
-
|
|
3157
|
-
|
|
4004
|
+
# Single vector file covering all images
|
|
4005
|
+
>>> stats = export_geotiff_tiles_batch(
|
|
4006
|
+
... images_folder='data/images',
|
|
4007
|
+
... masks_file='data/buildings.geojson',
|
|
4008
|
+
... output_folder='output/tiles'
|
|
4009
|
+
... )
|
|
3158
4010
|
|
|
3159
|
-
|
|
3160
|
-
|
|
4011
|
+
# Multiple vector files, matched by filename
|
|
4012
|
+
>>> stats = export_geotiff_tiles_batch(
|
|
4013
|
+
... images_folder='data/images',
|
|
4014
|
+
... masks_folder='data/masks',
|
|
4015
|
+
... output_folder='output/tiles',
|
|
4016
|
+
... match_by_name=True
|
|
4017
|
+
... )
|
|
4018
|
+
|
|
4019
|
+
# Multiple mask files, matched by sorted order
|
|
4020
|
+
>>> stats = export_geotiff_tiles_batch(
|
|
4021
|
+
... images_folder='data/images',
|
|
4022
|
+
... masks_folder='data/masks',
|
|
4023
|
+
... output_folder='output/tiles',
|
|
4024
|
+
... match_by_name=False
|
|
4025
|
+
... )
|
|
3161
4026
|
"""
|
|
3162
4027
|
|
|
3163
4028
|
import logging
|
|
3164
4029
|
|
|
3165
4030
|
logging.getLogger("rasterio").setLevel(logging.ERROR)
|
|
3166
4031
|
|
|
4032
|
+
# Validate input parameters
|
|
4033
|
+
if masks_folder is not None and masks_file is not None:
|
|
4034
|
+
raise ValueError(
|
|
4035
|
+
"Cannot specify both masks_folder and masks_file. Please use only one."
|
|
4036
|
+
)
|
|
4037
|
+
|
|
4038
|
+
# Default output folder if not specified
|
|
4039
|
+
if output_folder is None:
|
|
4040
|
+
output_folder = os.path.join(images_folder, "tiles")
|
|
4041
|
+
|
|
3167
4042
|
# Default extensions if not provided
|
|
3168
4043
|
if image_extensions is None:
|
|
3169
4044
|
image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
|
|
@@ -3190,9 +4065,37 @@ def export_geotiff_tiles_batch(
|
|
|
3190
4065
|
# Create output folder structure
|
|
3191
4066
|
os.makedirs(output_folder, exist_ok=True)
|
|
3192
4067
|
output_images_dir = os.path.join(output_folder, "images")
|
|
3193
|
-
output_masks_dir = os.path.join(output_folder, "masks")
|
|
3194
4068
|
os.makedirs(output_images_dir, exist_ok=True)
|
|
3195
|
-
|
|
4069
|
+
|
|
4070
|
+
# Only create masks directory if masks are provided
|
|
4071
|
+
output_masks_dir = None
|
|
4072
|
+
if masks_folder is not None or masks_file is not None:
|
|
4073
|
+
output_masks_dir = os.path.join(output_folder, "masks")
|
|
4074
|
+
os.makedirs(output_masks_dir, exist_ok=True)
|
|
4075
|
+
|
|
4076
|
+
# Create annotation directory based on metadata format (only if masks are provided)
|
|
4077
|
+
ann_dir = None
|
|
4078
|
+
if (masks_folder is not None or masks_file is not None) and metadata_format in [
|
|
4079
|
+
"PASCAL_VOC",
|
|
4080
|
+
"COCO",
|
|
4081
|
+
]:
|
|
4082
|
+
ann_dir = os.path.join(output_folder, "annotations")
|
|
4083
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
4084
|
+
|
|
4085
|
+
# Initialize COCO annotations dictionary (only if masks are provided)
|
|
4086
|
+
coco_annotations = None
|
|
4087
|
+
if (
|
|
4088
|
+
masks_folder is not None or masks_file is not None
|
|
4089
|
+
) and metadata_format == "COCO":
|
|
4090
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
4091
|
+
|
|
4092
|
+
# Initialize YOLO class set (only if masks are provided)
|
|
4093
|
+
yolo_classes = (
|
|
4094
|
+
set()
|
|
4095
|
+
if (masks_folder is not None or masks_file is not None)
|
|
4096
|
+
and metadata_format == "YOLO"
|
|
4097
|
+
else None
|
|
4098
|
+
)
|
|
3196
4099
|
|
|
3197
4100
|
# Get list of image files
|
|
3198
4101
|
image_files = []
|
|
@@ -3200,30 +4103,105 @@ def export_geotiff_tiles_batch(
|
|
|
3200
4103
|
pattern = os.path.join(images_folder, f"*{ext}")
|
|
3201
4104
|
image_files.extend(glob.glob(pattern))
|
|
3202
4105
|
|
|
3203
|
-
# Get list of mask files
|
|
3204
|
-
mask_files = []
|
|
3205
|
-
for ext in mask_extensions:
|
|
3206
|
-
pattern = os.path.join(masks_folder, f"*{ext}")
|
|
3207
|
-
mask_files.extend(glob.glob(pattern))
|
|
3208
|
-
|
|
3209
4106
|
# Sort files for consistent processing
|
|
3210
4107
|
image_files.sort()
|
|
3211
|
-
mask_files.sort()
|
|
3212
4108
|
|
|
3213
4109
|
if not image_files:
|
|
3214
4110
|
raise ValueError(
|
|
3215
4111
|
f"No image files found in {images_folder} with extensions {image_extensions}"
|
|
3216
4112
|
)
|
|
3217
4113
|
|
|
3218
|
-
|
|
3219
|
-
|
|
3220
|
-
|
|
3221
|
-
|
|
4114
|
+
# Handle different mask input modes
|
|
4115
|
+
use_single_mask_file = masks_file is not None
|
|
4116
|
+
has_masks = masks_file is not None or masks_folder is not None
|
|
4117
|
+
mask_files = []
|
|
4118
|
+
image_mask_pairs = []
|
|
3222
4119
|
|
|
3223
|
-
if
|
|
3224
|
-
|
|
3225
|
-
|
|
3226
|
-
|
|
4120
|
+
if not has_masks:
|
|
4121
|
+
# Mode 0: No masks - create pairs with None for mask
|
|
4122
|
+
for image_file in image_files:
|
|
4123
|
+
image_mask_pairs.append((image_file, None, None))
|
|
4124
|
+
|
|
4125
|
+
elif use_single_mask_file:
|
|
4126
|
+
# Mode 1: Single vector file covering all images
|
|
4127
|
+
if not os.path.exists(masks_file):
|
|
4128
|
+
raise ValueError(f"Mask file not found: {masks_file}")
|
|
4129
|
+
|
|
4130
|
+
# Load the single mask file once - will be spatially filtered per image
|
|
4131
|
+
single_mask_gdf = gpd.read_file(masks_file)
|
|
4132
|
+
|
|
4133
|
+
if not quiet:
|
|
4134
|
+
print(f"Using single mask file: {masks_file}")
|
|
4135
|
+
print(
|
|
4136
|
+
f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
|
|
4137
|
+
)
|
|
4138
|
+
|
|
4139
|
+
# Create pairs with the same mask file for all images
|
|
4140
|
+
for image_file in image_files:
|
|
4141
|
+
image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
|
|
4142
|
+
|
|
4143
|
+
else:
|
|
4144
|
+
# Mode 2/3: Multiple mask files (vector or raster)
|
|
4145
|
+
# Get list of mask files
|
|
4146
|
+
for ext in mask_extensions:
|
|
4147
|
+
pattern = os.path.join(masks_folder, f"*{ext}")
|
|
4148
|
+
mask_files.extend(glob.glob(pattern))
|
|
4149
|
+
|
|
4150
|
+
# Sort files for consistent processing
|
|
4151
|
+
mask_files.sort()
|
|
4152
|
+
|
|
4153
|
+
if not mask_files:
|
|
4154
|
+
raise ValueError(
|
|
4155
|
+
f"No mask files found in {masks_folder} with extensions {mask_extensions}"
|
|
4156
|
+
)
|
|
4157
|
+
|
|
4158
|
+
# Match images to masks
|
|
4159
|
+
if match_by_name:
|
|
4160
|
+
# Match by base filename
|
|
4161
|
+
image_dict = {
|
|
4162
|
+
os.path.splitext(os.path.basename(f))[0]: f for f in image_files
|
|
4163
|
+
}
|
|
4164
|
+
mask_dict = {
|
|
4165
|
+
os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
|
|
4166
|
+
}
|
|
4167
|
+
|
|
4168
|
+
# Find matching pairs
|
|
4169
|
+
for img_base, img_path in image_dict.items():
|
|
4170
|
+
if img_base in mask_dict:
|
|
4171
|
+
image_mask_pairs.append((img_path, mask_dict[img_base], None))
|
|
4172
|
+
else:
|
|
4173
|
+
if not quiet:
|
|
4174
|
+
print(f"Warning: No mask found for image {img_base}")
|
|
4175
|
+
|
|
4176
|
+
if not image_mask_pairs:
|
|
4177
|
+
# Provide detailed error message with found files
|
|
4178
|
+
image_bases = list(image_dict.keys())
|
|
4179
|
+
mask_bases = list(mask_dict.keys())
|
|
4180
|
+
error_msg = (
|
|
4181
|
+
"No matching image-mask pairs found when matching by filename. "
|
|
4182
|
+
"Check that image and mask files have matching base names.\n"
|
|
4183
|
+
f"Found {len(image_bases)} image(s): "
|
|
4184
|
+
f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
|
|
4185
|
+
f"{'...' if len(image_bases) > 5 else ''}\n"
|
|
4186
|
+
f"Found {len(mask_bases)} mask(s): "
|
|
4187
|
+
f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
|
|
4188
|
+
f"{'...' if len(mask_bases) > 5 else ''}\n"
|
|
4189
|
+
"Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
|
|
4190
|
+
)
|
|
4191
|
+
raise ValueError(error_msg)
|
|
4192
|
+
|
|
4193
|
+
else:
|
|
4194
|
+
# Match by sorted order
|
|
4195
|
+
if len(image_files) != len(mask_files):
|
|
4196
|
+
raise ValueError(
|
|
4197
|
+
f"Number of image files ({len(image_files)}) does not match "
|
|
4198
|
+
f"number of mask files ({len(mask_files)}) when matching by sorted order. "
|
|
4199
|
+
f"Use match_by_name=True for filename-based matching."
|
|
4200
|
+
)
|
|
4201
|
+
|
|
4202
|
+
# Create pairs by sorted order
|
|
4203
|
+
for image_file, mask_file in zip(image_files, mask_files):
|
|
4204
|
+
image_mask_pairs.append((image_file, mask_file, None))
|
|
3227
4205
|
|
|
3228
4206
|
# Initialize batch statistics
|
|
3229
4207
|
batch_stats = {
|
|
@@ -3237,23 +4215,28 @@ def export_geotiff_tiles_batch(
|
|
|
3237
4215
|
}
|
|
3238
4216
|
|
|
3239
4217
|
if not quiet:
|
|
3240
|
-
|
|
3241
|
-
|
|
3242
|
-
|
|
3243
|
-
|
|
4218
|
+
if not has_masks:
|
|
4219
|
+
print(
|
|
4220
|
+
f"Found {len(image_files)} image files to process (images only, no masks)"
|
|
4221
|
+
)
|
|
4222
|
+
elif use_single_mask_file:
|
|
4223
|
+
print(f"Found {len(image_files)} image files to process")
|
|
4224
|
+
print(f"Using single mask file: {masks_file}")
|
|
4225
|
+
else:
|
|
4226
|
+
print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
|
|
4227
|
+
print(f"Processing batch from {images_folder} and {masks_folder}")
|
|
3244
4228
|
print(f"Output folder: {output_folder}")
|
|
3245
4229
|
print("-" * 60)
|
|
3246
4230
|
|
|
3247
4231
|
# Global tile counter for unique naming
|
|
3248
4232
|
global_tile_counter = 0
|
|
3249
4233
|
|
|
3250
|
-
# Process each image-mask pair
|
|
3251
|
-
for idx, (image_file, mask_file) in enumerate(
|
|
4234
|
+
# Process each image-mask pair
|
|
4235
|
+
for idx, (image_file, mask_file, mask_gdf) in enumerate(
|
|
3252
4236
|
tqdm(
|
|
3253
|
-
|
|
4237
|
+
image_mask_pairs,
|
|
3254
4238
|
desc="Processing image pairs",
|
|
3255
4239
|
disable=quiet,
|
|
3256
|
-
total=len(image_files),
|
|
3257
4240
|
)
|
|
3258
4241
|
):
|
|
3259
4242
|
batch_stats["total_image_pairs"] += 1
|
|
@@ -3265,9 +4248,17 @@ def export_geotiff_tiles_batch(
|
|
|
3265
4248
|
if not quiet:
|
|
3266
4249
|
print(f"\nProcessing: {base_name}")
|
|
3267
4250
|
print(f" Image: {os.path.basename(image_file)}")
|
|
3268
|
-
|
|
4251
|
+
if mask_file is not None:
|
|
4252
|
+
if use_single_mask_file:
|
|
4253
|
+
print(
|
|
4254
|
+
f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
|
|
4255
|
+
)
|
|
4256
|
+
else:
|
|
4257
|
+
print(f" Mask: {os.path.basename(mask_file)}")
|
|
4258
|
+
else:
|
|
4259
|
+
print(f" Mask: None (images only)")
|
|
3269
4260
|
|
|
3270
|
-
# Process the image-mask pair
|
|
4261
|
+
# Process the image-mask pair
|
|
3271
4262
|
tiles_generated = _process_image_mask_pair(
|
|
3272
4263
|
image_file=image_file,
|
|
3273
4264
|
mask_file=mask_file,
|
|
@@ -3283,6 +4274,15 @@ def export_geotiff_tiles_batch(
|
|
|
3283
4274
|
all_touched=all_touched,
|
|
3284
4275
|
skip_empty_tiles=skip_empty_tiles,
|
|
3285
4276
|
quiet=quiet,
|
|
4277
|
+
mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
|
|
4278
|
+
use_single_mask_file=use_single_mask_file,
|
|
4279
|
+
metadata_format=metadata_format,
|
|
4280
|
+
ann_dir=(
|
|
4281
|
+
ann_dir
|
|
4282
|
+
if "ann_dir" in locals()
|
|
4283
|
+
and metadata_format in ["PASCAL_VOC", "COCO"]
|
|
4284
|
+
else None
|
|
4285
|
+
),
|
|
3286
4286
|
)
|
|
3287
4287
|
|
|
3288
4288
|
# Update counters
|
|
@@ -3304,6 +4304,23 @@ def export_geotiff_tiles_batch(
|
|
|
3304
4304
|
}
|
|
3305
4305
|
)
|
|
3306
4306
|
|
|
4307
|
+
# Aggregate COCO annotations
|
|
4308
|
+
if metadata_format == "COCO" and "coco_data" in tiles_generated:
|
|
4309
|
+
coco_data = tiles_generated["coco_data"]
|
|
4310
|
+
# Add images and annotations
|
|
4311
|
+
coco_annotations["images"].extend(coco_data.get("images", []))
|
|
4312
|
+
coco_annotations["annotations"].extend(coco_data.get("annotations", []))
|
|
4313
|
+
# Merge categories (avoid duplicates)
|
|
4314
|
+
for cat in coco_data.get("categories", []):
|
|
4315
|
+
if not any(
|
|
4316
|
+
c["id"] == cat["id"] for c in coco_annotations["categories"]
|
|
4317
|
+
):
|
|
4318
|
+
coco_annotations["categories"].append(cat)
|
|
4319
|
+
|
|
4320
|
+
# Aggregate YOLO classes
|
|
4321
|
+
if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
|
|
4322
|
+
yolo_classes.update(tiles_generated["yolo_classes"])
|
|
4323
|
+
|
|
3307
4324
|
except Exception as e:
|
|
3308
4325
|
if not quiet:
|
|
3309
4326
|
print(f"ERROR processing {base_name}: {e}")
|
|
@@ -3312,6 +4329,33 @@ def export_geotiff_tiles_batch(
|
|
|
3312
4329
|
)
|
|
3313
4330
|
batch_stats["errors"] += 1
|
|
3314
4331
|
|
|
4332
|
+
# Save aggregated COCO annotations
|
|
4333
|
+
if metadata_format == "COCO" and coco_annotations:
|
|
4334
|
+
import json
|
|
4335
|
+
|
|
4336
|
+
coco_path = os.path.join(ann_dir, "instances.json")
|
|
4337
|
+
with open(coco_path, "w") as f:
|
|
4338
|
+
json.dump(coco_annotations, f, indent=2)
|
|
4339
|
+
if not quiet:
|
|
4340
|
+
print(f"\nSaved COCO annotations: {coco_path}")
|
|
4341
|
+
print(
|
|
4342
|
+
f" Images: {len(coco_annotations['images'])}, "
|
|
4343
|
+
f"Annotations: {len(coco_annotations['annotations'])}, "
|
|
4344
|
+
f"Categories: {len(coco_annotations['categories'])}"
|
|
4345
|
+
)
|
|
4346
|
+
|
|
4347
|
+
# Save aggregated YOLO classes
|
|
4348
|
+
if metadata_format == "YOLO" and yolo_classes:
|
|
4349
|
+
classes_path = os.path.join(output_folder, "labels", "classes.txt")
|
|
4350
|
+
os.makedirs(os.path.dirname(classes_path), exist_ok=True)
|
|
4351
|
+
sorted_classes = sorted(yolo_classes)
|
|
4352
|
+
with open(classes_path, "w") as f:
|
|
4353
|
+
for cls in sorted_classes:
|
|
4354
|
+
f.write(f"{cls}\n")
|
|
4355
|
+
if not quiet:
|
|
4356
|
+
print(f"\nSaved YOLO classes: {classes_path}")
|
|
4357
|
+
print(f" Total classes: {len(sorted_classes)}")
|
|
4358
|
+
|
|
3315
4359
|
# Print batch summary
|
|
3316
4360
|
if not quiet:
|
|
3317
4361
|
print("\n" + "=" * 60)
|
|
@@ -3334,7 +4378,12 @@ def export_geotiff_tiles_batch(
|
|
|
3334
4378
|
|
|
3335
4379
|
print(f"Output saved to: {output_folder}")
|
|
3336
4380
|
print(f" Images: {output_images_dir}")
|
|
3337
|
-
|
|
4381
|
+
if output_masks_dir is not None:
|
|
4382
|
+
print(f" Masks: {output_masks_dir}")
|
|
4383
|
+
if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
|
|
4384
|
+
print(f" Annotations: {ann_dir}")
|
|
4385
|
+
elif metadata_format == "YOLO":
|
|
4386
|
+
print(f" Labels: {os.path.join(output_folder, 'labels')}")
|
|
3338
4387
|
|
|
3339
4388
|
# List failed files if any
|
|
3340
4389
|
if batch_stats["failed_files"]:
|
|
@@ -3360,18 +4409,26 @@ def _process_image_mask_pair(
|
|
|
3360
4409
|
all_touched=True,
|
|
3361
4410
|
skip_empty_tiles=False,
|
|
3362
4411
|
quiet=False,
|
|
4412
|
+
mask_gdf=None,
|
|
4413
|
+
use_single_mask_file=False,
|
|
4414
|
+
metadata_format="PASCAL_VOC",
|
|
4415
|
+
ann_dir=None,
|
|
3363
4416
|
):
|
|
3364
4417
|
"""
|
|
3365
4418
|
Process a single image-mask pair and save tiles directly to output directories.
|
|
3366
4419
|
|
|
4420
|
+
Args:
|
|
4421
|
+
mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
|
|
4422
|
+
use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
|
|
4423
|
+
|
|
3367
4424
|
Returns:
|
|
3368
4425
|
dict: Statistics for this image-mask pair
|
|
3369
4426
|
"""
|
|
3370
4427
|
import warnings
|
|
3371
4428
|
|
|
3372
|
-
# Determine if mask data is raster or vector
|
|
4429
|
+
# Determine if mask data is raster or vector (only if mask_file is provided)
|
|
3373
4430
|
is_class_data_raster = False
|
|
3374
|
-
if isinstance(mask_file, str):
|
|
4431
|
+
if mask_file is not None and isinstance(mask_file, str):
|
|
3375
4432
|
file_ext = Path(mask_file).suffix.lower()
|
|
3376
4433
|
# Common raster extensions
|
|
3377
4434
|
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
@@ -3388,6 +4445,13 @@ def _process_image_mask_pair(
|
|
|
3388
4445
|
"errors": 0,
|
|
3389
4446
|
}
|
|
3390
4447
|
|
|
4448
|
+
# Initialize COCO/YOLO tracking for this image
|
|
4449
|
+
if metadata_format == "COCO":
|
|
4450
|
+
stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
|
|
4451
|
+
coco_ann_id = 0
|
|
4452
|
+
if metadata_format == "YOLO":
|
|
4453
|
+
stats["yolo_classes"] = set()
|
|
4454
|
+
|
|
3391
4455
|
# Open the input raster
|
|
3392
4456
|
with rasterio.open(image_file) as src:
|
|
3393
4457
|
# Calculate number of tiles
|
|
@@ -3398,10 +4462,10 @@ def _process_image_mask_pair(
|
|
|
3398
4462
|
if max_tiles is None:
|
|
3399
4463
|
max_tiles = total_tiles
|
|
3400
4464
|
|
|
3401
|
-
# Process classification data
|
|
4465
|
+
# Process classification data (only if mask_file is provided)
|
|
3402
4466
|
class_to_id = {}
|
|
3403
4467
|
|
|
3404
|
-
if is_class_data_raster:
|
|
4468
|
+
if mask_file is not None and is_class_data_raster:
|
|
3405
4469
|
# Load raster class data
|
|
3406
4470
|
with rasterio.open(mask_file) as class_src:
|
|
3407
4471
|
# Check if raster CRS matches
|
|
@@ -3428,14 +4492,39 @@ def _process_image_mask_pair(
|
|
|
3428
4492
|
|
|
3429
4493
|
# Create class mapping
|
|
3430
4494
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
3431
|
-
|
|
4495
|
+
elif mask_file is not None:
|
|
3432
4496
|
# Load vector class data
|
|
3433
4497
|
try:
|
|
3434
|
-
|
|
4498
|
+
if use_single_mask_file and mask_gdf is not None:
|
|
4499
|
+
# Using pre-loaded single mask file - spatially filter to image bounds
|
|
4500
|
+
# Get image bounds
|
|
4501
|
+
image_bounds = box(*src.bounds)
|
|
4502
|
+
image_gdf = gpd.GeoDataFrame(
|
|
4503
|
+
{"geometry": [image_bounds]}, crs=src.crs
|
|
4504
|
+
)
|
|
3435
4505
|
|
|
3436
|
-
|
|
3437
|
-
|
|
3438
|
-
|
|
4506
|
+
# Reproject mask if needed
|
|
4507
|
+
if mask_gdf.crs != src.crs:
|
|
4508
|
+
mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
|
|
4509
|
+
else:
|
|
4510
|
+
mask_gdf_reprojected = mask_gdf
|
|
4511
|
+
|
|
4512
|
+
# Spatially filter features that intersect with image bounds
|
|
4513
|
+
gdf = mask_gdf_reprojected[
|
|
4514
|
+
mask_gdf_reprojected.intersects(image_bounds)
|
|
4515
|
+
].copy()
|
|
4516
|
+
|
|
4517
|
+
if not quiet and len(gdf) > 0:
|
|
4518
|
+
print(
|
|
4519
|
+
f" Filtered to {len(gdf)} features intersecting image bounds"
|
|
4520
|
+
)
|
|
4521
|
+
else:
|
|
4522
|
+
# Load individual mask file
|
|
4523
|
+
gdf = gpd.read_file(mask_file)
|
|
4524
|
+
|
|
4525
|
+
# Always reproject to match raster CRS
|
|
4526
|
+
if gdf.crs != src.crs:
|
|
4527
|
+
gdf = gdf.to_crs(src.crs)
|
|
3439
4528
|
|
|
3440
4529
|
# Apply buffer if specified
|
|
3441
4530
|
if buffer_radius > 0:
|
|
@@ -3455,9 +4544,6 @@ def _process_image_mask_pair(
|
|
|
3455
4544
|
tile_index = 0
|
|
3456
4545
|
for y in range(num_tiles_y):
|
|
3457
4546
|
for x in range(num_tiles_x):
|
|
3458
|
-
if tile_index >= max_tiles:
|
|
3459
|
-
break
|
|
3460
|
-
|
|
3461
4547
|
# Calculate window coordinates
|
|
3462
4548
|
window_x = x * stride
|
|
3463
4549
|
window_y = y * stride
|
|
@@ -3482,12 +4568,12 @@ def _process_image_mask_pair(
|
|
|
3482
4568
|
|
|
3483
4569
|
window_bounds = box(minx, miny, maxx, maxy)
|
|
3484
4570
|
|
|
3485
|
-
# Create label mask
|
|
4571
|
+
# Create label mask (only if mask_file is provided)
|
|
3486
4572
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
3487
4573
|
has_features = False
|
|
3488
4574
|
|
|
3489
|
-
# Process classification data to create labels
|
|
3490
|
-
if is_class_data_raster:
|
|
4575
|
+
# Process classification data to create labels (only if mask_file is provided)
|
|
4576
|
+
if mask_file is not None and is_class_data_raster:
|
|
3491
4577
|
# For raster class data
|
|
3492
4578
|
with rasterio.open(mask_file) as class_src:
|
|
3493
4579
|
# Get corresponding window in class raster
|
|
@@ -3520,7 +4606,7 @@ def _process_image_mask_pair(
|
|
|
3520
4606
|
if not quiet:
|
|
3521
4607
|
print(f"Error reading class raster window: {e}")
|
|
3522
4608
|
stats["errors"] += 1
|
|
3523
|
-
|
|
4609
|
+
elif mask_file is not None:
|
|
3524
4610
|
# For vector class data
|
|
3525
4611
|
# Find features that intersect with window
|
|
3526
4612
|
window_features = gdf[gdf.intersects(window_bounds)]
|
|
@@ -3558,11 +4644,14 @@ def _process_image_mask_pair(
|
|
|
3558
4644
|
print(f"Error rasterizing feature {idx}: {e}")
|
|
3559
4645
|
stats["errors"] += 1
|
|
3560
4646
|
|
|
3561
|
-
# Skip tile if no features and skip_empty_tiles is True
|
|
3562
|
-
if skip_empty_tiles and not has_features:
|
|
3563
|
-
tile_index += 1
|
|
4647
|
+
# Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
|
|
4648
|
+
if mask_file is not None and skip_empty_tiles and not has_features:
|
|
3564
4649
|
continue
|
|
3565
4650
|
|
|
4651
|
+
# Check if we've reached max_tiles before saving
|
|
4652
|
+
if tile_index >= max_tiles:
|
|
4653
|
+
break
|
|
4654
|
+
|
|
3566
4655
|
# Generate unique tile name
|
|
3567
4656
|
tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
|
|
3568
4657
|
|
|
@@ -3593,29 +4682,225 @@ def _process_image_mask_pair(
|
|
|
3593
4682
|
print(f"ERROR saving image GeoTIFF: {e}")
|
|
3594
4683
|
stats["errors"] += 1
|
|
3595
4684
|
|
|
3596
|
-
#
|
|
3597
|
-
|
|
3598
|
-
|
|
3599
|
-
|
|
3600
|
-
|
|
3601
|
-
|
|
3602
|
-
|
|
3603
|
-
|
|
3604
|
-
|
|
3605
|
-
|
|
4685
|
+
# Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
|
|
4686
|
+
if mask_file is not None and output_masks_dir is not None:
|
|
4687
|
+
# Create profile for label GeoTIFF
|
|
4688
|
+
label_profile = {
|
|
4689
|
+
"driver": "GTiff",
|
|
4690
|
+
"height": tile_size,
|
|
4691
|
+
"width": tile_size,
|
|
4692
|
+
"count": 1,
|
|
4693
|
+
"dtype": "uint8",
|
|
4694
|
+
"crs": src.crs,
|
|
4695
|
+
"transform": window_transform,
|
|
4696
|
+
}
|
|
3606
4697
|
|
|
3607
|
-
|
|
3608
|
-
|
|
3609
|
-
|
|
3610
|
-
|
|
3611
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
|
4698
|
+
label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
|
|
4699
|
+
try:
|
|
4700
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
4701
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
|
3612
4702
|
|
|
3613
|
-
|
|
3614
|
-
|
|
3615
|
-
|
|
3616
|
-
|
|
3617
|
-
|
|
3618
|
-
|
|
4703
|
+
if has_features:
|
|
4704
|
+
stats["tiles_with_features"] += 1
|
|
4705
|
+
except Exception as e:
|
|
4706
|
+
if not quiet:
|
|
4707
|
+
print(f"ERROR saving label GeoTIFF: {e}")
|
|
4708
|
+
stats["errors"] += 1
|
|
4709
|
+
|
|
4710
|
+
# Generate annotation metadata based on format (only if mask_file is provided)
|
|
4711
|
+
if (
|
|
4712
|
+
mask_file is not None
|
|
4713
|
+
and metadata_format == "PASCAL_VOC"
|
|
4714
|
+
and ann_dir
|
|
4715
|
+
):
|
|
4716
|
+
# Create PASCAL VOC XML annotation
|
|
4717
|
+
from lxml import etree as ET
|
|
4718
|
+
|
|
4719
|
+
annotation = ET.Element("annotation")
|
|
4720
|
+
ET.SubElement(annotation, "folder").text = os.path.basename(
|
|
4721
|
+
output_images_dir
|
|
4722
|
+
)
|
|
4723
|
+
ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
|
|
4724
|
+
ET.SubElement(annotation, "path").text = image_path
|
|
4725
|
+
|
|
4726
|
+
source = ET.SubElement(annotation, "source")
|
|
4727
|
+
ET.SubElement(source, "database").text = "GeoAI"
|
|
4728
|
+
|
|
4729
|
+
size = ET.SubElement(annotation, "size")
|
|
4730
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
|
4731
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
|
4732
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
4733
|
+
|
|
4734
|
+
ET.SubElement(annotation, "segmented").text = "1"
|
|
4735
|
+
|
|
4736
|
+
# Find connected components for instance segmentation
|
|
4737
|
+
from scipy import ndimage
|
|
4738
|
+
|
|
4739
|
+
for class_id in np.unique(label_mask):
|
|
4740
|
+
if class_id == 0:
|
|
4741
|
+
continue
|
|
4742
|
+
|
|
4743
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4744
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
|
4745
|
+
|
|
4746
|
+
for instance_id in range(1, num_features + 1):
|
|
4747
|
+
instance_mask = labeled_array == instance_id
|
|
4748
|
+
coords = np.argwhere(instance_mask)
|
|
4749
|
+
|
|
4750
|
+
if len(coords) == 0:
|
|
4751
|
+
continue
|
|
4752
|
+
|
|
4753
|
+
ymin, xmin = coords.min(axis=0)
|
|
4754
|
+
ymax, xmax = coords.max(axis=0)
|
|
4755
|
+
|
|
4756
|
+
obj = ET.SubElement(annotation, "object")
|
|
4757
|
+
class_name = next(
|
|
4758
|
+
(k for k, v in class_to_id.items() if v == class_id),
|
|
4759
|
+
str(class_id),
|
|
4760
|
+
)
|
|
4761
|
+
ET.SubElement(obj, "name").text = str(class_name)
|
|
4762
|
+
ET.SubElement(obj, "pose").text = "Unspecified"
|
|
4763
|
+
ET.SubElement(obj, "truncated").text = "0"
|
|
4764
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
4765
|
+
|
|
4766
|
+
bndbox = ET.SubElement(obj, "bndbox")
|
|
4767
|
+
ET.SubElement(bndbox, "xmin").text = str(int(xmin))
|
|
4768
|
+
ET.SubElement(bndbox, "ymin").text = str(int(ymin))
|
|
4769
|
+
ET.SubElement(bndbox, "xmax").text = str(int(xmax))
|
|
4770
|
+
ET.SubElement(bndbox, "ymax").text = str(int(ymax))
|
|
4771
|
+
|
|
4772
|
+
# Save XML file
|
|
4773
|
+
xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
|
|
4774
|
+
tree = ET.ElementTree(annotation)
|
|
4775
|
+
tree.write(xml_path, pretty_print=True, encoding="utf-8")
|
|
4776
|
+
|
|
4777
|
+
elif mask_file is not None and metadata_format == "COCO":
|
|
4778
|
+
# Add COCO image entry
|
|
4779
|
+
image_id = int(global_tile_counter + tile_index)
|
|
4780
|
+
stats["coco_data"]["images"].append(
|
|
4781
|
+
{
|
|
4782
|
+
"id": image_id,
|
|
4783
|
+
"file_name": f"{tile_name}.tif",
|
|
4784
|
+
"width": int(tile_size),
|
|
4785
|
+
"height": int(tile_size),
|
|
4786
|
+
}
|
|
4787
|
+
)
|
|
4788
|
+
|
|
4789
|
+
# Add COCO categories (only once per unique class)
|
|
4790
|
+
for class_val, class_id in class_to_id.items():
|
|
4791
|
+
if not any(
|
|
4792
|
+
c["id"] == class_id
|
|
4793
|
+
for c in stats["coco_data"]["categories"]
|
|
4794
|
+
):
|
|
4795
|
+
stats["coco_data"]["categories"].append(
|
|
4796
|
+
{
|
|
4797
|
+
"id": int(class_id),
|
|
4798
|
+
"name": str(class_val),
|
|
4799
|
+
"supercategory": "object",
|
|
4800
|
+
}
|
|
4801
|
+
)
|
|
4802
|
+
|
|
4803
|
+
# Add COCO annotations (instance segmentation)
|
|
4804
|
+
from scipy import ndimage
|
|
4805
|
+
from skimage import measure
|
|
4806
|
+
|
|
4807
|
+
for class_id in np.unique(label_mask):
|
|
4808
|
+
if class_id == 0:
|
|
4809
|
+
continue
|
|
4810
|
+
|
|
4811
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4812
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
|
4813
|
+
|
|
4814
|
+
for instance_id in range(1, num_features + 1):
|
|
4815
|
+
instance_mask = (labeled_array == instance_id).astype(
|
|
4816
|
+
np.uint8
|
|
4817
|
+
)
|
|
4818
|
+
coords = np.argwhere(instance_mask)
|
|
4819
|
+
|
|
4820
|
+
if len(coords) == 0:
|
|
4821
|
+
continue
|
|
4822
|
+
|
|
4823
|
+
ymin, xmin = coords.min(axis=0)
|
|
4824
|
+
ymax, xmax = coords.max(axis=0)
|
|
4825
|
+
|
|
4826
|
+
bbox = [
|
|
4827
|
+
int(xmin),
|
|
4828
|
+
int(ymin),
|
|
4829
|
+
int(xmax - xmin),
|
|
4830
|
+
int(ymax - ymin),
|
|
4831
|
+
]
|
|
4832
|
+
area = int(np.sum(instance_mask))
|
|
4833
|
+
|
|
4834
|
+
# Find contours for segmentation
|
|
4835
|
+
contours = measure.find_contours(instance_mask, 0.5)
|
|
4836
|
+
segmentation = []
|
|
4837
|
+
for contour in contours:
|
|
4838
|
+
contour = np.flip(contour, axis=1)
|
|
4839
|
+
segmentation_points = contour.ravel().tolist()
|
|
4840
|
+
if len(segmentation_points) >= 6:
|
|
4841
|
+
segmentation.append(segmentation_points)
|
|
4842
|
+
|
|
4843
|
+
if segmentation:
|
|
4844
|
+
stats["coco_data"]["annotations"].append(
|
|
4845
|
+
{
|
|
4846
|
+
"id": int(coco_ann_id),
|
|
4847
|
+
"image_id": int(image_id),
|
|
4848
|
+
"category_id": int(class_id),
|
|
4849
|
+
"bbox": bbox,
|
|
4850
|
+
"area": area,
|
|
4851
|
+
"segmentation": segmentation,
|
|
4852
|
+
"iscrowd": 0,
|
|
4853
|
+
}
|
|
4854
|
+
)
|
|
4855
|
+
coco_ann_id += 1
|
|
4856
|
+
|
|
4857
|
+
elif mask_file is not None and metadata_format == "YOLO":
|
|
4858
|
+
# Create YOLO labels directory if needed
|
|
4859
|
+
labels_dir = os.path.join(
|
|
4860
|
+
os.path.dirname(output_images_dir), "labels"
|
|
4861
|
+
)
|
|
4862
|
+
os.makedirs(labels_dir, exist_ok=True)
|
|
4863
|
+
|
|
4864
|
+
# Generate YOLO annotation file
|
|
4865
|
+
yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
|
|
4866
|
+
from scipy import ndimage
|
|
4867
|
+
|
|
4868
|
+
with open(yolo_path, "w") as yolo_file:
|
|
4869
|
+
for class_id in np.unique(label_mask):
|
|
4870
|
+
if class_id == 0:
|
|
4871
|
+
continue
|
|
4872
|
+
|
|
4873
|
+
# Track class for classes.txt
|
|
4874
|
+
class_name = next(
|
|
4875
|
+
(k for k, v in class_to_id.items() if v == class_id),
|
|
4876
|
+
str(class_id),
|
|
4877
|
+
)
|
|
4878
|
+
stats["yolo_classes"].add(class_name)
|
|
4879
|
+
|
|
4880
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4881
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
|
4882
|
+
|
|
4883
|
+
for instance_id in range(1, num_features + 1):
|
|
4884
|
+
instance_mask = labeled_array == instance_id
|
|
4885
|
+
coords = np.argwhere(instance_mask)
|
|
4886
|
+
|
|
4887
|
+
if len(coords) == 0:
|
|
4888
|
+
continue
|
|
4889
|
+
|
|
4890
|
+
ymin, xmin = coords.min(axis=0)
|
|
4891
|
+
ymax, xmax = coords.max(axis=0)
|
|
4892
|
+
|
|
4893
|
+
# Convert to YOLO format (normalized center coordinates)
|
|
4894
|
+
x_center = ((xmin + xmax) / 2) / tile_size
|
|
4895
|
+
y_center = ((ymin + ymax) / 2) / tile_size
|
|
4896
|
+
width = (xmax - xmin) / tile_size
|
|
4897
|
+
height = (ymax - ymin) / tile_size
|
|
4898
|
+
|
|
4899
|
+
# YOLO uses 0-based class indices
|
|
4900
|
+
yolo_class_id = class_id - 1
|
|
4901
|
+
yolo_file.write(
|
|
4902
|
+
f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
|
4903
|
+
)
|
|
3619
4904
|
|
|
3620
4905
|
tile_index += 1
|
|
3621
4906
|
if tile_index >= max_tiles:
|
|
@@ -3627,6 +4912,179 @@ def _process_image_mask_pair(
|
|
|
3627
4912
|
return stats
|
|
3628
4913
|
|
|
3629
4914
|
|
|
4915
|
+
def display_training_tiles(
|
|
4916
|
+
output_dir,
|
|
4917
|
+
num_tiles=6,
|
|
4918
|
+
figsize=(18, 6),
|
|
4919
|
+
cmap="gray",
|
|
4920
|
+
save_path=None,
|
|
4921
|
+
):
|
|
4922
|
+
"""
|
|
4923
|
+
Display image and mask tile pairs from training data output.
|
|
4924
|
+
|
|
4925
|
+
Args:
|
|
4926
|
+
output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
|
|
4927
|
+
num_tiles (int): Number of tile pairs to display (default: 6)
|
|
4928
|
+
figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
|
|
4929
|
+
cmap (str): Colormap for mask display (default: 'gray')
|
|
4930
|
+
save_path (str, optional): If provided, save figure to this path instead of displaying
|
|
4931
|
+
|
|
4932
|
+
Returns:
|
|
4933
|
+
tuple: (fig, axes) matplotlib figure and axes objects
|
|
4934
|
+
|
|
4935
|
+
Example:
|
|
4936
|
+
>>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
|
|
4937
|
+
>>> # Or save to file
|
|
4938
|
+
>>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
|
|
4939
|
+
"""
|
|
4940
|
+
import matplotlib.pyplot as plt
|
|
4941
|
+
|
|
4942
|
+
# Get list of image tiles
|
|
4943
|
+
images_dir = os.path.join(output_dir, "images")
|
|
4944
|
+
if not os.path.exists(images_dir):
|
|
4945
|
+
raise ValueError(f"Images directory not found: {images_dir}")
|
|
4946
|
+
|
|
4947
|
+
image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
|
|
4948
|
+
|
|
4949
|
+
if not image_tiles:
|
|
4950
|
+
raise ValueError(f"No image tiles found in {images_dir}")
|
|
4951
|
+
|
|
4952
|
+
# Limit to available tiles
|
|
4953
|
+
num_tiles = min(num_tiles, len(image_tiles))
|
|
4954
|
+
|
|
4955
|
+
# Create figure with subplots
|
|
4956
|
+
fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
|
|
4957
|
+
|
|
4958
|
+
# Handle case where num_tiles is 1
|
|
4959
|
+
if num_tiles == 1:
|
|
4960
|
+
axes = axes.reshape(2, 1)
|
|
4961
|
+
|
|
4962
|
+
for idx, tile_name in enumerate(image_tiles):
|
|
4963
|
+
# Load and display image tile
|
|
4964
|
+
image_path = os.path.join(output_dir, "images", tile_name)
|
|
4965
|
+
with rasterio.open(image_path) as src:
|
|
4966
|
+
show(src, ax=axes[0, idx], title=f"Image {idx+1}")
|
|
4967
|
+
|
|
4968
|
+
# Load and display mask tile
|
|
4969
|
+
mask_path = os.path.join(output_dir, "masks", tile_name)
|
|
4970
|
+
if os.path.exists(mask_path):
|
|
4971
|
+
with rasterio.open(mask_path) as src:
|
|
4972
|
+
show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
|
|
4973
|
+
else:
|
|
4974
|
+
axes[1, idx].text(
|
|
4975
|
+
0.5,
|
|
4976
|
+
0.5,
|
|
4977
|
+
"Mask not found",
|
|
4978
|
+
ha="center",
|
|
4979
|
+
va="center",
|
|
4980
|
+
transform=axes[1, idx].transAxes,
|
|
4981
|
+
)
|
|
4982
|
+
axes[1, idx].set_title(f"Mask {idx+1}")
|
|
4983
|
+
|
|
4984
|
+
plt.tight_layout()
|
|
4985
|
+
|
|
4986
|
+
# Save or show
|
|
4987
|
+
if save_path:
|
|
4988
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
4989
|
+
plt.close(fig)
|
|
4990
|
+
print(f"Figure saved to: {save_path}")
|
|
4991
|
+
else:
|
|
4992
|
+
plt.show()
|
|
4993
|
+
|
|
4994
|
+
return fig, axes
|
|
4995
|
+
|
|
4996
|
+
|
|
4997
|
+
def display_image_with_vector(
|
|
4998
|
+
image_path,
|
|
4999
|
+
vector_path,
|
|
5000
|
+
figsize=(16, 8),
|
|
5001
|
+
vector_color="red",
|
|
5002
|
+
vector_linewidth=1,
|
|
5003
|
+
vector_facecolor="none",
|
|
5004
|
+
save_path=None,
|
|
5005
|
+
):
|
|
5006
|
+
"""
|
|
5007
|
+
Display a raster image alongside the same image with vector overlay.
|
|
5008
|
+
|
|
5009
|
+
Args:
|
|
5010
|
+
image_path (str): Path to raster image file
|
|
5011
|
+
vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
|
|
5012
|
+
figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
|
|
5013
|
+
vector_color (str): Edge color for vector features (default: 'red')
|
|
5014
|
+
vector_linewidth (float): Line width for vector features (default: 1)
|
|
5015
|
+
vector_facecolor (str): Fill color for vector features (default: 'none')
|
|
5016
|
+
save_path (str, optional): If provided, save figure to this path instead of displaying
|
|
5017
|
+
|
|
5018
|
+
Returns:
|
|
5019
|
+
tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
|
|
5020
|
+
|
|
5021
|
+
Example:
|
|
5022
|
+
>>> fig, axes, info = display_image_with_vector(
|
|
5023
|
+
... 'image.tif',
|
|
5024
|
+
... 'buildings.geojson',
|
|
5025
|
+
... vector_color='blue'
|
|
5026
|
+
... )
|
|
5027
|
+
>>> print(f"Number of features: {info['num_features']}")
|
|
5028
|
+
"""
|
|
5029
|
+
import matplotlib.pyplot as plt
|
|
5030
|
+
|
|
5031
|
+
# Validate inputs
|
|
5032
|
+
if not os.path.exists(image_path):
|
|
5033
|
+
raise ValueError(f"Image file not found: {image_path}")
|
|
5034
|
+
if not os.path.exists(vector_path):
|
|
5035
|
+
raise ValueError(f"Vector file not found: {vector_path}")
|
|
5036
|
+
|
|
5037
|
+
# Create figure
|
|
5038
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
5039
|
+
|
|
5040
|
+
# Load and display image
|
|
5041
|
+
with rasterio.open(image_path) as src:
|
|
5042
|
+
# Plot image only
|
|
5043
|
+
show(src, ax=ax1, title="Image")
|
|
5044
|
+
|
|
5045
|
+
# Load vector data
|
|
5046
|
+
vector_data = gpd.read_file(vector_path)
|
|
5047
|
+
|
|
5048
|
+
# Reproject to image CRS if needed
|
|
5049
|
+
if vector_data.crs != src.crs:
|
|
5050
|
+
vector_data = vector_data.to_crs(src.crs)
|
|
5051
|
+
|
|
5052
|
+
# Plot image with vector overlay
|
|
5053
|
+
show(
|
|
5054
|
+
src,
|
|
5055
|
+
ax=ax2,
|
|
5056
|
+
title=f"Image with {len(vector_data)} Vector Features",
|
|
5057
|
+
)
|
|
5058
|
+
vector_data.plot(
|
|
5059
|
+
ax=ax2,
|
|
5060
|
+
facecolor=vector_facecolor,
|
|
5061
|
+
edgecolor=vector_color,
|
|
5062
|
+
linewidth=vector_linewidth,
|
|
5063
|
+
)
|
|
5064
|
+
|
|
5065
|
+
# Collect metadata
|
|
5066
|
+
info = {
|
|
5067
|
+
"image_shape": src.shape,
|
|
5068
|
+
"image_crs": src.crs,
|
|
5069
|
+
"image_bounds": src.bounds,
|
|
5070
|
+
"num_features": len(vector_data),
|
|
5071
|
+
"vector_crs": vector_data.crs,
|
|
5072
|
+
"vector_bounds": vector_data.total_bounds,
|
|
5073
|
+
}
|
|
5074
|
+
|
|
5075
|
+
plt.tight_layout()
|
|
5076
|
+
|
|
5077
|
+
# Save or show
|
|
5078
|
+
if save_path:
|
|
5079
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
5080
|
+
plt.close(fig)
|
|
5081
|
+
print(f"Figure saved to: {save_path}")
|
|
5082
|
+
else:
|
|
5083
|
+
plt.show()
|
|
5084
|
+
|
|
5085
|
+
return fig, (ax1, ax2), info
|
|
5086
|
+
|
|
5087
|
+
|
|
3630
5088
|
def create_overview_image(
|
|
3631
5089
|
src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
|
|
3632
5090
|
) -> str:
|
|
@@ -7521,17 +8979,39 @@ def write_colormap(
|
|
|
7521
8979
|
|
|
7522
8980
|
def plot_performance_metrics(
|
|
7523
8981
|
history_path: str,
|
|
7524
|
-
figsize: Tuple[int, int] =
|
|
8982
|
+
figsize: Optional[Tuple[int, int]] = None,
|
|
7525
8983
|
verbose: bool = True,
|
|
7526
8984
|
save_path: Optional[str] = None,
|
|
8985
|
+
csv_path: Optional[str] = None,
|
|
7527
8986
|
kwargs: Optional[Dict] = None,
|
|
7528
|
-
) ->
|
|
7529
|
-
"""Plot performance metrics from a history object.
|
|
8987
|
+
) -> pd.DataFrame:
|
|
8988
|
+
"""Plot performance metrics from a training history object and return as DataFrame.
|
|
8989
|
+
|
|
8990
|
+
This function loads training history, plots available metrics (loss, IoU, F1,
|
|
8991
|
+
precision, recall), optionally exports to CSV, and returns all metrics as a
|
|
8992
|
+
pandas DataFrame for further analysis.
|
|
7530
8993
|
|
|
7531
8994
|
Args:
|
|
7532
|
-
history_path:
|
|
7533
|
-
figsize:
|
|
7534
|
-
|
|
8995
|
+
history_path (str): Path to the saved training history (.pth file).
|
|
8996
|
+
figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
|
|
8997
|
+
automatically determined based on number of metrics.
|
|
8998
|
+
verbose (bool): Whether to print best and final metric values. Defaults to True.
|
|
8999
|
+
save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
|
|
9000
|
+
csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
|
|
9001
|
+
kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
|
|
9002
|
+
|
|
9003
|
+
Returns:
|
|
9004
|
+
pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
|
|
9005
|
+
Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
|
|
9006
|
+
'val_precision', 'val_recall' (depending on availability in history).
|
|
9007
|
+
|
|
9008
|
+
Example:
|
|
9009
|
+
>>> df = plot_performance_metrics(
|
|
9010
|
+
... 'training_history.pth',
|
|
9011
|
+
... save_path='metrics_plot.png',
|
|
9012
|
+
... csv_path='metrics.csv'
|
|
9013
|
+
... )
|
|
9014
|
+
>>> print(df.head())
|
|
7535
9015
|
"""
|
|
7536
9016
|
if kwargs is None:
|
|
7537
9017
|
kwargs = {}
|
|
@@ -7541,65 +9021,135 @@ def plot_performance_metrics(
|
|
|
7541
9021
|
train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
|
|
7542
9022
|
val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
|
|
7543
9023
|
val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
|
|
7544
|
-
|
|
9024
|
+
# Support both new (f1) and old (dice) key formats for backward compatibility
|
|
9025
|
+
val_f1_key = (
|
|
9026
|
+
"val_f1s"
|
|
9027
|
+
if "val_f1s" in history
|
|
9028
|
+
else ("val_dices" if "val_dices" in history else "val_dice")
|
|
9029
|
+
)
|
|
9030
|
+
# Add support for precision and recall
|
|
9031
|
+
val_precision_key = (
|
|
9032
|
+
"val_precisions" if "val_precisions" in history else "val_precision"
|
|
9033
|
+
)
|
|
9034
|
+
val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
|
|
9035
|
+
|
|
9036
|
+
# Collect available metrics for plotting
|
|
9037
|
+
available_metrics = []
|
|
9038
|
+
metric_info = {
|
|
9039
|
+
"Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
|
|
9040
|
+
"IoU": (val_iou_key, None, ["Val IoU"]),
|
|
9041
|
+
"F1": (val_f1_key, None, ["Val F1"]),
|
|
9042
|
+
"Precision": (val_precision_key, None, ["Val Precision"]),
|
|
9043
|
+
"Recall": (val_recall_key, None, ["Val Recall"]),
|
|
9044
|
+
}
|
|
7545
9045
|
|
|
7546
|
-
|
|
7547
|
-
|
|
7548
|
-
|
|
7549
|
-
figsize = (15, 5) if has_dice else (10, 5)
|
|
9046
|
+
for metric_name, (key1, key2, labels) in metric_info.items():
|
|
9047
|
+
if key1 in history or (key2 and key2 in history):
|
|
9048
|
+
available_metrics.append((metric_name, key1, key2, labels))
|
|
7550
9049
|
|
|
7551
|
-
|
|
9050
|
+
# Determine number of subplots and figure size
|
|
9051
|
+
n_plots = len(available_metrics)
|
|
9052
|
+
if figsize is None:
|
|
9053
|
+
figsize = (5 * n_plots, 5)
|
|
7552
9054
|
|
|
7553
|
-
#
|
|
7554
|
-
|
|
9055
|
+
# Create DataFrame for all metrics
|
|
9056
|
+
n_epochs = 0
|
|
9057
|
+
df_data = {}
|
|
9058
|
+
|
|
9059
|
+
# Add epochs
|
|
9060
|
+
if "epochs" in history:
|
|
9061
|
+
df_data["epoch"] = history["epochs"]
|
|
9062
|
+
n_epochs = len(history["epochs"])
|
|
9063
|
+
elif train_loss_key in history:
|
|
9064
|
+
n_epochs = len(history[train_loss_key])
|
|
9065
|
+
df_data["epoch"] = list(range(1, n_epochs + 1))
|
|
9066
|
+
|
|
9067
|
+
# Add all available metrics to DataFrame
|
|
7555
9068
|
if train_loss_key in history:
|
|
7556
|
-
|
|
9069
|
+
df_data["train_loss"] = history[train_loss_key]
|
|
7557
9070
|
if val_loss_key in history:
|
|
7558
|
-
|
|
7559
|
-
plt.title("Loss")
|
|
7560
|
-
plt.xlabel("Epoch")
|
|
7561
|
-
plt.ylabel("Loss")
|
|
7562
|
-
plt.legend()
|
|
7563
|
-
plt.grid(True)
|
|
7564
|
-
|
|
7565
|
-
# Plot IoU
|
|
7566
|
-
plt.subplot(1, n_plots, 2)
|
|
9071
|
+
df_data["val_loss"] = history[val_loss_key]
|
|
7567
9072
|
if val_iou_key in history:
|
|
7568
|
-
|
|
7569
|
-
|
|
7570
|
-
|
|
7571
|
-
|
|
7572
|
-
|
|
7573
|
-
|
|
7574
|
-
|
|
7575
|
-
|
|
7576
|
-
|
|
7577
|
-
|
|
7578
|
-
|
|
7579
|
-
|
|
7580
|
-
|
|
7581
|
-
|
|
7582
|
-
|
|
7583
|
-
|
|
9073
|
+
df_data["val_iou"] = history[val_iou_key]
|
|
9074
|
+
if val_f1_key in history:
|
|
9075
|
+
df_data["val_f1"] = history[val_f1_key]
|
|
9076
|
+
if val_precision_key in history:
|
|
9077
|
+
df_data["val_precision"] = history[val_precision_key]
|
|
9078
|
+
if val_recall_key in history:
|
|
9079
|
+
df_data["val_recall"] = history[val_recall_key]
|
|
9080
|
+
|
|
9081
|
+
# Create DataFrame
|
|
9082
|
+
df = pd.DataFrame(df_data)
|
|
9083
|
+
|
|
9084
|
+
# Export to CSV if requested
|
|
9085
|
+
if csv_path:
|
|
9086
|
+
df.to_csv(csv_path, index=False)
|
|
9087
|
+
if verbose:
|
|
9088
|
+
print(f"Metrics exported to: {csv_path}")
|
|
9089
|
+
|
|
9090
|
+
# Create plots
|
|
9091
|
+
if n_plots > 0:
|
|
9092
|
+
fig, axes = plt.subplots(1, n_plots, figsize=figsize)
|
|
9093
|
+
if n_plots == 1:
|
|
9094
|
+
axes = [axes]
|
|
9095
|
+
|
|
9096
|
+
for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
|
|
9097
|
+
ax = axes[idx]
|
|
9098
|
+
|
|
9099
|
+
if metric_name == "Loss":
|
|
9100
|
+
# Special handling for loss (has both train and val)
|
|
9101
|
+
if key1 in history:
|
|
9102
|
+
ax.plot(history[key1], label=labels[0])
|
|
9103
|
+
if key2 and key2 in history:
|
|
9104
|
+
ax.plot(history[key2], label=labels[1])
|
|
9105
|
+
else:
|
|
9106
|
+
# Single metric plots
|
|
9107
|
+
if key1 in history:
|
|
9108
|
+
ax.plot(history[key1], label=labels[0])
|
|
7584
9109
|
|
|
7585
|
-
|
|
9110
|
+
ax.set_title(metric_name)
|
|
9111
|
+
ax.set_xlabel("Epoch")
|
|
9112
|
+
ax.set_ylabel(metric_name)
|
|
9113
|
+
ax.legend()
|
|
9114
|
+
ax.grid(True)
|
|
7586
9115
|
|
|
7587
|
-
|
|
7588
|
-
if "dpi" not in kwargs:
|
|
7589
|
-
kwargs["dpi"] = 150
|
|
7590
|
-
if "bbox_inches" not in kwargs:
|
|
7591
|
-
kwargs["bbox_inches"] = "tight"
|
|
7592
|
-
plt.savefig(save_path, **kwargs)
|
|
9116
|
+
plt.tight_layout()
|
|
7593
9117
|
|
|
7594
|
-
|
|
9118
|
+
if save_path:
|
|
9119
|
+
if "dpi" not in kwargs:
|
|
9120
|
+
kwargs["dpi"] = 150
|
|
9121
|
+
if "bbox_inches" not in kwargs:
|
|
9122
|
+
kwargs["bbox_inches"] = "tight"
|
|
9123
|
+
plt.savefig(save_path, **kwargs)
|
|
9124
|
+
|
|
9125
|
+
plt.show()
|
|
7595
9126
|
|
|
9127
|
+
# Print summary statistics
|
|
7596
9128
|
if verbose:
|
|
9129
|
+
print("\n=== Performance Metrics Summary ===")
|
|
7597
9130
|
if val_iou_key in history:
|
|
7598
|
-
print(
|
|
7599
|
-
|
|
7600
|
-
|
|
7601
|
-
|
|
7602
|
-
print(
|
|
9131
|
+
print(
|
|
9132
|
+
f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
|
|
9133
|
+
)
|
|
9134
|
+
if val_f1_key in history:
|
|
9135
|
+
print(
|
|
9136
|
+
f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
|
|
9137
|
+
)
|
|
9138
|
+
if val_precision_key in history:
|
|
9139
|
+
print(
|
|
9140
|
+
f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
|
|
9141
|
+
)
|
|
9142
|
+
if val_recall_key in history:
|
|
9143
|
+
print(
|
|
9144
|
+
f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
|
|
9145
|
+
)
|
|
9146
|
+
if val_loss_key in history:
|
|
9147
|
+
print(
|
|
9148
|
+
f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
|
|
9149
|
+
)
|
|
9150
|
+
print("===================================\n")
|
|
9151
|
+
|
|
9152
|
+
return df
|
|
7603
9153
|
|
|
7604
9154
|
|
|
7605
9155
|
def get_device() -> torch.device:
|