geoai-py 0.18.1__py2.py3-none-any.whl → 0.19.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -64,7 +64,7 @@ def view_raster(
64
64
  client_args: Optional[Dict] = {"cors_all": False},
65
65
  basemap: Optional[str] = "OpenStreetMap",
66
66
  basemap_args: Optional[Dict] = None,
67
- backend: Optional[str] = "folium",
67
+ backend: Optional[str] = "ipyleaflet",
68
68
  **kwargs: Any,
69
69
  ) -> Any:
70
70
  """
@@ -87,7 +87,7 @@ def view_raster(
87
87
  client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
88
88
  basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
89
89
  basemap_args (Optional[Dict], optional): Additional arguments for the basemap. Defaults to None.
90
- backend (Optional[str], optional): The backend to use. Defaults to "folium".
90
+ backend (Optional[str], optional): The backend to use. Defaults to "ipyleaflet".
91
91
  **kwargs (Any): Additional keyword arguments.
92
92
 
93
93
  Returns:
@@ -123,39 +123,26 @@ def view_raster(
123
123
  if isinstance(source, dict):
124
124
  source = dict_to_image(source)
125
125
 
126
- if isinstance(source, str) and source.startswith("http"):
127
- if backend == "folium":
128
-
129
- m.add_geotiff(
130
- url=source,
131
- name=layer_name,
132
- opacity=opacity,
133
- attribution=attribution,
134
- fit_bounds=zoom_to_layer,
135
- palette=colormap,
136
- vmin=vmin,
137
- vmax=vmax,
138
- **kwargs,
139
- )
140
- m.add_layer_control()
141
- m.add_opacity_control()
142
-
143
- else:
144
- if indexes is not None:
145
- kwargs["bidx"] = indexes
146
- if colormap is not None:
147
- kwargs["colormap_name"] = colormap
148
- if attribution is None:
149
- attribution = "TiTiler"
150
-
151
- m.add_cog_layer(
152
- source,
153
- name=layer_name,
154
- opacity=opacity,
155
- attribution=attribution,
156
- zoom_to_layer=zoom_to_layer,
157
- **kwargs,
158
- )
126
+ if (
127
+ isinstance(source, str)
128
+ and source.lower().endswith(".tif")
129
+ and source.startswith("http")
130
+ ):
131
+ if indexes is not None:
132
+ kwargs["bidx"] = indexes
133
+ if colormap is not None:
134
+ kwargs["colormap_name"] = colormap
135
+ if attribution is None:
136
+ attribution = "TiTiler"
137
+
138
+ m.add_cog_layer(
139
+ source,
140
+ name=layer_name,
141
+ opacity=opacity,
142
+ attribution=attribution,
143
+ zoom_to_layer=zoom_to_layer,
144
+ **kwargs,
145
+ )
159
146
  else:
160
147
  m.add_raster(
161
148
  source=source,
@@ -237,8 +224,6 @@ def view_image(
237
224
  plt.show()
238
225
  plt.close()
239
226
 
240
- return ax
241
-
242
227
 
243
228
  def plot_images(
244
229
  images: Iterable[torch.Tensor],
@@ -396,394 +381,6 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
396
381
  return accum_mean / len(files), accum_std / len(files)
397
382
 
398
383
 
399
- def calc_iou(
400
- ground_truth: Union[str, np.ndarray, torch.Tensor],
401
- prediction: Union[str, np.ndarray, torch.Tensor],
402
- num_classes: Optional[int] = None,
403
- ignore_index: Optional[int] = None,
404
- smooth: float = 1e-6,
405
- band: int = 1,
406
- ) -> Union[float, np.ndarray]:
407
- """
408
- Calculate Intersection over Union (IoU) between ground truth and prediction masks.
409
-
410
- This function computes the IoU metric for segmentation tasks. It supports both
411
- binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
412
- or file paths to raster files.
413
-
414
- Args:
415
- ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
416
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
417
- For binary segmentation: shape (H, W) with values {0, 1}.
418
- For multi-class segmentation: shape (H, W) with class indices.
419
- prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
420
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
421
- Should have the same shape and format as ground_truth.
422
- num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
423
- If None, assumes binary segmentation. Defaults to None.
424
- ignore_index (Optional[int], optional): Class index to ignore in computation.
425
- Useful for ignoring background or unlabeled pixels. Defaults to None.
426
- smooth (float, optional): Smoothing factor to avoid division by zero.
427
- Defaults to 1e-6.
428
- band (int, optional): Band index to read from raster file (1-based indexing).
429
- Only used when input is a file path. Defaults to 1.
430
-
431
- Returns:
432
- Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
433
- For multi-class segmentation, returns an array of IoU scores for each class.
434
-
435
- Examples:
436
- >>> # Binary segmentation with arrays
437
- >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
438
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
439
- >>> iou = calc_iou(gt, pred)
440
- >>> print(f"IoU: {iou:.4f}")
441
- IoU: 0.8333
442
-
443
- >>> # Multi-class segmentation
444
- >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
445
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
446
- >>> iou = calc_iou(gt, pred, num_classes=3)
447
- >>> print(f"IoU per class: {iou}")
448
- IoU per class: [0.8333 0.5000 0.5000]
449
-
450
- >>> # Using PyTorch tensors
451
- >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
452
- >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
453
- >>> iou = calc_iou(gt_tensor, pred_tensor)
454
- >>> print(f"IoU: {iou:.4f}")
455
- IoU: 0.8333
456
-
457
- >>> # Using raster file paths
458
- >>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
459
- >>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
460
- Mean IoU: 0.7500
461
- """
462
- # Load from file if string path is provided
463
- if isinstance(ground_truth, str):
464
- with rasterio.open(ground_truth) as src:
465
- ground_truth = src.read(band)
466
- if isinstance(prediction, str):
467
- with rasterio.open(prediction) as src:
468
- prediction = src.read(band)
469
-
470
- # Convert to numpy if torch tensor
471
- if isinstance(ground_truth, torch.Tensor):
472
- ground_truth = ground_truth.cpu().numpy()
473
- if isinstance(prediction, torch.Tensor):
474
- prediction = prediction.cpu().numpy()
475
-
476
- # Ensure inputs have the same shape
477
- if ground_truth.shape != prediction.shape:
478
- raise ValueError(
479
- f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
480
- )
481
-
482
- # Binary segmentation
483
- if num_classes is None:
484
- ground_truth = ground_truth.astype(bool)
485
- prediction = prediction.astype(bool)
486
-
487
- intersection = np.logical_and(ground_truth, prediction).sum()
488
- union = np.logical_or(ground_truth, prediction).sum()
489
-
490
- if union == 0:
491
- return 1.0 if intersection == 0 else 0.0
492
-
493
- iou = (intersection + smooth) / (union + smooth)
494
- return float(iou)
495
-
496
- # Multi-class segmentation
497
- else:
498
- iou_per_class = []
499
-
500
- for class_idx in range(num_classes):
501
- # Handle ignored class by appending np.nan
502
- if ignore_index is not None and class_idx == ignore_index:
503
- iou_per_class.append(np.nan)
504
- continue
505
-
506
- # Create binary masks for current class
507
- gt_class = (ground_truth == class_idx).astype(bool)
508
- pred_class = (prediction == class_idx).astype(bool)
509
-
510
- intersection = np.logical_and(gt_class, pred_class).sum()
511
- union = np.logical_or(gt_class, pred_class).sum()
512
-
513
- if union == 0:
514
- # If class is not present in both gt and pred
515
- iou_per_class.append(np.nan)
516
- else:
517
- iou_per_class.append((intersection + smooth) / (union + smooth))
518
-
519
- return np.array(iou_per_class)
520
-
521
-
522
- def calc_f1_score(
523
- ground_truth: Union[str, np.ndarray, torch.Tensor],
524
- prediction: Union[str, np.ndarray, torch.Tensor],
525
- num_classes: Optional[int] = None,
526
- ignore_index: Optional[int] = None,
527
- smooth: float = 1e-6,
528
- band: int = 1,
529
- ) -> Union[float, np.ndarray]:
530
- """
531
- Calculate F1 score between ground truth and prediction masks.
532
-
533
- The F1 score is the harmonic mean of precision and recall, computed as:
534
- F1 = 2 * (precision * recall) / (precision + recall)
535
- where precision = TP / (TP + FP) and recall = TP / (TP + FN).
536
-
537
- This function supports both binary and multi-class segmentation, and can handle
538
- numpy arrays, PyTorch tensors, or file paths to raster files.
539
-
540
- Args:
541
- ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
542
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
543
- For binary segmentation: shape (H, W) with values {0, 1}.
544
- For multi-class segmentation: shape (H, W) with class indices.
545
- prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
546
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
547
- Should have the same shape and format as ground_truth.
548
- num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
549
- If None, assumes binary segmentation. Defaults to None.
550
- ignore_index (Optional[int], optional): Class index to ignore in computation.
551
- Useful for ignoring background or unlabeled pixels. Defaults to None.
552
- smooth (float, optional): Smoothing factor to avoid division by zero.
553
- Defaults to 1e-6.
554
- band (int, optional): Band index to read from raster file (1-based indexing).
555
- Only used when input is a file path. Defaults to 1.
556
-
557
- Returns:
558
- Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
559
- For multi-class segmentation, returns an array of F1 scores for each class.
560
-
561
- Examples:
562
- >>> # Binary segmentation with arrays
563
- >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
564
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
565
- >>> f1 = calc_f1_score(gt, pred)
566
- >>> print(f"F1 Score: {f1:.4f}")
567
- F1 Score: 0.8571
568
-
569
- >>> # Multi-class segmentation
570
- >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
571
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
572
- >>> f1 = calc_f1_score(gt, pred, num_classes=3)
573
- >>> print(f"F1 Score per class: {f1}")
574
- F1 Score per class: [0.8571 0.6667 0.6667]
575
-
576
- >>> # Using PyTorch tensors
577
- >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
578
- >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
579
- >>> f1 = calc_f1_score(gt_tensor, pred_tensor)
580
- >>> print(f"F1 Score: {f1:.4f}")
581
- F1 Score: 0.8571
582
-
583
- >>> # Using raster file paths
584
- >>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
585
- >>> print(f"Mean F1: {np.nanmean(f1):.4f}")
586
- Mean F1: 0.7302
587
- """
588
- # Load from file if string path is provided
589
- if isinstance(ground_truth, str):
590
- with rasterio.open(ground_truth) as src:
591
- ground_truth = src.read(band)
592
- if isinstance(prediction, str):
593
- with rasterio.open(prediction) as src:
594
- prediction = src.read(band)
595
-
596
- # Convert to numpy if torch tensor
597
- if isinstance(ground_truth, torch.Tensor):
598
- ground_truth = ground_truth.cpu().numpy()
599
- if isinstance(prediction, torch.Tensor):
600
- prediction = prediction.cpu().numpy()
601
-
602
- # Ensure inputs have the same shape
603
- if ground_truth.shape != prediction.shape:
604
- raise ValueError(
605
- f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
606
- )
607
-
608
- # Binary segmentation
609
- if num_classes is None:
610
- ground_truth = ground_truth.astype(bool)
611
- prediction = prediction.astype(bool)
612
-
613
- # Calculate True Positives, False Positives, False Negatives
614
- tp = np.logical_and(ground_truth, prediction).sum()
615
- fp = np.logical_and(~ground_truth, prediction).sum()
616
- fn = np.logical_and(ground_truth, ~prediction).sum()
617
-
618
- # Calculate precision and recall
619
- precision = (tp + smooth) / (tp + fp + smooth)
620
- recall = (tp + smooth) / (tp + fn + smooth)
621
-
622
- # Calculate F1 score
623
- f1 = 2 * (precision * recall) / (precision + recall + smooth)
624
- return float(f1)
625
-
626
- # Multi-class segmentation
627
- else:
628
- f1_per_class = []
629
-
630
- for class_idx in range(num_classes):
631
- # Mark ignored class with np.nan
632
- if ignore_index is not None and class_idx == ignore_index:
633
- f1_per_class.append(np.nan)
634
- continue
635
-
636
- # Create binary masks for current class
637
- gt_class = (ground_truth == class_idx).astype(bool)
638
- pred_class = (prediction == class_idx).astype(bool)
639
-
640
- # Calculate True Positives, False Positives, False Negatives
641
- tp = np.logical_and(gt_class, pred_class).sum()
642
- fp = np.logical_and(~gt_class, pred_class).sum()
643
- fn = np.logical_and(gt_class, ~pred_class).sum()
644
-
645
- # Calculate precision and recall
646
- precision = (tp + smooth) / (tp + fp + smooth)
647
- recall = (tp + smooth) / (tp + fn + smooth)
648
-
649
- # Calculate F1 score
650
- if tp + fp + fn == 0:
651
- # If class is not present in both gt and pred
652
- f1_per_class.append(np.nan)
653
- else:
654
- f1 = 2 * (precision * recall) / (precision + recall + smooth)
655
- f1_per_class.append(f1)
656
-
657
- return np.array(f1_per_class)
658
-
659
-
660
- def calc_segmentation_metrics(
661
- ground_truth: Union[str, np.ndarray, torch.Tensor],
662
- prediction: Union[str, np.ndarray, torch.Tensor],
663
- num_classes: Optional[int] = None,
664
- ignore_index: Optional[int] = None,
665
- smooth: float = 1e-6,
666
- metrics: List[str] = ["iou", "f1"],
667
- band: int = 1,
668
- ) -> Dict[str, Union[float, np.ndarray]]:
669
- """
670
- Calculate multiple segmentation metrics between ground truth and prediction masks.
671
-
672
- This is a convenient wrapper function that computes multiple metrics at once,
673
- including IoU (Intersection over Union) and F1 score. It supports both binary
674
- and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
675
- or file paths to raster files.
676
-
677
- Args:
678
- ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
679
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
680
- For binary segmentation: shape (H, W) with values {0, 1}.
681
- For multi-class segmentation: shape (H, W) with class indices.
682
- prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
683
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
684
- Should have the same shape and format as ground_truth.
685
- num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
686
- If None, assumes binary segmentation. Defaults to None.
687
- ignore_index (Optional[int], optional): Class index to ignore in computation.
688
- Useful for ignoring background or unlabeled pixels. Defaults to None.
689
- smooth (float, optional): Smoothing factor to avoid division by zero.
690
- Defaults to 1e-6.
691
- metrics (List[str], optional): List of metrics to calculate.
692
- Options: "iou", "f1". Defaults to ["iou", "f1"].
693
- band (int, optional): Band index to read from raster file (1-based indexing).
694
- Only used when input is a file path. Defaults to 1.
695
-
696
- Returns:
697
- Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
698
- Keys are metric names ("iou", "f1"), values are the metric scores.
699
- For binary segmentation, values are floats.
700
- For multi-class segmentation, values are numpy arrays with per-class scores.
701
- Also includes "mean_iou" and "mean_f1" for multi-class segmentation
702
- (mean computed over valid classes, ignoring NaN values).
703
-
704
- Examples:
705
- >>> # Binary segmentation with arrays
706
- >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
707
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
708
- >>> metrics = calc_segmentation_metrics(gt, pred)
709
- >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
710
- IoU: 0.8333, F1: 0.8571
711
-
712
- >>> # Multi-class segmentation
713
- >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
714
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
715
- >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
716
- >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
717
- >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
718
- >>> print(f"Per-class IoU: {metrics['iou']}")
719
- Mean IoU: 0.6111
720
- Mean F1: 0.7302
721
- Per-class IoU: [0.8333 0.5000 0.5000]
722
-
723
- >>> # Calculate only IoU
724
- >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
725
- >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
726
- Mean IoU: 0.6111
727
-
728
- >>> # Using PyTorch tensors
729
- >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
730
- >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
731
- >>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
732
- >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
733
- IoU: 0.8333, F1: 0.8571
734
-
735
- >>> # Using raster file paths
736
- >>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
737
- >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
738
- >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
739
- Mean IoU: 0.6111
740
- Mean F1: 0.7302
741
- """
742
- results = {}
743
-
744
- # Calculate IoU if requested
745
- if "iou" in metrics:
746
- iou = calc_iou(
747
- ground_truth,
748
- prediction,
749
- num_classes=num_classes,
750
- ignore_index=ignore_index,
751
- smooth=smooth,
752
- band=band,
753
- )
754
- results["iou"] = iou
755
-
756
- # Add mean IoU for multi-class
757
- if num_classes is not None and isinstance(iou, np.ndarray):
758
- # Calculate mean ignoring NaN values
759
- valid_ious = iou[~np.isnan(iou)]
760
- results["mean_iou"] = (
761
- float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
762
- )
763
-
764
- # Calculate F1 score if requested
765
- if "f1" in metrics:
766
- f1 = calc_f1_score(
767
- ground_truth,
768
- prediction,
769
- num_classes=num_classes,
770
- ignore_index=ignore_index,
771
- smooth=smooth,
772
- band=band,
773
- )
774
- results["f1"] = f1
775
-
776
- # Add mean F1 for multi-class
777
- if num_classes is not None and isinstance(f1, np.ndarray):
778
- # Calculate mean ignoring NaN values
779
- valid_f1s = f1[~np.isnan(f1)]
780
- results["mean_f1"] = (
781
- float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
782
- )
783
-
784
- return results
785
-
786
-
787
384
  def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
788
385
  """Convert a dictionary to a xarray DataArray. The dictionary should contain the
789
386
  following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
@@ -1094,9 +691,8 @@ def view_vector(
1094
691
 
1095
692
  def view_vector_interactive(
1096
693
  vector_data: Union[str, gpd.GeoDataFrame],
1097
- layer_name: str = "Vector",
694
+ layer_name: str = "Vector Layer",
1098
695
  tiles_args: Optional[Dict] = None,
1099
- opacity: float = 0.7,
1100
696
  **kwargs: Any,
1101
697
  ) -> Any:
1102
698
  """
@@ -1111,7 +707,6 @@ def view_vector_interactive(
1111
707
  layer_name (str, optional): The name of the layer. Defaults to "Vector Layer".
1112
708
  tiles_args (dict, optional): Additional arguments for the localtileserver client.
1113
709
  get_folium_tile_layer function. Defaults to None.
1114
- opacity (float, optional): The opacity of the layer. Defaults to 0.7.
1115
710
  **kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
1116
711
  See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
1117
712
 
@@ -1126,8 +721,9 @@ def view_vector_interactive(
1126
721
  >>> roads = gpd.read_file("roads.shp")
1127
722
  >>> view_vector_interactive(roads, figsize=(12, 8))
1128
723
  """
1129
-
1130
- from leafmap.foliumap import Map
724
+ import folium
725
+ import folium.plugins as plugins
726
+ from leafmap import cog_tile
1131
727
  from localtileserver import TileClient, get_folium_tile_layer
1132
728
 
1133
729
  google_tiles = {
@@ -1153,17 +749,9 @@ def view_vector_interactive(
1153
749
  },
1154
750
  }
1155
751
 
1156
- # Make it compatible with binder and JupyterHub
1157
- if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
1158
- os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
1159
- f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
1160
- )
1161
-
1162
752
  basemap_layer_name = None
1163
753
  raster_layer = None
1164
754
 
1165
- m = Map()
1166
-
1167
755
  if "tiles" in kwargs and isinstance(kwargs["tiles"], str):
1168
756
  if kwargs["tiles"].title() in google_tiles:
1169
757
  basemap_layer_name = google_tiles[kwargs["tiles"].title()]["name"]
@@ -1174,17 +762,14 @@ def view_vector_interactive(
1174
762
  tiles_args = {}
1175
763
  if kwargs["tiles"].lower().startswith("http"):
1176
764
  basemap_layer_name = "Remote Raster"
1177
- m.add_geotiff(kwargs["tiles"], name=basemap_layer_name, **tiles_args)
765
+ kwargs["tiles"] = cog_tile(kwargs["tiles"], **tiles_args)
766
+ kwargs["attr"] = "TiTiler"
1178
767
  else:
1179
768
  basemap_layer_name = "Local Raster"
1180
769
  client = TileClient(kwargs["tiles"])
1181
770
  raster_layer = get_folium_tile_layer(client, **tiles_args)
1182
- m.add_tile_layer(
1183
- raster_layer.tiles,
1184
- name=basemap_layer_name,
1185
- attribution="localtileserver",
1186
- **tiles_args,
1187
- )
771
+ kwargs["tiles"] = raster_layer.tiles
772
+ kwargs["attr"] = "localtileserver"
1188
773
 
1189
774
  if "max_zoom" not in kwargs:
1190
775
  kwargs["max_zoom"] = 30
@@ -1199,18 +784,23 @@ def view_vector_interactive(
1199
784
  if not isinstance(vector_data, gpd.GeoDataFrame):
1200
785
  raise TypeError("Input data must be a GeoDataFrame")
1201
786
 
1202
- if "column" in kwargs:
1203
- if "legend_position" not in kwargs:
1204
- kwargs["legend_position"] = "bottomleft"
1205
- if "cmap" not in kwargs:
1206
- kwargs["cmap"] = "viridis"
1207
- m.add_data(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
787
+ layer_control = kwargs.pop("layer_control", True)
788
+ fullscreen_control = kwargs.pop("fullscreen_control", True)
1208
789
 
1209
- else:
1210
- m.add_gdf(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
790
+ m = vector_data.explore(**kwargs)
1211
791
 
1212
- m.add_layer_control()
1213
- m.add_opacity_control()
792
+ # Change the layer name
793
+ for layer in m._children.values():
794
+ if isinstance(layer, folium.GeoJson):
795
+ layer.layer_name = layer_name
796
+ if isinstance(layer, folium.TileLayer) and basemap_layer_name:
797
+ layer.layer_name = basemap_layer_name
798
+
799
+ if layer_control:
800
+ m.add_child(folium.LayerControl())
801
+
802
+ if fullscreen_control:
803
+ plugins.Fullscreen().add_to(m)
1214
804
 
1215
805
  return m
1216
806
 
@@ -3007,7 +2597,7 @@ def batch_vector_to_raster(
3007
2597
  def export_geotiff_tiles(
3008
2598
  in_raster,
3009
2599
  out_folder,
3010
- in_class_data=None,
2600
+ in_class_data,
3011
2601
  tile_size=256,
3012
2602
  stride=128,
3013
2603
  class_value_field="class",
@@ -3017,7 +2607,6 @@ def export_geotiff_tiles(
3017
2607
  all_touched=True,
3018
2608
  create_overview=False,
3019
2609
  skip_empty_tiles=False,
3020
- metadata_format="PASCAL_VOC",
3021
2610
  ):
3022
2611
  """
3023
2612
  Export georeferenced GeoTIFF tiles and labels from raster and classification data.
@@ -3025,8 +2614,7 @@ def export_geotiff_tiles(
3025
2614
  Args:
3026
2615
  in_raster (str): Path to input raster image
3027
2616
  out_folder (str): Path to output folder
3028
- in_class_data (str, optional): Path to classification data - can be vector file or raster.
3029
- If None, only image tiles will be exported without labels. Defaults to None.
2617
+ in_class_data (str): Path to classification data - can be vector file or raster
3030
2618
  tile_size (int): Size of tiles in pixels (square)
3031
2619
  stride (int): Step size between tiles
3032
2620
  class_value_field (str): Field containing class values (for vector data)
@@ -3036,7 +2624,6 @@ def export_geotiff_tiles(
3036
2624
  all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
3037
2625
  create_overview (bool): Whether to create an overview image of all tiles
3038
2626
  skip_empty_tiles (bool): If True, skip tiles with no features
3039
- metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
3040
2627
  """
3041
2628
 
3042
2629
  import logging
@@ -3047,42 +2634,28 @@ def export_geotiff_tiles(
3047
2634
  os.makedirs(out_folder, exist_ok=True)
3048
2635
  image_dir = os.path.join(out_folder, "images")
3049
2636
  os.makedirs(image_dir, exist_ok=True)
2637
+ label_dir = os.path.join(out_folder, "labels")
2638
+ os.makedirs(label_dir, exist_ok=True)
2639
+ ann_dir = os.path.join(out_folder, "annotations")
2640
+ os.makedirs(ann_dir, exist_ok=True)
3050
2641
 
3051
- # Only create label and annotation directories if class data is provided
3052
- if in_class_data is not None:
3053
- label_dir = os.path.join(out_folder, "labels")
3054
- os.makedirs(label_dir, exist_ok=True)
3055
-
3056
- # Create annotation directory based on metadata format
3057
- if metadata_format in ["PASCAL_VOC", "COCO"]:
3058
- ann_dir = os.path.join(out_folder, "annotations")
3059
- os.makedirs(ann_dir, exist_ok=True)
3060
-
3061
- # Initialize COCO annotations dictionary
3062
- if metadata_format == "COCO":
3063
- coco_annotations = {"images": [], "annotations": [], "categories": []}
3064
- ann_id = 0
3065
-
3066
- # Determine if class data is raster or vector (only if class data provided)
2642
+ # Determine if class data is raster or vector
3067
2643
  is_class_data_raster = False
3068
- if in_class_data is not None:
3069
- if isinstance(in_class_data, str):
3070
- file_ext = Path(in_class_data).suffix.lower()
3071
- # Common raster extensions
3072
- if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
3073
- try:
3074
- with rasterio.open(in_class_data) as src:
3075
- is_class_data_raster = True
3076
- if not quiet:
3077
- print(f"Detected in_class_data as raster: {in_class_data}")
3078
- print(f"Raster CRS: {src.crs}")
3079
- print(f"Raster dimensions: {src.width} x {src.height}")
3080
- except Exception:
3081
- is_class_data_raster = False
2644
+ if isinstance(in_class_data, str):
2645
+ file_ext = Path(in_class_data).suffix.lower()
2646
+ # Common raster extensions
2647
+ if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
2648
+ try:
2649
+ with rasterio.open(in_class_data) as src:
2650
+ is_class_data_raster = True
3082
2651
  if not quiet:
3083
- print(
3084
- f"Unable to open {in_class_data} as raster, trying as vector"
3085
- )
2652
+ print(f"Detected in_class_data as raster: {in_class_data}")
2653
+ print(f"Raster CRS: {src.crs}")
2654
+ print(f"Raster dimensions: {src.width} x {src.height}")
2655
+ except Exception:
2656
+ is_class_data_raster = False
2657
+ if not quiet:
2658
+ print(f"Unable to open {in_class_data} as raster, trying as vector")
3086
2659
 
3087
2660
  # Open the input raster
3088
2661
  with rasterio.open(in_raster) as src:
@@ -3102,10 +2675,10 @@ def export_geotiff_tiles(
3102
2675
  if max_tiles is None:
3103
2676
  max_tiles = total_tiles
3104
2677
 
3105
- # Process classification data (only if class data provided)
2678
+ # Process classification data
3106
2679
  class_to_id = {}
3107
2680
 
3108
- if in_class_data is not None and is_class_data_raster:
2681
+ if is_class_data_raster:
3109
2682
  # Load raster class data
3110
2683
  with rasterio.open(in_class_data) as class_src:
3111
2684
  # Check if raster CRS matches
@@ -3138,18 +2711,7 @@ def export_geotiff_tiles(
3138
2711
 
3139
2712
  # Create class mapping
3140
2713
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
3141
-
3142
- # Populate COCO categories
3143
- if metadata_format == "COCO":
3144
- for cls_val in unique_classes:
3145
- coco_annotations["categories"].append(
3146
- {
3147
- "id": class_to_id[int(cls_val)],
3148
- "name": str(int(cls_val)),
3149
- "supercategory": "object",
3150
- }
3151
- )
3152
- elif in_class_data is not None:
2714
+ else:
3153
2715
  # Load vector class data
3154
2716
  try:
3155
2717
  gdf = gpd.read_file(in_class_data)
@@ -3178,33 +2740,12 @@ def export_geotiff_tiles(
3178
2740
  )
3179
2741
  # Create class mapping
3180
2742
  class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
3181
-
3182
- # Populate COCO categories
3183
- if metadata_format == "COCO":
3184
- for cls_val in unique_classes:
3185
- coco_annotations["categories"].append(
3186
- {
3187
- "id": class_to_id[cls_val],
3188
- "name": str(cls_val),
3189
- "supercategory": "object",
3190
- }
3191
- )
3192
2743
  else:
3193
2744
  if not quiet:
3194
2745
  print(
3195
2746
  f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
3196
2747
  )
3197
2748
  class_to_id = {1: 1} # Default mapping
3198
-
3199
- # Populate COCO categories with default
3200
- if metadata_format == "COCO":
3201
- coco_annotations["categories"].append(
3202
- {
3203
- "id": 1,
3204
- "name": "object",
3205
- "supercategory": "object",
3206
- }
3207
- )
3208
2749
  except Exception as e:
3209
2750
  raise ValueError(f"Error processing vector data: {e}")
3210
2751
 
@@ -3271,8 +2812,8 @@ def export_geotiff_tiles(
3271
2812
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
3272
2813
  has_features = False
3273
2814
 
3274
- # Process classification data to create labels (only if class data provided)
3275
- if in_class_data is not None and is_class_data_raster:
2815
+ # Process classification data to create labels
2816
+ if is_class_data_raster:
3276
2817
  # For raster class data
3277
2818
  with rasterio.open(in_class_data) as class_src:
3278
2819
  # Calculate window in class raster
@@ -3322,7 +2863,7 @@ def export_geotiff_tiles(
3322
2863
  except Exception as e:
3323
2864
  pbar.write(f"Error reading class raster window: {e}")
3324
2865
  stats["errors"] += 1
3325
- elif in_class_data is not None:
2866
+ else:
3326
2867
  # For vector class data
3327
2868
  # Find features that intersect with window
3328
2869
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -3365,8 +2906,8 @@ def export_geotiff_tiles(
3365
2906
  pbar.write(f"Error rasterizing feature {idx}: {e}")
3366
2907
  stats["errors"] += 1
3367
2908
 
3368
- # Skip tile if no features and skip_empty_tiles is True (only when class data provided)
3369
- if in_class_data is not None and skip_empty_tiles and not has_features:
2909
+ # Skip tile if no features and skip_empty_tiles is True
2910
+ if skip_empty_tiles and not has_features:
3370
2911
  pbar.update(1)
3371
2912
  tile_index += 1
3372
2913
  continue
@@ -3397,212 +2938,96 @@ def export_geotiff_tiles(
3397
2938
  pbar.write(f"ERROR saving image GeoTIFF: {e}")
3398
2939
  stats["errors"] += 1
3399
2940
 
3400
- # Export label as GeoTIFF (only if class data provided)
3401
- if in_class_data is not None:
3402
- # Create profile for label GeoTIFF
3403
- label_profile = {
3404
- "driver": "GTiff",
3405
- "height": tile_size,
3406
- "width": tile_size,
3407
- "count": 1,
3408
- "dtype": "uint8",
3409
- "crs": src.crs,
3410
- "transform": window_transform,
3411
- }
2941
+ # Create profile for label GeoTIFF
2942
+ label_profile = {
2943
+ "driver": "GTiff",
2944
+ "height": tile_size,
2945
+ "width": tile_size,
2946
+ "count": 1,
2947
+ "dtype": "uint8",
2948
+ "crs": src.crs,
2949
+ "transform": window_transform,
2950
+ }
3412
2951
 
3413
- label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
3414
- try:
3415
- with rasterio.open(label_path, "w", **label_profile) as dst:
3416
- dst.write(label_mask.astype(np.uint8), 1)
2952
+ # Export label as GeoTIFF
2953
+ label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
2954
+ try:
2955
+ with rasterio.open(label_path, "w", **label_profile) as dst:
2956
+ dst.write(label_mask.astype(np.uint8), 1)
3417
2957
 
3418
- if has_features:
3419
- stats["tiles_with_features"] += 1
3420
- stats["feature_pixels"] += np.count_nonzero(label_mask)
3421
- except Exception as e:
3422
- pbar.write(f"ERROR saving label GeoTIFF: {e}")
3423
- stats["errors"] += 1
2958
+ if has_features:
2959
+ stats["tiles_with_features"] += 1
2960
+ stats["feature_pixels"] += np.count_nonzero(label_mask)
2961
+ except Exception as e:
2962
+ pbar.write(f"ERROR saving label GeoTIFF: {e}")
2963
+ stats["errors"] += 1
3424
2964
 
3425
- # Create annotations for object detection if using vector class data
2965
+ # Create XML annotation for object detection if using vector class data
3426
2966
  if (
3427
- in_class_data is not None
3428
- and not is_class_data_raster
2967
+ not is_class_data_raster
3429
2968
  and "gdf" in locals()
3430
2969
  and len(window_features) > 0
3431
2970
  ):
3432
- if metadata_format == "PASCAL_VOC":
3433
- # Create XML annotation
3434
- root = ET.Element("annotation")
3435
- ET.SubElement(root, "folder").text = "images"
3436
- ET.SubElement(root, "filename").text = (
3437
- f"tile_{tile_index:06d}.tif"
3438
- )
3439
-
3440
- size = ET.SubElement(root, "size")
3441
- ET.SubElement(size, "width").text = str(tile_size)
3442
- ET.SubElement(size, "height").text = str(tile_size)
3443
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
3444
-
3445
- # Add georeference information
3446
- geo = ET.SubElement(root, "georeference")
3447
- ET.SubElement(geo, "crs").text = str(src.crs)
3448
- ET.SubElement(geo, "transform").text = str(
3449
- window_transform
3450
- ).replace("\n", "")
3451
- ET.SubElement(geo, "bounds").text = (
3452
- f"{minx}, {miny}, {maxx}, {maxy}"
3453
- )
3454
-
3455
- # Add objects
3456
- for idx, feature in window_features.iterrows():
3457
- # Get feature class
3458
- if class_value_field in feature:
3459
- class_val = feature[class_value_field]
3460
- else:
3461
- class_val = "object"
3462
-
3463
- # Get geometry bounds in pixel coordinates
3464
- geom = feature.geometry.intersection(window_bounds)
3465
- if not geom.is_empty:
3466
- # Get bounds in world coordinates
3467
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3468
-
3469
- # Convert to pixel coordinates
3470
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3471
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3472
-
3473
- # Ensure coordinates are within tile bounds
3474
- xmin = max(0, min(tile_size, int(col_min)))
3475
- ymin = max(0, min(tile_size, int(row_min)))
3476
- xmax = max(0, min(tile_size, int(col_max)))
3477
- ymax = max(0, min(tile_size, int(row_max)))
3478
-
3479
- # Only add if the box has non-zero area
3480
- if xmax > xmin and ymax > ymin:
3481
- obj = ET.SubElement(root, "object")
3482
- ET.SubElement(obj, "name").text = str(class_val)
3483
- ET.SubElement(obj, "difficult").text = "0"
3484
-
3485
- bbox = ET.SubElement(obj, "bndbox")
3486
- ET.SubElement(bbox, "xmin").text = str(xmin)
3487
- ET.SubElement(bbox, "ymin").text = str(ymin)
3488
- ET.SubElement(bbox, "xmax").text = str(xmax)
3489
- ET.SubElement(bbox, "ymax").text = str(ymax)
3490
-
3491
- # Save XML
3492
- tree = ET.ElementTree(root)
3493
- xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3494
- tree.write(xml_path)
3495
-
3496
- elif metadata_format == "COCO":
3497
- # Add image info
3498
- image_id = tile_index
3499
- coco_annotations["images"].append(
3500
- {
3501
- "id": image_id,
3502
- "file_name": f"tile_{tile_index:06d}.tif",
3503
- "width": tile_size,
3504
- "height": tile_size,
3505
- "crs": str(src.crs),
3506
- "transform": str(window_transform),
3507
- }
3508
- )
3509
-
3510
- # Add annotations for each feature
3511
- for _, feature in window_features.iterrows():
3512
- # Get feature class
3513
- if class_value_field in feature:
3514
- class_val = feature[class_value_field]
3515
- category_id = class_to_id.get(class_val, 1)
3516
- else:
3517
- category_id = 1
2971
+ # Create XML annotation
2972
+ root = ET.Element("annotation")
2973
+ ET.SubElement(root, "folder").text = "images"
2974
+ ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
3518
2975
 
3519
- # Get geometry bounds
3520
- geom = feature.geometry.intersection(window_bounds)
3521
- if not geom.is_empty:
3522
- # Get bounds in world coordinates
3523
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3524
-
3525
- # Convert to pixel coordinates
3526
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3527
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3528
-
3529
- # Ensure coordinates are within tile bounds
3530
- xmin = max(0, min(tile_size, int(col_min)))
3531
- ymin = max(0, min(tile_size, int(row_min)))
3532
- xmax = max(0, min(tile_size, int(col_max)))
3533
- ymax = max(0, min(tile_size, int(row_max)))
3534
-
3535
- # Skip if box is too small
3536
- if xmax - xmin < 1 or ymax - ymin < 1:
3537
- continue
3538
-
3539
- width = xmax - xmin
3540
- height = ymax - ymin
3541
-
3542
- # Add annotation
3543
- ann_id += 1
3544
- coco_annotations["annotations"].append(
3545
- {
3546
- "id": ann_id,
3547
- "image_id": image_id,
3548
- "category_id": category_id,
3549
- "bbox": [xmin, ymin, width, height],
3550
- "area": width * height,
3551
- "iscrowd": 0,
3552
- }
3553
- )
2976
+ size = ET.SubElement(root, "size")
2977
+ ET.SubElement(size, "width").text = str(tile_size)
2978
+ ET.SubElement(size, "height").text = str(tile_size)
2979
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
3554
2980
 
3555
- elif metadata_format == "YOLO":
3556
- # Create YOLO format annotations
3557
- yolo_annotations = []
2981
+ # Add georeference information
2982
+ geo = ET.SubElement(root, "georeference")
2983
+ ET.SubElement(geo, "crs").text = str(src.crs)
2984
+ ET.SubElement(geo, "transform").text = str(
2985
+ window_transform
2986
+ ).replace("\n", "")
2987
+ ET.SubElement(geo, "bounds").text = (
2988
+ f"{minx}, {miny}, {maxx}, {maxy}"
2989
+ )
3558
2990
 
3559
- for _, feature in window_features.iterrows():
3560
- # Get feature class
3561
- if class_value_field in feature:
3562
- class_val = feature[class_value_field]
3563
- # YOLO uses 0-indexed class IDs
3564
- class_id = class_to_id.get(class_val, 1) - 1
3565
- else:
3566
- class_id = 0
2991
+ # Add objects
2992
+ for idx, feature in window_features.iterrows():
2993
+ # Get feature class
2994
+ if class_value_field in feature:
2995
+ class_val = feature[class_value_field]
2996
+ else:
2997
+ class_val = "object"
3567
2998
 
3568
- # Get geometry bounds
3569
- geom = feature.geometry.intersection(window_bounds)
3570
- if not geom.is_empty:
3571
- # Get bounds in world coordinates
3572
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3573
-
3574
- # Convert to pixel coordinates
3575
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3576
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3577
-
3578
- # Ensure coordinates are within tile bounds
3579
- xmin = max(0, min(tile_size, col_min))
3580
- ymin = max(0, min(tile_size, row_min))
3581
- xmax = max(0, min(tile_size, col_max))
3582
- ymax = max(0, min(tile_size, row_max))
3583
-
3584
- # Skip if box is too small
3585
- if xmax - xmin < 1 or ymax - ymin < 1:
3586
- continue
3587
-
3588
- # Calculate normalized coordinates (YOLO format)
3589
- x_center = ((xmin + xmax) / 2) / tile_size
3590
- y_center = ((ymin + ymax) / 2) / tile_size
3591
- width = (xmax - xmin) / tile_size
3592
- height = (ymax - ymin) / tile_size
3593
-
3594
- # Add YOLO annotation line
3595
- yolo_annotations.append(
3596
- f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
3597
- )
2999
+ # Get geometry bounds in pixel coordinates
3000
+ geom = feature.geometry.intersection(window_bounds)
3001
+ if not geom.is_empty:
3002
+ # Get bounds in world coordinates
3003
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3004
+
3005
+ # Convert to pixel coordinates
3006
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3007
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3008
+
3009
+ # Ensure coordinates are within tile bounds
3010
+ xmin = max(0, min(tile_size, int(col_min)))
3011
+ ymin = max(0, min(tile_size, int(row_min)))
3012
+ xmax = max(0, min(tile_size, int(col_max)))
3013
+ ymax = max(0, min(tile_size, int(row_max)))
3014
+
3015
+ # Only add if the box has non-zero area
3016
+ if xmax > xmin and ymax > ymin:
3017
+ obj = ET.SubElement(root, "object")
3018
+ ET.SubElement(obj, "name").text = str(class_val)
3019
+ ET.SubElement(obj, "difficult").text = "0"
3020
+
3021
+ bbox = ET.SubElement(obj, "bndbox")
3022
+ ET.SubElement(bbox, "xmin").text = str(xmin)
3023
+ ET.SubElement(bbox, "ymin").text = str(ymin)
3024
+ ET.SubElement(bbox, "xmax").text = str(xmax)
3025
+ ET.SubElement(bbox, "ymax").text = str(ymax)
3598
3026
 
3599
- # Save YOLO annotations to text file
3600
- if yolo_annotations:
3601
- yolo_path = os.path.join(
3602
- label_dir, f"tile_{tile_index:06d}.txt"
3603
- )
3604
- with open(yolo_path, "w") as f:
3605
- f.write("\n".join(yolo_annotations))
3027
+ # Save XML
3028
+ tree = ET.ElementTree(root)
3029
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3030
+ tree.write(xml_path)
3606
3031
 
3607
3032
  # Update progress bar
3608
3033
  pbar.update(1)
@@ -3620,39 +3045,6 @@ def export_geotiff_tiles(
3620
3045
  # Close progress bar
3621
3046
  pbar.close()
3622
3047
 
3623
- # Save COCO annotations if applicable (only if class data provided)
3624
- if in_class_data is not None and metadata_format == "COCO":
3625
- try:
3626
- with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3627
- json.dump(coco_annotations, f, indent=2)
3628
- if not quiet:
3629
- print(
3630
- f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
3631
- f"{len(coco_annotations['annotations'])} annotations, "
3632
- f"{len(coco_annotations['categories'])} categories"
3633
- )
3634
- except Exception as e:
3635
- if not quiet:
3636
- print(f"ERROR saving COCO annotations: {e}")
3637
- stats["errors"] += 1
3638
-
3639
- # Save YOLO classes file if applicable (only if class data provided)
3640
- if in_class_data is not None and metadata_format == "YOLO":
3641
- try:
3642
- # Create classes.txt with class names
3643
- classes_path = os.path.join(out_folder, "classes.txt")
3644
- # Sort by class ID to ensure correct order
3645
- sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
3646
- with open(classes_path, "w") as f:
3647
- for class_val, _ in sorted_classes:
3648
- f.write(f"{class_val}\n")
3649
- if not quiet:
3650
- print(f"Saved YOLO classes file with {len(class_to_id)} classes")
3651
- except Exception as e:
3652
- if not quiet:
3653
- print(f"ERROR saving YOLO classes file: {e}")
3654
- stats["errors"] += 1
3655
-
3656
3048
  # Create overview image if requested
3657
3049
  if create_overview and stats["tile_coordinates"]:
3658
3050
  try:
@@ -3670,14 +3062,13 @@ def export_geotiff_tiles(
3670
3062
  if not quiet:
3671
3063
  print("\n------- Export Summary -------")
3672
3064
  print(f"Total tiles exported: {stats['total_tiles']}")
3673
- if in_class_data is not None:
3065
+ print(
3066
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3067
+ )
3068
+ if stats["tiles_with_features"] > 0:
3674
3069
  print(
3675
- f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3070
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3676
3071
  )
3677
- if stats["tiles_with_features"] > 0:
3678
- print(
3679
- f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3680
- )
3681
3072
  if stats["errors"] > 0:
3682
3073
  print(f"Errors encountered: {stats['errors']}")
3683
3074
  print(f"Output saved to: {out_folder}")
@@ -3686,6 +3077,7 @@ def export_geotiff_tiles(
3686
3077
  if stats["total_tiles"] > 0:
3687
3078
  print("\n------- Georeference Verification -------")
3688
3079
  sample_image = os.path.join(image_dir, f"tile_0.tif")
3080
+ sample_label = os.path.join(label_dir, f"tile_0.tif")
3689
3081
 
3690
3082
  if os.path.exists(sample_image):
3691
3083
  try:
@@ -3701,22 +3093,19 @@ def export_geotiff_tiles(
3701
3093
  except Exception as e:
3702
3094
  print(f"Error verifying image georeference: {e}")
3703
3095
 
3704
- # Only verify label if class data was provided
3705
- if in_class_data is not None:
3706
- sample_label = os.path.join(label_dir, f"tile_0.tif")
3707
- if os.path.exists(sample_label):
3708
- try:
3709
- with rasterio.open(sample_label) as lbl:
3710
- print(f"Label CRS: {lbl.crs}")
3711
- print(f"Label transform: {lbl.transform}")
3712
- print(
3713
- f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3714
- )
3715
- print(
3716
- f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3717
- )
3718
- except Exception as e:
3719
- print(f"Error verifying label georeference: {e}")
3096
+ if os.path.exists(sample_label):
3097
+ try:
3098
+ with rasterio.open(sample_label) as lbl:
3099
+ print(f"Label CRS: {lbl.crs}")
3100
+ print(f"Label transform: {lbl.transform}")
3101
+ print(
3102
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3103
+ )
3104
+ print(
3105
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3106
+ )
3107
+ except Exception as e:
3108
+ print(f"Error verifying label georeference: {e}")
3720
3109
 
3721
3110
  # Return statistics dictionary for further processing if needed
3722
3111
  return stats
@@ -3724,9 +3113,8 @@ def export_geotiff_tiles(
3724
3113
 
3725
3114
  def export_geotiff_tiles_batch(
3726
3115
  images_folder,
3727
- masks_folder=None,
3728
- masks_file=None,
3729
- output_folder=None,
3116
+ masks_folder,
3117
+ output_folder,
3730
3118
  tile_size=256,
3731
3119
  stride=128,
3732
3120
  class_value_field="class",
@@ -3734,43 +3122,25 @@ def export_geotiff_tiles_batch(
3734
3122
  max_tiles=None,
3735
3123
  quiet=False,
3736
3124
  all_touched=True,
3125
+ create_overview=False,
3737
3126
  skip_empty_tiles=False,
3738
3127
  image_extensions=None,
3739
3128
  mask_extensions=None,
3740
- match_by_name=False,
3741
- metadata_format="PASCAL_VOC",
3742
3129
  ) -> Dict[str, Any]:
3743
3130
  """
3744
- Export georeferenced GeoTIFF tiles from images and optionally masks.
3745
-
3746
- This function supports four modes:
3747
- 1. Images only (no masks) - when neither masks_file nor masks_folder is provided
3748
- 2. Single vector file covering all images (masks_file parameter)
3749
- 3. Multiple vector files, one per image (masks_folder parameter)
3750
- 4. Multiple raster mask files (masks_folder parameter)
3131
+ Export georeferenced GeoTIFF tiles from folders of images and masks.
3751
3132
 
3752
- For mode 1 (images only), only image tiles will be exported without labels.
3133
+ This function processes multiple image-mask pairs from input folders,
3134
+ generating tiles for each pair. All image tiles are saved to a single
3135
+ 'images' folder and all mask tiles to a single 'masks' folder.
3753
3136
 
3754
- For mode 2 (single vector file), specify masks_file path. The function will
3755
- use spatial intersection to determine which features apply to each image.
3756
-
3757
- For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
3758
- are paired either by matching filenames (match_by_name=True) or by sorted order
3759
- (match_by_name=False).
3760
-
3761
- All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
3762
- to a single 'masks' folder within the output directory.
3137
+ Images and masks are paired by their sorted order (alphabetically), not by
3138
+ filename matching. The number of images and masks must be equal.
3763
3139
 
3764
3140
  Args:
3765
3141
  images_folder (str): Path to folder containing raster images
3766
- masks_folder (str, optional): Path to folder containing classification masks/vectors.
3767
- Use this for multiple mask files (one per image or raster masks). If not provided
3768
- and masks_file is also not provided, only image tiles will be exported.
3769
- masks_file (str, optional): Path to a single vector file covering all images.
3770
- Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
3771
- and masks_folder is also not provided, only image tiles will be exported.
3772
- output_folder (str, optional): Path to output folder. If None, creates 'tiles'
3773
- subfolder in images_folder.
3142
+ masks_folder (str): Path to folder containing classification masks/vectors
3143
+ output_folder (str): Path to output folder
3774
3144
  tile_size (int): Size of tiles in pixels (square)
3775
3145
  stride (int): Step size between tiles
3776
3146
  class_value_field (str): Field containing class values (for vector data)
@@ -3782,63 +3152,18 @@ def export_geotiff_tiles_batch(
3782
3152
  skip_empty_tiles (bool): If True, skip tiles with no features
3783
3153
  image_extensions (list): List of image file extensions to process (default: common raster formats)
3784
3154
  mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
3785
- match_by_name (bool): If True, match image and mask files by base filename.
3786
- If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
3787
- metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
3788
- Default is "PASCAL_VOC".
3789
3155
 
3790
3156
  Returns:
3791
3157
  Dict[str, Any]: Dictionary containing batch processing statistics
3792
3158
 
3793
3159
  Raises:
3794
- ValueError: If no images found, or if masks_folder and masks_file are both specified,
3795
- or if counts don't match when using masks_folder with match_by_name=False.
3796
-
3797
- Examples:
3798
- # Images only (no masks)
3799
- >>> stats = export_geotiff_tiles_batch(
3800
- ... images_folder='data/images',
3801
- ... output_folder='output/tiles'
3802
- ... )
3803
-
3804
- # Single vector file covering all images
3805
- >>> stats = export_geotiff_tiles_batch(
3806
- ... images_folder='data/images',
3807
- ... masks_file='data/buildings.geojson',
3808
- ... output_folder='output/tiles'
3809
- ... )
3810
-
3811
- # Multiple vector files, matched by filename
3812
- >>> stats = export_geotiff_tiles_batch(
3813
- ... images_folder='data/images',
3814
- ... masks_folder='data/masks',
3815
- ... output_folder='output/tiles',
3816
- ... match_by_name=True
3817
- ... )
3818
-
3819
- # Multiple mask files, matched by sorted order
3820
- >>> stats = export_geotiff_tiles_batch(
3821
- ... images_folder='data/images',
3822
- ... masks_folder='data/masks',
3823
- ... output_folder='output/tiles',
3824
- ... match_by_name=False
3825
- ... )
3160
+ ValueError: If no images or masks found, or if counts don't match
3826
3161
  """
3827
3162
 
3828
3163
  import logging
3829
3164
 
3830
3165
  logging.getLogger("rasterio").setLevel(logging.ERROR)
3831
3166
 
3832
- # Validate input parameters
3833
- if masks_folder is not None and masks_file is not None:
3834
- raise ValueError(
3835
- "Cannot specify both masks_folder and masks_file. Please use only one."
3836
- )
3837
-
3838
- # Default output folder if not specified
3839
- if output_folder is None:
3840
- output_folder = os.path.join(images_folder, "tiles")
3841
-
3842
3167
  # Default extensions if not provided
3843
3168
  if image_extensions is None:
3844
3169
  image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
@@ -3865,37 +3190,9 @@ def export_geotiff_tiles_batch(
3865
3190
  # Create output folder structure
3866
3191
  os.makedirs(output_folder, exist_ok=True)
3867
3192
  output_images_dir = os.path.join(output_folder, "images")
3193
+ output_masks_dir = os.path.join(output_folder, "masks")
3868
3194
  os.makedirs(output_images_dir, exist_ok=True)
3869
-
3870
- # Only create masks directory if masks are provided
3871
- output_masks_dir = None
3872
- if masks_folder is not None or masks_file is not None:
3873
- output_masks_dir = os.path.join(output_folder, "masks")
3874
- os.makedirs(output_masks_dir, exist_ok=True)
3875
-
3876
- # Create annotation directory based on metadata format (only if masks are provided)
3877
- ann_dir = None
3878
- if (masks_folder is not None or masks_file is not None) and metadata_format in [
3879
- "PASCAL_VOC",
3880
- "COCO",
3881
- ]:
3882
- ann_dir = os.path.join(output_folder, "annotations")
3883
- os.makedirs(ann_dir, exist_ok=True)
3884
-
3885
- # Initialize COCO annotations dictionary (only if masks are provided)
3886
- coco_annotations = None
3887
- if (
3888
- masks_folder is not None or masks_file is not None
3889
- ) and metadata_format == "COCO":
3890
- coco_annotations = {"images": [], "annotations": [], "categories": []}
3891
-
3892
- # Initialize YOLO class set (only if masks are provided)
3893
- yolo_classes = (
3894
- set()
3895
- if (masks_folder is not None or masks_file is not None)
3896
- and metadata_format == "YOLO"
3897
- else None
3898
- )
3195
+ os.makedirs(output_masks_dir, exist_ok=True)
3899
3196
 
3900
3197
  # Get list of image files
3901
3198
  image_files = []
@@ -3903,105 +3200,30 @@ def export_geotiff_tiles_batch(
3903
3200
  pattern = os.path.join(images_folder, f"*{ext}")
3904
3201
  image_files.extend(glob.glob(pattern))
3905
3202
 
3203
+ # Get list of mask files
3204
+ mask_files = []
3205
+ for ext in mask_extensions:
3206
+ pattern = os.path.join(masks_folder, f"*{ext}")
3207
+ mask_files.extend(glob.glob(pattern))
3208
+
3906
3209
  # Sort files for consistent processing
3907
3210
  image_files.sort()
3211
+ mask_files.sort()
3908
3212
 
3909
3213
  if not image_files:
3910
3214
  raise ValueError(
3911
3215
  f"No image files found in {images_folder} with extensions {image_extensions}"
3912
3216
  )
3913
3217
 
3914
- # Handle different mask input modes
3915
- use_single_mask_file = masks_file is not None
3916
- has_masks = masks_file is not None or masks_folder is not None
3917
- mask_files = []
3918
- image_mask_pairs = []
3919
-
3920
- if not has_masks:
3921
- # Mode 0: No masks - create pairs with None for mask
3922
- for image_file in image_files:
3923
- image_mask_pairs.append((image_file, None, None))
3924
-
3925
- elif use_single_mask_file:
3926
- # Mode 1: Single vector file covering all images
3927
- if not os.path.exists(masks_file):
3928
- raise ValueError(f"Mask file not found: {masks_file}")
3929
-
3930
- # Load the single mask file once - will be spatially filtered per image
3931
- single_mask_gdf = gpd.read_file(masks_file)
3932
-
3933
- if not quiet:
3934
- print(f"Using single mask file: {masks_file}")
3935
- print(
3936
- f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
3937
- )
3938
-
3939
- # Create pairs with the same mask file for all images
3940
- for image_file in image_files:
3941
- image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
3942
-
3943
- else:
3944
- # Mode 2/3: Multiple mask files (vector or raster)
3945
- # Get list of mask files
3946
- for ext in mask_extensions:
3947
- pattern = os.path.join(masks_folder, f"*{ext}")
3948
- mask_files.extend(glob.glob(pattern))
3949
-
3950
- # Sort files for consistent processing
3951
- mask_files.sort()
3952
-
3953
- if not mask_files:
3954
- raise ValueError(
3955
- f"No mask files found in {masks_folder} with extensions {mask_extensions}"
3956
- )
3957
-
3958
- # Match images to masks
3959
- if match_by_name:
3960
- # Match by base filename
3961
- image_dict = {
3962
- os.path.splitext(os.path.basename(f))[0]: f for f in image_files
3963
- }
3964
- mask_dict = {
3965
- os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
3966
- }
3967
-
3968
- # Find matching pairs
3969
- for img_base, img_path in image_dict.items():
3970
- if img_base in mask_dict:
3971
- image_mask_pairs.append((img_path, mask_dict[img_base], None))
3972
- else:
3973
- if not quiet:
3974
- print(f"Warning: No mask found for image {img_base}")
3975
-
3976
- if not image_mask_pairs:
3977
- # Provide detailed error message with found files
3978
- image_bases = list(image_dict.keys())
3979
- mask_bases = list(mask_dict.keys())
3980
- error_msg = (
3981
- "No matching image-mask pairs found when matching by filename. "
3982
- "Check that image and mask files have matching base names.\n"
3983
- f"Found {len(image_bases)} image(s): "
3984
- f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
3985
- f"{'...' if len(image_bases) > 5 else ''}\n"
3986
- f"Found {len(mask_bases)} mask(s): "
3987
- f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
3988
- f"{'...' if len(mask_bases) > 5 else ''}\n"
3989
- "Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
3990
- )
3991
- raise ValueError(error_msg)
3992
-
3993
- else:
3994
- # Match by sorted order
3995
- if len(image_files) != len(mask_files):
3996
- raise ValueError(
3997
- f"Number of image files ({len(image_files)}) does not match "
3998
- f"number of mask files ({len(mask_files)}) when matching by sorted order. "
3999
- f"Use match_by_name=True for filename-based matching."
4000
- )
3218
+ if not mask_files:
3219
+ raise ValueError(
3220
+ f"No mask files found in {masks_folder} with extensions {mask_extensions}"
3221
+ )
4001
3222
 
4002
- # Create pairs by sorted order
4003
- for image_file, mask_file in zip(image_files, mask_files):
4004
- image_mask_pairs.append((image_file, mask_file, None))
3223
+ if len(image_files) != len(mask_files):
3224
+ raise ValueError(
3225
+ f"Number of image files ({len(image_files)}) does not match number of mask files ({len(mask_files)})"
3226
+ )
4005
3227
 
4006
3228
  # Initialize batch statistics
4007
3229
  batch_stats = {
@@ -4015,28 +3237,23 @@ def export_geotiff_tiles_batch(
4015
3237
  }
4016
3238
 
4017
3239
  if not quiet:
4018
- if not has_masks:
4019
- print(
4020
- f"Found {len(image_files)} image files to process (images only, no masks)"
4021
- )
4022
- elif use_single_mask_file:
4023
- print(f"Found {len(image_files)} image files to process")
4024
- print(f"Using single mask file: {masks_file}")
4025
- else:
4026
- print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
4027
- print(f"Processing batch from {images_folder} and {masks_folder}")
3240
+ print(
3241
+ f"Found {len(image_files)} image files and {len(mask_files)} mask files to process"
3242
+ )
3243
+ print(f"Processing batch from {images_folder} and {masks_folder}")
4028
3244
  print(f"Output folder: {output_folder}")
4029
3245
  print("-" * 60)
4030
3246
 
4031
3247
  # Global tile counter for unique naming
4032
3248
  global_tile_counter = 0
4033
3249
 
4034
- # Process each image-mask pair
4035
- for idx, (image_file, mask_file, mask_gdf) in enumerate(
3250
+ # Process each image-mask pair by sorted order
3251
+ for idx, (image_file, mask_file) in enumerate(
4036
3252
  tqdm(
4037
- image_mask_pairs,
3253
+ zip(image_files, mask_files),
4038
3254
  desc="Processing image pairs",
4039
3255
  disable=quiet,
3256
+ total=len(image_files),
4040
3257
  )
4041
3258
  ):
4042
3259
  batch_stats["total_image_pairs"] += 1
@@ -4048,17 +3265,9 @@ def export_geotiff_tiles_batch(
4048
3265
  if not quiet:
4049
3266
  print(f"\nProcessing: {base_name}")
4050
3267
  print(f" Image: {os.path.basename(image_file)}")
4051
- if mask_file is not None:
4052
- if use_single_mask_file:
4053
- print(
4054
- f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
4055
- )
4056
- else:
4057
- print(f" Mask: {os.path.basename(mask_file)}")
4058
- else:
4059
- print(f" Mask: None (images only)")
3268
+ print(f" Mask: {os.path.basename(mask_file)}")
4060
3269
 
4061
- # Process the image-mask pair
3270
+ # Process the image-mask pair manually to get direct control over tile saving
4062
3271
  tiles_generated = _process_image_mask_pair(
4063
3272
  image_file=image_file,
4064
3273
  mask_file=mask_file,
@@ -4074,15 +3283,6 @@ def export_geotiff_tiles_batch(
4074
3283
  all_touched=all_touched,
4075
3284
  skip_empty_tiles=skip_empty_tiles,
4076
3285
  quiet=quiet,
4077
- mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
4078
- use_single_mask_file=use_single_mask_file,
4079
- metadata_format=metadata_format,
4080
- ann_dir=(
4081
- ann_dir
4082
- if "ann_dir" in locals()
4083
- and metadata_format in ["PASCAL_VOC", "COCO"]
4084
- else None
4085
- ),
4086
3286
  )
4087
3287
 
4088
3288
  # Update counters
@@ -4104,23 +3304,6 @@ def export_geotiff_tiles_batch(
4104
3304
  }
4105
3305
  )
4106
3306
 
4107
- # Aggregate COCO annotations
4108
- if metadata_format == "COCO" and "coco_data" in tiles_generated:
4109
- coco_data = tiles_generated["coco_data"]
4110
- # Add images and annotations
4111
- coco_annotations["images"].extend(coco_data.get("images", []))
4112
- coco_annotations["annotations"].extend(coco_data.get("annotations", []))
4113
- # Merge categories (avoid duplicates)
4114
- for cat in coco_data.get("categories", []):
4115
- if not any(
4116
- c["id"] == cat["id"] for c in coco_annotations["categories"]
4117
- ):
4118
- coco_annotations["categories"].append(cat)
4119
-
4120
- # Aggregate YOLO classes
4121
- if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
4122
- yolo_classes.update(tiles_generated["yolo_classes"])
4123
-
4124
3307
  except Exception as e:
4125
3308
  if not quiet:
4126
3309
  print(f"ERROR processing {base_name}: {e}")
@@ -4129,33 +3312,6 @@ def export_geotiff_tiles_batch(
4129
3312
  )
4130
3313
  batch_stats["errors"] += 1
4131
3314
 
4132
- # Save aggregated COCO annotations
4133
- if metadata_format == "COCO" and coco_annotations:
4134
- import json
4135
-
4136
- coco_path = os.path.join(ann_dir, "instances.json")
4137
- with open(coco_path, "w") as f:
4138
- json.dump(coco_annotations, f, indent=2)
4139
- if not quiet:
4140
- print(f"\nSaved COCO annotations: {coco_path}")
4141
- print(
4142
- f" Images: {len(coco_annotations['images'])}, "
4143
- f"Annotations: {len(coco_annotations['annotations'])}, "
4144
- f"Categories: {len(coco_annotations['categories'])}"
4145
- )
4146
-
4147
- # Save aggregated YOLO classes
4148
- if metadata_format == "YOLO" and yolo_classes:
4149
- classes_path = os.path.join(output_folder, "labels", "classes.txt")
4150
- os.makedirs(os.path.dirname(classes_path), exist_ok=True)
4151
- sorted_classes = sorted(yolo_classes)
4152
- with open(classes_path, "w") as f:
4153
- for cls in sorted_classes:
4154
- f.write(f"{cls}\n")
4155
- if not quiet:
4156
- print(f"\nSaved YOLO classes: {classes_path}")
4157
- print(f" Total classes: {len(sorted_classes)}")
4158
-
4159
3315
  # Print batch summary
4160
3316
  if not quiet:
4161
3317
  print("\n" + "=" * 60)
@@ -4178,12 +3334,7 @@ def export_geotiff_tiles_batch(
4178
3334
 
4179
3335
  print(f"Output saved to: {output_folder}")
4180
3336
  print(f" Images: {output_images_dir}")
4181
- if output_masks_dir is not None:
4182
- print(f" Masks: {output_masks_dir}")
4183
- if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
4184
- print(f" Annotations: {ann_dir}")
4185
- elif metadata_format == "YOLO":
4186
- print(f" Labels: {os.path.join(output_folder, 'labels')}")
3337
+ print(f" Masks: {output_masks_dir}")
4187
3338
 
4188
3339
  # List failed files if any
4189
3340
  if batch_stats["failed_files"]:
@@ -4209,26 +3360,18 @@ def _process_image_mask_pair(
4209
3360
  all_touched=True,
4210
3361
  skip_empty_tiles=False,
4211
3362
  quiet=False,
4212
- mask_gdf=None,
4213
- use_single_mask_file=False,
4214
- metadata_format="PASCAL_VOC",
4215
- ann_dir=None,
4216
3363
  ):
4217
3364
  """
4218
3365
  Process a single image-mask pair and save tiles directly to output directories.
4219
3366
 
4220
- Args:
4221
- mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
4222
- use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
4223
-
4224
3367
  Returns:
4225
3368
  dict: Statistics for this image-mask pair
4226
3369
  """
4227
3370
  import warnings
4228
3371
 
4229
- # Determine if mask data is raster or vector (only if mask_file is provided)
3372
+ # Determine if mask data is raster or vector
4230
3373
  is_class_data_raster = False
4231
- if mask_file is not None and isinstance(mask_file, str):
3374
+ if isinstance(mask_file, str):
4232
3375
  file_ext = Path(mask_file).suffix.lower()
4233
3376
  # Common raster extensions
4234
3377
  if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
@@ -4245,13 +3388,6 @@ def _process_image_mask_pair(
4245
3388
  "errors": 0,
4246
3389
  }
4247
3390
 
4248
- # Initialize COCO/YOLO tracking for this image
4249
- if metadata_format == "COCO":
4250
- stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
4251
- coco_ann_id = 0
4252
- if metadata_format == "YOLO":
4253
- stats["yolo_classes"] = set()
4254
-
4255
3391
  # Open the input raster
4256
3392
  with rasterio.open(image_file) as src:
4257
3393
  # Calculate number of tiles
@@ -4262,10 +3398,10 @@ def _process_image_mask_pair(
4262
3398
  if max_tiles is None:
4263
3399
  max_tiles = total_tiles
4264
3400
 
4265
- # Process classification data (only if mask_file is provided)
3401
+ # Process classification data
4266
3402
  class_to_id = {}
4267
3403
 
4268
- if mask_file is not None and is_class_data_raster:
3404
+ if is_class_data_raster:
4269
3405
  # Load raster class data
4270
3406
  with rasterio.open(mask_file) as class_src:
4271
3407
  # Check if raster CRS matches
@@ -4292,39 +3428,14 @@ def _process_image_mask_pair(
4292
3428
 
4293
3429
  # Create class mapping
4294
3430
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
4295
- elif mask_file is not None:
3431
+ else:
4296
3432
  # Load vector class data
4297
3433
  try:
4298
- if use_single_mask_file and mask_gdf is not None:
4299
- # Using pre-loaded single mask file - spatially filter to image bounds
4300
- # Get image bounds
4301
- image_bounds = box(*src.bounds)
4302
- image_gdf = gpd.GeoDataFrame(
4303
- {"geometry": [image_bounds]}, crs=src.crs
4304
- )
4305
-
4306
- # Reproject mask if needed
4307
- if mask_gdf.crs != src.crs:
4308
- mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
4309
- else:
4310
- mask_gdf_reprojected = mask_gdf
4311
-
4312
- # Spatially filter features that intersect with image bounds
4313
- gdf = mask_gdf_reprojected[
4314
- mask_gdf_reprojected.intersects(image_bounds)
4315
- ].copy()
4316
-
4317
- if not quiet and len(gdf) > 0:
4318
- print(
4319
- f" Filtered to {len(gdf)} features intersecting image bounds"
4320
- )
4321
- else:
4322
- # Load individual mask file
4323
- gdf = gpd.read_file(mask_file)
3434
+ gdf = gpd.read_file(mask_file)
4324
3435
 
4325
- # Always reproject to match raster CRS
4326
- if gdf.crs != src.crs:
4327
- gdf = gdf.to_crs(src.crs)
3436
+ # Always reproject to match raster CRS
3437
+ if gdf.crs != src.crs:
3438
+ gdf = gdf.to_crs(src.crs)
4328
3439
 
4329
3440
  # Apply buffer if specified
4330
3441
  if buffer_radius > 0:
@@ -4344,6 +3455,9 @@ def _process_image_mask_pair(
4344
3455
  tile_index = 0
4345
3456
  for y in range(num_tiles_y):
4346
3457
  for x in range(num_tiles_x):
3458
+ if tile_index >= max_tiles:
3459
+ break
3460
+
4347
3461
  # Calculate window coordinates
4348
3462
  window_x = x * stride
4349
3463
  window_y = y * stride
@@ -4368,12 +3482,12 @@ def _process_image_mask_pair(
4368
3482
 
4369
3483
  window_bounds = box(minx, miny, maxx, maxy)
4370
3484
 
4371
- # Create label mask (only if mask_file is provided)
3485
+ # Create label mask
4372
3486
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
4373
3487
  has_features = False
4374
3488
 
4375
- # Process classification data to create labels (only if mask_file is provided)
4376
- if mask_file is not None and is_class_data_raster:
3489
+ # Process classification data to create labels
3490
+ if is_class_data_raster:
4377
3491
  # For raster class data
4378
3492
  with rasterio.open(mask_file) as class_src:
4379
3493
  # Get corresponding window in class raster
@@ -4406,7 +3520,7 @@ def _process_image_mask_pair(
4406
3520
  if not quiet:
4407
3521
  print(f"Error reading class raster window: {e}")
4408
3522
  stats["errors"] += 1
4409
- elif mask_file is not None:
3523
+ else:
4410
3524
  # For vector class data
4411
3525
  # Find features that intersect with window
4412
3526
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -4444,14 +3558,11 @@ def _process_image_mask_pair(
4444
3558
  print(f"Error rasterizing feature {idx}: {e}")
4445
3559
  stats["errors"] += 1
4446
3560
 
4447
- # Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
4448
- if mask_file is not None and skip_empty_tiles and not has_features:
3561
+ # Skip tile if no features and skip_empty_tiles is True
3562
+ if skip_empty_tiles and not has_features:
3563
+ tile_index += 1
4449
3564
  continue
4450
3565
 
4451
- # Check if we've reached max_tiles before saving
4452
- if tile_index >= max_tiles:
4453
- break
4454
-
4455
3566
  # Generate unique tile name
4456
3567
  tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
4457
3568
 
@@ -4482,225 +3593,29 @@ def _process_image_mask_pair(
4482
3593
  print(f"ERROR saving image GeoTIFF: {e}")
4483
3594
  stats["errors"] += 1
4484
3595
 
4485
- # Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
4486
- if mask_file is not None and output_masks_dir is not None:
4487
- # Create profile for label GeoTIFF
4488
- label_profile = {
4489
- "driver": "GTiff",
4490
- "height": tile_size,
4491
- "width": tile_size,
4492
- "count": 1,
4493
- "dtype": "uint8",
4494
- "crs": src.crs,
4495
- "transform": window_transform,
4496
- }
4497
-
4498
- label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
4499
- try:
4500
- with rasterio.open(label_path, "w", **label_profile) as dst:
4501
- dst.write(label_mask.astype(np.uint8), 1)
4502
-
4503
- if has_features:
4504
- stats["tiles_with_features"] += 1
4505
- except Exception as e:
4506
- if not quiet:
4507
- print(f"ERROR saving label GeoTIFF: {e}")
4508
- stats["errors"] += 1
4509
-
4510
- # Generate annotation metadata based on format (only if mask_file is provided)
4511
- if (
4512
- mask_file is not None
4513
- and metadata_format == "PASCAL_VOC"
4514
- and ann_dir
4515
- ):
4516
- # Create PASCAL VOC XML annotation
4517
- from lxml import etree as ET
4518
-
4519
- annotation = ET.Element("annotation")
4520
- ET.SubElement(annotation, "folder").text = os.path.basename(
4521
- output_images_dir
4522
- )
4523
- ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
4524
- ET.SubElement(annotation, "path").text = image_path
4525
-
4526
- source = ET.SubElement(annotation, "source")
4527
- ET.SubElement(source, "database").text = "GeoAI"
4528
-
4529
- size = ET.SubElement(annotation, "size")
4530
- ET.SubElement(size, "width").text = str(tile_size)
4531
- ET.SubElement(size, "height").text = str(tile_size)
4532
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
4533
-
4534
- ET.SubElement(annotation, "segmented").text = "1"
4535
-
4536
- # Find connected components for instance segmentation
4537
- from scipy import ndimage
4538
-
4539
- for class_id in np.unique(label_mask):
4540
- if class_id == 0:
4541
- continue
4542
-
4543
- class_mask = (label_mask == class_id).astype(np.uint8)
4544
- labeled_array, num_features = ndimage.label(class_mask)
4545
-
4546
- for instance_id in range(1, num_features + 1):
4547
- instance_mask = labeled_array == instance_id
4548
- coords = np.argwhere(instance_mask)
4549
-
4550
- if len(coords) == 0:
4551
- continue
4552
-
4553
- ymin, xmin = coords.min(axis=0)
4554
- ymax, xmax = coords.max(axis=0)
4555
-
4556
- obj = ET.SubElement(annotation, "object")
4557
- class_name = next(
4558
- (k for k, v in class_to_id.items() if v == class_id),
4559
- str(class_id),
4560
- )
4561
- ET.SubElement(obj, "name").text = str(class_name)
4562
- ET.SubElement(obj, "pose").text = "Unspecified"
4563
- ET.SubElement(obj, "truncated").text = "0"
4564
- ET.SubElement(obj, "difficult").text = "0"
4565
-
4566
- bndbox = ET.SubElement(obj, "bndbox")
4567
- ET.SubElement(bndbox, "xmin").text = str(int(xmin))
4568
- ET.SubElement(bndbox, "ymin").text = str(int(ymin))
4569
- ET.SubElement(bndbox, "xmax").text = str(int(xmax))
4570
- ET.SubElement(bndbox, "ymax").text = str(int(ymax))
4571
-
4572
- # Save XML file
4573
- xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
4574
- tree = ET.ElementTree(annotation)
4575
- tree.write(xml_path, pretty_print=True, encoding="utf-8")
4576
-
4577
- elif mask_file is not None and metadata_format == "COCO":
4578
- # Add COCO image entry
4579
- image_id = int(global_tile_counter + tile_index)
4580
- stats["coco_data"]["images"].append(
4581
- {
4582
- "id": image_id,
4583
- "file_name": f"{tile_name}.tif",
4584
- "width": int(tile_size),
4585
- "height": int(tile_size),
4586
- }
4587
- )
4588
-
4589
- # Add COCO categories (only once per unique class)
4590
- for class_val, class_id in class_to_id.items():
4591
- if not any(
4592
- c["id"] == class_id
4593
- for c in stats["coco_data"]["categories"]
4594
- ):
4595
- stats["coco_data"]["categories"].append(
4596
- {
4597
- "id": int(class_id),
4598
- "name": str(class_val),
4599
- "supercategory": "object",
4600
- }
4601
- )
4602
-
4603
- # Add COCO annotations (instance segmentation)
4604
- from scipy import ndimage
4605
- from skimage import measure
4606
-
4607
- for class_id in np.unique(label_mask):
4608
- if class_id == 0:
4609
- continue
4610
-
4611
- class_mask = (label_mask == class_id).astype(np.uint8)
4612
- labeled_array, num_features = ndimage.label(class_mask)
4613
-
4614
- for instance_id in range(1, num_features + 1):
4615
- instance_mask = (labeled_array == instance_id).astype(
4616
- np.uint8
4617
- )
4618
- coords = np.argwhere(instance_mask)
4619
-
4620
- if len(coords) == 0:
4621
- continue
4622
-
4623
- ymin, xmin = coords.min(axis=0)
4624
- ymax, xmax = coords.max(axis=0)
4625
-
4626
- bbox = [
4627
- int(xmin),
4628
- int(ymin),
4629
- int(xmax - xmin),
4630
- int(ymax - ymin),
4631
- ]
4632
- area = int(np.sum(instance_mask))
4633
-
4634
- # Find contours for segmentation
4635
- contours = measure.find_contours(instance_mask, 0.5)
4636
- segmentation = []
4637
- for contour in contours:
4638
- contour = np.flip(contour, axis=1)
4639
- segmentation_points = contour.ravel().tolist()
4640
- if len(segmentation_points) >= 6:
4641
- segmentation.append(segmentation_points)
4642
-
4643
- if segmentation:
4644
- stats["coco_data"]["annotations"].append(
4645
- {
4646
- "id": int(coco_ann_id),
4647
- "image_id": int(image_id),
4648
- "category_id": int(class_id),
4649
- "bbox": bbox,
4650
- "area": area,
4651
- "segmentation": segmentation,
4652
- "iscrowd": 0,
4653
- }
4654
- )
4655
- coco_ann_id += 1
4656
-
4657
- elif mask_file is not None and metadata_format == "YOLO":
4658
- # Create YOLO labels directory if needed
4659
- labels_dir = os.path.join(
4660
- os.path.dirname(output_images_dir), "labels"
4661
- )
4662
- os.makedirs(labels_dir, exist_ok=True)
4663
-
4664
- # Generate YOLO annotation file
4665
- yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
4666
- from scipy import ndimage
4667
-
4668
- with open(yolo_path, "w") as yolo_file:
4669
- for class_id in np.unique(label_mask):
4670
- if class_id == 0:
4671
- continue
4672
-
4673
- # Track class for classes.txt
4674
- class_name = next(
4675
- (k for k, v in class_to_id.items() if v == class_id),
4676
- str(class_id),
4677
- )
4678
- stats["yolo_classes"].add(class_name)
4679
-
4680
- class_mask = (label_mask == class_id).astype(np.uint8)
4681
- labeled_array, num_features = ndimage.label(class_mask)
4682
-
4683
- for instance_id in range(1, num_features + 1):
4684
- instance_mask = labeled_array == instance_id
4685
- coords = np.argwhere(instance_mask)
4686
-
4687
- if len(coords) == 0:
4688
- continue
4689
-
4690
- ymin, xmin = coords.min(axis=0)
4691
- ymax, xmax = coords.max(axis=0)
3596
+ # Create profile for label GeoTIFF
3597
+ label_profile = {
3598
+ "driver": "GTiff",
3599
+ "height": tile_size,
3600
+ "width": tile_size,
3601
+ "count": 1,
3602
+ "dtype": "uint8",
3603
+ "crs": src.crs,
3604
+ "transform": window_transform,
3605
+ }
4692
3606
 
4693
- # Convert to YOLO format (normalized center coordinates)
4694
- x_center = ((xmin + xmax) / 2) / tile_size
4695
- y_center = ((ymin + ymax) / 2) / tile_size
4696
- width = (xmax - xmin) / tile_size
4697
- height = (ymax - ymin) / tile_size
3607
+ # Export label as GeoTIFF
3608
+ label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
3609
+ try:
3610
+ with rasterio.open(label_path, "w", **label_profile) as dst:
3611
+ dst.write(label_mask.astype(np.uint8), 1)
4698
3612
 
4699
- # YOLO uses 0-based class indices
4700
- yolo_class_id = class_id - 1
4701
- yolo_file.write(
4702
- f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
4703
- )
3613
+ if has_features:
3614
+ stats["tiles_with_features"] += 1
3615
+ except Exception as e:
3616
+ if not quiet:
3617
+ print(f"ERROR saving label GeoTIFF: {e}")
3618
+ stats["errors"] += 1
4704
3619
 
4705
3620
  tile_index += 1
4706
3621
  if tile_index >= max_tiles:
@@ -4712,179 +3627,6 @@ def _process_image_mask_pair(
4712
3627
  return stats
4713
3628
 
4714
3629
 
4715
- def display_training_tiles(
4716
- output_dir,
4717
- num_tiles=6,
4718
- figsize=(18, 6),
4719
- cmap="gray",
4720
- save_path=None,
4721
- ):
4722
- """
4723
- Display image and mask tile pairs from training data output.
4724
-
4725
- Args:
4726
- output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
4727
- num_tiles (int): Number of tile pairs to display (default: 6)
4728
- figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
4729
- cmap (str): Colormap for mask display (default: 'gray')
4730
- save_path (str, optional): If provided, save figure to this path instead of displaying
4731
-
4732
- Returns:
4733
- tuple: (fig, axes) matplotlib figure and axes objects
4734
-
4735
- Example:
4736
- >>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
4737
- >>> # Or save to file
4738
- >>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
4739
- """
4740
- import matplotlib.pyplot as plt
4741
-
4742
- # Get list of image tiles
4743
- images_dir = os.path.join(output_dir, "images")
4744
- if not os.path.exists(images_dir):
4745
- raise ValueError(f"Images directory not found: {images_dir}")
4746
-
4747
- image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
4748
-
4749
- if not image_tiles:
4750
- raise ValueError(f"No image tiles found in {images_dir}")
4751
-
4752
- # Limit to available tiles
4753
- num_tiles = min(num_tiles, len(image_tiles))
4754
-
4755
- # Create figure with subplots
4756
- fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
4757
-
4758
- # Handle case where num_tiles is 1
4759
- if num_tiles == 1:
4760
- axes = axes.reshape(2, 1)
4761
-
4762
- for idx, tile_name in enumerate(image_tiles):
4763
- # Load and display image tile
4764
- image_path = os.path.join(output_dir, "images", tile_name)
4765
- with rasterio.open(image_path) as src:
4766
- show(src, ax=axes[0, idx], title=f"Image {idx+1}")
4767
-
4768
- # Load and display mask tile
4769
- mask_path = os.path.join(output_dir, "masks", tile_name)
4770
- if os.path.exists(mask_path):
4771
- with rasterio.open(mask_path) as src:
4772
- show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
4773
- else:
4774
- axes[1, idx].text(
4775
- 0.5,
4776
- 0.5,
4777
- "Mask not found",
4778
- ha="center",
4779
- va="center",
4780
- transform=axes[1, idx].transAxes,
4781
- )
4782
- axes[1, idx].set_title(f"Mask {idx+1}")
4783
-
4784
- plt.tight_layout()
4785
-
4786
- # Save or show
4787
- if save_path:
4788
- plt.savefig(save_path, dpi=150, bbox_inches="tight")
4789
- plt.close(fig)
4790
- print(f"Figure saved to: {save_path}")
4791
- else:
4792
- plt.show()
4793
-
4794
- return fig, axes
4795
-
4796
-
4797
- def display_image_with_vector(
4798
- image_path,
4799
- vector_path,
4800
- figsize=(16, 8),
4801
- vector_color="red",
4802
- vector_linewidth=1,
4803
- vector_facecolor="none",
4804
- save_path=None,
4805
- ):
4806
- """
4807
- Display a raster image alongside the same image with vector overlay.
4808
-
4809
- Args:
4810
- image_path (str): Path to raster image file
4811
- vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
4812
- figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
4813
- vector_color (str): Edge color for vector features (default: 'red')
4814
- vector_linewidth (float): Line width for vector features (default: 1)
4815
- vector_facecolor (str): Fill color for vector features (default: 'none')
4816
- save_path (str, optional): If provided, save figure to this path instead of displaying
4817
-
4818
- Returns:
4819
- tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
4820
-
4821
- Example:
4822
- >>> fig, axes, info = display_image_with_vector(
4823
- ... 'image.tif',
4824
- ... 'buildings.geojson',
4825
- ... vector_color='blue'
4826
- ... )
4827
- >>> print(f"Number of features: {info['num_features']}")
4828
- """
4829
- import matplotlib.pyplot as plt
4830
-
4831
- # Validate inputs
4832
- if not os.path.exists(image_path):
4833
- raise ValueError(f"Image file not found: {image_path}")
4834
- if not os.path.exists(vector_path):
4835
- raise ValueError(f"Vector file not found: {vector_path}")
4836
-
4837
- # Create figure
4838
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
4839
-
4840
- # Load and display image
4841
- with rasterio.open(image_path) as src:
4842
- # Plot image only
4843
- show(src, ax=ax1, title="Image")
4844
-
4845
- # Load vector data
4846
- vector_data = gpd.read_file(vector_path)
4847
-
4848
- # Reproject to image CRS if needed
4849
- if vector_data.crs != src.crs:
4850
- vector_data = vector_data.to_crs(src.crs)
4851
-
4852
- # Plot image with vector overlay
4853
- show(
4854
- src,
4855
- ax=ax2,
4856
- title=f"Image with {len(vector_data)} Vector Features",
4857
- )
4858
- vector_data.plot(
4859
- ax=ax2,
4860
- facecolor=vector_facecolor,
4861
- edgecolor=vector_color,
4862
- linewidth=vector_linewidth,
4863
- )
4864
-
4865
- # Collect metadata
4866
- info = {
4867
- "image_shape": src.shape,
4868
- "image_crs": src.crs,
4869
- "image_bounds": src.bounds,
4870
- "num_features": len(vector_data),
4871
- "vector_crs": vector_data.crs,
4872
- "vector_bounds": vector_data.total_bounds,
4873
- }
4874
-
4875
- plt.tight_layout()
4876
-
4877
- # Save or show
4878
- if save_path:
4879
- plt.savefig(save_path, dpi=150, bbox_inches="tight")
4880
- plt.close(fig)
4881
- print(f"Figure saved to: {save_path}")
4882
- else:
4883
- plt.show()
4884
-
4885
- return fig, (ax1, ax2), info
4886
-
4887
-
4888
3630
  def create_overview_image(
4889
3631
  src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
4890
3632
  ) -> str:
@@ -8779,39 +7521,17 @@ def write_colormap(
8779
7521
 
8780
7522
  def plot_performance_metrics(
8781
7523
  history_path: str,
8782
- figsize: Optional[Tuple[int, int]] = None,
7524
+ figsize: Tuple[int, int] = (15, 5),
8783
7525
  verbose: bool = True,
8784
7526
  save_path: Optional[str] = None,
8785
- csv_path: Optional[str] = None,
8786
7527
  kwargs: Optional[Dict] = None,
8787
- ) -> pd.DataFrame:
8788
- """Plot performance metrics from a training history object and return as DataFrame.
8789
-
8790
- This function loads training history, plots available metrics (loss, IoU, F1,
8791
- precision, recall), optionally exports to CSV, and returns all metrics as a
8792
- pandas DataFrame for further analysis.
7528
+ ) -> None:
7529
+ """Plot performance metrics from a history object.
8793
7530
 
8794
7531
  Args:
8795
- history_path (str): Path to the saved training history (.pth file).
8796
- figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
8797
- automatically determined based on number of metrics.
8798
- verbose (bool): Whether to print best and final metric values. Defaults to True.
8799
- save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
8800
- csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
8801
- kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
8802
-
8803
- Returns:
8804
- pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
8805
- Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
8806
- 'val_precision', 'val_recall' (depending on availability in history).
8807
-
8808
- Example:
8809
- >>> df = plot_performance_metrics(
8810
- ... 'training_history.pth',
8811
- ... save_path='metrics_plot.png',
8812
- ... csv_path='metrics.csv'
8813
- ... )
8814
- >>> print(df.head())
7532
+ history_path: The history object to plot.
7533
+ figsize: The figure size.
7534
+ verbose: Whether to print the best and final metrics.
8815
7535
  """
8816
7536
  if kwargs is None:
8817
7537
  kwargs = {}
@@ -8821,135 +7541,65 @@ def plot_performance_metrics(
8821
7541
  train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
8822
7542
  val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
8823
7543
  val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
8824
- # Support both new (f1) and old (dice) key formats for backward compatibility
8825
- val_f1_key = (
8826
- "val_f1s"
8827
- if "val_f1s" in history
8828
- else ("val_dices" if "val_dices" in history else "val_dice")
8829
- )
8830
- # Add support for precision and recall
8831
- val_precision_key = (
8832
- "val_precisions" if "val_precisions" in history else "val_precision"
8833
- )
8834
- val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
8835
-
8836
- # Collect available metrics for plotting
8837
- available_metrics = []
8838
- metric_info = {
8839
- "Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
8840
- "IoU": (val_iou_key, None, ["Val IoU"]),
8841
- "F1": (val_f1_key, None, ["Val F1"]),
8842
- "Precision": (val_precision_key, None, ["Val Precision"]),
8843
- "Recall": (val_recall_key, None, ["Val Recall"]),
8844
- }
8845
-
8846
- for metric_name, (key1, key2, labels) in metric_info.items():
8847
- if key1 in history or (key2 and key2 in history):
8848
- available_metrics.append((metric_name, key1, key2, labels))
8849
-
8850
- # Determine number of subplots and figure size
8851
- n_plots = len(available_metrics)
8852
- if figsize is None:
8853
- figsize = (5 * n_plots, 5)
7544
+ val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
8854
7545
 
8855
- # Create DataFrame for all metrics
8856
- n_epochs = 0
8857
- df_data = {}
7546
+ # Determine number of subplots based on available metrics
7547
+ has_dice = val_dice_key in history
7548
+ n_plots = 3 if has_dice else 2
7549
+ figsize = (15, 5) if has_dice else (10, 5)
8858
7550
 
8859
- # Add epochs
8860
- if "epochs" in history:
8861
- df_data["epoch"] = history["epochs"]
8862
- n_epochs = len(history["epochs"])
8863
- elif train_loss_key in history:
8864
- n_epochs = len(history[train_loss_key])
8865
- df_data["epoch"] = list(range(1, n_epochs + 1))
7551
+ plt.figure(figsize=figsize)
8866
7552
 
8867
- # Add all available metrics to DataFrame
7553
+ # Plot loss
7554
+ plt.subplot(1, n_plots, 1)
8868
7555
  if train_loss_key in history:
8869
- df_data["train_loss"] = history[train_loss_key]
7556
+ plt.plot(history[train_loss_key], label="Train Loss")
8870
7557
  if val_loss_key in history:
8871
- df_data["val_loss"] = history[val_loss_key]
7558
+ plt.plot(history[val_loss_key], label="Val Loss")
7559
+ plt.title("Loss")
7560
+ plt.xlabel("Epoch")
7561
+ plt.ylabel("Loss")
7562
+ plt.legend()
7563
+ plt.grid(True)
7564
+
7565
+ # Plot IoU
7566
+ plt.subplot(1, n_plots, 2)
8872
7567
  if val_iou_key in history:
8873
- df_data["val_iou"] = history[val_iou_key]
8874
- if val_f1_key in history:
8875
- df_data["val_f1"] = history[val_f1_key]
8876
- if val_precision_key in history:
8877
- df_data["val_precision"] = history[val_precision_key]
8878
- if val_recall_key in history:
8879
- df_data["val_recall"] = history[val_recall_key]
8880
-
8881
- # Create DataFrame
8882
- df = pd.DataFrame(df_data)
8883
-
8884
- # Export to CSV if requested
8885
- if csv_path:
8886
- df.to_csv(csv_path, index=False)
8887
- if verbose:
8888
- print(f"Metrics exported to: {csv_path}")
8889
-
8890
- # Create plots
8891
- if n_plots > 0:
8892
- fig, axes = plt.subplots(1, n_plots, figsize=figsize)
8893
- if n_plots == 1:
8894
- axes = [axes]
8895
-
8896
- for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
8897
- ax = axes[idx]
8898
-
8899
- if metric_name == "Loss":
8900
- # Special handling for loss (has both train and val)
8901
- if key1 in history:
8902
- ax.plot(history[key1], label=labels[0])
8903
- if key2 and key2 in history:
8904
- ax.plot(history[key2], label=labels[1])
8905
- else:
8906
- # Single metric plots
8907
- if key1 in history:
8908
- ax.plot(history[key1], label=labels[0])
8909
-
8910
- ax.set_title(metric_name)
8911
- ax.set_xlabel("Epoch")
8912
- ax.set_ylabel(metric_name)
8913
- ax.legend()
8914
- ax.grid(True)
7568
+ plt.plot(history[val_iou_key], label="Val IoU")
7569
+ plt.title("IoU Score")
7570
+ plt.xlabel("Epoch")
7571
+ plt.ylabel("IoU")
7572
+ plt.legend()
7573
+ plt.grid(True)
7574
+
7575
+ # Plot Dice if available
7576
+ if has_dice:
7577
+ plt.subplot(1, n_plots, 3)
7578
+ plt.plot(history[val_dice_key], label="Val Dice")
7579
+ plt.title("Dice Score")
7580
+ plt.xlabel("Epoch")
7581
+ plt.ylabel("Dice")
7582
+ plt.legend()
7583
+ plt.grid(True)
8915
7584
 
8916
- plt.tight_layout()
7585
+ plt.tight_layout()
8917
7586
 
8918
- if save_path:
8919
- if "dpi" not in kwargs:
8920
- kwargs["dpi"] = 150
8921
- if "bbox_inches" not in kwargs:
8922
- kwargs["bbox_inches"] = "tight"
8923
- plt.savefig(save_path, **kwargs)
7587
+ if save_path:
7588
+ if "dpi" not in kwargs:
7589
+ kwargs["dpi"] = 150
7590
+ if "bbox_inches" not in kwargs:
7591
+ kwargs["bbox_inches"] = "tight"
7592
+ plt.savefig(save_path, **kwargs)
8924
7593
 
8925
- plt.show()
7594
+ plt.show()
8926
7595
 
8927
- # Print summary statistics
8928
7596
  if verbose:
8929
- print("\n=== Performance Metrics Summary ===")
8930
7597
  if val_iou_key in history:
8931
- print(
8932
- f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
8933
- )
8934
- if val_f1_key in history:
8935
- print(
8936
- f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
8937
- )
8938
- if val_precision_key in history:
8939
- print(
8940
- f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
8941
- )
8942
- if val_recall_key in history:
8943
- print(
8944
- f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
8945
- )
8946
- if val_loss_key in history:
8947
- print(
8948
- f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
8949
- )
8950
- print("===================================\n")
8951
-
8952
- return df
7598
+ print(f"Best IoU: {max(history[val_iou_key]):.4f}")
7599
+ print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
7600
+ if val_dice_key in history:
7601
+ print(f"Best Dice: {max(history[val_dice_key]):.4f}")
7602
+ print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
8953
7603
 
8954
7604
 
8955
7605
  def get_device() -> torch.device: