geoai-py 0.18.2__py2.py3-none-any.whl → 0.20.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +23 -1
- geoai/agents/__init__.py +1 -0
- geoai/agents/geo_agents.py +74 -29
- geoai/geoai.py +2 -0
- geoai/landcover_train.py +685 -0
- geoai/landcover_utils.py +383 -0
- geoai/map_widgets.py +556 -0
- geoai/moondream.py +990 -0
- geoai/tools/__init__.py +11 -0
- geoai/tools/sr.py +194 -0
- geoai/utils.py +329 -1881
- {geoai_py-0.18.2.dist-info → geoai_py-0.20.0.dist-info}/METADATA +3 -1
- {geoai_py-0.18.2.dist-info → geoai_py-0.20.0.dist-info}/RECORD +17 -13
- {geoai_py-0.18.2.dist-info → geoai_py-0.20.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.18.2.dist-info → geoai_py-0.20.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.18.2.dist-info → geoai_py-0.20.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.18.2.dist-info → geoai_py-0.20.0.dist-info}/top_level.txt +0 -0
geoai/utils.py
CHANGED
|
@@ -64,7 +64,7 @@ def view_raster(
|
|
|
64
64
|
client_args: Optional[Dict] = {"cors_all": False},
|
|
65
65
|
basemap: Optional[str] = "OpenStreetMap",
|
|
66
66
|
basemap_args: Optional[Dict] = None,
|
|
67
|
-
backend: Optional[str] = "
|
|
67
|
+
backend: Optional[str] = "ipyleaflet",
|
|
68
68
|
**kwargs: Any,
|
|
69
69
|
) -> Any:
|
|
70
70
|
"""
|
|
@@ -87,7 +87,7 @@ def view_raster(
|
|
|
87
87
|
client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
|
|
88
88
|
basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
|
|
89
89
|
basemap_args (Optional[Dict], optional): Additional arguments for the basemap. Defaults to None.
|
|
90
|
-
backend (Optional[str], optional): The backend to use. Defaults to "
|
|
90
|
+
backend (Optional[str], optional): The backend to use. Defaults to "ipyleaflet".
|
|
91
91
|
**kwargs (Any): Additional keyword arguments.
|
|
92
92
|
|
|
93
93
|
Returns:
|
|
@@ -123,39 +123,26 @@ def view_raster(
|
|
|
123
123
|
if isinstance(source, dict):
|
|
124
124
|
source = dict_to_image(source)
|
|
125
125
|
|
|
126
|
-
if
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
if colormap is not None:
|
|
147
|
-
kwargs["colormap_name"] = colormap
|
|
148
|
-
if attribution is None:
|
|
149
|
-
attribution = "TiTiler"
|
|
150
|
-
|
|
151
|
-
m.add_cog_layer(
|
|
152
|
-
source,
|
|
153
|
-
name=layer_name,
|
|
154
|
-
opacity=opacity,
|
|
155
|
-
attribution=attribution,
|
|
156
|
-
zoom_to_layer=zoom_to_layer,
|
|
157
|
-
**kwargs,
|
|
158
|
-
)
|
|
126
|
+
if (
|
|
127
|
+
isinstance(source, str)
|
|
128
|
+
and source.lower().endswith(".tif")
|
|
129
|
+
and source.startswith("http")
|
|
130
|
+
):
|
|
131
|
+
if indexes is not None:
|
|
132
|
+
kwargs["bidx"] = indexes
|
|
133
|
+
if colormap is not None:
|
|
134
|
+
kwargs["colormap_name"] = colormap
|
|
135
|
+
if attribution is None:
|
|
136
|
+
attribution = "TiTiler"
|
|
137
|
+
|
|
138
|
+
m.add_cog_layer(
|
|
139
|
+
source,
|
|
140
|
+
name=layer_name,
|
|
141
|
+
opacity=opacity,
|
|
142
|
+
attribution=attribution,
|
|
143
|
+
zoom_to_layer=zoom_to_layer,
|
|
144
|
+
**kwargs,
|
|
145
|
+
)
|
|
159
146
|
else:
|
|
160
147
|
m.add_raster(
|
|
161
148
|
source=source,
|
|
@@ -237,8 +224,6 @@ def view_image(
|
|
|
237
224
|
plt.show()
|
|
238
225
|
plt.close()
|
|
239
226
|
|
|
240
|
-
return ax
|
|
241
|
-
|
|
242
227
|
|
|
243
228
|
def plot_images(
|
|
244
229
|
images: Iterable[torch.Tensor],
|
|
@@ -396,394 +381,6 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
|
|
|
396
381
|
return accum_mean / len(files), accum_std / len(files)
|
|
397
382
|
|
|
398
383
|
|
|
399
|
-
def calc_iou(
|
|
400
|
-
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
401
|
-
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
402
|
-
num_classes: Optional[int] = None,
|
|
403
|
-
ignore_index: Optional[int] = None,
|
|
404
|
-
smooth: float = 1e-6,
|
|
405
|
-
band: int = 1,
|
|
406
|
-
) -> Union[float, np.ndarray]:
|
|
407
|
-
"""
|
|
408
|
-
Calculate Intersection over Union (IoU) between ground truth and prediction masks.
|
|
409
|
-
|
|
410
|
-
This function computes the IoU metric for segmentation tasks. It supports both
|
|
411
|
-
binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
|
412
|
-
or file paths to raster files.
|
|
413
|
-
|
|
414
|
-
Args:
|
|
415
|
-
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
416
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
417
|
-
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
418
|
-
For multi-class segmentation: shape (H, W) with class indices.
|
|
419
|
-
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
420
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
421
|
-
Should have the same shape and format as ground_truth.
|
|
422
|
-
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
423
|
-
If None, assumes binary segmentation. Defaults to None.
|
|
424
|
-
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
425
|
-
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
426
|
-
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
427
|
-
Defaults to 1e-6.
|
|
428
|
-
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
429
|
-
Only used when input is a file path. Defaults to 1.
|
|
430
|
-
|
|
431
|
-
Returns:
|
|
432
|
-
Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
|
|
433
|
-
For multi-class segmentation, returns an array of IoU scores for each class.
|
|
434
|
-
|
|
435
|
-
Examples:
|
|
436
|
-
>>> # Binary segmentation with arrays
|
|
437
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
438
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
439
|
-
>>> iou = calc_iou(gt, pred)
|
|
440
|
-
>>> print(f"IoU: {iou:.4f}")
|
|
441
|
-
IoU: 0.8333
|
|
442
|
-
|
|
443
|
-
>>> # Multi-class segmentation
|
|
444
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
445
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
446
|
-
>>> iou = calc_iou(gt, pred, num_classes=3)
|
|
447
|
-
>>> print(f"IoU per class: {iou}")
|
|
448
|
-
IoU per class: [0.8333 0.5000 0.5000]
|
|
449
|
-
|
|
450
|
-
>>> # Using PyTorch tensors
|
|
451
|
-
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
452
|
-
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
453
|
-
>>> iou = calc_iou(gt_tensor, pred_tensor)
|
|
454
|
-
>>> print(f"IoU: {iou:.4f}")
|
|
455
|
-
IoU: 0.8333
|
|
456
|
-
|
|
457
|
-
>>> # Using raster file paths
|
|
458
|
-
>>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
459
|
-
>>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
|
|
460
|
-
Mean IoU: 0.7500
|
|
461
|
-
"""
|
|
462
|
-
# Load from file if string path is provided
|
|
463
|
-
if isinstance(ground_truth, str):
|
|
464
|
-
with rasterio.open(ground_truth) as src:
|
|
465
|
-
ground_truth = src.read(band)
|
|
466
|
-
if isinstance(prediction, str):
|
|
467
|
-
with rasterio.open(prediction) as src:
|
|
468
|
-
prediction = src.read(band)
|
|
469
|
-
|
|
470
|
-
# Convert to numpy if torch tensor
|
|
471
|
-
if isinstance(ground_truth, torch.Tensor):
|
|
472
|
-
ground_truth = ground_truth.cpu().numpy()
|
|
473
|
-
if isinstance(prediction, torch.Tensor):
|
|
474
|
-
prediction = prediction.cpu().numpy()
|
|
475
|
-
|
|
476
|
-
# Ensure inputs have the same shape
|
|
477
|
-
if ground_truth.shape != prediction.shape:
|
|
478
|
-
raise ValueError(
|
|
479
|
-
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
|
480
|
-
)
|
|
481
|
-
|
|
482
|
-
# Binary segmentation
|
|
483
|
-
if num_classes is None:
|
|
484
|
-
ground_truth = ground_truth.astype(bool)
|
|
485
|
-
prediction = prediction.astype(bool)
|
|
486
|
-
|
|
487
|
-
intersection = np.logical_and(ground_truth, prediction).sum()
|
|
488
|
-
union = np.logical_or(ground_truth, prediction).sum()
|
|
489
|
-
|
|
490
|
-
if union == 0:
|
|
491
|
-
return 1.0 if intersection == 0 else 0.0
|
|
492
|
-
|
|
493
|
-
iou = (intersection + smooth) / (union + smooth)
|
|
494
|
-
return float(iou)
|
|
495
|
-
|
|
496
|
-
# Multi-class segmentation
|
|
497
|
-
else:
|
|
498
|
-
iou_per_class = []
|
|
499
|
-
|
|
500
|
-
for class_idx in range(num_classes):
|
|
501
|
-
# Handle ignored class by appending np.nan
|
|
502
|
-
if ignore_index is not None and class_idx == ignore_index:
|
|
503
|
-
iou_per_class.append(np.nan)
|
|
504
|
-
continue
|
|
505
|
-
|
|
506
|
-
# Create binary masks for current class
|
|
507
|
-
gt_class = (ground_truth == class_idx).astype(bool)
|
|
508
|
-
pred_class = (prediction == class_idx).astype(bool)
|
|
509
|
-
|
|
510
|
-
intersection = np.logical_and(gt_class, pred_class).sum()
|
|
511
|
-
union = np.logical_or(gt_class, pred_class).sum()
|
|
512
|
-
|
|
513
|
-
if union == 0:
|
|
514
|
-
# If class is not present in both gt and pred
|
|
515
|
-
iou_per_class.append(np.nan)
|
|
516
|
-
else:
|
|
517
|
-
iou_per_class.append((intersection + smooth) / (union + smooth))
|
|
518
|
-
|
|
519
|
-
return np.array(iou_per_class)
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
def calc_f1_score(
|
|
523
|
-
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
524
|
-
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
525
|
-
num_classes: Optional[int] = None,
|
|
526
|
-
ignore_index: Optional[int] = None,
|
|
527
|
-
smooth: float = 1e-6,
|
|
528
|
-
band: int = 1,
|
|
529
|
-
) -> Union[float, np.ndarray]:
|
|
530
|
-
"""
|
|
531
|
-
Calculate F1 score between ground truth and prediction masks.
|
|
532
|
-
|
|
533
|
-
The F1 score is the harmonic mean of precision and recall, computed as:
|
|
534
|
-
F1 = 2 * (precision * recall) / (precision + recall)
|
|
535
|
-
where precision = TP / (TP + FP) and recall = TP / (TP + FN).
|
|
536
|
-
|
|
537
|
-
This function supports both binary and multi-class segmentation, and can handle
|
|
538
|
-
numpy arrays, PyTorch tensors, or file paths to raster files.
|
|
539
|
-
|
|
540
|
-
Args:
|
|
541
|
-
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
542
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
543
|
-
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
544
|
-
For multi-class segmentation: shape (H, W) with class indices.
|
|
545
|
-
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
546
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
547
|
-
Should have the same shape and format as ground_truth.
|
|
548
|
-
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
549
|
-
If None, assumes binary segmentation. Defaults to None.
|
|
550
|
-
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
551
|
-
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
552
|
-
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
553
|
-
Defaults to 1e-6.
|
|
554
|
-
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
555
|
-
Only used when input is a file path. Defaults to 1.
|
|
556
|
-
|
|
557
|
-
Returns:
|
|
558
|
-
Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
|
|
559
|
-
For multi-class segmentation, returns an array of F1 scores for each class.
|
|
560
|
-
|
|
561
|
-
Examples:
|
|
562
|
-
>>> # Binary segmentation with arrays
|
|
563
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
564
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
565
|
-
>>> f1 = calc_f1_score(gt, pred)
|
|
566
|
-
>>> print(f"F1 Score: {f1:.4f}")
|
|
567
|
-
F1 Score: 0.8571
|
|
568
|
-
|
|
569
|
-
>>> # Multi-class segmentation
|
|
570
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
571
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
572
|
-
>>> f1 = calc_f1_score(gt, pred, num_classes=3)
|
|
573
|
-
>>> print(f"F1 Score per class: {f1}")
|
|
574
|
-
F1 Score per class: [0.8571 0.6667 0.6667]
|
|
575
|
-
|
|
576
|
-
>>> # Using PyTorch tensors
|
|
577
|
-
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
578
|
-
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
579
|
-
>>> f1 = calc_f1_score(gt_tensor, pred_tensor)
|
|
580
|
-
>>> print(f"F1 Score: {f1:.4f}")
|
|
581
|
-
F1 Score: 0.8571
|
|
582
|
-
|
|
583
|
-
>>> # Using raster file paths
|
|
584
|
-
>>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
585
|
-
>>> print(f"Mean F1: {np.nanmean(f1):.4f}")
|
|
586
|
-
Mean F1: 0.7302
|
|
587
|
-
"""
|
|
588
|
-
# Load from file if string path is provided
|
|
589
|
-
if isinstance(ground_truth, str):
|
|
590
|
-
with rasterio.open(ground_truth) as src:
|
|
591
|
-
ground_truth = src.read(band)
|
|
592
|
-
if isinstance(prediction, str):
|
|
593
|
-
with rasterio.open(prediction) as src:
|
|
594
|
-
prediction = src.read(band)
|
|
595
|
-
|
|
596
|
-
# Convert to numpy if torch tensor
|
|
597
|
-
if isinstance(ground_truth, torch.Tensor):
|
|
598
|
-
ground_truth = ground_truth.cpu().numpy()
|
|
599
|
-
if isinstance(prediction, torch.Tensor):
|
|
600
|
-
prediction = prediction.cpu().numpy()
|
|
601
|
-
|
|
602
|
-
# Ensure inputs have the same shape
|
|
603
|
-
if ground_truth.shape != prediction.shape:
|
|
604
|
-
raise ValueError(
|
|
605
|
-
f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
|
|
606
|
-
)
|
|
607
|
-
|
|
608
|
-
# Binary segmentation
|
|
609
|
-
if num_classes is None:
|
|
610
|
-
ground_truth = ground_truth.astype(bool)
|
|
611
|
-
prediction = prediction.astype(bool)
|
|
612
|
-
|
|
613
|
-
# Calculate True Positives, False Positives, False Negatives
|
|
614
|
-
tp = np.logical_and(ground_truth, prediction).sum()
|
|
615
|
-
fp = np.logical_and(~ground_truth, prediction).sum()
|
|
616
|
-
fn = np.logical_and(ground_truth, ~prediction).sum()
|
|
617
|
-
|
|
618
|
-
# Calculate precision and recall
|
|
619
|
-
precision = (tp + smooth) / (tp + fp + smooth)
|
|
620
|
-
recall = (tp + smooth) / (tp + fn + smooth)
|
|
621
|
-
|
|
622
|
-
# Calculate F1 score
|
|
623
|
-
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
|
624
|
-
return float(f1)
|
|
625
|
-
|
|
626
|
-
# Multi-class segmentation
|
|
627
|
-
else:
|
|
628
|
-
f1_per_class = []
|
|
629
|
-
|
|
630
|
-
for class_idx in range(num_classes):
|
|
631
|
-
# Mark ignored class with np.nan
|
|
632
|
-
if ignore_index is not None and class_idx == ignore_index:
|
|
633
|
-
f1_per_class.append(np.nan)
|
|
634
|
-
continue
|
|
635
|
-
|
|
636
|
-
# Create binary masks for current class
|
|
637
|
-
gt_class = (ground_truth == class_idx).astype(bool)
|
|
638
|
-
pred_class = (prediction == class_idx).astype(bool)
|
|
639
|
-
|
|
640
|
-
# Calculate True Positives, False Positives, False Negatives
|
|
641
|
-
tp = np.logical_and(gt_class, pred_class).sum()
|
|
642
|
-
fp = np.logical_and(~gt_class, pred_class).sum()
|
|
643
|
-
fn = np.logical_and(gt_class, ~pred_class).sum()
|
|
644
|
-
|
|
645
|
-
# Calculate precision and recall
|
|
646
|
-
precision = (tp + smooth) / (tp + fp + smooth)
|
|
647
|
-
recall = (tp + smooth) / (tp + fn + smooth)
|
|
648
|
-
|
|
649
|
-
# Calculate F1 score
|
|
650
|
-
if tp + fp + fn == 0:
|
|
651
|
-
# If class is not present in both gt and pred
|
|
652
|
-
f1_per_class.append(np.nan)
|
|
653
|
-
else:
|
|
654
|
-
f1 = 2 * (precision * recall) / (precision + recall + smooth)
|
|
655
|
-
f1_per_class.append(f1)
|
|
656
|
-
|
|
657
|
-
return np.array(f1_per_class)
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
def calc_segmentation_metrics(
|
|
661
|
-
ground_truth: Union[str, np.ndarray, torch.Tensor],
|
|
662
|
-
prediction: Union[str, np.ndarray, torch.Tensor],
|
|
663
|
-
num_classes: Optional[int] = None,
|
|
664
|
-
ignore_index: Optional[int] = None,
|
|
665
|
-
smooth: float = 1e-6,
|
|
666
|
-
metrics: List[str] = ["iou", "f1"],
|
|
667
|
-
band: int = 1,
|
|
668
|
-
) -> Dict[str, Union[float, np.ndarray]]:
|
|
669
|
-
"""
|
|
670
|
-
Calculate multiple segmentation metrics between ground truth and prediction masks.
|
|
671
|
-
|
|
672
|
-
This is a convenient wrapper function that computes multiple metrics at once,
|
|
673
|
-
including IoU (Intersection over Union) and F1 score. It supports both binary
|
|
674
|
-
and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
|
|
675
|
-
or file paths to raster files.
|
|
676
|
-
|
|
677
|
-
Args:
|
|
678
|
-
ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
|
|
679
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
680
|
-
For binary segmentation: shape (H, W) with values {0, 1}.
|
|
681
|
-
For multi-class segmentation: shape (H, W) with class indices.
|
|
682
|
-
prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
|
|
683
|
-
Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
|
|
684
|
-
Should have the same shape and format as ground_truth.
|
|
685
|
-
num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
|
|
686
|
-
If None, assumes binary segmentation. Defaults to None.
|
|
687
|
-
ignore_index (Optional[int], optional): Class index to ignore in computation.
|
|
688
|
-
Useful for ignoring background or unlabeled pixels. Defaults to None.
|
|
689
|
-
smooth (float, optional): Smoothing factor to avoid division by zero.
|
|
690
|
-
Defaults to 1e-6.
|
|
691
|
-
metrics (List[str], optional): List of metrics to calculate.
|
|
692
|
-
Options: "iou", "f1". Defaults to ["iou", "f1"].
|
|
693
|
-
band (int, optional): Band index to read from raster file (1-based indexing).
|
|
694
|
-
Only used when input is a file path. Defaults to 1.
|
|
695
|
-
|
|
696
|
-
Returns:
|
|
697
|
-
Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
|
|
698
|
-
Keys are metric names ("iou", "f1"), values are the metric scores.
|
|
699
|
-
For binary segmentation, values are floats.
|
|
700
|
-
For multi-class segmentation, values are numpy arrays with per-class scores.
|
|
701
|
-
Also includes "mean_iou" and "mean_f1" for multi-class segmentation
|
|
702
|
-
(mean computed over valid classes, ignoring NaN values).
|
|
703
|
-
|
|
704
|
-
Examples:
|
|
705
|
-
>>> # Binary segmentation with arrays
|
|
706
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
707
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
708
|
-
>>> metrics = calc_segmentation_metrics(gt, pred)
|
|
709
|
-
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
|
710
|
-
IoU: 0.8333, F1: 0.8571
|
|
711
|
-
|
|
712
|
-
>>> # Multi-class segmentation
|
|
713
|
-
>>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
|
|
714
|
-
>>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
|
|
715
|
-
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
|
|
716
|
-
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
717
|
-
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
|
718
|
-
>>> print(f"Per-class IoU: {metrics['iou']}")
|
|
719
|
-
Mean IoU: 0.6111
|
|
720
|
-
Mean F1: 0.7302
|
|
721
|
-
Per-class IoU: [0.8333 0.5000 0.5000]
|
|
722
|
-
|
|
723
|
-
>>> # Calculate only IoU
|
|
724
|
-
>>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
|
|
725
|
-
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
726
|
-
Mean IoU: 0.6111
|
|
727
|
-
|
|
728
|
-
>>> # Using PyTorch tensors
|
|
729
|
-
>>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
|
|
730
|
-
>>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
|
|
731
|
-
>>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
|
|
732
|
-
>>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
|
|
733
|
-
IoU: 0.8333, F1: 0.8571
|
|
734
|
-
|
|
735
|
-
>>> # Using raster file paths
|
|
736
|
-
>>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
|
|
737
|
-
>>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
|
|
738
|
-
>>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
|
|
739
|
-
Mean IoU: 0.6111
|
|
740
|
-
Mean F1: 0.7302
|
|
741
|
-
"""
|
|
742
|
-
results = {}
|
|
743
|
-
|
|
744
|
-
# Calculate IoU if requested
|
|
745
|
-
if "iou" in metrics:
|
|
746
|
-
iou = calc_iou(
|
|
747
|
-
ground_truth,
|
|
748
|
-
prediction,
|
|
749
|
-
num_classes=num_classes,
|
|
750
|
-
ignore_index=ignore_index,
|
|
751
|
-
smooth=smooth,
|
|
752
|
-
band=band,
|
|
753
|
-
)
|
|
754
|
-
results["iou"] = iou
|
|
755
|
-
|
|
756
|
-
# Add mean IoU for multi-class
|
|
757
|
-
if num_classes is not None and isinstance(iou, np.ndarray):
|
|
758
|
-
# Calculate mean ignoring NaN values
|
|
759
|
-
valid_ious = iou[~np.isnan(iou)]
|
|
760
|
-
results["mean_iou"] = (
|
|
761
|
-
float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
|
|
762
|
-
)
|
|
763
|
-
|
|
764
|
-
# Calculate F1 score if requested
|
|
765
|
-
if "f1" in metrics:
|
|
766
|
-
f1 = calc_f1_score(
|
|
767
|
-
ground_truth,
|
|
768
|
-
prediction,
|
|
769
|
-
num_classes=num_classes,
|
|
770
|
-
ignore_index=ignore_index,
|
|
771
|
-
smooth=smooth,
|
|
772
|
-
band=band,
|
|
773
|
-
)
|
|
774
|
-
results["f1"] = f1
|
|
775
|
-
|
|
776
|
-
# Add mean F1 for multi-class
|
|
777
|
-
if num_classes is not None and isinstance(f1, np.ndarray):
|
|
778
|
-
# Calculate mean ignoring NaN values
|
|
779
|
-
valid_f1s = f1[~np.isnan(f1)]
|
|
780
|
-
results["mean_f1"] = (
|
|
781
|
-
float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
|
|
782
|
-
)
|
|
783
|
-
|
|
784
|
-
return results
|
|
785
|
-
|
|
786
|
-
|
|
787
384
|
def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
|
|
788
385
|
"""Convert a dictionary to a xarray DataArray. The dictionary should contain the
|
|
789
386
|
following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
|
|
@@ -1094,9 +691,8 @@ def view_vector(
|
|
|
1094
691
|
|
|
1095
692
|
def view_vector_interactive(
|
|
1096
693
|
vector_data: Union[str, gpd.GeoDataFrame],
|
|
1097
|
-
layer_name: str = "Vector",
|
|
694
|
+
layer_name: str = "Vector Layer",
|
|
1098
695
|
tiles_args: Optional[Dict] = None,
|
|
1099
|
-
opacity: float = 0.7,
|
|
1100
696
|
**kwargs: Any,
|
|
1101
697
|
) -> Any:
|
|
1102
698
|
"""
|
|
@@ -1111,7 +707,6 @@ def view_vector_interactive(
|
|
|
1111
707
|
layer_name (str, optional): The name of the layer. Defaults to "Vector Layer".
|
|
1112
708
|
tiles_args (dict, optional): Additional arguments for the localtileserver client.
|
|
1113
709
|
get_folium_tile_layer function. Defaults to None.
|
|
1114
|
-
opacity (float, optional): The opacity of the layer. Defaults to 0.7.
|
|
1115
710
|
**kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
|
|
1116
711
|
See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
|
|
1117
712
|
|
|
@@ -1126,8 +721,9 @@ def view_vector_interactive(
|
|
|
1126
721
|
>>> roads = gpd.read_file("roads.shp")
|
|
1127
722
|
>>> view_vector_interactive(roads, figsize=(12, 8))
|
|
1128
723
|
"""
|
|
1129
|
-
|
|
1130
|
-
|
|
724
|
+
import folium
|
|
725
|
+
import folium.plugins as plugins
|
|
726
|
+
from leafmap import cog_tile
|
|
1131
727
|
from localtileserver import TileClient, get_folium_tile_layer
|
|
1132
728
|
|
|
1133
729
|
google_tiles = {
|
|
@@ -1153,17 +749,9 @@ def view_vector_interactive(
|
|
|
1153
749
|
},
|
|
1154
750
|
}
|
|
1155
751
|
|
|
1156
|
-
# Make it compatible with binder and JupyterHub
|
|
1157
|
-
if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
|
|
1158
|
-
os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
|
|
1159
|
-
f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
|
|
1160
|
-
)
|
|
1161
|
-
|
|
1162
752
|
basemap_layer_name = None
|
|
1163
753
|
raster_layer = None
|
|
1164
754
|
|
|
1165
|
-
m = Map()
|
|
1166
|
-
|
|
1167
755
|
if "tiles" in kwargs and isinstance(kwargs["tiles"], str):
|
|
1168
756
|
if kwargs["tiles"].title() in google_tiles:
|
|
1169
757
|
basemap_layer_name = google_tiles[kwargs["tiles"].title()]["name"]
|
|
@@ -1174,17 +762,14 @@ def view_vector_interactive(
|
|
|
1174
762
|
tiles_args = {}
|
|
1175
763
|
if kwargs["tiles"].lower().startswith("http"):
|
|
1176
764
|
basemap_layer_name = "Remote Raster"
|
|
1177
|
-
|
|
765
|
+
kwargs["tiles"] = cog_tile(kwargs["tiles"], **tiles_args)
|
|
766
|
+
kwargs["attr"] = "TiTiler"
|
|
1178
767
|
else:
|
|
1179
768
|
basemap_layer_name = "Local Raster"
|
|
1180
769
|
client = TileClient(kwargs["tiles"])
|
|
1181
770
|
raster_layer = get_folium_tile_layer(client, **tiles_args)
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
name=basemap_layer_name,
|
|
1185
|
-
attribution="localtileserver",
|
|
1186
|
-
**tiles_args,
|
|
1187
|
-
)
|
|
771
|
+
kwargs["tiles"] = raster_layer.tiles
|
|
772
|
+
kwargs["attr"] = "localtileserver"
|
|
1188
773
|
|
|
1189
774
|
if "max_zoom" not in kwargs:
|
|
1190
775
|
kwargs["max_zoom"] = 30
|
|
@@ -1199,18 +784,23 @@ def view_vector_interactive(
|
|
|
1199
784
|
if not isinstance(vector_data, gpd.GeoDataFrame):
|
|
1200
785
|
raise TypeError("Input data must be a GeoDataFrame")
|
|
1201
786
|
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
kwargs["legend_position"] = "bottomleft"
|
|
1205
|
-
if "cmap" not in kwargs:
|
|
1206
|
-
kwargs["cmap"] = "viridis"
|
|
1207
|
-
m.add_data(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
|
|
787
|
+
layer_control = kwargs.pop("layer_control", True)
|
|
788
|
+
fullscreen_control = kwargs.pop("fullscreen_control", True)
|
|
1208
789
|
|
|
1209
|
-
|
|
1210
|
-
|
|
790
|
+
m = vector_data.explore(**kwargs)
|
|
791
|
+
|
|
792
|
+
# Change the layer name
|
|
793
|
+
for layer in m._children.values():
|
|
794
|
+
if isinstance(layer, folium.GeoJson):
|
|
795
|
+
layer.layer_name = layer_name
|
|
796
|
+
if isinstance(layer, folium.TileLayer) and basemap_layer_name:
|
|
797
|
+
layer.layer_name = basemap_layer_name
|
|
798
|
+
|
|
799
|
+
if layer_control:
|
|
800
|
+
m.add_child(folium.LayerControl())
|
|
1211
801
|
|
|
1212
|
-
|
|
1213
|
-
|
|
802
|
+
if fullscreen_control:
|
|
803
|
+
plugins.Fullscreen().add_to(m)
|
|
1214
804
|
|
|
1215
805
|
return m
|
|
1216
806
|
|
|
@@ -3004,86 +2594,10 @@ def batch_vector_to_raster(
|
|
|
3004
2594
|
return output_files
|
|
3005
2595
|
|
|
3006
2596
|
|
|
3007
|
-
def get_default_augmentation_transforms(
|
|
3008
|
-
tile_size: int = 256,
|
|
3009
|
-
include_normalize: bool = False,
|
|
3010
|
-
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
|
3011
|
-
std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
|
|
3012
|
-
) -> Any:
|
|
3013
|
-
"""
|
|
3014
|
-
Get default data augmentation transforms for geospatial imagery using albumentations.
|
|
3015
|
-
|
|
3016
|
-
This function returns a composition of augmentation transforms commonly used
|
|
3017
|
-
for remote sensing and geospatial data. The transforms include geometric
|
|
3018
|
-
transformations (flips, rotations) and photometric adjustments (brightness,
|
|
3019
|
-
contrast, saturation).
|
|
3020
|
-
|
|
3021
|
-
Args:
|
|
3022
|
-
tile_size (int): Target size for tiles. Defaults to 256.
|
|
3023
|
-
include_normalize (bool): Whether to include normalization transform.
|
|
3024
|
-
Defaults to False. Set to True if using for training with pretrained models.
|
|
3025
|
-
mean (tuple): Mean values for normalization (RGB). Defaults to ImageNet values.
|
|
3026
|
-
std (tuple): Standard deviation for normalization (RGB). Defaults to ImageNet values.
|
|
3027
|
-
|
|
3028
|
-
Returns:
|
|
3029
|
-
albumentations.Compose: A composition of augmentation transforms.
|
|
3030
|
-
|
|
3031
|
-
Example:
|
|
3032
|
-
>>> import albumentations as A
|
|
3033
|
-
>>> # Get default transforms
|
|
3034
|
-
>>> transform = get_default_augmentation_transforms()
|
|
3035
|
-
>>> # Apply to image and mask
|
|
3036
|
-
>>> augmented = transform(image=image, mask=mask)
|
|
3037
|
-
>>> aug_image = augmented['image']
|
|
3038
|
-
>>> aug_mask = augmented['mask']
|
|
3039
|
-
"""
|
|
3040
|
-
try:
|
|
3041
|
-
import albumentations as A
|
|
3042
|
-
except ImportError:
|
|
3043
|
-
raise ImportError(
|
|
3044
|
-
"albumentations is required for data augmentation. "
|
|
3045
|
-
"Install it with: pip install albumentations"
|
|
3046
|
-
)
|
|
3047
|
-
|
|
3048
|
-
transforms_list = [
|
|
3049
|
-
# Geometric transforms
|
|
3050
|
-
A.HorizontalFlip(p=0.5),
|
|
3051
|
-
A.VerticalFlip(p=0.5),
|
|
3052
|
-
A.RandomRotate90(p=0.5),
|
|
3053
|
-
A.ShiftScaleRotate(
|
|
3054
|
-
shift_limit=0.1,
|
|
3055
|
-
scale_limit=0.1,
|
|
3056
|
-
rotate_limit=45,
|
|
3057
|
-
border_mode=0,
|
|
3058
|
-
p=0.5,
|
|
3059
|
-
),
|
|
3060
|
-
# Photometric transforms
|
|
3061
|
-
A.RandomBrightnessContrast(
|
|
3062
|
-
brightness_limit=0.2,
|
|
3063
|
-
contrast_limit=0.2,
|
|
3064
|
-
p=0.5,
|
|
3065
|
-
),
|
|
3066
|
-
A.HueSaturationValue(
|
|
3067
|
-
hue_shift_limit=10,
|
|
3068
|
-
sat_shift_limit=20,
|
|
3069
|
-
val_shift_limit=10,
|
|
3070
|
-
p=0.3,
|
|
3071
|
-
),
|
|
3072
|
-
A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
|
|
3073
|
-
A.GaussianBlur(blur_limit=(3, 5), p=0.2),
|
|
3074
|
-
]
|
|
3075
|
-
|
|
3076
|
-
# Add normalization if requested
|
|
3077
|
-
if include_normalize:
|
|
3078
|
-
transforms_list.append(A.Normalize(mean=mean, std=std))
|
|
3079
|
-
|
|
3080
|
-
return A.Compose(transforms_list)
|
|
3081
|
-
|
|
3082
|
-
|
|
3083
2597
|
def export_geotiff_tiles(
|
|
3084
2598
|
in_raster,
|
|
3085
2599
|
out_folder,
|
|
3086
|
-
in_class_data
|
|
2600
|
+
in_class_data,
|
|
3087
2601
|
tile_size=256,
|
|
3088
2602
|
stride=128,
|
|
3089
2603
|
class_value_field="class",
|
|
@@ -3093,10 +2607,6 @@ def export_geotiff_tiles(
|
|
|
3093
2607
|
all_touched=True,
|
|
3094
2608
|
create_overview=False,
|
|
3095
2609
|
skip_empty_tiles=False,
|
|
3096
|
-
metadata_format="PASCAL_VOC",
|
|
3097
|
-
apply_augmentation=False,
|
|
3098
|
-
augmentation_count=3,
|
|
3099
|
-
augmentation_transforms=None,
|
|
3100
2610
|
):
|
|
3101
2611
|
"""
|
|
3102
2612
|
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
|
@@ -3104,8 +2614,7 @@ def export_geotiff_tiles(
|
|
|
3104
2614
|
Args:
|
|
3105
2615
|
in_raster (str): Path to input raster image
|
|
3106
2616
|
out_folder (str): Path to output folder
|
|
3107
|
-
in_class_data (str
|
|
3108
|
-
If None, only image tiles will be exported without labels. Defaults to None.
|
|
2617
|
+
in_class_data (str): Path to classification data - can be vector file or raster
|
|
3109
2618
|
tile_size (int): Size of tiles in pixels (square)
|
|
3110
2619
|
stride (int): Step size between tiles
|
|
3111
2620
|
class_value_field (str): Field containing class values (for vector data)
|
|
@@ -3115,95 +2624,38 @@ def export_geotiff_tiles(
|
|
|
3115
2624
|
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
|
3116
2625
|
create_overview (bool): Whether to create an overview image of all tiles
|
|
3117
2626
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
3118
|
-
metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
|
|
3119
|
-
apply_augmentation (bool): If True, generate augmented versions of each tile.
|
|
3120
|
-
This will create multiple variants of each tile using data augmentation techniques.
|
|
3121
|
-
Defaults to False.
|
|
3122
|
-
augmentation_count (int): Number of augmented versions to generate per tile
|
|
3123
|
-
(only used if apply_augmentation=True). Defaults to 3.
|
|
3124
|
-
augmentation_transforms (albumentations.Compose, optional): Custom augmentation transforms.
|
|
3125
|
-
If None and apply_augmentation=True, uses default transforms from
|
|
3126
|
-
get_default_augmentation_transforms(). Should be an albumentations.Compose object.
|
|
3127
|
-
Defaults to None.
|
|
3128
|
-
|
|
3129
|
-
Returns:
|
|
3130
|
-
None: Tiles and labels are saved to out_folder.
|
|
3131
|
-
|
|
3132
|
-
Example:
|
|
3133
|
-
>>> # Export tiles without augmentation
|
|
3134
|
-
>>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif')
|
|
3135
|
-
>>>
|
|
3136
|
-
>>> # Export tiles with default augmentation (3 augmented versions per tile)
|
|
3137
|
-
>>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
|
|
3138
|
-
... apply_augmentation=True)
|
|
3139
|
-
>>>
|
|
3140
|
-
>>> # Export with custom augmentation
|
|
3141
|
-
>>> import albumentations as A
|
|
3142
|
-
>>> custom_transform = A.Compose([
|
|
3143
|
-
... A.HorizontalFlip(p=0.5),
|
|
3144
|
-
... A.RandomBrightnessContrast(p=0.5),
|
|
3145
|
-
... ])
|
|
3146
|
-
>>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
|
|
3147
|
-
... apply_augmentation=True,
|
|
3148
|
-
... augmentation_count=5,
|
|
3149
|
-
... augmentation_transforms=custom_transform)
|
|
3150
2627
|
"""
|
|
3151
2628
|
|
|
3152
2629
|
import logging
|
|
3153
2630
|
|
|
3154
2631
|
logging.getLogger("rasterio").setLevel(logging.ERROR)
|
|
3155
2632
|
|
|
3156
|
-
# Initialize augmentation transforms if needed
|
|
3157
|
-
if apply_augmentation:
|
|
3158
|
-
if augmentation_transforms is None:
|
|
3159
|
-
augmentation_transforms = get_default_augmentation_transforms(
|
|
3160
|
-
tile_size=tile_size
|
|
3161
|
-
)
|
|
3162
|
-
if not quiet:
|
|
3163
|
-
print(
|
|
3164
|
-
f"Data augmentation enabled: generating {augmentation_count} augmented versions per tile"
|
|
3165
|
-
)
|
|
3166
|
-
|
|
3167
2633
|
# Create output directories
|
|
3168
2634
|
os.makedirs(out_folder, exist_ok=True)
|
|
3169
2635
|
image_dir = os.path.join(out_folder, "images")
|
|
3170
2636
|
os.makedirs(image_dir, exist_ok=True)
|
|
2637
|
+
label_dir = os.path.join(out_folder, "labels")
|
|
2638
|
+
os.makedirs(label_dir, exist_ok=True)
|
|
2639
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
2640
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
3171
2641
|
|
|
3172
|
-
#
|
|
3173
|
-
if in_class_data is not None:
|
|
3174
|
-
label_dir = os.path.join(out_folder, "labels")
|
|
3175
|
-
os.makedirs(label_dir, exist_ok=True)
|
|
3176
|
-
|
|
3177
|
-
# Create annotation directory based on metadata format
|
|
3178
|
-
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
|
3179
|
-
ann_dir = os.path.join(out_folder, "annotations")
|
|
3180
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
3181
|
-
|
|
3182
|
-
# Initialize COCO annotations dictionary
|
|
3183
|
-
if metadata_format == "COCO":
|
|
3184
|
-
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
3185
|
-
ann_id = 0
|
|
3186
|
-
|
|
3187
|
-
# Determine if class data is raster or vector (only if class data provided)
|
|
2642
|
+
# Determine if class data is raster or vector
|
|
3188
2643
|
is_class_data_raster = False
|
|
3189
|
-
if in_class_data
|
|
3190
|
-
|
|
3191
|
-
|
|
3192
|
-
|
|
3193
|
-
|
|
3194
|
-
|
|
3195
|
-
|
|
3196
|
-
is_class_data_raster = True
|
|
3197
|
-
if not quiet:
|
|
3198
|
-
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
3199
|
-
print(f"Raster CRS: {src.crs}")
|
|
3200
|
-
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
3201
|
-
except Exception:
|
|
3202
|
-
is_class_data_raster = False
|
|
2644
|
+
if isinstance(in_class_data, str):
|
|
2645
|
+
file_ext = Path(in_class_data).suffix.lower()
|
|
2646
|
+
# Common raster extensions
|
|
2647
|
+
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
2648
|
+
try:
|
|
2649
|
+
with rasterio.open(in_class_data) as src:
|
|
2650
|
+
is_class_data_raster = True
|
|
3203
2651
|
if not quiet:
|
|
3204
|
-
print(
|
|
3205
|
-
|
|
3206
|
-
)
|
|
2652
|
+
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
2653
|
+
print(f"Raster CRS: {src.crs}")
|
|
2654
|
+
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
2655
|
+
except Exception:
|
|
2656
|
+
is_class_data_raster = False
|
|
2657
|
+
if not quiet:
|
|
2658
|
+
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
|
3207
2659
|
|
|
3208
2660
|
# Open the input raster
|
|
3209
2661
|
with rasterio.open(in_raster) as src:
|
|
@@ -3223,10 +2675,10 @@ def export_geotiff_tiles(
|
|
|
3223
2675
|
if max_tiles is None:
|
|
3224
2676
|
max_tiles = total_tiles
|
|
3225
2677
|
|
|
3226
|
-
# Process classification data
|
|
2678
|
+
# Process classification data
|
|
3227
2679
|
class_to_id = {}
|
|
3228
2680
|
|
|
3229
|
-
if
|
|
2681
|
+
if is_class_data_raster:
|
|
3230
2682
|
# Load raster class data
|
|
3231
2683
|
with rasterio.open(in_class_data) as class_src:
|
|
3232
2684
|
# Check if raster CRS matches
|
|
@@ -3259,18 +2711,7 @@ def export_geotiff_tiles(
|
|
|
3259
2711
|
|
|
3260
2712
|
# Create class mapping
|
|
3261
2713
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
3262
|
-
|
|
3263
|
-
# Populate COCO categories
|
|
3264
|
-
if metadata_format == "COCO":
|
|
3265
|
-
for cls_val in unique_classes:
|
|
3266
|
-
coco_annotations["categories"].append(
|
|
3267
|
-
{
|
|
3268
|
-
"id": class_to_id[int(cls_val)],
|
|
3269
|
-
"name": str(int(cls_val)),
|
|
3270
|
-
"supercategory": "object",
|
|
3271
|
-
}
|
|
3272
|
-
)
|
|
3273
|
-
elif in_class_data is not None:
|
|
2714
|
+
else:
|
|
3274
2715
|
# Load vector class data
|
|
3275
2716
|
try:
|
|
3276
2717
|
gdf = gpd.read_file(in_class_data)
|
|
@@ -3299,33 +2740,12 @@ def export_geotiff_tiles(
|
|
|
3299
2740
|
)
|
|
3300
2741
|
# Create class mapping
|
|
3301
2742
|
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
3302
|
-
|
|
3303
|
-
# Populate COCO categories
|
|
3304
|
-
if metadata_format == "COCO":
|
|
3305
|
-
for cls_val in unique_classes:
|
|
3306
|
-
coco_annotations["categories"].append(
|
|
3307
|
-
{
|
|
3308
|
-
"id": class_to_id[cls_val],
|
|
3309
|
-
"name": str(cls_val),
|
|
3310
|
-
"supercategory": "object",
|
|
3311
|
-
}
|
|
3312
|
-
)
|
|
3313
2743
|
else:
|
|
3314
2744
|
if not quiet:
|
|
3315
2745
|
print(
|
|
3316
2746
|
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
|
3317
2747
|
)
|
|
3318
2748
|
class_to_id = {1: 1} # Default mapping
|
|
3319
|
-
|
|
3320
|
-
# Populate COCO categories with default
|
|
3321
|
-
if metadata_format == "COCO":
|
|
3322
|
-
coco_annotations["categories"].append(
|
|
3323
|
-
{
|
|
3324
|
-
"id": 1,
|
|
3325
|
-
"name": "object",
|
|
3326
|
-
"supercategory": "object",
|
|
3327
|
-
}
|
|
3328
|
-
)
|
|
3329
2749
|
except Exception as e:
|
|
3330
2750
|
raise ValueError(f"Error processing vector data: {e}")
|
|
3331
2751
|
|
|
@@ -3392,8 +2812,8 @@ def export_geotiff_tiles(
|
|
|
3392
2812
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
3393
2813
|
has_features = False
|
|
3394
2814
|
|
|
3395
|
-
# Process classification data to create labels
|
|
3396
|
-
if
|
|
2815
|
+
# Process classification data to create labels
|
|
2816
|
+
if is_class_data_raster:
|
|
3397
2817
|
# For raster class data
|
|
3398
2818
|
with rasterio.open(in_class_data) as class_src:
|
|
3399
2819
|
# Calculate window in class raster
|
|
@@ -3443,7 +2863,7 @@ def export_geotiff_tiles(
|
|
|
3443
2863
|
except Exception as e:
|
|
3444
2864
|
pbar.write(f"Error reading class raster window: {e}")
|
|
3445
2865
|
stats["errors"] += 1
|
|
3446
|
-
|
|
2866
|
+
else:
|
|
3447
2867
|
# For vector class data
|
|
3448
2868
|
# Find features that intersect with window
|
|
3449
2869
|
window_features = gdf[gdf.intersects(window_bounds)]
|
|
@@ -3486,8 +2906,8 @@ def export_geotiff_tiles(
|
|
|
3486
2906
|
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
3487
2907
|
stats["errors"] += 1
|
|
3488
2908
|
|
|
3489
|
-
# Skip tile if no features and skip_empty_tiles is True
|
|
3490
|
-
if
|
|
2909
|
+
# Skip tile if no features and skip_empty_tiles is True
|
|
2910
|
+
if skip_empty_tiles and not has_features:
|
|
3491
2911
|
pbar.update(1)
|
|
3492
2912
|
tile_index += 1
|
|
3493
2913
|
continue
|
|
@@ -3495,316 +2915,119 @@ def export_geotiff_tiles(
|
|
|
3495
2915
|
# Read image data
|
|
3496
2916
|
image_data = src.read(window=window)
|
|
3497
2917
|
|
|
3498
|
-
#
|
|
3499
|
-
|
|
3500
|
-
img_data,
|
|
3501
|
-
lbl_mask,
|
|
3502
|
-
tile_id,
|
|
3503
|
-
img_profile,
|
|
3504
|
-
window_trans,
|
|
3505
|
-
is_augmented=False,
|
|
3506
|
-
):
|
|
3507
|
-
"""Save a single image and label tile."""
|
|
3508
|
-
# Export image as GeoTIFF
|
|
3509
|
-
image_path = os.path.join(image_dir, f"tile_{tile_id:06d}.tif")
|
|
3510
|
-
|
|
3511
|
-
# Update profile
|
|
3512
|
-
img_profile_copy = img_profile.copy()
|
|
3513
|
-
img_profile_copy.update(
|
|
3514
|
-
{
|
|
3515
|
-
"height": tile_size,
|
|
3516
|
-
"width": tile_size,
|
|
3517
|
-
"count": img_data.shape[0],
|
|
3518
|
-
"transform": window_trans,
|
|
3519
|
-
}
|
|
3520
|
-
)
|
|
3521
|
-
|
|
3522
|
-
# Save image as GeoTIFF
|
|
3523
|
-
try:
|
|
3524
|
-
with rasterio.open(image_path, "w", **img_profile_copy) as dst:
|
|
3525
|
-
dst.write(img_data)
|
|
3526
|
-
stats["total_tiles"] += 1
|
|
3527
|
-
except Exception as e:
|
|
3528
|
-
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
|
3529
|
-
stats["errors"] += 1
|
|
3530
|
-
return
|
|
3531
|
-
|
|
3532
|
-
# Export label as GeoTIFF (only if class data provided)
|
|
3533
|
-
if in_class_data is not None:
|
|
3534
|
-
# Create profile for label GeoTIFF
|
|
3535
|
-
label_profile = {
|
|
3536
|
-
"driver": "GTiff",
|
|
3537
|
-
"height": tile_size,
|
|
3538
|
-
"width": tile_size,
|
|
3539
|
-
"count": 1,
|
|
3540
|
-
"dtype": "uint8",
|
|
3541
|
-
"crs": src.crs,
|
|
3542
|
-
"transform": window_trans,
|
|
3543
|
-
}
|
|
3544
|
-
|
|
3545
|
-
label_path = os.path.join(label_dir, f"tile_{tile_id:06d}.tif")
|
|
3546
|
-
try:
|
|
3547
|
-
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
3548
|
-
dst.write(lbl_mask.astype(np.uint8), 1)
|
|
3549
|
-
|
|
3550
|
-
if not is_augmented and np.any(lbl_mask > 0):
|
|
3551
|
-
stats["tiles_with_features"] += 1
|
|
3552
|
-
stats["feature_pixels"] += np.count_nonzero(lbl_mask)
|
|
3553
|
-
except Exception as e:
|
|
3554
|
-
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
|
3555
|
-
stats["errors"] += 1
|
|
2918
|
+
# Export image as GeoTIFF
|
|
2919
|
+
image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
|
|
3556
2920
|
|
|
3557
|
-
#
|
|
3558
|
-
|
|
3559
|
-
|
|
3560
|
-
|
|
3561
|
-
|
|
3562
|
-
|
|
3563
|
-
|
|
3564
|
-
|
|
2921
|
+
# Create profile for image GeoTIFF
|
|
2922
|
+
image_profile = src.profile.copy()
|
|
2923
|
+
image_profile.update(
|
|
2924
|
+
{
|
|
2925
|
+
"height": tile_size,
|
|
2926
|
+
"width": tile_size,
|
|
2927
|
+
"count": image_data.shape[0],
|
|
2928
|
+
"transform": window_transform,
|
|
2929
|
+
}
|
|
3565
2930
|
)
|
|
3566
2931
|
|
|
3567
|
-
#
|
|
3568
|
-
|
|
3569
|
-
|
|
3570
|
-
|
|
3571
|
-
|
|
3572
|
-
|
|
3573
|
-
|
|
3574
|
-
|
|
3575
|
-
if not np.issubdtype(img_for_aug.dtype, np.uint8):
|
|
3576
|
-
# If image is float, scale to 0-255 and convert to uint8
|
|
3577
|
-
if np.issubdtype(img_for_aug.dtype, np.floating):
|
|
3578
|
-
img_for_aug = (
|
|
3579
|
-
(img_for_aug * 255).clip(0, 255).astype(np.uint8)
|
|
3580
|
-
)
|
|
3581
|
-
else:
|
|
3582
|
-
img_for_aug = img_for_aug.astype(np.uint8)
|
|
3583
|
-
|
|
3584
|
-
# Apply augmentation
|
|
3585
|
-
try:
|
|
3586
|
-
if in_class_data is not None:
|
|
3587
|
-
# Augment both image and mask
|
|
3588
|
-
augmented = augmentation_transforms(
|
|
3589
|
-
image=img_for_aug, mask=label_mask
|
|
3590
|
-
)
|
|
3591
|
-
aug_image = augmented["image"]
|
|
3592
|
-
aug_mask = augmented["mask"]
|
|
3593
|
-
else:
|
|
3594
|
-
# Augment only image
|
|
3595
|
-
augmented = augmentation_transforms(image=img_for_aug)
|
|
3596
|
-
aug_image = augmented["image"]
|
|
3597
|
-
aug_mask = label_mask
|
|
3598
|
-
|
|
3599
|
-
# Convert back from HWC to CHW
|
|
3600
|
-
aug_image = np.transpose(aug_image, (2, 0, 1))
|
|
3601
|
-
|
|
3602
|
-
# Ensure correct dtype for saving
|
|
3603
|
-
aug_image = aug_image.astype(image_data.dtype)
|
|
2932
|
+
# Save image as GeoTIFF
|
|
2933
|
+
try:
|
|
2934
|
+
with rasterio.open(image_path, "w", **image_profile) as dst:
|
|
2935
|
+
dst.write(image_data)
|
|
2936
|
+
stats["total_tiles"] += 1
|
|
2937
|
+
except Exception as e:
|
|
2938
|
+
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
|
2939
|
+
stats["errors"] += 1
|
|
3604
2940
|
|
|
3605
|
-
|
|
3606
|
-
|
|
3607
|
-
|
|
3608
|
-
|
|
3609
|
-
|
|
2941
|
+
# Create profile for label GeoTIFF
|
|
2942
|
+
label_profile = {
|
|
2943
|
+
"driver": "GTiff",
|
|
2944
|
+
"height": tile_size,
|
|
2945
|
+
"width": tile_size,
|
|
2946
|
+
"count": 1,
|
|
2947
|
+
"dtype": "uint8",
|
|
2948
|
+
"crs": src.crs,
|
|
2949
|
+
"transform": window_transform,
|
|
2950
|
+
}
|
|
3610
2951
|
|
|
3611
|
-
|
|
3612
|
-
|
|
3613
|
-
|
|
3614
|
-
|
|
3615
|
-
|
|
3616
|
-
src.profile,
|
|
3617
|
-
window_transform,
|
|
3618
|
-
is_augmented=True,
|
|
3619
|
-
)
|
|
2952
|
+
# Export label as GeoTIFF
|
|
2953
|
+
label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
|
|
2954
|
+
try:
|
|
2955
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
2956
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
|
3620
2957
|
|
|
3621
|
-
|
|
3622
|
-
|
|
3623
|
-
|
|
3624
|
-
|
|
3625
|
-
|
|
2958
|
+
if has_features:
|
|
2959
|
+
stats["tiles_with_features"] += 1
|
|
2960
|
+
stats["feature_pixels"] += np.count_nonzero(label_mask)
|
|
2961
|
+
except Exception as e:
|
|
2962
|
+
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
|
2963
|
+
stats["errors"] += 1
|
|
3626
2964
|
|
|
3627
|
-
# Create
|
|
2965
|
+
# Create XML annotation for object detection if using vector class data
|
|
3628
2966
|
if (
|
|
3629
|
-
|
|
3630
|
-
and not is_class_data_raster
|
|
2967
|
+
not is_class_data_raster
|
|
3631
2968
|
and "gdf" in locals()
|
|
3632
2969
|
and len(window_features) > 0
|
|
3633
2970
|
):
|
|
3634
|
-
|
|
3635
|
-
|
|
3636
|
-
|
|
3637
|
-
|
|
3638
|
-
ET.SubElement(root, "filename").text = (
|
|
3639
|
-
f"tile_{tile_index:06d}.tif"
|
|
3640
|
-
)
|
|
3641
|
-
|
|
3642
|
-
size = ET.SubElement(root, "size")
|
|
3643
|
-
ET.SubElement(size, "width").text = str(tile_size)
|
|
3644
|
-
ET.SubElement(size, "height").text = str(tile_size)
|
|
3645
|
-
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
3646
|
-
|
|
3647
|
-
# Add georeference information
|
|
3648
|
-
geo = ET.SubElement(root, "georeference")
|
|
3649
|
-
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
3650
|
-
ET.SubElement(geo, "transform").text = str(
|
|
3651
|
-
window_transform
|
|
3652
|
-
).replace("\n", "")
|
|
3653
|
-
ET.SubElement(geo, "bounds").text = (
|
|
3654
|
-
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
3655
|
-
)
|
|
3656
|
-
|
|
3657
|
-
# Add objects
|
|
3658
|
-
for idx, feature in window_features.iterrows():
|
|
3659
|
-
# Get feature class
|
|
3660
|
-
if class_value_field in feature:
|
|
3661
|
-
class_val = feature[class_value_field]
|
|
3662
|
-
else:
|
|
3663
|
-
class_val = "object"
|
|
3664
|
-
|
|
3665
|
-
# Get geometry bounds in pixel coordinates
|
|
3666
|
-
geom = feature.geometry.intersection(window_bounds)
|
|
3667
|
-
if not geom.is_empty:
|
|
3668
|
-
# Get bounds in world coordinates
|
|
3669
|
-
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3670
|
-
|
|
3671
|
-
# Convert to pixel coordinates
|
|
3672
|
-
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3673
|
-
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3674
|
-
|
|
3675
|
-
# Ensure coordinates are within tile bounds
|
|
3676
|
-
xmin = max(0, min(tile_size, int(col_min)))
|
|
3677
|
-
ymin = max(0, min(tile_size, int(row_min)))
|
|
3678
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
|
3679
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
|
3680
|
-
|
|
3681
|
-
# Only add if the box has non-zero area
|
|
3682
|
-
if xmax > xmin and ymax > ymin:
|
|
3683
|
-
obj = ET.SubElement(root, "object")
|
|
3684
|
-
ET.SubElement(obj, "name").text = str(class_val)
|
|
3685
|
-
ET.SubElement(obj, "difficult").text = "0"
|
|
3686
|
-
|
|
3687
|
-
bbox = ET.SubElement(obj, "bndbox")
|
|
3688
|
-
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
3689
|
-
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
3690
|
-
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
3691
|
-
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
3692
|
-
|
|
3693
|
-
# Save XML
|
|
3694
|
-
tree = ET.ElementTree(root)
|
|
3695
|
-
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
3696
|
-
tree.write(xml_path)
|
|
3697
|
-
|
|
3698
|
-
elif metadata_format == "COCO":
|
|
3699
|
-
# Add image info
|
|
3700
|
-
image_id = tile_index
|
|
3701
|
-
coco_annotations["images"].append(
|
|
3702
|
-
{
|
|
3703
|
-
"id": image_id,
|
|
3704
|
-
"file_name": f"tile_{tile_index:06d}.tif",
|
|
3705
|
-
"width": tile_size,
|
|
3706
|
-
"height": tile_size,
|
|
3707
|
-
"crs": str(src.crs),
|
|
3708
|
-
"transform": str(window_transform),
|
|
3709
|
-
}
|
|
3710
|
-
)
|
|
3711
|
-
|
|
3712
|
-
# Add annotations for each feature
|
|
3713
|
-
for _, feature in window_features.iterrows():
|
|
3714
|
-
# Get feature class
|
|
3715
|
-
if class_value_field in feature:
|
|
3716
|
-
class_val = feature[class_value_field]
|
|
3717
|
-
category_id = class_to_id.get(class_val, 1)
|
|
3718
|
-
else:
|
|
3719
|
-
category_id = 1
|
|
2971
|
+
# Create XML annotation
|
|
2972
|
+
root = ET.Element("annotation")
|
|
2973
|
+
ET.SubElement(root, "folder").text = "images"
|
|
2974
|
+
ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
|
|
3720
2975
|
|
|
3721
|
-
|
|
3722
|
-
|
|
3723
|
-
|
|
3724
|
-
|
|
3725
|
-
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3726
|
-
|
|
3727
|
-
# Convert to pixel coordinates
|
|
3728
|
-
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3729
|
-
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3730
|
-
|
|
3731
|
-
# Ensure coordinates are within tile bounds
|
|
3732
|
-
xmin = max(0, min(tile_size, int(col_min)))
|
|
3733
|
-
ymin = max(0, min(tile_size, int(row_min)))
|
|
3734
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
|
3735
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
|
3736
|
-
|
|
3737
|
-
# Skip if box is too small
|
|
3738
|
-
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
3739
|
-
continue
|
|
3740
|
-
|
|
3741
|
-
width = xmax - xmin
|
|
3742
|
-
height = ymax - ymin
|
|
3743
|
-
|
|
3744
|
-
# Add annotation
|
|
3745
|
-
ann_id += 1
|
|
3746
|
-
coco_annotations["annotations"].append(
|
|
3747
|
-
{
|
|
3748
|
-
"id": ann_id,
|
|
3749
|
-
"image_id": image_id,
|
|
3750
|
-
"category_id": category_id,
|
|
3751
|
-
"bbox": [xmin, ymin, width, height],
|
|
3752
|
-
"area": width * height,
|
|
3753
|
-
"iscrowd": 0,
|
|
3754
|
-
}
|
|
3755
|
-
)
|
|
2976
|
+
size = ET.SubElement(root, "size")
|
|
2977
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
|
2978
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
|
2979
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
3756
2980
|
|
|
3757
|
-
|
|
3758
|
-
|
|
3759
|
-
|
|
2981
|
+
# Add georeference information
|
|
2982
|
+
geo = ET.SubElement(root, "georeference")
|
|
2983
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
2984
|
+
ET.SubElement(geo, "transform").text = str(
|
|
2985
|
+
window_transform
|
|
2986
|
+
).replace("\n", "")
|
|
2987
|
+
ET.SubElement(geo, "bounds").text = (
|
|
2988
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
2989
|
+
)
|
|
3760
2990
|
|
|
3761
|
-
|
|
3762
|
-
|
|
3763
|
-
|
|
3764
|
-
|
|
3765
|
-
|
|
3766
|
-
|
|
3767
|
-
|
|
3768
|
-
class_id = 0
|
|
2991
|
+
# Add objects
|
|
2992
|
+
for idx, feature in window_features.iterrows():
|
|
2993
|
+
# Get feature class
|
|
2994
|
+
if class_value_field in feature:
|
|
2995
|
+
class_val = feature[class_value_field]
|
|
2996
|
+
else:
|
|
2997
|
+
class_val = "object"
|
|
3769
2998
|
|
|
3770
|
-
|
|
3771
|
-
|
|
3772
|
-
|
|
3773
|
-
|
|
3774
|
-
|
|
3775
|
-
|
|
3776
|
-
|
|
3777
|
-
|
|
3778
|
-
|
|
3779
|
-
|
|
3780
|
-
|
|
3781
|
-
|
|
3782
|
-
|
|
3783
|
-
|
|
3784
|
-
|
|
3785
|
-
|
|
3786
|
-
|
|
3787
|
-
|
|
3788
|
-
|
|
3789
|
-
|
|
3790
|
-
|
|
3791
|
-
|
|
3792
|
-
|
|
3793
|
-
|
|
3794
|
-
|
|
3795
|
-
|
|
3796
|
-
|
|
3797
|
-
yolo_annotations.append(
|
|
3798
|
-
f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
|
3799
|
-
)
|
|
2999
|
+
# Get geometry bounds in pixel coordinates
|
|
3000
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
3001
|
+
if not geom.is_empty:
|
|
3002
|
+
# Get bounds in world coordinates
|
|
3003
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
3004
|
+
|
|
3005
|
+
# Convert to pixel coordinates
|
|
3006
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3007
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3008
|
+
|
|
3009
|
+
# Ensure coordinates are within tile bounds
|
|
3010
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
|
3011
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
|
3012
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
|
3013
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
|
3014
|
+
|
|
3015
|
+
# Only add if the box has non-zero area
|
|
3016
|
+
if xmax > xmin and ymax > ymin:
|
|
3017
|
+
obj = ET.SubElement(root, "object")
|
|
3018
|
+
ET.SubElement(obj, "name").text = str(class_val)
|
|
3019
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
3020
|
+
|
|
3021
|
+
bbox = ET.SubElement(obj, "bndbox")
|
|
3022
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
3023
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
3024
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
3025
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
3800
3026
|
|
|
3801
|
-
|
|
3802
|
-
|
|
3803
|
-
|
|
3804
|
-
|
|
3805
|
-
)
|
|
3806
|
-
with open(yolo_path, "w") as f:
|
|
3807
|
-
f.write("\n".join(yolo_annotations))
|
|
3027
|
+
# Save XML
|
|
3028
|
+
tree = ET.ElementTree(root)
|
|
3029
|
+
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
3030
|
+
tree.write(xml_path)
|
|
3808
3031
|
|
|
3809
3032
|
# Update progress bar
|
|
3810
3033
|
pbar.update(1)
|
|
@@ -3822,39 +3045,6 @@ def export_geotiff_tiles(
|
|
|
3822
3045
|
# Close progress bar
|
|
3823
3046
|
pbar.close()
|
|
3824
3047
|
|
|
3825
|
-
# Save COCO annotations if applicable (only if class data provided)
|
|
3826
|
-
if in_class_data is not None and metadata_format == "COCO":
|
|
3827
|
-
try:
|
|
3828
|
-
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
|
3829
|
-
json.dump(coco_annotations, f, indent=2)
|
|
3830
|
-
if not quiet:
|
|
3831
|
-
print(
|
|
3832
|
-
f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
|
|
3833
|
-
f"{len(coco_annotations['annotations'])} annotations, "
|
|
3834
|
-
f"{len(coco_annotations['categories'])} categories"
|
|
3835
|
-
)
|
|
3836
|
-
except Exception as e:
|
|
3837
|
-
if not quiet:
|
|
3838
|
-
print(f"ERROR saving COCO annotations: {e}")
|
|
3839
|
-
stats["errors"] += 1
|
|
3840
|
-
|
|
3841
|
-
# Save YOLO classes file if applicable (only if class data provided)
|
|
3842
|
-
if in_class_data is not None and metadata_format == "YOLO":
|
|
3843
|
-
try:
|
|
3844
|
-
# Create classes.txt with class names
|
|
3845
|
-
classes_path = os.path.join(out_folder, "classes.txt")
|
|
3846
|
-
# Sort by class ID to ensure correct order
|
|
3847
|
-
sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
|
|
3848
|
-
with open(classes_path, "w") as f:
|
|
3849
|
-
for class_val, _ in sorted_classes:
|
|
3850
|
-
f.write(f"{class_val}\n")
|
|
3851
|
-
if not quiet:
|
|
3852
|
-
print(f"Saved YOLO classes file with {len(class_to_id)} classes")
|
|
3853
|
-
except Exception as e:
|
|
3854
|
-
if not quiet:
|
|
3855
|
-
print(f"ERROR saving YOLO classes file: {e}")
|
|
3856
|
-
stats["errors"] += 1
|
|
3857
|
-
|
|
3858
3048
|
# Create overview image if requested
|
|
3859
3049
|
if create_overview and stats["tile_coordinates"]:
|
|
3860
3050
|
try:
|
|
@@ -3872,14 +3062,13 @@ def export_geotiff_tiles(
|
|
|
3872
3062
|
if not quiet:
|
|
3873
3063
|
print("\n------- Export Summary -------")
|
|
3874
3064
|
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
3875
|
-
|
|
3065
|
+
print(
|
|
3066
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
3067
|
+
)
|
|
3068
|
+
if stats["tiles_with_features"] > 0:
|
|
3876
3069
|
print(
|
|
3877
|
-
f"
|
|
3070
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
3878
3071
|
)
|
|
3879
|
-
if stats["tiles_with_features"] > 0:
|
|
3880
|
-
print(
|
|
3881
|
-
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
3882
|
-
)
|
|
3883
3072
|
if stats["errors"] > 0:
|
|
3884
3073
|
print(f"Errors encountered: {stats['errors']}")
|
|
3885
3074
|
print(f"Output saved to: {out_folder}")
|
|
@@ -3888,6 +3077,7 @@ def export_geotiff_tiles(
|
|
|
3888
3077
|
if stats["total_tiles"] > 0:
|
|
3889
3078
|
print("\n------- Georeference Verification -------")
|
|
3890
3079
|
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
|
3080
|
+
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
|
3891
3081
|
|
|
3892
3082
|
if os.path.exists(sample_image):
|
|
3893
3083
|
try:
|
|
@@ -3903,22 +3093,19 @@ def export_geotiff_tiles(
|
|
|
3903
3093
|
except Exception as e:
|
|
3904
3094
|
print(f"Error verifying image georeference: {e}")
|
|
3905
3095
|
|
|
3906
|
-
|
|
3907
|
-
|
|
3908
|
-
|
|
3909
|
-
|
|
3910
|
-
|
|
3911
|
-
|
|
3912
|
-
|
|
3913
|
-
|
|
3914
|
-
|
|
3915
|
-
|
|
3916
|
-
|
|
3917
|
-
|
|
3918
|
-
|
|
3919
|
-
)
|
|
3920
|
-
except Exception as e:
|
|
3921
|
-
print(f"Error verifying label georeference: {e}")
|
|
3096
|
+
if os.path.exists(sample_label):
|
|
3097
|
+
try:
|
|
3098
|
+
with rasterio.open(sample_label) as lbl:
|
|
3099
|
+
print(f"Label CRS: {lbl.crs}")
|
|
3100
|
+
print(f"Label transform: {lbl.transform}")
|
|
3101
|
+
print(
|
|
3102
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
3103
|
+
)
|
|
3104
|
+
print(
|
|
3105
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
3106
|
+
)
|
|
3107
|
+
except Exception as e:
|
|
3108
|
+
print(f"Error verifying label georeference: {e}")
|
|
3922
3109
|
|
|
3923
3110
|
# Return statistics dictionary for further processing if needed
|
|
3924
3111
|
return stats
|
|
@@ -3926,9 +3113,8 @@ def export_geotiff_tiles(
|
|
|
3926
3113
|
|
|
3927
3114
|
def export_geotiff_tiles_batch(
|
|
3928
3115
|
images_folder,
|
|
3929
|
-
masks_folder
|
|
3930
|
-
|
|
3931
|
-
output_folder=None,
|
|
3116
|
+
masks_folder,
|
|
3117
|
+
output_folder,
|
|
3932
3118
|
tile_size=256,
|
|
3933
3119
|
stride=128,
|
|
3934
3120
|
class_value_field="class",
|
|
@@ -3936,43 +3122,25 @@ def export_geotiff_tiles_batch(
|
|
|
3936
3122
|
max_tiles=None,
|
|
3937
3123
|
quiet=False,
|
|
3938
3124
|
all_touched=True,
|
|
3125
|
+
create_overview=False,
|
|
3939
3126
|
skip_empty_tiles=False,
|
|
3940
3127
|
image_extensions=None,
|
|
3941
3128
|
mask_extensions=None,
|
|
3942
|
-
match_by_name=False,
|
|
3943
|
-
metadata_format="PASCAL_VOC",
|
|
3944
3129
|
) -> Dict[str, Any]:
|
|
3945
3130
|
"""
|
|
3946
|
-
Export georeferenced GeoTIFF tiles from images and
|
|
3947
|
-
|
|
3948
|
-
This function supports four modes:
|
|
3949
|
-
1. Images only (no masks) - when neither masks_file nor masks_folder is provided
|
|
3950
|
-
2. Single vector file covering all images (masks_file parameter)
|
|
3951
|
-
3. Multiple vector files, one per image (masks_folder parameter)
|
|
3952
|
-
4. Multiple raster mask files (masks_folder parameter)
|
|
3953
|
-
|
|
3954
|
-
For mode 1 (images only), only image tiles will be exported without labels.
|
|
3131
|
+
Export georeferenced GeoTIFF tiles from folders of images and masks.
|
|
3955
3132
|
|
|
3956
|
-
|
|
3957
|
-
|
|
3133
|
+
This function processes multiple image-mask pairs from input folders,
|
|
3134
|
+
generating tiles for each pair. All image tiles are saved to a single
|
|
3135
|
+
'images' folder and all mask tiles to a single 'masks' folder.
|
|
3958
3136
|
|
|
3959
|
-
|
|
3960
|
-
|
|
3961
|
-
(match_by_name=False).
|
|
3962
|
-
|
|
3963
|
-
All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
|
|
3964
|
-
to a single 'masks' folder within the output directory.
|
|
3137
|
+
Images and masks are paired by their sorted order (alphabetically), not by
|
|
3138
|
+
filename matching. The number of images and masks must be equal.
|
|
3965
3139
|
|
|
3966
3140
|
Args:
|
|
3967
3141
|
images_folder (str): Path to folder containing raster images
|
|
3968
|
-
masks_folder (str
|
|
3969
|
-
|
|
3970
|
-
and masks_file is also not provided, only image tiles will be exported.
|
|
3971
|
-
masks_file (str, optional): Path to a single vector file covering all images.
|
|
3972
|
-
Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
|
|
3973
|
-
and masks_folder is also not provided, only image tiles will be exported.
|
|
3974
|
-
output_folder (str, optional): Path to output folder. If None, creates 'tiles'
|
|
3975
|
-
subfolder in images_folder.
|
|
3142
|
+
masks_folder (str): Path to folder containing classification masks/vectors
|
|
3143
|
+
output_folder (str): Path to output folder
|
|
3976
3144
|
tile_size (int): Size of tiles in pixels (square)
|
|
3977
3145
|
stride (int): Step size between tiles
|
|
3978
3146
|
class_value_field (str): Field containing class values (for vector data)
|
|
@@ -3981,66 +3149,21 @@ def export_geotiff_tiles_batch(
|
|
|
3981
3149
|
quiet (bool): If True, suppress non-essential output
|
|
3982
3150
|
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
|
3983
3151
|
create_overview (bool): Whether to create an overview image of all tiles
|
|
3984
|
-
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
3985
|
-
image_extensions (list): List of image file extensions to process (default: common raster formats)
|
|
3986
|
-
mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
|
|
3987
|
-
match_by_name (bool): If True, match image and mask files by base filename.
|
|
3988
|
-
If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
|
|
3989
|
-
metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
|
|
3990
|
-
Default is "PASCAL_VOC".
|
|
3991
|
-
|
|
3992
|
-
Returns:
|
|
3993
|
-
Dict[str, Any]: Dictionary containing batch processing statistics
|
|
3994
|
-
|
|
3995
|
-
Raises:
|
|
3996
|
-
ValueError: If no images found, or if masks_folder and masks_file are both specified,
|
|
3997
|
-
or if counts don't match when using masks_folder with match_by_name=False.
|
|
3998
|
-
|
|
3999
|
-
Examples:
|
|
4000
|
-
# Images only (no masks)
|
|
4001
|
-
>>> stats = export_geotiff_tiles_batch(
|
|
4002
|
-
... images_folder='data/images',
|
|
4003
|
-
... output_folder='output/tiles'
|
|
4004
|
-
... )
|
|
4005
|
-
|
|
4006
|
-
# Single vector file covering all images
|
|
4007
|
-
>>> stats = export_geotiff_tiles_batch(
|
|
4008
|
-
... images_folder='data/images',
|
|
4009
|
-
... masks_file='data/buildings.geojson',
|
|
4010
|
-
... output_folder='output/tiles'
|
|
4011
|
-
... )
|
|
3152
|
+
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
3153
|
+
image_extensions (list): List of image file extensions to process (default: common raster formats)
|
|
3154
|
+
mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
|
|
4012
3155
|
|
|
4013
|
-
|
|
4014
|
-
|
|
4015
|
-
... images_folder='data/images',
|
|
4016
|
-
... masks_folder='data/masks',
|
|
4017
|
-
... output_folder='output/tiles',
|
|
4018
|
-
... match_by_name=True
|
|
4019
|
-
... )
|
|
3156
|
+
Returns:
|
|
3157
|
+
Dict[str, Any]: Dictionary containing batch processing statistics
|
|
4020
3158
|
|
|
4021
|
-
|
|
4022
|
-
|
|
4023
|
-
... images_folder='data/images',
|
|
4024
|
-
... masks_folder='data/masks',
|
|
4025
|
-
... output_folder='output/tiles',
|
|
4026
|
-
... match_by_name=False
|
|
4027
|
-
... )
|
|
3159
|
+
Raises:
|
|
3160
|
+
ValueError: If no images or masks found, or if counts don't match
|
|
4028
3161
|
"""
|
|
4029
3162
|
|
|
4030
3163
|
import logging
|
|
4031
3164
|
|
|
4032
3165
|
logging.getLogger("rasterio").setLevel(logging.ERROR)
|
|
4033
3166
|
|
|
4034
|
-
# Validate input parameters
|
|
4035
|
-
if masks_folder is not None and masks_file is not None:
|
|
4036
|
-
raise ValueError(
|
|
4037
|
-
"Cannot specify both masks_folder and masks_file. Please use only one."
|
|
4038
|
-
)
|
|
4039
|
-
|
|
4040
|
-
# Default output folder if not specified
|
|
4041
|
-
if output_folder is None:
|
|
4042
|
-
output_folder = os.path.join(images_folder, "tiles")
|
|
4043
|
-
|
|
4044
3167
|
# Default extensions if not provided
|
|
4045
3168
|
if image_extensions is None:
|
|
4046
3169
|
image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
|
|
@@ -4067,37 +3190,9 @@ def export_geotiff_tiles_batch(
|
|
|
4067
3190
|
# Create output folder structure
|
|
4068
3191
|
os.makedirs(output_folder, exist_ok=True)
|
|
4069
3192
|
output_images_dir = os.path.join(output_folder, "images")
|
|
3193
|
+
output_masks_dir = os.path.join(output_folder, "masks")
|
|
4070
3194
|
os.makedirs(output_images_dir, exist_ok=True)
|
|
4071
|
-
|
|
4072
|
-
# Only create masks directory if masks are provided
|
|
4073
|
-
output_masks_dir = None
|
|
4074
|
-
if masks_folder is not None or masks_file is not None:
|
|
4075
|
-
output_masks_dir = os.path.join(output_folder, "masks")
|
|
4076
|
-
os.makedirs(output_masks_dir, exist_ok=True)
|
|
4077
|
-
|
|
4078
|
-
# Create annotation directory based on metadata format (only if masks are provided)
|
|
4079
|
-
ann_dir = None
|
|
4080
|
-
if (masks_folder is not None or masks_file is not None) and metadata_format in [
|
|
4081
|
-
"PASCAL_VOC",
|
|
4082
|
-
"COCO",
|
|
4083
|
-
]:
|
|
4084
|
-
ann_dir = os.path.join(output_folder, "annotations")
|
|
4085
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
4086
|
-
|
|
4087
|
-
# Initialize COCO annotations dictionary (only if masks are provided)
|
|
4088
|
-
coco_annotations = None
|
|
4089
|
-
if (
|
|
4090
|
-
masks_folder is not None or masks_file is not None
|
|
4091
|
-
) and metadata_format == "COCO":
|
|
4092
|
-
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
4093
|
-
|
|
4094
|
-
# Initialize YOLO class set (only if masks are provided)
|
|
4095
|
-
yolo_classes = (
|
|
4096
|
-
set()
|
|
4097
|
-
if (masks_folder is not None or masks_file is not None)
|
|
4098
|
-
and metadata_format == "YOLO"
|
|
4099
|
-
else None
|
|
4100
|
-
)
|
|
3195
|
+
os.makedirs(output_masks_dir, exist_ok=True)
|
|
4101
3196
|
|
|
4102
3197
|
# Get list of image files
|
|
4103
3198
|
image_files = []
|
|
@@ -4105,105 +3200,30 @@ def export_geotiff_tiles_batch(
|
|
|
4105
3200
|
pattern = os.path.join(images_folder, f"*{ext}")
|
|
4106
3201
|
image_files.extend(glob.glob(pattern))
|
|
4107
3202
|
|
|
3203
|
+
# Get list of mask files
|
|
3204
|
+
mask_files = []
|
|
3205
|
+
for ext in mask_extensions:
|
|
3206
|
+
pattern = os.path.join(masks_folder, f"*{ext}")
|
|
3207
|
+
mask_files.extend(glob.glob(pattern))
|
|
3208
|
+
|
|
4108
3209
|
# Sort files for consistent processing
|
|
4109
3210
|
image_files.sort()
|
|
3211
|
+
mask_files.sort()
|
|
4110
3212
|
|
|
4111
3213
|
if not image_files:
|
|
4112
3214
|
raise ValueError(
|
|
4113
3215
|
f"No image files found in {images_folder} with extensions {image_extensions}"
|
|
4114
3216
|
)
|
|
4115
3217
|
|
|
4116
|
-
|
|
4117
|
-
|
|
4118
|
-
|
|
4119
|
-
|
|
4120
|
-
image_mask_pairs = []
|
|
4121
|
-
|
|
4122
|
-
if not has_masks:
|
|
4123
|
-
# Mode 0: No masks - create pairs with None for mask
|
|
4124
|
-
for image_file in image_files:
|
|
4125
|
-
image_mask_pairs.append((image_file, None, None))
|
|
4126
|
-
|
|
4127
|
-
elif use_single_mask_file:
|
|
4128
|
-
# Mode 1: Single vector file covering all images
|
|
4129
|
-
if not os.path.exists(masks_file):
|
|
4130
|
-
raise ValueError(f"Mask file not found: {masks_file}")
|
|
4131
|
-
|
|
4132
|
-
# Load the single mask file once - will be spatially filtered per image
|
|
4133
|
-
single_mask_gdf = gpd.read_file(masks_file)
|
|
4134
|
-
|
|
4135
|
-
if not quiet:
|
|
4136
|
-
print(f"Using single mask file: {masks_file}")
|
|
4137
|
-
print(
|
|
4138
|
-
f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
|
|
4139
|
-
)
|
|
4140
|
-
|
|
4141
|
-
# Create pairs with the same mask file for all images
|
|
4142
|
-
for image_file in image_files:
|
|
4143
|
-
image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
|
|
4144
|
-
|
|
4145
|
-
else:
|
|
4146
|
-
# Mode 2/3: Multiple mask files (vector or raster)
|
|
4147
|
-
# Get list of mask files
|
|
4148
|
-
for ext in mask_extensions:
|
|
4149
|
-
pattern = os.path.join(masks_folder, f"*{ext}")
|
|
4150
|
-
mask_files.extend(glob.glob(pattern))
|
|
4151
|
-
|
|
4152
|
-
# Sort files for consistent processing
|
|
4153
|
-
mask_files.sort()
|
|
4154
|
-
|
|
4155
|
-
if not mask_files:
|
|
4156
|
-
raise ValueError(
|
|
4157
|
-
f"No mask files found in {masks_folder} with extensions {mask_extensions}"
|
|
4158
|
-
)
|
|
4159
|
-
|
|
4160
|
-
# Match images to masks
|
|
4161
|
-
if match_by_name:
|
|
4162
|
-
# Match by base filename
|
|
4163
|
-
image_dict = {
|
|
4164
|
-
os.path.splitext(os.path.basename(f))[0]: f for f in image_files
|
|
4165
|
-
}
|
|
4166
|
-
mask_dict = {
|
|
4167
|
-
os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
|
|
4168
|
-
}
|
|
4169
|
-
|
|
4170
|
-
# Find matching pairs
|
|
4171
|
-
for img_base, img_path in image_dict.items():
|
|
4172
|
-
if img_base in mask_dict:
|
|
4173
|
-
image_mask_pairs.append((img_path, mask_dict[img_base], None))
|
|
4174
|
-
else:
|
|
4175
|
-
if not quiet:
|
|
4176
|
-
print(f"Warning: No mask found for image {img_base}")
|
|
4177
|
-
|
|
4178
|
-
if not image_mask_pairs:
|
|
4179
|
-
# Provide detailed error message with found files
|
|
4180
|
-
image_bases = list(image_dict.keys())
|
|
4181
|
-
mask_bases = list(mask_dict.keys())
|
|
4182
|
-
error_msg = (
|
|
4183
|
-
"No matching image-mask pairs found when matching by filename. "
|
|
4184
|
-
"Check that image and mask files have matching base names.\n"
|
|
4185
|
-
f"Found {len(image_bases)} image(s): "
|
|
4186
|
-
f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
|
|
4187
|
-
f"{'...' if len(image_bases) > 5 else ''}\n"
|
|
4188
|
-
f"Found {len(mask_bases)} mask(s): "
|
|
4189
|
-
f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
|
|
4190
|
-
f"{'...' if len(mask_bases) > 5 else ''}\n"
|
|
4191
|
-
"Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
|
|
4192
|
-
)
|
|
4193
|
-
raise ValueError(error_msg)
|
|
4194
|
-
|
|
4195
|
-
else:
|
|
4196
|
-
# Match by sorted order
|
|
4197
|
-
if len(image_files) != len(mask_files):
|
|
4198
|
-
raise ValueError(
|
|
4199
|
-
f"Number of image files ({len(image_files)}) does not match "
|
|
4200
|
-
f"number of mask files ({len(mask_files)}) when matching by sorted order. "
|
|
4201
|
-
f"Use match_by_name=True for filename-based matching."
|
|
4202
|
-
)
|
|
3218
|
+
if not mask_files:
|
|
3219
|
+
raise ValueError(
|
|
3220
|
+
f"No mask files found in {masks_folder} with extensions {mask_extensions}"
|
|
3221
|
+
)
|
|
4203
3222
|
|
|
4204
|
-
|
|
4205
|
-
|
|
4206
|
-
|
|
3223
|
+
if len(image_files) != len(mask_files):
|
|
3224
|
+
raise ValueError(
|
|
3225
|
+
f"Number of image files ({len(image_files)}) does not match number of mask files ({len(mask_files)})"
|
|
3226
|
+
)
|
|
4207
3227
|
|
|
4208
3228
|
# Initialize batch statistics
|
|
4209
3229
|
batch_stats = {
|
|
@@ -4217,28 +3237,23 @@ def export_geotiff_tiles_batch(
|
|
|
4217
3237
|
}
|
|
4218
3238
|
|
|
4219
3239
|
if not quiet:
|
|
4220
|
-
|
|
4221
|
-
|
|
4222
|
-
|
|
4223
|
-
|
|
4224
|
-
elif use_single_mask_file:
|
|
4225
|
-
print(f"Found {len(image_files)} image files to process")
|
|
4226
|
-
print(f"Using single mask file: {masks_file}")
|
|
4227
|
-
else:
|
|
4228
|
-
print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
|
|
4229
|
-
print(f"Processing batch from {images_folder} and {masks_folder}")
|
|
3240
|
+
print(
|
|
3241
|
+
f"Found {len(image_files)} image files and {len(mask_files)} mask files to process"
|
|
3242
|
+
)
|
|
3243
|
+
print(f"Processing batch from {images_folder} and {masks_folder}")
|
|
4230
3244
|
print(f"Output folder: {output_folder}")
|
|
4231
3245
|
print("-" * 60)
|
|
4232
3246
|
|
|
4233
3247
|
# Global tile counter for unique naming
|
|
4234
3248
|
global_tile_counter = 0
|
|
4235
3249
|
|
|
4236
|
-
# Process each image-mask pair
|
|
4237
|
-
for idx, (image_file, mask_file
|
|
3250
|
+
# Process each image-mask pair by sorted order
|
|
3251
|
+
for idx, (image_file, mask_file) in enumerate(
|
|
4238
3252
|
tqdm(
|
|
4239
|
-
|
|
3253
|
+
zip(image_files, mask_files),
|
|
4240
3254
|
desc="Processing image pairs",
|
|
4241
3255
|
disable=quiet,
|
|
3256
|
+
total=len(image_files),
|
|
4242
3257
|
)
|
|
4243
3258
|
):
|
|
4244
3259
|
batch_stats["total_image_pairs"] += 1
|
|
@@ -4250,17 +3265,9 @@ def export_geotiff_tiles_batch(
|
|
|
4250
3265
|
if not quiet:
|
|
4251
3266
|
print(f"\nProcessing: {base_name}")
|
|
4252
3267
|
print(f" Image: {os.path.basename(image_file)}")
|
|
4253
|
-
|
|
4254
|
-
if use_single_mask_file:
|
|
4255
|
-
print(
|
|
4256
|
-
f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
|
|
4257
|
-
)
|
|
4258
|
-
else:
|
|
4259
|
-
print(f" Mask: {os.path.basename(mask_file)}")
|
|
4260
|
-
else:
|
|
4261
|
-
print(f" Mask: None (images only)")
|
|
3268
|
+
print(f" Mask: {os.path.basename(mask_file)}")
|
|
4262
3269
|
|
|
4263
|
-
# Process the image-mask pair
|
|
3270
|
+
# Process the image-mask pair manually to get direct control over tile saving
|
|
4264
3271
|
tiles_generated = _process_image_mask_pair(
|
|
4265
3272
|
image_file=image_file,
|
|
4266
3273
|
mask_file=mask_file,
|
|
@@ -4276,15 +3283,6 @@ def export_geotiff_tiles_batch(
|
|
|
4276
3283
|
all_touched=all_touched,
|
|
4277
3284
|
skip_empty_tiles=skip_empty_tiles,
|
|
4278
3285
|
quiet=quiet,
|
|
4279
|
-
mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
|
|
4280
|
-
use_single_mask_file=use_single_mask_file,
|
|
4281
|
-
metadata_format=metadata_format,
|
|
4282
|
-
ann_dir=(
|
|
4283
|
-
ann_dir
|
|
4284
|
-
if "ann_dir" in locals()
|
|
4285
|
-
and metadata_format in ["PASCAL_VOC", "COCO"]
|
|
4286
|
-
else None
|
|
4287
|
-
),
|
|
4288
3286
|
)
|
|
4289
3287
|
|
|
4290
3288
|
# Update counters
|
|
@@ -4306,23 +3304,6 @@ def export_geotiff_tiles_batch(
|
|
|
4306
3304
|
}
|
|
4307
3305
|
)
|
|
4308
3306
|
|
|
4309
|
-
# Aggregate COCO annotations
|
|
4310
|
-
if metadata_format == "COCO" and "coco_data" in tiles_generated:
|
|
4311
|
-
coco_data = tiles_generated["coco_data"]
|
|
4312
|
-
# Add images and annotations
|
|
4313
|
-
coco_annotations["images"].extend(coco_data.get("images", []))
|
|
4314
|
-
coco_annotations["annotations"].extend(coco_data.get("annotations", []))
|
|
4315
|
-
# Merge categories (avoid duplicates)
|
|
4316
|
-
for cat in coco_data.get("categories", []):
|
|
4317
|
-
if not any(
|
|
4318
|
-
c["id"] == cat["id"] for c in coco_annotations["categories"]
|
|
4319
|
-
):
|
|
4320
|
-
coco_annotations["categories"].append(cat)
|
|
4321
|
-
|
|
4322
|
-
# Aggregate YOLO classes
|
|
4323
|
-
if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
|
|
4324
|
-
yolo_classes.update(tiles_generated["yolo_classes"])
|
|
4325
|
-
|
|
4326
3307
|
except Exception as e:
|
|
4327
3308
|
if not quiet:
|
|
4328
3309
|
print(f"ERROR processing {base_name}: {e}")
|
|
@@ -4331,33 +3312,6 @@ def export_geotiff_tiles_batch(
|
|
|
4331
3312
|
)
|
|
4332
3313
|
batch_stats["errors"] += 1
|
|
4333
3314
|
|
|
4334
|
-
# Save aggregated COCO annotations
|
|
4335
|
-
if metadata_format == "COCO" and coco_annotations:
|
|
4336
|
-
import json
|
|
4337
|
-
|
|
4338
|
-
coco_path = os.path.join(ann_dir, "instances.json")
|
|
4339
|
-
with open(coco_path, "w") as f:
|
|
4340
|
-
json.dump(coco_annotations, f, indent=2)
|
|
4341
|
-
if not quiet:
|
|
4342
|
-
print(f"\nSaved COCO annotations: {coco_path}")
|
|
4343
|
-
print(
|
|
4344
|
-
f" Images: {len(coco_annotations['images'])}, "
|
|
4345
|
-
f"Annotations: {len(coco_annotations['annotations'])}, "
|
|
4346
|
-
f"Categories: {len(coco_annotations['categories'])}"
|
|
4347
|
-
)
|
|
4348
|
-
|
|
4349
|
-
# Save aggregated YOLO classes
|
|
4350
|
-
if metadata_format == "YOLO" and yolo_classes:
|
|
4351
|
-
classes_path = os.path.join(output_folder, "labels", "classes.txt")
|
|
4352
|
-
os.makedirs(os.path.dirname(classes_path), exist_ok=True)
|
|
4353
|
-
sorted_classes = sorted(yolo_classes)
|
|
4354
|
-
with open(classes_path, "w") as f:
|
|
4355
|
-
for cls in sorted_classes:
|
|
4356
|
-
f.write(f"{cls}\n")
|
|
4357
|
-
if not quiet:
|
|
4358
|
-
print(f"\nSaved YOLO classes: {classes_path}")
|
|
4359
|
-
print(f" Total classes: {len(sorted_classes)}")
|
|
4360
|
-
|
|
4361
3315
|
# Print batch summary
|
|
4362
3316
|
if not quiet:
|
|
4363
3317
|
print("\n" + "=" * 60)
|
|
@@ -4380,12 +3334,7 @@ def export_geotiff_tiles_batch(
|
|
|
4380
3334
|
|
|
4381
3335
|
print(f"Output saved to: {output_folder}")
|
|
4382
3336
|
print(f" Images: {output_images_dir}")
|
|
4383
|
-
|
|
4384
|
-
print(f" Masks: {output_masks_dir}")
|
|
4385
|
-
if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
|
|
4386
|
-
print(f" Annotations: {ann_dir}")
|
|
4387
|
-
elif metadata_format == "YOLO":
|
|
4388
|
-
print(f" Labels: {os.path.join(output_folder, 'labels')}")
|
|
3337
|
+
print(f" Masks: {output_masks_dir}")
|
|
4389
3338
|
|
|
4390
3339
|
# List failed files if any
|
|
4391
3340
|
if batch_stats["failed_files"]:
|
|
@@ -4411,26 +3360,18 @@ def _process_image_mask_pair(
|
|
|
4411
3360
|
all_touched=True,
|
|
4412
3361
|
skip_empty_tiles=False,
|
|
4413
3362
|
quiet=False,
|
|
4414
|
-
mask_gdf=None,
|
|
4415
|
-
use_single_mask_file=False,
|
|
4416
|
-
metadata_format="PASCAL_VOC",
|
|
4417
|
-
ann_dir=None,
|
|
4418
3363
|
):
|
|
4419
3364
|
"""
|
|
4420
3365
|
Process a single image-mask pair and save tiles directly to output directories.
|
|
4421
3366
|
|
|
4422
|
-
Args:
|
|
4423
|
-
mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
|
|
4424
|
-
use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
|
|
4425
|
-
|
|
4426
3367
|
Returns:
|
|
4427
3368
|
dict: Statistics for this image-mask pair
|
|
4428
3369
|
"""
|
|
4429
3370
|
import warnings
|
|
4430
3371
|
|
|
4431
|
-
# Determine if mask data is raster or vector
|
|
3372
|
+
# Determine if mask data is raster or vector
|
|
4432
3373
|
is_class_data_raster = False
|
|
4433
|
-
if
|
|
3374
|
+
if isinstance(mask_file, str):
|
|
4434
3375
|
file_ext = Path(mask_file).suffix.lower()
|
|
4435
3376
|
# Common raster extensions
|
|
4436
3377
|
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
@@ -4447,13 +3388,6 @@ def _process_image_mask_pair(
|
|
|
4447
3388
|
"errors": 0,
|
|
4448
3389
|
}
|
|
4449
3390
|
|
|
4450
|
-
# Initialize COCO/YOLO tracking for this image
|
|
4451
|
-
if metadata_format == "COCO":
|
|
4452
|
-
stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
|
|
4453
|
-
coco_ann_id = 0
|
|
4454
|
-
if metadata_format == "YOLO":
|
|
4455
|
-
stats["yolo_classes"] = set()
|
|
4456
|
-
|
|
4457
3391
|
# Open the input raster
|
|
4458
3392
|
with rasterio.open(image_file) as src:
|
|
4459
3393
|
# Calculate number of tiles
|
|
@@ -4464,10 +3398,10 @@ def _process_image_mask_pair(
|
|
|
4464
3398
|
if max_tiles is None:
|
|
4465
3399
|
max_tiles = total_tiles
|
|
4466
3400
|
|
|
4467
|
-
# Process classification data
|
|
3401
|
+
# Process classification data
|
|
4468
3402
|
class_to_id = {}
|
|
4469
3403
|
|
|
4470
|
-
if
|
|
3404
|
+
if is_class_data_raster:
|
|
4471
3405
|
# Load raster class data
|
|
4472
3406
|
with rasterio.open(mask_file) as class_src:
|
|
4473
3407
|
# Check if raster CRS matches
|
|
@@ -4494,39 +3428,14 @@ def _process_image_mask_pair(
|
|
|
4494
3428
|
|
|
4495
3429
|
# Create class mapping
|
|
4496
3430
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
4497
|
-
|
|
3431
|
+
else:
|
|
4498
3432
|
# Load vector class data
|
|
4499
3433
|
try:
|
|
4500
|
-
|
|
4501
|
-
# Using pre-loaded single mask file - spatially filter to image bounds
|
|
4502
|
-
# Get image bounds
|
|
4503
|
-
image_bounds = box(*src.bounds)
|
|
4504
|
-
image_gdf = gpd.GeoDataFrame(
|
|
4505
|
-
{"geometry": [image_bounds]}, crs=src.crs
|
|
4506
|
-
)
|
|
4507
|
-
|
|
4508
|
-
# Reproject mask if needed
|
|
4509
|
-
if mask_gdf.crs != src.crs:
|
|
4510
|
-
mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
|
|
4511
|
-
else:
|
|
4512
|
-
mask_gdf_reprojected = mask_gdf
|
|
4513
|
-
|
|
4514
|
-
# Spatially filter features that intersect with image bounds
|
|
4515
|
-
gdf = mask_gdf_reprojected[
|
|
4516
|
-
mask_gdf_reprojected.intersects(image_bounds)
|
|
4517
|
-
].copy()
|
|
4518
|
-
|
|
4519
|
-
if not quiet and len(gdf) > 0:
|
|
4520
|
-
print(
|
|
4521
|
-
f" Filtered to {len(gdf)} features intersecting image bounds"
|
|
4522
|
-
)
|
|
4523
|
-
else:
|
|
4524
|
-
# Load individual mask file
|
|
4525
|
-
gdf = gpd.read_file(mask_file)
|
|
3434
|
+
gdf = gpd.read_file(mask_file)
|
|
4526
3435
|
|
|
4527
|
-
|
|
4528
|
-
|
|
4529
|
-
|
|
3436
|
+
# Always reproject to match raster CRS
|
|
3437
|
+
if gdf.crs != src.crs:
|
|
3438
|
+
gdf = gdf.to_crs(src.crs)
|
|
4530
3439
|
|
|
4531
3440
|
# Apply buffer if specified
|
|
4532
3441
|
if buffer_radius > 0:
|
|
@@ -4546,6 +3455,9 @@ def _process_image_mask_pair(
|
|
|
4546
3455
|
tile_index = 0
|
|
4547
3456
|
for y in range(num_tiles_y):
|
|
4548
3457
|
for x in range(num_tiles_x):
|
|
3458
|
+
if tile_index >= max_tiles:
|
|
3459
|
+
break
|
|
3460
|
+
|
|
4549
3461
|
# Calculate window coordinates
|
|
4550
3462
|
window_x = x * stride
|
|
4551
3463
|
window_y = y * stride
|
|
@@ -4570,12 +3482,12 @@ def _process_image_mask_pair(
|
|
|
4570
3482
|
|
|
4571
3483
|
window_bounds = box(minx, miny, maxx, maxy)
|
|
4572
3484
|
|
|
4573
|
-
# Create label mask
|
|
3485
|
+
# Create label mask
|
|
4574
3486
|
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
4575
3487
|
has_features = False
|
|
4576
3488
|
|
|
4577
|
-
# Process classification data to create labels
|
|
4578
|
-
if
|
|
3489
|
+
# Process classification data to create labels
|
|
3490
|
+
if is_class_data_raster:
|
|
4579
3491
|
# For raster class data
|
|
4580
3492
|
with rasterio.open(mask_file) as class_src:
|
|
4581
3493
|
# Get corresponding window in class raster
|
|
@@ -4608,7 +3520,7 @@ def _process_image_mask_pair(
|
|
|
4608
3520
|
if not quiet:
|
|
4609
3521
|
print(f"Error reading class raster window: {e}")
|
|
4610
3522
|
stats["errors"] += 1
|
|
4611
|
-
|
|
3523
|
+
else:
|
|
4612
3524
|
# For vector class data
|
|
4613
3525
|
# Find features that intersect with window
|
|
4614
3526
|
window_features = gdf[gdf.intersects(window_bounds)]
|
|
@@ -4646,14 +3558,11 @@ def _process_image_mask_pair(
|
|
|
4646
3558
|
print(f"Error rasterizing feature {idx}: {e}")
|
|
4647
3559
|
stats["errors"] += 1
|
|
4648
3560
|
|
|
4649
|
-
# Skip tile if no features and skip_empty_tiles is True
|
|
4650
|
-
if
|
|
3561
|
+
# Skip tile if no features and skip_empty_tiles is True
|
|
3562
|
+
if skip_empty_tiles and not has_features:
|
|
3563
|
+
tile_index += 1
|
|
4651
3564
|
continue
|
|
4652
3565
|
|
|
4653
|
-
# Check if we've reached max_tiles before saving
|
|
4654
|
-
if tile_index >= max_tiles:
|
|
4655
|
-
break
|
|
4656
|
-
|
|
4657
3566
|
# Generate unique tile name
|
|
4658
3567
|
tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
|
|
4659
3568
|
|
|
@@ -4684,225 +3593,29 @@ def _process_image_mask_pair(
|
|
|
4684
3593
|
print(f"ERROR saving image GeoTIFF: {e}")
|
|
4685
3594
|
stats["errors"] += 1
|
|
4686
3595
|
|
|
4687
|
-
#
|
|
4688
|
-
|
|
4689
|
-
|
|
4690
|
-
|
|
4691
|
-
|
|
4692
|
-
|
|
4693
|
-
|
|
4694
|
-
|
|
4695
|
-
|
|
4696
|
-
|
|
4697
|
-
"transform": window_transform,
|
|
4698
|
-
}
|
|
4699
|
-
|
|
4700
|
-
label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
|
|
4701
|
-
try:
|
|
4702
|
-
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
4703
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
|
4704
|
-
|
|
4705
|
-
if has_features:
|
|
4706
|
-
stats["tiles_with_features"] += 1
|
|
4707
|
-
except Exception as e:
|
|
4708
|
-
if not quiet:
|
|
4709
|
-
print(f"ERROR saving label GeoTIFF: {e}")
|
|
4710
|
-
stats["errors"] += 1
|
|
4711
|
-
|
|
4712
|
-
# Generate annotation metadata based on format (only if mask_file is provided)
|
|
4713
|
-
if (
|
|
4714
|
-
mask_file is not None
|
|
4715
|
-
and metadata_format == "PASCAL_VOC"
|
|
4716
|
-
and ann_dir
|
|
4717
|
-
):
|
|
4718
|
-
# Create PASCAL VOC XML annotation
|
|
4719
|
-
from lxml import etree as ET
|
|
4720
|
-
|
|
4721
|
-
annotation = ET.Element("annotation")
|
|
4722
|
-
ET.SubElement(annotation, "folder").text = os.path.basename(
|
|
4723
|
-
output_images_dir
|
|
4724
|
-
)
|
|
4725
|
-
ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
|
|
4726
|
-
ET.SubElement(annotation, "path").text = image_path
|
|
4727
|
-
|
|
4728
|
-
source = ET.SubElement(annotation, "source")
|
|
4729
|
-
ET.SubElement(source, "database").text = "GeoAI"
|
|
4730
|
-
|
|
4731
|
-
size = ET.SubElement(annotation, "size")
|
|
4732
|
-
ET.SubElement(size, "width").text = str(tile_size)
|
|
4733
|
-
ET.SubElement(size, "height").text = str(tile_size)
|
|
4734
|
-
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
4735
|
-
|
|
4736
|
-
ET.SubElement(annotation, "segmented").text = "1"
|
|
4737
|
-
|
|
4738
|
-
# Find connected components for instance segmentation
|
|
4739
|
-
from scipy import ndimage
|
|
4740
|
-
|
|
4741
|
-
for class_id in np.unique(label_mask):
|
|
4742
|
-
if class_id == 0:
|
|
4743
|
-
continue
|
|
4744
|
-
|
|
4745
|
-
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4746
|
-
labeled_array, num_features = ndimage.label(class_mask)
|
|
4747
|
-
|
|
4748
|
-
for instance_id in range(1, num_features + 1):
|
|
4749
|
-
instance_mask = labeled_array == instance_id
|
|
4750
|
-
coords = np.argwhere(instance_mask)
|
|
4751
|
-
|
|
4752
|
-
if len(coords) == 0:
|
|
4753
|
-
continue
|
|
4754
|
-
|
|
4755
|
-
ymin, xmin = coords.min(axis=0)
|
|
4756
|
-
ymax, xmax = coords.max(axis=0)
|
|
4757
|
-
|
|
4758
|
-
obj = ET.SubElement(annotation, "object")
|
|
4759
|
-
class_name = next(
|
|
4760
|
-
(k for k, v in class_to_id.items() if v == class_id),
|
|
4761
|
-
str(class_id),
|
|
4762
|
-
)
|
|
4763
|
-
ET.SubElement(obj, "name").text = str(class_name)
|
|
4764
|
-
ET.SubElement(obj, "pose").text = "Unspecified"
|
|
4765
|
-
ET.SubElement(obj, "truncated").text = "0"
|
|
4766
|
-
ET.SubElement(obj, "difficult").text = "0"
|
|
4767
|
-
|
|
4768
|
-
bndbox = ET.SubElement(obj, "bndbox")
|
|
4769
|
-
ET.SubElement(bndbox, "xmin").text = str(int(xmin))
|
|
4770
|
-
ET.SubElement(bndbox, "ymin").text = str(int(ymin))
|
|
4771
|
-
ET.SubElement(bndbox, "xmax").text = str(int(xmax))
|
|
4772
|
-
ET.SubElement(bndbox, "ymax").text = str(int(ymax))
|
|
4773
|
-
|
|
4774
|
-
# Save XML file
|
|
4775
|
-
xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
|
|
4776
|
-
tree = ET.ElementTree(annotation)
|
|
4777
|
-
tree.write(xml_path, pretty_print=True, encoding="utf-8")
|
|
4778
|
-
|
|
4779
|
-
elif mask_file is not None and metadata_format == "COCO":
|
|
4780
|
-
# Add COCO image entry
|
|
4781
|
-
image_id = int(global_tile_counter + tile_index)
|
|
4782
|
-
stats["coco_data"]["images"].append(
|
|
4783
|
-
{
|
|
4784
|
-
"id": image_id,
|
|
4785
|
-
"file_name": f"{tile_name}.tif",
|
|
4786
|
-
"width": int(tile_size),
|
|
4787
|
-
"height": int(tile_size),
|
|
4788
|
-
}
|
|
4789
|
-
)
|
|
4790
|
-
|
|
4791
|
-
# Add COCO categories (only once per unique class)
|
|
4792
|
-
for class_val, class_id in class_to_id.items():
|
|
4793
|
-
if not any(
|
|
4794
|
-
c["id"] == class_id
|
|
4795
|
-
for c in stats["coco_data"]["categories"]
|
|
4796
|
-
):
|
|
4797
|
-
stats["coco_data"]["categories"].append(
|
|
4798
|
-
{
|
|
4799
|
-
"id": int(class_id),
|
|
4800
|
-
"name": str(class_val),
|
|
4801
|
-
"supercategory": "object",
|
|
4802
|
-
}
|
|
4803
|
-
)
|
|
4804
|
-
|
|
4805
|
-
# Add COCO annotations (instance segmentation)
|
|
4806
|
-
from scipy import ndimage
|
|
4807
|
-
from skimage import measure
|
|
4808
|
-
|
|
4809
|
-
for class_id in np.unique(label_mask):
|
|
4810
|
-
if class_id == 0:
|
|
4811
|
-
continue
|
|
4812
|
-
|
|
4813
|
-
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4814
|
-
labeled_array, num_features = ndimage.label(class_mask)
|
|
4815
|
-
|
|
4816
|
-
for instance_id in range(1, num_features + 1):
|
|
4817
|
-
instance_mask = (labeled_array == instance_id).astype(
|
|
4818
|
-
np.uint8
|
|
4819
|
-
)
|
|
4820
|
-
coords = np.argwhere(instance_mask)
|
|
4821
|
-
|
|
4822
|
-
if len(coords) == 0:
|
|
4823
|
-
continue
|
|
4824
|
-
|
|
4825
|
-
ymin, xmin = coords.min(axis=0)
|
|
4826
|
-
ymax, xmax = coords.max(axis=0)
|
|
4827
|
-
|
|
4828
|
-
bbox = [
|
|
4829
|
-
int(xmin),
|
|
4830
|
-
int(ymin),
|
|
4831
|
-
int(xmax - xmin),
|
|
4832
|
-
int(ymax - ymin),
|
|
4833
|
-
]
|
|
4834
|
-
area = int(np.sum(instance_mask))
|
|
4835
|
-
|
|
4836
|
-
# Find contours for segmentation
|
|
4837
|
-
contours = measure.find_contours(instance_mask, 0.5)
|
|
4838
|
-
segmentation = []
|
|
4839
|
-
for contour in contours:
|
|
4840
|
-
contour = np.flip(contour, axis=1)
|
|
4841
|
-
segmentation_points = contour.ravel().tolist()
|
|
4842
|
-
if len(segmentation_points) >= 6:
|
|
4843
|
-
segmentation.append(segmentation_points)
|
|
4844
|
-
|
|
4845
|
-
if segmentation:
|
|
4846
|
-
stats["coco_data"]["annotations"].append(
|
|
4847
|
-
{
|
|
4848
|
-
"id": int(coco_ann_id),
|
|
4849
|
-
"image_id": int(image_id),
|
|
4850
|
-
"category_id": int(class_id),
|
|
4851
|
-
"bbox": bbox,
|
|
4852
|
-
"area": area,
|
|
4853
|
-
"segmentation": segmentation,
|
|
4854
|
-
"iscrowd": 0,
|
|
4855
|
-
}
|
|
4856
|
-
)
|
|
4857
|
-
coco_ann_id += 1
|
|
4858
|
-
|
|
4859
|
-
elif mask_file is not None and metadata_format == "YOLO":
|
|
4860
|
-
# Create YOLO labels directory if needed
|
|
4861
|
-
labels_dir = os.path.join(
|
|
4862
|
-
os.path.dirname(output_images_dir), "labels"
|
|
4863
|
-
)
|
|
4864
|
-
os.makedirs(labels_dir, exist_ok=True)
|
|
4865
|
-
|
|
4866
|
-
# Generate YOLO annotation file
|
|
4867
|
-
yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
|
|
4868
|
-
from scipy import ndimage
|
|
4869
|
-
|
|
4870
|
-
with open(yolo_path, "w") as yolo_file:
|
|
4871
|
-
for class_id in np.unique(label_mask):
|
|
4872
|
-
if class_id == 0:
|
|
4873
|
-
continue
|
|
4874
|
-
|
|
4875
|
-
# Track class for classes.txt
|
|
4876
|
-
class_name = next(
|
|
4877
|
-
(k for k, v in class_to_id.items() if v == class_id),
|
|
4878
|
-
str(class_id),
|
|
4879
|
-
)
|
|
4880
|
-
stats["yolo_classes"].add(class_name)
|
|
4881
|
-
|
|
4882
|
-
class_mask = (label_mask == class_id).astype(np.uint8)
|
|
4883
|
-
labeled_array, num_features = ndimage.label(class_mask)
|
|
4884
|
-
|
|
4885
|
-
for instance_id in range(1, num_features + 1):
|
|
4886
|
-
instance_mask = labeled_array == instance_id
|
|
4887
|
-
coords = np.argwhere(instance_mask)
|
|
4888
|
-
|
|
4889
|
-
if len(coords) == 0:
|
|
4890
|
-
continue
|
|
4891
|
-
|
|
4892
|
-
ymin, xmin = coords.min(axis=0)
|
|
4893
|
-
ymax, xmax = coords.max(axis=0)
|
|
3596
|
+
# Create profile for label GeoTIFF
|
|
3597
|
+
label_profile = {
|
|
3598
|
+
"driver": "GTiff",
|
|
3599
|
+
"height": tile_size,
|
|
3600
|
+
"width": tile_size,
|
|
3601
|
+
"count": 1,
|
|
3602
|
+
"dtype": "uint8",
|
|
3603
|
+
"crs": src.crs,
|
|
3604
|
+
"transform": window_transform,
|
|
3605
|
+
}
|
|
4894
3606
|
|
|
4895
|
-
|
|
4896
|
-
|
|
4897
|
-
|
|
4898
|
-
|
|
4899
|
-
|
|
3607
|
+
# Export label as GeoTIFF
|
|
3608
|
+
label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
|
|
3609
|
+
try:
|
|
3610
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
3611
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
|
4900
3612
|
|
|
4901
|
-
|
|
4902
|
-
|
|
4903
|
-
|
|
4904
|
-
|
|
4905
|
-
|
|
3613
|
+
if has_features:
|
|
3614
|
+
stats["tiles_with_features"] += 1
|
|
3615
|
+
except Exception as e:
|
|
3616
|
+
if not quiet:
|
|
3617
|
+
print(f"ERROR saving label GeoTIFF: {e}")
|
|
3618
|
+
stats["errors"] += 1
|
|
4906
3619
|
|
|
4907
3620
|
tile_index += 1
|
|
4908
3621
|
if tile_index >= max_tiles:
|
|
@@ -4914,179 +3627,6 @@ def _process_image_mask_pair(
|
|
|
4914
3627
|
return stats
|
|
4915
3628
|
|
|
4916
3629
|
|
|
4917
|
-
def display_training_tiles(
|
|
4918
|
-
output_dir,
|
|
4919
|
-
num_tiles=6,
|
|
4920
|
-
figsize=(18, 6),
|
|
4921
|
-
cmap="gray",
|
|
4922
|
-
save_path=None,
|
|
4923
|
-
):
|
|
4924
|
-
"""
|
|
4925
|
-
Display image and mask tile pairs from training data output.
|
|
4926
|
-
|
|
4927
|
-
Args:
|
|
4928
|
-
output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
|
|
4929
|
-
num_tiles (int): Number of tile pairs to display (default: 6)
|
|
4930
|
-
figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
|
|
4931
|
-
cmap (str): Colormap for mask display (default: 'gray')
|
|
4932
|
-
save_path (str, optional): If provided, save figure to this path instead of displaying
|
|
4933
|
-
|
|
4934
|
-
Returns:
|
|
4935
|
-
tuple: (fig, axes) matplotlib figure and axes objects
|
|
4936
|
-
|
|
4937
|
-
Example:
|
|
4938
|
-
>>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
|
|
4939
|
-
>>> # Or save to file
|
|
4940
|
-
>>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
|
|
4941
|
-
"""
|
|
4942
|
-
import matplotlib.pyplot as plt
|
|
4943
|
-
|
|
4944
|
-
# Get list of image tiles
|
|
4945
|
-
images_dir = os.path.join(output_dir, "images")
|
|
4946
|
-
if not os.path.exists(images_dir):
|
|
4947
|
-
raise ValueError(f"Images directory not found: {images_dir}")
|
|
4948
|
-
|
|
4949
|
-
image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
|
|
4950
|
-
|
|
4951
|
-
if not image_tiles:
|
|
4952
|
-
raise ValueError(f"No image tiles found in {images_dir}")
|
|
4953
|
-
|
|
4954
|
-
# Limit to available tiles
|
|
4955
|
-
num_tiles = min(num_tiles, len(image_tiles))
|
|
4956
|
-
|
|
4957
|
-
# Create figure with subplots
|
|
4958
|
-
fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
|
|
4959
|
-
|
|
4960
|
-
# Handle case where num_tiles is 1
|
|
4961
|
-
if num_tiles == 1:
|
|
4962
|
-
axes = axes.reshape(2, 1)
|
|
4963
|
-
|
|
4964
|
-
for idx, tile_name in enumerate(image_tiles):
|
|
4965
|
-
# Load and display image tile
|
|
4966
|
-
image_path = os.path.join(output_dir, "images", tile_name)
|
|
4967
|
-
with rasterio.open(image_path) as src:
|
|
4968
|
-
show(src, ax=axes[0, idx], title=f"Image {idx+1}")
|
|
4969
|
-
|
|
4970
|
-
# Load and display mask tile
|
|
4971
|
-
mask_path = os.path.join(output_dir, "masks", tile_name)
|
|
4972
|
-
if os.path.exists(mask_path):
|
|
4973
|
-
with rasterio.open(mask_path) as src:
|
|
4974
|
-
show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
|
|
4975
|
-
else:
|
|
4976
|
-
axes[1, idx].text(
|
|
4977
|
-
0.5,
|
|
4978
|
-
0.5,
|
|
4979
|
-
"Mask not found",
|
|
4980
|
-
ha="center",
|
|
4981
|
-
va="center",
|
|
4982
|
-
transform=axes[1, idx].transAxes,
|
|
4983
|
-
)
|
|
4984
|
-
axes[1, idx].set_title(f"Mask {idx+1}")
|
|
4985
|
-
|
|
4986
|
-
plt.tight_layout()
|
|
4987
|
-
|
|
4988
|
-
# Save or show
|
|
4989
|
-
if save_path:
|
|
4990
|
-
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
4991
|
-
plt.close(fig)
|
|
4992
|
-
print(f"Figure saved to: {save_path}")
|
|
4993
|
-
else:
|
|
4994
|
-
plt.show()
|
|
4995
|
-
|
|
4996
|
-
return fig, axes
|
|
4997
|
-
|
|
4998
|
-
|
|
4999
|
-
def display_image_with_vector(
|
|
5000
|
-
image_path,
|
|
5001
|
-
vector_path,
|
|
5002
|
-
figsize=(16, 8),
|
|
5003
|
-
vector_color="red",
|
|
5004
|
-
vector_linewidth=1,
|
|
5005
|
-
vector_facecolor="none",
|
|
5006
|
-
save_path=None,
|
|
5007
|
-
):
|
|
5008
|
-
"""
|
|
5009
|
-
Display a raster image alongside the same image with vector overlay.
|
|
5010
|
-
|
|
5011
|
-
Args:
|
|
5012
|
-
image_path (str): Path to raster image file
|
|
5013
|
-
vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
|
|
5014
|
-
figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
|
|
5015
|
-
vector_color (str): Edge color for vector features (default: 'red')
|
|
5016
|
-
vector_linewidth (float): Line width for vector features (default: 1)
|
|
5017
|
-
vector_facecolor (str): Fill color for vector features (default: 'none')
|
|
5018
|
-
save_path (str, optional): If provided, save figure to this path instead of displaying
|
|
5019
|
-
|
|
5020
|
-
Returns:
|
|
5021
|
-
tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
|
|
5022
|
-
|
|
5023
|
-
Example:
|
|
5024
|
-
>>> fig, axes, info = display_image_with_vector(
|
|
5025
|
-
... 'image.tif',
|
|
5026
|
-
... 'buildings.geojson',
|
|
5027
|
-
... vector_color='blue'
|
|
5028
|
-
... )
|
|
5029
|
-
>>> print(f"Number of features: {info['num_features']}")
|
|
5030
|
-
"""
|
|
5031
|
-
import matplotlib.pyplot as plt
|
|
5032
|
-
|
|
5033
|
-
# Validate inputs
|
|
5034
|
-
if not os.path.exists(image_path):
|
|
5035
|
-
raise ValueError(f"Image file not found: {image_path}")
|
|
5036
|
-
if not os.path.exists(vector_path):
|
|
5037
|
-
raise ValueError(f"Vector file not found: {vector_path}")
|
|
5038
|
-
|
|
5039
|
-
# Create figure
|
|
5040
|
-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
|
5041
|
-
|
|
5042
|
-
# Load and display image
|
|
5043
|
-
with rasterio.open(image_path) as src:
|
|
5044
|
-
# Plot image only
|
|
5045
|
-
show(src, ax=ax1, title="Image")
|
|
5046
|
-
|
|
5047
|
-
# Load vector data
|
|
5048
|
-
vector_data = gpd.read_file(vector_path)
|
|
5049
|
-
|
|
5050
|
-
# Reproject to image CRS if needed
|
|
5051
|
-
if vector_data.crs != src.crs:
|
|
5052
|
-
vector_data = vector_data.to_crs(src.crs)
|
|
5053
|
-
|
|
5054
|
-
# Plot image with vector overlay
|
|
5055
|
-
show(
|
|
5056
|
-
src,
|
|
5057
|
-
ax=ax2,
|
|
5058
|
-
title=f"Image with {len(vector_data)} Vector Features",
|
|
5059
|
-
)
|
|
5060
|
-
vector_data.plot(
|
|
5061
|
-
ax=ax2,
|
|
5062
|
-
facecolor=vector_facecolor,
|
|
5063
|
-
edgecolor=vector_color,
|
|
5064
|
-
linewidth=vector_linewidth,
|
|
5065
|
-
)
|
|
5066
|
-
|
|
5067
|
-
# Collect metadata
|
|
5068
|
-
info = {
|
|
5069
|
-
"image_shape": src.shape,
|
|
5070
|
-
"image_crs": src.crs,
|
|
5071
|
-
"image_bounds": src.bounds,
|
|
5072
|
-
"num_features": len(vector_data),
|
|
5073
|
-
"vector_crs": vector_data.crs,
|
|
5074
|
-
"vector_bounds": vector_data.total_bounds,
|
|
5075
|
-
}
|
|
5076
|
-
|
|
5077
|
-
plt.tight_layout()
|
|
5078
|
-
|
|
5079
|
-
# Save or show
|
|
5080
|
-
if save_path:
|
|
5081
|
-
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
|
5082
|
-
plt.close(fig)
|
|
5083
|
-
print(f"Figure saved to: {save_path}")
|
|
5084
|
-
else:
|
|
5085
|
-
plt.show()
|
|
5086
|
-
|
|
5087
|
-
return fig, (ax1, ax2), info
|
|
5088
|
-
|
|
5089
|
-
|
|
5090
3630
|
def create_overview_image(
|
|
5091
3631
|
src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
|
|
5092
3632
|
) -> str:
|
|
@@ -8981,39 +7521,17 @@ def write_colormap(
|
|
|
8981
7521
|
|
|
8982
7522
|
def plot_performance_metrics(
|
|
8983
7523
|
history_path: str,
|
|
8984
|
-
figsize:
|
|
7524
|
+
figsize: Tuple[int, int] = (15, 5),
|
|
8985
7525
|
verbose: bool = True,
|
|
8986
7526
|
save_path: Optional[str] = None,
|
|
8987
|
-
csv_path: Optional[str] = None,
|
|
8988
7527
|
kwargs: Optional[Dict] = None,
|
|
8989
|
-
) ->
|
|
8990
|
-
"""Plot performance metrics from a
|
|
8991
|
-
|
|
8992
|
-
This function loads training history, plots available metrics (loss, IoU, F1,
|
|
8993
|
-
precision, recall), optionally exports to CSV, and returns all metrics as a
|
|
8994
|
-
pandas DataFrame for further analysis.
|
|
7528
|
+
) -> None:
|
|
7529
|
+
"""Plot performance metrics from a history object.
|
|
8995
7530
|
|
|
8996
7531
|
Args:
|
|
8997
|
-
history_path
|
|
8998
|
-
figsize
|
|
8999
|
-
|
|
9000
|
-
verbose (bool): Whether to print best and final metric values. Defaults to True.
|
|
9001
|
-
save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
|
|
9002
|
-
csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
|
|
9003
|
-
kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
|
|
9004
|
-
|
|
9005
|
-
Returns:
|
|
9006
|
-
pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
|
|
9007
|
-
Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
|
|
9008
|
-
'val_precision', 'val_recall' (depending on availability in history).
|
|
9009
|
-
|
|
9010
|
-
Example:
|
|
9011
|
-
>>> df = plot_performance_metrics(
|
|
9012
|
-
... 'training_history.pth',
|
|
9013
|
-
... save_path='metrics_plot.png',
|
|
9014
|
-
... csv_path='metrics.csv'
|
|
9015
|
-
... )
|
|
9016
|
-
>>> print(df.head())
|
|
7532
|
+
history_path: The history object to plot.
|
|
7533
|
+
figsize: The figure size.
|
|
7534
|
+
verbose: Whether to print the best and final metrics.
|
|
9017
7535
|
"""
|
|
9018
7536
|
if kwargs is None:
|
|
9019
7537
|
kwargs = {}
|
|
@@ -9023,135 +7541,65 @@ def plot_performance_metrics(
|
|
|
9023
7541
|
train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
|
|
9024
7542
|
val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
|
|
9025
7543
|
val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
|
|
9026
|
-
|
|
9027
|
-
val_f1_key = (
|
|
9028
|
-
"val_f1s"
|
|
9029
|
-
if "val_f1s" in history
|
|
9030
|
-
else ("val_dices" if "val_dices" in history else "val_dice")
|
|
9031
|
-
)
|
|
9032
|
-
# Add support for precision and recall
|
|
9033
|
-
val_precision_key = (
|
|
9034
|
-
"val_precisions" if "val_precisions" in history else "val_precision"
|
|
9035
|
-
)
|
|
9036
|
-
val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
|
|
9037
|
-
|
|
9038
|
-
# Collect available metrics for plotting
|
|
9039
|
-
available_metrics = []
|
|
9040
|
-
metric_info = {
|
|
9041
|
-
"Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
|
|
9042
|
-
"IoU": (val_iou_key, None, ["Val IoU"]),
|
|
9043
|
-
"F1": (val_f1_key, None, ["Val F1"]),
|
|
9044
|
-
"Precision": (val_precision_key, None, ["Val Precision"]),
|
|
9045
|
-
"Recall": (val_recall_key, None, ["Val Recall"]),
|
|
9046
|
-
}
|
|
7544
|
+
val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
|
|
9047
7545
|
|
|
9048
|
-
|
|
9049
|
-
|
|
9050
|
-
|
|
7546
|
+
# Determine number of subplots based on available metrics
|
|
7547
|
+
has_dice = val_dice_key in history
|
|
7548
|
+
n_plots = 3 if has_dice else 2
|
|
7549
|
+
figsize = (15, 5) if has_dice else (10, 5)
|
|
9051
7550
|
|
|
9052
|
-
|
|
9053
|
-
n_plots = len(available_metrics)
|
|
9054
|
-
if figsize is None:
|
|
9055
|
-
figsize = (5 * n_plots, 5)
|
|
7551
|
+
plt.figure(figsize=figsize)
|
|
9056
7552
|
|
|
9057
|
-
#
|
|
9058
|
-
|
|
9059
|
-
df_data = {}
|
|
9060
|
-
|
|
9061
|
-
# Add epochs
|
|
9062
|
-
if "epochs" in history:
|
|
9063
|
-
df_data["epoch"] = history["epochs"]
|
|
9064
|
-
n_epochs = len(history["epochs"])
|
|
9065
|
-
elif train_loss_key in history:
|
|
9066
|
-
n_epochs = len(history[train_loss_key])
|
|
9067
|
-
df_data["epoch"] = list(range(1, n_epochs + 1))
|
|
9068
|
-
|
|
9069
|
-
# Add all available metrics to DataFrame
|
|
7553
|
+
# Plot loss
|
|
7554
|
+
plt.subplot(1, n_plots, 1)
|
|
9070
7555
|
if train_loss_key in history:
|
|
9071
|
-
|
|
7556
|
+
plt.plot(history[train_loss_key], label="Train Loss")
|
|
9072
7557
|
if val_loss_key in history:
|
|
9073
|
-
|
|
7558
|
+
plt.plot(history[val_loss_key], label="Val Loss")
|
|
7559
|
+
plt.title("Loss")
|
|
7560
|
+
plt.xlabel("Epoch")
|
|
7561
|
+
plt.ylabel("Loss")
|
|
7562
|
+
plt.legend()
|
|
7563
|
+
plt.grid(True)
|
|
7564
|
+
|
|
7565
|
+
# Plot IoU
|
|
7566
|
+
plt.subplot(1, n_plots, 2)
|
|
9074
7567
|
if val_iou_key in history:
|
|
9075
|
-
|
|
9076
|
-
|
|
9077
|
-
|
|
9078
|
-
|
|
9079
|
-
|
|
9080
|
-
|
|
9081
|
-
|
|
9082
|
-
|
|
9083
|
-
|
|
9084
|
-
|
|
9085
|
-
|
|
9086
|
-
|
|
9087
|
-
|
|
9088
|
-
|
|
9089
|
-
|
|
9090
|
-
|
|
9091
|
-
|
|
9092
|
-
# Create plots
|
|
9093
|
-
if n_plots > 0:
|
|
9094
|
-
fig, axes = plt.subplots(1, n_plots, figsize=figsize)
|
|
9095
|
-
if n_plots == 1:
|
|
9096
|
-
axes = [axes]
|
|
9097
|
-
|
|
9098
|
-
for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
|
|
9099
|
-
ax = axes[idx]
|
|
9100
|
-
|
|
9101
|
-
if metric_name == "Loss":
|
|
9102
|
-
# Special handling for loss (has both train and val)
|
|
9103
|
-
if key1 in history:
|
|
9104
|
-
ax.plot(history[key1], label=labels[0])
|
|
9105
|
-
if key2 and key2 in history:
|
|
9106
|
-
ax.plot(history[key2], label=labels[1])
|
|
9107
|
-
else:
|
|
9108
|
-
# Single metric plots
|
|
9109
|
-
if key1 in history:
|
|
9110
|
-
ax.plot(history[key1], label=labels[0])
|
|
9111
|
-
|
|
9112
|
-
ax.set_title(metric_name)
|
|
9113
|
-
ax.set_xlabel("Epoch")
|
|
9114
|
-
ax.set_ylabel(metric_name)
|
|
9115
|
-
ax.legend()
|
|
9116
|
-
ax.grid(True)
|
|
7568
|
+
plt.plot(history[val_iou_key], label="Val IoU")
|
|
7569
|
+
plt.title("IoU Score")
|
|
7570
|
+
plt.xlabel("Epoch")
|
|
7571
|
+
plt.ylabel("IoU")
|
|
7572
|
+
plt.legend()
|
|
7573
|
+
plt.grid(True)
|
|
7574
|
+
|
|
7575
|
+
# Plot Dice if available
|
|
7576
|
+
if has_dice:
|
|
7577
|
+
plt.subplot(1, n_plots, 3)
|
|
7578
|
+
plt.plot(history[val_dice_key], label="Val Dice")
|
|
7579
|
+
plt.title("Dice Score")
|
|
7580
|
+
plt.xlabel("Epoch")
|
|
7581
|
+
plt.ylabel("Dice")
|
|
7582
|
+
plt.legend()
|
|
7583
|
+
plt.grid(True)
|
|
9117
7584
|
|
|
9118
|
-
|
|
7585
|
+
plt.tight_layout()
|
|
9119
7586
|
|
|
9120
|
-
|
|
9121
|
-
|
|
9122
|
-
|
|
9123
|
-
|
|
9124
|
-
|
|
9125
|
-
|
|
7587
|
+
if save_path:
|
|
7588
|
+
if "dpi" not in kwargs:
|
|
7589
|
+
kwargs["dpi"] = 150
|
|
7590
|
+
if "bbox_inches" not in kwargs:
|
|
7591
|
+
kwargs["bbox_inches"] = "tight"
|
|
7592
|
+
plt.savefig(save_path, **kwargs)
|
|
9126
7593
|
|
|
9127
|
-
|
|
7594
|
+
plt.show()
|
|
9128
7595
|
|
|
9129
|
-
# Print summary statistics
|
|
9130
7596
|
if verbose:
|
|
9131
|
-
print("\n=== Performance Metrics Summary ===")
|
|
9132
7597
|
if val_iou_key in history:
|
|
9133
|
-
print(
|
|
9134
|
-
|
|
9135
|
-
|
|
9136
|
-
|
|
9137
|
-
print(
|
|
9138
|
-
f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
|
|
9139
|
-
)
|
|
9140
|
-
if val_precision_key in history:
|
|
9141
|
-
print(
|
|
9142
|
-
f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
|
|
9143
|
-
)
|
|
9144
|
-
if val_recall_key in history:
|
|
9145
|
-
print(
|
|
9146
|
-
f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
|
|
9147
|
-
)
|
|
9148
|
-
if val_loss_key in history:
|
|
9149
|
-
print(
|
|
9150
|
-
f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
|
|
9151
|
-
)
|
|
9152
|
-
print("===================================\n")
|
|
9153
|
-
|
|
9154
|
-
return df
|
|
7598
|
+
print(f"Best IoU: {max(history[val_iou_key]):.4f}")
|
|
7599
|
+
print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
|
|
7600
|
+
if val_dice_key in history:
|
|
7601
|
+
print(f"Best Dice: {max(history[val_dice_key]):.4f}")
|
|
7602
|
+
print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
|
|
9155
7603
|
|
|
9156
7604
|
|
|
9157
7605
|
def get_device() -> torch.device:
|