geoai-py 0.13.2__py2.py3-none-any.whl → 0.15.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +21 -1
- geoai/timm_segment.py +1097 -0
- geoai/timm_train.py +658 -0
- geoai/train.py +224 -77
- geoai/utils.py +893 -103
- {geoai_py-0.13.2.dist-info → geoai_py-0.15.0.dist-info}/METADATA +16 -5
- {geoai_py-0.13.2.dist-info → geoai_py-0.15.0.dist-info}/RECORD +11 -9
- {geoai_py-0.13.2.dist-info → geoai_py-0.15.0.dist-info}/licenses/LICENSE +1 -2
- {geoai_py-0.13.2.dist-info → geoai_py-0.15.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.13.2.dist-info → geoai_py-0.15.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.13.2.dist-info → geoai_py-0.15.0.dist-info}/top_level.txt +0 -0
geoai/utils.py
CHANGED
@@ -751,6 +751,12 @@ def view_vector_interactive(
|
|
751
751
|
},
|
752
752
|
}
|
753
753
|
|
754
|
+
# Make it compatible with binder and JupyterHub
|
755
|
+
if os.environ.get("JUPYTERHUB_SERVICE_PREFIX") is not None:
|
756
|
+
os.environ["LOCALTILESERVER_CLIENT_PREFIX"] = (
|
757
|
+
f"{os.environ['JUPYTERHUB_SERVICE_PREFIX'].lstrip('/')}/proxy/{{port}}"
|
758
|
+
)
|
759
|
+
|
754
760
|
basemap_layer_name = None
|
755
761
|
raster_layer = None
|
756
762
|
|
@@ -2609,6 +2615,7 @@ def export_geotiff_tiles(
|
|
2609
2615
|
all_touched=True,
|
2610
2616
|
create_overview=False,
|
2611
2617
|
skip_empty_tiles=False,
|
2618
|
+
metadata_format="PASCAL_VOC",
|
2612
2619
|
):
|
2613
2620
|
"""
|
2614
2621
|
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
@@ -2626,6 +2633,7 @@ def export_geotiff_tiles(
|
|
2626
2633
|
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
2627
2634
|
create_overview (bool): Whether to create an overview image of all tiles
|
2628
2635
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
2636
|
+
metadata_format (str): Output metadata format (PASCAL_VOC, COCO, YOLO). Default: PASCAL_VOC
|
2629
2637
|
"""
|
2630
2638
|
|
2631
2639
|
import logging
|
@@ -2638,8 +2646,16 @@ def export_geotiff_tiles(
|
|
2638
2646
|
os.makedirs(image_dir, exist_ok=True)
|
2639
2647
|
label_dir = os.path.join(out_folder, "labels")
|
2640
2648
|
os.makedirs(label_dir, exist_ok=True)
|
2641
|
-
|
2642
|
-
|
2649
|
+
|
2650
|
+
# Create annotation directory based on metadata format
|
2651
|
+
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
2652
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
2653
|
+
os.makedirs(ann_dir, exist_ok=True)
|
2654
|
+
|
2655
|
+
# Initialize COCO annotations dictionary
|
2656
|
+
if metadata_format == "COCO":
|
2657
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
2658
|
+
ann_id = 0
|
2643
2659
|
|
2644
2660
|
# Determine if class data is raster or vector
|
2645
2661
|
is_class_data_raster = False
|
@@ -2713,6 +2729,17 @@ def export_geotiff_tiles(
|
|
2713
2729
|
|
2714
2730
|
# Create class mapping
|
2715
2731
|
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
2732
|
+
|
2733
|
+
# Populate COCO categories
|
2734
|
+
if metadata_format == "COCO":
|
2735
|
+
for cls_val in unique_classes:
|
2736
|
+
coco_annotations["categories"].append(
|
2737
|
+
{
|
2738
|
+
"id": class_to_id[int(cls_val)],
|
2739
|
+
"name": str(int(cls_val)),
|
2740
|
+
"supercategory": "object",
|
2741
|
+
}
|
2742
|
+
)
|
2716
2743
|
else:
|
2717
2744
|
# Load vector class data
|
2718
2745
|
try:
|
@@ -2742,12 +2769,33 @@ def export_geotiff_tiles(
|
|
2742
2769
|
)
|
2743
2770
|
# Create class mapping
|
2744
2771
|
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
2772
|
+
|
2773
|
+
# Populate COCO categories
|
2774
|
+
if metadata_format == "COCO":
|
2775
|
+
for cls_val in unique_classes:
|
2776
|
+
coco_annotations["categories"].append(
|
2777
|
+
{
|
2778
|
+
"id": class_to_id[cls_val],
|
2779
|
+
"name": str(cls_val),
|
2780
|
+
"supercategory": "object",
|
2781
|
+
}
|
2782
|
+
)
|
2745
2783
|
else:
|
2746
2784
|
if not quiet:
|
2747
2785
|
print(
|
2748
2786
|
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
2749
2787
|
)
|
2750
2788
|
class_to_id = {1: 1} # Default mapping
|
2789
|
+
|
2790
|
+
# Populate COCO categories with default
|
2791
|
+
if metadata_format == "COCO":
|
2792
|
+
coco_annotations["categories"].append(
|
2793
|
+
{
|
2794
|
+
"id": 1,
|
2795
|
+
"name": "object",
|
2796
|
+
"supercategory": "object",
|
2797
|
+
}
|
2798
|
+
)
|
2751
2799
|
except Exception as e:
|
2752
2800
|
raise ValueError(f"Error processing vector data: {e}")
|
2753
2801
|
|
@@ -2964,72 +3012,186 @@ def export_geotiff_tiles(
|
|
2964
3012
|
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
2965
3013
|
stats["errors"] += 1
|
2966
3014
|
|
2967
|
-
# Create
|
3015
|
+
# Create annotations for object detection if using vector class data
|
2968
3016
|
if (
|
2969
3017
|
not is_class_data_raster
|
2970
3018
|
and "gdf" in locals()
|
2971
3019
|
and len(window_features) > 0
|
2972
3020
|
):
|
2973
|
-
|
2974
|
-
|
2975
|
-
|
2976
|
-
|
3021
|
+
if metadata_format == "PASCAL_VOC":
|
3022
|
+
# Create XML annotation
|
3023
|
+
root = ET.Element("annotation")
|
3024
|
+
ET.SubElement(root, "folder").text = "images"
|
3025
|
+
ET.SubElement(root, "filename").text = (
|
3026
|
+
f"tile_{tile_index:06d}.tif"
|
3027
|
+
)
|
2977
3028
|
|
2978
|
-
|
2979
|
-
|
2980
|
-
|
2981
|
-
|
3029
|
+
size = ET.SubElement(root, "size")
|
3030
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
3031
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
3032
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
3033
|
+
|
3034
|
+
# Add georeference information
|
3035
|
+
geo = ET.SubElement(root, "georeference")
|
3036
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
3037
|
+
ET.SubElement(geo, "transform").text = str(
|
3038
|
+
window_transform
|
3039
|
+
).replace("\n", "")
|
3040
|
+
ET.SubElement(geo, "bounds").text = (
|
3041
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
3042
|
+
)
|
2982
3043
|
|
2983
|
-
|
2984
|
-
|
2985
|
-
|
2986
|
-
|
2987
|
-
|
2988
|
-
|
2989
|
-
|
2990
|
-
f"{minx}, {miny}, {maxx}, {maxy}"
|
2991
|
-
)
|
3044
|
+
# Add objects
|
3045
|
+
for idx, feature in window_features.iterrows():
|
3046
|
+
# Get feature class
|
3047
|
+
if class_value_field in feature:
|
3048
|
+
class_val = feature[class_value_field]
|
3049
|
+
else:
|
3050
|
+
class_val = "object"
|
2992
3051
|
|
2993
|
-
|
2994
|
-
|
2995
|
-
|
2996
|
-
|
2997
|
-
|
2998
|
-
|
2999
|
-
|
3052
|
+
# Get geometry bounds in pixel coordinates
|
3053
|
+
geom = feature.geometry.intersection(window_bounds)
|
3054
|
+
if not geom.is_empty:
|
3055
|
+
# Get bounds in world coordinates
|
3056
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
3057
|
+
|
3058
|
+
# Convert to pixel coordinates
|
3059
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
3060
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
3061
|
+
|
3062
|
+
# Ensure coordinates are within tile bounds
|
3063
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
3064
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
3065
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
3066
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
3067
|
+
|
3068
|
+
# Only add if the box has non-zero area
|
3069
|
+
if xmax > xmin and ymax > ymin:
|
3070
|
+
obj = ET.SubElement(root, "object")
|
3071
|
+
ET.SubElement(obj, "name").text = str(class_val)
|
3072
|
+
ET.SubElement(obj, "difficult").text = "0"
|
3073
|
+
|
3074
|
+
bbox = ET.SubElement(obj, "bndbox")
|
3075
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
3076
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
3077
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
3078
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
3079
|
+
|
3080
|
+
# Save XML
|
3081
|
+
tree = ET.ElementTree(root)
|
3082
|
+
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
3083
|
+
tree.write(xml_path)
|
3000
3084
|
|
3001
|
-
|
3002
|
-
|
3003
|
-
|
3004
|
-
|
3005
|
-
|
3006
|
-
|
3007
|
-
|
3008
|
-
|
3009
|
-
|
3010
|
-
|
3011
|
-
|
3012
|
-
|
3013
|
-
|
3014
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
3015
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
3016
|
-
|
3017
|
-
# Only add if the box has non-zero area
|
3018
|
-
if xmax > xmin and ymax > ymin:
|
3019
|
-
obj = ET.SubElement(root, "object")
|
3020
|
-
ET.SubElement(obj, "name").text = str(class_val)
|
3021
|
-
ET.SubElement(obj, "difficult").text = "0"
|
3022
|
-
|
3023
|
-
bbox = ET.SubElement(obj, "bndbox")
|
3024
|
-
ET.SubElement(bbox, "xmin").text = str(xmin)
|
3025
|
-
ET.SubElement(bbox, "ymin").text = str(ymin)
|
3026
|
-
ET.SubElement(bbox, "xmax").text = str(xmax)
|
3027
|
-
ET.SubElement(bbox, "ymax").text = str(ymax)
|
3085
|
+
elif metadata_format == "COCO":
|
3086
|
+
# Add image info
|
3087
|
+
image_id = tile_index
|
3088
|
+
coco_annotations["images"].append(
|
3089
|
+
{
|
3090
|
+
"id": image_id,
|
3091
|
+
"file_name": f"tile_{tile_index:06d}.tif",
|
3092
|
+
"width": tile_size,
|
3093
|
+
"height": tile_size,
|
3094
|
+
"crs": str(src.crs),
|
3095
|
+
"transform": str(window_transform),
|
3096
|
+
}
|
3097
|
+
)
|
3028
3098
|
|
3029
|
-
|
3030
|
-
|
3031
|
-
|
3032
|
-
|
3099
|
+
# Add annotations for each feature
|
3100
|
+
for _, feature in window_features.iterrows():
|
3101
|
+
# Get feature class
|
3102
|
+
if class_value_field in feature:
|
3103
|
+
class_val = feature[class_value_field]
|
3104
|
+
category_id = class_to_id.get(class_val, 1)
|
3105
|
+
else:
|
3106
|
+
category_id = 1
|
3107
|
+
|
3108
|
+
# Get geometry bounds
|
3109
|
+
geom = feature.geometry.intersection(window_bounds)
|
3110
|
+
if not geom.is_empty:
|
3111
|
+
# Get bounds in world coordinates
|
3112
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
3113
|
+
|
3114
|
+
# Convert to pixel coordinates
|
3115
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
3116
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
3117
|
+
|
3118
|
+
# Ensure coordinates are within tile bounds
|
3119
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
3120
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
3121
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
3122
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
3123
|
+
|
3124
|
+
# Skip if box is too small
|
3125
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
3126
|
+
continue
|
3127
|
+
|
3128
|
+
width = xmax - xmin
|
3129
|
+
height = ymax - ymin
|
3130
|
+
|
3131
|
+
# Add annotation
|
3132
|
+
ann_id += 1
|
3133
|
+
coco_annotations["annotations"].append(
|
3134
|
+
{
|
3135
|
+
"id": ann_id,
|
3136
|
+
"image_id": image_id,
|
3137
|
+
"category_id": category_id,
|
3138
|
+
"bbox": [xmin, ymin, width, height],
|
3139
|
+
"area": width * height,
|
3140
|
+
"iscrowd": 0,
|
3141
|
+
}
|
3142
|
+
)
|
3143
|
+
|
3144
|
+
elif metadata_format == "YOLO":
|
3145
|
+
# Create YOLO format annotations
|
3146
|
+
yolo_annotations = []
|
3147
|
+
|
3148
|
+
for _, feature in window_features.iterrows():
|
3149
|
+
# Get feature class
|
3150
|
+
if class_value_field in feature:
|
3151
|
+
class_val = feature[class_value_field]
|
3152
|
+
# YOLO uses 0-indexed class IDs
|
3153
|
+
class_id = class_to_id.get(class_val, 1) - 1
|
3154
|
+
else:
|
3155
|
+
class_id = 0
|
3156
|
+
|
3157
|
+
# Get geometry bounds
|
3158
|
+
geom = feature.geometry.intersection(window_bounds)
|
3159
|
+
if not geom.is_empty:
|
3160
|
+
# Get bounds in world coordinates
|
3161
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
3162
|
+
|
3163
|
+
# Convert to pixel coordinates
|
3164
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
3165
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
3166
|
+
|
3167
|
+
# Ensure coordinates are within tile bounds
|
3168
|
+
xmin = max(0, min(tile_size, col_min))
|
3169
|
+
ymin = max(0, min(tile_size, row_min))
|
3170
|
+
xmax = max(0, min(tile_size, col_max))
|
3171
|
+
ymax = max(0, min(tile_size, row_max))
|
3172
|
+
|
3173
|
+
# Skip if box is too small
|
3174
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
3175
|
+
continue
|
3176
|
+
|
3177
|
+
# Calculate normalized coordinates (YOLO format)
|
3178
|
+
x_center = ((xmin + xmax) / 2) / tile_size
|
3179
|
+
y_center = ((ymin + ymax) / 2) / tile_size
|
3180
|
+
width = (xmax - xmin) / tile_size
|
3181
|
+
height = (ymax - ymin) / tile_size
|
3182
|
+
|
3183
|
+
# Add YOLO annotation line
|
3184
|
+
yolo_annotations.append(
|
3185
|
+
f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
3186
|
+
)
|
3187
|
+
|
3188
|
+
# Save YOLO annotations to text file
|
3189
|
+
if yolo_annotations:
|
3190
|
+
yolo_path = os.path.join(
|
3191
|
+
label_dir, f"tile_{tile_index:06d}.txt"
|
3192
|
+
)
|
3193
|
+
with open(yolo_path, "w") as f:
|
3194
|
+
f.write("\n".join(yolo_annotations))
|
3033
3195
|
|
3034
3196
|
# Update progress bar
|
3035
3197
|
pbar.update(1)
|
@@ -3047,6 +3209,39 @@ def export_geotiff_tiles(
|
|
3047
3209
|
# Close progress bar
|
3048
3210
|
pbar.close()
|
3049
3211
|
|
3212
|
+
# Save COCO annotations if applicable
|
3213
|
+
if metadata_format == "COCO":
|
3214
|
+
try:
|
3215
|
+
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
3216
|
+
json.dump(coco_annotations, f, indent=2)
|
3217
|
+
if not quiet:
|
3218
|
+
print(
|
3219
|
+
f"Saved COCO annotations: {len(coco_annotations['images'])} images, "
|
3220
|
+
f"{len(coco_annotations['annotations'])} annotations, "
|
3221
|
+
f"{len(coco_annotations['categories'])} categories"
|
3222
|
+
)
|
3223
|
+
except Exception as e:
|
3224
|
+
if not quiet:
|
3225
|
+
print(f"ERROR saving COCO annotations: {e}")
|
3226
|
+
stats["errors"] += 1
|
3227
|
+
|
3228
|
+
# Save YOLO classes file if applicable
|
3229
|
+
if metadata_format == "YOLO":
|
3230
|
+
try:
|
3231
|
+
# Create classes.txt with class names
|
3232
|
+
classes_path = os.path.join(out_folder, "classes.txt")
|
3233
|
+
# Sort by class ID to ensure correct order
|
3234
|
+
sorted_classes = sorted(class_to_id.items(), key=lambda x: x[1])
|
3235
|
+
with open(classes_path, "w") as f:
|
3236
|
+
for class_val, _ in sorted_classes:
|
3237
|
+
f.write(f"{class_val}\n")
|
3238
|
+
if not quiet:
|
3239
|
+
print(f"Saved YOLO classes file with {len(class_to_id)} classes")
|
3240
|
+
except Exception as e:
|
3241
|
+
if not quiet:
|
3242
|
+
print(f"ERROR saving YOLO classes file: {e}")
|
3243
|
+
stats["errors"] += 1
|
3244
|
+
|
3050
3245
|
# Create overview image if requested
|
3051
3246
|
if create_overview and stats["tile_coordinates"]:
|
3052
3247
|
try:
|
@@ -3115,8 +3310,9 @@ def export_geotiff_tiles(
|
|
3115
3310
|
|
3116
3311
|
def export_geotiff_tiles_batch(
|
3117
3312
|
images_folder,
|
3118
|
-
masks_folder,
|
3119
|
-
|
3313
|
+
masks_folder=None,
|
3314
|
+
masks_file=None,
|
3315
|
+
output_folder=None,
|
3120
3316
|
tile_size=256,
|
3121
3317
|
stride=128,
|
3122
3318
|
class_value_field="class",
|
@@ -3124,25 +3320,38 @@ def export_geotiff_tiles_batch(
|
|
3124
3320
|
max_tiles=None,
|
3125
3321
|
quiet=False,
|
3126
3322
|
all_touched=True,
|
3127
|
-
create_overview=False,
|
3128
3323
|
skip_empty_tiles=False,
|
3129
3324
|
image_extensions=None,
|
3130
3325
|
mask_extensions=None,
|
3326
|
+
match_by_name=True,
|
3327
|
+
metadata_format="PASCAL_VOC",
|
3131
3328
|
) -> Dict[str, Any]:
|
3132
3329
|
"""
|
3133
|
-
Export georeferenced GeoTIFF tiles from
|
3330
|
+
Export georeferenced GeoTIFF tiles from images and masks.
|
3331
|
+
|
3332
|
+
This function supports three mask input modes:
|
3333
|
+
1. Single vector file covering all images (masks_file parameter)
|
3334
|
+
2. Multiple vector files, one per image (masks_folder parameter)
|
3335
|
+
3. Multiple raster mask files (masks_folder parameter)
|
3336
|
+
|
3337
|
+
For mode 1 (single vector file), specify masks_file path. The function will
|
3338
|
+
use spatial intersection to determine which features apply to each image.
|
3134
3339
|
|
3135
|
-
|
3136
|
-
|
3137
|
-
|
3340
|
+
For mode 2/3 (multiple mask files), specify masks_folder path. Images and masks
|
3341
|
+
are paired either by matching filenames (match_by_name=True) or by sorted order
|
3342
|
+
(match_by_name=False).
|
3138
3343
|
|
3139
|
-
|
3140
|
-
|
3344
|
+
All image tiles are saved to a single 'images' folder and all mask tiles to a
|
3345
|
+
single 'masks' folder within the output directory.
|
3141
3346
|
|
3142
3347
|
Args:
|
3143
3348
|
images_folder (str): Path to folder containing raster images
|
3144
|
-
masks_folder (str): Path to folder containing classification masks/vectors
|
3145
|
-
|
3349
|
+
masks_folder (str, optional): Path to folder containing classification masks/vectors.
|
3350
|
+
Use this for multiple mask files (one per image or raster masks).
|
3351
|
+
masks_file (str, optional): Path to a single vector file covering all images.
|
3352
|
+
Use this for a single GeoJSON/Shapefile that covers multiple images.
|
3353
|
+
output_folder (str, optional): Path to output folder. If None, creates 'tiles'
|
3354
|
+
subfolder in images_folder.
|
3146
3355
|
tile_size (int): Size of tiles in pixels (square)
|
3147
3356
|
stride (int): Step size between tiles
|
3148
3357
|
class_value_field (str): Field containing class values (for vector data)
|
@@ -3154,18 +3363,63 @@ def export_geotiff_tiles_batch(
|
|
3154
3363
|
skip_empty_tiles (bool): If True, skip tiles with no features
|
3155
3364
|
image_extensions (list): List of image file extensions to process (default: common raster formats)
|
3156
3365
|
mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
|
3366
|
+
match_by_name (bool): If True, match image and mask files by base filename.
|
3367
|
+
If False, match by sorted order (alphabetically). Only applies when masks_folder is used.
|
3368
|
+
metadata_format (str): Annotation format - "PASCAL_VOC" (XML), "COCO" (JSON), or "YOLO" (TXT).
|
3369
|
+
Default is "PASCAL_VOC".
|
3157
3370
|
|
3158
3371
|
Returns:
|
3159
3372
|
Dict[str, Any]: Dictionary containing batch processing statistics
|
3160
3373
|
|
3161
3374
|
Raises:
|
3162
|
-
ValueError: If no images
|
3375
|
+
ValueError: If no images found, or if masks_folder and masks_file are both specified,
|
3376
|
+
or if neither is specified, or if counts don't match when using masks_folder with
|
3377
|
+
match_by_name=False.
|
3378
|
+
|
3379
|
+
Examples:
|
3380
|
+
# Single vector file covering all images
|
3381
|
+
>>> stats = export_geotiff_tiles_batch(
|
3382
|
+
... images_folder='data/images',
|
3383
|
+
... masks_file='data/buildings.geojson',
|
3384
|
+
... output_folder='output/tiles'
|
3385
|
+
... )
|
3386
|
+
|
3387
|
+
# Multiple vector files, matched by filename
|
3388
|
+
>>> stats = export_geotiff_tiles_batch(
|
3389
|
+
... images_folder='data/images',
|
3390
|
+
... masks_folder='data/masks',
|
3391
|
+
... output_folder='output/tiles',
|
3392
|
+
... match_by_name=True
|
3393
|
+
... )
|
3394
|
+
|
3395
|
+
# Multiple mask files, matched by sorted order
|
3396
|
+
>>> stats = export_geotiff_tiles_batch(
|
3397
|
+
... images_folder='data/images',
|
3398
|
+
... masks_folder='data/masks',
|
3399
|
+
... output_folder='output/tiles',
|
3400
|
+
... match_by_name=False
|
3401
|
+
... )
|
3163
3402
|
"""
|
3164
3403
|
|
3165
3404
|
import logging
|
3166
3405
|
|
3167
3406
|
logging.getLogger("rasterio").setLevel(logging.ERROR)
|
3168
3407
|
|
3408
|
+
# Validate input parameters
|
3409
|
+
if masks_folder is not None and masks_file is not None:
|
3410
|
+
raise ValueError(
|
3411
|
+
"Cannot specify both masks_folder and masks_file. Please use only one."
|
3412
|
+
)
|
3413
|
+
|
3414
|
+
if masks_folder is None and masks_file is None:
|
3415
|
+
raise ValueError(
|
3416
|
+
"Must specify either masks_folder or masks_file for mask data source."
|
3417
|
+
)
|
3418
|
+
|
3419
|
+
# Default output folder if not specified
|
3420
|
+
if output_folder is None:
|
3421
|
+
output_folder = os.path.join(images_folder, "tiles")
|
3422
|
+
|
3169
3423
|
# Default extensions if not provided
|
3170
3424
|
if image_extensions is None:
|
3171
3425
|
image_extensions = [".tif", ".tiff", ".jpg", ".jpeg", ".png", ".jp2", ".img"]
|
@@ -3196,36 +3450,107 @@ def export_geotiff_tiles_batch(
|
|
3196
3450
|
os.makedirs(output_images_dir, exist_ok=True)
|
3197
3451
|
os.makedirs(output_masks_dir, exist_ok=True)
|
3198
3452
|
|
3453
|
+
# Create annotation directory based on metadata format
|
3454
|
+
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
3455
|
+
ann_dir = os.path.join(output_folder, "annotations")
|
3456
|
+
os.makedirs(ann_dir, exist_ok=True)
|
3457
|
+
|
3458
|
+
# Initialize COCO annotations dictionary
|
3459
|
+
coco_annotations = None
|
3460
|
+
if metadata_format == "COCO":
|
3461
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
3462
|
+
|
3463
|
+
# Initialize YOLO class set
|
3464
|
+
yolo_classes = set() if metadata_format == "YOLO" else None
|
3465
|
+
|
3199
3466
|
# Get list of image files
|
3200
3467
|
image_files = []
|
3201
3468
|
for ext in image_extensions:
|
3202
3469
|
pattern = os.path.join(images_folder, f"*{ext}")
|
3203
3470
|
image_files.extend(glob.glob(pattern))
|
3204
3471
|
|
3205
|
-
# Get list of mask files
|
3206
|
-
mask_files = []
|
3207
|
-
for ext in mask_extensions:
|
3208
|
-
pattern = os.path.join(masks_folder, f"*{ext}")
|
3209
|
-
mask_files.extend(glob.glob(pattern))
|
3210
|
-
|
3211
3472
|
# Sort files for consistent processing
|
3212
3473
|
image_files.sort()
|
3213
|
-
mask_files.sort()
|
3214
3474
|
|
3215
3475
|
if not image_files:
|
3216
3476
|
raise ValueError(
|
3217
3477
|
f"No image files found in {images_folder} with extensions {image_extensions}"
|
3218
3478
|
)
|
3219
3479
|
|
3220
|
-
|
3221
|
-
|
3222
|
-
|
3223
|
-
|
3480
|
+
# Handle different mask input modes
|
3481
|
+
use_single_mask_file = masks_file is not None
|
3482
|
+
mask_files = []
|
3483
|
+
image_mask_pairs = []
|
3224
3484
|
|
3225
|
-
if
|
3226
|
-
|
3227
|
-
|
3228
|
-
|
3485
|
+
if use_single_mask_file:
|
3486
|
+
# Mode 1: Single vector file covering all images
|
3487
|
+
if not os.path.exists(masks_file):
|
3488
|
+
raise ValueError(f"Mask file not found: {masks_file}")
|
3489
|
+
|
3490
|
+
# Load the single mask file once - will be spatially filtered per image
|
3491
|
+
single_mask_gdf = gpd.read_file(masks_file)
|
3492
|
+
|
3493
|
+
if not quiet:
|
3494
|
+
print(f"Using single mask file: {masks_file}")
|
3495
|
+
print(
|
3496
|
+
f"Mask contains {len(single_mask_gdf)} features in CRS: {single_mask_gdf.crs}"
|
3497
|
+
)
|
3498
|
+
|
3499
|
+
# Create pairs with the same mask file for all images
|
3500
|
+
for image_file in image_files:
|
3501
|
+
image_mask_pairs.append((image_file, masks_file, single_mask_gdf))
|
3502
|
+
|
3503
|
+
else:
|
3504
|
+
# Mode 2/3: Multiple mask files (vector or raster)
|
3505
|
+
# Get list of mask files
|
3506
|
+
for ext in mask_extensions:
|
3507
|
+
pattern = os.path.join(masks_folder, f"*{ext}")
|
3508
|
+
mask_files.extend(glob.glob(pattern))
|
3509
|
+
|
3510
|
+
# Sort files for consistent processing
|
3511
|
+
mask_files.sort()
|
3512
|
+
|
3513
|
+
if not mask_files:
|
3514
|
+
raise ValueError(
|
3515
|
+
f"No mask files found in {masks_folder} with extensions {mask_extensions}"
|
3516
|
+
)
|
3517
|
+
|
3518
|
+
# Match images to masks
|
3519
|
+
if match_by_name:
|
3520
|
+
# Match by base filename
|
3521
|
+
image_dict = {
|
3522
|
+
os.path.splitext(os.path.basename(f))[0]: f for f in image_files
|
3523
|
+
}
|
3524
|
+
mask_dict = {
|
3525
|
+
os.path.splitext(os.path.basename(f))[0]: f for f in mask_files
|
3526
|
+
}
|
3527
|
+
|
3528
|
+
# Find matching pairs
|
3529
|
+
for img_base, img_path in image_dict.items():
|
3530
|
+
if img_base in mask_dict:
|
3531
|
+
image_mask_pairs.append((img_path, mask_dict[img_base], None))
|
3532
|
+
else:
|
3533
|
+
if not quiet:
|
3534
|
+
print(f"Warning: No mask found for image {img_base}")
|
3535
|
+
|
3536
|
+
if not image_mask_pairs:
|
3537
|
+
raise ValueError(
|
3538
|
+
"No matching image-mask pairs found when matching by filename. "
|
3539
|
+
"Check that image and mask files have matching base names."
|
3540
|
+
)
|
3541
|
+
|
3542
|
+
else:
|
3543
|
+
# Match by sorted order
|
3544
|
+
if len(image_files) != len(mask_files):
|
3545
|
+
raise ValueError(
|
3546
|
+
f"Number of image files ({len(image_files)}) does not match "
|
3547
|
+
f"number of mask files ({len(mask_files)}) when matching by sorted order. "
|
3548
|
+
f"Use match_by_name=True for filename-based matching."
|
3549
|
+
)
|
3550
|
+
|
3551
|
+
# Create pairs by sorted order
|
3552
|
+
for image_file, mask_file in zip(image_files, mask_files):
|
3553
|
+
image_mask_pairs.append((image_file, mask_file, None))
|
3229
3554
|
|
3230
3555
|
# Initialize batch statistics
|
3231
3556
|
batch_stats = {
|
@@ -3239,23 +3564,24 @@ def export_geotiff_tiles_batch(
|
|
3239
3564
|
}
|
3240
3565
|
|
3241
3566
|
if not quiet:
|
3242
|
-
|
3243
|
-
f"Found {len(image_files)} image files
|
3244
|
-
|
3245
|
-
|
3567
|
+
if use_single_mask_file:
|
3568
|
+
print(f"Found {len(image_files)} image files to process")
|
3569
|
+
print(f"Using single mask file: {masks_file}")
|
3570
|
+
else:
|
3571
|
+
print(f"Found {len(image_mask_pairs)} matching image-mask pairs to process")
|
3572
|
+
print(f"Processing batch from {images_folder} and {masks_folder}")
|
3246
3573
|
print(f"Output folder: {output_folder}")
|
3247
3574
|
print("-" * 60)
|
3248
3575
|
|
3249
3576
|
# Global tile counter for unique naming
|
3250
3577
|
global_tile_counter = 0
|
3251
3578
|
|
3252
|
-
# Process each image-mask pair
|
3253
|
-
for idx, (image_file, mask_file) in enumerate(
|
3579
|
+
# Process each image-mask pair
|
3580
|
+
for idx, (image_file, mask_file, mask_gdf) in enumerate(
|
3254
3581
|
tqdm(
|
3255
|
-
|
3582
|
+
image_mask_pairs,
|
3256
3583
|
desc="Processing image pairs",
|
3257
3584
|
disable=quiet,
|
3258
|
-
total=len(image_files),
|
3259
3585
|
)
|
3260
3586
|
):
|
3261
3587
|
batch_stats["total_image_pairs"] += 1
|
@@ -3267,9 +3593,12 @@ def export_geotiff_tiles_batch(
|
|
3267
3593
|
if not quiet:
|
3268
3594
|
print(f"\nProcessing: {base_name}")
|
3269
3595
|
print(f" Image: {os.path.basename(image_file)}")
|
3270
|
-
|
3596
|
+
if use_single_mask_file:
|
3597
|
+
print(f" Mask: {os.path.basename(mask_file)} (spatially filtered)")
|
3598
|
+
else:
|
3599
|
+
print(f" Mask: {os.path.basename(mask_file)}")
|
3271
3600
|
|
3272
|
-
# Process the image-mask pair
|
3601
|
+
# Process the image-mask pair
|
3273
3602
|
tiles_generated = _process_image_mask_pair(
|
3274
3603
|
image_file=image_file,
|
3275
3604
|
mask_file=mask_file,
|
@@ -3285,6 +3614,15 @@ def export_geotiff_tiles_batch(
|
|
3285
3614
|
all_touched=all_touched,
|
3286
3615
|
skip_empty_tiles=skip_empty_tiles,
|
3287
3616
|
quiet=quiet,
|
3617
|
+
mask_gdf=mask_gdf, # Pass pre-loaded GeoDataFrame if using single mask
|
3618
|
+
use_single_mask_file=use_single_mask_file,
|
3619
|
+
metadata_format=metadata_format,
|
3620
|
+
ann_dir=(
|
3621
|
+
ann_dir
|
3622
|
+
if "ann_dir" in locals()
|
3623
|
+
and metadata_format in ["PASCAL_VOC", "COCO"]
|
3624
|
+
else None
|
3625
|
+
),
|
3288
3626
|
)
|
3289
3627
|
|
3290
3628
|
# Update counters
|
@@ -3306,6 +3644,23 @@ def export_geotiff_tiles_batch(
|
|
3306
3644
|
}
|
3307
3645
|
)
|
3308
3646
|
|
3647
|
+
# Aggregate COCO annotations
|
3648
|
+
if metadata_format == "COCO" and "coco_data" in tiles_generated:
|
3649
|
+
coco_data = tiles_generated["coco_data"]
|
3650
|
+
# Add images and annotations
|
3651
|
+
coco_annotations["images"].extend(coco_data.get("images", []))
|
3652
|
+
coco_annotations["annotations"].extend(coco_data.get("annotations", []))
|
3653
|
+
# Merge categories (avoid duplicates)
|
3654
|
+
for cat in coco_data.get("categories", []):
|
3655
|
+
if not any(
|
3656
|
+
c["id"] == cat["id"] for c in coco_annotations["categories"]
|
3657
|
+
):
|
3658
|
+
coco_annotations["categories"].append(cat)
|
3659
|
+
|
3660
|
+
# Aggregate YOLO classes
|
3661
|
+
if metadata_format == "YOLO" and "yolo_classes" in tiles_generated:
|
3662
|
+
yolo_classes.update(tiles_generated["yolo_classes"])
|
3663
|
+
|
3309
3664
|
except Exception as e:
|
3310
3665
|
if not quiet:
|
3311
3666
|
print(f"ERROR processing {base_name}: {e}")
|
@@ -3314,6 +3669,33 @@ def export_geotiff_tiles_batch(
|
|
3314
3669
|
)
|
3315
3670
|
batch_stats["errors"] += 1
|
3316
3671
|
|
3672
|
+
# Save aggregated COCO annotations
|
3673
|
+
if metadata_format == "COCO" and coco_annotations:
|
3674
|
+
import json
|
3675
|
+
|
3676
|
+
coco_path = os.path.join(ann_dir, "instances.json")
|
3677
|
+
with open(coco_path, "w") as f:
|
3678
|
+
json.dump(coco_annotations, f, indent=2)
|
3679
|
+
if not quiet:
|
3680
|
+
print(f"\nSaved COCO annotations: {coco_path}")
|
3681
|
+
print(
|
3682
|
+
f" Images: {len(coco_annotations['images'])}, "
|
3683
|
+
f"Annotations: {len(coco_annotations['annotations'])}, "
|
3684
|
+
f"Categories: {len(coco_annotations['categories'])}"
|
3685
|
+
)
|
3686
|
+
|
3687
|
+
# Save aggregated YOLO classes
|
3688
|
+
if metadata_format == "YOLO" and yolo_classes:
|
3689
|
+
classes_path = os.path.join(output_folder, "labels", "classes.txt")
|
3690
|
+
os.makedirs(os.path.dirname(classes_path), exist_ok=True)
|
3691
|
+
sorted_classes = sorted(yolo_classes)
|
3692
|
+
with open(classes_path, "w") as f:
|
3693
|
+
for cls in sorted_classes:
|
3694
|
+
f.write(f"{cls}\n")
|
3695
|
+
if not quiet:
|
3696
|
+
print(f"\nSaved YOLO classes: {classes_path}")
|
3697
|
+
print(f" Total classes: {len(sorted_classes)}")
|
3698
|
+
|
3317
3699
|
# Print batch summary
|
3318
3700
|
if not quiet:
|
3319
3701
|
print("\n" + "=" * 60)
|
@@ -3337,6 +3719,10 @@ def export_geotiff_tiles_batch(
|
|
3337
3719
|
print(f"Output saved to: {output_folder}")
|
3338
3720
|
print(f" Images: {output_images_dir}")
|
3339
3721
|
print(f" Masks: {output_masks_dir}")
|
3722
|
+
if metadata_format in ["PASCAL_VOC", "COCO"]:
|
3723
|
+
print(f" Annotations: {ann_dir}")
|
3724
|
+
elif metadata_format == "YOLO":
|
3725
|
+
print(f" Labels: {os.path.join(output_folder, 'labels')}")
|
3340
3726
|
|
3341
3727
|
# List failed files if any
|
3342
3728
|
if batch_stats["failed_files"]:
|
@@ -3362,10 +3748,18 @@ def _process_image_mask_pair(
|
|
3362
3748
|
all_touched=True,
|
3363
3749
|
skip_empty_tiles=False,
|
3364
3750
|
quiet=False,
|
3751
|
+
mask_gdf=None,
|
3752
|
+
use_single_mask_file=False,
|
3753
|
+
metadata_format="PASCAL_VOC",
|
3754
|
+
ann_dir=None,
|
3365
3755
|
):
|
3366
3756
|
"""
|
3367
3757
|
Process a single image-mask pair and save tiles directly to output directories.
|
3368
3758
|
|
3759
|
+
Args:
|
3760
|
+
mask_gdf (GeoDataFrame, optional): Pre-loaded GeoDataFrame when using single mask file
|
3761
|
+
use_single_mask_file (bool): If True, spatially filter mask_gdf to image bounds
|
3762
|
+
|
3369
3763
|
Returns:
|
3370
3764
|
dict: Statistics for this image-mask pair
|
3371
3765
|
"""
|
@@ -3390,6 +3784,13 @@ def _process_image_mask_pair(
|
|
3390
3784
|
"errors": 0,
|
3391
3785
|
}
|
3392
3786
|
|
3787
|
+
# Initialize COCO/YOLO tracking for this image
|
3788
|
+
if metadata_format == "COCO":
|
3789
|
+
stats["coco_data"] = {"images": [], "annotations": [], "categories": []}
|
3790
|
+
coco_ann_id = 0
|
3791
|
+
if metadata_format == "YOLO":
|
3792
|
+
stats["yolo_classes"] = set()
|
3793
|
+
|
3393
3794
|
# Open the input raster
|
3394
3795
|
with rasterio.open(image_file) as src:
|
3395
3796
|
# Calculate number of tiles
|
@@ -3433,11 +3834,36 @@ def _process_image_mask_pair(
|
|
3433
3834
|
else:
|
3434
3835
|
# Load vector class data
|
3435
3836
|
try:
|
3436
|
-
|
3837
|
+
if use_single_mask_file and mask_gdf is not None:
|
3838
|
+
# Using pre-loaded single mask file - spatially filter to image bounds
|
3839
|
+
# Get image bounds
|
3840
|
+
image_bounds = box(*src.bounds)
|
3841
|
+
image_gdf = gpd.GeoDataFrame(
|
3842
|
+
{"geometry": [image_bounds]}, crs=src.crs
|
3843
|
+
)
|
3437
3844
|
|
3438
|
-
|
3439
|
-
|
3440
|
-
|
3845
|
+
# Reproject mask if needed
|
3846
|
+
if mask_gdf.crs != src.crs:
|
3847
|
+
mask_gdf_reprojected = mask_gdf.to_crs(src.crs)
|
3848
|
+
else:
|
3849
|
+
mask_gdf_reprojected = mask_gdf
|
3850
|
+
|
3851
|
+
# Spatially filter features that intersect with image bounds
|
3852
|
+
gdf = mask_gdf_reprojected[
|
3853
|
+
mask_gdf_reprojected.intersects(image_bounds)
|
3854
|
+
].copy()
|
3855
|
+
|
3856
|
+
if not quiet and len(gdf) > 0:
|
3857
|
+
print(
|
3858
|
+
f" Filtered to {len(gdf)} features intersecting image bounds"
|
3859
|
+
)
|
3860
|
+
else:
|
3861
|
+
# Load individual mask file
|
3862
|
+
gdf = gpd.read_file(mask_file)
|
3863
|
+
|
3864
|
+
# Always reproject to match raster CRS
|
3865
|
+
if gdf.crs != src.crs:
|
3866
|
+
gdf = gdf.to_crs(src.crs)
|
3441
3867
|
|
3442
3868
|
# Apply buffer if specified
|
3443
3869
|
if buffer_radius > 0:
|
@@ -3457,9 +3883,6 @@ def _process_image_mask_pair(
|
|
3457
3883
|
tile_index = 0
|
3458
3884
|
for y in range(num_tiles_y):
|
3459
3885
|
for x in range(num_tiles_x):
|
3460
|
-
if tile_index >= max_tiles:
|
3461
|
-
break
|
3462
|
-
|
3463
3886
|
# Calculate window coordinates
|
3464
3887
|
window_x = x * stride
|
3465
3888
|
window_y = y * stride
|
@@ -3562,9 +3985,12 @@ def _process_image_mask_pair(
|
|
3562
3985
|
|
3563
3986
|
# Skip tile if no features and skip_empty_tiles is True
|
3564
3987
|
if skip_empty_tiles and not has_features:
|
3565
|
-
tile_index += 1
|
3566
3988
|
continue
|
3567
3989
|
|
3990
|
+
# Check if we've reached max_tiles before saving
|
3991
|
+
if tile_index >= max_tiles:
|
3992
|
+
break
|
3993
|
+
|
3568
3994
|
# Generate unique tile name
|
3569
3995
|
tile_name = f"{base_name}_{global_tile_counter + tile_index:06d}"
|
3570
3996
|
|
@@ -3619,6 +4045,197 @@ def _process_image_mask_pair(
|
|
3619
4045
|
print(f"ERROR saving label GeoTIFF: {e}")
|
3620
4046
|
stats["errors"] += 1
|
3621
4047
|
|
4048
|
+
# Generate annotation metadata based on format
|
4049
|
+
if metadata_format == "PASCAL_VOC" and ann_dir:
|
4050
|
+
# Create PASCAL VOC XML annotation
|
4051
|
+
from lxml import etree as ET
|
4052
|
+
|
4053
|
+
annotation = ET.Element("annotation")
|
4054
|
+
ET.SubElement(annotation, "folder").text = os.path.basename(
|
4055
|
+
output_images_dir
|
4056
|
+
)
|
4057
|
+
ET.SubElement(annotation, "filename").text = f"{tile_name}.tif"
|
4058
|
+
ET.SubElement(annotation, "path").text = image_path
|
4059
|
+
|
4060
|
+
source = ET.SubElement(annotation, "source")
|
4061
|
+
ET.SubElement(source, "database").text = "GeoAI"
|
4062
|
+
|
4063
|
+
size = ET.SubElement(annotation, "size")
|
4064
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
4065
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
4066
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
4067
|
+
|
4068
|
+
ET.SubElement(annotation, "segmented").text = "1"
|
4069
|
+
|
4070
|
+
# Find connected components for instance segmentation
|
4071
|
+
from scipy import ndimage
|
4072
|
+
|
4073
|
+
for class_id in np.unique(label_mask):
|
4074
|
+
if class_id == 0:
|
4075
|
+
continue
|
4076
|
+
|
4077
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
4078
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
4079
|
+
|
4080
|
+
for instance_id in range(1, num_features + 1):
|
4081
|
+
instance_mask = labeled_array == instance_id
|
4082
|
+
coords = np.argwhere(instance_mask)
|
4083
|
+
|
4084
|
+
if len(coords) == 0:
|
4085
|
+
continue
|
4086
|
+
|
4087
|
+
ymin, xmin = coords.min(axis=0)
|
4088
|
+
ymax, xmax = coords.max(axis=0)
|
4089
|
+
|
4090
|
+
obj = ET.SubElement(annotation, "object")
|
4091
|
+
class_name = next(
|
4092
|
+
(k for k, v in class_to_id.items() if v == class_id),
|
4093
|
+
str(class_id),
|
4094
|
+
)
|
4095
|
+
ET.SubElement(obj, "name").text = str(class_name)
|
4096
|
+
ET.SubElement(obj, "pose").text = "Unspecified"
|
4097
|
+
ET.SubElement(obj, "truncated").text = "0"
|
4098
|
+
ET.SubElement(obj, "difficult").text = "0"
|
4099
|
+
|
4100
|
+
bndbox = ET.SubElement(obj, "bndbox")
|
4101
|
+
ET.SubElement(bndbox, "xmin").text = str(int(xmin))
|
4102
|
+
ET.SubElement(bndbox, "ymin").text = str(int(ymin))
|
4103
|
+
ET.SubElement(bndbox, "xmax").text = str(int(xmax))
|
4104
|
+
ET.SubElement(bndbox, "ymax").text = str(int(ymax))
|
4105
|
+
|
4106
|
+
# Save XML file
|
4107
|
+
xml_path = os.path.join(ann_dir, f"{tile_name}.xml")
|
4108
|
+
tree = ET.ElementTree(annotation)
|
4109
|
+
tree.write(xml_path, pretty_print=True, encoding="utf-8")
|
4110
|
+
|
4111
|
+
elif metadata_format == "COCO":
|
4112
|
+
# Add COCO image entry
|
4113
|
+
image_id = int(global_tile_counter + tile_index)
|
4114
|
+
stats["coco_data"]["images"].append(
|
4115
|
+
{
|
4116
|
+
"id": image_id,
|
4117
|
+
"file_name": f"{tile_name}.tif",
|
4118
|
+
"width": int(tile_size),
|
4119
|
+
"height": int(tile_size),
|
4120
|
+
}
|
4121
|
+
)
|
4122
|
+
|
4123
|
+
# Add COCO categories (only once per unique class)
|
4124
|
+
for class_val, class_id in class_to_id.items():
|
4125
|
+
if not any(
|
4126
|
+
c["id"] == class_id
|
4127
|
+
for c in stats["coco_data"]["categories"]
|
4128
|
+
):
|
4129
|
+
stats["coco_data"]["categories"].append(
|
4130
|
+
{
|
4131
|
+
"id": int(class_id),
|
4132
|
+
"name": str(class_val),
|
4133
|
+
"supercategory": "object",
|
4134
|
+
}
|
4135
|
+
)
|
4136
|
+
|
4137
|
+
# Add COCO annotations (instance segmentation)
|
4138
|
+
from scipy import ndimage
|
4139
|
+
from skimage import measure
|
4140
|
+
|
4141
|
+
for class_id in np.unique(label_mask):
|
4142
|
+
if class_id == 0:
|
4143
|
+
continue
|
4144
|
+
|
4145
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
4146
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
4147
|
+
|
4148
|
+
for instance_id in range(1, num_features + 1):
|
4149
|
+
instance_mask = (labeled_array == instance_id).astype(
|
4150
|
+
np.uint8
|
4151
|
+
)
|
4152
|
+
coords = np.argwhere(instance_mask)
|
4153
|
+
|
4154
|
+
if len(coords) == 0:
|
4155
|
+
continue
|
4156
|
+
|
4157
|
+
ymin, xmin = coords.min(axis=0)
|
4158
|
+
ymax, xmax = coords.max(axis=0)
|
4159
|
+
|
4160
|
+
bbox = [
|
4161
|
+
int(xmin),
|
4162
|
+
int(ymin),
|
4163
|
+
int(xmax - xmin),
|
4164
|
+
int(ymax - ymin),
|
4165
|
+
]
|
4166
|
+
area = int(np.sum(instance_mask))
|
4167
|
+
|
4168
|
+
# Find contours for segmentation
|
4169
|
+
contours = measure.find_contours(instance_mask, 0.5)
|
4170
|
+
segmentation = []
|
4171
|
+
for contour in contours:
|
4172
|
+
contour = np.flip(contour, axis=1)
|
4173
|
+
segmentation_points = contour.ravel().tolist()
|
4174
|
+
if len(segmentation_points) >= 6:
|
4175
|
+
segmentation.append(segmentation_points)
|
4176
|
+
|
4177
|
+
if segmentation:
|
4178
|
+
stats["coco_data"]["annotations"].append(
|
4179
|
+
{
|
4180
|
+
"id": int(coco_ann_id),
|
4181
|
+
"image_id": int(image_id),
|
4182
|
+
"category_id": int(class_id),
|
4183
|
+
"bbox": bbox,
|
4184
|
+
"area": area,
|
4185
|
+
"segmentation": segmentation,
|
4186
|
+
"iscrowd": 0,
|
4187
|
+
}
|
4188
|
+
)
|
4189
|
+
coco_ann_id += 1
|
4190
|
+
|
4191
|
+
elif metadata_format == "YOLO":
|
4192
|
+
# Create YOLO labels directory if needed
|
4193
|
+
labels_dir = os.path.join(
|
4194
|
+
os.path.dirname(output_images_dir), "labels"
|
4195
|
+
)
|
4196
|
+
os.makedirs(labels_dir, exist_ok=True)
|
4197
|
+
|
4198
|
+
# Generate YOLO annotation file
|
4199
|
+
yolo_path = os.path.join(labels_dir, f"{tile_name}.txt")
|
4200
|
+
from scipy import ndimage
|
4201
|
+
|
4202
|
+
with open(yolo_path, "w") as yolo_file:
|
4203
|
+
for class_id in np.unique(label_mask):
|
4204
|
+
if class_id == 0:
|
4205
|
+
continue
|
4206
|
+
|
4207
|
+
# Track class for classes.txt
|
4208
|
+
class_name = next(
|
4209
|
+
(k for k, v in class_to_id.items() if v == class_id),
|
4210
|
+
str(class_id),
|
4211
|
+
)
|
4212
|
+
stats["yolo_classes"].add(class_name)
|
4213
|
+
|
4214
|
+
class_mask = (label_mask == class_id).astype(np.uint8)
|
4215
|
+
labeled_array, num_features = ndimage.label(class_mask)
|
4216
|
+
|
4217
|
+
for instance_id in range(1, num_features + 1):
|
4218
|
+
instance_mask = labeled_array == instance_id
|
4219
|
+
coords = np.argwhere(instance_mask)
|
4220
|
+
|
4221
|
+
if len(coords) == 0:
|
4222
|
+
continue
|
4223
|
+
|
4224
|
+
ymin, xmin = coords.min(axis=0)
|
4225
|
+
ymax, xmax = coords.max(axis=0)
|
4226
|
+
|
4227
|
+
# Convert to YOLO format (normalized center coordinates)
|
4228
|
+
x_center = ((xmin + xmax) / 2) / tile_size
|
4229
|
+
y_center = ((ymin + ymax) / 2) / tile_size
|
4230
|
+
width = (xmax - xmin) / tile_size
|
4231
|
+
height = (ymax - ymin) / tile_size
|
4232
|
+
|
4233
|
+
# YOLO uses 0-based class indices
|
4234
|
+
yolo_class_id = class_id - 1
|
4235
|
+
yolo_file.write(
|
4236
|
+
f"{yolo_class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
4237
|
+
)
|
4238
|
+
|
3622
4239
|
tile_index += 1
|
3623
4240
|
if tile_index >= max_tiles:
|
3624
4241
|
break
|
@@ -3629,6 +4246,179 @@ def _process_image_mask_pair(
|
|
3629
4246
|
return stats
|
3630
4247
|
|
3631
4248
|
|
4249
|
+
def display_training_tiles(
|
4250
|
+
output_dir,
|
4251
|
+
num_tiles=6,
|
4252
|
+
figsize=(18, 6),
|
4253
|
+
cmap="gray",
|
4254
|
+
save_path=None,
|
4255
|
+
):
|
4256
|
+
"""
|
4257
|
+
Display image and mask tile pairs from training data output.
|
4258
|
+
|
4259
|
+
Args:
|
4260
|
+
output_dir (str): Path to output directory containing 'images' and 'masks' subdirectories
|
4261
|
+
num_tiles (int): Number of tile pairs to display (default: 6)
|
4262
|
+
figsize (tuple): Figure size as (width, height) in inches (default: (18, 6))
|
4263
|
+
cmap (str): Colormap for mask display (default: 'gray')
|
4264
|
+
save_path (str, optional): If provided, save figure to this path instead of displaying
|
4265
|
+
|
4266
|
+
Returns:
|
4267
|
+
tuple: (fig, axes) matplotlib figure and axes objects
|
4268
|
+
|
4269
|
+
Example:
|
4270
|
+
>>> fig, axes = display_training_tiles('output/tiles', num_tiles=6)
|
4271
|
+
>>> # Or save to file
|
4272
|
+
>>> display_training_tiles('output/tiles', num_tiles=4, save_path='tiles_preview.png')
|
4273
|
+
"""
|
4274
|
+
import matplotlib.pyplot as plt
|
4275
|
+
|
4276
|
+
# Get list of image tiles
|
4277
|
+
images_dir = os.path.join(output_dir, "images")
|
4278
|
+
if not os.path.exists(images_dir):
|
4279
|
+
raise ValueError(f"Images directory not found: {images_dir}")
|
4280
|
+
|
4281
|
+
image_tiles = sorted(os.listdir(images_dir))[:num_tiles]
|
4282
|
+
|
4283
|
+
if not image_tiles:
|
4284
|
+
raise ValueError(f"No image tiles found in {images_dir}")
|
4285
|
+
|
4286
|
+
# Limit to available tiles
|
4287
|
+
num_tiles = min(num_tiles, len(image_tiles))
|
4288
|
+
|
4289
|
+
# Create figure with subplots
|
4290
|
+
fig, axes = plt.subplots(2, num_tiles, figsize=figsize)
|
4291
|
+
|
4292
|
+
# Handle case where num_tiles is 1
|
4293
|
+
if num_tiles == 1:
|
4294
|
+
axes = axes.reshape(2, 1)
|
4295
|
+
|
4296
|
+
for idx, tile_name in enumerate(image_tiles):
|
4297
|
+
# Load and display image tile
|
4298
|
+
image_path = os.path.join(output_dir, "images", tile_name)
|
4299
|
+
with rasterio.open(image_path) as src:
|
4300
|
+
show(src, ax=axes[0, idx], title=f"Image {idx+1}")
|
4301
|
+
|
4302
|
+
# Load and display mask tile
|
4303
|
+
mask_path = os.path.join(output_dir, "masks", tile_name)
|
4304
|
+
if os.path.exists(mask_path):
|
4305
|
+
with rasterio.open(mask_path) as src:
|
4306
|
+
show(src, ax=axes[1, idx], title=f"Mask {idx+1}", cmap=cmap)
|
4307
|
+
else:
|
4308
|
+
axes[1, idx].text(
|
4309
|
+
0.5,
|
4310
|
+
0.5,
|
4311
|
+
"Mask not found",
|
4312
|
+
ha="center",
|
4313
|
+
va="center",
|
4314
|
+
transform=axes[1, idx].transAxes,
|
4315
|
+
)
|
4316
|
+
axes[1, idx].set_title(f"Mask {idx+1}")
|
4317
|
+
|
4318
|
+
plt.tight_layout()
|
4319
|
+
|
4320
|
+
# Save or show
|
4321
|
+
if save_path:
|
4322
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
4323
|
+
plt.close(fig)
|
4324
|
+
print(f"Figure saved to: {save_path}")
|
4325
|
+
else:
|
4326
|
+
plt.show()
|
4327
|
+
|
4328
|
+
return fig, axes
|
4329
|
+
|
4330
|
+
|
4331
|
+
def display_image_with_vector(
|
4332
|
+
image_path,
|
4333
|
+
vector_path,
|
4334
|
+
figsize=(16, 8),
|
4335
|
+
vector_color="red",
|
4336
|
+
vector_linewidth=1,
|
4337
|
+
vector_facecolor="none",
|
4338
|
+
save_path=None,
|
4339
|
+
):
|
4340
|
+
"""
|
4341
|
+
Display a raster image alongside the same image with vector overlay.
|
4342
|
+
|
4343
|
+
Args:
|
4344
|
+
image_path (str): Path to raster image file
|
4345
|
+
vector_path (str): Path to vector file (GeoJSON, Shapefile, etc.)
|
4346
|
+
figsize (tuple): Figure size as (width, height) in inches (default: (16, 8))
|
4347
|
+
vector_color (str): Edge color for vector features (default: 'red')
|
4348
|
+
vector_linewidth (float): Line width for vector features (default: 1)
|
4349
|
+
vector_facecolor (str): Fill color for vector features (default: 'none')
|
4350
|
+
save_path (str, optional): If provided, save figure to this path instead of displaying
|
4351
|
+
|
4352
|
+
Returns:
|
4353
|
+
tuple: (fig, axes, info_dict) where info_dict contains image and vector metadata
|
4354
|
+
|
4355
|
+
Example:
|
4356
|
+
>>> fig, axes, info = display_image_with_vector(
|
4357
|
+
... 'image.tif',
|
4358
|
+
... 'buildings.geojson',
|
4359
|
+
... vector_color='blue'
|
4360
|
+
... )
|
4361
|
+
>>> print(f"Number of features: {info['num_features']}")
|
4362
|
+
"""
|
4363
|
+
import matplotlib.pyplot as plt
|
4364
|
+
|
4365
|
+
# Validate inputs
|
4366
|
+
if not os.path.exists(image_path):
|
4367
|
+
raise ValueError(f"Image file not found: {image_path}")
|
4368
|
+
if not os.path.exists(vector_path):
|
4369
|
+
raise ValueError(f"Vector file not found: {vector_path}")
|
4370
|
+
|
4371
|
+
# Create figure
|
4372
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
4373
|
+
|
4374
|
+
# Load and display image
|
4375
|
+
with rasterio.open(image_path) as src:
|
4376
|
+
# Plot image only
|
4377
|
+
show(src, ax=ax1, title="Image")
|
4378
|
+
|
4379
|
+
# Load vector data
|
4380
|
+
vector_data = gpd.read_file(vector_path)
|
4381
|
+
|
4382
|
+
# Reproject to image CRS if needed
|
4383
|
+
if vector_data.crs != src.crs:
|
4384
|
+
vector_data = vector_data.to_crs(src.crs)
|
4385
|
+
|
4386
|
+
# Plot image with vector overlay
|
4387
|
+
show(
|
4388
|
+
src,
|
4389
|
+
ax=ax2,
|
4390
|
+
title=f"Image with {len(vector_data)} Vector Features",
|
4391
|
+
)
|
4392
|
+
vector_data.plot(
|
4393
|
+
ax=ax2,
|
4394
|
+
facecolor=vector_facecolor,
|
4395
|
+
edgecolor=vector_color,
|
4396
|
+
linewidth=vector_linewidth,
|
4397
|
+
)
|
4398
|
+
|
4399
|
+
# Collect metadata
|
4400
|
+
info = {
|
4401
|
+
"image_shape": src.shape,
|
4402
|
+
"image_crs": src.crs,
|
4403
|
+
"image_bounds": src.bounds,
|
4404
|
+
"num_features": len(vector_data),
|
4405
|
+
"vector_crs": vector_data.crs,
|
4406
|
+
"vector_bounds": vector_data.total_bounds,
|
4407
|
+
}
|
4408
|
+
|
4409
|
+
plt.tight_layout()
|
4410
|
+
|
4411
|
+
# Save or show
|
4412
|
+
if save_path:
|
4413
|
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
4414
|
+
plt.close(fig)
|
4415
|
+
print(f"Figure saved to: {save_path}")
|
4416
|
+
else:
|
4417
|
+
plt.show()
|
4418
|
+
|
4419
|
+
return fig, (ax1, ax2), info
|
4420
|
+
|
4421
|
+
|
3632
4422
|
def create_overview_image(
|
3633
4423
|
src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
|
3634
4424
|
) -> str:
|