geoai-py 0.3.5__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/download.py +9 -8
- geoai/extract.py +158 -38
- geoai/geoai.py +4 -1
- geoai/hf.py +447 -0
- geoai/segment.py +306 -0
- geoai/segmentation.py +8 -7
- geoai/train.py +1039 -0
- geoai/utils.py +863 -25
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/METADATA +5 -1
- geoai_py-0.4.0.dist-info/RECORD +15 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/WHEEL +1 -1
- geoai/preprocess.py +0 -3021
- geoai_py-0.3.5.dist-info/RECORD +0 -13
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/top_level.txt +0 -0
geoai/preprocess.py
DELETED
|
@@ -1,3021 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import math
|
|
3
|
-
import os
|
|
4
|
-
from PIL import Image
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
import requests
|
|
7
|
-
import warnings
|
|
8
|
-
import xml.etree.ElementTree as ET
|
|
9
|
-
import numpy as np
|
|
10
|
-
import rasterio
|
|
11
|
-
import geopandas as gpd
|
|
12
|
-
import pandas as pd
|
|
13
|
-
from rasterio.windows import Window
|
|
14
|
-
from rasterio import features
|
|
15
|
-
from rasterio.plot import show
|
|
16
|
-
from shapely.geometry import box, shape, mapping, Polygon
|
|
17
|
-
import matplotlib.pyplot as plt
|
|
18
|
-
from tqdm import tqdm
|
|
19
|
-
from torchvision.transforms import RandomRotation
|
|
20
|
-
from shapely.affinity import rotate
|
|
21
|
-
import torch
|
|
22
|
-
import cv2
|
|
23
|
-
|
|
24
|
-
try:
|
|
25
|
-
import torchgeo
|
|
26
|
-
except ImportError as e:
|
|
27
|
-
raise ImportError(
|
|
28
|
-
"Your torchgeo version is too old. Please upgrade to the latest version using 'pip install -U torchgeo'."
|
|
29
|
-
)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def download_file(url, output_path=None, overwrite=False):
|
|
33
|
-
"""
|
|
34
|
-
Download a file from a given URL with a progress bar.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
url (str): The URL of the file to download.
|
|
38
|
-
output_path (str, optional): The path where the downloaded file will be saved.
|
|
39
|
-
If not provided, the filename from the URL will be used.
|
|
40
|
-
overwrite (bool, optional): Whether to overwrite the file if it already exists.
|
|
41
|
-
|
|
42
|
-
Returns:
|
|
43
|
-
str: The path to the downloaded file.
|
|
44
|
-
"""
|
|
45
|
-
# Get the filename from the URL if output_path is not provided
|
|
46
|
-
if output_path is None:
|
|
47
|
-
output_path = os.path.basename(url)
|
|
48
|
-
|
|
49
|
-
# Check if the file already exists
|
|
50
|
-
if os.path.exists(output_path) and not overwrite:
|
|
51
|
-
print(f"File already exists: {output_path}")
|
|
52
|
-
return output_path
|
|
53
|
-
|
|
54
|
-
# Send a streaming GET request
|
|
55
|
-
response = requests.get(url, stream=True, timeout=50)
|
|
56
|
-
response.raise_for_status() # Raise an exception for HTTP errors
|
|
57
|
-
|
|
58
|
-
# Get the total file size if available
|
|
59
|
-
total_size = int(response.headers.get("content-length", 0))
|
|
60
|
-
|
|
61
|
-
# Open the output file
|
|
62
|
-
with (
|
|
63
|
-
open(output_path, "wb") as file,
|
|
64
|
-
tqdm(
|
|
65
|
-
desc=os.path.basename(output_path),
|
|
66
|
-
total=total_size,
|
|
67
|
-
unit="B",
|
|
68
|
-
unit_scale=True,
|
|
69
|
-
unit_divisor=1024,
|
|
70
|
-
) as progress_bar,
|
|
71
|
-
):
|
|
72
|
-
|
|
73
|
-
# Download the file in chunks and update the progress bar
|
|
74
|
-
for chunk in response.iter_content(chunk_size=1024):
|
|
75
|
-
if chunk: # filter out keep-alive new chunks
|
|
76
|
-
file.write(chunk)
|
|
77
|
-
progress_bar.update(len(chunk))
|
|
78
|
-
|
|
79
|
-
return output_path
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def get_raster_info(raster_path):
|
|
83
|
-
"""Display basic information about a raster dataset.
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
raster_path (str): Path to the raster file
|
|
87
|
-
|
|
88
|
-
Returns:
|
|
89
|
-
dict: Dictionary containing the basic information about the raster
|
|
90
|
-
"""
|
|
91
|
-
# Open the raster dataset
|
|
92
|
-
with rasterio.open(raster_path) as src:
|
|
93
|
-
# Get basic metadata
|
|
94
|
-
info = {
|
|
95
|
-
"driver": src.driver,
|
|
96
|
-
"width": src.width,
|
|
97
|
-
"height": src.height,
|
|
98
|
-
"count": src.count,
|
|
99
|
-
"dtype": src.dtypes[0],
|
|
100
|
-
"crs": src.crs.to_string() if src.crs else "No CRS defined",
|
|
101
|
-
"transform": src.transform,
|
|
102
|
-
"bounds": src.bounds,
|
|
103
|
-
"resolution": (src.transform[0], -src.transform[4]),
|
|
104
|
-
"nodata": src.nodata,
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
# Calculate statistics for each band
|
|
108
|
-
stats = []
|
|
109
|
-
for i in range(1, src.count + 1):
|
|
110
|
-
band = src.read(i, masked=True)
|
|
111
|
-
band_stats = {
|
|
112
|
-
"band": i,
|
|
113
|
-
"min": float(band.min()),
|
|
114
|
-
"max": float(band.max()),
|
|
115
|
-
"mean": float(band.mean()),
|
|
116
|
-
"std": float(band.std()),
|
|
117
|
-
}
|
|
118
|
-
stats.append(band_stats)
|
|
119
|
-
|
|
120
|
-
info["band_stats"] = stats
|
|
121
|
-
|
|
122
|
-
return info
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def get_raster_stats(raster_path, divide_by=1.0):
|
|
126
|
-
"""Calculate statistics for each band in a raster dataset.
|
|
127
|
-
|
|
128
|
-
This function computes min, max, mean, and standard deviation values
|
|
129
|
-
for each band in the provided raster, returning results in a dictionary
|
|
130
|
-
with lists for each statistic type.
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
raster_path (str): Path to the raster file
|
|
134
|
-
divide_by (float, optional): Value to divide pixel values by.
|
|
135
|
-
Defaults to 1.0, which keeps the original pixel
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
dict: Dictionary containing lists of statistics with keys:
|
|
139
|
-
- 'min': List of minimum values for each band
|
|
140
|
-
- 'max': List of maximum values for each band
|
|
141
|
-
- 'mean': List of mean values for each band
|
|
142
|
-
- 'std': List of standard deviation values for each band
|
|
143
|
-
"""
|
|
144
|
-
# Initialize the results dictionary with empty lists
|
|
145
|
-
stats = {"min": [], "max": [], "mean": [], "std": []}
|
|
146
|
-
|
|
147
|
-
# Open the raster dataset
|
|
148
|
-
with rasterio.open(raster_path) as src:
|
|
149
|
-
# Calculate statistics for each band
|
|
150
|
-
for i in range(1, src.count + 1):
|
|
151
|
-
band = src.read(i, masked=True)
|
|
152
|
-
|
|
153
|
-
# Append statistics for this band to each list
|
|
154
|
-
stats["min"].append(float(band.min()) / divide_by)
|
|
155
|
-
stats["max"].append(float(band.max()) / divide_by)
|
|
156
|
-
stats["mean"].append(float(band.mean()) / divide_by)
|
|
157
|
-
stats["std"].append(float(band.std()) / divide_by)
|
|
158
|
-
|
|
159
|
-
return stats
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def print_raster_info(raster_path, show_preview=True, figsize=(10, 8)):
|
|
163
|
-
"""Print formatted information about a raster dataset and optionally show a preview.
|
|
164
|
-
|
|
165
|
-
Args:
|
|
166
|
-
raster_path (str): Path to the raster file
|
|
167
|
-
show_preview (bool, optional): Whether to display a visual preview of the raster.
|
|
168
|
-
Defaults to True.
|
|
169
|
-
figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
|
|
170
|
-
|
|
171
|
-
Returns:
|
|
172
|
-
dict: Dictionary containing raster information if successful, None otherwise
|
|
173
|
-
"""
|
|
174
|
-
try:
|
|
175
|
-
info = get_raster_info(raster_path)
|
|
176
|
-
|
|
177
|
-
# Print basic information
|
|
178
|
-
print(f"===== RASTER INFORMATION: {raster_path} =====")
|
|
179
|
-
print(f"Driver: {info['driver']}")
|
|
180
|
-
print(f"Dimensions: {info['width']} x {info['height']} pixels")
|
|
181
|
-
print(f"Number of bands: {info['count']}")
|
|
182
|
-
print(f"Data type: {info['dtype']}")
|
|
183
|
-
print(f"Coordinate Reference System: {info['crs']}")
|
|
184
|
-
print(f"Georeferenced Bounds: {info['bounds']}")
|
|
185
|
-
print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
|
|
186
|
-
print(f"NoData Value: {info['nodata']}")
|
|
187
|
-
|
|
188
|
-
# Print band statistics
|
|
189
|
-
print("\n----- Band Statistics -----")
|
|
190
|
-
for band_stat in info["band_stats"]:
|
|
191
|
-
print(f"Band {band_stat['band']}:")
|
|
192
|
-
print(f" Min: {band_stat['min']:.2f}")
|
|
193
|
-
print(f" Max: {band_stat['max']:.2f}")
|
|
194
|
-
print(f" Mean: {band_stat['mean']:.2f}")
|
|
195
|
-
print(f" Std Dev: {band_stat['std']:.2f}")
|
|
196
|
-
|
|
197
|
-
# Show a preview if requested
|
|
198
|
-
if show_preview:
|
|
199
|
-
with rasterio.open(raster_path) as src:
|
|
200
|
-
# For multi-band images, show RGB composite or first band
|
|
201
|
-
if src.count >= 3:
|
|
202
|
-
# Try to show RGB composite
|
|
203
|
-
rgb = np.dstack([src.read(i) for i in range(1, 4)])
|
|
204
|
-
plt.figure(figsize=figsize)
|
|
205
|
-
plt.imshow(rgb)
|
|
206
|
-
plt.title(f"RGB Preview: {raster_path}")
|
|
207
|
-
else:
|
|
208
|
-
# Show first band for single-band images
|
|
209
|
-
plt.figure(figsize=figsize)
|
|
210
|
-
show(
|
|
211
|
-
src.read(1),
|
|
212
|
-
cmap="viridis",
|
|
213
|
-
title=f"Band 1 Preview: {raster_path}",
|
|
214
|
-
)
|
|
215
|
-
plt.colorbar(label="Pixel Value")
|
|
216
|
-
plt.show()
|
|
217
|
-
|
|
218
|
-
except Exception as e:
|
|
219
|
-
print(f"Error reading raster: {str(e)}")
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def get_raster_info_gdal(raster_path):
|
|
223
|
-
"""Get basic information about a raster dataset using GDAL.
|
|
224
|
-
|
|
225
|
-
Args:
|
|
226
|
-
raster_path (str): Path to the raster file
|
|
227
|
-
|
|
228
|
-
Returns:
|
|
229
|
-
dict: Dictionary containing the basic information about the raster,
|
|
230
|
-
or None if the file cannot be opened
|
|
231
|
-
"""
|
|
232
|
-
|
|
233
|
-
from osgeo import gdal
|
|
234
|
-
|
|
235
|
-
# Open the dataset
|
|
236
|
-
ds = gdal.Open(raster_path)
|
|
237
|
-
if ds is None:
|
|
238
|
-
print(f"Error: Could not open {raster_path}")
|
|
239
|
-
return None
|
|
240
|
-
|
|
241
|
-
# Get basic information
|
|
242
|
-
info = {
|
|
243
|
-
"driver": ds.GetDriver().ShortName,
|
|
244
|
-
"width": ds.RasterXSize,
|
|
245
|
-
"height": ds.RasterYSize,
|
|
246
|
-
"count": ds.RasterCount,
|
|
247
|
-
"projection": ds.GetProjection(),
|
|
248
|
-
"geotransform": ds.GetGeoTransform(),
|
|
249
|
-
}
|
|
250
|
-
|
|
251
|
-
# Calculate resolution
|
|
252
|
-
gt = ds.GetGeoTransform()
|
|
253
|
-
if gt:
|
|
254
|
-
info["resolution"] = (abs(gt[1]), abs(gt[5]))
|
|
255
|
-
info["origin"] = (gt[0], gt[3])
|
|
256
|
-
|
|
257
|
-
# Get band information
|
|
258
|
-
bands_info = []
|
|
259
|
-
for i in range(1, ds.RasterCount + 1):
|
|
260
|
-
band = ds.GetRasterBand(i)
|
|
261
|
-
stats = band.GetStatistics(True, True)
|
|
262
|
-
band_info = {
|
|
263
|
-
"band": i,
|
|
264
|
-
"datatype": gdal.GetDataTypeName(band.DataType),
|
|
265
|
-
"min": stats[0],
|
|
266
|
-
"max": stats[1],
|
|
267
|
-
"mean": stats[2],
|
|
268
|
-
"std": stats[3],
|
|
269
|
-
"nodata": band.GetNoDataValue(),
|
|
270
|
-
}
|
|
271
|
-
bands_info.append(band_info)
|
|
272
|
-
|
|
273
|
-
info["bands"] = bands_info
|
|
274
|
-
|
|
275
|
-
# Close the dataset
|
|
276
|
-
ds = None
|
|
277
|
-
|
|
278
|
-
return info
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
def get_vector_info(vector_path):
|
|
282
|
-
"""Display basic information about a vector dataset using GeoPandas.
|
|
283
|
-
|
|
284
|
-
Args:
|
|
285
|
-
vector_path (str): Path to the vector file
|
|
286
|
-
|
|
287
|
-
Returns:
|
|
288
|
-
dict: Dictionary containing the basic information about the vector dataset
|
|
289
|
-
"""
|
|
290
|
-
# Open the vector dataset
|
|
291
|
-
gdf = (
|
|
292
|
-
gpd.read_parquet(vector_path)
|
|
293
|
-
if vector_path.endswith(".parquet")
|
|
294
|
-
else gpd.read_file(vector_path)
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
# Get basic metadata
|
|
298
|
-
info = {
|
|
299
|
-
"file_path": vector_path,
|
|
300
|
-
"driver": os.path.splitext(vector_path)[1][1:].upper(), # Format from extension
|
|
301
|
-
"feature_count": len(gdf),
|
|
302
|
-
"crs": str(gdf.crs),
|
|
303
|
-
"geometry_type": str(gdf.geom_type.value_counts().to_dict()),
|
|
304
|
-
"attribute_count": len(gdf.columns) - 1, # Subtract the geometry column
|
|
305
|
-
"attribute_names": list(gdf.columns[gdf.columns != "geometry"]),
|
|
306
|
-
"bounds": gdf.total_bounds.tolist(),
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
# Add statistics about numeric attributes
|
|
310
|
-
numeric_columns = gdf.select_dtypes(include=["number"]).columns
|
|
311
|
-
attribute_stats = {}
|
|
312
|
-
for col in numeric_columns:
|
|
313
|
-
if col != "geometry":
|
|
314
|
-
attribute_stats[col] = {
|
|
315
|
-
"min": gdf[col].min(),
|
|
316
|
-
"max": gdf[col].max(),
|
|
317
|
-
"mean": gdf[col].mean(),
|
|
318
|
-
"std": gdf[col].std(),
|
|
319
|
-
"null_count": gdf[col].isna().sum(),
|
|
320
|
-
}
|
|
321
|
-
|
|
322
|
-
info["attribute_stats"] = attribute_stats
|
|
323
|
-
|
|
324
|
-
return info
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
def print_vector_info(vector_path, show_preview=True, figsize=(10, 8)):
|
|
328
|
-
"""Print formatted information about a vector dataset and optionally show a preview.
|
|
329
|
-
|
|
330
|
-
Args:
|
|
331
|
-
vector_path (str): Path to the vector file
|
|
332
|
-
show_preview (bool, optional): Whether to display a visual preview of the vector data.
|
|
333
|
-
Defaults to True.
|
|
334
|
-
figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
|
|
335
|
-
|
|
336
|
-
Returns:
|
|
337
|
-
dict: Dictionary containing vector information if successful, None otherwise
|
|
338
|
-
"""
|
|
339
|
-
try:
|
|
340
|
-
info = get_vector_info(vector_path)
|
|
341
|
-
|
|
342
|
-
# Print basic information
|
|
343
|
-
print(f"===== VECTOR INFORMATION: {vector_path} =====")
|
|
344
|
-
print(f"Driver: {info['driver']}")
|
|
345
|
-
print(f"Feature count: {info['feature_count']}")
|
|
346
|
-
print(f"Geometry types: {info['geometry_type']}")
|
|
347
|
-
print(f"Coordinate Reference System: {info['crs']}")
|
|
348
|
-
print(f"Bounds: {info['bounds']}")
|
|
349
|
-
print(f"Number of attributes: {info['attribute_count']}")
|
|
350
|
-
print(f"Attribute names: {', '.join(info['attribute_names'])}")
|
|
351
|
-
|
|
352
|
-
# Print attribute statistics
|
|
353
|
-
if info["attribute_stats"]:
|
|
354
|
-
print("\n----- Attribute Statistics -----")
|
|
355
|
-
for attr, stats in info["attribute_stats"].items():
|
|
356
|
-
print(f"Attribute: {attr}")
|
|
357
|
-
for stat_name, stat_value in stats.items():
|
|
358
|
-
print(
|
|
359
|
-
f" {stat_name}: {stat_value:.4f}"
|
|
360
|
-
if isinstance(stat_value, float)
|
|
361
|
-
else f" {stat_name}: {stat_value}"
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
# Show a preview if requested
|
|
365
|
-
if show_preview:
|
|
366
|
-
gdf = (
|
|
367
|
-
gpd.read_parquet(vector_path)
|
|
368
|
-
if vector_path.endswith(".parquet")
|
|
369
|
-
else gpd.read_file(vector_path)
|
|
370
|
-
)
|
|
371
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
372
|
-
gdf.plot(ax=ax, cmap="viridis")
|
|
373
|
-
ax.set_title(f"Preview: {vector_path}")
|
|
374
|
-
plt.tight_layout()
|
|
375
|
-
plt.show()
|
|
376
|
-
|
|
377
|
-
# # Show a sample of the attribute table
|
|
378
|
-
# if not gdf.empty:
|
|
379
|
-
# print("\n----- Sample of attribute table (first 5 rows) -----")
|
|
380
|
-
# print(gdf.head().to_string())
|
|
381
|
-
|
|
382
|
-
except Exception as e:
|
|
383
|
-
print(f"Error reading vector data: {str(e)}")
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
# Alternative implementation using OGR directly
|
|
387
|
-
def get_vector_info_ogr(vector_path):
|
|
388
|
-
"""Get basic information about a vector dataset using OGR.
|
|
389
|
-
|
|
390
|
-
Args:
|
|
391
|
-
vector_path (str): Path to the vector file
|
|
392
|
-
|
|
393
|
-
Returns:
|
|
394
|
-
dict: Dictionary containing the basic information about the vector dataset,
|
|
395
|
-
or None if the file cannot be opened
|
|
396
|
-
"""
|
|
397
|
-
from osgeo import ogr
|
|
398
|
-
|
|
399
|
-
# Register all OGR drivers
|
|
400
|
-
ogr.RegisterAll()
|
|
401
|
-
|
|
402
|
-
# Open the dataset
|
|
403
|
-
ds = ogr.Open(vector_path)
|
|
404
|
-
if ds is None:
|
|
405
|
-
print(f"Error: Could not open {vector_path}")
|
|
406
|
-
return None
|
|
407
|
-
|
|
408
|
-
# Basic dataset information
|
|
409
|
-
info = {
|
|
410
|
-
"file_path": vector_path,
|
|
411
|
-
"driver": ds.GetDriver().GetName(),
|
|
412
|
-
"layer_count": ds.GetLayerCount(),
|
|
413
|
-
"layers": [],
|
|
414
|
-
}
|
|
415
|
-
|
|
416
|
-
# Extract information for each layer
|
|
417
|
-
for i in range(ds.GetLayerCount()):
|
|
418
|
-
layer = ds.GetLayer(i)
|
|
419
|
-
layer_info = {
|
|
420
|
-
"name": layer.GetName(),
|
|
421
|
-
"feature_count": layer.GetFeatureCount(),
|
|
422
|
-
"geometry_type": ogr.GeometryTypeToName(layer.GetGeomType()),
|
|
423
|
-
"spatial_ref": (
|
|
424
|
-
layer.GetSpatialRef().ExportToWkt() if layer.GetSpatialRef() else "None"
|
|
425
|
-
),
|
|
426
|
-
"extent": layer.GetExtent(),
|
|
427
|
-
"fields": [],
|
|
428
|
-
}
|
|
429
|
-
|
|
430
|
-
# Get field information
|
|
431
|
-
defn = layer.GetLayerDefn()
|
|
432
|
-
for j in range(defn.GetFieldCount()):
|
|
433
|
-
field_defn = defn.GetFieldDefn(j)
|
|
434
|
-
field_info = {
|
|
435
|
-
"name": field_defn.GetName(),
|
|
436
|
-
"type": field_defn.GetTypeName(),
|
|
437
|
-
"width": field_defn.GetWidth(),
|
|
438
|
-
"precision": field_defn.GetPrecision(),
|
|
439
|
-
}
|
|
440
|
-
layer_info["fields"].append(field_info)
|
|
441
|
-
|
|
442
|
-
info["layers"].append(layer_info)
|
|
443
|
-
|
|
444
|
-
# Close the dataset
|
|
445
|
-
ds = None
|
|
446
|
-
|
|
447
|
-
return info
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
def analyze_vector_attributes(vector_path, attribute_name):
|
|
451
|
-
"""Analyze a specific attribute in a vector dataset and create a histogram.
|
|
452
|
-
|
|
453
|
-
Args:
|
|
454
|
-
vector_path (str): Path to the vector file
|
|
455
|
-
attribute_name (str): Name of the attribute to analyze
|
|
456
|
-
|
|
457
|
-
Returns:
|
|
458
|
-
dict: Dictionary containing analysis results for the attribute
|
|
459
|
-
"""
|
|
460
|
-
try:
|
|
461
|
-
gdf = gpd.read_file(vector_path)
|
|
462
|
-
|
|
463
|
-
# Check if attribute exists
|
|
464
|
-
if attribute_name not in gdf.columns:
|
|
465
|
-
print(f"Attribute '{attribute_name}' not found in the dataset")
|
|
466
|
-
return None
|
|
467
|
-
|
|
468
|
-
# Get the attribute series
|
|
469
|
-
attr = gdf[attribute_name]
|
|
470
|
-
|
|
471
|
-
# Perform different analyses based on data type
|
|
472
|
-
if pd.api.types.is_numeric_dtype(attr):
|
|
473
|
-
# Numeric attribute
|
|
474
|
-
analysis = {
|
|
475
|
-
"attribute": attribute_name,
|
|
476
|
-
"type": "numeric",
|
|
477
|
-
"count": attr.count(),
|
|
478
|
-
"null_count": attr.isna().sum(),
|
|
479
|
-
"min": attr.min(),
|
|
480
|
-
"max": attr.max(),
|
|
481
|
-
"mean": attr.mean(),
|
|
482
|
-
"median": attr.median(),
|
|
483
|
-
"std": attr.std(),
|
|
484
|
-
"unique_values": attr.nunique(),
|
|
485
|
-
}
|
|
486
|
-
|
|
487
|
-
# Create histogram
|
|
488
|
-
plt.figure(figsize=(10, 6))
|
|
489
|
-
plt.hist(attr.dropna(), bins=20, alpha=0.7, color="blue")
|
|
490
|
-
plt.title(f"Histogram of {attribute_name}")
|
|
491
|
-
plt.xlabel(attribute_name)
|
|
492
|
-
plt.ylabel("Frequency")
|
|
493
|
-
plt.grid(True, alpha=0.3)
|
|
494
|
-
plt.show()
|
|
495
|
-
|
|
496
|
-
else:
|
|
497
|
-
# Categorical attribute
|
|
498
|
-
analysis = {
|
|
499
|
-
"attribute": attribute_name,
|
|
500
|
-
"type": "categorical",
|
|
501
|
-
"count": attr.count(),
|
|
502
|
-
"null_count": attr.isna().sum(),
|
|
503
|
-
"unique_values": attr.nunique(),
|
|
504
|
-
"value_counts": attr.value_counts().to_dict(),
|
|
505
|
-
}
|
|
506
|
-
|
|
507
|
-
# Create bar plot for top categories
|
|
508
|
-
top_n = min(10, attr.nunique())
|
|
509
|
-
plt.figure(figsize=(10, 6))
|
|
510
|
-
attr.value_counts().head(top_n).plot(kind="bar", color="skyblue")
|
|
511
|
-
plt.title(f"Top {top_n} values for {attribute_name}")
|
|
512
|
-
plt.xlabel(attribute_name)
|
|
513
|
-
plt.ylabel("Count")
|
|
514
|
-
plt.xticks(rotation=45)
|
|
515
|
-
plt.grid(True, alpha=0.3)
|
|
516
|
-
plt.tight_layout()
|
|
517
|
-
plt.show()
|
|
518
|
-
|
|
519
|
-
return analysis
|
|
520
|
-
|
|
521
|
-
except Exception as e:
|
|
522
|
-
print(f"Error analyzing attribute: {str(e)}")
|
|
523
|
-
return None
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
def visualize_vector_by_attribute(
|
|
527
|
-
vector_path, attribute_name, cmap="viridis", figsize=(10, 8)
|
|
528
|
-
):
|
|
529
|
-
"""Create a thematic map visualization of vector data based on an attribute.
|
|
530
|
-
|
|
531
|
-
Args:
|
|
532
|
-
vector_path (str): Path to the vector file
|
|
533
|
-
attribute_name (str): Name of the attribute to visualize
|
|
534
|
-
cmap (str, optional): Matplotlib colormap name. Defaults to 'viridis'.
|
|
535
|
-
figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
|
|
536
|
-
|
|
537
|
-
Returns:
|
|
538
|
-
bool: True if visualization was successful, False otherwise
|
|
539
|
-
"""
|
|
540
|
-
try:
|
|
541
|
-
# Read the vector data
|
|
542
|
-
gdf = gpd.read_file(vector_path)
|
|
543
|
-
|
|
544
|
-
# Check if attribute exists
|
|
545
|
-
if attribute_name not in gdf.columns:
|
|
546
|
-
print(f"Attribute '{attribute_name}' not found in the dataset")
|
|
547
|
-
return False
|
|
548
|
-
|
|
549
|
-
# Create the plot
|
|
550
|
-
fig, ax = plt.subplots(figsize=figsize)
|
|
551
|
-
|
|
552
|
-
# Determine plot type based on data type
|
|
553
|
-
if pd.api.types.is_numeric_dtype(gdf[attribute_name]):
|
|
554
|
-
# Continuous data
|
|
555
|
-
gdf.plot(column=attribute_name, cmap=cmap, legend=True, ax=ax)
|
|
556
|
-
else:
|
|
557
|
-
# Categorical data
|
|
558
|
-
gdf.plot(column=attribute_name, categorical=True, legend=True, ax=ax)
|
|
559
|
-
|
|
560
|
-
# Add title and labels
|
|
561
|
-
ax.set_title(f"{os.path.basename(vector_path)} - {attribute_name}")
|
|
562
|
-
ax.set_xlabel("Longitude")
|
|
563
|
-
ax.set_ylabel("Latitude")
|
|
564
|
-
|
|
565
|
-
# Add basemap or additional elements if available
|
|
566
|
-
# Note: Additional options could be added here for more complex maps
|
|
567
|
-
|
|
568
|
-
plt.tight_layout()
|
|
569
|
-
plt.show()
|
|
570
|
-
|
|
571
|
-
except Exception as e:
|
|
572
|
-
print(f"Error visualizing data: {str(e)}")
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
def clip_raster_by_bbox(
|
|
576
|
-
input_raster, output_raster, bbox, bands=None, bbox_type="geo", bbox_crs=None
|
|
577
|
-
):
|
|
578
|
-
"""
|
|
579
|
-
Clip a raster dataset using a bounding box and optionally select specific bands.
|
|
580
|
-
|
|
581
|
-
Args:
|
|
582
|
-
input_raster (str): Path to the input raster file.
|
|
583
|
-
output_raster (str): Path where the clipped raster will be saved.
|
|
584
|
-
bbox (tuple): Bounding box coordinates either as:
|
|
585
|
-
- Geographic coordinates (minx, miny, maxx, maxy) if bbox_type="geo"
|
|
586
|
-
- Pixel indices (min_row, min_col, max_row, max_col) if bbox_type="pixel"
|
|
587
|
-
bands (list, optional): List of band indices to keep (1-based indexing).
|
|
588
|
-
If None, all bands will be kept.
|
|
589
|
-
bbox_type (str, optional): Type of bounding box coordinates. Either "geo" for
|
|
590
|
-
geographic coordinates or "pixel" for row/column indices.
|
|
591
|
-
Default is "geo".
|
|
592
|
-
bbox_crs (str or dict, optional): CRS of the bbox if different from the raster CRS.
|
|
593
|
-
Can be provided as EPSG code (e.g., "EPSG:4326") or
|
|
594
|
-
as a proj4 string. Only applies when bbox_type="geo".
|
|
595
|
-
If None, assumes bbox is in the same CRS as the raster.
|
|
596
|
-
|
|
597
|
-
Returns:
|
|
598
|
-
str: Path to the clipped output raster.
|
|
599
|
-
|
|
600
|
-
Raises:
|
|
601
|
-
ImportError: If required dependencies are not installed.
|
|
602
|
-
ValueError: If the bbox is invalid, bands are out of range, or bbox_type is invalid.
|
|
603
|
-
RuntimeError: If the clipping operation fails.
|
|
604
|
-
|
|
605
|
-
Examples:
|
|
606
|
-
# Clip using geographic coordinates in the same CRS as the raster
|
|
607
|
-
>>> clip_raster_by_bbox('input.tif', 'clipped_geo.tif', (100, 200, 300, 400))
|
|
608
|
-
'clipped_geo.tif'
|
|
609
|
-
|
|
610
|
-
# Clip using WGS84 coordinates when the raster is in a different CRS
|
|
611
|
-
>>> clip_raster_by_bbox('input.tif', 'clipped_wgs84.tif', (-122.5, 37.7, -122.4, 37.8),
|
|
612
|
-
... bbox_crs="EPSG:4326")
|
|
613
|
-
'clipped_wgs84.tif'
|
|
614
|
-
|
|
615
|
-
# Clip using row/column indices
|
|
616
|
-
>>> clip_raster_by_bbox('input.tif', 'clipped_pixel.tif', (50, 100, 150, 200),
|
|
617
|
-
... bbox_type="pixel")
|
|
618
|
-
'clipped_pixel.tif'
|
|
619
|
-
|
|
620
|
-
# Clip with band selection
|
|
621
|
-
>>> clip_raster_by_bbox('input.tif', 'clipped_bands.tif', (100, 200, 300, 400),
|
|
622
|
-
... bands=[1, 3])
|
|
623
|
-
'clipped_bands.tif'
|
|
624
|
-
"""
|
|
625
|
-
from rasterio.transform import from_bounds
|
|
626
|
-
from rasterio.warp import transform_bounds
|
|
627
|
-
|
|
628
|
-
# Validate bbox_type
|
|
629
|
-
if bbox_type not in ["geo", "pixel"]:
|
|
630
|
-
raise ValueError("bbox_type must be either 'geo' or 'pixel'")
|
|
631
|
-
|
|
632
|
-
# Validate bbox
|
|
633
|
-
if len(bbox) != 4:
|
|
634
|
-
raise ValueError("bbox must contain exactly 4 values")
|
|
635
|
-
|
|
636
|
-
# Open the source raster
|
|
637
|
-
with rasterio.open(input_raster) as src:
|
|
638
|
-
# Get the source CRS
|
|
639
|
-
src_crs = src.crs
|
|
640
|
-
|
|
641
|
-
# Handle different bbox types
|
|
642
|
-
if bbox_type == "geo":
|
|
643
|
-
minx, miny, maxx, maxy = bbox
|
|
644
|
-
|
|
645
|
-
# Validate geographic bbox
|
|
646
|
-
if minx >= maxx or miny >= maxy:
|
|
647
|
-
raise ValueError(
|
|
648
|
-
"Invalid geographic bbox. Expected (minx, miny, maxx, maxy) where minx < maxx and miny < maxy"
|
|
649
|
-
)
|
|
650
|
-
|
|
651
|
-
# If bbox_crs is provided and different from the source CRS, transform the bbox
|
|
652
|
-
if bbox_crs is not None and bbox_crs != src_crs:
|
|
653
|
-
try:
|
|
654
|
-
# Transform bbox coordinates from bbox_crs to src_crs
|
|
655
|
-
minx, miny, maxx, maxy = transform_bounds(
|
|
656
|
-
bbox_crs, src_crs, minx, miny, maxx, maxy
|
|
657
|
-
)
|
|
658
|
-
except Exception as e:
|
|
659
|
-
raise ValueError(
|
|
660
|
-
f"Failed to transform bbox from {bbox_crs} to {src_crs}: {str(e)}"
|
|
661
|
-
)
|
|
662
|
-
|
|
663
|
-
# Calculate the pixel window from geographic coordinates
|
|
664
|
-
window = src.window(minx, miny, maxx, maxy)
|
|
665
|
-
|
|
666
|
-
# Use the same bounds for the output transform
|
|
667
|
-
output_bounds = (minx, miny, maxx, maxy)
|
|
668
|
-
|
|
669
|
-
else: # bbox_type == "pixel"
|
|
670
|
-
min_row, min_col, max_row, max_col = bbox
|
|
671
|
-
|
|
672
|
-
# Validate pixel bbox
|
|
673
|
-
if min_row >= max_row or min_col >= max_col:
|
|
674
|
-
raise ValueError(
|
|
675
|
-
"Invalid pixel bbox. Expected (min_row, min_col, max_row, max_col) where min_row < max_row and min_col < max_col"
|
|
676
|
-
)
|
|
677
|
-
|
|
678
|
-
if (
|
|
679
|
-
min_row < 0
|
|
680
|
-
or min_col < 0
|
|
681
|
-
or max_row > src.height
|
|
682
|
-
or max_col > src.width
|
|
683
|
-
):
|
|
684
|
-
raise ValueError(
|
|
685
|
-
f"Pixel indices out of bounds. Raster dimensions are {src.height} rows x {src.width} columns"
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
# Create a window from pixel coordinates
|
|
689
|
-
window = Window(min_col, min_row, max_col - min_col, max_row - min_row)
|
|
690
|
-
|
|
691
|
-
# Calculate the geographic bounds for this window
|
|
692
|
-
window_transform = src.window_transform(window)
|
|
693
|
-
output_bounds = rasterio.transform.array_bounds(
|
|
694
|
-
window.height, window.width, window_transform
|
|
695
|
-
)
|
|
696
|
-
# Reorder to (minx, miny, maxx, maxy)
|
|
697
|
-
output_bounds = (
|
|
698
|
-
output_bounds[0],
|
|
699
|
-
output_bounds[1],
|
|
700
|
-
output_bounds[2],
|
|
701
|
-
output_bounds[3],
|
|
702
|
-
)
|
|
703
|
-
|
|
704
|
-
# Get window dimensions
|
|
705
|
-
window_width = int(window.width)
|
|
706
|
-
window_height = int(window.height)
|
|
707
|
-
|
|
708
|
-
# Check if the window is valid
|
|
709
|
-
if window_width <= 0 or window_height <= 0:
|
|
710
|
-
raise ValueError("Bounding box results in an empty window")
|
|
711
|
-
|
|
712
|
-
# Handle band selection
|
|
713
|
-
if bands is None:
|
|
714
|
-
# Use all bands
|
|
715
|
-
bands_to_read = list(range(1, src.count + 1))
|
|
716
|
-
else:
|
|
717
|
-
# Validate band indices
|
|
718
|
-
if not all(1 <= b <= src.count for b in bands):
|
|
719
|
-
raise ValueError(f"Band indices must be between 1 and {src.count}")
|
|
720
|
-
bands_to_read = bands
|
|
721
|
-
|
|
722
|
-
# Calculate new transform for the clipped raster
|
|
723
|
-
new_transform = from_bounds(
|
|
724
|
-
output_bounds[0],
|
|
725
|
-
output_bounds[1],
|
|
726
|
-
output_bounds[2],
|
|
727
|
-
output_bounds[3],
|
|
728
|
-
window_width,
|
|
729
|
-
window_height,
|
|
730
|
-
)
|
|
731
|
-
|
|
732
|
-
# Create a metadata dictionary for the output
|
|
733
|
-
out_meta = src.meta.copy()
|
|
734
|
-
out_meta.update(
|
|
735
|
-
{
|
|
736
|
-
"height": window_height,
|
|
737
|
-
"width": window_width,
|
|
738
|
-
"transform": new_transform,
|
|
739
|
-
"count": len(bands_to_read),
|
|
740
|
-
}
|
|
741
|
-
)
|
|
742
|
-
|
|
743
|
-
# Read the data for the selected bands
|
|
744
|
-
data = []
|
|
745
|
-
for band_idx in bands_to_read:
|
|
746
|
-
band_data = src.read(band_idx, window=window)
|
|
747
|
-
data.append(band_data)
|
|
748
|
-
|
|
749
|
-
# Stack the bands into a single array
|
|
750
|
-
if len(data) > 1:
|
|
751
|
-
clipped_data = np.stack(data)
|
|
752
|
-
else:
|
|
753
|
-
clipped_data = data[0][np.newaxis, :, :]
|
|
754
|
-
|
|
755
|
-
# Write the output raster
|
|
756
|
-
with rasterio.open(output_raster, "w", **out_meta) as dst:
|
|
757
|
-
dst.write(clipped_data)
|
|
758
|
-
|
|
759
|
-
return output_raster
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
def raster_to_vector(
|
|
763
|
-
raster_path,
|
|
764
|
-
output_path=None,
|
|
765
|
-
threshold=0,
|
|
766
|
-
min_area=10,
|
|
767
|
-
simplify_tolerance=None,
|
|
768
|
-
class_values=None,
|
|
769
|
-
attribute_name="class",
|
|
770
|
-
output_format="geojson",
|
|
771
|
-
plot_result=False,
|
|
772
|
-
):
|
|
773
|
-
"""
|
|
774
|
-
Convert a raster label mask to vector polygons.
|
|
775
|
-
|
|
776
|
-
Args:
|
|
777
|
-
raster_path (str): Path to the input raster file (e.g., GeoTIFF).
|
|
778
|
-
output_path (str): Path to save the output vector file. If None, returns GeoDataFrame without saving.
|
|
779
|
-
threshold (int/float): Pixel values greater than this threshold will be vectorized.
|
|
780
|
-
min_area (float): Minimum polygon area in square map units to keep.
|
|
781
|
-
simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
|
|
782
|
-
class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
|
|
783
|
-
attribute_name (str): Name of the attribute field for the class values.
|
|
784
|
-
output_format (str): Format for output file - 'geojson', 'shapefile', 'gpkg'.
|
|
785
|
-
plot_result (bool): Whether to plot the resulting polygons overlaid on the raster.
|
|
786
|
-
|
|
787
|
-
Returns:
|
|
788
|
-
geopandas.GeoDataFrame: A GeoDataFrame containing the vectorized polygons.
|
|
789
|
-
"""
|
|
790
|
-
# Open the raster file
|
|
791
|
-
with rasterio.open(raster_path) as src:
|
|
792
|
-
# Read the data
|
|
793
|
-
data = src.read(1)
|
|
794
|
-
|
|
795
|
-
# Get metadata
|
|
796
|
-
transform = src.transform
|
|
797
|
-
crs = src.crs
|
|
798
|
-
|
|
799
|
-
# Create mask based on threshold and class values
|
|
800
|
-
if class_values is not None:
|
|
801
|
-
# Create a mask for each specified class value
|
|
802
|
-
masks = {val: (data == val) for val in class_values}
|
|
803
|
-
else:
|
|
804
|
-
# Create a mask for values above threshold
|
|
805
|
-
masks = {1: (data > threshold)}
|
|
806
|
-
class_values = [1] # Default class
|
|
807
|
-
|
|
808
|
-
# Initialize list to store features
|
|
809
|
-
all_features = []
|
|
810
|
-
|
|
811
|
-
# Process each class value
|
|
812
|
-
for class_val in class_values:
|
|
813
|
-
mask = masks[class_val]
|
|
814
|
-
|
|
815
|
-
# Vectorize the mask
|
|
816
|
-
for geom, value in features.shapes(
|
|
817
|
-
mask.astype(np.uint8), mask=mask, transform=transform
|
|
818
|
-
):
|
|
819
|
-
# Convert to shapely geometry
|
|
820
|
-
geom = shape(geom)
|
|
821
|
-
|
|
822
|
-
# Skip small polygons
|
|
823
|
-
if geom.area < min_area:
|
|
824
|
-
continue
|
|
825
|
-
|
|
826
|
-
# Simplify geometry if requested
|
|
827
|
-
if simplify_tolerance is not None:
|
|
828
|
-
geom = geom.simplify(simplify_tolerance)
|
|
829
|
-
|
|
830
|
-
# Add to features list with class value
|
|
831
|
-
all_features.append({"geometry": geom, attribute_name: class_val})
|
|
832
|
-
|
|
833
|
-
# Create GeoDataFrame
|
|
834
|
-
if all_features:
|
|
835
|
-
gdf = gpd.GeoDataFrame(all_features, crs=crs)
|
|
836
|
-
else:
|
|
837
|
-
print("Warning: No features were extracted from the raster.")
|
|
838
|
-
# Return empty GeoDataFrame with correct CRS
|
|
839
|
-
gdf = gpd.GeoDataFrame([], geometry=[], crs=crs)
|
|
840
|
-
|
|
841
|
-
# Save to file if requested
|
|
842
|
-
if output_path is not None:
|
|
843
|
-
# Create directory if it doesn't exist
|
|
844
|
-
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
|
845
|
-
|
|
846
|
-
# Save to file based on format
|
|
847
|
-
if output_format.lower() == "geojson":
|
|
848
|
-
gdf.to_file(output_path, driver="GeoJSON")
|
|
849
|
-
elif output_format.lower() == "shapefile":
|
|
850
|
-
gdf.to_file(output_path)
|
|
851
|
-
elif output_format.lower() == "gpkg":
|
|
852
|
-
gdf.to_file(output_path, driver="GPKG")
|
|
853
|
-
else:
|
|
854
|
-
raise ValueError(f"Unsupported output format: {output_format}")
|
|
855
|
-
|
|
856
|
-
print(f"Vectorized data saved to {output_path}")
|
|
857
|
-
|
|
858
|
-
# Plot result if requested
|
|
859
|
-
if plot_result:
|
|
860
|
-
fig, ax = plt.subplots(figsize=(12, 12))
|
|
861
|
-
|
|
862
|
-
# Plot raster
|
|
863
|
-
raster_img = src.read()
|
|
864
|
-
if raster_img.shape[0] == 1:
|
|
865
|
-
plt.imshow(raster_img[0], cmap="viridis", alpha=0.7)
|
|
866
|
-
else:
|
|
867
|
-
# Use first 3 bands for RGB display
|
|
868
|
-
rgb = raster_img[:3].transpose(1, 2, 0)
|
|
869
|
-
# Normalize for display
|
|
870
|
-
rgb = np.clip(rgb / rgb.max(), 0, 1)
|
|
871
|
-
plt.imshow(rgb)
|
|
872
|
-
|
|
873
|
-
# Plot vector boundaries
|
|
874
|
-
if not gdf.empty:
|
|
875
|
-
gdf.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
|
|
876
|
-
|
|
877
|
-
plt.title("Raster with Vectorized Boundaries")
|
|
878
|
-
plt.axis("off")
|
|
879
|
-
plt.tight_layout()
|
|
880
|
-
plt.show()
|
|
881
|
-
|
|
882
|
-
return gdf
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
def batch_raster_to_vector(
|
|
886
|
-
input_dir,
|
|
887
|
-
output_dir,
|
|
888
|
-
pattern="*.tif",
|
|
889
|
-
threshold=0,
|
|
890
|
-
min_area=10,
|
|
891
|
-
simplify_tolerance=None,
|
|
892
|
-
class_values=None,
|
|
893
|
-
attribute_name="class",
|
|
894
|
-
output_format="geojson",
|
|
895
|
-
merge_output=False,
|
|
896
|
-
merge_filename="merged_vectors",
|
|
897
|
-
):
|
|
898
|
-
"""
|
|
899
|
-
Batch convert multiple raster files to vector polygons.
|
|
900
|
-
|
|
901
|
-
Args:
|
|
902
|
-
input_dir (str): Directory containing input raster files.
|
|
903
|
-
output_dir (str): Directory to save output vector files.
|
|
904
|
-
pattern (str): Pattern to match raster files (e.g., '*.tif').
|
|
905
|
-
threshold (int/float): Pixel values greater than this threshold will be vectorized.
|
|
906
|
-
min_area (float): Minimum polygon area in square map units to keep.
|
|
907
|
-
simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
|
|
908
|
-
class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
|
|
909
|
-
attribute_name (str): Name of the attribute field for the class values.
|
|
910
|
-
output_format (str): Format for output files - 'geojson', 'shapefile', 'gpkg'.
|
|
911
|
-
merge_output (bool): Whether to merge all output vectors into a single file.
|
|
912
|
-
merge_filename (str): Filename for the merged output (without extension).
|
|
913
|
-
|
|
914
|
-
Returns:
|
|
915
|
-
geopandas.GeoDataFrame or None: If merge_output is True, returns the merged GeoDataFrame.
|
|
916
|
-
"""
|
|
917
|
-
import glob
|
|
918
|
-
|
|
919
|
-
# Create output directory if it doesn't exist
|
|
920
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
921
|
-
|
|
922
|
-
# Get list of raster files
|
|
923
|
-
raster_files = glob.glob(os.path.join(input_dir, pattern))
|
|
924
|
-
|
|
925
|
-
if not raster_files:
|
|
926
|
-
print(f"No files matching pattern '{pattern}' found in {input_dir}")
|
|
927
|
-
return None
|
|
928
|
-
|
|
929
|
-
print(f"Found {len(raster_files)} raster files to process")
|
|
930
|
-
|
|
931
|
-
# Process each raster file
|
|
932
|
-
gdfs = []
|
|
933
|
-
for raster_file in tqdm(raster_files, desc="Processing rasters"):
|
|
934
|
-
# Get output filename
|
|
935
|
-
base_name = os.path.splitext(os.path.basename(raster_file))[0]
|
|
936
|
-
if output_format.lower() == "geojson":
|
|
937
|
-
out_file = os.path.join(output_dir, f"{base_name}.geojson")
|
|
938
|
-
elif output_format.lower() == "shapefile":
|
|
939
|
-
out_file = os.path.join(output_dir, f"{base_name}.shp")
|
|
940
|
-
elif output_format.lower() == "gpkg":
|
|
941
|
-
out_file = os.path.join(output_dir, f"{base_name}.gpkg")
|
|
942
|
-
else:
|
|
943
|
-
raise ValueError(f"Unsupported output format: {output_format}")
|
|
944
|
-
|
|
945
|
-
# Convert raster to vector
|
|
946
|
-
if merge_output:
|
|
947
|
-
# Don't save individual files if merging
|
|
948
|
-
gdf = raster_to_vector(
|
|
949
|
-
raster_file,
|
|
950
|
-
output_path=None,
|
|
951
|
-
threshold=threshold,
|
|
952
|
-
min_area=min_area,
|
|
953
|
-
simplify_tolerance=simplify_tolerance,
|
|
954
|
-
class_values=class_values,
|
|
955
|
-
attribute_name=attribute_name,
|
|
956
|
-
)
|
|
957
|
-
|
|
958
|
-
# Add filename as attribute
|
|
959
|
-
if not gdf.empty:
|
|
960
|
-
gdf["source_file"] = base_name
|
|
961
|
-
gdfs.append(gdf)
|
|
962
|
-
else:
|
|
963
|
-
# Save individual files
|
|
964
|
-
raster_to_vector(
|
|
965
|
-
raster_file,
|
|
966
|
-
output_path=out_file,
|
|
967
|
-
threshold=threshold,
|
|
968
|
-
min_area=min_area,
|
|
969
|
-
simplify_tolerance=simplify_tolerance,
|
|
970
|
-
class_values=class_values,
|
|
971
|
-
attribute_name=attribute_name,
|
|
972
|
-
output_format=output_format,
|
|
973
|
-
)
|
|
974
|
-
|
|
975
|
-
# Merge output if requested
|
|
976
|
-
if merge_output and gdfs:
|
|
977
|
-
merged_gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))
|
|
978
|
-
|
|
979
|
-
# Set CRS to the CRS of the first GeoDataFrame
|
|
980
|
-
if merged_gdf.crs is None and gdfs:
|
|
981
|
-
merged_gdf.crs = gdfs[0].crs
|
|
982
|
-
|
|
983
|
-
# Save merged output
|
|
984
|
-
if output_format.lower() == "geojson":
|
|
985
|
-
merged_file = os.path.join(output_dir, f"{merge_filename}.geojson")
|
|
986
|
-
merged_gdf.to_file(merged_file, driver="GeoJSON")
|
|
987
|
-
elif output_format.lower() == "shapefile":
|
|
988
|
-
merged_file = os.path.join(output_dir, f"{merge_filename}.shp")
|
|
989
|
-
merged_gdf.to_file(merged_file)
|
|
990
|
-
elif output_format.lower() == "gpkg":
|
|
991
|
-
merged_file = os.path.join(output_dir, f"{merge_filename}.gpkg")
|
|
992
|
-
merged_gdf.to_file(merged_file, driver="GPKG")
|
|
993
|
-
|
|
994
|
-
print(f"Merged vector data saved to {merged_file}")
|
|
995
|
-
return merged_gdf
|
|
996
|
-
|
|
997
|
-
return None
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
def vector_to_raster(
|
|
1001
|
-
vector_path,
|
|
1002
|
-
output_path=None,
|
|
1003
|
-
reference_raster=None,
|
|
1004
|
-
attribute_field=None,
|
|
1005
|
-
output_shape=None,
|
|
1006
|
-
transform=None,
|
|
1007
|
-
pixel_size=None,
|
|
1008
|
-
bounds=None,
|
|
1009
|
-
crs=None,
|
|
1010
|
-
all_touched=False,
|
|
1011
|
-
fill_value=0,
|
|
1012
|
-
dtype=np.uint8,
|
|
1013
|
-
nodata=None,
|
|
1014
|
-
plot_result=False,
|
|
1015
|
-
):
|
|
1016
|
-
"""
|
|
1017
|
-
Convert vector data to a raster.
|
|
1018
|
-
|
|
1019
|
-
Args:
|
|
1020
|
-
vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
|
|
1021
|
-
output_path (str): Path to save the output raster file. If None, returns the array without saving.
|
|
1022
|
-
reference_raster (str): Path to a reference raster for dimensions, transform and CRS.
|
|
1023
|
-
attribute_field (str): Field name in the vector data to use for pixel values.
|
|
1024
|
-
If None, all vector features will be burned with value 1.
|
|
1025
|
-
output_shape (tuple): Shape of the output raster as (height, width).
|
|
1026
|
-
Required if reference_raster is not provided.
|
|
1027
|
-
transform (affine.Affine): Affine transformation matrix.
|
|
1028
|
-
Required if reference_raster is not provided.
|
|
1029
|
-
pixel_size (float or tuple): Pixel size (resolution) as single value or (x_res, y_res).
|
|
1030
|
-
Used to calculate transform if transform is not provided.
|
|
1031
|
-
bounds (tuple): Bounds of the output raster as (left, bottom, right, top).
|
|
1032
|
-
Used to calculate transform if transform is not provided.
|
|
1033
|
-
crs (str or CRS): Coordinate reference system of the output raster.
|
|
1034
|
-
Required if reference_raster is not provided.
|
|
1035
|
-
all_touched (bool): If True, all pixels touched by geometries will be burned in.
|
|
1036
|
-
If False, only pixels whose center is within the geometry will be burned in.
|
|
1037
|
-
fill_value (int): Value to fill the raster with before burning in features.
|
|
1038
|
-
dtype (numpy.dtype): Data type of the output raster.
|
|
1039
|
-
nodata (int): No data value for the output raster.
|
|
1040
|
-
plot_result (bool): Whether to plot the resulting raster.
|
|
1041
|
-
|
|
1042
|
-
Returns:
|
|
1043
|
-
numpy.ndarray: The rasterized data array if output_path is None, else None.
|
|
1044
|
-
"""
|
|
1045
|
-
# Load vector data
|
|
1046
|
-
if isinstance(vector_path, gpd.GeoDataFrame):
|
|
1047
|
-
gdf = vector_path
|
|
1048
|
-
else:
|
|
1049
|
-
gdf = gpd.read_file(vector_path)
|
|
1050
|
-
|
|
1051
|
-
# Check if vector data is empty
|
|
1052
|
-
if gdf.empty:
|
|
1053
|
-
warnings.warn("The input vector data is empty. Creating an empty raster.")
|
|
1054
|
-
|
|
1055
|
-
# Get CRS from vector data if not provided
|
|
1056
|
-
if crs is None and reference_raster is None:
|
|
1057
|
-
crs = gdf.crs
|
|
1058
|
-
|
|
1059
|
-
# Get transform and output shape from reference raster if provided
|
|
1060
|
-
if reference_raster is not None:
|
|
1061
|
-
with rasterio.open(reference_raster) as src:
|
|
1062
|
-
transform = src.transform
|
|
1063
|
-
output_shape = src.shape
|
|
1064
|
-
crs = src.crs
|
|
1065
|
-
if nodata is None:
|
|
1066
|
-
nodata = src.nodata
|
|
1067
|
-
else:
|
|
1068
|
-
# Check if we have all required parameters
|
|
1069
|
-
if transform is None:
|
|
1070
|
-
if pixel_size is None or bounds is None:
|
|
1071
|
-
raise ValueError(
|
|
1072
|
-
"Either reference_raster, transform, or both pixel_size and bounds must be provided."
|
|
1073
|
-
)
|
|
1074
|
-
|
|
1075
|
-
# Calculate transform from pixel size and bounds
|
|
1076
|
-
if isinstance(pixel_size, (int, float)):
|
|
1077
|
-
x_res = y_res = float(pixel_size)
|
|
1078
|
-
else:
|
|
1079
|
-
x_res, y_res = pixel_size
|
|
1080
|
-
y_res = abs(y_res) * -1 # Convert to negative for north-up raster
|
|
1081
|
-
|
|
1082
|
-
left, bottom, right, top = bounds
|
|
1083
|
-
transform = rasterio.transform.from_bounds(
|
|
1084
|
-
left,
|
|
1085
|
-
bottom,
|
|
1086
|
-
right,
|
|
1087
|
-
top,
|
|
1088
|
-
int((right - left) / x_res),
|
|
1089
|
-
int((top - bottom) / abs(y_res)),
|
|
1090
|
-
)
|
|
1091
|
-
|
|
1092
|
-
if output_shape is None:
|
|
1093
|
-
# Calculate output shape from bounds and pixel size
|
|
1094
|
-
if bounds is None or pixel_size is None:
|
|
1095
|
-
raise ValueError(
|
|
1096
|
-
"output_shape must be provided if reference_raster is not provided and "
|
|
1097
|
-
"cannot be calculated from bounds and pixel_size."
|
|
1098
|
-
)
|
|
1099
|
-
|
|
1100
|
-
if isinstance(pixel_size, (int, float)):
|
|
1101
|
-
x_res = y_res = float(pixel_size)
|
|
1102
|
-
else:
|
|
1103
|
-
x_res, y_res = pixel_size
|
|
1104
|
-
|
|
1105
|
-
left, bottom, right, top = bounds
|
|
1106
|
-
width = int((right - left) / x_res)
|
|
1107
|
-
height = int((top - bottom) / abs(y_res))
|
|
1108
|
-
output_shape = (height, width)
|
|
1109
|
-
|
|
1110
|
-
# Ensure CRS is set
|
|
1111
|
-
if crs is None:
|
|
1112
|
-
raise ValueError(
|
|
1113
|
-
"CRS must be provided either directly, from reference_raster, or from input vector data."
|
|
1114
|
-
)
|
|
1115
|
-
|
|
1116
|
-
# Reproject vector data if its CRS doesn't match the output CRS
|
|
1117
|
-
if gdf.crs != crs:
|
|
1118
|
-
print(f"Reprojecting vector data from {gdf.crs} to {crs}")
|
|
1119
|
-
gdf = gdf.to_crs(crs)
|
|
1120
|
-
|
|
1121
|
-
# Create empty raster filled with fill_value
|
|
1122
|
-
raster_data = np.full(output_shape, fill_value, dtype=dtype)
|
|
1123
|
-
|
|
1124
|
-
# Burn vector features into raster
|
|
1125
|
-
if not gdf.empty:
|
|
1126
|
-
# Prepare shapes for burning
|
|
1127
|
-
if attribute_field is not None and attribute_field in gdf.columns:
|
|
1128
|
-
# Use attribute field for values
|
|
1129
|
-
shapes = [
|
|
1130
|
-
(geom, value) for geom, value in zip(gdf.geometry, gdf[attribute_field])
|
|
1131
|
-
]
|
|
1132
|
-
else:
|
|
1133
|
-
# Burn with value 1
|
|
1134
|
-
shapes = [(geom, 1) for geom in gdf.geometry]
|
|
1135
|
-
|
|
1136
|
-
# Burn shapes into raster
|
|
1137
|
-
burned = features.rasterize(
|
|
1138
|
-
shapes=shapes,
|
|
1139
|
-
out_shape=output_shape,
|
|
1140
|
-
transform=transform,
|
|
1141
|
-
fill=fill_value,
|
|
1142
|
-
all_touched=all_touched,
|
|
1143
|
-
dtype=dtype,
|
|
1144
|
-
)
|
|
1145
|
-
|
|
1146
|
-
# Update raster data
|
|
1147
|
-
raster_data = burned
|
|
1148
|
-
|
|
1149
|
-
# Save raster if output path is provided
|
|
1150
|
-
if output_path is not None:
|
|
1151
|
-
# Create directory if it doesn't exist
|
|
1152
|
-
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
|
1153
|
-
|
|
1154
|
-
# Define metadata
|
|
1155
|
-
metadata = {
|
|
1156
|
-
"driver": "GTiff",
|
|
1157
|
-
"height": output_shape[0],
|
|
1158
|
-
"width": output_shape[1],
|
|
1159
|
-
"count": 1,
|
|
1160
|
-
"dtype": raster_data.dtype,
|
|
1161
|
-
"crs": crs,
|
|
1162
|
-
"transform": transform,
|
|
1163
|
-
}
|
|
1164
|
-
|
|
1165
|
-
# Add nodata value if provided
|
|
1166
|
-
if nodata is not None:
|
|
1167
|
-
metadata["nodata"] = nodata
|
|
1168
|
-
|
|
1169
|
-
# Write raster
|
|
1170
|
-
with rasterio.open(output_path, "w", **metadata) as dst:
|
|
1171
|
-
dst.write(raster_data, 1)
|
|
1172
|
-
|
|
1173
|
-
print(f"Rasterized data saved to {output_path}")
|
|
1174
|
-
|
|
1175
|
-
# Plot result if requested
|
|
1176
|
-
if plot_result:
|
|
1177
|
-
fig, ax = plt.subplots(figsize=(10, 10))
|
|
1178
|
-
|
|
1179
|
-
# Plot raster
|
|
1180
|
-
im = ax.imshow(raster_data, cmap="viridis")
|
|
1181
|
-
plt.colorbar(im, ax=ax, label=attribute_field if attribute_field else "Value")
|
|
1182
|
-
|
|
1183
|
-
# Plot vector boundaries for reference
|
|
1184
|
-
if output_path is not None:
|
|
1185
|
-
# Get the extent of the raster
|
|
1186
|
-
with rasterio.open(output_path) as src:
|
|
1187
|
-
bounds = src.bounds
|
|
1188
|
-
raster_bbox = box(*bounds)
|
|
1189
|
-
else:
|
|
1190
|
-
# Calculate extent from transform and shape
|
|
1191
|
-
height, width = output_shape
|
|
1192
|
-
left, top = transform * (0, 0)
|
|
1193
|
-
right, bottom = transform * (width, height)
|
|
1194
|
-
raster_bbox = box(left, bottom, right, top)
|
|
1195
|
-
|
|
1196
|
-
# Clip vector to raster extent for clarity in plot
|
|
1197
|
-
if not gdf.empty:
|
|
1198
|
-
gdf_clipped = gpd.clip(gdf, raster_bbox)
|
|
1199
|
-
if not gdf_clipped.empty:
|
|
1200
|
-
gdf_clipped.boundary.plot(ax=ax, color="red", linewidth=1)
|
|
1201
|
-
|
|
1202
|
-
plt.title("Rasterized Vector Data")
|
|
1203
|
-
plt.tight_layout()
|
|
1204
|
-
plt.show()
|
|
1205
|
-
|
|
1206
|
-
return raster_data
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
def batch_vector_to_raster(
|
|
1210
|
-
vector_path,
|
|
1211
|
-
output_dir,
|
|
1212
|
-
attribute_field=None,
|
|
1213
|
-
reference_rasters=None,
|
|
1214
|
-
bounds_list=None,
|
|
1215
|
-
output_filename_pattern="{vector_name}_{index}",
|
|
1216
|
-
pixel_size=1.0,
|
|
1217
|
-
all_touched=False,
|
|
1218
|
-
fill_value=0,
|
|
1219
|
-
dtype=np.uint8,
|
|
1220
|
-
nodata=None,
|
|
1221
|
-
):
|
|
1222
|
-
"""
|
|
1223
|
-
Batch convert vector data to multiple rasters based on different extents or reference rasters.
|
|
1224
|
-
|
|
1225
|
-
Args:
|
|
1226
|
-
vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
|
|
1227
|
-
output_dir (str): Directory to save output raster files.
|
|
1228
|
-
attribute_field (str): Field name in the vector data to use for pixel values.
|
|
1229
|
-
reference_rasters (list): List of paths to reference rasters for dimensions, transform and CRS.
|
|
1230
|
-
bounds_list (list): List of bounds tuples (left, bottom, right, top) to use if reference_rasters not provided.
|
|
1231
|
-
output_filename_pattern (str): Pattern for output filenames.
|
|
1232
|
-
Can include {vector_name} and {index} placeholders.
|
|
1233
|
-
pixel_size (float or tuple): Pixel size to use if reference_rasters not provided.
|
|
1234
|
-
all_touched (bool): If True, all pixels touched by geometries will be burned in.
|
|
1235
|
-
fill_value (int): Value to fill the raster with before burning in features.
|
|
1236
|
-
dtype (numpy.dtype): Data type of the output raster.
|
|
1237
|
-
nodata (int): No data value for the output raster.
|
|
1238
|
-
|
|
1239
|
-
Returns:
|
|
1240
|
-
list: List of paths to the created raster files.
|
|
1241
|
-
"""
|
|
1242
|
-
# Create output directory if it doesn't exist
|
|
1243
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
1244
|
-
|
|
1245
|
-
# Load vector data if it's a path
|
|
1246
|
-
if isinstance(vector_path, str):
|
|
1247
|
-
gdf = gpd.read_file(vector_path)
|
|
1248
|
-
vector_name = os.path.splitext(os.path.basename(vector_path))[0]
|
|
1249
|
-
else:
|
|
1250
|
-
gdf = vector_path
|
|
1251
|
-
vector_name = "vector"
|
|
1252
|
-
|
|
1253
|
-
# Check input parameters
|
|
1254
|
-
if reference_rasters is None and bounds_list is None:
|
|
1255
|
-
raise ValueError("Either reference_rasters or bounds_list must be provided.")
|
|
1256
|
-
|
|
1257
|
-
# Use reference_rasters if provided, otherwise use bounds_list
|
|
1258
|
-
if reference_rasters is not None:
|
|
1259
|
-
sources = reference_rasters
|
|
1260
|
-
is_raster_reference = True
|
|
1261
|
-
else:
|
|
1262
|
-
sources = bounds_list
|
|
1263
|
-
is_raster_reference = False
|
|
1264
|
-
|
|
1265
|
-
# Create output filenames
|
|
1266
|
-
output_files = []
|
|
1267
|
-
|
|
1268
|
-
# Process each source (reference raster or bounds)
|
|
1269
|
-
for i, source in enumerate(tqdm(sources, desc="Processing")):
|
|
1270
|
-
# Generate output filename
|
|
1271
|
-
output_filename = output_filename_pattern.format(
|
|
1272
|
-
vector_name=vector_name, index=i
|
|
1273
|
-
)
|
|
1274
|
-
if not output_filename.endswith(".tif"):
|
|
1275
|
-
output_filename += ".tif"
|
|
1276
|
-
output_path = os.path.join(output_dir, output_filename)
|
|
1277
|
-
|
|
1278
|
-
if is_raster_reference:
|
|
1279
|
-
# Use reference raster
|
|
1280
|
-
vector_to_raster(
|
|
1281
|
-
vector_path=gdf,
|
|
1282
|
-
output_path=output_path,
|
|
1283
|
-
reference_raster=source,
|
|
1284
|
-
attribute_field=attribute_field,
|
|
1285
|
-
all_touched=all_touched,
|
|
1286
|
-
fill_value=fill_value,
|
|
1287
|
-
dtype=dtype,
|
|
1288
|
-
nodata=nodata,
|
|
1289
|
-
)
|
|
1290
|
-
else:
|
|
1291
|
-
# Use bounds
|
|
1292
|
-
vector_to_raster(
|
|
1293
|
-
vector_path=gdf,
|
|
1294
|
-
output_path=output_path,
|
|
1295
|
-
bounds=source,
|
|
1296
|
-
pixel_size=pixel_size,
|
|
1297
|
-
attribute_field=attribute_field,
|
|
1298
|
-
all_touched=all_touched,
|
|
1299
|
-
fill_value=fill_value,
|
|
1300
|
-
dtype=dtype,
|
|
1301
|
-
nodata=nodata,
|
|
1302
|
-
)
|
|
1303
|
-
|
|
1304
|
-
output_files.append(output_path)
|
|
1305
|
-
|
|
1306
|
-
return output_files
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
def export_geotiff_tiles(
|
|
1310
|
-
in_raster,
|
|
1311
|
-
out_folder,
|
|
1312
|
-
in_class_data,
|
|
1313
|
-
tile_size=256,
|
|
1314
|
-
stride=128,
|
|
1315
|
-
class_value_field="class",
|
|
1316
|
-
buffer_radius=0,
|
|
1317
|
-
max_tiles=None,
|
|
1318
|
-
quiet=False,
|
|
1319
|
-
all_touched=True,
|
|
1320
|
-
create_overview=False,
|
|
1321
|
-
skip_empty_tiles=False,
|
|
1322
|
-
):
|
|
1323
|
-
"""
|
|
1324
|
-
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
|
1325
|
-
|
|
1326
|
-
Args:
|
|
1327
|
-
in_raster (str): Path to input raster image
|
|
1328
|
-
out_folder (str): Path to output folder
|
|
1329
|
-
in_class_data (str): Path to classification data - can be vector file or raster
|
|
1330
|
-
tile_size (int): Size of tiles in pixels (square)
|
|
1331
|
-
stride (int): Step size between tiles
|
|
1332
|
-
class_value_field (str): Field containing class values (for vector data)
|
|
1333
|
-
buffer_radius (float): Buffer to add around features (in units of the CRS)
|
|
1334
|
-
max_tiles (int): Maximum number of tiles to process (None for all)
|
|
1335
|
-
quiet (bool): If True, suppress non-essential output
|
|
1336
|
-
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
|
1337
|
-
create_overview (bool): Whether to create an overview image of all tiles
|
|
1338
|
-
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
1339
|
-
"""
|
|
1340
|
-
# Create output directories
|
|
1341
|
-
os.makedirs(out_folder, exist_ok=True)
|
|
1342
|
-
image_dir = os.path.join(out_folder, "images")
|
|
1343
|
-
os.makedirs(image_dir, exist_ok=True)
|
|
1344
|
-
label_dir = os.path.join(out_folder, "labels")
|
|
1345
|
-
os.makedirs(label_dir, exist_ok=True)
|
|
1346
|
-
ann_dir = os.path.join(out_folder, "annotations")
|
|
1347
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
1348
|
-
|
|
1349
|
-
# Determine if class data is raster or vector
|
|
1350
|
-
is_class_data_raster = False
|
|
1351
|
-
if isinstance(in_class_data, str):
|
|
1352
|
-
file_ext = Path(in_class_data).suffix.lower()
|
|
1353
|
-
# Common raster extensions
|
|
1354
|
-
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
1355
|
-
try:
|
|
1356
|
-
with rasterio.open(in_class_data) as src:
|
|
1357
|
-
is_class_data_raster = True
|
|
1358
|
-
if not quiet:
|
|
1359
|
-
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
1360
|
-
print(f"Raster CRS: {src.crs}")
|
|
1361
|
-
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
1362
|
-
except Exception:
|
|
1363
|
-
is_class_data_raster = False
|
|
1364
|
-
if not quiet:
|
|
1365
|
-
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
|
1366
|
-
|
|
1367
|
-
# Open the input raster
|
|
1368
|
-
with rasterio.open(in_raster) as src:
|
|
1369
|
-
if not quiet:
|
|
1370
|
-
print(f"\nRaster info for {in_raster}:")
|
|
1371
|
-
print(f" CRS: {src.crs}")
|
|
1372
|
-
print(f" Dimensions: {src.width} x {src.height}")
|
|
1373
|
-
print(f" Bounds: {src.bounds}")
|
|
1374
|
-
|
|
1375
|
-
# Calculate number of tiles
|
|
1376
|
-
num_tiles_x = math.ceil((src.width - tile_size) / stride) + 1
|
|
1377
|
-
num_tiles_y = math.ceil((src.height - tile_size) / stride) + 1
|
|
1378
|
-
total_tiles = num_tiles_x * num_tiles_y
|
|
1379
|
-
|
|
1380
|
-
if max_tiles is None:
|
|
1381
|
-
max_tiles = total_tiles
|
|
1382
|
-
|
|
1383
|
-
# Process classification data
|
|
1384
|
-
class_to_id = {}
|
|
1385
|
-
|
|
1386
|
-
if is_class_data_raster:
|
|
1387
|
-
# Load raster class data
|
|
1388
|
-
with rasterio.open(in_class_data) as class_src:
|
|
1389
|
-
# Check if raster CRS matches
|
|
1390
|
-
if class_src.crs != src.crs:
|
|
1391
|
-
warnings.warn(
|
|
1392
|
-
f"CRS mismatch: Class raster ({class_src.crs}) doesn't match input raster ({src.crs}). "
|
|
1393
|
-
f"Results may be misaligned."
|
|
1394
|
-
)
|
|
1395
|
-
|
|
1396
|
-
# Get unique values from raster
|
|
1397
|
-
# Sample to avoid loading huge rasters
|
|
1398
|
-
sample_data = class_src.read(
|
|
1399
|
-
1,
|
|
1400
|
-
out_shape=(
|
|
1401
|
-
1,
|
|
1402
|
-
min(class_src.height, 1000),
|
|
1403
|
-
min(class_src.width, 1000),
|
|
1404
|
-
),
|
|
1405
|
-
)
|
|
1406
|
-
|
|
1407
|
-
unique_classes = np.unique(sample_data)
|
|
1408
|
-
unique_classes = unique_classes[
|
|
1409
|
-
unique_classes > 0
|
|
1410
|
-
] # Remove 0 as it's typically background
|
|
1411
|
-
|
|
1412
|
-
if not quiet:
|
|
1413
|
-
print(
|
|
1414
|
-
f"Found {len(unique_classes)} unique classes in raster: {unique_classes}"
|
|
1415
|
-
)
|
|
1416
|
-
|
|
1417
|
-
# Create class mapping
|
|
1418
|
-
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
1419
|
-
else:
|
|
1420
|
-
# Load vector class data
|
|
1421
|
-
try:
|
|
1422
|
-
gdf = gpd.read_file(in_class_data)
|
|
1423
|
-
if not quiet:
|
|
1424
|
-
print(f"Loaded {len(gdf)} features from {in_class_data}")
|
|
1425
|
-
print(f"Vector CRS: {gdf.crs}")
|
|
1426
|
-
|
|
1427
|
-
# Always reproject to match raster CRS
|
|
1428
|
-
if gdf.crs != src.crs:
|
|
1429
|
-
if not quiet:
|
|
1430
|
-
print(f"Reprojecting features from {gdf.crs} to {src.crs}")
|
|
1431
|
-
gdf = gdf.to_crs(src.crs)
|
|
1432
|
-
|
|
1433
|
-
# Apply buffer if specified
|
|
1434
|
-
if buffer_radius > 0:
|
|
1435
|
-
gdf["geometry"] = gdf.buffer(buffer_radius)
|
|
1436
|
-
if not quiet:
|
|
1437
|
-
print(f"Applied buffer of {buffer_radius} units")
|
|
1438
|
-
|
|
1439
|
-
# Check if class_value_field exists
|
|
1440
|
-
if class_value_field in gdf.columns:
|
|
1441
|
-
unique_classes = gdf[class_value_field].unique()
|
|
1442
|
-
if not quiet:
|
|
1443
|
-
print(
|
|
1444
|
-
f"Found {len(unique_classes)} unique classes: {unique_classes}"
|
|
1445
|
-
)
|
|
1446
|
-
# Create class mapping
|
|
1447
|
-
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
1448
|
-
else:
|
|
1449
|
-
if not quiet:
|
|
1450
|
-
print(
|
|
1451
|
-
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
|
1452
|
-
)
|
|
1453
|
-
class_to_id = {1: 1} # Default mapping
|
|
1454
|
-
except Exception as e:
|
|
1455
|
-
raise ValueError(f"Error processing vector data: {e}")
|
|
1456
|
-
|
|
1457
|
-
# Create progress bar
|
|
1458
|
-
pbar = tqdm(
|
|
1459
|
-
total=min(total_tiles, max_tiles),
|
|
1460
|
-
desc="Generating tiles",
|
|
1461
|
-
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
|
|
1462
|
-
)
|
|
1463
|
-
|
|
1464
|
-
# Track statistics for summary
|
|
1465
|
-
stats = {
|
|
1466
|
-
"total_tiles": 0,
|
|
1467
|
-
"tiles_with_features": 0,
|
|
1468
|
-
"feature_pixels": 0,
|
|
1469
|
-
"errors": 0,
|
|
1470
|
-
"tile_coordinates": [], # For overview image
|
|
1471
|
-
}
|
|
1472
|
-
|
|
1473
|
-
# Process tiles
|
|
1474
|
-
tile_index = 0
|
|
1475
|
-
for y in range(num_tiles_y):
|
|
1476
|
-
for x in range(num_tiles_x):
|
|
1477
|
-
if tile_index >= max_tiles:
|
|
1478
|
-
break
|
|
1479
|
-
|
|
1480
|
-
# Calculate window coordinates
|
|
1481
|
-
window_x = x * stride
|
|
1482
|
-
window_y = y * stride
|
|
1483
|
-
|
|
1484
|
-
# Adjust for edge cases
|
|
1485
|
-
if window_x + tile_size > src.width:
|
|
1486
|
-
window_x = src.width - tile_size
|
|
1487
|
-
if window_y + tile_size > src.height:
|
|
1488
|
-
window_y = src.height - tile_size
|
|
1489
|
-
|
|
1490
|
-
# Define window
|
|
1491
|
-
window = Window(window_x, window_y, tile_size, tile_size)
|
|
1492
|
-
|
|
1493
|
-
# Get window transform and bounds
|
|
1494
|
-
window_transform = src.window_transform(window)
|
|
1495
|
-
|
|
1496
|
-
# Calculate window bounds
|
|
1497
|
-
minx = window_transform[2] # Upper left x
|
|
1498
|
-
maxy = window_transform[5] # Upper left y
|
|
1499
|
-
maxx = minx + tile_size * window_transform[0] # Add width
|
|
1500
|
-
miny = maxy + tile_size * window_transform[4] # Add height
|
|
1501
|
-
|
|
1502
|
-
window_bounds = box(minx, miny, maxx, maxy)
|
|
1503
|
-
|
|
1504
|
-
# Store tile coordinates for overview
|
|
1505
|
-
if create_overview:
|
|
1506
|
-
stats["tile_coordinates"].append(
|
|
1507
|
-
{
|
|
1508
|
-
"index": tile_index,
|
|
1509
|
-
"x": window_x,
|
|
1510
|
-
"y": window_y,
|
|
1511
|
-
"bounds": [minx, miny, maxx, maxy],
|
|
1512
|
-
"has_features": False,
|
|
1513
|
-
}
|
|
1514
|
-
)
|
|
1515
|
-
|
|
1516
|
-
# Create label mask
|
|
1517
|
-
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
1518
|
-
has_features = False
|
|
1519
|
-
|
|
1520
|
-
# Process classification data to create labels
|
|
1521
|
-
if is_class_data_raster:
|
|
1522
|
-
# For raster class data
|
|
1523
|
-
with rasterio.open(in_class_data) as class_src:
|
|
1524
|
-
# Calculate window in class raster
|
|
1525
|
-
src_bounds = src.bounds
|
|
1526
|
-
class_bounds = class_src.bounds
|
|
1527
|
-
|
|
1528
|
-
# Check if windows overlap
|
|
1529
|
-
if (
|
|
1530
|
-
src_bounds.left > class_bounds.right
|
|
1531
|
-
or src_bounds.right < class_bounds.left
|
|
1532
|
-
or src_bounds.bottom > class_bounds.top
|
|
1533
|
-
or src_bounds.top < class_bounds.bottom
|
|
1534
|
-
):
|
|
1535
|
-
warnings.warn(
|
|
1536
|
-
"Class raster and input raster do not overlap."
|
|
1537
|
-
)
|
|
1538
|
-
else:
|
|
1539
|
-
# Get corresponding window in class raster
|
|
1540
|
-
window_class = rasterio.windows.from_bounds(
|
|
1541
|
-
minx, miny, maxx, maxy, class_src.transform
|
|
1542
|
-
)
|
|
1543
|
-
|
|
1544
|
-
# Read label data
|
|
1545
|
-
try:
|
|
1546
|
-
label_data = class_src.read(
|
|
1547
|
-
1,
|
|
1548
|
-
window=window_class,
|
|
1549
|
-
boundless=True,
|
|
1550
|
-
out_shape=(tile_size, tile_size),
|
|
1551
|
-
)
|
|
1552
|
-
|
|
1553
|
-
# Remap class values if needed
|
|
1554
|
-
if class_to_id:
|
|
1555
|
-
remapped_data = np.zeros_like(label_data)
|
|
1556
|
-
for orig_val, new_val in class_to_id.items():
|
|
1557
|
-
remapped_data[label_data == orig_val] = new_val
|
|
1558
|
-
label_mask = remapped_data
|
|
1559
|
-
else:
|
|
1560
|
-
label_mask = label_data
|
|
1561
|
-
|
|
1562
|
-
# Check if we have any features
|
|
1563
|
-
if np.any(label_mask > 0):
|
|
1564
|
-
has_features = True
|
|
1565
|
-
stats["feature_pixels"] += np.count_nonzero(
|
|
1566
|
-
label_mask
|
|
1567
|
-
)
|
|
1568
|
-
except Exception as e:
|
|
1569
|
-
pbar.write(f"Error reading class raster window: {e}")
|
|
1570
|
-
stats["errors"] += 1
|
|
1571
|
-
else:
|
|
1572
|
-
# For vector class data
|
|
1573
|
-
# Find features that intersect with window
|
|
1574
|
-
window_features = gdf[gdf.intersects(window_bounds)]
|
|
1575
|
-
|
|
1576
|
-
if len(window_features) > 0:
|
|
1577
|
-
for idx, feature in window_features.iterrows():
|
|
1578
|
-
# Get class value
|
|
1579
|
-
if class_value_field in feature:
|
|
1580
|
-
class_val = feature[class_value_field]
|
|
1581
|
-
class_id = class_to_id.get(class_val, 1)
|
|
1582
|
-
else:
|
|
1583
|
-
class_id = 1
|
|
1584
|
-
|
|
1585
|
-
# Get geometry in window coordinates
|
|
1586
|
-
geom = feature.geometry.intersection(window_bounds)
|
|
1587
|
-
if not geom.is_empty:
|
|
1588
|
-
try:
|
|
1589
|
-
# Rasterize feature
|
|
1590
|
-
feature_mask = features.rasterize(
|
|
1591
|
-
[(geom, class_id)],
|
|
1592
|
-
out_shape=(tile_size, tile_size),
|
|
1593
|
-
transform=window_transform,
|
|
1594
|
-
fill=0,
|
|
1595
|
-
all_touched=all_touched,
|
|
1596
|
-
)
|
|
1597
|
-
|
|
1598
|
-
# Add to label mask
|
|
1599
|
-
label_mask = np.maximum(label_mask, feature_mask)
|
|
1600
|
-
|
|
1601
|
-
# Check if the feature was actually rasterized
|
|
1602
|
-
if np.any(feature_mask):
|
|
1603
|
-
has_features = True
|
|
1604
|
-
if create_overview and tile_index < len(
|
|
1605
|
-
stats["tile_coordinates"]
|
|
1606
|
-
):
|
|
1607
|
-
stats["tile_coordinates"][tile_index][
|
|
1608
|
-
"has_features"
|
|
1609
|
-
] = True
|
|
1610
|
-
except Exception as e:
|
|
1611
|
-
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
1612
|
-
stats["errors"] += 1
|
|
1613
|
-
|
|
1614
|
-
# Skip tile if no features and skip_empty_tiles is True
|
|
1615
|
-
if skip_empty_tiles and not has_features:
|
|
1616
|
-
pbar.update(1)
|
|
1617
|
-
tile_index += 1
|
|
1618
|
-
continue
|
|
1619
|
-
|
|
1620
|
-
# Read image data
|
|
1621
|
-
image_data = src.read(window=window)
|
|
1622
|
-
|
|
1623
|
-
# Export image as GeoTIFF
|
|
1624
|
-
image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
|
|
1625
|
-
|
|
1626
|
-
# Create profile for image GeoTIFF
|
|
1627
|
-
image_profile = src.profile.copy()
|
|
1628
|
-
image_profile.update(
|
|
1629
|
-
{
|
|
1630
|
-
"height": tile_size,
|
|
1631
|
-
"width": tile_size,
|
|
1632
|
-
"count": image_data.shape[0],
|
|
1633
|
-
"transform": window_transform,
|
|
1634
|
-
}
|
|
1635
|
-
)
|
|
1636
|
-
|
|
1637
|
-
# Save image as GeoTIFF
|
|
1638
|
-
try:
|
|
1639
|
-
with rasterio.open(image_path, "w", **image_profile) as dst:
|
|
1640
|
-
dst.write(image_data)
|
|
1641
|
-
stats["total_tiles"] += 1
|
|
1642
|
-
except Exception as e:
|
|
1643
|
-
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
|
1644
|
-
stats["errors"] += 1
|
|
1645
|
-
|
|
1646
|
-
# Create profile for label GeoTIFF
|
|
1647
|
-
label_profile = {
|
|
1648
|
-
"driver": "GTiff",
|
|
1649
|
-
"height": tile_size,
|
|
1650
|
-
"width": tile_size,
|
|
1651
|
-
"count": 1,
|
|
1652
|
-
"dtype": "uint8",
|
|
1653
|
-
"crs": src.crs,
|
|
1654
|
-
"transform": window_transform,
|
|
1655
|
-
}
|
|
1656
|
-
|
|
1657
|
-
# Export label as GeoTIFF
|
|
1658
|
-
label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
|
|
1659
|
-
try:
|
|
1660
|
-
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
1661
|
-
dst.write(label_mask.astype(np.uint8), 1)
|
|
1662
|
-
|
|
1663
|
-
if has_features:
|
|
1664
|
-
stats["tiles_with_features"] += 1
|
|
1665
|
-
stats["feature_pixels"] += np.count_nonzero(label_mask)
|
|
1666
|
-
except Exception as e:
|
|
1667
|
-
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
|
1668
|
-
stats["errors"] += 1
|
|
1669
|
-
|
|
1670
|
-
# Create XML annotation for object detection if using vector class data
|
|
1671
|
-
if (
|
|
1672
|
-
not is_class_data_raster
|
|
1673
|
-
and "gdf" in locals()
|
|
1674
|
-
and len(window_features) > 0
|
|
1675
|
-
):
|
|
1676
|
-
# Create XML annotation
|
|
1677
|
-
root = ET.Element("annotation")
|
|
1678
|
-
ET.SubElement(root, "folder").text = "images"
|
|
1679
|
-
ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
|
|
1680
|
-
|
|
1681
|
-
size = ET.SubElement(root, "size")
|
|
1682
|
-
ET.SubElement(size, "width").text = str(tile_size)
|
|
1683
|
-
ET.SubElement(size, "height").text = str(tile_size)
|
|
1684
|
-
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
1685
|
-
|
|
1686
|
-
# Add georeference information
|
|
1687
|
-
geo = ET.SubElement(root, "georeference")
|
|
1688
|
-
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
1689
|
-
ET.SubElement(geo, "transform").text = str(
|
|
1690
|
-
window_transform
|
|
1691
|
-
).replace("\n", "")
|
|
1692
|
-
ET.SubElement(geo, "bounds").text = (
|
|
1693
|
-
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
1694
|
-
)
|
|
1695
|
-
|
|
1696
|
-
# Add objects
|
|
1697
|
-
for idx, feature in window_features.iterrows():
|
|
1698
|
-
# Get feature class
|
|
1699
|
-
if class_value_field in feature:
|
|
1700
|
-
class_val = feature[class_value_field]
|
|
1701
|
-
else:
|
|
1702
|
-
class_val = "object"
|
|
1703
|
-
|
|
1704
|
-
# Get geometry bounds in pixel coordinates
|
|
1705
|
-
geom = feature.geometry.intersection(window_bounds)
|
|
1706
|
-
if not geom.is_empty:
|
|
1707
|
-
# Get bounds in world coordinates
|
|
1708
|
-
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
1709
|
-
|
|
1710
|
-
# Convert to pixel coordinates
|
|
1711
|
-
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
1712
|
-
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
1713
|
-
|
|
1714
|
-
# Ensure coordinates are within tile bounds
|
|
1715
|
-
xmin = max(0, min(tile_size, int(col_min)))
|
|
1716
|
-
ymin = max(0, min(tile_size, int(row_min)))
|
|
1717
|
-
xmax = max(0, min(tile_size, int(col_max)))
|
|
1718
|
-
ymax = max(0, min(tile_size, int(row_max)))
|
|
1719
|
-
|
|
1720
|
-
# Only add if the box has non-zero area
|
|
1721
|
-
if xmax > xmin and ymax > ymin:
|
|
1722
|
-
obj = ET.SubElement(root, "object")
|
|
1723
|
-
ET.SubElement(obj, "name").text = str(class_val)
|
|
1724
|
-
ET.SubElement(obj, "difficult").text = "0"
|
|
1725
|
-
|
|
1726
|
-
bbox = ET.SubElement(obj, "bndbox")
|
|
1727
|
-
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
1728
|
-
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
1729
|
-
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
1730
|
-
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
1731
|
-
|
|
1732
|
-
# Save XML
|
|
1733
|
-
tree = ET.ElementTree(root)
|
|
1734
|
-
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
1735
|
-
tree.write(xml_path)
|
|
1736
|
-
|
|
1737
|
-
# Update progress bar
|
|
1738
|
-
pbar.update(1)
|
|
1739
|
-
pbar.set_description(
|
|
1740
|
-
f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
|
|
1741
|
-
)
|
|
1742
|
-
|
|
1743
|
-
tile_index += 1
|
|
1744
|
-
if tile_index >= max_tiles:
|
|
1745
|
-
break
|
|
1746
|
-
|
|
1747
|
-
if tile_index >= max_tiles:
|
|
1748
|
-
break
|
|
1749
|
-
|
|
1750
|
-
# Close progress bar
|
|
1751
|
-
pbar.close()
|
|
1752
|
-
|
|
1753
|
-
# Create overview image if requested
|
|
1754
|
-
if create_overview and stats["tile_coordinates"]:
|
|
1755
|
-
try:
|
|
1756
|
-
create_overview_image(
|
|
1757
|
-
src,
|
|
1758
|
-
stats["tile_coordinates"],
|
|
1759
|
-
os.path.join(out_folder, "overview.png"),
|
|
1760
|
-
tile_size,
|
|
1761
|
-
stride,
|
|
1762
|
-
)
|
|
1763
|
-
except Exception as e:
|
|
1764
|
-
print(f"Failed to create overview image: {e}")
|
|
1765
|
-
|
|
1766
|
-
# Report results
|
|
1767
|
-
if not quiet:
|
|
1768
|
-
print("\n------- Export Summary -------")
|
|
1769
|
-
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
1770
|
-
print(
|
|
1771
|
-
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
1772
|
-
)
|
|
1773
|
-
if stats["tiles_with_features"] > 0:
|
|
1774
|
-
print(
|
|
1775
|
-
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
1776
|
-
)
|
|
1777
|
-
if stats["errors"] > 0:
|
|
1778
|
-
print(f"Errors encountered: {stats['errors']}")
|
|
1779
|
-
print(f"Output saved to: {out_folder}")
|
|
1780
|
-
|
|
1781
|
-
# Verify georeference in a sample image and label
|
|
1782
|
-
if stats["total_tiles"] > 0:
|
|
1783
|
-
print("\n------- Georeference Verification -------")
|
|
1784
|
-
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
|
1785
|
-
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
|
1786
|
-
|
|
1787
|
-
if os.path.exists(sample_image):
|
|
1788
|
-
try:
|
|
1789
|
-
with rasterio.open(sample_image) as img:
|
|
1790
|
-
print(f"Image CRS: {img.crs}")
|
|
1791
|
-
print(f"Image transform: {img.transform}")
|
|
1792
|
-
print(
|
|
1793
|
-
f"Image has georeference: {img.crs is not None and img.transform is not None}"
|
|
1794
|
-
)
|
|
1795
|
-
print(
|
|
1796
|
-
f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
|
|
1797
|
-
)
|
|
1798
|
-
except Exception as e:
|
|
1799
|
-
print(f"Error verifying image georeference: {e}")
|
|
1800
|
-
|
|
1801
|
-
if os.path.exists(sample_label):
|
|
1802
|
-
try:
|
|
1803
|
-
with rasterio.open(sample_label) as lbl:
|
|
1804
|
-
print(f"Label CRS: {lbl.crs}")
|
|
1805
|
-
print(f"Label transform: {lbl.transform}")
|
|
1806
|
-
print(
|
|
1807
|
-
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
1808
|
-
)
|
|
1809
|
-
print(
|
|
1810
|
-
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
1811
|
-
)
|
|
1812
|
-
except Exception as e:
|
|
1813
|
-
print(f"Error verifying label georeference: {e}")
|
|
1814
|
-
|
|
1815
|
-
# Return statistics dictionary for further processing if needed
|
|
1816
|
-
return stats
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
def create_overview_image(
|
|
1820
|
-
src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
|
|
1821
|
-
):
|
|
1822
|
-
"""Create an overview image showing all tiles and their status, with optional GeoJSON export.
|
|
1823
|
-
|
|
1824
|
-
Args:
|
|
1825
|
-
src (rasterio.io.DatasetReader): The source raster dataset.
|
|
1826
|
-
tile_coordinates (list): A list of dictionaries containing tile information.
|
|
1827
|
-
output_path (str): The path where the overview image will be saved.
|
|
1828
|
-
tile_size (int): The size of each tile in pixels.
|
|
1829
|
-
stride (int): The stride between tiles in pixels. Controls overlap between adjacent tiles.
|
|
1830
|
-
geojson_path (str, optional): If provided, exports the tile rectangles as GeoJSON to this path.
|
|
1831
|
-
|
|
1832
|
-
Returns:
|
|
1833
|
-
str: Path to the saved overview image.
|
|
1834
|
-
"""
|
|
1835
|
-
# Read a reduced version of the source image
|
|
1836
|
-
overview_scale = max(
|
|
1837
|
-
1, int(max(src.width, src.height) / 2000)
|
|
1838
|
-
) # Scale to max ~2000px
|
|
1839
|
-
overview_width = src.width // overview_scale
|
|
1840
|
-
overview_height = src.height // overview_scale
|
|
1841
|
-
|
|
1842
|
-
# Read downsampled image
|
|
1843
|
-
overview_data = src.read(
|
|
1844
|
-
out_shape=(src.count, overview_height, overview_width),
|
|
1845
|
-
resampling=rasterio.enums.Resampling.average,
|
|
1846
|
-
)
|
|
1847
|
-
|
|
1848
|
-
# Create RGB image for display
|
|
1849
|
-
if overview_data.shape[0] >= 3:
|
|
1850
|
-
rgb = np.moveaxis(overview_data[:3], 0, -1)
|
|
1851
|
-
else:
|
|
1852
|
-
# For single band, create grayscale RGB
|
|
1853
|
-
rgb = np.stack([overview_data[0], overview_data[0], overview_data[0]], axis=-1)
|
|
1854
|
-
|
|
1855
|
-
# Normalize for display
|
|
1856
|
-
for i in range(rgb.shape[-1]):
|
|
1857
|
-
band = rgb[..., i]
|
|
1858
|
-
non_zero = band[band > 0]
|
|
1859
|
-
if len(non_zero) > 0:
|
|
1860
|
-
p2, p98 = np.percentile(non_zero, (2, 98))
|
|
1861
|
-
rgb[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
|
1862
|
-
|
|
1863
|
-
# Create figure
|
|
1864
|
-
plt.figure(figsize=(12, 12))
|
|
1865
|
-
plt.imshow(rgb)
|
|
1866
|
-
|
|
1867
|
-
# If GeoJSON export is requested, prepare GeoJSON structures
|
|
1868
|
-
if geojson_path:
|
|
1869
|
-
features = []
|
|
1870
|
-
|
|
1871
|
-
# Draw tile boundaries
|
|
1872
|
-
for tile in tile_coordinates:
|
|
1873
|
-
# Convert bounds to pixel coordinates in overview
|
|
1874
|
-
bounds = tile["bounds"]
|
|
1875
|
-
# Calculate scaled pixel coordinates
|
|
1876
|
-
x_min = int((tile["x"]) / overview_scale)
|
|
1877
|
-
y_min = int((tile["y"]) / overview_scale)
|
|
1878
|
-
width = int(tile_size / overview_scale)
|
|
1879
|
-
height = int(tile_size / overview_scale)
|
|
1880
|
-
|
|
1881
|
-
# Draw rectangle
|
|
1882
|
-
color = "lime" if tile["has_features"] else "red"
|
|
1883
|
-
rect = plt.Rectangle(
|
|
1884
|
-
(x_min, y_min), width, height, fill=False, edgecolor=color, linewidth=0.5
|
|
1885
|
-
)
|
|
1886
|
-
plt.gca().add_patch(rect)
|
|
1887
|
-
|
|
1888
|
-
# Add tile number if not too crowded
|
|
1889
|
-
if width > 20 and height > 20:
|
|
1890
|
-
plt.text(
|
|
1891
|
-
x_min + width / 2,
|
|
1892
|
-
y_min + height / 2,
|
|
1893
|
-
str(tile["index"]),
|
|
1894
|
-
color="white",
|
|
1895
|
-
ha="center",
|
|
1896
|
-
va="center",
|
|
1897
|
-
fontsize=8,
|
|
1898
|
-
)
|
|
1899
|
-
|
|
1900
|
-
# Add to GeoJSON features if exporting
|
|
1901
|
-
if geojson_path:
|
|
1902
|
-
# Create a polygon from the bounds (already in geo-coordinates)
|
|
1903
|
-
minx, miny, maxx, maxy = bounds
|
|
1904
|
-
polygon = box(minx, miny, maxx, maxy)
|
|
1905
|
-
|
|
1906
|
-
# Calculate overlap with neighboring tiles
|
|
1907
|
-
overlap = 0
|
|
1908
|
-
if stride < tile_size:
|
|
1909
|
-
overlap = tile_size - stride
|
|
1910
|
-
|
|
1911
|
-
# Create a GeoJSON feature
|
|
1912
|
-
feature = {
|
|
1913
|
-
"type": "Feature",
|
|
1914
|
-
"geometry": mapping(polygon),
|
|
1915
|
-
"properties": {
|
|
1916
|
-
"index": tile["index"],
|
|
1917
|
-
"has_features": tile["has_features"],
|
|
1918
|
-
"bounds_pixel": [
|
|
1919
|
-
tile["x"],
|
|
1920
|
-
tile["y"],
|
|
1921
|
-
tile["x"] + tile_size,
|
|
1922
|
-
tile["y"] + tile_size,
|
|
1923
|
-
],
|
|
1924
|
-
"tile_size_px": tile_size,
|
|
1925
|
-
"stride_px": stride,
|
|
1926
|
-
"overlap_px": overlap,
|
|
1927
|
-
},
|
|
1928
|
-
}
|
|
1929
|
-
|
|
1930
|
-
# Add any additional properties from the tile
|
|
1931
|
-
for key, value in tile.items():
|
|
1932
|
-
if key not in ["x", "y", "index", "has_features", "bounds"]:
|
|
1933
|
-
feature["properties"][key] = value
|
|
1934
|
-
|
|
1935
|
-
features.append(feature)
|
|
1936
|
-
|
|
1937
|
-
plt.title("Tile Overview (Green = Contains Features, Red = Empty)")
|
|
1938
|
-
plt.axis("off")
|
|
1939
|
-
plt.tight_layout()
|
|
1940
|
-
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
|
1941
|
-
plt.close()
|
|
1942
|
-
|
|
1943
|
-
print(f"Overview image saved to {output_path}")
|
|
1944
|
-
|
|
1945
|
-
# Export GeoJSON if requested
|
|
1946
|
-
if geojson_path:
|
|
1947
|
-
geojson_collection = {
|
|
1948
|
-
"type": "FeatureCollection",
|
|
1949
|
-
"features": features,
|
|
1950
|
-
"properties": {
|
|
1951
|
-
"crs": (
|
|
1952
|
-
src.crs.to_string()
|
|
1953
|
-
if hasattr(src.crs, "to_string")
|
|
1954
|
-
else str(src.crs)
|
|
1955
|
-
),
|
|
1956
|
-
"total_tiles": len(features),
|
|
1957
|
-
"source_raster_dimensions": [src.width, src.height],
|
|
1958
|
-
},
|
|
1959
|
-
}
|
|
1960
|
-
|
|
1961
|
-
# Save to file
|
|
1962
|
-
with open(geojson_path, "w") as f:
|
|
1963
|
-
json.dump(geojson_collection, f)
|
|
1964
|
-
|
|
1965
|
-
print(f"GeoJSON saved to {geojson_path}")
|
|
1966
|
-
|
|
1967
|
-
return output_path
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
def export_tiles_to_geojson(
|
|
1971
|
-
tile_coordinates, src, output_path, tile_size=None, stride=None
|
|
1972
|
-
):
|
|
1973
|
-
"""
|
|
1974
|
-
Export tile rectangles directly to GeoJSON without creating an overview image.
|
|
1975
|
-
|
|
1976
|
-
Args:
|
|
1977
|
-
tile_coordinates (list): A list of dictionaries containing tile information.
|
|
1978
|
-
src (rasterio.io.DatasetReader): The source raster dataset.
|
|
1979
|
-
output_path (str): The path where the GeoJSON will be saved.
|
|
1980
|
-
tile_size (int, optional): The size of each tile in pixels. Only needed if not in tile_coordinates.
|
|
1981
|
-
stride (int, optional): The stride between tiles in pixels. Used to calculate overlaps between tiles.
|
|
1982
|
-
|
|
1983
|
-
Returns:
|
|
1984
|
-
str: Path to the saved GeoJSON file.
|
|
1985
|
-
"""
|
|
1986
|
-
features = []
|
|
1987
|
-
|
|
1988
|
-
for tile in tile_coordinates:
|
|
1989
|
-
# Get the size from the tile or use the provided parameter
|
|
1990
|
-
tile_width = tile.get("width", tile.get("size", tile_size))
|
|
1991
|
-
tile_height = tile.get("height", tile.get("size", tile_size))
|
|
1992
|
-
|
|
1993
|
-
if tile_width is None or tile_height is None:
|
|
1994
|
-
raise ValueError(
|
|
1995
|
-
"Tile size not found in tile data and no tile_size parameter provided"
|
|
1996
|
-
)
|
|
1997
|
-
|
|
1998
|
-
# Get bounds from the tile
|
|
1999
|
-
if "bounds" in tile:
|
|
2000
|
-
# If bounds are already in geo coordinates
|
|
2001
|
-
minx, miny, maxx, maxy = tile["bounds"]
|
|
2002
|
-
else:
|
|
2003
|
-
# Try to calculate bounds from transform if available
|
|
2004
|
-
if hasattr(src, "transform"):
|
|
2005
|
-
# Convert pixel coordinates to geo coordinates
|
|
2006
|
-
window_transform = src.transform
|
|
2007
|
-
x, y = tile["x"], tile["y"]
|
|
2008
|
-
minx = window_transform[2] + x * window_transform[0]
|
|
2009
|
-
maxy = window_transform[5] + y * window_transform[4]
|
|
2010
|
-
maxx = minx + tile_width * window_transform[0]
|
|
2011
|
-
miny = maxy + tile_height * window_transform[4]
|
|
2012
|
-
else:
|
|
2013
|
-
raise ValueError(
|
|
2014
|
-
"Cannot determine bounds. Neither 'bounds' in tile nor transform in src."
|
|
2015
|
-
)
|
|
2016
|
-
|
|
2017
|
-
# Calculate overlap with neighboring tiles if stride is provided
|
|
2018
|
-
overlap = 0
|
|
2019
|
-
if stride is not None and stride < tile_width:
|
|
2020
|
-
overlap = tile_width - stride
|
|
2021
|
-
|
|
2022
|
-
# Create a polygon from the bounds
|
|
2023
|
-
polygon = box(minx, miny, maxx, maxy)
|
|
2024
|
-
|
|
2025
|
-
# Create a GeoJSON feature
|
|
2026
|
-
feature = {
|
|
2027
|
-
"type": "Feature",
|
|
2028
|
-
"geometry": mapping(polygon),
|
|
2029
|
-
"properties": {
|
|
2030
|
-
"index": tile["index"],
|
|
2031
|
-
"has_features": tile.get("has_features", False),
|
|
2032
|
-
"tile_width_px": tile_width,
|
|
2033
|
-
"tile_height_px": tile_height,
|
|
2034
|
-
},
|
|
2035
|
-
}
|
|
2036
|
-
|
|
2037
|
-
# Add overlap information if stride is provided
|
|
2038
|
-
if stride is not None:
|
|
2039
|
-
feature["properties"]["stride_px"] = stride
|
|
2040
|
-
feature["properties"]["overlap_px"] = overlap
|
|
2041
|
-
|
|
2042
|
-
# Add additional properties from the tile
|
|
2043
|
-
for key, value in tile.items():
|
|
2044
|
-
if key not in ["bounds", "geometry"]:
|
|
2045
|
-
feature["properties"][key] = value
|
|
2046
|
-
|
|
2047
|
-
features.append(feature)
|
|
2048
|
-
|
|
2049
|
-
# Create the GeoJSON collection
|
|
2050
|
-
geojson_collection = {
|
|
2051
|
-
"type": "FeatureCollection",
|
|
2052
|
-
"features": features,
|
|
2053
|
-
"properties": {
|
|
2054
|
-
"crs": (
|
|
2055
|
-
src.crs.to_string() if hasattr(src.crs, "to_string") else str(src.crs)
|
|
2056
|
-
),
|
|
2057
|
-
"total_tiles": len(features),
|
|
2058
|
-
"source_raster_dimensions": (
|
|
2059
|
-
[src.width, src.height] if hasattr(src, "width") else None
|
|
2060
|
-
),
|
|
2061
|
-
},
|
|
2062
|
-
}
|
|
2063
|
-
|
|
2064
|
-
# Create directory if it doesn't exist
|
|
2065
|
-
os.makedirs(os.path.dirname(os.path.abspath(output_path)) or ".", exist_ok=True)
|
|
2066
|
-
|
|
2067
|
-
# Save to file
|
|
2068
|
-
with open(output_path, "w") as f:
|
|
2069
|
-
json.dump(geojson_collection, f)
|
|
2070
|
-
|
|
2071
|
-
print(f"GeoJSON saved to {output_path}")
|
|
2072
|
-
return output_path
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
def export_training_data(
|
|
2076
|
-
in_raster,
|
|
2077
|
-
out_folder,
|
|
2078
|
-
in_class_data,
|
|
2079
|
-
image_chip_format="GEOTIFF",
|
|
2080
|
-
tile_size_x=256,
|
|
2081
|
-
tile_size_y=256,
|
|
2082
|
-
stride_x=None,
|
|
2083
|
-
stride_y=None,
|
|
2084
|
-
output_nofeature_tiles=True,
|
|
2085
|
-
metadata_format="PASCAL_VOC",
|
|
2086
|
-
start_index=0,
|
|
2087
|
-
class_value_field="class",
|
|
2088
|
-
buffer_radius=0,
|
|
2089
|
-
in_mask_polygons=None,
|
|
2090
|
-
rotation_angle=0,
|
|
2091
|
-
reference_system=None,
|
|
2092
|
-
blacken_around_feature=False,
|
|
2093
|
-
crop_mode="FIXED_SIZE", # Implemented but not fully used yet
|
|
2094
|
-
in_raster2=None,
|
|
2095
|
-
in_instance_data=None,
|
|
2096
|
-
instance_class_value_field=None, # Implemented but not fully used yet
|
|
2097
|
-
min_polygon_overlap_ratio=0.0,
|
|
2098
|
-
all_touched=True,
|
|
2099
|
-
save_geotiff=True,
|
|
2100
|
-
quiet=False,
|
|
2101
|
-
):
|
|
2102
|
-
"""
|
|
2103
|
-
Export training data for deep learning using TorchGeo with progress bar.
|
|
2104
|
-
|
|
2105
|
-
Args:
|
|
2106
|
-
in_raster (str): Path to input raster image.
|
|
2107
|
-
out_folder (str): Output folder path where chips and labels will be saved.
|
|
2108
|
-
in_class_data (str): Path to vector file containing class polygons.
|
|
2109
|
-
image_chip_format (str): Output image format (PNG, JPEG, TIFF, GEOTIFF).
|
|
2110
|
-
tile_size_x (int): Width of image chips in pixels.
|
|
2111
|
-
tile_size_y (int): Height of image chips in pixels.
|
|
2112
|
-
stride_x (int): Horizontal stride between chips. If None, uses tile_size_x.
|
|
2113
|
-
stride_y (int): Vertical stride between chips. If None, uses tile_size_y.
|
|
2114
|
-
output_nofeature_tiles (bool): Whether to export chips without features.
|
|
2115
|
-
metadata_format (str): Output metadata format (PASCAL_VOC, KITTI, COCO).
|
|
2116
|
-
start_index (int): Starting index for chip filenames.
|
|
2117
|
-
class_value_field (str): Field name in in_class_data containing class values.
|
|
2118
|
-
buffer_radius (float): Buffer radius around features (in CRS units).
|
|
2119
|
-
in_mask_polygons (str): Path to vector file containing mask polygons.
|
|
2120
|
-
rotation_angle (float): Rotation angle in degrees.
|
|
2121
|
-
reference_system (str): Reference system code.
|
|
2122
|
-
blacken_around_feature (bool): Whether to mask areas outside of features.
|
|
2123
|
-
crop_mode (str): Crop mode (FIXED_SIZE, CENTERED_ON_FEATURE).
|
|
2124
|
-
in_raster2 (str): Path to secondary raster image.
|
|
2125
|
-
in_instance_data (str): Path to vector file containing instance polygons.
|
|
2126
|
-
instance_class_value_field (str): Field name in in_instance_data for instance classes.
|
|
2127
|
-
min_polygon_overlap_ratio (float): Minimum overlap ratio for polygons.
|
|
2128
|
-
all_touched (bool): Whether to use all_touched=True in rasterization.
|
|
2129
|
-
save_geotiff (bool): Whether to save as GeoTIFF with georeferencing.
|
|
2130
|
-
quiet (bool): If True, suppress most output messages.
|
|
2131
|
-
"""
|
|
2132
|
-
# Create output directories
|
|
2133
|
-
image_dir = os.path.join(out_folder, "images")
|
|
2134
|
-
os.makedirs(image_dir, exist_ok=True)
|
|
2135
|
-
|
|
2136
|
-
label_dir = os.path.join(out_folder, "labels")
|
|
2137
|
-
os.makedirs(label_dir, exist_ok=True)
|
|
2138
|
-
|
|
2139
|
-
# Define annotation directories based on metadata format
|
|
2140
|
-
if metadata_format == "PASCAL_VOC":
|
|
2141
|
-
ann_dir = os.path.join(out_folder, "annotations")
|
|
2142
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
2143
|
-
elif metadata_format == "COCO":
|
|
2144
|
-
ann_dir = os.path.join(out_folder, "annotations")
|
|
2145
|
-
os.makedirs(ann_dir, exist_ok=True)
|
|
2146
|
-
# Initialize COCO annotations dictionary
|
|
2147
|
-
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
2148
|
-
|
|
2149
|
-
# Initialize statistics dictionary
|
|
2150
|
-
stats = {
|
|
2151
|
-
"total_tiles": 0,
|
|
2152
|
-
"tiles_with_features": 0,
|
|
2153
|
-
"feature_pixels": 0,
|
|
2154
|
-
"errors": 0,
|
|
2155
|
-
}
|
|
2156
|
-
|
|
2157
|
-
# Open raster
|
|
2158
|
-
with rasterio.open(in_raster) as src:
|
|
2159
|
-
if not quiet:
|
|
2160
|
-
print(f"\nRaster info for {in_raster}:")
|
|
2161
|
-
print(f" CRS: {src.crs}")
|
|
2162
|
-
print(f" Dimensions: {src.width} x {src.height}")
|
|
2163
|
-
print(f" Bounds: {src.bounds}")
|
|
2164
|
-
|
|
2165
|
-
# Set defaults for stride if not provided
|
|
2166
|
-
if stride_x is None:
|
|
2167
|
-
stride_x = tile_size_x
|
|
2168
|
-
if stride_y is None:
|
|
2169
|
-
stride_y = tile_size_y
|
|
2170
|
-
|
|
2171
|
-
# Calculate number of tiles in x and y directions
|
|
2172
|
-
num_tiles_x = math.ceil((src.width - tile_size_x) / stride_x) + 1
|
|
2173
|
-
num_tiles_y = math.ceil((src.height - tile_size_y) / stride_y) + 1
|
|
2174
|
-
total_tiles = num_tiles_x * num_tiles_y
|
|
2175
|
-
|
|
2176
|
-
# Read class data
|
|
2177
|
-
gdf = gpd.read_file(in_class_data)
|
|
2178
|
-
if not quiet:
|
|
2179
|
-
print(f"Loaded {len(gdf)} features from {in_class_data}")
|
|
2180
|
-
print(f"Available columns: {gdf.columns.tolist()}")
|
|
2181
|
-
print(f"GeoJSON CRS: {gdf.crs}")
|
|
2182
|
-
|
|
2183
|
-
# Check if class_value_field exists
|
|
2184
|
-
if class_value_field not in gdf.columns:
|
|
2185
|
-
if not quiet:
|
|
2186
|
-
print(
|
|
2187
|
-
f"WARNING: '{class_value_field}' field not found in the input data. Using default class value 1."
|
|
2188
|
-
)
|
|
2189
|
-
# Add a default class column
|
|
2190
|
-
gdf[class_value_field] = 1
|
|
2191
|
-
unique_classes = [1]
|
|
2192
|
-
else:
|
|
2193
|
-
# Print unique classes for debugging
|
|
2194
|
-
unique_classes = gdf[class_value_field].unique()
|
|
2195
|
-
if not quiet:
|
|
2196
|
-
print(f"Found {len(unique_classes)} unique classes: {unique_classes}")
|
|
2197
|
-
|
|
2198
|
-
# CRITICAL: Always reproject to match raster CRS to ensure proper alignment
|
|
2199
|
-
if gdf.crs != src.crs:
|
|
2200
|
-
if not quiet:
|
|
2201
|
-
print(f"Reprojecting features from {gdf.crs} to {src.crs}")
|
|
2202
|
-
gdf = gdf.to_crs(src.crs)
|
|
2203
|
-
elif reference_system and gdf.crs != reference_system:
|
|
2204
|
-
if not quiet:
|
|
2205
|
-
print(
|
|
2206
|
-
f"Reprojecting features to specified reference system {reference_system}"
|
|
2207
|
-
)
|
|
2208
|
-
gdf = gdf.to_crs(reference_system)
|
|
2209
|
-
|
|
2210
|
-
# Check overlap between raster and vector data
|
|
2211
|
-
raster_bounds = box(*src.bounds)
|
|
2212
|
-
vector_bounds = box(*gdf.total_bounds)
|
|
2213
|
-
if not raster_bounds.intersects(vector_bounds):
|
|
2214
|
-
if not quiet:
|
|
2215
|
-
print(
|
|
2216
|
-
"WARNING: The vector data doesn't intersect with the raster extent!"
|
|
2217
|
-
)
|
|
2218
|
-
print(f"Raster bounds: {src.bounds}")
|
|
2219
|
-
print(f"Vector bounds: {gdf.total_bounds}")
|
|
2220
|
-
else:
|
|
2221
|
-
overlap = (
|
|
2222
|
-
raster_bounds.intersection(vector_bounds).area / vector_bounds.area
|
|
2223
|
-
)
|
|
2224
|
-
if not quiet:
|
|
2225
|
-
print(f"Overlap between raster and vector: {overlap:.2%}")
|
|
2226
|
-
|
|
2227
|
-
# Apply buffer if specified
|
|
2228
|
-
if buffer_radius > 0:
|
|
2229
|
-
gdf["geometry"] = gdf.buffer(buffer_radius)
|
|
2230
|
-
|
|
2231
|
-
# Initialize class mapping (ensure all classes are mapped to non-zero values)
|
|
2232
|
-
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
2233
|
-
|
|
2234
|
-
# Store category info for COCO format
|
|
2235
|
-
if metadata_format == "COCO":
|
|
2236
|
-
for cls_val in unique_classes:
|
|
2237
|
-
coco_annotations["categories"].append(
|
|
2238
|
-
{
|
|
2239
|
-
"id": class_to_id[cls_val],
|
|
2240
|
-
"name": str(cls_val),
|
|
2241
|
-
"supercategory": "object",
|
|
2242
|
-
}
|
|
2243
|
-
)
|
|
2244
|
-
|
|
2245
|
-
# Load mask polygons if provided
|
|
2246
|
-
mask_gdf = None
|
|
2247
|
-
if in_mask_polygons:
|
|
2248
|
-
mask_gdf = gpd.read_file(in_mask_polygons)
|
|
2249
|
-
if reference_system:
|
|
2250
|
-
mask_gdf = mask_gdf.to_crs(reference_system)
|
|
2251
|
-
elif mask_gdf.crs != src.crs:
|
|
2252
|
-
mask_gdf = mask_gdf.to_crs(src.crs)
|
|
2253
|
-
|
|
2254
|
-
# Process instance data if provided
|
|
2255
|
-
instance_gdf = None
|
|
2256
|
-
if in_instance_data:
|
|
2257
|
-
instance_gdf = gpd.read_file(in_instance_data)
|
|
2258
|
-
if reference_system:
|
|
2259
|
-
instance_gdf = instance_gdf.to_crs(reference_system)
|
|
2260
|
-
elif instance_gdf.crs != src.crs:
|
|
2261
|
-
instance_gdf = instance_gdf.to_crs(src.crs)
|
|
2262
|
-
|
|
2263
|
-
# Load secondary raster if provided
|
|
2264
|
-
src2 = None
|
|
2265
|
-
if in_raster2:
|
|
2266
|
-
src2 = rasterio.open(in_raster2)
|
|
2267
|
-
|
|
2268
|
-
# Set up augmentation if rotation is specified
|
|
2269
|
-
augmentation = None
|
|
2270
|
-
if rotation_angle != 0:
|
|
2271
|
-
# Fixed: Added data_keys parameter to AugmentationSequential
|
|
2272
|
-
augmentation = torchgeo.transforms.AugmentationSequential(
|
|
2273
|
-
torch.nn.ModuleList([RandomRotation(rotation_angle)]),
|
|
2274
|
-
data_keys=["image"], # Add data_keys parameter
|
|
2275
|
-
)
|
|
2276
|
-
|
|
2277
|
-
# Initialize annotation ID for COCO format
|
|
2278
|
-
ann_id = 0
|
|
2279
|
-
|
|
2280
|
-
# Create progress bar
|
|
2281
|
-
pbar = tqdm(
|
|
2282
|
-
total=total_tiles,
|
|
2283
|
-
desc=f"Generating tiles (with features: 0)",
|
|
2284
|
-
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
|
|
2285
|
-
)
|
|
2286
|
-
|
|
2287
|
-
# Generate tiles
|
|
2288
|
-
chip_index = start_index
|
|
2289
|
-
for y in range(num_tiles_y):
|
|
2290
|
-
for x in range(num_tiles_x):
|
|
2291
|
-
# Calculate window coordinates
|
|
2292
|
-
window_x = x * stride_x
|
|
2293
|
-
window_y = y * stride_y
|
|
2294
|
-
|
|
2295
|
-
# Adjust for edge cases
|
|
2296
|
-
if window_x + tile_size_x > src.width:
|
|
2297
|
-
window_x = src.width - tile_size_x
|
|
2298
|
-
if window_y + tile_size_y > src.height:
|
|
2299
|
-
window_y = src.height - tile_size_y
|
|
2300
|
-
|
|
2301
|
-
# Adjust window based on crop_mode
|
|
2302
|
-
if crop_mode == "CENTERED_ON_FEATURE" and len(gdf) > 0:
|
|
2303
|
-
# Find the nearest feature to the center of this window
|
|
2304
|
-
window_center_x = window_x + tile_size_x // 2
|
|
2305
|
-
window_center_y = window_y + tile_size_y // 2
|
|
2306
|
-
|
|
2307
|
-
# Convert center to world coordinates
|
|
2308
|
-
center_x, center_y = src.xy(window_center_y, window_center_x)
|
|
2309
|
-
center_point = gpd.points_from_xy([center_x], [center_y])[0]
|
|
2310
|
-
|
|
2311
|
-
# Find nearest feature
|
|
2312
|
-
distances = gdf.geometry.distance(center_point)
|
|
2313
|
-
nearest_idx = distances.idxmin()
|
|
2314
|
-
nearest_feature = gdf.iloc[nearest_idx]
|
|
2315
|
-
|
|
2316
|
-
# Get centroid of nearest feature
|
|
2317
|
-
feature_centroid = nearest_feature.geometry.centroid
|
|
2318
|
-
|
|
2319
|
-
# Convert feature centroid to pixel coordinates
|
|
2320
|
-
feature_row, feature_col = src.index(
|
|
2321
|
-
feature_centroid.x, feature_centroid.y
|
|
2322
|
-
)
|
|
2323
|
-
|
|
2324
|
-
# Adjust window to center on feature
|
|
2325
|
-
window_x = max(
|
|
2326
|
-
0, min(src.width - tile_size_x, feature_col - tile_size_x // 2)
|
|
2327
|
-
)
|
|
2328
|
-
window_y = max(
|
|
2329
|
-
0, min(src.height - tile_size_y, feature_row - tile_size_y // 2)
|
|
2330
|
-
)
|
|
2331
|
-
|
|
2332
|
-
# Define window
|
|
2333
|
-
window = Window(window_x, window_y, tile_size_x, tile_size_y)
|
|
2334
|
-
|
|
2335
|
-
# Get window transform and bounds in source CRS
|
|
2336
|
-
window_transform = src.window_transform(window)
|
|
2337
|
-
|
|
2338
|
-
# Calculate window bounds more explicitly and accurately
|
|
2339
|
-
minx = window_transform[2] # Upper left x
|
|
2340
|
-
maxy = window_transform[5] # Upper left y
|
|
2341
|
-
maxx = minx + tile_size_x * window_transform[0] # Add width
|
|
2342
|
-
miny = (
|
|
2343
|
-
maxy + tile_size_y * window_transform[4]
|
|
2344
|
-
) # Add height (note: transform[4] is typically negative)
|
|
2345
|
-
|
|
2346
|
-
window_bounds = box(minx, miny, maxx, maxy)
|
|
2347
|
-
|
|
2348
|
-
# Apply rotation if specified
|
|
2349
|
-
if rotation_angle != 0:
|
|
2350
|
-
window_bounds = rotate(
|
|
2351
|
-
window_bounds, rotation_angle, origin="center"
|
|
2352
|
-
)
|
|
2353
|
-
|
|
2354
|
-
# Find features that intersect with window
|
|
2355
|
-
window_features = gdf[gdf.intersects(window_bounds)]
|
|
2356
|
-
|
|
2357
|
-
# Process instance data if provided
|
|
2358
|
-
window_instances = None
|
|
2359
|
-
if instance_gdf is not None and instance_class_value_field is not None:
|
|
2360
|
-
window_instances = instance_gdf[
|
|
2361
|
-
instance_gdf.intersects(window_bounds)
|
|
2362
|
-
]
|
|
2363
|
-
if len(window_instances) > 0:
|
|
2364
|
-
if not quiet:
|
|
2365
|
-
pbar.write(
|
|
2366
|
-
f"Found {len(window_instances)} instances in tile {chip_index}"
|
|
2367
|
-
)
|
|
2368
|
-
|
|
2369
|
-
# Skip if no features and output_nofeature_tiles is False
|
|
2370
|
-
if not output_nofeature_tiles and len(window_features) == 0:
|
|
2371
|
-
pbar.update(1) # Still update progress bar
|
|
2372
|
-
continue
|
|
2373
|
-
|
|
2374
|
-
# Check polygon overlap ratio if specified
|
|
2375
|
-
if min_polygon_overlap_ratio > 0 and len(window_features) > 0:
|
|
2376
|
-
valid_features = []
|
|
2377
|
-
for _, feature in window_features.iterrows():
|
|
2378
|
-
overlap_ratio = (
|
|
2379
|
-
feature.geometry.intersection(window_bounds).area
|
|
2380
|
-
/ feature.geometry.area
|
|
2381
|
-
)
|
|
2382
|
-
if overlap_ratio >= min_polygon_overlap_ratio:
|
|
2383
|
-
valid_features.append(feature)
|
|
2384
|
-
|
|
2385
|
-
if len(valid_features) > 0:
|
|
2386
|
-
window_features = gpd.GeoDataFrame(valid_features)
|
|
2387
|
-
elif not output_nofeature_tiles:
|
|
2388
|
-
pbar.update(1) # Still update progress bar
|
|
2389
|
-
continue
|
|
2390
|
-
|
|
2391
|
-
# Apply mask if provided
|
|
2392
|
-
if mask_gdf is not None:
|
|
2393
|
-
mask_features = mask_gdf[mask_gdf.intersects(window_bounds)]
|
|
2394
|
-
if len(mask_features) == 0:
|
|
2395
|
-
pbar.update(1) # Still update progress bar
|
|
2396
|
-
continue
|
|
2397
|
-
|
|
2398
|
-
# Read image data - keep original for GeoTIFF export
|
|
2399
|
-
orig_image_data = src.read(window=window)
|
|
2400
|
-
|
|
2401
|
-
# Create a copy for processing
|
|
2402
|
-
image_data = orig_image_data.copy().astype(np.float32)
|
|
2403
|
-
|
|
2404
|
-
# Normalize image data for processing
|
|
2405
|
-
for band in range(image_data.shape[0]):
|
|
2406
|
-
band_min, band_max = np.percentile(image_data[band], (1, 99))
|
|
2407
|
-
if band_max > band_min:
|
|
2408
|
-
image_data[band] = np.clip(
|
|
2409
|
-
(image_data[band] - band_min) / (band_max - band_min), 0, 1
|
|
2410
|
-
)
|
|
2411
|
-
|
|
2412
|
-
# Read secondary image data if provided
|
|
2413
|
-
if src2:
|
|
2414
|
-
image_data2 = src2.read(window=window)
|
|
2415
|
-
# Stack the two images
|
|
2416
|
-
image_data = np.vstack((image_data, image_data2))
|
|
2417
|
-
|
|
2418
|
-
# Apply blacken_around_feature if needed
|
|
2419
|
-
if blacken_around_feature and len(window_features) > 0:
|
|
2420
|
-
mask = np.zeros((tile_size_y, tile_size_x), dtype=bool)
|
|
2421
|
-
for _, feature in window_features.iterrows():
|
|
2422
|
-
# Project feature to pixel coordinates
|
|
2423
|
-
feature_pixels = features.rasterize(
|
|
2424
|
-
[(feature.geometry, 1)],
|
|
2425
|
-
out_shape=(tile_size_y, tile_size_x),
|
|
2426
|
-
transform=window_transform,
|
|
2427
|
-
)
|
|
2428
|
-
mask = np.logical_or(mask, feature_pixels.astype(bool))
|
|
2429
|
-
|
|
2430
|
-
# Apply mask to image
|
|
2431
|
-
for band in range(image_data.shape[0]):
|
|
2432
|
-
temp = image_data[band, :, :]
|
|
2433
|
-
temp[~mask] = 0
|
|
2434
|
-
image_data[band, :, :] = temp
|
|
2435
|
-
|
|
2436
|
-
# Apply rotation if specified
|
|
2437
|
-
if augmentation:
|
|
2438
|
-
# Convert to torch tensor for augmentation
|
|
2439
|
-
image_tensor = torch.from_numpy(image_data).unsqueeze(
|
|
2440
|
-
0
|
|
2441
|
-
) # Add batch dimension
|
|
2442
|
-
# Apply augmentation with proper data format
|
|
2443
|
-
augmented = augmentation({"image": image_tensor})
|
|
2444
|
-
image_data = (
|
|
2445
|
-
augmented["image"].squeeze(0).numpy()
|
|
2446
|
-
) # Remove batch dimension
|
|
2447
|
-
|
|
2448
|
-
# Create a processed version for regular image formats
|
|
2449
|
-
processed_image = (image_data * 255).astype(np.uint8)
|
|
2450
|
-
|
|
2451
|
-
# Create label mask
|
|
2452
|
-
label_mask = np.zeros((tile_size_y, tile_size_x), dtype=np.uint8)
|
|
2453
|
-
has_features = False
|
|
2454
|
-
|
|
2455
|
-
if len(window_features) > 0:
|
|
2456
|
-
for idx, feature in window_features.iterrows():
|
|
2457
|
-
# Get class value
|
|
2458
|
-
class_val = (
|
|
2459
|
-
feature[class_value_field]
|
|
2460
|
-
if class_value_field in feature
|
|
2461
|
-
else 1
|
|
2462
|
-
)
|
|
2463
|
-
if isinstance(class_val, str):
|
|
2464
|
-
# If class is a string, use its position in the unique classes list
|
|
2465
|
-
class_id = class_to_id.get(class_val, 1)
|
|
2466
|
-
else:
|
|
2467
|
-
# If class is already a number, use it directly
|
|
2468
|
-
class_id = int(class_val) if class_val > 0 else 1
|
|
2469
|
-
|
|
2470
|
-
# Get the geometry in pixel coordinates
|
|
2471
|
-
geom = feature.geometry.intersection(window_bounds)
|
|
2472
|
-
if not geom.is_empty:
|
|
2473
|
-
try:
|
|
2474
|
-
# Rasterize the feature
|
|
2475
|
-
feature_mask = features.rasterize(
|
|
2476
|
-
[(geom, class_id)],
|
|
2477
|
-
out_shape=(tile_size_y, tile_size_x),
|
|
2478
|
-
transform=window_transform,
|
|
2479
|
-
fill=0,
|
|
2480
|
-
all_touched=all_touched,
|
|
2481
|
-
)
|
|
2482
|
-
|
|
2483
|
-
# Update mask with higher class values taking precedence
|
|
2484
|
-
label_mask = np.maximum(label_mask, feature_mask)
|
|
2485
|
-
|
|
2486
|
-
# Check if any pixels were added
|
|
2487
|
-
if np.any(feature_mask):
|
|
2488
|
-
has_features = True
|
|
2489
|
-
except Exception as e:
|
|
2490
|
-
if not quiet:
|
|
2491
|
-
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
2492
|
-
stats["errors"] += 1
|
|
2493
|
-
|
|
2494
|
-
# Save as GeoTIFF if requested
|
|
2495
|
-
if save_geotiff or image_chip_format.upper() in [
|
|
2496
|
-
"TIFF",
|
|
2497
|
-
"TIF",
|
|
2498
|
-
"GEOTIFF",
|
|
2499
|
-
]:
|
|
2500
|
-
# Standardize extension to .tif for GeoTIFF files
|
|
2501
|
-
image_filename = f"tile_{chip_index:06d}.tif"
|
|
2502
|
-
image_path = os.path.join(image_dir, image_filename)
|
|
2503
|
-
|
|
2504
|
-
# Create profile for the GeoTIFF
|
|
2505
|
-
profile = src.profile.copy()
|
|
2506
|
-
profile.update(
|
|
2507
|
-
{
|
|
2508
|
-
"height": tile_size_y,
|
|
2509
|
-
"width": tile_size_x,
|
|
2510
|
-
"count": orig_image_data.shape[0],
|
|
2511
|
-
"transform": window_transform,
|
|
2512
|
-
}
|
|
2513
|
-
)
|
|
2514
|
-
|
|
2515
|
-
# Save the GeoTIFF with original data
|
|
2516
|
-
try:
|
|
2517
|
-
with rasterio.open(image_path, "w", **profile) as dst:
|
|
2518
|
-
dst.write(orig_image_data)
|
|
2519
|
-
stats["total_tiles"] += 1
|
|
2520
|
-
except Exception as e:
|
|
2521
|
-
if not quiet:
|
|
2522
|
-
pbar.write(
|
|
2523
|
-
f"ERROR saving image GeoTIFF for tile {chip_index}: {e}"
|
|
2524
|
-
)
|
|
2525
|
-
stats["errors"] += 1
|
|
2526
|
-
else:
|
|
2527
|
-
# For non-GeoTIFF formats, use PIL to save the image
|
|
2528
|
-
image_filename = (
|
|
2529
|
-
f"tile_{chip_index:06d}.{image_chip_format.lower()}"
|
|
2530
|
-
)
|
|
2531
|
-
image_path = os.path.join(image_dir, image_filename)
|
|
2532
|
-
|
|
2533
|
-
# Create PIL image for saving
|
|
2534
|
-
if processed_image.shape[0] == 1:
|
|
2535
|
-
img = Image.fromarray(processed_image[0])
|
|
2536
|
-
elif processed_image.shape[0] == 3:
|
|
2537
|
-
# For RGB, need to transpose and make sure it's the right data type
|
|
2538
|
-
rgb_data = np.transpose(processed_image, (1, 2, 0))
|
|
2539
|
-
img = Image.fromarray(rgb_data)
|
|
2540
|
-
else:
|
|
2541
|
-
# For multiband images, save only RGB or first three bands
|
|
2542
|
-
rgb_data = np.transpose(processed_image[:3], (1, 2, 0))
|
|
2543
|
-
img = Image.fromarray(rgb_data)
|
|
2544
|
-
|
|
2545
|
-
# Save image
|
|
2546
|
-
try:
|
|
2547
|
-
img.save(image_path)
|
|
2548
|
-
stats["total_tiles"] += 1
|
|
2549
|
-
except Exception as e:
|
|
2550
|
-
if not quiet:
|
|
2551
|
-
pbar.write(f"ERROR saving image for tile {chip_index}: {e}")
|
|
2552
|
-
stats["errors"] += 1
|
|
2553
|
-
|
|
2554
|
-
# Save label as GeoTIFF
|
|
2555
|
-
label_filename = f"tile_{chip_index:06d}.tif"
|
|
2556
|
-
label_path = os.path.join(label_dir, label_filename)
|
|
2557
|
-
|
|
2558
|
-
# Create profile for label GeoTIFF
|
|
2559
|
-
label_profile = {
|
|
2560
|
-
"driver": "GTiff",
|
|
2561
|
-
"height": tile_size_y,
|
|
2562
|
-
"width": tile_size_x,
|
|
2563
|
-
"count": 1,
|
|
2564
|
-
"dtype": "uint8",
|
|
2565
|
-
"crs": src.crs,
|
|
2566
|
-
"transform": window_transform,
|
|
2567
|
-
}
|
|
2568
|
-
|
|
2569
|
-
# Save label GeoTIFF
|
|
2570
|
-
try:
|
|
2571
|
-
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
2572
|
-
dst.write(label_mask, 1)
|
|
2573
|
-
|
|
2574
|
-
if has_features:
|
|
2575
|
-
pixel_count = np.count_nonzero(label_mask)
|
|
2576
|
-
stats["tiles_with_features"] += 1
|
|
2577
|
-
stats["feature_pixels"] += pixel_count
|
|
2578
|
-
except Exception as e:
|
|
2579
|
-
if not quiet:
|
|
2580
|
-
pbar.write(f"ERROR saving label for tile {chip_index}: {e}")
|
|
2581
|
-
stats["errors"] += 1
|
|
2582
|
-
|
|
2583
|
-
# Also save a PNG version for easy visualization if requested
|
|
2584
|
-
if metadata_format == "PASCAL_VOC":
|
|
2585
|
-
try:
|
|
2586
|
-
# Ensure correct data type for PIL
|
|
2587
|
-
png_label = label_mask.astype(np.uint8)
|
|
2588
|
-
label_img = Image.fromarray(png_label)
|
|
2589
|
-
label_png_path = os.path.join(
|
|
2590
|
-
label_dir, f"tile_{chip_index:06d}.png"
|
|
2591
|
-
)
|
|
2592
|
-
label_img.save(label_png_path)
|
|
2593
|
-
except Exception as e:
|
|
2594
|
-
if not quiet:
|
|
2595
|
-
pbar.write(
|
|
2596
|
-
f"ERROR saving PNG label for tile {chip_index}: {e}"
|
|
2597
|
-
)
|
|
2598
|
-
pbar.write(
|
|
2599
|
-
f" Label mask shape: {label_mask.shape}, dtype: {label_mask.dtype}"
|
|
2600
|
-
)
|
|
2601
|
-
# Try again with explicit conversion
|
|
2602
|
-
try:
|
|
2603
|
-
# Alternative approach for problematic arrays
|
|
2604
|
-
png_data = np.zeros(
|
|
2605
|
-
(tile_size_y, tile_size_x), dtype=np.uint8
|
|
2606
|
-
)
|
|
2607
|
-
np.copyto(png_data, label_mask, casting="unsafe")
|
|
2608
|
-
label_img = Image.fromarray(png_data)
|
|
2609
|
-
label_img.save(label_png_path)
|
|
2610
|
-
pbar.write(
|
|
2611
|
-
f" Succeeded using alternative conversion method"
|
|
2612
|
-
)
|
|
2613
|
-
except Exception as e2:
|
|
2614
|
-
pbar.write(f" Second attempt also failed: {e2}")
|
|
2615
|
-
stats["errors"] += 1
|
|
2616
|
-
|
|
2617
|
-
# Generate annotations
|
|
2618
|
-
if metadata_format == "PASCAL_VOC" and len(window_features) > 0:
|
|
2619
|
-
# Create XML annotation
|
|
2620
|
-
root = ET.Element("annotation")
|
|
2621
|
-
ET.SubElement(root, "folder").text = "images"
|
|
2622
|
-
ET.SubElement(root, "filename").text = image_filename
|
|
2623
|
-
|
|
2624
|
-
size = ET.SubElement(root, "size")
|
|
2625
|
-
ET.SubElement(size, "width").text = str(tile_size_x)
|
|
2626
|
-
ET.SubElement(size, "height").text = str(tile_size_y)
|
|
2627
|
-
ET.SubElement(size, "depth").text = str(min(image_data.shape[0], 3))
|
|
2628
|
-
|
|
2629
|
-
# Add georeference information
|
|
2630
|
-
geo = ET.SubElement(root, "georeference")
|
|
2631
|
-
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
2632
|
-
ET.SubElement(geo, "transform").text = str(
|
|
2633
|
-
window_transform
|
|
2634
|
-
).replace("\n", "")
|
|
2635
|
-
ET.SubElement(geo, "bounds").text = (
|
|
2636
|
-
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
2637
|
-
)
|
|
2638
|
-
|
|
2639
|
-
for _, feature in window_features.iterrows():
|
|
2640
|
-
# Convert feature geometry to pixel coordinates
|
|
2641
|
-
feature_bounds = feature.geometry.intersection(window_bounds)
|
|
2642
|
-
if feature_bounds.is_empty:
|
|
2643
|
-
continue
|
|
2644
|
-
|
|
2645
|
-
# Get pixel coordinates of bounds
|
|
2646
|
-
minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
|
|
2647
|
-
|
|
2648
|
-
# Convert to pixel coordinates
|
|
2649
|
-
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
2650
|
-
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
2651
|
-
|
|
2652
|
-
# Ensure coordinates are within bounds
|
|
2653
|
-
xmin = max(0, min(tile_size_x, int(col_min)))
|
|
2654
|
-
ymin = max(0, min(tile_size_y, int(row_min)))
|
|
2655
|
-
xmax = max(0, min(tile_size_x, int(col_max)))
|
|
2656
|
-
ymax = max(0, min(tile_size_y, int(row_max)))
|
|
2657
|
-
|
|
2658
|
-
# Skip if box is too small
|
|
2659
|
-
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
2660
|
-
continue
|
|
2661
|
-
|
|
2662
|
-
obj = ET.SubElement(root, "object")
|
|
2663
|
-
ET.SubElement(obj, "name").text = str(
|
|
2664
|
-
feature[class_value_field]
|
|
2665
|
-
)
|
|
2666
|
-
ET.SubElement(obj, "difficult").text = "0"
|
|
2667
|
-
|
|
2668
|
-
bbox = ET.SubElement(obj, "bndbox")
|
|
2669
|
-
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
2670
|
-
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
2671
|
-
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
2672
|
-
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
2673
|
-
|
|
2674
|
-
# Save XML
|
|
2675
|
-
try:
|
|
2676
|
-
tree = ET.ElementTree(root)
|
|
2677
|
-
xml_path = os.path.join(ann_dir, f"tile_{chip_index:06d}.xml")
|
|
2678
|
-
tree.write(xml_path)
|
|
2679
|
-
except Exception as e:
|
|
2680
|
-
if not quiet:
|
|
2681
|
-
pbar.write(
|
|
2682
|
-
f"ERROR saving XML annotation for tile {chip_index}: {e}"
|
|
2683
|
-
)
|
|
2684
|
-
stats["errors"] += 1
|
|
2685
|
-
|
|
2686
|
-
elif metadata_format == "COCO" and len(window_features) > 0:
|
|
2687
|
-
# Add image info
|
|
2688
|
-
image_id = chip_index
|
|
2689
|
-
coco_annotations["images"].append(
|
|
2690
|
-
{
|
|
2691
|
-
"id": image_id,
|
|
2692
|
-
"file_name": image_filename,
|
|
2693
|
-
"width": tile_size_x,
|
|
2694
|
-
"height": tile_size_y,
|
|
2695
|
-
"crs": str(src.crs),
|
|
2696
|
-
"transform": str(window_transform),
|
|
2697
|
-
}
|
|
2698
|
-
)
|
|
2699
|
-
|
|
2700
|
-
# Add annotations for each feature
|
|
2701
|
-
for _, feature in window_features.iterrows():
|
|
2702
|
-
feature_bounds = feature.geometry.intersection(window_bounds)
|
|
2703
|
-
if feature_bounds.is_empty:
|
|
2704
|
-
continue
|
|
2705
|
-
|
|
2706
|
-
# Get pixel coordinates of bounds
|
|
2707
|
-
minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
|
|
2708
|
-
|
|
2709
|
-
# Convert to pixel coordinates
|
|
2710
|
-
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
2711
|
-
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
2712
|
-
|
|
2713
|
-
# Ensure coordinates are within bounds
|
|
2714
|
-
xmin = max(0, min(tile_size_x, int(col_min)))
|
|
2715
|
-
ymin = max(0, min(tile_size_y, int(row_min)))
|
|
2716
|
-
xmax = max(0, min(tile_size_x, int(col_max)))
|
|
2717
|
-
ymax = max(0, min(tile_size_y, int(row_max)))
|
|
2718
|
-
|
|
2719
|
-
# Skip if box is too small
|
|
2720
|
-
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
2721
|
-
continue
|
|
2722
|
-
|
|
2723
|
-
width = xmax - xmin
|
|
2724
|
-
height = ymax - ymin
|
|
2725
|
-
|
|
2726
|
-
# Add annotation
|
|
2727
|
-
ann_id += 1
|
|
2728
|
-
category_id = class_to_id[feature[class_value_field]]
|
|
2729
|
-
|
|
2730
|
-
coco_annotations["annotations"].append(
|
|
2731
|
-
{
|
|
2732
|
-
"id": ann_id,
|
|
2733
|
-
"image_id": image_id,
|
|
2734
|
-
"category_id": category_id,
|
|
2735
|
-
"bbox": [xmin, ymin, width, height],
|
|
2736
|
-
"area": width * height,
|
|
2737
|
-
"iscrowd": 0,
|
|
2738
|
-
}
|
|
2739
|
-
)
|
|
2740
|
-
|
|
2741
|
-
# Update progress bar
|
|
2742
|
-
pbar.update(1)
|
|
2743
|
-
pbar.set_description(
|
|
2744
|
-
f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
|
|
2745
|
-
)
|
|
2746
|
-
|
|
2747
|
-
chip_index += 1
|
|
2748
|
-
|
|
2749
|
-
# Close progress bar
|
|
2750
|
-
pbar.close()
|
|
2751
|
-
|
|
2752
|
-
# Save COCO annotations if applicable
|
|
2753
|
-
if metadata_format == "COCO":
|
|
2754
|
-
try:
|
|
2755
|
-
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
|
2756
|
-
json.dump(coco_annotations, f)
|
|
2757
|
-
except Exception as e:
|
|
2758
|
-
if not quiet:
|
|
2759
|
-
print(f"ERROR saving COCO annotations: {e}")
|
|
2760
|
-
stats["errors"] += 1
|
|
2761
|
-
|
|
2762
|
-
# Close secondary raster if opened
|
|
2763
|
-
if src2:
|
|
2764
|
-
src2.close()
|
|
2765
|
-
|
|
2766
|
-
# Print summary
|
|
2767
|
-
if not quiet:
|
|
2768
|
-
print("\n------- Export Summary -------")
|
|
2769
|
-
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
2770
|
-
print(
|
|
2771
|
-
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
2772
|
-
)
|
|
2773
|
-
if stats["tiles_with_features"] > 0:
|
|
2774
|
-
print(
|
|
2775
|
-
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
2776
|
-
)
|
|
2777
|
-
if stats["errors"] > 0:
|
|
2778
|
-
print(f"Errors encountered: {stats['errors']}")
|
|
2779
|
-
print(f"Output saved to: {out_folder}")
|
|
2780
|
-
|
|
2781
|
-
# Verify georeference in a sample image and label
|
|
2782
|
-
if stats["total_tiles"] > 0:
|
|
2783
|
-
print("\n------- Georeference Verification -------")
|
|
2784
|
-
sample_image = os.path.join(image_dir, f"tile_{start_index}.tif")
|
|
2785
|
-
sample_label = os.path.join(label_dir, f"tile_{start_index}.tif")
|
|
2786
|
-
|
|
2787
|
-
if os.path.exists(sample_image):
|
|
2788
|
-
try:
|
|
2789
|
-
with rasterio.open(sample_image) as img:
|
|
2790
|
-
print(f"Image CRS: {img.crs}")
|
|
2791
|
-
print(f"Image transform: {img.transform}")
|
|
2792
|
-
print(
|
|
2793
|
-
f"Image has georeference: {img.crs is not None and img.transform is not None}"
|
|
2794
|
-
)
|
|
2795
|
-
print(
|
|
2796
|
-
f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
|
|
2797
|
-
)
|
|
2798
|
-
except Exception as e:
|
|
2799
|
-
print(f"Error verifying image georeference: {e}")
|
|
2800
|
-
|
|
2801
|
-
if os.path.exists(sample_label):
|
|
2802
|
-
try:
|
|
2803
|
-
with rasterio.open(sample_label) as lbl:
|
|
2804
|
-
print(f"Label CRS: {lbl.crs}")
|
|
2805
|
-
print(f"Label transform: {lbl.transform}")
|
|
2806
|
-
print(
|
|
2807
|
-
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
2808
|
-
)
|
|
2809
|
-
print(
|
|
2810
|
-
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
2811
|
-
)
|
|
2812
|
-
except Exception as e:
|
|
2813
|
-
print(f"Error verifying label georeference: {e}")
|
|
2814
|
-
|
|
2815
|
-
# Return statistics
|
|
2816
|
-
return stats, out_folder
|
|
2817
|
-
|
|
2818
|
-
|
|
2819
|
-
def masks_to_vector(
|
|
2820
|
-
mask_path,
|
|
2821
|
-
output_path=None,
|
|
2822
|
-
simplify_tolerance=1.0,
|
|
2823
|
-
mask_threshold=0.5,
|
|
2824
|
-
min_object_area=100,
|
|
2825
|
-
max_object_area=None,
|
|
2826
|
-
nms_iou_threshold=0.5,
|
|
2827
|
-
):
|
|
2828
|
-
"""
|
|
2829
|
-
Convert a building mask GeoTIFF to vector polygons and save as a vector dataset.
|
|
2830
|
-
|
|
2831
|
-
Args:
|
|
2832
|
-
mask_path: Path to the building masks GeoTIFF
|
|
2833
|
-
output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
|
|
2834
|
-
simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
|
|
2835
|
-
mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
|
|
2836
|
-
min_object_area: Minimum area in pixels to keep a building (default: self.min_object_area)
|
|
2837
|
-
max_object_area: Maximum area in pixels to keep a building (default: self.max_object_area)
|
|
2838
|
-
nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
|
|
2839
|
-
|
|
2840
|
-
Returns:
|
|
2841
|
-
GeoDataFrame with building footprints
|
|
2842
|
-
"""
|
|
2843
|
-
# Set default output path if not provided
|
|
2844
|
-
# if output_path is None:
|
|
2845
|
-
# output_path = os.path.splitext(mask_path)[0] + ".geojson"
|
|
2846
|
-
|
|
2847
|
-
print(f"Converting mask to GeoJSON with parameters:")
|
|
2848
|
-
print(f"- Mask threshold: {mask_threshold}")
|
|
2849
|
-
print(f"- Min building area: {min_object_area}")
|
|
2850
|
-
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
2851
|
-
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
2852
|
-
|
|
2853
|
-
# Open the mask raster
|
|
2854
|
-
with rasterio.open(mask_path) as src:
|
|
2855
|
-
# Read the mask data
|
|
2856
|
-
mask_data = src.read(1)
|
|
2857
|
-
transform = src.transform
|
|
2858
|
-
crs = src.crs
|
|
2859
|
-
|
|
2860
|
-
# Print mask statistics
|
|
2861
|
-
print(f"Mask dimensions: {mask_data.shape}")
|
|
2862
|
-
print(f"Mask value range: {mask_data.min()} to {mask_data.max()}")
|
|
2863
|
-
|
|
2864
|
-
# Prepare for connected component analysis
|
|
2865
|
-
# Binarize the mask based on threshold
|
|
2866
|
-
binary_mask = (mask_data > (mask_threshold * 255)).astype(np.uint8)
|
|
2867
|
-
|
|
2868
|
-
# Apply morphological operations for better results (optional)
|
|
2869
|
-
kernel = np.ones((3, 3), np.uint8)
|
|
2870
|
-
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
|
|
2871
|
-
|
|
2872
|
-
# Find connected components
|
|
2873
|
-
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
|
2874
|
-
binary_mask, connectivity=8
|
|
2875
|
-
)
|
|
2876
|
-
|
|
2877
|
-
print(f"Found {num_labels-1} potential buildings") # Subtract 1 for background
|
|
2878
|
-
|
|
2879
|
-
# Create list to store polygons and confidence values
|
|
2880
|
-
all_polygons = []
|
|
2881
|
-
all_confidences = []
|
|
2882
|
-
|
|
2883
|
-
# Process each component (skip the first one which is background)
|
|
2884
|
-
for i in tqdm(range(1, num_labels)):
|
|
2885
|
-
# Extract this building
|
|
2886
|
-
area = stats[i, cv2.CC_STAT_AREA]
|
|
2887
|
-
|
|
2888
|
-
# Skip if too small
|
|
2889
|
-
if area < min_object_area:
|
|
2890
|
-
continue
|
|
2891
|
-
|
|
2892
|
-
# Skip if too large
|
|
2893
|
-
if max_object_area is not None and area > max_object_area:
|
|
2894
|
-
continue
|
|
2895
|
-
|
|
2896
|
-
# Create a mask for this building
|
|
2897
|
-
building_mask = (labels == i).astype(np.uint8)
|
|
2898
|
-
|
|
2899
|
-
# Find contours
|
|
2900
|
-
contours, _ = cv2.findContours(
|
|
2901
|
-
building_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
2902
|
-
)
|
|
2903
|
-
|
|
2904
|
-
# Process each contour
|
|
2905
|
-
for contour in contours:
|
|
2906
|
-
# Skip if too few points
|
|
2907
|
-
if contour.shape[0] < 3:
|
|
2908
|
-
continue
|
|
2909
|
-
|
|
2910
|
-
# Simplify contour if it has many points
|
|
2911
|
-
if contour.shape[0] > 50 and simplify_tolerance > 0:
|
|
2912
|
-
epsilon = simplify_tolerance * cv2.arcLength(contour, True)
|
|
2913
|
-
contour = cv2.approxPolyDP(contour, epsilon, True)
|
|
2914
|
-
|
|
2915
|
-
# Convert to list of (x, y) coordinates
|
|
2916
|
-
polygon_points = contour.reshape(-1, 2)
|
|
2917
|
-
|
|
2918
|
-
# Convert pixel coordinates to geographic coordinates
|
|
2919
|
-
geo_points = []
|
|
2920
|
-
for x, y in polygon_points:
|
|
2921
|
-
gx, gy = transform * (x, y)
|
|
2922
|
-
geo_points.append((gx, gy))
|
|
2923
|
-
|
|
2924
|
-
# Create Shapely polygon
|
|
2925
|
-
if len(geo_points) >= 3:
|
|
2926
|
-
try:
|
|
2927
|
-
shapely_poly = Polygon(geo_points)
|
|
2928
|
-
if shapely_poly.is_valid and shapely_poly.area > 0:
|
|
2929
|
-
all_polygons.append(shapely_poly)
|
|
2930
|
-
|
|
2931
|
-
# Calculate "confidence" as normalized size
|
|
2932
|
-
# This is a proxy since we don't have model confidence scores
|
|
2933
|
-
normalized_size = min(1.0, area / 1000) # Cap at 1.0
|
|
2934
|
-
all_confidences.append(normalized_size)
|
|
2935
|
-
except Exception as e:
|
|
2936
|
-
print(f"Error creating polygon: {e}")
|
|
2937
|
-
|
|
2938
|
-
print(f"Created {len(all_polygons)} valid polygons")
|
|
2939
|
-
|
|
2940
|
-
# Create GeoDataFrame
|
|
2941
|
-
if not all_polygons:
|
|
2942
|
-
print("No valid polygons found")
|
|
2943
|
-
return None
|
|
2944
|
-
|
|
2945
|
-
gdf = gpd.GeoDataFrame(
|
|
2946
|
-
{
|
|
2947
|
-
"geometry": all_polygons,
|
|
2948
|
-
"confidence": all_confidences,
|
|
2949
|
-
"class": 1, # Building class
|
|
2950
|
-
},
|
|
2951
|
-
crs=crs,
|
|
2952
|
-
)
|
|
2953
|
-
|
|
2954
|
-
def filter_overlapping_polygons(gdf, **kwargs):
|
|
2955
|
-
"""
|
|
2956
|
-
Filter overlapping polygons using non-maximum suppression.
|
|
2957
|
-
|
|
2958
|
-
Args:
|
|
2959
|
-
gdf: GeoDataFrame with polygons
|
|
2960
|
-
**kwargs: Optional parameters:
|
|
2961
|
-
nms_iou_threshold: IoU threshold for filtering
|
|
2962
|
-
|
|
2963
|
-
Returns:
|
|
2964
|
-
Filtered GeoDataFrame
|
|
2965
|
-
"""
|
|
2966
|
-
if len(gdf) <= 1:
|
|
2967
|
-
return gdf
|
|
2968
|
-
|
|
2969
|
-
# Get parameters from kwargs or use instance defaults
|
|
2970
|
-
iou_threshold = kwargs.get("nms_iou_threshold", nms_iou_threshold)
|
|
2971
|
-
|
|
2972
|
-
# Sort by confidence
|
|
2973
|
-
gdf = gdf.sort_values("confidence", ascending=False)
|
|
2974
|
-
|
|
2975
|
-
# Fix any invalid geometries
|
|
2976
|
-
gdf["geometry"] = gdf["geometry"].apply(
|
|
2977
|
-
lambda geom: geom.buffer(0) if not geom.is_valid else geom
|
|
2978
|
-
)
|
|
2979
|
-
|
|
2980
|
-
keep_indices = []
|
|
2981
|
-
polygons = gdf.geometry.values
|
|
2982
|
-
|
|
2983
|
-
for i in range(len(polygons)):
|
|
2984
|
-
if i in keep_indices:
|
|
2985
|
-
continue
|
|
2986
|
-
|
|
2987
|
-
keep = True
|
|
2988
|
-
for j in keep_indices:
|
|
2989
|
-
# Skip invalid geometries
|
|
2990
|
-
if not polygons[i].is_valid or not polygons[j].is_valid:
|
|
2991
|
-
continue
|
|
2992
|
-
|
|
2993
|
-
# Calculate IoU
|
|
2994
|
-
try:
|
|
2995
|
-
intersection = polygons[i].intersection(polygons[j]).area
|
|
2996
|
-
union = polygons[i].area + polygons[j].area - intersection
|
|
2997
|
-
iou = intersection / union if union > 0 else 0
|
|
2998
|
-
|
|
2999
|
-
if iou > iou_threshold:
|
|
3000
|
-
keep = False
|
|
3001
|
-
break
|
|
3002
|
-
except Exception:
|
|
3003
|
-
# Skip on topology exceptions
|
|
3004
|
-
continue
|
|
3005
|
-
|
|
3006
|
-
if keep:
|
|
3007
|
-
keep_indices.append(i)
|
|
3008
|
-
|
|
3009
|
-
return gdf.iloc[keep_indices]
|
|
3010
|
-
|
|
3011
|
-
# Apply non-maximum suppression to remove overlapping polygons
|
|
3012
|
-
gdf = filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
|
|
3013
|
-
|
|
3014
|
-
print(f"Final building count after filtering: {len(gdf)}")
|
|
3015
|
-
|
|
3016
|
-
# Save to file
|
|
3017
|
-
if output_path is not None:
|
|
3018
|
-
gdf.to_file(output_path)
|
|
3019
|
-
print(f"Saved {len(gdf)} building footprints to {output_path}")
|
|
3020
|
-
|
|
3021
|
-
return gdf
|