geoai-py 0.13.2__py2.py3-none-any.whl → 0.15.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -751,6 +751,12 @@ def view_vector_interactive(
751
751
  },
752
752
  }
753
753
 
754
+ # Make it compatible with binder and JupyterHub
755
+ if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
756
+ os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
757
+ f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
758
+ )
759
+
754
760
  basemap_layer_name = None
755
761
  raster_layer = None
756
762
 
@@ -2609,6 +2615,7 @@ def export_geotiff_tiles(
2609
2615
  all_touched=True,
2610
2616
  create_overview=False,
2611
2617
  skip_empty_tiles=False,
2618
+ metadata_format="PASCAL_VOC",
2612
2619
  ):
2613
2620
  """
2614
2621
  Export georeferenced GeoTIFF tiles and labels from raster and classification data.
@@ -2626,6 +2633,7 @@ def export_geotiff_tiles(
2626
2633
  all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
2627
2634
  create_overview (bool): Whether to create an overview image of all tiles
2628
2635
  skip_empty_tiles (bool): If True, skip tiles with no features
2636
+ metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
2629
2637
  """
2630
2638
 
2631
2639
  import logging
@@ -2638,8 +2646,16 @@ def export_geotiff_tiles(
2638
2646
  os.makedirs(image_dir, exist_ok=True)
2639
2647
  label_dir = os.path.join(out_folder, "labels")
2640
2648
  os.makedirs(label_dir, exist_ok=True)
2641
- ann_dir = os.path.join(out_folder, "annotations")
2642
- os.makedirs(ann_dir, exist_ok=True)
2649
+
2650
+ # Create annotation directory based on metadata format
2651
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
2652
+ ann_dir = os.path.join(out_folder, "annotations")
2653
+ os.makedirs(ann_dir, exist_ok=True)
2654
+
2655
+ # Initialize COCO annotations dictionary
2656
+ if metadata_format == "COCO":
2657
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
2658
+ ann_id = 0
2643
2659
 
2644
2660
  # Determine if class data is raster or vector
2645
2661
  is_class_data_raster = False
@@ -2713,6 +2729,17 @@ def export_geotiff_tiles(
2713
2729
 
2714
2730
  # Create class mapping
2715
2731
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
2732
+
2733
+ # Populate COCO categories
2734
+ if metadata_format == "COCO":
2735
+ for cls_val in unique_classes:
2736
+ coco_annotations["categories"].append(
2737
+ {
2738
+ "id": class_to_id[int(cls_val)],
2739
+ "name": str(int(cls_val)),
2740
+ "supercategory": "object",
2741
+ }
2742
+ )
2716
2743
  else:
2717
2744
  # Load vector class data
2718
2745
  try:
@@ -2742,12 +2769,33 @@ def export_geotiff_tiles(
2742
2769
  )
2743
2770
  # Create class mapping
2744
2771
  class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
2772
+
2773
+ # Populate COCO categories
2774
+ if metadata_format == "COCO":
2775
+ for cls_val in unique_classes:
2776
+ coco_annotations["categories"].append(
2777
+ {
2778
+ "id": class_to_id[cls_val],
2779
+ "name": str(cls_val),
2780
+ "supercategory": "object",
2781
+ }
2782
+ )
2745
2783
  else:
2746
2784
  if not quiet:
2747
2785
  print(
2748
2786
  f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
2749
2787
  )
2750
2788
  class_to_id = {1: 1} # Default mapping
2789
+
2790
+ # Populate COCO categories with default
2791
+ if metadata_format == "COCO":
2792
+ coco_annotations["categories"].append(
2793
+ {
2794
+ "id": 1,
2795
+ "name": "object",
2796
+ "supercategory": "object",
2797
+ }
2798
+ )
2751
2799
  except Exception as e:
2752
2800
  raise ValueError(f"Error processing vector data: {e}")
2753
2801
 
@@ -2964,72 +3012,186 @@ def export_geotiff_tiles(
2964
3012
  pbar.write(f"ERROR saving label GeoTIFF: {e}")
2965
3013
  stats["errors"] += 1
2966
3014
 
2967
- # Create XML annotation for object detection if using vector class data
3015
+ # Create annotations for object detection if using vector class data
2968
3016
  if (
2969
3017
  not is_class_data_raster
2970
3018
  and "gdf" in locals()
2971
3019
  and len(window_features) > 0
2972
3020
  ):
2973
- # Create XML annotation
2974
- root = ET.Element("annotation")
2975
- ET.SubElement(root, "folder").text = "images"
2976
- ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
3021
+ if metadata_format == "PASCAL_VOC":
3022
+ # Create XML annotation
3023
+ root = ET.Element("annotation")
3024
+ ET.SubElement(root, "folder").text = "images"
3025
+ ET.SubElement(root, "filename").text = (
3026
+ f"tile_{tile_index:06d}.tif"
3027
+ )
2977
3028
 
2978
- size = ET.SubElement(root, "size")
2979
- ET.SubElement(size, "width").text = str(tile_size)
2980
- ET.SubElement(size, "height").text = str(tile_size)
2981
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
3029
+ size = ET.SubElement(root, "size")
3030
+ ET.SubElement(size, "width").text = str(tile_size)
3031
+ ET.SubElement(size, "height").text = str(tile_size)
3032
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
3033
+
3034
+ # Add georeference information
3035
+ geo = ET.SubElement(root, "georeference")
3036
+ ET.SubElement(geo, "crs").text = str(src.crs)
3037
+ ET.SubElement(geo, "transform").text = str(
3038
+ window_transform
3039
+ ).replace("\n", "")
3040
+ ET.SubElement(geo, "bounds").text = (
3041
+ f"{minx}, {miny}, {maxx}, {maxy}"
3042
+ )
2982
3043
 
2983
- # Add georeference information
2984
- geo = ET.SubElement(root, "georeference")
2985
- ET.SubElement(geo, "crs").text = str(src.crs)
2986
- ET.SubElement(geo, "transform").text = str(
2987
- window_transform
2988
- ).replace("\n", "")
2989
- ET.SubElement(geo, "bounds").text = (
2990
- f"{minx}, {miny}, {maxx}, {maxy}"
2991
- )
3044
+ # Add objects
3045
+ for idx, feature in window_features.iterrows():
3046
+ # Get feature class
3047
+ if class_value_field in feature:
3048
+ class_val = feature[class_value_field]
3049
+ else:
3050
+ class_val = "object"
2992
3051
 
2993
- # Add objects
2994
- for idx, feature in window_features.iterrows():
2995
- # Get feature class
2996
- if class_value_field in feature:
2997
- class_val = feature[class_value_field]
2998
- else:
2999
- class_val = "object"
3052
+ # Get geometry bounds in pixel coordinates
3053
+ geom = feature.geometry.intersection(window_bounds)
3054
+ if not geom.is_empty:
3055
+ # Get bounds in world coordinates
3056
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3057
+
3058
+ # Convert to pixel coordinates
3059
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3060
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3061
+
3062
+ # Ensure coordinates are within tile bounds
3063
+ xmin = max(0, min(tile_size, int(col_min)))
3064
+ ymin = max(0, min(tile_size, int(row_min)))
3065
+ xmax = max(0, min(tile_size, int(col_max)))
3066
+ ymax = max(0, min(tile_size, int(row_max)))
3067
+
3068
+ # Only add if the box has non-zero area
3069
+ if xmax > xmin and ymax > ymin:
3070
+ obj = ET.SubElement(root, "object")
3071
+ ET.SubElement(obj, "name").text = str(class_val)
3072
+ ET.SubElement(obj, "difficult").text = "0"
3073
+
3074
+ bbox = ET.SubElement(obj, "bndbox")
3075
+ ET.SubElement(bbox, "xmin").text = str(xmin)
3076
+ ET.SubElement(bbox, "ymin").text = str(ymin)
3077
+ ET.SubElement(bbox, "xmax").text = str(xmax)
3078
+ ET.SubElement(bbox, "ymax").text = str(ymax)
3079
+
3080
+ # Save XML
3081
+ tree = ET.ElementTree(root)
3082
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3083
+ tree.write(xml_path)
3000
3084
 
3001
- # Get geometry bounds in pixel coordinates
3002
- geom = feature.geometry.intersection(window_bounds)
3003
- if not geom.is_empty:
3004
- # Get bounds in world coordinates
3005
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3006
-
3007
- # Convert to pixel coordinates
3008
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3009
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3010
-
3011
- # Ensure coordinates are within tile bounds
3012
- xmin = max(0, min(tile_size, int(col_min)))
3013
- ymin = max(0, min(tile_size, int(row_min)))
3014
- xmax = max(0, min(tile_size, int(col_max)))
3015
- ymax = max(0, min(tile_size, int(row_max)))
3016
-
3017
- # Only add if the box has non-zero area
3018
- if xmax > xmin and ymax > ymin:
3019
- obj = ET.SubElement(root, "object")
3020
- ET.SubElement(obj, "name").text = str(class_val)
3021
- ET.SubElement(obj, "difficult").text = "0"
3022
-
3023
- bbox = ET.SubElement(obj, "bndbox")
3024
- ET.SubElement(bbox, "xmin").text = str(xmin)
3025
- ET.SubElement(bbox, "ymin").text = str(ymin)
3026
- ET.SubElement(bbox, "xmax").text = str(xmax)
3027
- ET.SubElement(bbox, "ymax").text = str(ymax)
3085
+ elif metadata_format == "COCO":
3086
+ # Add image info
3087
+ image_id = tile_index
3088
+ coco_annotations["images"].append(
3089
+ {
3090
+ "id": image_id,
3091
+ "file_name": f"tile_{tile_index:06d}.tif",
3092
+ "width": tile_size,
3093
+ "height": tile_size,
3094
+ "crs": str(src.crs),
3095
+ "transform": str(window_transform),
3096
+ }
3097
+ )
3028
3098
 
3029
- # Save XML
3030
- tree = ET.ElementTree(root)
3031
- xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3032
- tree.write(xml_path)
3099
+ # Add annotations for each feature
3100
+ for _, feature in window_features.iterrows():
3101
+ # Get feature class
3102
+ if class_value_field in feature:
3103
+ class_val = feature[class_value_field]
3104
+ category_id = class_to_id.get(class_val, 1)
3105
+ else:
3106
+ category_id = 1
3107
+
3108
+ # Get geometry bounds
3109
+ geom = feature.geometry.intersection(window_bounds)
3110
+ if not geom.is_empty:
3111
+ # Get bounds in world coordinates
3112
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3113
+
3114
+ # Convert to pixel coordinates
3115
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3116
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3117
+
3118
+ # Ensure coordinates are within tile bounds
3119
+ xmin = max(0, min(tile_size, int(col_min)))
3120
+ ymin = max(0, min(tile_size, int(row_min)))
3121
+ xmax = max(0, min(tile_size, int(col_max)))
3122
+ ymax = max(0, min(tile_size, int(row_max)))
3123
+
3124
+ # Skip if box is too small
3125
+ if xmax - xmin < 1 or ymax - ymin < 1:
3126
+ continue
3127
+
3128
+ width = xmax - xmin
3129
+ height = ymax - ymin
3130
+
3131
+ # Add annotation
3132
+ ann_id += 1
3133
+ coco_annotations["annotations"].append(
3134
+ {
3135
+ "id": ann_id,
3136
+ "image_id": image_id,
3137
+ "category_id": category_id,
3138
+ "bbox": [xmin, ymin, width, height],
3139
+ "area": width * height,
3140
+ "iscrowd": 0,
3141
+ }
3142
+ )
3143
+
3144
+ elif metadata_format == "YOLO":
3145
+ # Create YOLO format annotations
3146
+ yolo_annotations = []
3147
+
3148
+ for _, feature in window_features.iterrows():
3149
+ # Get feature class
3150
+ if class_value_field in feature:
3151
+ class_val = feature[class_value_field]
3152
+ # YOLO uses 0-indexed class IDs
3153
+ class_id = class_to_id.get(class_val, 1) - 1
3154
+ else:
3155
+ class_id = 0
3156
+
3157
+ # Get geometry bounds
3158
+ geom = feature.geometry.intersection(window_bounds)
3159
+ if not geom.is_empty:
3160
+ # Get bounds in world coordinates
3161
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3162
+
3163
+ # Convert to pixel coordinates
3164
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3165
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3166
+
3167
+ # Ensure coordinates are within tile bounds
3168
+ xmin = max(0, min(tile_size, col_min))
3169
+ ymin = max(0, min(tile_size, row_min))
3170
+ xmax = max(0, min(tile_size, col_max))
3171
+ ymax = max(0, min(tile_size, row_max))
3172
+
3173
+ # Skip if box is too small
3174
+ if xmax - xmin < 1 or ymax - ymin < 1:
3175
+ continue
3176
+
3177
+ # Calculate normalized coordinates (YOLO format)
3178
+ x_center = ((xmin + xmax) / 2) / tile_size
3179
+ y_center = ((ymin + ymax) / 2) / tile_size
3180
+ width = (xmax - xmin) / tile_size
3181
+ height = (ymax - ymin) / tile_size
3182
+
3183
+ # Add YOLO annotation line
3184
+ yolo_annotations.append(
3185
+ f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
3186
+ )
3187
+
3188
+ # Save YOLO annotations to text file
3189
+ if yolo_annotations:
3190
+ yolo_path = os.path.join(
3191
+ label_dir, f"tile_{tile_index:06d}.txt"
3192
+ )
3193
+ with open(yolo_path, "w") as f:
3194
+ f.write("\n".join(yolo_annotations))
3033
3195
 
3034
3196
  # Update progress bar
3035
3197
  pbar.update(1)
@@ -3047,6 +3209,39 @@ def export_geotiff_tiles(
3047
3209
  # Close progress bar
3048
3210
  pbar.close()
3049
3211
 
3212
+ # Save COCO annotations if applicable
3213
+ if metadata_format == "COCO":
3214
+ try:
3215
+ with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3216
+ json.dump(coco_annotations, f, indent=2)
3217
+ if not quiet:
3218
+ print(
3219
+ f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
3220
+ f"{len(coco_annotations['annotations'])} annotations, "
3221
+ f"{len(coco_annotations['categories'])} categories"
3222
+ )
3223
+ except Exception as e:
3224
+ if not quiet:
3225
+ print(f"ERROR saving COCO annotations: {e}")
3226
+ stats["errors"] += 1
3227
+
3228
+ # Save YOLO classes file if applicable
3229
+ if metadata_format == "YOLO":
3230
+ try:
3231
+ # Create classes.txt with class names
3232
+ classes_path = os.path.join(out_folder, "classes.txt")
3233
+ # Sort by class ID to ensure correct order
3234
+ sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
3235
+ with open(classes_path, "w") as f:
3236
+ for class_val, _ in sorted_classes:
3237
+ f.write(f"{class_val}\n")
3238
+ if not quiet:
3239
+ print(f"Saved YOLO classes file with {len(class_to_id)} classes")
3240
+ except Exception as e:
3241
+ if not quiet:
3242
+ print(f"ERROR saving YOLO classes file: {e}")
3243
+ stats["errors"] += 1
3244
+
3050
3245
  # Create overview image if requested
3051
3246
  if create_overview and stats["tile_coordinates"]:
3052
3247
  try:
@@ -3115,8 +3310,9 @@ def export_geotiff_tiles(
3115
3310
 
3116
3311
  def export_geotiff_tiles_batch(
3117
3312
  images_folder,
3118
- masks_folder,
3119
- output_folder,
3313
+ masks_folder=None,
3314
+ masks_file=None,
3315
+ output_folder=None,
3120
3316
  tile_size=256,
3121
3317
  stride=128,
3122
3318
  class_value_field="class",
@@ -3124,25 +3320,38 @@ def export_geotiff_tiles_batch(
3124
3320
  max_tiles=None,
3125
3321
  quiet=False,
3126
3322
  all_touched=True,
3127
- create_overview=False,
3128
3323
  skip_empty_tiles=False,
3129
3324
  image_extensions=None,
3130
3325
  mask_extensions=None,
3326
+ match_by_name=True,
3327
+ metadata_format="PASCAL_VOC",
3131
3328
  ) -> Dict[str, Any]:
3132
3329
  """
3133
- Export georeferenced GeoTIFF tiles from folders of images and masks.
3330
+ Export georeferenced GeoTIFF tiles from images and masks.
3331
+
3332
+ This function supports three mask input modes:
3333
+ 1. Single vector file covering all images (masks_file parameter)
3334
+ 2. Multiple vector files, one per image (masks_folder parameter)
3335
+ 3. Multiple raster mask files (masks_folder parameter)
3336
+
3337
+ For mode 1 (single vector file), specify masks_file path. The function will
3338
+ use spatial intersection to determine which features apply to each image.
3134
3339
 
3135
- This function processes multiple image-mask pairs from input folders,
3136
- generating tiles for each pair. All image tiles are saved to a single
3137
- 'images' folder and all mask tiles to a single 'masks' folder.
3340
+ For mode 2/3 (multiple mask files), specify masks_folder path. Images and masks
3341
+ are paired either by matching filenames (match_by_name=True) or by sorted order
3342
+ (match_by_name=False).
3138
3343
 
3139
- Images and masks are paired by their sorted order (alphabetically), not by
3140
- filename matching. The number of images and masks must be equal.
3344
+ All image tiles are saved to a single 'images' folder and all mask tiles to a
3345
+ single 'masks' folder within the output directory.
3141
3346
 
3142
3347
  Args:
3143
3348
  images_folder (str): Path to folder containing raster images
3144
- masks_folder (str): Path to folder containing classification masks/vectors
3145
- output_folder (str): Path to output folder
3349
+ masks_folder (str, optional): Path to folder containing classification masks/vectors.
3350
+ Use this for multiple mask files (one per image or raster masks).
3351
+ masks_file (str, optional): Path to a single vector file covering all images.
3352
+ Use this for a single GeoJSON/Shapefile that covers multiple images.
3353
+ output_folder (str, optional): Path to output folder. If None, creates 'tiles'
3354
+ subfolder in images_folder.
3146
3355
  tile_size (int): Size of tiles in pixels (square)
3147
3356
  stride (int): Step size between tiles
3148
3357
  class_value_field (str): Field containing class values (for vector data)
@@ -3154,18 +3363,63 @@ def export_geotiff_tiles_batch(
3154
3363
  skip_empty_tiles (bool): If True, skip tiles with no features
3155
3364
  image_extensions (list): List of image file extensions to process (default: common raster formats)
3156
3365
  mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
3366
+ match_by_name (bool): If True, match image and mask files by base filename.
3367
+ If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
3368
+ metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
3369
+ Default is "PASCAL_VOC".
3157
3370
 
3158
3371
  Returns:
3159
3372
  Dict[str, Any]: Dictionary containing batch processing statistics
3160
3373
 
3161
3374
  Raises:
3162
- ValueError: If no images or masks found, or if counts don't match
3375
+ ValueError: If no images found, or if masks_folder and masks_file are both specified,
3376
+ or if neither is specified, or if counts don't match when using masks_folder with
3377
+ match_by_name=False.
3378
+
3379
+ Examples:
3380
+ # Single vector file covering all images
3381
+ >>> stats = export_geotiff_tiles_batch(
3382
+ ... images_folder='data/images',
3383
+ ... masks_file='data/buildings.geojson',
3384
+ ... output_folder='output/tiles'
3385
+ ... )
3386
+
3387
+ # Multiple vector files, matched by filename
3388
+ >>> stats = export_geotiff_tiles_batch(
3389
+ ... images_folder='data/images',
3390
+ ... masks_folder='data/masks',
3391
+ ... output_folder='output/tiles',
3392
+ ... match_by_name=True
3393
+ ... )
3394
+
3395
+ # Multiple mask files, matched by sorted order
3396
+ >>> stats = export_geotiff_tiles_batch(
3397
+ ... images_folder='data/images',
3398
+ ... masks_folder='data/masks',
3399
+ ... output_folder='output/tiles',
3400
+ ... match_by_name=False
3401
+ ... )
3163
3402
  """
3164
3403
 
3165
3404
  import logging
3166
3405
 
3167
3406
  logging.getLogger("rasterio").setLevel(logging.ERROR)
3168
3407
 
3408
+ # Validate input parameters
3409
+ if masks_folder is not None and masks_file is not None:
3410
+ raise ValueError(
3411
+ "Cannot specify both masks_folder and masks_file. Please use only one."
3412
+ )
3413
+
3414
+ if masks_folder is None and masks_file is None:
3415
+ raise ValueError(
3416
+ "Must specify either masks_folder or masks_file for mask data source."
3417
+ )
3418
+
3419
+ # Default output folder if not specified
3420
+ if output_folder is None:
3421
+ output_folder = os.path.join(images_folder, "tiles")
3422
+
3169
3423
  # Default extensions if not provided
3170
3424
  if image_extensions is None:
3171
3425
  image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
@@ -3196,36 +3450,107 @@ def export_geotiff_tiles_batch(
3196
3450
  os.makedirs(output_images_dir, exist_ok=True)
3197
3451
  os.makedirs(output_masks_dir, exist_ok=True)
3198
3452
 
3453
+ # Create annotation directory based on metadata format
3454
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
3455
+ ann_dir = os.path.join(output_folder, "annotations")
3456
+ os.makedirs(ann_dir, exist_ok=True)
3457
+
3458
+ # Initialize COCO annotations dictionary
3459
+ coco_annotations = None
3460
+ if metadata_format == "COCO":
3461
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
3462
+
3463
+ # Initialize YOLO class set
3464
+ yolo_classes = set() if metadata_format == "YOLO" else None
3465
+
3199
3466
  # Get list of image files
3200
3467
  image_files = []
3201
3468
  for ext in image_extensions:
3202
3469
  pattern = os.path.join(images_folder, f"*{ext}")
3203
3470
  image_files.extend(glob.glob(pattern))
3204
3471
 
3205
- # Get list of mask files
3206
- mask_files = []
3207
- for ext in mask_extensions:
3208
- pattern = os.path.join(masks_folder, f"*{ext}")
3209
- mask_files.extend(glob.glob(pattern))
3210
-
3211
3472
  # Sort files for consistent processing
3212
3473
  image_files.sort()
3213
- mask_files.sort()
3214
3474
 
3215
3475
  if not image_files:
3216
3476
  raise ValueError(
3217
3477
  f"No image files found in {images_folder} with extensions {image_extensions}"
3218
3478
  )
3219
3479
 
3220
- if not mask_files:
3221
- raise ValueError(
3222
- f"No mask files found in {masks_folder} with extensions {mask_extensions}"
3223
- )
3480
+ # Handle different mask input modes
3481
+ use_single_mask_file = masks_file is not None
3482
+ mask_files = []
3483
+ image_mask_pairs = []
3224
3484
 
3225
- if len(image_files) != len(mask_files):
3226
- raise ValueError(
3227
- f"Number of image files ({len(image_files)}) does not match number of mask files ({len(mask_files)})"
3228
- )
3485
+ if use_single_mask_file:
3486
+ # Mode 1: Single vector file covering all images
3487
+ if not os.path.exists(masks_file):
3488
+ raise ValueError(f"Mask file not found: {masks_file}")
3489
+
3490
+ # Load the single mask file once - will be spatially filtered per image
3491
+ single_mask_gdf = gpd.read_file(masks_file)
3492
+
3493
+ if not quiet:
3494
+ print(f"Using single mask file: {masks_file}")
3495
+ print(
3496
+ f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
3497
+ )
3498
+
3499
+ # Create pairs with the same mask file for all images
3500
+ for image_file in image_files:
3501
+ image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
3502
+
3503
+ else:
3504
+ # Mode 2/3: Multiple mask files (vector or raster)
3505
+ # Get list of mask files
3506
+ for ext in mask_extensions:
3507
+ pattern = os.path.join(masks_folder, f"*{ext}")
3508
+ mask_files.extend(glob.glob(pattern))
3509
+
3510
+ # Sort files for consistent processing
3511
+ mask_files.sort()
3512
+
3513
+ if not mask_files:
3514
+ raise ValueError(
3515
+ f"No mask files found in {masks_folder} with extensions {mask_extensions}"
3516
+ )
3517
+
3518
+ # Match images to masks
3519
+ if match_by_name:
3520
+ # Match by base filename
3521
+ image_dict = {
3522
+ os.path.splitext(os.path.basename(f))[0]: f for f in image_files
3523
+ }
3524
+ mask_dict = {
3525
+ os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
3526
+ }
3527
+
3528
+ # Find matching pairs
3529
+ for img_base, img_path in image_dict.items():
3530
+ if img_base in mask_dict:
3531
+ image_mask_pairs.append((img_path, mask_dict[img_base], None))
3532
+ else:
3533
+ if not quiet:
3534
+ print(f"Warning: No mask found for image {img_base}")
3535
+
3536
+ if not image_mask_pairs:
3537
+ raise ValueError(
3538
+ "No matching image-mask pairs found when matching by filename. "
3539
+ "Check that image and mask files have matching base names."
3540
+ )
3541
+
3542
+ else:
3543
+ # Match by sorted order
3544
+ if len(image_files) != len(mask_files):
3545
+ raise ValueError(
3546
+ f"Number of image files ({len(image_files)}) does not match "
3547
+ f"number of mask files ({len(mask_files)}) when matching by sorted order. "
3548
+ f"Use match_by_name=True for filename-based matching."
3549
+ )
3550
+
3551
+ # Create pairs by sorted order
3552
+ for image_file, mask_file in zip(image_files, mask_files):
3553
+ image_mask_pairs.append((image_file, mask_file, None))
3229
3554
 
3230
3555
  # Initialize batch statistics
3231
3556
  batch_stats = {
@@ -3239,23 +3564,24 @@ def export_geotiff_tiles_batch(
3239
3564
  }
3240
3565
 
3241
3566
  if not quiet:
3242
- print(
3243
- f"Found {len(image_files)} image files and {len(mask_files)} mask files to process"
3244
- )
3245
- print(f"Processing batch from {images_folder} and {masks_folder}")
3567
+ if use_single_mask_file:
3568
+ print(f"Found {len(image_files)} image files to process")
3569
+ print(f"Using single mask file: {masks_file}")
3570
+ else:
3571
+ print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
3572
+ print(f"Processing batch from {images_folder} and {masks_folder}")
3246
3573
  print(f"Output folder: {output_folder}")
3247
3574
  print("-" * 60)
3248
3575
 
3249
3576
  # Global tile counter for unique naming
3250
3577
  global_tile_counter = 0
3251
3578
 
3252
- # Process each image-mask pair by sorted order
3253
- for idx, (image_file, mask_file) in enumerate(
3579
+ # Process each image-mask pair
3580
+ for idx, (image_file, mask_file, mask_gdf) in enumerate(
3254
3581
  tqdm(
3255
- zip(image_files, mask_files),
3582
+ image_mask_pairs,
3256
3583
  desc="Processing image pairs",
3257
3584
  disable=quiet,
3258
- total=len(image_files),
3259
3585
  )
3260
3586
  ):
3261
3587
  batch_stats["total_image_pairs"] += 1
@@ -3267,9 +3593,12 @@ def export_geotiff_tiles_batch(
3267
3593
  if not quiet:
3268
3594
  print(f"\nProcessing: {base_name}")
3269
3595
  print(f" Image: {os.path.basename(image_file)}")
3270
- print(f" Mask: {os.path.basename(mask_file)}")
3596
+ if use_single_mask_file:
3597
+ print(f" Mask: {os.path.basename(mask_file)} (spatially filtered)")
3598
+ else:
3599
+ print(f" Mask: {os.path.basename(mask_file)}")
3271
3600
 
3272
- # Process the image-mask pair manually to get direct control over tile saving
3601
+ # Process the image-mask pair
3273
3602
  tiles_generated = _process_image_mask_pair(
3274
3603
  image_file=image_file,
3275
3604
  mask_file=mask_file,
@@ -3285,6 +3614,15 @@ def export_geotiff_tiles_batch(
3285
3614
  all_touched=all_touched,
3286
3615
  skip_empty_tiles=skip_empty_tiles,
3287
3616
  quiet=quiet,
3617
+ mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
3618
+ use_single_mask_file=use_single_mask_file,
3619
+ metadata_format=metadata_format,
3620
+ ann_dir=(
3621
+ ann_dir
3622
+ if "ann_dir" in locals()
3623
+ and metadata_format in ["PASCAL_VOC", "COCO"]
3624
+ else None
3625
+ ),
3288
3626
  )
3289
3627
 
3290
3628
  # Update counters
@@ -3306,6 +3644,23 @@ def export_geotiff_tiles_batch(
3306
3644
  }
3307
3645
  )
3308
3646
 
3647
+ # Aggregate COCO annotations
3648
+ if metadata_format == "COCO" and "coco_data" in tiles_generated:
3649
+ coco_data = tiles_generated["coco_data"]
3650
+ # Add images and annotations
3651
+ coco_annotations["images"].extend(coco_data.get("images", []))
3652
+ coco_annotations["annotations"].extend(coco_data.get("annotations", []))
3653
+ # Merge categories (avoid duplicates)
3654
+ for cat in coco_data.get("categories", []):
3655
+ if not any(
3656
+ c["id"] == cat["id"] for c in coco_annotations["categories"]
3657
+ ):
3658
+ coco_annotations["categories"].append(cat)
3659
+
3660
+ # Aggregate YOLO classes
3661
+ if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
3662
+ yolo_classes.update(tiles_generated["yolo_classes"])
3663
+
3309
3664
  except Exception as e:
3310
3665
  if not quiet:
3311
3666
  print(f"ERROR processing {base_name}: {e}")
@@ -3314,6 +3669,33 @@ def export_geotiff_tiles_batch(
3314
3669
  )
3315
3670
  batch_stats["errors"] += 1
3316
3671
 
3672
+ # Save aggregated COCO annotations
3673
+ if metadata_format == "COCO" and coco_annotations:
3674
+ import json
3675
+
3676
+ coco_path = os.path.join(ann_dir, "instances.json")
3677
+ with open(coco_path, "w") as f:
3678
+ json.dump(coco_annotations, f, indent=2)
3679
+ if not quiet:
3680
+ print(f"\nSaved COCO annotations: {coco_path}")
3681
+ print(
3682
+ f" Images: {len(coco_annotations['images'])}, "
3683
+ f"Annotations: {len(coco_annotations['annotations'])}, "
3684
+ f"Categories: {len(coco_annotations['categories'])}"
3685
+ )
3686
+
3687
+ # Save aggregated YOLO classes
3688
+ if metadata_format == "YOLO" and yolo_classes:
3689
+ classes_path = os.path.join(output_folder, "labels", "classes.txt")
3690
+ os.makedirs(os.path.dirname(classes_path), exist_ok=True)
3691
+ sorted_classes = sorted(yolo_classes)
3692
+ with open(classes_path, "w") as f:
3693
+ for cls in sorted_classes:
3694
+ f.write(f"{cls}\n")
3695
+ if not quiet:
3696
+ print(f"\nSaved YOLO classes: {classes_path}")
3697
+ print(f" Total classes: {len(sorted_classes)}")
3698
+
3317
3699
  # Print batch summary
3318
3700
  if not quiet:
3319
3701
  print("\n" + "=" * 60)
@@ -3337,6 +3719,10 @@ def export_geotiff_tiles_batch(
3337
3719
  print(f"Output saved to: {output_folder}")
3338
3720
  print(f" Images: {output_images_dir}")
3339
3721
  print(f" Masks: {output_masks_dir}")
3722
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
3723
+ print(f" Annotations: {ann_dir}")
3724
+ elif metadata_format == "YOLO":
3725
+ print(f" Labels: {os.path.join(output_folder, 'labels')}")
3340
3726
 
3341
3727
  # List failed files if any
3342
3728
  if batch_stats["failed_files"]:
@@ -3362,10 +3748,18 @@ def _process_image_mask_pair(
3362
3748
  all_touched=True,
3363
3749
  skip_empty_tiles=False,
3364
3750
  quiet=False,
3751
+ mask_gdf=None,
3752
+ use_single_mask_file=False,
3753
+ metadata_format="PASCAL_VOC",
3754
+ ann_dir=None,
3365
3755
  ):
3366
3756
  """
3367
3757
  Process a single image-mask pair and save tiles directly to output directories.
3368
3758
 
3759
+ Args:
3760
+ mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
3761
+ use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
3762
+
3369
3763
  Returns:
3370
3764
  dict: Statistics for this image-mask pair
3371
3765
  """
@@ -3390,6 +3784,13 @@ def _process_image_mask_pair(
3390
3784
  "errors": 0,
3391
3785
  }
3392
3786
 
3787
+ # Initialize COCO/YOLO tracking for this image
3788
+ if metadata_format == "COCO":
3789
+ stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
3790
+ coco_ann_id = 0
3791
+ if metadata_format == "YOLO":
3792
+ stats["yolo_classes"] = set()
3793
+
3393
3794
  # Open the input raster
3394
3795
  with rasterio.open(image_file) as src:
3395
3796
  # Calculate number of tiles
@@ -3433,11 +3834,36 @@ def _process_image_mask_pair(
3433
3834
  else:
3434
3835
  # Load vector class data
3435
3836
  try:
3436
- gdf = gpd.read_file(mask_file)
3837
+ if use_single_mask_file and mask_gdf is not None:
3838
+ # Using pre-loaded single mask file - spatially filter to image bounds
3839
+ # Get image bounds
3840
+ image_bounds = box(*src.bounds)
3841
+ image_gdf = gpd.GeoDataFrame(
3842
+ {"geometry": [image_bounds]}, crs=src.crs
3843
+ )
3437
3844
 
3438
- # Always reproject to match raster CRS
3439
- if gdf.crs != src.crs:
3440
- gdf = gdf.to_crs(src.crs)
3845
+ # Reproject mask if needed
3846
+ if mask_gdf.crs != src.crs:
3847
+ mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
3848
+ else:
3849
+ mask_gdf_reprojected = mask_gdf
3850
+
3851
+ # Spatially filter features that intersect with image bounds
3852
+ gdf = mask_gdf_reprojected[
3853
+ mask_gdf_reprojected.intersects(image_bounds)
3854
+ ].copy()
3855
+
3856
+ if not quiet and len(gdf) > 0:
3857
+ print(
3858
+ f" Filtered to {len(gdf)} features intersecting image bounds"
3859
+ )
3860
+ else:
3861
+ # Load individual mask file
3862
+ gdf = gpd.read_file(mask_file)
3863
+
3864
+ # Always reproject to match raster CRS
3865
+ if gdf.crs != src.crs:
3866
+ gdf = gdf.to_crs(src.crs)
3441
3867
 
3442
3868
  # Apply buffer if specified
3443
3869
  if buffer_radius > 0:
@@ -3457,9 +3883,6 @@ def _process_image_mask_pair(
3457
3883
  tile_index = 0
3458
3884
  for y in range(num_tiles_y):
3459
3885
  for x in range(num_tiles_x):
3460
- if tile_index >= max_tiles:
3461
- break
3462
-
3463
3886
  # Calculate window coordinates
3464
3887
  window_x = x * stride
3465
3888
  window_y = y * stride
@@ -3562,9 +3985,12 @@ def _process_image_mask_pair(
3562
3985
 
3563
3986
  # Skip tile if no features and skip_empty_tiles is True
3564
3987
  if skip_empty_tiles and not has_features:
3565
- tile_index += 1
3566
3988
  continue
3567
3989
 
3990
+ # Check if we've reached max_tiles before saving
3991
+ if tile_index >= max_tiles:
3992
+ break
3993
+
3568
3994
  # Generate unique tile name
3569
3995
  tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
3570
3996
 
@@ -3619,6 +4045,197 @@ def _process_image_mask_pair(
3619
4045
  print(f"ERROR saving label GeoTIFF: {e}")
3620
4046
  stats["errors"] += 1
3621
4047
 
4048
+ # Generate annotation metadata based on format
4049
+ if metadata_format == "PASCAL_VOC" and ann_dir:
4050
+ # Create PASCAL VOC XML annotation
4051
+ from lxml import etree as ET
4052
+
4053
+ annotation = ET.Element("annotation")
4054
+ ET.SubElement(annotation, "folder").text = os.path.basename(
4055
+ output_images_dir
4056
+ )
4057
+ ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
4058
+ ET.SubElement(annotation, "path").text = image_path
4059
+
4060
+ source = ET.SubElement(annotation, "source")
4061
+ ET.SubElement(source, "database").text = "GeoAI"
4062
+
4063
+ size = ET.SubElement(annotation, "size")
4064
+ ET.SubElement(size, "width").text = str(tile_size)
4065
+ ET.SubElement(size, "height").text = str(tile_size)
4066
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
4067
+
4068
+ ET.SubElement(annotation, "segmented").text = "1"
4069
+
4070
+ # Find connected components for instance segmentation
4071
+ from scipy import ndimage
4072
+
4073
+ for class_id in np.unique(label_mask):
4074
+ if class_id == 0:
4075
+ continue
4076
+
4077
+ class_mask = (label_mask == class_id).astype(np.uint8)
4078
+ labeled_array, num_features = ndimage.label(class_mask)
4079
+
4080
+ for instance_id in range(1, num_features + 1):
4081
+ instance_mask = labeled_array == instance_id
4082
+ coords = np.argwhere(instance_mask)
4083
+
4084
+ if len(coords) == 0:
4085
+ continue
4086
+
4087
+ ymin, xmin = coords.min(axis=0)
4088
+ ymax, xmax = coords.max(axis=0)
4089
+
4090
+ obj = ET.SubElement(annotation, "object")
4091
+ class_name = next(
4092
+ (k for k, v in class_to_id.items() if v == class_id),
4093
+ str(class_id),
4094
+ )
4095
+ ET.SubElement(obj, "name").text = str(class_name)
4096
+ ET.SubElement(obj, "pose").text = "Unspecified"
4097
+ ET.SubElement(obj, "truncated").text = "0"
4098
+ ET.SubElement(obj, "difficult").text = "0"
4099
+
4100
+ bndbox = ET.SubElement(obj, "bndbox")
4101
+ ET.SubElement(bndbox, "xmin").text = str(int(xmin))
4102
+ ET.SubElement(bndbox, "ymin").text = str(int(ymin))
4103
+ ET.SubElement(bndbox, "xmax").text = str(int(xmax))
4104
+ ET.SubElement(bndbox, "ymax").text = str(int(ymax))
4105
+
4106
+ # Save XML file
4107
+ xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
4108
+ tree = ET.ElementTree(annotation)
4109
+ tree.write(xml_path, pretty_print=True, encoding="utf-8")
4110
+
4111
+ elif metadata_format == "COCO":
4112
+ # Add COCO image entry
4113
+ image_id = int(global_tile_counter + tile_index)
4114
+ stats["coco_data"]["images"].append(
4115
+ {
4116
+ "id": image_id,
4117
+ "file_name": f"{tile_name}.tif",
4118
+ "width": int(tile_size),
4119
+ "height": int(tile_size),
4120
+ }
4121
+ )
4122
+
4123
+ # Add COCO categories (only once per unique class)
4124
+ for class_val, class_id in class_to_id.items():
4125
+ if not any(
4126
+ c["id"] == class_id
4127
+ for c in stats["coco_data"]["categories"]
4128
+ ):
4129
+ stats["coco_data"]["categories"].append(
4130
+ {
4131
+ "id": int(class_id),
4132
+ "name": str(class_val),
4133
+ "supercategory": "object",
4134
+ }
4135
+ )
4136
+
4137
+ # Add COCO annotations (instance segmentation)
4138
+ from scipy import ndimage
4139
+ from skimage import measure
4140
+
4141
+ for class_id in np.unique(label_mask):
4142
+ if class_id == 0:
4143
+ continue
4144
+
4145
+ class_mask = (label_mask == class_id).astype(np.uint8)
4146
+ labeled_array, num_features = ndimage.label(class_mask)
4147
+
4148
+ for instance_id in range(1, num_features + 1):
4149
+ instance_mask = (labeled_array == instance_id).astype(
4150
+ np.uint8
4151
+ )
4152
+ coords = np.argwhere(instance_mask)
4153
+
4154
+ if len(coords) == 0:
4155
+ continue
4156
+
4157
+ ymin, xmin = coords.min(axis=0)
4158
+ ymax, xmax = coords.max(axis=0)
4159
+
4160
+ bbox = [
4161
+ int(xmin),
4162
+ int(ymin),
4163
+ int(xmax - xmin),
4164
+ int(ymax - ymin),
4165
+ ]
4166
+ area = int(np.sum(instance_mask))
4167
+
4168
+ # Find contours for segmentation
4169
+ contours = measure.find_contours(instance_mask, 0.5)
4170
+ segmentation = []
4171
+ for contour in contours:
4172
+ contour = np.flip(contour, axis=1)
4173
+ segmentation_points = contour.ravel().tolist()
4174
+ if len(segmentation_points) >= 6:
4175
+ segmentation.append(segmentation_points)
4176
+
4177
+ if segmentation:
4178
+ stats["coco_data"]["annotations"].append(
4179
+ {
4180
+ "id": int(coco_ann_id),
4181
+ "image_id": int(image_id),
4182
+ "category_id": int(class_id),
4183
+ "bbox": bbox,
4184
+ "area": area,
4185
+ "segmentation": segmentation,
4186
+ "iscrowd": 0,
4187
+ }
4188
+ )
4189
+ coco_ann_id += 1
4190
+
4191
+ elif metadata_format == "YOLO":
4192
+ # Create YOLO labels directory if needed
4193
+ labels_dir = os.path.join(
4194
+ os.path.dirname(output_images_dir), "labels"
4195
+ )
4196
+ os.makedirs(labels_dir, exist_ok=True)
4197
+
4198
+ # Generate YOLO annotation file
4199
+ yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
4200
+ from scipy import ndimage
4201
+
4202
+ with open(yolo_path, "w") as yolo_file:
4203
+ for class_id in np.unique(label_mask):
4204
+ if class_id == 0:
4205
+ continue
4206
+
4207
+ # Track class for classes.txt
4208
+ class_name = next(
4209
+ (k for k, v in class_to_id.items() if v == class_id),
4210
+ str(class_id),
4211
+ )
4212
+ stats["yolo_classes"].add(class_name)
4213
+
4214
+ class_mask = (label_mask == class_id).astype(np.uint8)
4215
+ labeled_array, num_features = ndimage.label(class_mask)
4216
+
4217
+ for instance_id in range(1, num_features + 1):
4218
+ instance_mask = labeled_array == instance_id
4219
+ coords = np.argwhere(instance_mask)
4220
+
4221
+ if len(coords) == 0:
4222
+ continue
4223
+
4224
+ ymin, xmin = coords.min(axis=0)
4225
+ ymax, xmax = coords.max(axis=0)
4226
+
4227
+ # Convert to YOLO format (normalized center coordinates)
4228
+ x_center = ((xmin + xmax) / 2) / tile_size
4229
+ y_center = ((ymin + ymax) / 2) / tile_size
4230
+ width = (xmax - xmin) / tile_size
4231
+ height = (ymax - ymin) / tile_size
4232
+
4233
+ # YOLO uses 0-based class indices
4234
+ yolo_class_id = class_id - 1
4235
+ yolo_file.write(
4236
+ f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
4237
+ )
4238
+
3622
4239
  tile_index += 1
3623
4240
  if tile_index >= max_tiles:
3624
4241
  break
@@ -3629,6 +4246,179 @@ def _process_image_mask_pair(
3629
4246
  return stats
3630
4247
 
3631
4248
 
4249
+ def display_training_tiles(
4250
+ output_dir,
4251
+ num_tiles=6,
4252
+ figsize=(18, 6),
4253
+ cmap="gray",
4254
+ save_path=None,
4255
+ ):
4256
+ """
4257
+ Display image and mask tile pairs from training data output.
4258
+
4259
+ Args:
4260
+ output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
4261
+ num_tiles (int): Number of tile pairs to display (default: 6)
4262
+ figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
4263
+ cmap (str): Colormap for mask display (default: 'gray')
4264
+ save_path (str, optional): If provided, save figure to this path instead of displaying
4265
+
4266
+ Returns:
4267
+ tuple: (fig, axes) matplotlib figure and axes objects
4268
+
4269
+ Example:
4270
+ >>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
4271
+ >>> # Or save to file
4272
+ >>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
4273
+ """
4274
+ import matplotlib.pyplot as plt
4275
+
4276
+ # Get list of image tiles
4277
+ images_dir = os.path.join(output_dir, "images")
4278
+ if not os.path.exists(images_dir):
4279
+ raise ValueError(f"Images directory not found: {images_dir}")
4280
+
4281
+ image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
4282
+
4283
+ if not image_tiles:
4284
+ raise ValueError(f"No image tiles found in {images_dir}")
4285
+
4286
+ # Limit to available tiles
4287
+ num_tiles = min(num_tiles, len(image_tiles))
4288
+
4289
+ # Create figure with subplots
4290
+ fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
4291
+
4292
+ # Handle case where num_tiles is 1
4293
+ if num_tiles == 1:
4294
+ axes = axes.reshape(2, 1)
4295
+
4296
+ for idx, tile_name in enumerate(image_tiles):
4297
+ # Load and display image tile
4298
+ image_path = os.path.join(output_dir, "images", tile_name)
4299
+ with rasterio.open(image_path) as src:
4300
+ show(src, ax=axes[0, idx], title=f"Image {idx+1}")
4301
+
4302
+ # Load and display mask tile
4303
+ mask_path = os.path.join(output_dir, "masks", tile_name)
4304
+ if os.path.exists(mask_path):
4305
+ with rasterio.open(mask_path) as src:
4306
+ show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
4307
+ else:
4308
+ axes[1, idx].text(
4309
+ 0.5,
4310
+ 0.5,
4311
+ "Mask not found",
4312
+ ha="center",
4313
+ va="center",
4314
+ transform=axes[1, idx].transAxes,
4315
+ )
4316
+ axes[1, idx].set_title(f"Mask {idx+1}")
4317
+
4318
+ plt.tight_layout()
4319
+
4320
+ # Save or show
4321
+ if save_path:
4322
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
4323
+ plt.close(fig)
4324
+ print(f"Figure saved to: {save_path}")
4325
+ else:
4326
+ plt.show()
4327
+
4328
+ return fig, axes
4329
+
4330
+
4331
+ def display_image_with_vector(
4332
+ image_path,
4333
+ vector_path,
4334
+ figsize=(16, 8),
4335
+ vector_color="red",
4336
+ vector_linewidth=1,
4337
+ vector_facecolor="none",
4338
+ save_path=None,
4339
+ ):
4340
+ """
4341
+ Display a raster image alongside the same image with vector overlay.
4342
+
4343
+ Args:
4344
+ image_path (str): Path to raster image file
4345
+ vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
4346
+ figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
4347
+ vector_color (str): Edge color for vector features (default: 'red')
4348
+ vector_linewidth (float): Line width for vector features (default: 1)
4349
+ vector_facecolor (str): Fill color for vector features (default: 'none')
4350
+ save_path (str, optional): If provided, save figure to this path instead of displaying
4351
+
4352
+ Returns:
4353
+ tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
4354
+
4355
+ Example:
4356
+ >>> fig, axes, info = display_image_with_vector(
4357
+ ... 'image.tif',
4358
+ ... 'buildings.geojson',
4359
+ ... vector_color='blue'
4360
+ ... )
4361
+ >>> print(f"Number of features: {info['num_features']}")
4362
+ """
4363
+ import matplotlib.pyplot as plt
4364
+
4365
+ # Validate inputs
4366
+ if not os.path.exists(image_path):
4367
+ raise ValueError(f"Image file not found: {image_path}")
4368
+ if not os.path.exists(vector_path):
4369
+ raise ValueError(f"Vector file not found: {vector_path}")
4370
+
4371
+ # Create figure
4372
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
4373
+
4374
+ # Load and display image
4375
+ with rasterio.open(image_path) as src:
4376
+ # Plot image only
4377
+ show(src, ax=ax1, title="Image")
4378
+
4379
+ # Load vector data
4380
+ vector_data = gpd.read_file(vector_path)
4381
+
4382
+ # Reproject to image CRS if needed
4383
+ if vector_data.crs != src.crs:
4384
+ vector_data = vector_data.to_crs(src.crs)
4385
+
4386
+ # Plot image with vector overlay
4387
+ show(
4388
+ src,
4389
+ ax=ax2,
4390
+ title=f"Image with {len(vector_data)} Vector Features",
4391
+ )
4392
+ vector_data.plot(
4393
+ ax=ax2,
4394
+ facecolor=vector_facecolor,
4395
+ edgecolor=vector_color,
4396
+ linewidth=vector_linewidth,
4397
+ )
4398
+
4399
+ # Collect metadata
4400
+ info = {
4401
+ "image_shape": src.shape,
4402
+ "image_crs": src.crs,
4403
+ "image_bounds": src.bounds,
4404
+ "num_features": len(vector_data),
4405
+ "vector_crs": vector_data.crs,
4406
+ "vector_bounds": vector_data.total_bounds,
4407
+ }
4408
+
4409
+ plt.tight_layout()
4410
+
4411
+ # Save or show
4412
+ if save_path:
4413
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
4414
+ plt.close(fig)
4415
+ print(f"Figure saved to: {save_path}")
4416
+ else:
4417
+ plt.show()
4418
+
4419
+ return fig, (ax1, ax2), info
4420
+
4421
+
3632
4422
  def create_overview_image(
3633
4423
  src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
3634
4424
  ) -> str: