geoai-py 0.14.0__py2.py3-none-any.whl → 0.15.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -751,6 +751,12 @@ def view_vector_interactive(
751
751
  },
752
752
  }
753
753
 
754
+ # Make it compatible with binder and JupyterHub
755
+ if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
756
+ os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
757
+ f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
758
+ )
759
+
754
760
  basemap_layer_name = None
755
761
  raster_layer = None
756
762
 
@@ -2609,6 +2615,7 @@ def export_geotiff_tiles(
2609
2615
  all_touched=True,
2610
2616
  create_overview=False,
2611
2617
  skip_empty_tiles=False,
2618
+ metadata_format="PASCAL_VOC",
2612
2619
  ):
2613
2620
  """
2614
2621
  Export georeferenced GeoTIFF tiles and labels from raster and classification data.
@@ -2626,6 +2633,7 @@ def export_geotiff_tiles(
2626
2633
  all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
2627
2634
  create_overview (bool): Whether to create an overview image of all tiles
2628
2635
  skip_empty_tiles (bool): If True, skip tiles with no features
2636
+ metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
2629
2637
  """
2630
2638
 
2631
2639
  import logging
@@ -2638,8 +2646,16 @@ def export_geotiff_tiles(
2638
2646
  os.makedirs(image_dir, exist_ok=True)
2639
2647
  label_dir = os.path.join(out_folder, "labels")
2640
2648
  os.makedirs(label_dir, exist_ok=True)
2641
- ann_dir = os.path.join(out_folder, "annotations")
2642
- os.makedirs(ann_dir, exist_ok=True)
2649
+
2650
+ # Create annotation directory based on metadata format
2651
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
2652
+ ann_dir = os.path.join(out_folder, "annotations")
2653
+ os.makedirs(ann_dir, exist_ok=True)
2654
+
2655
+ # Initialize COCO annotations dictionary
2656
+ if metadata_format == "COCO":
2657
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
2658
+ ann_id = 0
2643
2659
 
2644
2660
  # Determine if class data is raster or vector
2645
2661
  is_class_data_raster = False
@@ -2713,6 +2729,17 @@ def export_geotiff_tiles(
2713
2729
 
2714
2730
  # Create class mapping
2715
2731
  class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
2732
+
2733
+ # Populate COCO categories
2734
+ if metadata_format == "COCO":
2735
+ for cls_val in unique_classes:
2736
+ coco_annotations["categories"].append(
2737
+ {
2738
+ "id": class_to_id[int(cls_val)],
2739
+ "name": str(int(cls_val)),
2740
+ "supercategory": "object",
2741
+ }
2742
+ )
2716
2743
  else:
2717
2744
  # Load vector class data
2718
2745
  try:
@@ -2742,12 +2769,33 @@ def export_geotiff_tiles(
2742
2769
  )
2743
2770
  # Create class mapping
2744
2771
  class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
2772
+
2773
+ # Populate COCO categories
2774
+ if metadata_format == "COCO":
2775
+ for cls_val in unique_classes:
2776
+ coco_annotations["categories"].append(
2777
+ {
2778
+ "id": class_to_id[cls_val],
2779
+ "name": str(cls_val),
2780
+ "supercategory": "object",
2781
+ }
2782
+ )
2745
2783
  else:
2746
2784
  if not quiet:
2747
2785
  print(
2748
2786
  f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
2749
2787
  )
2750
2788
  class_to_id = {1: 1} # Default mapping
2789
+
2790
+ # Populate COCO categories with default
2791
+ if metadata_format == "COCO":
2792
+ coco_annotations["categories"].append(
2793
+ {
2794
+ "id": 1,
2795
+ "name": "object",
2796
+ "supercategory": "object",
2797
+ }
2798
+ )
2751
2799
  except Exception as e:
2752
2800
  raise ValueError(f"Error processing vector data: {e}")
2753
2801
 
@@ -2964,72 +3012,186 @@ def export_geotiff_tiles(
2964
3012
  pbar.write(f"ERROR saving label GeoTIFF: {e}")
2965
3013
  stats["errors"] += 1
2966
3014
 
2967
- # Create XML annotation for object detection if using vector class data
3015
+ # Create annotations for object detection if using vector class data
2968
3016
  if (
2969
3017
  not is_class_data_raster
2970
3018
  and "gdf" in locals()
2971
3019
  and len(window_features) > 0
2972
3020
  ):
2973
- # Create XML annotation
2974
- root = ET.Element("annotation")
2975
- ET.SubElement(root, "folder").text = "images"
2976
- ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
3021
+ if metadata_format == "PASCAL_VOC":
3022
+ # Create XML annotation
3023
+ root = ET.Element("annotation")
3024
+ ET.SubElement(root, "folder").text = "images"
3025
+ ET.SubElement(root, "filename").text = (
3026
+ f"tile_{tile_index:06d}.tif"
3027
+ )
2977
3028
 
2978
- size = ET.SubElement(root, "size")
2979
- ET.SubElement(size, "width").text = str(tile_size)
2980
- ET.SubElement(size, "height").text = str(tile_size)
2981
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
3029
+ size = ET.SubElement(root, "size")
3030
+ ET.SubElement(size, "width").text = str(tile_size)
3031
+ ET.SubElement(size, "height").text = str(tile_size)
3032
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
3033
+
3034
+ # Add georeference information
3035
+ geo = ET.SubElement(root, "georeference")
3036
+ ET.SubElement(geo, "crs").text = str(src.crs)
3037
+ ET.SubElement(geo, "transform").text = str(
3038
+ window_transform
3039
+ ).replace("\n", "")
3040
+ ET.SubElement(geo, "bounds").text = (
3041
+ f"{minx}, {miny}, {maxx}, {maxy}"
3042
+ )
2982
3043
 
2983
- # Add georeference information
2984
- geo = ET.SubElement(root, "georeference")
2985
- ET.SubElement(geo, "crs").text = str(src.crs)
2986
- ET.SubElement(geo, "transform").text = str(
2987
- window_transform
2988
- ).replace("\n", "")
2989
- ET.SubElement(geo, "bounds").text = (
2990
- f"{minx}, {miny}, {maxx}, {maxy}"
2991
- )
3044
+ # Add objects
3045
+ for idx, feature in window_features.iterrows():
3046
+ # Get feature class
3047
+ if class_value_field in feature:
3048
+ class_val = feature[class_value_field]
3049
+ else:
3050
+ class_val = "object"
2992
3051
 
2993
- # Add objects
2994
- for idx, feature in window_features.iterrows():
2995
- # Get feature class
2996
- if class_value_field in feature:
2997
- class_val = feature[class_value_field]
2998
- else:
2999
- class_val = "object"
3052
+ # Get geometry bounds in pixel coordinates
3053
+ geom = feature.geometry.intersection(window_bounds)
3054
+ if not geom.is_empty:
3055
+ # Get bounds in world coordinates
3056
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3057
+
3058
+ # Convert to pixel coordinates
3059
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3060
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3061
+
3062
+ # Ensure coordinates are within tile bounds
3063
+ xmin = max(0, min(tile_size, int(col_min)))
3064
+ ymin = max(0, min(tile_size, int(row_min)))
3065
+ xmax = max(0, min(tile_size, int(col_max)))
3066
+ ymax = max(0, min(tile_size, int(row_max)))
3067
+
3068
+ # Only add if the box has non-zero area
3069
+ if xmax > xmin and ymax > ymin:
3070
+ obj = ET.SubElement(root, "object")
3071
+ ET.SubElement(obj, "name").text = str(class_val)
3072
+ ET.SubElement(obj, "difficult").text = "0"
3073
+
3074
+ bbox = ET.SubElement(obj, "bndbox")
3075
+ ET.SubElement(bbox, "xmin").text = str(xmin)
3076
+ ET.SubElement(bbox, "ymin").text = str(ymin)
3077
+ ET.SubElement(bbox, "xmax").text = str(xmax)
3078
+ ET.SubElement(bbox, "ymax").text = str(ymax)
3079
+
3080
+ # Save XML
3081
+ tree = ET.ElementTree(root)
3082
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3083
+ tree.write(xml_path)
3000
3084
 
3001
- # Get geometry bounds in pixel coordinates
3002
- geom = feature.geometry.intersection(window_bounds)
3003
- if not geom.is_empty:
3004
- # Get bounds in world coordinates
3005
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3006
-
3007
- # Convert to pixel coordinates
3008
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
3009
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
3010
-
3011
- # Ensure coordinates are within tile bounds
3012
- xmin = max(0, min(tile_size, int(col_min)))
3013
- ymin = max(0, min(tile_size, int(row_min)))
3014
- xmax = max(0, min(tile_size, int(col_max)))
3015
- ymax = max(0, min(tile_size, int(row_max)))
3016
-
3017
- # Only add if the box has non-zero area
3018
- if xmax > xmin and ymax > ymin:
3019
- obj = ET.SubElement(root, "object")
3020
- ET.SubElement(obj, "name").text = str(class_val)
3021
- ET.SubElement(obj, "difficult").text = "0"
3022
-
3023
- bbox = ET.SubElement(obj, "bndbox")
3024
- ET.SubElement(bbox, "xmin").text = str(xmin)
3025
- ET.SubElement(bbox, "ymin").text = str(ymin)
3026
- ET.SubElement(bbox, "xmax").text = str(xmax)
3027
- ET.SubElement(bbox, "ymax").text = str(ymax)
3085
+ elif metadata_format == "COCO":
3086
+ # Add image info
3087
+ image_id = tile_index
3088
+ coco_annotations["images"].append(
3089
+ {
3090
+ "id": image_id,
3091
+ "file_name": f"tile_{tile_index:06d}.tif",
3092
+ "width": tile_size,
3093
+ "height": tile_size,
3094
+ "crs": str(src.crs),
3095
+ "transform": str(window_transform),
3096
+ }
3097
+ )
3028
3098
 
3029
- # Save XML
3030
- tree = ET.ElementTree(root)
3031
- xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
3032
- tree.write(xml_path)
3099
+ # Add annotations for each feature
3100
+ for _, feature in window_features.iterrows():
3101
+ # Get feature class
3102
+ if class_value_field in feature:
3103
+ class_val = feature[class_value_field]
3104
+ category_id = class_to_id.get(class_val, 1)
3105
+ else:
3106
+ category_id = 1
3107
+
3108
+ # Get geometry bounds
3109
+ geom = feature.geometry.intersection(window_bounds)
3110
+ if not geom.is_empty:
3111
+ # Get bounds in world coordinates
3112
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3113
+
3114
+ # Convert to pixel coordinates
3115
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3116
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3117
+
3118
+ # Ensure coordinates are within tile bounds
3119
+ xmin = max(0, min(tile_size, int(col_min)))
3120
+ ymin = max(0, min(tile_size, int(row_min)))
3121
+ xmax = max(0, min(tile_size, int(col_max)))
3122
+ ymax = max(0, min(tile_size, int(row_max)))
3123
+
3124
+ # Skip if box is too small
3125
+ if xmax - xmin < 1 or ymax - ymin < 1:
3126
+ continue
3127
+
3128
+ width = xmax - xmin
3129
+ height = ymax - ymin
3130
+
3131
+ # Add annotation
3132
+ ann_id += 1
3133
+ coco_annotations["annotations"].append(
3134
+ {
3135
+ "id": ann_id,
3136
+ "image_id": image_id,
3137
+ "category_id": category_id,
3138
+ "bbox": [xmin, ymin, width, height],
3139
+ "area": width * height,
3140
+ "iscrowd": 0,
3141
+ }
3142
+ )
3143
+
3144
+ elif metadata_format == "YOLO":
3145
+ # Create YOLO format annotations
3146
+ yolo_annotations = []
3147
+
3148
+ for _, feature in window_features.iterrows():
3149
+ # Get feature class
3150
+ if class_value_field in feature:
3151
+ class_val = feature[class_value_field]
3152
+ # YOLO uses 0-indexed class IDs
3153
+ class_id = class_to_id.get(class_val, 1) - 1
3154
+ else:
3155
+ class_id = 0
3156
+
3157
+ # Get geometry bounds
3158
+ geom = feature.geometry.intersection(window_bounds)
3159
+ if not geom.is_empty:
3160
+ # Get bounds in world coordinates
3161
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
3162
+
3163
+ # Convert to pixel coordinates
3164
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3165
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3166
+
3167
+ # Ensure coordinates are within tile bounds
3168
+ xmin = max(0, min(tile_size, col_min))
3169
+ ymin = max(0, min(tile_size, row_min))
3170
+ xmax = max(0, min(tile_size, col_max))
3171
+ ymax = max(0, min(tile_size, row_max))
3172
+
3173
+ # Skip if box is too small
3174
+ if xmax - xmin < 1 or ymax - ymin < 1:
3175
+ continue
3176
+
3177
+ # Calculate normalized coordinates (YOLO format)
3178
+ x_center = ((xmin + xmax) / 2) / tile_size
3179
+ y_center = ((ymin + ymax) / 2) / tile_size
3180
+ width = (xmax - xmin) / tile_size
3181
+ height = (ymax - ymin) / tile_size
3182
+
3183
+ # Add YOLO annotation line
3184
+ yolo_annotations.append(
3185
+ f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
3186
+ )
3187
+
3188
+ # Save YOLO annotations to text file
3189
+ if yolo_annotations:
3190
+ yolo_path = os.path.join(
3191
+ label_dir, f"tile_{tile_index:06d}.txt"
3192
+ )
3193
+ with open(yolo_path, "w") as f:
3194
+ f.write("\n".join(yolo_annotations))
3033
3195
 
3034
3196
  # Update progress bar
3035
3197
  pbar.update(1)
@@ -3047,6 +3209,39 @@ def export_geotiff_tiles(
3047
3209
  # Close progress bar
3048
3210
  pbar.close()
3049
3211
 
3212
+ # Save COCO annotations if applicable
3213
+ if metadata_format == "COCO":
3214
+ try:
3215
+ with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3216
+ json.dump(coco_annotations, f, indent=2)
3217
+ if not quiet:
3218
+ print(
3219
+ f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
3220
+ f"{len(coco_annotations['annotations'])} annotations, "
3221
+ f"{len(coco_annotations['categories'])} categories"
3222
+ )
3223
+ except Exception as e:
3224
+ if not quiet:
3225
+ print(f"ERROR saving COCO annotations: {e}")
3226
+ stats["errors"] += 1
3227
+
3228
+ # Save YOLO classes file if applicable
3229
+ if metadata_format == "YOLO":
3230
+ try:
3231
+ # Create classes.txt with class names
3232
+ classes_path = os.path.join(out_folder, "classes.txt")
3233
+ # Sort by class ID to ensure correct order
3234
+ sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
3235
+ with open(classes_path, "w") as f:
3236
+ for class_val, _ in sorted_classes:
3237
+ f.write(f"{class_val}\n")
3238
+ if not quiet:
3239
+ print(f"Saved YOLO classes file with {len(class_to_id)} classes")
3240
+ except Exception as e:
3241
+ if not quiet:
3242
+ print(f"ERROR saving YOLO classes file: {e}")
3243
+ stats["errors"] += 1
3244
+
3050
3245
  # Create overview image if requested
3051
3246
  if create_overview and stats["tile_coordinates"]:
3052
3247
  try:
@@ -3125,11 +3320,11 @@ def export_geotiff_tiles_batch(
3125
3320
  max_tiles=None,
3126
3321
  quiet=False,
3127
3322
  all_touched=True,
3128
- create_overview=False,
3129
3323
  skip_empty_tiles=False,
3130
3324
  image_extensions=None,
3131
3325
  mask_extensions=None,
3132
3326
  match_by_name=True,
3327
+ metadata_format="PASCAL_VOC",
3133
3328
  ) -> Dict[str, Any]:
3134
3329
  """
3135
3330
  Export georeferenced GeoTIFF tiles from images and masks.
@@ -3170,6 +3365,8 @@ def export_geotiff_tiles_batch(
3170
3365
  mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
3171
3366
  match_by_name (bool): If True, match image and mask files by base filename.
3172
3367
  If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
3368
+ metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
3369
+ Default is "PASCAL_VOC".
3173
3370
 
3174
3371
  Returns:
3175
3372
  Dict[str, Any]: Dictionary containing batch processing statistics
@@ -3253,6 +3450,19 @@ def export_geotiff_tiles_batch(
3253
3450
  os.makedirs(output_images_dir, exist_ok=True)
3254
3451
  os.makedirs(output_masks_dir, exist_ok=True)
3255
3452
 
3453
+ # Create annotation directory based on metadata format
3454
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
3455
+ ann_dir = os.path.join(output_folder, "annotations")
3456
+ os.makedirs(ann_dir, exist_ok=True)
3457
+
3458
+ # Initialize COCO annotations dictionary
3459
+ coco_annotations = None
3460
+ if metadata_format == "COCO":
3461
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
3462
+
3463
+ # Initialize YOLO class set
3464
+ yolo_classes = set() if metadata_format == "YOLO" else None
3465
+
3256
3466
  # Get list of image files
3257
3467
  image_files = []
3258
3468
  for ext in image_extensions:
@@ -3406,6 +3616,13 @@ def export_geotiff_tiles_batch(
3406
3616
  quiet=quiet,
3407
3617
  mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
3408
3618
  use_single_mask_file=use_single_mask_file,
3619
+ metadata_format=metadata_format,
3620
+ ann_dir=(
3621
+ ann_dir
3622
+ if "ann_dir" in locals()
3623
+ and metadata_format in ["PASCAL_VOC", "COCO"]
3624
+ else None
3625
+ ),
3409
3626
  )
3410
3627
 
3411
3628
  # Update counters
@@ -3427,6 +3644,23 @@ def export_geotiff_tiles_batch(
3427
3644
  }
3428
3645
  )
3429
3646
 
3647
+ # Aggregate COCO annotations
3648
+ if metadata_format == "COCO" and "coco_data" in tiles_generated:
3649
+ coco_data = tiles_generated["coco_data"]
3650
+ # Add images and annotations
3651
+ coco_annotations["images"].extend(coco_data.get("images", []))
3652
+ coco_annotations["annotations"].extend(coco_data.get("annotations", []))
3653
+ # Merge categories (avoid duplicates)
3654
+ for cat in coco_data.get("categories", []):
3655
+ if not any(
3656
+ c["id"] == cat["id"] for c in coco_annotations["categories"]
3657
+ ):
3658
+ coco_annotations["categories"].append(cat)
3659
+
3660
+ # Aggregate YOLO classes
3661
+ if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
3662
+ yolo_classes.update(tiles_generated["yolo_classes"])
3663
+
3430
3664
  except Exception as e:
3431
3665
  if not quiet:
3432
3666
  print(f"ERROR processing {base_name}: {e}")
@@ -3435,6 +3669,33 @@ def export_geotiff_tiles_batch(
3435
3669
  )
3436
3670
  batch_stats["errors"] += 1
3437
3671
 
3672
+ # Save aggregated COCO annotations
3673
+ if metadata_format == "COCO" and coco_annotations:
3674
+ import json
3675
+
3676
+ coco_path = os.path.join(ann_dir, "instances.json")
3677
+ with open(coco_path, "w") as f:
3678
+ json.dump(coco_annotations, f, indent=2)
3679
+ if not quiet:
3680
+ print(f"\nSaved COCO annotations: {coco_path}")
3681
+ print(
3682
+ f" Images: {len(coco_annotations['images'])}, "
3683
+ f"Annotations: {len(coco_annotations['annotations'])}, "
3684
+ f"Categories: {len(coco_annotations['categories'])}"
3685
+ )
3686
+
3687
+ # Save aggregated YOLO classes
3688
+ if metadata_format == "YOLO" and yolo_classes:
3689
+ classes_path = os.path.join(output_folder, "labels", "classes.txt")
3690
+ os.makedirs(os.path.dirname(classes_path), exist_ok=True)
3691
+ sorted_classes = sorted(yolo_classes)
3692
+ with open(classes_path, "w") as f:
3693
+ for cls in sorted_classes:
3694
+ f.write(f"{cls}\n")
3695
+ if not quiet:
3696
+ print(f"\nSaved YOLO classes: {classes_path}")
3697
+ print(f" Total classes: {len(sorted_classes)}")
3698
+
3438
3699
  # Print batch summary
3439
3700
  if not quiet:
3440
3701
  print("\n" + "=" * 60)
@@ -3458,6 +3719,10 @@ def export_geotiff_tiles_batch(
3458
3719
  print(f"Output saved to: {output_folder}")
3459
3720
  print(f" Images: {output_images_dir}")
3460
3721
  print(f" Masks: {output_masks_dir}")
3722
+ if metadata_format in ["PASCAL_VOC", "COCO"]:
3723
+ print(f" Annotations: {ann_dir}")
3724
+ elif metadata_format == "YOLO":
3725
+ print(f" Labels: {os.path.join(output_folder, 'labels')}")
3461
3726
 
3462
3727
  # List failed files if any
3463
3728
  if batch_stats["failed_files"]:
@@ -3485,6 +3750,8 @@ def _process_image_mask_pair(
3485
3750
  quiet=False,
3486
3751
  mask_gdf=None,
3487
3752
  use_single_mask_file=False,
3753
+ metadata_format="PASCAL_VOC",
3754
+ ann_dir=None,
3488
3755
  ):
3489
3756
  """
3490
3757
  Process a single image-mask pair and save tiles directly to output directories.
@@ -3517,6 +3784,13 @@ def _process_image_mask_pair(
3517
3784
  "errors": 0,
3518
3785
  }
3519
3786
 
3787
+ # Initialize COCO/YOLO tracking for this image
3788
+ if metadata_format == "COCO":
3789
+ stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
3790
+ coco_ann_id = 0
3791
+ if metadata_format == "YOLO":
3792
+ stats["yolo_classes"] = set()
3793
+
3520
3794
  # Open the input raster
3521
3795
  with rasterio.open(image_file) as src:
3522
3796
  # Calculate number of tiles
@@ -3609,9 +3883,6 @@ def _process_image_mask_pair(
3609
3883
  tile_index = 0
3610
3884
  for y in range(num_tiles_y):
3611
3885
  for x in range(num_tiles_x):
3612
- if tile_index >= max_tiles:
3613
- break
3614
-
3615
3886
  # Calculate window coordinates
3616
3887
  window_x = x * stride
3617
3888
  window_y = y * stride
@@ -3714,9 +3985,12 @@ def _process_image_mask_pair(
3714
3985
 
3715
3986
  # Skip tile if no features and skip_empty_tiles is True
3716
3987
  if skip_empty_tiles and not has_features:
3717
- tile_index += 1
3718
3988
  continue
3719
3989
 
3990
+ # Check if we've reached max_tiles before saving
3991
+ if tile_index >= max_tiles:
3992
+ break
3993
+
3720
3994
  # Generate unique tile name
3721
3995
  tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
3722
3996
 
@@ -3771,6 +4045,197 @@ def _process_image_mask_pair(
3771
4045
  print(f"ERROR saving label GeoTIFF: {e}")
3772
4046
  stats["errors"] += 1
3773
4047
 
4048
+ # Generate annotation metadata based on format
4049
+ if metadata_format == "PASCAL_VOC" and ann_dir:
4050
+ # Create PASCAL VOC XML annotation
4051
+ from lxml import etree as ET
4052
+
4053
+ annotation = ET.Element("annotation")
4054
+ ET.SubElement(annotation, "folder").text = os.path.basename(
4055
+ output_images_dir
4056
+ )
4057
+ ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
4058
+ ET.SubElement(annotation, "path").text = image_path
4059
+
4060
+ source = ET.SubElement(annotation, "source")
4061
+ ET.SubElement(source, "database").text = "GeoAI"
4062
+
4063
+ size = ET.SubElement(annotation, "size")
4064
+ ET.SubElement(size, "width").text = str(tile_size)
4065
+ ET.SubElement(size, "height").text = str(tile_size)
4066
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
4067
+
4068
+ ET.SubElement(annotation, "segmented").text = "1"
4069
+
4070
+ # Find connected components for instance segmentation
4071
+ from scipy import ndimage
4072
+
4073
+ for class_id in np.unique(label_mask):
4074
+ if class_id == 0:
4075
+ continue
4076
+
4077
+ class_mask = (label_mask == class_id).astype(np.uint8)
4078
+ labeled_array, num_features = ndimage.label(class_mask)
4079
+
4080
+ for instance_id in range(1, num_features + 1):
4081
+ instance_mask = labeled_array == instance_id
4082
+ coords = np.argwhere(instance_mask)
4083
+
4084
+ if len(coords) == 0:
4085
+ continue
4086
+
4087
+ ymin, xmin = coords.min(axis=0)
4088
+ ymax, xmax = coords.max(axis=0)
4089
+
4090
+ obj = ET.SubElement(annotation, "object")
4091
+ class_name = next(
4092
+ (k for k, v in class_to_id.items() if v == class_id),
4093
+ str(class_id),
4094
+ )
4095
+ ET.SubElement(obj, "name").text = str(class_name)
4096
+ ET.SubElement(obj, "pose").text = "Unspecified"
4097
+ ET.SubElement(obj, "truncated").text = "0"
4098
+ ET.SubElement(obj, "difficult").text = "0"
4099
+
4100
+ bndbox = ET.SubElement(obj, "bndbox")
4101
+ ET.SubElement(bndbox, "xmin").text = str(int(xmin))
4102
+ ET.SubElement(bndbox, "ymin").text = str(int(ymin))
4103
+ ET.SubElement(bndbox, "xmax").text = str(int(xmax))
4104
+ ET.SubElement(bndbox, "ymax").text = str(int(ymax))
4105
+
4106
+ # Save XML file
4107
+ xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
4108
+ tree = ET.ElementTree(annotation)
4109
+ tree.write(xml_path, pretty_print=True, encoding="utf-8")
4110
+
4111
+ elif metadata_format == "COCO":
4112
+ # Add COCO image entry
4113
+ image_id = int(global_tile_counter + tile_index)
4114
+ stats["coco_data"]["images"].append(
4115
+ {
4116
+ "id": image_id,
4117
+ "file_name": f"{tile_name}.tif",
4118
+ "width": int(tile_size),
4119
+ "height": int(tile_size),
4120
+ }
4121
+ )
4122
+
4123
+ # Add COCO categories (only once per unique class)
4124
+ for class_val, class_id in class_to_id.items():
4125
+ if not any(
4126
+ c["id"] == class_id
4127
+ for c in stats["coco_data"]["categories"]
4128
+ ):
4129
+ stats["coco_data"]["categories"].append(
4130
+ {
4131
+ "id": int(class_id),
4132
+ "name": str(class_val),
4133
+ "supercategory": "object",
4134
+ }
4135
+ )
4136
+
4137
+ # Add COCO annotations (instance segmentation)
4138
+ from scipy import ndimage
4139
+ from skimage import measure
4140
+
4141
+ for class_id in np.unique(label_mask):
4142
+ if class_id == 0:
4143
+ continue
4144
+
4145
+ class_mask = (label_mask == class_id).astype(np.uint8)
4146
+ labeled_array, num_features = ndimage.label(class_mask)
4147
+
4148
+ for instance_id in range(1, num_features + 1):
4149
+ instance_mask = (labeled_array == instance_id).astype(
4150
+ np.uint8
4151
+ )
4152
+ coords = np.argwhere(instance_mask)
4153
+
4154
+ if len(coords) == 0:
4155
+ continue
4156
+
4157
+ ymin, xmin = coords.min(axis=0)
4158
+ ymax, xmax = coords.max(axis=0)
4159
+
4160
+ bbox = [
4161
+ int(xmin),
4162
+ int(ymin),
4163
+ int(xmax - xmin),
4164
+ int(ymax - ymin),
4165
+ ]
4166
+ area = int(np.sum(instance_mask))
4167
+
4168
+ # Find contours for segmentation
4169
+ contours = measure.find_contours(instance_mask, 0.5)
4170
+ segmentation = []
4171
+ for contour in contours:
4172
+ contour = np.flip(contour, axis=1)
4173
+ segmentation_points = contour.ravel().tolist()
4174
+ if len(segmentation_points) >= 6:
4175
+ segmentation.append(segmentation_points)
4176
+
4177
+ if segmentation:
4178
+ stats["coco_data"]["annotations"].append(
4179
+ {
4180
+ "id": int(coco_ann_id),
4181
+ "image_id": int(image_id),
4182
+ "category_id": int(class_id),
4183
+ "bbox": bbox,
4184
+ "area": area,
4185
+ "segmentation": segmentation,
4186
+ "iscrowd": 0,
4187
+ }
4188
+ )
4189
+ coco_ann_id += 1
4190
+
4191
+ elif metadata_format == "YOLO":
4192
+ # Create YOLO labels directory if needed
4193
+ labels_dir = os.path.join(
4194
+ os.path.dirname(output_images_dir), "labels"
4195
+ )
4196
+ os.makedirs(labels_dir, exist_ok=True)
4197
+
4198
+ # Generate YOLO annotation file
4199
+ yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
4200
+ from scipy import ndimage
4201
+
4202
+ with open(yolo_path, "w") as yolo_file:
4203
+ for class_id in np.unique(label_mask):
4204
+ if class_id == 0:
4205
+ continue
4206
+
4207
+ # Track class for classes.txt
4208
+ class_name = next(
4209
+ (k for k, v in class_to_id.items() if v == class_id),
4210
+ str(class_id),
4211
+ )
4212
+ stats["yolo_classes"].add(class_name)
4213
+
4214
+ class_mask = (label_mask == class_id).astype(np.uint8)
4215
+ labeled_array, num_features = ndimage.label(class_mask)
4216
+
4217
+ for instance_id in range(1, num_features + 1):
4218
+ instance_mask = labeled_array == instance_id
4219
+ coords = np.argwhere(instance_mask)
4220
+
4221
+ if len(coords) == 0:
4222
+ continue
4223
+
4224
+ ymin, xmin = coords.min(axis=0)
4225
+ ymax, xmax = coords.max(axis=0)
4226
+
4227
+ # Convert to YOLO format (normalized center coordinates)
4228
+ x_center = ((xmin + xmax) / 2) / tile_size
4229
+ y_center = ((ymin + ymax) / 2) / tile_size
4230
+ width = (xmax - xmin) / tile_size
4231
+ height = (ymax - ymin) / tile_size
4232
+
4233
+ # YOLO uses 0-based class indices
4234
+ yolo_class_id = class_id - 1
4235
+ yolo_file.write(
4236
+ f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
4237
+ )
4238
+
3774
4239
  tile_index += 1
3775
4240
  if tile_index >= max_tiles:
3776
4241
  break
@@ -3781,6 +4246,179 @@ def _process_image_mask_pair(
3781
4246
  return stats
3782
4247
 
3783
4248
 
4249
+ def display_training_tiles(
4250
+ output_dir,
4251
+ num_tiles=6,
4252
+ figsize=(18, 6),
4253
+ cmap="gray",
4254
+ save_path=None,
4255
+ ):
4256
+ """
4257
+ Display image and mask tile pairs from training data output.
4258
+
4259
+ Args:
4260
+ output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
4261
+ num_tiles (int): Number of tile pairs to display (default: 6)
4262
+ figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
4263
+ cmap (str): Colormap for mask display (default: 'gray')
4264
+ save_path (str, optional): If provided, save figure to this path instead of displaying
4265
+
4266
+ Returns:
4267
+ tuple: (fig, axes) matplotlib figure and axes objects
4268
+
4269
+ Example:
4270
+ >>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
4271
+ >>> # Or save to file
4272
+ >>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
4273
+ """
4274
+ import matplotlib.pyplot as plt
4275
+
4276
+ # Get list of image tiles
4277
+ images_dir = os.path.join(output_dir, "images")
4278
+ if not os.path.exists(images_dir):
4279
+ raise ValueError(f"Images directory not found: {images_dir}")
4280
+
4281
+ image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
4282
+
4283
+ if not image_tiles:
4284
+ raise ValueError(f"No image tiles found in {images_dir}")
4285
+
4286
+ # Limit to available tiles
4287
+ num_tiles = min(num_tiles, len(image_tiles))
4288
+
4289
+ # Create figure with subplots
4290
+ fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
4291
+
4292
+ # Handle case where num_tiles is 1
4293
+ if num_tiles == 1:
4294
+ axes = axes.reshape(2, 1)
4295
+
4296
+ for idx, tile_name in enumerate(image_tiles):
4297
+ # Load and display image tile
4298
+ image_path = os.path.join(output_dir, "images", tile_name)
4299
+ with rasterio.open(image_path) as src:
4300
+ show(src, ax=axes[0, idx], title=f"Image {idx+1}")
4301
+
4302
+ # Load and display mask tile
4303
+ mask_path = os.path.join(output_dir, "masks", tile_name)
4304
+ if os.path.exists(mask_path):
4305
+ with rasterio.open(mask_path) as src:
4306
+ show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
4307
+ else:
4308
+ axes[1, idx].text(
4309
+ 0.5,
4310
+ 0.5,
4311
+ "Mask not found",
4312
+ ha="center",
4313
+ va="center",
4314
+ transform=axes[1, idx].transAxes,
4315
+ )
4316
+ axes[1, idx].set_title(f"Mask {idx+1}")
4317
+
4318
+ plt.tight_layout()
4319
+
4320
+ # Save or show
4321
+ if save_path:
4322
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
4323
+ plt.close(fig)
4324
+ print(f"Figure saved to: {save_path}")
4325
+ else:
4326
+ plt.show()
4327
+
4328
+ return fig, axes
4329
+
4330
+
4331
+ def display_image_with_vector(
4332
+ image_path,
4333
+ vector_path,
4334
+ figsize=(16, 8),
4335
+ vector_color="red",
4336
+ vector_linewidth=1,
4337
+ vector_facecolor="none",
4338
+ save_path=None,
4339
+ ):
4340
+ """
4341
+ Display a raster image alongside the same image with vector overlay.
4342
+
4343
+ Args:
4344
+ image_path (str): Path to raster image file
4345
+ vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
4346
+ figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
4347
+ vector_color (str): Edge color for vector features (default: 'red')
4348
+ vector_linewidth (float): Line width for vector features (default: 1)
4349
+ vector_facecolor (str): Fill color for vector features (default: 'none')
4350
+ save_path (str, optional): If provided, save figure to this path instead of displaying
4351
+
4352
+ Returns:
4353
+ tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
4354
+
4355
+ Example:
4356
+ >>> fig, axes, info = display_image_with_vector(
4357
+ ... 'image.tif',
4358
+ ... 'buildings.geojson',
4359
+ ... vector_color='blue'
4360
+ ... )
4361
+ >>> print(f"Number of features: {info['num_features']}")
4362
+ """
4363
+ import matplotlib.pyplot as plt
4364
+
4365
+ # Validate inputs
4366
+ if not os.path.exists(image_path):
4367
+ raise ValueError(f"Image file not found: {image_path}")
4368
+ if not os.path.exists(vector_path):
4369
+ raise ValueError(f"Vector file not found: {vector_path}")
4370
+
4371
+ # Create figure
4372
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
4373
+
4374
+ # Load and display image
4375
+ with rasterio.open(image_path) as src:
4376
+ # Plot image only
4377
+ show(src, ax=ax1, title="Image")
4378
+
4379
+ # Load vector data
4380
+ vector_data = gpd.read_file(vector_path)
4381
+
4382
+ # Reproject to image CRS if needed
4383
+ if vector_data.crs != src.crs:
4384
+ vector_data = vector_data.to_crs(src.crs)
4385
+
4386
+ # Plot image with vector overlay
4387
+ show(
4388
+ src,
4389
+ ax=ax2,
4390
+ title=f"Image with {len(vector_data)} Vector Features",
4391
+ )
4392
+ vector_data.plot(
4393
+ ax=ax2,
4394
+ facecolor=vector_facecolor,
4395
+ edgecolor=vector_color,
4396
+ linewidth=vector_linewidth,
4397
+ )
4398
+
4399
+ # Collect metadata
4400
+ info = {
4401
+ "image_shape": src.shape,
4402
+ "image_crs": src.crs,
4403
+ "image_bounds": src.bounds,
4404
+ "num_features": len(vector_data),
4405
+ "vector_crs": vector_data.crs,
4406
+ "vector_bounds": vector_data.total_bounds,
4407
+ }
4408
+
4409
+ plt.tight_layout()
4410
+
4411
+ # Save or show
4412
+ if save_path:
4413
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
4414
+ plt.close(fig)
4415
+ print(f"Figure saved to: {save_path}")
4416
+ else:
4417
+ plt.show()
4418
+
4419
+ return fig, (ax1, ax2), info
4420
+
4421
+
3784
4422
  def create_overview_image(
3785
4423
  src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
3786
4424
  ) -> str: