geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/utils.py CHANGED
@@ -11,7 +11,17 @@ import xml.etree.ElementTree as ET
11
11
  from collections import OrderedDict
12
12
  from collections.abc import Iterable
13
13
  from pathlib import Path
14
- from typing import Any, Dict, List, Optional, Tuple, Union
14
+ from typing import (
15
+ Any,
16
+ Callable,
17
+ Dict,
18
+ Generator,
19
+ Iterator,
20
+ List,
21
+ Optional,
22
+ Tuple,
23
+ Union,
24
+ )
15
25
 
16
26
  # Third-Party Libraries
17
27
  import cv2
@@ -55,8 +65,8 @@ def view_raster(
55
65
  basemap: Optional[str] = "OpenStreetMap",
56
66
  basemap_args: Optional[Dict] = None,
57
67
  backend: Optional[str] = "folium",
58
- **kwargs,
59
- ):
68
+ **kwargs: Any,
69
+ ) -> Any:
60
70
  """
61
71
  Visualize a raster using leafmap.
62
72
 
@@ -471,8 +481,8 @@ def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
471
481
 
472
482
 
473
483
  def dict_to_image(
474
- data_dict: Dict[str, Any], output: Optional[str] = None, **kwargs
475
- ) -> rasterio.DatasetReader:
484
+ data_dict: Dict[str, Any], output: Optional[str] = None, **kwargs: Any
485
+ ) -> Union[str, Any]:
476
486
  """Convert a dictionary containing spatial data to a rasterio dataset or save it to
477
487
  a file. The dictionary should contain the following keys: "crs", "bounds", and "image".
478
488
  It can be generated from a TorchGeo dataset sampler.
@@ -521,24 +531,24 @@ def dict_to_image(
521
531
 
522
532
 
523
533
  def view_vector(
524
- vector_data,
525
- column=None,
526
- cmap="viridis",
527
- figsize=(10, 10),
528
- title=None,
529
- legend=True,
530
- basemap=False,
531
- basemap_type="streets",
532
- alpha=0.7,
533
- edge_color="black",
534
- classification="quantiles",
535
- n_classes=5,
536
- highlight_index=None,
537
- highlight_color="red",
538
- scheme=None,
539
- save_path=None,
540
- dpi=300,
541
- ):
534
+ vector_data: Union[str, gpd.GeoDataFrame],
535
+ column: Optional[str] = None,
536
+ cmap: str = "viridis",
537
+ figsize: Tuple[int, int] = (10, 10),
538
+ title: Optional[str] = None,
539
+ legend: bool = True,
540
+ basemap: bool = False,
541
+ basemap_type: str = "streets",
542
+ alpha: float = 0.7,
543
+ edge_color: str = "black",
544
+ classification: str = "quantiles",
545
+ n_classes: int = 5,
546
+ highlight_index: Optional[int] = None,
547
+ highlight_color: str = "red",
548
+ scheme: Optional[str] = None,
549
+ save_path: Optional[str] = None,
550
+ dpi: int = 300,
551
+ ) -> Any:
542
552
  """
543
553
  Visualize vector datasets with options for styling, classification, basemaps and more.
544
554
 
@@ -675,11 +685,11 @@ def view_vector(
675
685
 
676
686
 
677
687
  def view_vector_interactive(
678
- vector_data,
679
- layer_name="Vector Layer",
680
- tiles_args=None,
681
- **kwargs,
682
- ):
688
+ vector_data: Union[str, gpd.GeoDataFrame],
689
+ layer_name: str = "Vector Layer",
690
+ tiles_args: Optional[Dict] = None,
691
+ **kwargs: Any,
692
+ ) -> Any:
683
693
  """
684
694
  Visualize vector datasets with options for styling, classification, basemaps and more.
685
695
 
@@ -791,12 +801,12 @@ def view_vector_interactive(
791
801
 
792
802
 
793
803
  def regularization(
794
- building_polygons,
795
- angle_tolerance=10,
796
- simplify_tolerance=0.5,
797
- orthogonalize=True,
798
- preserve_topology=True,
799
- ):
804
+ building_polygons: Union[gpd.GeoDataFrame, List[Polygon]],
805
+ angle_tolerance: float = 10,
806
+ simplify_tolerance: float = 0.5,
807
+ orthogonalize: bool = True,
808
+ preserve_topology: bool = True,
809
+ ) -> Union[gpd.GeoDataFrame, List[Polygon]]:
800
810
  """
801
811
  Regularizes building footprint polygons with multiple techniques beyond minimum
802
812
  rotated rectangles.
@@ -920,7 +930,9 @@ def regularization(
920
930
  return regularized_buildings
921
931
 
922
932
 
923
- def hybrid_regularization(building_polygons):
933
+ def hybrid_regularization(
934
+ building_polygons: Union[gpd.GeoDataFrame, List[Polygon]],
935
+ ) -> Union[gpd.GeoDataFrame, List[Polygon]]:
924
936
  """
925
937
  A comprehensive hybrid approach to building footprint regularization.
926
938
 
@@ -1031,8 +1043,11 @@ def hybrid_regularization(building_polygons):
1031
1043
 
1032
1044
 
1033
1045
  def adaptive_regularization(
1034
- building_polygons, simplify_tolerance=0.5, area_threshold=0.9, preserve_shape=True
1035
- ):
1046
+ building_polygons: Union[gpd.GeoDataFrame, List[Polygon]],
1047
+ simplify_tolerance: float = 0.5,
1048
+ area_threshold: float = 0.9,
1049
+ preserve_shape: bool = True,
1050
+ ) -> Union[gpd.GeoDataFrame, List[Polygon]]:
1036
1051
  """
1037
1052
  Adaptively regularizes building footprints based on their characteristics.
1038
1053
 
@@ -1132,7 +1147,7 @@ def adaptive_regularization(
1132
1147
  return results
1133
1148
 
1134
1149
 
1135
- def install_package(package):
1150
+ def install_package(package: Union[str, List[str]]) -> None:
1136
1151
  """Install a Python package.
1137
1152
 
1138
1153
  Args:
@@ -1188,9 +1203,9 @@ def create_split_map(
1188
1203
  zoom: Optional[int] = 2,
1189
1204
  height: Optional[int] = "600px",
1190
1205
  basemap: Optional[str] = None,
1191
- basemap_args: Optional[dict] = None,
1192
- m=None,
1193
- **kwargs,
1206
+ basemap_args: Optional[Dict] = None,
1207
+ m: Optional[Any] = None,
1208
+ **kwargs: Any,
1194
1209
  ) -> None:
1195
1210
  """Adds split map.
1196
1211
 
@@ -1261,7 +1276,12 @@ def create_split_map(
1261
1276
  return m
1262
1277
 
1263
1278
 
1264
- def download_file(url, output_path=None, overwrite=False, unzip=True):
1279
+ def download_file(
1280
+ url: str,
1281
+ output_path: Optional[str] = None,
1282
+ overwrite: bool = False,
1283
+ unzip: bool = True,
1284
+ ) -> str:
1265
1285
  """
1266
1286
  Download a file from a given URL with a progress bar.
1267
1287
  Optionally unzip the file if it's a ZIP archive.
@@ -1277,9 +1297,10 @@ def download_file(url, output_path=None, overwrite=False, unzip=True):
1277
1297
  str: The path to the downloaded file or the extracted directory.
1278
1298
  """
1279
1299
 
1280
- from tqdm import tqdm
1281
1300
  import zipfile
1282
1301
 
1302
+ from tqdm import tqdm
1303
+
1283
1304
  if output_path is None:
1284
1305
  output_path = os.path.basename(url)
1285
1306
 
@@ -1318,7 +1339,7 @@ def download_file(url, output_path=None, overwrite=False, unzip=True):
1318
1339
  return output_path
1319
1340
 
1320
1341
 
1321
- def get_raster_info(raster_path):
1342
+ def get_raster_info(raster_path: str) -> Dict[str, Any]:
1322
1343
  """Display basic information about a raster dataset.
1323
1344
 
1324
1345
  Args:
@@ -1361,7 +1382,7 @@ def get_raster_info(raster_path):
1361
1382
  return info
1362
1383
 
1363
1384
 
1364
- def get_raster_stats(raster_path, divide_by=1.0):
1385
+ def get_raster_stats(raster_path: str, divide_by: float = 1.0) -> Dict[str, Any]:
1365
1386
  """Calculate statistics for each band in a raster dataset.
1366
1387
 
1367
1388
  This function computes min, max, mean, and standard deviation values
@@ -1398,7 +1419,9 @@ def get_raster_stats(raster_path, divide_by=1.0):
1398
1419
  return stats
1399
1420
 
1400
1421
 
1401
- def print_raster_info(raster_path, show_preview=True, figsize=(10, 8)):
1422
+ def print_raster_info(
1423
+ raster_path: str, show_preview: bool = True, figsize: Tuple[int, int] = (10, 8)
1424
+ ) -> Optional[Dict[str, Any]]:
1402
1425
  """Print formatted information about a raster dataset and optionally show a preview.
1403
1426
 
1404
1427
  Args:
@@ -1458,7 +1481,7 @@ def print_raster_info(raster_path, show_preview=True, figsize=(10, 8)):
1458
1481
  print(f"Error reading raster: {str(e)}")
1459
1482
 
1460
1483
 
1461
- def get_raster_info_gdal(raster_path):
1484
+ def get_raster_info_gdal(raster_path: str) -> Optional[Dict[str, Any]]:
1462
1485
  """Get basic information about a raster dataset using GDAL.
1463
1486
 
1464
1487
  Args:
@@ -1517,7 +1540,7 @@ def get_raster_info_gdal(raster_path):
1517
1540
  return info
1518
1541
 
1519
1542
 
1520
- def get_vector_info(vector_path):
1543
+ def get_vector_info(vector_path: str) -> Optional[Dict[str, Any]]:
1521
1544
  """Display basic information about a vector dataset using GeoPandas.
1522
1545
 
1523
1546
  Args:
@@ -1563,7 +1586,9 @@ def get_vector_info(vector_path):
1563
1586
  return info
1564
1587
 
1565
1588
 
1566
- def print_vector_info(vector_path, show_preview=True, figsize=(10, 8)):
1589
+ def print_vector_info(
1590
+ vector_path: str, show_preview: bool = True, figsize: Tuple[int, int] = (10, 8)
1591
+ ) -> Optional[Dict[str, Any]]:
1567
1592
  """Print formatted information about a vector dataset and optionally show a preview.
1568
1593
 
1569
1594
  Args:
@@ -1623,7 +1648,7 @@ def print_vector_info(vector_path, show_preview=True, figsize=(10, 8)):
1623
1648
 
1624
1649
 
1625
1650
  # Alternative implementation using OGR directly
1626
- def get_vector_info_ogr(vector_path):
1651
+ def get_vector_info_ogr(vector_path: str) -> Optional[Dict[str, Any]]:
1627
1652
  """Get basic information about a vector dataset using OGR.
1628
1653
 
1629
1654
  Args:
@@ -1686,7 +1711,9 @@ def get_vector_info_ogr(vector_path):
1686
1711
  return info
1687
1712
 
1688
1713
 
1689
- def analyze_vector_attributes(vector_path, attribute_name):
1714
+ def analyze_vector_attributes(
1715
+ vector_path: str, attribute_name: str
1716
+ ) -> Optional[Dict[str, Any]]:
1690
1717
  """Analyze a specific attribute in a vector dataset and create a histogram.
1691
1718
 
1692
1719
  Args:
@@ -1763,8 +1790,11 @@ def analyze_vector_attributes(vector_path, attribute_name):
1763
1790
 
1764
1791
 
1765
1792
  def visualize_vector_by_attribute(
1766
- vector_path, attribute_name, cmap="viridis", figsize=(10, 8)
1767
- ):
1793
+ vector_path: str,
1794
+ attribute_name: str,
1795
+ cmap: str = "viridis",
1796
+ figsize: Tuple[int, int] = (10, 8),
1797
+ ) -> bool:
1768
1798
  """Create a thematic map visualization of vector data based on an attribute.
1769
1799
 
1770
1800
  Args:
@@ -1812,8 +1842,13 @@ def visualize_vector_by_attribute(
1812
1842
 
1813
1843
 
1814
1844
  def clip_raster_by_bbox(
1815
- input_raster, output_raster, bbox, bands=None, bbox_type="geo", bbox_crs=None
1816
- ):
1845
+ input_raster: str,
1846
+ output_raster: str,
1847
+ bbox: List[float],
1848
+ bands: Optional[List[int]] = None,
1849
+ bbox_type: str = "geo",
1850
+ bbox_crs: Optional[str] = None,
1851
+ ) -> str:
1817
1852
  """
1818
1853
  Clip a raster dataset using a bounding box and optionally select specific bands.
1819
1854
 
@@ -1999,17 +2034,17 @@ def clip_raster_by_bbox(
1999
2034
 
2000
2035
 
2001
2036
  def raster_to_vector(
2002
- raster_path,
2003
- output_path=None,
2004
- threshold=0,
2005
- min_area=10,
2006
- simplify_tolerance=None,
2007
- class_values=None,
2008
- attribute_name="class",
2009
- unique_attribute_value=False,
2010
- output_format="geojson",
2011
- plot_result=False,
2012
- ):
2037
+ raster_path: str,
2038
+ output_path: Optional[str] = None,
2039
+ threshold: float = 0,
2040
+ min_area: float = 10,
2041
+ simplify_tolerance: Optional[float] = None,
2042
+ class_values: Optional[List[int]] = None,
2043
+ attribute_name: str = "class",
2044
+ unique_attribute_value: bool = False,
2045
+ output_format: str = "geojson",
2046
+ plot_result: bool = False,
2047
+ ) -> gpd.GeoDataFrame:
2013
2048
  """
2014
2049
  Convert a raster label mask to vector polygons.
2015
2050
 
@@ -2131,18 +2166,18 @@ def raster_to_vector(
2131
2166
 
2132
2167
 
2133
2168
  def raster_to_vector_batch(
2134
- input_dir,
2135
- output_dir,
2136
- pattern="*.tif",
2137
- threshold=0,
2138
- min_area=10,
2139
- simplify_tolerance=None,
2140
- class_values=None,
2141
- attribute_name="class",
2142
- output_format="geojson",
2143
- merge_output=False,
2144
- merge_filename="merged_vectors",
2145
- ):
2169
+ input_dir: str,
2170
+ output_dir: str,
2171
+ pattern: str = "*.tif",
2172
+ threshold: float = 0,
2173
+ min_area: float = 10,
2174
+ simplify_tolerance: Optional[float] = None,
2175
+ class_values: Optional[List[int]] = None,
2176
+ attribute_name: str = "class",
2177
+ output_format: str = "geojson",
2178
+ merge_output: bool = False,
2179
+ merge_filename: str = "merged_vectors",
2180
+ ) -> Optional[gpd.GeoDataFrame]:
2146
2181
  """
2147
2182
  Batch convert multiple raster files to vector polygons.
2148
2183
 
@@ -2246,21 +2281,21 @@ def raster_to_vector_batch(
2246
2281
 
2247
2282
 
2248
2283
  def vector_to_raster(
2249
- vector_path,
2250
- output_path=None,
2251
- reference_raster=None,
2252
- attribute_field=None,
2253
- output_shape=None,
2254
- transform=None,
2255
- pixel_size=None,
2256
- bounds=None,
2257
- crs=None,
2258
- all_touched=False,
2259
- fill_value=0,
2260
- dtype=np.uint8,
2261
- nodata=None,
2262
- plot_result=False,
2263
- ):
2284
+ vector_path: Union[str, gpd.GeoDataFrame],
2285
+ output_path: Optional[str] = None,
2286
+ reference_raster: Optional[str] = None,
2287
+ attribute_field: Optional[str] = None,
2288
+ output_shape: Optional[Tuple[int, int]] = None,
2289
+ transform: Optional[Any] = None,
2290
+ pixel_size: Optional[float] = None,
2291
+ bounds: Optional[List[float]] = None,
2292
+ crs: Optional[str] = None,
2293
+ all_touched: bool = False,
2294
+ fill_value: Union[int, float] = 0,
2295
+ dtype: Any = np.uint8,
2296
+ nodata: Optional[Union[int, float]] = None,
2297
+ plot_result: bool = False,
2298
+ ) -> np.ndarray:
2264
2299
  """
2265
2300
  Convert vector data to a raster.
2266
2301
 
@@ -4782,7 +4817,9 @@ def masks_to_vector(
4782
4817
  return gdf
4783
4818
 
4784
4819
 
4785
- def read_vector(source, layer=None, **kwargs):
4820
+ def read_vector(
4821
+ source: str, layer: Optional[str] = None, **kwargs: Any
4822
+ ) -> gpd.GeoDataFrame:
4786
4823
  """Reads vector data from various formats including GeoParquet.
4787
4824
 
4788
4825
  This function dynamically determines the file type based on extension
@@ -4857,7 +4894,12 @@ def read_vector(source, layer=None, **kwargs):
4857
4894
  raise ValueError(f"Could not read from source '{source}': {str(e)}")
4858
4895
 
4859
4896
 
4860
- def read_raster(source, band=None, masked=True, **kwargs):
4897
+ def read_raster(
4898
+ source: str,
4899
+ band: Optional[Union[int, List[int]]] = None,
4900
+ masked: bool = True,
4901
+ **kwargs: Any,
4902
+ ) -> xr.DataArray:
4861
4903
  """Reads raster data from various formats using rioxarray.
4862
4904
 
4863
4905
  This function reads raster data from local files or URLs into a rioxarray
@@ -4922,7 +4964,7 @@ def read_raster(source, band=None, masked=True, **kwargs):
4922
4964
  raise ValueError(f"Error reading raster data: {str(e)}")
4923
4965
 
4924
4966
 
4925
- def temp_file_path(ext):
4967
+ def temp_file_path(ext: str) -> str:
4926
4968
  """Returns a temporary file path.
4927
4969
 
4928
4970
  Args:
@@ -5179,7 +5221,12 @@ def region_groups(
5179
5221
  return da, df
5180
5222
 
5181
5223
 
5182
- def add_geometric_properties(data, properties=None, area_unit="m2", length_unit="m"):
5224
+ def add_geometric_properties(
5225
+ data: gpd.GeoDataFrame,
5226
+ properties: Optional[List[str]] = None,
5227
+ area_unit: str = "m2",
5228
+ length_unit: str = "m",
5229
+ ) -> gpd.GeoDataFrame:
5183
5230
  """Calculates geometric properties and adds them to the GeoDataFrame.
5184
5231
 
5185
5232
  This function calculates various geometric properties of features in a
@@ -6617,7 +6664,7 @@ def orthogonalize(
6617
6664
  return gdf
6618
6665
 
6619
6666
 
6620
- def inspect_pth_file(pth_path):
6667
+ def inspect_pth_file(pth_path: str) -> Dict[str, Any]:
6621
6668
  """
6622
6669
  Inspect a PyTorch .pth model file to determine its architecture.
6623
6670
 
@@ -6769,7 +6816,7 @@ def inspect_pth_file(pth_path):
6769
6816
  print(f"Error loading the model file: {str(e)}")
6770
6817
 
6771
6818
 
6772
- def try_common_architectures(state_dict):
6819
+ def try_common_architectures(state_dict: Dict[str, Any]) -> Optional[str]:
6773
6820
  """
6774
6821
  Try to load the state_dict into common architectures to see which one fits.
6775
6822
 
@@ -6809,7 +6856,9 @@ def try_common_architectures(state_dict):
6809
6856
  print(f"- {name}: Failed to load - {str(e)}")
6810
6857
 
6811
6858
 
6812
- def mosaic_geotiffs(input_dir, output_file, mask_file=None):
6859
+ def mosaic_geotiffs(
6860
+ input_dir: str, output_file: str, mask_file: Optional[str] = None
6861
+ ) -> None:
6813
6862
  """Create a mosaic from all GeoTIFF files as a Cloud Optimized GeoTIFF (COG).
6814
6863
 
6815
6864
  This function identifies all GeoTIFF files in the specified directory,
@@ -6972,7 +7021,7 @@ def mosaic_geotiffs(input_dir, output_file, mask_file=None):
6972
7021
  return True
6973
7022
 
6974
7023
 
6975
- def download_model_from_hf(model_path, repo_id=None):
7024
+ def download_model_from_hf(model_path: str, repo_id: Optional[str] = None) -> str:
6976
7025
  """
6977
7026
  Download the object detection model from Hugging Face.
6978
7027
 
@@ -7102,7 +7151,9 @@ def regularize(
7102
7151
  return gdf
7103
7152
 
7104
7153
 
7105
- def vector_to_geojson(filename, output=None, **kwargs):
7154
+ def vector_to_geojson(
7155
+ filename: str, output: Optional[str] = None, **kwargs: Any
7156
+ ) -> str:
7106
7157
  """Converts a vector file to a geojson file.
7107
7158
 
7108
7159
  Args:
@@ -7162,8 +7213,8 @@ def coords_to_xy(
7162
7213
  src_fp: str,
7163
7214
  coords: np.ndarray,
7164
7215
  coord_crs: str = "epsg:4326",
7165
- return_out_of_bounds=False,
7166
- **kwargs,
7216
+ return_out_of_bounds: bool = False,
7217
+ **kwargs: Any,
7167
7218
  ) -> np.ndarray:
7168
7219
  """Converts a list or array of coordinates to pixel coordinates, i.e., (col, row) coordinates.
7169
7220
 
@@ -7230,7 +7281,13 @@ def coords_to_xy(
7230
7281
  return output
7231
7282
 
7232
7283
 
7233
- def boxes_to_vector(coords, src_crs, dst_crs="EPSG:4326", output=None, **kwargs):
7284
+ def boxes_to_vector(
7285
+ coords: Union[List[List[float]], np.ndarray],
7286
+ src_crs: str,
7287
+ dst_crs: str = "EPSG:4326",
7288
+ output: Optional[str] = None,
7289
+ **kwargs: Any,
7290
+ ) -> gpd.GeoDataFrame:
7234
7291
  """
7235
7292
  Convert a list of bounding box coordinates to vector data.
7236
7293
 
@@ -7264,16 +7321,16 @@ def boxes_to_vector(coords, src_crs, dst_crs="EPSG:4326", output=None, **kwargs)
7264
7321
 
7265
7322
 
7266
7323
  def rowcol_to_xy(
7267
- src_fp,
7268
- rows=None,
7269
- cols=None,
7270
- boxes=None,
7271
- zs=None,
7272
- offset="center",
7273
- output=None,
7274
- dst_crs="EPSG:4326",
7275
- **kwargs,
7276
- ):
7324
+ src_fp: str,
7325
+ rows: Optional[List[int]] = None,
7326
+ cols: Optional[List[int]] = None,
7327
+ boxes: Optional[List[List[int]]] = None,
7328
+ zs: Optional[List[float]] = None,
7329
+ offset: str = "center",
7330
+ output: Optional[str] = None,
7331
+ dst_crs: str = "EPSG:4326",
7332
+ **kwargs: Any,
7333
+ ) -> Tuple[List[float], List[float]]:
7277
7334
  """Converts a list of (row, col) coordinates to (x, y) coordinates.
7278
7335
 
7279
7336
  Args:
@@ -7320,8 +7377,8 @@ def rowcol_to_xy(
7320
7377
 
7321
7378
 
7322
7379
  def bbox_to_xy(
7323
- src_fp: str, coords: list, coord_crs: str = "epsg:4326", **kwargs
7324
- ) -> list:
7380
+ src_fp: str, coords: List[float], coord_crs: str = "epsg:4326", **kwargs: Any
7381
+ ) -> List[float]:
7325
7382
  """Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
7326
7383
  Note that map bbox coords is [minx, miny, maxx, maxy] from bottomleft to topright
7327
7384
  While rasterio bbox coords is [minx, max, maxx, min] from topleft to bottomright
@@ -7411,8 +7468,8 @@ def bbox_to_xy(
7411
7468
 
7412
7469
 
7413
7470
  def geojson_to_xy(
7414
- src_fp: str, geojson: str, coord_crs: str = "epsg:4326", **kwargs
7415
- ) -> list:
7471
+ src_fp: str, geojson: str, coord_crs: str = "epsg:4326", **kwargs: Any
7472
+ ) -> List[List[float]]:
7416
7473
  """Converts a geojson file or a dictionary of feature collection to a list of pixel coordinates.
7417
7474
 
7418
7475
  Args:
@@ -7430,7 +7487,11 @@ def geojson_to_xy(
7430
7487
  return coords_to_xy(src_fp, coords, src_crs, **kwargs)
7431
7488
 
7432
7489
 
7433
- def write_colormap(image, colormap, output=None):
7490
+ def write_colormap(
7491
+ image: Union[str, np.ndarray],
7492
+ colormap: Union[str, Dict],
7493
+ output: Optional[str] = None,
7494
+ ) -> Optional[str]:
7434
7495
  """Write a colormap to an image.
7435
7496
 
7436
7497
  Args:
@@ -7443,7 +7504,9 @@ def write_colormap(image, colormap, output=None):
7443
7504
  leafmap.write_image_colormap(image, colormap, output)
7444
7505
 
7445
7506
 
7446
- def plot_performance_metrics(history_path, figsize=(15, 5), verbose=True):
7507
+ def plot_performance_metrics(
7508
+ history_path: str, figsize: Tuple[int, int] = (15, 5), verbose: bool = True
7509
+ ) -> None:
7447
7510
  """Plot performance metrics from a history object.
7448
7511
 
7449
7512
  Args:
@@ -7510,7 +7573,7 @@ def plot_performance_metrics(history_path, figsize=(15, 5), verbose=True):
7510
7573
  print(f"Final Dice: {history[val_dice_key][-1]:.4f}")
7511
7574
 
7512
7575
 
7513
- def get_device():
7576
+ def get_device() -> torch.device:
7514
7577
  """
7515
7578
  Returns the best available device for deep learning in the order:
7516
7579
  CUDA (NVIDIA GPU) > MPS (Apple Silicon GPU) > CPU
@@ -7536,7 +7599,7 @@ def plot_prediction_comparison(
7536
7599
  original_colormap: Optional[str] = None,
7537
7600
  indexes: Optional[List[int]] = None,
7538
7601
  divider: Optional[float] = None,
7539
- ):
7602
+ ) -> None:
7540
7603
  """Plot original image, prediction, and optional ground truth side by side.
7541
7604
 
7542
7605
  Supports input as file paths, NumPy arrays, or PIL Images. For multi-band
@@ -7699,8 +7762,12 @@ def stack_bands(
7699
7762
  Returns:
7700
7763
  str: Path to the output file.
7701
7764
  """
7765
+ import leafmap
7766
+
7702
7767
  if not input_files:
7703
7768
  raise ValueError("No input files provided.")
7769
+ elif isinstance(input_files, str):
7770
+ input_files = leafmap.find_files(input_files, ".tif")
7704
7771
 
7705
7772
  if os.path.exists(output_file) and not overwrite:
7706
7773
  print(f"Output file already exists: {output_file}")
@@ -7744,3 +7811,11 @@ def stack_bands(
7744
7811
  os.remove(temp_vrt)
7745
7812
 
7746
7813
  return output_file
7814
+
7815
+
7816
+ def empty_cache() -> None:
7817
+ """Empty the cache of the current device."""
7818
+ if torch.cuda.is_available():
7819
+ torch.cuda.empty_cache()
7820
+ elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
7821
+ torch.mps.empty_cache()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: geoai-py
3
- Version: 0.8.3
3
+ Version: 0.9.1
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -19,6 +19,8 @@ License-File: LICENSE
19
19
  Requires-Dist: albumentations
20
20
  Requires-Dist: buildingregulariser
21
21
  Requires-Dist: contextily
22
+ Requires-Dist: datasets>=3.0
23
+ Requires-Dist: ever-beta
22
24
  Requires-Dist: geopandas
23
25
  Requires-Dist: huggingface_hub
24
26
  Requires-Dist: jupyter-server-proxy
@@ -29,12 +31,14 @@ Requires-Dist: maplibre
29
31
  Requires-Dist: opencv-python-headless
30
32
  Requires-Dist: overturemaps
31
33
  Requires-Dist: planetary-computer
34
+ Requires-Dist: pyarrow
32
35
  Requires-Dist: pystac-client
33
36
  Requires-Dist: rasterio
34
37
  Requires-Dist: rioxarray
35
38
  Requires-Dist: scikit-image
36
39
  Requires-Dist: scikit-learn
37
40
  Requires-Dist: torch
41
+ Requires-Dist: torchange
38
42
  Requires-Dist: torchgeo
39
43
  Requires-Dist: torchinfo
40
44
  Requires-Dist: tqdm
@@ -0,0 +1,19 @@
1
+ geoai/__init__.py,sha256=kKhbJueAgebEbAHDJ0qn3PMl5C_JZR93qvBtwhGtImk,3765
2
+ geoai/change_detection.py,sha256=DUhC8nqOYL2cKSuc1s7G4tw9tFb7armDIS84mHulSnE,59552
3
+ geoai/classify.py,sha256=0DcComVR6vKU4qWtH2oHVeXc7ZTcV0mFvdXRtlNmolo,35637
4
+ geoai/detectron2.py,sha256=dOOFM9M9-6PV8q2A4-mnIPrz7yTo-MpEvDiAW34nl0w,14610
5
+ geoai/download.py,sha256=B0EwpQFndJknOKmwRfEEnnCJhplOAwcLyNzFuA6FjZs,47633
6
+ geoai/extract.py,sha256=595MBcSaFx-gQLIEv5g3oEM90QA5In4L59GPVgBOlQc,122092
7
+ geoai/geoai.py,sha256=ZPr7hyJhOnwPO9c-nVJVaOUqMRZ77UpK95TFjKzDt0A,9782
8
+ geoai/hf.py,sha256=HbfJfpO6XnANKhmFOBvpwULiC65TeMlnLNtyQHHmlKA,17248
9
+ geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
10
+ geoai/segment.py,sha256=yBGTxA-ti8lBpk7WVaBOp6yP23HkaulKJQk88acrmZ0,43788
11
+ geoai/segmentation.py,sha256=7yEzBSKCyHW1dNssoK0rdvhxi2IXsIQIFSga817KdI4,11535
12
+ geoai/train.py,sha256=k-nniOSRSZgwItBPukvFhTfTymAIAGPlCDoNYOV8S5Y,135885
13
+ geoai/utils.py,sha256=PrqF1hMvmsd_nIJYTOZUDo4TAYfHsAWdadkPaAhoPec,300406
14
+ geoai_py-0.9.1.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
15
+ geoai_py-0.9.1.dist-info/METADATA,sha256=j2ouzTmHTBIQf4DqT-tRcZQDrIq4L4pLvTyyU9IrO78,6763
16
+ geoai_py-0.9.1.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
17
+ geoai_py-0.9.1.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
18
+ geoai_py-0.9.1.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
19
+ geoai_py-0.9.1.dist-info/RECORD,,