geoai-py 0.3.5__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/preprocess.py DELETED
@@ -1,3021 +0,0 @@
1
- import json
2
- import math
3
- import os
4
- from PIL import Image
5
- from pathlib import Path
6
- import requests
7
- import warnings
8
- import xml.etree.ElementTree as ET
9
- import numpy as np
10
- import rasterio
11
- import geopandas as gpd
12
- import pandas as pd
13
- from rasterio.windows import Window
14
- from rasterio import features
15
- from rasterio.plot import show
16
- from shapely.geometry import box, shape, mapping, Polygon
17
- import matplotlib.pyplot as plt
18
- from tqdm import tqdm
19
- from torchvision.transforms import RandomRotation
20
- from shapely.affinity import rotate
21
- import torch
22
- import cv2
23
-
24
- try:
25
- import torchgeo
26
- except ImportError as e:
27
- raise ImportError(
28
- "Your torchgeo version is too old. Please upgrade to the latest version using 'pip install -U torchgeo'."
29
- )
30
-
31
-
32
- def download_file(url, output_path=None, overwrite=False):
33
- """
34
- Download a file from a given URL with a progress bar.
35
-
36
- Args:
37
- url (str): The URL of the file to download.
38
- output_path (str, optional): The path where the downloaded file will be saved.
39
- If not provided, the filename from the URL will be used.
40
- overwrite (bool, optional): Whether to overwrite the file if it already exists.
41
-
42
- Returns:
43
- str: The path to the downloaded file.
44
- """
45
- # Get the filename from the URL if output_path is not provided
46
- if output_path is None:
47
- output_path = os.path.basename(url)
48
-
49
- # Check if the file already exists
50
- if os.path.exists(output_path) and not overwrite:
51
- print(f"File already exists: {output_path}")
52
- return output_path
53
-
54
- # Send a streaming GET request
55
- response = requests.get(url, stream=True, timeout=50)
56
- response.raise_for_status() # Raise an exception for HTTP errors
57
-
58
- # Get the total file size if available
59
- total_size = int(response.headers.get("content-length", 0))
60
-
61
- # Open the output file
62
- with (
63
- open(output_path, "wb") as file,
64
- tqdm(
65
- desc=os.path.basename(output_path),
66
- total=total_size,
67
- unit="B",
68
- unit_scale=True,
69
- unit_divisor=1024,
70
- ) as progress_bar,
71
- ):
72
-
73
- # Download the file in chunks and update the progress bar
74
- for chunk in response.iter_content(chunk_size=1024):
75
- if chunk: # filter out keep-alive new chunks
76
- file.write(chunk)
77
- progress_bar.update(len(chunk))
78
-
79
- return output_path
80
-
81
-
82
- def get_raster_info(raster_path):
83
- """Display basic information about a raster dataset.
84
-
85
- Args:
86
- raster_path (str): Path to the raster file
87
-
88
- Returns:
89
- dict: Dictionary containing the basic information about the raster
90
- """
91
- # Open the raster dataset
92
- with rasterio.open(raster_path) as src:
93
- # Get basic metadata
94
- info = {
95
- "driver": src.driver,
96
- "width": src.width,
97
- "height": src.height,
98
- "count": src.count,
99
- "dtype": src.dtypes[0],
100
- "crs": src.crs.to_string() if src.crs else "No CRS defined",
101
- "transform": src.transform,
102
- "bounds": src.bounds,
103
- "resolution": (src.transform[0], -src.transform[4]),
104
- "nodata": src.nodata,
105
- }
106
-
107
- # Calculate statistics for each band
108
- stats = []
109
- for i in range(1, src.count + 1):
110
- band = src.read(i, masked=True)
111
- band_stats = {
112
- "band": i,
113
- "min": float(band.min()),
114
- "max": float(band.max()),
115
- "mean": float(band.mean()),
116
- "std": float(band.std()),
117
- }
118
- stats.append(band_stats)
119
-
120
- info["band_stats"] = stats
121
-
122
- return info
123
-
124
-
125
- def get_raster_stats(raster_path, divide_by=1.0):
126
- """Calculate statistics for each band in a raster dataset.
127
-
128
- This function computes min, max, mean, and standard deviation values
129
- for each band in the provided raster, returning results in a dictionary
130
- with lists for each statistic type.
131
-
132
- Args:
133
- raster_path (str): Path to the raster file
134
- divide_by (float, optional): Value to divide pixel values by.
135
- Defaults to 1.0, which keeps the original pixel
136
-
137
- Returns:
138
- dict: Dictionary containing lists of statistics with keys:
139
- - 'min': List of minimum values for each band
140
- - 'max': List of maximum values for each band
141
- - 'mean': List of mean values for each band
142
- - 'std': List of standard deviation values for each band
143
- """
144
- # Initialize the results dictionary with empty lists
145
- stats = {"min": [], "max": [], "mean": [], "std": []}
146
-
147
- # Open the raster dataset
148
- with rasterio.open(raster_path) as src:
149
- # Calculate statistics for each band
150
- for i in range(1, src.count + 1):
151
- band = src.read(i, masked=True)
152
-
153
- # Append statistics for this band to each list
154
- stats["min"].append(float(band.min()) / divide_by)
155
- stats["max"].append(float(band.max()) / divide_by)
156
- stats["mean"].append(float(band.mean()) / divide_by)
157
- stats["std"].append(float(band.std()) / divide_by)
158
-
159
- return stats
160
-
161
-
162
- def print_raster_info(raster_path, show_preview=True, figsize=(10, 8)):
163
- """Print formatted information about a raster dataset and optionally show a preview.
164
-
165
- Args:
166
- raster_path (str): Path to the raster file
167
- show_preview (bool, optional): Whether to display a visual preview of the raster.
168
- Defaults to True.
169
- figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
170
-
171
- Returns:
172
- dict: Dictionary containing raster information if successful, None otherwise
173
- """
174
- try:
175
- info = get_raster_info(raster_path)
176
-
177
- # Print basic information
178
- print(f"===== RASTER INFORMATION: {raster_path} =====")
179
- print(f"Driver: {info['driver']}")
180
- print(f"Dimensions: {info['width']} x {info['height']} pixels")
181
- print(f"Number of bands: {info['count']}")
182
- print(f"Data type: {info['dtype']}")
183
- print(f"Coordinate Reference System: {info['crs']}")
184
- print(f"Georeferenced Bounds: {info['bounds']}")
185
- print(f"Pixel Resolution: {info['resolution'][0]}, {info['resolution'][1]}")
186
- print(f"NoData Value: {info['nodata']}")
187
-
188
- # Print band statistics
189
- print("\n----- Band Statistics -----")
190
- for band_stat in info["band_stats"]:
191
- print(f"Band {band_stat['band']}:")
192
- print(f" Min: {band_stat['min']:.2f}")
193
- print(f" Max: {band_stat['max']:.2f}")
194
- print(f" Mean: {band_stat['mean']:.2f}")
195
- print(f" Std Dev: {band_stat['std']:.2f}")
196
-
197
- # Show a preview if requested
198
- if show_preview:
199
- with rasterio.open(raster_path) as src:
200
- # For multi-band images, show RGB composite or first band
201
- if src.count >= 3:
202
- # Try to show RGB composite
203
- rgb = np.dstack([src.read(i) for i in range(1, 4)])
204
- plt.figure(figsize=figsize)
205
- plt.imshow(rgb)
206
- plt.title(f"RGB Preview: {raster_path}")
207
- else:
208
- # Show first band for single-band images
209
- plt.figure(figsize=figsize)
210
- show(
211
- src.read(1),
212
- cmap="viridis",
213
- title=f"Band 1 Preview: {raster_path}",
214
- )
215
- plt.colorbar(label="Pixel Value")
216
- plt.show()
217
-
218
- except Exception as e:
219
- print(f"Error reading raster: {str(e)}")
220
-
221
-
222
- def get_raster_info_gdal(raster_path):
223
- """Get basic information about a raster dataset using GDAL.
224
-
225
- Args:
226
- raster_path (str): Path to the raster file
227
-
228
- Returns:
229
- dict: Dictionary containing the basic information about the raster,
230
- or None if the file cannot be opened
231
- """
232
-
233
- from osgeo import gdal
234
-
235
- # Open the dataset
236
- ds = gdal.Open(raster_path)
237
- if ds is None:
238
- print(f"Error: Could not open {raster_path}")
239
- return None
240
-
241
- # Get basic information
242
- info = {
243
- "driver": ds.GetDriver().ShortName,
244
- "width": ds.RasterXSize,
245
- "height": ds.RasterYSize,
246
- "count": ds.RasterCount,
247
- "projection": ds.GetProjection(),
248
- "geotransform": ds.GetGeoTransform(),
249
- }
250
-
251
- # Calculate resolution
252
- gt = ds.GetGeoTransform()
253
- if gt:
254
- info["resolution"] = (abs(gt[1]), abs(gt[5]))
255
- info["origin"] = (gt[0], gt[3])
256
-
257
- # Get band information
258
- bands_info = []
259
- for i in range(1, ds.RasterCount + 1):
260
- band = ds.GetRasterBand(i)
261
- stats = band.GetStatistics(True, True)
262
- band_info = {
263
- "band": i,
264
- "datatype": gdal.GetDataTypeName(band.DataType),
265
- "min": stats[0],
266
- "max": stats[1],
267
- "mean": stats[2],
268
- "std": stats[3],
269
- "nodata": band.GetNoDataValue(),
270
- }
271
- bands_info.append(band_info)
272
-
273
- info["bands"] = bands_info
274
-
275
- # Close the dataset
276
- ds = None
277
-
278
- return info
279
-
280
-
281
- def get_vector_info(vector_path):
282
- """Display basic information about a vector dataset using GeoPandas.
283
-
284
- Args:
285
- vector_path (str): Path to the vector file
286
-
287
- Returns:
288
- dict: Dictionary containing the basic information about the vector dataset
289
- """
290
- # Open the vector dataset
291
- gdf = (
292
- gpd.read_parquet(vector_path)
293
- if vector_path.endswith(".parquet")
294
- else gpd.read_file(vector_path)
295
- )
296
-
297
- # Get basic metadata
298
- info = {
299
- "file_path": vector_path,
300
- "driver": os.path.splitext(vector_path)[1][1:].upper(), # Format from extension
301
- "feature_count": len(gdf),
302
- "crs": str(gdf.crs),
303
- "geometry_type": str(gdf.geom_type.value_counts().to_dict()),
304
- "attribute_count": len(gdf.columns) - 1, # Subtract the geometry column
305
- "attribute_names": list(gdf.columns[gdf.columns != "geometry"]),
306
- "bounds": gdf.total_bounds.tolist(),
307
- }
308
-
309
- # Add statistics about numeric attributes
310
- numeric_columns = gdf.select_dtypes(include=["number"]).columns
311
- attribute_stats = {}
312
- for col in numeric_columns:
313
- if col != "geometry":
314
- attribute_stats[col] = {
315
- "min": gdf[col].min(),
316
- "max": gdf[col].max(),
317
- "mean": gdf[col].mean(),
318
- "std": gdf[col].std(),
319
- "null_count": gdf[col].isna().sum(),
320
- }
321
-
322
- info["attribute_stats"] = attribute_stats
323
-
324
- return info
325
-
326
-
327
- def print_vector_info(vector_path, show_preview=True, figsize=(10, 8)):
328
- """Print formatted information about a vector dataset and optionally show a preview.
329
-
330
- Args:
331
- vector_path (str): Path to the vector file
332
- show_preview (bool, optional): Whether to display a visual preview of the vector data.
333
- Defaults to True.
334
- figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
335
-
336
- Returns:
337
- dict: Dictionary containing vector information if successful, None otherwise
338
- """
339
- try:
340
- info = get_vector_info(vector_path)
341
-
342
- # Print basic information
343
- print(f"===== VECTOR INFORMATION: {vector_path} =====")
344
- print(f"Driver: {info['driver']}")
345
- print(f"Feature count: {info['feature_count']}")
346
- print(f"Geometry types: {info['geometry_type']}")
347
- print(f"Coordinate Reference System: {info['crs']}")
348
- print(f"Bounds: {info['bounds']}")
349
- print(f"Number of attributes: {info['attribute_count']}")
350
- print(f"Attribute names: {', '.join(info['attribute_names'])}")
351
-
352
- # Print attribute statistics
353
- if info["attribute_stats"]:
354
- print("\n----- Attribute Statistics -----")
355
- for attr, stats in info["attribute_stats"].items():
356
- print(f"Attribute: {attr}")
357
- for stat_name, stat_value in stats.items():
358
- print(
359
- f" {stat_name}: {stat_value:.4f}"
360
- if isinstance(stat_value, float)
361
- else f" {stat_name}: {stat_value}"
362
- )
363
-
364
- # Show a preview if requested
365
- if show_preview:
366
- gdf = (
367
- gpd.read_parquet(vector_path)
368
- if vector_path.endswith(".parquet")
369
- else gpd.read_file(vector_path)
370
- )
371
- fig, ax = plt.subplots(figsize=figsize)
372
- gdf.plot(ax=ax, cmap="viridis")
373
- ax.set_title(f"Preview: {vector_path}")
374
- plt.tight_layout()
375
- plt.show()
376
-
377
- # # Show a sample of the attribute table
378
- # if not gdf.empty:
379
- # print("\n----- Sample of attribute table (first 5 rows) -----")
380
- # print(gdf.head().to_string())
381
-
382
- except Exception as e:
383
- print(f"Error reading vector data: {str(e)}")
384
-
385
-
386
- # Alternative implementation using OGR directly
387
- def get_vector_info_ogr(vector_path):
388
- """Get basic information about a vector dataset using OGR.
389
-
390
- Args:
391
- vector_path (str): Path to the vector file
392
-
393
- Returns:
394
- dict: Dictionary containing the basic information about the vector dataset,
395
- or None if the file cannot be opened
396
- """
397
- from osgeo import ogr
398
-
399
- # Register all OGR drivers
400
- ogr.RegisterAll()
401
-
402
- # Open the dataset
403
- ds = ogr.Open(vector_path)
404
- if ds is None:
405
- print(f"Error: Could not open {vector_path}")
406
- return None
407
-
408
- # Basic dataset information
409
- info = {
410
- "file_path": vector_path,
411
- "driver": ds.GetDriver().GetName(),
412
- "layer_count": ds.GetLayerCount(),
413
- "layers": [],
414
- }
415
-
416
- # Extract information for each layer
417
- for i in range(ds.GetLayerCount()):
418
- layer = ds.GetLayer(i)
419
- layer_info = {
420
- "name": layer.GetName(),
421
- "feature_count": layer.GetFeatureCount(),
422
- "geometry_type": ogr.GeometryTypeToName(layer.GetGeomType()),
423
- "spatial_ref": (
424
- layer.GetSpatialRef().ExportToWkt() if layer.GetSpatialRef() else "None"
425
- ),
426
- "extent": layer.GetExtent(),
427
- "fields": [],
428
- }
429
-
430
- # Get field information
431
- defn = layer.GetLayerDefn()
432
- for j in range(defn.GetFieldCount()):
433
- field_defn = defn.GetFieldDefn(j)
434
- field_info = {
435
- "name": field_defn.GetName(),
436
- "type": field_defn.GetTypeName(),
437
- "width": field_defn.GetWidth(),
438
- "precision": field_defn.GetPrecision(),
439
- }
440
- layer_info["fields"].append(field_info)
441
-
442
- info["layers"].append(layer_info)
443
-
444
- # Close the dataset
445
- ds = None
446
-
447
- return info
448
-
449
-
450
- def analyze_vector_attributes(vector_path, attribute_name):
451
- """Analyze a specific attribute in a vector dataset and create a histogram.
452
-
453
- Args:
454
- vector_path (str): Path to the vector file
455
- attribute_name (str): Name of the attribute to analyze
456
-
457
- Returns:
458
- dict: Dictionary containing analysis results for the attribute
459
- """
460
- try:
461
- gdf = gpd.read_file(vector_path)
462
-
463
- # Check if attribute exists
464
- if attribute_name not in gdf.columns:
465
- print(f"Attribute '{attribute_name}' not found in the dataset")
466
- return None
467
-
468
- # Get the attribute series
469
- attr = gdf[attribute_name]
470
-
471
- # Perform different analyses based on data type
472
- if pd.api.types.is_numeric_dtype(attr):
473
- # Numeric attribute
474
- analysis = {
475
- "attribute": attribute_name,
476
- "type": "numeric",
477
- "count": attr.count(),
478
- "null_count": attr.isna().sum(),
479
- "min": attr.min(),
480
- "max": attr.max(),
481
- "mean": attr.mean(),
482
- "median": attr.median(),
483
- "std": attr.std(),
484
- "unique_values": attr.nunique(),
485
- }
486
-
487
- # Create histogram
488
- plt.figure(figsize=(10, 6))
489
- plt.hist(attr.dropna(), bins=20, alpha=0.7, color="blue")
490
- plt.title(f"Histogram of {attribute_name}")
491
- plt.xlabel(attribute_name)
492
- plt.ylabel("Frequency")
493
- plt.grid(True, alpha=0.3)
494
- plt.show()
495
-
496
- else:
497
- # Categorical attribute
498
- analysis = {
499
- "attribute": attribute_name,
500
- "type": "categorical",
501
- "count": attr.count(),
502
- "null_count": attr.isna().sum(),
503
- "unique_values": attr.nunique(),
504
- "value_counts": attr.value_counts().to_dict(),
505
- }
506
-
507
- # Create bar plot for top categories
508
- top_n = min(10, attr.nunique())
509
- plt.figure(figsize=(10, 6))
510
- attr.value_counts().head(top_n).plot(kind="bar", color="skyblue")
511
- plt.title(f"Top {top_n} values for {attribute_name}")
512
- plt.xlabel(attribute_name)
513
- plt.ylabel("Count")
514
- plt.xticks(rotation=45)
515
- plt.grid(True, alpha=0.3)
516
- plt.tight_layout()
517
- plt.show()
518
-
519
- return analysis
520
-
521
- except Exception as e:
522
- print(f"Error analyzing attribute: {str(e)}")
523
- return None
524
-
525
-
526
- def visualize_vector_by_attribute(
527
- vector_path, attribute_name, cmap="viridis", figsize=(10, 8)
528
- ):
529
- """Create a thematic map visualization of vector data based on an attribute.
530
-
531
- Args:
532
- vector_path (str): Path to the vector file
533
- attribute_name (str): Name of the attribute to visualize
534
- cmap (str, optional): Matplotlib colormap name. Defaults to 'viridis'.
535
- figsize (tuple, optional): Figure size as (width, height). Defaults to (10, 8).
536
-
537
- Returns:
538
- bool: True if visualization was successful, False otherwise
539
- """
540
- try:
541
- # Read the vector data
542
- gdf = gpd.read_file(vector_path)
543
-
544
- # Check if attribute exists
545
- if attribute_name not in gdf.columns:
546
- print(f"Attribute '{attribute_name}' not found in the dataset")
547
- return False
548
-
549
- # Create the plot
550
- fig, ax = plt.subplots(figsize=figsize)
551
-
552
- # Determine plot type based on data type
553
- if pd.api.types.is_numeric_dtype(gdf[attribute_name]):
554
- # Continuous data
555
- gdf.plot(column=attribute_name, cmap=cmap, legend=True, ax=ax)
556
- else:
557
- # Categorical data
558
- gdf.plot(column=attribute_name, categorical=True, legend=True, ax=ax)
559
-
560
- # Add title and labels
561
- ax.set_title(f"{os.path.basename(vector_path)} - {attribute_name}")
562
- ax.set_xlabel("Longitude")
563
- ax.set_ylabel("Latitude")
564
-
565
- # Add basemap or additional elements if available
566
- # Note: Additional options could be added here for more complex maps
567
-
568
- plt.tight_layout()
569
- plt.show()
570
-
571
- except Exception as e:
572
- print(f"Error visualizing data: {str(e)}")
573
-
574
-
575
- def clip_raster_by_bbox(
576
- input_raster, output_raster, bbox, bands=None, bbox_type="geo", bbox_crs=None
577
- ):
578
- """
579
- Clip a raster dataset using a bounding box and optionally select specific bands.
580
-
581
- Args:
582
- input_raster (str): Path to the input raster file.
583
- output_raster (str): Path where the clipped raster will be saved.
584
- bbox (tuple): Bounding box coordinates either as:
585
- - Geographic coordinates (minx, miny, maxx, maxy) if bbox_type="geo"
586
- - Pixel indices (min_row, min_col, max_row, max_col) if bbox_type="pixel"
587
- bands (list, optional): List of band indices to keep (1-based indexing).
588
- If None, all bands will be kept.
589
- bbox_type (str, optional): Type of bounding box coordinates. Either "geo" for
590
- geographic coordinates or "pixel" for row/column indices.
591
- Default is "geo".
592
- bbox_crs (str or dict, optional): CRS of the bbox if different from the raster CRS.
593
- Can be provided as EPSG code (e.g., "EPSG:4326") or
594
- as a proj4 string. Only applies when bbox_type="geo".
595
- If None, assumes bbox is in the same CRS as the raster.
596
-
597
- Returns:
598
- str: Path to the clipped output raster.
599
-
600
- Raises:
601
- ImportError: If required dependencies are not installed.
602
- ValueError: If the bbox is invalid, bands are out of range, or bbox_type is invalid.
603
- RuntimeError: If the clipping operation fails.
604
-
605
- Examples:
606
- # Clip using geographic coordinates in the same CRS as the raster
607
- >>> clip_raster_by_bbox('input.tif', 'clipped_geo.tif', (100, 200, 300, 400))
608
- 'clipped_geo.tif'
609
-
610
- # Clip using WGS84 coordinates when the raster is in a different CRS
611
- >>> clip_raster_by_bbox('input.tif', 'clipped_wgs84.tif', (-122.5, 37.7, -122.4, 37.8),
612
- ... bbox_crs="EPSG:4326")
613
- 'clipped_wgs84.tif'
614
-
615
- # Clip using row/column indices
616
- >>> clip_raster_by_bbox('input.tif', 'clipped_pixel.tif', (50, 100, 150, 200),
617
- ... bbox_type="pixel")
618
- 'clipped_pixel.tif'
619
-
620
- # Clip with band selection
621
- >>> clip_raster_by_bbox('input.tif', 'clipped_bands.tif', (100, 200, 300, 400),
622
- ... bands=[1, 3])
623
- 'clipped_bands.tif'
624
- """
625
- from rasterio.transform import from_bounds
626
- from rasterio.warp import transform_bounds
627
-
628
- # Validate bbox_type
629
- if bbox_type not in ["geo", "pixel"]:
630
- raise ValueError("bbox_type must be either 'geo' or 'pixel'")
631
-
632
- # Validate bbox
633
- if len(bbox) != 4:
634
- raise ValueError("bbox must contain exactly 4 values")
635
-
636
- # Open the source raster
637
- with rasterio.open(input_raster) as src:
638
- # Get the source CRS
639
- src_crs = src.crs
640
-
641
- # Handle different bbox types
642
- if bbox_type == "geo":
643
- minx, miny, maxx, maxy = bbox
644
-
645
- # Validate geographic bbox
646
- if minx >= maxx or miny >= maxy:
647
- raise ValueError(
648
- "Invalid geographic bbox. Expected (minx, miny, maxx, maxy) where minx < maxx and miny < maxy"
649
- )
650
-
651
- # If bbox_crs is provided and different from the source CRS, transform the bbox
652
- if bbox_crs is not None and bbox_crs != src_crs:
653
- try:
654
- # Transform bbox coordinates from bbox_crs to src_crs
655
- minx, miny, maxx, maxy = transform_bounds(
656
- bbox_crs, src_crs, minx, miny, maxx, maxy
657
- )
658
- except Exception as e:
659
- raise ValueError(
660
- f"Failed to transform bbox from {bbox_crs} to {src_crs}: {str(e)}"
661
- )
662
-
663
- # Calculate the pixel window from geographic coordinates
664
- window = src.window(minx, miny, maxx, maxy)
665
-
666
- # Use the same bounds for the output transform
667
- output_bounds = (minx, miny, maxx, maxy)
668
-
669
- else: # bbox_type == "pixel"
670
- min_row, min_col, max_row, max_col = bbox
671
-
672
- # Validate pixel bbox
673
- if min_row >= max_row or min_col >= max_col:
674
- raise ValueError(
675
- "Invalid pixel bbox. Expected (min_row, min_col, max_row, max_col) where min_row < max_row and min_col < max_col"
676
- )
677
-
678
- if (
679
- min_row < 0
680
- or min_col < 0
681
- or max_row > src.height
682
- or max_col > src.width
683
- ):
684
- raise ValueError(
685
- f"Pixel indices out of bounds. Raster dimensions are {src.height} rows x {src.width} columns"
686
- )
687
-
688
- # Create a window from pixel coordinates
689
- window = Window(min_col, min_row, max_col - min_col, max_row - min_row)
690
-
691
- # Calculate the geographic bounds for this window
692
- window_transform = src.window_transform(window)
693
- output_bounds = rasterio.transform.array_bounds(
694
- window.height, window.width, window_transform
695
- )
696
- # Reorder to (minx, miny, maxx, maxy)
697
- output_bounds = (
698
- output_bounds[0],
699
- output_bounds[1],
700
- output_bounds[2],
701
- output_bounds[3],
702
- )
703
-
704
- # Get window dimensions
705
- window_width = int(window.width)
706
- window_height = int(window.height)
707
-
708
- # Check if the window is valid
709
- if window_width <= 0 or window_height <= 0:
710
- raise ValueError("Bounding box results in an empty window")
711
-
712
- # Handle band selection
713
- if bands is None:
714
- # Use all bands
715
- bands_to_read = list(range(1, src.count + 1))
716
- else:
717
- # Validate band indices
718
- if not all(1 <= b <= src.count for b in bands):
719
- raise ValueError(f"Band indices must be between 1 and {src.count}")
720
- bands_to_read = bands
721
-
722
- # Calculate new transform for the clipped raster
723
- new_transform = from_bounds(
724
- output_bounds[0],
725
- output_bounds[1],
726
- output_bounds[2],
727
- output_bounds[3],
728
- window_width,
729
- window_height,
730
- )
731
-
732
- # Create a metadata dictionary for the output
733
- out_meta = src.meta.copy()
734
- out_meta.update(
735
- {
736
- "height": window_height,
737
- "width": window_width,
738
- "transform": new_transform,
739
- "count": len(bands_to_read),
740
- }
741
- )
742
-
743
- # Read the data for the selected bands
744
- data = []
745
- for band_idx in bands_to_read:
746
- band_data = src.read(band_idx, window=window)
747
- data.append(band_data)
748
-
749
- # Stack the bands into a single array
750
- if len(data) > 1:
751
- clipped_data = np.stack(data)
752
- else:
753
- clipped_data = data[0][np.newaxis, :, :]
754
-
755
- # Write the output raster
756
- with rasterio.open(output_raster, "w", **out_meta) as dst:
757
- dst.write(clipped_data)
758
-
759
- return output_raster
760
-
761
-
762
- def raster_to_vector(
763
- raster_path,
764
- output_path=None,
765
- threshold=0,
766
- min_area=10,
767
- simplify_tolerance=None,
768
- class_values=None,
769
- attribute_name="class",
770
- output_format="geojson",
771
- plot_result=False,
772
- ):
773
- """
774
- Convert a raster label mask to vector polygons.
775
-
776
- Args:
777
- raster_path (str): Path to the input raster file (e.g., GeoTIFF).
778
- output_path (str): Path to save the output vector file. If None, returns GeoDataFrame without saving.
779
- threshold (int/float): Pixel values greater than this threshold will be vectorized.
780
- min_area (float): Minimum polygon area in square map units to keep.
781
- simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
782
- class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
783
- attribute_name (str): Name of the attribute field for the class values.
784
- output_format (str): Format for output file - 'geojson', 'shapefile', 'gpkg'.
785
- plot_result (bool): Whether to plot the resulting polygons overlaid on the raster.
786
-
787
- Returns:
788
- geopandas.GeoDataFrame: A GeoDataFrame containing the vectorized polygons.
789
- """
790
- # Open the raster file
791
- with rasterio.open(raster_path) as src:
792
- # Read the data
793
- data = src.read(1)
794
-
795
- # Get metadata
796
- transform = src.transform
797
- crs = src.crs
798
-
799
- # Create mask based on threshold and class values
800
- if class_values is not None:
801
- # Create a mask for each specified class value
802
- masks = {val: (data == val) for val in class_values}
803
- else:
804
- # Create a mask for values above threshold
805
- masks = {1: (data > threshold)}
806
- class_values = [1] # Default class
807
-
808
- # Initialize list to store features
809
- all_features = []
810
-
811
- # Process each class value
812
- for class_val in class_values:
813
- mask = masks[class_val]
814
-
815
- # Vectorize the mask
816
- for geom, value in features.shapes(
817
- mask.astype(np.uint8), mask=mask, transform=transform
818
- ):
819
- # Convert to shapely geometry
820
- geom = shape(geom)
821
-
822
- # Skip small polygons
823
- if geom.area < min_area:
824
- continue
825
-
826
- # Simplify geometry if requested
827
- if simplify_tolerance is not None:
828
- geom = geom.simplify(simplify_tolerance)
829
-
830
- # Add to features list with class value
831
- all_features.append({"geometry": geom, attribute_name: class_val})
832
-
833
- # Create GeoDataFrame
834
- if all_features:
835
- gdf = gpd.GeoDataFrame(all_features, crs=crs)
836
- else:
837
- print("Warning: No features were extracted from the raster.")
838
- # Return empty GeoDataFrame with correct CRS
839
- gdf = gpd.GeoDataFrame([], geometry=[], crs=crs)
840
-
841
- # Save to file if requested
842
- if output_path is not None:
843
- # Create directory if it doesn't exist
844
- os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
845
-
846
- # Save to file based on format
847
- if output_format.lower() == "geojson":
848
- gdf.to_file(output_path, driver="GeoJSON")
849
- elif output_format.lower() == "shapefile":
850
- gdf.to_file(output_path)
851
- elif output_format.lower() == "gpkg":
852
- gdf.to_file(output_path, driver="GPKG")
853
- else:
854
- raise ValueError(f"Unsupported output format: {output_format}")
855
-
856
- print(f"Vectorized data saved to {output_path}")
857
-
858
- # Plot result if requested
859
- if plot_result:
860
- fig, ax = plt.subplots(figsize=(12, 12))
861
-
862
- # Plot raster
863
- raster_img = src.read()
864
- if raster_img.shape[0] == 1:
865
- plt.imshow(raster_img[0], cmap="viridis", alpha=0.7)
866
- else:
867
- # Use first 3 bands for RGB display
868
- rgb = raster_img[:3].transpose(1, 2, 0)
869
- # Normalize for display
870
- rgb = np.clip(rgb / rgb.max(), 0, 1)
871
- plt.imshow(rgb)
872
-
873
- # Plot vector boundaries
874
- if not gdf.empty:
875
- gdf.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
876
-
877
- plt.title("Raster with Vectorized Boundaries")
878
- plt.axis("off")
879
- plt.tight_layout()
880
- plt.show()
881
-
882
- return gdf
883
-
884
-
885
- def batch_raster_to_vector(
886
- input_dir,
887
- output_dir,
888
- pattern="*.tif",
889
- threshold=0,
890
- min_area=10,
891
- simplify_tolerance=None,
892
- class_values=None,
893
- attribute_name="class",
894
- output_format="geojson",
895
- merge_output=False,
896
- merge_filename="merged_vectors",
897
- ):
898
- """
899
- Batch convert multiple raster files to vector polygons.
900
-
901
- Args:
902
- input_dir (str): Directory containing input raster files.
903
- output_dir (str): Directory to save output vector files.
904
- pattern (str): Pattern to match raster files (e.g., '*.tif').
905
- threshold (int/float): Pixel values greater than this threshold will be vectorized.
906
- min_area (float): Minimum polygon area in square map units to keep.
907
- simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
908
- class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
909
- attribute_name (str): Name of the attribute field for the class values.
910
- output_format (str): Format for output files - 'geojson', 'shapefile', 'gpkg'.
911
- merge_output (bool): Whether to merge all output vectors into a single file.
912
- merge_filename (str): Filename for the merged output (without extension).
913
-
914
- Returns:
915
- geopandas.GeoDataFrame or None: If merge_output is True, returns the merged GeoDataFrame.
916
- """
917
- import glob
918
-
919
- # Create output directory if it doesn't exist
920
- os.makedirs(output_dir, exist_ok=True)
921
-
922
- # Get list of raster files
923
- raster_files = glob.glob(os.path.join(input_dir, pattern))
924
-
925
- if not raster_files:
926
- print(f"No files matching pattern '{pattern}' found in {input_dir}")
927
- return None
928
-
929
- print(f"Found {len(raster_files)} raster files to process")
930
-
931
- # Process each raster file
932
- gdfs = []
933
- for raster_file in tqdm(raster_files, desc="Processing rasters"):
934
- # Get output filename
935
- base_name = os.path.splitext(os.path.basename(raster_file))[0]
936
- if output_format.lower() == "geojson":
937
- out_file = os.path.join(output_dir, f"{base_name}.geojson")
938
- elif output_format.lower() == "shapefile":
939
- out_file = os.path.join(output_dir, f"{base_name}.shp")
940
- elif output_format.lower() == "gpkg":
941
- out_file = os.path.join(output_dir, f"{base_name}.gpkg")
942
- else:
943
- raise ValueError(f"Unsupported output format: {output_format}")
944
-
945
- # Convert raster to vector
946
- if merge_output:
947
- # Don't save individual files if merging
948
- gdf = raster_to_vector(
949
- raster_file,
950
- output_path=None,
951
- threshold=threshold,
952
- min_area=min_area,
953
- simplify_tolerance=simplify_tolerance,
954
- class_values=class_values,
955
- attribute_name=attribute_name,
956
- )
957
-
958
- # Add filename as attribute
959
- if not gdf.empty:
960
- gdf["source_file"] = base_name
961
- gdfs.append(gdf)
962
- else:
963
- # Save individual files
964
- raster_to_vector(
965
- raster_file,
966
- output_path=out_file,
967
- threshold=threshold,
968
- min_area=min_area,
969
- simplify_tolerance=simplify_tolerance,
970
- class_values=class_values,
971
- attribute_name=attribute_name,
972
- output_format=output_format,
973
- )
974
-
975
- # Merge output if requested
976
- if merge_output and gdfs:
977
- merged_gdf = gpd.GeoDataFrame(pd.concat(gdfs, ignore_index=True))
978
-
979
- # Set CRS to the CRS of the first GeoDataFrame
980
- if merged_gdf.crs is None and gdfs:
981
- merged_gdf.crs = gdfs[0].crs
982
-
983
- # Save merged output
984
- if output_format.lower() == "geojson":
985
- merged_file = os.path.join(output_dir, f"{merge_filename}.geojson")
986
- merged_gdf.to_file(merged_file, driver="GeoJSON")
987
- elif output_format.lower() == "shapefile":
988
- merged_file = os.path.join(output_dir, f"{merge_filename}.shp")
989
- merged_gdf.to_file(merged_file)
990
- elif output_format.lower() == "gpkg":
991
- merged_file = os.path.join(output_dir, f"{merge_filename}.gpkg")
992
- merged_gdf.to_file(merged_file, driver="GPKG")
993
-
994
- print(f"Merged vector data saved to {merged_file}")
995
- return merged_gdf
996
-
997
- return None
998
-
999
-
1000
- def vector_to_raster(
1001
- vector_path,
1002
- output_path=None,
1003
- reference_raster=None,
1004
- attribute_field=None,
1005
- output_shape=None,
1006
- transform=None,
1007
- pixel_size=None,
1008
- bounds=None,
1009
- crs=None,
1010
- all_touched=False,
1011
- fill_value=0,
1012
- dtype=np.uint8,
1013
- nodata=None,
1014
- plot_result=False,
1015
- ):
1016
- """
1017
- Convert vector data to a raster.
1018
-
1019
- Args:
1020
- vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
1021
- output_path (str): Path to save the output raster file. If None, returns the array without saving.
1022
- reference_raster (str): Path to a reference raster for dimensions, transform and CRS.
1023
- attribute_field (str): Field name in the vector data to use for pixel values.
1024
- If None, all vector features will be burned with value 1.
1025
- output_shape (tuple): Shape of the output raster as (height, width).
1026
- Required if reference_raster is not provided.
1027
- transform (affine.Affine): Affine transformation matrix.
1028
- Required if reference_raster is not provided.
1029
- pixel_size (float or tuple): Pixel size (resolution) as single value or (x_res, y_res).
1030
- Used to calculate transform if transform is not provided.
1031
- bounds (tuple): Bounds of the output raster as (left, bottom, right, top).
1032
- Used to calculate transform if transform is not provided.
1033
- crs (str or CRS): Coordinate reference system of the output raster.
1034
- Required if reference_raster is not provided.
1035
- all_touched (bool): If True, all pixels touched by geometries will be burned in.
1036
- If False, only pixels whose center is within the geometry will be burned in.
1037
- fill_value (int): Value to fill the raster with before burning in features.
1038
- dtype (numpy.dtype): Data type of the output raster.
1039
- nodata (int): No data value for the output raster.
1040
- plot_result (bool): Whether to plot the resulting raster.
1041
-
1042
- Returns:
1043
- numpy.ndarray: The rasterized data array if output_path is None, else None.
1044
- """
1045
- # Load vector data
1046
- if isinstance(vector_path, gpd.GeoDataFrame):
1047
- gdf = vector_path
1048
- else:
1049
- gdf = gpd.read_file(vector_path)
1050
-
1051
- # Check if vector data is empty
1052
- if gdf.empty:
1053
- warnings.warn("The input vector data is empty. Creating an empty raster.")
1054
-
1055
- # Get CRS from vector data if not provided
1056
- if crs is None and reference_raster is None:
1057
- crs = gdf.crs
1058
-
1059
- # Get transform and output shape from reference raster if provided
1060
- if reference_raster is not None:
1061
- with rasterio.open(reference_raster) as src:
1062
- transform = src.transform
1063
- output_shape = src.shape
1064
- crs = src.crs
1065
- if nodata is None:
1066
- nodata = src.nodata
1067
- else:
1068
- # Check if we have all required parameters
1069
- if transform is None:
1070
- if pixel_size is None or bounds is None:
1071
- raise ValueError(
1072
- "Either reference_raster, transform, or both pixel_size and bounds must be provided."
1073
- )
1074
-
1075
- # Calculate transform from pixel size and bounds
1076
- if isinstance(pixel_size, (int, float)):
1077
- x_res = y_res = float(pixel_size)
1078
- else:
1079
- x_res, y_res = pixel_size
1080
- y_res = abs(y_res) * -1 # Convert to negative for north-up raster
1081
-
1082
- left, bottom, right, top = bounds
1083
- transform = rasterio.transform.from_bounds(
1084
- left,
1085
- bottom,
1086
- right,
1087
- top,
1088
- int((right - left) / x_res),
1089
- int((top - bottom) / abs(y_res)),
1090
- )
1091
-
1092
- if output_shape is None:
1093
- # Calculate output shape from bounds and pixel size
1094
- if bounds is None or pixel_size is None:
1095
- raise ValueError(
1096
- "output_shape must be provided if reference_raster is not provided and "
1097
- "cannot be calculated from bounds and pixel_size."
1098
- )
1099
-
1100
- if isinstance(pixel_size, (int, float)):
1101
- x_res = y_res = float(pixel_size)
1102
- else:
1103
- x_res, y_res = pixel_size
1104
-
1105
- left, bottom, right, top = bounds
1106
- width = int((right - left) / x_res)
1107
- height = int((top - bottom) / abs(y_res))
1108
- output_shape = (height, width)
1109
-
1110
- # Ensure CRS is set
1111
- if crs is None:
1112
- raise ValueError(
1113
- "CRS must be provided either directly, from reference_raster, or from input vector data."
1114
- )
1115
-
1116
- # Reproject vector data if its CRS doesn't match the output CRS
1117
- if gdf.crs != crs:
1118
- print(f"Reprojecting vector data from {gdf.crs} to {crs}")
1119
- gdf = gdf.to_crs(crs)
1120
-
1121
- # Create empty raster filled with fill_value
1122
- raster_data = np.full(output_shape, fill_value, dtype=dtype)
1123
-
1124
- # Burn vector features into raster
1125
- if not gdf.empty:
1126
- # Prepare shapes for burning
1127
- if attribute_field is not None and attribute_field in gdf.columns:
1128
- # Use attribute field for values
1129
- shapes = [
1130
- (geom, value) for geom, value in zip(gdf.geometry, gdf[attribute_field])
1131
- ]
1132
- else:
1133
- # Burn with value 1
1134
- shapes = [(geom, 1) for geom in gdf.geometry]
1135
-
1136
- # Burn shapes into raster
1137
- burned = features.rasterize(
1138
- shapes=shapes,
1139
- out_shape=output_shape,
1140
- transform=transform,
1141
- fill=fill_value,
1142
- all_touched=all_touched,
1143
- dtype=dtype,
1144
- )
1145
-
1146
- # Update raster data
1147
- raster_data = burned
1148
-
1149
- # Save raster if output path is provided
1150
- if output_path is not None:
1151
- # Create directory if it doesn't exist
1152
- os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
1153
-
1154
- # Define metadata
1155
- metadata = {
1156
- "driver": "GTiff",
1157
- "height": output_shape[0],
1158
- "width": output_shape[1],
1159
- "count": 1,
1160
- "dtype": raster_data.dtype,
1161
- "crs": crs,
1162
- "transform": transform,
1163
- }
1164
-
1165
- # Add nodata value if provided
1166
- if nodata is not None:
1167
- metadata["nodata"] = nodata
1168
-
1169
- # Write raster
1170
- with rasterio.open(output_path, "w", **metadata) as dst:
1171
- dst.write(raster_data, 1)
1172
-
1173
- print(f"Rasterized data saved to {output_path}")
1174
-
1175
- # Plot result if requested
1176
- if plot_result:
1177
- fig, ax = plt.subplots(figsize=(10, 10))
1178
-
1179
- # Plot raster
1180
- im = ax.imshow(raster_data, cmap="viridis")
1181
- plt.colorbar(im, ax=ax, label=attribute_field if attribute_field else "Value")
1182
-
1183
- # Plot vector boundaries for reference
1184
- if output_path is not None:
1185
- # Get the extent of the raster
1186
- with rasterio.open(output_path) as src:
1187
- bounds = src.bounds
1188
- raster_bbox = box(*bounds)
1189
- else:
1190
- # Calculate extent from transform and shape
1191
- height, width = output_shape
1192
- left, top = transform * (0, 0)
1193
- right, bottom = transform * (width, height)
1194
- raster_bbox = box(left, bottom, right, top)
1195
-
1196
- # Clip vector to raster extent for clarity in plot
1197
- if not gdf.empty:
1198
- gdf_clipped = gpd.clip(gdf, raster_bbox)
1199
- if not gdf_clipped.empty:
1200
- gdf_clipped.boundary.plot(ax=ax, color="red", linewidth=1)
1201
-
1202
- plt.title("Rasterized Vector Data")
1203
- plt.tight_layout()
1204
- plt.show()
1205
-
1206
- return raster_data
1207
-
1208
-
1209
- def batch_vector_to_raster(
1210
- vector_path,
1211
- output_dir,
1212
- attribute_field=None,
1213
- reference_rasters=None,
1214
- bounds_list=None,
1215
- output_filename_pattern="{vector_name}_{index}",
1216
- pixel_size=1.0,
1217
- all_touched=False,
1218
- fill_value=0,
1219
- dtype=np.uint8,
1220
- nodata=None,
1221
- ):
1222
- """
1223
- Batch convert vector data to multiple rasters based on different extents or reference rasters.
1224
-
1225
- Args:
1226
- vector_path (str or GeoDataFrame): Path to the input vector file or a GeoDataFrame.
1227
- output_dir (str): Directory to save output raster files.
1228
- attribute_field (str): Field name in the vector data to use for pixel values.
1229
- reference_rasters (list): List of paths to reference rasters for dimensions, transform and CRS.
1230
- bounds_list (list): List of bounds tuples (left, bottom, right, top) to use if reference_rasters not provided.
1231
- output_filename_pattern (str): Pattern for output filenames.
1232
- Can include {vector_name} and {index} placeholders.
1233
- pixel_size (float or tuple): Pixel size to use if reference_rasters not provided.
1234
- all_touched (bool): If True, all pixels touched by geometries will be burned in.
1235
- fill_value (int): Value to fill the raster with before burning in features.
1236
- dtype (numpy.dtype): Data type of the output raster.
1237
- nodata (int): No data value for the output raster.
1238
-
1239
- Returns:
1240
- list: List of paths to the created raster files.
1241
- """
1242
- # Create output directory if it doesn't exist
1243
- os.makedirs(output_dir, exist_ok=True)
1244
-
1245
- # Load vector data if it's a path
1246
- if isinstance(vector_path, str):
1247
- gdf = gpd.read_file(vector_path)
1248
- vector_name = os.path.splitext(os.path.basename(vector_path))[0]
1249
- else:
1250
- gdf = vector_path
1251
- vector_name = "vector"
1252
-
1253
- # Check input parameters
1254
- if reference_rasters is None and bounds_list is None:
1255
- raise ValueError("Either reference_rasters or bounds_list must be provided.")
1256
-
1257
- # Use reference_rasters if provided, otherwise use bounds_list
1258
- if reference_rasters is not None:
1259
- sources = reference_rasters
1260
- is_raster_reference = True
1261
- else:
1262
- sources = bounds_list
1263
- is_raster_reference = False
1264
-
1265
- # Create output filenames
1266
- output_files = []
1267
-
1268
- # Process each source (reference raster or bounds)
1269
- for i, source in enumerate(tqdm(sources, desc="Processing")):
1270
- # Generate output filename
1271
- output_filename = output_filename_pattern.format(
1272
- vector_name=vector_name, index=i
1273
- )
1274
- if not output_filename.endswith(".tif"):
1275
- output_filename += ".tif"
1276
- output_path = os.path.join(output_dir, output_filename)
1277
-
1278
- if is_raster_reference:
1279
- # Use reference raster
1280
- vector_to_raster(
1281
- vector_path=gdf,
1282
- output_path=output_path,
1283
- reference_raster=source,
1284
- attribute_field=attribute_field,
1285
- all_touched=all_touched,
1286
- fill_value=fill_value,
1287
- dtype=dtype,
1288
- nodata=nodata,
1289
- )
1290
- else:
1291
- # Use bounds
1292
- vector_to_raster(
1293
- vector_path=gdf,
1294
- output_path=output_path,
1295
- bounds=source,
1296
- pixel_size=pixel_size,
1297
- attribute_field=attribute_field,
1298
- all_touched=all_touched,
1299
- fill_value=fill_value,
1300
- dtype=dtype,
1301
- nodata=nodata,
1302
- )
1303
-
1304
- output_files.append(output_path)
1305
-
1306
- return output_files
1307
-
1308
-
1309
- def export_geotiff_tiles(
1310
- in_raster,
1311
- out_folder,
1312
- in_class_data,
1313
- tile_size=256,
1314
- stride=128,
1315
- class_value_field="class",
1316
- buffer_radius=0,
1317
- max_tiles=None,
1318
- quiet=False,
1319
- all_touched=True,
1320
- create_overview=False,
1321
- skip_empty_tiles=False,
1322
- ):
1323
- """
1324
- Export georeferenced GeoTIFF tiles and labels from raster and classification data.
1325
-
1326
- Args:
1327
- in_raster (str): Path to input raster image
1328
- out_folder (str): Path to output folder
1329
- in_class_data (str): Path to classification data - can be vector file or raster
1330
- tile_size (int): Size of tiles in pixels (square)
1331
- stride (int): Step size between tiles
1332
- class_value_field (str): Field containing class values (for vector data)
1333
- buffer_radius (float): Buffer to add around features (in units of the CRS)
1334
- max_tiles (int): Maximum number of tiles to process (None for all)
1335
- quiet (bool): If True, suppress non-essential output
1336
- all_touched (bool): Whether to use all_touched=True in rasterization (for vector data)
1337
- create_overview (bool): Whether to create an overview image of all tiles
1338
- skip_empty_tiles (bool): If True, skip tiles with no features
1339
- """
1340
- # Create output directories
1341
- os.makedirs(out_folder, exist_ok=True)
1342
- image_dir = os.path.join(out_folder, "images")
1343
- os.makedirs(image_dir, exist_ok=True)
1344
- label_dir = os.path.join(out_folder, "labels")
1345
- os.makedirs(label_dir, exist_ok=True)
1346
- ann_dir = os.path.join(out_folder, "annotations")
1347
- os.makedirs(ann_dir, exist_ok=True)
1348
-
1349
- # Determine if class data is raster or vector
1350
- is_class_data_raster = False
1351
- if isinstance(in_class_data, str):
1352
- file_ext = Path(in_class_data).suffix.lower()
1353
- # Common raster extensions
1354
- if file_ext in [".tif", ".tiff", ".img", ".jp2", ".png", ".bmp", ".gif"]:
1355
- try:
1356
- with rasterio.open(in_class_data) as src:
1357
- is_class_data_raster = True
1358
- if not quiet:
1359
- print(f"Detected in_class_data as raster: {in_class_data}")
1360
- print(f"Raster CRS: {src.crs}")
1361
- print(f"Raster dimensions: {src.width} x {src.height}")
1362
- except Exception:
1363
- is_class_data_raster = False
1364
- if not quiet:
1365
- print(f"Unable to open {in_class_data} as raster, trying as vector")
1366
-
1367
- # Open the input raster
1368
- with rasterio.open(in_raster) as src:
1369
- if not quiet:
1370
- print(f"\nRaster info for {in_raster}:")
1371
- print(f" CRS: {src.crs}")
1372
- print(f" Dimensions: {src.width} x {src.height}")
1373
- print(f" Bounds: {src.bounds}")
1374
-
1375
- # Calculate number of tiles
1376
- num_tiles_x = math.ceil((src.width - tile_size) / stride) + 1
1377
- num_tiles_y = math.ceil((src.height - tile_size) / stride) + 1
1378
- total_tiles = num_tiles_x * num_tiles_y
1379
-
1380
- if max_tiles is None:
1381
- max_tiles = total_tiles
1382
-
1383
- # Process classification data
1384
- class_to_id = {}
1385
-
1386
- if is_class_data_raster:
1387
- # Load raster class data
1388
- with rasterio.open(in_class_data) as class_src:
1389
- # Check if raster CRS matches
1390
- if class_src.crs != src.crs:
1391
- warnings.warn(
1392
- f"CRS mismatch: Class raster ({class_src.crs}) doesn't match input raster ({src.crs}). "
1393
- f"Results may be misaligned."
1394
- )
1395
-
1396
- # Get unique values from raster
1397
- # Sample to avoid loading huge rasters
1398
- sample_data = class_src.read(
1399
- 1,
1400
- out_shape=(
1401
- 1,
1402
- min(class_src.height, 1000),
1403
- min(class_src.width, 1000),
1404
- ),
1405
- )
1406
-
1407
- unique_classes = np.unique(sample_data)
1408
- unique_classes = unique_classes[
1409
- unique_classes > 0
1410
- ] # Remove 0 as it's typically background
1411
-
1412
- if not quiet:
1413
- print(
1414
- f"Found {len(unique_classes)} unique classes in raster: {unique_classes}"
1415
- )
1416
-
1417
- # Create class mapping
1418
- class_to_id = {int(cls): i + 1 for i, cls in enumerate(unique_classes)}
1419
- else:
1420
- # Load vector class data
1421
- try:
1422
- gdf = gpd.read_file(in_class_data)
1423
- if not quiet:
1424
- print(f"Loaded {len(gdf)} features from {in_class_data}")
1425
- print(f"Vector CRS: {gdf.crs}")
1426
-
1427
- # Always reproject to match raster CRS
1428
- if gdf.crs != src.crs:
1429
- if not quiet:
1430
- print(f"Reprojecting features from {gdf.crs} to {src.crs}")
1431
- gdf = gdf.to_crs(src.crs)
1432
-
1433
- # Apply buffer if specified
1434
- if buffer_radius > 0:
1435
- gdf["geometry"] = gdf.buffer(buffer_radius)
1436
- if not quiet:
1437
- print(f"Applied buffer of {buffer_radius} units")
1438
-
1439
- # Check if class_value_field exists
1440
- if class_value_field in gdf.columns:
1441
- unique_classes = gdf[class_value_field].unique()
1442
- if not quiet:
1443
- print(
1444
- f"Found {len(unique_classes)} unique classes: {unique_classes}"
1445
- )
1446
- # Create class mapping
1447
- class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
1448
- else:
1449
- if not quiet:
1450
- print(
1451
- f"WARNING: '{class_value_field}' not found in vector data. Using default class ID 1."
1452
- )
1453
- class_to_id = {1: 1} # Default mapping
1454
- except Exception as e:
1455
- raise ValueError(f"Error processing vector data: {e}")
1456
-
1457
- # Create progress bar
1458
- pbar = tqdm(
1459
- total=min(total_tiles, max_tiles),
1460
- desc="Generating tiles",
1461
- bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
1462
- )
1463
-
1464
- # Track statistics for summary
1465
- stats = {
1466
- "total_tiles": 0,
1467
- "tiles_with_features": 0,
1468
- "feature_pixels": 0,
1469
- "errors": 0,
1470
- "tile_coordinates": [], # For overview image
1471
- }
1472
-
1473
- # Process tiles
1474
- tile_index = 0
1475
- for y in range(num_tiles_y):
1476
- for x in range(num_tiles_x):
1477
- if tile_index >= max_tiles:
1478
- break
1479
-
1480
- # Calculate window coordinates
1481
- window_x = x * stride
1482
- window_y = y * stride
1483
-
1484
- # Adjust for edge cases
1485
- if window_x + tile_size > src.width:
1486
- window_x = src.width - tile_size
1487
- if window_y + tile_size > src.height:
1488
- window_y = src.height - tile_size
1489
-
1490
- # Define window
1491
- window = Window(window_x, window_y, tile_size, tile_size)
1492
-
1493
- # Get window transform and bounds
1494
- window_transform = src.window_transform(window)
1495
-
1496
- # Calculate window bounds
1497
- minx = window_transform[2] # Upper left x
1498
- maxy = window_transform[5] # Upper left y
1499
- maxx = minx + tile_size * window_transform[0] # Add width
1500
- miny = maxy + tile_size * window_transform[4] # Add height
1501
-
1502
- window_bounds = box(minx, miny, maxx, maxy)
1503
-
1504
- # Store tile coordinates for overview
1505
- if create_overview:
1506
- stats["tile_coordinates"].append(
1507
- {
1508
- "index": tile_index,
1509
- "x": window_x,
1510
- "y": window_y,
1511
- "bounds": [minx, miny, maxx, maxy],
1512
- "has_features": False,
1513
- }
1514
- )
1515
-
1516
- # Create label mask
1517
- label_mask = np.zeros((tile_size, tile_size), dtype=np.uint8)
1518
- has_features = False
1519
-
1520
- # Process classification data to create labels
1521
- if is_class_data_raster:
1522
- # For raster class data
1523
- with rasterio.open(in_class_data) as class_src:
1524
- # Calculate window in class raster
1525
- src_bounds = src.bounds
1526
- class_bounds = class_src.bounds
1527
-
1528
- # Check if windows overlap
1529
- if (
1530
- src_bounds.left > class_bounds.right
1531
- or src_bounds.right < class_bounds.left
1532
- or src_bounds.bottom > class_bounds.top
1533
- or src_bounds.top < class_bounds.bottom
1534
- ):
1535
- warnings.warn(
1536
- "Class raster and input raster do not overlap."
1537
- )
1538
- else:
1539
- # Get corresponding window in class raster
1540
- window_class = rasterio.windows.from_bounds(
1541
- minx, miny, maxx, maxy, class_src.transform
1542
- )
1543
-
1544
- # Read label data
1545
- try:
1546
- label_data = class_src.read(
1547
- 1,
1548
- window=window_class,
1549
- boundless=True,
1550
- out_shape=(tile_size, tile_size),
1551
- )
1552
-
1553
- # Remap class values if needed
1554
- if class_to_id:
1555
- remapped_data = np.zeros_like(label_data)
1556
- for orig_val, new_val in class_to_id.items():
1557
- remapped_data[label_data == orig_val] = new_val
1558
- label_mask = remapped_data
1559
- else:
1560
- label_mask = label_data
1561
-
1562
- # Check if we have any features
1563
- if np.any(label_mask > 0):
1564
- has_features = True
1565
- stats["feature_pixels"] += np.count_nonzero(
1566
- label_mask
1567
- )
1568
- except Exception as e:
1569
- pbar.write(f"Error reading class raster window: {e}")
1570
- stats["errors"] += 1
1571
- else:
1572
- # For vector class data
1573
- # Find features that intersect with window
1574
- window_features = gdf[gdf.intersects(window_bounds)]
1575
-
1576
- if len(window_features) > 0:
1577
- for idx, feature in window_features.iterrows():
1578
- # Get class value
1579
- if class_value_field in feature:
1580
- class_val = feature[class_value_field]
1581
- class_id = class_to_id.get(class_val, 1)
1582
- else:
1583
- class_id = 1
1584
-
1585
- # Get geometry in window coordinates
1586
- geom = feature.geometry.intersection(window_bounds)
1587
- if not geom.is_empty:
1588
- try:
1589
- # Rasterize feature
1590
- feature_mask = features.rasterize(
1591
- [(geom, class_id)],
1592
- out_shape=(tile_size, tile_size),
1593
- transform=window_transform,
1594
- fill=0,
1595
- all_touched=all_touched,
1596
- )
1597
-
1598
- # Add to label mask
1599
- label_mask = np.maximum(label_mask, feature_mask)
1600
-
1601
- # Check if the feature was actually rasterized
1602
- if np.any(feature_mask):
1603
- has_features = True
1604
- if create_overview and tile_index < len(
1605
- stats["tile_coordinates"]
1606
- ):
1607
- stats["tile_coordinates"][tile_index][
1608
- "has_features"
1609
- ] = True
1610
- except Exception as e:
1611
- pbar.write(f"Error rasterizing feature {idx}: {e}")
1612
- stats["errors"] += 1
1613
-
1614
- # Skip tile if no features and skip_empty_tiles is True
1615
- if skip_empty_tiles and not has_features:
1616
- pbar.update(1)
1617
- tile_index += 1
1618
- continue
1619
-
1620
- # Read image data
1621
- image_data = src.read(window=window)
1622
-
1623
- # Export image as GeoTIFF
1624
- image_path = os.path.join(image_dir, f"tile_{tile_index:06d}.tif")
1625
-
1626
- # Create profile for image GeoTIFF
1627
- image_profile = src.profile.copy()
1628
- image_profile.update(
1629
- {
1630
- "height": tile_size,
1631
- "width": tile_size,
1632
- "count": image_data.shape[0],
1633
- "transform": window_transform,
1634
- }
1635
- )
1636
-
1637
- # Save image as GeoTIFF
1638
- try:
1639
- with rasterio.open(image_path, "w", **image_profile) as dst:
1640
- dst.write(image_data)
1641
- stats["total_tiles"] += 1
1642
- except Exception as e:
1643
- pbar.write(f"ERROR saving image GeoTIFF: {e}")
1644
- stats["errors"] += 1
1645
-
1646
- # Create profile for label GeoTIFF
1647
- label_profile = {
1648
- "driver": "GTiff",
1649
- "height": tile_size,
1650
- "width": tile_size,
1651
- "count": 1,
1652
- "dtype": "uint8",
1653
- "crs": src.crs,
1654
- "transform": window_transform,
1655
- }
1656
-
1657
- # Export label as GeoTIFF
1658
- label_path = os.path.join(label_dir, f"tile_{tile_index:06d}.tif")
1659
- try:
1660
- with rasterio.open(label_path, "w", **label_profile) as dst:
1661
- dst.write(label_mask.astype(np.uint8), 1)
1662
-
1663
- if has_features:
1664
- stats["tiles_with_features"] += 1
1665
- stats["feature_pixels"] += np.count_nonzero(label_mask)
1666
- except Exception as e:
1667
- pbar.write(f"ERROR saving label GeoTIFF: {e}")
1668
- stats["errors"] += 1
1669
-
1670
- # Create XML annotation for object detection if using vector class data
1671
- if (
1672
- not is_class_data_raster
1673
- and "gdf" in locals()
1674
- and len(window_features) > 0
1675
- ):
1676
- # Create XML annotation
1677
- root = ET.Element("annotation")
1678
- ET.SubElement(root, "folder").text = "images"
1679
- ET.SubElement(root, "filename").text = f"tile_{tile_index:06d}.tif"
1680
-
1681
- size = ET.SubElement(root, "size")
1682
- ET.SubElement(size, "width").text = str(tile_size)
1683
- ET.SubElement(size, "height").text = str(tile_size)
1684
- ET.SubElement(size, "depth").text = str(image_data.shape[0])
1685
-
1686
- # Add georeference information
1687
- geo = ET.SubElement(root, "georeference")
1688
- ET.SubElement(geo, "crs").text = str(src.crs)
1689
- ET.SubElement(geo, "transform").text = str(
1690
- window_transform
1691
- ).replace("\n", "")
1692
- ET.SubElement(geo, "bounds").text = (
1693
- f"{minx}, {miny}, {maxx}, {maxy}"
1694
- )
1695
-
1696
- # Add objects
1697
- for idx, feature in window_features.iterrows():
1698
- # Get feature class
1699
- if class_value_field in feature:
1700
- class_val = feature[class_value_field]
1701
- else:
1702
- class_val = "object"
1703
-
1704
- # Get geometry bounds in pixel coordinates
1705
- geom = feature.geometry.intersection(window_bounds)
1706
- if not geom.is_empty:
1707
- # Get bounds in world coordinates
1708
- minx_f, miny_f, maxx_f, maxy_f = geom.bounds
1709
-
1710
- # Convert to pixel coordinates
1711
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
1712
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
1713
-
1714
- # Ensure coordinates are within tile bounds
1715
- xmin = max(0, min(tile_size, int(col_min)))
1716
- ymin = max(0, min(tile_size, int(row_min)))
1717
- xmax = max(0, min(tile_size, int(col_max)))
1718
- ymax = max(0, min(tile_size, int(row_max)))
1719
-
1720
- # Only add if the box has non-zero area
1721
- if xmax > xmin and ymax > ymin:
1722
- obj = ET.SubElement(root, "object")
1723
- ET.SubElement(obj, "name").text = str(class_val)
1724
- ET.SubElement(obj, "difficult").text = "0"
1725
-
1726
- bbox = ET.SubElement(obj, "bndbox")
1727
- ET.SubElement(bbox, "xmin").text = str(xmin)
1728
- ET.SubElement(bbox, "ymin").text = str(ymin)
1729
- ET.SubElement(bbox, "xmax").text = str(xmax)
1730
- ET.SubElement(bbox, "ymax").text = str(ymax)
1731
-
1732
- # Save XML
1733
- tree = ET.ElementTree(root)
1734
- xml_path = os.path.join(ann_dir, f"tile_{tile_index:06d}.xml")
1735
- tree.write(xml_path)
1736
-
1737
- # Update progress bar
1738
- pbar.update(1)
1739
- pbar.set_description(
1740
- f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
1741
- )
1742
-
1743
- tile_index += 1
1744
- if tile_index >= max_tiles:
1745
- break
1746
-
1747
- if tile_index >= max_tiles:
1748
- break
1749
-
1750
- # Close progress bar
1751
- pbar.close()
1752
-
1753
- # Create overview image if requested
1754
- if create_overview and stats["tile_coordinates"]:
1755
- try:
1756
- create_overview_image(
1757
- src,
1758
- stats["tile_coordinates"],
1759
- os.path.join(out_folder, "overview.png"),
1760
- tile_size,
1761
- stride,
1762
- )
1763
- except Exception as e:
1764
- print(f"Failed to create overview image: {e}")
1765
-
1766
- # Report results
1767
- if not quiet:
1768
- print("\n------- Export Summary -------")
1769
- print(f"Total tiles exported: {stats['total_tiles']}")
1770
- print(
1771
- f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
1772
- )
1773
- if stats["tiles_with_features"] > 0:
1774
- print(
1775
- f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
1776
- )
1777
- if stats["errors"] > 0:
1778
- print(f"Errors encountered: {stats['errors']}")
1779
- print(f"Output saved to: {out_folder}")
1780
-
1781
- # Verify georeference in a sample image and label
1782
- if stats["total_tiles"] > 0:
1783
- print("\n------- Georeference Verification -------")
1784
- sample_image = os.path.join(image_dir, f"tile_0.tif")
1785
- sample_label = os.path.join(label_dir, f"tile_0.tif")
1786
-
1787
- if os.path.exists(sample_image):
1788
- try:
1789
- with rasterio.open(sample_image) as img:
1790
- print(f"Image CRS: {img.crs}")
1791
- print(f"Image transform: {img.transform}")
1792
- print(
1793
- f"Image has georeference: {img.crs is not None and img.transform is not None}"
1794
- )
1795
- print(
1796
- f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
1797
- )
1798
- except Exception as e:
1799
- print(f"Error verifying image georeference: {e}")
1800
-
1801
- if os.path.exists(sample_label):
1802
- try:
1803
- with rasterio.open(sample_label) as lbl:
1804
- print(f"Label CRS: {lbl.crs}")
1805
- print(f"Label transform: {lbl.transform}")
1806
- print(
1807
- f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
1808
- )
1809
- print(
1810
- f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
1811
- )
1812
- except Exception as e:
1813
- print(f"Error verifying label georeference: {e}")
1814
-
1815
- # Return statistics dictionary for further processing if needed
1816
- return stats
1817
-
1818
-
1819
- def create_overview_image(
1820
- src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
1821
- ):
1822
- """Create an overview image showing all tiles and their status, with optional GeoJSON export.
1823
-
1824
- Args:
1825
- src (rasterio.io.DatasetReader): The source raster dataset.
1826
- tile_coordinates (list): A list of dictionaries containing tile information.
1827
- output_path (str): The path where the overview image will be saved.
1828
- tile_size (int): The size of each tile in pixels.
1829
- stride (int): The stride between tiles in pixels. Controls overlap between adjacent tiles.
1830
- geojson_path (str, optional): If provided, exports the tile rectangles as GeoJSON to this path.
1831
-
1832
- Returns:
1833
- str: Path to the saved overview image.
1834
- """
1835
- # Read a reduced version of the source image
1836
- overview_scale = max(
1837
- 1, int(max(src.width, src.height) / 2000)
1838
- ) # Scale to max ~2000px
1839
- overview_width = src.width // overview_scale
1840
- overview_height = src.height // overview_scale
1841
-
1842
- # Read downsampled image
1843
- overview_data = src.read(
1844
- out_shape=(src.count, overview_height, overview_width),
1845
- resampling=rasterio.enums.Resampling.average,
1846
- )
1847
-
1848
- # Create RGB image for display
1849
- if overview_data.shape[0] >= 3:
1850
- rgb = np.moveaxis(overview_data[:3], 0, -1)
1851
- else:
1852
- # For single band, create grayscale RGB
1853
- rgb = np.stack([overview_data[0], overview_data[0], overview_data[0]], axis=-1)
1854
-
1855
- # Normalize for display
1856
- for i in range(rgb.shape[-1]):
1857
- band = rgb[..., i]
1858
- non_zero = band[band > 0]
1859
- if len(non_zero) > 0:
1860
- p2, p98 = np.percentile(non_zero, (2, 98))
1861
- rgb[..., i] = np.clip((band - p2) / (p98 - p2), 0, 1)
1862
-
1863
- # Create figure
1864
- plt.figure(figsize=(12, 12))
1865
- plt.imshow(rgb)
1866
-
1867
- # If GeoJSON export is requested, prepare GeoJSON structures
1868
- if geojson_path:
1869
- features = []
1870
-
1871
- # Draw tile boundaries
1872
- for tile in tile_coordinates:
1873
- # Convert bounds to pixel coordinates in overview
1874
- bounds = tile["bounds"]
1875
- # Calculate scaled pixel coordinates
1876
- x_min = int((tile["x"]) / overview_scale)
1877
- y_min = int((tile["y"]) / overview_scale)
1878
- width = int(tile_size / overview_scale)
1879
- height = int(tile_size / overview_scale)
1880
-
1881
- # Draw rectangle
1882
- color = "lime" if tile["has_features"] else "red"
1883
- rect = plt.Rectangle(
1884
- (x_min, y_min), width, height, fill=False, edgecolor=color, linewidth=0.5
1885
- )
1886
- plt.gca().add_patch(rect)
1887
-
1888
- # Add tile number if not too crowded
1889
- if width > 20 and height > 20:
1890
- plt.text(
1891
- x_min + width / 2,
1892
- y_min + height / 2,
1893
- str(tile["index"]),
1894
- color="white",
1895
- ha="center",
1896
- va="center",
1897
- fontsize=8,
1898
- )
1899
-
1900
- # Add to GeoJSON features if exporting
1901
- if geojson_path:
1902
- # Create a polygon from the bounds (already in geo-coordinates)
1903
- minx, miny, maxx, maxy = bounds
1904
- polygon = box(minx, miny, maxx, maxy)
1905
-
1906
- # Calculate overlap with neighboring tiles
1907
- overlap = 0
1908
- if stride < tile_size:
1909
- overlap = tile_size - stride
1910
-
1911
- # Create a GeoJSON feature
1912
- feature = {
1913
- "type": "Feature",
1914
- "geometry": mapping(polygon),
1915
- "properties": {
1916
- "index": tile["index"],
1917
- "has_features": tile["has_features"],
1918
- "bounds_pixel": [
1919
- tile["x"],
1920
- tile["y"],
1921
- tile["x"] + tile_size,
1922
- tile["y"] + tile_size,
1923
- ],
1924
- "tile_size_px": tile_size,
1925
- "stride_px": stride,
1926
- "overlap_px": overlap,
1927
- },
1928
- }
1929
-
1930
- # Add any additional properties from the tile
1931
- for key, value in tile.items():
1932
- if key not in ["x", "y", "index", "has_features", "bounds"]:
1933
- feature["properties"][key] = value
1934
-
1935
- features.append(feature)
1936
-
1937
- plt.title("Tile Overview (Green = Contains Features, Red = Empty)")
1938
- plt.axis("off")
1939
- plt.tight_layout()
1940
- plt.savefig(output_path, dpi=300, bbox_inches="tight")
1941
- plt.close()
1942
-
1943
- print(f"Overview image saved to {output_path}")
1944
-
1945
- # Export GeoJSON if requested
1946
- if geojson_path:
1947
- geojson_collection = {
1948
- "type": "FeatureCollection",
1949
- "features": features,
1950
- "properties": {
1951
- "crs": (
1952
- src.crs.to_string()
1953
- if hasattr(src.crs, "to_string")
1954
- else str(src.crs)
1955
- ),
1956
- "total_tiles": len(features),
1957
- "source_raster_dimensions": [src.width, src.height],
1958
- },
1959
- }
1960
-
1961
- # Save to file
1962
- with open(geojson_path, "w") as f:
1963
- json.dump(geojson_collection, f)
1964
-
1965
- print(f"GeoJSON saved to {geojson_path}")
1966
-
1967
- return output_path
1968
-
1969
-
1970
- def export_tiles_to_geojson(
1971
- tile_coordinates, src, output_path, tile_size=None, stride=None
1972
- ):
1973
- """
1974
- Export tile rectangles directly to GeoJSON without creating an overview image.
1975
-
1976
- Args:
1977
- tile_coordinates (list): A list of dictionaries containing tile information.
1978
- src (rasterio.io.DatasetReader): The source raster dataset.
1979
- output_path (str): The path where the GeoJSON will be saved.
1980
- tile_size (int, optional): The size of each tile in pixels. Only needed if not in tile_coordinates.
1981
- stride (int, optional): The stride between tiles in pixels. Used to calculate overlaps between tiles.
1982
-
1983
- Returns:
1984
- str: Path to the saved GeoJSON file.
1985
- """
1986
- features = []
1987
-
1988
- for tile in tile_coordinates:
1989
- # Get the size from the tile or use the provided parameter
1990
- tile_width = tile.get("width", tile.get("size", tile_size))
1991
- tile_height = tile.get("height", tile.get("size", tile_size))
1992
-
1993
- if tile_width is None or tile_height is None:
1994
- raise ValueError(
1995
- "Tile size not found in tile data and no tile_size parameter provided"
1996
- )
1997
-
1998
- # Get bounds from the tile
1999
- if "bounds" in tile:
2000
- # If bounds are already in geo coordinates
2001
- minx, miny, maxx, maxy = tile["bounds"]
2002
- else:
2003
- # Try to calculate bounds from transform if available
2004
- if hasattr(src, "transform"):
2005
- # Convert pixel coordinates to geo coordinates
2006
- window_transform = src.transform
2007
- x, y = tile["x"], tile["y"]
2008
- minx = window_transform[2] + x * window_transform[0]
2009
- maxy = window_transform[5] + y * window_transform[4]
2010
- maxx = minx + tile_width * window_transform[0]
2011
- miny = maxy + tile_height * window_transform[4]
2012
- else:
2013
- raise ValueError(
2014
- "Cannot determine bounds. Neither 'bounds' in tile nor transform in src."
2015
- )
2016
-
2017
- # Calculate overlap with neighboring tiles if stride is provided
2018
- overlap = 0
2019
- if stride is not None and stride < tile_width:
2020
- overlap = tile_width - stride
2021
-
2022
- # Create a polygon from the bounds
2023
- polygon = box(minx, miny, maxx, maxy)
2024
-
2025
- # Create a GeoJSON feature
2026
- feature = {
2027
- "type": "Feature",
2028
- "geometry": mapping(polygon),
2029
- "properties": {
2030
- "index": tile["index"],
2031
- "has_features": tile.get("has_features", False),
2032
- "tile_width_px": tile_width,
2033
- "tile_height_px": tile_height,
2034
- },
2035
- }
2036
-
2037
- # Add overlap information if stride is provided
2038
- if stride is not None:
2039
- feature["properties"]["stride_px"] = stride
2040
- feature["properties"]["overlap_px"] = overlap
2041
-
2042
- # Add additional properties from the tile
2043
- for key, value in tile.items():
2044
- if key not in ["bounds", "geometry"]:
2045
- feature["properties"][key] = value
2046
-
2047
- features.append(feature)
2048
-
2049
- # Create the GeoJSON collection
2050
- geojson_collection = {
2051
- "type": "FeatureCollection",
2052
- "features": features,
2053
- "properties": {
2054
- "crs": (
2055
- src.crs.to_string() if hasattr(src.crs, "to_string") else str(src.crs)
2056
- ),
2057
- "total_tiles": len(features),
2058
- "source_raster_dimensions": (
2059
- [src.width, src.height] if hasattr(src, "width") else None
2060
- ),
2061
- },
2062
- }
2063
-
2064
- # Create directory if it doesn't exist
2065
- os.makedirs(os.path.dirname(os.path.abspath(output_path)) or ".", exist_ok=True)
2066
-
2067
- # Save to file
2068
- with open(output_path, "w") as f:
2069
- json.dump(geojson_collection, f)
2070
-
2071
- print(f"GeoJSON saved to {output_path}")
2072
- return output_path
2073
-
2074
-
2075
- def export_training_data(
2076
- in_raster,
2077
- out_folder,
2078
- in_class_data,
2079
- image_chip_format="GEOTIFF",
2080
- tile_size_x=256,
2081
- tile_size_y=256,
2082
- stride_x=None,
2083
- stride_y=None,
2084
- output_nofeature_tiles=True,
2085
- metadata_format="PASCAL_VOC",
2086
- start_index=0,
2087
- class_value_field="class",
2088
- buffer_radius=0,
2089
- in_mask_polygons=None,
2090
- rotation_angle=0,
2091
- reference_system=None,
2092
- blacken_around_feature=False,
2093
- crop_mode="FIXED_SIZE", # Implemented but not fully used yet
2094
- in_raster2=None,
2095
- in_instance_data=None,
2096
- instance_class_value_field=None, # Implemented but not fully used yet
2097
- min_polygon_overlap_ratio=0.0,
2098
- all_touched=True,
2099
- save_geotiff=True,
2100
- quiet=False,
2101
- ):
2102
- """
2103
- Export training data for deep learning using TorchGeo with progress bar.
2104
-
2105
- Args:
2106
- in_raster (str): Path to input raster image.
2107
- out_folder (str): Output folder path where chips and labels will be saved.
2108
- in_class_data (str): Path to vector file containing class polygons.
2109
- image_chip_format (str): Output image format (PNG, JPEG, TIFF, GEOTIFF).
2110
- tile_size_x (int): Width of image chips in pixels.
2111
- tile_size_y (int): Height of image chips in pixels.
2112
- stride_x (int): Horizontal stride between chips. If None, uses tile_size_x.
2113
- stride_y (int): Vertical stride between chips. If None, uses tile_size_y.
2114
- output_nofeature_tiles (bool): Whether to export chips without features.
2115
- metadata_format (str): Output metadata format (PASCAL_VOC, KITTI, COCO).
2116
- start_index (int): Starting index for chip filenames.
2117
- class_value_field (str): Field name in in_class_data containing class values.
2118
- buffer_radius (float): Buffer radius around features (in CRS units).
2119
- in_mask_polygons (str): Path to vector file containing mask polygons.
2120
- rotation_angle (float): Rotation angle in degrees.
2121
- reference_system (str): Reference system code.
2122
- blacken_around_feature (bool): Whether to mask areas outside of features.
2123
- crop_mode (str): Crop mode (FIXED_SIZE, CENTERED_ON_FEATURE).
2124
- in_raster2 (str): Path to secondary raster image.
2125
- in_instance_data (str): Path to vector file containing instance polygons.
2126
- instance_class_value_field (str): Field name in in_instance_data for instance classes.
2127
- min_polygon_overlap_ratio (float): Minimum overlap ratio for polygons.
2128
- all_touched (bool): Whether to use all_touched=True in rasterization.
2129
- save_geotiff (bool): Whether to save as GeoTIFF with georeferencing.
2130
- quiet (bool): If True, suppress most output messages.
2131
- """
2132
- # Create output directories
2133
- image_dir = os.path.join(out_folder, "images")
2134
- os.makedirs(image_dir, exist_ok=True)
2135
-
2136
- label_dir = os.path.join(out_folder, "labels")
2137
- os.makedirs(label_dir, exist_ok=True)
2138
-
2139
- # Define annotation directories based on metadata format
2140
- if metadata_format == "PASCAL_VOC":
2141
- ann_dir = os.path.join(out_folder, "annotations")
2142
- os.makedirs(ann_dir, exist_ok=True)
2143
- elif metadata_format == "COCO":
2144
- ann_dir = os.path.join(out_folder, "annotations")
2145
- os.makedirs(ann_dir, exist_ok=True)
2146
- # Initialize COCO annotations dictionary
2147
- coco_annotations = {"images": [], "annotations": [], "categories": []}
2148
-
2149
- # Initialize statistics dictionary
2150
- stats = {
2151
- "total_tiles": 0,
2152
- "tiles_with_features": 0,
2153
- "feature_pixels": 0,
2154
- "errors": 0,
2155
- }
2156
-
2157
- # Open raster
2158
- with rasterio.open(in_raster) as src:
2159
- if not quiet:
2160
- print(f"\nRaster info for {in_raster}:")
2161
- print(f" CRS: {src.crs}")
2162
- print(f" Dimensions: {src.width} x {src.height}")
2163
- print(f" Bounds: {src.bounds}")
2164
-
2165
- # Set defaults for stride if not provided
2166
- if stride_x is None:
2167
- stride_x = tile_size_x
2168
- if stride_y is None:
2169
- stride_y = tile_size_y
2170
-
2171
- # Calculate number of tiles in x and y directions
2172
- num_tiles_x = math.ceil((src.width - tile_size_x) / stride_x) + 1
2173
- num_tiles_y = math.ceil((src.height - tile_size_y) / stride_y) + 1
2174
- total_tiles = num_tiles_x * num_tiles_y
2175
-
2176
- # Read class data
2177
- gdf = gpd.read_file(in_class_data)
2178
- if not quiet:
2179
- print(f"Loaded {len(gdf)} features from {in_class_data}")
2180
- print(f"Available columns: {gdf.columns.tolist()}")
2181
- print(f"GeoJSON CRS: {gdf.crs}")
2182
-
2183
- # Check if class_value_field exists
2184
- if class_value_field not in gdf.columns:
2185
- if not quiet:
2186
- print(
2187
- f"WARNING: '{class_value_field}' field not found in the input data. Using default class value 1."
2188
- )
2189
- # Add a default class column
2190
- gdf[class_value_field] = 1
2191
- unique_classes = [1]
2192
- else:
2193
- # Print unique classes for debugging
2194
- unique_classes = gdf[class_value_field].unique()
2195
- if not quiet:
2196
- print(f"Found {len(unique_classes)} unique classes: {unique_classes}")
2197
-
2198
- # CRITICAL: Always reproject to match raster CRS to ensure proper alignment
2199
- if gdf.crs != src.crs:
2200
- if not quiet:
2201
- print(f"Reprojecting features from {gdf.crs} to {src.crs}")
2202
- gdf = gdf.to_crs(src.crs)
2203
- elif reference_system and gdf.crs != reference_system:
2204
- if not quiet:
2205
- print(
2206
- f"Reprojecting features to specified reference system {reference_system}"
2207
- )
2208
- gdf = gdf.to_crs(reference_system)
2209
-
2210
- # Check overlap between raster and vector data
2211
- raster_bounds = box(*src.bounds)
2212
- vector_bounds = box(*gdf.total_bounds)
2213
- if not raster_bounds.intersects(vector_bounds):
2214
- if not quiet:
2215
- print(
2216
- "WARNING: The vector data doesn't intersect with the raster extent!"
2217
- )
2218
- print(f"Raster bounds: {src.bounds}")
2219
- print(f"Vector bounds: {gdf.total_bounds}")
2220
- else:
2221
- overlap = (
2222
- raster_bounds.intersection(vector_bounds).area / vector_bounds.area
2223
- )
2224
- if not quiet:
2225
- print(f"Overlap between raster and vector: {overlap:.2%}")
2226
-
2227
- # Apply buffer if specified
2228
- if buffer_radius > 0:
2229
- gdf["geometry"] = gdf.buffer(buffer_radius)
2230
-
2231
- # Initialize class mapping (ensure all classes are mapped to non-zero values)
2232
- class_to_id = {cls: i + 1 for i, cls in enumerate(unique_classes)}
2233
-
2234
- # Store category info for COCO format
2235
- if metadata_format == "COCO":
2236
- for cls_val in unique_classes:
2237
- coco_annotations["categories"].append(
2238
- {
2239
- "id": class_to_id[cls_val],
2240
- "name": str(cls_val),
2241
- "supercategory": "object",
2242
- }
2243
- )
2244
-
2245
- # Load mask polygons if provided
2246
- mask_gdf = None
2247
- if in_mask_polygons:
2248
- mask_gdf = gpd.read_file(in_mask_polygons)
2249
- if reference_system:
2250
- mask_gdf = mask_gdf.to_crs(reference_system)
2251
- elif mask_gdf.crs != src.crs:
2252
- mask_gdf = mask_gdf.to_crs(src.crs)
2253
-
2254
- # Process instance data if provided
2255
- instance_gdf = None
2256
- if in_instance_data:
2257
- instance_gdf = gpd.read_file(in_instance_data)
2258
- if reference_system:
2259
- instance_gdf = instance_gdf.to_crs(reference_system)
2260
- elif instance_gdf.crs != src.crs:
2261
- instance_gdf = instance_gdf.to_crs(src.crs)
2262
-
2263
- # Load secondary raster if provided
2264
- src2 = None
2265
- if in_raster2:
2266
- src2 = rasterio.open(in_raster2)
2267
-
2268
- # Set up augmentation if rotation is specified
2269
- augmentation = None
2270
- if rotation_angle != 0:
2271
- # Fixed: Added data_keys parameter to AugmentationSequential
2272
- augmentation = torchgeo.transforms.AugmentationSequential(
2273
- torch.nn.ModuleList([RandomRotation(rotation_angle)]),
2274
- data_keys=["image"], # Add data_keys parameter
2275
- )
2276
-
2277
- # Initialize annotation ID for COCO format
2278
- ann_id = 0
2279
-
2280
- # Create progress bar
2281
- pbar = tqdm(
2282
- total=total_tiles,
2283
- desc=f"Generating tiles (with features: 0)",
2284
- bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
2285
- )
2286
-
2287
- # Generate tiles
2288
- chip_index = start_index
2289
- for y in range(num_tiles_y):
2290
- for x in range(num_tiles_x):
2291
- # Calculate window coordinates
2292
- window_x = x * stride_x
2293
- window_y = y * stride_y
2294
-
2295
- # Adjust for edge cases
2296
- if window_x + tile_size_x > src.width:
2297
- window_x = src.width - tile_size_x
2298
- if window_y + tile_size_y > src.height:
2299
- window_y = src.height - tile_size_y
2300
-
2301
- # Adjust window based on crop_mode
2302
- if crop_mode == "CENTERED_ON_FEATURE" and len(gdf) > 0:
2303
- # Find the nearest feature to the center of this window
2304
- window_center_x = window_x + tile_size_x // 2
2305
- window_center_y = window_y + tile_size_y // 2
2306
-
2307
- # Convert center to world coordinates
2308
- center_x, center_y = src.xy(window_center_y, window_center_x)
2309
- center_point = gpd.points_from_xy([center_x], [center_y])[0]
2310
-
2311
- # Find nearest feature
2312
- distances = gdf.geometry.distance(center_point)
2313
- nearest_idx = distances.idxmin()
2314
- nearest_feature = gdf.iloc[nearest_idx]
2315
-
2316
- # Get centroid of nearest feature
2317
- feature_centroid = nearest_feature.geometry.centroid
2318
-
2319
- # Convert feature centroid to pixel coordinates
2320
- feature_row, feature_col = src.index(
2321
- feature_centroid.x, feature_centroid.y
2322
- )
2323
-
2324
- # Adjust window to center on feature
2325
- window_x = max(
2326
- 0, min(src.width - tile_size_x, feature_col - tile_size_x // 2)
2327
- )
2328
- window_y = max(
2329
- 0, min(src.height - tile_size_y, feature_row - tile_size_y // 2)
2330
- )
2331
-
2332
- # Define window
2333
- window = Window(window_x, window_y, tile_size_x, tile_size_y)
2334
-
2335
- # Get window transform and bounds in source CRS
2336
- window_transform = src.window_transform(window)
2337
-
2338
- # Calculate window bounds more explicitly and accurately
2339
- minx = window_transform[2] # Upper left x
2340
- maxy = window_transform[5] # Upper left y
2341
- maxx = minx + tile_size_x * window_transform[0] # Add width
2342
- miny = (
2343
- maxy + tile_size_y * window_transform[4]
2344
- ) # Add height (note: transform[4] is typically negative)
2345
-
2346
- window_bounds = box(minx, miny, maxx, maxy)
2347
-
2348
- # Apply rotation if specified
2349
- if rotation_angle != 0:
2350
- window_bounds = rotate(
2351
- window_bounds, rotation_angle, origin="center"
2352
- )
2353
-
2354
- # Find features that intersect with window
2355
- window_features = gdf[gdf.intersects(window_bounds)]
2356
-
2357
- # Process instance data if provided
2358
- window_instances = None
2359
- if instance_gdf is not None and instance_class_value_field is not None:
2360
- window_instances = instance_gdf[
2361
- instance_gdf.intersects(window_bounds)
2362
- ]
2363
- if len(window_instances) > 0:
2364
- if not quiet:
2365
- pbar.write(
2366
- f"Found {len(window_instances)} instances in tile {chip_index}"
2367
- )
2368
-
2369
- # Skip if no features and output_nofeature_tiles is False
2370
- if not output_nofeature_tiles and len(window_features) == 0:
2371
- pbar.update(1) # Still update progress bar
2372
- continue
2373
-
2374
- # Check polygon overlap ratio if specified
2375
- if min_polygon_overlap_ratio > 0 and len(window_features) > 0:
2376
- valid_features = []
2377
- for _, feature in window_features.iterrows():
2378
- overlap_ratio = (
2379
- feature.geometry.intersection(window_bounds).area
2380
- / feature.geometry.area
2381
- )
2382
- if overlap_ratio >= min_polygon_overlap_ratio:
2383
- valid_features.append(feature)
2384
-
2385
- if len(valid_features) > 0:
2386
- window_features = gpd.GeoDataFrame(valid_features)
2387
- elif not output_nofeature_tiles:
2388
- pbar.update(1) # Still update progress bar
2389
- continue
2390
-
2391
- # Apply mask if provided
2392
- if mask_gdf is not None:
2393
- mask_features = mask_gdf[mask_gdf.intersects(window_bounds)]
2394
- if len(mask_features) == 0:
2395
- pbar.update(1) # Still update progress bar
2396
- continue
2397
-
2398
- # Read image data - keep original for GeoTIFF export
2399
- orig_image_data = src.read(window=window)
2400
-
2401
- # Create a copy for processing
2402
- image_data = orig_image_data.copy().astype(np.float32)
2403
-
2404
- # Normalize image data for processing
2405
- for band in range(image_data.shape[0]):
2406
- band_min, band_max = np.percentile(image_data[band], (1, 99))
2407
- if band_max > band_min:
2408
- image_data[band] = np.clip(
2409
- (image_data[band] - band_min) / (band_max - band_min), 0, 1
2410
- )
2411
-
2412
- # Read secondary image data if provided
2413
- if src2:
2414
- image_data2 = src2.read(window=window)
2415
- # Stack the two images
2416
- image_data = np.vstack((image_data, image_data2))
2417
-
2418
- # Apply blacken_around_feature if needed
2419
- if blacken_around_feature and len(window_features) > 0:
2420
- mask = np.zeros((tile_size_y, tile_size_x), dtype=bool)
2421
- for _, feature in window_features.iterrows():
2422
- # Project feature to pixel coordinates
2423
- feature_pixels = features.rasterize(
2424
- [(feature.geometry, 1)],
2425
- out_shape=(tile_size_y, tile_size_x),
2426
- transform=window_transform,
2427
- )
2428
- mask = np.logical_or(mask, feature_pixels.astype(bool))
2429
-
2430
- # Apply mask to image
2431
- for band in range(image_data.shape[0]):
2432
- temp = image_data[band, :, :]
2433
- temp[~mask] = 0
2434
- image_data[band, :, :] = temp
2435
-
2436
- # Apply rotation if specified
2437
- if augmentation:
2438
- # Convert to torch tensor for augmentation
2439
- image_tensor = torch.from_numpy(image_data).unsqueeze(
2440
- 0
2441
- ) # Add batch dimension
2442
- # Apply augmentation with proper data format
2443
- augmented = augmentation({"image": image_tensor})
2444
- image_data = (
2445
- augmented["image"].squeeze(0).numpy()
2446
- ) # Remove batch dimension
2447
-
2448
- # Create a processed version for regular image formats
2449
- processed_image = (image_data * 255).astype(np.uint8)
2450
-
2451
- # Create label mask
2452
- label_mask = np.zeros((tile_size_y, tile_size_x), dtype=np.uint8)
2453
- has_features = False
2454
-
2455
- if len(window_features) > 0:
2456
- for idx, feature in window_features.iterrows():
2457
- # Get class value
2458
- class_val = (
2459
- feature[class_value_field]
2460
- if class_value_field in feature
2461
- else 1
2462
- )
2463
- if isinstance(class_val, str):
2464
- # If class is a string, use its position in the unique classes list
2465
- class_id = class_to_id.get(class_val, 1)
2466
- else:
2467
- # If class is already a number, use it directly
2468
- class_id = int(class_val) if class_val > 0 else 1
2469
-
2470
- # Get the geometry in pixel coordinates
2471
- geom = feature.geometry.intersection(window_bounds)
2472
- if not geom.is_empty:
2473
- try:
2474
- # Rasterize the feature
2475
- feature_mask = features.rasterize(
2476
- [(geom, class_id)],
2477
- out_shape=(tile_size_y, tile_size_x),
2478
- transform=window_transform,
2479
- fill=0,
2480
- all_touched=all_touched,
2481
- )
2482
-
2483
- # Update mask with higher class values taking precedence
2484
- label_mask = np.maximum(label_mask, feature_mask)
2485
-
2486
- # Check if any pixels were added
2487
- if np.any(feature_mask):
2488
- has_features = True
2489
- except Exception as e:
2490
- if not quiet:
2491
- pbar.write(f"Error rasterizing feature {idx}: {e}")
2492
- stats["errors"] += 1
2493
-
2494
- # Save as GeoTIFF if requested
2495
- if save_geotiff or image_chip_format.upper() in [
2496
- "TIFF",
2497
- "TIF",
2498
- "GEOTIFF",
2499
- ]:
2500
- # Standardize extension to .tif for GeoTIFF files
2501
- image_filename = f"tile_{chip_index:06d}.tif"
2502
- image_path = os.path.join(image_dir, image_filename)
2503
-
2504
- # Create profile for the GeoTIFF
2505
- profile = src.profile.copy()
2506
- profile.update(
2507
- {
2508
- "height": tile_size_y,
2509
- "width": tile_size_x,
2510
- "count": orig_image_data.shape[0],
2511
- "transform": window_transform,
2512
- }
2513
- )
2514
-
2515
- # Save the GeoTIFF with original data
2516
- try:
2517
- with rasterio.open(image_path, "w", **profile) as dst:
2518
- dst.write(orig_image_data)
2519
- stats["total_tiles"] += 1
2520
- except Exception as e:
2521
- if not quiet:
2522
- pbar.write(
2523
- f"ERROR saving image GeoTIFF for tile {chip_index}: {e}"
2524
- )
2525
- stats["errors"] += 1
2526
- else:
2527
- # For non-GeoTIFF formats, use PIL to save the image
2528
- image_filename = (
2529
- f"tile_{chip_index:06d}.{image_chip_format.lower()}"
2530
- )
2531
- image_path = os.path.join(image_dir, image_filename)
2532
-
2533
- # Create PIL image for saving
2534
- if processed_image.shape[0] == 1:
2535
- img = Image.fromarray(processed_image[0])
2536
- elif processed_image.shape[0] == 3:
2537
- # For RGB, need to transpose and make sure it's the right data type
2538
- rgb_data = np.transpose(processed_image, (1, 2, 0))
2539
- img = Image.fromarray(rgb_data)
2540
- else:
2541
- # For multiband images, save only RGB or first three bands
2542
- rgb_data = np.transpose(processed_image[:3], (1, 2, 0))
2543
- img = Image.fromarray(rgb_data)
2544
-
2545
- # Save image
2546
- try:
2547
- img.save(image_path)
2548
- stats["total_tiles"] += 1
2549
- except Exception as e:
2550
- if not quiet:
2551
- pbar.write(f"ERROR saving image for tile {chip_index}: {e}")
2552
- stats["errors"] += 1
2553
-
2554
- # Save label as GeoTIFF
2555
- label_filename = f"tile_{chip_index:06d}.tif"
2556
- label_path = os.path.join(label_dir, label_filename)
2557
-
2558
- # Create profile for label GeoTIFF
2559
- label_profile = {
2560
- "driver": "GTiff",
2561
- "height": tile_size_y,
2562
- "width": tile_size_x,
2563
- "count": 1,
2564
- "dtype": "uint8",
2565
- "crs": src.crs,
2566
- "transform": window_transform,
2567
- }
2568
-
2569
- # Save label GeoTIFF
2570
- try:
2571
- with rasterio.open(label_path, "w", **label_profile) as dst:
2572
- dst.write(label_mask, 1)
2573
-
2574
- if has_features:
2575
- pixel_count = np.count_nonzero(label_mask)
2576
- stats["tiles_with_features"] += 1
2577
- stats["feature_pixels"] += pixel_count
2578
- except Exception as e:
2579
- if not quiet:
2580
- pbar.write(f"ERROR saving label for tile {chip_index}: {e}")
2581
- stats["errors"] += 1
2582
-
2583
- # Also save a PNG version for easy visualization if requested
2584
- if metadata_format == "PASCAL_VOC":
2585
- try:
2586
- # Ensure correct data type for PIL
2587
- png_label = label_mask.astype(np.uint8)
2588
- label_img = Image.fromarray(png_label)
2589
- label_png_path = os.path.join(
2590
- label_dir, f"tile_{chip_index:06d}.png"
2591
- )
2592
- label_img.save(label_png_path)
2593
- except Exception as e:
2594
- if not quiet:
2595
- pbar.write(
2596
- f"ERROR saving PNG label for tile {chip_index}: {e}"
2597
- )
2598
- pbar.write(
2599
- f" Label mask shape: {label_mask.shape}, dtype: {label_mask.dtype}"
2600
- )
2601
- # Try again with explicit conversion
2602
- try:
2603
- # Alternative approach for problematic arrays
2604
- png_data = np.zeros(
2605
- (tile_size_y, tile_size_x), dtype=np.uint8
2606
- )
2607
- np.copyto(png_data, label_mask, casting="unsafe")
2608
- label_img = Image.fromarray(png_data)
2609
- label_img.save(label_png_path)
2610
- pbar.write(
2611
- f" Succeeded using alternative conversion method"
2612
- )
2613
- except Exception as e2:
2614
- pbar.write(f" Second attempt also failed: {e2}")
2615
- stats["errors"] += 1
2616
-
2617
- # Generate annotations
2618
- if metadata_format == "PASCAL_VOC" and len(window_features) > 0:
2619
- # Create XML annotation
2620
- root = ET.Element("annotation")
2621
- ET.SubElement(root, "folder").text = "images"
2622
- ET.SubElement(root, "filename").text = image_filename
2623
-
2624
- size = ET.SubElement(root, "size")
2625
- ET.SubElement(size, "width").text = str(tile_size_x)
2626
- ET.SubElement(size, "height").text = str(tile_size_y)
2627
- ET.SubElement(size, "depth").text = str(min(image_data.shape[0], 3))
2628
-
2629
- # Add georeference information
2630
- geo = ET.SubElement(root, "georeference")
2631
- ET.SubElement(geo, "crs").text = str(src.crs)
2632
- ET.SubElement(geo, "transform").text = str(
2633
- window_transform
2634
- ).replace("\n", "")
2635
- ET.SubElement(geo, "bounds").text = (
2636
- f"{minx}, {miny}, {maxx}, {maxy}"
2637
- )
2638
-
2639
- for _, feature in window_features.iterrows():
2640
- # Convert feature geometry to pixel coordinates
2641
- feature_bounds = feature.geometry.intersection(window_bounds)
2642
- if feature_bounds.is_empty:
2643
- continue
2644
-
2645
- # Get pixel coordinates of bounds
2646
- minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
2647
-
2648
- # Convert to pixel coordinates
2649
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
2650
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
2651
-
2652
- # Ensure coordinates are within bounds
2653
- xmin = max(0, min(tile_size_x, int(col_min)))
2654
- ymin = max(0, min(tile_size_y, int(row_min)))
2655
- xmax = max(0, min(tile_size_x, int(col_max)))
2656
- ymax = max(0, min(tile_size_y, int(row_max)))
2657
-
2658
- # Skip if box is too small
2659
- if xmax - xmin < 1 or ymax - ymin < 1:
2660
- continue
2661
-
2662
- obj = ET.SubElement(root, "object")
2663
- ET.SubElement(obj, "name").text = str(
2664
- feature[class_value_field]
2665
- )
2666
- ET.SubElement(obj, "difficult").text = "0"
2667
-
2668
- bbox = ET.SubElement(obj, "bndbox")
2669
- ET.SubElement(bbox, "xmin").text = str(xmin)
2670
- ET.SubElement(bbox, "ymin").text = str(ymin)
2671
- ET.SubElement(bbox, "xmax").text = str(xmax)
2672
- ET.SubElement(bbox, "ymax").text = str(ymax)
2673
-
2674
- # Save XML
2675
- try:
2676
- tree = ET.ElementTree(root)
2677
- xml_path = os.path.join(ann_dir, f"tile_{chip_index:06d}.xml")
2678
- tree.write(xml_path)
2679
- except Exception as e:
2680
- if not quiet:
2681
- pbar.write(
2682
- f"ERROR saving XML annotation for tile {chip_index}: {e}"
2683
- )
2684
- stats["errors"] += 1
2685
-
2686
- elif metadata_format == "COCO" and len(window_features) > 0:
2687
- # Add image info
2688
- image_id = chip_index
2689
- coco_annotations["images"].append(
2690
- {
2691
- "id": image_id,
2692
- "file_name": image_filename,
2693
- "width": tile_size_x,
2694
- "height": tile_size_y,
2695
- "crs": str(src.crs),
2696
- "transform": str(window_transform),
2697
- }
2698
- )
2699
-
2700
- # Add annotations for each feature
2701
- for _, feature in window_features.iterrows():
2702
- feature_bounds = feature.geometry.intersection(window_bounds)
2703
- if feature_bounds.is_empty:
2704
- continue
2705
-
2706
- # Get pixel coordinates of bounds
2707
- minx_f, miny_f, maxx_f, maxy_f = feature_bounds.bounds
2708
-
2709
- # Convert to pixel coordinates
2710
- col_min, row_min = ~window_transform * (minx_f, maxy_f)
2711
- col_max, row_max = ~window_transform * (maxx_f, miny_f)
2712
-
2713
- # Ensure coordinates are within bounds
2714
- xmin = max(0, min(tile_size_x, int(col_min)))
2715
- ymin = max(0, min(tile_size_y, int(row_min)))
2716
- xmax = max(0, min(tile_size_x, int(col_max)))
2717
- ymax = max(0, min(tile_size_y, int(row_max)))
2718
-
2719
- # Skip if box is too small
2720
- if xmax - xmin < 1 or ymax - ymin < 1:
2721
- continue
2722
-
2723
- width = xmax - xmin
2724
- height = ymax - ymin
2725
-
2726
- # Add annotation
2727
- ann_id += 1
2728
- category_id = class_to_id[feature[class_value_field]]
2729
-
2730
- coco_annotations["annotations"].append(
2731
- {
2732
- "id": ann_id,
2733
- "image_id": image_id,
2734
- "category_id": category_id,
2735
- "bbox": [xmin, ymin, width, height],
2736
- "area": width * height,
2737
- "iscrowd": 0,
2738
- }
2739
- )
2740
-
2741
- # Update progress bar
2742
- pbar.update(1)
2743
- pbar.set_description(
2744
- f"Generated: {stats['total_tiles']}, With features: {stats['tiles_with_features']}"
2745
- )
2746
-
2747
- chip_index += 1
2748
-
2749
- # Close progress bar
2750
- pbar.close()
2751
-
2752
- # Save COCO annotations if applicable
2753
- if metadata_format == "COCO":
2754
- try:
2755
- with open(os.path.join(ann_dir, "instances.json"), "w") as f:
2756
- json.dump(coco_annotations, f)
2757
- except Exception as e:
2758
- if not quiet:
2759
- print(f"ERROR saving COCO annotations: {e}")
2760
- stats["errors"] += 1
2761
-
2762
- # Close secondary raster if opened
2763
- if src2:
2764
- src2.close()
2765
-
2766
- # Print summary
2767
- if not quiet:
2768
- print("\n------- Export Summary -------")
2769
- print(f"Total tiles exported: {stats['total_tiles']}")
2770
- print(
2771
- f"Tiles with features: {stats['tiles_with_features']} ({stats['tiles_with_features']/max(1, stats['total_tiles'])*100:.1f}%)"
2772
- )
2773
- if stats["tiles_with_features"] > 0:
2774
- print(
2775
- f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}"
2776
- )
2777
- if stats["errors"] > 0:
2778
- print(f"Errors encountered: {stats['errors']}")
2779
- print(f"Output saved to: {out_folder}")
2780
-
2781
- # Verify georeference in a sample image and label
2782
- if stats["total_tiles"] > 0:
2783
- print("\n------- Georeference Verification -------")
2784
- sample_image = os.path.join(image_dir, f"tile_{start_index}.tif")
2785
- sample_label = os.path.join(label_dir, f"tile_{start_index}.tif")
2786
-
2787
- if os.path.exists(sample_image):
2788
- try:
2789
- with rasterio.open(sample_image) as img:
2790
- print(f"Image CRS: {img.crs}")
2791
- print(f"Image transform: {img.transform}")
2792
- print(
2793
- f"Image has georeference: {img.crs is not None and img.transform is not None}"
2794
- )
2795
- print(
2796
- f"Image dimensions: {img.width}x{img.height}, {img.count} bands, {img.dtypes[0]} type"
2797
- )
2798
- except Exception as e:
2799
- print(f"Error verifying image georeference: {e}")
2800
-
2801
- if os.path.exists(sample_label):
2802
- try:
2803
- with rasterio.open(sample_label) as lbl:
2804
- print(f"Label CRS: {lbl.crs}")
2805
- print(f"Label transform: {lbl.transform}")
2806
- print(
2807
- f"Label has georeference: {lbl.crs is not None and lbl.transform is not None}"
2808
- )
2809
- print(
2810
- f"Label dimensions: {lbl.width}x{lbl.height}, {lbl.count} bands, {lbl.dtypes[0]} type"
2811
- )
2812
- except Exception as e:
2813
- print(f"Error verifying label georeference: {e}")
2814
-
2815
- # Return statistics
2816
- return stats, out_folder
2817
-
2818
-
2819
- def masks_to_vector(
2820
- mask_path,
2821
- output_path=None,
2822
- simplify_tolerance=1.0,
2823
- mask_threshold=0.5,
2824
- min_object_area=100,
2825
- max_object_area=None,
2826
- nms_iou_threshold=0.5,
2827
- ):
2828
- """
2829
- Convert a building mask GeoTIFF to vector polygons and save as a vector dataset.
2830
-
2831
- Args:
2832
- mask_path: Path to the building masks GeoTIFF
2833
- output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
2834
- simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
2835
- mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
2836
- min_object_area: Minimum area in pixels to keep a building (default: self.min_object_area)
2837
- max_object_area: Maximum area in pixels to keep a building (default: self.max_object_area)
2838
- nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
2839
-
2840
- Returns:
2841
- GeoDataFrame with building footprints
2842
- """
2843
- # Set default output path if not provided
2844
- # if output_path is None:
2845
- # output_path = os.path.splitext(mask_path)[0] + ".geojson"
2846
-
2847
- print(f"Converting mask to GeoJSON with parameters:")
2848
- print(f"- Mask threshold: {mask_threshold}")
2849
- print(f"- Min building area: {min_object_area}")
2850
- print(f"- Simplify tolerance: {simplify_tolerance}")
2851
- print(f"- NMS IoU threshold: {nms_iou_threshold}")
2852
-
2853
- # Open the mask raster
2854
- with rasterio.open(mask_path) as src:
2855
- # Read the mask data
2856
- mask_data = src.read(1)
2857
- transform = src.transform
2858
- crs = src.crs
2859
-
2860
- # Print mask statistics
2861
- print(f"Mask dimensions: {mask_data.shape}")
2862
- print(f"Mask value range: {mask_data.min()} to {mask_data.max()}")
2863
-
2864
- # Prepare for connected component analysis
2865
- # Binarize the mask based on threshold
2866
- binary_mask = (mask_data > (mask_threshold * 255)).astype(np.uint8)
2867
-
2868
- # Apply morphological operations for better results (optional)
2869
- kernel = np.ones((3, 3), np.uint8)
2870
- binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
2871
-
2872
- # Find connected components
2873
- num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
2874
- binary_mask, connectivity=8
2875
- )
2876
-
2877
- print(f"Found {num_labels-1} potential buildings") # Subtract 1 for background
2878
-
2879
- # Create list to store polygons and confidence values
2880
- all_polygons = []
2881
- all_confidences = []
2882
-
2883
- # Process each component (skip the first one which is background)
2884
- for i in tqdm(range(1, num_labels)):
2885
- # Extract this building
2886
- area = stats[i, cv2.CC_STAT_AREA]
2887
-
2888
- # Skip if too small
2889
- if area < min_object_area:
2890
- continue
2891
-
2892
- # Skip if too large
2893
- if max_object_area is not None and area > max_object_area:
2894
- continue
2895
-
2896
- # Create a mask for this building
2897
- building_mask = (labels == i).astype(np.uint8)
2898
-
2899
- # Find contours
2900
- contours, _ = cv2.findContours(
2901
- building_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
2902
- )
2903
-
2904
- # Process each contour
2905
- for contour in contours:
2906
- # Skip if too few points
2907
- if contour.shape[0] < 3:
2908
- continue
2909
-
2910
- # Simplify contour if it has many points
2911
- if contour.shape[0] > 50 and simplify_tolerance > 0:
2912
- epsilon = simplify_tolerance * cv2.arcLength(contour, True)
2913
- contour = cv2.approxPolyDP(contour, epsilon, True)
2914
-
2915
- # Convert to list of (x, y) coordinates
2916
- polygon_points = contour.reshape(-1, 2)
2917
-
2918
- # Convert pixel coordinates to geographic coordinates
2919
- geo_points = []
2920
- for x, y in polygon_points:
2921
- gx, gy = transform * (x, y)
2922
- geo_points.append((gx, gy))
2923
-
2924
- # Create Shapely polygon
2925
- if len(geo_points) >= 3:
2926
- try:
2927
- shapely_poly = Polygon(geo_points)
2928
- if shapely_poly.is_valid and shapely_poly.area > 0:
2929
- all_polygons.append(shapely_poly)
2930
-
2931
- # Calculate "confidence" as normalized size
2932
- # This is a proxy since we don't have model confidence scores
2933
- normalized_size = min(1.0, area / 1000) # Cap at 1.0
2934
- all_confidences.append(normalized_size)
2935
- except Exception as e:
2936
- print(f"Error creating polygon: {e}")
2937
-
2938
- print(f"Created {len(all_polygons)} valid polygons")
2939
-
2940
- # Create GeoDataFrame
2941
- if not all_polygons:
2942
- print("No valid polygons found")
2943
- return None
2944
-
2945
- gdf = gpd.GeoDataFrame(
2946
- {
2947
- "geometry": all_polygons,
2948
- "confidence": all_confidences,
2949
- "class": 1, # Building class
2950
- },
2951
- crs=crs,
2952
- )
2953
-
2954
- def filter_overlapping_polygons(gdf, **kwargs):
2955
- """
2956
- Filter overlapping polygons using non-maximum suppression.
2957
-
2958
- Args:
2959
- gdf: GeoDataFrame with polygons
2960
- **kwargs: Optional parameters:
2961
- nms_iou_threshold: IoU threshold for filtering
2962
-
2963
- Returns:
2964
- Filtered GeoDataFrame
2965
- """
2966
- if len(gdf) <= 1:
2967
- return gdf
2968
-
2969
- # Get parameters from kwargs or use instance defaults
2970
- iou_threshold = kwargs.get("nms_iou_threshold", nms_iou_threshold)
2971
-
2972
- # Sort by confidence
2973
- gdf = gdf.sort_values("confidence", ascending=False)
2974
-
2975
- # Fix any invalid geometries
2976
- gdf["geometry"] = gdf["geometry"].apply(
2977
- lambda geom: geom.buffer(0) if not geom.is_valid else geom
2978
- )
2979
-
2980
- keep_indices = []
2981
- polygons = gdf.geometry.values
2982
-
2983
- for i in range(len(polygons)):
2984
- if i in keep_indices:
2985
- continue
2986
-
2987
- keep = True
2988
- for j in keep_indices:
2989
- # Skip invalid geometries
2990
- if not polygons[i].is_valid or not polygons[j].is_valid:
2991
- continue
2992
-
2993
- # Calculate IoU
2994
- try:
2995
- intersection = polygons[i].intersection(polygons[j]).area
2996
- union = polygons[i].area + polygons[j].area - intersection
2997
- iou = intersection / union if union > 0 else 0
2998
-
2999
- if iou > iou_threshold:
3000
- keep = False
3001
- break
3002
- except Exception:
3003
- # Skip on topology exceptions
3004
- continue
3005
-
3006
- if keep:
3007
- keep_indices.append(i)
3008
-
3009
- return gdf.iloc[keep_indices]
3010
-
3011
- # Apply non-maximum suppression to remove overlapping polygons
3012
- gdf = filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
3013
-
3014
- print(f"Final building count after filtering: {len(gdf)}")
3015
-
3016
- # Save to file
3017
- if output_path is not None:
3018
- gdf.to_file(output_path)
3019
- print(f"Saved {len(gdf)} building footprints to {output_path}")
3020
-
3021
- return gdf