geoai-py 0.1.7__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +28 -1
- geoai/common.py +158 -1
- geoai/download.py +9 -0
- geoai/extract.py +832 -0
- geoai/preprocess.py +2008 -0
- geoai_py-0.2.1.dist-info/METADATA +136 -0
- geoai_py-0.2.1.dist-info/RECORD +13 -0
- geoai_py-0.1.7.dist-info/METADATA +0 -51
- geoai_py-0.1.7.dist-info/RECORD +0 -11
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/LICENSE +0 -0
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/top_level.txt +0 -0
geoai/preprocess.py
ADDED
|
@@ -0,0 +1,2008 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
import os
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import warnings
|
|
7
|
+
import xml.etree.ElementTree as ET
|
|
8
|
+
import numpy as np
|
|
9
|
+
import rasterio
|
|
10
|
+
import geopandas as gpd
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from rasterio.windows import Window
|
|
13
|
+
from rasterio import features
|
|
14
|
+
from shapely.geometry import box, shape
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
from torchvision.transforms import RandomRotation
|
|
18
|
+
from shapely.affinity import rotate
|
|
19
|
+
import torchgeo
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def raster_to_vector(
|
|
24
|
+
raster_path,
|
|
25
|
+
output_path=None,
|
|
26
|
+
threshold=0,
|
|
27
|
+
min_area=10,
|
|
28
|
+
simplify_tolerance=None,
|
|
29
|
+
class_values=None,
|
|
30
|
+
attribute_name="class",
|
|
31
|
+
output_format="geojson",
|
|
32
|
+
plot_result=False,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Convert a raster label mask to vector polygons.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
raster_path (str): Path to the input raster file (e.g., GeoTIFF).
|
|
39
|
+
output_path (str): Path to save the output vector file. If None, returns GeoDataFrame without saving.
|
|
40
|
+
threshold (int/float): Pixel values greater than this threshold will be vectorized.
|
|
41
|
+
min_area (float): Minimum polygon area in square map units to keep.
|
|
42
|
+
simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
|
|
43
|
+
class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
|
|
44
|
+
attribute_name (str): Name of the attribute field for the class values.
|
|
45
|
+
output_format (str): Format for output file - 'geojson', 'shapefile', 'gpkg'.
|
|
46
|
+
plot_result (bool): Whether to plot the resulting polygons overlaid on the raster.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
geopandas.GeoDataFrame: A GeoDataFrame containing the vectorized polygons.
|
|
50
|
+
"""
|
|
51
|
+
# Open the raster file
|
|
52
|
+
with rasterio.open(raster_path) as src:
|
|
53
|
+
# Read the data
|
|
54
|
+
data = src.read(1)
|
|
55
|
+
|
|
56
|
+
# Get metadata
|
|
57
|
+
transform = src.transform
|
|
58
|
+
crs = src.crs
|
|
59
|
+
|
|
60
|
+
# Create mask based on threshold and class values
|
|
61
|
+
if class_values is not None:
|
|
62
|
+
# Create a mask for each specified class value
|
|
63
|
+
masks = {val: (data == val) for val in class_values}
|
|
64
|
+
else:
|
|
65
|
+
# Create a mask for values above threshold
|
|
66
|
+
masks = {1: (data > threshold)}
|
|
67
|
+
class_values = [1] # Default class
|
|
68
|
+
|
|
69
|
+
# Initialize list to store features
|
|
70
|
+
all_features = []
|
|
71
|
+
|
|
72
|
+
# Process each class value
|
|
73
|
+
for class_val in class_values:
|
|
74
|
+
mask = masks[class_val]
|
|
75
|
+
|
|
76
|
+
# Vectorize the mask
|
|
77
|
+
for geom, value in features.shapes(
|
|
78
|
+
mask.astype(np.uint8), mask=mask, transform=transform
|
|
79
|
+
):
|
|
80
|
+
# Convert to shapely geometry
|
|
81
|
+
geom = shape(geom)
|
|
82
|
+
|
|
83
|
+
# Skip small polygons
|
|
84
|
+
if geom.area < min_area:
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
# Simplify geometry if requested
|
|
88
|
+
if simplify_tolerance is not None:
|
|
89
|
+
geom = geom.simplify(simplify_tolerance)
|
|
90
|
+
|
|
91
|
+
# Add to features list with class value
|
|
92
|
+
all_features.append({"geometry": geom, attribute_name: class_val})
|
|
93
|
+
|
|
94
|
+
# Create GeoDataFrame
|
|
95
|
+
if all_features:
|
|
96
|
+
gdf = gpd.GeoDataFrame(all_features, crs=crs)
|
|
97
|
+
else:
|
|
98
|
+
print("Warning: No features were extracted from the raster.")
|
|
99
|
+
# Return empty GeoDataFrame with correct CRS
|
|
100
|
+
gdf = gpd.GeoDataFrame([], geometry=[], crs=crs)
|
|
101
|
+
|
|
102
|
+
# Save to file if requested
|
|
103
|
+
if output_path is not None:
|
|
104
|
+
# Create directory if it doesn't exist
|
|
105
|
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
|
106
|
+
|
|
107
|
+
# Save to file based on format
|
|
108
|
+
if output_format.lower() == "geojson":
|
|
109
|
+
gdf.to_file(output_path, driver="GeoJSON")
|
|
110
|
+
elif output_format.lower() == "shapefile":
|
|
111
|
+
gdf.to_file(output_path)
|
|
112
|
+
elif output_format.lower() == "gpkg":
|
|
113
|
+
gdf.to_file(output_path, driver="GPKG")
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(f"Unsupported output format: {output_format}")
|
|
116
|
+
|
|
117
|
+
print(f"Vectorized data saved to {output_path}")
|
|
118
|
+
|
|
119
|
+
# Plot result if requested
|
|
120
|
+
if plot_result:
|
|
121
|
+
fig, ax = plt.subplots(figsize=(12, 12))
|
|
122
|
+
|
|
123
|
+
# Plot raster
|
|
124
|
+
raster_img = src.read()
|
|
125
|
+
if raster_img.shape[0] == 1:
|
|
126
|
+
plt.imshow(raster_img[0], cmap="viridis", alpha=0.7)
|
|
127
|
+
else:
|
|
128
|
+
# Use first 3 bands for RGB display
|
|
129
|
+
rgb = raster_img[:3].transpose(1, 2, 0)
|
|
130
|
+
# Normalize for display
|
|
131
|
+
rgb = np.clip(rgb / rgb.max(), 0, 1)
|
|
132
|
+
plt.imshow(rgb)
|
|
133
|
+
|
|
134
|
+
# Plot vector boundaries
|
|
135
|
+
if not gdf.empty:
|
|
136
|
+
gdf.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
|
|
137
|
+
|
|
138
|
+
plt.title("Raster with Vectorized Boundaries")
|
|
139
|
+
plt.axis("off")
|
|
140
|
+
plt.tight_layout()
|
|
141
|
+
plt.show()
|
|
142
|
+
|
|
143
|
+
return gdf
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def batch_raster_to_vector(
|
|
147
|
+
input_dir,
|
|
148
|
+
output_dir,
|
|
149
|
+
pattern="*.tif",
|
|
150
|
+
threshold=0,
|
|
151
|
+
min_area=10,
|
|
152
|
+
simplify_tolerance=None,
|
|
153
|
+
class_values=None,
|
|
154
|
+
attribute_name="class",
|
|
155
|
+
output_format="geojson",
|
|
156
|
+
merge_output=False,
|
|
157
|
+
merge_filename="merged_vectors",
|
|
158
|
+
):
|
|
159
|
+
"""
|
|
160
|
+
Batch convert multiple raster files to vector polygons.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
input_dir (str): Directory containing input raster files.
|
|
164
|
+
output_dir (str): Directory to save output vector files.
|
|
165
|
+
pattern (str): Pattern to match raster files (e.g., '*.tif').
|
|
166
|
+
threshold (int/float): Pixel values greater than this threshold will be vectorized.
|
|
167
|
+
min_area (float): Minimum polygon area in square map units to keep.
|
|
168
|
+
simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
|
|
169
|
+
class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
|
|
170
|
+
attribute_name (str): Name of the attribute field for the class values.
|
|
171
|
+
output_format (str): Format for output files - 'geojson', 'shapefile', 'gpkg'.
|
|
172
|
+
merge_output (bool): Whether to merge all output vectors into a single file.
|
|
173
|
+
merge_filename (str): Filename for the merged output (without extension).
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
geopandas.GeoDataFrame or None: If merge_output is True, returns the merged GeoDataFrame.
|
|
177
|
+
"""
|
|
178
|
+
import glob
|
|
179
|
+
|
|
180
|
+
# Create output directory if it doesn't exist
|
|
181
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
182
|
+
|
|
183
|
+
# Get list of raster files
|
|
184
|
+
raster_files = glob.glob(os.path.join(input_dir, pattern))
|
|
185
|
+
|
|
186
|
+
if not raster_files:
|
|
187
|
+
print(f"No files matching pattern '{pattern}' found in {input_dir}")
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
print(f"Found {len(raster_files)} raster files to process")
|
|
191
|
+
|
|
192
|
+
# Process each raster file
|
|
193
|
+
gdfs = []
|
|
194
|
+
for raster_file in tqdm(raster_files, desc="Processing rasters"):
|
|
195
|
+
# Get output filename
|
|
196
|
+
base_name = os.path.splitext(os.path.basename(raster_file))[0]
|
|
197
|
+
if output_format.lower() == "geojson":
|
|
198
|
+
out_file = os.path.join(output_dir, f"{base_name}.geojson")
|
|
199
|
+
elif output_format.lower() == "shapefile":
|
|
200
|
+
out_file = os.path.join(output_dir, f"{base_name}.shp")
|
|
201
|
+
elif output_format.lower() == "gpkg":
|
|
202
|
+
out_file = os.path.join(output_dir, f"{base_name}.gpkg")
|
|
203
|
+
else:
|
|
204
|
+
raise ValueError(f"Unsupported output format: {output_format}")
|
|
205
|
+
|
|
206
|
+
# Convert raster to vector
|
|
207
|
+
if merge_output:
|
|
208
|
+
# Don't save individual files if merging
|
|
209
|
+
gdf = raster_to_vector(
|
|
210
|
+
raster_file,
|
|
211
|
+
output_path=None,
|
|
212
|
+
threshold=threshold,
|
|
213
|
+
min_area=min_area,
|
|
214
|
+
simplify_tolerance=simplify_tolerance,
|
|
215
|
+
class_values=class_values,
|
|
216
|
+
attribute_name=attribute_name,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
# Add filename as attribute
|
|
220
|
+
if not gdf.empty:
|
|
221
|
+
gdf["source_file"] = base_name
|
|
222
|
+
gdfs.append(gdf)
|
|
223
|
+
else:
|
|
224
|
+
# Save individual files
|
|
225
|
+
raster_to_vector(
|
|
226
|
+
raster_file,
|
|
227
|
+
output_path=out_file,
|
|
228
|
+
threshold=threshold,
|
|
229
|
+
min_area=min_area,
|
|
230
|
+
simplify_tolerance=simplify_tolerance,
|
|
231
|
+
class_values=class_values,
|
|
232
|
+
attribute_name=attribute_name,
|
|
233
|
+
output_format=output_format,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Merge output if requested
|
|
237
|
+
if merge_output and gdfs:
|
|
238
|
+
merged_gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))
|
|
239
|
+
|
|
240
|
+
# Set CRS to the CRS of the first GeoDataFrame
|
|
241
|
+
if merged_gdf.crs is None and gdfs:
|
|
242
|
+
merged_gdf.crs = gdfs[0].crs
|
|
243
|
+
|
|
244
|
+
# Save merged output
|
|
245
|
+
if output_format.lower() == "geojson":
|
|
246
|
+
merged_file = os.path.join(output_dir, f"{merge_filename}.geojson")
|
|
247
|
+
merged_gdf.to_file(merged_file, driver="GeoJSON")
|
|
248
|
+
elif output_format.lower() == "shapefile":
|
|
249
|
+
merged_file = os.path.join(output_dir, f"{merge_filename}.shp")
|
|
250
|
+
merged_gdf.to_file(merged_file)
|
|
251
|
+
elif output_format.lower() == "gpkg":
|
|
252
|
+
merged_file = os.path.join(output_dir, f"{merge_filename}.gpkg")
|
|
253
|
+
merged_gdf.to_file(merged_file, driver="GPKG")
|
|
254
|
+
|
|
255
|
+
print(f"Merged vector data saved to {merged_file}")
|
|
256
|
+
return merged_gdf
|
|
257
|
+
|
|
258
|
+
return None
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
# # Example usage
|
|
262
|
+
# if __name__ == "__main__":
|
|
263
|
+
# # Single file conversion example
|
|
264
|
+
# gdf = raster_to_vector(
|
|
265
|
+
# raster_path="output/labels/tile_000001.tif",
|
|
266
|
+
# output_path="output/labels/tile_000001.geojson",
|
|
267
|
+
# threshold=0,
|
|
268
|
+
# min_area=10,
|
|
269
|
+
# simplify_tolerance=0.5,
|
|
270
|
+
# class_values=[1], # For a binary mask, use [1]
|
|
271
|
+
# attribute_name='class',
|
|
272
|
+
# plot_result=True
|
|
273
|
+
# )
|
|
274
|
+
|
|
275
|
+
# Batch conversion example
|
|
276
|
+
# batch_raster_to_vector(
|
|
277
|
+
# input_dir="path/to/labels",
|
|
278
|
+
# output_dir="path/to/vectors",
|
|
279
|
+
# pattern="*.tif",
|
|
280
|
+
# threshold=0,
|
|
281
|
+
# min_area=10,
|
|
282
|
+
# class_values=[1, 2, 3], # For a multiclass mask
|
|
283
|
+
# merge_output=True
|
|
284
|
+
# )
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def vector_to_raster(
|
|
288
|
+
vector_path,
|
|
289
|
+
output_path=None,
|
|
290
|
+
reference_raster=None,
|
|
291
|
+
attribute_field=None,
|
|
292
|
+
output_shape=None,
|
|
293
|
+
transform=None,
|
|
294
|
+
pixel_size=None,
|
|
295
|
+
bounds=None,
|
|
296
|
+
crs=None,
|
|
297
|
+
all_touched=False,
|
|
298
|
+
fill_value=0,
|
|
299
|
+
dtype=np.uint8,
|
|
300
|
+
nodata=None,
|
|
301
|
+
plot_result=False,
|
|
302
|
+
):
|
|
303
|
+
"""
|
|
304
|
+
Convert vector data to a raster.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
|
|
308
|
+
output_path (str): Path to save the output raster file. If None, returns the array without saving.
|
|
309
|
+
reference_raster (str): Path to a reference raster for dimensions, transform and CRS.
|
|
310
|
+
attribute_field (str): Field name in the vector data to use for pixel values.
|
|
311
|
+
If None, all vector features will be burned with value 1.
|
|
312
|
+
output_shape (tuple): Shape of the output raster as (height, width).
|
|
313
|
+
Required if reference_raster is not provided.
|
|
314
|
+
transform (affine.Affine): Affine transformation matrix.
|
|
315
|
+
Required if reference_raster is not provided.
|
|
316
|
+
pixel_size (float or tuple): Pixel size (resolution) as single value or (x_res, y_res).
|
|
317
|
+
Used to calculate transform if transform is not provided.
|
|
318
|
+
bounds (tuple): Bounds of the output raster as (left, bottom, right, top).
|
|
319
|
+
Used to calculate transform if transform is not provided.
|
|
320
|
+
crs (str or CRS): Coordinate reference system of the output raster.
|
|
321
|
+
Required if reference_raster is not provided.
|
|
322
|
+
all_touched (bool): If True, all pixels touched by geometries will be burned in.
|
|
323
|
+
If False, only pixels whose center is within the geometry will be burned in.
|
|
324
|
+
fill_value (int): Value to fill the raster with before burning in features.
|
|
325
|
+
dtype (numpy.dtype): Data type of the output raster.
|
|
326
|
+
nodata (int): No data value for the output raster.
|
|
327
|
+
plot_result (bool): Whether to plot the resulting raster.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
numpy.ndarray: The rasterized data array if output_path is None, else None.
|
|
331
|
+
"""
|
|
332
|
+
# Load vector data
|
|
333
|
+
if isinstance(vector_path, gpd.GeoDataFrame):
|
|
334
|
+
gdf = vector_path
|
|
335
|
+
else:
|
|
336
|
+
gdf = gpd.read_file(vector_path)
|
|
337
|
+
|
|
338
|
+
# Check if vector data is empty
|
|
339
|
+
if gdf.empty:
|
|
340
|
+
warnings.warn("The input vector data is empty. Creating an empty raster.")
|
|
341
|
+
|
|
342
|
+
# Get CRS from vector data if not provided
|
|
343
|
+
if crs is None and reference_raster is None:
|
|
344
|
+
crs = gdf.crs
|
|
345
|
+
|
|
346
|
+
# Get transform and output shape from reference raster if provided
|
|
347
|
+
if reference_raster is not None:
|
|
348
|
+
with rasterio.open(reference_raster) as src:
|
|
349
|
+
transform = src.transform
|
|
350
|
+
output_shape = src.shape
|
|
351
|
+
crs = src.crs
|
|
352
|
+
if nodata is None:
|
|
353
|
+
nodata = src.nodata
|
|
354
|
+
else:
|
|
355
|
+
# Check if we have all required parameters
|
|
356
|
+
if transform is None:
|
|
357
|
+
if pixel_size is None or bounds is None:
|
|
358
|
+
raise ValueError(
|
|
359
|
+
"Either reference_raster, transform, or both pixel_size and bounds must be provided."
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# Calculate transform from pixel size and bounds
|
|
363
|
+
if isinstance(pixel_size, (int, float)):
|
|
364
|
+
x_res = y_res = float(pixel_size)
|
|
365
|
+
else:
|
|
366
|
+
x_res, y_res = pixel_size
|
|
367
|
+
y_res = abs(y_res) * -1 # Convert to negative for north-up raster
|
|
368
|
+
|
|
369
|
+
left, bottom, right, top = bounds
|
|
370
|
+
transform = rasterio.transform.from_bounds(
|
|
371
|
+
left,
|
|
372
|
+
bottom,
|
|
373
|
+
right,
|
|
374
|
+
top,
|
|
375
|
+
int((right - left) / x_res),
|
|
376
|
+
int((top - bottom) / abs(y_res)),
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if output_shape is None:
|
|
380
|
+
# Calculate output shape from bounds and pixel size
|
|
381
|
+
if bounds is None or pixel_size is None:
|
|
382
|
+
raise ValueError(
|
|
383
|
+
"output_shape must be provided if reference_raster is not provided and "
|
|
384
|
+
"cannot be calculated from bounds and pixel_size."
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if isinstance(pixel_size, (int, float)):
|
|
388
|
+
x_res = y_res = float(pixel_size)
|
|
389
|
+
else:
|
|
390
|
+
x_res, y_res = pixel_size
|
|
391
|
+
|
|
392
|
+
left, bottom, right, top = bounds
|
|
393
|
+
width = int((right - left) / x_res)
|
|
394
|
+
height = int((top - bottom) / abs(y_res))
|
|
395
|
+
output_shape = (height, width)
|
|
396
|
+
|
|
397
|
+
# Ensure CRS is set
|
|
398
|
+
if crs is None:
|
|
399
|
+
raise ValueError(
|
|
400
|
+
"CRS must be provided either directly, from reference_raster, or from input vector data."
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# Reproject vector data if its CRS doesn't match the output CRS
|
|
404
|
+
if gdf.crs != crs:
|
|
405
|
+
print(f"Reprojecting vector data from {gdf.crs} to {crs}")
|
|
406
|
+
gdf = gdf.to_crs(crs)
|
|
407
|
+
|
|
408
|
+
# Create empty raster filled with fill_value
|
|
409
|
+
raster_data = np.full(output_shape, fill_value, dtype=dtype)
|
|
410
|
+
|
|
411
|
+
# Burn vector features into raster
|
|
412
|
+
if not gdf.empty:
|
|
413
|
+
# Prepare shapes for burning
|
|
414
|
+
if attribute_field is not None and attribute_field in gdf.columns:
|
|
415
|
+
# Use attribute field for values
|
|
416
|
+
shapes = [
|
|
417
|
+
(geom, value) for geom, value in zip(gdf.geometry, gdf[attribute_field])
|
|
418
|
+
]
|
|
419
|
+
else:
|
|
420
|
+
# Burn with value 1
|
|
421
|
+
shapes = [(geom, 1) for geom in gdf.geometry]
|
|
422
|
+
|
|
423
|
+
# Burn shapes into raster
|
|
424
|
+
burned = features.rasterize(
|
|
425
|
+
shapes=shapes,
|
|
426
|
+
out_shape=output_shape,
|
|
427
|
+
transform=transform,
|
|
428
|
+
fill=fill_value,
|
|
429
|
+
all_touched=all_touched,
|
|
430
|
+
dtype=dtype,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Update raster data
|
|
434
|
+
raster_data = burned
|
|
435
|
+
|
|
436
|
+
# Save raster if output path is provided
|
|
437
|
+
if output_path is not None:
|
|
438
|
+
# Create directory if it doesn't exist
|
|
439
|
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
|
440
|
+
|
|
441
|
+
# Define metadata
|
|
442
|
+
metadata = {
|
|
443
|
+
"driver": "GTiff",
|
|
444
|
+
"height": output_shape[0],
|
|
445
|
+
"width": output_shape[1],
|
|
446
|
+
"count": 1,
|
|
447
|
+
"dtype": raster_data.dtype,
|
|
448
|
+
"crs": crs,
|
|
449
|
+
"transform": transform,
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
# Add nodata value if provided
|
|
453
|
+
if nodata is not None:
|
|
454
|
+
metadata["nodata"] = nodata
|
|
455
|
+
|
|
456
|
+
# Write raster
|
|
457
|
+
with rasterio.open(output_path, "w", **metadata) as dst:
|
|
458
|
+
dst.write(raster_data, 1)
|
|
459
|
+
|
|
460
|
+
print(f"Rasterized data saved to {output_path}")
|
|
461
|
+
|
|
462
|
+
# Plot result if requested
|
|
463
|
+
if plot_result:
|
|
464
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
|
465
|
+
|
|
466
|
+
# Plot raster
|
|
467
|
+
im = ax.imshow(raster_data, cmap="viridis")
|
|
468
|
+
plt.colorbar(im, ax=ax, label=attribute_field if attribute_field else "Value")
|
|
469
|
+
|
|
470
|
+
# Plot vector boundaries for reference
|
|
471
|
+
if output_path is not None:
|
|
472
|
+
# Get the extent of the raster
|
|
473
|
+
with rasterio.open(output_path) as src:
|
|
474
|
+
bounds = src.bounds
|
|
475
|
+
raster_bbox = box(*bounds)
|
|
476
|
+
else:
|
|
477
|
+
# Calculate extent from transform and shape
|
|
478
|
+
height, width = output_shape
|
|
479
|
+
left, top = transform * (0, 0)
|
|
480
|
+
right, bottom = transform * (width, height)
|
|
481
|
+
raster_bbox = box(left, bottom, right, top)
|
|
482
|
+
|
|
483
|
+
# Clip vector to raster extent for clarity in plot
|
|
484
|
+
if not gdf.empty:
|
|
485
|
+
gdf_clipped = gpd.clip(gdf, raster_bbox)
|
|
486
|
+
if not gdf_clipped.empty:
|
|
487
|
+
gdf_clipped.boundary.plot(ax=ax, color="red", linewidth=1)
|
|
488
|
+
|
|
489
|
+
plt.title("Rasterized Vector Data")
|
|
490
|
+
plt.tight_layout()
|
|
491
|
+
plt.show()
|
|
492
|
+
|
|
493
|
+
return raster_data
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def batch_vector_to_raster(
|
|
497
|
+
vector_path,
|
|
498
|
+
output_dir,
|
|
499
|
+
attribute_field=None,
|
|
500
|
+
reference_rasters=None,
|
|
501
|
+
bounds_list=None,
|
|
502
|
+
output_filename_pattern="{vector_name}_{index}",
|
|
503
|
+
pixel_size=1.0,
|
|
504
|
+
all_touched=False,
|
|
505
|
+
fill_value=0,
|
|
506
|
+
dtype=np.uint8,
|
|
507
|
+
nodata=None,
|
|
508
|
+
):
|
|
509
|
+
"""
|
|
510
|
+
Batch convert vector data to multiple rasters based on different extents or reference rasters.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
|
|
514
|
+
output_dir (str): Directory to save output raster files.
|
|
515
|
+
attribute_field (str): Field name in the vector data to use for pixel values.
|
|
516
|
+
reference_rasters (list): List of paths to reference rasters for dimensions, transform and CRS.
|
|
517
|
+
bounds_list (list): List of bounds tuples (left, bottom, right, top) to use if reference_rasters not provided.
|
|
518
|
+
output_filename_pattern (str): Pattern for output filenames.
|
|
519
|
+
Can include {vector_name} and {index} placeholders.
|
|
520
|
+
pixel_size (float or tuple): Pixel size to use if reference_rasters not provided.
|
|
521
|
+
all_touched (bool): If True, all pixels touched by geometries will be burned in.
|
|
522
|
+
fill_value (int): Value to fill the raster with before burning in features.
|
|
523
|
+
dtype (numpy.dtype): Data type of the output raster.
|
|
524
|
+
nodata (int): No data value for the output raster.
|
|
525
|
+
|
|
526
|
+
Returns:
|
|
527
|
+
list: List of paths to the created raster files.
|
|
528
|
+
"""
|
|
529
|
+
# Create output directory if it doesn't exist
|
|
530
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
531
|
+
|
|
532
|
+
# Load vector data if it's a path
|
|
533
|
+
if isinstance(vector_path, str):
|
|
534
|
+
gdf = gpd.read_file(vector_path)
|
|
535
|
+
vector_name = os.path.splitext(os.path.basename(vector_path))[0]
|
|
536
|
+
else:
|
|
537
|
+
gdf = vector_path
|
|
538
|
+
vector_name = "vector"
|
|
539
|
+
|
|
540
|
+
# Check input parameters
|
|
541
|
+
if reference_rasters is None and bounds_list is None:
|
|
542
|
+
raise ValueError("Either reference_rasters or bounds_list must be provided.")
|
|
543
|
+
|
|
544
|
+
# Use reference_rasters if provided, otherwise use bounds_list
|
|
545
|
+
if reference_rasters is not None:
|
|
546
|
+
sources = reference_rasters
|
|
547
|
+
is_raster_reference = True
|
|
548
|
+
else:
|
|
549
|
+
sources = bounds_list
|
|
550
|
+
is_raster_reference = False
|
|
551
|
+
|
|
552
|
+
# Create output filenames
|
|
553
|
+
output_files = []
|
|
554
|
+
|
|
555
|
+
# Process each source (reference raster or bounds)
|
|
556
|
+
for i, source in enumerate(tqdm(sources, desc="Processing")):
|
|
557
|
+
# Generate output filename
|
|
558
|
+
output_filename = output_filename_pattern.format(
|
|
559
|
+
vector_name=vector_name, index=i
|
|
560
|
+
)
|
|
561
|
+
if not output_filename.endswith(".tif"):
|
|
562
|
+
output_filename += ".tif"
|
|
563
|
+
output_path = os.path.join(output_dir, output_filename)
|
|
564
|
+
|
|
565
|
+
if is_raster_reference:
|
|
566
|
+
# Use reference raster
|
|
567
|
+
vector_to_raster(
|
|
568
|
+
vector_path=gdf,
|
|
569
|
+
output_path=output_path,
|
|
570
|
+
reference_raster=source,
|
|
571
|
+
attribute_field=attribute_field,
|
|
572
|
+
all_touched=all_touched,
|
|
573
|
+
fill_value=fill_value,
|
|
574
|
+
dtype=dtype,
|
|
575
|
+
nodata=nodata,
|
|
576
|
+
)
|
|
577
|
+
else:
|
|
578
|
+
# Use bounds
|
|
579
|
+
vector_to_raster(
|
|
580
|
+
vector_path=gdf,
|
|
581
|
+
output_path=output_path,
|
|
582
|
+
bounds=source,
|
|
583
|
+
pixel_size=pixel_size,
|
|
584
|
+
attribute_field=attribute_field,
|
|
585
|
+
all_touched=all_touched,
|
|
586
|
+
fill_value=fill_value,
|
|
587
|
+
dtype=dtype,
|
|
588
|
+
nodata=nodata,
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
output_files.append(output_path)
|
|
592
|
+
|
|
593
|
+
return output_files
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
# # Example usage
|
|
597
|
+
# if __name__ == "__main__":
|
|
598
|
+
# # Single file conversion example
|
|
599
|
+
# raster_data = vector_to_raster(
|
|
600
|
+
# vector_path="buildings_train.geojson",
|
|
601
|
+
# output_path="buildings_train.tif",
|
|
602
|
+
# reference_raster="naip_train.tif", # Optional, can use other parameters instead
|
|
603
|
+
# # attribute_field="class", # Optional, uses field values for pixel values
|
|
604
|
+
# all_touched=True, # Ensures small features are captured
|
|
605
|
+
# plot_result=True
|
|
606
|
+
# )
|
|
607
|
+
|
|
608
|
+
# Example with custom dimensions
|
|
609
|
+
# raster_data = vector_to_raster(
|
|
610
|
+
# vector_path="path/to/buildings.geojson",
|
|
611
|
+
# output_path="path/to/rasterized_buildings.tif",
|
|
612
|
+
# pixel_size=0.5, # 0.5 meter resolution
|
|
613
|
+
# bounds=(454780, 5277567, 456282, 5278242), # from original data
|
|
614
|
+
# crs="EPSG:26911",
|
|
615
|
+
# output_shape=(1350, 3000), # custom dimensions
|
|
616
|
+
# attribute_field="class"
|
|
617
|
+
# )
|
|
618
|
+
|
|
619
|
+
# Batch conversion example
|
|
620
|
+
# output_files = batch_vector_to_raster(
|
|
621
|
+
# vector_path="path/to/buildings.geojson",
|
|
622
|
+
# output_dir="path/to/output",
|
|
623
|
+
# reference_rasters=["path/to/ref1.tif", "path/to/ref2.tif"],
|
|
624
|
+
# attribute_field="class",
|
|
625
|
+
# all_touched=True
|
|
626
|
+
# )
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def export_geotiff_tiles(
|
|
630
|
+
in_raster,
|
|
631
|
+
out_folder,
|
|
632
|
+
in_class_data,
|
|
633
|
+
tile_size=256,
|
|
634
|
+
stride=128,
|
|
635
|
+
class_value_field="class",
|
|
636
|
+
buffer_radius=0,
|
|
637
|
+
max_tiles=None,
|
|
638
|
+
quiet=False,
|
|
639
|
+
all_touched=True,
|
|
640
|
+
create_overview=False,
|
|
641
|
+
skip_empty_tiles=False,
|
|
642
|
+
):
|
|
643
|
+
"""
|
|
644
|
+
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
in_raster (str): Path to input raster image
|
|
648
|
+
out_folder (str): Path to output folder
|
|
649
|
+
in_class_data (str): Path to classification data - can be vector file or raster
|
|
650
|
+
tile_size (int): Size of tiles in pixels (square)
|
|
651
|
+
stride (int): Step size between tiles
|
|
652
|
+
class_value_field (str): Field containing class values (for vector data)
|
|
653
|
+
buffer_radius (float): Buffer to add around features (in units of the CRS)
|
|
654
|
+
max_tiles (int): Maximum number of tiles to process (None for all)
|
|
655
|
+
quiet (bool): If True, suppress non-essential output
|
|
656
|
+
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
|
657
|
+
create_overview (bool): Whether to create an overview image of all tiles
|
|
658
|
+
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
659
|
+
"""
|
|
660
|
+
# Create output directories
|
|
661
|
+
os.makedirs(out_folder, exist_ok=True)
|
|
662
|
+
image_dir = os.path.join(out_folder, "images")
|
|
663
|
+
os.makedirs(image_dir, exist_ok=True)
|
|
664
|
+
label_dir = os.path.join(out_folder, "labels")
|
|
665
|
+
os.makedirs(label_dir, exist_ok=True)
|
|
666
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
667
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
668
|
+
|
|
669
|
+
# Determine if class data is raster or vector
|
|
670
|
+
is_class_data_raster = False
|
|
671
|
+
if isinstance(in_class_data, str):
|
|
672
|
+
file_ext = Path(in_class_data).suffix.lower()
|
|
673
|
+
# Common raster extensions
|
|
674
|
+
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
675
|
+
try:
|
|
676
|
+
with rasterio.open(in_class_data) as src:
|
|
677
|
+
is_class_data_raster = True
|
|
678
|
+
if not quiet:
|
|
679
|
+
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
680
|
+
print(f"Raster CRS: {src.crs}")
|
|
681
|
+
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
682
|
+
except Exception:
|
|
683
|
+
is_class_data_raster = False
|
|
684
|
+
if not quiet:
|
|
685
|
+
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
|
686
|
+
|
|
687
|
+
# Open the input raster
|
|
688
|
+
with rasterio.open(in_raster) as src:
|
|
689
|
+
if not quiet:
|
|
690
|
+
print(f"\nRaster info for {in_raster}:")
|
|
691
|
+
print(f" CRS: {src.crs}")
|
|
692
|
+
print(f" Dimensions: {src.width} x {src.height}")
|
|
693
|
+
print(f" Bounds: {src.bounds}")
|
|
694
|
+
|
|
695
|
+
# Calculate number of tiles
|
|
696
|
+
num_tiles_x = math.ceil((src.width - tile_size) / stride) + 1
|
|
697
|
+
num_tiles_y = math.ceil((src.height - tile_size) / stride) + 1
|
|
698
|
+
total_tiles = num_tiles_x * num_tiles_y
|
|
699
|
+
|
|
700
|
+
if max_tiles is None:
|
|
701
|
+
max_tiles = total_tiles
|
|
702
|
+
|
|
703
|
+
# Process classification data
|
|
704
|
+
class_to_id = {}
|
|
705
|
+
|
|
706
|
+
if is_class_data_raster:
|
|
707
|
+
# Load raster class data
|
|
708
|
+
with rasterio.open(in_class_data) as class_src:
|
|
709
|
+
# Check if raster CRS matches
|
|
710
|
+
if class_src.crs != src.crs:
|
|
711
|
+
warnings.warn(
|
|
712
|
+
f"CRS mismatch: Class raster ({class_src.crs}) doesn't match input raster ({src.crs}). "
|
|
713
|
+
f"Results may be misaligned."
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
# Get unique values from raster
|
|
717
|
+
# Sample to avoid loading huge rasters
|
|
718
|
+
sample_data = class_src.read(
|
|
719
|
+
1,
|
|
720
|
+
out_shape=(
|
|
721
|
+
1,
|
|
722
|
+
min(class_src.height, 1000),
|
|
723
|
+
min(class_src.width, 1000),
|
|
724
|
+
),
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
unique_classes = np.unique(sample_data)
|
|
728
|
+
unique_classes = unique_classes[
|
|
729
|
+
unique_classes > 0
|
|
730
|
+
] # Remove 0 as it's typically background
|
|
731
|
+
|
|
732
|
+
if not quiet:
|
|
733
|
+
print(
|
|
734
|
+
f"Found {len(unique_classes)} unique classes in raster: {unique_classes}"
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# Create class mapping
|
|
738
|
+
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
739
|
+
else:
|
|
740
|
+
# Load vector class data
|
|
741
|
+
try:
|
|
742
|
+
gdf = gpd.read_file(in_class_data)
|
|
743
|
+
if not quiet:
|
|
744
|
+
print(f"Loaded {len(gdf)} features from {in_class_data}")
|
|
745
|
+
print(f"Vector CRS: {gdf.crs}")
|
|
746
|
+
|
|
747
|
+
# Always reproject to match raster CRS
|
|
748
|
+
if gdf.crs != src.crs:
|
|
749
|
+
if not quiet:
|
|
750
|
+
print(f"Reprojecting features from {gdf.crs} to {src.crs}")
|
|
751
|
+
gdf = gdf.to_crs(src.crs)
|
|
752
|
+
|
|
753
|
+
# Apply buffer if specified
|
|
754
|
+
if buffer_radius > 0:
|
|
755
|
+
gdf["geometry"] = gdf.buffer(buffer_radius)
|
|
756
|
+
if not quiet:
|
|
757
|
+
print(f"Applied buffer of {buffer_radius} units")
|
|
758
|
+
|
|
759
|
+
# Check if class_value_field exists
|
|
760
|
+
if class_value_field in gdf.columns:
|
|
761
|
+
unique_classes = gdf[class_value_field].unique()
|
|
762
|
+
if not quiet:
|
|
763
|
+
print(
|
|
764
|
+
f"Found {len(unique_classes)} unique classes: {unique_classes}"
|
|
765
|
+
)
|
|
766
|
+
# Create class mapping
|
|
767
|
+
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
768
|
+
else:
|
|
769
|
+
if not quiet:
|
|
770
|
+
print(
|
|
771
|
+
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
|
772
|
+
)
|
|
773
|
+
class_to_id = {1: 1} # Default mapping
|
|
774
|
+
except Exception as e:
|
|
775
|
+
raise ValueError(f"Error processing vector data: {e}")
|
|
776
|
+
|
|
777
|
+
# Create progress bar
|
|
778
|
+
pbar = tqdm(
|
|
779
|
+
total=min(total_tiles, max_tiles),
|
|
780
|
+
desc="Generating tiles",
|
|
781
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
# Track statistics for summary
|
|
785
|
+
stats = {
|
|
786
|
+
"total_tiles": 0,
|
|
787
|
+
"tiles_with_features": 0,
|
|
788
|
+
"feature_pixels": 0,
|
|
789
|
+
"errors": 0,
|
|
790
|
+
"tile_coordinates": [], # For overview image
|
|
791
|
+
}
|
|
792
|
+
|
|
793
|
+
# Process tiles
|
|
794
|
+
tile_index = 0
|
|
795
|
+
for y in range(num_tiles_y):
|
|
796
|
+
for x in range(num_tiles_x):
|
|
797
|
+
if tile_index >= max_tiles:
|
|
798
|
+
break
|
|
799
|
+
|
|
800
|
+
# Calculate window coordinates
|
|
801
|
+
window_x = x * stride
|
|
802
|
+
window_y = y * stride
|
|
803
|
+
|
|
804
|
+
# Adjust for edge cases
|
|
805
|
+
if window_x + tile_size > src.width:
|
|
806
|
+
window_x = src.width - tile_size
|
|
807
|
+
if window_y + tile_size > src.height:
|
|
808
|
+
window_y = src.height - tile_size
|
|
809
|
+
|
|
810
|
+
# Define window
|
|
811
|
+
window = Window(window_x, window_y, tile_size, tile_size)
|
|
812
|
+
|
|
813
|
+
# Get window transform and bounds
|
|
814
|
+
window_transform = src.window_transform(window)
|
|
815
|
+
|
|
816
|
+
# Calculate window bounds
|
|
817
|
+
minx = window_transform[2] # Upper left x
|
|
818
|
+
maxy = window_transform[5] # Upper left y
|
|
819
|
+
maxx = minx + tile_size * window_transform[0] # Add width
|
|
820
|
+
miny = maxy + tile_size * window_transform[4] # Add height
|
|
821
|
+
|
|
822
|
+
window_bounds = box(minx, miny, maxx, maxy)
|
|
823
|
+
|
|
824
|
+
# Store tile coordinates for overview
|
|
825
|
+
if create_overview:
|
|
826
|
+
stats["tile_coordinates"].append(
|
|
827
|
+
{
|
|
828
|
+
"index": tile_index,
|
|
829
|
+
"x": window_x,
|
|
830
|
+
"y": window_y,
|
|
831
|
+
"bounds": [minx, miny, maxx, maxy],
|
|
832
|
+
"has_features": False,
|
|
833
|
+
}
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
# Create label mask
|
|
837
|
+
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
838
|
+
has_features = False
|
|
839
|
+
|
|
840
|
+
# Process classification data to create labels
|
|
841
|
+
if is_class_data_raster:
|
|
842
|
+
# For raster class data
|
|
843
|
+
with rasterio.open(in_class_data) as class_src:
|
|
844
|
+
# Calculate window in class raster
|
|
845
|
+
src_bounds = src.bounds
|
|
846
|
+
class_bounds = class_src.bounds
|
|
847
|
+
|
|
848
|
+
# Check if windows overlap
|
|
849
|
+
if (
|
|
850
|
+
src_bounds.left > class_bounds.right
|
|
851
|
+
or src_bounds.right < class_bounds.left
|
|
852
|
+
or src_bounds.bottom > class_bounds.top
|
|
853
|
+
or src_bounds.top < class_bounds.bottom
|
|
854
|
+
):
|
|
855
|
+
warnings.warn(
|
|
856
|
+
"Class raster and input raster do not overlap."
|
|
857
|
+
)
|
|
858
|
+
else:
|
|
859
|
+
# Get corresponding window in class raster
|
|
860
|
+
window_class = rasterio.windows.from_bounds(
|
|
861
|
+
minx, miny, maxx, maxy, class_src.transform
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# Read label data
|
|
865
|
+
try:
|
|
866
|
+
label_data = class_src.read(
|
|
867
|
+
1,
|
|
868
|
+
window=window_class,
|
|
869
|
+
boundless=True,
|
|
870
|
+
out_shape=(tile_size, tile_size),
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
# Remap class values if needed
|
|
874
|
+
if class_to_id:
|
|
875
|
+
remapped_data = np.zeros_like(label_data)
|
|
876
|
+
for orig_val, new_val in class_to_id.items():
|
|
877
|
+
remapped_data[label_data == orig_val] = new_val
|
|
878
|
+
label_mask = remapped_data
|
|
879
|
+
else:
|
|
880
|
+
label_mask = label_data
|
|
881
|
+
|
|
882
|
+
# Check if we have any features
|
|
883
|
+
if np.any(label_mask > 0):
|
|
884
|
+
has_features = True
|
|
885
|
+
stats["feature_pixels"] += np.count_nonzero(
|
|
886
|
+
label_mask
|
|
887
|
+
)
|
|
888
|
+
except Exception as e:
|
|
889
|
+
pbar.write(f"Error reading class raster window: {e}")
|
|
890
|
+
stats["errors"] += 1
|
|
891
|
+
else:
|
|
892
|
+
# For vector class data
|
|
893
|
+
# Find features that intersect with window
|
|
894
|
+
window_features = gdf[gdf.intersects(window_bounds)]
|
|
895
|
+
|
|
896
|
+
if len(window_features) > 0:
|
|
897
|
+
for idx, feature in window_features.iterrows():
|
|
898
|
+
# Get class value
|
|
899
|
+
if class_value_field in feature:
|
|
900
|
+
class_val = feature[class_value_field]
|
|
901
|
+
class_id = class_to_id.get(class_val, 1)
|
|
902
|
+
else:
|
|
903
|
+
class_id = 1
|
|
904
|
+
|
|
905
|
+
# Get geometry in window coordinates
|
|
906
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
907
|
+
if not geom.is_empty:
|
|
908
|
+
try:
|
|
909
|
+
# Rasterize feature
|
|
910
|
+
feature_mask = features.rasterize(
|
|
911
|
+
[(geom, class_id)],
|
|
912
|
+
out_shape=(tile_size, tile_size),
|
|
913
|
+
transform=window_transform,
|
|
914
|
+
fill=0,
|
|
915
|
+
all_touched=all_touched,
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
# Add to label mask
|
|
919
|
+
label_mask = np.maximum(label_mask, feature_mask)
|
|
920
|
+
|
|
921
|
+
# Check if the feature was actually rasterized
|
|
922
|
+
if np.any(feature_mask):
|
|
923
|
+
has_features = True
|
|
924
|
+
if create_overview and tile_index < len(
|
|
925
|
+
stats["tile_coordinates"]
|
|
926
|
+
):
|
|
927
|
+
stats["tile_coordinates"][tile_index][
|
|
928
|
+
"has_features"
|
|
929
|
+
] = True
|
|
930
|
+
except Exception as e:
|
|
931
|
+
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
932
|
+
stats["errors"] += 1
|
|
933
|
+
|
|
934
|
+
# Skip tile if no features and skip_empty_tiles is True
|
|
935
|
+
if skip_empty_tiles and not has_features:
|
|
936
|
+
pbar.update(1)
|
|
937
|
+
tile_index += 1
|
|
938
|
+
continue
|
|
939
|
+
|
|
940
|
+
# Read image data
|
|
941
|
+
image_data = src.read(window=window)
|
|
942
|
+
|
|
943
|
+
# Export image as GeoTIFF
|
|
944
|
+
image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
|
|
945
|
+
|
|
946
|
+
# Create profile for image GeoTIFF
|
|
947
|
+
image_profile = src.profile.copy()
|
|
948
|
+
image_profile.update(
|
|
949
|
+
{
|
|
950
|
+
"height": tile_size,
|
|
951
|
+
"width": tile_size,
|
|
952
|
+
"count": image_data.shape[0],
|
|
953
|
+
"transform": window_transform,
|
|
954
|
+
}
|
|
955
|
+
)
|
|
956
|
+
|
|
957
|
+
# Save image as GeoTIFF
|
|
958
|
+
try:
|
|
959
|
+
with rasterio.open(image_path, "w", **image_profile) as dst:
|
|
960
|
+
dst.write(image_data)
|
|
961
|
+
stats["total_tiles"] += 1
|
|
962
|
+
except Exception as e:
|
|
963
|
+
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
|
964
|
+
stats["errors"] += 1
|
|
965
|
+
|
|
966
|
+
# Create profile for label GeoTIFF
|
|
967
|
+
label_profile = {
|
|
968
|
+
"driver": "GTiff",
|
|
969
|
+
"height": tile_size,
|
|
970
|
+
"width": tile_size,
|
|
971
|
+
"count": 1,
|
|
972
|
+
"dtype": "uint8",
|
|
973
|
+
"crs": src.crs,
|
|
974
|
+
"transform": window_transform,
|
|
975
|
+
}
|
|
976
|
+
|
|
977
|
+
# Export label as GeoTIFF
|
|
978
|
+
label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
|
|
979
|
+
try:
|
|
980
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
981
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
|
982
|
+
|
|
983
|
+
if has_features:
|
|
984
|
+
stats["tiles_with_features"] += 1
|
|
985
|
+
stats["feature_pixels"] += np.count_nonzero(label_mask)
|
|
986
|
+
except Exception as e:
|
|
987
|
+
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
|
988
|
+
stats["errors"] += 1
|
|
989
|
+
|
|
990
|
+
# Create XML annotation for object detection if using vector class data
|
|
991
|
+
if (
|
|
992
|
+
not is_class_data_raster
|
|
993
|
+
and "gdf" in locals()
|
|
994
|
+
and len(window_features) > 0
|
|
995
|
+
):
|
|
996
|
+
# Create XML annotation
|
|
997
|
+
root = ET.Element("annotation")
|
|
998
|
+
ET.SubElement(root, "folder").text = "images"
|
|
999
|
+
ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
|
|
1000
|
+
|
|
1001
|
+
size = ET.SubElement(root, "size")
|
|
1002
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
|
1003
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
|
1004
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
1005
|
+
|
|
1006
|
+
# Add georeference information
|
|
1007
|
+
geo = ET.SubElement(root, "georeference")
|
|
1008
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
1009
|
+
ET.SubElement(geo, "transform").text = str(
|
|
1010
|
+
window_transform
|
|
1011
|
+
).replace("\n", "")
|
|
1012
|
+
ET.SubElement(geo, "bounds").text = (
|
|
1013
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
# Add objects
|
|
1017
|
+
for idx, feature in window_features.iterrows():
|
|
1018
|
+
# Get feature class
|
|
1019
|
+
if class_value_field in feature:
|
|
1020
|
+
class_val = feature[class_value_field]
|
|
1021
|
+
else:
|
|
1022
|
+
class_val = "object"
|
|
1023
|
+
|
|
1024
|
+
# Get geometry bounds in pixel coordinates
|
|
1025
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
1026
|
+
if not geom.is_empty:
|
|
1027
|
+
# Get bounds in world coordinates
|
|
1028
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
1029
|
+
|
|
1030
|
+
# Convert to pixel coordinates
|
|
1031
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
1032
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
1033
|
+
|
|
1034
|
+
# Ensure coordinates are within tile bounds
|
|
1035
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
|
1036
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
|
1037
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
|
1038
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
|
1039
|
+
|
|
1040
|
+
# Only add if the box has non-zero area
|
|
1041
|
+
if xmax > xmin and ymax > ymin:
|
|
1042
|
+
obj = ET.SubElement(root, "object")
|
|
1043
|
+
ET.SubElement(obj, "name").text = str(class_val)
|
|
1044
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
1045
|
+
|
|
1046
|
+
bbox = ET.SubElement(obj, "bndbox")
|
|
1047
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
1048
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
1049
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
1050
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
1051
|
+
|
|
1052
|
+
# Save XML
|
|
1053
|
+
tree = ET.ElementTree(root)
|
|
1054
|
+
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
1055
|
+
tree.write(xml_path)
|
|
1056
|
+
|
|
1057
|
+
# Update progress bar
|
|
1058
|
+
pbar.update(1)
|
|
1059
|
+
pbar.set_description(
|
|
1060
|
+
f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
tile_index += 1
|
|
1064
|
+
if tile_index >= max_tiles:
|
|
1065
|
+
break
|
|
1066
|
+
|
|
1067
|
+
if tile_index >= max_tiles:
|
|
1068
|
+
break
|
|
1069
|
+
|
|
1070
|
+
# Close progress bar
|
|
1071
|
+
pbar.close()
|
|
1072
|
+
|
|
1073
|
+
# Create overview image if requested
|
|
1074
|
+
if create_overview and stats["tile_coordinates"]:
|
|
1075
|
+
try:
|
|
1076
|
+
create_overview_image(
|
|
1077
|
+
src,
|
|
1078
|
+
stats["tile_coordinates"],
|
|
1079
|
+
os.path.join(out_folder, "overview.png"),
|
|
1080
|
+
tile_size,
|
|
1081
|
+
stride,
|
|
1082
|
+
)
|
|
1083
|
+
except Exception as e:
|
|
1084
|
+
print(f"Failed to create overview image: {e}")
|
|
1085
|
+
|
|
1086
|
+
# Report results
|
|
1087
|
+
if not quiet:
|
|
1088
|
+
print("\n------- Export Summary -------")
|
|
1089
|
+
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
1090
|
+
print(
|
|
1091
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
1092
|
+
)
|
|
1093
|
+
if stats["tiles_with_features"] > 0:
|
|
1094
|
+
print(
|
|
1095
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
1096
|
+
)
|
|
1097
|
+
if stats["errors"] > 0:
|
|
1098
|
+
print(f"Errors encountered: {stats['errors']}")
|
|
1099
|
+
print(f"Output saved to: {out_folder}")
|
|
1100
|
+
|
|
1101
|
+
# Verify georeference in a sample image and label
|
|
1102
|
+
if stats["total_tiles"] > 0:
|
|
1103
|
+
print("\n------- Georeference Verification -------")
|
|
1104
|
+
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
|
1105
|
+
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
|
1106
|
+
|
|
1107
|
+
if os.path.exists(sample_image):
|
|
1108
|
+
try:
|
|
1109
|
+
with rasterio.open(sample_image) as img:
|
|
1110
|
+
print(f"Image CRS: {img.crs}")
|
|
1111
|
+
print(f"Image transform: {img.transform}")
|
|
1112
|
+
print(
|
|
1113
|
+
f"Image has georeference: {img.crs is not None and img.transform is not None}"
|
|
1114
|
+
)
|
|
1115
|
+
print(
|
|
1116
|
+
f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
|
|
1117
|
+
)
|
|
1118
|
+
except Exception as e:
|
|
1119
|
+
print(f"Error verifying image georeference: {e}")
|
|
1120
|
+
|
|
1121
|
+
if os.path.exists(sample_label):
|
|
1122
|
+
try:
|
|
1123
|
+
with rasterio.open(sample_label) as lbl:
|
|
1124
|
+
print(f"Label CRS: {lbl.crs}")
|
|
1125
|
+
print(f"Label transform: {lbl.transform}")
|
|
1126
|
+
print(
|
|
1127
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
1128
|
+
)
|
|
1129
|
+
print(
|
|
1130
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
1131
|
+
)
|
|
1132
|
+
except Exception as e:
|
|
1133
|
+
print(f"Error verifying label georeference: {e}")
|
|
1134
|
+
|
|
1135
|
+
# Return statistics dictionary for further processing if needed
|
|
1136
|
+
return stats
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
def create_overview_image(src, tile_coordinates, output_path, tile_size, stride):
|
|
1140
|
+
"""Create an overview image showing all tiles and their status."""
|
|
1141
|
+
# Read a reduced version of the source image
|
|
1142
|
+
overview_scale = max(
|
|
1143
|
+
1, int(max(src.width, src.height) / 2000)
|
|
1144
|
+
) # Scale to max ~2000px
|
|
1145
|
+
overview_width = src.width // overview_scale
|
|
1146
|
+
overview_height = src.height // overview_scale
|
|
1147
|
+
|
|
1148
|
+
# Read downsampled image
|
|
1149
|
+
overview_data = src.read(
|
|
1150
|
+
out_shape=(src.count, overview_height, overview_width),
|
|
1151
|
+
resampling=rasterio.enums.Resampling.average,
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
# Create RGB image for display
|
|
1155
|
+
if overview_data.shape[0] >= 3:
|
|
1156
|
+
rgb = np.moveaxis(overview_data[:3], 0, -1)
|
|
1157
|
+
else:
|
|
1158
|
+
# For single band, create grayscale RGB
|
|
1159
|
+
rgb = np.stack([overview_data[0], overview_data[0], overview_data[0]], axis=-1)
|
|
1160
|
+
|
|
1161
|
+
# Normalize for display
|
|
1162
|
+
for i in range(rgb.shape[-1]):
|
|
1163
|
+
band = rgb[..., i]
|
|
1164
|
+
non_zero = band[band > 0]
|
|
1165
|
+
if len(non_zero) > 0:
|
|
1166
|
+
p2, p98 = np.percentile(non_zero, (2, 98))
|
|
1167
|
+
rgb[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
|
1168
|
+
|
|
1169
|
+
# Create figure
|
|
1170
|
+
plt.figure(figsize=(12, 12))
|
|
1171
|
+
plt.imshow(rgb)
|
|
1172
|
+
|
|
1173
|
+
# Draw tile boundaries
|
|
1174
|
+
for tile in tile_coordinates:
|
|
1175
|
+
# Convert bounds to pixel coordinates in overview
|
|
1176
|
+
bounds = tile["bounds"]
|
|
1177
|
+
# Calculate scaled pixel coordinates
|
|
1178
|
+
x_min = int((tile["x"]) / overview_scale)
|
|
1179
|
+
y_min = int((tile["y"]) / overview_scale)
|
|
1180
|
+
width = int(tile_size / overview_scale)
|
|
1181
|
+
height = int(tile_size / overview_scale)
|
|
1182
|
+
|
|
1183
|
+
# Draw rectangle
|
|
1184
|
+
color = "lime" if tile["has_features"] else "red"
|
|
1185
|
+
rect = plt.Rectangle(
|
|
1186
|
+
(x_min, y_min), width, height, fill=False, edgecolor=color, linewidth=0.5
|
|
1187
|
+
)
|
|
1188
|
+
plt.gca().add_patch(rect)
|
|
1189
|
+
|
|
1190
|
+
# Add tile number if not too crowded
|
|
1191
|
+
if width > 20 and height > 20:
|
|
1192
|
+
plt.text(
|
|
1193
|
+
x_min + width / 2,
|
|
1194
|
+
y_min + height / 2,
|
|
1195
|
+
str(tile["index"]),
|
|
1196
|
+
color="white",
|
|
1197
|
+
ha="center",
|
|
1198
|
+
va="center",
|
|
1199
|
+
fontsize=8,
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
plt.title("Tile Overview (Green = Contains Features, Red = Empty)")
|
|
1203
|
+
plt.axis("off")
|
|
1204
|
+
plt.tight_layout()
|
|
1205
|
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
|
1206
|
+
plt.close()
|
|
1207
|
+
|
|
1208
|
+
print(f"Overview image saved to {output_path}")
|
|
1209
|
+
|
|
1210
|
+
|
|
1211
|
+
# # Example usage
|
|
1212
|
+
# if __name__ == "__main__":
|
|
1213
|
+
# # Try to install tqdm if not available
|
|
1214
|
+
# try:
|
|
1215
|
+
# import tqdm
|
|
1216
|
+
# except ImportError:
|
|
1217
|
+
# print("Installing tqdm progress bar library...")
|
|
1218
|
+
# import sys
|
|
1219
|
+
# import subprocess
|
|
1220
|
+
|
|
1221
|
+
# subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"])
|
|
1222
|
+
# import tqdm
|
|
1223
|
+
|
|
1224
|
+
# # Example with vector class data
|
|
1225
|
+
# export_geotiff_tiles(
|
|
1226
|
+
# in_raster="naip_train.tif",
|
|
1227
|
+
# out_folder="geotiff_output_vector",
|
|
1228
|
+
# in_class_data="buildings_train.geojson",
|
|
1229
|
+
# tile_size=256,
|
|
1230
|
+
# stride=128,
|
|
1231
|
+
# class_value_field="class",
|
|
1232
|
+
# buffer_radius=2,
|
|
1233
|
+
# create_overview=True,
|
|
1234
|
+
# )
|
|
1235
|
+
|
|
1236
|
+
# # Example with raster class data
|
|
1237
|
+
# export_geotiff_tiles(
|
|
1238
|
+
# in_raster="naip_train.tif",
|
|
1239
|
+
# out_folder="geotiff_output_raster",
|
|
1240
|
+
# in_class_data="buildings_train.tif", # This would be a raster mask
|
|
1241
|
+
# tile_size=256,
|
|
1242
|
+
# stride=128,
|
|
1243
|
+
# create_overview=True,
|
|
1244
|
+
# skip_empty_tiles=True,
|
|
1245
|
+
# )
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
def export_training_data(
|
|
1249
|
+
in_raster,
|
|
1250
|
+
out_folder,
|
|
1251
|
+
in_class_data,
|
|
1252
|
+
image_chip_format="GEOTIFF",
|
|
1253
|
+
tile_size_x=256,
|
|
1254
|
+
tile_size_y=256,
|
|
1255
|
+
stride_x=None,
|
|
1256
|
+
stride_y=None,
|
|
1257
|
+
output_nofeature_tiles=True,
|
|
1258
|
+
metadata_format="PASCAL_VOC",
|
|
1259
|
+
start_index=0,
|
|
1260
|
+
class_value_field="class",
|
|
1261
|
+
buffer_radius=0,
|
|
1262
|
+
in_mask_polygons=None,
|
|
1263
|
+
rotation_angle=0,
|
|
1264
|
+
reference_system=None,
|
|
1265
|
+
blacken_around_feature=False,
|
|
1266
|
+
crop_mode="FIXED_SIZE", # Implemented but not fully used yet
|
|
1267
|
+
in_raster2=None,
|
|
1268
|
+
in_instance_data=None,
|
|
1269
|
+
instance_class_value_field=None, # Implemented but not fully used yet
|
|
1270
|
+
min_polygon_overlap_ratio=0.0,
|
|
1271
|
+
all_touched=True,
|
|
1272
|
+
save_geotiff=True,
|
|
1273
|
+
quiet=False,
|
|
1274
|
+
):
|
|
1275
|
+
"""
|
|
1276
|
+
Export training data for deep learning using TorchGeo with progress bar.
|
|
1277
|
+
|
|
1278
|
+
Args:
|
|
1279
|
+
in_raster (str): Path to input raster image.
|
|
1280
|
+
out_folder (str): Output folder path where chips and labels will be saved.
|
|
1281
|
+
in_class_data (str): Path to vector file containing class polygons.
|
|
1282
|
+
image_chip_format (str): Output image format (PNG, JPEG, TIFF, GEOTIFF).
|
|
1283
|
+
tile_size_x (int): Width of image chips in pixels.
|
|
1284
|
+
tile_size_y (int): Height of image chips in pixels.
|
|
1285
|
+
stride_x (int): Horizontal stride between chips. If None, uses tile_size_x.
|
|
1286
|
+
stride_y (int): Vertical stride between chips. If None, uses tile_size_y.
|
|
1287
|
+
output_nofeature_tiles (bool): Whether to export chips without features.
|
|
1288
|
+
metadata_format (str): Output metadata format (PASCAL_VOC, KITTI, COCO).
|
|
1289
|
+
start_index (int): Starting index for chip filenames.
|
|
1290
|
+
class_value_field (str): Field name in in_class_data containing class values.
|
|
1291
|
+
buffer_radius (float): Buffer radius around features (in CRS units).
|
|
1292
|
+
in_mask_polygons (str): Path to vector file containing mask polygons.
|
|
1293
|
+
rotation_angle (float): Rotation angle in degrees.
|
|
1294
|
+
reference_system (str): Reference system code.
|
|
1295
|
+
blacken_around_feature (bool): Whether to mask areas outside of features.
|
|
1296
|
+
crop_mode (str): Crop mode (FIXED_SIZE, CENTERED_ON_FEATURE).
|
|
1297
|
+
in_raster2 (str): Path to secondary raster image.
|
|
1298
|
+
in_instance_data (str): Path to vector file containing instance polygons.
|
|
1299
|
+
instance_class_value_field (str): Field name in in_instance_data for instance classes.
|
|
1300
|
+
min_polygon_overlap_ratio (float): Minimum overlap ratio for polygons.
|
|
1301
|
+
all_touched (bool): Whether to use all_touched=True in rasterization.
|
|
1302
|
+
save_geotiff (bool): Whether to save as GeoTIFF with georeferencing.
|
|
1303
|
+
quiet (bool): If True, suppress most output messages.
|
|
1304
|
+
"""
|
|
1305
|
+
# Create output directories
|
|
1306
|
+
image_dir = os.path.join(out_folder, "images")
|
|
1307
|
+
os.makedirs(image_dir, exist_ok=True)
|
|
1308
|
+
|
|
1309
|
+
label_dir = os.path.join(out_folder, "labels")
|
|
1310
|
+
os.makedirs(label_dir, exist_ok=True)
|
|
1311
|
+
|
|
1312
|
+
# Define annotation directories based on metadata format
|
|
1313
|
+
if metadata_format == "PASCAL_VOC":
|
|
1314
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
1315
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
1316
|
+
elif metadata_format == "COCO":
|
|
1317
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
1318
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
1319
|
+
# Initialize COCO annotations dictionary
|
|
1320
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
1321
|
+
|
|
1322
|
+
# Initialize statistics dictionary
|
|
1323
|
+
stats = {
|
|
1324
|
+
"total_tiles": 0,
|
|
1325
|
+
"tiles_with_features": 0,
|
|
1326
|
+
"feature_pixels": 0,
|
|
1327
|
+
"errors": 0,
|
|
1328
|
+
}
|
|
1329
|
+
|
|
1330
|
+
# Open raster
|
|
1331
|
+
with rasterio.open(in_raster) as src:
|
|
1332
|
+
if not quiet:
|
|
1333
|
+
print(f"\nRaster info for {in_raster}:")
|
|
1334
|
+
print(f" CRS: {src.crs}")
|
|
1335
|
+
print(f" Dimensions: {src.width} x {src.height}")
|
|
1336
|
+
print(f" Bounds: {src.bounds}")
|
|
1337
|
+
|
|
1338
|
+
# Set defaults for stride if not provided
|
|
1339
|
+
if stride_x is None:
|
|
1340
|
+
stride_x = tile_size_x
|
|
1341
|
+
if stride_y is None:
|
|
1342
|
+
stride_y = tile_size_y
|
|
1343
|
+
|
|
1344
|
+
# Calculate number of tiles in x and y directions
|
|
1345
|
+
num_tiles_x = math.ceil((src.width - tile_size_x) / stride_x) + 1
|
|
1346
|
+
num_tiles_y = math.ceil((src.height - tile_size_y) / stride_y) + 1
|
|
1347
|
+
total_tiles = num_tiles_x * num_tiles_y
|
|
1348
|
+
|
|
1349
|
+
# Read class data
|
|
1350
|
+
gdf = gpd.read_file(in_class_data)
|
|
1351
|
+
if not quiet:
|
|
1352
|
+
print(f"Loaded {len(gdf)} features from {in_class_data}")
|
|
1353
|
+
print(f"Available columns: {gdf.columns.tolist()}")
|
|
1354
|
+
print(f"GeoJSON CRS: {gdf.crs}")
|
|
1355
|
+
|
|
1356
|
+
# Check if class_value_field exists
|
|
1357
|
+
if class_value_field not in gdf.columns:
|
|
1358
|
+
if not quiet:
|
|
1359
|
+
print(
|
|
1360
|
+
f"WARNING: '{class_value_field}' field not found in the input data. Using default class value 1."
|
|
1361
|
+
)
|
|
1362
|
+
# Add a default class column
|
|
1363
|
+
gdf[class_value_field] = 1
|
|
1364
|
+
unique_classes = [1]
|
|
1365
|
+
else:
|
|
1366
|
+
# Print unique classes for debugging
|
|
1367
|
+
unique_classes = gdf[class_value_field].unique()
|
|
1368
|
+
if not quiet:
|
|
1369
|
+
print(f"Found {len(unique_classes)} unique classes: {unique_classes}")
|
|
1370
|
+
|
|
1371
|
+
# CRITICAL: Always reproject to match raster CRS to ensure proper alignment
|
|
1372
|
+
if gdf.crs != src.crs:
|
|
1373
|
+
if not quiet:
|
|
1374
|
+
print(f"Reprojecting features from {gdf.crs} to {src.crs}")
|
|
1375
|
+
gdf = gdf.to_crs(src.crs)
|
|
1376
|
+
elif reference_system and gdf.crs != reference_system:
|
|
1377
|
+
if not quiet:
|
|
1378
|
+
print(
|
|
1379
|
+
f"Reprojecting features to specified reference system {reference_system}"
|
|
1380
|
+
)
|
|
1381
|
+
gdf = gdf.to_crs(reference_system)
|
|
1382
|
+
|
|
1383
|
+
# Check overlap between raster and vector data
|
|
1384
|
+
raster_bounds = box(*src.bounds)
|
|
1385
|
+
vector_bounds = box(*gdf.total_bounds)
|
|
1386
|
+
if not raster_bounds.intersects(vector_bounds):
|
|
1387
|
+
if not quiet:
|
|
1388
|
+
print(
|
|
1389
|
+
"WARNING: The vector data doesn't intersect with the raster extent!"
|
|
1390
|
+
)
|
|
1391
|
+
print(f"Raster bounds: {src.bounds}")
|
|
1392
|
+
print(f"Vector bounds: {gdf.total_bounds}")
|
|
1393
|
+
else:
|
|
1394
|
+
overlap = (
|
|
1395
|
+
raster_bounds.intersection(vector_bounds).area / vector_bounds.area
|
|
1396
|
+
)
|
|
1397
|
+
if not quiet:
|
|
1398
|
+
print(f"Overlap between raster and vector: {overlap:.2%}")
|
|
1399
|
+
|
|
1400
|
+
# Apply buffer if specified
|
|
1401
|
+
if buffer_radius > 0:
|
|
1402
|
+
gdf["geometry"] = gdf.buffer(buffer_radius)
|
|
1403
|
+
|
|
1404
|
+
# Initialize class mapping (ensure all classes are mapped to non-zero values)
|
|
1405
|
+
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
1406
|
+
|
|
1407
|
+
# Store category info for COCO format
|
|
1408
|
+
if metadata_format == "COCO":
|
|
1409
|
+
for cls_val in unique_classes:
|
|
1410
|
+
coco_annotations["categories"].append(
|
|
1411
|
+
{
|
|
1412
|
+
"id": class_to_id[cls_val],
|
|
1413
|
+
"name": str(cls_val),
|
|
1414
|
+
"supercategory": "object",
|
|
1415
|
+
}
|
|
1416
|
+
)
|
|
1417
|
+
|
|
1418
|
+
# Load mask polygons if provided
|
|
1419
|
+
mask_gdf = None
|
|
1420
|
+
if in_mask_polygons:
|
|
1421
|
+
mask_gdf = gpd.read_file(in_mask_polygons)
|
|
1422
|
+
if reference_system:
|
|
1423
|
+
mask_gdf = mask_gdf.to_crs(reference_system)
|
|
1424
|
+
elif mask_gdf.crs != src.crs:
|
|
1425
|
+
mask_gdf = mask_gdf.to_crs(src.crs)
|
|
1426
|
+
|
|
1427
|
+
# Process instance data if provided
|
|
1428
|
+
instance_gdf = None
|
|
1429
|
+
if in_instance_data:
|
|
1430
|
+
instance_gdf = gpd.read_file(in_instance_data)
|
|
1431
|
+
if reference_system:
|
|
1432
|
+
instance_gdf = instance_gdf.to_crs(reference_system)
|
|
1433
|
+
elif instance_gdf.crs != src.crs:
|
|
1434
|
+
instance_gdf = instance_gdf.to_crs(src.crs)
|
|
1435
|
+
|
|
1436
|
+
# Load secondary raster if provided
|
|
1437
|
+
src2 = None
|
|
1438
|
+
if in_raster2:
|
|
1439
|
+
src2 = rasterio.open(in_raster2)
|
|
1440
|
+
|
|
1441
|
+
# Set up augmentation if rotation is specified
|
|
1442
|
+
augmentation = None
|
|
1443
|
+
if rotation_angle != 0:
|
|
1444
|
+
# Fixed: Added data_keys parameter to AugmentationSequential
|
|
1445
|
+
augmentation = torchgeo.transforms.AugmentationSequential(
|
|
1446
|
+
torch.nn.ModuleList([RandomRotation(rotation_angle)]),
|
|
1447
|
+
data_keys=["image"], # Add data_keys parameter
|
|
1448
|
+
)
|
|
1449
|
+
|
|
1450
|
+
# Initialize annotation ID for COCO format
|
|
1451
|
+
ann_id = 0
|
|
1452
|
+
|
|
1453
|
+
# Create progress bar
|
|
1454
|
+
pbar = tqdm(
|
|
1455
|
+
total=total_tiles,
|
|
1456
|
+
desc=f"Generating tiles (with features: 0)",
|
|
1457
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
|
|
1458
|
+
)
|
|
1459
|
+
|
|
1460
|
+
# Generate tiles
|
|
1461
|
+
chip_index = start_index
|
|
1462
|
+
for y in range(num_tiles_y):
|
|
1463
|
+
for x in range(num_tiles_x):
|
|
1464
|
+
# Calculate window coordinates
|
|
1465
|
+
window_x = x * stride_x
|
|
1466
|
+
window_y = y * stride_y
|
|
1467
|
+
|
|
1468
|
+
# Adjust for edge cases
|
|
1469
|
+
if window_x + tile_size_x > src.width:
|
|
1470
|
+
window_x = src.width - tile_size_x
|
|
1471
|
+
if window_y + tile_size_y > src.height:
|
|
1472
|
+
window_y = src.height - tile_size_y
|
|
1473
|
+
|
|
1474
|
+
# Adjust window based on crop_mode
|
|
1475
|
+
if crop_mode == "CENTERED_ON_FEATURE" and len(gdf) > 0:
|
|
1476
|
+
# Find the nearest feature to the center of this window
|
|
1477
|
+
window_center_x = window_x + tile_size_x // 2
|
|
1478
|
+
window_center_y = window_y + tile_size_y // 2
|
|
1479
|
+
|
|
1480
|
+
# Convert center to world coordinates
|
|
1481
|
+
center_x, center_y = src.xy(window_center_y, window_center_x)
|
|
1482
|
+
center_point = gpd.points_from_xy([center_x], [center_y])[0]
|
|
1483
|
+
|
|
1484
|
+
# Find nearest feature
|
|
1485
|
+
distances = gdf.geometry.distance(center_point)
|
|
1486
|
+
nearest_idx = distances.idxmin()
|
|
1487
|
+
nearest_feature = gdf.iloc[nearest_idx]
|
|
1488
|
+
|
|
1489
|
+
# Get centroid of nearest feature
|
|
1490
|
+
feature_centroid = nearest_feature.geometry.centroid
|
|
1491
|
+
|
|
1492
|
+
# Convert feature centroid to pixel coordinates
|
|
1493
|
+
feature_row, feature_col = src.index(
|
|
1494
|
+
feature_centroid.x, feature_centroid.y
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
# Adjust window to center on feature
|
|
1498
|
+
window_x = max(
|
|
1499
|
+
0, min(src.width - tile_size_x, feature_col - tile_size_x // 2)
|
|
1500
|
+
)
|
|
1501
|
+
window_y = max(
|
|
1502
|
+
0, min(src.height - tile_size_y, feature_row - tile_size_y // 2)
|
|
1503
|
+
)
|
|
1504
|
+
|
|
1505
|
+
# Define window
|
|
1506
|
+
window = Window(window_x, window_y, tile_size_x, tile_size_y)
|
|
1507
|
+
|
|
1508
|
+
# Get window transform and bounds in source CRS
|
|
1509
|
+
window_transform = src.window_transform(window)
|
|
1510
|
+
|
|
1511
|
+
# Calculate window bounds more explicitly and accurately
|
|
1512
|
+
minx = window_transform[2] # Upper left x
|
|
1513
|
+
maxy = window_transform[5] # Upper left y
|
|
1514
|
+
maxx = minx + tile_size_x * window_transform[0] # Add width
|
|
1515
|
+
miny = (
|
|
1516
|
+
maxy + tile_size_y * window_transform[4]
|
|
1517
|
+
) # Add height (note: transform[4] is typically negative)
|
|
1518
|
+
|
|
1519
|
+
window_bounds = box(minx, miny, maxx, maxy)
|
|
1520
|
+
|
|
1521
|
+
# Apply rotation if specified
|
|
1522
|
+
if rotation_angle != 0:
|
|
1523
|
+
window_bounds = rotate(
|
|
1524
|
+
window_bounds, rotation_angle, origin="center"
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
# Find features that intersect with window
|
|
1528
|
+
window_features = gdf[gdf.intersects(window_bounds)]
|
|
1529
|
+
|
|
1530
|
+
# Process instance data if provided
|
|
1531
|
+
window_instances = None
|
|
1532
|
+
if instance_gdf is not None and instance_class_value_field is not None:
|
|
1533
|
+
window_instances = instance_gdf[
|
|
1534
|
+
instance_gdf.intersects(window_bounds)
|
|
1535
|
+
]
|
|
1536
|
+
if len(window_instances) > 0:
|
|
1537
|
+
if not quiet:
|
|
1538
|
+
pbar.write(
|
|
1539
|
+
f"Found {len(window_instances)} instances in tile {chip_index}"
|
|
1540
|
+
)
|
|
1541
|
+
|
|
1542
|
+
# Skip if no features and output_nofeature_tiles is False
|
|
1543
|
+
if not output_nofeature_tiles and len(window_features) == 0:
|
|
1544
|
+
pbar.update(1) # Still update progress bar
|
|
1545
|
+
continue
|
|
1546
|
+
|
|
1547
|
+
# Check polygon overlap ratio if specified
|
|
1548
|
+
if min_polygon_overlap_ratio > 0 and len(window_features) > 0:
|
|
1549
|
+
valid_features = []
|
|
1550
|
+
for _, feature in window_features.iterrows():
|
|
1551
|
+
overlap_ratio = (
|
|
1552
|
+
feature.geometry.intersection(window_bounds).area
|
|
1553
|
+
/ feature.geometry.area
|
|
1554
|
+
)
|
|
1555
|
+
if overlap_ratio >= min_polygon_overlap_ratio:
|
|
1556
|
+
valid_features.append(feature)
|
|
1557
|
+
|
|
1558
|
+
if len(valid_features) > 0:
|
|
1559
|
+
window_features = gpd.GeoDataFrame(valid_features)
|
|
1560
|
+
elif not output_nofeature_tiles:
|
|
1561
|
+
pbar.update(1) # Still update progress bar
|
|
1562
|
+
continue
|
|
1563
|
+
|
|
1564
|
+
# Apply mask if provided
|
|
1565
|
+
if mask_gdf is not None:
|
|
1566
|
+
mask_features = mask_gdf[mask_gdf.intersects(window_bounds)]
|
|
1567
|
+
if len(mask_features) == 0:
|
|
1568
|
+
pbar.update(1) # Still update progress bar
|
|
1569
|
+
continue
|
|
1570
|
+
|
|
1571
|
+
# Read image data - keep original for GeoTIFF export
|
|
1572
|
+
orig_image_data = src.read(window=window)
|
|
1573
|
+
|
|
1574
|
+
# Create a copy for processing
|
|
1575
|
+
image_data = orig_image_data.copy().astype(np.float32)
|
|
1576
|
+
|
|
1577
|
+
# Normalize image data for processing
|
|
1578
|
+
for band in range(image_data.shape[0]):
|
|
1579
|
+
band_min, band_max = np.percentile(image_data[band], (1, 99))
|
|
1580
|
+
if band_max > band_min:
|
|
1581
|
+
image_data[band] = np.clip(
|
|
1582
|
+
(image_data[band] - band_min) / (band_max - band_min), 0, 1
|
|
1583
|
+
)
|
|
1584
|
+
|
|
1585
|
+
# Read secondary image data if provided
|
|
1586
|
+
if src2:
|
|
1587
|
+
image_data2 = src2.read(window=window)
|
|
1588
|
+
# Stack the two images
|
|
1589
|
+
image_data = np.vstack((image_data, image_data2))
|
|
1590
|
+
|
|
1591
|
+
# Apply blacken_around_feature if needed
|
|
1592
|
+
if blacken_around_feature and len(window_features) > 0:
|
|
1593
|
+
mask = np.zeros((tile_size_y, tile_size_x), dtype=bool)
|
|
1594
|
+
for _, feature in window_features.iterrows():
|
|
1595
|
+
# Project feature to pixel coordinates
|
|
1596
|
+
feature_pixels = features.rasterize(
|
|
1597
|
+
[(feature.geometry, 1)],
|
|
1598
|
+
out_shape=(tile_size_y, tile_size_x),
|
|
1599
|
+
transform=window_transform,
|
|
1600
|
+
)
|
|
1601
|
+
mask = np.logical_or(mask, feature_pixels.astype(bool))
|
|
1602
|
+
|
|
1603
|
+
# Apply mask to image
|
|
1604
|
+
for band in range(image_data.shape[0]):
|
|
1605
|
+
temp = image_data[band, :, :]
|
|
1606
|
+
temp[~mask] = 0
|
|
1607
|
+
image_data[band, :, :] = temp
|
|
1608
|
+
|
|
1609
|
+
# Apply rotation if specified
|
|
1610
|
+
if augmentation:
|
|
1611
|
+
# Convert to torch tensor for augmentation
|
|
1612
|
+
image_tensor = torch.from_numpy(image_data).unsqueeze(
|
|
1613
|
+
0
|
|
1614
|
+
) # Add batch dimension
|
|
1615
|
+
# Apply augmentation with proper data format
|
|
1616
|
+
augmented = augmentation({"image": image_tensor})
|
|
1617
|
+
image_data = (
|
|
1618
|
+
augmented["image"].squeeze(0).numpy()
|
|
1619
|
+
) # Remove batch dimension
|
|
1620
|
+
|
|
1621
|
+
# Create a processed version for regular image formats
|
|
1622
|
+
processed_image = (image_data * 255).astype(np.uint8)
|
|
1623
|
+
|
|
1624
|
+
# Create label mask
|
|
1625
|
+
label_mask = np.zeros((tile_size_y, tile_size_x), dtype=np.uint8)
|
|
1626
|
+
has_features = False
|
|
1627
|
+
|
|
1628
|
+
if len(window_features) > 0:
|
|
1629
|
+
for idx, feature in window_features.iterrows():
|
|
1630
|
+
# Get class value
|
|
1631
|
+
class_val = (
|
|
1632
|
+
feature[class_value_field]
|
|
1633
|
+
if class_value_field in feature
|
|
1634
|
+
else 1
|
|
1635
|
+
)
|
|
1636
|
+
if isinstance(class_val, str):
|
|
1637
|
+
# If class is a string, use its position in the unique classes list
|
|
1638
|
+
class_id = class_to_id.get(class_val, 1)
|
|
1639
|
+
else:
|
|
1640
|
+
# If class is already a number, use it directly
|
|
1641
|
+
class_id = int(class_val) if class_val > 0 else 1
|
|
1642
|
+
|
|
1643
|
+
# Get the geometry in pixel coordinates
|
|
1644
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
1645
|
+
if not geom.is_empty:
|
|
1646
|
+
try:
|
|
1647
|
+
# Rasterize the feature
|
|
1648
|
+
feature_mask = features.rasterize(
|
|
1649
|
+
[(geom, class_id)],
|
|
1650
|
+
out_shape=(tile_size_y, tile_size_x),
|
|
1651
|
+
transform=window_transform,
|
|
1652
|
+
fill=0,
|
|
1653
|
+
all_touched=all_touched,
|
|
1654
|
+
)
|
|
1655
|
+
|
|
1656
|
+
# Update mask with higher class values taking precedence
|
|
1657
|
+
label_mask = np.maximum(label_mask, feature_mask)
|
|
1658
|
+
|
|
1659
|
+
# Check if any pixels were added
|
|
1660
|
+
if np.any(feature_mask):
|
|
1661
|
+
has_features = True
|
|
1662
|
+
except Exception as e:
|
|
1663
|
+
if not quiet:
|
|
1664
|
+
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
1665
|
+
stats["errors"] += 1
|
|
1666
|
+
|
|
1667
|
+
# Save as GeoTIFF if requested
|
|
1668
|
+
if save_geotiff or image_chip_format.upper() in [
|
|
1669
|
+
"TIFF",
|
|
1670
|
+
"TIF",
|
|
1671
|
+
"GEOTIFF",
|
|
1672
|
+
]:
|
|
1673
|
+
# Standardize extension to .tif for GeoTIFF files
|
|
1674
|
+
image_filename = f"tile_{chip_index:06d}.tif"
|
|
1675
|
+
image_path = os.path.join(image_dir, image_filename)
|
|
1676
|
+
|
|
1677
|
+
# Create profile for the GeoTIFF
|
|
1678
|
+
profile = src.profile.copy()
|
|
1679
|
+
profile.update(
|
|
1680
|
+
{
|
|
1681
|
+
"height": tile_size_y,
|
|
1682
|
+
"width": tile_size_x,
|
|
1683
|
+
"count": orig_image_data.shape[0],
|
|
1684
|
+
"transform": window_transform,
|
|
1685
|
+
}
|
|
1686
|
+
)
|
|
1687
|
+
|
|
1688
|
+
# Save the GeoTIFF with original data
|
|
1689
|
+
try:
|
|
1690
|
+
with rasterio.open(image_path, "w", **profile) as dst:
|
|
1691
|
+
dst.write(orig_image_data)
|
|
1692
|
+
stats["total_tiles"] += 1
|
|
1693
|
+
except Exception as e:
|
|
1694
|
+
if not quiet:
|
|
1695
|
+
pbar.write(
|
|
1696
|
+
f"ERROR saving image GeoTIFF for tile {chip_index}: {e}"
|
|
1697
|
+
)
|
|
1698
|
+
stats["errors"] += 1
|
|
1699
|
+
else:
|
|
1700
|
+
# For non-GeoTIFF formats, use PIL to save the image
|
|
1701
|
+
image_filename = (
|
|
1702
|
+
f"tile_{chip_index:06d}.{image_chip_format.lower()}"
|
|
1703
|
+
)
|
|
1704
|
+
image_path = os.path.join(image_dir, image_filename)
|
|
1705
|
+
|
|
1706
|
+
# Create PIL image for saving
|
|
1707
|
+
if processed_image.shape[0] == 1:
|
|
1708
|
+
img = Image.fromarray(processed_image[0])
|
|
1709
|
+
elif processed_image.shape[0] == 3:
|
|
1710
|
+
# For RGB, need to transpose and make sure it's the right data type
|
|
1711
|
+
rgb_data = np.transpose(processed_image, (1, 2, 0))
|
|
1712
|
+
img = Image.fromarray(rgb_data)
|
|
1713
|
+
else:
|
|
1714
|
+
# For multiband images, save only RGB or first three bands
|
|
1715
|
+
rgb_data = np.transpose(processed_image[:3], (1, 2, 0))
|
|
1716
|
+
img = Image.fromarray(rgb_data)
|
|
1717
|
+
|
|
1718
|
+
# Save image
|
|
1719
|
+
try:
|
|
1720
|
+
img.save(image_path)
|
|
1721
|
+
stats["total_tiles"] += 1
|
|
1722
|
+
except Exception as e:
|
|
1723
|
+
if not quiet:
|
|
1724
|
+
pbar.write(f"ERROR saving image for tile {chip_index}: {e}")
|
|
1725
|
+
stats["errors"] += 1
|
|
1726
|
+
|
|
1727
|
+
# Save label as GeoTIFF
|
|
1728
|
+
label_filename = f"tile_{chip_index:06d}.tif"
|
|
1729
|
+
label_path = os.path.join(label_dir, label_filename)
|
|
1730
|
+
|
|
1731
|
+
# Create profile for label GeoTIFF
|
|
1732
|
+
label_profile = {
|
|
1733
|
+
"driver": "GTiff",
|
|
1734
|
+
"height": tile_size_y,
|
|
1735
|
+
"width": tile_size_x,
|
|
1736
|
+
"count": 1,
|
|
1737
|
+
"dtype": "uint8",
|
|
1738
|
+
"crs": src.crs,
|
|
1739
|
+
"transform": window_transform,
|
|
1740
|
+
}
|
|
1741
|
+
|
|
1742
|
+
# Save label GeoTIFF
|
|
1743
|
+
try:
|
|
1744
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
1745
|
+
dst.write(label_mask, 1)
|
|
1746
|
+
|
|
1747
|
+
if has_features:
|
|
1748
|
+
pixel_count = np.count_nonzero(label_mask)
|
|
1749
|
+
stats["tiles_with_features"] += 1
|
|
1750
|
+
stats["feature_pixels"] += pixel_count
|
|
1751
|
+
except Exception as e:
|
|
1752
|
+
if not quiet:
|
|
1753
|
+
pbar.write(f"ERROR saving label for tile {chip_index}: {e}")
|
|
1754
|
+
stats["errors"] += 1
|
|
1755
|
+
|
|
1756
|
+
# Also save a PNG version for easy visualization if requested
|
|
1757
|
+
if metadata_format == "PASCAL_VOC":
|
|
1758
|
+
try:
|
|
1759
|
+
# Ensure correct data type for PIL
|
|
1760
|
+
png_label = label_mask.astype(np.uint8)
|
|
1761
|
+
label_img = Image.fromarray(png_label)
|
|
1762
|
+
label_png_path = os.path.join(
|
|
1763
|
+
label_dir, f"tile_{chip_index:06d}.png"
|
|
1764
|
+
)
|
|
1765
|
+
label_img.save(label_png_path)
|
|
1766
|
+
except Exception as e:
|
|
1767
|
+
if not quiet:
|
|
1768
|
+
pbar.write(
|
|
1769
|
+
f"ERROR saving PNG label for tile {chip_index}: {e}"
|
|
1770
|
+
)
|
|
1771
|
+
pbar.write(
|
|
1772
|
+
f" Label mask shape: {label_mask.shape}, dtype: {label_mask.dtype}"
|
|
1773
|
+
)
|
|
1774
|
+
# Try again with explicit conversion
|
|
1775
|
+
try:
|
|
1776
|
+
# Alternative approach for problematic arrays
|
|
1777
|
+
png_data = np.zeros(
|
|
1778
|
+
(tile_size_y, tile_size_x), dtype=np.uint8
|
|
1779
|
+
)
|
|
1780
|
+
np.copyto(png_data, label_mask, casting="unsafe")
|
|
1781
|
+
label_img = Image.fromarray(png_data)
|
|
1782
|
+
label_img.save(label_png_path)
|
|
1783
|
+
pbar.write(
|
|
1784
|
+
f" Succeeded using alternative conversion method"
|
|
1785
|
+
)
|
|
1786
|
+
except Exception as e2:
|
|
1787
|
+
pbar.write(f" Second attempt also failed: {e2}")
|
|
1788
|
+
stats["errors"] += 1
|
|
1789
|
+
|
|
1790
|
+
# Generate annotations
|
|
1791
|
+
if metadata_format == "PASCAL_VOC" and len(window_features) > 0:
|
|
1792
|
+
# Create XML annotation
|
|
1793
|
+
root = ET.Element("annotation")
|
|
1794
|
+
ET.SubElement(root, "folder").text = "images"
|
|
1795
|
+
ET.SubElement(root, "filename").text = image_filename
|
|
1796
|
+
|
|
1797
|
+
size = ET.SubElement(root, "size")
|
|
1798
|
+
ET.SubElement(size, "width").text = str(tile_size_x)
|
|
1799
|
+
ET.SubElement(size, "height").text = str(tile_size_y)
|
|
1800
|
+
ET.SubElement(size, "depth").text = str(min(image_data.shape[0], 3))
|
|
1801
|
+
|
|
1802
|
+
# Add georeference information
|
|
1803
|
+
geo = ET.SubElement(root, "georeference")
|
|
1804
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
1805
|
+
ET.SubElement(geo, "transform").text = str(
|
|
1806
|
+
window_transform
|
|
1807
|
+
).replace("\n", "")
|
|
1808
|
+
ET.SubElement(geo, "bounds").text = (
|
|
1809
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
1810
|
+
)
|
|
1811
|
+
|
|
1812
|
+
for _, feature in window_features.iterrows():
|
|
1813
|
+
# Convert feature geometry to pixel coordinates
|
|
1814
|
+
feature_bounds = feature.geometry.intersection(window_bounds)
|
|
1815
|
+
if feature_bounds.is_empty:
|
|
1816
|
+
continue
|
|
1817
|
+
|
|
1818
|
+
# Get pixel coordinates of bounds
|
|
1819
|
+
minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
|
|
1820
|
+
|
|
1821
|
+
# Convert to pixel coordinates
|
|
1822
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
1823
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
1824
|
+
|
|
1825
|
+
# Ensure coordinates are within bounds
|
|
1826
|
+
xmin = max(0, min(tile_size_x, int(col_min)))
|
|
1827
|
+
ymin = max(0, min(tile_size_y, int(row_min)))
|
|
1828
|
+
xmax = max(0, min(tile_size_x, int(col_max)))
|
|
1829
|
+
ymax = max(0, min(tile_size_y, int(row_max)))
|
|
1830
|
+
|
|
1831
|
+
# Skip if box is too small
|
|
1832
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
1833
|
+
continue
|
|
1834
|
+
|
|
1835
|
+
obj = ET.SubElement(root, "object")
|
|
1836
|
+
ET.SubElement(obj, "name").text = str(
|
|
1837
|
+
feature[class_value_field]
|
|
1838
|
+
)
|
|
1839
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
1840
|
+
|
|
1841
|
+
bbox = ET.SubElement(obj, "bndbox")
|
|
1842
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
1843
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
1844
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
1845
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
1846
|
+
|
|
1847
|
+
# Save XML
|
|
1848
|
+
try:
|
|
1849
|
+
tree = ET.ElementTree(root)
|
|
1850
|
+
xml_path = os.path.join(ann_dir, f"tile_{chip_index:06d}.xml")
|
|
1851
|
+
tree.write(xml_path)
|
|
1852
|
+
except Exception as e:
|
|
1853
|
+
if not quiet:
|
|
1854
|
+
pbar.write(
|
|
1855
|
+
f"ERROR saving XML annotation for tile {chip_index}: {e}"
|
|
1856
|
+
)
|
|
1857
|
+
stats["errors"] += 1
|
|
1858
|
+
|
|
1859
|
+
elif metadata_format == "COCO" and len(window_features) > 0:
|
|
1860
|
+
# Add image info
|
|
1861
|
+
image_id = chip_index
|
|
1862
|
+
coco_annotations["images"].append(
|
|
1863
|
+
{
|
|
1864
|
+
"id": image_id,
|
|
1865
|
+
"file_name": image_filename,
|
|
1866
|
+
"width": tile_size_x,
|
|
1867
|
+
"height": tile_size_y,
|
|
1868
|
+
"crs": str(src.crs),
|
|
1869
|
+
"transform": str(window_transform),
|
|
1870
|
+
}
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
# Add annotations for each feature
|
|
1874
|
+
for _, feature in window_features.iterrows():
|
|
1875
|
+
feature_bounds = feature.geometry.intersection(window_bounds)
|
|
1876
|
+
if feature_bounds.is_empty:
|
|
1877
|
+
continue
|
|
1878
|
+
|
|
1879
|
+
# Get pixel coordinates of bounds
|
|
1880
|
+
minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
|
|
1881
|
+
|
|
1882
|
+
# Convert to pixel coordinates
|
|
1883
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
1884
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
1885
|
+
|
|
1886
|
+
# Ensure coordinates are within bounds
|
|
1887
|
+
xmin = max(0, min(tile_size_x, int(col_min)))
|
|
1888
|
+
ymin = max(0, min(tile_size_y, int(row_min)))
|
|
1889
|
+
xmax = max(0, min(tile_size_x, int(col_max)))
|
|
1890
|
+
ymax = max(0, min(tile_size_y, int(row_max)))
|
|
1891
|
+
|
|
1892
|
+
# Skip if box is too small
|
|
1893
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
1894
|
+
continue
|
|
1895
|
+
|
|
1896
|
+
width = xmax - xmin
|
|
1897
|
+
height = ymax - ymin
|
|
1898
|
+
|
|
1899
|
+
# Add annotation
|
|
1900
|
+
ann_id += 1
|
|
1901
|
+
category_id = class_to_id[feature[class_value_field]]
|
|
1902
|
+
|
|
1903
|
+
coco_annotations["annotations"].append(
|
|
1904
|
+
{
|
|
1905
|
+
"id": ann_id,
|
|
1906
|
+
"image_id": image_id,
|
|
1907
|
+
"category_id": category_id,
|
|
1908
|
+
"bbox": [xmin, ymin, width, height],
|
|
1909
|
+
"area": width * height,
|
|
1910
|
+
"iscrowd": 0,
|
|
1911
|
+
}
|
|
1912
|
+
)
|
|
1913
|
+
|
|
1914
|
+
# Update progress bar
|
|
1915
|
+
pbar.update(1)
|
|
1916
|
+
pbar.set_description(
|
|
1917
|
+
f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
|
|
1918
|
+
)
|
|
1919
|
+
|
|
1920
|
+
chip_index += 1
|
|
1921
|
+
|
|
1922
|
+
# Close progress bar
|
|
1923
|
+
pbar.close()
|
|
1924
|
+
|
|
1925
|
+
# Save COCO annotations if applicable
|
|
1926
|
+
if metadata_format == "COCO":
|
|
1927
|
+
try:
|
|
1928
|
+
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
|
1929
|
+
json.dump(coco_annotations, f)
|
|
1930
|
+
except Exception as e:
|
|
1931
|
+
if not quiet:
|
|
1932
|
+
print(f"ERROR saving COCO annotations: {e}")
|
|
1933
|
+
stats["errors"] += 1
|
|
1934
|
+
|
|
1935
|
+
# Close secondary raster if opened
|
|
1936
|
+
if src2:
|
|
1937
|
+
src2.close()
|
|
1938
|
+
|
|
1939
|
+
# Print summary
|
|
1940
|
+
if not quiet:
|
|
1941
|
+
print("\n------- Export Summary -------")
|
|
1942
|
+
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
1943
|
+
print(
|
|
1944
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
1945
|
+
)
|
|
1946
|
+
if stats["tiles_with_features"] > 0:
|
|
1947
|
+
print(
|
|
1948
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
1949
|
+
)
|
|
1950
|
+
if stats["errors"] > 0:
|
|
1951
|
+
print(f"Errors encountered: {stats['errors']}")
|
|
1952
|
+
print(f"Output saved to: {out_folder}")
|
|
1953
|
+
|
|
1954
|
+
# Verify georeference in a sample image and label
|
|
1955
|
+
if stats["total_tiles"] > 0:
|
|
1956
|
+
print("\n------- Georeference Verification -------")
|
|
1957
|
+
sample_image = os.path.join(image_dir, f"tile_{start_index}.tif")
|
|
1958
|
+
sample_label = os.path.join(label_dir, f"tile_{start_index}.tif")
|
|
1959
|
+
|
|
1960
|
+
if os.path.exists(sample_image):
|
|
1961
|
+
try:
|
|
1962
|
+
with rasterio.open(sample_image) as img:
|
|
1963
|
+
print(f"Image CRS: {img.crs}")
|
|
1964
|
+
print(f"Image transform: {img.transform}")
|
|
1965
|
+
print(
|
|
1966
|
+
f"Image has georeference: {img.crs is not None and img.transform is not None}"
|
|
1967
|
+
)
|
|
1968
|
+
print(
|
|
1969
|
+
f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
|
|
1970
|
+
)
|
|
1971
|
+
except Exception as e:
|
|
1972
|
+
print(f"Error verifying image georeference: {e}")
|
|
1973
|
+
|
|
1974
|
+
if os.path.exists(sample_label):
|
|
1975
|
+
try:
|
|
1976
|
+
with rasterio.open(sample_label) as lbl:
|
|
1977
|
+
print(f"Label CRS: {lbl.crs}")
|
|
1978
|
+
print(f"Label transform: {lbl.transform}")
|
|
1979
|
+
print(
|
|
1980
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
1981
|
+
)
|
|
1982
|
+
print(
|
|
1983
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
1984
|
+
)
|
|
1985
|
+
except Exception as e:
|
|
1986
|
+
print(f"Error verifying label georeference: {e}")
|
|
1987
|
+
|
|
1988
|
+
# Return statistics
|
|
1989
|
+
return stats, out_folder
|
|
1990
|
+
|
|
1991
|
+
|
|
1992
|
+
# if __name__ == "__main__":
|
|
1993
|
+
# # Example parameters
|
|
1994
|
+
# export_training_data(
|
|
1995
|
+
# in_raster="naip_train.tif",
|
|
1996
|
+
# out_folder="output",
|
|
1997
|
+
# in_class_data="buildings_train.geojson",
|
|
1998
|
+
# image_chip_format="GEOTIFF", # Use GeoTIFF format to preserve georeference
|
|
1999
|
+
# tile_size_x=256,
|
|
2000
|
+
# tile_size_y=256,
|
|
2001
|
+
# stride_x=128, # Use overlapping tiles to increase chance of capturing features
|
|
2002
|
+
# stride_y=128,
|
|
2003
|
+
# metadata_format="PASCAL_VOC",
|
|
2004
|
+
# class_value_field="class",
|
|
2005
|
+
# buffer_radius=2, # Add small buffer to buildings to ensure they're captured
|
|
2006
|
+
# all_touched=True, # Ensure small features are rasterized
|
|
2007
|
+
# save_geotiff=True, # Always save as GeoTIFF regardless of image_chip_format
|
|
2008
|
+
# )
|