geoai-py 0.18.1__py2.py3-none-any.whl → 0.19.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +23 -1
- geoai/agents/__init__.py +1 -0
- geoai/agents/geo_agents.py +74 -29
- geoai/geoai.py +2 -0
- geoai/landcover_train.py +685 -0
- geoai/landcover_utils.py +383 -0
- geoai/map_widgets.py +556 -0
- geoai/moondream.py +990 -0
- geoai/tools/__init__.py +11 -0
- geoai/tools/sr.py +194 -0
- geoai/train.py +22 -0
- geoai/utils.py +304 -1654
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/METADATA +3 -1
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/RECORD +18 -14
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.18.1.dist-info → geoai_py-0.19.0.dist-info}/top_level.txt +0 -0
geoai/utils.py
CHANGED
|
@@ -64,7 +64,7 @@ def view_raster(
|
|
|
64
64
|
client_args: Optional[Dict] = {"cors_all": False},
|
|
65
65
|
basemap: Optional[str] = "OpenStreetMap",
|
|
66
66
|
basemap_args: Optional[Dict] = None,
|
|
67
|
-
backend: Optional[str] = "
|
|
67
|
+
backend: Optional[str] = "ipyleaflet",
|
|
68
68
|
**kwargs: Any,
|
|
69
69
|
) -> Any:
|
|
70
70
|
"""
|
|
@@ -87,7 +87,7 @@ def view_raster(
|
|
|
87
87
|
client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
|
|
88
88
|
basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
|
|
89
89
|
basemap_args (Optional[Dict], optional): Additional arguments for the basemap. Defaults to None.
|
|
90
|
-
backend (Optional[str], optional): The backend to use. Defaults to "
|
|
90
|
+
backend (Optional[str], optional): The backend to use. Defaults to "ipyleaflet".
|
|
91
91
|
**kwargs (Any): Additional keyword arguments.
|
|
92
92
|
|
|
93
93
|
Returns:
|
|
@@ -123,39 +123,26 @@ def view_raster(
|
|
|
123
123
|
if isinstance(source, dict):
|
|
124
124
|
source = dict_to_image(source)
|
|
125
125
|
|
|
126
|
-
if
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
if colormap is not None:
|
|
147
|
-
kwargs["colormap_name"] = colormap
|
|
148
|
-
if attribution is None:
|
|
149
|
-
attribution = "TiTiler"
|
|
150
|
-
|
|
151
|
-
m.add_cog_layer(
|
|
152
|
-
source,
|
|
153
|
-
name=layer_name,
|
|
154
|
-
opacity=opacity,
|
|
155
|
-
attribution=attribution,
|
|
156
|
-
zoom_to_layer=zoom_to_layer,
|
|
157
|
-
**kwargs,
|
|
158
|
-
)
|
|
126
|
+
if (
|
|
127
|
+
isinstance(source, str)
|
|
128
|
+
and source.lower().endswith(".tif")
|
|
129
|
+
and source.startswith("http")
|
|
130
|
+
):
|
|
131
|
+
if indexes is not None:
|
|
132
|
+
kwargs["bidx"] = indexes
|
|
133
|
+
if colormap is not None:
|
|
134
|
+
kwargs["colormap_name"] = colormap
|
|
135
|
+
if attribution is None:
|
|
136
|
+
attribution = "TiTiler"
|
|
137
|
+
|
|
138
|
+
m.add_cog_layer(
|
|
139
|
+
source,
|
|
140
|
+
name=layer_name,
|
|
141
|
+
opacity=opacity,
|
|
142
|
+
attribution=attribution,
|
|
143
|
+
zoom_to_layer=zoom_to_layer,
|
|
144
|
+
**kwargs,
|
|
145
|
+
)
|
|
159
146
|
else:
|
|
160
147
|
m.add_raster(
|
|
161
148
|
source=source,
|
|
@@ -237,8 +224,6 @@ def view_image(
|
|
|
237
224
|
plt.show()
|
|
238
225
|
plt.close()
|
|
239
226
|
|
|
240
|
-
return ax
|
|
241
|
-
|
|
242
227
|
|
|
243
228
|
def plot_images(
|
|
244
229
|
images: Iterable[torch.Tensor],
|
|
@@ -396,394 +381,6 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
|
|
|
396
381
|
return accum_mean / len(files), accum_std / len(files)
|
|
397
382
|
|
|
398
383
|
|
|
399
|
-
def calc_iou(
|
|
400
|
-
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
401
|
-
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
402
|
-
num_classes: Optional[int] = None,
|
|
403
|
-
ignore_index: Optional[int] = None,
|
|
404
|
-
smooth: float = 1e-6,
|
|
405
|
-
band: int = 1,
|
|
406
|
-
) -> Union[float, np.ndarray]:
|
|
407
|
-
"""
|
|
408
|
-
Calculate Intersection over Union (IoU) between ground truth and prediction masks.
|
|
409
|
-
|
|
410
|
-
This function computes the IoU metric for segmentation tasks. It supports both
|
|
411
|
-
binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
|
412
|
-
or file paths to raster files.
|
|
413
|
-
|
|
414
|
-
Args:
|
|
415
|
-
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
416
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
417
|
-
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
418
|
-
For multi-class segmentation: shape (H, W) with class indices.
|
|
419
|
-
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
420
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
421
|
-
Should have the same shape and format as ground_truth.
|
|
422
|
-
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
423
|
-
If None, assumes binary segmentation. Defaults to None.
|
|
424
|
-
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
425
|
-
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
426
|
-
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
427
|
-
Defaults to 1e-6.
|
|
428
|
-
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
429
|
-
Only used when input is a file path. Defaults to 1.
|
|
430
|
-
|
|
431
|
-
Returns:
|
|
432
|
-
Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
|
|
433
|
-
For multi-class segmentation, returns an array of IoU scores for each class.
|
|
434
|
-
|
|
435
|
-
Examples:
|
|
436
|
-
>>> # Binary segmentation with arrays
|
|
437
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
438
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
439
|
-
>>> iou = calc_iou(gt, pred)
|
|
440
|
-
>>> print(f"IoU: {iou:.4f}")
|
|
441
|
-
IoU: 0.8333
|
|
442
|
-
|
|
443
|
-
>>> # Multi-class segmentation
|
|
444
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
445
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
446
|
-
>>> iou = calc_iou(gt, pred, num_classes=3)
|
|
447
|
-
>>> print(f"IoU per class: {iou}")
|
|
448
|
-
IoU per class: [0.8333 0.5000 0.5000]
|
|
449
|
-
|
|
450
|
-
>>> # Using PyTorch tensors
|
|
451
|
-
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
452
|
-
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
453
|
-
>>> iou = calc_iou(gt_tensor, pred_tensor)
|
|
454
|
-
>>> print(f"IoU: {iou:.4f}")
|
|
455
|
-
IoU: 0.8333
|
|
456
|
-
|
|
457
|
-
>>> # Using raster file paths
|
|
458
|
-
>>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
459
|
-
>>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
|
|
460
|
-
Mean IoU: 0.7500
|
|
461
|
-
"""
|
|
462
|
-
# Load from file if string path is provided
|
|
463
|
-
if isinstance(ground_truth, str):
|
|
464
|
-
with rasterio.open(ground_truth) as src:
|
|
465
|
-
ground_truth = src.read(band)
|
|
466
|
-
if isinstance(prediction, str):
|
|
467
|
-
with rasterio.open(prediction) as src:
|
|
468
|
-
prediction = src.read(band)
|
|
469
|
-
|
|
470
|
-
# Convert to numpy if torch tensor
|
|
471
|
-
if isinstance(ground_truth, torch.Tensor):
|
|
472
|
-
ground_truth = ground_truth.cpu().numpy()
|
|
473
|
-
if isinstance(prediction, torch.Tensor):
|
|
474
|
-
prediction = prediction.cpu().numpy()
|
|
475
|
-
|
|
476
|
-
# Ensure inputs have the same shape
|
|
477
|
-
if ground_truth.shape != prediction.shape:
|
|
478
|
-
raise ValueError(
|
|
479
|
-
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
|
480
|
-
)
|
|
481
|
-
|
|
482
|
-
# Binary segmentation
|
|
483
|
-
if num_classes is None:
|
|
484
|
-
ground_truth = ground_truth.astype(bool)
|
|
485
|
-
prediction = prediction.astype(bool)
|
|
486
|
-
|
|
487
|
-
intersection = np.logical_and(ground_truth, prediction).sum()
|
|
488
|
-
union = np.logical_or(ground_truth, prediction).sum()
|
|
489
|
-
|
|
490
|
-
if union == 0:
|
|
491
|
-
return 1.0 if intersection == 0 else 0.0
|
|
492
|
-
|
|
493
|
-
iou = (intersection + smooth) / (union + smooth)
|
|
494
|
-
return float(iou)
|
|
495
|
-
|
|
496
|
-
# Multi-class segmentation
|
|
497
|
-
else:
|
|
498
|
-
iou_per_class = []
|
|
499
|
-
|
|
500
|
-
for class_idx in range(num_classes):
|
|
501
|
-
# Handle ignored class by appending np.nan
|
|
502
|
-
if ignore_index is not None and class_idx == ignore_index:
|
|
503
|
-
iou_per_class.append(np.nan)
|
|
504
|
-
continue
|
|
505
|
-
|
|
506
|
-
# Create binary masks for current class
|
|
507
|
-
gt_class = (ground_truth == class_idx).astype(bool)
|
|
508
|
-
pred_class = (prediction == class_idx).astype(bool)
|
|
509
|
-
|
|
510
|
-
intersection = np.logical_and(gt_class, pred_class).sum()
|
|
511
|
-
union = np.logical_or(gt_class, pred_class).sum()
|
|
512
|
-
|
|
513
|
-
if union == 0:
|
|
514
|
-
# If class is not present in both gt and pred
|
|
515
|
-
iou_per_class.append(np.nan)
|
|
516
|
-
else:
|
|
517
|
-
iou_per_class.append((intersection + smooth) / (union + smooth))
|
|
518
|
-
|
|
519
|
-
return np.array(iou_per_class)
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
def calc_f1_score(
|
|
523
|
-
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
524
|
-
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
525
|
-
num_classes: Optional[int] = None,
|
|
526
|
-
ignore_index: Optional[int] = None,
|
|
527
|
-
smooth: float = 1e-6,
|
|
528
|
-
band: int = 1,
|
|
529
|
-
) -> Union[float, np.ndarray]:
|
|
530
|
-
"""
|
|
531
|
-
Calculate F1 score between ground truth and prediction masks.
|
|
532
|
-
|
|
533
|
-
The F1 score is the harmonic mean of precision and recall, computed as:
|
|
534
|
-
F1 = 2 * (precision * recall) / (precision + recall)
|
|
535
|
-
where precision = TP / (TP + FP) and recall = TP / (TP + FN).
|
|
536
|
-
|
|
537
|
-
This function supports both binary and multi-class segmentation, and can handle
|
|
538
|
-
numpy arrays, PyTorch tensors, or file paths to raster files.
|
|
539
|
-
|
|
540
|
-
Args:
|
|
541
|
-
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
542
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
543
|
-
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
544
|
-
For multi-class segmentation: shape (H, W) with class indices.
|
|
545
|
-
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
546
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
547
|
-
Should have the same shape and format as ground_truth.
|
|
548
|
-
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
549
|
-
If None, assumes binary segmentation. Defaults to None.
|
|
550
|
-
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
551
|
-
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
552
|
-
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
553
|
-
Defaults to 1e-6.
|
|
554
|
-
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
555
|
-
Only used when input is a file path. Defaults to 1.
|
|
556
|
-
|
|
557
|
-
Returns:
|
|
558
|
-
Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
|
|
559
|
-
For multi-class segmentation, returns an array of F1 scores for each class.
|
|
560
|
-
|
|
561
|
-
Examples:
|
|
562
|
-
>>> # Binary segmentation with arrays
|
|
563
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
564
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
565
|
-
>>> f1 = calc_f1_score(gt, pred)
|
|
566
|
-
>>> print(f"F1 Score: {f1:.4f}")
|
|
567
|
-
F1 Score: 0.8571
|
|
568
|
-
|
|
569
|
-
>>> # Multi-class segmentation
|
|
570
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
571
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
572
|
-
>>> f1 = calc_f1_score(gt, pred, num_classes=3)
|
|
573
|
-
>>> print(f"F1 Score per class: {f1}")
|
|
574
|
-
F1 Score per class: [0.8571 0.6667 0.6667]
|
|
575
|
-
|
|
576
|
-
>>> # Using PyTorch tensors
|
|
577
|
-
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
578
|
-
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
579
|
-
>>> f1 = calc_f1_score(gt_tensor, pred_tensor)
|
|
580
|
-
>>> print(f"F1 Score: {f1:.4f}")
|
|
581
|
-
F1 Score: 0.8571
|
|
582
|
-
|
|
583
|
-
>>> # Using raster file paths
|
|
584
|
-
>>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
585
|
-
>>> print(f"Mean F1: {np.nanmean(f1):.4f}")
|
|
586
|
-
Mean F1: 0.7302
|
|
587
|
-
"""
|
|
588
|
-
# Load from file if string path is provided
|
|
589
|
-
if isinstance(ground_truth, str):
|
|
590
|
-
with rasterio.open(ground_truth) as src:
|
|
591
|
-
ground_truth = src.read(band)
|
|
592
|
-
if isinstance(prediction, str):
|
|
593
|
-
with rasterio.open(prediction) as src:
|
|
594
|
-
prediction = src.read(band)
|
|
595
|
-
|
|
596
|
-
# Convert to numpy if torch tensor
|
|
597
|
-
if isinstance(ground_truth, torch.Tensor):
|
|
598
|
-
ground_truth = ground_truth.cpu().numpy()
|
|
599
|
-
if isinstance(prediction, torch.Tensor):
|
|
600
|
-
prediction = prediction.cpu().numpy()
|
|
601
|
-
|
|
602
|
-
# Ensure inputs have the same shape
|
|
603
|
-
if ground_truth.shape != prediction.shape:
|
|
604
|
-
raise ValueError(
|
|
605
|
-
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
|
606
|
-
)
|
|
607
|
-
|
|
608
|
-
# Binary segmentation
|
|
609
|
-
if num_classes is None:
|
|
610
|
-
ground_truth = ground_truth.astype(bool)
|
|
611
|
-
prediction = prediction.astype(bool)
|
|
612
|
-
|
|
613
|
-
# Calculate True Positives, False Positives, False Negatives
|
|
614
|
-
tp = np.logical_and(ground_truth, prediction).sum()
|
|
615
|
-
fp = np.logical_and(~ground_truth, prediction).sum()
|
|
616
|
-
fn = np.logical_and(ground_truth, ~prediction).sum()
|
|
617
|
-
|
|
618
|
-
# Calculate precision and recall
|
|
619
|
-
precision = (tp + smooth) / (tp + fp + smooth)
|
|
620
|
-
recall = (tp + smooth) / (tp + fn + smooth)
|
|
621
|
-
|
|
622
|
-
# Calculate F1 score
|
|
623
|
-
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
|
624
|
-
return float(f1)
|
|
625
|
-
|
|
626
|
-
# Multi-class segmentation
|
|
627
|
-
else:
|
|
628
|
-
f1_per_class = []
|
|
629
|
-
|
|
630
|
-
for class_idx in range(num_classes):
|
|
631
|
-
# Mark ignored class with np.nan
|
|
632
|
-
if ignore_index is not None and class_idx == ignore_index:
|
|
633
|
-
f1_per_class.append(np.nan)
|
|
634
|
-
continue
|
|
635
|
-
|
|
636
|
-
# Create binary masks for current class
|
|
637
|
-
gt_class = (ground_truth == class_idx).astype(bool)
|
|
638
|
-
pred_class = (prediction == class_idx).astype(bool)
|
|
639
|
-
|
|
640
|
-
# Calculate True Positives, False Positives, False Negatives
|
|
641
|
-
tp = np.logical_and(gt_class, pred_class).sum()
|
|
642
|
-
fp = np.logical_and(~gt_class, pred_class).sum()
|
|
643
|
-
fn = np.logical_and(gt_class, ~pred_class).sum()
|
|
644
|
-
|
|
645
|
-
# Calculate precision and recall
|
|
646
|
-
precision = (tp + smooth) / (tp + fp + smooth)
|
|
647
|
-
recall = (tp + smooth) / (tp + fn + smooth)
|
|
648
|
-
|
|
649
|
-
# Calculate F1 score
|
|
650
|
-
if tp + fp + fn == 0:
|
|
651
|
-
# If class is not present in both gt and pred
|
|
652
|
-
f1_per_class.append(np.nan)
|
|
653
|
-
else:
|
|
654
|
-
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
|
655
|
-
f1_per_class.append(f1)
|
|
656
|
-
|
|
657
|
-
return np.array(f1_per_class)
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
def calc_segmentation_metrics(
|
|
661
|
-
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
662
|
-
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
663
|
-
num_classes: Optional[int] = None,
|
|
664
|
-
ignore_index: Optional[int] = None,
|
|
665
|
-
smooth: float = 1e-6,
|
|
666
|
-
metrics: List[str] = ["iou", "f1"],
|
|
667
|
-
band: int = 1,
|
|
668
|
-
) -> Dict[str, Union[float, np.ndarray]]:
|
|
669
|
-
"""
|
|
670
|
-
Calculate multiple segmentation metrics between ground truth and prediction masks.
|
|
671
|
-
|
|
672
|
-
This is a convenient wrapper function that computes multiple metrics at once,
|
|
673
|
-
including IoU (Intersection over Union) and F1 score. It supports both binary
|
|
674
|
-
and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
|
675
|
-
or file paths to raster files.
|
|
676
|
-
|
|
677
|
-
Args:
|
|
678
|
-
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
679
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
680
|
-
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
681
|
-
For multi-class segmentation: shape (H, W) with class indices.
|
|
682
|
-
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
683
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
684
|
-
Should have the same shape and format as ground_truth.
|
|
685
|
-
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
686
|
-
If None, assumes binary segmentation. Defaults to None.
|
|
687
|
-
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
688
|
-
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
689
|
-
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
690
|
-
Defaults to 1e-6.
|
|
691
|
-
metrics (List[str], optional): List of metrics to calculate.
|
|
692
|
-
Options: "iou", "f1". Defaults to ["iou", "f1"].
|
|
693
|
-
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
694
|
-
Only used when input is a file path. Defaults to 1.
|
|
695
|
-
|
|
696
|
-
Returns:
|
|
697
|
-
Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
|
|
698
|
-
Keys are metric names ("iou", "f1"), values are the metric scores.
|
|
699
|
-
For binary segmentation, values are floats.
|
|
700
|
-
For multi-class segmentation, values are numpy arrays with per-class scores.
|
|
701
|
-
Also includes "mean_iou" and "mean_f1" for multi-class segmentation
|
|
702
|
-
(mean computed over valid classes, ignoring NaN values).
|
|
703
|
-
|
|
704
|
-
Examples:
|
|
705
|
-
>>> # Binary segmentation with arrays
|
|
706
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
707
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
708
|
-
>>> metrics = calc_segmentation_metrics(gt, pred)
|
|
709
|
-
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
|
710
|
-
IoU: 0.8333, F1: 0.8571
|
|
711
|
-
|
|
712
|
-
>>> # Multi-class segmentation
|
|
713
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
714
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
715
|
-
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
|
|
716
|
-
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
717
|
-
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
|
718
|
-
>>> print(f"Per-class IoU: {metrics['iou']}")
|
|
719
|
-
Mean IoU: 0.6111
|
|
720
|
-
Mean F1: 0.7302
|
|
721
|
-
Per-class IoU: [0.8333 0.5000 0.5000]
|
|
722
|
-
|
|
723
|
-
>>> # Calculate only IoU
|
|
724
|
-
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
|
|
725
|
-
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
726
|
-
Mean IoU: 0.6111
|
|
727
|
-
|
|
728
|
-
>>> # Using PyTorch tensors
|
|
729
|
-
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
730
|
-
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
731
|
-
>>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
|
|
732
|
-
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
|
733
|
-
IoU: 0.8333, F1: 0.8571
|
|
734
|
-
|
|
735
|
-
>>> # Using raster file paths
|
|
736
|
-
>>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
737
|
-
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
738
|
-
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
|
739
|
-
Mean IoU: 0.6111
|
|
740
|
-
Mean F1: 0.7302
|
|
741
|
-
"""
|
|
742
|
-
results = {}
|
|
743
|
-
|
|
744
|
-
# Calculate IoU if requested
|
|
745
|
-
if "iou" in metrics:
|
|
746
|
-
iou = calc_iou(
|
|
747
|
-
ground_truth,
|
|
748
|
-
prediction,
|
|
749
|
-
num_classes=num_classes,
|
|
750
|
-
ignore_index=ignore_index,
|
|
751
|
-
smooth=smooth,
|
|
752
|
-
band=band,
|
|
753
|
-
)
|
|
754
|
-
results["iou"] = iou
|
|
755
|
-
|
|
756
|
-
# Add mean IoU for multi-class
|
|
757
|
-
if num_classes is not None and isinstance(iou, np.ndarray):
|
|
758
|
-
# Calculate mean ignoring NaN values
|
|
759
|
-
valid_ious = iou[~np.isnan(iou)]
|
|
760
|
-
results["mean_iou"] = (
|
|
761
|
-
float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
|
|
762
|
-
)
|
|
763
|
-
|
|
764
|
-
# Calculate F1 score if requested
|
|
765
|
-
if "f1" in metrics:
|
|
766
|
-
f1 = calc_f1_score(
|
|
767
|
-
ground_truth,
|
|
768
|
-
prediction,
|
|
769
|
-
num_classes=num_classes,
|
|
770
|
-
ignore_index=ignore_index,
|
|
771
|
-
smooth=smooth,
|
|
772
|
-
band=band,
|
|
773
|
-
)
|
|
774
|
-
results["f1"] = f1
|
|
775
|
-
|
|
776
|
-
# Add mean F1 for multi-class
|
|
777
|
-
if num_classes is not None and isinstance(f1, np.ndarray):
|
|
778
|
-
# Calculate mean ignoring NaN values
|
|
779
|
-
valid_f1s = f1[~np.isnan(f1)]
|
|
780
|
-
results["mean_f1"] = (
|
|
781
|
-
float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
|
|
782
|
-
)
|
|
783
|
-
|
|
784
|
-
return results
|
|
785
|
-
|
|
786
|
-
|
|
787
384
|
def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
|
|
788
385
|
"""Convert a dictionary to a xarray DataArray. The dictionary should contain the
|
|
789
386
|
following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
|
|
@@ -1094,9 +691,8 @@ def view_vector(
|
|
|
1094
691
|
|
|
1095
692
|
def view_vector_interactive(
|
|
1096
693
|
vector_data: Union[str, gpd.GeoDataFrame],
|
|
1097
|
-
layer_name: str = "Vector",
|
|
694
|
+
layer_name: str = "Vector Layer",
|
|
1098
695
|
tiles_args: Optional[Dict] = None,
|
|
1099
|
-
opacity: float = 0.7,
|
|
1100
696
|
**kwargs: Any,
|
|
1101
697
|
) -> Any:
|
|
1102
698
|
"""
|
|
@@ -1111,7 +707,6 @@ def view_vector_interactive(
|
|
|
1111
707
|
layer_name (str, optional): The name of the layer. Defaults to "Vector Layer".
|
|
1112
708
|
tiles_args (dict, optional): Additional arguments for the localtileserver client.
|
|
1113
709
|
get_folium_tile_layer function. Defaults to None.
|
|
1114
|
-
opacity (float, optional): The opacity of the layer. Defaults to 0.7.
|
|
1115
710
|
**kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
|
|
1116
711
|
See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
|
|
1117
712
|
|
|
@@ -1126,8 +721,9 @@ def view_vector_interactive(
|
|
|
1126
721
|
>>> roads = gpd.read_file("roads.shp")
|
|
1127
722
|
>>> view_vector_interactive(roads, figsize=(12, 8))
|
|
1128
723
|
"""
|
|
1129
|
-
|
|
1130
|
-
|
|
724
|
+
import folium
|
|
725
|
+
import folium.plugins as plugins
|
|
726
|
+
from leafmap import cog_tile
|
|
1131
727
|
from localtileserver import TileClient, get_folium_tile_layer
|
|
1132
728
|
|
|
1133
729
|
google_tiles = {
|
|
@@ -1153,17 +749,9 @@ def view_vector_interactive(
|
|
|
1153
749
|
},
|
|
1154
750
|
}
|
|
1155
751
|
|
|
1156
|
-
# Make it compatible with binder and JupyterHub
|
|
1157
|
-
if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
|
|
1158
|
-
os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
|
|
1159
|
-
f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
|
|
1160
|
-
)
|
|
1161
|
-
|
|
1162
752
|
basemap_layer_name = None
|
|
1163
753
|
raster_layer = None
|
|
1164
754
|
|
|
1165
|
-
m = Map()
|
|
1166
|
-
|
|
1167
755
|
if "tiles" in kwargs and isinstance(kwargs["tiles"], str):
|
|
1168
756
|
if kwargs["tiles"].title() in google_tiles:
|
|
1169
757
|
basemap_layer_name = google_tiles[kwargs["tiles"].title()]["name"]
|
|
@@ -1174,17 +762,14 @@ def view_vector_interactive(
|
|
|
1174
762
|
tiles_args = {}
|
|
1175
763
|
if kwargs["tiles"].lower().startswith("http"):
|
|
1176
764
|
basemap_layer_name = "Remote Raster"
|
|
1177
|
-
|
|
765
|
+
kwargs["tiles"] = cog_tile(kwargs["tiles"], **tiles_args)
|
|
766
|
+
kwargs["attr"] = "TiTiler"
|
|
1178
767
|
else:
|
|
1179
768
|
basemap_layer_name = "Local Raster"
|
|
1180
769
|
client = TileClient(kwargs["tiles"])
|
|
1181
770
|
raster_layer = get_folium_tile_layer(client, **tiles_args)
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
name=basemap_layer_name,
|
|
1185
|
-
attribution="localtileserver",
|
|
1186
|
-
**tiles_args,
|
|
1187
|
-
)
|
|
771
|
+
kwargs["tiles"] = raster_layer.tiles
|
|
772
|
+
kwargs["attr"] = "localtileserver"
|
|
1188
773
|
|
|
1189
774
|
if "max_zoom" not in kwargs:
|
|
1190
775
|
kwargs["max_zoom"] = 30
|
|
@@ -1199,18 +784,23 @@ def view_vector_interactive(
|
|
|
1199
784
|
if not isinstance(vector_data, gpd.GeoDataFrame):
|
|
1200
785
|
raise TypeError("Input data must be a GeoDataFrame")
|
|
1201
786
|
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
kwargs["legend_position"] = "bottomleft"
|
|
1205
|
-
if "cmap" not in kwargs:
|
|
1206
|
-
kwargs["cmap"] = "viridis"
|
|
1207
|
-
m.add_data(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
|
|
787
|
+
layer_control = kwargs.pop("layer_control", True)
|
|
788
|
+
fullscreen_control = kwargs.pop("fullscreen_control", True)
|
|
1208
789
|
|
|
1209
|
-
|
|
1210
|
-
m.add_gdf(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
|
|
790
|
+
m = vector_data.explore(**kwargs)
|
|
1211
791
|
|
|
1212
|
-
|
|
1213
|
-
m.
|
|
792
|
+
# Change the layer name
|
|
793
|
+
for layer in m._children.values():
|
|
794
|
+
if isinstance(layer, folium.GeoJson):
|
|
795
|
+
layer.layer_name = layer_name
|
|
796
|
+
if isinstance(layer, folium.TileLayer) and basemap_layer_name:
|
|
797
|
+
layer.layer_name = basemap_layer_name
|
|
798
|
+
|
|
799
|
+
if layer_control:
|
|
800
|
+
m.add_child(folium.LayerControl())
|
|
801
|
+
|
|
802
|
+
if fullscreen_control:
|
|
803
|
+
plugins.Fullscreen().add_to(m)
|
|
1214
804
|
|
|
1215
805
|
return m
|
|
1216
806
|
|
|
@@ -3007,7 +2597,7 @@ def batch_vector_to_raster(
|
|
|
3007
2597
|
def export_geotiff_tiles(
|
|
3008
2598
|
in_raster,
|
|
3009
2599
|
out_folder,
|
|
3010
|
-
in_class_data
|
|
2600
|
+
in_class_data,
|
|
3011
2601
|
tile_size=256,
|
|
3012
2602
|
stride=128,
|
|
3013
2603
|
class_value_field="class",
|
|
@@ -3017,7 +2607,6 @@ def export_geotiff_tiles(
|
|
|
3017
2607
|
all_touched=True,
|
|
3018
2608
|
create_overview=False,
|
|
3019
2609
|
skip_empty_tiles=False,
|
|
3020
|
-
metadata_format="PASCAL_VOC",
|
|
3021
2610
|
):
|
|
3022
2611
|
"""
|
|
3023
2612
|
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
|
@@ -3025,8 +2614,7 @@ def export_geotiff_tiles(
|
|
|
3025
2614
|
Args:
|
|
3026
2615
|
in_raster (str): Path to input raster image
|
|
3027
2616
|
out_folder (str): Path to output folder
|
|
3028
|
-
in_class_data (str
|
|
3029
|
-
If None, only image tiles will be exported without labels. Defaults to None.
|
|
2617
|
+
in_class_data (str): Path to classification data - can be vector file or raster
|
|
3030
2618
|
tile_size (int): Size of tiles in pixels (square)
|
|
3031
2619
|
stride (int): Step size between tiles
|
|
3032
2620
|
class_value_field (str): Field containing class values (for vector data)
|
|
@@ -3036,7 +2624,6 @@ def export_geotiff_tiles(
|
|
|
3036
2624
|
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
|
3037
2625
|
create_overview (bool): Whether to create an overview image of all tiles
|
|
3038
2626
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
3039
|
-
metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
|
|
3040
2627
|
"""
|
|
3041
2628
|
|
|
3042
2629
|
import logging
|
|
@@ -3047,42 +2634,28 @@ def export_geotiff_tiles(
|
|
|
3047
2634
|
os.makedirs(out_folder, exist_ok=True)
|
|
3048
2635
|
image_dir = os.path.join(out_folder, "images")
|
|
3049
2636
|
os.makedirs(image_dir, exist_ok=True)
|
|
2637
|
+
label_dir = os.path.join(out_folder, "labels")
|
|
2638
|
+
os.makedirs(label_dir, exist_ok=True)
|
|
2639
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
2640
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
3050
2641
|
|
|
3051
|
-
#
|
|
3052
|
-
if in_class_data is not None:
|
|
3053
|
-
label_dir = os.path.join(out_folder, "labels")
|
|
3054
|
-
os.makedirs(label_dir, exist_ok=True)
|
|
3055
|
-
|
|
3056
|
-
# Create annotation directory based on metadata format
|
|
3057
|
-
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
|
3058
|
-
ann_dir = os.path.join(out_folder, "annotations")
|
|
3059
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
3060
|
-
|
|
3061
|
-
# Initialize COCO annotations dictionary
|
|
3062
|
-
if metadata_format == "COCO":
|
|
3063
|
-
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
3064
|
-
ann_id = 0
|
|
3065
|
-
|
|
3066
|
-
# Determine if class data is raster or vector (only if class data provided)
|
|
2642
|
+
# Determine if class data is raster or vector
|
|
3067
2643
|
is_class_data_raster = False
|
|
3068
|
-
if in_class_data
|
|
3069
|
-
|
|
3070
|
-
|
|
3071
|
-
|
|
3072
|
-
|
|
3073
|
-
|
|
3074
|
-
|
|
3075
|
-
is_class_data_raster = True
|
|
3076
|
-
if not quiet:
|
|
3077
|
-
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
3078
|
-
print(f"Raster CRS: {src.crs}")
|
|
3079
|
-
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
3080
|
-
except Exception:
|
|
3081
|
-
is_class_data_raster = False
|
|
2644
|
+
if isinstance(in_class_data, str):
|
|
2645
|
+
file_ext = Path(in_class_data).suffix.lower()
|
|
2646
|
+
# Common raster extensions
|
|
2647
|
+
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
2648
|
+
try:
|
|
2649
|
+
with rasterio.open(in_class_data) as src:
|
|
2650
|
+
is_class_data_raster = True
|
|
3082
2651
|
if not quiet:
|
|
3083
|
-
print(
|
|
3084
|
-
|
|
3085
|
-
)
|
|
2652
|
+
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
2653
|
+
print(f"Raster CRS: {src.crs}")
|
|
2654
|
+
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
2655
|
+
except Exception:
|
|
2656
|
+
is_class_data_raster = False
|
|
2657
|
+
if not quiet:
|
|
2658
|
+
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
|
3086
2659
|
|
|
3087
2660
|
# Open the input raster
|
|
3088
2661
|
with rasterio.open(in_raster) as src:
|
|
@@ -3102,10 +2675,10 @@ def export_geotiff_tiles(
|
|
|
3102
2675
|
if max_tiles is None:
|
|
3103
2676
|
max_tiles = total_tiles
|
|
3104
2677
|
|
|
3105
|
-
# Process classification data
|
|
2678
|
+
# Process classification data
|
|
3106
2679
|
class_to_id = {}
|
|
3107
2680
|
|
|
3108
|
-
if
|
|
2681
|
+
if is_class_data_raster:
|
|
3109
2682
|
# Load raster class data
|
|
3110
2683
|
with rasterio.open(in_class_data) as class_src:
|
|
3111
2684
|
# Check if raster CRS matches
|
|
@@ -3138,18 +2711,7 @@ def export_geotiff_tiles(
|
|
|
3138
2711
|
|
|
3139
2712
|
# Create class mapping
|
|
3140
2713
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
3141
|
-
|
|
3142
|
-
# Populate COCO categories
|
|
3143
|
-
if metadata_format == "COCO":
|
|
3144
|
-
for cls_val in unique_classes:
|
|
3145
|
-
coco_annotations["categories"].append(
|
|
3146
|
-
{
|
|
3147
|
-
"id": class_to_id[int(cls_val)],
|
|
3148
|
-
"name": str(int(cls_val)),
|
|
3149
|
-
"supercategory": "object",
|
|
3150
|
-
}
|
|
3151
|
-
)
|
|
3152
|
-
elif in_class_data is not None:
|
|
2714
|
+
else:
|
|
3153
2715
|
# Load vector class data
|
|
3154
2716
|
try:
|
|
3155
2717
|
gdf = gpd.read_file(in_class_data)
|
|
@@ -3178,33 +2740,12 @@ def export_geotiff_tiles(
|
|
|
3178
2740
|
)
|
|
3179
2741
|
# Create class mapping
|
|
3180
2742
|
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
3181
|
-
|
|
3182
|
-
# Populate COCO categories
|
|
3183
|
-
if metadata_format == "COCO":
|
|
3184
|
-
for cls_val in unique_classes:
|
|
3185
|
-
coco_annotations["categories"].append(
|
|
3186
|
-
{
|
|
3187
|
-
"id": class_to_id[cls_val],
|
|
3188
|
-
"name": str(cls_val),
|
|
3189
|
-
"supercategory": "object",
|
|
3190
|
-
}
|
|
3191
|
-
)
|
|
3192
2743
|
else:
|
|
3193
2744
|
if not quiet:
|
|
3194
2745
|
print(
|
|
3195
2746
|
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
|
3196
2747
|
)
|
|
3197
2748
|
class_to_id = {1: 1} # Default mapping
|
|
3198
|
-
|
|
3199
|
-
# Populate COCO categories with default
|
|
3200
|
-
if metadata_format == "COCO":
|
|
3201
|
-
coco_annotations["categories"].append(
|
|
3202
|
-
{
|
|
3203
|
-
"id": 1,
|
|
3204
|
-
"name": "object",
|
|
3205
|
-
"supercategory": "object",
|
|
3206
|
-
}
|
|
3207
|
-
)
|
|
3208
2749
|
except Exception as e:
|
|
3209
2750
|
raise ValueError(f"Error processing vector data: {e}")
|
|
3210
2751
|
|
|
@@ -3271,8 +2812,8 @@ def export_geotiff_tiles(
|
|
|
3271
2812
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
3272
2813
|
has_features = False
|
|
3273
2814
|
|
|
3274
|
-
# Process classification data to create labels
|
|
3275
|
-
if
|
|
2815
|
+
# Process classification data to create labels
|
|
2816
|
+
if is_class_data_raster:
|
|
3276
2817
|
# For raster class data
|
|
3277
2818
|
with rasterio.open(in_class_data) as class_src:
|
|
3278
2819
|
# Calculate window in class raster
|
|
@@ -3322,7 +2863,7 @@ def export_geotiff_tiles(
|
|
|
3322
2863
|
except Exception as e:
|
|
3323
2864
|
pbar.write(f"Error reading class raster window: {e}")
|
|
3324
2865
|
stats["errors"] += 1
|
|
3325
|
-
|
|
2866
|
+
else:
|
|
3326
2867
|
# For vector class data
|
|
3327
2868
|
# Find features that intersect with window
|
|
3328
2869
|
window_features = gdf[gdf.intersects(window_bounds)]
|
|
@@ -3365,8 +2906,8 @@ def export_geotiff_tiles(
|
|
|
3365
2906
|
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
3366
2907
|
stats["errors"] += 1
|
|
3367
2908
|
|
|
3368
|
-
# Skip tile if no features and skip_empty_tiles is True
|
|
3369
|
-
if
|
|
2909
|
+
# Skip tile if no features and skip_empty_tiles is True
|
|
2910
|
+
if skip_empty_tiles and not has_features:
|
|
3370
2911
|
pbar.update(1)
|
|
3371
2912
|
tile_index += 1
|
|
3372
2913
|
continue
|
|
@@ -3397,212 +2938,96 @@ def export_geotiff_tiles(
|
|
|
3397
2938
|
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
|
3398
2939
|
stats["errors"] += 1
|
|
3399
2940
|
|
|
3400
|
-
#
|
|
3401
|
-
|
|
3402
|
-
|
|
3403
|
-
|
|
3404
|
-
|
|
3405
|
-
|
|
3406
|
-
|
|
3407
|
-
|
|
3408
|
-
|
|
3409
|
-
|
|
3410
|
-
"transform": window_transform,
|
|
3411
|
-
}
|
|
2941
|
+
# Create profile for label GeoTIFF
|
|
2942
|
+
label_profile = {
|
|
2943
|
+
"driver": "GTiff",
|
|
2944
|
+
"height": tile_size,
|
|
2945
|
+
"width": tile_size,
|
|
2946
|
+
"count": 1,
|
|
2947
|
+
"dtype": "uint8",
|
|
2948
|
+
"crs": src.crs,
|
|
2949
|
+
"transform": window_transform,
|
|
2950
|
+
}
|
|
3412
2951
|
|
|
3413
|
-
|
|
3414
|
-
|
|
3415
|
-
|
|
3416
|
-
|
|
2952
|
+
# Export label as GeoTIFF
|
|
2953
|
+
label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
|
|
2954
|
+
try:
|
|
2955
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
2956
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
|
3417
2957
|
|
|
3418
|
-
|
|
3419
|
-
|
|
3420
|
-
|
|
3421
|
-
|
|
3422
|
-
|
|
3423
|
-
|
|
2958
|
+
if has_features:
|
|
2959
|
+
stats["tiles_with_features"] += 1
|
|
2960
|
+
stats["feature_pixels"] += np.count_nonzero(label_mask)
|
|
2961
|
+
except Exception as e:
|
|
2962
|
+
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
|
2963
|
+
stats["errors"] += 1
|
|
3424
2964
|
|
|
3425
|
-
# Create
|
|
2965
|
+
# Create XML annotation for object detection if using vector class data
|
|
3426
2966
|
if (
|
|
3427
|
-
|
|
3428
|
-
and not is_class_data_raster
|
|
2967
|
+
not is_class_data_raster
|
|
3429
2968
|
and "gdf" in locals()
|
|
3430
2969
|
and len(window_features) > 0
|
|
3431
2970
|
):
|
|
3432
|
-
|
|
3433
|
-
|
|
3434
|
-
|
|
3435
|
-
|
|
3436
|
-
ET.SubElement(root, "filename").text = (
|
|
3437
|
-
f"tile_{tile_index:06d}.tif"
|
|
3438
|
-
)
|
|
3439
|
-
|
|
3440
|
-
size = ET.SubElement(root, "size")
|
|
3441
|
-
ET.SubElement(size, "width").text = str(tile_size)
|
|
3442
|
-
ET.SubElement(size, "height").text = str(tile_size)
|
|
3443
|
-
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
3444
|
-
|
|
3445
|
-
# Add georeference information
|
|
3446
|
-
geo = ET.SubElement(root, "georeference")
|
|
3447
|
-
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
3448
|
-
ET.SubElement(geo, "transform").text = str(
|
|
3449
|
-
window_transform
|
|
3450
|
-
).replace("\n", "")
|
|
3451
|
-
ET.SubElement(geo, "bounds").text = (
|
|
3452
|
-
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
3453
|
-
)
|
|
3454
|
-
|
|
3455
|
-
# Add objects
|
|
3456
|
-
for idx, feature in window_features.iterrows():
|
|
3457
|
-
# Get feature class
|
|
3458
|
-
if class_value_field in feature:
|
|
3459
|
-
class_val = feature[class_value_field]
|
|
3460
|
-
else:
|
|
3461
|
-
class_val = "object"
|
|
3462
|
-
|
|
3463
|
-
# Get geometry bounds in pixel coordinates
|
|
3464
|
-
geom = feature.geometry.intersection(window_bounds)
|
|
3465
|
-
if not geom.is_empty:
|
|
3466
|
-
# Get bounds in world coordinates
|
|
3467
|
-
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3468
|
-
|
|
3469
|
-
# Convert to pixel coordinates
|
|
3470
|
-
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3471
|
-
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3472
|
-
|
|
3473
|
-
# Ensure coordinates are within tile bounds
|
|
3474
|
-
xmin = max(0, min(tile_size, int(col_min)))
|
|
3475
|
-
ymin = max(0, min(tile_size, int(row_min)))
|
|
3476
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
|
3477
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
|
3478
|
-
|
|
3479
|
-
# Only add if the box has non-zero area
|
|
3480
|
-
if xmax > xmin and ymax > ymin:
|
|
3481
|
-
obj = ET.SubElement(root, "object")
|
|
3482
|
-
ET.SubElement(obj, "name").text = str(class_val)
|
|
3483
|
-
ET.SubElement(obj, "difficult").text = "0"
|
|
3484
|
-
|
|
3485
|
-
bbox = ET.SubElement(obj, "bndbox")
|
|
3486
|
-
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
3487
|
-
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
3488
|
-
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
3489
|
-
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
3490
|
-
|
|
3491
|
-
# Save XML
|
|
3492
|
-
tree = ET.ElementTree(root)
|
|
3493
|
-
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
3494
|
-
tree.write(xml_path)
|
|
3495
|
-
|
|
3496
|
-
elif metadata_format == "COCO":
|
|
3497
|
-
# Add image info
|
|
3498
|
-
image_id = tile_index
|
|
3499
|
-
coco_annotations["images"].append(
|
|
3500
|
-
{
|
|
3501
|
-
"id": image_id,
|
|
3502
|
-
"file_name": f"tile_{tile_index:06d}.tif",
|
|
3503
|
-
"width": tile_size,
|
|
3504
|
-
"height": tile_size,
|
|
3505
|
-
"crs": str(src.crs),
|
|
3506
|
-
"transform": str(window_transform),
|
|
3507
|
-
}
|
|
3508
|
-
)
|
|
3509
|
-
|
|
3510
|
-
# Add annotations for each feature
|
|
3511
|
-
for _, feature in window_features.iterrows():
|
|
3512
|
-
# Get feature class
|
|
3513
|
-
if class_value_field in feature:
|
|
3514
|
-
class_val = feature[class_value_field]
|
|
3515
|
-
category_id = class_to_id.get(class_val, 1)
|
|
3516
|
-
else:
|
|
3517
|
-
category_id = 1
|
|
2971
|
+
# Create XML annotation
|
|
2972
|
+
root = ET.Element("annotation")
|
|
2973
|
+
ET.SubElement(root, "folder").text = "images"
|
|
2974
|
+
ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
|
|
3518
2975
|
|
|
3519
|
-
|
|
3520
|
-
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3524
|
-
|
|
3525
|
-
# Convert to pixel coordinates
|
|
3526
|
-
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3527
|
-
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3528
|
-
|
|
3529
|
-
# Ensure coordinates are within tile bounds
|
|
3530
|
-
xmin = max(0, min(tile_size, int(col_min)))
|
|
3531
|
-
ymin = max(0, min(tile_size, int(row_min)))
|
|
3532
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
|
3533
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
|
3534
|
-
|
|
3535
|
-
# Skip if box is too small
|
|
3536
|
-
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
3537
|
-
continue
|
|
3538
|
-
|
|
3539
|
-
width = xmax - xmin
|
|
3540
|
-
height = ymax - ymin
|
|
3541
|
-
|
|
3542
|
-
# Add annotation
|
|
3543
|
-
ann_id += 1
|
|
3544
|
-
coco_annotations["annotations"].append(
|
|
3545
|
-
{
|
|
3546
|
-
"id": ann_id,
|
|
3547
|
-
"image_id": image_id,
|
|
3548
|
-
"category_id": category_id,
|
|
3549
|
-
"bbox": [xmin, ymin, width, height],
|
|
3550
|
-
"area": width * height,
|
|
3551
|
-
"iscrowd": 0,
|
|
3552
|
-
}
|
|
3553
|
-
)
|
|
2976
|
+
size = ET.SubElement(root, "size")
|
|
2977
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
|
2978
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
|
2979
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
3554
2980
|
|
|
3555
|
-
|
|
3556
|
-
|
|
3557
|
-
|
|
2981
|
+
# Add georeference information
|
|
2982
|
+
geo = ET.SubElement(root, "georeference")
|
|
2983
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
2984
|
+
ET.SubElement(geo, "transform").text = str(
|
|
2985
|
+
window_transform
|
|
2986
|
+
).replace("\n", "")
|
|
2987
|
+
ET.SubElement(geo, "bounds").text = (
|
|
2988
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
2989
|
+
)
|
|
3558
2990
|
|
|
3559
|
-
|
|
3560
|
-
|
|
3561
|
-
|
|
3562
|
-
|
|
3563
|
-
|
|
3564
|
-
|
|
3565
|
-
|
|
3566
|
-
class_id = 0
|
|
2991
|
+
# Add objects
|
|
2992
|
+
for idx, feature in window_features.iterrows():
|
|
2993
|
+
# Get feature class
|
|
2994
|
+
if class_value_field in feature:
|
|
2995
|
+
class_val = feature[class_value_field]
|
|
2996
|
+
else:
|
|
2997
|
+
class_val = "object"
|
|
3567
2998
|
|
|
3568
|
-
|
|
3569
|
-
|
|
3570
|
-
|
|
3571
|
-
|
|
3572
|
-
|
|
3573
|
-
|
|
3574
|
-
|
|
3575
|
-
|
|
3576
|
-
|
|
3577
|
-
|
|
3578
|
-
|
|
3579
|
-
|
|
3580
|
-
|
|
3581
|
-
|
|
3582
|
-
|
|
3583
|
-
|
|
3584
|
-
|
|
3585
|
-
|
|
3586
|
-
|
|
3587
|
-
|
|
3588
|
-
|
|
3589
|
-
|
|
3590
|
-
|
|
3591
|
-
|
|
3592
|
-
|
|
3593
|
-
|
|
3594
|
-
|
|
3595
|
-
yolo_annotations.append(
|
|
3596
|
-
f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
|
3597
|
-
)
|
|
2999
|
+
# Get geometry bounds in pixel coordinates
|
|
3000
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
3001
|
+
if not geom.is_empty:
|
|
3002
|
+
# Get bounds in world coordinates
|
|
3003
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3004
|
+
|
|
3005
|
+
# Convert to pixel coordinates
|
|
3006
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3007
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3008
|
+
|
|
3009
|
+
# Ensure coordinates are within tile bounds
|
|
3010
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
|
3011
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
|
3012
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
|
3013
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
|
3014
|
+
|
|
3015
|
+
# Only add if the box has non-zero area
|
|
3016
|
+
if xmax > xmin and ymax > ymin:
|
|
3017
|
+
obj = ET.SubElement(root, "object")
|
|
3018
|
+
ET.SubElement(obj, "name").text = str(class_val)
|
|
3019
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
3020
|
+
|
|
3021
|
+
bbox = ET.SubElement(obj, "bndbox")
|
|
3022
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
3023
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
3024
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
3025
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
3598
3026
|
|
|
3599
|
-
|
|
3600
|
-
|
|
3601
|
-
|
|
3602
|
-
|
|
3603
|
-
)
|
|
3604
|
-
with open(yolo_path, "w") as f:
|
|
3605
|
-
f.write("\n".join(yolo_annotations))
|
|
3027
|
+
# Save XML
|
|
3028
|
+
tree = ET.ElementTree(root)
|
|
3029
|
+
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
3030
|
+
tree.write(xml_path)
|
|
3606
3031
|
|
|
3607
3032
|
# Update progress bar
|
|
3608
3033
|
pbar.update(1)
|
|
@@ -3620,39 +3045,6 @@ def export_geotiff_tiles(
|
|
|
3620
3045
|
# Close progress bar
|
|
3621
3046
|
pbar.close()
|
|
3622
3047
|
|
|
3623
|
-
# Save COCO annotations if applicable (only if class data provided)
|
|
3624
|
-
if in_class_data is not None and metadata_format == "COCO":
|
|
3625
|
-
try:
|
|
3626
|
-
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
|
3627
|
-
json.dump(coco_annotations, f, indent=2)
|
|
3628
|
-
if not quiet:
|
|
3629
|
-
print(
|
|
3630
|
-
f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
|
|
3631
|
-
f"{len(coco_annotations['annotations'])} annotations, "
|
|
3632
|
-
f"{len(coco_annotations['categories'])} categories"
|
|
3633
|
-
)
|
|
3634
|
-
except Exception as e:
|
|
3635
|
-
if not quiet:
|
|
3636
|
-
print(f"ERROR saving COCO annotations: {e}")
|
|
3637
|
-
stats["errors"] += 1
|
|
3638
|
-
|
|
3639
|
-
# Save YOLO classes file if applicable (only if class data provided)
|
|
3640
|
-
if in_class_data is not None and metadata_format == "YOLO":
|
|
3641
|
-
try:
|
|
3642
|
-
# Create classes.txt with class names
|
|
3643
|
-
classes_path = os.path.join(out_folder, "classes.txt")
|
|
3644
|
-
# Sort by class ID to ensure correct order
|
|
3645
|
-
sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
|
|
3646
|
-
with open(classes_path, "w") as f:
|
|
3647
|
-
for class_val, _ in sorted_classes:
|
|
3648
|
-
f.write(f"{class_val}\n")
|
|
3649
|
-
if not quiet:
|
|
3650
|
-
print(f"Saved YOLO classes file with {len(class_to_id)} classes")
|
|
3651
|
-
except Exception as e:
|
|
3652
|
-
if not quiet:
|
|
3653
|
-
print(f"ERROR saving YOLO classes file: {e}")
|
|
3654
|
-
stats["errors"] += 1
|
|
3655
|
-
|
|
3656
3048
|
# Create overview image if requested
|
|
3657
3049
|
if create_overview and stats["tile_coordinates"]:
|
|
3658
3050
|
try:
|
|
@@ -3670,14 +3062,13 @@ def export_geotiff_tiles(
|
|
|
3670
3062
|
if not quiet:
|
|
3671
3063
|
print("\n------- Export Summary -------")
|
|
3672
3064
|
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
3673
|
-
|
|
3065
|
+
print(
|
|
3066
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
3067
|
+
)
|
|
3068
|
+
if stats["tiles_with_features"] > 0:
|
|
3674
3069
|
print(
|
|
3675
|
-
f"
|
|
3070
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
3676
3071
|
)
|
|
3677
|
-
if stats["tiles_with_features"] > 0:
|
|
3678
|
-
print(
|
|
3679
|
-
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
3680
|
-
)
|
|
3681
3072
|
if stats["errors"] > 0:
|
|
3682
3073
|
print(f"Errors encountered: {stats['errors']}")
|
|
3683
3074
|
print(f"Output saved to: {out_folder}")
|
|
@@ -3686,6 +3077,7 @@ def export_geotiff_tiles(
|
|
|
3686
3077
|
if stats["total_tiles"] > 0:
|
|
3687
3078
|
print("\n------- Georeference Verification -------")
|
|
3688
3079
|
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
|
3080
|
+
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
|
3689
3081
|
|
|
3690
3082
|
if os.path.exists(sample_image):
|
|
3691
3083
|
try:
|
|
@@ -3701,22 +3093,19 @@ def export_geotiff_tiles(
|
|
|
3701
3093
|
except Exception as e:
|
|
3702
3094
|
print(f"Error verifying image georeference: {e}")
|
|
3703
3095
|
|
|
3704
|
-
|
|
3705
|
-
|
|
3706
|
-
|
|
3707
|
-
|
|
3708
|
-
|
|
3709
|
-
|
|
3710
|
-
|
|
3711
|
-
|
|
3712
|
-
|
|
3713
|
-
|
|
3714
|
-
|
|
3715
|
-
|
|
3716
|
-
|
|
3717
|
-
)
|
|
3718
|
-
except Exception as e:
|
|
3719
|
-
print(f"Error verifying label georeference: {e}")
|
|
3096
|
+
if os.path.exists(sample_label):
|
|
3097
|
+
try:
|
|
3098
|
+
with rasterio.open(sample_label) as lbl:
|
|
3099
|
+
print(f"Label CRS: {lbl.crs}")
|
|
3100
|
+
print(f"Label transform: {lbl.transform}")
|
|
3101
|
+
print(
|
|
3102
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
3103
|
+
)
|
|
3104
|
+
print(
|
|
3105
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
3106
|
+
)
|
|
3107
|
+
except Exception as e:
|
|
3108
|
+
print(f"Error verifying label georeference: {e}")
|
|
3720
3109
|
|
|
3721
3110
|
# Return statistics dictionary for further processing if needed
|
|
3722
3111
|
return stats
|
|
@@ -3724,9 +3113,8 @@ def export_geotiff_tiles(
|
|
|
3724
3113
|
|
|
3725
3114
|
def export_geotiff_tiles_batch(
|
|
3726
3115
|
images_folder,
|
|
3727
|
-
masks_folder
|
|
3728
|
-
|
|
3729
|
-
output_folder=None,
|
|
3116
|
+
masks_folder,
|
|
3117
|
+
output_folder,
|
|
3730
3118
|
tile_size=256,
|
|
3731
3119
|
stride=128,
|
|
3732
3120
|
class_value_field="class",
|
|
@@ -3734,43 +3122,25 @@ def export_geotiff_tiles_batch(
|
|
|
3734
3122
|
max_tiles=None,
|
|
3735
3123
|
quiet=False,
|
|
3736
3124
|
all_touched=True,
|
|
3125
|
+
create_overview=False,
|
|
3737
3126
|
skip_empty_tiles=False,
|
|
3738
3127
|
image_extensions=None,
|
|
3739
3128
|
mask_extensions=None,
|
|
3740
|
-
match_by_name=False,
|
|
3741
|
-
metadata_format="PASCAL_VOC",
|
|
3742
3129
|
) -> Dict[str, Any]:
|
|
3743
3130
|
"""
|
|
3744
|
-
Export georeferenced GeoTIFF tiles from images and
|
|
3745
|
-
|
|
3746
|
-
This function supports four modes:
|
|
3747
|
-
1. Images only (no masks) - when neither masks_file nor masks_folder is provided
|
|
3748
|
-
2. Single vector file covering all images (masks_file parameter)
|
|
3749
|
-
3. Multiple vector files, one per image (masks_folder parameter)
|
|
3750
|
-
4. Multiple raster mask files (masks_folder parameter)
|
|
3131
|
+
Export georeferenced GeoTIFF tiles from folders of images and masks.
|
|
3751
3132
|
|
|
3752
|
-
|
|
3133
|
+
This function processes multiple image-mask pairs from input folders,
|
|
3134
|
+
generating tiles for each pair. All image tiles are saved to a single
|
|
3135
|
+
'images' folder and all mask tiles to a single 'masks' folder.
|
|
3753
3136
|
|
|
3754
|
-
|
|
3755
|
-
|
|
3756
|
-
|
|
3757
|
-
For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
|
|
3758
|
-
are paired either by matching filenames (match_by_name=True) or by sorted order
|
|
3759
|
-
(match_by_name=False).
|
|
3760
|
-
|
|
3761
|
-
All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
|
|
3762
|
-
to a single 'masks' folder within the output directory.
|
|
3137
|
+
Images and masks are paired by their sorted order (alphabetically), not by
|
|
3138
|
+
filename matching. The number of images and masks must be equal.
|
|
3763
3139
|
|
|
3764
3140
|
Args:
|
|
3765
3141
|
images_folder (str): Path to folder containing raster images
|
|
3766
|
-
masks_folder (str
|
|
3767
|
-
|
|
3768
|
-
and masks_file is also not provided, only image tiles will be exported.
|
|
3769
|
-
masks_file (str, optional): Path to a single vector file covering all images.
|
|
3770
|
-
Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
|
|
3771
|
-
and masks_folder is also not provided, only image tiles will be exported.
|
|
3772
|
-
output_folder (str, optional): Path to output folder. If None, creates 'tiles'
|
|
3773
|
-
subfolder in images_folder.
|
|
3142
|
+
masks_folder (str): Path to folder containing classification masks/vectors
|
|
3143
|
+
output_folder (str): Path to output folder
|
|
3774
3144
|
tile_size (int): Size of tiles in pixels (square)
|
|
3775
3145
|
stride (int): Step size between tiles
|
|
3776
3146
|
class_value_field (str): Field containing class values (for vector data)
|
|
@@ -3782,63 +3152,18 @@ def export_geotiff_tiles_batch(
|
|
|
3782
3152
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
3783
3153
|
image_extensions (list): List of image file extensions to process (default: common raster formats)
|
|
3784
3154
|
mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
|
|
3785
|
-
match_by_name (bool): If True, match image and mask files by base filename.
|
|
3786
|
-
If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
|
|
3787
|
-
metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
|
|
3788
|
-
Default is "PASCAL_VOC".
|
|
3789
3155
|
|
|
3790
3156
|
Returns:
|
|
3791
3157
|
Dict[str, Any]: Dictionary containing batch processing statistics
|
|
3792
3158
|
|
|
3793
3159
|
Raises:
|
|
3794
|
-
ValueError: If no images found, or if
|
|
3795
|
-
or if counts don't match when using masks_folder with match_by_name=False.
|
|
3796
|
-
|
|
3797
|
-
Examples:
|
|
3798
|
-
# Images only (no masks)
|
|
3799
|
-
>>> stats = export_geotiff_tiles_batch(
|
|
3800
|
-
... images_folder='data/images',
|
|
3801
|
-
... output_folder='output/tiles'
|
|
3802
|
-
... )
|
|
3803
|
-
|
|
3804
|
-
# Single vector file covering all images
|
|
3805
|
-
>>> stats = export_geotiff_tiles_batch(
|
|
3806
|
-
... images_folder='data/images',
|
|
3807
|
-
... masks_file='data/buildings.geojson',
|
|
3808
|
-
... output_folder='output/tiles'
|
|
3809
|
-
... )
|
|
3810
|
-
|
|
3811
|
-
# Multiple vector files, matched by filename
|
|
3812
|
-
>>> stats = export_geotiff_tiles_batch(
|
|
3813
|
-
... images_folder='data/images',
|
|
3814
|
-
... masks_folder='data/masks',
|
|
3815
|
-
... output_folder='output/tiles',
|
|
3816
|
-
... match_by_name=True
|
|
3817
|
-
... )
|
|
3818
|
-
|
|
3819
|
-
# Multiple mask files, matched by sorted order
|
|
3820
|
-
>>> stats = export_geotiff_tiles_batch(
|
|
3821
|
-
... images_folder='data/images',
|
|
3822
|
-
... masks_folder='data/masks',
|
|
3823
|
-
... output_folder='output/tiles',
|
|
3824
|
-
... match_by_name=False
|
|
3825
|
-
... )
|
|
3160
|
+
ValueError: If no images or masks found, or if counts don't match
|
|
3826
3161
|
"""
|
|
3827
3162
|
|
|
3828
3163
|
import logging
|
|
3829
3164
|
|
|
3830
3165
|
logging.getLogger("rasterio").setLevel(logging.ERROR)
|
|
3831
3166
|
|
|
3832
|
-
# Validate input parameters
|
|
3833
|
-
if masks_folder is not None and masks_file is not None:
|
|
3834
|
-
raise ValueError(
|
|
3835
|
-
"Cannot specify both masks_folder and masks_file. Please use only one."
|
|
3836
|
-
)
|
|
3837
|
-
|
|
3838
|
-
# Default output folder if not specified
|
|
3839
|
-
if output_folder is None:
|
|
3840
|
-
output_folder = os.path.join(images_folder, "tiles")
|
|
3841
|
-
|
|
3842
3167
|
# Default extensions if not provided
|
|
3843
3168
|
if image_extensions is None:
|
|
3844
3169
|
image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
|
|
@@ -3865,37 +3190,9 @@ def export_geotiff_tiles_batch(
|
|
|
3865
3190
|
# Create output folder structure
|
|
3866
3191
|
os.makedirs(output_folder, exist_ok=True)
|
|
3867
3192
|
output_images_dir = os.path.join(output_folder, "images")
|
|
3193
|
+
output_masks_dir = os.path.join(output_folder, "masks")
|
|
3868
3194
|
os.makedirs(output_images_dir, exist_ok=True)
|
|
3869
|
-
|
|
3870
|
-
# Only create masks directory if masks are provided
|
|
3871
|
-
output_masks_dir = None
|
|
3872
|
-
if masks_folder is not None or masks_file is not None:
|
|
3873
|
-
output_masks_dir = os.path.join(output_folder, "masks")
|
|
3874
|
-
os.makedirs(output_masks_dir, exist_ok=True)
|
|
3875
|
-
|
|
3876
|
-
# Create annotation directory based on metadata format (only if masks are provided)
|
|
3877
|
-
ann_dir = None
|
|
3878
|
-
if (masks_folder is not None or masks_file is not None) and metadata_format in [
|
|
3879
|
-
"PASCAL_VOC",
|
|
3880
|
-
"COCO",
|
|
3881
|
-
]:
|
|
3882
|
-
ann_dir = os.path.join(output_folder, "annotations")
|
|
3883
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
3884
|
-
|
|
3885
|
-
# Initialize COCO annotations dictionary (only if masks are provided)
|
|
3886
|
-
coco_annotations = None
|
|
3887
|
-
if (
|
|
3888
|
-
masks_folder is not None or masks_file is not None
|
|
3889
|
-
) and metadata_format == "COCO":
|
|
3890
|
-
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
3891
|
-
|
|
3892
|
-
# Initialize YOLO class set (only if masks are provided)
|
|
3893
|
-
yolo_classes = (
|
|
3894
|
-
set()
|
|
3895
|
-
if (masks_folder is not None or masks_file is not None)
|
|
3896
|
-
and metadata_format == "YOLO"
|
|
3897
|
-
else None
|
|
3898
|
-
)
|
|
3195
|
+
os.makedirs(output_masks_dir, exist_ok=True)
|
|
3899
3196
|
|
|
3900
3197
|
# Get list of image files
|
|
3901
3198
|
image_files = []
|
|
@@ -3903,105 +3200,30 @@ def export_geotiff_tiles_batch(
|
|
|
3903
3200
|
pattern = os.path.join(images_folder, f"*{ext}")
|
|
3904
3201
|
image_files.extend(glob.glob(pattern))
|
|
3905
3202
|
|
|
3203
|
+
# Get list of mask files
|
|
3204
|
+
mask_files = []
|
|
3205
|
+
for ext in mask_extensions:
|
|
3206
|
+
pattern = os.path.join(masks_folder, f"*{ext}")
|
|
3207
|
+
mask_files.extend(glob.glob(pattern))
|
|
3208
|
+
|
|
3906
3209
|
# Sort files for consistent processing
|
|
3907
3210
|
image_files.sort()
|
|
3211
|
+
mask_files.sort()
|
|
3908
3212
|
|
|
3909
3213
|
if not image_files:
|
|
3910
3214
|
raise ValueError(
|
|
3911
3215
|
f"No image files found in {images_folder} with extensions {image_extensions}"
|
|
3912
3216
|
)
|
|
3913
3217
|
|
|
3914
|
-
|
|
3915
|
-
|
|
3916
|
-
|
|
3917
|
-
|
|
3918
|
-
image_mask_pairs = []
|
|
3919
|
-
|
|
3920
|
-
if not has_masks:
|
|
3921
|
-
# Mode 0: No masks - create pairs with None for mask
|
|
3922
|
-
for image_file in image_files:
|
|
3923
|
-
image_mask_pairs.append((image_file, None, None))
|
|
3924
|
-
|
|
3925
|
-
elif use_single_mask_file:
|
|
3926
|
-
# Mode 1: Single vector file covering all images
|
|
3927
|
-
if not os.path.exists(masks_file):
|
|
3928
|
-
raise ValueError(f"Mask file not found: {masks_file}")
|
|
3929
|
-
|
|
3930
|
-
# Load the single mask file once - will be spatially filtered per image
|
|
3931
|
-
single_mask_gdf = gpd.read_file(masks_file)
|
|
3932
|
-
|
|
3933
|
-
if not quiet:
|
|
3934
|
-
print(f"Using single mask file: {masks_file}")
|
|
3935
|
-
print(
|
|
3936
|
-
f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
|
|
3937
|
-
)
|
|
3938
|
-
|
|
3939
|
-
# Create pairs with the same mask file for all images
|
|
3940
|
-
for image_file in image_files:
|
|
3941
|
-
image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
|
|
3942
|
-
|
|
3943
|
-
else:
|
|
3944
|
-
# Mode 2/3: Multiple mask files (vector or raster)
|
|
3945
|
-
# Get list of mask files
|
|
3946
|
-
for ext in mask_extensions:
|
|
3947
|
-
pattern = os.path.join(masks_folder, f"*{ext}")
|
|
3948
|
-
mask_files.extend(glob.glob(pattern))
|
|
3949
|
-
|
|
3950
|
-
# Sort files for consistent processing
|
|
3951
|
-
mask_files.sort()
|
|
3952
|
-
|
|
3953
|
-
if not mask_files:
|
|
3954
|
-
raise ValueError(
|
|
3955
|
-
f"No mask files found in {masks_folder} with extensions {mask_extensions}"
|
|
3956
|
-
)
|
|
3957
|
-
|
|
3958
|
-
# Match images to masks
|
|
3959
|
-
if match_by_name:
|
|
3960
|
-
# Match by base filename
|
|
3961
|
-
image_dict = {
|
|
3962
|
-
os.path.splitext(os.path.basename(f))[0]: f for f in image_files
|
|
3963
|
-
}
|
|
3964
|
-
mask_dict = {
|
|
3965
|
-
os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
|
|
3966
|
-
}
|
|
3967
|
-
|
|
3968
|
-
# Find matching pairs
|
|
3969
|
-
for img_base, img_path in image_dict.items():
|
|
3970
|
-
if img_base in mask_dict:
|
|
3971
|
-
image_mask_pairs.append((img_path, mask_dict[img_base], None))
|
|
3972
|
-
else:
|
|
3973
|
-
if not quiet:
|
|
3974
|
-
print(f"Warning: No mask found for image {img_base}")
|
|
3975
|
-
|
|
3976
|
-
if not image_mask_pairs:
|
|
3977
|
-
# Provide detailed error message with found files
|
|
3978
|
-
image_bases = list(image_dict.keys())
|
|
3979
|
-
mask_bases = list(mask_dict.keys())
|
|
3980
|
-
error_msg = (
|
|
3981
|
-
"No matching image-mask pairs found when matching by filename. "
|
|
3982
|
-
"Check that image and mask files have matching base names.\n"
|
|
3983
|
-
f"Found {len(image_bases)} image(s): "
|
|
3984
|
-
f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
|
|
3985
|
-
f"{'...' if len(image_bases) > 5 else ''}\n"
|
|
3986
|
-
f"Found {len(mask_bases)} mask(s): "
|
|
3987
|
-
f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
|
|
3988
|
-
f"{'...' if len(mask_bases) > 5 else ''}\n"
|
|
3989
|
-
"Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
|
|
3990
|
-
)
|
|
3991
|
-
raise ValueError(error_msg)
|
|
3992
|
-
|
|
3993
|
-
else:
|
|
3994
|
-
# Match by sorted order
|
|
3995
|
-
if len(image_files) != len(mask_files):
|
|
3996
|
-
raise ValueError(
|
|
3997
|
-
f"Number of image files ({len(image_files)}) does not match "
|
|
3998
|
-
f"number of mask files ({len(mask_files)}) when matching by sorted order. "
|
|
3999
|
-
f"Use match_by_name=True for filename-based matching."
|
|
4000
|
-
)
|
|
3218
|
+
if not mask_files:
|
|
3219
|
+
raise ValueError(
|
|
3220
|
+
f"No mask files found in {masks_folder} with extensions {mask_extensions}"
|
|
3221
|
+
)
|
|
4001
3222
|
|
|
4002
|
-
|
|
4003
|
-
|
|
4004
|
-
|
|
3223
|
+
if len(image_files) != len(mask_files):
|
|
3224
|
+
raise ValueError(
|
|
3225
|
+
f"Number of image files ({len(image_files)}) does not match number of mask files ({len(mask_files)})"
|
|
3226
|
+
)
|
|
4005
3227
|
|
|
4006
3228
|
# Initialize batch statistics
|
|
4007
3229
|
batch_stats = {
|
|
@@ -4015,28 +3237,23 @@ def export_geotiff_tiles_batch(
|
|
|
4015
3237
|
}
|
|
4016
3238
|
|
|
4017
3239
|
if not quiet:
|
|
4018
|
-
|
|
4019
|
-
|
|
4020
|
-
|
|
4021
|
-
|
|
4022
|
-
elif use_single_mask_file:
|
|
4023
|
-
print(f"Found {len(image_files)} image files to process")
|
|
4024
|
-
print(f"Using single mask file: {masks_file}")
|
|
4025
|
-
else:
|
|
4026
|
-
print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
|
|
4027
|
-
print(f"Processing batch from {images_folder} and {masks_folder}")
|
|
3240
|
+
print(
|
|
3241
|
+
f"Found {len(image_files)} image files and {len(mask_files)} mask files to process"
|
|
3242
|
+
)
|
|
3243
|
+
print(f"Processing batch from {images_folder} and {masks_folder}")
|
|
4028
3244
|
print(f"Output folder: {output_folder}")
|
|
4029
3245
|
print("-" * 60)
|
|
4030
3246
|
|
|
4031
3247
|
# Global tile counter for unique naming
|
|
4032
3248
|
global_tile_counter = 0
|
|
4033
3249
|
|
|
4034
|
-
# Process each image-mask pair
|
|
4035
|
-
for idx, (image_file, mask_file
|
|
3250
|
+
# Process each image-mask pair by sorted order
|
|
3251
|
+
for idx, (image_file, mask_file) in enumerate(
|
|
4036
3252
|
tqdm(
|
|
4037
|
-
|
|
3253
|
+
zip(image_files, mask_files),
|
|
4038
3254
|
desc="Processing image pairs",
|
|
4039
3255
|
disable=quiet,
|
|
3256
|
+
total=len(image_files),
|
|
4040
3257
|
)
|
|
4041
3258
|
):
|
|
4042
3259
|
batch_stats["total_image_pairs"] += 1
|
|
@@ -4048,17 +3265,9 @@ def export_geotiff_tiles_batch(
|
|
|
4048
3265
|
if not quiet:
|
|
4049
3266
|
print(f"\nProcessing: {base_name}")
|
|
4050
3267
|
print(f" Image: {os.path.basename(image_file)}")
|
|
4051
|
-
|
|
4052
|
-
if use_single_mask_file:
|
|
4053
|
-
print(
|
|
4054
|
-
f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
|
|
4055
|
-
)
|
|
4056
|
-
else:
|
|
4057
|
-
print(f" Mask: {os.path.basename(mask_file)}")
|
|
4058
|
-
else:
|
|
4059
|
-
print(f" Mask: None (images only)")
|
|
3268
|
+
print(f" Mask: {os.path.basename(mask_file)}")
|
|
4060
3269
|
|
|
4061
|
-
# Process the image-mask pair
|
|
3270
|
+
# Process the image-mask pair manually to get direct control over tile saving
|
|
4062
3271
|
tiles_generated = _process_image_mask_pair(
|
|
4063
3272
|
image_file=image_file,
|
|
4064
3273
|
mask_file=mask_file,
|
|
@@ -4074,15 +3283,6 @@ def export_geotiff_tiles_batch(
|
|
|
4074
3283
|
all_touched=all_touched,
|
|
4075
3284
|
skip_empty_tiles=skip_empty_tiles,
|
|
4076
3285
|
quiet=quiet,
|
|
4077
|
-
mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
|
|
4078
|
-
use_single_mask_file=use_single_mask_file,
|
|
4079
|
-
metadata_format=metadata_format,
|
|
4080
|
-
ann_dir=(
|
|
4081
|
-
ann_dir
|
|
4082
|
-
if "ann_dir" in locals()
|
|
4083
|
-
and metadata_format in ["PASCAL_VOC", "COCO"]
|
|
4084
|
-
else None
|
|
4085
|
-
),
|
|
4086
3286
|
)
|
|
4087
3287
|
|
|
4088
3288
|
# Update counters
|
|
@@ -4104,23 +3304,6 @@ def export_geotiff_tiles_batch(
|
|
|
4104
3304
|
}
|
|
4105
3305
|
)
|
|
4106
3306
|
|
|
4107
|
-
# Aggregate COCO annotations
|
|
4108
|
-
if metadata_format == "COCO" and "coco_data" in tiles_generated:
|
|
4109
|
-
coco_data = tiles_generated["coco_data"]
|
|
4110
|
-
# Add images and annotations
|
|
4111
|
-
coco_annotations["images"].extend(coco_data.get("images", []))
|
|
4112
|
-
coco_annotations["annotations"].extend(coco_data.get("annotations", []))
|
|
4113
|
-
# Merge categories (avoid duplicates)
|
|
4114
|
-
for cat in coco_data.get("categories", []):
|
|
4115
|
-
if not any(
|
|
4116
|
-
c["id"] == cat["id"] for c in coco_annotations["categories"]
|
|
4117
|
-
):
|
|
4118
|
-
coco_annotations["categories"].append(cat)
|
|
4119
|
-
|
|
4120
|
-
# Aggregate YOLO classes
|
|
4121
|
-
if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
|
|
4122
|
-
yolo_classes.update(tiles_generated["yolo_classes"])
|
|
4123
|
-
|
|
4124
3307
|
except Exception as e:
|
|
4125
3308
|
if not quiet:
|
|
4126
3309
|
print(f"ERROR processing {base_name}: {e}")
|
|
@@ -4129,33 +3312,6 @@ def export_geotiff_tiles_batch(
|
|
|
4129
3312
|
)
|
|
4130
3313
|
batch_stats["errors"] += 1
|
|
4131
3314
|
|
|
4132
|
-
# Save aggregated COCO annotations
|
|
4133
|
-
if metadata_format == "COCO" and coco_annotations:
|
|
4134
|
-
import json
|
|
4135
|
-
|
|
4136
|
-
coco_path = os.path.join(ann_dir, "instances.json")
|
|
4137
|
-
with open(coco_path, "w") as f:
|
|
4138
|
-
json.dump(coco_annotations, f, indent=2)
|
|
4139
|
-
if not quiet:
|
|
4140
|
-
print(f"\nSaved COCO annotations: {coco_path}")
|
|
4141
|
-
print(
|
|
4142
|
-
f" Images: {len(coco_annotations['images'])}, "
|
|
4143
|
-
f"Annotations: {len(coco_annotations['annotations'])}, "
|
|
4144
|
-
f"Categories: {len(coco_annotations['categories'])}"
|
|
4145
|
-
)
|
|
4146
|
-
|
|
4147
|
-
# Save aggregated YOLO classes
|
|
4148
|
-
if metadata_format == "YOLO" and yolo_classes:
|
|
4149
|
-
classes_path = os.path.join(output_folder, "labels", "classes.txt")
|
|
4150
|
-
os.makedirs(os.path.dirname(classes_path), exist_ok=True)
|
|
4151
|
-
sorted_classes = sorted(yolo_classes)
|
|
4152
|
-
with open(classes_path, "w") as f:
|
|
4153
|
-
for cls in sorted_classes:
|
|
4154
|
-
f.write(f"{cls}\n")
|
|
4155
|
-
if not quiet:
|
|
4156
|
-
print(f"\nSaved YOLO classes: {classes_path}")
|
|
4157
|
-
print(f" Total classes: {len(sorted_classes)}")
|
|
4158
|
-
|
|
4159
3315
|
# Print batch summary
|
|
4160
3316
|
if not quiet:
|
|
4161
3317
|
print("\n" + "=" * 60)
|
|
@@ -4178,12 +3334,7 @@ def export_geotiff_tiles_batch(
|
|
|
4178
3334
|
|
|
4179
3335
|
print(f"Output saved to: {output_folder}")
|
|
4180
3336
|
print(f" Images: {output_images_dir}")
|
|
4181
|
-
|
|
4182
|
-
print(f" Masks: {output_masks_dir}")
|
|
4183
|
-
if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
|
|
4184
|
-
print(f" Annotations: {ann_dir}")
|
|
4185
|
-
elif metadata_format == "YOLO":
|
|
4186
|
-
print(f" Labels: {os.path.join(output_folder, 'labels')}")
|
|
3337
|
+
print(f" Masks: {output_masks_dir}")
|
|
4187
3338
|
|
|
4188
3339
|
# List failed files if any
|
|
4189
3340
|
if batch_stats["failed_files"]:
|
|
@@ -4209,26 +3360,18 @@ def _process_image_mask_pair(
|
|
|
4209
3360
|
all_touched=True,
|
|
4210
3361
|
skip_empty_tiles=False,
|
|
4211
3362
|
quiet=False,
|
|
4212
|
-
mask_gdf=None,
|
|
4213
|
-
use_single_mask_file=False,
|
|
4214
|
-
metadata_format="PASCAL_VOC",
|
|
4215
|
-
ann_dir=None,
|
|
4216
3363
|
):
|
|
4217
3364
|
"""
|
|
4218
3365
|
Process a single image-mask pair and save tiles directly to output directories.
|
|
4219
3366
|
|
|
4220
|
-
Args:
|
|
4221
|
-
mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
|
|
4222
|
-
use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
|
|
4223
|
-
|
|
4224
3367
|
Returns:
|
|
4225
3368
|
dict: Statistics for this image-mask pair
|
|
4226
3369
|
"""
|
|
4227
3370
|
import warnings
|
|
4228
3371
|
|
|
4229
|
-
# Determine if mask data is raster or vector
|
|
3372
|
+
# Determine if mask data is raster or vector
|
|
4230
3373
|
is_class_data_raster = False
|
|
4231
|
-
if
|
|
3374
|
+
if isinstance(mask_file, str):
|
|
4232
3375
|
file_ext = Path(mask_file).suffix.lower()
|
|
4233
3376
|
# Common raster extensions
|
|
4234
3377
|
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
@@ -4245,13 +3388,6 @@ def _process_image_mask_pair(
|
|
|
4245
3388
|
"errors": 0,
|
|
4246
3389
|
}
|
|
4247
3390
|
|
|
4248
|
-
# Initialize COCO/YOLO tracking for this image
|
|
4249
|
-
if metadata_format == "COCO":
|
|
4250
|
-
stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
|
|
4251
|
-
coco_ann_id = 0
|
|
4252
|
-
if metadata_format == "YOLO":
|
|
4253
|
-
stats["yolo_classes"] = set()
|
|
4254
|
-
|
|
4255
3391
|
# Open the input raster
|
|
4256
3392
|
with rasterio.open(image_file) as src:
|
|
4257
3393
|
# Calculate number of tiles
|
|
@@ -4262,10 +3398,10 @@ def _process_image_mask_pair(
|
|
|
4262
3398
|
if max_tiles is None:
|
|
4263
3399
|
max_tiles = total_tiles
|
|
4264
3400
|
|
|
4265
|
-
# Process classification data
|
|
3401
|
+
# Process classification data
|
|
4266
3402
|
class_to_id = {}
|
|
4267
3403
|
|
|
4268
|
-
if
|
|
3404
|
+
if is_class_data_raster:
|
|
4269
3405
|
# Load raster class data
|
|
4270
3406
|
with rasterio.open(mask_file) as class_src:
|
|
4271
3407
|
# Check if raster CRS matches
|
|
@@ -4292,39 +3428,14 @@ def _process_image_mask_pair(
|
|
|
4292
3428
|
|
|
4293
3429
|
# Create class mapping
|
|
4294
3430
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
4295
|
-
|
|
3431
|
+
else:
|
|
4296
3432
|
# Load vector class data
|
|
4297
3433
|
try:
|
|
4298
|
-
|
|
4299
|
-
# Using pre-loaded single mask file - spatially filter to image bounds
|
|
4300
|
-
# Get image bounds
|
|
4301
|
-
image_bounds = box(*src.bounds)
|
|
4302
|
-
image_gdf = gpd.GeoDataFrame(
|
|
4303
|
-
{"geometry": [image_bounds]}, crs=src.crs
|
|
4304
|
-
)
|
|
4305
|
-
|
|
4306
|
-
# Reproject mask if needed
|
|
4307
|
-
if mask_gdf.crs != src.crs:
|
|
4308
|
-
mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
|
|
4309
|
-
else:
|
|
4310
|
-
mask_gdf_reprojected = mask_gdf
|
|
4311
|
-
|
|
4312
|
-
# Spatially filter features that intersect with image bounds
|
|
4313
|
-
gdf = mask_gdf_reprojected[
|
|
4314
|
-
mask_gdf_reprojected.intersects(image_bounds)
|
|
4315
|
-
].copy()
|
|
4316
|
-
|
|
4317
|
-
if not quiet and len(gdf) > 0:
|
|
4318
|
-
print(
|
|
4319
|
-
f" Filtered to {len(gdf)} features intersecting image bounds"
|
|
4320
|
-
)
|
|
4321
|
-
else:
|
|
4322
|
-
# Load individual mask file
|
|
4323
|
-
gdf = gpd.read_file(mask_file)
|
|
3434
|
+
gdf = gpd.read_file(mask_file)
|
|
4324
3435
|
|
|
4325
|
-
|
|
4326
|
-
|
|
4327
|
-
|
|
3436
|
+
# Always reproject to match raster CRS
|
|
3437
|
+
if gdf.crs != src.crs:
|
|
3438
|
+
gdf = gdf.to_crs(src.crs)
|
|
4328
3439
|
|
|
4329
3440
|
# Apply buffer if specified
|
|
4330
3441
|
if buffer_radius > 0:
|
|
@@ -4344,6 +3455,9 @@ def _process_image_mask_pair(
|
|
|
4344
3455
|
tile_index = 0
|
|
4345
3456
|
for y in range(num_tiles_y):
|
|
4346
3457
|
for x in range(num_tiles_x):
|
|
3458
|
+
if tile_index >= max_tiles:
|
|
3459
|
+
break
|
|
3460
|
+
|
|
4347
3461
|
# Calculate window coordinates
|
|
4348
3462
|
window_x = x * stride
|
|
4349
3463
|
window_y = y * stride
|
|
@@ -4368,12 +3482,12 @@ def _process_image_mask_pair(
|
|
|
4368
3482
|
|
|
4369
3483
|
window_bounds = box(minx, miny, maxx, maxy)
|
|
4370
3484
|
|
|
4371
|
-
# Create label mask
|
|
3485
|
+
# Create label mask
|
|
4372
3486
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
4373
3487
|
has_features = False
|
|
4374
3488
|
|
|
4375
|
-
# Process classification data to create labels
|
|
4376
|
-
if
|
|
3489
|
+
# Process classification data to create labels
|
|
3490
|
+
if is_class_data_raster:
|
|
4377
3491
|
# For raster class data
|
|
4378
3492
|
with rasterio.open(mask_file) as class_src:
|
|
4379
3493
|
# Get corresponding window in class raster
|
|
@@ -4406,7 +3520,7 @@ def _process_image_mask_pair(
|
|
|
4406
3520
|
if not quiet:
|
|
4407
3521
|
print(f"Error reading class raster window: {e}")
|
|
4408
3522
|
stats["errors"] += 1
|
|
4409
|
-
|
|
3523
|
+
else:
|
|
4410
3524
|
# For vector class data
|
|
4411
3525
|
# Find features that intersect with window
|
|
4412
3526
|
window_features = gdf[gdf.intersects(window_bounds)]
|
|
@@ -4444,14 +3558,11 @@ def _process_image_mask_pair(
|
|
|
4444
3558
|
print(f"Error rasterizing feature {idx}: {e}")
|
|
4445
3559
|
stats["errors"] += 1
|
|
4446
3560
|
|
|
4447
|
-
# Skip tile if no features and skip_empty_tiles is True
|
|
4448
|
-
if
|
|
3561
|
+
# Skip tile if no features and skip_empty_tiles is True
|
|
3562
|
+
if skip_empty_tiles and not has_features:
|
|
3563
|
+
tile_index += 1
|
|
4449
3564
|
continue
|
|
4450
3565
|
|
|
4451
|
-
# Check if we've reached max_tiles before saving
|
|
4452
|
-
if tile_index >= max_tiles:
|
|
4453
|
-
break
|
|
4454
|
-
|
|
4455
3566
|
# Generate unique tile name
|
|
4456
3567
|
tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
|
|
4457
3568
|
|
|
@@ -4482,225 +3593,29 @@ def _process_image_mask_pair(
|
|
|
4482
3593
|
print(f"ERROR saving image GeoTIFF: {e}")
|
|
4483
3594
|
stats["errors"] += 1
|
|
4484
3595
|
|
|
4485
|
-
#
|
|
4486
|
-
|
|
4487
|
-
|
|
4488
|
-
|
|
4489
|
-
|
|
4490
|
-
|
|
4491
|
-
|
|
4492
|
-
|
|
4493
|
-
|
|
4494
|
-
|
|
4495
|
-
"transform": window_transform,
|
|
4496
|
-
}
|
|
4497
|
-
|
|
4498
|
-
label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
|
|
4499
|
-
try:
|
|
4500
|
-
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
4501
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
|
4502
|
-
|
|
4503
|
-
if has_features:
|
|
4504
|
-
stats["tiles_with_features"] += 1
|
|
4505
|
-
except Exception as e:
|
|
4506
|
-
if not quiet:
|
|
4507
|
-
print(f"ERROR saving label GeoTIFF: {e}")
|
|
4508
|
-
stats["errors"] += 1
|
|
4509
|
-
|
|
4510
|
-
# Generate annotation metadata based on format (only if mask_file is provided)
|
|
4511
|
-
if (
|
|
4512
|
-
mask_file is not None
|
|
4513
|
-
and metadata_format == "PASCAL_VOC"
|
|
4514
|
-
and ann_dir
|
|
4515
|
-
):
|
|
4516
|
-
# Create PASCAL VOC XML annotation
|
|
4517
|
-
from lxml import etree as ET
|
|
4518
|
-
|
|
4519
|
-
annotation = ET.Element("annotation")
|
|
4520
|
-
ET.SubElement(annotation, "folder").text = os.path.basename(
|
|
4521
|
-
output_images_dir
|
|
4522
|
-
)
|
|
4523
|
-
ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
|
|
4524
|
-
ET.SubElement(annotation, "path").text = image_path
|
|
4525
|
-
|
|
4526
|
-
source = ET.SubElement(annotation, "source")
|
|
4527
|
-
ET.SubElement(source, "database").text = "GeoAI"
|
|
4528
|
-
|
|
4529
|
-
size = ET.SubElement(annotation, "size")
|
|
4530
|
-
ET.SubElement(size, "width").text = str(tile_size)
|
|
4531
|
-
ET.SubElement(size, "height").text = str(tile_size)
|
|
4532
|
-
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
4533
|
-
|
|
4534
|
-
ET.SubElement(annotation, "segmented").text = "1"
|
|
4535
|
-
|
|
4536
|
-
# Find connected components for instance segmentation
|
|
4537
|
-
from scipy import ndimage
|
|
4538
|
-
|
|
4539
|
-
for class_id in np.unique(label_mask):
|
|
4540
|
-
if class_id == 0:
|
|
4541
|
-
continue
|
|
4542
|
-
|
|
4543
|
-
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4544
|
-
labeled_array, num_features = ndimage.label(class_mask)
|
|
4545
|
-
|
|
4546
|
-
for instance_id in range(1, num_features + 1):
|
|
4547
|
-
instance_mask = labeled_array == instance_id
|
|
4548
|
-
coords = np.argwhere(instance_mask)
|
|
4549
|
-
|
|
4550
|
-
if len(coords) == 0:
|
|
4551
|
-
continue
|
|
4552
|
-
|
|
4553
|
-
ymin, xmin = coords.min(axis=0)
|
|
4554
|
-
ymax, xmax = coords.max(axis=0)
|
|
4555
|
-
|
|
4556
|
-
obj = ET.SubElement(annotation, "object")
|
|
4557
|
-
class_name = next(
|
|
4558
|
-
(k for k, v in class_to_id.items() if v == class_id),
|
|
4559
|
-
str(class_id),
|
|
4560
|
-
)
|
|
4561
|
-
ET.SubElement(obj, "name").text = str(class_name)
|
|
4562
|
-
ET.SubElement(obj, "pose").text = "Unspecified"
|
|
4563
|
-
ET.SubElement(obj, "truncated").text = "0"
|
|
4564
|
-
ET.SubElement(obj, "difficult").text = "0"
|
|
4565
|
-
|
|
4566
|
-
bndbox = ET.SubElement(obj, "bndbox")
|
|
4567
|
-
ET.SubElement(bndbox, "xmin").text = str(int(xmin))
|
|
4568
|
-
ET.SubElement(bndbox, "ymin").text = str(int(ymin))
|
|
4569
|
-
ET.SubElement(bndbox, "xmax").text = str(int(xmax))
|
|
4570
|
-
ET.SubElement(bndbox, "ymax").text = str(int(ymax))
|
|
4571
|
-
|
|
4572
|
-
# Save XML file
|
|
4573
|
-
xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
|
|
4574
|
-
tree = ET.ElementTree(annotation)
|
|
4575
|
-
tree.write(xml_path, pretty_print=True, encoding="utf-8")
|
|
4576
|
-
|
|
4577
|
-
elif mask_file is not None and metadata_format == "COCO":
|
|
4578
|
-
# Add COCO image entry
|
|
4579
|
-
image_id = int(global_tile_counter + tile_index)
|
|
4580
|
-
stats["coco_data"]["images"].append(
|
|
4581
|
-
{
|
|
4582
|
-
"id": image_id,
|
|
4583
|
-
"file_name": f"{tile_name}.tif",
|
|
4584
|
-
"width": int(tile_size),
|
|
4585
|
-
"height": int(tile_size),
|
|
4586
|
-
}
|
|
4587
|
-
)
|
|
4588
|
-
|
|
4589
|
-
# Add COCO categories (only once per unique class)
|
|
4590
|
-
for class_val, class_id in class_to_id.items():
|
|
4591
|
-
if not any(
|
|
4592
|
-
c["id"] == class_id
|
|
4593
|
-
for c in stats["coco_data"]["categories"]
|
|
4594
|
-
):
|
|
4595
|
-
stats["coco_data"]["categories"].append(
|
|
4596
|
-
{
|
|
4597
|
-
"id": int(class_id),
|
|
4598
|
-
"name": str(class_val),
|
|
4599
|
-
"supercategory": "object",
|
|
4600
|
-
}
|
|
4601
|
-
)
|
|
4602
|
-
|
|
4603
|
-
# Add COCO annotations (instance segmentation)
|
|
4604
|
-
from scipy import ndimage
|
|
4605
|
-
from skimage import measure
|
|
4606
|
-
|
|
4607
|
-
for class_id in np.unique(label_mask):
|
|
4608
|
-
if class_id == 0:
|
|
4609
|
-
continue
|
|
4610
|
-
|
|
4611
|
-
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4612
|
-
labeled_array, num_features = ndimage.label(class_mask)
|
|
4613
|
-
|
|
4614
|
-
for instance_id in range(1, num_features + 1):
|
|
4615
|
-
instance_mask = (labeled_array == instance_id).astype(
|
|
4616
|
-
np.uint8
|
|
4617
|
-
)
|
|
4618
|
-
coords = np.argwhere(instance_mask)
|
|
4619
|
-
|
|
4620
|
-
if len(coords) == 0:
|
|
4621
|
-
continue
|
|
4622
|
-
|
|
4623
|
-
ymin, xmin = coords.min(axis=0)
|
|
4624
|
-
ymax, xmax = coords.max(axis=0)
|
|
4625
|
-
|
|
4626
|
-
bbox = [
|
|
4627
|
-
int(xmin),
|
|
4628
|
-
int(ymin),
|
|
4629
|
-
int(xmax - xmin),
|
|
4630
|
-
int(ymax - ymin),
|
|
4631
|
-
]
|
|
4632
|
-
area = int(np.sum(instance_mask))
|
|
4633
|
-
|
|
4634
|
-
# Find contours for segmentation
|
|
4635
|
-
contours = measure.find_contours(instance_mask, 0.5)
|
|
4636
|
-
segmentation = []
|
|
4637
|
-
for contour in contours:
|
|
4638
|
-
contour = np.flip(contour, axis=1)
|
|
4639
|
-
segmentation_points = contour.ravel().tolist()
|
|
4640
|
-
if len(segmentation_points) >= 6:
|
|
4641
|
-
segmentation.append(segmentation_points)
|
|
4642
|
-
|
|
4643
|
-
if segmentation:
|
|
4644
|
-
stats["coco_data"]["annotations"].append(
|
|
4645
|
-
{
|
|
4646
|
-
"id": int(coco_ann_id),
|
|
4647
|
-
"image_id": int(image_id),
|
|
4648
|
-
"category_id": int(class_id),
|
|
4649
|
-
"bbox": bbox,
|
|
4650
|
-
"area": area,
|
|
4651
|
-
"segmentation": segmentation,
|
|
4652
|
-
"iscrowd": 0,
|
|
4653
|
-
}
|
|
4654
|
-
)
|
|
4655
|
-
coco_ann_id += 1
|
|
4656
|
-
|
|
4657
|
-
elif mask_file is not None and metadata_format == "YOLO":
|
|
4658
|
-
# Create YOLO labels directory if needed
|
|
4659
|
-
labels_dir = os.path.join(
|
|
4660
|
-
os.path.dirname(output_images_dir), "labels"
|
|
4661
|
-
)
|
|
4662
|
-
os.makedirs(labels_dir, exist_ok=True)
|
|
4663
|
-
|
|
4664
|
-
# Generate YOLO annotation file
|
|
4665
|
-
yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
|
|
4666
|
-
from scipy import ndimage
|
|
4667
|
-
|
|
4668
|
-
with open(yolo_path, "w") as yolo_file:
|
|
4669
|
-
for class_id in np.unique(label_mask):
|
|
4670
|
-
if class_id == 0:
|
|
4671
|
-
continue
|
|
4672
|
-
|
|
4673
|
-
# Track class for classes.txt
|
|
4674
|
-
class_name = next(
|
|
4675
|
-
(k for k, v in class_to_id.items() if v == class_id),
|
|
4676
|
-
str(class_id),
|
|
4677
|
-
)
|
|
4678
|
-
stats["yolo_classes"].add(class_name)
|
|
4679
|
-
|
|
4680
|
-
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4681
|
-
labeled_array, num_features = ndimage.label(class_mask)
|
|
4682
|
-
|
|
4683
|
-
for instance_id in range(1, num_features + 1):
|
|
4684
|
-
instance_mask = labeled_array == instance_id
|
|
4685
|
-
coords = np.argwhere(instance_mask)
|
|
4686
|
-
|
|
4687
|
-
if len(coords) == 0:
|
|
4688
|
-
continue
|
|
4689
|
-
|
|
4690
|
-
ymin, xmin = coords.min(axis=0)
|
|
4691
|
-
ymax, xmax = coords.max(axis=0)
|
|
3596
|
+
# Create profile for label GeoTIFF
|
|
3597
|
+
label_profile = {
|
|
3598
|
+
"driver": "GTiff",
|
|
3599
|
+
"height": tile_size,
|
|
3600
|
+
"width": tile_size,
|
|
3601
|
+
"count": 1,
|
|
3602
|
+
"dtype": "uint8",
|
|
3603
|
+
"crs": src.crs,
|
|
3604
|
+
"transform": window_transform,
|
|
3605
|
+
}
|
|
4692
3606
|
|
|
4693
|
-
|
|
4694
|
-
|
|
4695
|
-
|
|
4696
|
-
|
|
4697
|
-
|
|
3607
|
+
# Export label as GeoTIFF
|
|
3608
|
+
label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
|
|
3609
|
+
try:
|
|
3610
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
3611
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
|
4698
3612
|
|
|
4699
|
-
|
|
4700
|
-
|
|
4701
|
-
|
|
4702
|
-
|
|
4703
|
-
|
|
3613
|
+
if has_features:
|
|
3614
|
+
stats["tiles_with_features"] += 1
|
|
3615
|
+
except Exception as e:
|
|
3616
|
+
if not quiet:
|
|
3617
|
+
print(f"ERROR saving label GeoTIFF: {e}")
|
|
3618
|
+
stats["errors"] += 1
|
|
4704
3619
|
|
|
4705
3620
|
tile_index += 1
|
|
4706
3621
|
if tile_index >= max_tiles:
|
|
@@ -4712,179 +3627,6 @@ def _process_image_mask_pair(
|
|
|
4712
3627
|
return stats
|
|
4713
3628
|
|
|
4714
3629
|
|
|
4715
|
-
def display_training_tiles(
|
|
4716
|
-
output_dir,
|
|
4717
|
-
num_tiles=6,
|
|
4718
|
-
figsize=(18, 6),
|
|
4719
|
-
cmap="gray",
|
|
4720
|
-
save_path=None,
|
|
4721
|
-
):
|
|
4722
|
-
"""
|
|
4723
|
-
Display image and mask tile pairs from training data output.
|
|
4724
|
-
|
|
4725
|
-
Args:
|
|
4726
|
-
output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
|
|
4727
|
-
num_tiles (int): Number of tile pairs to display (default: 6)
|
|
4728
|
-
figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
|
|
4729
|
-
cmap (str): Colormap for mask display (default: 'gray')
|
|
4730
|
-
save_path (str, optional): If provided, save figure to this path instead of displaying
|
|
4731
|
-
|
|
4732
|
-
Returns:
|
|
4733
|
-
tuple: (fig, axes) matplotlib figure and axes objects
|
|
4734
|
-
|
|
4735
|
-
Example:
|
|
4736
|
-
>>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
|
|
4737
|
-
>>> # Or save to file
|
|
4738
|
-
>>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
|
|
4739
|
-
"""
|
|
4740
|
-
import matplotlib.pyplot as plt
|
|
4741
|
-
|
|
4742
|
-
# Get list of image tiles
|
|
4743
|
-
images_dir = os.path.join(output_dir, "images")
|
|
4744
|
-
if not os.path.exists(images_dir):
|
|
4745
|
-
raise ValueError(f"Images directory not found: {images_dir}")
|
|
4746
|
-
|
|
4747
|
-
image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
|
|
4748
|
-
|
|
4749
|
-
if not image_tiles:
|
|
4750
|
-
raise ValueError(f"No image tiles found in {images_dir}")
|
|
4751
|
-
|
|
4752
|
-
# Limit to available tiles
|
|
4753
|
-
num_tiles = min(num_tiles, len(image_tiles))
|
|
4754
|
-
|
|
4755
|
-
# Create figure with subplots
|
|
4756
|
-
fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
|
|
4757
|
-
|
|
4758
|
-
# Handle case where num_tiles is 1
|
|
4759
|
-
if num_tiles == 1:
|
|
4760
|
-
axes = axes.reshape(2, 1)
|
|
4761
|
-
|
|
4762
|
-
for idx, tile_name in enumerate(image_tiles):
|
|
4763
|
-
# Load and display image tile
|
|
4764
|
-
image_path = os.path.join(output_dir, "images", tile_name)
|
|
4765
|
-
with rasterio.open(image_path) as src:
|
|
4766
|
-
show(src, ax=axes[0, idx], title=f"Image {idx+1}")
|
|
4767
|
-
|
|
4768
|
-
# Load and display mask tile
|
|
4769
|
-
mask_path = os.path.join(output_dir, "masks", tile_name)
|
|
4770
|
-
if os.path.exists(mask_path):
|
|
4771
|
-
with rasterio.open(mask_path) as src:
|
|
4772
|
-
show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
|
|
4773
|
-
else:
|
|
4774
|
-
axes[1, idx].text(
|
|
4775
|
-
0.5,
|
|
4776
|
-
0.5,
|
|
4777
|
-
"Mask not found",
|
|
4778
|
-
ha="center",
|
|
4779
|
-
va="center",
|
|
4780
|
-
transform=axes[1, idx].transAxes,
|
|
4781
|
-
)
|
|
4782
|
-
axes[1, idx].set_title(f"Mask {idx+1}")
|
|
4783
|
-
|
|
4784
|
-
plt.tight_layout()
|
|
4785
|
-
|
|
4786
|
-
# Save or show
|
|
4787
|
-
if save_path:
|
|
4788
|
-
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
4789
|
-
plt.close(fig)
|
|
4790
|
-
print(f"Figure saved to: {save_path}")
|
|
4791
|
-
else:
|
|
4792
|
-
plt.show()
|
|
4793
|
-
|
|
4794
|
-
return fig, axes
|
|
4795
|
-
|
|
4796
|
-
|
|
4797
|
-
def display_image_with_vector(
|
|
4798
|
-
image_path,
|
|
4799
|
-
vector_path,
|
|
4800
|
-
figsize=(16, 8),
|
|
4801
|
-
vector_color="red",
|
|
4802
|
-
vector_linewidth=1,
|
|
4803
|
-
vector_facecolor="none",
|
|
4804
|
-
save_path=None,
|
|
4805
|
-
):
|
|
4806
|
-
"""
|
|
4807
|
-
Display a raster image alongside the same image with vector overlay.
|
|
4808
|
-
|
|
4809
|
-
Args:
|
|
4810
|
-
image_path (str): Path to raster image file
|
|
4811
|
-
vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
|
|
4812
|
-
figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
|
|
4813
|
-
vector_color (str): Edge color for vector features (default: 'red')
|
|
4814
|
-
vector_linewidth (float): Line width for vector features (default: 1)
|
|
4815
|
-
vector_facecolor (str): Fill color for vector features (default: 'none')
|
|
4816
|
-
save_path (str, optional): If provided, save figure to this path instead of displaying
|
|
4817
|
-
|
|
4818
|
-
Returns:
|
|
4819
|
-
tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
|
|
4820
|
-
|
|
4821
|
-
Example:
|
|
4822
|
-
>>> fig, axes, info = display_image_with_vector(
|
|
4823
|
-
... 'image.tif',
|
|
4824
|
-
... 'buildings.geojson',
|
|
4825
|
-
... vector_color='blue'
|
|
4826
|
-
... )
|
|
4827
|
-
>>> print(f"Number of features: {info['num_features']}")
|
|
4828
|
-
"""
|
|
4829
|
-
import matplotlib.pyplot as plt
|
|
4830
|
-
|
|
4831
|
-
# Validate inputs
|
|
4832
|
-
if not os.path.exists(image_path):
|
|
4833
|
-
raise ValueError(f"Image file not found: {image_path}")
|
|
4834
|
-
if not os.path.exists(vector_path):
|
|
4835
|
-
raise ValueError(f"Vector file not found: {vector_path}")
|
|
4836
|
-
|
|
4837
|
-
# Create figure
|
|
4838
|
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
4839
|
-
|
|
4840
|
-
# Load and display image
|
|
4841
|
-
with rasterio.open(image_path) as src:
|
|
4842
|
-
# Plot image only
|
|
4843
|
-
show(src, ax=ax1, title="Image")
|
|
4844
|
-
|
|
4845
|
-
# Load vector data
|
|
4846
|
-
vector_data = gpd.read_file(vector_path)
|
|
4847
|
-
|
|
4848
|
-
# Reproject to image CRS if needed
|
|
4849
|
-
if vector_data.crs != src.crs:
|
|
4850
|
-
vector_data = vector_data.to_crs(src.crs)
|
|
4851
|
-
|
|
4852
|
-
# Plot image with vector overlay
|
|
4853
|
-
show(
|
|
4854
|
-
src,
|
|
4855
|
-
ax=ax2,
|
|
4856
|
-
title=f"Image with {len(vector_data)} Vector Features",
|
|
4857
|
-
)
|
|
4858
|
-
vector_data.plot(
|
|
4859
|
-
ax=ax2,
|
|
4860
|
-
facecolor=vector_facecolor,
|
|
4861
|
-
edgecolor=vector_color,
|
|
4862
|
-
linewidth=vector_linewidth,
|
|
4863
|
-
)
|
|
4864
|
-
|
|
4865
|
-
# Collect metadata
|
|
4866
|
-
info = {
|
|
4867
|
-
"image_shape": src.shape,
|
|
4868
|
-
"image_crs": src.crs,
|
|
4869
|
-
"image_bounds": src.bounds,
|
|
4870
|
-
"num_features": len(vector_data),
|
|
4871
|
-
"vector_crs": vector_data.crs,
|
|
4872
|
-
"vector_bounds": vector_data.total_bounds,
|
|
4873
|
-
}
|
|
4874
|
-
|
|
4875
|
-
plt.tight_layout()
|
|
4876
|
-
|
|
4877
|
-
# Save or show
|
|
4878
|
-
if save_path:
|
|
4879
|
-
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
4880
|
-
plt.close(fig)
|
|
4881
|
-
print(f"Figure saved to: {save_path}")
|
|
4882
|
-
else:
|
|
4883
|
-
plt.show()
|
|
4884
|
-
|
|
4885
|
-
return fig, (ax1, ax2), info
|
|
4886
|
-
|
|
4887
|
-
|
|
4888
3630
|
def create_overview_image(
|
|
4889
3631
|
src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
|
|
4890
3632
|
) -> str:
|
|
@@ -8779,39 +7521,17 @@ def write_colormap(
|
|
|
8779
7521
|
|
|
8780
7522
|
def plot_performance_metrics(
|
|
8781
7523
|
history_path: str,
|
|
8782
|
-
figsize:
|
|
7524
|
+
figsize: Tuple[int, int] = (15, 5),
|
|
8783
7525
|
verbose: bool = True,
|
|
8784
7526
|
save_path: Optional[str] = None,
|
|
8785
|
-
csv_path: Optional[str] = None,
|
|
8786
7527
|
kwargs: Optional[Dict] = None,
|
|
8787
|
-
) ->
|
|
8788
|
-
"""Plot performance metrics from a
|
|
8789
|
-
|
|
8790
|
-
This function loads training history, plots available metrics (loss, IoU, F1,
|
|
8791
|
-
precision, recall), optionally exports to CSV, and returns all metrics as a
|
|
8792
|
-
pandas DataFrame for further analysis.
|
|
7528
|
+
) -> None:
|
|
7529
|
+
"""Plot performance metrics from a history object.
|
|
8793
7530
|
|
|
8794
7531
|
Args:
|
|
8795
|
-
history_path
|
|
8796
|
-
figsize
|
|
8797
|
-
|
|
8798
|
-
verbose (bool): Whether to print best and final metric values. Defaults to True.
|
|
8799
|
-
save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
|
|
8800
|
-
csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
|
|
8801
|
-
kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
|
|
8802
|
-
|
|
8803
|
-
Returns:
|
|
8804
|
-
pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
|
|
8805
|
-
Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
|
|
8806
|
-
'val_precision', 'val_recall' (depending on availability in history).
|
|
8807
|
-
|
|
8808
|
-
Example:
|
|
8809
|
-
>>> df = plot_performance_metrics(
|
|
8810
|
-
... 'training_history.pth',
|
|
8811
|
-
... save_path='metrics_plot.png',
|
|
8812
|
-
... csv_path='metrics.csv'
|
|
8813
|
-
... )
|
|
8814
|
-
>>> print(df.head())
|
|
7532
|
+
history_path: The history object to plot.
|
|
7533
|
+
figsize: The figure size.
|
|
7534
|
+
verbose: Whether to print the best and final metrics.
|
|
8815
7535
|
"""
|
|
8816
7536
|
if kwargs is None:
|
|
8817
7537
|
kwargs = {}
|
|
@@ -8821,135 +7541,65 @@ def plot_performance_metrics(
|
|
|
8821
7541
|
train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
|
|
8822
7542
|
val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
|
|
8823
7543
|
val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
|
|
8824
|
-
|
|
8825
|
-
val_f1_key = (
|
|
8826
|
-
"val_f1s"
|
|
8827
|
-
if "val_f1s" in history
|
|
8828
|
-
else ("val_dices" if "val_dices" in history else "val_dice")
|
|
8829
|
-
)
|
|
8830
|
-
# Add support for precision and recall
|
|
8831
|
-
val_precision_key = (
|
|
8832
|
-
"val_precisions" if "val_precisions" in history else "val_precision"
|
|
8833
|
-
)
|
|
8834
|
-
val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
|
|
8835
|
-
|
|
8836
|
-
# Collect available metrics for plotting
|
|
8837
|
-
available_metrics = []
|
|
8838
|
-
metric_info = {
|
|
8839
|
-
"Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
|
|
8840
|
-
"IoU": (val_iou_key, None, ["Val IoU"]),
|
|
8841
|
-
"F1": (val_f1_key, None, ["Val F1"]),
|
|
8842
|
-
"Precision": (val_precision_key, None, ["Val Precision"]),
|
|
8843
|
-
"Recall": (val_recall_key, None, ["Val Recall"]),
|
|
8844
|
-
}
|
|
8845
|
-
|
|
8846
|
-
for metric_name, (key1, key2, labels) in metric_info.items():
|
|
8847
|
-
if key1 in history or (key2 and key2 in history):
|
|
8848
|
-
available_metrics.append((metric_name, key1, key2, labels))
|
|
8849
|
-
|
|
8850
|
-
# Determine number of subplots and figure size
|
|
8851
|
-
n_plots = len(available_metrics)
|
|
8852
|
-
if figsize is None:
|
|
8853
|
-
figsize = (5 * n_plots, 5)
|
|
7544
|
+
val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
|
|
8854
7545
|
|
|
8855
|
-
#
|
|
8856
|
-
|
|
8857
|
-
|
|
7546
|
+
# Determine number of subplots based on available metrics
|
|
7547
|
+
has_dice = val_dice_key in history
|
|
7548
|
+
n_plots = 3 if has_dice else 2
|
|
7549
|
+
figsize = (15, 5) if has_dice else (10, 5)
|
|
8858
7550
|
|
|
8859
|
-
|
|
8860
|
-
if "epochs" in history:
|
|
8861
|
-
df_data["epoch"] = history["epochs"]
|
|
8862
|
-
n_epochs = len(history["epochs"])
|
|
8863
|
-
elif train_loss_key in history:
|
|
8864
|
-
n_epochs = len(history[train_loss_key])
|
|
8865
|
-
df_data["epoch"] = list(range(1, n_epochs + 1))
|
|
7551
|
+
plt.figure(figsize=figsize)
|
|
8866
7552
|
|
|
8867
|
-
#
|
|
7553
|
+
# Plot loss
|
|
7554
|
+
plt.subplot(1, n_plots, 1)
|
|
8868
7555
|
if train_loss_key in history:
|
|
8869
|
-
|
|
7556
|
+
plt.plot(history[train_loss_key], label="Train Loss")
|
|
8870
7557
|
if val_loss_key in history:
|
|
8871
|
-
|
|
7558
|
+
plt.plot(history[val_loss_key], label="Val Loss")
|
|
7559
|
+
plt.title("Loss")
|
|
7560
|
+
plt.xlabel("Epoch")
|
|
7561
|
+
plt.ylabel("Loss")
|
|
7562
|
+
plt.legend()
|
|
7563
|
+
plt.grid(True)
|
|
7564
|
+
|
|
7565
|
+
# Plot IoU
|
|
7566
|
+
plt.subplot(1, n_plots, 2)
|
|
8872
7567
|
if val_iou_key in history:
|
|
8873
|
-
|
|
8874
|
-
|
|
8875
|
-
|
|
8876
|
-
|
|
8877
|
-
|
|
8878
|
-
|
|
8879
|
-
|
|
8880
|
-
|
|
8881
|
-
|
|
8882
|
-
|
|
8883
|
-
|
|
8884
|
-
|
|
8885
|
-
|
|
8886
|
-
|
|
8887
|
-
|
|
8888
|
-
|
|
8889
|
-
|
|
8890
|
-
# Create plots
|
|
8891
|
-
if n_plots > 0:
|
|
8892
|
-
fig, axes = plt.subplots(1, n_plots, figsize=figsize)
|
|
8893
|
-
if n_plots == 1:
|
|
8894
|
-
axes = [axes]
|
|
8895
|
-
|
|
8896
|
-
for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
|
|
8897
|
-
ax = axes[idx]
|
|
8898
|
-
|
|
8899
|
-
if metric_name == "Loss":
|
|
8900
|
-
# Special handling for loss (has both train and val)
|
|
8901
|
-
if key1 in history:
|
|
8902
|
-
ax.plot(history[key1], label=labels[0])
|
|
8903
|
-
if key2 and key2 in history:
|
|
8904
|
-
ax.plot(history[key2], label=labels[1])
|
|
8905
|
-
else:
|
|
8906
|
-
# Single metric plots
|
|
8907
|
-
if key1 in history:
|
|
8908
|
-
ax.plot(history[key1], label=labels[0])
|
|
8909
|
-
|
|
8910
|
-
ax.set_title(metric_name)
|
|
8911
|
-
ax.set_xlabel("Epoch")
|
|
8912
|
-
ax.set_ylabel(metric_name)
|
|
8913
|
-
ax.legend()
|
|
8914
|
-
ax.grid(True)
|
|
7568
|
+
plt.plot(history[val_iou_key], label="Val IoU")
|
|
7569
|
+
plt.title("IoU Score")
|
|
7570
|
+
plt.xlabel("Epoch")
|
|
7571
|
+
plt.ylabel("IoU")
|
|
7572
|
+
plt.legend()
|
|
7573
|
+
plt.grid(True)
|
|
7574
|
+
|
|
7575
|
+
# Plot Dice if available
|
|
7576
|
+
if has_dice:
|
|
7577
|
+
plt.subplot(1, n_plots, 3)
|
|
7578
|
+
plt.plot(history[val_dice_key], label="Val Dice")
|
|
7579
|
+
plt.title("Dice Score")
|
|
7580
|
+
plt.xlabel("Epoch")
|
|
7581
|
+
plt.ylabel("Dice")
|
|
7582
|
+
plt.legend()
|
|
7583
|
+
plt.grid(True)
|
|
8915
7584
|
|
|
8916
|
-
|
|
7585
|
+
plt.tight_layout()
|
|
8917
7586
|
|
|
8918
|
-
|
|
8919
|
-
|
|
8920
|
-
|
|
8921
|
-
|
|
8922
|
-
|
|
8923
|
-
|
|
7587
|
+
if save_path:
|
|
7588
|
+
if "dpi" not in kwargs:
|
|
7589
|
+
kwargs["dpi"] = 150
|
|
7590
|
+
if "bbox_inches" not in kwargs:
|
|
7591
|
+
kwargs["bbox_inches"] = "tight"
|
|
7592
|
+
plt.savefig(save_path, **kwargs)
|
|
8924
7593
|
|
|
8925
|
-
|
|
7594
|
+
plt.show()
|
|
8926
7595
|
|
|
8927
|
-
# Print summary statistics
|
|
8928
7596
|
if verbose:
|
|
8929
|
-
print("\n=== Performance Metrics Summary ===")
|
|
8930
7597
|
if val_iou_key in history:
|
|
8931
|
-
print(
|
|
8932
|
-
|
|
8933
|
-
|
|
8934
|
-
|
|
8935
|
-
print(
|
|
8936
|
-
f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
|
|
8937
|
-
)
|
|
8938
|
-
if val_precision_key in history:
|
|
8939
|
-
print(
|
|
8940
|
-
f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
|
|
8941
|
-
)
|
|
8942
|
-
if val_recall_key in history:
|
|
8943
|
-
print(
|
|
8944
|
-
f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
|
|
8945
|
-
)
|
|
8946
|
-
if val_loss_key in history:
|
|
8947
|
-
print(
|
|
8948
|
-
f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
|
|
8949
|
-
)
|
|
8950
|
-
print("===================================\n")
|
|
8951
|
-
|
|
8952
|
-
return df
|
|
7598
|
+
print(f"Best IoU: {max(history[val_iou_key]):.4f}")
|
|
7599
|
+
print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
|
|
7600
|
+
if val_dice_key in history:
|
|
7601
|
+
print(f"Best Dice: {max(history[val_dice_key]):.4f}")
|
|
7602
|
+
print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
|
|
8953
7603
|
|
|
8954
7604
|
|
|
8955
7605
|
def get_device() -> torch.device:
|