geoai-py 0.1.7__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/preprocess.py ADDED
@@ -0,0 +1,2008 @@
1
+ import json
2
+ import math
3
+ import os
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ import warnings
7
+ import xml.etree.ElementTree as ET
8
+ import numpy as np
9
+ import rasterio
10
+ import geopandas as gpd
11
+ import pandas as pd
12
+ from rasterio.windows import Window
13
+ from rasterio import features
14
+ from shapely.geometry import box, shape
15
+ import matplotlib.pyplot as plt
16
+ from tqdm import tqdm
17
+ from torchvision.transforms import RandomRotation
18
+ from shapely.affinity import rotate
19
+ import torchgeo
20
+ import torch
21
+
22
+
23
+ def raster_to_vector(
24
+ raster_path,
25
+ output_path=None,
26
+ threshold=0,
27
+ min_area=10,
28
+ simplify_tolerance=None,
29
+ class_values=None,
30
+ attribute_name="class",
31
+ output_format="geojson",
32
+ plot_result=False,
33
+ ):
34
+ """
35
+ Convert a raster label mask to vector polygons.
36
+
37
+ Args:
38
+ raster_path (str): Path to the input raster file (e.g., GeoTIFF).
39
+ output_path (str): Path to save the output vector file. If None, returns GeoDataFrame without saving.
40
+ threshold (int/float): Pixel values greater than this threshold will be vectorized.
41
+ min_area (float): Minimum polygon area in square map units to keep.
42
+ simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
43
+ class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
44
+ attribute_name (str): Name of the attribute field for the class values.
45
+ output_format (str): Format for output file - 'geojson', 'shapefile', 'gpkg'.
46
+ plot_result (bool): Whether to plot the resulting polygons overlaid on the raster.
47
+
48
+ Returns:
49
+ geopandas.GeoDataFrame: A GeoDataFrame containing the vectorized polygons.
50
+ """
51
+ # Open the raster file
52
+ with rasterio.open(raster_path) as src:
53
+ # Read the data
54
+ data = src.read(1)
55
+
56
+ # Get metadata
57
+ transform = src.transform
58
+ crs = src.crs
59
+
60
+ # Create mask based on threshold and class values
61
+ if class_values is not None:
62
+ # Create a mask for each specified class value
63
+ masks = {val: (data == val) for val in class_values}
64
+ else:
65
+ # Create a mask for values above threshold
66
+ masks = {1: (data > threshold)}
67
+ class_values = [1] # Default class
68
+
69
+ # Initialize list to store features
70
+ all_features = []
71
+
72
+ # Process each class value
73
+ for class_val in class_values:
74
+ mask = masks[class_val]
75
+
76
+ # Vectorize the mask
77
+ for geom, value in features.shapes(
78
+ mask.astype(np.uint8), mask=mask, transform=transform
79
+ ):
80
+ # Convert to shapely geometry
81
+ geom = shape(geom)
82
+
83
+ # Skip small polygons
84
+ if geom.area < min_area:
85
+ continue
86
+
87
+ # Simplify geometry if requested
88
+ if simplify_tolerance is not None:
89
+ geom = geom.simplify(simplify_tolerance)
90
+
91
+ # Add to features list with class value
92
+ all_features.append({"geometry": geom, attribute_name: class_val})
93
+
94
+ # Create GeoDataFrame
95
+ if all_features:
96
+ gdf = gpd.GeoDataFrame(all_features, crs=crs)
97
+ else:
98
+ print("Warning: No features were extracted from the raster.")
99
+ # Return empty GeoDataFrame with correct CRS
100
+ gdf = gpd.GeoDataFrame([], geometry=[], crs=crs)
101
+
102
+ # Save to file if requested
103
+ if output_path is not None:
104
+ # Create directory if it doesn't exist
105
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
106
+
107
+ # Save to file based on format
108
+ if output_format.lower() == "geojson":
109
+ gdf.to_file(output_path, driver="GeoJSON")
110
+ elif output_format.lower() == "shapefile":
111
+ gdf.to_file(output_path)
112
+ elif output_format.lower() == "gpkg":
113
+ gdf.to_file(output_path, driver="GPKG")
114
+ else:
115
+ raise ValueError(f"Unsupported output format: {output_format}")
116
+
117
+ print(f"Vectorized data saved to {output_path}")
118
+
119
+ # Plot result if requested
120
+ if plot_result:
121
+ fig, ax = plt.subplots(figsize=(12, 12))
122
+
123
+ # Plot raster
124
+ raster_img = src.read()
125
+ if raster_img.shape[0] == 1:
126
+ plt.imshow(raster_img[0], cmap="viridis", alpha=0.7)
127
+ else:
128
+ # Use first 3 bands for RGB display
129
+ rgb = raster_img[:3].transpose(1, 2, 0)
130
+ # Normalize for display
131
+ rgb = np.clip(rgb / rgb.max(), 0, 1)
132
+ plt.imshow(rgb)
133
+
134
+ # Plot vector boundaries
135
+ if not gdf.empty:
136
+ gdf.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
137
+
138
+ plt.title("Raster with Vectorized Boundaries")
139
+ plt.axis("off")
140
+ plt.tight_layout()
141
+ plt.show()
142
+
143
+ return gdf
144
+
145
+
146
+ def batch_raster_to_vector(
147
+ input_dir,
148
+ output_dir,
149
+ pattern="*.tif",
150
+ threshold=0,
151
+ min_area=10,
152
+ simplify_tolerance=None,
153
+ class_values=None,
154
+ attribute_name="class",
155
+ output_format="geojson",
156
+ merge_output=False,
157
+ merge_filename="merged_vectors",
158
+ ):
159
+ """
160
+ Batch convert multiple raster files to vector polygons.
161
+
162
+ Args:
163
+ input_dir (str): Directory containing input raster files.
164
+ output_dir (str): Directory to save output vector files.
165
+ pattern (str): Pattern to match raster files (e.g., '*.tif').
166
+ threshold (int/float): Pixel values greater than this threshold will be vectorized.
167
+ min_area (float): Minimum polygon area in square map units to keep.
168
+ simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
169
+ class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
170
+ attribute_name (str): Name of the attribute field for the class values.
171
+ output_format (str): Format for output files - 'geojson', 'shapefile', 'gpkg'.
172
+ merge_output (bool): Whether to merge all output vectors into a single file.
173
+ merge_filename (str): Filename for the merged output (without extension).
174
+
175
+ Returns:
176
+ geopandas.GeoDataFrame or None: If merge_output is True, returns the merged GeoDataFrame.
177
+ """
178
+ import glob
179
+
180
+ # Create output directory if it doesn't exist
181
+ os.makedirs(output_dir, exist_ok=True)
182
+
183
+ # Get list of raster files
184
+ raster_files = glob.glob(os.path.join(input_dir, pattern))
185
+
186
+ if not raster_files:
187
+ print(f"No files matching pattern '{pattern}' found in {input_dir}")
188
+ return None
189
+
190
+ print(f"Found {len(raster_files)} raster files to process")
191
+
192
+ # Process each raster file
193
+ gdfs = []
194
+ for raster_file in tqdm(raster_files, desc="Processing rasters"):
195
+ # Get output filename
196
+ base_name = os.path.splitext(os.path.basename(raster_file))[0]
197
+ if output_format.lower() == "geojson":
198
+ out_file = os.path.join(output_dir, f"{base_name}.geojson")
199
+ elif output_format.lower() == "shapefile":
200
+ out_file = os.path.join(output_dir, f"{base_name}.shp")
201
+ elif output_format.lower() == "gpkg":
202
+ out_file = os.path.join(output_dir, f"{base_name}.gpkg")
203
+ else:
204
+ raise ValueError(f"Unsupported output format: {output_format}")
205
+
206
+ # Convert raster to vector
207
+ if merge_output:
208
+ # Don't save individual files if merging
209
+ gdf = raster_to_vector(
210
+ raster_file,
211
+ output_path=None,
212
+ threshold=threshold,
213
+ min_area=min_area,
214
+ simplify_tolerance=simplify_tolerance,
215
+ class_values=class_values,
216
+ attribute_name=attribute_name,
217
+ )
218
+
219
+ # Add filename as attribute
220
+ if not gdf.empty:
221
+ gdf["source_file"] = base_name
222
+ gdfs.append(gdf)
223
+ else:
224
+ # Save individual files
225
+ raster_to_vector(
226
+ raster_file,
227
+ output_path=out_file,
228
+ threshold=threshold,
229
+ min_area=min_area,
230
+ simplify_tolerance=simplify_tolerance,
231
+ class_values=class_values,
232
+ attribute_name=attribute_name,
233
+ output_format=output_format,
234
+ )
235
+
236
+ # Merge output if requested
237
+ if merge_output and gdfs:
238
+ merged_gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))
239
+
240
+ # Set CRS to the CRS of the first GeoDataFrame
241
+ if merged_gdf.crs is None and gdfs:
242
+ merged_gdf.crs = gdfs[0].crs
243
+
244
+ # Save merged output
245
+ if output_format.lower() == "geojson":
246
+ merged_file = os.path.join(output_dir, f"{merge_filename}.geojson")
247
+ merged_gdf.to_file(merged_file, driver="GeoJSON")
248
+ elif output_format.lower() == "shapefile":
249
+ merged_file = os.path.join(output_dir, f"{merge_filename}.shp")
250
+ merged_gdf.to_file(merged_file)
251
+ elif output_format.lower() == "gpkg":
252
+ merged_file = os.path.join(output_dir, f"{merge_filename}.gpkg")
253
+ merged_gdf.to_file(merged_file, driver="GPKG")
254
+
255
+ print(f"Merged vector data saved to {merged_file}")
256
+ return merged_gdf
257
+
258
+ return None
259
+
260
+
261
+ # # Example usage
262
+ # if __name__ == "__main__":
263
+ # # Single file conversion example
264
+ # gdf = raster_to_vector(
265
+ # raster_path="output/labels/tile_000001.tif",
266
+ # output_path="output/labels/tile_000001.geojson",
267
+ # threshold=0,
268
+ # min_area=10,
269
+ # simplify_tolerance=0.5,
270
+ # class_values=[1], # For a binary mask, use [1]
271
+ # attribute_name='class',
272
+ # plot_result=True
273
+ # )
274
+
275
+ # Batch conversion example
276
+ # batch_raster_to_vector(
277
+ # input_dir="path/to/labels",
278
+ # output_dir="path/to/vectors",
279
+ # pattern="*.tif",
280
+ # threshold=0,
281
+ # min_area=10,
282
+ # class_values=[1, 2, 3], # For a multiclass mask
283
+ # merge_output=True
284
+ # )
285
+
286
+
287
+ def vector_to_raster(
288
+ vector_path,
289
+ output_path=None,
290
+ reference_raster=None,
291
+ attribute_field=None,
292
+ output_shape=None,
293
+ transform=None,
294
+ pixel_size=None,
295
+ bounds=None,
296
+ crs=None,
297
+ all_touched=False,
298
+ fill_value=0,
299
+ dtype=np.uint8,
300
+ nodata=None,
301
+ plot_result=False,
302
+ ):
303
+ """
304
+ Convert vector data to a raster.
305
+
306
+ Args:
307
+ vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
308
+ output_path (str): Path to save the output raster file. If None, returns the array without saving.
309
+ reference_raster (str): Path to a reference raster for dimensions, transform and CRS.
310
+ attribute_field (str): Field name in the vector data to use for pixel values.
311
+ If None, all vector features will be burned with value 1.
312
+ output_shape (tuple): Shape of the output raster as (height, width).
313
+ Required if reference_raster is not provided.
314
+ transform (affine.Affine): Affine transformation matrix.
315
+ Required if reference_raster is not provided.
316
+ pixel_size (float or tuple): Pixel size (resolution) as single value or (x_res, y_res).
317
+ Used to calculate transform if transform is not provided.
318
+ bounds (tuple): Bounds of the output raster as (left, bottom, right, top).
319
+ Used to calculate transform if transform is not provided.
320
+ crs (str or CRS): Coordinate reference system of the output raster.
321
+ Required if reference_raster is not provided.
322
+ all_touched (bool): If True, all pixels touched by geometries will be burned in.
323
+ If False, only pixels whose center is within the geometry will be burned in.
324
+ fill_value (int): Value to fill the raster with before burning in features.
325
+ dtype (numpy.dtype): Data type of the output raster.
326
+ nodata (int): No data value for the output raster.
327
+ plot_result (bool): Whether to plot the resulting raster.
328
+
329
+ Returns:
330
+ numpy.ndarray: The rasterized data array if output_path is None, else None.
331
+ """
332
+ # Load vector data
333
+ if isinstance(vector_path, gpd.GeoDataFrame):
334
+ gdf = vector_path
335
+ else:
336
+ gdf = gpd.read_file(vector_path)
337
+
338
+ # Check if vector data is empty
339
+ if gdf.empty:
340
+ warnings.warn("The input vector data is empty. Creating an empty raster.")
341
+
342
+ # Get CRS from vector data if not provided
343
+ if crs is None and reference_raster is None:
344
+ crs = gdf.crs
345
+
346
+ # Get transform and output shape from reference raster if provided
347
+ if reference_raster is not None:
348
+ with rasterio.open(reference_raster) as src:
349
+ transform = src.transform
350
+ output_shape = src.shape
351
+ crs = src.crs
352
+ if nodata is None:
353
+ nodata = src.nodata
354
+ else:
355
+ # Check if we have all required parameters
356
+ if transform is None:
357
+ if pixel_size is None or bounds is None:
358
+ raise ValueError(
359
+ "Either reference_raster, transform, or both pixel_size and bounds must be provided."
360
+ )
361
+
362
+ # Calculate transform from pixel size and bounds
363
+ if isinstance(pixel_size, (int, float)):
364
+ x_res = y_res = float(pixel_size)
365
+ else:
366
+ x_res, y_res = pixel_size
367
+ y_res = abs(y_res) * -1 # Convert to negative for north-up raster
368
+
369
+ left, bottom, right, top = bounds
370
+ transform = rasterio.transform.from_bounds(
371
+ left,
372
+ bottom,
373
+ right,
374
+ top,
375
+ int((right - left) / x_res),
376
+ int((top - bottom) / abs(y_res)),
377
+ )
378
+
379
+ if output_shape is None:
380
+ # Calculate output shape from bounds and pixel size
381
+ if bounds is None or pixel_size is None:
382
+ raise ValueError(
383
+ "output_shape must be provided if reference_raster is not provided and "
384
+ "cannot be calculated from bounds and pixel_size."
385
+ )
386
+
387
+ if isinstance(pixel_size, (int, float)):
388
+ x_res = y_res = float(pixel_size)
389
+ else:
390
+ x_res, y_res = pixel_size
391
+
392
+ left, bottom, right, top = bounds
393
+ width = int((right - left) / x_res)
394
+ height = int((top - bottom) / abs(y_res))
395
+ output_shape = (height, width)
396
+
397
+ # Ensure CRS is set
398
+ if crs is None:
399
+ raise ValueError(
400
+ "CRS must be provided either directly, from reference_raster, or from input vector data."
401
+ )
402
+
403
+ # Reproject vector data if its CRS doesn't match the output CRS
404
+ if gdf.crs != crs:
405
+ print(f"Reprojecting vector data from {gdf.crs} to {crs}")
406
+ gdf = gdf.to_crs(crs)
407
+
408
+ # Create empty raster filled with fill_value
409
+ raster_data = np.full(output_shape, fill_value, dtype=dtype)
410
+
411
+ # Burn vector features into raster
412
+ if not gdf.empty:
413
+ # Prepare shapes for burning
414
+ if attribute_field is not None and attribute_field in gdf.columns:
415
+ # Use attribute field for values
416
+ shapes = [
417
+ (geom, value) for geom, value in zip(gdf.geometry, gdf[attribute_field])
418
+ ]
419
+ else:
420
+ # Burn with value 1
421
+ shapes = [(geom, 1) for geom in gdf.geometry]
422
+
423
+ # Burn shapes into raster
424
+ burned = features.rasterize(
425
+ shapes=shapes,
426
+ out_shape=output_shape,
427
+ transform=transform,
428
+ fill=fill_value,
429
+ all_touched=all_touched,
430
+ dtype=dtype,
431
+ )
432
+
433
+ # Update raster data
434
+ raster_data = burned
435
+
436
+ # Save raster if output path is provided
437
+ if output_path is not None:
438
+ # Create directory if it doesn't exist
439
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
440
+
441
+ # Define metadata
442
+ metadata = {
443
+ "driver": "GTiff",
444
+ "height": output_shape[0],
445
+ "width": output_shape[1],
446
+ "count": 1,
447
+ "dtype": raster_data.dtype,
448
+ "crs": crs,
449
+ "transform": transform,
450
+ }
451
+
452
+ # Add nodata value if provided
453
+ if nodata is not None:
454
+ metadata["nodata"] = nodata
455
+
456
+ # Write raster
457
+ with rasterio.open(output_path, "w", **metadata) as dst:
458
+ dst.write(raster_data, 1)
459
+
460
+ print(f"Rasterized data saved to {output_path}")
461
+
462
+ # Plot result if requested
463
+ if plot_result:
464
+ fig, ax = plt.subplots(figsize=(10, 10))
465
+
466
+ # Plot raster
467
+ im = ax.imshow(raster_data, cmap="viridis")
468
+ plt.colorbar(im, ax=ax, label=attribute_field if attribute_field else "Value")
469
+
470
+ # Plot vector boundaries for reference
471
+ if output_path is not None:
472
+ # Get the extent of the raster
473
+ with rasterio.open(output_path) as src:
474
+ bounds = src.bounds
475
+ raster_bbox = box(*bounds)
476
+ else:
477
+ # Calculate extent from transform and shape
478
+ height, width = output_shape
479
+ left, top = transform * (0, 0)
480
+ right, bottom = transform * (width, height)
481
+ raster_bbox = box(left, bottom, right, top)
482
+
483
+ # Clip vector to raster extent for clarity in plot
484
+ if not gdf.empty:
485
+ gdf_clipped = gpd.clip(gdf, raster_bbox)
486
+ if not gdf_clipped.empty:
487
+ gdf_clipped.boundary.plot(ax=ax, color="red", linewidth=1)
488
+
489
+ plt.title("Rasterized Vector Data")
490
+ plt.tight_layout()
491
+ plt.show()
492
+
493
+ return raster_data
494
+
495
+
496
+ def batch_vector_to_raster(
497
+ vector_path,
498
+ output_dir,
499
+ attribute_field=None,
500
+ reference_rasters=None,
501
+ bounds_list=None,
502
+ output_filename_pattern="{vector_name}_{index}",
503
+ pixel_size=1.0,
504
+ all_touched=False,
505
+ fill_value=0,
506
+ dtype=np.uint8,
507
+ nodata=None,
508
+ ):
509
+ """
510
+ Batch convert vector data to multiple rasters based on different extents or reference rasters.
511
+
512
+ Args:
513
+ vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
514
+ output_dir (str): Directory to save output raster files.
515
+ attribute_field (str): Field name in the vector data to use for pixel values.
516
+ reference_rasters (list): List of paths to reference rasters for dimensions, transform and CRS.
517
+ bounds_list (list): List of bounds tuples (left, bottom, right, top) to use if reference_rasters not provided.
518
+ output_filename_pattern (str): Pattern for output filenames.
519
+ Can include {vector_name} and {index} placeholders.
520
+ pixel_size (float or tuple): Pixel size to use if reference_rasters not provided.
521
+ all_touched (bool): If True, all pixels touched by geometries will be burned in.
522
+ fill_value (int): Value to fill the raster with before burning in features.
523
+ dtype (numpy.dtype): Data type of the output raster.
524
+ nodata (int): No data value for the output raster.
525
+
526
+ Returns:
527
+ list: List of paths to the created raster files.
528
+ """
529
+ # Create output directory if it doesn't exist
530
+ os.makedirs(output_dir, exist_ok=True)
531
+
532
+ # Load vector data if it's a path
533
+ if isinstance(vector_path, str):
534
+ gdf = gpd.read_file(vector_path)
535
+ vector_name = os.path.splitext(os.path.basename(vector_path))[0]
536
+ else:
537
+ gdf = vector_path
538
+ vector_name = "vector"
539
+
540
+ # Check input parameters
541
+ if reference_rasters is None and bounds_list is None:
542
+ raise ValueError("Either reference_rasters or bounds_list must be provided.")
543
+
544
+ # Use reference_rasters if provided, otherwise use bounds_list
545
+ if reference_rasters is not None:
546
+ sources = reference_rasters
547
+ is_raster_reference = True
548
+ else:
549
+ sources = bounds_list
550
+ is_raster_reference = False
551
+
552
+ # Create output filenames
553
+ output_files = []
554
+
555
+ # Process each source (reference raster or bounds)
556
+ for i, source in enumerate(tqdm(sources, desc="Processing")):
557
+ # Generate output filename
558
+ output_filename = output_filename_pattern.format(
559
+ vector_name=vector_name, index=i
560
+ )
561
+ if not output_filename.endswith(".tif"):
562
+ output_filename += ".tif"
563
+ output_path = os.path.join(output_dir, output_filename)
564
+
565
+ if is_raster_reference:
566
+ # Use reference raster
567
+ vector_to_raster(
568
+ vector_path=gdf,
569
+ output_path=output_path,
570
+ reference_raster=source,
571
+ attribute_field=attribute_field,
572
+ all_touched=all_touched,
573
+ fill_value=fill_value,
574
+ dtype=dtype,
575
+ nodata=nodata,
576
+ )
577
+ else:
578
+ # Use bounds
579
+ vector_to_raster(
580
+ vector_path=gdf,
581
+ output_path=output_path,
582
+ bounds=source,
583
+ pixel_size=pixel_size,
584
+ attribute_field=attribute_field,
585
+ all_touched=all_touched,
586
+ fill_value=fill_value,
587
+ dtype=dtype,
588
+ nodata=nodata,
589
+ )
590
+
591
+ output_files.append(output_path)
592
+
593
+ return output_files
594
+
595
+
596
+ # # Example usage
597
+ # if __name__ == "__main__":
598
+ # # Single file conversion example
599
+ # raster_data = vector_to_raster(
600
+ # vector_path="buildings_train.geojson",
601
+ # output_path="buildings_train.tif",
602
+ # reference_raster="naip_train.tif", # Optional, can use other parameters instead
603
+ # # attribute_field="class", # Optional, uses field values for pixel values
604
+ # all_touched=True, # Ensures small features are captured
605
+ # plot_result=True
606
+ # )
607
+
608
+ # Example with custom dimensions
609
+ # raster_data = vector_to_raster(
610
+ # vector_path="path/to/buildings.geojson",
611
+ # output_path="path/to/rasterized_buildings.tif",
612
+ # pixel_size=0.5, # 0.5 meter resolution
613
+ # bounds=(454780, 5277567, 456282, 5278242), # from original data
614
+ # crs="EPSG:26911",
615
+ # output_shape=(1350, 3000), # custom dimensions
616
+ # attribute_field="class"
617
+ # )
618
+
619
+ # Batch conversion example
620
+ # output_files = batch_vector_to_raster(
621
+ # vector_path="path/to/buildings.geojson",
622
+ # output_dir="path/to/output",
623
+ # reference_rasters=["path/to/ref1.tif", "path/to/ref2.tif"],
624
+ # attribute_field="class",
625
+ # all_touched=True
626
+ # )
627
+
628
+
629
+ def export_geotiff_tiles(
630
+ in_raster,
631
+ out_folder,
632
+ in_class_data,
633
+ tile_size=256,
634
+ stride=128,
635
+ class_value_field="class",
636
+ buffer_radius=0,
637
+ max_tiles=None,
638
+ quiet=False,
639
+ all_touched=True,
640
+ create_overview=False,
641
+ skip_empty_tiles=False,
642
+ ):
643
+ """
644
+ Export georeferenced GeoTIFF tiles and labels from raster and classification data.
645
+
646
+ Args:
647
+ in_raster (str): Path to input raster image
648
+ out_folder (str): Path to output folder
649
+ in_class_data (str): Path to classification data - can be vector file or raster
650
+ tile_size (int): Size of tiles in pixels (square)
651
+ stride (int): Step size between tiles
652
+ class_value_field (str): Field containing class values (for vector data)
653
+ buffer_radius (float): Buffer to add around features (in units of the CRS)
654
+ max_tiles (int): Maximum number of tiles to process (None for all)
655
+ quiet (bool): If True, suppress non-essential output
656
+ all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
657
+ create_overview (bool): Whether to create an overview image of all tiles
658
+ skip_empty_tiles (bool): If True, skip tiles with no features
659
+ """
660
+ # Create output directories
661
+ os.makedirs(out_folder, exist_ok=True)
662
+ image_dir = os.path.join(out_folder, "images")
663
+ os.makedirs(image_dir, exist_ok=True)
664
+ label_dir = os.path.join(out_folder, "labels")
665
+ os.makedirs(label_dir, exist_ok=True)
666
+ ann_dir = os.path.join(out_folder, "annotations")
667
+ os.makedirs(ann_dir, exist_ok=True)
668
+
669
+ # Determine if class data is raster or vector
670
+ is_class_data_raster = False
671
+ if isinstance(in_class_data, str):
672
+ file_ext = Path(in_class_data).suffix.lower()
673
+ # Common raster extensions
674
+ if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
675
+ try:
676
+ with rasterio.open(in_class_data) as src:
677
+ is_class_data_raster = True
678
+ if not quiet:
679
+ print(f"Detected in_class_data as raster: {in_class_data}")
680
+ print(f"Raster CRS: {src.crs}")
681
+ print(f"Raster dimensions: {src.width} x {src.height}")
682
+ except Exception:
683
+ is_class_data_raster = False
684
+ if not quiet:
685
+ print(f"Unable to open {in_class_data} as raster, trying as vector")
686
+
687
+ # Open the input raster
688
+ with rasterio.open(in_raster) as src:
689
+ if not quiet:
690
+ print(f"\nRaster info for {in_raster}:")
691
+ print(f" CRS: {src.crs}")
692
+ print(f" Dimensions: {src.width} x {src.height}")
693
+ print(f" Bounds: {src.bounds}")
694
+
695
+ # Calculate number of tiles
696
+ num_tiles_x = math.ceil((src.width - tile_size) / stride) + 1
697
+ num_tiles_y = math.ceil((src.height - tile_size) / stride) + 1
698
+ total_tiles = num_tiles_x * num_tiles_y
699
+
700
+ if max_tiles is None:
701
+ max_tiles = total_tiles
702
+
703
+ # Process classification data
704
+ class_to_id = {}
705
+
706
+ if is_class_data_raster:
707
+ # Load raster class data
708
+ with rasterio.open(in_class_data) as class_src:
709
+ # Check if raster CRS matches
710
+ if class_src.crs != src.crs:
711
+ warnings.warn(
712
+ f"CRS mismatch: Class raster ({class_src.crs}) doesn't match input raster ({src.crs}). "
713
+ f"Results may be misaligned."
714
+ )
715
+
716
+ # Get unique values from raster
717
+ # Sample to avoid loading huge rasters
718
+ sample_data = class_src.read(
719
+ 1,
720
+ out_shape=(
721
+ 1,
722
+ min(class_src.height, 1000),
723
+ min(class_src.width, 1000),
724
+ ),
725
+ )
726
+
727
+ unique_classes = np.unique(sample_data)
728
+ unique_classes = unique_classes[
729
+ unique_classes > 0
730
+ ] # Remove 0 as it's typically background
731
+
732
+ if not quiet:
733
+ print(
734
+ f"Found {len(unique_classes)} unique classes in raster: {unique_classes}"
735
+ )
736
+
737
+ # Create class mapping
738
+ class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
739
+ else:
740
+ # Load vector class data
741
+ try:
742
+ gdf = gpd.read_file(in_class_data)
743
+ if not quiet:
744
+ print(f"Loaded {len(gdf)} features from {in_class_data}")
745
+ print(f"Vector CRS: {gdf.crs}")
746
+
747
+ # Always reproject to match raster CRS
748
+ if gdf.crs != src.crs:
749
+ if not quiet:
750
+ print(f"Reprojecting features from {gdf.crs} to {src.crs}")
751
+ gdf = gdf.to_crs(src.crs)
752
+
753
+ # Apply buffer if specified
754
+ if buffer_radius > 0:
755
+ gdf["geometry"] = gdf.buffer(buffer_radius)
756
+ if not quiet:
757
+ print(f"Applied buffer of {buffer_radius} units")
758
+
759
+ # Check if class_value_field exists
760
+ if class_value_field in gdf.columns:
761
+ unique_classes = gdf[class_value_field].unique()
762
+ if not quiet:
763
+ print(
764
+ f"Found {len(unique_classes)} unique classes: {unique_classes}"
765
+ )
766
+ # Create class mapping
767
+ class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
768
+ else:
769
+ if not quiet:
770
+ print(
771
+ f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
772
+ )
773
+ class_to_id = {1: 1} # Default mapping
774
+ except Exception as e:
775
+ raise ValueError(f"Error processing vector data: {e}")
776
+
777
+ # Create progress bar
778
+ pbar = tqdm(
779
+ total=min(total_tiles, max_tiles),
780
+ desc="Generating tiles",
781
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
782
+ )
783
+
784
+ # Track statistics for summary
785
+ stats = {
786
+ "total_tiles": 0,
787
+ "tiles_with_features": 0,
788
+ "feature_pixels": 0,
789
+ "errors": 0,
790
+ "tile_coordinates": [], # For overview image
791
+ }
792
+
793
+ # Process tiles
794
+ tile_index = 0
795
+ for y in range(num_tiles_y):
796
+ for x in range(num_tiles_x):
797
+ if tile_index >= max_tiles:
798
+ break
799
+
800
+ # Calculate window coordinates
801
+ window_x = x * stride
802
+ window_y = y * stride
803
+
804
+ # Adjust for edge cases
805
+ if window_x + tile_size > src.width:
806
+ window_x = src.width - tile_size
807
+ if window_y + tile_size > src.height:
808
+ window_y = src.height - tile_size
809
+
810
+ # Define window
811
+ window = Window(window_x, window_y, tile_size, tile_size)
812
+
813
+ # Get window transform and bounds
814
+ window_transform = src.window_transform(window)
815
+
816
+ # Calculate window bounds
817
+ minx = window_transform[2] # Upper left x
818
+ maxy = window_transform[5] # Upper left y
819
+ maxx = minx + tile_size * window_transform[0] # Add width
820
+ miny = maxy + tile_size * window_transform[4] # Add height
821
+
822
+ window_bounds = box(minx, miny, maxx, maxy)
823
+
824
+ # Store tile coordinates for overview
825
+ if create_overview:
826
+ stats["tile_coordinates"].append(
827
+ {
828
+ "index": tile_index,
829
+ "x": window_x,
830
+ "y": window_y,
831
+ "bounds": [minx, miny, maxx, maxy],
832
+ "has_features": False,
833
+ }
834
+ )
835
+
836
+ # Create label mask
837
+ label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
838
+ has_features = False
839
+
840
+ # Process classification data to create labels
841
+ if is_class_data_raster:
842
+ # For raster class data
843
+ with rasterio.open(in_class_data) as class_src:
844
+ # Calculate window in class raster
845
+ src_bounds = src.bounds
846
+ class_bounds = class_src.bounds
847
+
848
+ # Check if windows overlap
849
+ if (
850
+ src_bounds.left > class_bounds.right
851
+ or src_bounds.right < class_bounds.left
852
+ or src_bounds.bottom > class_bounds.top
853
+ or src_bounds.top < class_bounds.bottom
854
+ ):
855
+ warnings.warn(
856
+ "Class raster and input raster do not overlap."
857
+ )
858
+ else:
859
+ # Get corresponding window in class raster
860
+ window_class = rasterio.windows.from_bounds(
861
+ minx, miny, maxx, maxy, class_src.transform
862
+ )
863
+
864
+ # Read label data
865
+ try:
866
+ label_data = class_src.read(
867
+ 1,
868
+ window=window_class,
869
+ boundless=True,
870
+ out_shape=(tile_size, tile_size),
871
+ )
872
+
873
+ # Remap class values if needed
874
+ if class_to_id:
875
+ remapped_data = np.zeros_like(label_data)
876
+ for orig_val, new_val in class_to_id.items():
877
+ remapped_data[label_data == orig_val] = new_val
878
+ label_mask = remapped_data
879
+ else:
880
+ label_mask = label_data
881
+
882
+ # Check if we have any features
883
+ if np.any(label_mask > 0):
884
+ has_features = True
885
+ stats["feature_pixels"] += np.count_nonzero(
886
+ label_mask
887
+ )
888
+ except Exception as e:
889
+ pbar.write(f"Error reading class raster window: {e}")
890
+ stats["errors"] += 1
891
+ else:
892
+ # For vector class data
893
+ # Find features that intersect with window
894
+ window_features = gdf[gdf.intersects(window_bounds)]
895
+
896
+ if len(window_features) > 0:
897
+ for idx, feature in window_features.iterrows():
898
+ # Get class value
899
+ if class_value_field in feature:
900
+ class_val = feature[class_value_field]
901
+ class_id = class_to_id.get(class_val, 1)
902
+ else:
903
+ class_id = 1
904
+
905
+ # Get geometry in window coordinates
906
+ geom = feature.geometry.intersection(window_bounds)
907
+ if not geom.is_empty:
908
+ try:
909
+ # Rasterize feature
910
+ feature_mask = features.rasterize(
911
+ [(geom, class_id)],
912
+ out_shape=(tile_size, tile_size),
913
+ transform=window_transform,
914
+ fill=0,
915
+ all_touched=all_touched,
916
+ )
917
+
918
+ # Add to label mask
919
+ label_mask = np.maximum(label_mask, feature_mask)
920
+
921
+ # Check if the feature was actually rasterized
922
+ if np.any(feature_mask):
923
+ has_features = True
924
+ if create_overview and tile_index < len(
925
+ stats["tile_coordinates"]
926
+ ):
927
+ stats["tile_coordinates"][tile_index][
928
+ "has_features"
929
+ ] = True
930
+ except Exception as e:
931
+ pbar.write(f"Error rasterizing feature {idx}: {e}")
932
+ stats["errors"] += 1
933
+
934
+ # Skip tile if no features and skip_empty_tiles is True
935
+ if skip_empty_tiles and not has_features:
936
+ pbar.update(1)
937
+ tile_index += 1
938
+ continue
939
+
940
+ # Read image data
941
+ image_data = src.read(window=window)
942
+
943
+ # Export image as GeoTIFF
944
+ image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
945
+
946
+ # Create profile for image GeoTIFF
947
+ image_profile = src.profile.copy()
948
+ image_profile.update(
949
+ {
950
+ "height": tile_size,
951
+ "width": tile_size,
952
+ "count": image_data.shape[0],
953
+ "transform": window_transform,
954
+ }
955
+ )
956
+
957
+ # Save image as GeoTIFF
958
+ try:
959
+ with rasterio.open(image_path, "w", **image_profile) as dst:
960
+ dst.write(image_data)
961
+ stats["total_tiles"] += 1
962
+ except Exception as e:
963
+ pbar.write(f"ERROR saving image GeoTIFF: {e}")
964
+ stats["errors"] += 1
965
+
966
+ # Create profile for label GeoTIFF
967
+ label_profile = {
968
+ "driver": "GTiff",
969
+ "height": tile_size,
970
+ "width": tile_size,
971
+ "count": 1,
972
+ "dtype": "uint8",
973
+ "crs": src.crs,
974
+ "transform": window_transform,
975
+ }
976
+
977
+ # Export label as GeoTIFF
978
+ label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
979
+ try:
980
+ with rasterio.open(label_path, "w", **label_profile) as dst:
981
+ dst.write(label_mask.astype(np.uint8), 1)
982
+
983
+ if has_features:
984
+ stats["tiles_with_features"] += 1
985
+ stats["feature_pixels"] += np.count_nonzero(label_mask)
986
+ except Exception as e:
987
+ pbar.write(f"ERROR saving label GeoTIFF: {e}")
988
+ stats["errors"] += 1
989
+
990
+ # Create XML annotation for object detection if using vector class data
991
+ if (
992
+ not is_class_data_raster
993
+ and "gdf" in locals()
994
+ and len(window_features) > 0
995
+ ):
996
+ # Create XML annotation
997
+ root = ET.Element("annotation")
998
+ ET.SubElement(root, "folder").text = "images"
999
+ ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
1000
+
1001
+ size = ET.SubElement(root, "size")
1002
+ ET.SubElement(size, "width").text = str(tile_size)
1003
+ ET.SubElement(size, "height").text = str(tile_size)
1004
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
1005
+
1006
+ # Add georeference information
1007
+ geo = ET.SubElement(root, "georeference")
1008
+ ET.SubElement(geo, "crs").text = str(src.crs)
1009
+ ET.SubElement(geo, "transform").text = str(
1010
+ window_transform
1011
+ ).replace("\n", "")
1012
+ ET.SubElement(geo, "bounds").text = (
1013
+ f"{minx}, {miny}, {maxx}, {maxy}"
1014
+ )
1015
+
1016
+ # Add objects
1017
+ for idx, feature in window_features.iterrows():
1018
+ # Get feature class
1019
+ if class_value_field in feature:
1020
+ class_val = feature[class_value_field]
1021
+ else:
1022
+ class_val = "object"
1023
+
1024
+ # Get geometry bounds in pixel coordinates
1025
+ geom = feature.geometry.intersection(window_bounds)
1026
+ if not geom.is_empty:
1027
+ # Get bounds in world coordinates
1028
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
1029
+
1030
+ # Convert to pixel coordinates
1031
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
1032
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
1033
+
1034
+ # Ensure coordinates are within tile bounds
1035
+ xmin = max(0, min(tile_size, int(col_min)))
1036
+ ymin = max(0, min(tile_size, int(row_min)))
1037
+ xmax = max(0, min(tile_size, int(col_max)))
1038
+ ymax = max(0, min(tile_size, int(row_max)))
1039
+
1040
+ # Only add if the box has non-zero area
1041
+ if xmax > xmin and ymax > ymin:
1042
+ obj = ET.SubElement(root, "object")
1043
+ ET.SubElement(obj, "name").text = str(class_val)
1044
+ ET.SubElement(obj, "difficult").text = "0"
1045
+
1046
+ bbox = ET.SubElement(obj, "bndbox")
1047
+ ET.SubElement(bbox, "xmin").text = str(xmin)
1048
+ ET.SubElement(bbox, "ymin").text = str(ymin)
1049
+ ET.SubElement(bbox, "xmax").text = str(xmax)
1050
+ ET.SubElement(bbox, "ymax").text = str(ymax)
1051
+
1052
+ # Save XML
1053
+ tree = ET.ElementTree(root)
1054
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
1055
+ tree.write(xml_path)
1056
+
1057
+ # Update progress bar
1058
+ pbar.update(1)
1059
+ pbar.set_description(
1060
+ f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
1061
+ )
1062
+
1063
+ tile_index += 1
1064
+ if tile_index >= max_tiles:
1065
+ break
1066
+
1067
+ if tile_index >= max_tiles:
1068
+ break
1069
+
1070
+ # Close progress bar
1071
+ pbar.close()
1072
+
1073
+ # Create overview image if requested
1074
+ if create_overview and stats["tile_coordinates"]:
1075
+ try:
1076
+ create_overview_image(
1077
+ src,
1078
+ stats["tile_coordinates"],
1079
+ os.path.join(out_folder, "overview.png"),
1080
+ tile_size,
1081
+ stride,
1082
+ )
1083
+ except Exception as e:
1084
+ print(f"Failed to create overview image: {e}")
1085
+
1086
+ # Report results
1087
+ if not quiet:
1088
+ print("\n------- Export Summary -------")
1089
+ print(f"Total tiles exported: {stats['total_tiles']}")
1090
+ print(
1091
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
1092
+ )
1093
+ if stats["tiles_with_features"] > 0:
1094
+ print(
1095
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
1096
+ )
1097
+ if stats["errors"] > 0:
1098
+ print(f"Errors encountered: {stats['errors']}")
1099
+ print(f"Output saved to: {out_folder}")
1100
+
1101
+ # Verify georeference in a sample image and label
1102
+ if stats["total_tiles"] > 0:
1103
+ print("\n------- Georeference Verification -------")
1104
+ sample_image = os.path.join(image_dir, f"tile_0.tif")
1105
+ sample_label = os.path.join(label_dir, f"tile_0.tif")
1106
+
1107
+ if os.path.exists(sample_image):
1108
+ try:
1109
+ with rasterio.open(sample_image) as img:
1110
+ print(f"Image CRS: {img.crs}")
1111
+ print(f"Image transform: {img.transform}")
1112
+ print(
1113
+ f"Image has georeference: {img.crs is not None and img.transform is not None}"
1114
+ )
1115
+ print(
1116
+ f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
1117
+ )
1118
+ except Exception as e:
1119
+ print(f"Error verifying image georeference: {e}")
1120
+
1121
+ if os.path.exists(sample_label):
1122
+ try:
1123
+ with rasterio.open(sample_label) as lbl:
1124
+ print(f"Label CRS: {lbl.crs}")
1125
+ print(f"Label transform: {lbl.transform}")
1126
+ print(
1127
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
1128
+ )
1129
+ print(
1130
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
1131
+ )
1132
+ except Exception as e:
1133
+ print(f"Error verifying label georeference: {e}")
1134
+
1135
+ # Return statistics dictionary for further processing if needed
1136
+ return stats
1137
+
1138
+
1139
+ def create_overview_image(src, tile_coordinates, output_path, tile_size, stride):
1140
+ """Create an overview image showing all tiles and their status."""
1141
+ # Read a reduced version of the source image
1142
+ overview_scale = max(
1143
+ 1, int(max(src.width, src.height) / 2000)
1144
+ ) # Scale to max ~2000px
1145
+ overview_width = src.width // overview_scale
1146
+ overview_height = src.height // overview_scale
1147
+
1148
+ # Read downsampled image
1149
+ overview_data = src.read(
1150
+ out_shape=(src.count, overview_height, overview_width),
1151
+ resampling=rasterio.enums.Resampling.average,
1152
+ )
1153
+
1154
+ # Create RGB image for display
1155
+ if overview_data.shape[0] >= 3:
1156
+ rgb = np.moveaxis(overview_data[:3], 0, -1)
1157
+ else:
1158
+ # For single band, create grayscale RGB
1159
+ rgb = np.stack([overview_data[0], overview_data[0], overview_data[0]], axis=-1)
1160
+
1161
+ # Normalize for display
1162
+ for i in range(rgb.shape[-1]):
1163
+ band = rgb[..., i]
1164
+ non_zero = band[band > 0]
1165
+ if len(non_zero) > 0:
1166
+ p2, p98 = np.percentile(non_zero, (2, 98))
1167
+ rgb[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
1168
+
1169
+ # Create figure
1170
+ plt.figure(figsize=(12, 12))
1171
+ plt.imshow(rgb)
1172
+
1173
+ # Draw tile boundaries
1174
+ for tile in tile_coordinates:
1175
+ # Convert bounds to pixel coordinates in overview
1176
+ bounds = tile["bounds"]
1177
+ # Calculate scaled pixel coordinates
1178
+ x_min = int((tile["x"]) / overview_scale)
1179
+ y_min = int((tile["y"]) / overview_scale)
1180
+ width = int(tile_size / overview_scale)
1181
+ height = int(tile_size / overview_scale)
1182
+
1183
+ # Draw rectangle
1184
+ color = "lime" if tile["has_features"] else "red"
1185
+ rect = plt.Rectangle(
1186
+ (x_min, y_min), width, height, fill=False, edgecolor=color, linewidth=0.5
1187
+ )
1188
+ plt.gca().add_patch(rect)
1189
+
1190
+ # Add tile number if not too crowded
1191
+ if width > 20 and height > 20:
1192
+ plt.text(
1193
+ x_min + width / 2,
1194
+ y_min + height / 2,
1195
+ str(tile["index"]),
1196
+ color="white",
1197
+ ha="center",
1198
+ va="center",
1199
+ fontsize=8,
1200
+ )
1201
+
1202
+ plt.title("Tile Overview (Green = Contains Features, Red = Empty)")
1203
+ plt.axis("off")
1204
+ plt.tight_layout()
1205
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
1206
+ plt.close()
1207
+
1208
+ print(f"Overview image saved to {output_path}")
1209
+
1210
+
1211
+ # # Example usage
1212
+ # if __name__ == "__main__":
1213
+ # # Try to install tqdm if not available
1214
+ # try:
1215
+ # import tqdm
1216
+ # except ImportError:
1217
+ # print("Installing tqdm progress bar library...")
1218
+ # import sys
1219
+ # import subprocess
1220
+
1221
+ # subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"])
1222
+ # import tqdm
1223
+
1224
+ # # Example with vector class data
1225
+ # export_geotiff_tiles(
1226
+ # in_raster="naip_train.tif",
1227
+ # out_folder="geotiff_output_vector",
1228
+ # in_class_data="buildings_train.geojson",
1229
+ # tile_size=256,
1230
+ # stride=128,
1231
+ # class_value_field="class",
1232
+ # buffer_radius=2,
1233
+ # create_overview=True,
1234
+ # )
1235
+
1236
+ # # Example with raster class data
1237
+ # export_geotiff_tiles(
1238
+ # in_raster="naip_train.tif",
1239
+ # out_folder="geotiff_output_raster",
1240
+ # in_class_data="buildings_train.tif", # This would be a raster mask
1241
+ # tile_size=256,
1242
+ # stride=128,
1243
+ # create_overview=True,
1244
+ # skip_empty_tiles=True,
1245
+ # )
1246
+
1247
+
1248
+ def export_training_data(
1249
+ in_raster,
1250
+ out_folder,
1251
+ in_class_data,
1252
+ image_chip_format="GEOTIFF",
1253
+ tile_size_x=256,
1254
+ tile_size_y=256,
1255
+ stride_x=None,
1256
+ stride_y=None,
1257
+ output_nofeature_tiles=True,
1258
+ metadata_format="PASCAL_VOC",
1259
+ start_index=0,
1260
+ class_value_field="class",
1261
+ buffer_radius=0,
1262
+ in_mask_polygons=None,
1263
+ rotation_angle=0,
1264
+ reference_system=None,
1265
+ blacken_around_feature=False,
1266
+ crop_mode="FIXED_SIZE", # Implemented but not fully used yet
1267
+ in_raster2=None,
1268
+ in_instance_data=None,
1269
+ instance_class_value_field=None, # Implemented but not fully used yet
1270
+ min_polygon_overlap_ratio=0.0,
1271
+ all_touched=True,
1272
+ save_geotiff=True,
1273
+ quiet=False,
1274
+ ):
1275
+ """
1276
+ Export training data for deep learning using TorchGeo with progress bar.
1277
+
1278
+ Args:
1279
+ in_raster (str): Path to input raster image.
1280
+ out_folder (str): Output folder path where chips and labels will be saved.
1281
+ in_class_data (str): Path to vector file containing class polygons.
1282
+ image_chip_format (str): Output image format (PNG, JPEG, TIFF, GEOTIFF).
1283
+ tile_size_x (int): Width of image chips in pixels.
1284
+ tile_size_y (int): Height of image chips in pixels.
1285
+ stride_x (int): Horizontal stride between chips. If None, uses tile_size_x.
1286
+ stride_y (int): Vertical stride between chips. If None, uses tile_size_y.
1287
+ output_nofeature_tiles (bool): Whether to export chips without features.
1288
+ metadata_format (str): Output metadata format (PASCAL_VOC, KITTI, COCO).
1289
+ start_index (int): Starting index for chip filenames.
1290
+ class_value_field (str): Field name in in_class_data containing class values.
1291
+ buffer_radius (float): Buffer radius around features (in CRS units).
1292
+ in_mask_polygons (str): Path to vector file containing mask polygons.
1293
+ rotation_angle (float): Rotation angle in degrees.
1294
+ reference_system (str): Reference system code.
1295
+ blacken_around_feature (bool): Whether to mask areas outside of features.
1296
+ crop_mode (str): Crop mode (FIXED_SIZE, CENTERED_ON_FEATURE).
1297
+ in_raster2 (str): Path to secondary raster image.
1298
+ in_instance_data (str): Path to vector file containing instance polygons.
1299
+ instance_class_value_field (str): Field name in in_instance_data for instance classes.
1300
+ min_polygon_overlap_ratio (float): Minimum overlap ratio for polygons.
1301
+ all_touched (bool): Whether to use all_touched=True in rasterization.
1302
+ save_geotiff (bool): Whether to save as GeoTIFF with georeferencing.
1303
+ quiet (bool): If True, suppress most output messages.
1304
+ """
1305
+ # Create output directories
1306
+ image_dir = os.path.join(out_folder, "images")
1307
+ os.makedirs(image_dir, exist_ok=True)
1308
+
1309
+ label_dir = os.path.join(out_folder, "labels")
1310
+ os.makedirs(label_dir, exist_ok=True)
1311
+
1312
+ # Define annotation directories based on metadata format
1313
+ if metadata_format == "PASCAL_VOC":
1314
+ ann_dir = os.path.join(out_folder, "annotations")
1315
+ os.makedirs(ann_dir, exist_ok=True)
1316
+ elif metadata_format == "COCO":
1317
+ ann_dir = os.path.join(out_folder, "annotations")
1318
+ os.makedirs(ann_dir, exist_ok=True)
1319
+ # Initialize COCO annotations dictionary
1320
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
1321
+
1322
+ # Initialize statistics dictionary
1323
+ stats = {
1324
+ "total_tiles": 0,
1325
+ "tiles_with_features": 0,
1326
+ "feature_pixels": 0,
1327
+ "errors": 0,
1328
+ }
1329
+
1330
+ # Open raster
1331
+ with rasterio.open(in_raster) as src:
1332
+ if not quiet:
1333
+ print(f"\nRaster info for {in_raster}:")
1334
+ print(f" CRS: {src.crs}")
1335
+ print(f" Dimensions: {src.width} x {src.height}")
1336
+ print(f" Bounds: {src.bounds}")
1337
+
1338
+ # Set defaults for stride if not provided
1339
+ if stride_x is None:
1340
+ stride_x = tile_size_x
1341
+ if stride_y is None:
1342
+ stride_y = tile_size_y
1343
+
1344
+ # Calculate number of tiles in x and y directions
1345
+ num_tiles_x = math.ceil((src.width - tile_size_x) / stride_x) + 1
1346
+ num_tiles_y = math.ceil((src.height - tile_size_y) / stride_y) + 1
1347
+ total_tiles = num_tiles_x * num_tiles_y
1348
+
1349
+ # Read class data
1350
+ gdf = gpd.read_file(in_class_data)
1351
+ if not quiet:
1352
+ print(f"Loaded {len(gdf)} features from {in_class_data}")
1353
+ print(f"Available columns: {gdf.columns.tolist()}")
1354
+ print(f"GeoJSON CRS: {gdf.crs}")
1355
+
1356
+ # Check if class_value_field exists
1357
+ if class_value_field not in gdf.columns:
1358
+ if not quiet:
1359
+ print(
1360
+ f"WARNING: '{class_value_field}' field not found in the input data. Using default class value 1."
1361
+ )
1362
+ # Add a default class column
1363
+ gdf[class_value_field] = 1
1364
+ unique_classes = [1]
1365
+ else:
1366
+ # Print unique classes for debugging
1367
+ unique_classes = gdf[class_value_field].unique()
1368
+ if not quiet:
1369
+ print(f"Found {len(unique_classes)} unique classes: {unique_classes}")
1370
+
1371
+ # CRITICAL: Always reproject to match raster CRS to ensure proper alignment
1372
+ if gdf.crs != src.crs:
1373
+ if not quiet:
1374
+ print(f"Reprojecting features from {gdf.crs} to {src.crs}")
1375
+ gdf = gdf.to_crs(src.crs)
1376
+ elif reference_system and gdf.crs != reference_system:
1377
+ if not quiet:
1378
+ print(
1379
+ f"Reprojecting features to specified reference system {reference_system}"
1380
+ )
1381
+ gdf = gdf.to_crs(reference_system)
1382
+
1383
+ # Check overlap between raster and vector data
1384
+ raster_bounds = box(*src.bounds)
1385
+ vector_bounds = box(*gdf.total_bounds)
1386
+ if not raster_bounds.intersects(vector_bounds):
1387
+ if not quiet:
1388
+ print(
1389
+ "WARNING: The vector data doesn't intersect with the raster extent!"
1390
+ )
1391
+ print(f"Raster bounds: {src.bounds}")
1392
+ print(f"Vector bounds: {gdf.total_bounds}")
1393
+ else:
1394
+ overlap = (
1395
+ raster_bounds.intersection(vector_bounds).area / vector_bounds.area
1396
+ )
1397
+ if not quiet:
1398
+ print(f"Overlap between raster and vector: {overlap:.2%}")
1399
+
1400
+ # Apply buffer if specified
1401
+ if buffer_radius > 0:
1402
+ gdf["geometry"] = gdf.buffer(buffer_radius)
1403
+
1404
+ # Initialize class mapping (ensure all classes are mapped to non-zero values)
1405
+ class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
1406
+
1407
+ # Store category info for COCO format
1408
+ if metadata_format == "COCO":
1409
+ for cls_val in unique_classes:
1410
+ coco_annotations["categories"].append(
1411
+ {
1412
+ "id": class_to_id[cls_val],
1413
+ "name": str(cls_val),
1414
+ "supercategory": "object",
1415
+ }
1416
+ )
1417
+
1418
+ # Load mask polygons if provided
1419
+ mask_gdf = None
1420
+ if in_mask_polygons:
1421
+ mask_gdf = gpd.read_file(in_mask_polygons)
1422
+ if reference_system:
1423
+ mask_gdf = mask_gdf.to_crs(reference_system)
1424
+ elif mask_gdf.crs != src.crs:
1425
+ mask_gdf = mask_gdf.to_crs(src.crs)
1426
+
1427
+ # Process instance data if provided
1428
+ instance_gdf = None
1429
+ if in_instance_data:
1430
+ instance_gdf = gpd.read_file(in_instance_data)
1431
+ if reference_system:
1432
+ instance_gdf = instance_gdf.to_crs(reference_system)
1433
+ elif instance_gdf.crs != src.crs:
1434
+ instance_gdf = instance_gdf.to_crs(src.crs)
1435
+
1436
+ # Load secondary raster if provided
1437
+ src2 = None
1438
+ if in_raster2:
1439
+ src2 = rasterio.open(in_raster2)
1440
+
1441
+ # Set up augmentation if rotation is specified
1442
+ augmentation = None
1443
+ if rotation_angle != 0:
1444
+ # Fixed: Added data_keys parameter to AugmentationSequential
1445
+ augmentation = torchgeo.transforms.AugmentationSequential(
1446
+ torch.nn.ModuleList([RandomRotation(rotation_angle)]),
1447
+ data_keys=["image"], # Add data_keys parameter
1448
+ )
1449
+
1450
+ # Initialize annotation ID for COCO format
1451
+ ann_id = 0
1452
+
1453
+ # Create progress bar
1454
+ pbar = tqdm(
1455
+ total=total_tiles,
1456
+ desc=f"Generating tiles (with features: 0)",
1457
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
1458
+ )
1459
+
1460
+ # Generate tiles
1461
+ chip_index = start_index
1462
+ for y in range(num_tiles_y):
1463
+ for x in range(num_tiles_x):
1464
+ # Calculate window coordinates
1465
+ window_x = x * stride_x
1466
+ window_y = y * stride_y
1467
+
1468
+ # Adjust for edge cases
1469
+ if window_x + tile_size_x > src.width:
1470
+ window_x = src.width - tile_size_x
1471
+ if window_y + tile_size_y > src.height:
1472
+ window_y = src.height - tile_size_y
1473
+
1474
+ # Adjust window based on crop_mode
1475
+ if crop_mode == "CENTERED_ON_FEATURE" and len(gdf) > 0:
1476
+ # Find the nearest feature to the center of this window
1477
+ window_center_x = window_x + tile_size_x // 2
1478
+ window_center_y = window_y + tile_size_y // 2
1479
+
1480
+ # Convert center to world coordinates
1481
+ center_x, center_y = src.xy(window_center_y, window_center_x)
1482
+ center_point = gpd.points_from_xy([center_x], [center_y])[0]
1483
+
1484
+ # Find nearest feature
1485
+ distances = gdf.geometry.distance(center_point)
1486
+ nearest_idx = distances.idxmin()
1487
+ nearest_feature = gdf.iloc[nearest_idx]
1488
+
1489
+ # Get centroid of nearest feature
1490
+ feature_centroid = nearest_feature.geometry.centroid
1491
+
1492
+ # Convert feature centroid to pixel coordinates
1493
+ feature_row, feature_col = src.index(
1494
+ feature_centroid.x, feature_centroid.y
1495
+ )
1496
+
1497
+ # Adjust window to center on feature
1498
+ window_x = max(
1499
+ 0, min(src.width - tile_size_x, feature_col - tile_size_x // 2)
1500
+ )
1501
+ window_y = max(
1502
+ 0, min(src.height - tile_size_y, feature_row - tile_size_y // 2)
1503
+ )
1504
+
1505
+ # Define window
1506
+ window = Window(window_x, window_y, tile_size_x, tile_size_y)
1507
+
1508
+ # Get window transform and bounds in source CRS
1509
+ window_transform = src.window_transform(window)
1510
+
1511
+ # Calculate window bounds more explicitly and accurately
1512
+ minx = window_transform[2] # Upper left x
1513
+ maxy = window_transform[5] # Upper left y
1514
+ maxx = minx + tile_size_x * window_transform[0] # Add width
1515
+ miny = (
1516
+ maxy + tile_size_y * window_transform[4]
1517
+ ) # Add height (note: transform[4] is typically negative)
1518
+
1519
+ window_bounds = box(minx, miny, maxx, maxy)
1520
+
1521
+ # Apply rotation if specified
1522
+ if rotation_angle != 0:
1523
+ window_bounds = rotate(
1524
+ window_bounds, rotation_angle, origin="center"
1525
+ )
1526
+
1527
+ # Find features that intersect with window
1528
+ window_features = gdf[gdf.intersects(window_bounds)]
1529
+
1530
+ # Process instance data if provided
1531
+ window_instances = None
1532
+ if instance_gdf is not None and instance_class_value_field is not None:
1533
+ window_instances = instance_gdf[
1534
+ instance_gdf.intersects(window_bounds)
1535
+ ]
1536
+ if len(window_instances) > 0:
1537
+ if not quiet:
1538
+ pbar.write(
1539
+ f"Found {len(window_instances)} instances in tile {chip_index}"
1540
+ )
1541
+
1542
+ # Skip if no features and output_nofeature_tiles is False
1543
+ if not output_nofeature_tiles and len(window_features) == 0:
1544
+ pbar.update(1) # Still update progress bar
1545
+ continue
1546
+
1547
+ # Check polygon overlap ratio if specified
1548
+ if min_polygon_overlap_ratio > 0 and len(window_features) > 0:
1549
+ valid_features = []
1550
+ for _, feature in window_features.iterrows():
1551
+ overlap_ratio = (
1552
+ feature.geometry.intersection(window_bounds).area
1553
+ / feature.geometry.area
1554
+ )
1555
+ if overlap_ratio >= min_polygon_overlap_ratio:
1556
+ valid_features.append(feature)
1557
+
1558
+ if len(valid_features) > 0:
1559
+ window_features = gpd.GeoDataFrame(valid_features)
1560
+ elif not output_nofeature_tiles:
1561
+ pbar.update(1) # Still update progress bar
1562
+ continue
1563
+
1564
+ # Apply mask if provided
1565
+ if mask_gdf is not None:
1566
+ mask_features = mask_gdf[mask_gdf.intersects(window_bounds)]
1567
+ if len(mask_features) == 0:
1568
+ pbar.update(1) # Still update progress bar
1569
+ continue
1570
+
1571
+ # Read image data - keep original for GeoTIFF export
1572
+ orig_image_data = src.read(window=window)
1573
+
1574
+ # Create a copy for processing
1575
+ image_data = orig_image_data.copy().astype(np.float32)
1576
+
1577
+ # Normalize image data for processing
1578
+ for band in range(image_data.shape[0]):
1579
+ band_min, band_max = np.percentile(image_data[band], (1, 99))
1580
+ if band_max > band_min:
1581
+ image_data[band] = np.clip(
1582
+ (image_data[band] - band_min) / (band_max - band_min), 0, 1
1583
+ )
1584
+
1585
+ # Read secondary image data if provided
1586
+ if src2:
1587
+ image_data2 = src2.read(window=window)
1588
+ # Stack the two images
1589
+ image_data = np.vstack((image_data, image_data2))
1590
+
1591
+ # Apply blacken_around_feature if needed
1592
+ if blacken_around_feature and len(window_features) > 0:
1593
+ mask = np.zeros((tile_size_y, tile_size_x), dtype=bool)
1594
+ for _, feature in window_features.iterrows():
1595
+ # Project feature to pixel coordinates
1596
+ feature_pixels = features.rasterize(
1597
+ [(feature.geometry, 1)],
1598
+ out_shape=(tile_size_y, tile_size_x),
1599
+ transform=window_transform,
1600
+ )
1601
+ mask = np.logical_or(mask, feature_pixels.astype(bool))
1602
+
1603
+ # Apply mask to image
1604
+ for band in range(image_data.shape[0]):
1605
+ temp = image_data[band, :, :]
1606
+ temp[~mask] = 0
1607
+ image_data[band, :, :] = temp
1608
+
1609
+ # Apply rotation if specified
1610
+ if augmentation:
1611
+ # Convert to torch tensor for augmentation
1612
+ image_tensor = torch.from_numpy(image_data).unsqueeze(
1613
+ 0
1614
+ ) # Add batch dimension
1615
+ # Apply augmentation with proper data format
1616
+ augmented = augmentation({"image": image_tensor})
1617
+ image_data = (
1618
+ augmented["image"].squeeze(0).numpy()
1619
+ ) # Remove batch dimension
1620
+
1621
+ # Create a processed version for regular image formats
1622
+ processed_image = (image_data * 255).astype(np.uint8)
1623
+
1624
+ # Create label mask
1625
+ label_mask = np.zeros((tile_size_y, tile_size_x), dtype=np.uint8)
1626
+ has_features = False
1627
+
1628
+ if len(window_features) > 0:
1629
+ for idx, feature in window_features.iterrows():
1630
+ # Get class value
1631
+ class_val = (
1632
+ feature[class_value_field]
1633
+ if class_value_field in feature
1634
+ else 1
1635
+ )
1636
+ if isinstance(class_val, str):
1637
+ # If class is a string, use its position in the unique classes list
1638
+ class_id = class_to_id.get(class_val, 1)
1639
+ else:
1640
+ # If class is already a number, use it directly
1641
+ class_id = int(class_val) if class_val > 0 else 1
1642
+
1643
+ # Get the geometry in pixel coordinates
1644
+ geom = feature.geometry.intersection(window_bounds)
1645
+ if not geom.is_empty:
1646
+ try:
1647
+ # Rasterize the feature
1648
+ feature_mask = features.rasterize(
1649
+ [(geom, class_id)],
1650
+ out_shape=(tile_size_y, tile_size_x),
1651
+ transform=window_transform,
1652
+ fill=0,
1653
+ all_touched=all_touched,
1654
+ )
1655
+
1656
+ # Update mask with higher class values taking precedence
1657
+ label_mask = np.maximum(label_mask, feature_mask)
1658
+
1659
+ # Check if any pixels were added
1660
+ if np.any(feature_mask):
1661
+ has_features = True
1662
+ except Exception as e:
1663
+ if not quiet:
1664
+ pbar.write(f"Error rasterizing feature {idx}: {e}")
1665
+ stats["errors"] += 1
1666
+
1667
+ # Save as GeoTIFF if requested
1668
+ if save_geotiff or image_chip_format.upper() in [
1669
+ "TIFF",
1670
+ "TIF",
1671
+ "GEOTIFF",
1672
+ ]:
1673
+ # Standardize extension to .tif for GeoTIFF files
1674
+ image_filename = f"tile_{chip_index:06d}.tif"
1675
+ image_path = os.path.join(image_dir, image_filename)
1676
+
1677
+ # Create profile for the GeoTIFF
1678
+ profile = src.profile.copy()
1679
+ profile.update(
1680
+ {
1681
+ "height": tile_size_y,
1682
+ "width": tile_size_x,
1683
+ "count": orig_image_data.shape[0],
1684
+ "transform": window_transform,
1685
+ }
1686
+ )
1687
+
1688
+ # Save the GeoTIFF with original data
1689
+ try:
1690
+ with rasterio.open(image_path, "w", **profile) as dst:
1691
+ dst.write(orig_image_data)
1692
+ stats["total_tiles"] += 1
1693
+ except Exception as e:
1694
+ if not quiet:
1695
+ pbar.write(
1696
+ f"ERROR saving image GeoTIFF for tile {chip_index}: {e}"
1697
+ )
1698
+ stats["errors"] += 1
1699
+ else:
1700
+ # For non-GeoTIFF formats, use PIL to save the image
1701
+ image_filename = (
1702
+ f"tile_{chip_index:06d}.{image_chip_format.lower()}"
1703
+ )
1704
+ image_path = os.path.join(image_dir, image_filename)
1705
+
1706
+ # Create PIL image for saving
1707
+ if processed_image.shape[0] == 1:
1708
+ img = Image.fromarray(processed_image[0])
1709
+ elif processed_image.shape[0] == 3:
1710
+ # For RGB, need to transpose and make sure it's the right data type
1711
+ rgb_data = np.transpose(processed_image, (1, 2, 0))
1712
+ img = Image.fromarray(rgb_data)
1713
+ else:
1714
+ # For multiband images, save only RGB or first three bands
1715
+ rgb_data = np.transpose(processed_image[:3], (1, 2, 0))
1716
+ img = Image.fromarray(rgb_data)
1717
+
1718
+ # Save image
1719
+ try:
1720
+ img.save(image_path)
1721
+ stats["total_tiles"] += 1
1722
+ except Exception as e:
1723
+ if not quiet:
1724
+ pbar.write(f"ERROR saving image for tile {chip_index}: {e}")
1725
+ stats["errors"] += 1
1726
+
1727
+ # Save label as GeoTIFF
1728
+ label_filename = f"tile_{chip_index:06d}.tif"
1729
+ label_path = os.path.join(label_dir, label_filename)
1730
+
1731
+ # Create profile for label GeoTIFF
1732
+ label_profile = {
1733
+ "driver": "GTiff",
1734
+ "height": tile_size_y,
1735
+ "width": tile_size_x,
1736
+ "count": 1,
1737
+ "dtype": "uint8",
1738
+ "crs": src.crs,
1739
+ "transform": window_transform,
1740
+ }
1741
+
1742
+ # Save label GeoTIFF
1743
+ try:
1744
+ with rasterio.open(label_path, "w", **label_profile) as dst:
1745
+ dst.write(label_mask, 1)
1746
+
1747
+ if has_features:
1748
+ pixel_count = np.count_nonzero(label_mask)
1749
+ stats["tiles_with_features"] += 1
1750
+ stats["feature_pixels"] += pixel_count
1751
+ except Exception as e:
1752
+ if not quiet:
1753
+ pbar.write(f"ERROR saving label for tile {chip_index}: {e}")
1754
+ stats["errors"] += 1
1755
+
1756
+ # Also save a PNG version for easy visualization if requested
1757
+ if metadata_format == "PASCAL_VOC":
1758
+ try:
1759
+ # Ensure correct data type for PIL
1760
+ png_label = label_mask.astype(np.uint8)
1761
+ label_img = Image.fromarray(png_label)
1762
+ label_png_path = os.path.join(
1763
+ label_dir, f"tile_{chip_index:06d}.png"
1764
+ )
1765
+ label_img.save(label_png_path)
1766
+ except Exception as e:
1767
+ if not quiet:
1768
+ pbar.write(
1769
+ f"ERROR saving PNG label for tile {chip_index}: {e}"
1770
+ )
1771
+ pbar.write(
1772
+ f" Label mask shape: {label_mask.shape}, dtype: {label_mask.dtype}"
1773
+ )
1774
+ # Try again with explicit conversion
1775
+ try:
1776
+ # Alternative approach for problematic arrays
1777
+ png_data = np.zeros(
1778
+ (tile_size_y, tile_size_x), dtype=np.uint8
1779
+ )
1780
+ np.copyto(png_data, label_mask, casting="unsafe")
1781
+ label_img = Image.fromarray(png_data)
1782
+ label_img.save(label_png_path)
1783
+ pbar.write(
1784
+ f" Succeeded using alternative conversion method"
1785
+ )
1786
+ except Exception as e2:
1787
+ pbar.write(f" Second attempt also failed: {e2}")
1788
+ stats["errors"] += 1
1789
+
1790
+ # Generate annotations
1791
+ if metadata_format == "PASCAL_VOC" and len(window_features) > 0:
1792
+ # Create XML annotation
1793
+ root = ET.Element("annotation")
1794
+ ET.SubElement(root, "folder").text = "images"
1795
+ ET.SubElement(root, "filename").text = image_filename
1796
+
1797
+ size = ET.SubElement(root, "size")
1798
+ ET.SubElement(size, "width").text = str(tile_size_x)
1799
+ ET.SubElement(size, "height").text = str(tile_size_y)
1800
+ ET.SubElement(size, "depth").text = str(min(image_data.shape[0], 3))
1801
+
1802
+ # Add georeference information
1803
+ geo = ET.SubElement(root, "georeference")
1804
+ ET.SubElement(geo, "crs").text = str(src.crs)
1805
+ ET.SubElement(geo, "transform").text = str(
1806
+ window_transform
1807
+ ).replace("\n", "")
1808
+ ET.SubElement(geo, "bounds").text = (
1809
+ f"{minx}, {miny}, {maxx}, {maxy}"
1810
+ )
1811
+
1812
+ for _, feature in window_features.iterrows():
1813
+ # Convert feature geometry to pixel coordinates
1814
+ feature_bounds = feature.geometry.intersection(window_bounds)
1815
+ if feature_bounds.is_empty:
1816
+ continue
1817
+
1818
+ # Get pixel coordinates of bounds
1819
+ minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
1820
+
1821
+ # Convert to pixel coordinates
1822
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
1823
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
1824
+
1825
+ # Ensure coordinates are within bounds
1826
+ xmin = max(0, min(tile_size_x, int(col_min)))
1827
+ ymin = max(0, min(tile_size_y, int(row_min)))
1828
+ xmax = max(0, min(tile_size_x, int(col_max)))
1829
+ ymax = max(0, min(tile_size_y, int(row_max)))
1830
+
1831
+ # Skip if box is too small
1832
+ if xmax - xmin < 1 or ymax - ymin < 1:
1833
+ continue
1834
+
1835
+ obj = ET.SubElement(root, "object")
1836
+ ET.SubElement(obj, "name").text = str(
1837
+ feature[class_value_field]
1838
+ )
1839
+ ET.SubElement(obj, "difficult").text = "0"
1840
+
1841
+ bbox = ET.SubElement(obj, "bndbox")
1842
+ ET.SubElement(bbox, "xmin").text = str(xmin)
1843
+ ET.SubElement(bbox, "ymin").text = str(ymin)
1844
+ ET.SubElement(bbox, "xmax").text = str(xmax)
1845
+ ET.SubElement(bbox, "ymax").text = str(ymax)
1846
+
1847
+ # Save XML
1848
+ try:
1849
+ tree = ET.ElementTree(root)
1850
+ xml_path = os.path.join(ann_dir, f"tile_{chip_index:06d}.xml")
1851
+ tree.write(xml_path)
1852
+ except Exception as e:
1853
+ if not quiet:
1854
+ pbar.write(
1855
+ f"ERROR saving XML annotation for tile {chip_index}: {e}"
1856
+ )
1857
+ stats["errors"] += 1
1858
+
1859
+ elif metadata_format == "COCO" and len(window_features) > 0:
1860
+ # Add image info
1861
+ image_id = chip_index
1862
+ coco_annotations["images"].append(
1863
+ {
1864
+ "id": image_id,
1865
+ "file_name": image_filename,
1866
+ "width": tile_size_x,
1867
+ "height": tile_size_y,
1868
+ "crs": str(src.crs),
1869
+ "transform": str(window_transform),
1870
+ }
1871
+ )
1872
+
1873
+ # Add annotations for each feature
1874
+ for _, feature in window_features.iterrows():
1875
+ feature_bounds = feature.geometry.intersection(window_bounds)
1876
+ if feature_bounds.is_empty:
1877
+ continue
1878
+
1879
+ # Get pixel coordinates of bounds
1880
+ minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
1881
+
1882
+ # Convert to pixel coordinates
1883
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
1884
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
1885
+
1886
+ # Ensure coordinates are within bounds
1887
+ xmin = max(0, min(tile_size_x, int(col_min)))
1888
+ ymin = max(0, min(tile_size_y, int(row_min)))
1889
+ xmax = max(0, min(tile_size_x, int(col_max)))
1890
+ ymax = max(0, min(tile_size_y, int(row_max)))
1891
+
1892
+ # Skip if box is too small
1893
+ if xmax - xmin < 1 or ymax - ymin < 1:
1894
+ continue
1895
+
1896
+ width = xmax - xmin
1897
+ height = ymax - ymin
1898
+
1899
+ # Add annotation
1900
+ ann_id += 1
1901
+ category_id = class_to_id[feature[class_value_field]]
1902
+
1903
+ coco_annotations["annotations"].append(
1904
+ {
1905
+ "id": ann_id,
1906
+ "image_id": image_id,
1907
+ "category_id": category_id,
1908
+ "bbox": [xmin, ymin, width, height],
1909
+ "area": width * height,
1910
+ "iscrowd": 0,
1911
+ }
1912
+ )
1913
+
1914
+ # Update progress bar
1915
+ pbar.update(1)
1916
+ pbar.set_description(
1917
+ f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
1918
+ )
1919
+
1920
+ chip_index += 1
1921
+
1922
+ # Close progress bar
1923
+ pbar.close()
1924
+
1925
+ # Save COCO annotations if applicable
1926
+ if metadata_format == "COCO":
1927
+ try:
1928
+ with open(os.path.join(ann_dir, "instances.json"), "w") as f:
1929
+ json.dump(coco_annotations, f)
1930
+ except Exception as e:
1931
+ if not quiet:
1932
+ print(f"ERROR saving COCO annotations: {e}")
1933
+ stats["errors"] += 1
1934
+
1935
+ # Close secondary raster if opened
1936
+ if src2:
1937
+ src2.close()
1938
+
1939
+ # Print summary
1940
+ if not quiet:
1941
+ print("\n------- Export Summary -------")
1942
+ print(f"Total tiles exported: {stats['total_tiles']}")
1943
+ print(
1944
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
1945
+ )
1946
+ if stats["tiles_with_features"] > 0:
1947
+ print(
1948
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
1949
+ )
1950
+ if stats["errors"] > 0:
1951
+ print(f"Errors encountered: {stats['errors']}")
1952
+ print(f"Output saved to: {out_folder}")
1953
+
1954
+ # Verify georeference in a sample image and label
1955
+ if stats["total_tiles"] > 0:
1956
+ print("\n------- Georeference Verification -------")
1957
+ sample_image = os.path.join(image_dir, f"tile_{start_index}.tif")
1958
+ sample_label = os.path.join(label_dir, f"tile_{start_index}.tif")
1959
+
1960
+ if os.path.exists(sample_image):
1961
+ try:
1962
+ with rasterio.open(sample_image) as img:
1963
+ print(f"Image CRS: {img.crs}")
1964
+ print(f"Image transform: {img.transform}")
1965
+ print(
1966
+ f"Image has georeference: {img.crs is not None and img.transform is not None}"
1967
+ )
1968
+ print(
1969
+ f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
1970
+ )
1971
+ except Exception as e:
1972
+ print(f"Error verifying image georeference: {e}")
1973
+
1974
+ if os.path.exists(sample_label):
1975
+ try:
1976
+ with rasterio.open(sample_label) as lbl:
1977
+ print(f"Label CRS: {lbl.crs}")
1978
+ print(f"Label transform: {lbl.transform}")
1979
+ print(
1980
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
1981
+ )
1982
+ print(
1983
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
1984
+ )
1985
+ except Exception as e:
1986
+ print(f"Error verifying label georeference: {e}")
1987
+
1988
+ # Return statistics
1989
+ return stats, out_folder
1990
+
1991
+
1992
+ # if __name__ == "__main__":
1993
+ # # Example parameters
1994
+ # export_training_data(
1995
+ # in_raster="naip_train.tif",
1996
+ # out_folder="output",
1997
+ # in_class_data="buildings_train.geojson",
1998
+ # image_chip_format="GEOTIFF", # Use GeoTIFF format to preserve georeference
1999
+ # tile_size_x=256,
2000
+ # tile_size_y=256,
2001
+ # stride_x=128, # Use overlapping tiles to increase chance of capturing features
2002
+ # stride_y=128,
2003
+ # metadata_format="PASCAL_VOC",
2004
+ # class_value_field="class",
2005
+ # buffer_radius=2, # Add small buffer to buildings to ensure they're captured
2006
+ # all_touched=True, # Ensure small features are rasterized
2007
+ # save_geotiff=True, # Always save as GeoTIFF regardless of image_chip_format
2008
+ # )