geoai-py 0.3.2__py2.py3-none-any.whl → 0.3.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -1,22 +1,37 @@
1
- """The common module contains common functions and classes used by the other modules."""
1
+ """The utils module contains common functions and classes used by the other modules."""
2
2
 
3
+ import json
4
+ import math
3
5
  import os
4
6
  from collections.abc import Iterable
5
- from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ import requests
10
+ import warnings
11
+ import xml.etree.ElementTree as ET
12
+ from typing import Any, Dict, List, Optional, Tuple, Union
13
+ import numpy as np
6
14
  import matplotlib.pyplot as plt
7
15
  import geopandas as gpd
8
16
  import leafmap
9
- import torch
10
17
  import numpy as np
18
+ import pandas as pd
11
19
  import xarray as xr
12
- import rioxarray
13
- import rasterio as rio
14
- from torch.utils.data import DataLoader
20
+ import rioxarray as rxr
21
+ import rasterio
22
+ from torchvision.transforms import RandomRotation
23
+ from rasterio.windows import Window
24
+ from rasterio import features
25
+ from rasterio.plot import show
26
+ from shapely.geometry import box, shape, mapping, Polygon, MultiPolygon
27
+ from shapely.affinity import rotate
28
+ from tqdm import tqdm
29
+ import torch
30
+ import torchgeo
31
+ import cv2
15
32
 
16
33
  try:
17
- from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
18
- from torchgeo.samplers import RandomGeoSampler, Units
19
- from torchgeo.transforms import indices
34
+ from torchgeo.datasets import RasterDataset, unbind_samples
20
35
  except ImportError as e:
21
36
  raise ImportError(
22
37
  "Your torchgeo version is too old. Please upgrade to the latest version using 'pip install -U torchgeo'."
@@ -121,7 +136,7 @@ def view_image(
121
136
  if isinstance(image, torch.Tensor):
122
137
  image = image.cpu().numpy()
123
138
  elif isinstance(image, str):
124
- image = rio.open(image).read().transpose(1, 2, 0)
139
+ image = rasterio.open(image).read().transpose(1, 2, 0)
125
140
 
126
141
  plt.figure(figsize=figsize)
127
142
 
@@ -274,7 +289,6 @@ def calc_stats(
274
289
  Returns:
275
290
  Tuple[np.ndarray, np.ndarray]: The mean and standard deviation for each band.
276
291
  """
277
- import rasterio as rio
278
292
 
279
293
  # To avoid loading the entire dataset in memory, we will loop through each img
280
294
  # The filenames will be retrieved from the dataset's rtree index
@@ -288,7 +302,7 @@ def calc_stats(
288
302
  accum_std = 0
289
303
 
290
304
  for file in files:
291
- img = rio.open(file).read() / divide_by # type: ignore
305
+ img = rasterio.open(file).read() / divide_by # type: ignore
292
306
  accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
293
307
  accum_std += img.reshape((img.shape[0], -1)).std(axis=1)
294
308
 
@@ -396,7 +410,7 @@ def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
396
410
 
397
411
  def dict_to_image(
398
412
  data_dict: Dict[str, Any], output: Optional[str] = None, **kwargs
399
- ) -> rio.DatasetReader:
413
+ ) -> rasterio.DatasetReader:
400
414
  """Convert a dictionary containing spatial data to a rasterio dataset or save it to
401
415
  a file. The dictionary should contain the following keys: "crs", "bounds", and "image".
402
416
  It can be generated from a TorchGeo dataset sampler.
@@ -1174,3 +1188,3764 @@ def create_split_map(
1174
1188
  )
1175
1189
 
1176
1190
  return m
1191
+
1192
+
1193
+ def download_file(url, output_path=None, overwrite=False):
1194
+ """
1195
+ Download a file from a given URL with a progress bar.
1196
+
1197
+ Args:
1198
+ url (str): The URL of the file to download.
1199
+ output_path (str, optional): The path where the downloaded file will be saved.
1200
+ If not provided, the filename from the URL will be used.
1201
+ overwrite (bool, optional): Whether to overwrite the file if it already exists.
1202
+
1203
+ Returns:
1204
+ str: The path to the downloaded file.
1205
+ """
1206
+ # Get the filename from the URL if output_path is not provided
1207
+ if output_path is None:
1208
+ output_path = os.path.basename(url)
1209
+
1210
+ # Check if the file already exists
1211
+ if os.path.exists(output_path) and not overwrite:
1212
+ print(f"File already exists: {output_path}")
1213
+ return output_path
1214
+
1215
+ # Send a streaming GET request
1216
+ response = requests.get(url, stream=True, timeout=50)
1217
+ response.raise_for_status() # Raise an exception for HTTP errors
1218
+
1219
+ # Get the total file size if available
1220
+ total_size = int(response.headers.get("content-length", 0))
1221
+
1222
+ # Open the output file
1223
+ with (
1224
+ open(output_path, "wb") as file,
1225
+ tqdm(
1226
+ desc=os.path.basename(output_path),
1227
+ total=total_size,
1228
+ unit="B",
1229
+ unit_scale=True,
1230
+ unit_divisor=1024,
1231
+ ) as progress_bar,
1232
+ ):
1233
+
1234
+ # Download the file in chunks and update the progress bar
1235
+ for chunk in response.iter_content(chunk_size=1024):
1236
+ if chunk: # filter out keep-alive new chunks
1237
+ file.write(chunk)
1238
+ progress_bar.update(len(chunk))
1239
+
1240
+ return output_path
1241
+
1242
+
1243
+ def get_raster_info(raster_path):
1244
+ """Display basic information about a raster dataset.
1245
+
1246
+ Args:
1247
+ raster_path (str): Path to the raster file
1248
+
1249
+ Returns:
1250
+ dict: Dictionary containing the basic information about the raster
1251
+ """
1252
+ # Open the raster dataset
1253
+ with rasterio.open(raster_path) as src:
1254
+ # Get basic metadata
1255
+ info = {
1256
+ "driver": src.driver,
1257
+ "width": src.width,
1258
+ "height": src.height,
1259
+ "count": src.count,
1260
+ "dtype": src.dtypes[0],
1261
+ "crs": src.crs.to_string() if src.crs else "No CRS defined",
1262
+ "transform": src.transform,
1263
+ "bounds": src.bounds,
1264
+ "resolution": (src.transform[0], -src.transform[4]),
1265
+ "nodata": src.nodata,
1266
+ }
1267
+
1268
+ # Calculate statistics for each band
1269
+ stats = []
1270
+ for i in range(1, src.count + 1):
1271
+ band = src.read(i, masked=True)
1272
+ band_stats = {
1273
+ "band": i,
1274
+ "min": float(band.min()),
1275
+ "max": float(band.max()),
1276
+ "mean": float(band.mean()),
1277
+ "std": float(band.std()),
1278
+ }
1279
+ stats.append(band_stats)
1280
+
1281
+ info["band_stats"] = stats
1282
+
1283
+ return info
1284
+
1285
+
1286
+ def get_raster_stats(raster_path, divide_by=1.0):
1287
+ """Calculate statistics for each band in a raster dataset.
1288
+
1289
+ This function computes min, max, mean, and standard deviation values
1290
+ for each band in the provided raster, returning results in a dictionary
1291
+ with lists for each statistic type.
1292
+
1293
+ Args:
1294
+ raster_path (str): Path to the raster file
1295
+ divide_by (float, optional): Value to divide pixel values by.
1296
+ Defaults to 1.0, which keeps the original pixel
1297
+
1298
+ Returns:
1299
+ dict: Dictionary containing lists of statistics with keys:
1300
+ - 'min': List of minimum values for each band
1301
+ - 'max': List of maximum values for each band
1302
+ - 'mean': List of mean values for each band
1303
+ - 'std': List of standard deviation values for each band
1304
+ """
1305
+ # Initialize the results dictionary with empty lists
1306
+ stats = {"min": [], "max": [], "mean": [], "std": []}
1307
+
1308
+ # Open the raster dataset
1309
+ with rasterio.open(raster_path) as src:
1310
+ # Calculate statistics for each band
1311
+ for i in range(1, src.count + 1):
1312
+ band = src.read(i, masked=True)
1313
+
1314
+ # Append statistics for this band to each list
1315
+ stats["min"].append(float(band.min()) / divide_by)
1316
+ stats["max"].append(float(band.max()) / divide_by)
1317
+ stats["mean"].append(float(band.mean()) / divide_by)
1318
+ stats["std"].append(float(band.std()) / divide_by)
1319
+
1320
+ return stats
1321
+
1322
+
1323
+ def print_raster_info(raster_path, show_preview=True, figsize=(10, 8)):
1324
+ """Print formatted information about a raster dataset and optionally show a preview.
1325
+
1326
+ Args:
1327
+ raster_path (str): Path to the raster file
1328
+ show_preview (bool, optional): Whether to display a visual preview of the raster.
1329
+ Defaults to True.
1330
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
1331
+
1332
+ Returns:
1333
+ dict: Dictionary containing raster information if successful, None otherwise
1334
+ """
1335
+ try:
1336
+ info = get_raster_info(raster_path)
1337
+
1338
+ # Print basic information
1339
+ print(f"===== RASTER INFORMATION: {raster_path} =====")
1340
+ print(f"Driver: {info['driver']}")
1341
+ print(f"Dimensions: {info['width']} x {info['height']} pixels")
1342
+ print(f"Number of bands: {info['count']}")
1343
+ print(f"Data type: {info['dtype']}")
1344
+ print(f"Coordinate Reference System: {info['crs']}")
1345
+ print(f"Georeferenced Bounds: {info['bounds']}")
1346
+ print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
1347
+ print(f"NoData Value: {info['nodata']}")
1348
+
1349
+ # Print band statistics
1350
+ print("\n----- Band Statistics -----")
1351
+ for band_stat in info["band_stats"]:
1352
+ print(f"Band {band_stat['band']}:")
1353
+ print(f" Min: {band_stat['min']:.2f}")
1354
+ print(f" Max: {band_stat['max']:.2f}")
1355
+ print(f" Mean: {band_stat['mean']:.2f}")
1356
+ print(f" Std Dev: {band_stat['std']:.2f}")
1357
+
1358
+ # Show a preview if requested
1359
+ if show_preview:
1360
+ with rasterio.open(raster_path) as src:
1361
+ # For multi-band images, show RGB composite or first band
1362
+ if src.count >= 3:
1363
+ # Try to show RGB composite
1364
+ rgb = np.dstack([src.read(i) for i in range(1, 4)])
1365
+ plt.figure(figsize=figsize)
1366
+ plt.imshow(rgb)
1367
+ plt.title(f"RGB Preview: {raster_path}")
1368
+ else:
1369
+ # Show first band for single-band images
1370
+ plt.figure(figsize=figsize)
1371
+ show(
1372
+ src.read(1),
1373
+ cmap="viridis",
1374
+ title=f"Band 1 Preview: {raster_path}",
1375
+ )
1376
+ plt.colorbar(label="Pixel Value")
1377
+ plt.show()
1378
+
1379
+ except Exception as e:
1380
+ print(f"Error reading raster: {str(e)}")
1381
+
1382
+
1383
+ def get_raster_info_gdal(raster_path):
1384
+ """Get basic information about a raster dataset using GDAL.
1385
+
1386
+ Args:
1387
+ raster_path (str): Path to the raster file
1388
+
1389
+ Returns:
1390
+ dict: Dictionary containing the basic information about the raster,
1391
+ or None if the file cannot be opened
1392
+ """
1393
+
1394
+ from osgeo import gdal
1395
+
1396
+ # Open the dataset
1397
+ ds = gdal.Open(raster_path)
1398
+ if ds is None:
1399
+ print(f"Error: Could not open {raster_path}")
1400
+ return None
1401
+
1402
+ # Get basic information
1403
+ info = {
1404
+ "driver": ds.GetDriver().ShortName,
1405
+ "width": ds.RasterXSize,
1406
+ "height": ds.RasterYSize,
1407
+ "count": ds.RasterCount,
1408
+ "projection": ds.GetProjection(),
1409
+ "geotransform": ds.GetGeoTransform(),
1410
+ }
1411
+
1412
+ # Calculate resolution
1413
+ gt = ds.GetGeoTransform()
1414
+ if gt:
1415
+ info["resolution"] = (abs(gt[1]), abs(gt[5]))
1416
+ info["origin"] = (gt[0], gt[3])
1417
+
1418
+ # Get band information
1419
+ bands_info = []
1420
+ for i in range(1, ds.RasterCount + 1):
1421
+ band = ds.GetRasterBand(i)
1422
+ stats = band.GetStatistics(True, True)
1423
+ band_info = {
1424
+ "band": i,
1425
+ "datatype": gdal.GetDataTypeName(band.DataType),
1426
+ "min": stats[0],
1427
+ "max": stats[1],
1428
+ "mean": stats[2],
1429
+ "std": stats[3],
1430
+ "nodata": band.GetNoDataValue(),
1431
+ }
1432
+ bands_info.append(band_info)
1433
+
1434
+ info["bands"] = bands_info
1435
+
1436
+ # Close the dataset
1437
+ ds = None
1438
+
1439
+ return info
1440
+
1441
+
1442
+ def get_vector_info(vector_path):
1443
+ """Display basic information about a vector dataset using GeoPandas.
1444
+
1445
+ Args:
1446
+ vector_path (str): Path to the vector file
1447
+
1448
+ Returns:
1449
+ dict: Dictionary containing the basic information about the vector dataset
1450
+ """
1451
+ # Open the vector dataset
1452
+ gdf = (
1453
+ gpd.read_parquet(vector_path)
1454
+ if vector_path.endswith(".parquet")
1455
+ else gpd.read_file(vector_path)
1456
+ )
1457
+
1458
+ # Get basic metadata
1459
+ info = {
1460
+ "file_path": vector_path,
1461
+ "driver": os.path.splitext(vector_path)[1][1:].upper(), # Format from extension
1462
+ "feature_count": len(gdf),
1463
+ "crs": str(gdf.crs),
1464
+ "geometry_type": str(gdf.geom_type.value_counts().to_dict()),
1465
+ "attribute_count": len(gdf.columns) - 1, # Subtract the geometry column
1466
+ "attribute_names": list(gdf.columns[gdf.columns != "geometry"]),
1467
+ "bounds": gdf.total_bounds.tolist(),
1468
+ }
1469
+
1470
+ # Add statistics about numeric attributes
1471
+ numeric_columns = gdf.select_dtypes(include=["number"]).columns
1472
+ attribute_stats = {}
1473
+ for col in numeric_columns:
1474
+ if col != "geometry":
1475
+ attribute_stats[col] = {
1476
+ "min": gdf[col].min(),
1477
+ "max": gdf[col].max(),
1478
+ "mean": gdf[col].mean(),
1479
+ "std": gdf[col].std(),
1480
+ "null_count": gdf[col].isna().sum(),
1481
+ }
1482
+
1483
+ info["attribute_stats"] = attribute_stats
1484
+
1485
+ return info
1486
+
1487
+
1488
+ def print_vector_info(vector_path, show_preview=True, figsize=(10, 8)):
1489
+ """Print formatted information about a vector dataset and optionally show a preview.
1490
+
1491
+ Args:
1492
+ vector_path (str): Path to the vector file
1493
+ show_preview (bool, optional): Whether to display a visual preview of the vector data.
1494
+ Defaults to True.
1495
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
1496
+
1497
+ Returns:
1498
+ dict: Dictionary containing vector information if successful, None otherwise
1499
+ """
1500
+ try:
1501
+ info = get_vector_info(vector_path)
1502
+
1503
+ # Print basic information
1504
+ print(f"===== VECTOR INFORMATION: {vector_path} =====")
1505
+ print(f"Driver: {info['driver']}")
1506
+ print(f"Feature count: {info['feature_count']}")
1507
+ print(f"Geometry types: {info['geometry_type']}")
1508
+ print(f"Coordinate Reference System: {info['crs']}")
1509
+ print(f"Bounds: {info['bounds']}")
1510
+ print(f"Number of attributes: {info['attribute_count']}")
1511
+ print(f"Attribute names: {', '.join(info['attribute_names'])}")
1512
+
1513
+ # Print attribute statistics
1514
+ if info["attribute_stats"]:
1515
+ print("\n----- Attribute Statistics -----")
1516
+ for attr, stats in info["attribute_stats"].items():
1517
+ print(f"Attribute: {attr}")
1518
+ for stat_name, stat_value in stats.items():
1519
+ print(
1520
+ f" {stat_name}: {stat_value:.4f}"
1521
+ if isinstance(stat_value, float)
1522
+ else f" {stat_name}: {stat_value}"
1523
+ )
1524
+
1525
+ # Show a preview if requested
1526
+ if show_preview:
1527
+ gdf = (
1528
+ gpd.read_parquet(vector_path)
1529
+ if vector_path.endswith(".parquet")
1530
+ else gpd.read_file(vector_path)
1531
+ )
1532
+ fig, ax = plt.subplots(figsize=figsize)
1533
+ gdf.plot(ax=ax, cmap="viridis")
1534
+ ax.set_title(f"Preview: {vector_path}")
1535
+ plt.tight_layout()
1536
+ plt.show()
1537
+
1538
+ # # Show a sample of the attribute table
1539
+ # if not gdf.empty:
1540
+ # print("\n----- Sample of attribute table (first 5 rows) -----")
1541
+ # print(gdf.head().to_string())
1542
+
1543
+ except Exception as e:
1544
+ print(f"Error reading vector data: {str(e)}")
1545
+
1546
+
1547
+ # Alternative implementation using OGR directly
1548
+ def get_vector_info_ogr(vector_path):
1549
+ """Get basic information about a vector dataset using OGR.
1550
+
1551
+ Args:
1552
+ vector_path (str): Path to the vector file
1553
+
1554
+ Returns:
1555
+ dict: Dictionary containing the basic information about the vector dataset,
1556
+ or None if the file cannot be opened
1557
+ """
1558
+ from osgeo import ogr
1559
+
1560
+ # Register all OGR drivers
1561
+ ogr.RegisterAll()
1562
+
1563
+ # Open the dataset
1564
+ ds = ogr.Open(vector_path)
1565
+ if ds is None:
1566
+ print(f"Error: Could not open {vector_path}")
1567
+ return None
1568
+
1569
+ # Basic dataset information
1570
+ info = {
1571
+ "file_path": vector_path,
1572
+ "driver": ds.GetDriver().GetName(),
1573
+ "layer_count": ds.GetLayerCount(),
1574
+ "layers": [],
1575
+ }
1576
+
1577
+ # Extract information for each layer
1578
+ for i in range(ds.GetLayerCount()):
1579
+ layer = ds.GetLayer(i)
1580
+ layer_info = {
1581
+ "name": layer.GetName(),
1582
+ "feature_count": layer.GetFeatureCount(),
1583
+ "geometry_type": ogr.GeometryTypeToName(layer.GetGeomType()),
1584
+ "spatial_ref": (
1585
+ layer.GetSpatialRef().ExportToWkt() if layer.GetSpatialRef() else "None"
1586
+ ),
1587
+ "extent": layer.GetExtent(),
1588
+ "fields": [],
1589
+ }
1590
+
1591
+ # Get field information
1592
+ defn = layer.GetLayerDefn()
1593
+ for j in range(defn.GetFieldCount()):
1594
+ field_defn = defn.GetFieldDefn(j)
1595
+ field_info = {
1596
+ "name": field_defn.GetName(),
1597
+ "type": field_defn.GetTypeName(),
1598
+ "width": field_defn.GetWidth(),
1599
+ "precision": field_defn.GetPrecision(),
1600
+ }
1601
+ layer_info["fields"].append(field_info)
1602
+
1603
+ info["layers"].append(layer_info)
1604
+
1605
+ # Close the dataset
1606
+ ds = None
1607
+
1608
+ return info
1609
+
1610
+
1611
+ def analyze_vector_attributes(vector_path, attribute_name):
1612
+ """Analyze a specific attribute in a vector dataset and create a histogram.
1613
+
1614
+ Args:
1615
+ vector_path (str): Path to the vector file
1616
+ attribute_name (str): Name of the attribute to analyze
1617
+
1618
+ Returns:
1619
+ dict: Dictionary containing analysis results for the attribute
1620
+ """
1621
+ try:
1622
+ gdf = gpd.read_file(vector_path)
1623
+
1624
+ # Check if attribute exists
1625
+ if attribute_name not in gdf.columns:
1626
+ print(f"Attribute '{attribute_name}' not found in the dataset")
1627
+ return None
1628
+
1629
+ # Get the attribute series
1630
+ attr = gdf[attribute_name]
1631
+
1632
+ # Perform different analyses based on data type
1633
+ if pd.api.types.is_numeric_dtype(attr):
1634
+ # Numeric attribute
1635
+ analysis = {
1636
+ "attribute": attribute_name,
1637
+ "type": "numeric",
1638
+ "count": attr.count(),
1639
+ "null_count": attr.isna().sum(),
1640
+ "min": attr.min(),
1641
+ "max": attr.max(),
1642
+ "mean": attr.mean(),
1643
+ "median": attr.median(),
1644
+ "std": attr.std(),
1645
+ "unique_values": attr.nunique(),
1646
+ }
1647
+
1648
+ # Create histogram
1649
+ plt.figure(figsize=(10, 6))
1650
+ plt.hist(attr.dropna(), bins=20, alpha=0.7, color="blue")
1651
+ plt.title(f"Histogram of {attribute_name}")
1652
+ plt.xlabel(attribute_name)
1653
+ plt.ylabel("Frequency")
1654
+ plt.grid(True, alpha=0.3)
1655
+ plt.show()
1656
+
1657
+ else:
1658
+ # Categorical attribute
1659
+ analysis = {
1660
+ "attribute": attribute_name,
1661
+ "type": "categorical",
1662
+ "count": attr.count(),
1663
+ "null_count": attr.isna().sum(),
1664
+ "unique_values": attr.nunique(),
1665
+ "value_counts": attr.value_counts().to_dict(),
1666
+ }
1667
+
1668
+ # Create bar plot for top categories
1669
+ top_n = min(10, attr.nunique())
1670
+ plt.figure(figsize=(10, 6))
1671
+ attr.value_counts().head(top_n).plot(kind="bar", color="skyblue")
1672
+ plt.title(f"Top {top_n} values for {attribute_name}")
1673
+ plt.xlabel(attribute_name)
1674
+ plt.ylabel("Count")
1675
+ plt.xticks(rotation=45)
1676
+ plt.grid(True, alpha=0.3)
1677
+ plt.tight_layout()
1678
+ plt.show()
1679
+
1680
+ return analysis
1681
+
1682
+ except Exception as e:
1683
+ print(f"Error analyzing attribute: {str(e)}")
1684
+ return None
1685
+
1686
+
1687
+ def visualize_vector_by_attribute(
1688
+ vector_path, attribute_name, cmap="viridis", figsize=(10, 8)
1689
+ ):
1690
+ """Create a thematic map visualization of vector data based on an attribute.
1691
+
1692
+ Args:
1693
+ vector_path (str): Path to the vector file
1694
+ attribute_name (str): Name of the attribute to visualize
1695
+ cmap (str, optional): Matplotlib colormap name. Defaults to 'viridis'.
1696
+ figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
1697
+
1698
+ Returns:
1699
+ bool: True if visualization was successful, False otherwise
1700
+ """
1701
+ try:
1702
+ # Read the vector data
1703
+ gdf = gpd.read_file(vector_path)
1704
+
1705
+ # Check if attribute exists
1706
+ if attribute_name not in gdf.columns:
1707
+ print(f"Attribute '{attribute_name}' not found in the dataset")
1708
+ return False
1709
+
1710
+ # Create the plot
1711
+ fig, ax = plt.subplots(figsize=figsize)
1712
+
1713
+ # Determine plot type based on data type
1714
+ if pd.api.types.is_numeric_dtype(gdf[attribute_name]):
1715
+ # Continuous data
1716
+ gdf.plot(column=attribute_name, cmap=cmap, legend=True, ax=ax)
1717
+ else:
1718
+ # Categorical data
1719
+ gdf.plot(column=attribute_name, categorical=True, legend=True, ax=ax)
1720
+
1721
+ # Add title and labels
1722
+ ax.set_title(f"{os.path.basename(vector_path)} - {attribute_name}")
1723
+ ax.set_xlabel("Longitude")
1724
+ ax.set_ylabel("Latitude")
1725
+
1726
+ # Add basemap or additional elements if available
1727
+ # Note: Additional options could be added here for more complex maps
1728
+
1729
+ plt.tight_layout()
1730
+ plt.show()
1731
+
1732
+ except Exception as e:
1733
+ print(f"Error visualizing data: {str(e)}")
1734
+
1735
+
1736
+ def clip_raster_by_bbox(
1737
+ input_raster, output_raster, bbox, bands=None, bbox_type="geo", bbox_crs=None
1738
+ ):
1739
+ """
1740
+ Clip a raster dataset using a bounding box and optionally select specific bands.
1741
+
1742
+ Args:
1743
+ input_raster (str): Path to the input raster file.
1744
+ output_raster (str): Path where the clipped raster will be saved.
1745
+ bbox (tuple): Bounding box coordinates either as:
1746
+ - Geographic coordinates (minx, miny, maxx, maxy) if bbox_type="geo"
1747
+ - Pixel indices (min_row, min_col, max_row, max_col) if bbox_type="pixel"
1748
+ bands (list, optional): List of band indices to keep (1-based indexing).
1749
+ If None, all bands will be kept.
1750
+ bbox_type (str, optional): Type of bounding box coordinates. Either "geo" for
1751
+ geographic coordinates or "pixel" for row/column indices.
1752
+ Default is "geo".
1753
+ bbox_crs (str or dict, optional): CRS of the bbox if different from the raster CRS.
1754
+ Can be provided as EPSG code (e.g., "EPSG:4326") or
1755
+ as a proj4 string. Only applies when bbox_type="geo".
1756
+ If None, assumes bbox is in the same CRS as the raster.
1757
+
1758
+ Returns:
1759
+ str: Path to the clipped output raster.
1760
+
1761
+ Raises:
1762
+ ImportError: If required dependencies are not installed.
1763
+ ValueError: If the bbox is invalid, bands are out of range, or bbox_type is invalid.
1764
+ RuntimeError: If the clipping operation fails.
1765
+
1766
+ Examples:
1767
+ Clip using geographic coordinates in the same CRS as the raster
1768
+ >>> clip_raster_by_bbox('input.tif', 'clipped_geo.tif', (100, 200, 300, 400))
1769
+ 'clipped_geo.tif'
1770
+
1771
+ Clip using WGS84 coordinates when the raster is in a different CRS
1772
+ >>> clip_raster_by_bbox('input.tif', 'clipped_wgs84.tif', (-122.5, 37.7, -122.4, 37.8),
1773
+ ... bbox_crs="EPSG:4326")
1774
+ 'clipped_wgs84.tif'
1775
+
1776
+ Clip using row/column indices
1777
+ >>> clip_raster_by_bbox('input.tif', 'clipped_pixel.tif', (50, 100, 150, 200),
1778
+ ... bbox_type="pixel")
1779
+ 'clipped_pixel.tif'
1780
+
1781
+ Clip with band selection
1782
+ >>> clip_raster_by_bbox('input.tif', 'clipped_bands.tif', (100, 200, 300, 400),
1783
+ ... bands=[1, 3])
1784
+ 'clipped_bands.tif'
1785
+ """
1786
+ from rasterio.transform import from_bounds
1787
+ from rasterio.warp import transform_bounds
1788
+
1789
+ # Validate bbox_type
1790
+ if bbox_type not in ["geo", "pixel"]:
1791
+ raise ValueError("bbox_type must be either 'geo' or 'pixel'")
1792
+
1793
+ # Validate bbox
1794
+ if len(bbox) != 4:
1795
+ raise ValueError("bbox must contain exactly 4 values")
1796
+
1797
+ # Open the source raster
1798
+ with rasterio.open(input_raster) as src:
1799
+ # Get the source CRS
1800
+ src_crs = src.crs
1801
+
1802
+ # Handle different bbox types
1803
+ if bbox_type == "geo":
1804
+ minx, miny, maxx, maxy = bbox
1805
+
1806
+ # Validate geographic bbox
1807
+ if minx >= maxx or miny >= maxy:
1808
+ raise ValueError(
1809
+ "Invalid geographic bbox. Expected (minx, miny, maxx, maxy) where minx < maxx and miny < maxy"
1810
+ )
1811
+
1812
+ # If bbox_crs is provided and different from the source CRS, transform the bbox
1813
+ if bbox_crs is not None and bbox_crs != src_crs:
1814
+ try:
1815
+ # Transform bbox coordinates from bbox_crs to src_crs
1816
+ minx, miny, maxx, maxy = transform_bounds(
1817
+ bbox_crs, src_crs, minx, miny, maxx, maxy
1818
+ )
1819
+ except Exception as e:
1820
+ raise ValueError(
1821
+ f"Failed to transform bbox from {bbox_crs} to {src_crs}: {str(e)}"
1822
+ )
1823
+
1824
+ # Calculate the pixel window from geographic coordinates
1825
+ window = src.window(minx, miny, maxx, maxy)
1826
+
1827
+ # Use the same bounds for the output transform
1828
+ output_bounds = (minx, miny, maxx, maxy)
1829
+
1830
+ else: # bbox_type == "pixel"
1831
+ min_row, min_col, max_row, max_col = bbox
1832
+
1833
+ # Validate pixel bbox
1834
+ if min_row >= max_row or min_col >= max_col:
1835
+ raise ValueError(
1836
+ "Invalid pixel bbox. Expected (min_row, min_col, max_row, max_col) where min_row < max_row and min_col < max_col"
1837
+ )
1838
+
1839
+ if (
1840
+ min_row < 0
1841
+ or min_col < 0
1842
+ or max_row > src.height
1843
+ or max_col > src.width
1844
+ ):
1845
+ raise ValueError(
1846
+ f"Pixel indices out of bounds. Raster dimensions are {src.height} rows x {src.width} columns"
1847
+ )
1848
+
1849
+ # Create a window from pixel coordinates
1850
+ window = Window(min_col, min_row, max_col - min_col, max_row - min_row)
1851
+
1852
+ # Calculate the geographic bounds for this window
1853
+ window_transform = src.window_transform(window)
1854
+ output_bounds = rasterio.transform.array_bounds(
1855
+ window.height, window.width, window_transform
1856
+ )
1857
+ # Reorder to (minx, miny, maxx, maxy)
1858
+ output_bounds = (
1859
+ output_bounds[0],
1860
+ output_bounds[1],
1861
+ output_bounds[2],
1862
+ output_bounds[3],
1863
+ )
1864
+
1865
+ # Get window dimensions
1866
+ window_width = int(window.width)
1867
+ window_height = int(window.height)
1868
+
1869
+ # Check if the window is valid
1870
+ if window_width <= 0 or window_height <= 0:
1871
+ raise ValueError("Bounding box results in an empty window")
1872
+
1873
+ # Handle band selection
1874
+ if bands is None:
1875
+ # Use all bands
1876
+ bands_to_read = list(range(1, src.count + 1))
1877
+ else:
1878
+ # Validate band indices
1879
+ if not all(1 <= b <= src.count for b in bands):
1880
+ raise ValueError(f"Band indices must be between 1 and {src.count}")
1881
+ bands_to_read = bands
1882
+
1883
+ # Calculate new transform for the clipped raster
1884
+ new_transform = from_bounds(
1885
+ output_bounds[0],
1886
+ output_bounds[1],
1887
+ output_bounds[2],
1888
+ output_bounds[3],
1889
+ window_width,
1890
+ window_height,
1891
+ )
1892
+
1893
+ # Create a metadata dictionary for the output
1894
+ out_meta = src.meta.copy()
1895
+ out_meta.update(
1896
+ {
1897
+ "height": window_height,
1898
+ "width": window_width,
1899
+ "transform": new_transform,
1900
+ "count": len(bands_to_read),
1901
+ }
1902
+ )
1903
+
1904
+ # Read the data for the selected bands
1905
+ data = []
1906
+ for band_idx in bands_to_read:
1907
+ band_data = src.read(band_idx, window=window)
1908
+ data.append(band_data)
1909
+
1910
+ # Stack the bands into a single array
1911
+ if len(data) > 1:
1912
+ clipped_data = np.stack(data)
1913
+ else:
1914
+ clipped_data = data[0][np.newaxis, :, :]
1915
+
1916
+ # Write the output raster
1917
+ with rasterio.open(output_raster, "w", **out_meta) as dst:
1918
+ dst.write(clipped_data)
1919
+
1920
+ return output_raster
1921
+
1922
+
1923
+ def raster_to_vector(
1924
+ raster_path,
1925
+ output_path=None,
1926
+ threshold=0,
1927
+ min_area=10,
1928
+ simplify_tolerance=None,
1929
+ class_values=None,
1930
+ attribute_name="class",
1931
+ output_format="geojson",
1932
+ plot_result=False,
1933
+ ):
1934
+ """
1935
+ Convert a raster label mask to vector polygons.
1936
+
1937
+ Args:
1938
+ raster_path (str): Path to the input raster file (e.g., GeoTIFF).
1939
+ output_path (str): Path to save the output vector file. If None, returns GeoDataFrame without saving.
1940
+ threshold (int/float): Pixel values greater than this threshold will be vectorized.
1941
+ min_area (float): Minimum polygon area in square map units to keep.
1942
+ simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
1943
+ class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
1944
+ attribute_name (str): Name of the attribute field for the class values.
1945
+ output_format (str): Format for output file - 'geojson', 'shapefile', 'gpkg'.
1946
+ plot_result (bool): Whether to plot the resulting polygons overlaid on the raster.
1947
+
1948
+ Returns:
1949
+ geopandas.GeoDataFrame: A GeoDataFrame containing the vectorized polygons.
1950
+ """
1951
+ # Open the raster file
1952
+ with rasterio.open(raster_path) as src:
1953
+ # Read the data
1954
+ data = src.read(1)
1955
+
1956
+ # Get metadata
1957
+ transform = src.transform
1958
+ crs = src.crs
1959
+
1960
+ # Create mask based on threshold and class values
1961
+ if class_values is not None:
1962
+ # Create a mask for each specified class value
1963
+ masks = {val: (data == val) for val in class_values}
1964
+ else:
1965
+ # Create a mask for values above threshold
1966
+ masks = {1: (data > threshold)}
1967
+ class_values = [1] # Default class
1968
+
1969
+ # Initialize list to store features
1970
+ all_features = []
1971
+
1972
+ # Process each class value
1973
+ for class_val in class_values:
1974
+ mask = masks[class_val]
1975
+
1976
+ # Vectorize the mask
1977
+ for geom, value in features.shapes(
1978
+ mask.astype(np.uint8), mask=mask, transform=transform
1979
+ ):
1980
+ # Convert to shapely geometry
1981
+ geom = shape(geom)
1982
+
1983
+ # Skip small polygons
1984
+ if geom.area < min_area:
1985
+ continue
1986
+
1987
+ # Simplify geometry if requested
1988
+ if simplify_tolerance is not None:
1989
+ geom = geom.simplify(simplify_tolerance)
1990
+
1991
+ # Add to features list with class value
1992
+ all_features.append({"geometry": geom, attribute_name: class_val})
1993
+
1994
+ # Create GeoDataFrame
1995
+ if all_features:
1996
+ gdf = gpd.GeoDataFrame(all_features, crs=crs)
1997
+ else:
1998
+ print("Warning: No features were extracted from the raster.")
1999
+ # Return empty GeoDataFrame with correct CRS
2000
+ gdf = gpd.GeoDataFrame([], geometry=[], crs=crs)
2001
+
2002
+ # Save to file if requested
2003
+ if output_path is not None:
2004
+ # Create directory if it doesn't exist
2005
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
2006
+
2007
+ # Save to file based on format
2008
+ if output_format.lower() == "geojson":
2009
+ gdf.to_file(output_path, driver="GeoJSON")
2010
+ elif output_format.lower() == "shapefile":
2011
+ gdf.to_file(output_path)
2012
+ elif output_format.lower() == "gpkg":
2013
+ gdf.to_file(output_path, driver="GPKG")
2014
+ else:
2015
+ raise ValueError(f"Unsupported output format: {output_format}")
2016
+
2017
+ print(f"Vectorized data saved to {output_path}")
2018
+
2019
+ # Plot result if requested
2020
+ if plot_result:
2021
+ fig, ax = plt.subplots(figsize=(12, 12))
2022
+
2023
+ # Plot raster
2024
+ raster_img = src.read()
2025
+ if raster_img.shape[0] == 1:
2026
+ plt.imshow(raster_img[0], cmap="viridis", alpha=0.7)
2027
+ else:
2028
+ # Use first 3 bands for RGB display
2029
+ rgb = raster_img[:3].transpose(1, 2, 0)
2030
+ # Normalize for display
2031
+ rgb = np.clip(rgb / rgb.max(), 0, 1)
2032
+ plt.imshow(rgb)
2033
+
2034
+ # Plot vector boundaries
2035
+ if not gdf.empty:
2036
+ gdf.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
2037
+
2038
+ plt.title("Raster with Vectorized Boundaries")
2039
+ plt.axis("off")
2040
+ plt.tight_layout()
2041
+ plt.show()
2042
+
2043
+ return gdf
2044
+
2045
+
2046
+ def batch_raster_to_vector(
2047
+ input_dir,
2048
+ output_dir,
2049
+ pattern="*.tif",
2050
+ threshold=0,
2051
+ min_area=10,
2052
+ simplify_tolerance=None,
2053
+ class_values=None,
2054
+ attribute_name="class",
2055
+ output_format="geojson",
2056
+ merge_output=False,
2057
+ merge_filename="merged_vectors",
2058
+ ):
2059
+ """
2060
+ Batch convert multiple raster files to vector polygons.
2061
+
2062
+ Args:
2063
+ input_dir (str): Directory containing input raster files.
2064
+ output_dir (str): Directory to save output vector files.
2065
+ pattern (str): Pattern to match raster files (e.g., '*.tif').
2066
+ threshold (int/float): Pixel values greater than this threshold will be vectorized.
2067
+ min_area (float): Minimum polygon area in square map units to keep.
2068
+ simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
2069
+ class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
2070
+ attribute_name (str): Name of the attribute field for the class values.
2071
+ output_format (str): Format for output files - 'geojson', 'shapefile', 'gpkg'.
2072
+ merge_output (bool): Whether to merge all output vectors into a single file.
2073
+ merge_filename (str): Filename for the merged output (without extension).
2074
+
2075
+ Returns:
2076
+ geopandas.GeoDataFrame or None: If merge_output is True, returns the merged GeoDataFrame.
2077
+ """
2078
+ import glob
2079
+
2080
+ # Create output directory if it doesn't exist
2081
+ os.makedirs(output_dir, exist_ok=True)
2082
+
2083
+ # Get list of raster files
2084
+ raster_files = glob.glob(os.path.join(input_dir, pattern))
2085
+
2086
+ if not raster_files:
2087
+ print(f"No files matching pattern '{pattern}' found in {input_dir}")
2088
+ return None
2089
+
2090
+ print(f"Found {len(raster_files)} raster files to process")
2091
+
2092
+ # Process each raster file
2093
+ gdfs = []
2094
+ for raster_file in tqdm(raster_files, desc="Processing rasters"):
2095
+ # Get output filename
2096
+ base_name = os.path.splitext(os.path.basename(raster_file))[0]
2097
+ if output_format.lower() == "geojson":
2098
+ out_file = os.path.join(output_dir, f"{base_name}.geojson")
2099
+ elif output_format.lower() == "shapefile":
2100
+ out_file = os.path.join(output_dir, f"{base_name}.shp")
2101
+ elif output_format.lower() == "gpkg":
2102
+ out_file = os.path.join(output_dir, f"{base_name}.gpkg")
2103
+ else:
2104
+ raise ValueError(f"Unsupported output format: {output_format}")
2105
+
2106
+ # Convert raster to vector
2107
+ if merge_output:
2108
+ # Don't save individual files if merging
2109
+ gdf = raster_to_vector(
2110
+ raster_file,
2111
+ output_path=None,
2112
+ threshold=threshold,
2113
+ min_area=min_area,
2114
+ simplify_tolerance=simplify_tolerance,
2115
+ class_values=class_values,
2116
+ attribute_name=attribute_name,
2117
+ )
2118
+
2119
+ # Add filename as attribute
2120
+ if not gdf.empty:
2121
+ gdf["source_file"] = base_name
2122
+ gdfs.append(gdf)
2123
+ else:
2124
+ # Save individual files
2125
+ raster_to_vector(
2126
+ raster_file,
2127
+ output_path=out_file,
2128
+ threshold=threshold,
2129
+ min_area=min_area,
2130
+ simplify_tolerance=simplify_tolerance,
2131
+ class_values=class_values,
2132
+ attribute_name=attribute_name,
2133
+ output_format=output_format,
2134
+ )
2135
+
2136
+ # Merge output if requested
2137
+ if merge_output and gdfs:
2138
+ merged_gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))
2139
+
2140
+ # Set CRS to the CRS of the first GeoDataFrame
2141
+ if merged_gdf.crs is None and gdfs:
2142
+ merged_gdf.crs = gdfs[0].crs
2143
+
2144
+ # Save merged output
2145
+ if output_format.lower() == "geojson":
2146
+ merged_file = os.path.join(output_dir, f"{merge_filename}.geojson")
2147
+ merged_gdf.to_file(merged_file, driver="GeoJSON")
2148
+ elif output_format.lower() == "shapefile":
2149
+ merged_file = os.path.join(output_dir, f"{merge_filename}.shp")
2150
+ merged_gdf.to_file(merged_file)
2151
+ elif output_format.lower() == "gpkg":
2152
+ merged_file = os.path.join(output_dir, f"{merge_filename}.gpkg")
2153
+ merged_gdf.to_file(merged_file, driver="GPKG")
2154
+
2155
+ print(f"Merged vector data saved to {merged_file}")
2156
+ return merged_gdf
2157
+
2158
+ return None
2159
+
2160
+
2161
+ def vector_to_raster(
2162
+ vector_path,
2163
+ output_path=None,
2164
+ reference_raster=None,
2165
+ attribute_field=None,
2166
+ output_shape=None,
2167
+ transform=None,
2168
+ pixel_size=None,
2169
+ bounds=None,
2170
+ crs=None,
2171
+ all_touched=False,
2172
+ fill_value=0,
2173
+ dtype=np.uint8,
2174
+ nodata=None,
2175
+ plot_result=False,
2176
+ ):
2177
+ """
2178
+ Convert vector data to a raster.
2179
+
2180
+ Args:
2181
+ vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
2182
+ output_path (str): Path to save the output raster file. If None, returns the array without saving.
2183
+ reference_raster (str): Path to a reference raster for dimensions, transform and CRS.
2184
+ attribute_field (str): Field name in the vector data to use for pixel values.
2185
+ If None, all vector features will be burned with value 1.
2186
+ output_shape (tuple): Shape of the output raster as (height, width).
2187
+ Required if reference_raster is not provided.
2188
+ transform (affine.Affine): Affine transformation matrix.
2189
+ Required if reference_raster is not provided.
2190
+ pixel_size (float or tuple): Pixel size (resolution) as single value or (x_res, y_res).
2191
+ Used to calculate transform if transform is not provided.
2192
+ bounds (tuple): Bounds of the output raster as (left, bottom, right, top).
2193
+ Used to calculate transform if transform is not provided.
2194
+ crs (str or CRS): Coordinate reference system of the output raster.
2195
+ Required if reference_raster is not provided.
2196
+ all_touched (bool): If True, all pixels touched by geometries will be burned in.
2197
+ If False, only pixels whose center is within the geometry will be burned in.
2198
+ fill_value (int): Value to fill the raster with before burning in features.
2199
+ dtype (numpy.dtype): Data type of the output raster.
2200
+ nodata (int): No data value for the output raster.
2201
+ plot_result (bool): Whether to plot the resulting raster.
2202
+
2203
+ Returns:
2204
+ numpy.ndarray: The rasterized data array if output_path is None, else None.
2205
+ """
2206
+ # Load vector data
2207
+ if isinstance(vector_path, gpd.GeoDataFrame):
2208
+ gdf = vector_path
2209
+ else:
2210
+ gdf = gpd.read_file(vector_path)
2211
+
2212
+ # Check if vector data is empty
2213
+ if gdf.empty:
2214
+ warnings.warn("The input vector data is empty. Creating an empty raster.")
2215
+
2216
+ # Get CRS from vector data if not provided
2217
+ if crs is None and reference_raster is None:
2218
+ crs = gdf.crs
2219
+
2220
+ # Get transform and output shape from reference raster if provided
2221
+ if reference_raster is not None:
2222
+ with rasterio.open(reference_raster) as src:
2223
+ transform = src.transform
2224
+ output_shape = src.shape
2225
+ crs = src.crs
2226
+ if nodata is None:
2227
+ nodata = src.nodata
2228
+ else:
2229
+ # Check if we have all required parameters
2230
+ if transform is None:
2231
+ if pixel_size is None or bounds is None:
2232
+ raise ValueError(
2233
+ "Either reference_raster, transform, or both pixel_size and bounds must be provided."
2234
+ )
2235
+
2236
+ # Calculate transform from pixel size and bounds
2237
+ if isinstance(pixel_size, (int, float)):
2238
+ x_res = y_res = float(pixel_size)
2239
+ else:
2240
+ x_res, y_res = pixel_size
2241
+ y_res = abs(y_res) * -1 # Convert to negative for north-up raster
2242
+
2243
+ left, bottom, right, top = bounds
2244
+ transform = rasterio.transform.from_bounds(
2245
+ left,
2246
+ bottom,
2247
+ right,
2248
+ top,
2249
+ int((right - left) / x_res),
2250
+ int((top - bottom) / abs(y_res)),
2251
+ )
2252
+
2253
+ if output_shape is None:
2254
+ # Calculate output shape from bounds and pixel size
2255
+ if bounds is None or pixel_size is None:
2256
+ raise ValueError(
2257
+ "output_shape must be provided if reference_raster is not provided and "
2258
+ "cannot be calculated from bounds and pixel_size."
2259
+ )
2260
+
2261
+ if isinstance(pixel_size, (int, float)):
2262
+ x_res = y_res = float(pixel_size)
2263
+ else:
2264
+ x_res, y_res = pixel_size
2265
+
2266
+ left, bottom, right, top = bounds
2267
+ width = int((right - left) / x_res)
2268
+ height = int((top - bottom) / abs(y_res))
2269
+ output_shape = (height, width)
2270
+
2271
+ # Ensure CRS is set
2272
+ if crs is None:
2273
+ raise ValueError(
2274
+ "CRS must be provided either directly, from reference_raster, or from input vector data."
2275
+ )
2276
+
2277
+ # Reproject vector data if its CRS doesn't match the output CRS
2278
+ if gdf.crs != crs:
2279
+ print(f"Reprojecting vector data from {gdf.crs} to {crs}")
2280
+ gdf = gdf.to_crs(crs)
2281
+
2282
+ # Create empty raster filled with fill_value
2283
+ raster_data = np.full(output_shape, fill_value, dtype=dtype)
2284
+
2285
+ # Burn vector features into raster
2286
+ if not gdf.empty:
2287
+ # Prepare shapes for burning
2288
+ if attribute_field is not None and attribute_field in gdf.columns:
2289
+ # Use attribute field for values
2290
+ shapes = [
2291
+ (geom, value) for geom, value in zip(gdf.geometry, gdf[attribute_field])
2292
+ ]
2293
+ else:
2294
+ # Burn with value 1
2295
+ shapes = [(geom, 1) for geom in gdf.geometry]
2296
+
2297
+ # Burn shapes into raster
2298
+ burned = features.rasterize(
2299
+ shapes=shapes,
2300
+ out_shape=output_shape,
2301
+ transform=transform,
2302
+ fill=fill_value,
2303
+ all_touched=all_touched,
2304
+ dtype=dtype,
2305
+ )
2306
+
2307
+ # Update raster data
2308
+ raster_data = burned
2309
+
2310
+ # Save raster if output path is provided
2311
+ if output_path is not None:
2312
+ # Create directory if it doesn't exist
2313
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
2314
+
2315
+ # Define metadata
2316
+ metadata = {
2317
+ "driver": "GTiff",
2318
+ "height": output_shape[0],
2319
+ "width": output_shape[1],
2320
+ "count": 1,
2321
+ "dtype": raster_data.dtype,
2322
+ "crs": crs,
2323
+ "transform": transform,
2324
+ }
2325
+
2326
+ # Add nodata value if provided
2327
+ if nodata is not None:
2328
+ metadata["nodata"] = nodata
2329
+
2330
+ # Write raster
2331
+ with rasterio.open(output_path, "w", **metadata) as dst:
2332
+ dst.write(raster_data, 1)
2333
+
2334
+ print(f"Rasterized data saved to {output_path}")
2335
+
2336
+ # Plot result if requested
2337
+ if plot_result:
2338
+ fig, ax = plt.subplots(figsize=(10, 10))
2339
+
2340
+ # Plot raster
2341
+ im = ax.imshow(raster_data, cmap="viridis")
2342
+ plt.colorbar(im, ax=ax, label=attribute_field if attribute_field else "Value")
2343
+
2344
+ # Plot vector boundaries for reference
2345
+ if output_path is not None:
2346
+ # Get the extent of the raster
2347
+ with rasterio.open(output_path) as src:
2348
+ bounds = src.bounds
2349
+ raster_bbox = box(*bounds)
2350
+ else:
2351
+ # Calculate extent from transform and shape
2352
+ height, width = output_shape
2353
+ left, top = transform * (0, 0)
2354
+ right, bottom = transform * (width, height)
2355
+ raster_bbox = box(left, bottom, right, top)
2356
+
2357
+ # Clip vector to raster extent for clarity in plot
2358
+ if not gdf.empty:
2359
+ gdf_clipped = gpd.clip(gdf, raster_bbox)
2360
+ if not gdf_clipped.empty:
2361
+ gdf_clipped.boundary.plot(ax=ax, color="red", linewidth=1)
2362
+
2363
+ plt.title("Rasterized Vector Data")
2364
+ plt.tight_layout()
2365
+ plt.show()
2366
+
2367
+ return raster_data
2368
+
2369
+
2370
+ def batch_vector_to_raster(
2371
+ vector_path,
2372
+ output_dir,
2373
+ attribute_field=None,
2374
+ reference_rasters=None,
2375
+ bounds_list=None,
2376
+ output_filename_pattern="{vector_name}_{index}",
2377
+ pixel_size=1.0,
2378
+ all_touched=False,
2379
+ fill_value=0,
2380
+ dtype=np.uint8,
2381
+ nodata=None,
2382
+ ):
2383
+ """
2384
+ Batch convert vector data to multiple rasters based on different extents or reference rasters.
2385
+
2386
+ Args:
2387
+ vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
2388
+ output_dir (str): Directory to save output raster files.
2389
+ attribute_field (str): Field name in the vector data to use for pixel values.
2390
+ reference_rasters (list): List of paths to reference rasters for dimensions, transform and CRS.
2391
+ bounds_list (list): List of bounds tuples (left, bottom, right, top) to use if reference_rasters not provided.
2392
+ output_filename_pattern (str): Pattern for output filenames.
2393
+ Can include {vector_name} and {index} placeholders.
2394
+ pixel_size (float or tuple): Pixel size to use if reference_rasters not provided.
2395
+ all_touched (bool): If True, all pixels touched by geometries will be burned in.
2396
+ fill_value (int): Value to fill the raster with before burning in features.
2397
+ dtype (numpy.dtype): Data type of the output raster.
2398
+ nodata (int): No data value for the output raster.
2399
+
2400
+ Returns:
2401
+ list: List of paths to the created raster files.
2402
+ """
2403
+ # Create output directory if it doesn't exist
2404
+ os.makedirs(output_dir, exist_ok=True)
2405
+
2406
+ # Load vector data if it's a path
2407
+ if isinstance(vector_path, str):
2408
+ gdf = gpd.read_file(vector_path)
2409
+ vector_name = os.path.splitext(os.path.basename(vector_path))[0]
2410
+ else:
2411
+ gdf = vector_path
2412
+ vector_name = "vector"
2413
+
2414
+ # Check input parameters
2415
+ if reference_rasters is None and bounds_list is None:
2416
+ raise ValueError("Either reference_rasters or bounds_list must be provided.")
2417
+
2418
+ # Use reference_rasters if provided, otherwise use bounds_list
2419
+ if reference_rasters is not None:
2420
+ sources = reference_rasters
2421
+ is_raster_reference = True
2422
+ else:
2423
+ sources = bounds_list
2424
+ is_raster_reference = False
2425
+
2426
+ # Create output filenames
2427
+ output_files = []
2428
+
2429
+ # Process each source (reference raster or bounds)
2430
+ for i, source in enumerate(tqdm(sources, desc="Processing")):
2431
+ # Generate output filename
2432
+ output_filename = output_filename_pattern.format(
2433
+ vector_name=vector_name, index=i
2434
+ )
2435
+ if not output_filename.endswith(".tif"):
2436
+ output_filename += ".tif"
2437
+ output_path = os.path.join(output_dir, output_filename)
2438
+
2439
+ if is_raster_reference:
2440
+ # Use reference raster
2441
+ vector_to_raster(
2442
+ vector_path=gdf,
2443
+ output_path=output_path,
2444
+ reference_raster=source,
2445
+ attribute_field=attribute_field,
2446
+ all_touched=all_touched,
2447
+ fill_value=fill_value,
2448
+ dtype=dtype,
2449
+ nodata=nodata,
2450
+ )
2451
+ else:
2452
+ # Use bounds
2453
+ vector_to_raster(
2454
+ vector_path=gdf,
2455
+ output_path=output_path,
2456
+ bounds=source,
2457
+ pixel_size=pixel_size,
2458
+ attribute_field=attribute_field,
2459
+ all_touched=all_touched,
2460
+ fill_value=fill_value,
2461
+ dtype=dtype,
2462
+ nodata=nodata,
2463
+ )
2464
+
2465
+ output_files.append(output_path)
2466
+
2467
+ return output_files
2468
+
2469
+
2470
+ def export_geotiff_tiles(
2471
+ in_raster,
2472
+ out_folder,
2473
+ in_class_data,
2474
+ tile_size=256,
2475
+ stride=128,
2476
+ class_value_field="class",
2477
+ buffer_radius=0,
2478
+ max_tiles=None,
2479
+ quiet=False,
2480
+ all_touched=True,
2481
+ create_overview=False,
2482
+ skip_empty_tiles=False,
2483
+ ):
2484
+ """
2485
+ Export georeferenced GeoTIFF tiles and labels from raster and classification data.
2486
+
2487
+ Args:
2488
+ in_raster (str): Path to input raster image
2489
+ out_folder (str): Path to output folder
2490
+ in_class_data (str): Path to classification data - can be vector file or raster
2491
+ tile_size (int): Size of tiles in pixels (square)
2492
+ stride (int): Step size between tiles
2493
+ class_value_field (str): Field containing class values (for vector data)
2494
+ buffer_radius (float): Buffer to add around features (in units of the CRS)
2495
+ max_tiles (int): Maximum number of tiles to process (None for all)
2496
+ quiet (bool): If True, suppress non-essential output
2497
+ all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
2498
+ create_overview (bool): Whether to create an overview image of all tiles
2499
+ skip_empty_tiles (bool): If True, skip tiles with no features
2500
+ """
2501
+ # Create output directories
2502
+ os.makedirs(out_folder, exist_ok=True)
2503
+ image_dir = os.path.join(out_folder, "images")
2504
+ os.makedirs(image_dir, exist_ok=True)
2505
+ label_dir = os.path.join(out_folder, "labels")
2506
+ os.makedirs(label_dir, exist_ok=True)
2507
+ ann_dir = os.path.join(out_folder, "annotations")
2508
+ os.makedirs(ann_dir, exist_ok=True)
2509
+
2510
+ # Determine if class data is raster or vector
2511
+ is_class_data_raster = False
2512
+ if isinstance(in_class_data, str):
2513
+ file_ext = Path(in_class_data).suffix.lower()
2514
+ # Common raster extensions
2515
+ if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
2516
+ try:
2517
+ with rasterio.open(in_class_data) as src:
2518
+ is_class_data_raster = True
2519
+ if not quiet:
2520
+ print(f"Detected in_class_data as raster: {in_class_data}")
2521
+ print(f"Raster CRS: {src.crs}")
2522
+ print(f"Raster dimensions: {src.width} x {src.height}")
2523
+ except Exception:
2524
+ is_class_data_raster = False
2525
+ if not quiet:
2526
+ print(f"Unable to open {in_class_data} as raster, trying as vector")
2527
+
2528
+ # Open the input raster
2529
+ with rasterio.open(in_raster) as src:
2530
+ if not quiet:
2531
+ print(f"\nRaster info for {in_raster}:")
2532
+ print(f" CRS: {src.crs}")
2533
+ print(f" Dimensions: {src.width} x {src.height}")
2534
+ print(f" Bounds: {src.bounds}")
2535
+
2536
+ # Calculate number of tiles
2537
+ num_tiles_x = math.ceil((src.width - tile_size) / stride) + 1
2538
+ num_tiles_y = math.ceil((src.height - tile_size) / stride) + 1
2539
+ total_tiles = num_tiles_x * num_tiles_y
2540
+
2541
+ if max_tiles is None:
2542
+ max_tiles = total_tiles
2543
+
2544
+ # Process classification data
2545
+ class_to_id = {}
2546
+
2547
+ if is_class_data_raster:
2548
+ # Load raster class data
2549
+ with rasterio.open(in_class_data) as class_src:
2550
+ # Check if raster CRS matches
2551
+ if class_src.crs != src.crs:
2552
+ warnings.warn(
2553
+ f"CRS mismatch: Class raster ({class_src.crs}) doesn't match input raster ({src.crs}). "
2554
+ f"Results may be misaligned."
2555
+ )
2556
+
2557
+ # Get unique values from raster
2558
+ # Sample to avoid loading huge rasters
2559
+ sample_data = class_src.read(
2560
+ 1,
2561
+ out_shape=(
2562
+ 1,
2563
+ min(class_src.height, 1000),
2564
+ min(class_src.width, 1000),
2565
+ ),
2566
+ )
2567
+
2568
+ unique_classes = np.unique(sample_data)
2569
+ unique_classes = unique_classes[
2570
+ unique_classes > 0
2571
+ ] # Remove 0 as it's typically background
2572
+
2573
+ if not quiet:
2574
+ print(
2575
+ f"Found {len(unique_classes)} unique classes in raster: {unique_classes}"
2576
+ )
2577
+
2578
+ # Create class mapping
2579
+ class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
2580
+ else:
2581
+ # Load vector class data
2582
+ try:
2583
+ gdf = gpd.read_file(in_class_data)
2584
+ if not quiet:
2585
+ print(f"Loaded {len(gdf)} features from {in_class_data}")
2586
+ print(f"Vector CRS: {gdf.crs}")
2587
+
2588
+ # Always reproject to match raster CRS
2589
+ if gdf.crs != src.crs:
2590
+ if not quiet:
2591
+ print(f"Reprojecting features from {gdf.crs} to {src.crs}")
2592
+ gdf = gdf.to_crs(src.crs)
2593
+
2594
+ # Apply buffer if specified
2595
+ if buffer_radius > 0:
2596
+ gdf["geometry"] = gdf.buffer(buffer_radius)
2597
+ if not quiet:
2598
+ print(f"Applied buffer of {buffer_radius} units")
2599
+
2600
+ # Check if class_value_field exists
2601
+ if class_value_field in gdf.columns:
2602
+ unique_classes = gdf[class_value_field].unique()
2603
+ if not quiet:
2604
+ print(
2605
+ f"Found {len(unique_classes)} unique classes: {unique_classes}"
2606
+ )
2607
+ # Create class mapping
2608
+ class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
2609
+ else:
2610
+ if not quiet:
2611
+ print(
2612
+ f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
2613
+ )
2614
+ class_to_id = {1: 1} # Default mapping
2615
+ except Exception as e:
2616
+ raise ValueError(f"Error processing vector data: {e}")
2617
+
2618
+ # Create progress bar
2619
+ pbar = tqdm(
2620
+ total=min(total_tiles, max_tiles),
2621
+ desc="Generating tiles",
2622
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
2623
+ )
2624
+
2625
+ # Track statistics for summary
2626
+ stats = {
2627
+ "total_tiles": 0,
2628
+ "tiles_with_features": 0,
2629
+ "feature_pixels": 0,
2630
+ "errors": 0,
2631
+ "tile_coordinates": [], # For overview image
2632
+ }
2633
+
2634
+ # Process tiles
2635
+ tile_index = 0
2636
+ for y in range(num_tiles_y):
2637
+ for x in range(num_tiles_x):
2638
+ if tile_index >= max_tiles:
2639
+ break
2640
+
2641
+ # Calculate window coordinates
2642
+ window_x = x * stride
2643
+ window_y = y * stride
2644
+
2645
+ # Adjust for edge cases
2646
+ if window_x + tile_size > src.width:
2647
+ window_x = src.width - tile_size
2648
+ if window_y + tile_size > src.height:
2649
+ window_y = src.height - tile_size
2650
+
2651
+ # Define window
2652
+ window = Window(window_x, window_y, tile_size, tile_size)
2653
+
2654
+ # Get window transform and bounds
2655
+ window_transform = src.window_transform(window)
2656
+
2657
+ # Calculate window bounds
2658
+ minx = window_transform[2] # Upper left x
2659
+ maxy = window_transform[5] # Upper left y
2660
+ maxx = minx + tile_size * window_transform[0] # Add width
2661
+ miny = maxy + tile_size * window_transform[4] # Add height
2662
+
2663
+ window_bounds = box(minx, miny, maxx, maxy)
2664
+
2665
+ # Store tile coordinates for overview
2666
+ if create_overview:
2667
+ stats["tile_coordinates"].append(
2668
+ {
2669
+ "index": tile_index,
2670
+ "x": window_x,
2671
+ "y": window_y,
2672
+ "bounds": [minx, miny, maxx, maxy],
2673
+ "has_features": False,
2674
+ }
2675
+ )
2676
+
2677
+ # Create label mask
2678
+ label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
2679
+ has_features = False
2680
+
2681
+ # Process classification data to create labels
2682
+ if is_class_data_raster:
2683
+ # For raster class data
2684
+ with rasterio.open(in_class_data) as class_src:
2685
+ # Calculate window in class raster
2686
+ src_bounds = src.bounds
2687
+ class_bounds = class_src.bounds
2688
+
2689
+ # Check if windows overlap
2690
+ if (
2691
+ src_bounds.left > class_bounds.right
2692
+ or src_bounds.right < class_bounds.left
2693
+ or src_bounds.bottom > class_bounds.top
2694
+ or src_bounds.top < class_bounds.bottom
2695
+ ):
2696
+ warnings.warn(
2697
+ "Class raster and input raster do not overlap."
2698
+ )
2699
+ else:
2700
+ # Get corresponding window in class raster
2701
+ window_class = rasterio.windows.from_bounds(
2702
+ minx, miny, maxx, maxy, class_src.transform
2703
+ )
2704
+
2705
+ # Read label data
2706
+ try:
2707
+ label_data = class_src.read(
2708
+ 1,
2709
+ window=window_class,
2710
+ boundless=True,
2711
+ out_shape=(tile_size, tile_size),
2712
+ )
2713
+
2714
+ # Remap class values if needed
2715
+ if class_to_id:
2716
+ remapped_data = np.zeros_like(label_data)
2717
+ for orig_val, new_val in class_to_id.items():
2718
+ remapped_data[label_data == orig_val] = new_val
2719
+ label_mask = remapped_data
2720
+ else:
2721
+ label_mask = label_data
2722
+
2723
+ # Check if we have any features
2724
+ if np.any(label_mask > 0):
2725
+ has_features = True
2726
+ stats["feature_pixels"] += np.count_nonzero(
2727
+ label_mask
2728
+ )
2729
+ except Exception as e:
2730
+ pbar.write(f"Error reading class raster window: {e}")
2731
+ stats["errors"] += 1
2732
+ else:
2733
+ # For vector class data
2734
+ # Find features that intersect with window
2735
+ window_features = gdf[gdf.intersects(window_bounds)]
2736
+
2737
+ if len(window_features) > 0:
2738
+ for idx, feature in window_features.iterrows():
2739
+ # Get class value
2740
+ if class_value_field in feature:
2741
+ class_val = feature[class_value_field]
2742
+ class_id = class_to_id.get(class_val, 1)
2743
+ else:
2744
+ class_id = 1
2745
+
2746
+ # Get geometry in window coordinates
2747
+ geom = feature.geometry.intersection(window_bounds)
2748
+ if not geom.is_empty:
2749
+ try:
2750
+ # Rasterize feature
2751
+ feature_mask = features.rasterize(
2752
+ [(geom, class_id)],
2753
+ out_shape=(tile_size, tile_size),
2754
+ transform=window_transform,
2755
+ fill=0,
2756
+ all_touched=all_touched,
2757
+ )
2758
+
2759
+ # Add to label mask
2760
+ label_mask = np.maximum(label_mask, feature_mask)
2761
+
2762
+ # Check if the feature was actually rasterized
2763
+ if np.any(feature_mask):
2764
+ has_features = True
2765
+ if create_overview and tile_index < len(
2766
+ stats["tile_coordinates"]
2767
+ ):
2768
+ stats["tile_coordinates"][tile_index][
2769
+ "has_features"
2770
+ ] = True
2771
+ except Exception as e:
2772
+ pbar.write(f"Error rasterizing feature {idx}: {e}")
2773
+ stats["errors"] += 1
2774
+
2775
+ # Skip tile if no features and skip_empty_tiles is True
2776
+ if skip_empty_tiles and not has_features:
2777
+ pbar.update(1)
2778
+ tile_index += 1
2779
+ continue
2780
+
2781
+ # Read image data
2782
+ image_data = src.read(window=window)
2783
+
2784
+ # Export image as GeoTIFF
2785
+ image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
2786
+
2787
+ # Create profile for image GeoTIFF
2788
+ image_profile = src.profile.copy()
2789
+ image_profile.update(
2790
+ {
2791
+ "height": tile_size,
2792
+ "width": tile_size,
2793
+ "count": image_data.shape[0],
2794
+ "transform": window_transform,
2795
+ }
2796
+ )
2797
+
2798
+ # Save image as GeoTIFF
2799
+ try:
2800
+ with rasterio.open(image_path, "w", **image_profile) as dst:
2801
+ dst.write(image_data)
2802
+ stats["total_tiles"] += 1
2803
+ except Exception as e:
2804
+ pbar.write(f"ERROR saving image GeoTIFF: {e}")
2805
+ stats["errors"] += 1
2806
+
2807
+ # Create profile for label GeoTIFF
2808
+ label_profile = {
2809
+ "driver": "GTiff",
2810
+ "height": tile_size,
2811
+ "width": tile_size,
2812
+ "count": 1,
2813
+ "dtype": "uint8",
2814
+ "crs": src.crs,
2815
+ "transform": window_transform,
2816
+ }
2817
+
2818
+ # Export label as GeoTIFF
2819
+ label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
2820
+ try:
2821
+ with rasterio.open(label_path, "w", **label_profile) as dst:
2822
+ dst.write(label_mask.astype(np.uint8), 1)
2823
+
2824
+ if has_features:
2825
+ stats["tiles_with_features"] += 1
2826
+ stats["feature_pixels"] += np.count_nonzero(label_mask)
2827
+ except Exception as e:
2828
+ pbar.write(f"ERROR saving label GeoTIFF: {e}")
2829
+ stats["errors"] += 1
2830
+
2831
+ # Create XML annotation for object detection if using vector class data
2832
+ if (
2833
+ not is_class_data_raster
2834
+ and "gdf" in locals()
2835
+ and len(window_features) > 0
2836
+ ):
2837
+ # Create XML annotation
2838
+ root = ET.Element("annotation")
2839
+ ET.SubElement(root, "folder").text = "images"
2840
+ ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
2841
+
2842
+ size = ET.SubElement(root, "size")
2843
+ ET.SubElement(size, "width").text = str(tile_size)
2844
+ ET.SubElement(size, "height").text = str(tile_size)
2845
+ ET.SubElement(size, "depth").text = str(image_data.shape[0])
2846
+
2847
+ # Add georeference information
2848
+ geo = ET.SubElement(root, "georeference")
2849
+ ET.SubElement(geo, "crs").text = str(src.crs)
2850
+ ET.SubElement(geo, "transform").text = str(
2851
+ window_transform
2852
+ ).replace("\n", "")
2853
+ ET.SubElement(geo, "bounds").text = (
2854
+ f"{minx}, {miny}, {maxx}, {maxy}"
2855
+ )
2856
+
2857
+ # Add objects
2858
+ for idx, feature in window_features.iterrows():
2859
+ # Get feature class
2860
+ if class_value_field in feature:
2861
+ class_val = feature[class_value_field]
2862
+ else:
2863
+ class_val = "object"
2864
+
2865
+ # Get geometry bounds in pixel coordinates
2866
+ geom = feature.geometry.intersection(window_bounds)
2867
+ if not geom.is_empty:
2868
+ # Get bounds in world coordinates
2869
+ minx_f, miny_f, maxx_f, maxy_f = geom.bounds
2870
+
2871
+ # Convert to pixel coordinates
2872
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
2873
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
2874
+
2875
+ # Ensure coordinates are within tile bounds
2876
+ xmin = max(0, min(tile_size, int(col_min)))
2877
+ ymin = max(0, min(tile_size, int(row_min)))
2878
+ xmax = max(0, min(tile_size, int(col_max)))
2879
+ ymax = max(0, min(tile_size, int(row_max)))
2880
+
2881
+ # Only add if the box has non-zero area
2882
+ if xmax > xmin and ymax > ymin:
2883
+ obj = ET.SubElement(root, "object")
2884
+ ET.SubElement(obj, "name").text = str(class_val)
2885
+ ET.SubElement(obj, "difficult").text = "0"
2886
+
2887
+ bbox = ET.SubElement(obj, "bndbox")
2888
+ ET.SubElement(bbox, "xmin").text = str(xmin)
2889
+ ET.SubElement(bbox, "ymin").text = str(ymin)
2890
+ ET.SubElement(bbox, "xmax").text = str(xmax)
2891
+ ET.SubElement(bbox, "ymax").text = str(ymax)
2892
+
2893
+ # Save XML
2894
+ tree = ET.ElementTree(root)
2895
+ xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
2896
+ tree.write(xml_path)
2897
+
2898
+ # Update progress bar
2899
+ pbar.update(1)
2900
+ pbar.set_description(
2901
+ f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
2902
+ )
2903
+
2904
+ tile_index += 1
2905
+ if tile_index >= max_tiles:
2906
+ break
2907
+
2908
+ if tile_index >= max_tiles:
2909
+ break
2910
+
2911
+ # Close progress bar
2912
+ pbar.close()
2913
+
2914
+ # Create overview image if requested
2915
+ if create_overview and stats["tile_coordinates"]:
2916
+ try:
2917
+ create_overview_image(
2918
+ src,
2919
+ stats["tile_coordinates"],
2920
+ os.path.join(out_folder, "overview.png"),
2921
+ tile_size,
2922
+ stride,
2923
+ )
2924
+ except Exception as e:
2925
+ print(f"Failed to create overview image: {e}")
2926
+
2927
+ # Report results
2928
+ if not quiet:
2929
+ print("\n------- Export Summary -------")
2930
+ print(f"Total tiles exported: {stats['total_tiles']}")
2931
+ print(
2932
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
2933
+ )
2934
+ if stats["tiles_with_features"] > 0:
2935
+ print(
2936
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
2937
+ )
2938
+ if stats["errors"] > 0:
2939
+ print(f"Errors encountered: {stats['errors']}")
2940
+ print(f"Output saved to: {out_folder}")
2941
+
2942
+ # Verify georeference in a sample image and label
2943
+ if stats["total_tiles"] > 0:
2944
+ print("\n------- Georeference Verification -------")
2945
+ sample_image = os.path.join(image_dir, f"tile_0.tif")
2946
+ sample_label = os.path.join(label_dir, f"tile_0.tif")
2947
+
2948
+ if os.path.exists(sample_image):
2949
+ try:
2950
+ with rasterio.open(sample_image) as img:
2951
+ print(f"Image CRS: {img.crs}")
2952
+ print(f"Image transform: {img.transform}")
2953
+ print(
2954
+ f"Image has georeference: {img.crs is not None and img.transform is not None}"
2955
+ )
2956
+ print(
2957
+ f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
2958
+ )
2959
+ except Exception as e:
2960
+ print(f"Error verifying image georeference: {e}")
2961
+
2962
+ if os.path.exists(sample_label):
2963
+ try:
2964
+ with rasterio.open(sample_label) as lbl:
2965
+ print(f"Label CRS: {lbl.crs}")
2966
+ print(f"Label transform: {lbl.transform}")
2967
+ print(
2968
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
2969
+ )
2970
+ print(
2971
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
2972
+ )
2973
+ except Exception as e:
2974
+ print(f"Error verifying label georeference: {e}")
2975
+
2976
+ # Return statistics dictionary for further processing if needed
2977
+ return stats
2978
+
2979
+
2980
+ def create_overview_image(
2981
+ src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
2982
+ ):
2983
+ """Create an overview image showing all tiles and their status, with optional GeoJSON export.
2984
+
2985
+ Args:
2986
+ src (rasterio.io.DatasetReader): The source raster dataset.
2987
+ tile_coordinates (list): A list of dictionaries containing tile information.
2988
+ output_path (str): The path where the overview image will be saved.
2989
+ tile_size (int): The size of each tile in pixels.
2990
+ stride (int): The stride between tiles in pixels. Controls overlap between adjacent tiles.
2991
+ geojson_path (str, optional): If provided, exports the tile rectangles as GeoJSON to this path.
2992
+
2993
+ Returns:
2994
+ str: Path to the saved overview image.
2995
+ """
2996
+ # Read a reduced version of the source image
2997
+ overview_scale = max(
2998
+ 1, int(max(src.width, src.height) / 2000)
2999
+ ) # Scale to max ~2000px
3000
+ overview_width = src.width // overview_scale
3001
+ overview_height = src.height // overview_scale
3002
+
3003
+ # Read downsampled image
3004
+ overview_data = src.read(
3005
+ out_shape=(src.count, overview_height, overview_width),
3006
+ resampling=rasterio.enums.Resampling.average,
3007
+ )
3008
+
3009
+ # Create RGB image for display
3010
+ if overview_data.shape[0] >= 3:
3011
+ rgb = np.moveaxis(overview_data[:3], 0, -1)
3012
+ else:
3013
+ # For single band, create grayscale RGB
3014
+ rgb = np.stack([overview_data[0], overview_data[0], overview_data[0]], axis=-1)
3015
+
3016
+ # Normalize for display
3017
+ for i in range(rgb.shape[-1]):
3018
+ band = rgb[..., i]
3019
+ non_zero = band[band > 0]
3020
+ if len(non_zero) > 0:
3021
+ p2, p98 = np.percentile(non_zero, (2, 98))
3022
+ rgb[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
3023
+
3024
+ # Create figure
3025
+ plt.figure(figsize=(12, 12))
3026
+ plt.imshow(rgb)
3027
+
3028
+ # If GeoJSON export is requested, prepare GeoJSON structures
3029
+ if geojson_path:
3030
+ features = []
3031
+
3032
+ # Draw tile boundaries
3033
+ for tile in tile_coordinates:
3034
+ # Convert bounds to pixel coordinates in overview
3035
+ bounds = tile["bounds"]
3036
+ # Calculate scaled pixel coordinates
3037
+ x_min = int((tile["x"]) / overview_scale)
3038
+ y_min = int((tile["y"]) / overview_scale)
3039
+ width = int(tile_size / overview_scale)
3040
+ height = int(tile_size / overview_scale)
3041
+
3042
+ # Draw rectangle
3043
+ color = "lime" if tile["has_features"] else "red"
3044
+ rect = plt.Rectangle(
3045
+ (x_min, y_min), width, height, fill=False, edgecolor=color, linewidth=0.5
3046
+ )
3047
+ plt.gca().add_patch(rect)
3048
+
3049
+ # Add tile number if not too crowded
3050
+ if width > 20 and height > 20:
3051
+ plt.text(
3052
+ x_min + width / 2,
3053
+ y_min + height / 2,
3054
+ str(tile["index"]),
3055
+ color="white",
3056
+ ha="center",
3057
+ va="center",
3058
+ fontsize=8,
3059
+ )
3060
+
3061
+ # Add to GeoJSON features if exporting
3062
+ if geojson_path:
3063
+ # Create a polygon from the bounds (already in geo-coordinates)
3064
+ minx, miny, maxx, maxy = bounds
3065
+ polygon = box(minx, miny, maxx, maxy)
3066
+
3067
+ # Calculate overlap with neighboring tiles
3068
+ overlap = 0
3069
+ if stride < tile_size:
3070
+ overlap = tile_size - stride
3071
+
3072
+ # Create a GeoJSON feature
3073
+ feature = {
3074
+ "type": "Feature",
3075
+ "geometry": mapping(polygon),
3076
+ "properties": {
3077
+ "index": tile["index"],
3078
+ "has_features": tile["has_features"],
3079
+ "bounds_pixel": [
3080
+ tile["x"],
3081
+ tile["y"],
3082
+ tile["x"] + tile_size,
3083
+ tile["y"] + tile_size,
3084
+ ],
3085
+ "tile_size_px": tile_size,
3086
+ "stride_px": stride,
3087
+ "overlap_px": overlap,
3088
+ },
3089
+ }
3090
+
3091
+ # Add any additional properties from the tile
3092
+ for key, value in tile.items():
3093
+ if key not in ["x", "y", "index", "has_features", "bounds"]:
3094
+ feature["properties"][key] = value
3095
+
3096
+ features.append(feature)
3097
+
3098
+ plt.title("Tile Overview (Green = Contains Features, Red = Empty)")
3099
+ plt.axis("off")
3100
+ plt.tight_layout()
3101
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
3102
+ plt.close()
3103
+
3104
+ print(f"Overview image saved to {output_path}")
3105
+
3106
+ # Export GeoJSON if requested
3107
+ if geojson_path:
3108
+ geojson_collection = {
3109
+ "type": "FeatureCollection",
3110
+ "features": features,
3111
+ "properties": {
3112
+ "crs": (
3113
+ src.crs.to_string()
3114
+ if hasattr(src.crs, "to_string")
3115
+ else str(src.crs)
3116
+ ),
3117
+ "total_tiles": len(features),
3118
+ "source_raster_dimensions": [src.width, src.height],
3119
+ },
3120
+ }
3121
+
3122
+ # Save to file
3123
+ with open(geojson_path, "w") as f:
3124
+ json.dump(geojson_collection, f)
3125
+
3126
+ print(f"GeoJSON saved to {geojson_path}")
3127
+
3128
+ return output_path
3129
+
3130
+
3131
+ def export_tiles_to_geojson(
3132
+ tile_coordinates, src, output_path, tile_size=None, stride=None
3133
+ ):
3134
+ """
3135
+ Export tile rectangles directly to GeoJSON without creating an overview image.
3136
+
3137
+ Args:
3138
+ tile_coordinates (list): A list of dictionaries containing tile information.
3139
+ src (rasterio.io.DatasetReader): The source raster dataset.
3140
+ output_path (str): The path where the GeoJSON will be saved.
3141
+ tile_size (int, optional): The size of each tile in pixels. Only needed if not in tile_coordinates.
3142
+ stride (int, optional): The stride between tiles in pixels. Used to calculate overlaps between tiles.
3143
+
3144
+ Returns:
3145
+ str: Path to the saved GeoJSON file.
3146
+ """
3147
+ features = []
3148
+
3149
+ for tile in tile_coordinates:
3150
+ # Get the size from the tile or use the provided parameter
3151
+ tile_width = tile.get("width", tile.get("size", tile_size))
3152
+ tile_height = tile.get("height", tile.get("size", tile_size))
3153
+
3154
+ if tile_width is None or tile_height is None:
3155
+ raise ValueError(
3156
+ "Tile size not found in tile data and no tile_size parameter provided"
3157
+ )
3158
+
3159
+ # Get bounds from the tile
3160
+ if "bounds" in tile:
3161
+ # If bounds are already in geo coordinates
3162
+ minx, miny, maxx, maxy = tile["bounds"]
3163
+ else:
3164
+ # Try to calculate bounds from transform if available
3165
+ if hasattr(src, "transform"):
3166
+ # Convert pixel coordinates to geo coordinates
3167
+ window_transform = src.transform
3168
+ x, y = tile["x"], tile["y"]
3169
+ minx = window_transform[2] + x * window_transform[0]
3170
+ maxy = window_transform[5] + y * window_transform[4]
3171
+ maxx = minx + tile_width * window_transform[0]
3172
+ miny = maxy + tile_height * window_transform[4]
3173
+ else:
3174
+ raise ValueError(
3175
+ "Cannot determine bounds. Neither 'bounds' in tile nor transform in src."
3176
+ )
3177
+
3178
+ # Calculate overlap with neighboring tiles if stride is provided
3179
+ overlap = 0
3180
+ if stride is not None and stride < tile_width:
3181
+ overlap = tile_width - stride
3182
+
3183
+ # Create a polygon from the bounds
3184
+ polygon = box(minx, miny, maxx, maxy)
3185
+
3186
+ # Create a GeoJSON feature
3187
+ feature = {
3188
+ "type": "Feature",
3189
+ "geometry": mapping(polygon),
3190
+ "properties": {
3191
+ "index": tile["index"],
3192
+ "has_features": tile.get("has_features", False),
3193
+ "tile_width_px": tile_width,
3194
+ "tile_height_px": tile_height,
3195
+ },
3196
+ }
3197
+
3198
+ # Add overlap information if stride is provided
3199
+ if stride is not None:
3200
+ feature["properties"]["stride_px"] = stride
3201
+ feature["properties"]["overlap_px"] = overlap
3202
+
3203
+ # Add additional properties from the tile
3204
+ for key, value in tile.items():
3205
+ if key not in ["bounds", "geometry"]:
3206
+ feature["properties"][key] = value
3207
+
3208
+ features.append(feature)
3209
+
3210
+ # Create the GeoJSON collection
3211
+ geojson_collection = {
3212
+ "type": "FeatureCollection",
3213
+ "features": features,
3214
+ "properties": {
3215
+ "crs": (
3216
+ src.crs.to_string() if hasattr(src.crs, "to_string") else str(src.crs)
3217
+ ),
3218
+ "total_tiles": len(features),
3219
+ "source_raster_dimensions": (
3220
+ [src.width, src.height] if hasattr(src, "width") else None
3221
+ ),
3222
+ },
3223
+ }
3224
+
3225
+ # Create directory if it doesn't exist
3226
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)) or ".", exist_ok=True)
3227
+
3228
+ # Save to file
3229
+ with open(output_path, "w") as f:
3230
+ json.dump(geojson_collection, f)
3231
+
3232
+ print(f"GeoJSON saved to {output_path}")
3233
+ return output_path
3234
+
3235
+
3236
+ def export_training_data(
3237
+ in_raster,
3238
+ out_folder,
3239
+ in_class_data,
3240
+ image_chip_format="GEOTIFF",
3241
+ tile_size_x=256,
3242
+ tile_size_y=256,
3243
+ stride_x=None,
3244
+ stride_y=None,
3245
+ output_nofeature_tiles=True,
3246
+ metadata_format="PASCAL_VOC",
3247
+ start_index=0,
3248
+ class_value_field="class",
3249
+ buffer_radius=0,
3250
+ in_mask_polygons=None,
3251
+ rotation_angle=0,
3252
+ reference_system=None,
3253
+ blacken_around_feature=False,
3254
+ crop_mode="FIXED_SIZE", # Implemented but not fully used yet
3255
+ in_raster2=None,
3256
+ in_instance_data=None,
3257
+ instance_class_value_field=None, # Implemented but not fully used yet
3258
+ min_polygon_overlap_ratio=0.0,
3259
+ all_touched=True,
3260
+ save_geotiff=True,
3261
+ quiet=False,
3262
+ ):
3263
+ """
3264
+ Export training data for deep learning using TorchGeo with progress bar.
3265
+
3266
+ Args:
3267
+ in_raster (str): Path to input raster image.
3268
+ out_folder (str): Output folder path where chips and labels will be saved.
3269
+ in_class_data (str): Path to vector file containing class polygons.
3270
+ image_chip_format (str): Output image format (PNG, JPEG, TIFF, GEOTIFF).
3271
+ tile_size_x (int): Width of image chips in pixels.
3272
+ tile_size_y (int): Height of image chips in pixels.
3273
+ stride_x (int): Horizontal stride between chips. If None, uses tile_size_x.
3274
+ stride_y (int): Vertical stride between chips. If None, uses tile_size_y.
3275
+ output_nofeature_tiles (bool): Whether to export chips without features.
3276
+ metadata_format (str): Output metadata format (PASCAL_VOC, KITTI, COCO).
3277
+ start_index (int): Starting index for chip filenames.
3278
+ class_value_field (str): Field name in in_class_data containing class values.
3279
+ buffer_radius (float): Buffer radius around features (in CRS units).
3280
+ in_mask_polygons (str): Path to vector file containing mask polygons.
3281
+ rotation_angle (float): Rotation angle in degrees.
3282
+ reference_system (str): Reference system code.
3283
+ blacken_around_feature (bool): Whether to mask areas outside of features.
3284
+ crop_mode (str): Crop mode (FIXED_SIZE, CENTERED_ON_FEATURE).
3285
+ in_raster2 (str): Path to secondary raster image.
3286
+ in_instance_data (str): Path to vector file containing instance polygons.
3287
+ instance_class_value_field (str): Field name in in_instance_data for instance classes.
3288
+ min_polygon_overlap_ratio (float): Minimum overlap ratio for polygons.
3289
+ all_touched (bool): Whether to use all_touched=True in rasterization.
3290
+ save_geotiff (bool): Whether to save as GeoTIFF with georeferencing.
3291
+ quiet (bool): If True, suppress most output messages.
3292
+ """
3293
+ # Create output directories
3294
+ image_dir = os.path.join(out_folder, "images")
3295
+ os.makedirs(image_dir, exist_ok=True)
3296
+
3297
+ label_dir = os.path.join(out_folder, "labels")
3298
+ os.makedirs(label_dir, exist_ok=True)
3299
+
3300
+ # Define annotation directories based on metadata format
3301
+ if metadata_format == "PASCAL_VOC":
3302
+ ann_dir = os.path.join(out_folder, "annotations")
3303
+ os.makedirs(ann_dir, exist_ok=True)
3304
+ elif metadata_format == "COCO":
3305
+ ann_dir = os.path.join(out_folder, "annotations")
3306
+ os.makedirs(ann_dir, exist_ok=True)
3307
+ # Initialize COCO annotations dictionary
3308
+ coco_annotations = {"images": [], "annotations": [], "categories": []}
3309
+
3310
+ # Initialize statistics dictionary
3311
+ stats = {
3312
+ "total_tiles": 0,
3313
+ "tiles_with_features": 0,
3314
+ "feature_pixels": 0,
3315
+ "errors": 0,
3316
+ }
3317
+
3318
+ # Open raster
3319
+ with rasterio.open(in_raster) as src:
3320
+ if not quiet:
3321
+ print(f"\nRaster info for {in_raster}:")
3322
+ print(f" CRS: {src.crs}")
3323
+ print(f" Dimensions: {src.width} x {src.height}")
3324
+ print(f" Bounds: {src.bounds}")
3325
+
3326
+ # Set defaults for stride if not provided
3327
+ if stride_x is None:
3328
+ stride_x = tile_size_x
3329
+ if stride_y is None:
3330
+ stride_y = tile_size_y
3331
+
3332
+ # Calculate number of tiles in x and y directions
3333
+ num_tiles_x = math.ceil((src.width - tile_size_x) / stride_x) + 1
3334
+ num_tiles_y = math.ceil((src.height - tile_size_y) / stride_y) + 1
3335
+ total_tiles = num_tiles_x * num_tiles_y
3336
+
3337
+ # Read class data
3338
+ gdf = gpd.read_file(in_class_data)
3339
+ if not quiet:
3340
+ print(f"Loaded {len(gdf)} features from {in_class_data}")
3341
+ print(f"Available columns: {gdf.columns.tolist()}")
3342
+ print(f"GeoJSON CRS: {gdf.crs}")
3343
+
3344
+ # Check if class_value_field exists
3345
+ if class_value_field not in gdf.columns:
3346
+ if not quiet:
3347
+ print(
3348
+ f"WARNING: '{class_value_field}' field not found in the input data. Using default class value 1."
3349
+ )
3350
+ # Add a default class column
3351
+ gdf[class_value_field] = 1
3352
+ unique_classes = [1]
3353
+ else:
3354
+ # Print unique classes for debugging
3355
+ unique_classes = gdf[class_value_field].unique()
3356
+ if not quiet:
3357
+ print(f"Found {len(unique_classes)} unique classes: {unique_classes}")
3358
+
3359
+ # CRITICAL: Always reproject to match raster CRS to ensure proper alignment
3360
+ if gdf.crs != src.crs:
3361
+ if not quiet:
3362
+ print(f"Reprojecting features from {gdf.crs} to {src.crs}")
3363
+ gdf = gdf.to_crs(src.crs)
3364
+ elif reference_system and gdf.crs != reference_system:
3365
+ if not quiet:
3366
+ print(
3367
+ f"Reprojecting features to specified reference system {reference_system}"
3368
+ )
3369
+ gdf = gdf.to_crs(reference_system)
3370
+
3371
+ # Check overlap between raster and vector data
3372
+ raster_bounds = box(*src.bounds)
3373
+ vector_bounds = box(*gdf.total_bounds)
3374
+ if not raster_bounds.intersects(vector_bounds):
3375
+ if not quiet:
3376
+ print(
3377
+ "WARNING: The vector data doesn't intersect with the raster extent!"
3378
+ )
3379
+ print(f"Raster bounds: {src.bounds}")
3380
+ print(f"Vector bounds: {gdf.total_bounds}")
3381
+ else:
3382
+ overlap = (
3383
+ raster_bounds.intersection(vector_bounds).area / vector_bounds.area
3384
+ )
3385
+ if not quiet:
3386
+ print(f"Overlap between raster and vector: {overlap:.2%}")
3387
+
3388
+ # Apply buffer if specified
3389
+ if buffer_radius > 0:
3390
+ gdf["geometry"] = gdf.buffer(buffer_radius)
3391
+
3392
+ # Initialize class mapping (ensure all classes are mapped to non-zero values)
3393
+ class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
3394
+
3395
+ # Store category info for COCO format
3396
+ if metadata_format == "COCO":
3397
+ for cls_val in unique_classes:
3398
+ coco_annotations["categories"].append(
3399
+ {
3400
+ "id": class_to_id[cls_val],
3401
+ "name": str(cls_val),
3402
+ "supercategory": "object",
3403
+ }
3404
+ )
3405
+
3406
+ # Load mask polygons if provided
3407
+ mask_gdf = None
3408
+ if in_mask_polygons:
3409
+ mask_gdf = gpd.read_file(in_mask_polygons)
3410
+ if reference_system:
3411
+ mask_gdf = mask_gdf.to_crs(reference_system)
3412
+ elif mask_gdf.crs != src.crs:
3413
+ mask_gdf = mask_gdf.to_crs(src.crs)
3414
+
3415
+ # Process instance data if provided
3416
+ instance_gdf = None
3417
+ if in_instance_data:
3418
+ instance_gdf = gpd.read_file(in_instance_data)
3419
+ if reference_system:
3420
+ instance_gdf = instance_gdf.to_crs(reference_system)
3421
+ elif instance_gdf.crs != src.crs:
3422
+ instance_gdf = instance_gdf.to_crs(src.crs)
3423
+
3424
+ # Load secondary raster if provided
3425
+ src2 = None
3426
+ if in_raster2:
3427
+ src2 = rasterio.open(in_raster2)
3428
+
3429
+ # Set up augmentation if rotation is specified
3430
+ augmentation = None
3431
+ if rotation_angle != 0:
3432
+ # Fixed: Added data_keys parameter to AugmentationSequential
3433
+ augmentation = torchgeo.transforms.AugmentationSequential(
3434
+ torch.nn.ModuleList([RandomRotation(rotation_angle)]),
3435
+ data_keys=["image"], # Add data_keys parameter
3436
+ )
3437
+
3438
+ # Initialize annotation ID for COCO format
3439
+ ann_id = 0
3440
+
3441
+ # Create progress bar
3442
+ pbar = tqdm(
3443
+ total=total_tiles,
3444
+ desc=f"Generating tiles (with features: 0)",
3445
+ bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
3446
+ )
3447
+
3448
+ # Generate tiles
3449
+ chip_index = start_index
3450
+ for y in range(num_tiles_y):
3451
+ for x in range(num_tiles_x):
3452
+ # Calculate window coordinates
3453
+ window_x = x * stride_x
3454
+ window_y = y * stride_y
3455
+
3456
+ # Adjust for edge cases
3457
+ if window_x + tile_size_x > src.width:
3458
+ window_x = src.width - tile_size_x
3459
+ if window_y + tile_size_y > src.height:
3460
+ window_y = src.height - tile_size_y
3461
+
3462
+ # Adjust window based on crop_mode
3463
+ if crop_mode == "CENTERED_ON_FEATURE" and len(gdf) > 0:
3464
+ # Find the nearest feature to the center of this window
3465
+ window_center_x = window_x + tile_size_x // 2
3466
+ window_center_y = window_y + tile_size_y // 2
3467
+
3468
+ # Convert center to world coordinates
3469
+ center_x, center_y = src.xy(window_center_y, window_center_x)
3470
+ center_point = gpd.points_from_xy([center_x], [center_y])[0]
3471
+
3472
+ # Find nearest feature
3473
+ distances = gdf.geometry.distance(center_point)
3474
+ nearest_idx = distances.idxmin()
3475
+ nearest_feature = gdf.iloc[nearest_idx]
3476
+
3477
+ # Get centroid of nearest feature
3478
+ feature_centroid = nearest_feature.geometry.centroid
3479
+
3480
+ # Convert feature centroid to pixel coordinates
3481
+ feature_row, feature_col = src.index(
3482
+ feature_centroid.x, feature_centroid.y
3483
+ )
3484
+
3485
+ # Adjust window to center on feature
3486
+ window_x = max(
3487
+ 0, min(src.width - tile_size_x, feature_col - tile_size_x // 2)
3488
+ )
3489
+ window_y = max(
3490
+ 0, min(src.height - tile_size_y, feature_row - tile_size_y // 2)
3491
+ )
3492
+
3493
+ # Define window
3494
+ window = Window(window_x, window_y, tile_size_x, tile_size_y)
3495
+
3496
+ # Get window transform and bounds in source CRS
3497
+ window_transform = src.window_transform(window)
3498
+
3499
+ # Calculate window bounds more explicitly and accurately
3500
+ minx = window_transform[2] # Upper left x
3501
+ maxy = window_transform[5] # Upper left y
3502
+ maxx = minx + tile_size_x * window_transform[0] # Add width
3503
+ miny = (
3504
+ maxy + tile_size_y * window_transform[4]
3505
+ ) # Add height (note: transform[4] is typically negative)
3506
+
3507
+ window_bounds = box(minx, miny, maxx, maxy)
3508
+
3509
+ # Apply rotation if specified
3510
+ if rotation_angle != 0:
3511
+ window_bounds = rotate(
3512
+ window_bounds, rotation_angle, origin="center"
3513
+ )
3514
+
3515
+ # Find features that intersect with window
3516
+ window_features = gdf[gdf.intersects(window_bounds)]
3517
+
3518
+ # Process instance data if provided
3519
+ window_instances = None
3520
+ if instance_gdf is not None and instance_class_value_field is not None:
3521
+ window_instances = instance_gdf[
3522
+ instance_gdf.intersects(window_bounds)
3523
+ ]
3524
+ if len(window_instances) > 0:
3525
+ if not quiet:
3526
+ pbar.write(
3527
+ f"Found {len(window_instances)} instances in tile {chip_index}"
3528
+ )
3529
+
3530
+ # Skip if no features and output_nofeature_tiles is False
3531
+ if not output_nofeature_tiles and len(window_features) == 0:
3532
+ pbar.update(1) # Still update progress bar
3533
+ continue
3534
+
3535
+ # Check polygon overlap ratio if specified
3536
+ if min_polygon_overlap_ratio > 0 and len(window_features) > 0:
3537
+ valid_features = []
3538
+ for _, feature in window_features.iterrows():
3539
+ overlap_ratio = (
3540
+ feature.geometry.intersection(window_bounds).area
3541
+ / feature.geometry.area
3542
+ )
3543
+ if overlap_ratio >= min_polygon_overlap_ratio:
3544
+ valid_features.append(feature)
3545
+
3546
+ if len(valid_features) > 0:
3547
+ window_features = gpd.GeoDataFrame(valid_features)
3548
+ elif not output_nofeature_tiles:
3549
+ pbar.update(1) # Still update progress bar
3550
+ continue
3551
+
3552
+ # Apply mask if provided
3553
+ if mask_gdf is not None:
3554
+ mask_features = mask_gdf[mask_gdf.intersects(window_bounds)]
3555
+ if len(mask_features) == 0:
3556
+ pbar.update(1) # Still update progress bar
3557
+ continue
3558
+
3559
+ # Read image data - keep original for GeoTIFF export
3560
+ orig_image_data = src.read(window=window)
3561
+
3562
+ # Create a copy for processing
3563
+ image_data = orig_image_data.copy().astype(np.float32)
3564
+
3565
+ # Normalize image data for processing
3566
+ for band in range(image_data.shape[0]):
3567
+ band_min, band_max = np.percentile(image_data[band], (1, 99))
3568
+ if band_max > band_min:
3569
+ image_data[band] = np.clip(
3570
+ (image_data[band] - band_min) / (band_max - band_min), 0, 1
3571
+ )
3572
+
3573
+ # Read secondary image data if provided
3574
+ if src2:
3575
+ image_data2 = src2.read(window=window)
3576
+ # Stack the two images
3577
+ image_data = np.vstack((image_data, image_data2))
3578
+
3579
+ # Apply blacken_around_feature if needed
3580
+ if blacken_around_feature and len(window_features) > 0:
3581
+ mask = np.zeros((tile_size_y, tile_size_x), dtype=bool)
3582
+ for _, feature in window_features.iterrows():
3583
+ # Project feature to pixel coordinates
3584
+ feature_pixels = features.rasterize(
3585
+ [(feature.geometry, 1)],
3586
+ out_shape=(tile_size_y, tile_size_x),
3587
+ transform=window_transform,
3588
+ )
3589
+ mask = np.logical_or(mask, feature_pixels.astype(bool))
3590
+
3591
+ # Apply mask to image
3592
+ for band in range(image_data.shape[0]):
3593
+ temp = image_data[band, :, :]
3594
+ temp[~mask] = 0
3595
+ image_data[band, :, :] = temp
3596
+
3597
+ # Apply rotation if specified
3598
+ if augmentation:
3599
+ # Convert to torch tensor for augmentation
3600
+ image_tensor = torch.from_numpy(image_data).unsqueeze(
3601
+ 0
3602
+ ) # Add batch dimension
3603
+ # Apply augmentation with proper data format
3604
+ augmented = augmentation({"image": image_tensor})
3605
+ image_data = (
3606
+ augmented["image"].squeeze(0).numpy()
3607
+ ) # Remove batch dimension
3608
+
3609
+ # Create a processed version for regular image formats
3610
+ processed_image = (image_data * 255).astype(np.uint8)
3611
+
3612
+ # Create label mask
3613
+ label_mask = np.zeros((tile_size_y, tile_size_x), dtype=np.uint8)
3614
+ has_features = False
3615
+
3616
+ if len(window_features) > 0:
3617
+ for idx, feature in window_features.iterrows():
3618
+ # Get class value
3619
+ class_val = (
3620
+ feature[class_value_field]
3621
+ if class_value_field in feature
3622
+ else 1
3623
+ )
3624
+ if isinstance(class_val, str):
3625
+ # If class is a string, use its position in the unique classes list
3626
+ class_id = class_to_id.get(class_val, 1)
3627
+ else:
3628
+ # If class is already a number, use it directly
3629
+ class_id = int(class_val) if class_val > 0 else 1
3630
+
3631
+ # Get the geometry in pixel coordinates
3632
+ geom = feature.geometry.intersection(window_bounds)
3633
+ if not geom.is_empty:
3634
+ try:
3635
+ # Rasterize the feature
3636
+ feature_mask = features.rasterize(
3637
+ [(geom, class_id)],
3638
+ out_shape=(tile_size_y, tile_size_x),
3639
+ transform=window_transform,
3640
+ fill=0,
3641
+ all_touched=all_touched,
3642
+ )
3643
+
3644
+ # Update mask with higher class values taking precedence
3645
+ label_mask = np.maximum(label_mask, feature_mask)
3646
+
3647
+ # Check if any pixels were added
3648
+ if np.any(feature_mask):
3649
+ has_features = True
3650
+ except Exception as e:
3651
+ if not quiet:
3652
+ pbar.write(f"Error rasterizing feature {idx}: {e}")
3653
+ stats["errors"] += 1
3654
+
3655
+ # Save as GeoTIFF if requested
3656
+ if save_geotiff or image_chip_format.upper() in [
3657
+ "TIFF",
3658
+ "TIF",
3659
+ "GEOTIFF",
3660
+ ]:
3661
+ # Standardize extension to .tif for GeoTIFF files
3662
+ image_filename = f"tile_{chip_index:06d}.tif"
3663
+ image_path = os.path.join(image_dir, image_filename)
3664
+
3665
+ # Create profile for the GeoTIFF
3666
+ profile = src.profile.copy()
3667
+ profile.update(
3668
+ {
3669
+ "height": tile_size_y,
3670
+ "width": tile_size_x,
3671
+ "count": orig_image_data.shape[0],
3672
+ "transform": window_transform,
3673
+ }
3674
+ )
3675
+
3676
+ # Save the GeoTIFF with original data
3677
+ try:
3678
+ with rasterio.open(image_path, "w", **profile) as dst:
3679
+ dst.write(orig_image_data)
3680
+ stats["total_tiles"] += 1
3681
+ except Exception as e:
3682
+ if not quiet:
3683
+ pbar.write(
3684
+ f"ERROR saving image GeoTIFF for tile {chip_index}: {e}"
3685
+ )
3686
+ stats["errors"] += 1
3687
+ else:
3688
+ # For non-GeoTIFF formats, use PIL to save the image
3689
+ image_filename = (
3690
+ f"tile_{chip_index:06d}.{image_chip_format.lower()}"
3691
+ )
3692
+ image_path = os.path.join(image_dir, image_filename)
3693
+
3694
+ # Create PIL image for saving
3695
+ if processed_image.shape[0] == 1:
3696
+ img = Image.fromarray(processed_image[0])
3697
+ elif processed_image.shape[0] == 3:
3698
+ # For RGB, need to transpose and make sure it's the right data type
3699
+ rgb_data = np.transpose(processed_image, (1, 2, 0))
3700
+ img = Image.fromarray(rgb_data)
3701
+ else:
3702
+ # For multiband images, save only RGB or first three bands
3703
+ rgb_data = np.transpose(processed_image[:3], (1, 2, 0))
3704
+ img = Image.fromarray(rgb_data)
3705
+
3706
+ # Save image
3707
+ try:
3708
+ img.save(image_path)
3709
+ stats["total_tiles"] += 1
3710
+ except Exception as e:
3711
+ if not quiet:
3712
+ pbar.write(f"ERROR saving image for tile {chip_index}: {e}")
3713
+ stats["errors"] += 1
3714
+
3715
+ # Save label as GeoTIFF
3716
+ label_filename = f"tile_{chip_index:06d}.tif"
3717
+ label_path = os.path.join(label_dir, label_filename)
3718
+
3719
+ # Create profile for label GeoTIFF
3720
+ label_profile = {
3721
+ "driver": "GTiff",
3722
+ "height": tile_size_y,
3723
+ "width": tile_size_x,
3724
+ "count": 1,
3725
+ "dtype": "uint8",
3726
+ "crs": src.crs,
3727
+ "transform": window_transform,
3728
+ }
3729
+
3730
+ # Save label GeoTIFF
3731
+ try:
3732
+ with rasterio.open(label_path, "w", **label_profile) as dst:
3733
+ dst.write(label_mask, 1)
3734
+
3735
+ if has_features:
3736
+ pixel_count = np.count_nonzero(label_mask)
3737
+ stats["tiles_with_features"] += 1
3738
+ stats["feature_pixels"] += pixel_count
3739
+ except Exception as e:
3740
+ if not quiet:
3741
+ pbar.write(f"ERROR saving label for tile {chip_index}: {e}")
3742
+ stats["errors"] += 1
3743
+
3744
+ # Also save a PNG version for easy visualization if requested
3745
+ if metadata_format == "PASCAL_VOC":
3746
+ try:
3747
+ # Ensure correct data type for PIL
3748
+ png_label = label_mask.astype(np.uint8)
3749
+ label_img = Image.fromarray(png_label)
3750
+ label_png_path = os.path.join(
3751
+ label_dir, f"tile_{chip_index:06d}.png"
3752
+ )
3753
+ label_img.save(label_png_path)
3754
+ except Exception as e:
3755
+ if not quiet:
3756
+ pbar.write(
3757
+ f"ERROR saving PNG label for tile {chip_index}: {e}"
3758
+ )
3759
+ pbar.write(
3760
+ f" Label mask shape: {label_mask.shape}, dtype: {label_mask.dtype}"
3761
+ )
3762
+ # Try again with explicit conversion
3763
+ try:
3764
+ # Alternative approach for problematic arrays
3765
+ png_data = np.zeros(
3766
+ (tile_size_y, tile_size_x), dtype=np.uint8
3767
+ )
3768
+ np.copyto(png_data, label_mask, casting="unsafe")
3769
+ label_img = Image.fromarray(png_data)
3770
+ label_img.save(label_png_path)
3771
+ pbar.write(
3772
+ f" Succeeded using alternative conversion method"
3773
+ )
3774
+ except Exception as e2:
3775
+ pbar.write(f" Second attempt also failed: {e2}")
3776
+ stats["errors"] += 1
3777
+
3778
+ # Generate annotations
3779
+ if metadata_format == "PASCAL_VOC" and len(window_features) > 0:
3780
+ # Create XML annotation
3781
+ root = ET.Element("annotation")
3782
+ ET.SubElement(root, "folder").text = "images"
3783
+ ET.SubElement(root, "filename").text = image_filename
3784
+
3785
+ size = ET.SubElement(root, "size")
3786
+ ET.SubElement(size, "width").text = str(tile_size_x)
3787
+ ET.SubElement(size, "height").text = str(tile_size_y)
3788
+ ET.SubElement(size, "depth").text = str(min(image_data.shape[0], 3))
3789
+
3790
+ # Add georeference information
3791
+ geo = ET.SubElement(root, "georeference")
3792
+ ET.SubElement(geo, "crs").text = str(src.crs)
3793
+ ET.SubElement(geo, "transform").text = str(
3794
+ window_transform
3795
+ ).replace("\n", "")
3796
+ ET.SubElement(geo, "bounds").text = (
3797
+ f"{minx}, {miny}, {maxx}, {maxy}"
3798
+ )
3799
+
3800
+ for _, feature in window_features.iterrows():
3801
+ # Convert feature geometry to pixel coordinates
3802
+ feature_bounds = feature.geometry.intersection(window_bounds)
3803
+ if feature_bounds.is_empty:
3804
+ continue
3805
+
3806
+ # Get pixel coordinates of bounds
3807
+ minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
3808
+
3809
+ # Convert to pixel coordinates
3810
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3811
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3812
+
3813
+ # Ensure coordinates are within bounds
3814
+ xmin = max(0, min(tile_size_x, int(col_min)))
3815
+ ymin = max(0, min(tile_size_y, int(row_min)))
3816
+ xmax = max(0, min(tile_size_x, int(col_max)))
3817
+ ymax = max(0, min(tile_size_y, int(row_max)))
3818
+
3819
+ # Skip if box is too small
3820
+ if xmax - xmin < 1 or ymax - ymin < 1:
3821
+ continue
3822
+
3823
+ obj = ET.SubElement(root, "object")
3824
+ ET.SubElement(obj, "name").text = str(
3825
+ feature[class_value_field]
3826
+ )
3827
+ ET.SubElement(obj, "difficult").text = "0"
3828
+
3829
+ bbox = ET.SubElement(obj, "bndbox")
3830
+ ET.SubElement(bbox, "xmin").text = str(xmin)
3831
+ ET.SubElement(bbox, "ymin").text = str(ymin)
3832
+ ET.SubElement(bbox, "xmax").text = str(xmax)
3833
+ ET.SubElement(bbox, "ymax").text = str(ymax)
3834
+
3835
+ # Save XML
3836
+ try:
3837
+ tree = ET.ElementTree(root)
3838
+ xml_path = os.path.join(ann_dir, f"tile_{chip_index:06d}.xml")
3839
+ tree.write(xml_path)
3840
+ except Exception as e:
3841
+ if not quiet:
3842
+ pbar.write(
3843
+ f"ERROR saving XML annotation for tile {chip_index}: {e}"
3844
+ )
3845
+ stats["errors"] += 1
3846
+
3847
+ elif metadata_format == "COCO" and len(window_features) > 0:
3848
+ # Add image info
3849
+ image_id = chip_index
3850
+ coco_annotations["images"].append(
3851
+ {
3852
+ "id": image_id,
3853
+ "file_name": image_filename,
3854
+ "width": tile_size_x,
3855
+ "height": tile_size_y,
3856
+ "crs": str(src.crs),
3857
+ "transform": str(window_transform),
3858
+ }
3859
+ )
3860
+
3861
+ # Add annotations for each feature
3862
+ for _, feature in window_features.iterrows():
3863
+ feature_bounds = feature.geometry.intersection(window_bounds)
3864
+ if feature_bounds.is_empty:
3865
+ continue
3866
+
3867
+ # Get pixel coordinates of bounds
3868
+ minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
3869
+
3870
+ # Convert to pixel coordinates
3871
+ col_min, row_min = ~window_transform * (minx_f, maxy_f)
3872
+ col_max, row_max = ~window_transform * (maxx_f, miny_f)
3873
+
3874
+ # Ensure coordinates are within bounds
3875
+ xmin = max(0, min(tile_size_x, int(col_min)))
3876
+ ymin = max(0, min(tile_size_y, int(row_min)))
3877
+ xmax = max(0, min(tile_size_x, int(col_max)))
3878
+ ymax = max(0, min(tile_size_y, int(row_max)))
3879
+
3880
+ # Skip if box is too small
3881
+ if xmax - xmin < 1 or ymax - ymin < 1:
3882
+ continue
3883
+
3884
+ width = xmax - xmin
3885
+ height = ymax - ymin
3886
+
3887
+ # Add annotation
3888
+ ann_id += 1
3889
+ category_id = class_to_id[feature[class_value_field]]
3890
+
3891
+ coco_annotations["annotations"].append(
3892
+ {
3893
+ "id": ann_id,
3894
+ "image_id": image_id,
3895
+ "category_id": category_id,
3896
+ "bbox": [xmin, ymin, width, height],
3897
+ "area": width * height,
3898
+ "iscrowd": 0,
3899
+ }
3900
+ )
3901
+
3902
+ # Update progress bar
3903
+ pbar.update(1)
3904
+ pbar.set_description(
3905
+ f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
3906
+ )
3907
+
3908
+ chip_index += 1
3909
+
3910
+ # Close progress bar
3911
+ pbar.close()
3912
+
3913
+ # Save COCO annotations if applicable
3914
+ if metadata_format == "COCO":
3915
+ try:
3916
+ with open(os.path.join(ann_dir, "instances.json"), "w") as f:
3917
+ json.dump(coco_annotations, f)
3918
+ except Exception as e:
3919
+ if not quiet:
3920
+ print(f"ERROR saving COCO annotations: {e}")
3921
+ stats["errors"] += 1
3922
+
3923
+ # Close secondary raster if opened
3924
+ if src2:
3925
+ src2.close()
3926
+
3927
+ # Print summary
3928
+ if not quiet:
3929
+ print("\n------- Export Summary -------")
3930
+ print(f"Total tiles exported: {stats['total_tiles']}")
3931
+ print(
3932
+ f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
3933
+ )
3934
+ if stats["tiles_with_features"] > 0:
3935
+ print(
3936
+ f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
3937
+ )
3938
+ if stats["errors"] > 0:
3939
+ print(f"Errors encountered: {stats['errors']}")
3940
+ print(f"Output saved to: {out_folder}")
3941
+
3942
+ # Verify georeference in a sample image and label
3943
+ if stats["total_tiles"] > 0:
3944
+ print("\n------- Georeference Verification -------")
3945
+ sample_image = os.path.join(image_dir, f"tile_{start_index}.tif")
3946
+ sample_label = os.path.join(label_dir, f"tile_{start_index}.tif")
3947
+
3948
+ if os.path.exists(sample_image):
3949
+ try:
3950
+ with rasterio.open(sample_image) as img:
3951
+ print(f"Image CRS: {img.crs}")
3952
+ print(f"Image transform: {img.transform}")
3953
+ print(
3954
+ f"Image has georeference: {img.crs is not None and img.transform is not None}"
3955
+ )
3956
+ print(
3957
+ f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
3958
+ )
3959
+ except Exception as e:
3960
+ print(f"Error verifying image georeference: {e}")
3961
+
3962
+ if os.path.exists(sample_label):
3963
+ try:
3964
+ with rasterio.open(sample_label) as lbl:
3965
+ print(f"Label CRS: {lbl.crs}")
3966
+ print(f"Label transform: {lbl.transform}")
3967
+ print(
3968
+ f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
3969
+ )
3970
+ print(
3971
+ f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
3972
+ )
3973
+ except Exception as e:
3974
+ print(f"Error verifying label georeference: {e}")
3975
+
3976
+ # Return statistics
3977
+ return stats, out_folder
3978
+
3979
+
3980
+ def masks_to_vector(
3981
+ mask_path,
3982
+ output_path=None,
3983
+ simplify_tolerance=1.0,
3984
+ mask_threshold=0.5,
3985
+ min_object_area=100,
3986
+ max_object_area=None,
3987
+ nms_iou_threshold=0.5,
3988
+ ):
3989
+ """
3990
+ Convert a building mask GeoTIFF to vector polygons and save as a vector dataset.
3991
+
3992
+ Args:
3993
+ mask_path: Path to the building masks GeoTIFF
3994
+ output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
3995
+ simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
3996
+ mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
3997
+ min_object_area: Minimum area in pixels to keep a building (default: self.min_object_area)
3998
+ max_object_area: Maximum area in pixels to keep a building (default: self.max_object_area)
3999
+ nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
4000
+
4001
+ Returns:
4002
+ GeoDataFrame with building footprints
4003
+ """
4004
+ # Set default output path if not provided
4005
+ # if output_path is None:
4006
+ # output_path = os.path.splitext(mask_path)[0] + ".geojson"
4007
+
4008
+ print(f"Converting mask to GeoJSON with parameters:")
4009
+ print(f"- Mask threshold: {mask_threshold}")
4010
+ print(f"- Min building area: {min_object_area}")
4011
+ print(f"- Simplify tolerance: {simplify_tolerance}")
4012
+ print(f"- NMS IoU threshold: {nms_iou_threshold}")
4013
+
4014
+ # Open the mask raster
4015
+ with rasterio.open(mask_path) as src:
4016
+ # Read the mask data
4017
+ mask_data = src.read(1)
4018
+ transform = src.transform
4019
+ crs = src.crs
4020
+
4021
+ # Print mask statistics
4022
+ print(f"Mask dimensions: {mask_data.shape}")
4023
+ print(f"Mask value range: {mask_data.min()} to {mask_data.max()}")
4024
+
4025
+ # Prepare for connected component analysis
4026
+ # Binarize the mask based on threshold
4027
+ binary_mask = (mask_data > (mask_threshold * 255)).astype(np.uint8)
4028
+
4029
+ # Apply morphological operations for better results (optional)
4030
+ kernel = np.ones((3, 3), np.uint8)
4031
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
4032
+
4033
+ # Find connected components
4034
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
4035
+ binary_mask, connectivity=8
4036
+ )
4037
+
4038
+ print(f"Found {num_labels-1} potential buildings") # Subtract 1 for background
4039
+
4040
+ # Create list to store polygons and confidence values
4041
+ all_polygons = []
4042
+ all_confidences = []
4043
+
4044
+ # Process each component (skip the first one which is background)
4045
+ for i in tqdm(range(1, num_labels)):
4046
+ # Extract this building
4047
+ area = stats[i, cv2.CC_STAT_AREA]
4048
+
4049
+ # Skip if too small
4050
+ if area < min_object_area:
4051
+ continue
4052
+
4053
+ # Skip if too large
4054
+ if max_object_area is not None and area > max_object_area:
4055
+ continue
4056
+
4057
+ # Create a mask for this building
4058
+ building_mask = (labels == i).astype(np.uint8)
4059
+
4060
+ # Find contours
4061
+ contours, _ = cv2.findContours(
4062
+ building_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
4063
+ )
4064
+
4065
+ # Process each contour
4066
+ for contour in contours:
4067
+ # Skip if too few points
4068
+ if contour.shape[0] < 3:
4069
+ continue
4070
+
4071
+ # Simplify contour if it has many points
4072
+ if contour.shape[0] > 50 and simplify_tolerance > 0:
4073
+ epsilon = simplify_tolerance * cv2.arcLength(contour, True)
4074
+ contour = cv2.approxPolyDP(contour, epsilon, True)
4075
+
4076
+ # Convert to list of (x, y) coordinates
4077
+ polygon_points = contour.reshape(-1, 2)
4078
+
4079
+ # Convert pixel coordinates to geographic coordinates
4080
+ geo_points = []
4081
+ for x, y in polygon_points:
4082
+ gx, gy = transform * (x, y)
4083
+ geo_points.append((gx, gy))
4084
+
4085
+ # Create Shapely polygon
4086
+ if len(geo_points) >= 3:
4087
+ try:
4088
+ shapely_poly = Polygon(geo_points)
4089
+ if shapely_poly.is_valid and shapely_poly.area > 0:
4090
+ all_polygons.append(shapely_poly)
4091
+
4092
+ # Calculate "confidence" as normalized size
4093
+ # This is a proxy since we don't have model confidence scores
4094
+ normalized_size = min(1.0, area / 1000) # Cap at 1.0
4095
+ all_confidences.append(normalized_size)
4096
+ except Exception as e:
4097
+ print(f"Error creating polygon: {e}")
4098
+
4099
+ print(f"Created {len(all_polygons)} valid polygons")
4100
+
4101
+ # Create GeoDataFrame
4102
+ if not all_polygons:
4103
+ print("No valid polygons found")
4104
+ return None
4105
+
4106
+ gdf = gpd.GeoDataFrame(
4107
+ {
4108
+ "geometry": all_polygons,
4109
+ "confidence": all_confidences,
4110
+ "class": 1, # Building class
4111
+ },
4112
+ crs=crs,
4113
+ )
4114
+
4115
+ def filter_overlapping_polygons(gdf, **kwargs):
4116
+ """
4117
+ Filter overlapping polygons using non-maximum suppression.
4118
+
4119
+ Args:
4120
+ gdf: GeoDataFrame with polygons
4121
+ **kwargs: Optional parameters:
4122
+ nms_iou_threshold: IoU threshold for filtering
4123
+
4124
+ Returns:
4125
+ Filtered GeoDataFrame
4126
+ """
4127
+ if len(gdf) <= 1:
4128
+ return gdf
4129
+
4130
+ # Get parameters from kwargs or use instance defaults
4131
+ iou_threshold = kwargs.get("nms_iou_threshold", nms_iou_threshold)
4132
+
4133
+ # Sort by confidence
4134
+ gdf = gdf.sort_values("confidence", ascending=False)
4135
+
4136
+ # Fix any invalid geometries
4137
+ gdf["geometry"] = gdf["geometry"].apply(
4138
+ lambda geom: geom.buffer(0) if not geom.is_valid else geom
4139
+ )
4140
+
4141
+ keep_indices = []
4142
+ polygons = gdf.geometry.values
4143
+
4144
+ for i in range(len(polygons)):
4145
+ if i in keep_indices:
4146
+ continue
4147
+
4148
+ keep = True
4149
+ for j in keep_indices:
4150
+ # Skip invalid geometries
4151
+ if not polygons[i].is_valid or not polygons[j].is_valid:
4152
+ continue
4153
+
4154
+ # Calculate IoU
4155
+ try:
4156
+ intersection = polygons[i].intersection(polygons[j]).area
4157
+ union = polygons[i].area + polygons[j].area - intersection
4158
+ iou = intersection / union if union > 0 else 0
4159
+
4160
+ if iou > iou_threshold:
4161
+ keep = False
4162
+ break
4163
+ except Exception:
4164
+ # Skip on topology exceptions
4165
+ continue
4166
+
4167
+ if keep:
4168
+ keep_indices.append(i)
4169
+
4170
+ return gdf.iloc[keep_indices]
4171
+
4172
+ # Apply non-maximum suppression to remove overlapping polygons
4173
+ gdf = filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
4174
+
4175
+ print(f"Final building count after filtering: {len(gdf)}")
4176
+
4177
+ # Save to file
4178
+ if output_path is not None:
4179
+ gdf.to_file(output_path)
4180
+ print(f"Saved {len(gdf)} building footprints to {output_path}")
4181
+
4182
+ return gdf
4183
+
4184
+
4185
+ def read_vector(source, layer=None, **kwargs):
4186
+ """Reads vector data from various formats including GeoParquet.
4187
+
4188
+ This function dynamically determines the file type based on extension
4189
+ and reads it into a GeoDataFrame. It supports both local files and HTTP/HTTPS URLs.
4190
+
4191
+ Args:
4192
+ source: String path to the vector file or URL.
4193
+ layer: String or integer specifying which layer to read from multi-layer
4194
+ files (only applicable for formats like GPKG, GeoJSON, etc.).
4195
+ Defaults to None.
4196
+ **kwargs: Additional keyword arguments to pass to the underlying reader.
4197
+
4198
+ Returns:
4199
+ geopandas.GeoDataFrame: A GeoDataFrame containing the vector data.
4200
+
4201
+ Raises:
4202
+ ValueError: If the file format is not supported or source cannot be accessed.
4203
+
4204
+ Examples:
4205
+ Read a local shapefile
4206
+ >>> gdf = read_vector("path/to/data.shp")
4207
+ >>>
4208
+ Read a GeoParquet file from URL
4209
+ >>> gdf = read_vector("https://example.com/data.parquet")
4210
+ >>>
4211
+ Read a specific layer from a GeoPackage
4212
+ >>> gdf = read_vector("path/to/data.gpkg", layer="layer_name")
4213
+ """
4214
+
4215
+ import fiona
4216
+ import urllib.parse
4217
+
4218
+ # Determine if source is a URL or local file
4219
+ parsed_url = urllib.parse.urlparse(source)
4220
+ is_url = parsed_url.scheme in ["http", "https"]
4221
+
4222
+ # If it's a local file, check if it exists
4223
+ if not is_url and not os.path.exists(source):
4224
+ raise ValueError(f"File does not exist: {source}")
4225
+
4226
+ # Get file extension
4227
+ _, ext = os.path.splitext(source)
4228
+ ext = ext.lower()
4229
+
4230
+ # Handle GeoParquet files
4231
+ if ext in [".parquet", ".pq", ".geoparquet"]:
4232
+ return gpd.read_parquet(source, **kwargs)
4233
+
4234
+ # Handle common vector formats
4235
+ if ext in [".shp", ".geojson", ".json", ".gpkg", ".gml", ".kml", ".gpx"]:
4236
+ # For formats that might have multiple layers
4237
+ if ext in [".gpkg", ".gml"] and layer is not None:
4238
+ return gpd.read_file(source, layer=layer, **kwargs)
4239
+ return gpd.read_file(source, **kwargs)
4240
+
4241
+ # Try to use fiona to identify valid layers for formats that might have them
4242
+ # Only attempt this for local files as fiona.listlayers might not work with URLs
4243
+ if layer is None and ext in [".gpkg", ".gml"] and not is_url:
4244
+ try:
4245
+ layers = fiona.listlayers(source)
4246
+ if layers:
4247
+ return gpd.read_file(source, layer=layers[0], **kwargs)
4248
+ except Exception:
4249
+ # If listing layers fails, we'll fall through to the generic read attempt
4250
+ pass
4251
+
4252
+ # For other formats or when layer listing fails, attempt to read using GeoPandas
4253
+ try:
4254
+ return gpd.read_file(source, **kwargs)
4255
+ except Exception as e:
4256
+ raise ValueError(f"Could not read from source '{source}': {str(e)}")
4257
+
4258
+
4259
+ def read_raster(source, band=None, masked=True, **kwargs):
4260
+ """Reads raster data from various formats using rioxarray.
4261
+
4262
+ This function reads raster data from local files or URLs into a rioxarray
4263
+ data structure with preserved geospatial metadata.
4264
+
4265
+ Args:
4266
+ source: String path to the raster file or URL.
4267
+ band: Integer or list of integers specifying which band(s) to read.
4268
+ Defaults to None (all bands).
4269
+ masked: Boolean indicating whether to mask nodata values.
4270
+ Defaults to True.
4271
+ **kwargs: Additional keyword arguments to pass to rioxarray.open_rasterio.
4272
+
4273
+ Returns:
4274
+ xarray.DataArray: A DataArray containing the raster data with geospatial
4275
+ metadata preserved.
4276
+
4277
+ Raises:
4278
+ ValueError: If the file format is not supported or source cannot be accessed.
4279
+
4280
+ Examples:
4281
+ Read a local GeoTIFF
4282
+ >>> raster = read_raster("path/to/data.tif")
4283
+ >>>
4284
+ Read only band 1 from a remote GeoTIFF
4285
+ >>> raster = read_raster("https://example.com/data.tif", band=1)
4286
+ >>>
4287
+ Read a raster without masking nodata values
4288
+ >>> raster = read_raster("path/to/data.tif", masked=False)
4289
+ """
4290
+ import urllib.parse
4291
+ from rasterio.errors import RasterioIOError
4292
+
4293
+ # Determine if source is a URL or local file
4294
+ parsed_url = urllib.parse.urlparse(source)
4295
+ is_url = parsed_url.scheme in ["http", "https"]
4296
+
4297
+ # If it's a local file, check if it exists
4298
+ if not is_url and not os.path.exists(source):
4299
+ raise ValueError(f"Raster file does not exist: {source}")
4300
+
4301
+ try:
4302
+ # Open the raster with rioxarray
4303
+ raster = rxr.open_rasterio(source, masked=masked, **kwargs)
4304
+
4305
+ # Handle band selection if specified
4306
+ if band is not None:
4307
+ if isinstance(band, (list, tuple)):
4308
+ # Convert from 1-based indexing to 0-based indexing
4309
+ band_indices = [b - 1 for b in band]
4310
+ raster = raster.isel(band=band_indices)
4311
+ else:
4312
+ # Single band selection (convert from 1-based to 0-based indexing)
4313
+ raster = raster.isel(band=band - 1)
4314
+
4315
+ return raster
4316
+
4317
+ except RasterioIOError as e:
4318
+ raise ValueError(f"Could not read raster from source '{source}': {str(e)}")
4319
+ except Exception as e:
4320
+ raise ValueError(f"Error reading raster data: {str(e)}")
4321
+
4322
+
4323
+ def temp_file_path(ext):
4324
+ """Returns a temporary file path.
4325
+
4326
+ Args:
4327
+ ext (str): The file extension.
4328
+
4329
+ Returns:
4330
+ str: The temporary file path.
4331
+ """
4332
+
4333
+ import tempfile
4334
+ import uuid
4335
+
4336
+ if not ext.startswith("."):
4337
+ ext = "." + ext
4338
+ file_id = str(uuid.uuid4())
4339
+ file_path = os.path.join(tempfile.gettempdir(), f"{file_id}{ext}")
4340
+
4341
+ return file_path
4342
+
4343
+
4344
+ def region_groups(
4345
+ image: Union[str, "xr.DataArray", np.ndarray],
4346
+ connectivity: int = 1,
4347
+ min_size: int = 10,
4348
+ max_size: Optional[int] = None,
4349
+ threshold: Optional[int] = None,
4350
+ properties: Optional[List[str]] = None,
4351
+ intensity_image: Optional[Union[str, "xr.DataArray", np.ndarray]] = None,
4352
+ out_csv: Optional[str] = None,
4353
+ out_vector: Optional[str] = None,
4354
+ out_image: Optional[str] = None,
4355
+ **kwargs: Any,
4356
+ ) -> Union[Tuple[np.ndarray, "pd.DataFrame"], Tuple["xr.DataArray", "pd.DataFrame"]]:
4357
+ """
4358
+ Segment regions in an image and filter them based on size.
4359
+
4360
+ Args:
4361
+ image (Union[str, xr.DataArray, np.ndarray]): Input image, can be a file
4362
+ path, xarray DataArray, or numpy array.
4363
+ connectivity (int, optional): Connectivity for labeling. Defaults to 1
4364
+ for 4-connectivity. Use 2 for 8-connectivity.
4365
+ min_size (int, optional): Minimum size of regions to keep. Defaults to 10.
4366
+ max_size (Optional[int], optional): Maximum size of regions to keep.
4367
+ Defaults to None.
4368
+ threshold (Optional[int], optional): Threshold for filling holes.
4369
+ Defaults to None, which is equal to min_size.
4370
+ properties (Optional[List[str]], optional): List of properties to measure.
4371
+ See https://scikit-image.org/docs/stable/api/skimage.measure.html#skimage.measure.regionprops
4372
+ Defaults to None.
4373
+ intensity_image (Optional[Union[str, xr.DataArray, np.ndarray]], optional):
4374
+ Intensity image to measure properties. Defaults to None.
4375
+ out_csv (Optional[str], optional): Path to save the properties as a CSV file.
4376
+ Defaults to None.
4377
+ out_vector (Optional[str], optional): Path to save the vector file.
4378
+ Defaults to None.
4379
+ out_image (Optional[str], optional): Path to save the output image.
4380
+ Defaults to None.
4381
+
4382
+ Returns:
4383
+ Union[Tuple[np.ndarray, pd.DataFrame], Tuple[xr.DataArray, pd.DataFrame]]: Labeled image and properties DataFrame.
4384
+ """
4385
+ from skimage import measure
4386
+ import scipy.ndimage as ndi
4387
+
4388
+ if isinstance(image, str):
4389
+ ds = rxr.open_rasterio(image)
4390
+ da = ds.sel(band=1)
4391
+ array = da.values.squeeze()
4392
+ elif isinstance(image, xr.DataArray):
4393
+ da = image
4394
+ array = image.values.squeeze()
4395
+ elif isinstance(image, np.ndarray):
4396
+ array = image
4397
+ else:
4398
+ raise ValueError(
4399
+ "The input image must be a file path, xarray DataArray, or numpy array."
4400
+ )
4401
+
4402
+ if threshold is None:
4403
+ threshold = min_size
4404
+
4405
+ # Define a custom function to calculate median intensity
4406
+ def intensity_median(region, intensity_image):
4407
+ # Extract the intensity values for the region
4408
+ return np.median(intensity_image[region])
4409
+
4410
+ # Add your custom function to the list of extra properties
4411
+ if intensity_image is not None:
4412
+ extra_props = (intensity_median,)
4413
+ else:
4414
+ extra_props = None
4415
+
4416
+ if properties is None:
4417
+ properties = [
4418
+ "label",
4419
+ "area",
4420
+ "area_bbox",
4421
+ "area_convex",
4422
+ "area_filled",
4423
+ "major_length",
4424
+ "minor_length",
4425
+ "eccentricity",
4426
+ "diameter_areagth",
4427
+ "extent",
4428
+ "orientation",
4429
+ "perimeter",
4430
+ "solidity",
4431
+ ]
4432
+
4433
+ if intensity_image is not None:
4434
+
4435
+ properties += [
4436
+ "intensity_max",
4437
+ "intensity_mean",
4438
+ "intensity_min",
4439
+ "intensity_std",
4440
+ ]
4441
+
4442
+ if intensity_image is not None:
4443
+ if isinstance(intensity_image, str):
4444
+ ds = rxr.open_rasterio(intensity_image)
4445
+ intensity_da = ds.sel(band=1)
4446
+ intensity_image = intensity_da.values.squeeze()
4447
+ elif isinstance(intensity_image, xr.DataArray):
4448
+ intensity_image = intensity_image.values.squeeze()
4449
+ elif isinstance(intensity_image, np.ndarray):
4450
+ pass
4451
+ else:
4452
+ raise ValueError(
4453
+ "The intensity_image must be a file path, xarray DataArray, or numpy array."
4454
+ )
4455
+
4456
+ label_image = measure.label(array, connectivity=connectivity)
4457
+ props = measure.regionprops_table(
4458
+ label_image, properties=properties, intensity_image=intensity_image, **kwargs
4459
+ )
4460
+
4461
+ df = pd.DataFrame(props)
4462
+
4463
+ # Get the labels of regions with area smaller than the threshold
4464
+ small_regions = df[df["area"] < min_size]["label"].values
4465
+ # Set the corresponding labels in the label_image to zero
4466
+ for region_label in small_regions:
4467
+ label_image[label_image == region_label] = 0
4468
+
4469
+ if max_size is not None:
4470
+ large_regions = df[df["area"] > max_size]["label"].values
4471
+ for region_label in large_regions:
4472
+ label_image[label_image == region_label] = 0
4473
+
4474
+ # Find the background (holes) which are zeros
4475
+ holes = label_image == 0
4476
+
4477
+ # Label the holes (connected components in the background)
4478
+ labeled_holes, _ = ndi.label(holes)
4479
+
4480
+ # Measure properties of the labeled holes, including area and bounding box
4481
+ hole_props = measure.regionprops(labeled_holes)
4482
+
4483
+ # Loop through each hole and fill it if it is smaller than the threshold
4484
+ for prop in hole_props:
4485
+ if prop.area < threshold:
4486
+ # Get the coordinates of the small hole
4487
+ coords = prop.coords
4488
+
4489
+ # Find the surrounding region's ID (non-zero value near the hole)
4490
+ surrounding_region_values = []
4491
+ for coord in coords:
4492
+ x, y = coord
4493
+ # Get a 3x3 neighborhood around the hole pixel
4494
+ neighbors = label_image[max(0, x - 1) : x + 2, max(0, y - 1) : y + 2]
4495
+ # Exclude the hole pixels (zeros) and get region values
4496
+ region_values = neighbors[neighbors != 0]
4497
+ if region_values.size > 0:
4498
+ surrounding_region_values.append(
4499
+ region_values[0]
4500
+ ) # Take the first non-zero value
4501
+
4502
+ if surrounding_region_values:
4503
+ # Fill the hole with the mode (most frequent) of the surrounding region values
4504
+ fill_value = max(
4505
+ set(surrounding_region_values), key=surrounding_region_values.count
4506
+ )
4507
+ label_image[coords[:, 0], coords[:, 1]] = fill_value
4508
+
4509
+ label_image, num_labels = measure.label(
4510
+ label_image, connectivity=connectivity, return_num=True
4511
+ )
4512
+ props = measure.regionprops_table(
4513
+ label_image,
4514
+ properties=properties,
4515
+ intensity_image=intensity_image,
4516
+ extra_properties=extra_props,
4517
+ **kwargs,
4518
+ )
4519
+
4520
+ df = pd.DataFrame(props)
4521
+ df["elongation"] = df["major_length"] / df["minor_length"]
4522
+
4523
+ dtype = "uint8"
4524
+ if num_labels > 255 and num_labels <= 65535:
4525
+ dtype = "uint16"
4526
+ elif num_labels > 65535:
4527
+ dtype = "uint32"
4528
+
4529
+ if out_csv is not None:
4530
+ df.to_csv(out_csv, index=False)
4531
+
4532
+ if isinstance(image, np.ndarray):
4533
+ return label_image, df
4534
+ else:
4535
+ da.values = label_image
4536
+ if out_image is not None:
4537
+ da.rio.to_raster(out_image, dtype=dtype)
4538
+ if out_vector is not None:
4539
+ tmp_vector = temp_file_path(".gpkg")
4540
+ raster_to_vector(out_image, tmp_vector)
4541
+ gdf = gpd.read_file(tmp_vector)
4542
+ gdf["label"] = gdf["value"].astype(int)
4543
+ gdf.drop(columns=["value"], inplace=True)
4544
+ gdf2 = pd.merge(gdf, df, on="label", how="left")
4545
+ gdf2.to_file(out_vector)
4546
+ gdf2.sort_values("label", inplace=True)
4547
+ df = gdf2
4548
+ return da, df
4549
+
4550
+
4551
+ def add_geometric_properties(data, properties=None, area_unit="m2", length_unit="m"):
4552
+ """Calculates geometric properties and adds them to the GeoDataFrame.
4553
+
4554
+ This function calculates various geometric properties of features in a
4555
+ GeoDataFrame and adds them as new columns without modifying existing attributes.
4556
+
4557
+ Args:
4558
+ data: GeoDataFrame containing vector features.
4559
+ properties: List of geometric properties to calculate. Options include:
4560
+ 'area', 'length', 'perimeter', 'centroid_x', 'centroid_y', 'bounds',
4561
+ 'convex_hull_area', 'orientation', 'complexity', 'area_bbox',
4562
+ 'area_convex', 'area_filled', 'major_length', 'minor_length',
4563
+ 'eccentricity', 'diameter_areagth', 'extent', 'solidity',
4564
+ 'elongation'.
4565
+ Defaults to ['area', 'length'] if None.
4566
+ area_unit: String specifying the unit for area calculation ('m2', 'km2',
4567
+ 'ha'). Defaults to 'm2'.
4568
+ length_unit: String specifying the unit for length calculation ('m', 'km').
4569
+ Defaults to 'm'.
4570
+
4571
+ Returns:
4572
+ geopandas.GeoDataFrame: A copy of the input GeoDataFrame with added
4573
+ geometric property columns.
4574
+ """
4575
+ from shapely.ops import unary_union
4576
+
4577
+ if isinstance(data, str):
4578
+ data = read_vector(data)
4579
+
4580
+ # Make a copy to avoid modifying the original
4581
+ result = data.copy()
4582
+
4583
+ # Default properties to calculate
4584
+ if properties is None:
4585
+ properties = [
4586
+ "area",
4587
+ "length",
4588
+ "perimeter",
4589
+ "convex_hull_area",
4590
+ "orientation",
4591
+ "complexity",
4592
+ "area_bbox",
4593
+ "area_convex",
4594
+ "area_filled",
4595
+ "major_length",
4596
+ "minor_length",
4597
+ "eccentricity",
4598
+ "diameter_area",
4599
+ "extent",
4600
+ "solidity",
4601
+ "elongation",
4602
+ ]
4603
+
4604
+ # Make sure we're working with a GeoDataFrame with a valid CRS
4605
+
4606
+ if not isinstance(result, gpd.GeoDataFrame):
4607
+ raise ValueError("Input must be a GeoDataFrame")
4608
+
4609
+ if result.crs is None:
4610
+ raise ValueError(
4611
+ "GeoDataFrame must have a defined coordinate reference system (CRS)"
4612
+ )
4613
+
4614
+ # Ensure we're working with a projected CRS for accurate measurements
4615
+ if result.crs.is_geographic:
4616
+ # Reproject to a suitable projected CRS for accurate measurements
4617
+ result = result.to_crs(result.estimate_utm_crs())
4618
+
4619
+ # Basic area calculation with unit conversion
4620
+ if "area" in properties:
4621
+ # Calculate area (only for polygons)
4622
+ result["area"] = result.geometry.apply(
4623
+ lambda geom: geom.area if isinstance(geom, (Polygon, MultiPolygon)) else 0
4624
+ )
4625
+
4626
+ # Convert to requested units
4627
+ if area_unit == "km2":
4628
+ result["area"] = result["area"] / 1_000_000 # m² to km²
4629
+ result.rename(columns={"area": "area_km2"}, inplace=True)
4630
+ elif area_unit == "ha":
4631
+ result["area"] = result["area"] / 10_000 # m² to hectares
4632
+ result.rename(columns={"area": "area_ha"}, inplace=True)
4633
+ else: # Default is m²
4634
+ result.rename(columns={"area": "area_m2"}, inplace=True)
4635
+
4636
+ # Length calculation with unit conversion
4637
+ if "length" in properties:
4638
+ # Calculate length (works for lines and polygon boundaries)
4639
+ result["length"] = result.geometry.length
4640
+
4641
+ # Convert to requested units
4642
+ if length_unit == "km":
4643
+ result["length"] = result["length"] / 1_000 # m to km
4644
+ result.rename(columns={"length": "length_km"}, inplace=True)
4645
+ else: # Default is m
4646
+ result.rename(columns={"length": "length_m"}, inplace=True)
4647
+
4648
+ # Perimeter calculation (for polygons)
4649
+ if "perimeter" in properties:
4650
+ result["perimeter"] = result.geometry.apply(
4651
+ lambda geom: (
4652
+ geom.boundary.length if isinstance(geom, (Polygon, MultiPolygon)) else 0
4653
+ )
4654
+ )
4655
+
4656
+ # Convert to requested units
4657
+ if length_unit == "km":
4658
+ result["perimeter"] = result["perimeter"] / 1_000 # m to km
4659
+ result.rename(columns={"perimeter": "perimeter_km"}, inplace=True)
4660
+ else: # Default is m
4661
+ result.rename(columns={"perimeter": "perimeter_m"}, inplace=True)
4662
+
4663
+ # Centroid coordinates
4664
+ if "centroid_x" in properties or "centroid_y" in properties:
4665
+ centroids = result.geometry.centroid
4666
+
4667
+ if "centroid_x" in properties:
4668
+ result["centroid_x"] = centroids.x
4669
+
4670
+ if "centroid_y" in properties:
4671
+ result["centroid_y"] = centroids.y
4672
+
4673
+ # Bounding box properties
4674
+ if "bounds" in properties:
4675
+ bounds = result.geometry.bounds
4676
+ result["minx"] = bounds.minx
4677
+ result["miny"] = bounds.miny
4678
+ result["maxx"] = bounds.maxx
4679
+ result["maxy"] = bounds.maxy
4680
+
4681
+ # Area of bounding box
4682
+ if "area_bbox" in properties:
4683
+ bounds = result.geometry.bounds
4684
+ result["area_bbox"] = (bounds.maxx - bounds.minx) * (bounds.maxy - bounds.miny)
4685
+
4686
+ # Convert to requested units
4687
+ if area_unit == "km2":
4688
+ result["area_bbox"] = result["area_bbox"] / 1_000_000
4689
+ result.rename(columns={"area_bbox": "area_bbox_km2"}, inplace=True)
4690
+ elif area_unit == "ha":
4691
+ result["area_bbox"] = result["area_bbox"] / 10_000
4692
+ result.rename(columns={"area_bbox": "area_bbox_ha"}, inplace=True)
4693
+ else: # Default is m²
4694
+ result.rename(columns={"area_bbox": "area_bbox_m2"}, inplace=True)
4695
+
4696
+ # Area of convex hull
4697
+ if "area_convex" in properties or "convex_hull_area" in properties:
4698
+ result["area_convex"] = result.geometry.convex_hull.area
4699
+
4700
+ # Convert to requested units
4701
+ if area_unit == "km2":
4702
+ result["area_convex"] = result["area_convex"] / 1_000_000
4703
+ result.rename(columns={"area_convex": "area_convex_km2"}, inplace=True)
4704
+ elif area_unit == "ha":
4705
+ result["area_convex"] = result["area_convex"] / 10_000
4706
+ result.rename(columns={"area_convex": "area_convex_ha"}, inplace=True)
4707
+ else: # Default is m²
4708
+ result.rename(columns={"area_convex": "area_convex_m2"}, inplace=True)
4709
+
4710
+ # For backward compatibility
4711
+ if "convex_hull_area" in properties and "area_convex" not in properties:
4712
+ result["convex_hull_area"] = result["area_convex"]
4713
+ if area_unit == "km2":
4714
+ result.rename(
4715
+ columns={"convex_hull_area": "convex_hull_area_km2"}, inplace=True
4716
+ )
4717
+ elif area_unit == "ha":
4718
+ result.rename(
4719
+ columns={"convex_hull_area": "convex_hull_area_ha"}, inplace=True
4720
+ )
4721
+ else:
4722
+ result.rename(
4723
+ columns={"convex_hull_area": "convex_hull_area_m2"}, inplace=True
4724
+ )
4725
+
4726
+ # Area of filled geometry (no holes)
4727
+ if "area_filled" in properties:
4728
+
4729
+ def get_filled_area(geom):
4730
+ if not isinstance(geom, (Polygon, MultiPolygon)):
4731
+ return 0
4732
+
4733
+ if isinstance(geom, MultiPolygon):
4734
+ # For MultiPolygon, fill all constituent polygons
4735
+ filled_polys = [Polygon(p.exterior) for p in geom.geoms]
4736
+ return unary_union(filled_polys).area
4737
+ else:
4738
+ # For single Polygon, create a new one with just the exterior ring
4739
+ return Polygon(geom.exterior).area
4740
+
4741
+ result["area_filled"] = result.geometry.apply(get_filled_area)
4742
+
4743
+ # Convert to requested units
4744
+ if area_unit == "km2":
4745
+ result["area_filled"] = result["area_filled"] / 1_000_000
4746
+ result.rename(columns={"area_filled": "area_filled_km2"}, inplace=True)
4747
+ elif area_unit == "ha":
4748
+ result["area_filled"] = result["area_filled"] / 10_000
4749
+ result.rename(columns={"area_filled": "area_filled_ha"}, inplace=True)
4750
+ else: # Default is m²
4751
+ result.rename(columns={"area_filled": "area_filled_m2"}, inplace=True)
4752
+
4753
+ # Axes lengths, eccentricity, orientation, and elongation
4754
+ if any(
4755
+ p in properties
4756
+ for p in [
4757
+ "major_length",
4758
+ "minor_length",
4759
+ "eccentricity",
4760
+ "orientation",
4761
+ "elongation",
4762
+ ]
4763
+ ):
4764
+
4765
+ def get_axes_properties(geom):
4766
+ # Skip non-polygons
4767
+ if not isinstance(geom, (Polygon, MultiPolygon)):
4768
+ return None, None, None, None, None
4769
+
4770
+ # Handle multipolygons by using the largest polygon
4771
+ if isinstance(geom, MultiPolygon):
4772
+ # Get the polygon with the largest area
4773
+ geom = sorted(list(geom.geoms), key=lambda p: p.area, reverse=True)[0]
4774
+
4775
+ try:
4776
+ # Get the minimum rotated rectangle
4777
+ rect = geom.minimum_rotated_rectangle
4778
+
4779
+ # Extract coordinates
4780
+ coords = list(rect.exterior.coords)[
4781
+ :-1
4782
+ ] # Remove the duplicated last point
4783
+
4784
+ if len(coords) < 4:
4785
+ return None, None, None, None, None
4786
+
4787
+ # Calculate lengths of all four sides
4788
+ sides = []
4789
+ for i in range(len(coords)):
4790
+ p1 = coords[i]
4791
+ p2 = coords[(i + 1) % len(coords)]
4792
+ dx = p2[0] - p1[0]
4793
+ dy = p2[1] - p1[1]
4794
+ length = np.sqrt(dx**2 + dy**2)
4795
+ angle = np.degrees(np.arctan2(dy, dx)) % 180
4796
+ sides.append((length, angle, p1, p2))
4797
+
4798
+ # Group sides by length (allowing for small differences due to floating point precision)
4799
+ # This ensures we correctly identify the rectangle's dimensions
4800
+ sides_grouped = {}
4801
+ tolerance = 1e-6 # Tolerance for length comparison
4802
+
4803
+ for s in sides:
4804
+ length, angle = s[0], s[1]
4805
+ matched = False
4806
+
4807
+ for key in sides_grouped:
4808
+ if abs(length - key) < tolerance:
4809
+ sides_grouped[key].append(s)
4810
+ matched = True
4811
+ break
4812
+
4813
+ if not matched:
4814
+ sides_grouped[length] = [s]
4815
+
4816
+ # Get unique lengths (should be 2 for a rectangle, parallel sides have equal length)
4817
+ unique_lengths = sorted(sides_grouped.keys(), reverse=True)
4818
+
4819
+ if len(unique_lengths) != 2:
4820
+ # If we don't get exactly 2 unique lengths, something is wrong with the rectangle
4821
+ # Fall back to simpler method using bounds
4822
+ bounds = rect.bounds
4823
+ width = bounds[2] - bounds[0]
4824
+ height = bounds[3] - bounds[1]
4825
+ major_length = max(width, height)
4826
+ minor_length = min(width, height)
4827
+ orientation = 0 if width > height else 90
4828
+ else:
4829
+ major_length = unique_lengths[0]
4830
+ minor_length = unique_lengths[1]
4831
+ # Get orientation from the major axis
4832
+ orientation = sides_grouped[major_length][0][1]
4833
+
4834
+ # Calculate eccentricity
4835
+ if major_length > 0:
4836
+ # Eccentricity for an ellipse: e = sqrt(1 - (b²/a²))
4837
+ # where a is the semi-major axis and b is the semi-minor axis
4838
+ eccentricity = np.sqrt(
4839
+ 1 - ((minor_length / 2) ** 2 / (major_length / 2) ** 2)
4840
+ )
4841
+ else:
4842
+ eccentricity = 0
4843
+
4844
+ # Calculate elongation (ratio of minor to major axis)
4845
+ elongation = major_length / minor_length if major_length > 0 else 1
4846
+
4847
+ return major_length, minor_length, eccentricity, orientation, elongation
4848
+
4849
+ except Exception as e:
4850
+ # For debugging
4851
+ # print(f"Error calculating axes: {e}")
4852
+ return None, None, None, None, None
4853
+
4854
+ # Apply the function and split the results
4855
+ axes_data = result.geometry.apply(get_axes_properties)
4856
+
4857
+ if "major_length" in properties:
4858
+ result["major_length"] = axes_data.apply(lambda x: x[0] if x else None)
4859
+ # Convert to requested units
4860
+ if length_unit == "km":
4861
+ result["major_length"] = result["major_length"] / 1_000
4862
+ result.rename(columns={"major_length": "major_length_km"}, inplace=True)
4863
+ else:
4864
+ result.rename(columns={"major_length": "major_length_m"}, inplace=True)
4865
+
4866
+ if "minor_length" in properties:
4867
+ result["minor_length"] = axes_data.apply(lambda x: x[1] if x else None)
4868
+ # Convert to requested units
4869
+ if length_unit == "km":
4870
+ result["minor_length"] = result["minor_length"] / 1_000
4871
+ result.rename(columns={"minor_length": "minor_length_km"}, inplace=True)
4872
+ else:
4873
+ result.rename(columns={"minor_length": "minor_length_m"}, inplace=True)
4874
+
4875
+ if "eccentricity" in properties:
4876
+ result["eccentricity"] = axes_data.apply(lambda x: x[2] if x else None)
4877
+
4878
+ if "orientation" in properties:
4879
+ result["orientation"] = axes_data.apply(lambda x: x[3] if x else None)
4880
+
4881
+ if "elongation" in properties:
4882
+ result["elongation"] = axes_data.apply(lambda x: x[4] if x else None)
4883
+
4884
+ # Equivalent diameter based on area
4885
+ if "diameter_areagth" in properties:
4886
+
4887
+ def get_equivalent_diameter(geom):
4888
+ if not isinstance(geom, (Polygon, MultiPolygon)) or geom.area <= 0:
4889
+ return None
4890
+ # Diameter of a circle with the same area: d = 2 * sqrt(A / π)
4891
+ return 2 * np.sqrt(geom.area / np.pi)
4892
+
4893
+ result["diameter_areagth"] = result.geometry.apply(get_equivalent_diameter)
4894
+
4895
+ # Convert to requested units
4896
+ if length_unit == "km":
4897
+ result["diameter_areagth"] = result["diameter_areagth"] / 1_000
4898
+ result.rename(
4899
+ columns={"diameter_areagth": "equivalent_diameter_area_km"},
4900
+ inplace=True,
4901
+ )
4902
+ else:
4903
+ result.rename(
4904
+ columns={"diameter_areagth": "equivalent_diameter_area_m"},
4905
+ inplace=True,
4906
+ )
4907
+
4908
+ # Extent (ratio of shape area to bounding box area)
4909
+ if "extent" in properties:
4910
+
4911
+ def get_extent(geom):
4912
+ if not isinstance(geom, (Polygon, MultiPolygon)) or geom.area <= 0:
4913
+ return None
4914
+
4915
+ bounds = geom.bounds
4916
+ bbox_area = (bounds[2] - bounds[0]) * (bounds[3] - bounds[1])
4917
+
4918
+ if bbox_area > 0:
4919
+ return geom.area / bbox_area
4920
+ return None
4921
+
4922
+ result["extent"] = result.geometry.apply(get_extent)
4923
+
4924
+ # Solidity (ratio of shape area to convex hull area)
4925
+ if "solidity" in properties:
4926
+
4927
+ def get_solidity(geom):
4928
+ if not isinstance(geom, (Polygon, MultiPolygon)) or geom.area <= 0:
4929
+ return None
4930
+
4931
+ convex_hull_area = geom.convex_hull.area
4932
+
4933
+ if convex_hull_area > 0:
4934
+ return geom.area / convex_hull_area
4935
+ return None
4936
+
4937
+ result["solidity"] = result.geometry.apply(get_solidity)
4938
+
4939
+ # Complexity (ratio of perimeter to area)
4940
+ if "complexity" in properties:
4941
+
4942
+ def calc_complexity(geom):
4943
+ if isinstance(geom, (Polygon, MultiPolygon)) and geom.area > 0:
4944
+ # Shape index: P / (2 * sqrt(π * A))
4945
+ # Normalized to 1 for a circle, higher for more complex shapes
4946
+ return geom.boundary.length / (2 * np.sqrt(np.pi * geom.area))
4947
+ return None
4948
+
4949
+ result["complexity"] = result.geometry.apply(calc_complexity)
4950
+
4951
+ return result