geoai-py 0.1.7__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,34 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.1.7"
5
+ __version__ = "0.2.1"
6
6
 
7
7
 
8
+ import os
9
+ import sys
10
+
11
+
12
+ def set_proj_lib_path():
13
+ """Set the PROJ_LIB environment variable based on the current conda environment."""
14
+ try:
15
+ # Get conda environment path
16
+ conda_env_path = os.environ.get("CONDA_PREFIX") or sys.prefix
17
+
18
+ # Set PROJ_LIB environment variable
19
+ proj_path = os.path.join(conda_env_path, "share", "proj")
20
+ gdal_path = os.path.join(conda_env_path, "share", "gdal")
21
+
22
+ # Check if the directory exists before setting
23
+ if os.path.exists(proj_path):
24
+ os.environ["PROJ_LIB"] = proj_path
25
+ if os.path.exists(gdal_path):
26
+ os.environ["GDAL_DATA"] = gdal_path
27
+ except Exception as e:
28
+ print(e)
29
+ return
30
+
31
+
32
+ if "google.colab" not in sys.modules:
33
+ set_proj_lib_path()
34
+
8
35
  from .geoai import *
geoai/common.py CHANGED
@@ -5,8 +5,12 @@ from collections.abc import Iterable
5
5
  from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
6
6
  import matplotlib.pyplot as plt
7
7
 
8
+ import leafmap
8
9
  import torch
9
10
  import numpy as np
11
+ import xarray as xr
12
+ import rioxarray
13
+ import rasterio as rio
10
14
  from torch.utils.data import DataLoader
11
15
  from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
12
16
  from torchgeo.samplers import RandomGeoSampler, Units
@@ -55,10 +59,12 @@ def viz_raster(
55
59
  Returns:
56
60
  leafmap.Map: The map object with the raster layer added.
57
61
  """
58
- import leafmap
59
62
 
60
63
  m = leafmap.Map(basemap=basemap)
61
64
 
65
+ if isinstance(source, dict):
66
+ source = dict_to_image(source)
67
+
62
68
  m.add_raster(
63
69
  source=source,
64
70
  indexes=indexes,
@@ -86,6 +92,7 @@ def viz_image(
86
92
  scale_factor: float = 1.0,
87
93
  figsize: Tuple[int, int] = (10, 5),
88
94
  axis_off: bool = True,
95
+ title: Optional[str] = None,
89
96
  **kwargs: Any,
90
97
  ) -> None:
91
98
  """
@@ -98,6 +105,7 @@ def viz_image(
98
105
  scale_factor (float, optional): The scale factor to apply to the image. Defaults to 1.0.
99
106
  figsize (Tuple[int, int], optional): The size of the figure. Defaults to (10, 5).
100
107
  axis_off (bool, optional): Whether to turn off the axis. Defaults to True.
108
+ title (Optional[str], optional): The title of the plot. Defaults to None.
101
109
  **kwargs (Any): Additional keyword arguments for plt.imshow().
102
110
 
103
111
  Returns:
@@ -124,6 +132,8 @@ def viz_image(
124
132
  plt.imshow(image, **kwargs)
125
133
  if axis_off:
126
134
  plt.axis("off")
135
+ if title is not None:
136
+ plt.title(title)
127
137
  plt.show()
128
138
  plt.close()
129
139
 
@@ -277,3 +287,150 @@ def calc_stats(
277
287
  # at the end, we shall have 2 vectors with length n=chnls
278
288
  # we will average them considering the number of images
279
289
  return accum_mean / len(files), accum_std / len(files)
290
+
291
+
292
+ def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
293
+ """Convert a dictionary to a xarray DataArray. The dictionary should contain the
294
+ following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
295
+ dataset sampler.
296
+
297
+ Args:
298
+ data_dict (Dict): The dictionary containing the data.
299
+
300
+ Returns:
301
+ xr.DataArray: The xarray DataArray.
302
+ """
303
+
304
+ from affine import Affine
305
+
306
+ # Extract components from the dictionary
307
+ crs = data_dict["crs"]
308
+ bounds = data_dict["bounds"]
309
+ image_tensor = data_dict["image"]
310
+
311
+ # Convert tensor to numpy array if needed
312
+ if hasattr(image_tensor, "numpy"):
313
+ # For PyTorch tensors
314
+ image_array = image_tensor.numpy()
315
+ else:
316
+ # If it's already a numpy array or similar
317
+ image_array = np.array(image_tensor)
318
+
319
+ # Calculate pixel resolution
320
+ width = image_array.shape[2] # Width is the size of the last dimension
321
+ height = image_array.shape[1] # Height is the size of the middle dimension
322
+
323
+ res_x = (bounds.maxx - bounds.minx) / width
324
+ res_y = (bounds.maxy - bounds.miny) / height
325
+
326
+ # Create the transform matrix
327
+ transform = Affine(res_x, 0.0, bounds.minx, 0.0, -res_y, bounds.maxy)
328
+
329
+ # Create dimensions
330
+ x_coords = np.linspace(bounds.minx + res_x / 2, bounds.maxx - res_x / 2, width)
331
+ y_coords = np.linspace(bounds.maxy - res_y / 2, bounds.miny + res_y / 2, height)
332
+
333
+ # If time dimension exists in the bounds
334
+ if hasattr(bounds, "mint") and hasattr(bounds, "maxt"):
335
+ # Create a single time value or range if needed
336
+ t_coords = [
337
+ bounds.mint
338
+ ] # Or np.linspace(bounds.mint, bounds.maxt, num_time_steps)
339
+
340
+ # Create DataArray with time dimension
341
+ dims = (
342
+ ("band", "y", "x")
343
+ if image_array.shape[0] <= 10
344
+ else ("time", "band", "y", "x")
345
+ )
346
+
347
+ if dims[0] == "band":
348
+ # For multi-band single time
349
+ da = xr.DataArray(
350
+ image_array,
351
+ dims=dims,
352
+ coords={
353
+ "band": np.arange(1, image_array.shape[0] + 1),
354
+ "y": y_coords,
355
+ "x": x_coords,
356
+ },
357
+ )
358
+ else:
359
+ # For multi-time multi-band
360
+ da = xr.DataArray(
361
+ image_array,
362
+ dims=dims,
363
+ coords={
364
+ "time": t_coords,
365
+ "band": np.arange(1, image_array.shape[1] + 1),
366
+ "y": y_coords,
367
+ "x": x_coords,
368
+ },
369
+ )
370
+ else:
371
+ # Create DataArray without time dimension
372
+ da = xr.DataArray(
373
+ image_array,
374
+ dims=("band", "y", "x"),
375
+ coords={
376
+ "band": np.arange(1, image_array.shape[0] + 1),
377
+ "y": y_coords,
378
+ "x": x_coords,
379
+ },
380
+ )
381
+
382
+ # Set spatial attributes
383
+ da.rio.write_crs(crs, inplace=True)
384
+ da.rio.write_transform(transform, inplace=True)
385
+
386
+ return da
387
+
388
+
389
+ def dict_to_image(
390
+ data_dict: Dict[str, Any], output: Optional[str] = None, **kwargs
391
+ ) -> rio.DatasetReader:
392
+ """Convert a dictionary containing spatial data to a rasterio dataset or save it to
393
+ a file. The dictionary should contain the following keys: "crs", "bounds", and "image".
394
+ It can be generated from a TorchGeo dataset sampler.
395
+
396
+ This function transforms a dictionary with CRS, bounding box, and image data
397
+ into a rasterio DatasetReader using leafmap's array_to_image utility after
398
+ first converting to a rioxarray DataArray.
399
+
400
+ Args:
401
+ data_dict: A dictionary containing:
402
+ - 'crs': A pyproj CRS object
403
+ - 'bounds': A BoundingBox object with minx, maxx, miny, maxy attributes
404
+ and optionally mint, maxt for temporal bounds
405
+ - 'image': A tensor or array-like object with image data
406
+ output: Optional path to save the image to a file. If not provided, the image
407
+ will be returned as a rasterio DatasetReader object.
408
+ **kwargs: Additional keyword arguments to pass to leafmap.array_to_image.
409
+ Common options include:
410
+ - colormap: str, name of the colormap (e.g., 'viridis', 'terrain')
411
+ - vmin: float, minimum value for colormap scaling
412
+ - vmax: float, maximum value for colormap scaling
413
+
414
+ Returns:
415
+ A rasterio DatasetReader object that can be used for visualization or
416
+ further processing.
417
+
418
+ Examples:
419
+ >>> image = dict_to_image(
420
+ ... {'crs': CRS.from_epsg(26911), 'bounds': bbox, 'image': tensor},
421
+ ... colormap='terrain'
422
+ ... )
423
+ >>> fig, ax = plt.subplots(figsize=(10, 10))
424
+ >>> show(image, ax=ax)
425
+ """
426
+ da = dict_to_rioxarray(data_dict)
427
+
428
+ if output is not None:
429
+ out_dir = os.path.abspath(os.path.dirname(output))
430
+ if not os.path.exists(out_dir):
431
+ os.makedirs(out_dir, exist_ok=True)
432
+ da.rio.to_raster(output)
433
+ return output
434
+ else:
435
+ image = leafmap.array_to_image(da, **kwargs)
436
+ return image
geoai/download.py CHANGED
@@ -26,6 +26,7 @@ def download_naip(
26
26
  output_dir: str,
27
27
  year: Optional[int] = None,
28
28
  max_items: int = 10,
29
+ overwrite: bool = False,
29
30
  preview: bool = False,
30
31
  **kwargs: Any,
31
32
  ) -> List[str]:
@@ -40,6 +41,7 @@ def download_naip(
40
41
  output_dir: Directory to save the downloaded imagery.
41
42
  year: Specific year of NAIP imagery to download (e.g., 2020). If None, returns imagery from all available years.
42
43
  max_items: Maximum number of items to download.
44
+ overwrite: If True, overwrite existing files with the same name.
43
45
  preview: If True, display a preview of the downloaded imagery.
44
46
 
45
47
  Returns:
@@ -75,6 +77,9 @@ def download_naip(
75
77
  search_results = catalog.search(**search_params)
76
78
  items = list(search_results.items())
77
79
 
80
+ if len(items) > max_items:
81
+ items = items[:max_items]
82
+
78
83
  if not items:
79
84
  print("No NAIP imagery found for the specified region and parameters.")
80
85
  return []
@@ -98,6 +103,10 @@ def download_naip(
98
103
  rgb_asset.href.split("?")[0]
99
104
  ) # Remove query parameters
100
105
  output_path = os.path.join(output_dir, original_filename)
106
+ if not overwrite and os.path.exists(output_path):
107
+ print(f"Skipping existing file: {output_path}")
108
+ downloaded_files.append(output_path)
109
+ continue
101
110
 
102
111
  print(f"Downloading item {i+1}/{len(items)}: {original_filename}")
103
112