geoai-py 0.18.2__py2.py3-none-any.whl → 0.19.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -64,7 +64,7 @@ def view_raster(
64
64
  client_args: Optional[Dict] = {"cors_all": False},
65
65
  basemap: Optional[str] = "OpenStreetMap",
66
66
  basemap_args: Optional[Dict] = None,
67
- backend: Optional[str] = "folium",
67
+ backend: Optional[str] = "ipyleaflet",
68
68
  **kwargs: Any,
69
69
  ) -> Any:
70
70
  """
@@ -87,7 +87,7 @@ def view_raster(
87
87
  client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
88
88
  basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
89
89
  basemap_args (Optional[Dict], optional): Additional arguments for the basemap. Defaults to None.
90
- backend (Optional[str], optional): The backend to use. Defaults to "folium".
90
+ backend (Optional[str], optional): The backend to use. Defaults to "ipyleaflet".
91
91
  **kwargs (Any): Additional keyword arguments.
92
92
 
93
93
  Returns:
@@ -123,39 +123,26 @@ def view_raster(
123
123
  if isinstance(source, dict):
124
124
  source = dict_to_image(source)
125
125
 
126
- if isinstance(source, str) and source.startswith("http"):
127
- if backend == "folium":
128
-
129
- m.add_geotiff(
130
- url=source,
131
- name=layer_name,
132
- opacity=opacity,
133
- attribution=attribution,
134
- fit_bounds=zoom_to_layer,
135
- palette=colormap,
136
- vmin=vmin,
137
- vmax=vmax,
138
- **kwargs,
139
- )
140
- m.add_layer_control()
141
- m.add_opacity_control()
142
-
143
- else:
144
- if indexes is not None:
145
- kwargs["bidx"] = indexes
146
- if colormap is not None:
147
- kwargs["colormap_name"] = colormap
148
- if attribution is None:
149
- attribution = "TiTiler"
150
-
151
- m.add_cog_layer(
152
- source,
153
- name=layer_name,
154
- opacity=opacity,
155
- attribution=attribution,
156
- zoom_to_layer=zoom_to_layer,
157
- **kwargs,
158
- )
126
+ if (
127
+ isinstance(source, str)
128
+ and source.lower().endswith(".tif")
129
+ and source.startswith("http")
130
+ ):
131
+ if indexes is not None:
132
+ kwargs["bidx"] = indexes
133
+ if colormap is not None:
134
+ kwargs["colormap_name"] = colormap
135
+ if attribution is None:
136
+ attribution = "TiTiler"
137
+
138
+ m.add_cog_layer(
139
+ source,
140
+ name=layer_name,
141
+ opacity=opacity,
142
+ attribution=attribution,
143
+ zoom_to_layer=zoom_to_layer,
144
+ **kwargs,
145
+ )
159
146
  else:
160
147
  m.add_raster(
161
148
  source=source,
@@ -237,8 +224,6 @@ def view_image(
237
224
  plt.show()
238
225
  plt.close()
239
226
 
240
- return ax
241
-
242
227
 
243
228
  def plot_images(
244
229
  images: Iterable[torch.Tensor],
@@ -396,394 +381,6 @@ def calc_stats(dataset, divide_by: float = 1.0) -> Tuple[np.ndarray, np.ndarray]
396
381
  return accum_mean / len(files), accum_std / len(files)
397
382
 
398
383
 
399
- def calc_iou(
400
- ground_truth: Union[str, np.ndarray, torch.Tensor],
401
- prediction: Union[str, np.ndarray, torch.Tensor],
402
- num_classes: Optional[int] = None,
403
- ignore_index: Optional[int] = None,
404
- smooth: float = 1e-6,
405
- band: int = 1,
406
- ) -> Union[float, np.ndarray]:
407
- """
408
- Calculate Intersection over Union (IoU) between ground truth and prediction masks.
409
-
410
- This function computes the IoU metric for segmentation tasks. It supports both
411
- binary and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
412
- or file paths to raster files.
413
-
414
- Args:
415
- ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
416
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
417
- For binary segmentation: shape (H, W) with values {0, 1}.
418
- For multi-class segmentation: shape (H, W) with class indices.
419
- prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
420
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
421
- Should have the same shape and format as ground_truth.
422
- num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
423
- If None, assumes binary segmentation. Defaults to None.
424
- ignore_index (Optional[int], optional): Class index to ignore in computation.
425
- Useful for ignoring background or unlabeled pixels. Defaults to None.
426
- smooth (float, optional): Smoothing factor to avoid division by zero.
427
- Defaults to 1e-6.
428
- band (int, optional): Band index to read from raster file (1-based indexing).
429
- Only used when input is a file path. Defaults to 1.
430
-
431
- Returns:
432
- Union[float, np.ndarray]: For binary segmentation, returns a single float IoU score.
433
- For multi-class segmentation, returns an array of IoU scores for each class.
434
-
435
- Examples:
436
- >>> # Binary segmentation with arrays
437
- >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
438
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
439
- >>> iou = calc_iou(gt, pred)
440
- >>> print(f"IoU: {iou:.4f}")
441
- IoU: 0.8333
442
-
443
- >>> # Multi-class segmentation
444
- >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
445
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
446
- >>> iou = calc_iou(gt, pred, num_classes=3)
447
- >>> print(f"IoU per class: {iou}")
448
- IoU per class: [0.8333 0.5000 0.5000]
449
-
450
- >>> # Using PyTorch tensors
451
- >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
452
- >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
453
- >>> iou = calc_iou(gt_tensor, pred_tensor)
454
- >>> print(f"IoU: {iou:.4f}")
455
- IoU: 0.8333
456
-
457
- >>> # Using raster file paths
458
- >>> iou = calc_iou("ground_truth.tif", "prediction.tif", num_classes=3)
459
- >>> print(f"Mean IoU: {np.nanmean(iou):.4f}")
460
- Mean IoU: 0.7500
461
- """
462
- # Load from file if string path is provided
463
- if isinstance(ground_truth, str):
464
- with rasterio.open(ground_truth) as src:
465
- ground_truth = src.read(band)
466
- if isinstance(prediction, str):
467
- with rasterio.open(prediction) as src:
468
- prediction = src.read(band)
469
-
470
- # Convert to numpy if torch tensor
471
- if isinstance(ground_truth, torch.Tensor):
472
- ground_truth = ground_truth.cpu().numpy()
473
- if isinstance(prediction, torch.Tensor):
474
- prediction = prediction.cpu().numpy()
475
-
476
- # Ensure inputs have the same shape
477
- if ground_truth.shape != prediction.shape:
478
- raise ValueError(
479
- f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
480
- )
481
-
482
- # Binary segmentation
483
- if num_classes is None:
484
- ground_truth = ground_truth.astype(bool)
485
- prediction = prediction.astype(bool)
486
-
487
- intersection = np.logical_and(ground_truth, prediction).sum()
488
- union = np.logical_or(ground_truth, prediction).sum()
489
-
490
- if union == 0:
491
- return 1.0 if intersection == 0 else 0.0
492
-
493
- iou = (intersection + smooth) / (union + smooth)
494
- return float(iou)
495
-
496
- # Multi-class segmentation
497
- else:
498
- iou_per_class = []
499
-
500
- for class_idx in range(num_classes):
501
- # Handle ignored class by appending np.nan
502
- if ignore_index is not None and class_idx == ignore_index:
503
- iou_per_class.append(np.nan)
504
- continue
505
-
506
- # Create binary masks for current class
507
- gt_class = (ground_truth == class_idx).astype(bool)
508
- pred_class = (prediction == class_idx).astype(bool)
509
-
510
- intersection = np.logical_and(gt_class, pred_class).sum()
511
- union = np.logical_or(gt_class, pred_class).sum()
512
-
513
- if union == 0:
514
- # If class is not present in both gt and pred
515
- iou_per_class.append(np.nan)
516
- else:
517
- iou_per_class.append((intersection + smooth) / (union + smooth))
518
-
519
- return np.array(iou_per_class)
520
-
521
-
522
- def calc_f1_score(
523
- ground_truth: Union[str, np.ndarray, torch.Tensor],
524
- prediction: Union[str, np.ndarray, torch.Tensor],
525
- num_classes: Optional[int] = None,
526
- ignore_index: Optional[int] = None,
527
- smooth: float = 1e-6,
528
- band: int = 1,
529
- ) -> Union[float, np.ndarray]:
530
- """
531
- Calculate F1 score between ground truth and prediction masks.
532
-
533
- The F1 score is the harmonic mean of precision and recall, computed as:
534
- F1 = 2 * (precision * recall) / (precision + recall)
535
- where precision = TP / (TP + FP) and recall = TP / (TP + FN).
536
-
537
- This function supports both binary and multi-class segmentation, and can handle
538
- numpy arrays, PyTorch tensors, or file paths to raster files.
539
-
540
- Args:
541
- ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
542
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
543
- For binary segmentation: shape (H, W) with values {0, 1}.
544
- For multi-class segmentation: shape (H, W) with class indices.
545
- prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
546
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
547
- Should have the same shape and format as ground_truth.
548
- num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
549
- If None, assumes binary segmentation. Defaults to None.
550
- ignore_index (Optional[int], optional): Class index to ignore in computation.
551
- Useful for ignoring background or unlabeled pixels. Defaults to None.
552
- smooth (float, optional): Smoothing factor to avoid division by zero.
553
- Defaults to 1e-6.
554
- band (int, optional): Band index to read from raster file (1-based indexing).
555
- Only used when input is a file path. Defaults to 1.
556
-
557
- Returns:
558
- Union[float, np.ndarray]: For binary segmentation, returns a single float F1 score.
559
- For multi-class segmentation, returns an array of F1 scores for each class.
560
-
561
- Examples:
562
- >>> # Binary segmentation with arrays
563
- >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
564
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
565
- >>> f1 = calc_f1_score(gt, pred)
566
- >>> print(f"F1 Score: {f1:.4f}")
567
- F1 Score: 0.8571
568
-
569
- >>> # Multi-class segmentation
570
- >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
571
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
572
- >>> f1 = calc_f1_score(gt, pred, num_classes=3)
573
- >>> print(f"F1 Score per class: {f1}")
574
- F1 Score per class: [0.8571 0.6667 0.6667]
575
-
576
- >>> # Using PyTorch tensors
577
- >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
578
- >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
579
- >>> f1 = calc_f1_score(gt_tensor, pred_tensor)
580
- >>> print(f"F1 Score: {f1:.4f}")
581
- F1 Score: 0.8571
582
-
583
- >>> # Using raster file paths
584
- >>> f1 = calc_f1_score("ground_truth.tif", "prediction.tif", num_classes=3)
585
- >>> print(f"Mean F1: {np.nanmean(f1):.4f}")
586
- Mean F1: 0.7302
587
- """
588
- # Load from file if string path is provided
589
- if isinstance(ground_truth, str):
590
- with rasterio.open(ground_truth) as src:
591
- ground_truth = src.read(band)
592
- if isinstance(prediction, str):
593
- with rasterio.open(prediction) as src:
594
- prediction = src.read(band)
595
-
596
- # Convert to numpy if torch tensor
597
- if isinstance(ground_truth, torch.Tensor):
598
- ground_truth = ground_truth.cpu().numpy()
599
- if isinstance(prediction, torch.Tensor):
600
- prediction = prediction.cpu().numpy()
601
-
602
- # Ensure inputs have the same shape
603
- if ground_truth.shape != prediction.shape:
604
- raise ValueError(
605
- f"Shape mismatch: ground_truth {ground_truth.shape} vs prediction {prediction.shape}"
606
- )
607
-
608
- # Binary segmentation
609
- if num_classes is None:
610
- ground_truth = ground_truth.astype(bool)
611
- prediction = prediction.astype(bool)
612
-
613
- # Calculate True Positives, False Positives, False Negatives
614
- tp = np.logical_and(ground_truth, prediction).sum()
615
- fp = np.logical_and(~ground_truth, prediction).sum()
616
- fn = np.logical_and(ground_truth, ~prediction).sum()
617
-
618
- # Calculate precision and recall
619
- precision = (tp + smooth) / (tp + fp + smooth)
620
- recall = (tp + smooth) / (tp + fn + smooth)
621
-
622
- # Calculate F1 score
623
- f1 = 2 * (precision * recall) / (precision + recall + smooth)
624
- return float(f1)
625
-
626
- # Multi-class segmentation
627
- else:
628
- f1_per_class = []
629
-
630
- for class_idx in range(num_classes):
631
- # Mark ignored class with np.nan
632
- if ignore_index is not None and class_idx == ignore_index:
633
- f1_per_class.append(np.nan)
634
- continue
635
-
636
- # Create binary masks for current class
637
- gt_class = (ground_truth == class_idx).astype(bool)
638
- pred_class = (prediction == class_idx).astype(bool)
639
-
640
- # Calculate True Positives, False Positives, False Negatives
641
- tp = np.logical_and(gt_class, pred_class).sum()
642
- fp = np.logical_and(~gt_class, pred_class).sum()
643
- fn = np.logical_and(gt_class, ~pred_class).sum()
644
-
645
- # Calculate precision and recall
646
- precision = (tp + smooth) / (tp + fp + smooth)
647
- recall = (tp + smooth) / (tp + fn + smooth)
648
-
649
- # Calculate F1 score
650
- if tp + fp + fn == 0:
651
- # If class is not present in both gt and pred
652
- f1_per_class.append(np.nan)
653
- else:
654
- f1 = 2 * (precision * recall) / (precision + recall + smooth)
655
- f1_per_class.append(f1)
656
-
657
- return np.array(f1_per_class)
658
-
659
-
660
- def calc_segmentation_metrics(
661
- ground_truth: Union[str, np.ndarray, torch.Tensor],
662
- prediction: Union[str, np.ndarray, torch.Tensor],
663
- num_classes: Optional[int] = None,
664
- ignore_index: Optional[int] = None,
665
- smooth: float = 1e-6,
666
- metrics: List[str] = ["iou", "f1"],
667
- band: int = 1,
668
- ) -> Dict[str, Union[float, np.ndarray]]:
669
- """
670
- Calculate multiple segmentation metrics between ground truth and prediction masks.
671
-
672
- This is a convenient wrapper function that computes multiple metrics at once,
673
- including IoU (Intersection over Union) and F1 score. It supports both binary
674
- and multi-class segmentation, and can handle numpy arrays, PyTorch tensors,
675
- or file paths to raster files.
676
-
677
- Args:
678
- ground_truth (Union[str, np.ndarray, torch.Tensor]): Ground truth segmentation mask.
679
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
680
- For binary segmentation: shape (H, W) with values {0, 1}.
681
- For multi-class segmentation: shape (H, W) with class indices.
682
- prediction (Union[str, np.ndarray, torch.Tensor]): Predicted segmentation mask.
683
- Can be a file path (str) to a raster file, numpy array, or PyTorch tensor.
684
- Should have the same shape and format as ground_truth.
685
- num_classes (Optional[int], optional): Number of classes for multi-class segmentation.
686
- If None, assumes binary segmentation. Defaults to None.
687
- ignore_index (Optional[int], optional): Class index to ignore in computation.
688
- Useful for ignoring background or unlabeled pixels. Defaults to None.
689
- smooth (float, optional): Smoothing factor to avoid division by zero.
690
- Defaults to 1e-6.
691
- metrics (List[str], optional): List of metrics to calculate.
692
- Options: "iou", "f1". Defaults to ["iou", "f1"].
693
- band (int, optional): Band index to read from raster file (1-based indexing).
694
- Only used when input is a file path. Defaults to 1.
695
-
696
- Returns:
697
- Dict[str, Union[float, np.ndarray]]: Dictionary containing the computed metrics.
698
- Keys are metric names ("iou", "f1"), values are the metric scores.
699
- For binary segmentation, values are floats.
700
- For multi-class segmentation, values are numpy arrays with per-class scores.
701
- Also includes "mean_iou" and "mean_f1" for multi-class segmentation
702
- (mean computed over valid classes, ignoring NaN values).
703
-
704
- Examples:
705
- >>> # Binary segmentation with arrays
706
- >>> gt = np.array([[0, 0, 1, 1], [0, 1, 1, 1]])
707
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 1, 1]])
708
- >>> metrics = calc_segmentation_metrics(gt, pred)
709
- >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
710
- IoU: 0.8333, F1: 0.8571
711
-
712
- >>> # Multi-class segmentation
713
- >>> gt = np.array([[0, 0, 1, 1], [0, 2, 2, 1]])
714
- >>> pred = np.array([[0, 0, 1, 1], [0, 0, 2, 2]])
715
- >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3)
716
- >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
717
- >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
718
- >>> print(f"Per-class IoU: {metrics['iou']}")
719
- Mean IoU: 0.6111
720
- Mean F1: 0.7302
721
- Per-class IoU: [0.8333 0.5000 0.5000]
722
-
723
- >>> # Calculate only IoU
724
- >>> metrics = calc_segmentation_metrics(gt, pred, num_classes=3, metrics=["iou"])
725
- >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
726
- Mean IoU: 0.6111
727
-
728
- >>> # Using PyTorch tensors
729
- >>> gt_tensor = torch.tensor([[0, 0, 1, 1], [0, 1, 1, 1]])
730
- >>> pred_tensor = torch.tensor([[0, 0, 1, 1], [0, 0, 1, 1]])
731
- >>> metrics = calc_segmentation_metrics(gt_tensor, pred_tensor)
732
- >>> print(f"IoU: {metrics['iou']:.4f}, F1: {metrics['f1']:.4f}")
733
- IoU: 0.8333, F1: 0.8571
734
-
735
- >>> # Using raster file paths
736
- >>> metrics = calc_segmentation_metrics("ground_truth.tif", "prediction.tif", num_classes=3)
737
- >>> print(f"Mean IoU: {metrics['mean_iou']:.4f}")
738
- >>> print(f"Mean F1: {metrics['mean_f1']:.4f}")
739
- Mean IoU: 0.6111
740
- Mean F1: 0.7302
741
- """
742
- results = {}
743
-
744
- # Calculate IoU if requested
745
- if "iou" in metrics:
746
- iou = calc_iou(
747
- ground_truth,
748
- prediction,
749
- num_classes=num_classes,
750
- ignore_index=ignore_index,
751
- smooth=smooth,
752
- band=band,
753
- )
754
- results["iou"] = iou
755
-
756
- # Add mean IoU for multi-class
757
- if num_classes is not None and isinstance(iou, np.ndarray):
758
- # Calculate mean ignoring NaN values
759
- valid_ious = iou[~np.isnan(iou)]
760
- results["mean_iou"] = (
761
- float(np.mean(valid_ious)) if len(valid_ious) > 0 else 0.0
762
- )
763
-
764
- # Calculate F1 score if requested
765
- if "f1" in metrics:
766
- f1 = calc_f1_score(
767
- ground_truth,
768
- prediction,
769
- num_classes=num_classes,
770
- ignore_index=ignore_index,
771
- smooth=smooth,
772
- band=band,
773
- )
774
- results["f1"] = f1
775
-
776
- # Add mean F1 for multi-class
777
- if num_classes is not None and isinstance(f1, np.ndarray):
778
- # Calculate mean ignoring NaN values
779
- valid_f1s = f1[~np.isnan(f1)]
780
- results["mean_f1"] = (
781
- float(np.mean(valid_f1s)) if len(valid_f1s) > 0 else 0.0
782
- )
783
-
784
- return results
785
-
786
-
787
384
  def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
788
385
  """Convert a dictionary to a xarray DataArray. The dictionary should contain the
789
386
  following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
@@ -1094,9 +691,8 @@ def view_vector(
1094
691
 
1095
692
  def view_vector_interactive(
1096
693
  vector_data: Union[str, gpd.GeoDataFrame],
1097
- layer_name: str = "Vector",
694
+ layer_name: str = "Vector Layer",
1098
695
  tiles_args: Optional[Dict] = None,
1099
- opacity: float = 0.7,
1100
696
  **kwargs: Any,
1101
697
  ) -> Any:
1102
698
  """
@@ -1111,7 +707,6 @@ def view_vector_interactive(
1111
707
  layer_name (str, optional): The name of the layer. Defaults to "Vector Layer".
1112
708
  tiles_args (dict, optional): Additional arguments for the localtileserver client.
1113
709
  get_folium_tile_layer function. Defaults to None.
1114
- opacity (float, optional): The opacity of the layer. Defaults to 0.7.
1115
710
  **kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
1116
711
  See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
1117
712
 
@@ -1126,8 +721,9 @@ def view_vector_interactive(
1126
721
  >>> roads = gpd.read_file("roads.shp")
1127
722
  >>> view_vector_interactive(roads, figsize=(12, 8))
1128
723
  """
1129
-
1130
- from leafmap.foliumap import Map
724
+ import folium
725
+ import folium.plugins as plugins
726
+ from leafmap import cog_tile
1131
727
  from localtileserver import TileClient, get_folium_tile_layer
1132
728
 
1133
729
  google_tiles = {
@@ -1153,17 +749,9 @@ def view_vector_interactive(
1153
749
  },
1154
750
  }
1155
751
 
1156
- # Make it compatible with binder and JupyterHub
1157
- if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
1158
- os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
1159
- f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
1160
- )
1161
-
1162
752
  basemap_layer_name = None
1163
753
  raster_layer = None
1164
754
 
1165
- m = Map()
1166
-
1167
755
  if "tiles" in kwargs and isinstance(kwargs["tiles"], str):
1168
756
  if kwargs["tiles"].title() in google_tiles:
1169
757
  basemap_layer_name = google_tiles[kwargs["tiles"].title()]["name"]
@@ -1174,17 +762,14 @@ def view_vector_interactive(
1174
762
  tiles_args = {}
1175
763
  if kwargs["tiles"].lower().startswith("http"):
1176
764
  basemap_layer_name = "Remote Raster"
1177
- m.add_geotiff(kwargs["tiles"], name=basemap_layer_name, **tiles_args)
765
+ kwargs["tiles"] = cog_tile(kwargs["tiles"], **tiles_args)
766
+ kwargs["attr"] = "TiTiler"
1178
767
  else:
1179
768
  basemap_layer_name = "Local Raster"
1180
769
  client = TileClient(kwargs["tiles"])
1181
770
  raster_layer = get_folium_tile_layer(client, **tiles_args)
1182
- m.add_tile_layer(
1183
- raster_layer.tiles,
1184
- name=basemap_layer_name,
1185
- attribution="localtileserver",
1186
- **tiles_args,
1187
- )
771
+ kwargs["tiles"] = raster_layer.tiles
772
+ kwargs["attr"] = "localtileserver"
1188
773
 
1189
774
  if "max_zoom" not in kwargs:
1190
775
  kwargs["max_zoom"] = 30
@@ -1199,18 +784,23 @@ def view_vector_interactive(
1199
784
  if not isinstance(vector_data, gpd.GeoDataFrame):
1200
785
  raise TypeError("Input data must be a GeoDataFrame")
1201
786
 
1202
- if "column" in kwargs:
1203
- if "legend_position" not in kwargs:
1204
- kwargs["legend_position"] = "bottomleft"
1205
- if "cmap" not in kwargs:
1206
- kwargs["cmap"] = "viridis"
1207
- m.add_data(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
787
+ layer_control = kwargs.pop("layer_control", True)
788
+ fullscreen_control = kwargs.pop("fullscreen_control", True)
1208
789
 
1209
- else:
1210
- m.add_gdf(vector_data, layer_name=layer_name, opacity=opacity, **kwargs)
790
+ m = vector_data.explore(**kwargs)
791
+
792
+ # Change the layer name
793
+ for layer in m._children.values():
794
+ if isinstance(layer, folium.GeoJson):
795
+ layer.layer_name = layer_name
796
+ if isinstance(layer, folium.TileLayer) and basemap_layer_name:
797
+ layer.layer_name = basemap_layer_name
798
+
799
+ if layer_control:
800
+ m.add_child(folium.LayerControl())
1211
801
 
1212
- m.add_layer_control()
1213
- m.add_opacity_control()
802
+ if fullscreen_control:
803
+ plugins.Fullscreen().add_to(m)
1214
804
 
1215
805
  return m
1216
806
 
@@ -3004,86 +2594,10 @@ def batch_vector_to_raster(
3004
2594
  return output_files
3005
2595
 
3006
2596
 
3007
- def get_default_augmentation_transforms(
3008
- tile_size: int = 256,
3009
- include_normalize: bool = False,
3010
- mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
3011
- std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
3012
- ) -> Any:
3013
- """
3014
- Get default data augmentation transforms for geospatial imagery using albumentations.
3015
-
3016
- This function returns a composition of augmentation transforms commonly used
3017
- for remote sensing and geospatial data. The transforms include geometric
3018
- transformations (flips, rotations) and photometric adjustments (brightness,
3019
- contrast, saturation).
3020
-
3021
- Args:
3022
- tile_size (int): Target size for tiles. Defaults to 256.
3023
- include_normalize (bool): Whether to include normalization transform.
3024
- Defaults to False. Set to True if using for training with pretrained models.
3025
- mean (tuple): Mean values for normalization (RGB). Defaults to ImageNet values.
3026
- std (tuple): Standard deviation for normalization (RGB). Defaults to ImageNet values.
3027
-
3028
- Returns:
3029
- albumentations.Compose: A composition of augmentation transforms.
3030
-
3031
- Example:
3032
- >>> import albumentations as A
3033
- >>> # Get default transforms
3034
- >>> transform = get_default_augmentation_transforms()
3035
- >>> # Apply to image and mask
3036
- >>> augmented = transform(image=image, mask=mask)
3037
- >>> aug_image = augmented['image']
3038
- >>> aug_mask = augmented['mask']
3039
- """
3040
- try:
3041
- import albumentations as A
3042
- except ImportError:
3043
- raise ImportError(
3044
- "albumentations is required for data augmentation. "
3045
- "Install it with: pip install albumentations"
3046
- )
3047
-
3048
- transforms_list = [
3049
- # Geometric transforms
3050
- A.HorizontalFlip(p=0.5),
3051
- A.VerticalFlip(p=0.5),
3052
- A.RandomRotate90(p=0.5),
3053
- A.ShiftScaleRotate(
3054
- shift_limit=0.1,
3055
- scale_limit=0.1,
3056
- rotate_limit=45,
3057
- border_mode=0,
3058
- p=0.5,
3059
- ),
3060
- # Photometric transforms
3061
- A.RandomBrightnessContrast(
3062
- brightness_limit=0.2,
3063
- contrast_limit=0.2,
3064
- p=0.5,
3065
- ),
3066
- A.HueSaturationValue(
3067
- hue_shift_limit=10,
3068
- sat_shift_limit=20,
3069
- val_shift_limit=10,
3070
- p=0.3,
3071
- ),
3072
- A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
3073
- A.GaussianBlur(blur_limit=(3, 5), p=0.2),
3074
- ]
3075
-
3076
- # Add normalization if requested
3077
- if include_normalize:
3078
- transforms_list.append(A.Normalize(mean=mean, std=std))
3079
-
3080
- return A.Compose(transforms_list)
3081
-
3082
-
3083
2597
  def export_geotiff_tiles(
3084
2598
  in_raster,
3085
2599
  out_folder,
3086
- in_class_data=None,
2600
+ in_class_data,
3087
2601
  tile_size=256,
3088
2602
  stride=128,
3089
2603
  class_value_field="class",
@@ -3093,10 +2607,6 @@ def export_geotiff_tiles(
3093
2607
  all_touched=True,
3094
2608
  create_overview=False,
3095
2609
  skip_empty_tiles=False,
3096
- metadata_format="PASCAL_VOC",
3097
- apply_augmentation=False,
3098
- augmentation_count=3,
3099
- augmentation_transforms=None,
3100
2610
  ):
3101
2611
  """
3102
2612
  Export georeferenced GeoTIFF tiles and labels from raster and classification data.
@@ -3104,8 +2614,7 @@ def export_geotiff_tiles(
3104
2614
  Args:
3105
2615
  in_raster (str): Path to input raster image
3106
2616
  out_folder (str): Path to output folder
3107
- in_class_data (str, optional): Path to classification data - can be vector file or raster.
3108
- If None, only image tiles will be exported without labels. Defaults to None.
2617
+ in_class_data (str): Path to classification data - can be vector file or raster
3109
2618
  tile_size (int): Size of tiles in pixels (square)
3110
2619
  stride (int): Step size between tiles
3111
2620
  class_value_field (str): Field containing class values (for vector data)
@@ -3115,95 +2624,38 @@ def export_geotiff_tiles(
3115
2624
  all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
3116
2625
  create_overview (bool): Whether to create an overview image of all tiles
3117
2626
  skip_empty_tiles (bool): If True, skip tiles with no features
3118
- metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
3119
- apply_augmentation (bool): If True, generate augmented versions of each tile.
3120
- This will create multiple variants of each tile using data augmentation techniques.
3121
- Defaults to False.
3122
- augmentation_count (int): Number of augmented versions to generate per tile
3123
- (only used if apply_augmentation=True). Defaults to 3.
3124
- augmentation_transforms (albumentations.Compose, optional): Custom augmentation transforms.
3125
- If None and apply_augmentation=True, uses default transforms from
3126
- get_default_augmentation_transforms(). Should be an albumentations.Compose object.
3127
- Defaults to None.
3128
-
3129
- Returns:
3130
- None: Tiles and labels are saved to out_folder.
3131
-
3132
- Example:
3133
- >>> # Export tiles without augmentation
3134
- >>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif')
3135
- >>>
3136
- >>> # Export tiles with default augmentation (3 augmented versions per tile)
3137
- >>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
3138
- ... apply_augmentation=True)
3139
- >>>
3140
- >>> # Export with custom augmentation
3141
- >>> import albumentations as A
3142
- >>> custom_transform = A.Compose([
3143
- ... A.HorizontalFlip(p=0.5),
3144
- ... A.RandomBrightnessContrast(p=0.5),
3145
- ... ])
3146
- >>> export_geotiff_tiles('image.tif', 'output/', 'labels.tif',
3147
- ... apply_augmentation=True,
3148
- ... augmentation_count=5,
3149
- ... augmentation_transforms=custom_transform)
3150
2627
  """
3151
2628
 
3152
2629
  import logging
3153
2630
 
3154
2631
  logging.getLogger("rasterio").setLevel(logging.ERROR)
3155
2632
 
3156
- # Initialize augmentation transforms if needed
3157
- if apply_augmentation:
3158
- if augmentation_transforms is None:
3159
- augmentation_transforms = get_default_augmentation_transforms(
3160
- tile_size=tile_size
3161
- )
3162
- if not quiet:
3163
- print(
3164
- f"Data augmentation enabled: generating {augmentation_count} augmented versions per tile"
3165
- )
3166
-
3167
2633
  # Create output directories
3168
2634
  os.makedirs(out_folder, exist_ok=True)
3169
2635
  image_dir = os.path.join(out_folder, "images")
3170
2636
  os.makedirs(image_dir, exist_ok=True)
2637
+ label_dir = os.path.join(out_folder, "labels")
2638
+ os.makedirs(label_dir, exist_ok=True)
2639
+ ann_dir = os.path.join(out_folder, "annotations")
2640
+ os.makedirs(ann_dir, exist_ok=True)
3171
2641
 
3172
- # Only create label and annotation directories if class data is provided
3173
- if in_class_data is not None:
3174
- label_dir = os.path.join(out_folder, "labels")
3175
- os.makedirs(label_dir, exist_ok=True)
3176
-
3177
- # Create annotation directory based on metadata format
3178
- if metadata_format in ["PASCAL_VOC", "COCO"]:
3179
- ann_dir = os.path.join(out_folder, "annotations")
3180
- os.makedirs(ann_dir, exist_ok=True)
3181
-
3182
- # Initialize COCO annotations dictionary
3183
- if metadata_format == "COCO":
3184
- coco_annotations = {"images": [], "annotations": [], "categories": []}
3185
- ann_id = 0
3186
-
3187
- # Determine if class data is raster or vector (only if class data provided)
2642
+ # Determine if class data is raster or vector
3188
2643
  is_class_data_raster = False
3189
- if in_class_data is not None:
3190
- if isinstance(in_class_data, str):
3191
- file_ext = Path(in_class_data).suffix.lower()
3192
- # Common raster extensions
3193
- if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
3194
- try:
3195
- with rasterio.open(in_class_data) as src:
3196
- is_class_data_raster = True
3197
- if not quiet:
3198
- print(f"Detected in_class_data as raster: {in_class_data}")
3199
- print(f"Raster CRS: {src.crs}")
3200
- print(f"Raster dimensions: {src.width} x {src.height}")
3201
- except Exception:
3202
- is_class_data_raster = False
2644
+ if isinstance(in_class_data, str):
2645
+ file_ext = Path(in_class_data).suffix.lower()
2646
+ # Common raster extensions
2647
+ if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
2648
+ try:
2649
+ with rasterio.open(in_class_data) as src:
2650
+ is_class_data_raster = True
3203
2651
  if not quiet:
3204
- print(
3205
- f"Unable to open {in_class_data} as raster, trying as vector"
3206
- )
2652
+ print(f"Detected in_class_data as raster: {in_class_data}")
2653
+ print(f"Raster CRS: {src.crs}")
2654
+ print(f"Raster dimensions: {src.width} x {src.height}")
2655
+ except Exception:
2656
+ is_class_data_raster = False
2657
+ if not quiet:
2658
+ print(f"Unable to open {in_class_data} as raster, trying as vector")
3207
2659
 
3208
2660
  # Open the input raster
3209
2661
  with rasterio.open(in_raster) as src:
@@ -3223,10 +2675,10 @@ def export_geotiff_tiles(
3223
2675
  if max_tiles is None:
3224
2676
  max_tiles = total_tiles
3225
2677
 
3226
- # Process classification data (only if class data provided)
2678
+ # Process classification data
3227
2679
  class_to_id = {}
3228
2680
 
3229
- if in_class_data is not None and is_class_data_raster:
2681
+ if is_class_data_raster:
3230
2682
  # Load raster class data
3231
2683
  with rasterio.open(in_class_data) as class_src:
3232
2684
  # Check if raster CRS matches
@@ -3259,18 +2711,7 @@ def export_geotiff_tiles(
3259
2711
 
3260
2712
  # Create class mapping
3261
2713
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
3262
-
3263
- # Populate COCO categories
3264
- if metadata_format == "COCO":
3265
- for cls_val in unique_classes:
3266
- coco_annotations["categories"].append(
3267
- {
3268
- "id": class_to_id[int(cls_val)],
3269
- "name": str(int(cls_val)),
3270
- "supercategory": "object",
3271
- }
3272
- )
3273
- elif in_class_data is not None:
2714
+ else:
3274
2715
  # Load vector class data
3275
2716
  try:
3276
2717
  gdf = gpd.read_file(in_class_data)
@@ -3299,33 +2740,12 @@ def export_geotiff_tiles(
3299
2740
  )
3300
2741
  # Create class mapping
3301
2742
  class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
3302
-
3303
- # Populate COCO categories
3304
- if metadata_format == "COCO":
3305
- for cls_val in unique_classes:
3306
- coco_annotations["categories"].append(
3307
- {
3308
- "id": class_to_id[cls_val],
3309
- "name": str(cls_val),
3310
- "supercategory": "object",
3311
- }
3312
- )
3313
2743
  else:
3314
2744
  if not quiet:
3315
2745
  print(
3316
2746
  f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
3317
2747
  )
3318
2748
  class_to_id = {1: 1} # Default mapping
3319
-
3320
- # Populate COCO categories with default
3321
- if metadata_format == "COCO":
3322
- coco_annotations["categories"].append(
3323
- {
3324
- "id": 1,
3325
- "name": "object",
3326
- "supercategory": "object",
3327
- }
3328
- )
3329
2749
  except Exception as e:
3330
2750
  raise ValueError(f"Error processing vector data: {e}")
3331
2751
 
@@ -3392,8 +2812,8 @@ def export_geotiff_tiles(
3392
2812
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
3393
2813
  has_features = False
3394
2814
 
3395
- # Process classification data to create labels (only if class data provided)
3396
- if in_class_data is not None and is_class_data_raster:
2815
+ # Process classification data to create labels
2816
+ if is_class_data_raster:
3397
2817
  # For raster class data
3398
2818
  with rasterio.open(in_class_data) as class_src:
3399
2819
  # Calculate window in class raster
@@ -3443,7 +2863,7 @@ def export_geotiff_tiles(
3443
2863
  except Exception as e:
3444
2864
  pbar.write(f"Error reading class raster window: {e}")
3445
2865
  stats["errors"] += 1
3446
- elif in_class_data is not None:
2866
+ else:
3447
2867
  # For vector class data
3448
2868
  # Find features that intersect with window
3449
2869
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -3486,8 +2906,8 @@ def export_geotiff_tiles(
3486
2906
  pbar.write(f"Error rasterizing feature {idx}: {e}")
3487
2907
  stats["errors"] += 1
3488
2908
 
3489
- # Skip tile if no features and skip_empty_tiles is True (only when class data provided)
3490
- if in_class_data is not None and skip_empty_tiles and not has_features:
2909
+ # Skip tile if no features and skip_empty_tiles is True
2910
+ if skip_empty_tiles and not has_features:
3491
2911
  pbar.update(1)
3492
2912
  tile_index += 1
3493
2913
  continue
@@ -3495,316 +2915,119 @@ def export_geotiff_tiles(
3495
2915
  # Read image data
3496
2916
  image_data = src.read(window=window)
3497
2917
 
3498
- # Helper function to save a single tile (original or augmented)
3499
- def save_tile(
3500
- img_data,
3501
- lbl_mask,
3502
- tile_id,
3503
- img_profile,
3504
- window_trans,
3505
- is_augmented=False,
3506
- ):
3507
- """Save a single image and label tile."""
3508
- # Export image as GeoTIFF
3509
- image_path = os.path.join(image_dir, f"tile_{tile_id:06d}.tif")
3510
-
3511
- # Update profile
3512
- img_profile_copy = img_profile.copy()
3513
- img_profile_copy.update(
3514
- {
3515
- "height": tile_size,
3516
- "width": tile_size,
3517
- "count": img_data.shape[0],
3518
- "transform": window_trans,
3519
- }
3520
- )
3521
-
3522
- # Save image as GeoTIFF
3523
- try:
3524
- with rasterio.open(image_path, "w", **img_profile_copy) as dst:
3525
- dst.write(img_data)
3526
- stats["total_tiles"] += 1
3527
- except Exception as e:
3528
- pbar.write(f"ERROR saving image GeoTIFF: {e}")
3529
- stats["errors"] += 1
3530
- return
3531
-
3532
- # Export label as GeoTIFF (only if class data provided)
3533
- if in_class_data is not None:
3534
- # Create profile for label GeoTIFF
3535
- label_profile = {
3536
- "driver": "GTiff",
3537
- "height": tile_size,
3538
- "width": tile_size,
3539
- "count": 1,
3540
- "dtype": "uint8",
3541
- "crs": src.crs,
3542
- "transform": window_trans,
3543
- }
3544
-
3545
- label_path = os.path.join(label_dir, f"tile_{tile_id:06d}.tif")
3546
- try:
3547
- with rasterio.open(label_path, "w", **label_profile) as dst:
3548
- dst.write(lbl_mask.astype(np.uint8), 1)
3549
-
3550
- if not is_augmented and np.any(lbl_mask > 0):
3551
- stats["tiles_with_features"] += 1
3552
- stats["feature_pixels"] += np.count_nonzero(lbl_mask)
3553
- except Exception as e:
3554
- pbar.write(f"ERROR saving label GeoTIFF: {e}")
3555
- stats["errors"] += 1
2918
+ # Export image as GeoTIFF
2919
+ image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
3556
2920
 
3557
- # Save original tile
3558
- save_tile(
3559
- image_data,
3560
- label_mask,
3561
- tile_index,
3562
- src.profile,
3563
- window_transform,
3564
- is_augmented=False,
2921
+ # Create profile for image GeoTIFF
2922
+ image_profile = src.profile.copy()
2923
+ image_profile.update(
2924
+ {
2925
+ "height": tile_size,
2926
+ "width": tile_size,
2927
+ "count": image_data.shape[0],
2928
+ "transform": window_transform,
2929
+ }
3565
2930
  )
3566
2931
 
3567
- # Generate and save augmented tiles if enabled
3568
- if apply_augmentation:
3569
- for aug_idx in range(augmentation_count):
3570
- # Prepare image for augmentation (convert from CHW to HWC)
3571
- img_for_aug = np.transpose(image_data, (1, 2, 0))
3572
-
3573
- # Ensure uint8 data type for albumentations
3574
- # Albumentations expects uint8 for most transforms
3575
- if not np.issubdtype(img_for_aug.dtype, np.uint8):
3576
- # If image is float, scale to 0-255 and convert to uint8
3577
- if np.issubdtype(img_for_aug.dtype, np.floating):
3578
- img_for_aug = (
3579
- (img_for_aug * 255).clip(0, 255).astype(np.uint8)
3580
- )
3581
- else:
3582
- img_for_aug = img_for_aug.astype(np.uint8)
3583
-
3584
- # Apply augmentation
3585
- try:
3586
- if in_class_data is not None:
3587
- # Augment both image and mask
3588
- augmented = augmentation_transforms(
3589
- image=img_for_aug, mask=label_mask
3590
- )
3591
- aug_image = augmented["image"]
3592
- aug_mask = augmented["mask"]
3593
- else:
3594
- # Augment only image
3595
- augmented = augmentation_transforms(image=img_for_aug)
3596
- aug_image = augmented["image"]
3597
- aug_mask = label_mask
3598
-
3599
- # Convert back from HWC to CHW
3600
- aug_image = np.transpose(aug_image, (2, 0, 1))
3601
-
3602
- # Ensure correct dtype for saving
3603
- aug_image = aug_image.astype(image_data.dtype)
2932
+ # Save image as GeoTIFF
2933
+ try:
2934
+ with rasterio.open(image_path, "w", **image_profile) as dst:
2935
+ dst.write(image_data)
2936
+ stats["total_tiles"] += 1
2937
+ except Exception as e:
2938
+ pbar.write(f"ERROR saving image GeoTIFF: {e}")
2939
+ stats["errors"] += 1
3604
2940
 
3605
- # Generate unique tile ID for augmented version
3606
- # Use a collision-free numbering scheme: (tile_index * (augmentation_count + 1)) + aug_idx + 1
3607
- aug_tile_id = (
3608
- (tile_index * (augmentation_count + 1)) + aug_idx + 1
3609
- )
2941
+ # Create profile for label GeoTIFF
2942
+ label_profile = {
2943
+ "driver": "GTiff",
2944
+ "height": tile_size,
2945
+ "width": tile_size,
2946
+ "count": 1,
2947
+ "dtype": "uint8",
2948
+ "crs": src.crs,
2949
+ "transform": window_transform,
2950
+ }
3610
2951
 
3611
- # Save augmented tile
3612
- save_tile(
3613
- aug_image,
3614
- aug_mask,
3615
- aug_tile_id,
3616
- src.profile,
3617
- window_transform,
3618
- is_augmented=True,
3619
- )
2952
+ # Export label as GeoTIFF
2953
+ label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
2954
+ try:
2955
+ with rasterio.open(label_path, "w", **label_profile) as dst:
2956
+ dst.write(label_mask.astype(np.uint8), 1)
3620
2957
 
3621
- except Exception as e:
3622
- pbar.write(
3623
- f"ERROR applying augmentation {aug_idx} to tile {tile_index}: {e}"
3624
- )
3625
- stats["errors"] += 1
2958
+ if has_features:
2959
+ stats["tiles_with_features"] += 1
2960
+ stats["feature_pixels"] += np.count_nonzero(label_mask)
2961
+ except Exception as e:
2962
+ pbar.write(f"ERROR saving label GeoTIFF: {e}")
2963
+ stats["errors"] += 1
3626
2964
 
3627
- # Create annotations for object detection if using vector class data
2965
+ # Create XML annotation for object detection if using vector class data
3628
2966
  if (
3629
- in_class_data is not None
3630
- and not is_class_data_raster
2967
+ not is_class_data_raster
3631
2968
  and "gdf" in locals()
3632
2969
  and len(window_features) > 0
3633
2970
  ):
3634
- if metadata_format == "PASCAL_VOC":
3635
- # Create XML annotation
3636
- root = ET.Element("annotation")
3637
- ET.SubElement(root, "folder").text = "images"
3638
- ET.SubElement(root, "filename").text = (
3639
- f"tile_{tile_index:06d}.tif"
3640
- )
3641
-
3642
- size = ET.SubElement(root, "size")
3643
- ET.SubElement(size, "width").text = str(tile_size)
3644
- ET.SubElement(size, "height").text = str(tile_size)
3645
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
3646
-
3647
- # Add georeference information
3648
- geo = ET.SubElement(root, "georeference")
3649
- ET.SubElement(geo, "crs").text = str(src.crs)
3650
- ET.SubElement(geo, "transform").text = str(
3651
- window_transform
3652
- ).replace("\n", "")
3653
- ET.SubElement(geo, "bounds").text = (
3654
- f"{minx}, {miny}, {maxx}, {maxy}"
3655
- )
3656
-
3657
- # Add objects
3658
- for idx, feature in window_features.iterrows():
3659
- # Get feature class
3660
- if class_value_field in feature:
3661
- class_val = feature[class_value_field]
3662
- else:
3663
- class_val = "object"
3664
-
3665
- # Get geometry bounds in pixel coordinates
3666
- geom = feature.geometry.intersection(window_bounds)
3667
- if not geom.is_empty:
3668
- # Get bounds in world coordinates
3669
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3670
-
3671
- # Convert to pixel coordinates
3672
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3673
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3674
-
3675
- # Ensure coordinates are within tile bounds
3676
- xmin = max(0, min(tile_size, int(col_min)))
3677
- ymin = max(0, min(tile_size, int(row_min)))
3678
- xmax = max(0, min(tile_size, int(col_max)))
3679
- ymax = max(0, min(tile_size, int(row_max)))
3680
-
3681
- # Only add if the box has non-zero area
3682
- if xmax > xmin and ymax > ymin:
3683
- obj = ET.SubElement(root, "object")
3684
- ET.SubElement(obj, "name").text = str(class_val)
3685
- ET.SubElement(obj, "difficult").text = "0"
3686
-
3687
- bbox = ET.SubElement(obj, "bndbox")
3688
- ET.SubElement(bbox, "xmin").text = str(xmin)
3689
- ET.SubElement(bbox, "ymin").text = str(ymin)
3690
- ET.SubElement(bbox, "xmax").text = str(xmax)
3691
- ET.SubElement(bbox, "ymax").text = str(ymax)
3692
-
3693
- # Save XML
3694
- tree = ET.ElementTree(root)
3695
- xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3696
- tree.write(xml_path)
3697
-
3698
- elif metadata_format == "COCO":
3699
- # Add image info
3700
- image_id = tile_index
3701
- coco_annotations["images"].append(
3702
- {
3703
- "id": image_id,
3704
- "file_name": f"tile_{tile_index:06d}.tif",
3705
- "width": tile_size,
3706
- "height": tile_size,
3707
- "crs": str(src.crs),
3708
- "transform": str(window_transform),
3709
- }
3710
- )
3711
-
3712
- # Add annotations for each feature
3713
- for _, feature in window_features.iterrows():
3714
- # Get feature class
3715
- if class_value_field in feature:
3716
- class_val = feature[class_value_field]
3717
- category_id = class_to_id.get(class_val, 1)
3718
- else:
3719
- category_id = 1
2971
+ # Create XML annotation
2972
+ root = ET.Element("annotation")
2973
+ ET.SubElement(root, "folder").text = "images"
2974
+ ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
3720
2975
 
3721
- # Get geometry bounds
3722
- geom = feature.geometry.intersection(window_bounds)
3723
- if not geom.is_empty:
3724
- # Get bounds in world coordinates
3725
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3726
-
3727
- # Convert to pixel coordinates
3728
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3729
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3730
-
3731
- # Ensure coordinates are within tile bounds
3732
- xmin = max(0, min(tile_size, int(col_min)))
3733
- ymin = max(0, min(tile_size, int(row_min)))
3734
- xmax = max(0, min(tile_size, int(col_max)))
3735
- ymax = max(0, min(tile_size, int(row_max)))
3736
-
3737
- # Skip if box is too small
3738
- if xmax - xmin < 1 or ymax - ymin < 1:
3739
- continue
3740
-
3741
- width = xmax - xmin
3742
- height = ymax - ymin
3743
-
3744
- # Add annotation
3745
- ann_id += 1
3746
- coco_annotations["annotations"].append(
3747
- {
3748
- "id": ann_id,
3749
- "image_id": image_id,
3750
- "category_id": category_id,
3751
- "bbox": [xmin, ymin, width, height],
3752
- "area": width * height,
3753
- "iscrowd": 0,
3754
- }
3755
- )
2976
+ size = ET.SubElement(root, "size")
2977
+ ET.SubElement(size, "width").text = str(tile_size)
2978
+ ET.SubElement(size, "height").text = str(tile_size)
2979
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
3756
2980
 
3757
- elif metadata_format == "YOLO":
3758
- # Create YOLO format annotations
3759
- yolo_annotations = []
2981
+ # Add georeference information
2982
+ geo = ET.SubElement(root, "georeference")
2983
+ ET.SubElement(geo, "crs").text = str(src.crs)
2984
+ ET.SubElement(geo, "transform").text = str(
2985
+ window_transform
2986
+ ).replace("\n", "")
2987
+ ET.SubElement(geo, "bounds").text = (
2988
+ f"{minx}, {miny}, {maxx}, {maxy}"
2989
+ )
3760
2990
 
3761
- for _, feature in window_features.iterrows():
3762
- # Get feature class
3763
- if class_value_field in feature:
3764
- class_val = feature[class_value_field]
3765
- # YOLO uses 0-indexed class IDs
3766
- class_id = class_to_id.get(class_val, 1) - 1
3767
- else:
3768
- class_id = 0
2991
+ # Add objects
2992
+ for idx, feature in window_features.iterrows():
2993
+ # Get feature class
2994
+ if class_value_field in feature:
2995
+ class_val = feature[class_value_field]
2996
+ else:
2997
+ class_val = "object"
3769
2998
 
3770
- # Get geometry bounds
3771
- geom = feature.geometry.intersection(window_bounds)
3772
- if not geom.is_empty:
3773
- # Get bounds in world coordinates
3774
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3775
-
3776
- # Convert to pixel coordinates
3777
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3778
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3779
-
3780
- # Ensure coordinates are within tile bounds
3781
- xmin = max(0, min(tile_size, col_min))
3782
- ymin = max(0, min(tile_size, row_min))
3783
- xmax = max(0, min(tile_size, col_max))
3784
- ymax = max(0, min(tile_size, row_max))
3785
-
3786
- # Skip if box is too small
3787
- if xmax - xmin < 1 or ymax - ymin < 1:
3788
- continue
3789
-
3790
- # Calculate normalized coordinates (YOLO format)
3791
- x_center = ((xmin + xmax) / 2) / tile_size
3792
- y_center = ((ymin + ymax) / 2) / tile_size
3793
- width = (xmax - xmin) / tile_size
3794
- height = (ymax - ymin) / tile_size
3795
-
3796
- # Add YOLO annotation line
3797
- yolo_annotations.append(
3798
- f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
3799
- )
2999
+ # Get geometry bounds in pixel coordinates
3000
+ geom = feature.geometry.intersection(window_bounds)
3001
+ if not geom.is_empty:
3002
+ # Get bounds in world coordinates
3003
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3004
+
3005
+ # Convert to pixel coordinates
3006
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3007
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3008
+
3009
+ # Ensure coordinates are within tile bounds
3010
+ xmin = max(0, min(tile_size, int(col_min)))
3011
+ ymin = max(0, min(tile_size, int(row_min)))
3012
+ xmax = max(0, min(tile_size, int(col_max)))
3013
+ ymax = max(0, min(tile_size, int(row_max)))
3014
+
3015
+ # Only add if the box has non-zero area
3016
+ if xmax > xmin and ymax > ymin:
3017
+ obj = ET.SubElement(root, "object")
3018
+ ET.SubElement(obj, "name").text = str(class_val)
3019
+ ET.SubElement(obj, "difficult").text = "0"
3020
+
3021
+ bbox = ET.SubElement(obj, "bndbox")
3022
+ ET.SubElement(bbox, "xmin").text = str(xmin)
3023
+ ET.SubElement(bbox, "ymin").text = str(ymin)
3024
+ ET.SubElement(bbox, "xmax").text = str(xmax)
3025
+ ET.SubElement(bbox, "ymax").text = str(ymax)
3800
3026
 
3801
- # Save YOLO annotations to text file
3802
- if yolo_annotations:
3803
- yolo_path = os.path.join(
3804
- label_dir, f"tile_{tile_index:06d}.txt"
3805
- )
3806
- with open(yolo_path, "w") as f:
3807
- f.write("\n".join(yolo_annotations))
3027
+ # Save XML
3028
+ tree = ET.ElementTree(root)
3029
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3030
+ tree.write(xml_path)
3808
3031
 
3809
3032
  # Update progress bar
3810
3033
  pbar.update(1)
@@ -3822,39 +3045,6 @@ def export_geotiff_tiles(
3822
3045
  # Close progress bar
3823
3046
  pbar.close()
3824
3047
 
3825
- # Save COCO annotations if applicable (only if class data provided)
3826
- if in_class_data is not None and metadata_format == "COCO":
3827
- try:
3828
- with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3829
- json.dump(coco_annotations, f, indent=2)
3830
- if not quiet:
3831
- print(
3832
- f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
3833
- f"{len(coco_annotations['annotations'])} annotations, "
3834
- f"{len(coco_annotations['categories'])} categories"
3835
- )
3836
- except Exception as e:
3837
- if not quiet:
3838
- print(f"ERROR saving COCO annotations: {e}")
3839
- stats["errors"] += 1
3840
-
3841
- # Save YOLO classes file if applicable (only if class data provided)
3842
- if in_class_data is not None and metadata_format == "YOLO":
3843
- try:
3844
- # Create classes.txt with class names
3845
- classes_path = os.path.join(out_folder, "classes.txt")
3846
- # Sort by class ID to ensure correct order
3847
- sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
3848
- with open(classes_path, "w") as f:
3849
- for class_val, _ in sorted_classes:
3850
- f.write(f"{class_val}\n")
3851
- if not quiet:
3852
- print(f"Saved YOLO classes file with {len(class_to_id)} classes")
3853
- except Exception as e:
3854
- if not quiet:
3855
- print(f"ERROR saving YOLO classes file: {e}")
3856
- stats["errors"] += 1
3857
-
3858
3048
  # Create overview image if requested
3859
3049
  if create_overview and stats["tile_coordinates"]:
3860
3050
  try:
@@ -3872,14 +3062,13 @@ def export_geotiff_tiles(
3872
3062
  if not quiet:
3873
3063
  print("\n------- Export Summary -------")
3874
3064
  print(f"Total tiles exported: {stats['total_tiles']}")
3875
- if in_class_data is not None:
3065
+ print(
3066
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3067
+ )
3068
+ if stats["tiles_with_features"] > 0:
3876
3069
  print(
3877
- f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3070
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3878
3071
  )
3879
- if stats["tiles_with_features"] > 0:
3880
- print(
3881
- f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3882
- )
3883
3072
  if stats["errors"] > 0:
3884
3073
  print(f"Errors encountered: {stats['errors']}")
3885
3074
  print(f"Output saved to: {out_folder}")
@@ -3888,6 +3077,7 @@ def export_geotiff_tiles(
3888
3077
  if stats["total_tiles"] > 0:
3889
3078
  print("\n------- Georeference Verification -------")
3890
3079
  sample_image = os.path.join(image_dir, f"tile_0.tif")
3080
+ sample_label = os.path.join(label_dir, f"tile_0.tif")
3891
3081
 
3892
3082
  if os.path.exists(sample_image):
3893
3083
  try:
@@ -3903,22 +3093,19 @@ def export_geotiff_tiles(
3903
3093
  except Exception as e:
3904
3094
  print(f"Error verifying image georeference: {e}")
3905
3095
 
3906
- # Only verify label if class data was provided
3907
- if in_class_data is not None:
3908
- sample_label = os.path.join(label_dir, f"tile_0.tif")
3909
- if os.path.exists(sample_label):
3910
- try:
3911
- with rasterio.open(sample_label) as lbl:
3912
- print(f"Label CRS: {lbl.crs}")
3913
- print(f"Label transform: {lbl.transform}")
3914
- print(
3915
- f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3916
- )
3917
- print(
3918
- f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3919
- )
3920
- except Exception as e:
3921
- print(f"Error verifying label georeference: {e}")
3096
+ if os.path.exists(sample_label):
3097
+ try:
3098
+ with rasterio.open(sample_label) as lbl:
3099
+ print(f"Label CRS: {lbl.crs}")
3100
+ print(f"Label transform: {lbl.transform}")
3101
+ print(
3102
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3103
+ )
3104
+ print(
3105
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3106
+ )
3107
+ except Exception as e:
3108
+ print(f"Error verifying label georeference: {e}")
3922
3109
 
3923
3110
  # Return statistics dictionary for further processing if needed
3924
3111
  return stats
@@ -3926,9 +3113,8 @@ def export_geotiff_tiles(
3926
3113
 
3927
3114
  def export_geotiff_tiles_batch(
3928
3115
  images_folder,
3929
- masks_folder=None,
3930
- masks_file=None,
3931
- output_folder=None,
3116
+ masks_folder,
3117
+ output_folder,
3932
3118
  tile_size=256,
3933
3119
  stride=128,
3934
3120
  class_value_field="class",
@@ -3936,43 +3122,25 @@ def export_geotiff_tiles_batch(
3936
3122
  max_tiles=None,
3937
3123
  quiet=False,
3938
3124
  all_touched=True,
3125
+ create_overview=False,
3939
3126
  skip_empty_tiles=False,
3940
3127
  image_extensions=None,
3941
3128
  mask_extensions=None,
3942
- match_by_name=False,
3943
- metadata_format="PASCAL_VOC",
3944
3129
  ) -> Dict[str, Any]:
3945
3130
  """
3946
- Export georeferenced GeoTIFF tiles from images and optionally masks.
3947
-
3948
- This function supports four modes:
3949
- 1. Images only (no masks) - when neither masks_file nor masks_folder is provided
3950
- 2. Single vector file covering all images (masks_file parameter)
3951
- 3. Multiple vector files, one per image (masks_folder parameter)
3952
- 4. Multiple raster mask files (masks_folder parameter)
3953
-
3954
- For mode 1 (images only), only image tiles will be exported without labels.
3131
+ Export georeferenced GeoTIFF tiles from folders of images and masks.
3955
3132
 
3956
- For mode 2 (single vector file), specify masks_file path. The function will
3957
- use spatial intersection to determine which features apply to each image.
3133
+ This function processes multiple image-mask pairs from input folders,
3134
+ generating tiles for each pair. All image tiles are saved to a single
3135
+ 'images' folder and all mask tiles to a single 'masks' folder.
3958
3136
 
3959
- For mode 3/4 (multiple mask files), specify masks_folder path. Images and masks
3960
- are paired either by matching filenames (match_by_name=True) or by sorted order
3961
- (match_by_name=False).
3962
-
3963
- All image tiles are saved to a single 'images' folder and all mask tiles (if provided)
3964
- to a single 'masks' folder within the output directory.
3137
+ Images and masks are paired by their sorted order (alphabetically), not by
3138
+ filename matching. The number of images and masks must be equal.
3965
3139
 
3966
3140
  Args:
3967
3141
  images_folder (str): Path to folder containing raster images
3968
- masks_folder (str, optional): Path to folder containing classification masks/vectors.
3969
- Use this for multiple mask files (one per image or raster masks). If not provided
3970
- and masks_file is also not provided, only image tiles will be exported.
3971
- masks_file (str, optional): Path to a single vector file covering all images.
3972
- Use this for a single GeoJSON/Shapefile that covers multiple images. If not provided
3973
- and masks_folder is also not provided, only image tiles will be exported.
3974
- output_folder (str, optional): Path to output folder. If None, creates 'tiles'
3975
- subfolder in images_folder.
3142
+ masks_folder (str): Path to folder containing classification masks/vectors
3143
+ output_folder (str): Path to output folder
3976
3144
  tile_size (int): Size of tiles in pixels (square)
3977
3145
  stride (int): Step size between tiles
3978
3146
  class_value_field (str): Field containing class values (for vector data)
@@ -3981,66 +3149,21 @@ def export_geotiff_tiles_batch(
3981
3149
  quiet (bool): If True, suppress non-essential output
3982
3150
  all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
3983
3151
  create_overview (bool): Whether to create an overview image of all tiles
3984
- skip_empty_tiles (bool): If True, skip tiles with no features
3985
- image_extensions (list): List of image file extensions to process (default: common raster formats)
3986
- mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
3987
- match_by_name (bool): If True, match image and mask files by base filename.
3988
- If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
3989
- metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
3990
- Default is "PASCAL_VOC".
3991
-
3992
- Returns:
3993
- Dict[str, Any]: Dictionary containing batch processing statistics
3994
-
3995
- Raises:
3996
- ValueError: If no images found, or if masks_folder and masks_file are both specified,
3997
- or if counts don't match when using masks_folder with match_by_name=False.
3998
-
3999
- Examples:
4000
- # Images only (no masks)
4001
- >>> stats = export_geotiff_tiles_batch(
4002
- ... images_folder='data/images',
4003
- ... output_folder='output/tiles'
4004
- ... )
4005
-
4006
- # Single vector file covering all images
4007
- >>> stats = export_geotiff_tiles_batch(
4008
- ... images_folder='data/images',
4009
- ... masks_file='data/buildings.geojson',
4010
- ... output_folder='output/tiles'
4011
- ... )
3152
+ skip_empty_tiles (bool): If True, skip tiles with no features
3153
+ image_extensions (list): List of image file extensions to process (default: common raster formats)
3154
+ mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
4012
3155
 
4013
- # Multiple vector files, matched by filename
4014
- >>> stats = export_geotiff_tiles_batch(
4015
- ... images_folder='data/images',
4016
- ... masks_folder='data/masks',
4017
- ... output_folder='output/tiles',
4018
- ... match_by_name=True
4019
- ... )
3156
+ Returns:
3157
+ Dict[str, Any]: Dictionary containing batch processing statistics
4020
3158
 
4021
- # Multiple mask files, matched by sorted order
4022
- >>> stats = export_geotiff_tiles_batch(
4023
- ... images_folder='data/images',
4024
- ... masks_folder='data/masks',
4025
- ... output_folder='output/tiles',
4026
- ... match_by_name=False
4027
- ... )
3159
+ Raises:
3160
+ ValueError: If no images or masks found, or if counts don't match
4028
3161
  """
4029
3162
 
4030
3163
  import logging
4031
3164
 
4032
3165
  logging.getLogger("rasterio").setLevel(logging.ERROR)
4033
3166
 
4034
- # Validate input parameters
4035
- if masks_folder is not None and masks_file is not None:
4036
- raise ValueError(
4037
- "Cannot specify both masks_folder and masks_file. Please use only one."
4038
- )
4039
-
4040
- # Default output folder if not specified
4041
- if output_folder is None:
4042
- output_folder = os.path.join(images_folder, "tiles")
4043
-
4044
3167
  # Default extensions if not provided
4045
3168
  if image_extensions is None:
4046
3169
  image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
@@ -4067,37 +3190,9 @@ def export_geotiff_tiles_batch(
4067
3190
  # Create output folder structure
4068
3191
  os.makedirs(output_folder, exist_ok=True)
4069
3192
  output_images_dir = os.path.join(output_folder, "images")
3193
+ output_masks_dir = os.path.join(output_folder, "masks")
4070
3194
  os.makedirs(output_images_dir, exist_ok=True)
4071
-
4072
- # Only create masks directory if masks are provided
4073
- output_masks_dir = None
4074
- if masks_folder is not None or masks_file is not None:
4075
- output_masks_dir = os.path.join(output_folder, "masks")
4076
- os.makedirs(output_masks_dir, exist_ok=True)
4077
-
4078
- # Create annotation directory based on metadata format (only if masks are provided)
4079
- ann_dir = None
4080
- if (masks_folder is not None or masks_file is not None) and metadata_format in [
4081
- "PASCAL_VOC",
4082
- "COCO",
4083
- ]:
4084
- ann_dir = os.path.join(output_folder, "annotations")
4085
- os.makedirs(ann_dir, exist_ok=True)
4086
-
4087
- # Initialize COCO annotations dictionary (only if masks are provided)
4088
- coco_annotations = None
4089
- if (
4090
- masks_folder is not None or masks_file is not None
4091
- ) and metadata_format == "COCO":
4092
- coco_annotations = {"images": [], "annotations": [], "categories": []}
4093
-
4094
- # Initialize YOLO class set (only if masks are provided)
4095
- yolo_classes = (
4096
- set()
4097
- if (masks_folder is not None or masks_file is not None)
4098
- and metadata_format == "YOLO"
4099
- else None
4100
- )
3195
+ os.makedirs(output_masks_dir, exist_ok=True)
4101
3196
 
4102
3197
  # Get list of image files
4103
3198
  image_files = []
@@ -4105,105 +3200,30 @@ def export_geotiff_tiles_batch(
4105
3200
  pattern = os.path.join(images_folder, f"*{ext}")
4106
3201
  image_files.extend(glob.glob(pattern))
4107
3202
 
3203
+ # Get list of mask files
3204
+ mask_files = []
3205
+ for ext in mask_extensions:
3206
+ pattern = os.path.join(masks_folder, f"*{ext}")
3207
+ mask_files.extend(glob.glob(pattern))
3208
+
4108
3209
  # Sort files for consistent processing
4109
3210
  image_files.sort()
3211
+ mask_files.sort()
4110
3212
 
4111
3213
  if not image_files:
4112
3214
  raise ValueError(
4113
3215
  f"No image files found in {images_folder} with extensions {image_extensions}"
4114
3216
  )
4115
3217
 
4116
- # Handle different mask input modes
4117
- use_single_mask_file = masks_file is not None
4118
- has_masks = masks_file is not None or masks_folder is not None
4119
- mask_files = []
4120
- image_mask_pairs = []
4121
-
4122
- if not has_masks:
4123
- # Mode 0: No masks - create pairs with None for mask
4124
- for image_file in image_files:
4125
- image_mask_pairs.append((image_file, None, None))
4126
-
4127
- elif use_single_mask_file:
4128
- # Mode 1: Single vector file covering all images
4129
- if not os.path.exists(masks_file):
4130
- raise ValueError(f"Mask file not found: {masks_file}")
4131
-
4132
- # Load the single mask file once - will be spatially filtered per image
4133
- single_mask_gdf = gpd.read_file(masks_file)
4134
-
4135
- if not quiet:
4136
- print(f"Using single mask file: {masks_file}")
4137
- print(
4138
- f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
4139
- )
4140
-
4141
- # Create pairs with the same mask file for all images
4142
- for image_file in image_files:
4143
- image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
4144
-
4145
- else:
4146
- # Mode 2/3: Multiple mask files (vector or raster)
4147
- # Get list of mask files
4148
- for ext in mask_extensions:
4149
- pattern = os.path.join(masks_folder, f"*{ext}")
4150
- mask_files.extend(glob.glob(pattern))
4151
-
4152
- # Sort files for consistent processing
4153
- mask_files.sort()
4154
-
4155
- if not mask_files:
4156
- raise ValueError(
4157
- f"No mask files found in {masks_folder} with extensions {mask_extensions}"
4158
- )
4159
-
4160
- # Match images to masks
4161
- if match_by_name:
4162
- # Match by base filename
4163
- image_dict = {
4164
- os.path.splitext(os.path.basename(f))[0]: f for f in image_files
4165
- }
4166
- mask_dict = {
4167
- os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
4168
- }
4169
-
4170
- # Find matching pairs
4171
- for img_base, img_path in image_dict.items():
4172
- if img_base in mask_dict:
4173
- image_mask_pairs.append((img_path, mask_dict[img_base], None))
4174
- else:
4175
- if not quiet:
4176
- print(f"Warning: No mask found for image {img_base}")
4177
-
4178
- if not image_mask_pairs:
4179
- # Provide detailed error message with found files
4180
- image_bases = list(image_dict.keys())
4181
- mask_bases = list(mask_dict.keys())
4182
- error_msg = (
4183
- "No matching image-mask pairs found when matching by filename. "
4184
- "Check that image and mask files have matching base names.\n"
4185
- f"Found {len(image_bases)} image(s): "
4186
- f"{', '.join(image_bases[:5]) if image_bases else 'None found'}"
4187
- f"{'...' if len(image_bases) > 5 else ''}\n"
4188
- f"Found {len(mask_bases)} mask(s): "
4189
- f"{', '.join(mask_bases[:5]) if mask_bases else 'None found'}"
4190
- f"{'...' if len(mask_bases) > 5 else ''}\n"
4191
- "Tip: Set match_by_name=False to match by sorted order, or ensure filenames match."
4192
- )
4193
- raise ValueError(error_msg)
4194
-
4195
- else:
4196
- # Match by sorted order
4197
- if len(image_files) != len(mask_files):
4198
- raise ValueError(
4199
- f"Number of image files ({len(image_files)}) does not match "
4200
- f"number of mask files ({len(mask_files)}) when matching by sorted order. "
4201
- f"Use match_by_name=True for filename-based matching."
4202
- )
3218
+ if not mask_files:
3219
+ raise ValueError(
3220
+ f"No mask files found in {masks_folder} with extensions {mask_extensions}"
3221
+ )
4203
3222
 
4204
- # Create pairs by sorted order
4205
- for image_file, mask_file in zip(image_files, mask_files):
4206
- image_mask_pairs.append((image_file, mask_file, None))
3223
+ if len(image_files) != len(mask_files):
3224
+ raise ValueError(
3225
+ f"Number of image files ({len(image_files)}) does not match number of mask files ({len(mask_files)})"
3226
+ )
4207
3227
 
4208
3228
  # Initialize batch statistics
4209
3229
  batch_stats = {
@@ -4217,28 +3237,23 @@ def export_geotiff_tiles_batch(
4217
3237
  }
4218
3238
 
4219
3239
  if not quiet:
4220
- if not has_masks:
4221
- print(
4222
- f"Found {len(image_files)} image files to process (images only, no masks)"
4223
- )
4224
- elif use_single_mask_file:
4225
- print(f"Found {len(image_files)} image files to process")
4226
- print(f"Using single mask file: {masks_file}")
4227
- else:
4228
- print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
4229
- print(f"Processing batch from {images_folder} and {masks_folder}")
3240
+ print(
3241
+ f"Found {len(image_files)} image files and {len(mask_files)} mask files to process"
3242
+ )
3243
+ print(f"Processing batch from {images_folder} and {masks_folder}")
4230
3244
  print(f"Output folder: {output_folder}")
4231
3245
  print("-" * 60)
4232
3246
 
4233
3247
  # Global tile counter for unique naming
4234
3248
  global_tile_counter = 0
4235
3249
 
4236
- # Process each image-mask pair
4237
- for idx, (image_file, mask_file, mask_gdf) in enumerate(
3250
+ # Process each image-mask pair by sorted order
3251
+ for idx, (image_file, mask_file) in enumerate(
4238
3252
  tqdm(
4239
- image_mask_pairs,
3253
+ zip(image_files, mask_files),
4240
3254
  desc="Processing image pairs",
4241
3255
  disable=quiet,
3256
+ total=len(image_files),
4242
3257
  )
4243
3258
  ):
4244
3259
  batch_stats["total_image_pairs"] += 1
@@ -4250,17 +3265,9 @@ def export_geotiff_tiles_batch(
4250
3265
  if not quiet:
4251
3266
  print(f"\nProcessing: {base_name}")
4252
3267
  print(f" Image: {os.path.basename(image_file)}")
4253
- if mask_file is not None:
4254
- if use_single_mask_file:
4255
- print(
4256
- f" Mask: {os.path.basename(mask_file)} (spatially filtered)"
4257
- )
4258
- else:
4259
- print(f" Mask: {os.path.basename(mask_file)}")
4260
- else:
4261
- print(f" Mask: None (images only)")
3268
+ print(f" Mask: {os.path.basename(mask_file)}")
4262
3269
 
4263
- # Process the image-mask pair
3270
+ # Process the image-mask pair manually to get direct control over tile saving
4264
3271
  tiles_generated = _process_image_mask_pair(
4265
3272
  image_file=image_file,
4266
3273
  mask_file=mask_file,
@@ -4276,15 +3283,6 @@ def export_geotiff_tiles_batch(
4276
3283
  all_touched=all_touched,
4277
3284
  skip_empty_tiles=skip_empty_tiles,
4278
3285
  quiet=quiet,
4279
- mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
4280
- use_single_mask_file=use_single_mask_file,
4281
- metadata_format=metadata_format,
4282
- ann_dir=(
4283
- ann_dir
4284
- if "ann_dir" in locals()
4285
- and metadata_format in ["PASCAL_VOC", "COCO"]
4286
- else None
4287
- ),
4288
3286
  )
4289
3287
 
4290
3288
  # Update counters
@@ -4306,23 +3304,6 @@ def export_geotiff_tiles_batch(
4306
3304
  }
4307
3305
  )
4308
3306
 
4309
- # Aggregate COCO annotations
4310
- if metadata_format == "COCO" and "coco_data" in tiles_generated:
4311
- coco_data = tiles_generated["coco_data"]
4312
- # Add images and annotations
4313
- coco_annotations["images"].extend(coco_data.get("images", []))
4314
- coco_annotations["annotations"].extend(coco_data.get("annotations", []))
4315
- # Merge categories (avoid duplicates)
4316
- for cat in coco_data.get("categories", []):
4317
- if not any(
4318
- c["id"] == cat["id"] for c in coco_annotations["categories"]
4319
- ):
4320
- coco_annotations["categories"].append(cat)
4321
-
4322
- # Aggregate YOLO classes
4323
- if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
4324
- yolo_classes.update(tiles_generated["yolo_classes"])
4325
-
4326
3307
  except Exception as e:
4327
3308
  if not quiet:
4328
3309
  print(f"ERROR processing {base_name}: {e}")
@@ -4331,33 +3312,6 @@ def export_geotiff_tiles_batch(
4331
3312
  )
4332
3313
  batch_stats["errors"] += 1
4333
3314
 
4334
- # Save aggregated COCO annotations
4335
- if metadata_format == "COCO" and coco_annotations:
4336
- import json
4337
-
4338
- coco_path = os.path.join(ann_dir, "instances.json")
4339
- with open(coco_path, "w") as f:
4340
- json.dump(coco_annotations, f, indent=2)
4341
- if not quiet:
4342
- print(f"\nSaved COCO annotations: {coco_path}")
4343
- print(
4344
- f" Images: {len(coco_annotations['images'])}, "
4345
- f"Annotations: {len(coco_annotations['annotations'])}, "
4346
- f"Categories: {len(coco_annotations['categories'])}"
4347
- )
4348
-
4349
- # Save aggregated YOLO classes
4350
- if metadata_format == "YOLO" and yolo_classes:
4351
- classes_path = os.path.join(output_folder, "labels", "classes.txt")
4352
- os.makedirs(os.path.dirname(classes_path), exist_ok=True)
4353
- sorted_classes = sorted(yolo_classes)
4354
- with open(classes_path, "w") as f:
4355
- for cls in sorted_classes:
4356
- f.write(f"{cls}\n")
4357
- if not quiet:
4358
- print(f"\nSaved YOLO classes: {classes_path}")
4359
- print(f" Total classes: {len(sorted_classes)}")
4360
-
4361
3315
  # Print batch summary
4362
3316
  if not quiet:
4363
3317
  print("\n" + "=" * 60)
@@ -4380,12 +3334,7 @@ def export_geotiff_tiles_batch(
4380
3334
 
4381
3335
  print(f"Output saved to: {output_folder}")
4382
3336
  print(f" Images: {output_images_dir}")
4383
- if output_masks_dir is not None:
4384
- print(f" Masks: {output_masks_dir}")
4385
- if metadata_format in ["PASCAL_VOC", "COCO"] and ann_dir is not None:
4386
- print(f" Annotations: {ann_dir}")
4387
- elif metadata_format == "YOLO":
4388
- print(f" Labels: {os.path.join(output_folder, 'labels')}")
3337
+ print(f" Masks: {output_masks_dir}")
4389
3338
 
4390
3339
  # List failed files if any
4391
3340
  if batch_stats["failed_files"]:
@@ -4411,26 +3360,18 @@ def _process_image_mask_pair(
4411
3360
  all_touched=True,
4412
3361
  skip_empty_tiles=False,
4413
3362
  quiet=False,
4414
- mask_gdf=None,
4415
- use_single_mask_file=False,
4416
- metadata_format="PASCAL_VOC",
4417
- ann_dir=None,
4418
3363
  ):
4419
3364
  """
4420
3365
  Process a single image-mask pair and save tiles directly to output directories.
4421
3366
 
4422
- Args:
4423
- mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
4424
- use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
4425
-
4426
3367
  Returns:
4427
3368
  dict: Statistics for this image-mask pair
4428
3369
  """
4429
3370
  import warnings
4430
3371
 
4431
- # Determine if mask data is raster or vector (only if mask_file is provided)
3372
+ # Determine if mask data is raster or vector
4432
3373
  is_class_data_raster = False
4433
- if mask_file is not None and isinstance(mask_file, str):
3374
+ if isinstance(mask_file, str):
4434
3375
  file_ext = Path(mask_file).suffix.lower()
4435
3376
  # Common raster extensions
4436
3377
  if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
@@ -4447,13 +3388,6 @@ def _process_image_mask_pair(
4447
3388
  "errors": 0,
4448
3389
  }
4449
3390
 
4450
- # Initialize COCO/YOLO tracking for this image
4451
- if metadata_format == "COCO":
4452
- stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
4453
- coco_ann_id = 0
4454
- if metadata_format == "YOLO":
4455
- stats["yolo_classes"] = set()
4456
-
4457
3391
  # Open the input raster
4458
3392
  with rasterio.open(image_file) as src:
4459
3393
  # Calculate number of tiles
@@ -4464,10 +3398,10 @@ def _process_image_mask_pair(
4464
3398
  if max_tiles is None:
4465
3399
  max_tiles = total_tiles
4466
3400
 
4467
- # Process classification data (only if mask_file is provided)
3401
+ # Process classification data
4468
3402
  class_to_id = {}
4469
3403
 
4470
- if mask_file is not None and is_class_data_raster:
3404
+ if is_class_data_raster:
4471
3405
  # Load raster class data
4472
3406
  with rasterio.open(mask_file) as class_src:
4473
3407
  # Check if raster CRS matches
@@ -4494,39 +3428,14 @@ def _process_image_mask_pair(
4494
3428
 
4495
3429
  # Create class mapping
4496
3430
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
4497
- elif mask_file is not None:
3431
+ else:
4498
3432
  # Load vector class data
4499
3433
  try:
4500
- if use_single_mask_file and mask_gdf is not None:
4501
- # Using pre-loaded single mask file - spatially filter to image bounds
4502
- # Get image bounds
4503
- image_bounds = box(*src.bounds)
4504
- image_gdf = gpd.GeoDataFrame(
4505
- {"geometry": [image_bounds]}, crs=src.crs
4506
- )
4507
-
4508
- # Reproject mask if needed
4509
- if mask_gdf.crs != src.crs:
4510
- mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
4511
- else:
4512
- mask_gdf_reprojected = mask_gdf
4513
-
4514
- # Spatially filter features that intersect with image bounds
4515
- gdf = mask_gdf_reprojected[
4516
- mask_gdf_reprojected.intersects(image_bounds)
4517
- ].copy()
4518
-
4519
- if not quiet and len(gdf) > 0:
4520
- print(
4521
- f" Filtered to {len(gdf)} features intersecting image bounds"
4522
- )
4523
- else:
4524
- # Load individual mask file
4525
- gdf = gpd.read_file(mask_file)
3434
+ gdf = gpd.read_file(mask_file)
4526
3435
 
4527
- # Always reproject to match raster CRS
4528
- if gdf.crs != src.crs:
4529
- gdf = gdf.to_crs(src.crs)
3436
+ # Always reproject to match raster CRS
3437
+ if gdf.crs != src.crs:
3438
+ gdf = gdf.to_crs(src.crs)
4530
3439
 
4531
3440
  # Apply buffer if specified
4532
3441
  if buffer_radius > 0:
@@ -4546,6 +3455,9 @@ def _process_image_mask_pair(
4546
3455
  tile_index = 0
4547
3456
  for y in range(num_tiles_y):
4548
3457
  for x in range(num_tiles_x):
3458
+ if tile_index >= max_tiles:
3459
+ break
3460
+
4549
3461
  # Calculate window coordinates
4550
3462
  window_x = x * stride
4551
3463
  window_y = y * stride
@@ -4570,12 +3482,12 @@ def _process_image_mask_pair(
4570
3482
 
4571
3483
  window_bounds = box(minx, miny, maxx, maxy)
4572
3484
 
4573
- # Create label mask (only if mask_file is provided)
3485
+ # Create label mask
4574
3486
  label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
4575
3487
  has_features = False
4576
3488
 
4577
- # Process classification data to create labels (only if mask_file is provided)
4578
- if mask_file is not None and is_class_data_raster:
3489
+ # Process classification data to create labels
3490
+ if is_class_data_raster:
4579
3491
  # For raster class data
4580
3492
  with rasterio.open(mask_file) as class_src:
4581
3493
  # Get corresponding window in class raster
@@ -4608,7 +3520,7 @@ def _process_image_mask_pair(
4608
3520
  if not quiet:
4609
3521
  print(f"Error reading class raster window: {e}")
4610
3522
  stats["errors"] += 1
4611
- elif mask_file is not None:
3523
+ else:
4612
3524
  # For vector class data
4613
3525
  # Find features that intersect with window
4614
3526
  window_features = gdf[gdf.intersects(window_bounds)]
@@ -4646,14 +3558,11 @@ def _process_image_mask_pair(
4646
3558
  print(f"Error rasterizing feature {idx}: {e}")
4647
3559
  stats["errors"] += 1
4648
3560
 
4649
- # Skip tile if no features and skip_empty_tiles is True (only applies when masks are provided)
4650
- if mask_file is not None and skip_empty_tiles and not has_features:
3561
+ # Skip tile if no features and skip_empty_tiles is True
3562
+ if skip_empty_tiles and not has_features:
3563
+ tile_index += 1
4651
3564
  continue
4652
3565
 
4653
- # Check if we've reached max_tiles before saving
4654
- if tile_index >= max_tiles:
4655
- break
4656
-
4657
3566
  # Generate unique tile name
4658
3567
  tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
4659
3568
 
@@ -4684,225 +3593,29 @@ def _process_image_mask_pair(
4684
3593
  print(f"ERROR saving image GeoTIFF: {e}")
4685
3594
  stats["errors"] += 1
4686
3595
 
4687
- # Export label as GeoTIFF (only if mask_file and output_masks_dir are provided)
4688
- if mask_file is not None and output_masks_dir is not None:
4689
- # Create profile for label GeoTIFF
4690
- label_profile = {
4691
- "driver": "GTiff",
4692
- "height": tile_size,
4693
- "width": tile_size,
4694
- "count": 1,
4695
- "dtype": "uint8",
4696
- "crs": src.crs,
4697
- "transform": window_transform,
4698
- }
4699
-
4700
- label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
4701
- try:
4702
- with rasterio.open(label_path, "w", **label_profile) as dst:
4703
- dst.write(label_mask.astype(np.uint8), 1)
4704
-
4705
- if has_features:
4706
- stats["tiles_with_features"] += 1
4707
- except Exception as e:
4708
- if not quiet:
4709
- print(f"ERROR saving label GeoTIFF: {e}")
4710
- stats["errors"] += 1
4711
-
4712
- # Generate annotation metadata based on format (only if mask_file is provided)
4713
- if (
4714
- mask_file is not None
4715
- and metadata_format == "PASCAL_VOC"
4716
- and ann_dir
4717
- ):
4718
- # Create PASCAL VOC XML annotation
4719
- from lxml import etree as ET
4720
-
4721
- annotation = ET.Element("annotation")
4722
- ET.SubElement(annotation, "folder").text = os.path.basename(
4723
- output_images_dir
4724
- )
4725
- ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
4726
- ET.SubElement(annotation, "path").text = image_path
4727
-
4728
- source = ET.SubElement(annotation, "source")
4729
- ET.SubElement(source, "database").text = "GeoAI"
4730
-
4731
- size = ET.SubElement(annotation, "size")
4732
- ET.SubElement(size, "width").text = str(tile_size)
4733
- ET.SubElement(size, "height").text = str(tile_size)
4734
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
4735
-
4736
- ET.SubElement(annotation, "segmented").text = "1"
4737
-
4738
- # Find connected components for instance segmentation
4739
- from scipy import ndimage
4740
-
4741
- for class_id in np.unique(label_mask):
4742
- if class_id == 0:
4743
- continue
4744
-
4745
- class_mask = (label_mask == class_id).astype(np.uint8)
4746
- labeled_array, num_features = ndimage.label(class_mask)
4747
-
4748
- for instance_id in range(1, num_features + 1):
4749
- instance_mask = labeled_array == instance_id
4750
- coords = np.argwhere(instance_mask)
4751
-
4752
- if len(coords) == 0:
4753
- continue
4754
-
4755
- ymin, xmin = coords.min(axis=0)
4756
- ymax, xmax = coords.max(axis=0)
4757
-
4758
- obj = ET.SubElement(annotation, "object")
4759
- class_name = next(
4760
- (k for k, v in class_to_id.items() if v == class_id),
4761
- str(class_id),
4762
- )
4763
- ET.SubElement(obj, "name").text = str(class_name)
4764
- ET.SubElement(obj, "pose").text = "Unspecified"
4765
- ET.SubElement(obj, "truncated").text = "0"
4766
- ET.SubElement(obj, "difficult").text = "0"
4767
-
4768
- bndbox = ET.SubElement(obj, "bndbox")
4769
- ET.SubElement(bndbox, "xmin").text = str(int(xmin))
4770
- ET.SubElement(bndbox, "ymin").text = str(int(ymin))
4771
- ET.SubElement(bndbox, "xmax").text = str(int(xmax))
4772
- ET.SubElement(bndbox, "ymax").text = str(int(ymax))
4773
-
4774
- # Save XML file
4775
- xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
4776
- tree = ET.ElementTree(annotation)
4777
- tree.write(xml_path, pretty_print=True, encoding="utf-8")
4778
-
4779
- elif mask_file is not None and metadata_format == "COCO":
4780
- # Add COCO image entry
4781
- image_id = int(global_tile_counter + tile_index)
4782
- stats["coco_data"]["images"].append(
4783
- {
4784
- "id": image_id,
4785
- "file_name": f"{tile_name}.tif",
4786
- "width": int(tile_size),
4787
- "height": int(tile_size),
4788
- }
4789
- )
4790
-
4791
- # Add COCO categories (only once per unique class)
4792
- for class_val, class_id in class_to_id.items():
4793
- if not any(
4794
- c["id"] == class_id
4795
- for c in stats["coco_data"]["categories"]
4796
- ):
4797
- stats["coco_data"]["categories"].append(
4798
- {
4799
- "id": int(class_id),
4800
- "name": str(class_val),
4801
- "supercategory": "object",
4802
- }
4803
- )
4804
-
4805
- # Add COCO annotations (instance segmentation)
4806
- from scipy import ndimage
4807
- from skimage import measure
4808
-
4809
- for class_id in np.unique(label_mask):
4810
- if class_id == 0:
4811
- continue
4812
-
4813
- class_mask = (label_mask == class_id).astype(np.uint8)
4814
- labeled_array, num_features = ndimage.label(class_mask)
4815
-
4816
- for instance_id in range(1, num_features + 1):
4817
- instance_mask = (labeled_array == instance_id).astype(
4818
- np.uint8
4819
- )
4820
- coords = np.argwhere(instance_mask)
4821
-
4822
- if len(coords) == 0:
4823
- continue
4824
-
4825
- ymin, xmin = coords.min(axis=0)
4826
- ymax, xmax = coords.max(axis=0)
4827
-
4828
- bbox = [
4829
- int(xmin),
4830
- int(ymin),
4831
- int(xmax - xmin),
4832
- int(ymax - ymin),
4833
- ]
4834
- area = int(np.sum(instance_mask))
4835
-
4836
- # Find contours for segmentation
4837
- contours = measure.find_contours(instance_mask, 0.5)
4838
- segmentation = []
4839
- for contour in contours:
4840
- contour = np.flip(contour, axis=1)
4841
- segmentation_points = contour.ravel().tolist()
4842
- if len(segmentation_points) >= 6:
4843
- segmentation.append(segmentation_points)
4844
-
4845
- if segmentation:
4846
- stats["coco_data"]["annotations"].append(
4847
- {
4848
- "id": int(coco_ann_id),
4849
- "image_id": int(image_id),
4850
- "category_id": int(class_id),
4851
- "bbox": bbox,
4852
- "area": area,
4853
- "segmentation": segmentation,
4854
- "iscrowd": 0,
4855
- }
4856
- )
4857
- coco_ann_id += 1
4858
-
4859
- elif mask_file is not None and metadata_format == "YOLO":
4860
- # Create YOLO labels directory if needed
4861
- labels_dir = os.path.join(
4862
- os.path.dirname(output_images_dir), "labels"
4863
- )
4864
- os.makedirs(labels_dir, exist_ok=True)
4865
-
4866
- # Generate YOLO annotation file
4867
- yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
4868
- from scipy import ndimage
4869
-
4870
- with open(yolo_path, "w") as yolo_file:
4871
- for class_id in np.unique(label_mask):
4872
- if class_id == 0:
4873
- continue
4874
-
4875
- # Track class for classes.txt
4876
- class_name = next(
4877
- (k for k, v in class_to_id.items() if v == class_id),
4878
- str(class_id),
4879
- )
4880
- stats["yolo_classes"].add(class_name)
4881
-
4882
- class_mask = (label_mask == class_id).astype(np.uint8)
4883
- labeled_array, num_features = ndimage.label(class_mask)
4884
-
4885
- for instance_id in range(1, num_features + 1):
4886
- instance_mask = labeled_array == instance_id
4887
- coords = np.argwhere(instance_mask)
4888
-
4889
- if len(coords) == 0:
4890
- continue
4891
-
4892
- ymin, xmin = coords.min(axis=0)
4893
- ymax, xmax = coords.max(axis=0)
3596
+ # Create profile for label GeoTIFF
3597
+ label_profile = {
3598
+ "driver": "GTiff",
3599
+ "height": tile_size,
3600
+ "width": tile_size,
3601
+ "count": 1,
3602
+ "dtype": "uint8",
3603
+ "crs": src.crs,
3604
+ "transform": window_transform,
3605
+ }
4894
3606
 
4895
- # Convert to YOLO format (normalized center coordinates)
4896
- x_center = ((xmin + xmax) / 2) / tile_size
4897
- y_center = ((ymin + ymax) / 2) / tile_size
4898
- width = (xmax - xmin) / tile_size
4899
- height = (ymax - ymin) / tile_size
3607
+ # Export label as GeoTIFF
3608
+ label_path = os.path.join(output_masks_dir, f"{tile_name}.tif")
3609
+ try:
3610
+ with rasterio.open(label_path, "w", **label_profile) as dst:
3611
+ dst.write(label_mask.astype(np.uint8), 1)
4900
3612
 
4901
- # YOLO uses 0-based class indices
4902
- yolo_class_id = class_id - 1
4903
- yolo_file.write(
4904
- f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
4905
- )
3613
+ if has_features:
3614
+ stats["tiles_with_features"] += 1
3615
+ except Exception as e:
3616
+ if not quiet:
3617
+ print(f"ERROR saving label GeoTIFF: {e}")
3618
+ stats["errors"] += 1
4906
3619
 
4907
3620
  tile_index += 1
4908
3621
  if tile_index >= max_tiles:
@@ -4914,179 +3627,6 @@ def _process_image_mask_pair(
4914
3627
  return stats
4915
3628
 
4916
3629
 
4917
- def display_training_tiles(
4918
- output_dir,
4919
- num_tiles=6,
4920
- figsize=(18, 6),
4921
- cmap="gray",
4922
- save_path=None,
4923
- ):
4924
- """
4925
- Display image and mask tile pairs from training data output.
4926
-
4927
- Args:
4928
- output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
4929
- num_tiles (int): Number of tile pairs to display (default: 6)
4930
- figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
4931
- cmap (str): Colormap for mask display (default: 'gray')
4932
- save_path (str, optional): If provided, save figure to this path instead of displaying
4933
-
4934
- Returns:
4935
- tuple: (fig, axes) matplotlib figure and axes objects
4936
-
4937
- Example:
4938
- >>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
4939
- >>> # Or save to file
4940
- >>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
4941
- """
4942
- import matplotlib.pyplot as plt
4943
-
4944
- # Get list of image tiles
4945
- images_dir = os.path.join(output_dir, "images")
4946
- if not os.path.exists(images_dir):
4947
- raise ValueError(f"Images directory not found: {images_dir}")
4948
-
4949
- image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
4950
-
4951
- if not image_tiles:
4952
- raise ValueError(f"No image tiles found in {images_dir}")
4953
-
4954
- # Limit to available tiles
4955
- num_tiles = min(num_tiles, len(image_tiles))
4956
-
4957
- # Create figure with subplots
4958
- fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
4959
-
4960
- # Handle case where num_tiles is 1
4961
- if num_tiles == 1:
4962
- axes = axes.reshape(2, 1)
4963
-
4964
- for idx, tile_name in enumerate(image_tiles):
4965
- # Load and display image tile
4966
- image_path = os.path.join(output_dir, "images", tile_name)
4967
- with rasterio.open(image_path) as src:
4968
- show(src, ax=axes[0, idx], title=f"Image {idx+1}")
4969
-
4970
- # Load and display mask tile
4971
- mask_path = os.path.join(output_dir, "masks", tile_name)
4972
- if os.path.exists(mask_path):
4973
- with rasterio.open(mask_path) as src:
4974
- show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
4975
- else:
4976
- axes[1, idx].text(
4977
- 0.5,
4978
- 0.5,
4979
- "Mask not found",
4980
- ha="center",
4981
- va="center",
4982
- transform=axes[1, idx].transAxes,
4983
- )
4984
- axes[1, idx].set_title(f"Mask {idx+1}")
4985
-
4986
- plt.tight_layout()
4987
-
4988
- # Save or show
4989
- if save_path:
4990
- plt.savefig(save_path, dpi=150, bbox_inches="tight")
4991
- plt.close(fig)
4992
- print(f"Figure saved to: {save_path}")
4993
- else:
4994
- plt.show()
4995
-
4996
- return fig, axes
4997
-
4998
-
4999
- def display_image_with_vector(
5000
- image_path,
5001
- vector_path,
5002
- figsize=(16, 8),
5003
- vector_color="red",
5004
- vector_linewidth=1,
5005
- vector_facecolor="none",
5006
- save_path=None,
5007
- ):
5008
- """
5009
- Display a raster image alongside the same image with vector overlay.
5010
-
5011
- Args:
5012
- image_path (str): Path to raster image file
5013
- vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
5014
- figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
5015
- vector_color (str): Edge color for vector features (default: 'red')
5016
- vector_linewidth (float): Line width for vector features (default: 1)
5017
- vector_facecolor (str): Fill color for vector features (default: 'none')
5018
- save_path (str, optional): If provided, save figure to this path instead of displaying
5019
-
5020
- Returns:
5021
- tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
5022
-
5023
- Example:
5024
- >>> fig, axes, info = display_image_with_vector(
5025
- ... 'image.tif',
5026
- ... 'buildings.geojson',
5027
- ... vector_color='blue'
5028
- ... )
5029
- >>> print(f"Number of features: {info['num_features']}")
5030
- """
5031
- import matplotlib.pyplot as plt
5032
-
5033
- # Validate inputs
5034
- if not os.path.exists(image_path):
5035
- raise ValueError(f"Image file not found: {image_path}")
5036
- if not os.path.exists(vector_path):
5037
- raise ValueError(f"Vector file not found: {vector_path}")
5038
-
5039
- # Create figure
5040
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
5041
-
5042
- # Load and display image
5043
- with rasterio.open(image_path) as src:
5044
- # Plot image only
5045
- show(src, ax=ax1, title="Image")
5046
-
5047
- # Load vector data
5048
- vector_data = gpd.read_file(vector_path)
5049
-
5050
- # Reproject to image CRS if needed
5051
- if vector_data.crs != src.crs:
5052
- vector_data = vector_data.to_crs(src.crs)
5053
-
5054
- # Plot image with vector overlay
5055
- show(
5056
- src,
5057
- ax=ax2,
5058
- title=f"Image with {len(vector_data)} Vector Features",
5059
- )
5060
- vector_data.plot(
5061
- ax=ax2,
5062
- facecolor=vector_facecolor,
5063
- edgecolor=vector_color,
5064
- linewidth=vector_linewidth,
5065
- )
5066
-
5067
- # Collect metadata
5068
- info = {
5069
- "image_shape": src.shape,
5070
- "image_crs": src.crs,
5071
- "image_bounds": src.bounds,
5072
- "num_features": len(vector_data),
5073
- "vector_crs": vector_data.crs,
5074
- "vector_bounds": vector_data.total_bounds,
5075
- }
5076
-
5077
- plt.tight_layout()
5078
-
5079
- # Save or show
5080
- if save_path:
5081
- plt.savefig(save_path, dpi=150, bbox_inches="tight")
5082
- plt.close(fig)
5083
- print(f"Figure saved to: {save_path}")
5084
- else:
5085
- plt.show()
5086
-
5087
- return fig, (ax1, ax2), info
5088
-
5089
-
5090
3630
  def create_overview_image(
5091
3631
  src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
5092
3632
  ) -> str:
@@ -8981,39 +7521,17 @@ def write_colormap(
8981
7521
 
8982
7522
  def plot_performance_metrics(
8983
7523
  history_path: str,
8984
- figsize: Optional[Tuple[int, int]] = None,
7524
+ figsize: Tuple[int, int] = (15, 5),
8985
7525
  verbose: bool = True,
8986
7526
  save_path: Optional[str] = None,
8987
- csv_path: Optional[str] = None,
8988
7527
  kwargs: Optional[Dict] = None,
8989
- ) -> pd.DataFrame:
8990
- """Plot performance metrics from a training history object and return as DataFrame.
8991
-
8992
- This function loads training history, plots available metrics (loss, IoU, F1,
8993
- precision, recall), optionally exports to CSV, and returns all metrics as a
8994
- pandas DataFrame for further analysis.
7528
+ ) -> None:
7529
+ """Plot performance metrics from a history object.
8995
7530
 
8996
7531
  Args:
8997
- history_path (str): Path to the saved training history (.pth file).
8998
- figsize (Optional[Tuple[int, int]]): Figure size in inches. If None,
8999
- automatically determined based on number of metrics.
9000
- verbose (bool): Whether to print best and final metric values. Defaults to True.
9001
- save_path (Optional[str]): Path to save the plot image. If None, plot is not saved.
9002
- csv_path (Optional[str]): Path to export metrics as CSV. If None, CSV is not exported.
9003
- kwargs (Optional[Dict]): Additional keyword arguments for plt.savefig().
9004
-
9005
- Returns:
9006
- pd.DataFrame: DataFrame containing all metrics with columns for epoch and each metric.
9007
- Columns include: 'epoch', 'train_loss', 'val_loss', 'val_iou', 'val_f1',
9008
- 'val_precision', 'val_recall' (depending on availability in history).
9009
-
9010
- Example:
9011
- >>> df = plot_performance_metrics(
9012
- ... 'training_history.pth',
9013
- ... save_path='metrics_plot.png',
9014
- ... csv_path='metrics.csv'
9015
- ... )
9016
- >>> print(df.head())
7532
+ history_path: The history object to plot.
7533
+ figsize: The figure size.
7534
+ verbose: Whether to print the best and final metrics.
9017
7535
  """
9018
7536
  if kwargs is None:
9019
7537
  kwargs = {}
@@ -9023,135 +7541,65 @@ def plot_performance_metrics(
9023
7541
  train_loss_key = "train_losses" if "train_losses" in history else "train_loss"
9024
7542
  val_loss_key = "val_losses" if "val_losses" in history else "val_loss"
9025
7543
  val_iou_key = "val_ious" if "val_ious" in history else "val_iou"
9026
- # Support both new (f1) and old (dice) key formats for backward compatibility
9027
- val_f1_key = (
9028
- "val_f1s"
9029
- if "val_f1s" in history
9030
- else ("val_dices" if "val_dices" in history else "val_dice")
9031
- )
9032
- # Add support for precision and recall
9033
- val_precision_key = (
9034
- "val_precisions" if "val_precisions" in history else "val_precision"
9035
- )
9036
- val_recall_key = "val_recalls" if "val_recalls" in history else "val_recall"
9037
-
9038
- # Collect available metrics for plotting
9039
- available_metrics = []
9040
- metric_info = {
9041
- "Loss": (train_loss_key, val_loss_key, ["Train Loss", "Val Loss"]),
9042
- "IoU": (val_iou_key, None, ["Val IoU"]),
9043
- "F1": (val_f1_key, None, ["Val F1"]),
9044
- "Precision": (val_precision_key, None, ["Val Precision"]),
9045
- "Recall": (val_recall_key, None, ["Val Recall"]),
9046
- }
7544
+ val_dice_key = "val_dices" if "val_dices" in history else "val_dice"
9047
7545
 
9048
- for metric_name, (key1, key2, labels) in metric_info.items():
9049
- if key1 in history or (key2 and key2 in history):
9050
- available_metrics.append((metric_name, key1, key2, labels))
7546
+ # Determine number of subplots based on available metrics
7547
+ has_dice = val_dice_key in history
7548
+ n_plots = 3 if has_dice else 2
7549
+ figsize = (15, 5) if has_dice else (10, 5)
9051
7550
 
9052
- # Determine number of subplots and figure size
9053
- n_plots = len(available_metrics)
9054
- if figsize is None:
9055
- figsize = (5 * n_plots, 5)
7551
+ plt.figure(figsize=figsize)
9056
7552
 
9057
- # Create DataFrame for all metrics
9058
- n_epochs = 0
9059
- df_data = {}
9060
-
9061
- # Add epochs
9062
- if "epochs" in history:
9063
- df_data["epoch"] = history["epochs"]
9064
- n_epochs = len(history["epochs"])
9065
- elif train_loss_key in history:
9066
- n_epochs = len(history[train_loss_key])
9067
- df_data["epoch"] = list(range(1, n_epochs + 1))
9068
-
9069
- # Add all available metrics to DataFrame
7553
+ # Plot loss
7554
+ plt.subplot(1, n_plots, 1)
9070
7555
  if train_loss_key in history:
9071
- df_data["train_loss"] = history[train_loss_key]
7556
+ plt.plot(history[train_loss_key], label="Train Loss")
9072
7557
  if val_loss_key in history:
9073
- df_data["val_loss"] = history[val_loss_key]
7558
+ plt.plot(history[val_loss_key], label="Val Loss")
7559
+ plt.title("Loss")
7560
+ plt.xlabel("Epoch")
7561
+ plt.ylabel("Loss")
7562
+ plt.legend()
7563
+ plt.grid(True)
7564
+
7565
+ # Plot IoU
7566
+ plt.subplot(1, n_plots, 2)
9074
7567
  if val_iou_key in history:
9075
- df_data["val_iou"] = history[val_iou_key]
9076
- if val_f1_key in history:
9077
- df_data["val_f1"] = history[val_f1_key]
9078
- if val_precision_key in history:
9079
- df_data["val_precision"] = history[val_precision_key]
9080
- if val_recall_key in history:
9081
- df_data["val_recall"] = history[val_recall_key]
9082
-
9083
- # Create DataFrame
9084
- df = pd.DataFrame(df_data)
9085
-
9086
- # Export to CSV if requested
9087
- if csv_path:
9088
- df.to_csv(csv_path, index=False)
9089
- if verbose:
9090
- print(f"Metrics exported to: {csv_path}")
9091
-
9092
- # Create plots
9093
- if n_plots > 0:
9094
- fig, axes = plt.subplots(1, n_plots, figsize=figsize)
9095
- if n_plots == 1:
9096
- axes = [axes]
9097
-
9098
- for idx, (metric_name, key1, key2, labels) in enumerate(available_metrics):
9099
- ax = axes[idx]
9100
-
9101
- if metric_name == "Loss":
9102
- # Special handling for loss (has both train and val)
9103
- if key1 in history:
9104
- ax.plot(history[key1], label=labels[0])
9105
- if key2 and key2 in history:
9106
- ax.plot(history[key2], label=labels[1])
9107
- else:
9108
- # Single metric plots
9109
- if key1 in history:
9110
- ax.plot(history[key1], label=labels[0])
9111
-
9112
- ax.set_title(metric_name)
9113
- ax.set_xlabel("Epoch")
9114
- ax.set_ylabel(metric_name)
9115
- ax.legend()
9116
- ax.grid(True)
7568
+ plt.plot(history[val_iou_key], label="Val IoU")
7569
+ plt.title("IoU Score")
7570
+ plt.xlabel("Epoch")
7571
+ plt.ylabel("IoU")
7572
+ plt.legend()
7573
+ plt.grid(True)
7574
+
7575
+ # Plot Dice if available
7576
+ if has_dice:
7577
+ plt.subplot(1, n_plots, 3)
7578
+ plt.plot(history[val_dice_key], label="Val Dice")
7579
+ plt.title("Dice Score")
7580
+ plt.xlabel("Epoch")
7581
+ plt.ylabel("Dice")
7582
+ plt.legend()
7583
+ plt.grid(True)
9117
7584
 
9118
- plt.tight_layout()
7585
+ plt.tight_layout()
9119
7586
 
9120
- if save_path:
9121
- if "dpi" not in kwargs:
9122
- kwargs["dpi"] = 150
9123
- if "bbox_inches" not in kwargs:
9124
- kwargs["bbox_inches"] = "tight"
9125
- plt.savefig(save_path, **kwargs)
7587
+ if save_path:
7588
+ if "dpi" not in kwargs:
7589
+ kwargs["dpi"] = 150
7590
+ if "bbox_inches" not in kwargs:
7591
+ kwargs["bbox_inches"] = "tight"
7592
+ plt.savefig(save_path, **kwargs)
9126
7593
 
9127
- plt.show()
7594
+ plt.show()
9128
7595
 
9129
- # Print summary statistics
9130
7596
  if verbose:
9131
- print("\n=== Performance Metrics Summary ===")
9132
7597
  if val_iou_key in history:
9133
- print(
9134
- f"IoU - Best: {max(history[val_iou_key]):.4f} | Final: {history[val_iou_key][-1]:.4f}"
9135
- )
9136
- if val_f1_key in history:
9137
- print(
9138
- f"F1 - Best: {max(history[val_f1_key]):.4f} | Final: {history[val_f1_key][-1]:.4f}"
9139
- )
9140
- if val_precision_key in history:
9141
- print(
9142
- f"Precision - Best: {max(history[val_precision_key]):.4f} | Final: {history[val_precision_key][-1]:.4f}"
9143
- )
9144
- if val_recall_key in history:
9145
- print(
9146
- f"Recall - Best: {max(history[val_recall_key]):.4f} | Final: {history[val_recall_key][-1]:.4f}"
9147
- )
9148
- if val_loss_key in history:
9149
- print(
9150
- f"Val Loss - Best: {min(history[val_loss_key]):.4f} | Final: {history[val_loss_key][-1]:.4f}"
9151
- )
9152
- print("===================================\n")
9153
-
9154
- return df
7598
+ print(f"Best IoU: {max(history[val_iou_key]):.4f}")
7599
+ print(f"Final IoU: {history[val_iou_key][-1]:.4f}")
7600
+ if val_dice_key in history:
7601
+ print(f"Best Dice: {max(history[val_dice_key]):.4f}")
7602
+ print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
9155
7603
 
9156
7604
 
9157
7605
  def get_device() -> torch.device: