geoai-py 0.1.6__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,32 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.1.6"
5
+ __version__ = "0.2.0"
6
6
 
7
7
 
8
+ import os
9
+ import sys
10
+
11
+
12
+ def set_proj_lib_path():
13
+ """Set the PROJ_LIB environment variable based on the current conda environment."""
14
+ try:
15
+ # Get conda environment path
16
+ conda_env_path = os.environ.get("CONDA_PREFIX") or sys.prefix
17
+
18
+ # Set PROJ_LIB environment variable
19
+ proj_path = os.path.join(conda_env_path, "share", "proj")
20
+ gdal_path = os.path.join(conda_env_path, "share", "gdal")
21
+
22
+ # Check if the directory exists before setting
23
+ if os.path.exists(proj_path):
24
+ os.environ["PROJ_LIB"] = proj_path
25
+ if os.path.exists(gdal_path):
26
+ os.environ["GDAL_DATA"] = gdal_path
27
+ except Exception as e:
28
+ print(e)
29
+ return
30
+
31
+
32
+ set_proj_lib_path()
8
33
  from .geoai import *
geoai/common.py CHANGED
@@ -5,8 +5,12 @@ from collections.abc import Iterable
5
5
  from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
6
6
  import matplotlib.pyplot as plt
7
7
 
8
+ import leafmap
8
9
  import torch
9
10
  import numpy as np
11
+ import xarray as xr
12
+ import rioxarray
13
+ import rasterio as rio
10
14
  from torch.utils.data import DataLoader
11
15
  from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
12
16
  from torchgeo.samplers import RandomGeoSampler, Units
@@ -55,10 +59,12 @@ def viz_raster(
55
59
  Returns:
56
60
  leafmap.Map: The map object with the raster layer added.
57
61
  """
58
- import leafmap
59
62
 
60
63
  m = leafmap.Map(basemap=basemap)
61
64
 
65
+ if isinstance(source, dict):
66
+ source = dict_to_image(source)
67
+
62
68
  m.add_raster(
63
69
  source=source,
64
70
  indexes=indexes,
@@ -86,6 +92,7 @@ def viz_image(
86
92
  scale_factor: float = 1.0,
87
93
  figsize: Tuple[int, int] = (10, 5),
88
94
  axis_off: bool = True,
95
+ title: Optional[str] = None,
89
96
  **kwargs: Any,
90
97
  ) -> None:
91
98
  """
@@ -98,6 +105,7 @@ def viz_image(
98
105
  scale_factor (float, optional): The scale factor to apply to the image. Defaults to 1.0.
99
106
  figsize (Tuple[int, int], optional): The size of the figure. Defaults to (10, 5).
100
107
  axis_off (bool, optional): Whether to turn off the axis. Defaults to True.
108
+ title (Optional[str], optional): The title of the plot. Defaults to None.
101
109
  **kwargs (Any): Additional keyword arguments for plt.imshow().
102
110
 
103
111
  Returns:
@@ -124,6 +132,8 @@ def viz_image(
124
132
  plt.imshow(image, **kwargs)
125
133
  if axis_off:
126
134
  plt.axis("off")
135
+ if title is not None:
136
+ plt.title(title)
127
137
  plt.show()
128
138
  plt.close()
129
139
 
@@ -277,3 +287,150 @@ def calc_stats(
277
287
  # at the end, we shall have 2 vectors with length n=chnls
278
288
  # we will average them considering the number of images
279
289
  return accum_mean / len(files), accum_std / len(files)
290
+
291
+
292
+ def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
293
+ """Convert a dictionary to a xarray DataArray. The dictionary should contain the
294
+ following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
295
+ dataset sampler.
296
+
297
+ Args:
298
+ data_dict (Dict): The dictionary containing the data.
299
+
300
+ Returns:
301
+ xr.DataArray: The xarray DataArray.
302
+ """
303
+
304
+ from affine import Affine
305
+
306
+ # Extract components from the dictionary
307
+ crs = data_dict["crs"]
308
+ bounds = data_dict["bounds"]
309
+ image_tensor = data_dict["image"]
310
+
311
+ # Convert tensor to numpy array if needed
312
+ if hasattr(image_tensor, "numpy"):
313
+ # For PyTorch tensors
314
+ image_array = image_tensor.numpy()
315
+ else:
316
+ # If it's already a numpy array or similar
317
+ image_array = np.array(image_tensor)
318
+
319
+ # Calculate pixel resolution
320
+ width = image_array.shape[2] # Width is the size of the last dimension
321
+ height = image_array.shape[1] # Height is the size of the middle dimension
322
+
323
+ res_x = (bounds.maxx - bounds.minx) / width
324
+ res_y = (bounds.maxy - bounds.miny) / height
325
+
326
+ # Create the transform matrix
327
+ transform = Affine(res_x, 0.0, bounds.minx, 0.0, -res_y, bounds.maxy)
328
+
329
+ # Create dimensions
330
+ x_coords = np.linspace(bounds.minx + res_x / 2, bounds.maxx - res_x / 2, width)
331
+ y_coords = np.linspace(bounds.maxy - res_y / 2, bounds.miny + res_y / 2, height)
332
+
333
+ # If time dimension exists in the bounds
334
+ if hasattr(bounds, "mint") and hasattr(bounds, "maxt"):
335
+ # Create a single time value or range if needed
336
+ t_coords = [
337
+ bounds.mint
338
+ ] # Or np.linspace(bounds.mint, bounds.maxt, num_time_steps)
339
+
340
+ # Create DataArray with time dimension
341
+ dims = (
342
+ ("band", "y", "x")
343
+ if image_array.shape[0] <= 10
344
+ else ("time", "band", "y", "x")
345
+ )
346
+
347
+ if dims[0] == "band":
348
+ # For multi-band single time
349
+ da = xr.DataArray(
350
+ image_array,
351
+ dims=dims,
352
+ coords={
353
+ "band": np.arange(1, image_array.shape[0] + 1),
354
+ "y": y_coords,
355
+ "x": x_coords,
356
+ },
357
+ )
358
+ else:
359
+ # For multi-time multi-band
360
+ da = xr.DataArray(
361
+ image_array,
362
+ dims=dims,
363
+ coords={
364
+ "time": t_coords,
365
+ "band": np.arange(1, image_array.shape[1] + 1),
366
+ "y": y_coords,
367
+ "x": x_coords,
368
+ },
369
+ )
370
+ else:
371
+ # Create DataArray without time dimension
372
+ da = xr.DataArray(
373
+ image_array,
374
+ dims=("band", "y", "x"),
375
+ coords={
376
+ "band": np.arange(1, image_array.shape[0] + 1),
377
+ "y": y_coords,
378
+ "x": x_coords,
379
+ },
380
+ )
381
+
382
+ # Set spatial attributes
383
+ da.rio.write_crs(crs, inplace=True)
384
+ da.rio.write_transform(transform, inplace=True)
385
+
386
+ return da
387
+
388
+
389
+ def dict_to_image(
390
+ data_dict: Dict[str, Any], output: Optional[str] = None, **kwargs
391
+ ) -> rio.DatasetReader:
392
+ """Convert a dictionary containing spatial data to a rasterio dataset or save it to
393
+ a file. The dictionary should contain the following keys: "crs", "bounds", and "image".
394
+ It can be generated from a TorchGeo dataset sampler.
395
+
396
+ This function transforms a dictionary with CRS, bounding box, and image data
397
+ into a rasterio DatasetReader using leafmap's array_to_image utility after
398
+ first converting to a rioxarray DataArray.
399
+
400
+ Args:
401
+ data_dict: A dictionary containing:
402
+ - 'crs': A pyproj CRS object
403
+ - 'bounds': A BoundingBox object with minx, maxx, miny, maxy attributes
404
+ and optionally mint, maxt for temporal bounds
405
+ - 'image': A tensor or array-like object with image data
406
+ output: Optional path to save the image to a file. If not provided, the image
407
+ will be returned as a rasterio DatasetReader object.
408
+ **kwargs: Additional keyword arguments to pass to leafmap.array_to_image.
409
+ Common options include:
410
+ - colormap: str, name of the colormap (e.g., 'viridis', 'terrain')
411
+ - vmin: float, minimum value for colormap scaling
412
+ - vmax: float, maximum value for colormap scaling
413
+
414
+ Returns:
415
+ A rasterio DatasetReader object that can be used for visualization or
416
+ further processing.
417
+
418
+ Examples:
419
+ >>> image = dict_to_image(
420
+ ... {'crs': CRS.from_epsg(26911), 'bounds': bbox, 'image': tensor},
421
+ ... colormap='terrain'
422
+ ... )
423
+ >>> fig, ax = plt.subplots(figsize=(10, 10))
424
+ >>> show(image, ax=ax)
425
+ """
426
+ da = dict_to_rioxarray(data_dict)
427
+
428
+ if output is not None:
429
+ out_dir = os.path.abspath(os.path.dirname(output))
430
+ if not os.path.exists(out_dir):
431
+ os.makedirs(out_dir, exist_ok=True)
432
+ da.rio.to_raster(output)
433
+ return output
434
+ else:
435
+ image = leafmap.array_to_image(da, **kwargs)
436
+ return image
geoai/download.py ADDED
@@ -0,0 +1,395 @@
1
+ """This module provides functions to download data, including NAIP imagery and building data from Overture Maps."""
2
+
3
+ import os
4
+ from typing import List, Tuple, Optional, Dict, Any
5
+ import rioxarray
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from pystac_client import Client
9
+ import planetary_computer as pc
10
+ import geopandas as gpd
11
+ from shapely.geometry import box
12
+ from tqdm import tqdm
13
+ import requests
14
+ import subprocess
15
+ import logging
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def download_naip(
25
+ bbox: Tuple[float, float, float, float],
26
+ output_dir: str,
27
+ year: Optional[int] = None,
28
+ max_items: int = 10,
29
+ overwrite: bool = False,
30
+ preview: bool = False,
31
+ **kwargs: Any,
32
+ ) -> List[str]:
33
+ """Download NAIP imagery from Planetary Computer based on a bounding box.
34
+
35
+ This function searches for NAIP (National Agriculture Imagery Program) imagery
36
+ from Microsoft's Planetary Computer that intersects with the specified bounding box.
37
+ It downloads the imagery and saves it as GeoTIFF files.
38
+
39
+ Args:
40
+ bbox: Bounding box in the format (min_lon, min_lat, max_lon, max_lat) in WGS84 coordinates.
41
+ output_dir: Directory to save the downloaded imagery.
42
+ year: Specific year of NAIP imagery to download (e.g., 2020). If None, returns imagery from all available years.
43
+ max_items: Maximum number of items to download.
44
+ overwrite: If True, overwrite existing files with the same name.
45
+ preview: If True, display a preview of the downloaded imagery.
46
+
47
+ Returns:
48
+ List of downloaded file paths.
49
+
50
+ Raises:
51
+ Exception: If there is an error downloading or saving the imagery.
52
+ """
53
+ # Create output directory if it doesn't exist
54
+ os.makedirs(output_dir, exist_ok=True)
55
+
56
+ # Create a geometry from the bounding box
57
+ geometry = box(*bbox)
58
+
59
+ # Connect to Planetary Computer STAC API
60
+ catalog = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")
61
+
62
+ # Build query for NAIP data
63
+ search_params = {
64
+ "collections": ["naip"],
65
+ "intersects": geometry,
66
+ "limit": max_items,
67
+ }
68
+
69
+ # Add year filter if specified
70
+ if year:
71
+ search_params["query"] = {"naip:year": {"eq": year}}
72
+
73
+ for key, value in kwargs.items():
74
+ search_params[key] = value
75
+
76
+ # Search for NAIP imagery
77
+ search_results = catalog.search(**search_params)
78
+ items = list(search_results.items())
79
+
80
+ if len(items) > max_items:
81
+ items = items[:max_items]
82
+
83
+ if not items:
84
+ print("No NAIP imagery found for the specified region and parameters.")
85
+ return []
86
+
87
+ print(f"Found {len(items)} NAIP items.")
88
+
89
+ # Download and save each item
90
+ downloaded_files = []
91
+ for i, item in enumerate(items):
92
+ # Sign the assets (required for Planetary Computer)
93
+ signed_item = pc.sign(item)
94
+
95
+ # Get the RGB asset URL
96
+ rgb_asset = signed_item.assets.get("image")
97
+ if not rgb_asset:
98
+ print(f"No RGB asset found for item {i+1}")
99
+ continue
100
+
101
+ # Use the original filename from the asset
102
+ original_filename = os.path.basename(
103
+ rgb_asset.href.split("?")[0]
104
+ ) # Remove query parameters
105
+ output_path = os.path.join(output_dir, original_filename)
106
+ if not overwrite and os.path.exists(output_path):
107
+ print(f"Skipping existing file: {output_path}")
108
+ downloaded_files.append(output_path)
109
+ continue
110
+
111
+ print(f"Downloading item {i+1}/{len(items)}: {original_filename}")
112
+
113
+ try:
114
+ # Open and save the data with progress bar
115
+ # For direct file download with progress bar
116
+ if rgb_asset.href.startswith("http"):
117
+ download_with_progress(rgb_asset.href, output_path)
118
+ #
119
+ else:
120
+ # Fallback to direct rioxarray opening (less common case)
121
+ data = rioxarray.open_rasterio(rgb_asset.href)
122
+ data.rio.to_raster(output_path)
123
+
124
+ downloaded_files.append(output_path)
125
+ print(f"Successfully saved to {output_path}")
126
+
127
+ # Optional: Display a preview (uncomment if needed)
128
+ if preview:
129
+ data = rioxarray.open_rasterio(output_path)
130
+ preview_raster(data)
131
+
132
+ except Exception as e:
133
+ print(f"Error downloading item {i+1}: {str(e)}")
134
+
135
+ return downloaded_files
136
+
137
+
138
+ def download_with_progress(url: str, output_path: str) -> None:
139
+ """Download a file with a progress bar.
140
+
141
+ Args:
142
+ url: URL of the file to download.
143
+ output_path: Path where the file will be saved.
144
+ """
145
+ response = requests.get(url, stream=True)
146
+ total_size = int(response.headers.get("content-length", 0))
147
+ block_size = 1024 # 1 Kibibyte
148
+
149
+ with (
150
+ open(output_path, "wb") as file,
151
+ tqdm(
152
+ desc=os.path.basename(output_path),
153
+ total=total_size,
154
+ unit="iB",
155
+ unit_scale=True,
156
+ unit_divisor=1024,
157
+ ) as bar,
158
+ ):
159
+ for data in response.iter_content(block_size):
160
+ size = file.write(data)
161
+ bar.update(size)
162
+
163
+
164
+ def preview_raster(data: Any, title: str = None) -> None:
165
+ """Display a preview of the downloaded imagery.
166
+
167
+ This function creates a visualization of the downloaded NAIP imagery
168
+ by converting it to an RGB array and displaying it with matplotlib.
169
+
170
+ Args:
171
+ data: The raster data as a rioxarray object.
172
+ title: The title for the preview plot.
173
+ """
174
+ # Convert to 8-bit RGB for display
175
+ rgb_data = data.transpose("y", "x", "band").values[:, :, 0:3]
176
+ rgb_data = np.where(rgb_data > 255, 255, rgb_data).astype(np.uint8)
177
+
178
+ plt.figure(figsize=(10, 10))
179
+ plt.imshow(rgb_data)
180
+ if title is not None:
181
+ plt.title(title)
182
+ plt.axis("off")
183
+ plt.show()
184
+
185
+
186
+ # Helper function to convert NumPy types to native Python types for JSON serialization
187
+ def json_serializable(obj: Any) -> Any:
188
+ """Convert NumPy types to native Python types for JSON serialization.
189
+
190
+ Args:
191
+ obj: Any object to convert.
192
+
193
+ Returns:
194
+ JSON serializable version of the object.
195
+ """
196
+ if isinstance(obj, np.integer):
197
+ return int(obj)
198
+ elif isinstance(obj, np.floating):
199
+ return float(obj)
200
+ elif isinstance(obj, np.ndarray):
201
+ return obj.tolist()
202
+ else:
203
+ return obj
204
+
205
+
206
+ def download_overture_buildings(
207
+ bbox: Tuple[float, float, float, float],
208
+ output_file: str,
209
+ output_format: str = "geojson",
210
+ data_type: str = "building",
211
+ verbose: bool = True,
212
+ ) -> str:
213
+ """Download building data from Overture Maps for a given bounding box using the overturemaps CLI tool.
214
+
215
+ Args:
216
+ bbox: Bounding box in the format (min_lon, min_lat, max_lon, max_lat) in WGS84 coordinates.
217
+ output_file: Path to save the output file.
218
+ output_format: Format to save the output, one of "geojson", "geojsonseq", or "geoparquet".
219
+ data_type: The Overture Maps data type to download (building, place, etc.).
220
+ verbose: Whether to print verbose output.
221
+
222
+ Returns:
223
+ Path to the output file.
224
+ """
225
+ # Create output directory if needed
226
+ output_dir = os.path.dirname(output_file)
227
+ if output_dir and not os.path.exists(output_dir):
228
+ os.makedirs(output_dir, exist_ok=True)
229
+
230
+ # Format the bounding box string for the command
231
+ west, south, east, north = bbox
232
+ bbox_str = f"{west},{south},{east},{north}"
233
+
234
+ # Build the command
235
+ cmd = [
236
+ "overturemaps",
237
+ "download",
238
+ "--bbox",
239
+ bbox_str,
240
+ "-f",
241
+ output_format,
242
+ "--type",
243
+ data_type,
244
+ "--output",
245
+ output_file,
246
+ ]
247
+
248
+ if verbose:
249
+ logger.info(f"Running command: {' '.join(cmd)}")
250
+ logger.info("Downloading %s data for area: %s", data_type, bbox_str)
251
+
252
+ try:
253
+ # Run the command
254
+ result = subprocess.run(
255
+ cmd,
256
+ check=True,
257
+ stdout=subprocess.PIPE if not verbose else None,
258
+ stderr=subprocess.PIPE,
259
+ text=True,
260
+ )
261
+
262
+ # Check if the file was created
263
+ if os.path.exists(output_file):
264
+ file_size = os.path.getsize(output_file) / (1024 * 1024) # Size in MB
265
+ logger.info(
266
+ f"Successfully downloaded data to {output_file} ({file_size:.2f} MB)"
267
+ )
268
+
269
+ # Optionally show some stats about the downloaded data
270
+ if output_format == "geojson" and os.path.getsize(output_file) > 0:
271
+ try:
272
+ gdf = gpd.read_file(output_file)
273
+ logger.info(f"Downloaded {len(gdf)} features")
274
+
275
+ if len(gdf) > 0 and verbose:
276
+ # Show a sample of the attribute names
277
+ attrs = list(gdf.columns)
278
+ attrs.remove("geometry")
279
+ logger.info(f"Available attributes: {', '.join(attrs[:10])}...")
280
+ except Exception as e:
281
+ logger.warning(f"Could not read the GeoJSON file: {str(e)}")
282
+
283
+ return output_file
284
+ else:
285
+ logger.error(f"Command completed but file {output_file} was not created")
286
+ if result.stderr:
287
+ logger.error(f"Command error output: {result.stderr}")
288
+ return None
289
+
290
+ except subprocess.CalledProcessError as e:
291
+ logger.error(f"Error running overturemaps command: {str(e)}")
292
+ if e.stderr:
293
+ logger.error(f"Command error output: {e.stderr}")
294
+ raise RuntimeError(f"Failed to download Overture Maps data: {str(e)}")
295
+ except Exception as e:
296
+ logger.error(f"Unexpected error: {str(e)}")
297
+ raise
298
+
299
+
300
+ def convert_vector_format(
301
+ input_file: str,
302
+ output_format: str = "geojson",
303
+ filter_expression: Optional[str] = None,
304
+ ) -> str:
305
+ """Convert the downloaded data to a different format or filter it.
306
+
307
+ Args:
308
+ input_file: Path to the input file.
309
+ output_format: Format to convert to, one of "geojson", "parquet", "shapefile", "csv".
310
+ filter_expression: Optional GeoDataFrame query expression to filter the data.
311
+
312
+ Returns:
313
+ Path to the converted file.
314
+ """
315
+ try:
316
+ # Read the input file
317
+ logger.info(f"Reading {input_file}")
318
+ gdf = gpd.read_file(input_file)
319
+
320
+ # Apply filter if specified
321
+ if filter_expression:
322
+ logger.info(f"Filtering data using expression: {filter_expression}")
323
+ gdf = gdf.query(filter_expression)
324
+ logger.info(f"After filtering: {len(gdf)} features")
325
+
326
+ # Define output file path
327
+ base_path = os.path.splitext(input_file)[0]
328
+
329
+ if output_format == "geojson":
330
+ output_file = f"{base_path}.geojson"
331
+ logger.info(f"Converting to GeoJSON: {output_file}")
332
+ gdf.to_file(output_file, driver="GeoJSON")
333
+ elif output_format == "parquet":
334
+ output_file = f"{base_path}.parquet"
335
+ logger.info(f"Converting to Parquet: {output_file}")
336
+ gdf.to_parquet(output_file)
337
+ elif output_format == "shapefile":
338
+ output_file = f"{base_path}.shp"
339
+ logger.info(f"Converting to Shapefile: {output_file}")
340
+ gdf.to_file(output_file)
341
+ elif output_format == "csv":
342
+ output_file = f"{base_path}.csv"
343
+ logger.info(f"Converting to CSV: {output_file}")
344
+
345
+ # For CSV, we need to convert geometry to WKT
346
+ gdf["geometry_wkt"] = gdf.geometry.apply(lambda g: g.wkt)
347
+
348
+ # Save to CSV with geometry as WKT
349
+ gdf.drop(columns=["geometry"]).to_csv(output_file, index=False)
350
+ else:
351
+ raise ValueError(f"Unsupported output format: {output_format}")
352
+
353
+ return output_file
354
+
355
+ except Exception as e:
356
+ logger.error(f"Error converting data: {str(e)}")
357
+ raise
358
+
359
+
360
+ def extract_building_stats(geojson_file: str) -> Dict[str, Any]:
361
+ """Extract statistics from the building data.
362
+
363
+ Args:
364
+ geojson_file: Path to the GeoJSON file.
365
+
366
+ Returns:
367
+ Dictionary with statistics.
368
+ """
369
+ try:
370
+ # Read the GeoJSON file
371
+ gdf = gpd.read_file(geojson_file)
372
+
373
+ # Calculate statistics
374
+ bbox = gdf.total_bounds.tolist()
375
+ # Convert numpy values to Python native types
376
+ bbox = [float(x) for x in bbox]
377
+
378
+ stats = {
379
+ "total_buildings": int(len(gdf)),
380
+ "has_height": (
381
+ int(gdf["height"].notna().sum()) if "height" in gdf.columns else 0
382
+ ),
383
+ "has_name": (
384
+ int(gdf["names.common.value"].notna().sum())
385
+ if "names.common.value" in gdf.columns
386
+ else 0
387
+ ),
388
+ "bbox": bbox,
389
+ }
390
+
391
+ return stats
392
+
393
+ except Exception as e:
394
+ logger.error(f"Error extracting statistics: {str(e)}")
395
+ return {"error": str(e)}