geoai-py 0.3.2__py2.py3-none-any.whl → 0.3.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/extract.py +38 -17
- geoai/geoai.py +0 -1
- geoai/preprocess.py +20 -6
- geoai/utils.py +3788 -13
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.3.dist-info}/METADATA +11 -4
- geoai_py-0.3.3.dist-info/RECORD +13 -0
- geoai_py-0.3.2.dist-info/RECORD +0 -13
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.3.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.3.dist-info}/WHEEL +0 -0
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.3.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.2.dist-info → geoai_py-0.3.3.dist-info}/top_level.txt +0 -0
geoai/utils.py
CHANGED
|
@@ -1,22 +1,37 @@
|
|
|
1
|
-
"""The
|
|
1
|
+
"""The utils module contains common functions and classes used by the other modules."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
4
|
+
import math
|
|
3
5
|
import os
|
|
4
6
|
from collections.abc import Iterable
|
|
5
|
-
from
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
import requests
|
|
10
|
+
import warnings
|
|
11
|
+
import xml.etree.ElementTree as ET
|
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13
|
+
import numpy as np
|
|
6
14
|
import matplotlib.pyplot as plt
|
|
7
15
|
import geopandas as gpd
|
|
8
16
|
import leafmap
|
|
9
|
-
import torch
|
|
10
17
|
import numpy as np
|
|
18
|
+
import pandas as pd
|
|
11
19
|
import xarray as xr
|
|
12
|
-
import rioxarray
|
|
13
|
-
import rasterio
|
|
14
|
-
from
|
|
20
|
+
import rioxarray as rxr
|
|
21
|
+
import rasterio
|
|
22
|
+
from torchvision.transforms import RandomRotation
|
|
23
|
+
from rasterio.windows import Window
|
|
24
|
+
from rasterio import features
|
|
25
|
+
from rasterio.plot import show
|
|
26
|
+
from shapely.geometry import box, shape, mapping, Polygon, MultiPolygon
|
|
27
|
+
from shapely.affinity import rotate
|
|
28
|
+
from tqdm import tqdm
|
|
29
|
+
import torch
|
|
30
|
+
import torchgeo
|
|
31
|
+
import cv2
|
|
15
32
|
|
|
16
33
|
try:
|
|
17
|
-
from torchgeo.datasets import RasterDataset,
|
|
18
|
-
from torchgeo.samplers import RandomGeoSampler, Units
|
|
19
|
-
from torchgeo.transforms import indices
|
|
34
|
+
from torchgeo.datasets import RasterDataset, unbind_samples
|
|
20
35
|
except ImportError as e:
|
|
21
36
|
raise ImportError(
|
|
22
37
|
"Your torchgeo version is too old. Please upgrade to the latest version using 'pip install -U torchgeo'."
|
|
@@ -121,7 +136,7 @@ def view_image(
|
|
|
121
136
|
if isinstance(image, torch.Tensor):
|
|
122
137
|
image = image.cpu().numpy()
|
|
123
138
|
elif isinstance(image, str):
|
|
124
|
-
image =
|
|
139
|
+
image = rasterio.open(image).read().transpose(1, 2, 0)
|
|
125
140
|
|
|
126
141
|
plt.figure(figsize=figsize)
|
|
127
142
|
|
|
@@ -274,7 +289,6 @@ def calc_stats(
|
|
|
274
289
|
Returns:
|
|
275
290
|
Tuple[np.ndarray, np.ndarray]: The mean and standard deviation for each band.
|
|
276
291
|
"""
|
|
277
|
-
import rasterio as rio
|
|
278
292
|
|
|
279
293
|
# To avoid loading the entire dataset in memory, we will loop through each img
|
|
280
294
|
# The filenames will be retrieved from the dataset's rtree index
|
|
@@ -288,7 +302,7 @@ def calc_stats(
|
|
|
288
302
|
accum_std = 0
|
|
289
303
|
|
|
290
304
|
for file in files:
|
|
291
|
-
img =
|
|
305
|
+
img = rasterio.open(file).read() / divide_by # type: ignore
|
|
292
306
|
accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
|
|
293
307
|
accum_std += img.reshape((img.shape[0], -1)).std(axis=1)
|
|
294
308
|
|
|
@@ -396,7 +410,7 @@ def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
|
|
|
396
410
|
|
|
397
411
|
def dict_to_image(
|
|
398
412
|
data_dict: Dict[str, Any], output: Optional[str] = None, **kwargs
|
|
399
|
-
) ->
|
|
413
|
+
) -> rasterio.DatasetReader:
|
|
400
414
|
"""Convert a dictionary containing spatial data to a rasterio dataset or save it to
|
|
401
415
|
a file. The dictionary should contain the following keys: "crs", "bounds", and "image".
|
|
402
416
|
It can be generated from a TorchGeo dataset sampler.
|
|
@@ -1174,3 +1188,3764 @@ def create_split_map(
|
|
|
1174
1188
|
)
|
|
1175
1189
|
|
|
1176
1190
|
return m
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
def download_file(url, output_path=None, overwrite=False):
|
|
1194
|
+
"""
|
|
1195
|
+
Download a file from a given URL with a progress bar.
|
|
1196
|
+
|
|
1197
|
+
Args:
|
|
1198
|
+
url (str): The URL of the file to download.
|
|
1199
|
+
output_path (str, optional): The path where the downloaded file will be saved.
|
|
1200
|
+
If not provided, the filename from the URL will be used.
|
|
1201
|
+
overwrite (bool, optional): Whether to overwrite the file if it already exists.
|
|
1202
|
+
|
|
1203
|
+
Returns:
|
|
1204
|
+
str: The path to the downloaded file.
|
|
1205
|
+
"""
|
|
1206
|
+
# Get the filename from the URL if output_path is not provided
|
|
1207
|
+
if output_path is None:
|
|
1208
|
+
output_path = os.path.basename(url)
|
|
1209
|
+
|
|
1210
|
+
# Check if the file already exists
|
|
1211
|
+
if os.path.exists(output_path) and not overwrite:
|
|
1212
|
+
print(f"File already exists: {output_path}")
|
|
1213
|
+
return output_path
|
|
1214
|
+
|
|
1215
|
+
# Send a streaming GET request
|
|
1216
|
+
response = requests.get(url, stream=True, timeout=50)
|
|
1217
|
+
response.raise_for_status() # Raise an exception for HTTP errors
|
|
1218
|
+
|
|
1219
|
+
# Get the total file size if available
|
|
1220
|
+
total_size = int(response.headers.get("content-length", 0))
|
|
1221
|
+
|
|
1222
|
+
# Open the output file
|
|
1223
|
+
with (
|
|
1224
|
+
open(output_path, "wb") as file,
|
|
1225
|
+
tqdm(
|
|
1226
|
+
desc=os.path.basename(output_path),
|
|
1227
|
+
total=total_size,
|
|
1228
|
+
unit="B",
|
|
1229
|
+
unit_scale=True,
|
|
1230
|
+
unit_divisor=1024,
|
|
1231
|
+
) as progress_bar,
|
|
1232
|
+
):
|
|
1233
|
+
|
|
1234
|
+
# Download the file in chunks and update the progress bar
|
|
1235
|
+
for chunk in response.iter_content(chunk_size=1024):
|
|
1236
|
+
if chunk: # filter out keep-alive new chunks
|
|
1237
|
+
file.write(chunk)
|
|
1238
|
+
progress_bar.update(len(chunk))
|
|
1239
|
+
|
|
1240
|
+
return output_path
|
|
1241
|
+
|
|
1242
|
+
|
|
1243
|
+
def get_raster_info(raster_path):
|
|
1244
|
+
"""Display basic information about a raster dataset.
|
|
1245
|
+
|
|
1246
|
+
Args:
|
|
1247
|
+
raster_path (str): Path to the raster file
|
|
1248
|
+
|
|
1249
|
+
Returns:
|
|
1250
|
+
dict: Dictionary containing the basic information about the raster
|
|
1251
|
+
"""
|
|
1252
|
+
# Open the raster dataset
|
|
1253
|
+
with rasterio.open(raster_path) as src:
|
|
1254
|
+
# Get basic metadata
|
|
1255
|
+
info = {
|
|
1256
|
+
"driver": src.driver,
|
|
1257
|
+
"width": src.width,
|
|
1258
|
+
"height": src.height,
|
|
1259
|
+
"count": src.count,
|
|
1260
|
+
"dtype": src.dtypes[0],
|
|
1261
|
+
"crs": src.crs.to_string() if src.crs else "No CRS defined",
|
|
1262
|
+
"transform": src.transform,
|
|
1263
|
+
"bounds": src.bounds,
|
|
1264
|
+
"resolution": (src.transform[0], -src.transform[4]),
|
|
1265
|
+
"nodata": src.nodata,
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
# Calculate statistics for each band
|
|
1269
|
+
stats = []
|
|
1270
|
+
for i in range(1, src.count + 1):
|
|
1271
|
+
band = src.read(i, masked=True)
|
|
1272
|
+
band_stats = {
|
|
1273
|
+
"band": i,
|
|
1274
|
+
"min": float(band.min()),
|
|
1275
|
+
"max": float(band.max()),
|
|
1276
|
+
"mean": float(band.mean()),
|
|
1277
|
+
"std": float(band.std()),
|
|
1278
|
+
}
|
|
1279
|
+
stats.append(band_stats)
|
|
1280
|
+
|
|
1281
|
+
info["band_stats"] = stats
|
|
1282
|
+
|
|
1283
|
+
return info
|
|
1284
|
+
|
|
1285
|
+
|
|
1286
|
+
def get_raster_stats(raster_path, divide_by=1.0):
|
|
1287
|
+
"""Calculate statistics for each band in a raster dataset.
|
|
1288
|
+
|
|
1289
|
+
This function computes min, max, mean, and standard deviation values
|
|
1290
|
+
for each band in the provided raster, returning results in a dictionary
|
|
1291
|
+
with lists for each statistic type.
|
|
1292
|
+
|
|
1293
|
+
Args:
|
|
1294
|
+
raster_path (str): Path to the raster file
|
|
1295
|
+
divide_by (float, optional): Value to divide pixel values by.
|
|
1296
|
+
Defaults to 1.0, which keeps the original pixel
|
|
1297
|
+
|
|
1298
|
+
Returns:
|
|
1299
|
+
dict: Dictionary containing lists of statistics with keys:
|
|
1300
|
+
- 'min': List of minimum values for each band
|
|
1301
|
+
- 'max': List of maximum values for each band
|
|
1302
|
+
- 'mean': List of mean values for each band
|
|
1303
|
+
- 'std': List of standard deviation values for each band
|
|
1304
|
+
"""
|
|
1305
|
+
# Initialize the results dictionary with empty lists
|
|
1306
|
+
stats = {"min": [], "max": [], "mean": [], "std": []}
|
|
1307
|
+
|
|
1308
|
+
# Open the raster dataset
|
|
1309
|
+
with rasterio.open(raster_path) as src:
|
|
1310
|
+
# Calculate statistics for each band
|
|
1311
|
+
for i in range(1, src.count + 1):
|
|
1312
|
+
band = src.read(i, masked=True)
|
|
1313
|
+
|
|
1314
|
+
# Append statistics for this band to each list
|
|
1315
|
+
stats["min"].append(float(band.min()) / divide_by)
|
|
1316
|
+
stats["max"].append(float(band.max()) / divide_by)
|
|
1317
|
+
stats["mean"].append(float(band.mean()) / divide_by)
|
|
1318
|
+
stats["std"].append(float(band.std()) / divide_by)
|
|
1319
|
+
|
|
1320
|
+
return stats
|
|
1321
|
+
|
|
1322
|
+
|
|
1323
|
+
def print_raster_info(raster_path, show_preview=True, figsize=(10, 8)):
|
|
1324
|
+
"""Print formatted information about a raster dataset and optionally show a preview.
|
|
1325
|
+
|
|
1326
|
+
Args:
|
|
1327
|
+
raster_path (str): Path to the raster file
|
|
1328
|
+
show_preview (bool, optional): Whether to display a visual preview of the raster.
|
|
1329
|
+
Defaults to True.
|
|
1330
|
+
figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
|
|
1331
|
+
|
|
1332
|
+
Returns:
|
|
1333
|
+
dict: Dictionary containing raster information if successful, None otherwise
|
|
1334
|
+
"""
|
|
1335
|
+
try:
|
|
1336
|
+
info = get_raster_info(raster_path)
|
|
1337
|
+
|
|
1338
|
+
# Print basic information
|
|
1339
|
+
print(f"===== RASTER INFORMATION: {raster_path} =====")
|
|
1340
|
+
print(f"Driver: {info['driver']}")
|
|
1341
|
+
print(f"Dimensions: {info['width']} x {info['height']} pixels")
|
|
1342
|
+
print(f"Number of bands: {info['count']}")
|
|
1343
|
+
print(f"Data type: {info['dtype']}")
|
|
1344
|
+
print(f"Coordinate Reference System: {info['crs']}")
|
|
1345
|
+
print(f"Georeferenced Bounds: {info['bounds']}")
|
|
1346
|
+
print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
|
|
1347
|
+
print(f"NoData Value: {info['nodata']}")
|
|
1348
|
+
|
|
1349
|
+
# Print band statistics
|
|
1350
|
+
print("\n----- Band Statistics -----")
|
|
1351
|
+
for band_stat in info["band_stats"]:
|
|
1352
|
+
print(f"Band {band_stat['band']}:")
|
|
1353
|
+
print(f" Min: {band_stat['min']:.2f}")
|
|
1354
|
+
print(f" Max: {band_stat['max']:.2f}")
|
|
1355
|
+
print(f" Mean: {band_stat['mean']:.2f}")
|
|
1356
|
+
print(f" Std Dev: {band_stat['std']:.2f}")
|
|
1357
|
+
|
|
1358
|
+
# Show a preview if requested
|
|
1359
|
+
if show_preview:
|
|
1360
|
+
with rasterio.open(raster_path) as src:
|
|
1361
|
+
# For multi-band images, show RGB composite or first band
|
|
1362
|
+
if src.count >= 3:
|
|
1363
|
+
# Try to show RGB composite
|
|
1364
|
+
rgb = np.dstack([src.read(i) for i in range(1, 4)])
|
|
1365
|
+
plt.figure(figsize=figsize)
|
|
1366
|
+
plt.imshow(rgb)
|
|
1367
|
+
plt.title(f"RGB Preview: {raster_path}")
|
|
1368
|
+
else:
|
|
1369
|
+
# Show first band for single-band images
|
|
1370
|
+
plt.figure(figsize=figsize)
|
|
1371
|
+
show(
|
|
1372
|
+
src.read(1),
|
|
1373
|
+
cmap="viridis",
|
|
1374
|
+
title=f"Band 1 Preview: {raster_path}",
|
|
1375
|
+
)
|
|
1376
|
+
plt.colorbar(label="Pixel Value")
|
|
1377
|
+
plt.show()
|
|
1378
|
+
|
|
1379
|
+
except Exception as e:
|
|
1380
|
+
print(f"Error reading raster: {str(e)}")
|
|
1381
|
+
|
|
1382
|
+
|
|
1383
|
+
def get_raster_info_gdal(raster_path):
|
|
1384
|
+
"""Get basic information about a raster dataset using GDAL.
|
|
1385
|
+
|
|
1386
|
+
Args:
|
|
1387
|
+
raster_path (str): Path to the raster file
|
|
1388
|
+
|
|
1389
|
+
Returns:
|
|
1390
|
+
dict: Dictionary containing the basic information about the raster,
|
|
1391
|
+
or None if the file cannot be opened
|
|
1392
|
+
"""
|
|
1393
|
+
|
|
1394
|
+
from osgeo import gdal
|
|
1395
|
+
|
|
1396
|
+
# Open the dataset
|
|
1397
|
+
ds = gdal.Open(raster_path)
|
|
1398
|
+
if ds is None:
|
|
1399
|
+
print(f"Error: Could not open {raster_path}")
|
|
1400
|
+
return None
|
|
1401
|
+
|
|
1402
|
+
# Get basic information
|
|
1403
|
+
info = {
|
|
1404
|
+
"driver": ds.GetDriver().ShortName,
|
|
1405
|
+
"width": ds.RasterXSize,
|
|
1406
|
+
"height": ds.RasterYSize,
|
|
1407
|
+
"count": ds.RasterCount,
|
|
1408
|
+
"projection": ds.GetProjection(),
|
|
1409
|
+
"geotransform": ds.GetGeoTransform(),
|
|
1410
|
+
}
|
|
1411
|
+
|
|
1412
|
+
# Calculate resolution
|
|
1413
|
+
gt = ds.GetGeoTransform()
|
|
1414
|
+
if gt:
|
|
1415
|
+
info["resolution"] = (abs(gt[1]), abs(gt[5]))
|
|
1416
|
+
info["origin"] = (gt[0], gt[3])
|
|
1417
|
+
|
|
1418
|
+
# Get band information
|
|
1419
|
+
bands_info = []
|
|
1420
|
+
for i in range(1, ds.RasterCount + 1):
|
|
1421
|
+
band = ds.GetRasterBand(i)
|
|
1422
|
+
stats = band.GetStatistics(True, True)
|
|
1423
|
+
band_info = {
|
|
1424
|
+
"band": i,
|
|
1425
|
+
"datatype": gdal.GetDataTypeName(band.DataType),
|
|
1426
|
+
"min": stats[0],
|
|
1427
|
+
"max": stats[1],
|
|
1428
|
+
"mean": stats[2],
|
|
1429
|
+
"std": stats[3],
|
|
1430
|
+
"nodata": band.GetNoDataValue(),
|
|
1431
|
+
}
|
|
1432
|
+
bands_info.append(band_info)
|
|
1433
|
+
|
|
1434
|
+
info["bands"] = bands_info
|
|
1435
|
+
|
|
1436
|
+
# Close the dataset
|
|
1437
|
+
ds = None
|
|
1438
|
+
|
|
1439
|
+
return info
|
|
1440
|
+
|
|
1441
|
+
|
|
1442
|
+
def get_vector_info(vector_path):
|
|
1443
|
+
"""Display basic information about a vector dataset using GeoPandas.
|
|
1444
|
+
|
|
1445
|
+
Args:
|
|
1446
|
+
vector_path (str): Path to the vector file
|
|
1447
|
+
|
|
1448
|
+
Returns:
|
|
1449
|
+
dict: Dictionary containing the basic information about the vector dataset
|
|
1450
|
+
"""
|
|
1451
|
+
# Open the vector dataset
|
|
1452
|
+
gdf = (
|
|
1453
|
+
gpd.read_parquet(vector_path)
|
|
1454
|
+
if vector_path.endswith(".parquet")
|
|
1455
|
+
else gpd.read_file(vector_path)
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
# Get basic metadata
|
|
1459
|
+
info = {
|
|
1460
|
+
"file_path": vector_path,
|
|
1461
|
+
"driver": os.path.splitext(vector_path)[1][1:].upper(), # Format from extension
|
|
1462
|
+
"feature_count": len(gdf),
|
|
1463
|
+
"crs": str(gdf.crs),
|
|
1464
|
+
"geometry_type": str(gdf.geom_type.value_counts().to_dict()),
|
|
1465
|
+
"attribute_count": len(gdf.columns) - 1, # Subtract the geometry column
|
|
1466
|
+
"attribute_names": list(gdf.columns[gdf.columns != "geometry"]),
|
|
1467
|
+
"bounds": gdf.total_bounds.tolist(),
|
|
1468
|
+
}
|
|
1469
|
+
|
|
1470
|
+
# Add statistics about numeric attributes
|
|
1471
|
+
numeric_columns = gdf.select_dtypes(include=["number"]).columns
|
|
1472
|
+
attribute_stats = {}
|
|
1473
|
+
for col in numeric_columns:
|
|
1474
|
+
if col != "geometry":
|
|
1475
|
+
attribute_stats[col] = {
|
|
1476
|
+
"min": gdf[col].min(),
|
|
1477
|
+
"max": gdf[col].max(),
|
|
1478
|
+
"mean": gdf[col].mean(),
|
|
1479
|
+
"std": gdf[col].std(),
|
|
1480
|
+
"null_count": gdf[col].isna().sum(),
|
|
1481
|
+
}
|
|
1482
|
+
|
|
1483
|
+
info["attribute_stats"] = attribute_stats
|
|
1484
|
+
|
|
1485
|
+
return info
|
|
1486
|
+
|
|
1487
|
+
|
|
1488
|
+
def print_vector_info(vector_path, show_preview=True, figsize=(10, 8)):
|
|
1489
|
+
"""Print formatted information about a vector dataset and optionally show a preview.
|
|
1490
|
+
|
|
1491
|
+
Args:
|
|
1492
|
+
vector_path (str): Path to the vector file
|
|
1493
|
+
show_preview (bool, optional): Whether to display a visual preview of the vector data.
|
|
1494
|
+
Defaults to True.
|
|
1495
|
+
figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
|
|
1496
|
+
|
|
1497
|
+
Returns:
|
|
1498
|
+
dict: Dictionary containing vector information if successful, None otherwise
|
|
1499
|
+
"""
|
|
1500
|
+
try:
|
|
1501
|
+
info = get_vector_info(vector_path)
|
|
1502
|
+
|
|
1503
|
+
# Print basic information
|
|
1504
|
+
print(f"===== VECTOR INFORMATION: {vector_path} =====")
|
|
1505
|
+
print(f"Driver: {info['driver']}")
|
|
1506
|
+
print(f"Feature count: {info['feature_count']}")
|
|
1507
|
+
print(f"Geometry types: {info['geometry_type']}")
|
|
1508
|
+
print(f"Coordinate Reference System: {info['crs']}")
|
|
1509
|
+
print(f"Bounds: {info['bounds']}")
|
|
1510
|
+
print(f"Number of attributes: {info['attribute_count']}")
|
|
1511
|
+
print(f"Attribute names: {', '.join(info['attribute_names'])}")
|
|
1512
|
+
|
|
1513
|
+
# Print attribute statistics
|
|
1514
|
+
if info["attribute_stats"]:
|
|
1515
|
+
print("\n----- Attribute Statistics -----")
|
|
1516
|
+
for attr, stats in info["attribute_stats"].items():
|
|
1517
|
+
print(f"Attribute: {attr}")
|
|
1518
|
+
for stat_name, stat_value in stats.items():
|
|
1519
|
+
print(
|
|
1520
|
+
f" {stat_name}: {stat_value:.4f}"
|
|
1521
|
+
if isinstance(stat_value, float)
|
|
1522
|
+
else f" {stat_name}: {stat_value}"
|
|
1523
|
+
)
|
|
1524
|
+
|
|
1525
|
+
# Show a preview if requested
|
|
1526
|
+
if show_preview:
|
|
1527
|
+
gdf = (
|
|
1528
|
+
gpd.read_parquet(vector_path)
|
|
1529
|
+
if vector_path.endswith(".parquet")
|
|
1530
|
+
else gpd.read_file(vector_path)
|
|
1531
|
+
)
|
|
1532
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1533
|
+
gdf.plot(ax=ax, cmap="viridis")
|
|
1534
|
+
ax.set_title(f"Preview: {vector_path}")
|
|
1535
|
+
plt.tight_layout()
|
|
1536
|
+
plt.show()
|
|
1537
|
+
|
|
1538
|
+
# # Show a sample of the attribute table
|
|
1539
|
+
# if not gdf.empty:
|
|
1540
|
+
# print("\n----- Sample of attribute table (first 5 rows) -----")
|
|
1541
|
+
# print(gdf.head().to_string())
|
|
1542
|
+
|
|
1543
|
+
except Exception as e:
|
|
1544
|
+
print(f"Error reading vector data: {str(e)}")
|
|
1545
|
+
|
|
1546
|
+
|
|
1547
|
+
# Alternative implementation using OGR directly
|
|
1548
|
+
def get_vector_info_ogr(vector_path):
|
|
1549
|
+
"""Get basic information about a vector dataset using OGR.
|
|
1550
|
+
|
|
1551
|
+
Args:
|
|
1552
|
+
vector_path (str): Path to the vector file
|
|
1553
|
+
|
|
1554
|
+
Returns:
|
|
1555
|
+
dict: Dictionary containing the basic information about the vector dataset,
|
|
1556
|
+
or None if the file cannot be opened
|
|
1557
|
+
"""
|
|
1558
|
+
from osgeo import ogr
|
|
1559
|
+
|
|
1560
|
+
# Register all OGR drivers
|
|
1561
|
+
ogr.RegisterAll()
|
|
1562
|
+
|
|
1563
|
+
# Open the dataset
|
|
1564
|
+
ds = ogr.Open(vector_path)
|
|
1565
|
+
if ds is None:
|
|
1566
|
+
print(f"Error: Could not open {vector_path}")
|
|
1567
|
+
return None
|
|
1568
|
+
|
|
1569
|
+
# Basic dataset information
|
|
1570
|
+
info = {
|
|
1571
|
+
"file_path": vector_path,
|
|
1572
|
+
"driver": ds.GetDriver().GetName(),
|
|
1573
|
+
"layer_count": ds.GetLayerCount(),
|
|
1574
|
+
"layers": [],
|
|
1575
|
+
}
|
|
1576
|
+
|
|
1577
|
+
# Extract information for each layer
|
|
1578
|
+
for i in range(ds.GetLayerCount()):
|
|
1579
|
+
layer = ds.GetLayer(i)
|
|
1580
|
+
layer_info = {
|
|
1581
|
+
"name": layer.GetName(),
|
|
1582
|
+
"feature_count": layer.GetFeatureCount(),
|
|
1583
|
+
"geometry_type": ogr.GeometryTypeToName(layer.GetGeomType()),
|
|
1584
|
+
"spatial_ref": (
|
|
1585
|
+
layer.GetSpatialRef().ExportToWkt() if layer.GetSpatialRef() else "None"
|
|
1586
|
+
),
|
|
1587
|
+
"extent": layer.GetExtent(),
|
|
1588
|
+
"fields": [],
|
|
1589
|
+
}
|
|
1590
|
+
|
|
1591
|
+
# Get field information
|
|
1592
|
+
defn = layer.GetLayerDefn()
|
|
1593
|
+
for j in range(defn.GetFieldCount()):
|
|
1594
|
+
field_defn = defn.GetFieldDefn(j)
|
|
1595
|
+
field_info = {
|
|
1596
|
+
"name": field_defn.GetName(),
|
|
1597
|
+
"type": field_defn.GetTypeName(),
|
|
1598
|
+
"width": field_defn.GetWidth(),
|
|
1599
|
+
"precision": field_defn.GetPrecision(),
|
|
1600
|
+
}
|
|
1601
|
+
layer_info["fields"].append(field_info)
|
|
1602
|
+
|
|
1603
|
+
info["layers"].append(layer_info)
|
|
1604
|
+
|
|
1605
|
+
# Close the dataset
|
|
1606
|
+
ds = None
|
|
1607
|
+
|
|
1608
|
+
return info
|
|
1609
|
+
|
|
1610
|
+
|
|
1611
|
+
def analyze_vector_attributes(vector_path, attribute_name):
|
|
1612
|
+
"""Analyze a specific attribute in a vector dataset and create a histogram.
|
|
1613
|
+
|
|
1614
|
+
Args:
|
|
1615
|
+
vector_path (str): Path to the vector file
|
|
1616
|
+
attribute_name (str): Name of the attribute to analyze
|
|
1617
|
+
|
|
1618
|
+
Returns:
|
|
1619
|
+
dict: Dictionary containing analysis results for the attribute
|
|
1620
|
+
"""
|
|
1621
|
+
try:
|
|
1622
|
+
gdf = gpd.read_file(vector_path)
|
|
1623
|
+
|
|
1624
|
+
# Check if attribute exists
|
|
1625
|
+
if attribute_name not in gdf.columns:
|
|
1626
|
+
print(f"Attribute '{attribute_name}' not found in the dataset")
|
|
1627
|
+
return None
|
|
1628
|
+
|
|
1629
|
+
# Get the attribute series
|
|
1630
|
+
attr = gdf[attribute_name]
|
|
1631
|
+
|
|
1632
|
+
# Perform different analyses based on data type
|
|
1633
|
+
if pd.api.types.is_numeric_dtype(attr):
|
|
1634
|
+
# Numeric attribute
|
|
1635
|
+
analysis = {
|
|
1636
|
+
"attribute": attribute_name,
|
|
1637
|
+
"type": "numeric",
|
|
1638
|
+
"count": attr.count(),
|
|
1639
|
+
"null_count": attr.isna().sum(),
|
|
1640
|
+
"min": attr.min(),
|
|
1641
|
+
"max": attr.max(),
|
|
1642
|
+
"mean": attr.mean(),
|
|
1643
|
+
"median": attr.median(),
|
|
1644
|
+
"std": attr.std(),
|
|
1645
|
+
"unique_values": attr.nunique(),
|
|
1646
|
+
}
|
|
1647
|
+
|
|
1648
|
+
# Create histogram
|
|
1649
|
+
plt.figure(figsize=(10, 6))
|
|
1650
|
+
plt.hist(attr.dropna(), bins=20, alpha=0.7, color="blue")
|
|
1651
|
+
plt.title(f"Histogram of {attribute_name}")
|
|
1652
|
+
plt.xlabel(attribute_name)
|
|
1653
|
+
plt.ylabel("Frequency")
|
|
1654
|
+
plt.grid(True, alpha=0.3)
|
|
1655
|
+
plt.show()
|
|
1656
|
+
|
|
1657
|
+
else:
|
|
1658
|
+
# Categorical attribute
|
|
1659
|
+
analysis = {
|
|
1660
|
+
"attribute": attribute_name,
|
|
1661
|
+
"type": "categorical",
|
|
1662
|
+
"count": attr.count(),
|
|
1663
|
+
"null_count": attr.isna().sum(),
|
|
1664
|
+
"unique_values": attr.nunique(),
|
|
1665
|
+
"value_counts": attr.value_counts().to_dict(),
|
|
1666
|
+
}
|
|
1667
|
+
|
|
1668
|
+
# Create bar plot for top categories
|
|
1669
|
+
top_n = min(10, attr.nunique())
|
|
1670
|
+
plt.figure(figsize=(10, 6))
|
|
1671
|
+
attr.value_counts().head(top_n).plot(kind="bar", color="skyblue")
|
|
1672
|
+
plt.title(f"Top {top_n} values for {attribute_name}")
|
|
1673
|
+
plt.xlabel(attribute_name)
|
|
1674
|
+
plt.ylabel("Count")
|
|
1675
|
+
plt.xticks(rotation=45)
|
|
1676
|
+
plt.grid(True, alpha=0.3)
|
|
1677
|
+
plt.tight_layout()
|
|
1678
|
+
plt.show()
|
|
1679
|
+
|
|
1680
|
+
return analysis
|
|
1681
|
+
|
|
1682
|
+
except Exception as e:
|
|
1683
|
+
print(f"Error analyzing attribute: {str(e)}")
|
|
1684
|
+
return None
|
|
1685
|
+
|
|
1686
|
+
|
|
1687
|
+
def visualize_vector_by_attribute(
|
|
1688
|
+
vector_path, attribute_name, cmap="viridis", figsize=(10, 8)
|
|
1689
|
+
):
|
|
1690
|
+
"""Create a thematic map visualization of vector data based on an attribute.
|
|
1691
|
+
|
|
1692
|
+
Args:
|
|
1693
|
+
vector_path (str): Path to the vector file
|
|
1694
|
+
attribute_name (str): Name of the attribute to visualize
|
|
1695
|
+
cmap (str, optional): Matplotlib colormap name. Defaults to 'viridis'.
|
|
1696
|
+
figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
|
|
1697
|
+
|
|
1698
|
+
Returns:
|
|
1699
|
+
bool: True if visualization was successful, False otherwise
|
|
1700
|
+
"""
|
|
1701
|
+
try:
|
|
1702
|
+
# Read the vector data
|
|
1703
|
+
gdf = gpd.read_file(vector_path)
|
|
1704
|
+
|
|
1705
|
+
# Check if attribute exists
|
|
1706
|
+
if attribute_name not in gdf.columns:
|
|
1707
|
+
print(f"Attribute '{attribute_name}' not found in the dataset")
|
|
1708
|
+
return False
|
|
1709
|
+
|
|
1710
|
+
# Create the plot
|
|
1711
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1712
|
+
|
|
1713
|
+
# Determine plot type based on data type
|
|
1714
|
+
if pd.api.types.is_numeric_dtype(gdf[attribute_name]):
|
|
1715
|
+
# Continuous data
|
|
1716
|
+
gdf.plot(column=attribute_name, cmap=cmap, legend=True, ax=ax)
|
|
1717
|
+
else:
|
|
1718
|
+
# Categorical data
|
|
1719
|
+
gdf.plot(column=attribute_name, categorical=True, legend=True, ax=ax)
|
|
1720
|
+
|
|
1721
|
+
# Add title and labels
|
|
1722
|
+
ax.set_title(f"{os.path.basename(vector_path)} - {attribute_name}")
|
|
1723
|
+
ax.set_xlabel("Longitude")
|
|
1724
|
+
ax.set_ylabel("Latitude")
|
|
1725
|
+
|
|
1726
|
+
# Add basemap or additional elements if available
|
|
1727
|
+
# Note: Additional options could be added here for more complex maps
|
|
1728
|
+
|
|
1729
|
+
plt.tight_layout()
|
|
1730
|
+
plt.show()
|
|
1731
|
+
|
|
1732
|
+
except Exception as e:
|
|
1733
|
+
print(f"Error visualizing data: {str(e)}")
|
|
1734
|
+
|
|
1735
|
+
|
|
1736
|
+
def clip_raster_by_bbox(
|
|
1737
|
+
input_raster, output_raster, bbox, bands=None, bbox_type="geo", bbox_crs=None
|
|
1738
|
+
):
|
|
1739
|
+
"""
|
|
1740
|
+
Clip a raster dataset using a bounding box and optionally select specific bands.
|
|
1741
|
+
|
|
1742
|
+
Args:
|
|
1743
|
+
input_raster (str): Path to the input raster file.
|
|
1744
|
+
output_raster (str): Path where the clipped raster will be saved.
|
|
1745
|
+
bbox (tuple): Bounding box coordinates either as:
|
|
1746
|
+
- Geographic coordinates (minx, miny, maxx, maxy) if bbox_type="geo"
|
|
1747
|
+
- Pixel indices (min_row, min_col, max_row, max_col) if bbox_type="pixel"
|
|
1748
|
+
bands (list, optional): List of band indices to keep (1-based indexing).
|
|
1749
|
+
If None, all bands will be kept.
|
|
1750
|
+
bbox_type (str, optional): Type of bounding box coordinates. Either "geo" for
|
|
1751
|
+
geographic coordinates or "pixel" for row/column indices.
|
|
1752
|
+
Default is "geo".
|
|
1753
|
+
bbox_crs (str or dict, optional): CRS of the bbox if different from the raster CRS.
|
|
1754
|
+
Can be provided as EPSG code (e.g., "EPSG:4326") or
|
|
1755
|
+
as a proj4 string. Only applies when bbox_type="geo".
|
|
1756
|
+
If None, assumes bbox is in the same CRS as the raster.
|
|
1757
|
+
|
|
1758
|
+
Returns:
|
|
1759
|
+
str: Path to the clipped output raster.
|
|
1760
|
+
|
|
1761
|
+
Raises:
|
|
1762
|
+
ImportError: If required dependencies are not installed.
|
|
1763
|
+
ValueError: If the bbox is invalid, bands are out of range, or bbox_type is invalid.
|
|
1764
|
+
RuntimeError: If the clipping operation fails.
|
|
1765
|
+
|
|
1766
|
+
Examples:
|
|
1767
|
+
Clip using geographic coordinates in the same CRS as the raster
|
|
1768
|
+
>>> clip_raster_by_bbox('input.tif', 'clipped_geo.tif', (100, 200, 300, 400))
|
|
1769
|
+
'clipped_geo.tif'
|
|
1770
|
+
|
|
1771
|
+
Clip using WGS84 coordinates when the raster is in a different CRS
|
|
1772
|
+
>>> clip_raster_by_bbox('input.tif', 'clipped_wgs84.tif', (-122.5, 37.7, -122.4, 37.8),
|
|
1773
|
+
... bbox_crs="EPSG:4326")
|
|
1774
|
+
'clipped_wgs84.tif'
|
|
1775
|
+
|
|
1776
|
+
Clip using row/column indices
|
|
1777
|
+
>>> clip_raster_by_bbox('input.tif', 'clipped_pixel.tif', (50, 100, 150, 200),
|
|
1778
|
+
... bbox_type="pixel")
|
|
1779
|
+
'clipped_pixel.tif'
|
|
1780
|
+
|
|
1781
|
+
Clip with band selection
|
|
1782
|
+
>>> clip_raster_by_bbox('input.tif', 'clipped_bands.tif', (100, 200, 300, 400),
|
|
1783
|
+
... bands=[1, 3])
|
|
1784
|
+
'clipped_bands.tif'
|
|
1785
|
+
"""
|
|
1786
|
+
from rasterio.transform import from_bounds
|
|
1787
|
+
from rasterio.warp import transform_bounds
|
|
1788
|
+
|
|
1789
|
+
# Validate bbox_type
|
|
1790
|
+
if bbox_type not in ["geo", "pixel"]:
|
|
1791
|
+
raise ValueError("bbox_type must be either 'geo' or 'pixel'")
|
|
1792
|
+
|
|
1793
|
+
# Validate bbox
|
|
1794
|
+
if len(bbox) != 4:
|
|
1795
|
+
raise ValueError("bbox must contain exactly 4 values")
|
|
1796
|
+
|
|
1797
|
+
# Open the source raster
|
|
1798
|
+
with rasterio.open(input_raster) as src:
|
|
1799
|
+
# Get the source CRS
|
|
1800
|
+
src_crs = src.crs
|
|
1801
|
+
|
|
1802
|
+
# Handle different bbox types
|
|
1803
|
+
if bbox_type == "geo":
|
|
1804
|
+
minx, miny, maxx, maxy = bbox
|
|
1805
|
+
|
|
1806
|
+
# Validate geographic bbox
|
|
1807
|
+
if minx >= maxx or miny >= maxy:
|
|
1808
|
+
raise ValueError(
|
|
1809
|
+
"Invalid geographic bbox. Expected (minx, miny, maxx, maxy) where minx < maxx and miny < maxy"
|
|
1810
|
+
)
|
|
1811
|
+
|
|
1812
|
+
# If bbox_crs is provided and different from the source CRS, transform the bbox
|
|
1813
|
+
if bbox_crs is not None and bbox_crs != src_crs:
|
|
1814
|
+
try:
|
|
1815
|
+
# Transform bbox coordinates from bbox_crs to src_crs
|
|
1816
|
+
minx, miny, maxx, maxy = transform_bounds(
|
|
1817
|
+
bbox_crs, src_crs, minx, miny, maxx, maxy
|
|
1818
|
+
)
|
|
1819
|
+
except Exception as e:
|
|
1820
|
+
raise ValueError(
|
|
1821
|
+
f"Failed to transform bbox from {bbox_crs} to {src_crs}: {str(e)}"
|
|
1822
|
+
)
|
|
1823
|
+
|
|
1824
|
+
# Calculate the pixel window from geographic coordinates
|
|
1825
|
+
window = src.window(minx, miny, maxx, maxy)
|
|
1826
|
+
|
|
1827
|
+
# Use the same bounds for the output transform
|
|
1828
|
+
output_bounds = (minx, miny, maxx, maxy)
|
|
1829
|
+
|
|
1830
|
+
else: # bbox_type == "pixel"
|
|
1831
|
+
min_row, min_col, max_row, max_col = bbox
|
|
1832
|
+
|
|
1833
|
+
# Validate pixel bbox
|
|
1834
|
+
if min_row >= max_row or min_col >= max_col:
|
|
1835
|
+
raise ValueError(
|
|
1836
|
+
"Invalid pixel bbox. Expected (min_row, min_col, max_row, max_col) where min_row < max_row and min_col < max_col"
|
|
1837
|
+
)
|
|
1838
|
+
|
|
1839
|
+
if (
|
|
1840
|
+
min_row < 0
|
|
1841
|
+
or min_col < 0
|
|
1842
|
+
or max_row > src.height
|
|
1843
|
+
or max_col > src.width
|
|
1844
|
+
):
|
|
1845
|
+
raise ValueError(
|
|
1846
|
+
f"Pixel indices out of bounds. Raster dimensions are {src.height} rows x {src.width} columns"
|
|
1847
|
+
)
|
|
1848
|
+
|
|
1849
|
+
# Create a window from pixel coordinates
|
|
1850
|
+
window = Window(min_col, min_row, max_col - min_col, max_row - min_row)
|
|
1851
|
+
|
|
1852
|
+
# Calculate the geographic bounds for this window
|
|
1853
|
+
window_transform = src.window_transform(window)
|
|
1854
|
+
output_bounds = rasterio.transform.array_bounds(
|
|
1855
|
+
window.height, window.width, window_transform
|
|
1856
|
+
)
|
|
1857
|
+
# Reorder to (minx, miny, maxx, maxy)
|
|
1858
|
+
output_bounds = (
|
|
1859
|
+
output_bounds[0],
|
|
1860
|
+
output_bounds[1],
|
|
1861
|
+
output_bounds[2],
|
|
1862
|
+
output_bounds[3],
|
|
1863
|
+
)
|
|
1864
|
+
|
|
1865
|
+
# Get window dimensions
|
|
1866
|
+
window_width = int(window.width)
|
|
1867
|
+
window_height = int(window.height)
|
|
1868
|
+
|
|
1869
|
+
# Check if the window is valid
|
|
1870
|
+
if window_width <= 0 or window_height <= 0:
|
|
1871
|
+
raise ValueError("Bounding box results in an empty window")
|
|
1872
|
+
|
|
1873
|
+
# Handle band selection
|
|
1874
|
+
if bands is None:
|
|
1875
|
+
# Use all bands
|
|
1876
|
+
bands_to_read = list(range(1, src.count + 1))
|
|
1877
|
+
else:
|
|
1878
|
+
# Validate band indices
|
|
1879
|
+
if not all(1 <= b <= src.count for b in bands):
|
|
1880
|
+
raise ValueError(f"Band indices must be between 1 and {src.count}")
|
|
1881
|
+
bands_to_read = bands
|
|
1882
|
+
|
|
1883
|
+
# Calculate new transform for the clipped raster
|
|
1884
|
+
new_transform = from_bounds(
|
|
1885
|
+
output_bounds[0],
|
|
1886
|
+
output_bounds[1],
|
|
1887
|
+
output_bounds[2],
|
|
1888
|
+
output_bounds[3],
|
|
1889
|
+
window_width,
|
|
1890
|
+
window_height,
|
|
1891
|
+
)
|
|
1892
|
+
|
|
1893
|
+
# Create a metadata dictionary for the output
|
|
1894
|
+
out_meta = src.meta.copy()
|
|
1895
|
+
out_meta.update(
|
|
1896
|
+
{
|
|
1897
|
+
"height": window_height,
|
|
1898
|
+
"width": window_width,
|
|
1899
|
+
"transform": new_transform,
|
|
1900
|
+
"count": len(bands_to_read),
|
|
1901
|
+
}
|
|
1902
|
+
)
|
|
1903
|
+
|
|
1904
|
+
# Read the data for the selected bands
|
|
1905
|
+
data = []
|
|
1906
|
+
for band_idx in bands_to_read:
|
|
1907
|
+
band_data = src.read(band_idx, window=window)
|
|
1908
|
+
data.append(band_data)
|
|
1909
|
+
|
|
1910
|
+
# Stack the bands into a single array
|
|
1911
|
+
if len(data) > 1:
|
|
1912
|
+
clipped_data = np.stack(data)
|
|
1913
|
+
else:
|
|
1914
|
+
clipped_data = data[0][np.newaxis, :, :]
|
|
1915
|
+
|
|
1916
|
+
# Write the output raster
|
|
1917
|
+
with rasterio.open(output_raster, "w", **out_meta) as dst:
|
|
1918
|
+
dst.write(clipped_data)
|
|
1919
|
+
|
|
1920
|
+
return output_raster
|
|
1921
|
+
|
|
1922
|
+
|
|
1923
|
+
def raster_to_vector(
|
|
1924
|
+
raster_path,
|
|
1925
|
+
output_path=None,
|
|
1926
|
+
threshold=0,
|
|
1927
|
+
min_area=10,
|
|
1928
|
+
simplify_tolerance=None,
|
|
1929
|
+
class_values=None,
|
|
1930
|
+
attribute_name="class",
|
|
1931
|
+
output_format="geojson",
|
|
1932
|
+
plot_result=False,
|
|
1933
|
+
):
|
|
1934
|
+
"""
|
|
1935
|
+
Convert a raster label mask to vector polygons.
|
|
1936
|
+
|
|
1937
|
+
Args:
|
|
1938
|
+
raster_path (str): Path to the input raster file (e.g., GeoTIFF).
|
|
1939
|
+
output_path (str): Path to save the output vector file. If None, returns GeoDataFrame without saving.
|
|
1940
|
+
threshold (int/float): Pixel values greater than this threshold will be vectorized.
|
|
1941
|
+
min_area (float): Minimum polygon area in square map units to keep.
|
|
1942
|
+
simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
|
|
1943
|
+
class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
|
|
1944
|
+
attribute_name (str): Name of the attribute field for the class values.
|
|
1945
|
+
output_format (str): Format for output file - 'geojson', 'shapefile', 'gpkg'.
|
|
1946
|
+
plot_result (bool): Whether to plot the resulting polygons overlaid on the raster.
|
|
1947
|
+
|
|
1948
|
+
Returns:
|
|
1949
|
+
geopandas.GeoDataFrame: A GeoDataFrame containing the vectorized polygons.
|
|
1950
|
+
"""
|
|
1951
|
+
# Open the raster file
|
|
1952
|
+
with rasterio.open(raster_path) as src:
|
|
1953
|
+
# Read the data
|
|
1954
|
+
data = src.read(1)
|
|
1955
|
+
|
|
1956
|
+
# Get metadata
|
|
1957
|
+
transform = src.transform
|
|
1958
|
+
crs = src.crs
|
|
1959
|
+
|
|
1960
|
+
# Create mask based on threshold and class values
|
|
1961
|
+
if class_values is not None:
|
|
1962
|
+
# Create a mask for each specified class value
|
|
1963
|
+
masks = {val: (data == val) for val in class_values}
|
|
1964
|
+
else:
|
|
1965
|
+
# Create a mask for values above threshold
|
|
1966
|
+
masks = {1: (data > threshold)}
|
|
1967
|
+
class_values = [1] # Default class
|
|
1968
|
+
|
|
1969
|
+
# Initialize list to store features
|
|
1970
|
+
all_features = []
|
|
1971
|
+
|
|
1972
|
+
# Process each class value
|
|
1973
|
+
for class_val in class_values:
|
|
1974
|
+
mask = masks[class_val]
|
|
1975
|
+
|
|
1976
|
+
# Vectorize the mask
|
|
1977
|
+
for geom, value in features.shapes(
|
|
1978
|
+
mask.astype(np.uint8), mask=mask, transform=transform
|
|
1979
|
+
):
|
|
1980
|
+
# Convert to shapely geometry
|
|
1981
|
+
geom = shape(geom)
|
|
1982
|
+
|
|
1983
|
+
# Skip small polygons
|
|
1984
|
+
if geom.area < min_area:
|
|
1985
|
+
continue
|
|
1986
|
+
|
|
1987
|
+
# Simplify geometry if requested
|
|
1988
|
+
if simplify_tolerance is not None:
|
|
1989
|
+
geom = geom.simplify(simplify_tolerance)
|
|
1990
|
+
|
|
1991
|
+
# Add to features list with class value
|
|
1992
|
+
all_features.append({"geometry": geom, attribute_name: class_val})
|
|
1993
|
+
|
|
1994
|
+
# Create GeoDataFrame
|
|
1995
|
+
if all_features:
|
|
1996
|
+
gdf = gpd.GeoDataFrame(all_features, crs=crs)
|
|
1997
|
+
else:
|
|
1998
|
+
print("Warning: No features were extracted from the raster.")
|
|
1999
|
+
# Return empty GeoDataFrame with correct CRS
|
|
2000
|
+
gdf = gpd.GeoDataFrame([], geometry=[], crs=crs)
|
|
2001
|
+
|
|
2002
|
+
# Save to file if requested
|
|
2003
|
+
if output_path is not None:
|
|
2004
|
+
# Create directory if it doesn't exist
|
|
2005
|
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
|
2006
|
+
|
|
2007
|
+
# Save to file based on format
|
|
2008
|
+
if output_format.lower() == "geojson":
|
|
2009
|
+
gdf.to_file(output_path, driver="GeoJSON")
|
|
2010
|
+
elif output_format.lower() == "shapefile":
|
|
2011
|
+
gdf.to_file(output_path)
|
|
2012
|
+
elif output_format.lower() == "gpkg":
|
|
2013
|
+
gdf.to_file(output_path, driver="GPKG")
|
|
2014
|
+
else:
|
|
2015
|
+
raise ValueError(f"Unsupported output format: {output_format}")
|
|
2016
|
+
|
|
2017
|
+
print(f"Vectorized data saved to {output_path}")
|
|
2018
|
+
|
|
2019
|
+
# Plot result if requested
|
|
2020
|
+
if plot_result:
|
|
2021
|
+
fig, ax = plt.subplots(figsize=(12, 12))
|
|
2022
|
+
|
|
2023
|
+
# Plot raster
|
|
2024
|
+
raster_img = src.read()
|
|
2025
|
+
if raster_img.shape[0] == 1:
|
|
2026
|
+
plt.imshow(raster_img[0], cmap="viridis", alpha=0.7)
|
|
2027
|
+
else:
|
|
2028
|
+
# Use first 3 bands for RGB display
|
|
2029
|
+
rgb = raster_img[:3].transpose(1, 2, 0)
|
|
2030
|
+
# Normalize for display
|
|
2031
|
+
rgb = np.clip(rgb / rgb.max(), 0, 1)
|
|
2032
|
+
plt.imshow(rgb)
|
|
2033
|
+
|
|
2034
|
+
# Plot vector boundaries
|
|
2035
|
+
if not gdf.empty:
|
|
2036
|
+
gdf.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
|
|
2037
|
+
|
|
2038
|
+
plt.title("Raster with Vectorized Boundaries")
|
|
2039
|
+
plt.axis("off")
|
|
2040
|
+
plt.tight_layout()
|
|
2041
|
+
plt.show()
|
|
2042
|
+
|
|
2043
|
+
return gdf
|
|
2044
|
+
|
|
2045
|
+
|
|
2046
|
+
def batch_raster_to_vector(
|
|
2047
|
+
input_dir,
|
|
2048
|
+
output_dir,
|
|
2049
|
+
pattern="*.tif",
|
|
2050
|
+
threshold=0,
|
|
2051
|
+
min_area=10,
|
|
2052
|
+
simplify_tolerance=None,
|
|
2053
|
+
class_values=None,
|
|
2054
|
+
attribute_name="class",
|
|
2055
|
+
output_format="geojson",
|
|
2056
|
+
merge_output=False,
|
|
2057
|
+
merge_filename="merged_vectors",
|
|
2058
|
+
):
|
|
2059
|
+
"""
|
|
2060
|
+
Batch convert multiple raster files to vector polygons.
|
|
2061
|
+
|
|
2062
|
+
Args:
|
|
2063
|
+
input_dir (str): Directory containing input raster files.
|
|
2064
|
+
output_dir (str): Directory to save output vector files.
|
|
2065
|
+
pattern (str): Pattern to match raster files (e.g., '*.tif').
|
|
2066
|
+
threshold (int/float): Pixel values greater than this threshold will be vectorized.
|
|
2067
|
+
min_area (float): Minimum polygon area in square map units to keep.
|
|
2068
|
+
simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
|
|
2069
|
+
class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
|
|
2070
|
+
attribute_name (str): Name of the attribute field for the class values.
|
|
2071
|
+
output_format (str): Format for output files - 'geojson', 'shapefile', 'gpkg'.
|
|
2072
|
+
merge_output (bool): Whether to merge all output vectors into a single file.
|
|
2073
|
+
merge_filename (str): Filename for the merged output (without extension).
|
|
2074
|
+
|
|
2075
|
+
Returns:
|
|
2076
|
+
geopandas.GeoDataFrame or None: If merge_output is True, returns the merged GeoDataFrame.
|
|
2077
|
+
"""
|
|
2078
|
+
import glob
|
|
2079
|
+
|
|
2080
|
+
# Create output directory if it doesn't exist
|
|
2081
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
2082
|
+
|
|
2083
|
+
# Get list of raster files
|
|
2084
|
+
raster_files = glob.glob(os.path.join(input_dir, pattern))
|
|
2085
|
+
|
|
2086
|
+
if not raster_files:
|
|
2087
|
+
print(f"No files matching pattern '{pattern}' found in {input_dir}")
|
|
2088
|
+
return None
|
|
2089
|
+
|
|
2090
|
+
print(f"Found {len(raster_files)} raster files to process")
|
|
2091
|
+
|
|
2092
|
+
# Process each raster file
|
|
2093
|
+
gdfs = []
|
|
2094
|
+
for raster_file in tqdm(raster_files, desc="Processing rasters"):
|
|
2095
|
+
# Get output filename
|
|
2096
|
+
base_name = os.path.splitext(os.path.basename(raster_file))[0]
|
|
2097
|
+
if output_format.lower() == "geojson":
|
|
2098
|
+
out_file = os.path.join(output_dir, f"{base_name}.geojson")
|
|
2099
|
+
elif output_format.lower() == "shapefile":
|
|
2100
|
+
out_file = os.path.join(output_dir, f"{base_name}.shp")
|
|
2101
|
+
elif output_format.lower() == "gpkg":
|
|
2102
|
+
out_file = os.path.join(output_dir, f"{base_name}.gpkg")
|
|
2103
|
+
else:
|
|
2104
|
+
raise ValueError(f"Unsupported output format: {output_format}")
|
|
2105
|
+
|
|
2106
|
+
# Convert raster to vector
|
|
2107
|
+
if merge_output:
|
|
2108
|
+
# Don't save individual files if merging
|
|
2109
|
+
gdf = raster_to_vector(
|
|
2110
|
+
raster_file,
|
|
2111
|
+
output_path=None,
|
|
2112
|
+
threshold=threshold,
|
|
2113
|
+
min_area=min_area,
|
|
2114
|
+
simplify_tolerance=simplify_tolerance,
|
|
2115
|
+
class_values=class_values,
|
|
2116
|
+
attribute_name=attribute_name,
|
|
2117
|
+
)
|
|
2118
|
+
|
|
2119
|
+
# Add filename as attribute
|
|
2120
|
+
if not gdf.empty:
|
|
2121
|
+
gdf["source_file"] = base_name
|
|
2122
|
+
gdfs.append(gdf)
|
|
2123
|
+
else:
|
|
2124
|
+
# Save individual files
|
|
2125
|
+
raster_to_vector(
|
|
2126
|
+
raster_file,
|
|
2127
|
+
output_path=out_file,
|
|
2128
|
+
threshold=threshold,
|
|
2129
|
+
min_area=min_area,
|
|
2130
|
+
simplify_tolerance=simplify_tolerance,
|
|
2131
|
+
class_values=class_values,
|
|
2132
|
+
attribute_name=attribute_name,
|
|
2133
|
+
output_format=output_format,
|
|
2134
|
+
)
|
|
2135
|
+
|
|
2136
|
+
# Merge output if requested
|
|
2137
|
+
if merge_output and gdfs:
|
|
2138
|
+
merged_gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))
|
|
2139
|
+
|
|
2140
|
+
# Set CRS to the CRS of the first GeoDataFrame
|
|
2141
|
+
if merged_gdf.crs is None and gdfs:
|
|
2142
|
+
merged_gdf.crs = gdfs[0].crs
|
|
2143
|
+
|
|
2144
|
+
# Save merged output
|
|
2145
|
+
if output_format.lower() == "geojson":
|
|
2146
|
+
merged_file = os.path.join(output_dir, f"{merge_filename}.geojson")
|
|
2147
|
+
merged_gdf.to_file(merged_file, driver="GeoJSON")
|
|
2148
|
+
elif output_format.lower() == "shapefile":
|
|
2149
|
+
merged_file = os.path.join(output_dir, f"{merge_filename}.shp")
|
|
2150
|
+
merged_gdf.to_file(merged_file)
|
|
2151
|
+
elif output_format.lower() == "gpkg":
|
|
2152
|
+
merged_file = os.path.join(output_dir, f"{merge_filename}.gpkg")
|
|
2153
|
+
merged_gdf.to_file(merged_file, driver="GPKG")
|
|
2154
|
+
|
|
2155
|
+
print(f"Merged vector data saved to {merged_file}")
|
|
2156
|
+
return merged_gdf
|
|
2157
|
+
|
|
2158
|
+
return None
|
|
2159
|
+
|
|
2160
|
+
|
|
2161
|
+
def vector_to_raster(
|
|
2162
|
+
vector_path,
|
|
2163
|
+
output_path=None,
|
|
2164
|
+
reference_raster=None,
|
|
2165
|
+
attribute_field=None,
|
|
2166
|
+
output_shape=None,
|
|
2167
|
+
transform=None,
|
|
2168
|
+
pixel_size=None,
|
|
2169
|
+
bounds=None,
|
|
2170
|
+
crs=None,
|
|
2171
|
+
all_touched=False,
|
|
2172
|
+
fill_value=0,
|
|
2173
|
+
dtype=np.uint8,
|
|
2174
|
+
nodata=None,
|
|
2175
|
+
plot_result=False,
|
|
2176
|
+
):
|
|
2177
|
+
"""
|
|
2178
|
+
Convert vector data to a raster.
|
|
2179
|
+
|
|
2180
|
+
Args:
|
|
2181
|
+
vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
|
|
2182
|
+
output_path (str): Path to save the output raster file. If None, returns the array without saving.
|
|
2183
|
+
reference_raster (str): Path to a reference raster for dimensions, transform and CRS.
|
|
2184
|
+
attribute_field (str): Field name in the vector data to use for pixel values.
|
|
2185
|
+
If None, all vector features will be burned with value 1.
|
|
2186
|
+
output_shape (tuple): Shape of the output raster as (height, width).
|
|
2187
|
+
Required if reference_raster is not provided.
|
|
2188
|
+
transform (affine.Affine): Affine transformation matrix.
|
|
2189
|
+
Required if reference_raster is not provided.
|
|
2190
|
+
pixel_size (float or tuple): Pixel size (resolution) as single value or (x_res, y_res).
|
|
2191
|
+
Used to calculate transform if transform is not provided.
|
|
2192
|
+
bounds (tuple): Bounds of the output raster as (left, bottom, right, top).
|
|
2193
|
+
Used to calculate transform if transform is not provided.
|
|
2194
|
+
crs (str or CRS): Coordinate reference system of the output raster.
|
|
2195
|
+
Required if reference_raster is not provided.
|
|
2196
|
+
all_touched (bool): If True, all pixels touched by geometries will be burned in.
|
|
2197
|
+
If False, only pixels whose center is within the geometry will be burned in.
|
|
2198
|
+
fill_value (int): Value to fill the raster with before burning in features.
|
|
2199
|
+
dtype (numpy.dtype): Data type of the output raster.
|
|
2200
|
+
nodata (int): No data value for the output raster.
|
|
2201
|
+
plot_result (bool): Whether to plot the resulting raster.
|
|
2202
|
+
|
|
2203
|
+
Returns:
|
|
2204
|
+
numpy.ndarray: The rasterized data array if output_path is None, else None.
|
|
2205
|
+
"""
|
|
2206
|
+
# Load vector data
|
|
2207
|
+
if isinstance(vector_path, gpd.GeoDataFrame):
|
|
2208
|
+
gdf = vector_path
|
|
2209
|
+
else:
|
|
2210
|
+
gdf = gpd.read_file(vector_path)
|
|
2211
|
+
|
|
2212
|
+
# Check if vector data is empty
|
|
2213
|
+
if gdf.empty:
|
|
2214
|
+
warnings.warn("The input vector data is empty. Creating an empty raster.")
|
|
2215
|
+
|
|
2216
|
+
# Get CRS from vector data if not provided
|
|
2217
|
+
if crs is None and reference_raster is None:
|
|
2218
|
+
crs = gdf.crs
|
|
2219
|
+
|
|
2220
|
+
# Get transform and output shape from reference raster if provided
|
|
2221
|
+
if reference_raster is not None:
|
|
2222
|
+
with rasterio.open(reference_raster) as src:
|
|
2223
|
+
transform = src.transform
|
|
2224
|
+
output_shape = src.shape
|
|
2225
|
+
crs = src.crs
|
|
2226
|
+
if nodata is None:
|
|
2227
|
+
nodata = src.nodata
|
|
2228
|
+
else:
|
|
2229
|
+
# Check if we have all required parameters
|
|
2230
|
+
if transform is None:
|
|
2231
|
+
if pixel_size is None or bounds is None:
|
|
2232
|
+
raise ValueError(
|
|
2233
|
+
"Either reference_raster, transform, or both pixel_size and bounds must be provided."
|
|
2234
|
+
)
|
|
2235
|
+
|
|
2236
|
+
# Calculate transform from pixel size and bounds
|
|
2237
|
+
if isinstance(pixel_size, (int, float)):
|
|
2238
|
+
x_res = y_res = float(pixel_size)
|
|
2239
|
+
else:
|
|
2240
|
+
x_res, y_res = pixel_size
|
|
2241
|
+
y_res = abs(y_res) * -1 # Convert to negative for north-up raster
|
|
2242
|
+
|
|
2243
|
+
left, bottom, right, top = bounds
|
|
2244
|
+
transform = rasterio.transform.from_bounds(
|
|
2245
|
+
left,
|
|
2246
|
+
bottom,
|
|
2247
|
+
right,
|
|
2248
|
+
top,
|
|
2249
|
+
int((right - left) / x_res),
|
|
2250
|
+
int((top - bottom) / abs(y_res)),
|
|
2251
|
+
)
|
|
2252
|
+
|
|
2253
|
+
if output_shape is None:
|
|
2254
|
+
# Calculate output shape from bounds and pixel size
|
|
2255
|
+
if bounds is None or pixel_size is None:
|
|
2256
|
+
raise ValueError(
|
|
2257
|
+
"output_shape must be provided if reference_raster is not provided and "
|
|
2258
|
+
"cannot be calculated from bounds and pixel_size."
|
|
2259
|
+
)
|
|
2260
|
+
|
|
2261
|
+
if isinstance(pixel_size, (int, float)):
|
|
2262
|
+
x_res = y_res = float(pixel_size)
|
|
2263
|
+
else:
|
|
2264
|
+
x_res, y_res = pixel_size
|
|
2265
|
+
|
|
2266
|
+
left, bottom, right, top = bounds
|
|
2267
|
+
width = int((right - left) / x_res)
|
|
2268
|
+
height = int((top - bottom) / abs(y_res))
|
|
2269
|
+
output_shape = (height, width)
|
|
2270
|
+
|
|
2271
|
+
# Ensure CRS is set
|
|
2272
|
+
if crs is None:
|
|
2273
|
+
raise ValueError(
|
|
2274
|
+
"CRS must be provided either directly, from reference_raster, or from input vector data."
|
|
2275
|
+
)
|
|
2276
|
+
|
|
2277
|
+
# Reproject vector data if its CRS doesn't match the output CRS
|
|
2278
|
+
if gdf.crs != crs:
|
|
2279
|
+
print(f"Reprojecting vector data from {gdf.crs} to {crs}")
|
|
2280
|
+
gdf = gdf.to_crs(crs)
|
|
2281
|
+
|
|
2282
|
+
# Create empty raster filled with fill_value
|
|
2283
|
+
raster_data = np.full(output_shape, fill_value, dtype=dtype)
|
|
2284
|
+
|
|
2285
|
+
# Burn vector features into raster
|
|
2286
|
+
if not gdf.empty:
|
|
2287
|
+
# Prepare shapes for burning
|
|
2288
|
+
if attribute_field is not None and attribute_field in gdf.columns:
|
|
2289
|
+
# Use attribute field for values
|
|
2290
|
+
shapes = [
|
|
2291
|
+
(geom, value) for geom, value in zip(gdf.geometry, gdf[attribute_field])
|
|
2292
|
+
]
|
|
2293
|
+
else:
|
|
2294
|
+
# Burn with value 1
|
|
2295
|
+
shapes = [(geom, 1) for geom in gdf.geometry]
|
|
2296
|
+
|
|
2297
|
+
# Burn shapes into raster
|
|
2298
|
+
burned = features.rasterize(
|
|
2299
|
+
shapes=shapes,
|
|
2300
|
+
out_shape=output_shape,
|
|
2301
|
+
transform=transform,
|
|
2302
|
+
fill=fill_value,
|
|
2303
|
+
all_touched=all_touched,
|
|
2304
|
+
dtype=dtype,
|
|
2305
|
+
)
|
|
2306
|
+
|
|
2307
|
+
# Update raster data
|
|
2308
|
+
raster_data = burned
|
|
2309
|
+
|
|
2310
|
+
# Save raster if output path is provided
|
|
2311
|
+
if output_path is not None:
|
|
2312
|
+
# Create directory if it doesn't exist
|
|
2313
|
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
|
2314
|
+
|
|
2315
|
+
# Define metadata
|
|
2316
|
+
metadata = {
|
|
2317
|
+
"driver": "GTiff",
|
|
2318
|
+
"height": output_shape[0],
|
|
2319
|
+
"width": output_shape[1],
|
|
2320
|
+
"count": 1,
|
|
2321
|
+
"dtype": raster_data.dtype,
|
|
2322
|
+
"crs": crs,
|
|
2323
|
+
"transform": transform,
|
|
2324
|
+
}
|
|
2325
|
+
|
|
2326
|
+
# Add nodata value if provided
|
|
2327
|
+
if nodata is not None:
|
|
2328
|
+
metadata["nodata"] = nodata
|
|
2329
|
+
|
|
2330
|
+
# Write raster
|
|
2331
|
+
with rasterio.open(output_path, "w", **metadata) as dst:
|
|
2332
|
+
dst.write(raster_data, 1)
|
|
2333
|
+
|
|
2334
|
+
print(f"Rasterized data saved to {output_path}")
|
|
2335
|
+
|
|
2336
|
+
# Plot result if requested
|
|
2337
|
+
if plot_result:
|
|
2338
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
|
2339
|
+
|
|
2340
|
+
# Plot raster
|
|
2341
|
+
im = ax.imshow(raster_data, cmap="viridis")
|
|
2342
|
+
plt.colorbar(im, ax=ax, label=attribute_field if attribute_field else "Value")
|
|
2343
|
+
|
|
2344
|
+
# Plot vector boundaries for reference
|
|
2345
|
+
if output_path is not None:
|
|
2346
|
+
# Get the extent of the raster
|
|
2347
|
+
with rasterio.open(output_path) as src:
|
|
2348
|
+
bounds = src.bounds
|
|
2349
|
+
raster_bbox = box(*bounds)
|
|
2350
|
+
else:
|
|
2351
|
+
# Calculate extent from transform and shape
|
|
2352
|
+
height, width = output_shape
|
|
2353
|
+
left, top = transform * (0, 0)
|
|
2354
|
+
right, bottom = transform * (width, height)
|
|
2355
|
+
raster_bbox = box(left, bottom, right, top)
|
|
2356
|
+
|
|
2357
|
+
# Clip vector to raster extent for clarity in plot
|
|
2358
|
+
if not gdf.empty:
|
|
2359
|
+
gdf_clipped = gpd.clip(gdf, raster_bbox)
|
|
2360
|
+
if not gdf_clipped.empty:
|
|
2361
|
+
gdf_clipped.boundary.plot(ax=ax, color="red", linewidth=1)
|
|
2362
|
+
|
|
2363
|
+
plt.title("Rasterized Vector Data")
|
|
2364
|
+
plt.tight_layout()
|
|
2365
|
+
plt.show()
|
|
2366
|
+
|
|
2367
|
+
return raster_data
|
|
2368
|
+
|
|
2369
|
+
|
|
2370
|
+
def batch_vector_to_raster(
|
|
2371
|
+
vector_path,
|
|
2372
|
+
output_dir,
|
|
2373
|
+
attribute_field=None,
|
|
2374
|
+
reference_rasters=None,
|
|
2375
|
+
bounds_list=None,
|
|
2376
|
+
output_filename_pattern="{vector_name}_{index}",
|
|
2377
|
+
pixel_size=1.0,
|
|
2378
|
+
all_touched=False,
|
|
2379
|
+
fill_value=0,
|
|
2380
|
+
dtype=np.uint8,
|
|
2381
|
+
nodata=None,
|
|
2382
|
+
):
|
|
2383
|
+
"""
|
|
2384
|
+
Batch convert vector data to multiple rasters based on different extents or reference rasters.
|
|
2385
|
+
|
|
2386
|
+
Args:
|
|
2387
|
+
vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
|
|
2388
|
+
output_dir (str): Directory to save output raster files.
|
|
2389
|
+
attribute_field (str): Field name in the vector data to use for pixel values.
|
|
2390
|
+
reference_rasters (list): List of paths to reference rasters for dimensions, transform and CRS.
|
|
2391
|
+
bounds_list (list): List of bounds tuples (left, bottom, right, top) to use if reference_rasters not provided.
|
|
2392
|
+
output_filename_pattern (str): Pattern for output filenames.
|
|
2393
|
+
Can include {vector_name} and {index} placeholders.
|
|
2394
|
+
pixel_size (float or tuple): Pixel size to use if reference_rasters not provided.
|
|
2395
|
+
all_touched (bool): If True, all pixels touched by geometries will be burned in.
|
|
2396
|
+
fill_value (int): Value to fill the raster with before burning in features.
|
|
2397
|
+
dtype (numpy.dtype): Data type of the output raster.
|
|
2398
|
+
nodata (int): No data value for the output raster.
|
|
2399
|
+
|
|
2400
|
+
Returns:
|
|
2401
|
+
list: List of paths to the created raster files.
|
|
2402
|
+
"""
|
|
2403
|
+
# Create output directory if it doesn't exist
|
|
2404
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
2405
|
+
|
|
2406
|
+
# Load vector data if it's a path
|
|
2407
|
+
if isinstance(vector_path, str):
|
|
2408
|
+
gdf = gpd.read_file(vector_path)
|
|
2409
|
+
vector_name = os.path.splitext(os.path.basename(vector_path))[0]
|
|
2410
|
+
else:
|
|
2411
|
+
gdf = vector_path
|
|
2412
|
+
vector_name = "vector"
|
|
2413
|
+
|
|
2414
|
+
# Check input parameters
|
|
2415
|
+
if reference_rasters is None and bounds_list is None:
|
|
2416
|
+
raise ValueError("Either reference_rasters or bounds_list must be provided.")
|
|
2417
|
+
|
|
2418
|
+
# Use reference_rasters if provided, otherwise use bounds_list
|
|
2419
|
+
if reference_rasters is not None:
|
|
2420
|
+
sources = reference_rasters
|
|
2421
|
+
is_raster_reference = True
|
|
2422
|
+
else:
|
|
2423
|
+
sources = bounds_list
|
|
2424
|
+
is_raster_reference = False
|
|
2425
|
+
|
|
2426
|
+
# Create output filenames
|
|
2427
|
+
output_files = []
|
|
2428
|
+
|
|
2429
|
+
# Process each source (reference raster or bounds)
|
|
2430
|
+
for i, source in enumerate(tqdm(sources, desc="Processing")):
|
|
2431
|
+
# Generate output filename
|
|
2432
|
+
output_filename = output_filename_pattern.format(
|
|
2433
|
+
vector_name=vector_name, index=i
|
|
2434
|
+
)
|
|
2435
|
+
if not output_filename.endswith(".tif"):
|
|
2436
|
+
output_filename += ".tif"
|
|
2437
|
+
output_path = os.path.join(output_dir, output_filename)
|
|
2438
|
+
|
|
2439
|
+
if is_raster_reference:
|
|
2440
|
+
# Use reference raster
|
|
2441
|
+
vector_to_raster(
|
|
2442
|
+
vector_path=gdf,
|
|
2443
|
+
output_path=output_path,
|
|
2444
|
+
reference_raster=source,
|
|
2445
|
+
attribute_field=attribute_field,
|
|
2446
|
+
all_touched=all_touched,
|
|
2447
|
+
fill_value=fill_value,
|
|
2448
|
+
dtype=dtype,
|
|
2449
|
+
nodata=nodata,
|
|
2450
|
+
)
|
|
2451
|
+
else:
|
|
2452
|
+
# Use bounds
|
|
2453
|
+
vector_to_raster(
|
|
2454
|
+
vector_path=gdf,
|
|
2455
|
+
output_path=output_path,
|
|
2456
|
+
bounds=source,
|
|
2457
|
+
pixel_size=pixel_size,
|
|
2458
|
+
attribute_field=attribute_field,
|
|
2459
|
+
all_touched=all_touched,
|
|
2460
|
+
fill_value=fill_value,
|
|
2461
|
+
dtype=dtype,
|
|
2462
|
+
nodata=nodata,
|
|
2463
|
+
)
|
|
2464
|
+
|
|
2465
|
+
output_files.append(output_path)
|
|
2466
|
+
|
|
2467
|
+
return output_files
|
|
2468
|
+
|
|
2469
|
+
|
|
2470
|
+
def export_geotiff_tiles(
|
|
2471
|
+
in_raster,
|
|
2472
|
+
out_folder,
|
|
2473
|
+
in_class_data,
|
|
2474
|
+
tile_size=256,
|
|
2475
|
+
stride=128,
|
|
2476
|
+
class_value_field="class",
|
|
2477
|
+
buffer_radius=0,
|
|
2478
|
+
max_tiles=None,
|
|
2479
|
+
quiet=False,
|
|
2480
|
+
all_touched=True,
|
|
2481
|
+
create_overview=False,
|
|
2482
|
+
skip_empty_tiles=False,
|
|
2483
|
+
):
|
|
2484
|
+
"""
|
|
2485
|
+
Export georeferenced GeoTIFF tiles and labels from raster and classification data.
|
|
2486
|
+
|
|
2487
|
+
Args:
|
|
2488
|
+
in_raster (str): Path to input raster image
|
|
2489
|
+
out_folder (str): Path to output folder
|
|
2490
|
+
in_class_data (str): Path to classification data - can be vector file or raster
|
|
2491
|
+
tile_size (int): Size of tiles in pixels (square)
|
|
2492
|
+
stride (int): Step size between tiles
|
|
2493
|
+
class_value_field (str): Field containing class values (for vector data)
|
|
2494
|
+
buffer_radius (float): Buffer to add around features (in units of the CRS)
|
|
2495
|
+
max_tiles (int): Maximum number of tiles to process (None for all)
|
|
2496
|
+
quiet (bool): If True, suppress non-essential output
|
|
2497
|
+
all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
|
|
2498
|
+
create_overview (bool): Whether to create an overview image of all tiles
|
|
2499
|
+
skip_empty_tiles (bool): If True, skip tiles with no features
|
|
2500
|
+
"""
|
|
2501
|
+
# Create output directories
|
|
2502
|
+
os.makedirs(out_folder, exist_ok=True)
|
|
2503
|
+
image_dir = os.path.join(out_folder, "images")
|
|
2504
|
+
os.makedirs(image_dir, exist_ok=True)
|
|
2505
|
+
label_dir = os.path.join(out_folder, "labels")
|
|
2506
|
+
os.makedirs(label_dir, exist_ok=True)
|
|
2507
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
2508
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
2509
|
+
|
|
2510
|
+
# Determine if class data is raster or vector
|
|
2511
|
+
is_class_data_raster = False
|
|
2512
|
+
if isinstance(in_class_data, str):
|
|
2513
|
+
file_ext = Path(in_class_data).suffix.lower()
|
|
2514
|
+
# Common raster extensions
|
|
2515
|
+
if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
|
|
2516
|
+
try:
|
|
2517
|
+
with rasterio.open(in_class_data) as src:
|
|
2518
|
+
is_class_data_raster = True
|
|
2519
|
+
if not quiet:
|
|
2520
|
+
print(f"Detected in_class_data as raster: {in_class_data}")
|
|
2521
|
+
print(f"Raster CRS: {src.crs}")
|
|
2522
|
+
print(f"Raster dimensions: {src.width} x {src.height}")
|
|
2523
|
+
except Exception:
|
|
2524
|
+
is_class_data_raster = False
|
|
2525
|
+
if not quiet:
|
|
2526
|
+
print(f"Unable to open {in_class_data} as raster, trying as vector")
|
|
2527
|
+
|
|
2528
|
+
# Open the input raster
|
|
2529
|
+
with rasterio.open(in_raster) as src:
|
|
2530
|
+
if not quiet:
|
|
2531
|
+
print(f"\nRaster info for {in_raster}:")
|
|
2532
|
+
print(f" CRS: {src.crs}")
|
|
2533
|
+
print(f" Dimensions: {src.width} x {src.height}")
|
|
2534
|
+
print(f" Bounds: {src.bounds}")
|
|
2535
|
+
|
|
2536
|
+
# Calculate number of tiles
|
|
2537
|
+
num_tiles_x = math.ceil((src.width - tile_size) / stride) + 1
|
|
2538
|
+
num_tiles_y = math.ceil((src.height - tile_size) / stride) + 1
|
|
2539
|
+
total_tiles = num_tiles_x * num_tiles_y
|
|
2540
|
+
|
|
2541
|
+
if max_tiles is None:
|
|
2542
|
+
max_tiles = total_tiles
|
|
2543
|
+
|
|
2544
|
+
# Process classification data
|
|
2545
|
+
class_to_id = {}
|
|
2546
|
+
|
|
2547
|
+
if is_class_data_raster:
|
|
2548
|
+
# Load raster class data
|
|
2549
|
+
with rasterio.open(in_class_data) as class_src:
|
|
2550
|
+
# Check if raster CRS matches
|
|
2551
|
+
if class_src.crs != src.crs:
|
|
2552
|
+
warnings.warn(
|
|
2553
|
+
f"CRS mismatch: Class raster ({class_src.crs}) doesn't match input raster ({src.crs}). "
|
|
2554
|
+
f"Results may be misaligned."
|
|
2555
|
+
)
|
|
2556
|
+
|
|
2557
|
+
# Get unique values from raster
|
|
2558
|
+
# Sample to avoid loading huge rasters
|
|
2559
|
+
sample_data = class_src.read(
|
|
2560
|
+
1,
|
|
2561
|
+
out_shape=(
|
|
2562
|
+
1,
|
|
2563
|
+
min(class_src.height, 1000),
|
|
2564
|
+
min(class_src.width, 1000),
|
|
2565
|
+
),
|
|
2566
|
+
)
|
|
2567
|
+
|
|
2568
|
+
unique_classes = np.unique(sample_data)
|
|
2569
|
+
unique_classes = unique_classes[
|
|
2570
|
+
unique_classes > 0
|
|
2571
|
+
] # Remove 0 as it's typically background
|
|
2572
|
+
|
|
2573
|
+
if not quiet:
|
|
2574
|
+
print(
|
|
2575
|
+
f"Found {len(unique_classes)} unique classes in raster: {unique_classes}"
|
|
2576
|
+
)
|
|
2577
|
+
|
|
2578
|
+
# Create class mapping
|
|
2579
|
+
class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
|
|
2580
|
+
else:
|
|
2581
|
+
# Load vector class data
|
|
2582
|
+
try:
|
|
2583
|
+
gdf = gpd.read_file(in_class_data)
|
|
2584
|
+
if not quiet:
|
|
2585
|
+
print(f"Loaded {len(gdf)} features from {in_class_data}")
|
|
2586
|
+
print(f"Vector CRS: {gdf.crs}")
|
|
2587
|
+
|
|
2588
|
+
# Always reproject to match raster CRS
|
|
2589
|
+
if gdf.crs != src.crs:
|
|
2590
|
+
if not quiet:
|
|
2591
|
+
print(f"Reprojecting features from {gdf.crs} to {src.crs}")
|
|
2592
|
+
gdf = gdf.to_crs(src.crs)
|
|
2593
|
+
|
|
2594
|
+
# Apply buffer if specified
|
|
2595
|
+
if buffer_radius > 0:
|
|
2596
|
+
gdf["geometry"] = gdf.buffer(buffer_radius)
|
|
2597
|
+
if not quiet:
|
|
2598
|
+
print(f"Applied buffer of {buffer_radius} units")
|
|
2599
|
+
|
|
2600
|
+
# Check if class_value_field exists
|
|
2601
|
+
if class_value_field in gdf.columns:
|
|
2602
|
+
unique_classes = gdf[class_value_field].unique()
|
|
2603
|
+
if not quiet:
|
|
2604
|
+
print(
|
|
2605
|
+
f"Found {len(unique_classes)} unique classes: {unique_classes}"
|
|
2606
|
+
)
|
|
2607
|
+
# Create class mapping
|
|
2608
|
+
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
2609
|
+
else:
|
|
2610
|
+
if not quiet:
|
|
2611
|
+
print(
|
|
2612
|
+
f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
|
|
2613
|
+
)
|
|
2614
|
+
class_to_id = {1: 1} # Default mapping
|
|
2615
|
+
except Exception as e:
|
|
2616
|
+
raise ValueError(f"Error processing vector data: {e}")
|
|
2617
|
+
|
|
2618
|
+
# Create progress bar
|
|
2619
|
+
pbar = tqdm(
|
|
2620
|
+
total=min(total_tiles, max_tiles),
|
|
2621
|
+
desc="Generating tiles",
|
|
2622
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
|
|
2623
|
+
)
|
|
2624
|
+
|
|
2625
|
+
# Track statistics for summary
|
|
2626
|
+
stats = {
|
|
2627
|
+
"total_tiles": 0,
|
|
2628
|
+
"tiles_with_features": 0,
|
|
2629
|
+
"feature_pixels": 0,
|
|
2630
|
+
"errors": 0,
|
|
2631
|
+
"tile_coordinates": [], # For overview image
|
|
2632
|
+
}
|
|
2633
|
+
|
|
2634
|
+
# Process tiles
|
|
2635
|
+
tile_index = 0
|
|
2636
|
+
for y in range(num_tiles_y):
|
|
2637
|
+
for x in range(num_tiles_x):
|
|
2638
|
+
if tile_index >= max_tiles:
|
|
2639
|
+
break
|
|
2640
|
+
|
|
2641
|
+
# Calculate window coordinates
|
|
2642
|
+
window_x = x * stride
|
|
2643
|
+
window_y = y * stride
|
|
2644
|
+
|
|
2645
|
+
# Adjust for edge cases
|
|
2646
|
+
if window_x + tile_size > src.width:
|
|
2647
|
+
window_x = src.width - tile_size
|
|
2648
|
+
if window_y + tile_size > src.height:
|
|
2649
|
+
window_y = src.height - tile_size
|
|
2650
|
+
|
|
2651
|
+
# Define window
|
|
2652
|
+
window = Window(window_x, window_y, tile_size, tile_size)
|
|
2653
|
+
|
|
2654
|
+
# Get window transform and bounds
|
|
2655
|
+
window_transform = src.window_transform(window)
|
|
2656
|
+
|
|
2657
|
+
# Calculate window bounds
|
|
2658
|
+
minx = window_transform[2] # Upper left x
|
|
2659
|
+
maxy = window_transform[5] # Upper left y
|
|
2660
|
+
maxx = minx + tile_size * window_transform[0] # Add width
|
|
2661
|
+
miny = maxy + tile_size * window_transform[4] # Add height
|
|
2662
|
+
|
|
2663
|
+
window_bounds = box(minx, miny, maxx, maxy)
|
|
2664
|
+
|
|
2665
|
+
# Store tile coordinates for overview
|
|
2666
|
+
if create_overview:
|
|
2667
|
+
stats["tile_coordinates"].append(
|
|
2668
|
+
{
|
|
2669
|
+
"index": tile_index,
|
|
2670
|
+
"x": window_x,
|
|
2671
|
+
"y": window_y,
|
|
2672
|
+
"bounds": [minx, miny, maxx, maxy],
|
|
2673
|
+
"has_features": False,
|
|
2674
|
+
}
|
|
2675
|
+
)
|
|
2676
|
+
|
|
2677
|
+
# Create label mask
|
|
2678
|
+
label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
|
|
2679
|
+
has_features = False
|
|
2680
|
+
|
|
2681
|
+
# Process classification data to create labels
|
|
2682
|
+
if is_class_data_raster:
|
|
2683
|
+
# For raster class data
|
|
2684
|
+
with rasterio.open(in_class_data) as class_src:
|
|
2685
|
+
# Calculate window in class raster
|
|
2686
|
+
src_bounds = src.bounds
|
|
2687
|
+
class_bounds = class_src.bounds
|
|
2688
|
+
|
|
2689
|
+
# Check if windows overlap
|
|
2690
|
+
if (
|
|
2691
|
+
src_bounds.left > class_bounds.right
|
|
2692
|
+
or src_bounds.right < class_bounds.left
|
|
2693
|
+
or src_bounds.bottom > class_bounds.top
|
|
2694
|
+
or src_bounds.top < class_bounds.bottom
|
|
2695
|
+
):
|
|
2696
|
+
warnings.warn(
|
|
2697
|
+
"Class raster and input raster do not overlap."
|
|
2698
|
+
)
|
|
2699
|
+
else:
|
|
2700
|
+
# Get corresponding window in class raster
|
|
2701
|
+
window_class = rasterio.windows.from_bounds(
|
|
2702
|
+
minx, miny, maxx, maxy, class_src.transform
|
|
2703
|
+
)
|
|
2704
|
+
|
|
2705
|
+
# Read label data
|
|
2706
|
+
try:
|
|
2707
|
+
label_data = class_src.read(
|
|
2708
|
+
1,
|
|
2709
|
+
window=window_class,
|
|
2710
|
+
boundless=True,
|
|
2711
|
+
out_shape=(tile_size, tile_size),
|
|
2712
|
+
)
|
|
2713
|
+
|
|
2714
|
+
# Remap class values if needed
|
|
2715
|
+
if class_to_id:
|
|
2716
|
+
remapped_data = np.zeros_like(label_data)
|
|
2717
|
+
for orig_val, new_val in class_to_id.items():
|
|
2718
|
+
remapped_data[label_data == orig_val] = new_val
|
|
2719
|
+
label_mask = remapped_data
|
|
2720
|
+
else:
|
|
2721
|
+
label_mask = label_data
|
|
2722
|
+
|
|
2723
|
+
# Check if we have any features
|
|
2724
|
+
if np.any(label_mask > 0):
|
|
2725
|
+
has_features = True
|
|
2726
|
+
stats["feature_pixels"] += np.count_nonzero(
|
|
2727
|
+
label_mask
|
|
2728
|
+
)
|
|
2729
|
+
except Exception as e:
|
|
2730
|
+
pbar.write(f"Error reading class raster window: {e}")
|
|
2731
|
+
stats["errors"] += 1
|
|
2732
|
+
else:
|
|
2733
|
+
# For vector class data
|
|
2734
|
+
# Find features that intersect with window
|
|
2735
|
+
window_features = gdf[gdf.intersects(window_bounds)]
|
|
2736
|
+
|
|
2737
|
+
if len(window_features) > 0:
|
|
2738
|
+
for idx, feature in window_features.iterrows():
|
|
2739
|
+
# Get class value
|
|
2740
|
+
if class_value_field in feature:
|
|
2741
|
+
class_val = feature[class_value_field]
|
|
2742
|
+
class_id = class_to_id.get(class_val, 1)
|
|
2743
|
+
else:
|
|
2744
|
+
class_id = 1
|
|
2745
|
+
|
|
2746
|
+
# Get geometry in window coordinates
|
|
2747
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
2748
|
+
if not geom.is_empty:
|
|
2749
|
+
try:
|
|
2750
|
+
# Rasterize feature
|
|
2751
|
+
feature_mask = features.rasterize(
|
|
2752
|
+
[(geom, class_id)],
|
|
2753
|
+
out_shape=(tile_size, tile_size),
|
|
2754
|
+
transform=window_transform,
|
|
2755
|
+
fill=0,
|
|
2756
|
+
all_touched=all_touched,
|
|
2757
|
+
)
|
|
2758
|
+
|
|
2759
|
+
# Add to label mask
|
|
2760
|
+
label_mask = np.maximum(label_mask, feature_mask)
|
|
2761
|
+
|
|
2762
|
+
# Check if the feature was actually rasterized
|
|
2763
|
+
if np.any(feature_mask):
|
|
2764
|
+
has_features = True
|
|
2765
|
+
if create_overview and tile_index < len(
|
|
2766
|
+
stats["tile_coordinates"]
|
|
2767
|
+
):
|
|
2768
|
+
stats["tile_coordinates"][tile_index][
|
|
2769
|
+
"has_features"
|
|
2770
|
+
] = True
|
|
2771
|
+
except Exception as e:
|
|
2772
|
+
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
2773
|
+
stats["errors"] += 1
|
|
2774
|
+
|
|
2775
|
+
# Skip tile if no features and skip_empty_tiles is True
|
|
2776
|
+
if skip_empty_tiles and not has_features:
|
|
2777
|
+
pbar.update(1)
|
|
2778
|
+
tile_index += 1
|
|
2779
|
+
continue
|
|
2780
|
+
|
|
2781
|
+
# Read image data
|
|
2782
|
+
image_data = src.read(window=window)
|
|
2783
|
+
|
|
2784
|
+
# Export image as GeoTIFF
|
|
2785
|
+
image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
|
|
2786
|
+
|
|
2787
|
+
# Create profile for image GeoTIFF
|
|
2788
|
+
image_profile = src.profile.copy()
|
|
2789
|
+
image_profile.update(
|
|
2790
|
+
{
|
|
2791
|
+
"height": tile_size,
|
|
2792
|
+
"width": tile_size,
|
|
2793
|
+
"count": image_data.shape[0],
|
|
2794
|
+
"transform": window_transform,
|
|
2795
|
+
}
|
|
2796
|
+
)
|
|
2797
|
+
|
|
2798
|
+
# Save image as GeoTIFF
|
|
2799
|
+
try:
|
|
2800
|
+
with rasterio.open(image_path, "w", **image_profile) as dst:
|
|
2801
|
+
dst.write(image_data)
|
|
2802
|
+
stats["total_tiles"] += 1
|
|
2803
|
+
except Exception as e:
|
|
2804
|
+
pbar.write(f"ERROR saving image GeoTIFF: {e}")
|
|
2805
|
+
stats["errors"] += 1
|
|
2806
|
+
|
|
2807
|
+
# Create profile for label GeoTIFF
|
|
2808
|
+
label_profile = {
|
|
2809
|
+
"driver": "GTiff",
|
|
2810
|
+
"height": tile_size,
|
|
2811
|
+
"width": tile_size,
|
|
2812
|
+
"count": 1,
|
|
2813
|
+
"dtype": "uint8",
|
|
2814
|
+
"crs": src.crs,
|
|
2815
|
+
"transform": window_transform,
|
|
2816
|
+
}
|
|
2817
|
+
|
|
2818
|
+
# Export label as GeoTIFF
|
|
2819
|
+
label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
|
|
2820
|
+
try:
|
|
2821
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
2822
|
+
dst.write(label_mask.astype(np.uint8), 1)
|
|
2823
|
+
|
|
2824
|
+
if has_features:
|
|
2825
|
+
stats["tiles_with_features"] += 1
|
|
2826
|
+
stats["feature_pixels"] += np.count_nonzero(label_mask)
|
|
2827
|
+
except Exception as e:
|
|
2828
|
+
pbar.write(f"ERROR saving label GeoTIFF: {e}")
|
|
2829
|
+
stats["errors"] += 1
|
|
2830
|
+
|
|
2831
|
+
# Create XML annotation for object detection if using vector class data
|
|
2832
|
+
if (
|
|
2833
|
+
not is_class_data_raster
|
|
2834
|
+
and "gdf" in locals()
|
|
2835
|
+
and len(window_features) > 0
|
|
2836
|
+
):
|
|
2837
|
+
# Create XML annotation
|
|
2838
|
+
root = ET.Element("annotation")
|
|
2839
|
+
ET.SubElement(root, "folder").text = "images"
|
|
2840
|
+
ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
|
|
2841
|
+
|
|
2842
|
+
size = ET.SubElement(root, "size")
|
|
2843
|
+
ET.SubElement(size, "width").text = str(tile_size)
|
|
2844
|
+
ET.SubElement(size, "height").text = str(tile_size)
|
|
2845
|
+
ET.SubElement(size, "depth").text = str(image_data.shape[0])
|
|
2846
|
+
|
|
2847
|
+
# Add georeference information
|
|
2848
|
+
geo = ET.SubElement(root, "georeference")
|
|
2849
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
2850
|
+
ET.SubElement(geo, "transform").text = str(
|
|
2851
|
+
window_transform
|
|
2852
|
+
).replace("\n", "")
|
|
2853
|
+
ET.SubElement(geo, "bounds").text = (
|
|
2854
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
2855
|
+
)
|
|
2856
|
+
|
|
2857
|
+
# Add objects
|
|
2858
|
+
for idx, feature in window_features.iterrows():
|
|
2859
|
+
# Get feature class
|
|
2860
|
+
if class_value_field in feature:
|
|
2861
|
+
class_val = feature[class_value_field]
|
|
2862
|
+
else:
|
|
2863
|
+
class_val = "object"
|
|
2864
|
+
|
|
2865
|
+
# Get geometry bounds in pixel coordinates
|
|
2866
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
2867
|
+
if not geom.is_empty:
|
|
2868
|
+
# Get bounds in world coordinates
|
|
2869
|
+
minx_f, miny_f, maxx_f, maxy_f = geom.bounds
|
|
2870
|
+
|
|
2871
|
+
# Convert to pixel coordinates
|
|
2872
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
2873
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
2874
|
+
|
|
2875
|
+
# Ensure coordinates are within tile bounds
|
|
2876
|
+
xmin = max(0, min(tile_size, int(col_min)))
|
|
2877
|
+
ymin = max(0, min(tile_size, int(row_min)))
|
|
2878
|
+
xmax = max(0, min(tile_size, int(col_max)))
|
|
2879
|
+
ymax = max(0, min(tile_size, int(row_max)))
|
|
2880
|
+
|
|
2881
|
+
# Only add if the box has non-zero area
|
|
2882
|
+
if xmax > xmin and ymax > ymin:
|
|
2883
|
+
obj = ET.SubElement(root, "object")
|
|
2884
|
+
ET.SubElement(obj, "name").text = str(class_val)
|
|
2885
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
2886
|
+
|
|
2887
|
+
bbox = ET.SubElement(obj, "bndbox")
|
|
2888
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
2889
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
2890
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
2891
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
2892
|
+
|
|
2893
|
+
# Save XML
|
|
2894
|
+
tree = ET.ElementTree(root)
|
|
2895
|
+
xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
|
|
2896
|
+
tree.write(xml_path)
|
|
2897
|
+
|
|
2898
|
+
# Update progress bar
|
|
2899
|
+
pbar.update(1)
|
|
2900
|
+
pbar.set_description(
|
|
2901
|
+
f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
|
|
2902
|
+
)
|
|
2903
|
+
|
|
2904
|
+
tile_index += 1
|
|
2905
|
+
if tile_index >= max_tiles:
|
|
2906
|
+
break
|
|
2907
|
+
|
|
2908
|
+
if tile_index >= max_tiles:
|
|
2909
|
+
break
|
|
2910
|
+
|
|
2911
|
+
# Close progress bar
|
|
2912
|
+
pbar.close()
|
|
2913
|
+
|
|
2914
|
+
# Create overview image if requested
|
|
2915
|
+
if create_overview and stats["tile_coordinates"]:
|
|
2916
|
+
try:
|
|
2917
|
+
create_overview_image(
|
|
2918
|
+
src,
|
|
2919
|
+
stats["tile_coordinates"],
|
|
2920
|
+
os.path.join(out_folder, "overview.png"),
|
|
2921
|
+
tile_size,
|
|
2922
|
+
stride,
|
|
2923
|
+
)
|
|
2924
|
+
except Exception as e:
|
|
2925
|
+
print(f"Failed to create overview image: {e}")
|
|
2926
|
+
|
|
2927
|
+
# Report results
|
|
2928
|
+
if not quiet:
|
|
2929
|
+
print("\n------- Export Summary -------")
|
|
2930
|
+
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
2931
|
+
print(
|
|
2932
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
2933
|
+
)
|
|
2934
|
+
if stats["tiles_with_features"] > 0:
|
|
2935
|
+
print(
|
|
2936
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
2937
|
+
)
|
|
2938
|
+
if stats["errors"] > 0:
|
|
2939
|
+
print(f"Errors encountered: {stats['errors']}")
|
|
2940
|
+
print(f"Output saved to: {out_folder}")
|
|
2941
|
+
|
|
2942
|
+
# Verify georeference in a sample image and label
|
|
2943
|
+
if stats["total_tiles"] > 0:
|
|
2944
|
+
print("\n------- Georeference Verification -------")
|
|
2945
|
+
sample_image = os.path.join(image_dir, f"tile_0.tif")
|
|
2946
|
+
sample_label = os.path.join(label_dir, f"tile_0.tif")
|
|
2947
|
+
|
|
2948
|
+
if os.path.exists(sample_image):
|
|
2949
|
+
try:
|
|
2950
|
+
with rasterio.open(sample_image) as img:
|
|
2951
|
+
print(f"Image CRS: {img.crs}")
|
|
2952
|
+
print(f"Image transform: {img.transform}")
|
|
2953
|
+
print(
|
|
2954
|
+
f"Image has georeference: {img.crs is not None and img.transform is not None}"
|
|
2955
|
+
)
|
|
2956
|
+
print(
|
|
2957
|
+
f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
|
|
2958
|
+
)
|
|
2959
|
+
except Exception as e:
|
|
2960
|
+
print(f"Error verifying image georeference: {e}")
|
|
2961
|
+
|
|
2962
|
+
if os.path.exists(sample_label):
|
|
2963
|
+
try:
|
|
2964
|
+
with rasterio.open(sample_label) as lbl:
|
|
2965
|
+
print(f"Label CRS: {lbl.crs}")
|
|
2966
|
+
print(f"Label transform: {lbl.transform}")
|
|
2967
|
+
print(
|
|
2968
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
2969
|
+
)
|
|
2970
|
+
print(
|
|
2971
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
2972
|
+
)
|
|
2973
|
+
except Exception as e:
|
|
2974
|
+
print(f"Error verifying label georeference: {e}")
|
|
2975
|
+
|
|
2976
|
+
# Return statistics dictionary for further processing if needed
|
|
2977
|
+
return stats
|
|
2978
|
+
|
|
2979
|
+
|
|
2980
|
+
def create_overview_image(
|
|
2981
|
+
src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
|
|
2982
|
+
):
|
|
2983
|
+
"""Create an overview image showing all tiles and their status, with optional GeoJSON export.
|
|
2984
|
+
|
|
2985
|
+
Args:
|
|
2986
|
+
src (rasterio.io.DatasetReader): The source raster dataset.
|
|
2987
|
+
tile_coordinates (list): A list of dictionaries containing tile information.
|
|
2988
|
+
output_path (str): The path where the overview image will be saved.
|
|
2989
|
+
tile_size (int): The size of each tile in pixels.
|
|
2990
|
+
stride (int): The stride between tiles in pixels. Controls overlap between adjacent tiles.
|
|
2991
|
+
geojson_path (str, optional): If provided, exports the tile rectangles as GeoJSON to this path.
|
|
2992
|
+
|
|
2993
|
+
Returns:
|
|
2994
|
+
str: Path to the saved overview image.
|
|
2995
|
+
"""
|
|
2996
|
+
# Read a reduced version of the source image
|
|
2997
|
+
overview_scale = max(
|
|
2998
|
+
1, int(max(src.width, src.height) / 2000)
|
|
2999
|
+
) # Scale to max ~2000px
|
|
3000
|
+
overview_width = src.width // overview_scale
|
|
3001
|
+
overview_height = src.height // overview_scale
|
|
3002
|
+
|
|
3003
|
+
# Read downsampled image
|
|
3004
|
+
overview_data = src.read(
|
|
3005
|
+
out_shape=(src.count, overview_height, overview_width),
|
|
3006
|
+
resampling=rasterio.enums.Resampling.average,
|
|
3007
|
+
)
|
|
3008
|
+
|
|
3009
|
+
# Create RGB image for display
|
|
3010
|
+
if overview_data.shape[0] >= 3:
|
|
3011
|
+
rgb = np.moveaxis(overview_data[:3], 0, -1)
|
|
3012
|
+
else:
|
|
3013
|
+
# For single band, create grayscale RGB
|
|
3014
|
+
rgb = np.stack([overview_data[0], overview_data[0], overview_data[0]], axis=-1)
|
|
3015
|
+
|
|
3016
|
+
# Normalize for display
|
|
3017
|
+
for i in range(rgb.shape[-1]):
|
|
3018
|
+
band = rgb[..., i]
|
|
3019
|
+
non_zero = band[band > 0]
|
|
3020
|
+
if len(non_zero) > 0:
|
|
3021
|
+
p2, p98 = np.percentile(non_zero, (2, 98))
|
|
3022
|
+
rgb[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
|
|
3023
|
+
|
|
3024
|
+
# Create figure
|
|
3025
|
+
plt.figure(figsize=(12, 12))
|
|
3026
|
+
plt.imshow(rgb)
|
|
3027
|
+
|
|
3028
|
+
# If GeoJSON export is requested, prepare GeoJSON structures
|
|
3029
|
+
if geojson_path:
|
|
3030
|
+
features = []
|
|
3031
|
+
|
|
3032
|
+
# Draw tile boundaries
|
|
3033
|
+
for tile in tile_coordinates:
|
|
3034
|
+
# Convert bounds to pixel coordinates in overview
|
|
3035
|
+
bounds = tile["bounds"]
|
|
3036
|
+
# Calculate scaled pixel coordinates
|
|
3037
|
+
x_min = int((tile["x"]) / overview_scale)
|
|
3038
|
+
y_min = int((tile["y"]) / overview_scale)
|
|
3039
|
+
width = int(tile_size / overview_scale)
|
|
3040
|
+
height = int(tile_size / overview_scale)
|
|
3041
|
+
|
|
3042
|
+
# Draw rectangle
|
|
3043
|
+
color = "lime" if tile["has_features"] else "red"
|
|
3044
|
+
rect = plt.Rectangle(
|
|
3045
|
+
(x_min, y_min), width, height, fill=False, edgecolor=color, linewidth=0.5
|
|
3046
|
+
)
|
|
3047
|
+
plt.gca().add_patch(rect)
|
|
3048
|
+
|
|
3049
|
+
# Add tile number if not too crowded
|
|
3050
|
+
if width > 20 and height > 20:
|
|
3051
|
+
plt.text(
|
|
3052
|
+
x_min + width / 2,
|
|
3053
|
+
y_min + height / 2,
|
|
3054
|
+
str(tile["index"]),
|
|
3055
|
+
color="white",
|
|
3056
|
+
ha="center",
|
|
3057
|
+
va="center",
|
|
3058
|
+
fontsize=8,
|
|
3059
|
+
)
|
|
3060
|
+
|
|
3061
|
+
# Add to GeoJSON features if exporting
|
|
3062
|
+
if geojson_path:
|
|
3063
|
+
# Create a polygon from the bounds (already in geo-coordinates)
|
|
3064
|
+
minx, miny, maxx, maxy = bounds
|
|
3065
|
+
polygon = box(minx, miny, maxx, maxy)
|
|
3066
|
+
|
|
3067
|
+
# Calculate overlap with neighboring tiles
|
|
3068
|
+
overlap = 0
|
|
3069
|
+
if stride < tile_size:
|
|
3070
|
+
overlap = tile_size - stride
|
|
3071
|
+
|
|
3072
|
+
# Create a GeoJSON feature
|
|
3073
|
+
feature = {
|
|
3074
|
+
"type": "Feature",
|
|
3075
|
+
"geometry": mapping(polygon),
|
|
3076
|
+
"properties": {
|
|
3077
|
+
"index": tile["index"],
|
|
3078
|
+
"has_features": tile["has_features"],
|
|
3079
|
+
"bounds_pixel": [
|
|
3080
|
+
tile["x"],
|
|
3081
|
+
tile["y"],
|
|
3082
|
+
tile["x"] + tile_size,
|
|
3083
|
+
tile["y"] + tile_size,
|
|
3084
|
+
],
|
|
3085
|
+
"tile_size_px": tile_size,
|
|
3086
|
+
"stride_px": stride,
|
|
3087
|
+
"overlap_px": overlap,
|
|
3088
|
+
},
|
|
3089
|
+
}
|
|
3090
|
+
|
|
3091
|
+
# Add any additional properties from the tile
|
|
3092
|
+
for key, value in tile.items():
|
|
3093
|
+
if key not in ["x", "y", "index", "has_features", "bounds"]:
|
|
3094
|
+
feature["properties"][key] = value
|
|
3095
|
+
|
|
3096
|
+
features.append(feature)
|
|
3097
|
+
|
|
3098
|
+
plt.title("Tile Overview (Green = Contains Features, Red = Empty)")
|
|
3099
|
+
plt.axis("off")
|
|
3100
|
+
plt.tight_layout()
|
|
3101
|
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
|
3102
|
+
plt.close()
|
|
3103
|
+
|
|
3104
|
+
print(f"Overview image saved to {output_path}")
|
|
3105
|
+
|
|
3106
|
+
# Export GeoJSON if requested
|
|
3107
|
+
if geojson_path:
|
|
3108
|
+
geojson_collection = {
|
|
3109
|
+
"type": "FeatureCollection",
|
|
3110
|
+
"features": features,
|
|
3111
|
+
"properties": {
|
|
3112
|
+
"crs": (
|
|
3113
|
+
src.crs.to_string()
|
|
3114
|
+
if hasattr(src.crs, "to_string")
|
|
3115
|
+
else str(src.crs)
|
|
3116
|
+
),
|
|
3117
|
+
"total_tiles": len(features),
|
|
3118
|
+
"source_raster_dimensions": [src.width, src.height],
|
|
3119
|
+
},
|
|
3120
|
+
}
|
|
3121
|
+
|
|
3122
|
+
# Save to file
|
|
3123
|
+
with open(geojson_path, "w") as f:
|
|
3124
|
+
json.dump(geojson_collection, f)
|
|
3125
|
+
|
|
3126
|
+
print(f"GeoJSON saved to {geojson_path}")
|
|
3127
|
+
|
|
3128
|
+
return output_path
|
|
3129
|
+
|
|
3130
|
+
|
|
3131
|
+
def export_tiles_to_geojson(
|
|
3132
|
+
tile_coordinates, src, output_path, tile_size=None, stride=None
|
|
3133
|
+
):
|
|
3134
|
+
"""
|
|
3135
|
+
Export tile rectangles directly to GeoJSON without creating an overview image.
|
|
3136
|
+
|
|
3137
|
+
Args:
|
|
3138
|
+
tile_coordinates (list): A list of dictionaries containing tile information.
|
|
3139
|
+
src (rasterio.io.DatasetReader): The source raster dataset.
|
|
3140
|
+
output_path (str): The path where the GeoJSON will be saved.
|
|
3141
|
+
tile_size (int, optional): The size of each tile in pixels. Only needed if not in tile_coordinates.
|
|
3142
|
+
stride (int, optional): The stride between tiles in pixels. Used to calculate overlaps between tiles.
|
|
3143
|
+
|
|
3144
|
+
Returns:
|
|
3145
|
+
str: Path to the saved GeoJSON file.
|
|
3146
|
+
"""
|
|
3147
|
+
features = []
|
|
3148
|
+
|
|
3149
|
+
for tile in tile_coordinates:
|
|
3150
|
+
# Get the size from the tile or use the provided parameter
|
|
3151
|
+
tile_width = tile.get("width", tile.get("size", tile_size))
|
|
3152
|
+
tile_height = tile.get("height", tile.get("size", tile_size))
|
|
3153
|
+
|
|
3154
|
+
if tile_width is None or tile_height is None:
|
|
3155
|
+
raise ValueError(
|
|
3156
|
+
"Tile size not found in tile data and no tile_size parameter provided"
|
|
3157
|
+
)
|
|
3158
|
+
|
|
3159
|
+
# Get bounds from the tile
|
|
3160
|
+
if "bounds" in tile:
|
|
3161
|
+
# If bounds are already in geo coordinates
|
|
3162
|
+
minx, miny, maxx, maxy = tile["bounds"]
|
|
3163
|
+
else:
|
|
3164
|
+
# Try to calculate bounds from transform if available
|
|
3165
|
+
if hasattr(src, "transform"):
|
|
3166
|
+
# Convert pixel coordinates to geo coordinates
|
|
3167
|
+
window_transform = src.transform
|
|
3168
|
+
x, y = tile["x"], tile["y"]
|
|
3169
|
+
minx = window_transform[2] + x * window_transform[0]
|
|
3170
|
+
maxy = window_transform[5] + y * window_transform[4]
|
|
3171
|
+
maxx = minx + tile_width * window_transform[0]
|
|
3172
|
+
miny = maxy + tile_height * window_transform[4]
|
|
3173
|
+
else:
|
|
3174
|
+
raise ValueError(
|
|
3175
|
+
"Cannot determine bounds. Neither 'bounds' in tile nor transform in src."
|
|
3176
|
+
)
|
|
3177
|
+
|
|
3178
|
+
# Calculate overlap with neighboring tiles if stride is provided
|
|
3179
|
+
overlap = 0
|
|
3180
|
+
if stride is not None and stride < tile_width:
|
|
3181
|
+
overlap = tile_width - stride
|
|
3182
|
+
|
|
3183
|
+
# Create a polygon from the bounds
|
|
3184
|
+
polygon = box(minx, miny, maxx, maxy)
|
|
3185
|
+
|
|
3186
|
+
# Create a GeoJSON feature
|
|
3187
|
+
feature = {
|
|
3188
|
+
"type": "Feature",
|
|
3189
|
+
"geometry": mapping(polygon),
|
|
3190
|
+
"properties": {
|
|
3191
|
+
"index": tile["index"],
|
|
3192
|
+
"has_features": tile.get("has_features", False),
|
|
3193
|
+
"tile_width_px": tile_width,
|
|
3194
|
+
"tile_height_px": tile_height,
|
|
3195
|
+
},
|
|
3196
|
+
}
|
|
3197
|
+
|
|
3198
|
+
# Add overlap information if stride is provided
|
|
3199
|
+
if stride is not None:
|
|
3200
|
+
feature["properties"]["stride_px"] = stride
|
|
3201
|
+
feature["properties"]["overlap_px"] = overlap
|
|
3202
|
+
|
|
3203
|
+
# Add additional properties from the tile
|
|
3204
|
+
for key, value in tile.items():
|
|
3205
|
+
if key not in ["bounds", "geometry"]:
|
|
3206
|
+
feature["properties"][key] = value
|
|
3207
|
+
|
|
3208
|
+
features.append(feature)
|
|
3209
|
+
|
|
3210
|
+
# Create the GeoJSON collection
|
|
3211
|
+
geojson_collection = {
|
|
3212
|
+
"type": "FeatureCollection",
|
|
3213
|
+
"features": features,
|
|
3214
|
+
"properties": {
|
|
3215
|
+
"crs": (
|
|
3216
|
+
src.crs.to_string() if hasattr(src.crs, "to_string") else str(src.crs)
|
|
3217
|
+
),
|
|
3218
|
+
"total_tiles": len(features),
|
|
3219
|
+
"source_raster_dimensions": (
|
|
3220
|
+
[src.width, src.height] if hasattr(src, "width") else None
|
|
3221
|
+
),
|
|
3222
|
+
},
|
|
3223
|
+
}
|
|
3224
|
+
|
|
3225
|
+
# Create directory if it doesn't exist
|
|
3226
|
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)) or ".", exist_ok=True)
|
|
3227
|
+
|
|
3228
|
+
# Save to file
|
|
3229
|
+
with open(output_path, "w") as f:
|
|
3230
|
+
json.dump(geojson_collection, f)
|
|
3231
|
+
|
|
3232
|
+
print(f"GeoJSON saved to {output_path}")
|
|
3233
|
+
return output_path
|
|
3234
|
+
|
|
3235
|
+
|
|
3236
|
+
def export_training_data(
|
|
3237
|
+
in_raster,
|
|
3238
|
+
out_folder,
|
|
3239
|
+
in_class_data,
|
|
3240
|
+
image_chip_format="GEOTIFF",
|
|
3241
|
+
tile_size_x=256,
|
|
3242
|
+
tile_size_y=256,
|
|
3243
|
+
stride_x=None,
|
|
3244
|
+
stride_y=None,
|
|
3245
|
+
output_nofeature_tiles=True,
|
|
3246
|
+
metadata_format="PASCAL_VOC",
|
|
3247
|
+
start_index=0,
|
|
3248
|
+
class_value_field="class",
|
|
3249
|
+
buffer_radius=0,
|
|
3250
|
+
in_mask_polygons=None,
|
|
3251
|
+
rotation_angle=0,
|
|
3252
|
+
reference_system=None,
|
|
3253
|
+
blacken_around_feature=False,
|
|
3254
|
+
crop_mode="FIXED_SIZE", # Implemented but not fully used yet
|
|
3255
|
+
in_raster2=None,
|
|
3256
|
+
in_instance_data=None,
|
|
3257
|
+
instance_class_value_field=None, # Implemented but not fully used yet
|
|
3258
|
+
min_polygon_overlap_ratio=0.0,
|
|
3259
|
+
all_touched=True,
|
|
3260
|
+
save_geotiff=True,
|
|
3261
|
+
quiet=False,
|
|
3262
|
+
):
|
|
3263
|
+
"""
|
|
3264
|
+
Export training data for deep learning using TorchGeo with progress bar.
|
|
3265
|
+
|
|
3266
|
+
Args:
|
|
3267
|
+
in_raster (str): Path to input raster image.
|
|
3268
|
+
out_folder (str): Output folder path where chips and labels will be saved.
|
|
3269
|
+
in_class_data (str): Path to vector file containing class polygons.
|
|
3270
|
+
image_chip_format (str): Output image format (PNG, JPEG, TIFF, GEOTIFF).
|
|
3271
|
+
tile_size_x (int): Width of image chips in pixels.
|
|
3272
|
+
tile_size_y (int): Height of image chips in pixels.
|
|
3273
|
+
stride_x (int): Horizontal stride between chips. If None, uses tile_size_x.
|
|
3274
|
+
stride_y (int): Vertical stride between chips. If None, uses tile_size_y.
|
|
3275
|
+
output_nofeature_tiles (bool): Whether to export chips without features.
|
|
3276
|
+
metadata_format (str): Output metadata format (PASCAL_VOC, KITTI, COCO).
|
|
3277
|
+
start_index (int): Starting index for chip filenames.
|
|
3278
|
+
class_value_field (str): Field name in in_class_data containing class values.
|
|
3279
|
+
buffer_radius (float): Buffer radius around features (in CRS units).
|
|
3280
|
+
in_mask_polygons (str): Path to vector file containing mask polygons.
|
|
3281
|
+
rotation_angle (float): Rotation angle in degrees.
|
|
3282
|
+
reference_system (str): Reference system code.
|
|
3283
|
+
blacken_around_feature (bool): Whether to mask areas outside of features.
|
|
3284
|
+
crop_mode (str): Crop mode (FIXED_SIZE, CENTERED_ON_FEATURE).
|
|
3285
|
+
in_raster2 (str): Path to secondary raster image.
|
|
3286
|
+
in_instance_data (str): Path to vector file containing instance polygons.
|
|
3287
|
+
instance_class_value_field (str): Field name in in_instance_data for instance classes.
|
|
3288
|
+
min_polygon_overlap_ratio (float): Minimum overlap ratio for polygons.
|
|
3289
|
+
all_touched (bool): Whether to use all_touched=True in rasterization.
|
|
3290
|
+
save_geotiff (bool): Whether to save as GeoTIFF with georeferencing.
|
|
3291
|
+
quiet (bool): If True, suppress most output messages.
|
|
3292
|
+
"""
|
|
3293
|
+
# Create output directories
|
|
3294
|
+
image_dir = os.path.join(out_folder, "images")
|
|
3295
|
+
os.makedirs(image_dir, exist_ok=True)
|
|
3296
|
+
|
|
3297
|
+
label_dir = os.path.join(out_folder, "labels")
|
|
3298
|
+
os.makedirs(label_dir, exist_ok=True)
|
|
3299
|
+
|
|
3300
|
+
# Define annotation directories based on metadata format
|
|
3301
|
+
if metadata_format == "PASCAL_VOC":
|
|
3302
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
3303
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
3304
|
+
elif metadata_format == "COCO":
|
|
3305
|
+
ann_dir = os.path.join(out_folder, "annotations")
|
|
3306
|
+
os.makedirs(ann_dir, exist_ok=True)
|
|
3307
|
+
# Initialize COCO annotations dictionary
|
|
3308
|
+
coco_annotations = {"images": [], "annotations": [], "categories": []}
|
|
3309
|
+
|
|
3310
|
+
# Initialize statistics dictionary
|
|
3311
|
+
stats = {
|
|
3312
|
+
"total_tiles": 0,
|
|
3313
|
+
"tiles_with_features": 0,
|
|
3314
|
+
"feature_pixels": 0,
|
|
3315
|
+
"errors": 0,
|
|
3316
|
+
}
|
|
3317
|
+
|
|
3318
|
+
# Open raster
|
|
3319
|
+
with rasterio.open(in_raster) as src:
|
|
3320
|
+
if not quiet:
|
|
3321
|
+
print(f"\nRaster info for {in_raster}:")
|
|
3322
|
+
print(f" CRS: {src.crs}")
|
|
3323
|
+
print(f" Dimensions: {src.width} x {src.height}")
|
|
3324
|
+
print(f" Bounds: {src.bounds}")
|
|
3325
|
+
|
|
3326
|
+
# Set defaults for stride if not provided
|
|
3327
|
+
if stride_x is None:
|
|
3328
|
+
stride_x = tile_size_x
|
|
3329
|
+
if stride_y is None:
|
|
3330
|
+
stride_y = tile_size_y
|
|
3331
|
+
|
|
3332
|
+
# Calculate number of tiles in x and y directions
|
|
3333
|
+
num_tiles_x = math.ceil((src.width - tile_size_x) / stride_x) + 1
|
|
3334
|
+
num_tiles_y = math.ceil((src.height - tile_size_y) / stride_y) + 1
|
|
3335
|
+
total_tiles = num_tiles_x * num_tiles_y
|
|
3336
|
+
|
|
3337
|
+
# Read class data
|
|
3338
|
+
gdf = gpd.read_file(in_class_data)
|
|
3339
|
+
if not quiet:
|
|
3340
|
+
print(f"Loaded {len(gdf)} features from {in_class_data}")
|
|
3341
|
+
print(f"Available columns: {gdf.columns.tolist()}")
|
|
3342
|
+
print(f"GeoJSON CRS: {gdf.crs}")
|
|
3343
|
+
|
|
3344
|
+
# Check if class_value_field exists
|
|
3345
|
+
if class_value_field not in gdf.columns:
|
|
3346
|
+
if not quiet:
|
|
3347
|
+
print(
|
|
3348
|
+
f"WARNING: '{class_value_field}' field not found in the input data. Using default class value 1."
|
|
3349
|
+
)
|
|
3350
|
+
# Add a default class column
|
|
3351
|
+
gdf[class_value_field] = 1
|
|
3352
|
+
unique_classes = [1]
|
|
3353
|
+
else:
|
|
3354
|
+
# Print unique classes for debugging
|
|
3355
|
+
unique_classes = gdf[class_value_field].unique()
|
|
3356
|
+
if not quiet:
|
|
3357
|
+
print(f"Found {len(unique_classes)} unique classes: {unique_classes}")
|
|
3358
|
+
|
|
3359
|
+
# CRITICAL: Always reproject to match raster CRS to ensure proper alignment
|
|
3360
|
+
if gdf.crs != src.crs:
|
|
3361
|
+
if not quiet:
|
|
3362
|
+
print(f"Reprojecting features from {gdf.crs} to {src.crs}")
|
|
3363
|
+
gdf = gdf.to_crs(src.crs)
|
|
3364
|
+
elif reference_system and gdf.crs != reference_system:
|
|
3365
|
+
if not quiet:
|
|
3366
|
+
print(
|
|
3367
|
+
f"Reprojecting features to specified reference system {reference_system}"
|
|
3368
|
+
)
|
|
3369
|
+
gdf = gdf.to_crs(reference_system)
|
|
3370
|
+
|
|
3371
|
+
# Check overlap between raster and vector data
|
|
3372
|
+
raster_bounds = box(*src.bounds)
|
|
3373
|
+
vector_bounds = box(*gdf.total_bounds)
|
|
3374
|
+
if not raster_bounds.intersects(vector_bounds):
|
|
3375
|
+
if not quiet:
|
|
3376
|
+
print(
|
|
3377
|
+
"WARNING: The vector data doesn't intersect with the raster extent!"
|
|
3378
|
+
)
|
|
3379
|
+
print(f"Raster bounds: {src.bounds}")
|
|
3380
|
+
print(f"Vector bounds: {gdf.total_bounds}")
|
|
3381
|
+
else:
|
|
3382
|
+
overlap = (
|
|
3383
|
+
raster_bounds.intersection(vector_bounds).area / vector_bounds.area
|
|
3384
|
+
)
|
|
3385
|
+
if not quiet:
|
|
3386
|
+
print(f"Overlap between raster and vector: {overlap:.2%}")
|
|
3387
|
+
|
|
3388
|
+
# Apply buffer if specified
|
|
3389
|
+
if buffer_radius > 0:
|
|
3390
|
+
gdf["geometry"] = gdf.buffer(buffer_radius)
|
|
3391
|
+
|
|
3392
|
+
# Initialize class mapping (ensure all classes are mapped to non-zero values)
|
|
3393
|
+
class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
|
|
3394
|
+
|
|
3395
|
+
# Store category info for COCO format
|
|
3396
|
+
if metadata_format == "COCO":
|
|
3397
|
+
for cls_val in unique_classes:
|
|
3398
|
+
coco_annotations["categories"].append(
|
|
3399
|
+
{
|
|
3400
|
+
"id": class_to_id[cls_val],
|
|
3401
|
+
"name": str(cls_val),
|
|
3402
|
+
"supercategory": "object",
|
|
3403
|
+
}
|
|
3404
|
+
)
|
|
3405
|
+
|
|
3406
|
+
# Load mask polygons if provided
|
|
3407
|
+
mask_gdf = None
|
|
3408
|
+
if in_mask_polygons:
|
|
3409
|
+
mask_gdf = gpd.read_file(in_mask_polygons)
|
|
3410
|
+
if reference_system:
|
|
3411
|
+
mask_gdf = mask_gdf.to_crs(reference_system)
|
|
3412
|
+
elif mask_gdf.crs != src.crs:
|
|
3413
|
+
mask_gdf = mask_gdf.to_crs(src.crs)
|
|
3414
|
+
|
|
3415
|
+
# Process instance data if provided
|
|
3416
|
+
instance_gdf = None
|
|
3417
|
+
if in_instance_data:
|
|
3418
|
+
instance_gdf = gpd.read_file(in_instance_data)
|
|
3419
|
+
if reference_system:
|
|
3420
|
+
instance_gdf = instance_gdf.to_crs(reference_system)
|
|
3421
|
+
elif instance_gdf.crs != src.crs:
|
|
3422
|
+
instance_gdf = instance_gdf.to_crs(src.crs)
|
|
3423
|
+
|
|
3424
|
+
# Load secondary raster if provided
|
|
3425
|
+
src2 = None
|
|
3426
|
+
if in_raster2:
|
|
3427
|
+
src2 = rasterio.open(in_raster2)
|
|
3428
|
+
|
|
3429
|
+
# Set up augmentation if rotation is specified
|
|
3430
|
+
augmentation = None
|
|
3431
|
+
if rotation_angle != 0:
|
|
3432
|
+
# Fixed: Added data_keys parameter to AugmentationSequential
|
|
3433
|
+
augmentation = torchgeo.transforms.AugmentationSequential(
|
|
3434
|
+
torch.nn.ModuleList([RandomRotation(rotation_angle)]),
|
|
3435
|
+
data_keys=["image"], # Add data_keys parameter
|
|
3436
|
+
)
|
|
3437
|
+
|
|
3438
|
+
# Initialize annotation ID for COCO format
|
|
3439
|
+
ann_id = 0
|
|
3440
|
+
|
|
3441
|
+
# Create progress bar
|
|
3442
|
+
pbar = tqdm(
|
|
3443
|
+
total=total_tiles,
|
|
3444
|
+
desc=f"Generating tiles (with features: 0)",
|
|
3445
|
+
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
|
|
3446
|
+
)
|
|
3447
|
+
|
|
3448
|
+
# Generate tiles
|
|
3449
|
+
chip_index = start_index
|
|
3450
|
+
for y in range(num_tiles_y):
|
|
3451
|
+
for x in range(num_tiles_x):
|
|
3452
|
+
# Calculate window coordinates
|
|
3453
|
+
window_x = x * stride_x
|
|
3454
|
+
window_y = y * stride_y
|
|
3455
|
+
|
|
3456
|
+
# Adjust for edge cases
|
|
3457
|
+
if window_x + tile_size_x > src.width:
|
|
3458
|
+
window_x = src.width - tile_size_x
|
|
3459
|
+
if window_y + tile_size_y > src.height:
|
|
3460
|
+
window_y = src.height - tile_size_y
|
|
3461
|
+
|
|
3462
|
+
# Adjust window based on crop_mode
|
|
3463
|
+
if crop_mode == "CENTERED_ON_FEATURE" and len(gdf) > 0:
|
|
3464
|
+
# Find the nearest feature to the center of this window
|
|
3465
|
+
window_center_x = window_x + tile_size_x // 2
|
|
3466
|
+
window_center_y = window_y + tile_size_y // 2
|
|
3467
|
+
|
|
3468
|
+
# Convert center to world coordinates
|
|
3469
|
+
center_x, center_y = src.xy(window_center_y, window_center_x)
|
|
3470
|
+
center_point = gpd.points_from_xy([center_x], [center_y])[0]
|
|
3471
|
+
|
|
3472
|
+
# Find nearest feature
|
|
3473
|
+
distances = gdf.geometry.distance(center_point)
|
|
3474
|
+
nearest_idx = distances.idxmin()
|
|
3475
|
+
nearest_feature = gdf.iloc[nearest_idx]
|
|
3476
|
+
|
|
3477
|
+
# Get centroid of nearest feature
|
|
3478
|
+
feature_centroid = nearest_feature.geometry.centroid
|
|
3479
|
+
|
|
3480
|
+
# Convert feature centroid to pixel coordinates
|
|
3481
|
+
feature_row, feature_col = src.index(
|
|
3482
|
+
feature_centroid.x, feature_centroid.y
|
|
3483
|
+
)
|
|
3484
|
+
|
|
3485
|
+
# Adjust window to center on feature
|
|
3486
|
+
window_x = max(
|
|
3487
|
+
0, min(src.width - tile_size_x, feature_col - tile_size_x // 2)
|
|
3488
|
+
)
|
|
3489
|
+
window_y = max(
|
|
3490
|
+
0, min(src.height - tile_size_y, feature_row - tile_size_y // 2)
|
|
3491
|
+
)
|
|
3492
|
+
|
|
3493
|
+
# Define window
|
|
3494
|
+
window = Window(window_x, window_y, tile_size_x, tile_size_y)
|
|
3495
|
+
|
|
3496
|
+
# Get window transform and bounds in source CRS
|
|
3497
|
+
window_transform = src.window_transform(window)
|
|
3498
|
+
|
|
3499
|
+
# Calculate window bounds more explicitly and accurately
|
|
3500
|
+
minx = window_transform[2] # Upper left x
|
|
3501
|
+
maxy = window_transform[5] # Upper left y
|
|
3502
|
+
maxx = minx + tile_size_x * window_transform[0] # Add width
|
|
3503
|
+
miny = (
|
|
3504
|
+
maxy + tile_size_y * window_transform[4]
|
|
3505
|
+
) # Add height (note: transform[4] is typically negative)
|
|
3506
|
+
|
|
3507
|
+
window_bounds = box(minx, miny, maxx, maxy)
|
|
3508
|
+
|
|
3509
|
+
# Apply rotation if specified
|
|
3510
|
+
if rotation_angle != 0:
|
|
3511
|
+
window_bounds = rotate(
|
|
3512
|
+
window_bounds, rotation_angle, origin="center"
|
|
3513
|
+
)
|
|
3514
|
+
|
|
3515
|
+
# Find features that intersect with window
|
|
3516
|
+
window_features = gdf[gdf.intersects(window_bounds)]
|
|
3517
|
+
|
|
3518
|
+
# Process instance data if provided
|
|
3519
|
+
window_instances = None
|
|
3520
|
+
if instance_gdf is not None and instance_class_value_field is not None:
|
|
3521
|
+
window_instances = instance_gdf[
|
|
3522
|
+
instance_gdf.intersects(window_bounds)
|
|
3523
|
+
]
|
|
3524
|
+
if len(window_instances) > 0:
|
|
3525
|
+
if not quiet:
|
|
3526
|
+
pbar.write(
|
|
3527
|
+
f"Found {len(window_instances)} instances in tile {chip_index}"
|
|
3528
|
+
)
|
|
3529
|
+
|
|
3530
|
+
# Skip if no features and output_nofeature_tiles is False
|
|
3531
|
+
if not output_nofeature_tiles and len(window_features) == 0:
|
|
3532
|
+
pbar.update(1) # Still update progress bar
|
|
3533
|
+
continue
|
|
3534
|
+
|
|
3535
|
+
# Check polygon overlap ratio if specified
|
|
3536
|
+
if min_polygon_overlap_ratio > 0 and len(window_features) > 0:
|
|
3537
|
+
valid_features = []
|
|
3538
|
+
for _, feature in window_features.iterrows():
|
|
3539
|
+
overlap_ratio = (
|
|
3540
|
+
feature.geometry.intersection(window_bounds).area
|
|
3541
|
+
/ feature.geometry.area
|
|
3542
|
+
)
|
|
3543
|
+
if overlap_ratio >= min_polygon_overlap_ratio:
|
|
3544
|
+
valid_features.append(feature)
|
|
3545
|
+
|
|
3546
|
+
if len(valid_features) > 0:
|
|
3547
|
+
window_features = gpd.GeoDataFrame(valid_features)
|
|
3548
|
+
elif not output_nofeature_tiles:
|
|
3549
|
+
pbar.update(1) # Still update progress bar
|
|
3550
|
+
continue
|
|
3551
|
+
|
|
3552
|
+
# Apply mask if provided
|
|
3553
|
+
if mask_gdf is not None:
|
|
3554
|
+
mask_features = mask_gdf[mask_gdf.intersects(window_bounds)]
|
|
3555
|
+
if len(mask_features) == 0:
|
|
3556
|
+
pbar.update(1) # Still update progress bar
|
|
3557
|
+
continue
|
|
3558
|
+
|
|
3559
|
+
# Read image data - keep original for GeoTIFF export
|
|
3560
|
+
orig_image_data = src.read(window=window)
|
|
3561
|
+
|
|
3562
|
+
# Create a copy for processing
|
|
3563
|
+
image_data = orig_image_data.copy().astype(np.float32)
|
|
3564
|
+
|
|
3565
|
+
# Normalize image data for processing
|
|
3566
|
+
for band in range(image_data.shape[0]):
|
|
3567
|
+
band_min, band_max = np.percentile(image_data[band], (1, 99))
|
|
3568
|
+
if band_max > band_min:
|
|
3569
|
+
image_data[band] = np.clip(
|
|
3570
|
+
(image_data[band] - band_min) / (band_max - band_min), 0, 1
|
|
3571
|
+
)
|
|
3572
|
+
|
|
3573
|
+
# Read secondary image data if provided
|
|
3574
|
+
if src2:
|
|
3575
|
+
image_data2 = src2.read(window=window)
|
|
3576
|
+
# Stack the two images
|
|
3577
|
+
image_data = np.vstack((image_data, image_data2))
|
|
3578
|
+
|
|
3579
|
+
# Apply blacken_around_feature if needed
|
|
3580
|
+
if blacken_around_feature and len(window_features) > 0:
|
|
3581
|
+
mask = np.zeros((tile_size_y, tile_size_x), dtype=bool)
|
|
3582
|
+
for _, feature in window_features.iterrows():
|
|
3583
|
+
# Project feature to pixel coordinates
|
|
3584
|
+
feature_pixels = features.rasterize(
|
|
3585
|
+
[(feature.geometry, 1)],
|
|
3586
|
+
out_shape=(tile_size_y, tile_size_x),
|
|
3587
|
+
transform=window_transform,
|
|
3588
|
+
)
|
|
3589
|
+
mask = np.logical_or(mask, feature_pixels.astype(bool))
|
|
3590
|
+
|
|
3591
|
+
# Apply mask to image
|
|
3592
|
+
for band in range(image_data.shape[0]):
|
|
3593
|
+
temp = image_data[band, :, :]
|
|
3594
|
+
temp[~mask] = 0
|
|
3595
|
+
image_data[band, :, :] = temp
|
|
3596
|
+
|
|
3597
|
+
# Apply rotation if specified
|
|
3598
|
+
if augmentation:
|
|
3599
|
+
# Convert to torch tensor for augmentation
|
|
3600
|
+
image_tensor = torch.from_numpy(image_data).unsqueeze(
|
|
3601
|
+
0
|
|
3602
|
+
) # Add batch dimension
|
|
3603
|
+
# Apply augmentation with proper data format
|
|
3604
|
+
augmented = augmentation({"image": image_tensor})
|
|
3605
|
+
image_data = (
|
|
3606
|
+
augmented["image"].squeeze(0).numpy()
|
|
3607
|
+
) # Remove batch dimension
|
|
3608
|
+
|
|
3609
|
+
# Create a processed version for regular image formats
|
|
3610
|
+
processed_image = (image_data * 255).astype(np.uint8)
|
|
3611
|
+
|
|
3612
|
+
# Create label mask
|
|
3613
|
+
label_mask = np.zeros((tile_size_y, tile_size_x), dtype=np.uint8)
|
|
3614
|
+
has_features = False
|
|
3615
|
+
|
|
3616
|
+
if len(window_features) > 0:
|
|
3617
|
+
for idx, feature in window_features.iterrows():
|
|
3618
|
+
# Get class value
|
|
3619
|
+
class_val = (
|
|
3620
|
+
feature[class_value_field]
|
|
3621
|
+
if class_value_field in feature
|
|
3622
|
+
else 1
|
|
3623
|
+
)
|
|
3624
|
+
if isinstance(class_val, str):
|
|
3625
|
+
# If class is a string, use its position in the unique classes list
|
|
3626
|
+
class_id = class_to_id.get(class_val, 1)
|
|
3627
|
+
else:
|
|
3628
|
+
# If class is already a number, use it directly
|
|
3629
|
+
class_id = int(class_val) if class_val > 0 else 1
|
|
3630
|
+
|
|
3631
|
+
# Get the geometry in pixel coordinates
|
|
3632
|
+
geom = feature.geometry.intersection(window_bounds)
|
|
3633
|
+
if not geom.is_empty:
|
|
3634
|
+
try:
|
|
3635
|
+
# Rasterize the feature
|
|
3636
|
+
feature_mask = features.rasterize(
|
|
3637
|
+
[(geom, class_id)],
|
|
3638
|
+
out_shape=(tile_size_y, tile_size_x),
|
|
3639
|
+
transform=window_transform,
|
|
3640
|
+
fill=0,
|
|
3641
|
+
all_touched=all_touched,
|
|
3642
|
+
)
|
|
3643
|
+
|
|
3644
|
+
# Update mask with higher class values taking precedence
|
|
3645
|
+
label_mask = np.maximum(label_mask, feature_mask)
|
|
3646
|
+
|
|
3647
|
+
# Check if any pixels were added
|
|
3648
|
+
if np.any(feature_mask):
|
|
3649
|
+
has_features = True
|
|
3650
|
+
except Exception as e:
|
|
3651
|
+
if not quiet:
|
|
3652
|
+
pbar.write(f"Error rasterizing feature {idx}: {e}")
|
|
3653
|
+
stats["errors"] += 1
|
|
3654
|
+
|
|
3655
|
+
# Save as GeoTIFF if requested
|
|
3656
|
+
if save_geotiff or image_chip_format.upper() in [
|
|
3657
|
+
"TIFF",
|
|
3658
|
+
"TIF",
|
|
3659
|
+
"GEOTIFF",
|
|
3660
|
+
]:
|
|
3661
|
+
# Standardize extension to .tif for GeoTIFF files
|
|
3662
|
+
image_filename = f"tile_{chip_index:06d}.tif"
|
|
3663
|
+
image_path = os.path.join(image_dir, image_filename)
|
|
3664
|
+
|
|
3665
|
+
# Create profile for the GeoTIFF
|
|
3666
|
+
profile = src.profile.copy()
|
|
3667
|
+
profile.update(
|
|
3668
|
+
{
|
|
3669
|
+
"height": tile_size_y,
|
|
3670
|
+
"width": tile_size_x,
|
|
3671
|
+
"count": orig_image_data.shape[0],
|
|
3672
|
+
"transform": window_transform,
|
|
3673
|
+
}
|
|
3674
|
+
)
|
|
3675
|
+
|
|
3676
|
+
# Save the GeoTIFF with original data
|
|
3677
|
+
try:
|
|
3678
|
+
with rasterio.open(image_path, "w", **profile) as dst:
|
|
3679
|
+
dst.write(orig_image_data)
|
|
3680
|
+
stats["total_tiles"] += 1
|
|
3681
|
+
except Exception as e:
|
|
3682
|
+
if not quiet:
|
|
3683
|
+
pbar.write(
|
|
3684
|
+
f"ERROR saving image GeoTIFF for tile {chip_index}: {e}"
|
|
3685
|
+
)
|
|
3686
|
+
stats["errors"] += 1
|
|
3687
|
+
else:
|
|
3688
|
+
# For non-GeoTIFF formats, use PIL to save the image
|
|
3689
|
+
image_filename = (
|
|
3690
|
+
f"tile_{chip_index:06d}.{image_chip_format.lower()}"
|
|
3691
|
+
)
|
|
3692
|
+
image_path = os.path.join(image_dir, image_filename)
|
|
3693
|
+
|
|
3694
|
+
# Create PIL image for saving
|
|
3695
|
+
if processed_image.shape[0] == 1:
|
|
3696
|
+
img = Image.fromarray(processed_image[0])
|
|
3697
|
+
elif processed_image.shape[0] == 3:
|
|
3698
|
+
# For RGB, need to transpose and make sure it's the right data type
|
|
3699
|
+
rgb_data = np.transpose(processed_image, (1, 2, 0))
|
|
3700
|
+
img = Image.fromarray(rgb_data)
|
|
3701
|
+
else:
|
|
3702
|
+
# For multiband images, save only RGB or first three bands
|
|
3703
|
+
rgb_data = np.transpose(processed_image[:3], (1, 2, 0))
|
|
3704
|
+
img = Image.fromarray(rgb_data)
|
|
3705
|
+
|
|
3706
|
+
# Save image
|
|
3707
|
+
try:
|
|
3708
|
+
img.save(image_path)
|
|
3709
|
+
stats["total_tiles"] += 1
|
|
3710
|
+
except Exception as e:
|
|
3711
|
+
if not quiet:
|
|
3712
|
+
pbar.write(f"ERROR saving image for tile {chip_index}: {e}")
|
|
3713
|
+
stats["errors"] += 1
|
|
3714
|
+
|
|
3715
|
+
# Save label as GeoTIFF
|
|
3716
|
+
label_filename = f"tile_{chip_index:06d}.tif"
|
|
3717
|
+
label_path = os.path.join(label_dir, label_filename)
|
|
3718
|
+
|
|
3719
|
+
# Create profile for label GeoTIFF
|
|
3720
|
+
label_profile = {
|
|
3721
|
+
"driver": "GTiff",
|
|
3722
|
+
"height": tile_size_y,
|
|
3723
|
+
"width": tile_size_x,
|
|
3724
|
+
"count": 1,
|
|
3725
|
+
"dtype": "uint8",
|
|
3726
|
+
"crs": src.crs,
|
|
3727
|
+
"transform": window_transform,
|
|
3728
|
+
}
|
|
3729
|
+
|
|
3730
|
+
# Save label GeoTIFF
|
|
3731
|
+
try:
|
|
3732
|
+
with rasterio.open(label_path, "w", **label_profile) as dst:
|
|
3733
|
+
dst.write(label_mask, 1)
|
|
3734
|
+
|
|
3735
|
+
if has_features:
|
|
3736
|
+
pixel_count = np.count_nonzero(label_mask)
|
|
3737
|
+
stats["tiles_with_features"] += 1
|
|
3738
|
+
stats["feature_pixels"] += pixel_count
|
|
3739
|
+
except Exception as e:
|
|
3740
|
+
if not quiet:
|
|
3741
|
+
pbar.write(f"ERROR saving label for tile {chip_index}: {e}")
|
|
3742
|
+
stats["errors"] += 1
|
|
3743
|
+
|
|
3744
|
+
# Also save a PNG version for easy visualization if requested
|
|
3745
|
+
if metadata_format == "PASCAL_VOC":
|
|
3746
|
+
try:
|
|
3747
|
+
# Ensure correct data type for PIL
|
|
3748
|
+
png_label = label_mask.astype(np.uint8)
|
|
3749
|
+
label_img = Image.fromarray(png_label)
|
|
3750
|
+
label_png_path = os.path.join(
|
|
3751
|
+
label_dir, f"tile_{chip_index:06d}.png"
|
|
3752
|
+
)
|
|
3753
|
+
label_img.save(label_png_path)
|
|
3754
|
+
except Exception as e:
|
|
3755
|
+
if not quiet:
|
|
3756
|
+
pbar.write(
|
|
3757
|
+
f"ERROR saving PNG label for tile {chip_index}: {e}"
|
|
3758
|
+
)
|
|
3759
|
+
pbar.write(
|
|
3760
|
+
f" Label mask shape: {label_mask.shape}, dtype: {label_mask.dtype}"
|
|
3761
|
+
)
|
|
3762
|
+
# Try again with explicit conversion
|
|
3763
|
+
try:
|
|
3764
|
+
# Alternative approach for problematic arrays
|
|
3765
|
+
png_data = np.zeros(
|
|
3766
|
+
(tile_size_y, tile_size_x), dtype=np.uint8
|
|
3767
|
+
)
|
|
3768
|
+
np.copyto(png_data, label_mask, casting="unsafe")
|
|
3769
|
+
label_img = Image.fromarray(png_data)
|
|
3770
|
+
label_img.save(label_png_path)
|
|
3771
|
+
pbar.write(
|
|
3772
|
+
f" Succeeded using alternative conversion method"
|
|
3773
|
+
)
|
|
3774
|
+
except Exception as e2:
|
|
3775
|
+
pbar.write(f" Second attempt also failed: {e2}")
|
|
3776
|
+
stats["errors"] += 1
|
|
3777
|
+
|
|
3778
|
+
# Generate annotations
|
|
3779
|
+
if metadata_format == "PASCAL_VOC" and len(window_features) > 0:
|
|
3780
|
+
# Create XML annotation
|
|
3781
|
+
root = ET.Element("annotation")
|
|
3782
|
+
ET.SubElement(root, "folder").text = "images"
|
|
3783
|
+
ET.SubElement(root, "filename").text = image_filename
|
|
3784
|
+
|
|
3785
|
+
size = ET.SubElement(root, "size")
|
|
3786
|
+
ET.SubElement(size, "width").text = str(tile_size_x)
|
|
3787
|
+
ET.SubElement(size, "height").text = str(tile_size_y)
|
|
3788
|
+
ET.SubElement(size, "depth").text = str(min(image_data.shape[0], 3))
|
|
3789
|
+
|
|
3790
|
+
# Add georeference information
|
|
3791
|
+
geo = ET.SubElement(root, "georeference")
|
|
3792
|
+
ET.SubElement(geo, "crs").text = str(src.crs)
|
|
3793
|
+
ET.SubElement(geo, "transform").text = str(
|
|
3794
|
+
window_transform
|
|
3795
|
+
).replace("\n", "")
|
|
3796
|
+
ET.SubElement(geo, "bounds").text = (
|
|
3797
|
+
f"{minx}, {miny}, {maxx}, {maxy}"
|
|
3798
|
+
)
|
|
3799
|
+
|
|
3800
|
+
for _, feature in window_features.iterrows():
|
|
3801
|
+
# Convert feature geometry to pixel coordinates
|
|
3802
|
+
feature_bounds = feature.geometry.intersection(window_bounds)
|
|
3803
|
+
if feature_bounds.is_empty:
|
|
3804
|
+
continue
|
|
3805
|
+
|
|
3806
|
+
# Get pixel coordinates of bounds
|
|
3807
|
+
minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
|
|
3808
|
+
|
|
3809
|
+
# Convert to pixel coordinates
|
|
3810
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3811
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3812
|
+
|
|
3813
|
+
# Ensure coordinates are within bounds
|
|
3814
|
+
xmin = max(0, min(tile_size_x, int(col_min)))
|
|
3815
|
+
ymin = max(0, min(tile_size_y, int(row_min)))
|
|
3816
|
+
xmax = max(0, min(tile_size_x, int(col_max)))
|
|
3817
|
+
ymax = max(0, min(tile_size_y, int(row_max)))
|
|
3818
|
+
|
|
3819
|
+
# Skip if box is too small
|
|
3820
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
3821
|
+
continue
|
|
3822
|
+
|
|
3823
|
+
obj = ET.SubElement(root, "object")
|
|
3824
|
+
ET.SubElement(obj, "name").text = str(
|
|
3825
|
+
feature[class_value_field]
|
|
3826
|
+
)
|
|
3827
|
+
ET.SubElement(obj, "difficult").text = "0"
|
|
3828
|
+
|
|
3829
|
+
bbox = ET.SubElement(obj, "bndbox")
|
|
3830
|
+
ET.SubElement(bbox, "xmin").text = str(xmin)
|
|
3831
|
+
ET.SubElement(bbox, "ymin").text = str(ymin)
|
|
3832
|
+
ET.SubElement(bbox, "xmax").text = str(xmax)
|
|
3833
|
+
ET.SubElement(bbox, "ymax").text = str(ymax)
|
|
3834
|
+
|
|
3835
|
+
# Save XML
|
|
3836
|
+
try:
|
|
3837
|
+
tree = ET.ElementTree(root)
|
|
3838
|
+
xml_path = os.path.join(ann_dir, f"tile_{chip_index:06d}.xml")
|
|
3839
|
+
tree.write(xml_path)
|
|
3840
|
+
except Exception as e:
|
|
3841
|
+
if not quiet:
|
|
3842
|
+
pbar.write(
|
|
3843
|
+
f"ERROR saving XML annotation for tile {chip_index}: {e}"
|
|
3844
|
+
)
|
|
3845
|
+
stats["errors"] += 1
|
|
3846
|
+
|
|
3847
|
+
elif metadata_format == "COCO" and len(window_features) > 0:
|
|
3848
|
+
# Add image info
|
|
3849
|
+
image_id = chip_index
|
|
3850
|
+
coco_annotations["images"].append(
|
|
3851
|
+
{
|
|
3852
|
+
"id": image_id,
|
|
3853
|
+
"file_name": image_filename,
|
|
3854
|
+
"width": tile_size_x,
|
|
3855
|
+
"height": tile_size_y,
|
|
3856
|
+
"crs": str(src.crs),
|
|
3857
|
+
"transform": str(window_transform),
|
|
3858
|
+
}
|
|
3859
|
+
)
|
|
3860
|
+
|
|
3861
|
+
# Add annotations for each feature
|
|
3862
|
+
for _, feature in window_features.iterrows():
|
|
3863
|
+
feature_bounds = feature.geometry.intersection(window_bounds)
|
|
3864
|
+
if feature_bounds.is_empty:
|
|
3865
|
+
continue
|
|
3866
|
+
|
|
3867
|
+
# Get pixel coordinates of bounds
|
|
3868
|
+
minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
|
|
3869
|
+
|
|
3870
|
+
# Convert to pixel coordinates
|
|
3871
|
+
col_min, row_min = ~window_transform * (minx_f, maxy_f)
|
|
3872
|
+
col_max, row_max = ~window_transform * (maxx_f, miny_f)
|
|
3873
|
+
|
|
3874
|
+
# Ensure coordinates are within bounds
|
|
3875
|
+
xmin = max(0, min(tile_size_x, int(col_min)))
|
|
3876
|
+
ymin = max(0, min(tile_size_y, int(row_min)))
|
|
3877
|
+
xmax = max(0, min(tile_size_x, int(col_max)))
|
|
3878
|
+
ymax = max(0, min(tile_size_y, int(row_max)))
|
|
3879
|
+
|
|
3880
|
+
# Skip if box is too small
|
|
3881
|
+
if xmax - xmin < 1 or ymax - ymin < 1:
|
|
3882
|
+
continue
|
|
3883
|
+
|
|
3884
|
+
width = xmax - xmin
|
|
3885
|
+
height = ymax - ymin
|
|
3886
|
+
|
|
3887
|
+
# Add annotation
|
|
3888
|
+
ann_id += 1
|
|
3889
|
+
category_id = class_to_id[feature[class_value_field]]
|
|
3890
|
+
|
|
3891
|
+
coco_annotations["annotations"].append(
|
|
3892
|
+
{
|
|
3893
|
+
"id": ann_id,
|
|
3894
|
+
"image_id": image_id,
|
|
3895
|
+
"category_id": category_id,
|
|
3896
|
+
"bbox": [xmin, ymin, width, height],
|
|
3897
|
+
"area": width * height,
|
|
3898
|
+
"iscrowd": 0,
|
|
3899
|
+
}
|
|
3900
|
+
)
|
|
3901
|
+
|
|
3902
|
+
# Update progress bar
|
|
3903
|
+
pbar.update(1)
|
|
3904
|
+
pbar.set_description(
|
|
3905
|
+
f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
|
|
3906
|
+
)
|
|
3907
|
+
|
|
3908
|
+
chip_index += 1
|
|
3909
|
+
|
|
3910
|
+
# Close progress bar
|
|
3911
|
+
pbar.close()
|
|
3912
|
+
|
|
3913
|
+
# Save COCO annotations if applicable
|
|
3914
|
+
if metadata_format == "COCO":
|
|
3915
|
+
try:
|
|
3916
|
+
with open(os.path.join(ann_dir, "instances.json"), "w") as f:
|
|
3917
|
+
json.dump(coco_annotations, f)
|
|
3918
|
+
except Exception as e:
|
|
3919
|
+
if not quiet:
|
|
3920
|
+
print(f"ERROR saving COCO annotations: {e}")
|
|
3921
|
+
stats["errors"] += 1
|
|
3922
|
+
|
|
3923
|
+
# Close secondary raster if opened
|
|
3924
|
+
if src2:
|
|
3925
|
+
src2.close()
|
|
3926
|
+
|
|
3927
|
+
# Print summary
|
|
3928
|
+
if not quiet:
|
|
3929
|
+
print("\n------- Export Summary -------")
|
|
3930
|
+
print(f"Total tiles exported: {stats['total_tiles']}")
|
|
3931
|
+
print(
|
|
3932
|
+
f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
|
|
3933
|
+
)
|
|
3934
|
+
if stats["tiles_with_features"] > 0:
|
|
3935
|
+
print(
|
|
3936
|
+
f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
|
|
3937
|
+
)
|
|
3938
|
+
if stats["errors"] > 0:
|
|
3939
|
+
print(f"Errors encountered: {stats['errors']}")
|
|
3940
|
+
print(f"Output saved to: {out_folder}")
|
|
3941
|
+
|
|
3942
|
+
# Verify georeference in a sample image and label
|
|
3943
|
+
if stats["total_tiles"] > 0:
|
|
3944
|
+
print("\n------- Georeference Verification -------")
|
|
3945
|
+
sample_image = os.path.join(image_dir, f"tile_{start_index}.tif")
|
|
3946
|
+
sample_label = os.path.join(label_dir, f"tile_{start_index}.tif")
|
|
3947
|
+
|
|
3948
|
+
if os.path.exists(sample_image):
|
|
3949
|
+
try:
|
|
3950
|
+
with rasterio.open(sample_image) as img:
|
|
3951
|
+
print(f"Image CRS: {img.crs}")
|
|
3952
|
+
print(f"Image transform: {img.transform}")
|
|
3953
|
+
print(
|
|
3954
|
+
f"Image has georeference: {img.crs is not None and img.transform is not None}"
|
|
3955
|
+
)
|
|
3956
|
+
print(
|
|
3957
|
+
f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
|
|
3958
|
+
)
|
|
3959
|
+
except Exception as e:
|
|
3960
|
+
print(f"Error verifying image georeference: {e}")
|
|
3961
|
+
|
|
3962
|
+
if os.path.exists(sample_label):
|
|
3963
|
+
try:
|
|
3964
|
+
with rasterio.open(sample_label) as lbl:
|
|
3965
|
+
print(f"Label CRS: {lbl.crs}")
|
|
3966
|
+
print(f"Label transform: {lbl.transform}")
|
|
3967
|
+
print(
|
|
3968
|
+
f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
|
|
3969
|
+
)
|
|
3970
|
+
print(
|
|
3971
|
+
f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
|
|
3972
|
+
)
|
|
3973
|
+
except Exception as e:
|
|
3974
|
+
print(f"Error verifying label georeference: {e}")
|
|
3975
|
+
|
|
3976
|
+
# Return statistics
|
|
3977
|
+
return stats, out_folder
|
|
3978
|
+
|
|
3979
|
+
|
|
3980
|
+
def masks_to_vector(
|
|
3981
|
+
mask_path,
|
|
3982
|
+
output_path=None,
|
|
3983
|
+
simplify_tolerance=1.0,
|
|
3984
|
+
mask_threshold=0.5,
|
|
3985
|
+
min_object_area=100,
|
|
3986
|
+
max_object_area=None,
|
|
3987
|
+
nms_iou_threshold=0.5,
|
|
3988
|
+
):
|
|
3989
|
+
"""
|
|
3990
|
+
Convert a building mask GeoTIFF to vector polygons and save as a vector dataset.
|
|
3991
|
+
|
|
3992
|
+
Args:
|
|
3993
|
+
mask_path: Path to the building masks GeoTIFF
|
|
3994
|
+
output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
|
|
3995
|
+
simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
|
|
3996
|
+
mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
|
|
3997
|
+
min_object_area: Minimum area in pixels to keep a building (default: self.min_object_area)
|
|
3998
|
+
max_object_area: Maximum area in pixels to keep a building (default: self.max_object_area)
|
|
3999
|
+
nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
|
|
4000
|
+
|
|
4001
|
+
Returns:
|
|
4002
|
+
GeoDataFrame with building footprints
|
|
4003
|
+
"""
|
|
4004
|
+
# Set default output path if not provided
|
|
4005
|
+
# if output_path is None:
|
|
4006
|
+
# output_path = os.path.splitext(mask_path)[0] + ".geojson"
|
|
4007
|
+
|
|
4008
|
+
print(f"Converting mask to GeoJSON with parameters:")
|
|
4009
|
+
print(f"- Mask threshold: {mask_threshold}")
|
|
4010
|
+
print(f"- Min building area: {min_object_area}")
|
|
4011
|
+
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
4012
|
+
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
4013
|
+
|
|
4014
|
+
# Open the mask raster
|
|
4015
|
+
with rasterio.open(mask_path) as src:
|
|
4016
|
+
# Read the mask data
|
|
4017
|
+
mask_data = src.read(1)
|
|
4018
|
+
transform = src.transform
|
|
4019
|
+
crs = src.crs
|
|
4020
|
+
|
|
4021
|
+
# Print mask statistics
|
|
4022
|
+
print(f"Mask dimensions: {mask_data.shape}")
|
|
4023
|
+
print(f"Mask value range: {mask_data.min()} to {mask_data.max()}")
|
|
4024
|
+
|
|
4025
|
+
# Prepare for connected component analysis
|
|
4026
|
+
# Binarize the mask based on threshold
|
|
4027
|
+
binary_mask = (mask_data > (mask_threshold * 255)).astype(np.uint8)
|
|
4028
|
+
|
|
4029
|
+
# Apply morphological operations for better results (optional)
|
|
4030
|
+
kernel = np.ones((3, 3), np.uint8)
|
|
4031
|
+
binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
|
|
4032
|
+
|
|
4033
|
+
# Find connected components
|
|
4034
|
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
|
4035
|
+
binary_mask, connectivity=8
|
|
4036
|
+
)
|
|
4037
|
+
|
|
4038
|
+
print(f"Found {num_labels-1} potential buildings") # Subtract 1 for background
|
|
4039
|
+
|
|
4040
|
+
# Create list to store polygons and confidence values
|
|
4041
|
+
all_polygons = []
|
|
4042
|
+
all_confidences = []
|
|
4043
|
+
|
|
4044
|
+
# Process each component (skip the first one which is background)
|
|
4045
|
+
for i in tqdm(range(1, num_labels)):
|
|
4046
|
+
# Extract this building
|
|
4047
|
+
area = stats[i, cv2.CC_STAT_AREA]
|
|
4048
|
+
|
|
4049
|
+
# Skip if too small
|
|
4050
|
+
if area < min_object_area:
|
|
4051
|
+
continue
|
|
4052
|
+
|
|
4053
|
+
# Skip if too large
|
|
4054
|
+
if max_object_area is not None and area > max_object_area:
|
|
4055
|
+
continue
|
|
4056
|
+
|
|
4057
|
+
# Create a mask for this building
|
|
4058
|
+
building_mask = (labels == i).astype(np.uint8)
|
|
4059
|
+
|
|
4060
|
+
# Find contours
|
|
4061
|
+
contours, _ = cv2.findContours(
|
|
4062
|
+
building_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
4063
|
+
)
|
|
4064
|
+
|
|
4065
|
+
# Process each contour
|
|
4066
|
+
for contour in contours:
|
|
4067
|
+
# Skip if too few points
|
|
4068
|
+
if contour.shape[0] < 3:
|
|
4069
|
+
continue
|
|
4070
|
+
|
|
4071
|
+
# Simplify contour if it has many points
|
|
4072
|
+
if contour.shape[0] > 50 and simplify_tolerance > 0:
|
|
4073
|
+
epsilon = simplify_tolerance * cv2.arcLength(contour, True)
|
|
4074
|
+
contour = cv2.approxPolyDP(contour, epsilon, True)
|
|
4075
|
+
|
|
4076
|
+
# Convert to list of (x, y) coordinates
|
|
4077
|
+
polygon_points = contour.reshape(-1, 2)
|
|
4078
|
+
|
|
4079
|
+
# Convert pixel coordinates to geographic coordinates
|
|
4080
|
+
geo_points = []
|
|
4081
|
+
for x, y in polygon_points:
|
|
4082
|
+
gx, gy = transform * (x, y)
|
|
4083
|
+
geo_points.append((gx, gy))
|
|
4084
|
+
|
|
4085
|
+
# Create Shapely polygon
|
|
4086
|
+
if len(geo_points) >= 3:
|
|
4087
|
+
try:
|
|
4088
|
+
shapely_poly = Polygon(geo_points)
|
|
4089
|
+
if shapely_poly.is_valid and shapely_poly.area > 0:
|
|
4090
|
+
all_polygons.append(shapely_poly)
|
|
4091
|
+
|
|
4092
|
+
# Calculate "confidence" as normalized size
|
|
4093
|
+
# This is a proxy since we don't have model confidence scores
|
|
4094
|
+
normalized_size = min(1.0, area / 1000) # Cap at 1.0
|
|
4095
|
+
all_confidences.append(normalized_size)
|
|
4096
|
+
except Exception as e:
|
|
4097
|
+
print(f"Error creating polygon: {e}")
|
|
4098
|
+
|
|
4099
|
+
print(f"Created {len(all_polygons)} valid polygons")
|
|
4100
|
+
|
|
4101
|
+
# Create GeoDataFrame
|
|
4102
|
+
if not all_polygons:
|
|
4103
|
+
print("No valid polygons found")
|
|
4104
|
+
return None
|
|
4105
|
+
|
|
4106
|
+
gdf = gpd.GeoDataFrame(
|
|
4107
|
+
{
|
|
4108
|
+
"geometry": all_polygons,
|
|
4109
|
+
"confidence": all_confidences,
|
|
4110
|
+
"class": 1, # Building class
|
|
4111
|
+
},
|
|
4112
|
+
crs=crs,
|
|
4113
|
+
)
|
|
4114
|
+
|
|
4115
|
+
def filter_overlapping_polygons(gdf, **kwargs):
|
|
4116
|
+
"""
|
|
4117
|
+
Filter overlapping polygons using non-maximum suppression.
|
|
4118
|
+
|
|
4119
|
+
Args:
|
|
4120
|
+
gdf: GeoDataFrame with polygons
|
|
4121
|
+
**kwargs: Optional parameters:
|
|
4122
|
+
nms_iou_threshold: IoU threshold for filtering
|
|
4123
|
+
|
|
4124
|
+
Returns:
|
|
4125
|
+
Filtered GeoDataFrame
|
|
4126
|
+
"""
|
|
4127
|
+
if len(gdf) <= 1:
|
|
4128
|
+
return gdf
|
|
4129
|
+
|
|
4130
|
+
# Get parameters from kwargs or use instance defaults
|
|
4131
|
+
iou_threshold = kwargs.get("nms_iou_threshold", nms_iou_threshold)
|
|
4132
|
+
|
|
4133
|
+
# Sort by confidence
|
|
4134
|
+
gdf = gdf.sort_values("confidence", ascending=False)
|
|
4135
|
+
|
|
4136
|
+
# Fix any invalid geometries
|
|
4137
|
+
gdf["geometry"] = gdf["geometry"].apply(
|
|
4138
|
+
lambda geom: geom.buffer(0) if not geom.is_valid else geom
|
|
4139
|
+
)
|
|
4140
|
+
|
|
4141
|
+
keep_indices = []
|
|
4142
|
+
polygons = gdf.geometry.values
|
|
4143
|
+
|
|
4144
|
+
for i in range(len(polygons)):
|
|
4145
|
+
if i in keep_indices:
|
|
4146
|
+
continue
|
|
4147
|
+
|
|
4148
|
+
keep = True
|
|
4149
|
+
for j in keep_indices:
|
|
4150
|
+
# Skip invalid geometries
|
|
4151
|
+
if not polygons[i].is_valid or not polygons[j].is_valid:
|
|
4152
|
+
continue
|
|
4153
|
+
|
|
4154
|
+
# Calculate IoU
|
|
4155
|
+
try:
|
|
4156
|
+
intersection = polygons[i].intersection(polygons[j]).area
|
|
4157
|
+
union = polygons[i].area + polygons[j].area - intersection
|
|
4158
|
+
iou = intersection / union if union > 0 else 0
|
|
4159
|
+
|
|
4160
|
+
if iou > iou_threshold:
|
|
4161
|
+
keep = False
|
|
4162
|
+
break
|
|
4163
|
+
except Exception:
|
|
4164
|
+
# Skip on topology exceptions
|
|
4165
|
+
continue
|
|
4166
|
+
|
|
4167
|
+
if keep:
|
|
4168
|
+
keep_indices.append(i)
|
|
4169
|
+
|
|
4170
|
+
return gdf.iloc[keep_indices]
|
|
4171
|
+
|
|
4172
|
+
# Apply non-maximum suppression to remove overlapping polygons
|
|
4173
|
+
gdf = filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
|
|
4174
|
+
|
|
4175
|
+
print(f"Final building count after filtering: {len(gdf)}")
|
|
4176
|
+
|
|
4177
|
+
# Save to file
|
|
4178
|
+
if output_path is not None:
|
|
4179
|
+
gdf.to_file(output_path)
|
|
4180
|
+
print(f"Saved {len(gdf)} building footprints to {output_path}")
|
|
4181
|
+
|
|
4182
|
+
return gdf
|
|
4183
|
+
|
|
4184
|
+
|
|
4185
|
+
def read_vector(source, layer=None, **kwargs):
|
|
4186
|
+
"""Reads vector data from various formats including GeoParquet.
|
|
4187
|
+
|
|
4188
|
+
This function dynamically determines the file type based on extension
|
|
4189
|
+
and reads it into a GeoDataFrame. It supports both local files and HTTP/HTTPS URLs.
|
|
4190
|
+
|
|
4191
|
+
Args:
|
|
4192
|
+
source: String path to the vector file or URL.
|
|
4193
|
+
layer: String or integer specifying which layer to read from multi-layer
|
|
4194
|
+
files (only applicable for formats like GPKG, GeoJSON, etc.).
|
|
4195
|
+
Defaults to None.
|
|
4196
|
+
**kwargs: Additional keyword arguments to pass to the underlying reader.
|
|
4197
|
+
|
|
4198
|
+
Returns:
|
|
4199
|
+
geopandas.GeoDataFrame: A GeoDataFrame containing the vector data.
|
|
4200
|
+
|
|
4201
|
+
Raises:
|
|
4202
|
+
ValueError: If the file format is not supported or source cannot be accessed.
|
|
4203
|
+
|
|
4204
|
+
Examples:
|
|
4205
|
+
Read a local shapefile
|
|
4206
|
+
>>> gdf = read_vector("path/to/data.shp")
|
|
4207
|
+
>>>
|
|
4208
|
+
Read a GeoParquet file from URL
|
|
4209
|
+
>>> gdf = read_vector("https://example.com/data.parquet")
|
|
4210
|
+
>>>
|
|
4211
|
+
Read a specific layer from a GeoPackage
|
|
4212
|
+
>>> gdf = read_vector("path/to/data.gpkg", layer="layer_name")
|
|
4213
|
+
"""
|
|
4214
|
+
|
|
4215
|
+
import fiona
|
|
4216
|
+
import urllib.parse
|
|
4217
|
+
|
|
4218
|
+
# Determine if source is a URL or local file
|
|
4219
|
+
parsed_url = urllib.parse.urlparse(source)
|
|
4220
|
+
is_url = parsed_url.scheme in ["http", "https"]
|
|
4221
|
+
|
|
4222
|
+
# If it's a local file, check if it exists
|
|
4223
|
+
if not is_url and not os.path.exists(source):
|
|
4224
|
+
raise ValueError(f"File does not exist: {source}")
|
|
4225
|
+
|
|
4226
|
+
# Get file extension
|
|
4227
|
+
_, ext = os.path.splitext(source)
|
|
4228
|
+
ext = ext.lower()
|
|
4229
|
+
|
|
4230
|
+
# Handle GeoParquet files
|
|
4231
|
+
if ext in [".parquet", ".pq", ".geoparquet"]:
|
|
4232
|
+
return gpd.read_parquet(source, **kwargs)
|
|
4233
|
+
|
|
4234
|
+
# Handle common vector formats
|
|
4235
|
+
if ext in [".shp", ".geojson", ".json", ".gpkg", ".gml", ".kml", ".gpx"]:
|
|
4236
|
+
# For formats that might have multiple layers
|
|
4237
|
+
if ext in [".gpkg", ".gml"] and layer is not None:
|
|
4238
|
+
return gpd.read_file(source, layer=layer, **kwargs)
|
|
4239
|
+
return gpd.read_file(source, **kwargs)
|
|
4240
|
+
|
|
4241
|
+
# Try to use fiona to identify valid layers for formats that might have them
|
|
4242
|
+
# Only attempt this for local files as fiona.listlayers might not work with URLs
|
|
4243
|
+
if layer is None and ext in [".gpkg", ".gml"] and not is_url:
|
|
4244
|
+
try:
|
|
4245
|
+
layers = fiona.listlayers(source)
|
|
4246
|
+
if layers:
|
|
4247
|
+
return gpd.read_file(source, layer=layers[0], **kwargs)
|
|
4248
|
+
except Exception:
|
|
4249
|
+
# If listing layers fails, we'll fall through to the generic read attempt
|
|
4250
|
+
pass
|
|
4251
|
+
|
|
4252
|
+
# For other formats or when layer listing fails, attempt to read using GeoPandas
|
|
4253
|
+
try:
|
|
4254
|
+
return gpd.read_file(source, **kwargs)
|
|
4255
|
+
except Exception as e:
|
|
4256
|
+
raise ValueError(f"Could not read from source '{source}': {str(e)}")
|
|
4257
|
+
|
|
4258
|
+
|
|
4259
|
+
def read_raster(source, band=None, masked=True, **kwargs):
|
|
4260
|
+
"""Reads raster data from various formats using rioxarray.
|
|
4261
|
+
|
|
4262
|
+
This function reads raster data from local files or URLs into a rioxarray
|
|
4263
|
+
data structure with preserved geospatial metadata.
|
|
4264
|
+
|
|
4265
|
+
Args:
|
|
4266
|
+
source: String path to the raster file or URL.
|
|
4267
|
+
band: Integer or list of integers specifying which band(s) to read.
|
|
4268
|
+
Defaults to None (all bands).
|
|
4269
|
+
masked: Boolean indicating whether to mask nodata values.
|
|
4270
|
+
Defaults to True.
|
|
4271
|
+
**kwargs: Additional keyword arguments to pass to rioxarray.open_rasterio.
|
|
4272
|
+
|
|
4273
|
+
Returns:
|
|
4274
|
+
xarray.DataArray: A DataArray containing the raster data with geospatial
|
|
4275
|
+
metadata preserved.
|
|
4276
|
+
|
|
4277
|
+
Raises:
|
|
4278
|
+
ValueError: If the file format is not supported or source cannot be accessed.
|
|
4279
|
+
|
|
4280
|
+
Examples:
|
|
4281
|
+
Read a local GeoTIFF
|
|
4282
|
+
>>> raster = read_raster("path/to/data.tif")
|
|
4283
|
+
>>>
|
|
4284
|
+
Read only band 1 from a remote GeoTIFF
|
|
4285
|
+
>>> raster = read_raster("https://example.com/data.tif", band=1)
|
|
4286
|
+
>>>
|
|
4287
|
+
Read a raster without masking nodata values
|
|
4288
|
+
>>> raster = read_raster("path/to/data.tif", masked=False)
|
|
4289
|
+
"""
|
|
4290
|
+
import urllib.parse
|
|
4291
|
+
from rasterio.errors import RasterioIOError
|
|
4292
|
+
|
|
4293
|
+
# Determine if source is a URL or local file
|
|
4294
|
+
parsed_url = urllib.parse.urlparse(source)
|
|
4295
|
+
is_url = parsed_url.scheme in ["http", "https"]
|
|
4296
|
+
|
|
4297
|
+
# If it's a local file, check if it exists
|
|
4298
|
+
if not is_url and not os.path.exists(source):
|
|
4299
|
+
raise ValueError(f"Raster file does not exist: {source}")
|
|
4300
|
+
|
|
4301
|
+
try:
|
|
4302
|
+
# Open the raster with rioxarray
|
|
4303
|
+
raster = rxr.open_rasterio(source, masked=masked, **kwargs)
|
|
4304
|
+
|
|
4305
|
+
# Handle band selection if specified
|
|
4306
|
+
if band is not None:
|
|
4307
|
+
if isinstance(band, (list, tuple)):
|
|
4308
|
+
# Convert from 1-based indexing to 0-based indexing
|
|
4309
|
+
band_indices = [b - 1 for b in band]
|
|
4310
|
+
raster = raster.isel(band=band_indices)
|
|
4311
|
+
else:
|
|
4312
|
+
# Single band selection (convert from 1-based to 0-based indexing)
|
|
4313
|
+
raster = raster.isel(band=band - 1)
|
|
4314
|
+
|
|
4315
|
+
return raster
|
|
4316
|
+
|
|
4317
|
+
except RasterioIOError as e:
|
|
4318
|
+
raise ValueError(f"Could not read raster from source '{source}': {str(e)}")
|
|
4319
|
+
except Exception as e:
|
|
4320
|
+
raise ValueError(f"Error reading raster data: {str(e)}")
|
|
4321
|
+
|
|
4322
|
+
|
|
4323
|
+
def temp_file_path(ext):
|
|
4324
|
+
"""Returns a temporary file path.
|
|
4325
|
+
|
|
4326
|
+
Args:
|
|
4327
|
+
ext (str): The file extension.
|
|
4328
|
+
|
|
4329
|
+
Returns:
|
|
4330
|
+
str: The temporary file path.
|
|
4331
|
+
"""
|
|
4332
|
+
|
|
4333
|
+
import tempfile
|
|
4334
|
+
import uuid
|
|
4335
|
+
|
|
4336
|
+
if not ext.startswith("."):
|
|
4337
|
+
ext = "." + ext
|
|
4338
|
+
file_id = str(uuid.uuid4())
|
|
4339
|
+
file_path = os.path.join(tempfile.gettempdir(), f"{file_id}{ext}")
|
|
4340
|
+
|
|
4341
|
+
return file_path
|
|
4342
|
+
|
|
4343
|
+
|
|
4344
|
+
def region_groups(
|
|
4345
|
+
image: Union[str, "xr.DataArray", np.ndarray],
|
|
4346
|
+
connectivity: int = 1,
|
|
4347
|
+
min_size: int = 10,
|
|
4348
|
+
max_size: Optional[int] = None,
|
|
4349
|
+
threshold: Optional[int] = None,
|
|
4350
|
+
properties: Optional[List[str]] = None,
|
|
4351
|
+
intensity_image: Optional[Union[str, "xr.DataArray", np.ndarray]] = None,
|
|
4352
|
+
out_csv: Optional[str] = None,
|
|
4353
|
+
out_vector: Optional[str] = None,
|
|
4354
|
+
out_image: Optional[str] = None,
|
|
4355
|
+
**kwargs: Any,
|
|
4356
|
+
) -> Union[Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]]:
|
|
4357
|
+
"""
|
|
4358
|
+
Segment regions in an image and filter them based on size.
|
|
4359
|
+
|
|
4360
|
+
Args:
|
|
4361
|
+
image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file
|
|
4362
|
+
path, xarray DataArray, or numpy array.
|
|
4363
|
+
connectivity (int, optional): Connectivity for labeling. Defaults to 1
|
|
4364
|
+
for 4-connectivity. Use 2 for 8-connectivity.
|
|
4365
|
+
min_size (int, optional): Minimum size of regions to keep. Defaults to 10.
|
|
4366
|
+
max_size (Optional[int], optional): Maximum size of regions to keep.
|
|
4367
|
+
Defaults to None.
|
|
4368
|
+
threshold (Optional[int], optional): Threshold for filling holes.
|
|
4369
|
+
Defaults to None, which is equal to min_size.
|
|
4370
|
+
properties (Optional[List[str]], optional): List of properties to measure.
|
|
4371
|
+
See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops
|
|
4372
|
+
Defaults to None.
|
|
4373
|
+
intensity_image (Optional[Union[str, xr.DataArray, np.ndarray]], optional):
|
|
4374
|
+
Intensity image to measure properties. Defaults to None.
|
|
4375
|
+
out_csv (Optional[str], optional): Path to save the properties as a CSV file.
|
|
4376
|
+
Defaults to None.
|
|
4377
|
+
out_vector (Optional[str], optional): Path to save the vector file.
|
|
4378
|
+
Defaults to None.
|
|
4379
|
+
out_image (Optional[str], optional): Path to save the output image.
|
|
4380
|
+
Defaults to None.
|
|
4381
|
+
|
|
4382
|
+
Returns:
|
|
4383
|
+
Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame.
|
|
4384
|
+
"""
|
|
4385
|
+
from skimage import measure
|
|
4386
|
+
import scipy.ndimage as ndi
|
|
4387
|
+
|
|
4388
|
+
if isinstance(image, str):
|
|
4389
|
+
ds = rxr.open_rasterio(image)
|
|
4390
|
+
da = ds.sel(band=1)
|
|
4391
|
+
array = da.values.squeeze()
|
|
4392
|
+
elif isinstance(image, xr.DataArray):
|
|
4393
|
+
da = image
|
|
4394
|
+
array = image.values.squeeze()
|
|
4395
|
+
elif isinstance(image, np.ndarray):
|
|
4396
|
+
array = image
|
|
4397
|
+
else:
|
|
4398
|
+
raise ValueError(
|
|
4399
|
+
"The input image must be a file path, xarray DataArray, or numpy array."
|
|
4400
|
+
)
|
|
4401
|
+
|
|
4402
|
+
if threshold is None:
|
|
4403
|
+
threshold = min_size
|
|
4404
|
+
|
|
4405
|
+
# Define a custom function to calculate median intensity
|
|
4406
|
+
def intensity_median(region, intensity_image):
|
|
4407
|
+
# Extract the intensity values for the region
|
|
4408
|
+
return np.median(intensity_image[region])
|
|
4409
|
+
|
|
4410
|
+
# Add your custom function to the list of extra properties
|
|
4411
|
+
if intensity_image is not None:
|
|
4412
|
+
extra_props = (intensity_median,)
|
|
4413
|
+
else:
|
|
4414
|
+
extra_props = None
|
|
4415
|
+
|
|
4416
|
+
if properties is None:
|
|
4417
|
+
properties = [
|
|
4418
|
+
"label",
|
|
4419
|
+
"area",
|
|
4420
|
+
"area_bbox",
|
|
4421
|
+
"area_convex",
|
|
4422
|
+
"area_filled",
|
|
4423
|
+
"major_length",
|
|
4424
|
+
"minor_length",
|
|
4425
|
+
"eccentricity",
|
|
4426
|
+
"diameter_areagth",
|
|
4427
|
+
"extent",
|
|
4428
|
+
"orientation",
|
|
4429
|
+
"perimeter",
|
|
4430
|
+
"solidity",
|
|
4431
|
+
]
|
|
4432
|
+
|
|
4433
|
+
if intensity_image is not None:
|
|
4434
|
+
|
|
4435
|
+
properties += [
|
|
4436
|
+
"intensity_max",
|
|
4437
|
+
"intensity_mean",
|
|
4438
|
+
"intensity_min",
|
|
4439
|
+
"intensity_std",
|
|
4440
|
+
]
|
|
4441
|
+
|
|
4442
|
+
if intensity_image is not None:
|
|
4443
|
+
if isinstance(intensity_image, str):
|
|
4444
|
+
ds = rxr.open_rasterio(intensity_image)
|
|
4445
|
+
intensity_da = ds.sel(band=1)
|
|
4446
|
+
intensity_image = intensity_da.values.squeeze()
|
|
4447
|
+
elif isinstance(intensity_image, xr.DataArray):
|
|
4448
|
+
intensity_image = intensity_image.values.squeeze()
|
|
4449
|
+
elif isinstance(intensity_image, np.ndarray):
|
|
4450
|
+
pass
|
|
4451
|
+
else:
|
|
4452
|
+
raise ValueError(
|
|
4453
|
+
"The intensity_image must be a file path, xarray DataArray, or numpy array."
|
|
4454
|
+
)
|
|
4455
|
+
|
|
4456
|
+
label_image = measure.label(array, connectivity=connectivity)
|
|
4457
|
+
props = measure.regionprops_table(
|
|
4458
|
+
label_image, properties=properties, intensity_image=intensity_image, **kwargs
|
|
4459
|
+
)
|
|
4460
|
+
|
|
4461
|
+
df = pd.DataFrame(props)
|
|
4462
|
+
|
|
4463
|
+
# Get the labels of regions with area smaller than the threshold
|
|
4464
|
+
small_regions = df[df["area"] < min_size]["label"].values
|
|
4465
|
+
# Set the corresponding labels in the label_image to zero
|
|
4466
|
+
for region_label in small_regions:
|
|
4467
|
+
label_image[label_image == region_label] = 0
|
|
4468
|
+
|
|
4469
|
+
if max_size is not None:
|
|
4470
|
+
large_regions = df[df["area"] > max_size]["label"].values
|
|
4471
|
+
for region_label in large_regions:
|
|
4472
|
+
label_image[label_image == region_label] = 0
|
|
4473
|
+
|
|
4474
|
+
# Find the background (holes) which are zeros
|
|
4475
|
+
holes = label_image == 0
|
|
4476
|
+
|
|
4477
|
+
# Label the holes (connected components in the background)
|
|
4478
|
+
labeled_holes, _ = ndi.label(holes)
|
|
4479
|
+
|
|
4480
|
+
# Measure properties of the labeled holes, including area and bounding box
|
|
4481
|
+
hole_props = measure.regionprops(labeled_holes)
|
|
4482
|
+
|
|
4483
|
+
# Loop through each hole and fill it if it is smaller than the threshold
|
|
4484
|
+
for prop in hole_props:
|
|
4485
|
+
if prop.area < threshold:
|
|
4486
|
+
# Get the coordinates of the small hole
|
|
4487
|
+
coords = prop.coords
|
|
4488
|
+
|
|
4489
|
+
# Find the surrounding region's ID (non-zero value near the hole)
|
|
4490
|
+
surrounding_region_values = []
|
|
4491
|
+
for coord in coords:
|
|
4492
|
+
x, y = coord
|
|
4493
|
+
# Get a 3x3 neighborhood around the hole pixel
|
|
4494
|
+
neighbors = label_image[max(0, x - 1) : x + 2, max(0, y - 1) : y + 2]
|
|
4495
|
+
# Exclude the hole pixels (zeros) and get region values
|
|
4496
|
+
region_values = neighbors[neighbors != 0]
|
|
4497
|
+
if region_values.size > 0:
|
|
4498
|
+
surrounding_region_values.append(
|
|
4499
|
+
region_values[0]
|
|
4500
|
+
) # Take the first non-zero value
|
|
4501
|
+
|
|
4502
|
+
if surrounding_region_values:
|
|
4503
|
+
# Fill the hole with the mode (most frequent) of the surrounding region values
|
|
4504
|
+
fill_value = max(
|
|
4505
|
+
set(surrounding_region_values), key=surrounding_region_values.count
|
|
4506
|
+
)
|
|
4507
|
+
label_image[coords[:, 0], coords[:, 1]] = fill_value
|
|
4508
|
+
|
|
4509
|
+
label_image, num_labels = measure.label(
|
|
4510
|
+
label_image, connectivity=connectivity, return_num=True
|
|
4511
|
+
)
|
|
4512
|
+
props = measure.regionprops_table(
|
|
4513
|
+
label_image,
|
|
4514
|
+
properties=properties,
|
|
4515
|
+
intensity_image=intensity_image,
|
|
4516
|
+
extra_properties=extra_props,
|
|
4517
|
+
**kwargs,
|
|
4518
|
+
)
|
|
4519
|
+
|
|
4520
|
+
df = pd.DataFrame(props)
|
|
4521
|
+
df["elongation"] = df["major_length"] / df["minor_length"]
|
|
4522
|
+
|
|
4523
|
+
dtype = "uint8"
|
|
4524
|
+
if num_labels > 255 and num_labels <= 65535:
|
|
4525
|
+
dtype = "uint16"
|
|
4526
|
+
elif num_labels > 65535:
|
|
4527
|
+
dtype = "uint32"
|
|
4528
|
+
|
|
4529
|
+
if out_csv is not None:
|
|
4530
|
+
df.to_csv(out_csv, index=False)
|
|
4531
|
+
|
|
4532
|
+
if isinstance(image, np.ndarray):
|
|
4533
|
+
return label_image, df
|
|
4534
|
+
else:
|
|
4535
|
+
da.values = label_image
|
|
4536
|
+
if out_image is not None:
|
|
4537
|
+
da.rio.to_raster(out_image, dtype=dtype)
|
|
4538
|
+
if out_vector is not None:
|
|
4539
|
+
tmp_vector = temp_file_path(".gpkg")
|
|
4540
|
+
raster_to_vector(out_image, tmp_vector)
|
|
4541
|
+
gdf = gpd.read_file(tmp_vector)
|
|
4542
|
+
gdf["label"] = gdf["value"].astype(int)
|
|
4543
|
+
gdf.drop(columns=["value"], inplace=True)
|
|
4544
|
+
gdf2 = pd.merge(gdf, df, on="label", how="left")
|
|
4545
|
+
gdf2.to_file(out_vector)
|
|
4546
|
+
gdf2.sort_values("label", inplace=True)
|
|
4547
|
+
df = gdf2
|
|
4548
|
+
return da, df
|
|
4549
|
+
|
|
4550
|
+
|
|
4551
|
+
def add_geometric_properties(data, properties=None, area_unit="m2", length_unit="m"):
|
|
4552
|
+
"""Calculates geometric properties and adds them to the GeoDataFrame.
|
|
4553
|
+
|
|
4554
|
+
This function calculates various geometric properties of features in a
|
|
4555
|
+
GeoDataFrame and adds them as new columns without modifying existing attributes.
|
|
4556
|
+
|
|
4557
|
+
Args:
|
|
4558
|
+
data: GeoDataFrame containing vector features.
|
|
4559
|
+
properties: List of geometric properties to calculate. Options include:
|
|
4560
|
+
'area', 'length', 'perimeter', 'centroid_x', 'centroid_y', 'bounds',
|
|
4561
|
+
'convex_hull_area', 'orientation', 'complexity', 'area_bbox',
|
|
4562
|
+
'area_convex', 'area_filled', 'major_length', 'minor_length',
|
|
4563
|
+
'eccentricity', 'diameter_areagth', 'extent', 'solidity',
|
|
4564
|
+
'elongation'.
|
|
4565
|
+
Defaults to ['area', 'length'] if None.
|
|
4566
|
+
area_unit: String specifying the unit for area calculation ('m2', 'km2',
|
|
4567
|
+
'ha'). Defaults to 'm2'.
|
|
4568
|
+
length_unit: String specifying the unit for length calculation ('m', 'km').
|
|
4569
|
+
Defaults to 'm'.
|
|
4570
|
+
|
|
4571
|
+
Returns:
|
|
4572
|
+
geopandas.GeoDataFrame: A copy of the input GeoDataFrame with added
|
|
4573
|
+
geometric property columns.
|
|
4574
|
+
"""
|
|
4575
|
+
from shapely.ops import unary_union
|
|
4576
|
+
|
|
4577
|
+
if isinstance(data, str):
|
|
4578
|
+
data = read_vector(data)
|
|
4579
|
+
|
|
4580
|
+
# Make a copy to avoid modifying the original
|
|
4581
|
+
result = data.copy()
|
|
4582
|
+
|
|
4583
|
+
# Default properties to calculate
|
|
4584
|
+
if properties is None:
|
|
4585
|
+
properties = [
|
|
4586
|
+
"area",
|
|
4587
|
+
"length",
|
|
4588
|
+
"perimeter",
|
|
4589
|
+
"convex_hull_area",
|
|
4590
|
+
"orientation",
|
|
4591
|
+
"complexity",
|
|
4592
|
+
"area_bbox",
|
|
4593
|
+
"area_convex",
|
|
4594
|
+
"area_filled",
|
|
4595
|
+
"major_length",
|
|
4596
|
+
"minor_length",
|
|
4597
|
+
"eccentricity",
|
|
4598
|
+
"diameter_area",
|
|
4599
|
+
"extent",
|
|
4600
|
+
"solidity",
|
|
4601
|
+
"elongation",
|
|
4602
|
+
]
|
|
4603
|
+
|
|
4604
|
+
# Make sure we're working with a GeoDataFrame with a valid CRS
|
|
4605
|
+
|
|
4606
|
+
if not isinstance(result, gpd.GeoDataFrame):
|
|
4607
|
+
raise ValueError("Input must be a GeoDataFrame")
|
|
4608
|
+
|
|
4609
|
+
if result.crs is None:
|
|
4610
|
+
raise ValueError(
|
|
4611
|
+
"GeoDataFrame must have a defined coordinate reference system (CRS)"
|
|
4612
|
+
)
|
|
4613
|
+
|
|
4614
|
+
# Ensure we're working with a projected CRS for accurate measurements
|
|
4615
|
+
if result.crs.is_geographic:
|
|
4616
|
+
# Reproject to a suitable projected CRS for accurate measurements
|
|
4617
|
+
result = result.to_crs(result.estimate_utm_crs())
|
|
4618
|
+
|
|
4619
|
+
# Basic area calculation with unit conversion
|
|
4620
|
+
if "area" in properties:
|
|
4621
|
+
# Calculate area (only for polygons)
|
|
4622
|
+
result["area"] = result.geometry.apply(
|
|
4623
|
+
lambda geom: geom.area if isinstance(geom, (Polygon, MultiPolygon)) else 0
|
|
4624
|
+
)
|
|
4625
|
+
|
|
4626
|
+
# Convert to requested units
|
|
4627
|
+
if area_unit == "km2":
|
|
4628
|
+
result["area"] = result["area"] / 1_000_000 # m² to km²
|
|
4629
|
+
result.rename(columns={"area": "area_km2"}, inplace=True)
|
|
4630
|
+
elif area_unit == "ha":
|
|
4631
|
+
result["area"] = result["area"] / 10_000 # m² to hectares
|
|
4632
|
+
result.rename(columns={"area": "area_ha"}, inplace=True)
|
|
4633
|
+
else: # Default is m²
|
|
4634
|
+
result.rename(columns={"area": "area_m2"}, inplace=True)
|
|
4635
|
+
|
|
4636
|
+
# Length calculation with unit conversion
|
|
4637
|
+
if "length" in properties:
|
|
4638
|
+
# Calculate length (works for lines and polygon boundaries)
|
|
4639
|
+
result["length"] = result.geometry.length
|
|
4640
|
+
|
|
4641
|
+
# Convert to requested units
|
|
4642
|
+
if length_unit == "km":
|
|
4643
|
+
result["length"] = result["length"] / 1_000 # m to km
|
|
4644
|
+
result.rename(columns={"length": "length_km"}, inplace=True)
|
|
4645
|
+
else: # Default is m
|
|
4646
|
+
result.rename(columns={"length": "length_m"}, inplace=True)
|
|
4647
|
+
|
|
4648
|
+
# Perimeter calculation (for polygons)
|
|
4649
|
+
if "perimeter" in properties:
|
|
4650
|
+
result["perimeter"] = result.geometry.apply(
|
|
4651
|
+
lambda geom: (
|
|
4652
|
+
geom.boundary.length if isinstance(geom, (Polygon, MultiPolygon)) else 0
|
|
4653
|
+
)
|
|
4654
|
+
)
|
|
4655
|
+
|
|
4656
|
+
# Convert to requested units
|
|
4657
|
+
if length_unit == "km":
|
|
4658
|
+
result["perimeter"] = result["perimeter"] / 1_000 # m to km
|
|
4659
|
+
result.rename(columns={"perimeter": "perimeter_km"}, inplace=True)
|
|
4660
|
+
else: # Default is m
|
|
4661
|
+
result.rename(columns={"perimeter": "perimeter_m"}, inplace=True)
|
|
4662
|
+
|
|
4663
|
+
# Centroid coordinates
|
|
4664
|
+
if "centroid_x" in properties or "centroid_y" in properties:
|
|
4665
|
+
centroids = result.geometry.centroid
|
|
4666
|
+
|
|
4667
|
+
if "centroid_x" in properties:
|
|
4668
|
+
result["centroid_x"] = centroids.x
|
|
4669
|
+
|
|
4670
|
+
if "centroid_y" in properties:
|
|
4671
|
+
result["centroid_y"] = centroids.y
|
|
4672
|
+
|
|
4673
|
+
# Bounding box properties
|
|
4674
|
+
if "bounds" in properties:
|
|
4675
|
+
bounds = result.geometry.bounds
|
|
4676
|
+
result["minx"] = bounds.minx
|
|
4677
|
+
result["miny"] = bounds.miny
|
|
4678
|
+
result["maxx"] = bounds.maxx
|
|
4679
|
+
result["maxy"] = bounds.maxy
|
|
4680
|
+
|
|
4681
|
+
# Area of bounding box
|
|
4682
|
+
if "area_bbox" in properties:
|
|
4683
|
+
bounds = result.geometry.bounds
|
|
4684
|
+
result["area_bbox"] = (bounds.maxx - bounds.minx) * (bounds.maxy - bounds.miny)
|
|
4685
|
+
|
|
4686
|
+
# Convert to requested units
|
|
4687
|
+
if area_unit == "km2":
|
|
4688
|
+
result["area_bbox"] = result["area_bbox"] / 1_000_000
|
|
4689
|
+
result.rename(columns={"area_bbox": "area_bbox_km2"}, inplace=True)
|
|
4690
|
+
elif area_unit == "ha":
|
|
4691
|
+
result["area_bbox"] = result["area_bbox"] / 10_000
|
|
4692
|
+
result.rename(columns={"area_bbox": "area_bbox_ha"}, inplace=True)
|
|
4693
|
+
else: # Default is m²
|
|
4694
|
+
result.rename(columns={"area_bbox": "area_bbox_m2"}, inplace=True)
|
|
4695
|
+
|
|
4696
|
+
# Area of convex hull
|
|
4697
|
+
if "area_convex" in properties or "convex_hull_area" in properties:
|
|
4698
|
+
result["area_convex"] = result.geometry.convex_hull.area
|
|
4699
|
+
|
|
4700
|
+
# Convert to requested units
|
|
4701
|
+
if area_unit == "km2":
|
|
4702
|
+
result["area_convex"] = result["area_convex"] / 1_000_000
|
|
4703
|
+
result.rename(columns={"area_convex": "area_convex_km2"}, inplace=True)
|
|
4704
|
+
elif area_unit == "ha":
|
|
4705
|
+
result["area_convex"] = result["area_convex"] / 10_000
|
|
4706
|
+
result.rename(columns={"area_convex": "area_convex_ha"}, inplace=True)
|
|
4707
|
+
else: # Default is m²
|
|
4708
|
+
result.rename(columns={"area_convex": "area_convex_m2"}, inplace=True)
|
|
4709
|
+
|
|
4710
|
+
# For backward compatibility
|
|
4711
|
+
if "convex_hull_area" in properties and "area_convex" not in properties:
|
|
4712
|
+
result["convex_hull_area"] = result["area_convex"]
|
|
4713
|
+
if area_unit == "km2":
|
|
4714
|
+
result.rename(
|
|
4715
|
+
columns={"convex_hull_area": "convex_hull_area_km2"}, inplace=True
|
|
4716
|
+
)
|
|
4717
|
+
elif area_unit == "ha":
|
|
4718
|
+
result.rename(
|
|
4719
|
+
columns={"convex_hull_area": "convex_hull_area_ha"}, inplace=True
|
|
4720
|
+
)
|
|
4721
|
+
else:
|
|
4722
|
+
result.rename(
|
|
4723
|
+
columns={"convex_hull_area": "convex_hull_area_m2"}, inplace=True
|
|
4724
|
+
)
|
|
4725
|
+
|
|
4726
|
+
# Area of filled geometry (no holes)
|
|
4727
|
+
if "area_filled" in properties:
|
|
4728
|
+
|
|
4729
|
+
def get_filled_area(geom):
|
|
4730
|
+
if not isinstance(geom, (Polygon, MultiPolygon)):
|
|
4731
|
+
return 0
|
|
4732
|
+
|
|
4733
|
+
if isinstance(geom, MultiPolygon):
|
|
4734
|
+
# For MultiPolygon, fill all constituent polygons
|
|
4735
|
+
filled_polys = [Polygon(p.exterior) for p in geom.geoms]
|
|
4736
|
+
return unary_union(filled_polys).area
|
|
4737
|
+
else:
|
|
4738
|
+
# For single Polygon, create a new one with just the exterior ring
|
|
4739
|
+
return Polygon(geom.exterior).area
|
|
4740
|
+
|
|
4741
|
+
result["area_filled"] = result.geometry.apply(get_filled_area)
|
|
4742
|
+
|
|
4743
|
+
# Convert to requested units
|
|
4744
|
+
if area_unit == "km2":
|
|
4745
|
+
result["area_filled"] = result["area_filled"] / 1_000_000
|
|
4746
|
+
result.rename(columns={"area_filled": "area_filled_km2"}, inplace=True)
|
|
4747
|
+
elif area_unit == "ha":
|
|
4748
|
+
result["area_filled"] = result["area_filled"] / 10_000
|
|
4749
|
+
result.rename(columns={"area_filled": "area_filled_ha"}, inplace=True)
|
|
4750
|
+
else: # Default is m²
|
|
4751
|
+
result.rename(columns={"area_filled": "area_filled_m2"}, inplace=True)
|
|
4752
|
+
|
|
4753
|
+
# Axes lengths, eccentricity, orientation, and elongation
|
|
4754
|
+
if any(
|
|
4755
|
+
p in properties
|
|
4756
|
+
for p in [
|
|
4757
|
+
"major_length",
|
|
4758
|
+
"minor_length",
|
|
4759
|
+
"eccentricity",
|
|
4760
|
+
"orientation",
|
|
4761
|
+
"elongation",
|
|
4762
|
+
]
|
|
4763
|
+
):
|
|
4764
|
+
|
|
4765
|
+
def get_axes_properties(geom):
|
|
4766
|
+
# Skip non-polygons
|
|
4767
|
+
if not isinstance(geom, (Polygon, MultiPolygon)):
|
|
4768
|
+
return None, None, None, None, None
|
|
4769
|
+
|
|
4770
|
+
# Handle multipolygons by using the largest polygon
|
|
4771
|
+
if isinstance(geom, MultiPolygon):
|
|
4772
|
+
# Get the polygon with the largest area
|
|
4773
|
+
geom = sorted(list(geom.geoms), key=lambda p: p.area, reverse=True)[0]
|
|
4774
|
+
|
|
4775
|
+
try:
|
|
4776
|
+
# Get the minimum rotated rectangle
|
|
4777
|
+
rect = geom.minimum_rotated_rectangle
|
|
4778
|
+
|
|
4779
|
+
# Extract coordinates
|
|
4780
|
+
coords = list(rect.exterior.coords)[
|
|
4781
|
+
:-1
|
|
4782
|
+
] # Remove the duplicated last point
|
|
4783
|
+
|
|
4784
|
+
if len(coords) < 4:
|
|
4785
|
+
return None, None, None, None, None
|
|
4786
|
+
|
|
4787
|
+
# Calculate lengths of all four sides
|
|
4788
|
+
sides = []
|
|
4789
|
+
for i in range(len(coords)):
|
|
4790
|
+
p1 = coords[i]
|
|
4791
|
+
p2 = coords[(i + 1) % len(coords)]
|
|
4792
|
+
dx = p2[0] - p1[0]
|
|
4793
|
+
dy = p2[1] - p1[1]
|
|
4794
|
+
length = np.sqrt(dx**2 + dy**2)
|
|
4795
|
+
angle = np.degrees(np.arctan2(dy, dx)) % 180
|
|
4796
|
+
sides.append((length, angle, p1, p2))
|
|
4797
|
+
|
|
4798
|
+
# Group sides by length (allowing for small differences due to floating point precision)
|
|
4799
|
+
# This ensures we correctly identify the rectangle's dimensions
|
|
4800
|
+
sides_grouped = {}
|
|
4801
|
+
tolerance = 1e-6 # Tolerance for length comparison
|
|
4802
|
+
|
|
4803
|
+
for s in sides:
|
|
4804
|
+
length, angle = s[0], s[1]
|
|
4805
|
+
matched = False
|
|
4806
|
+
|
|
4807
|
+
for key in sides_grouped:
|
|
4808
|
+
if abs(length - key) < tolerance:
|
|
4809
|
+
sides_grouped[key].append(s)
|
|
4810
|
+
matched = True
|
|
4811
|
+
break
|
|
4812
|
+
|
|
4813
|
+
if not matched:
|
|
4814
|
+
sides_grouped[length] = [s]
|
|
4815
|
+
|
|
4816
|
+
# Get unique lengths (should be 2 for a rectangle, parallel sides have equal length)
|
|
4817
|
+
unique_lengths = sorted(sides_grouped.keys(), reverse=True)
|
|
4818
|
+
|
|
4819
|
+
if len(unique_lengths) != 2:
|
|
4820
|
+
# If we don't get exactly 2 unique lengths, something is wrong with the rectangle
|
|
4821
|
+
# Fall back to simpler method using bounds
|
|
4822
|
+
bounds = rect.bounds
|
|
4823
|
+
width = bounds[2] - bounds[0]
|
|
4824
|
+
height = bounds[3] - bounds[1]
|
|
4825
|
+
major_length = max(width, height)
|
|
4826
|
+
minor_length = min(width, height)
|
|
4827
|
+
orientation = 0 if width > height else 90
|
|
4828
|
+
else:
|
|
4829
|
+
major_length = unique_lengths[0]
|
|
4830
|
+
minor_length = unique_lengths[1]
|
|
4831
|
+
# Get orientation from the major axis
|
|
4832
|
+
orientation = sides_grouped[major_length][0][1]
|
|
4833
|
+
|
|
4834
|
+
# Calculate eccentricity
|
|
4835
|
+
if major_length > 0:
|
|
4836
|
+
# Eccentricity for an ellipse: e = sqrt(1 - (b²/a²))
|
|
4837
|
+
# where a is the semi-major axis and b is the semi-minor axis
|
|
4838
|
+
eccentricity = np.sqrt(
|
|
4839
|
+
1 - ((minor_length / 2) ** 2 / (major_length / 2) ** 2)
|
|
4840
|
+
)
|
|
4841
|
+
else:
|
|
4842
|
+
eccentricity = 0
|
|
4843
|
+
|
|
4844
|
+
# Calculate elongation (ratio of minor to major axis)
|
|
4845
|
+
elongation = major_length / minor_length if major_length > 0 else 1
|
|
4846
|
+
|
|
4847
|
+
return major_length, minor_length, eccentricity, orientation, elongation
|
|
4848
|
+
|
|
4849
|
+
except Exception as e:
|
|
4850
|
+
# For debugging
|
|
4851
|
+
# print(f"Error calculating axes: {e}")
|
|
4852
|
+
return None, None, None, None, None
|
|
4853
|
+
|
|
4854
|
+
# Apply the function and split the results
|
|
4855
|
+
axes_data = result.geometry.apply(get_axes_properties)
|
|
4856
|
+
|
|
4857
|
+
if "major_length" in properties:
|
|
4858
|
+
result["major_length"] = axes_data.apply(lambda x: x[0] if x else None)
|
|
4859
|
+
# Convert to requested units
|
|
4860
|
+
if length_unit == "km":
|
|
4861
|
+
result["major_length"] = result["major_length"] / 1_000
|
|
4862
|
+
result.rename(columns={"major_length": "major_length_km"}, inplace=True)
|
|
4863
|
+
else:
|
|
4864
|
+
result.rename(columns={"major_length": "major_length_m"}, inplace=True)
|
|
4865
|
+
|
|
4866
|
+
if "minor_length" in properties:
|
|
4867
|
+
result["minor_length"] = axes_data.apply(lambda x: x[1] if x else None)
|
|
4868
|
+
# Convert to requested units
|
|
4869
|
+
if length_unit == "km":
|
|
4870
|
+
result["minor_length"] = result["minor_length"] / 1_000
|
|
4871
|
+
result.rename(columns={"minor_length": "minor_length_km"}, inplace=True)
|
|
4872
|
+
else:
|
|
4873
|
+
result.rename(columns={"minor_length": "minor_length_m"}, inplace=True)
|
|
4874
|
+
|
|
4875
|
+
if "eccentricity" in properties:
|
|
4876
|
+
result["eccentricity"] = axes_data.apply(lambda x: x[2] if x else None)
|
|
4877
|
+
|
|
4878
|
+
if "orientation" in properties:
|
|
4879
|
+
result["orientation"] = axes_data.apply(lambda x: x[3] if x else None)
|
|
4880
|
+
|
|
4881
|
+
if "elongation" in properties:
|
|
4882
|
+
result["elongation"] = axes_data.apply(lambda x: x[4] if x else None)
|
|
4883
|
+
|
|
4884
|
+
# Equivalent diameter based on area
|
|
4885
|
+
if "diameter_areagth" in properties:
|
|
4886
|
+
|
|
4887
|
+
def get_equivalent_diameter(geom):
|
|
4888
|
+
if not isinstance(geom, (Polygon, MultiPolygon)) or geom.area <= 0:
|
|
4889
|
+
return None
|
|
4890
|
+
# Diameter of a circle with the same area: d = 2 * sqrt(A / π)
|
|
4891
|
+
return 2 * np.sqrt(geom.area / np.pi)
|
|
4892
|
+
|
|
4893
|
+
result["diameter_areagth"] = result.geometry.apply(get_equivalent_diameter)
|
|
4894
|
+
|
|
4895
|
+
# Convert to requested units
|
|
4896
|
+
if length_unit == "km":
|
|
4897
|
+
result["diameter_areagth"] = result["diameter_areagth"] / 1_000
|
|
4898
|
+
result.rename(
|
|
4899
|
+
columns={"diameter_areagth": "equivalent_diameter_area_km"},
|
|
4900
|
+
inplace=True,
|
|
4901
|
+
)
|
|
4902
|
+
else:
|
|
4903
|
+
result.rename(
|
|
4904
|
+
columns={"diameter_areagth": "equivalent_diameter_area_m"},
|
|
4905
|
+
inplace=True,
|
|
4906
|
+
)
|
|
4907
|
+
|
|
4908
|
+
# Extent (ratio of shape area to bounding box area)
|
|
4909
|
+
if "extent" in properties:
|
|
4910
|
+
|
|
4911
|
+
def get_extent(geom):
|
|
4912
|
+
if not isinstance(geom, (Polygon, MultiPolygon)) or geom.area <= 0:
|
|
4913
|
+
return None
|
|
4914
|
+
|
|
4915
|
+
bounds = geom.bounds
|
|
4916
|
+
bbox_area = (bounds[2] - bounds[0]) * (bounds[3] - bounds[1])
|
|
4917
|
+
|
|
4918
|
+
if bbox_area > 0:
|
|
4919
|
+
return geom.area / bbox_area
|
|
4920
|
+
return None
|
|
4921
|
+
|
|
4922
|
+
result["extent"] = result.geometry.apply(get_extent)
|
|
4923
|
+
|
|
4924
|
+
# Solidity (ratio of shape area to convex hull area)
|
|
4925
|
+
if "solidity" in properties:
|
|
4926
|
+
|
|
4927
|
+
def get_solidity(geom):
|
|
4928
|
+
if not isinstance(geom, (Polygon, MultiPolygon)) or geom.area <= 0:
|
|
4929
|
+
return None
|
|
4930
|
+
|
|
4931
|
+
convex_hull_area = geom.convex_hull.area
|
|
4932
|
+
|
|
4933
|
+
if convex_hull_area > 0:
|
|
4934
|
+
return geom.area / convex_hull_area
|
|
4935
|
+
return None
|
|
4936
|
+
|
|
4937
|
+
result["solidity"] = result.geometry.apply(get_solidity)
|
|
4938
|
+
|
|
4939
|
+
# Complexity (ratio of perimeter to area)
|
|
4940
|
+
if "complexity" in properties:
|
|
4941
|
+
|
|
4942
|
+
def calc_complexity(geom):
|
|
4943
|
+
if isinstance(geom, (Polygon, MultiPolygon)) and geom.area > 0:
|
|
4944
|
+
# Shape index: P / (2 * sqrt(π * A))
|
|
4945
|
+
# Normalized to 1 for a circle, higher for more complex shapes
|
|
4946
|
+
return geom.boundary.length / (2 * np.sqrt(np.pi * geom.area))
|
|
4947
|
+
return None
|
|
4948
|
+
|
|
4949
|
+
result["complexity"] = result.geometry.apply(calc_complexity)
|
|
4950
|
+
|
|
4951
|
+
return result
|