geoai-py 0.1.5__py2.py3-none-any.whl → 0.1.7__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,4 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.1.5"
5
+ __version__ = "0.1.7"
6
+
7
+
8
+ from .geoai import *
geoai/common.py CHANGED
@@ -1,6 +1,279 @@
1
1
  """The common module contains common functions and classes used by the other modules."""
2
2
 
3
+ import os
4
+ from collections.abc import Iterable
5
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
6
+ import matplotlib.pyplot as plt
3
7
 
4
- def hello_world():
5
- """Prints "Hello World!" to the console."""
6
- print("Hello World!")
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import DataLoader
11
+ from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
12
+ from torchgeo.samplers import RandomGeoSampler, Units
13
+ from torchgeo.transforms import indices
14
+
15
+
16
+ def viz_raster(
17
+ source: str,
18
+ indexes: Optional[int] = None,
19
+ colormap: Optional[str] = None,
20
+ vmin: Optional[float] = None,
21
+ vmax: Optional[float] = None,
22
+ nodata: Optional[float] = None,
23
+ attribution: Optional[str] = None,
24
+ layer_name: Optional[str] = "Raster",
25
+ layer_index: Optional[int] = None,
26
+ zoom_to_layer: Optional[bool] = True,
27
+ visible: Optional[bool] = True,
28
+ opacity: Optional[float] = 1.0,
29
+ array_args: Optional[Dict] = {},
30
+ client_args: Optional[Dict] = {"cors_all": False},
31
+ basemap: Optional[str] = "OpenStreetMap",
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Visualize a raster using leafmap.
36
+
37
+ Args:
38
+ source (str): The source of the raster.
39
+ indexes (Optional[int], optional): The band indexes to visualize. Defaults to None.
40
+ colormap (Optional[str], optional): The colormap to apply. Defaults to None.
41
+ vmin (Optional[float], optional): The minimum value for colormap scaling. Defaults to None.
42
+ vmax (Optional[float], optional): The maximum value for colormap scaling. Defaults to None.
43
+ nodata (Optional[float], optional): The nodata value. Defaults to None.
44
+ attribution (Optional[str], optional): The attribution for the raster. Defaults to None.
45
+ layer_name (Optional[str], optional): The name of the layer. Defaults to "Raster".
46
+ layer_index (Optional[int], optional): The index of the layer. Defaults to None.
47
+ zoom_to_layer (Optional[bool], optional): Whether to zoom to the layer. Defaults to True.
48
+ visible (Optional[bool], optional): Whether the layer is visible. Defaults to True.
49
+ opacity (Optional[float], optional): The opacity of the layer. Defaults to 1.0.
50
+ array_args (Optional[Dict], optional): Additional arguments for array processing. Defaults to {}.
51
+ client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
52
+ basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
53
+ **kwargs (Any): Additional keyword arguments.
54
+
55
+ Returns:
56
+ leafmap.Map: The map object with the raster layer added.
57
+ """
58
+ import leafmap
59
+
60
+ m = leafmap.Map(basemap=basemap)
61
+
62
+ m.add_raster(
63
+ source=source,
64
+ indexes=indexes,
65
+ colormap=colormap,
66
+ vmin=vmin,
67
+ vmax=vmax,
68
+ nodata=nodata,
69
+ attribution=attribution,
70
+ layer_name=layer_name,
71
+ layer_index=layer_index,
72
+ zoom_to_layer=zoom_to_layer,
73
+ visible=visible,
74
+ opacity=opacity,
75
+ array_args=array_args,
76
+ client_args=client_args,
77
+ **kwargs,
78
+ )
79
+ return m
80
+
81
+
82
+ def viz_image(
83
+ image: Union[np.ndarray, torch.Tensor],
84
+ transpose: bool = False,
85
+ bdx: Optional[int] = None,
86
+ scale_factor: float = 1.0,
87
+ figsize: Tuple[int, int] = (10, 5),
88
+ axis_off: bool = True,
89
+ **kwargs: Any,
90
+ ) -> None:
91
+ """
92
+ Visualize an image using matplotlib.
93
+
94
+ Args:
95
+ image (Union[np.ndarray, torch.Tensor]): The image to visualize.
96
+ transpose (bool, optional): Whether to transpose the image. Defaults to False.
97
+ bdx (Optional[int], optional): The band index to visualize. Defaults to None.
98
+ scale_factor (float, optional): The scale factor to apply to the image. Defaults to 1.0.
99
+ figsize (Tuple[int, int], optional): The size of the figure. Defaults to (10, 5).
100
+ axis_off (bool, optional): Whether to turn off the axis. Defaults to True.
101
+ **kwargs (Any): Additional keyword arguments for plt.imshow().
102
+
103
+ Returns:
104
+ None
105
+ """
106
+
107
+ if isinstance(image, torch.Tensor):
108
+ image = image.cpu().numpy()
109
+
110
+ plt.figure(figsize=figsize)
111
+
112
+ if transpose:
113
+ image = image.transpose(1, 2, 0)
114
+
115
+ if bdx is not None:
116
+ image = image[:, :, bdx]
117
+
118
+ if len(image.shape) > 2 and image.shape[2] > 3:
119
+ image = image[:, :, 0:3]
120
+
121
+ if scale_factor != 1.0:
122
+ image = np.clip(image * scale_factor, 0, 1)
123
+
124
+ plt.imshow(image, **kwargs)
125
+ if axis_off:
126
+ plt.axis("off")
127
+ plt.show()
128
+ plt.close()
129
+
130
+
131
+ def plot_images(
132
+ images: Iterable[torch.Tensor],
133
+ axs: Iterable[plt.Axes],
134
+ chnls: List[int] = [2, 1, 0],
135
+ bright: float = 1.0,
136
+ ) -> None:
137
+ """
138
+ Plot a list of images.
139
+
140
+ Args:
141
+ images (Iterable[torch.Tensor]): The images to plot.
142
+ axs (Iterable[plt.Axes]): The axes to plot the images on.
143
+ chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
144
+ bright (float, optional): The brightness factor. Defaults to 1.0.
145
+
146
+ Returns:
147
+ None
148
+ """
149
+ for img, ax in zip(images, axs):
150
+ arr = torch.clamp(bright * img, min=0, max=1).numpy()
151
+ rgb = arr.transpose(1, 2, 0)[:, :, chnls]
152
+ ax.imshow(rgb)
153
+ ax.axis("off")
154
+
155
+
156
+ def plot_masks(
157
+ masks: Iterable[torch.Tensor], axs: Iterable[plt.Axes], cmap: str = "Blues"
158
+ ) -> None:
159
+ """
160
+ Plot a list of masks.
161
+
162
+ Args:
163
+ masks (Iterable[torch.Tensor]): The masks to plot.
164
+ axs (Iterable[plt.Axes]): The axes to plot the masks on.
165
+ cmap (str, optional): The colormap to use. Defaults to "Blues".
166
+
167
+ Returns:
168
+ None
169
+ """
170
+ for mask, ax in zip(masks, axs):
171
+ ax.imshow(mask.squeeze().numpy(), cmap=cmap)
172
+ ax.axis("off")
173
+
174
+
175
+ def plot_batch(
176
+ batch: Dict[str, Any],
177
+ bright: float = 1.0,
178
+ cols: int = 4,
179
+ width: int = 5,
180
+ chnls: List[int] = [2, 1, 0],
181
+ cmap: str = "Blues",
182
+ ) -> None:
183
+ """
184
+ Plot a batch of images and masks. This function is adapted from the plot_batch()
185
+ function in the torchgeo library at
186
+ https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html
187
+ Credit to the torchgeo developers for the original implementation.
188
+
189
+ Args:
190
+ batch (Dict[str, Any]): The batch containing images and masks.
191
+ bright (float, optional): The brightness factor. Defaults to 1.0.
192
+ cols (int, optional): The number of columns in the plot grid. Defaults to 4.
193
+ width (int, optional): The width of each plot. Defaults to 5.
194
+ chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
195
+ cmap (str, optional): The colormap to use for masks. Defaults to "Blues".
196
+
197
+ Returns:
198
+ None
199
+ """
200
+ # Get the samples and the number of items in the batch
201
+ samples = unbind_samples(batch.copy())
202
+
203
+ # if batch contains images and masks, the number of images will be doubled
204
+ n = 2 * len(samples) if ("image" in batch) and ("mask" in batch) else len(samples)
205
+
206
+ # calculate the number of rows in the grid
207
+ rows = n // cols + (1 if n % cols != 0 else 0)
208
+
209
+ # create a grid
210
+ _, axs = plt.subplots(rows, cols, figsize=(cols * width, rows * width))
211
+
212
+ if ("image" in batch) and ("mask" in batch):
213
+ # plot the images on the even axis
214
+ plot_images(
215
+ images=map(lambda x: x["image"], samples),
216
+ axs=axs.reshape(-1)[::2],
217
+ chnls=chnls,
218
+ bright=bright,
219
+ )
220
+
221
+ # plot the masks on the odd axis
222
+ plot_masks(masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1)[1::2])
223
+
224
+ else:
225
+ if "image" in batch:
226
+ plot_images(
227
+ images=map(lambda x: x["image"], samples),
228
+ axs=axs.reshape(-1),
229
+ chnls=chnls,
230
+ bright=bright,
231
+ )
232
+
233
+ elif "mask" in batch:
234
+ plot_masks(
235
+ masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1), cmap=cmap
236
+ )
237
+
238
+
239
+ def calc_stats(
240
+ dataset: RasterDataset, divide_by: float = 1.0
241
+ ) -> Tuple[np.ndarray, np.ndarray]:
242
+ """
243
+ Calculate the statistics (mean and std) for the entire dataset.
244
+
245
+ This function is adapted from the plot_batch() function in the torchgeo library at
246
+ https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html.
247
+ Credit to the torchgeo developers for the original implementation.
248
+
249
+ Warning: This is an approximation. The correct value should take into account the
250
+ mean for the whole dataset for computing individual stds.
251
+
252
+ Args:
253
+ dataset (RasterDataset): The dataset to calculate statistics for.
254
+ divide_by (float, optional): The value to divide the image data by. Defaults to 1.0.
255
+
256
+ Returns:
257
+ Tuple[np.ndarray, np.ndarray]: The mean and standard deviation for each band.
258
+ """
259
+ import rasterio as rio
260
+
261
+ # To avoid loading the entire dataset in memory, we will loop through each img
262
+ # The filenames will be retrieved from the dataset's rtree index
263
+ files = [
264
+ item.object
265
+ for item in dataset.index.intersection(dataset.index.bounds, objects=True)
266
+ ]
267
+
268
+ # Resetting statistics
269
+ accum_mean = 0
270
+ accum_std = 0
271
+
272
+ for file in files:
273
+ img = rio.open(file).read() / divide_by # type: ignore
274
+ accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
275
+ accum_std += img.reshape((img.shape[0], -1)).std(axis=1)
276
+
277
+ # at the end, we shall have 2 vectors with length n=chnls
278
+ # we will average them considering the number of images
279
+ return accum_mean / len(files), accum_std / len(files)
geoai/download.py ADDED
@@ -0,0 +1,386 @@
1
+ """This module provides functions to download data, including NAIP imagery and building data from Overture Maps."""
2
+
3
+ import os
4
+ from typing import List, Tuple, Optional, Dict, Any
5
+ import rioxarray
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from pystac_client import Client
9
+ import planetary_computer as pc
10
+ import geopandas as gpd
11
+ from shapely.geometry import box
12
+ from tqdm import tqdm
13
+ import requests
14
+ import subprocess
15
+ import logging
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def download_naip(
25
+ bbox: Tuple[float, float, float, float],
26
+ output_dir: str,
27
+ year: Optional[int] = None,
28
+ max_items: int = 10,
29
+ preview: bool = False,
30
+ **kwargs: Any,
31
+ ) -> List[str]:
32
+ """Download NAIP imagery from Planetary Computer based on a bounding box.
33
+
34
+ This function searches for NAIP (National Agriculture Imagery Program) imagery
35
+ from Microsoft's Planetary Computer that intersects with the specified bounding box.
36
+ It downloads the imagery and saves it as GeoTIFF files.
37
+
38
+ Args:
39
+ bbox: Bounding box in the format (min_lon, min_lat, max_lon, max_lat) in WGS84 coordinates.
40
+ output_dir: Directory to save the downloaded imagery.
41
+ year: Specific year of NAIP imagery to download (e.g., 2020). If None, returns imagery from all available years.
42
+ max_items: Maximum number of items to download.
43
+ preview: If True, display a preview of the downloaded imagery.
44
+
45
+ Returns:
46
+ List of downloaded file paths.
47
+
48
+ Raises:
49
+ Exception: If there is an error downloading or saving the imagery.
50
+ """
51
+ # Create output directory if it doesn't exist
52
+ os.makedirs(output_dir, exist_ok=True)
53
+
54
+ # Create a geometry from the bounding box
55
+ geometry = box(*bbox)
56
+
57
+ # Connect to Planetary Computer STAC API
58
+ catalog = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")
59
+
60
+ # Build query for NAIP data
61
+ search_params = {
62
+ "collections": ["naip"],
63
+ "intersects": geometry,
64
+ "limit": max_items,
65
+ }
66
+
67
+ # Add year filter if specified
68
+ if year:
69
+ search_params["query"] = {"naip:year": {"eq": year}}
70
+
71
+ for key, value in kwargs.items():
72
+ search_params[key] = value
73
+
74
+ # Search for NAIP imagery
75
+ search_results = catalog.search(**search_params)
76
+ items = list(search_results.items())
77
+
78
+ if not items:
79
+ print("No NAIP imagery found for the specified region and parameters.")
80
+ return []
81
+
82
+ print(f"Found {len(items)} NAIP items.")
83
+
84
+ # Download and save each item
85
+ downloaded_files = []
86
+ for i, item in enumerate(items):
87
+ # Sign the assets (required for Planetary Computer)
88
+ signed_item = pc.sign(item)
89
+
90
+ # Get the RGB asset URL
91
+ rgb_asset = signed_item.assets.get("image")
92
+ if not rgb_asset:
93
+ print(f"No RGB asset found for item {i+1}")
94
+ continue
95
+
96
+ # Use the original filename from the asset
97
+ original_filename = os.path.basename(
98
+ rgb_asset.href.split("?")[0]
99
+ ) # Remove query parameters
100
+ output_path = os.path.join(output_dir, original_filename)
101
+
102
+ print(f"Downloading item {i+1}/{len(items)}: {original_filename}")
103
+
104
+ try:
105
+ # Open and save the data with progress bar
106
+ # For direct file download with progress bar
107
+ if rgb_asset.href.startswith("http"):
108
+ download_with_progress(rgb_asset.href, output_path)
109
+ #
110
+ else:
111
+ # Fallback to direct rioxarray opening (less common case)
112
+ data = rioxarray.open_rasterio(rgb_asset.href)
113
+ data.rio.to_raster(output_path)
114
+
115
+ downloaded_files.append(output_path)
116
+ print(f"Successfully saved to {output_path}")
117
+
118
+ # Optional: Display a preview (uncomment if needed)
119
+ if preview:
120
+ data = rioxarray.open_rasterio(output_path)
121
+ preview_raster(data)
122
+
123
+ except Exception as e:
124
+ print(f"Error downloading item {i+1}: {str(e)}")
125
+
126
+ return downloaded_files
127
+
128
+
129
+ def download_with_progress(url: str, output_path: str) -> None:
130
+ """Download a file with a progress bar.
131
+
132
+ Args:
133
+ url: URL of the file to download.
134
+ output_path: Path where the file will be saved.
135
+ """
136
+ response = requests.get(url, stream=True)
137
+ total_size = int(response.headers.get("content-length", 0))
138
+ block_size = 1024 # 1 Kibibyte
139
+
140
+ with (
141
+ open(output_path, "wb") as file,
142
+ tqdm(
143
+ desc=os.path.basename(output_path),
144
+ total=total_size,
145
+ unit="iB",
146
+ unit_scale=True,
147
+ unit_divisor=1024,
148
+ ) as bar,
149
+ ):
150
+ for data in response.iter_content(block_size):
151
+ size = file.write(data)
152
+ bar.update(size)
153
+
154
+
155
+ def preview_raster(data: Any, title: str = None) -> None:
156
+ """Display a preview of the downloaded imagery.
157
+
158
+ This function creates a visualization of the downloaded NAIP imagery
159
+ by converting it to an RGB array and displaying it with matplotlib.
160
+
161
+ Args:
162
+ data: The raster data as a rioxarray object.
163
+ title: The title for the preview plot.
164
+ """
165
+ # Convert to 8-bit RGB for display
166
+ rgb_data = data.transpose("y", "x", "band").values[:, :, 0:3]
167
+ rgb_data = np.where(rgb_data > 255, 255, rgb_data).astype(np.uint8)
168
+
169
+ plt.figure(figsize=(10, 10))
170
+ plt.imshow(rgb_data)
171
+ if title is not None:
172
+ plt.title(title)
173
+ plt.axis("off")
174
+ plt.show()
175
+
176
+
177
+ # Helper function to convert NumPy types to native Python types for JSON serialization
178
+ def json_serializable(obj: Any) -> Any:
179
+ """Convert NumPy types to native Python types for JSON serialization.
180
+
181
+ Args:
182
+ obj: Any object to convert.
183
+
184
+ Returns:
185
+ JSON serializable version of the object.
186
+ """
187
+ if isinstance(obj, np.integer):
188
+ return int(obj)
189
+ elif isinstance(obj, np.floating):
190
+ return float(obj)
191
+ elif isinstance(obj, np.ndarray):
192
+ return obj.tolist()
193
+ else:
194
+ return obj
195
+
196
+
197
+ def download_overture_buildings(
198
+ bbox: Tuple[float, float, float, float],
199
+ output_file: str,
200
+ output_format: str = "geojson",
201
+ data_type: str = "building",
202
+ verbose: bool = True,
203
+ ) -> str:
204
+ """Download building data from Overture Maps for a given bounding box using the overturemaps CLI tool.
205
+
206
+ Args:
207
+ bbox: Bounding box in the format (min_lon, min_lat, max_lon, max_lat) in WGS84 coordinates.
208
+ output_file: Path to save the output file.
209
+ output_format: Format to save the output, one of "geojson", "geojsonseq", or "geoparquet".
210
+ data_type: The Overture Maps data type to download (building, place, etc.).
211
+ verbose: Whether to print verbose output.
212
+
213
+ Returns:
214
+ Path to the output file.
215
+ """
216
+ # Create output directory if needed
217
+ output_dir = os.path.dirname(output_file)
218
+ if output_dir and not os.path.exists(output_dir):
219
+ os.makedirs(output_dir, exist_ok=True)
220
+
221
+ # Format the bounding box string for the command
222
+ west, south, east, north = bbox
223
+ bbox_str = f"{west},{south},{east},{north}"
224
+
225
+ # Build the command
226
+ cmd = [
227
+ "overturemaps",
228
+ "download",
229
+ "--bbox",
230
+ bbox_str,
231
+ "-f",
232
+ output_format,
233
+ "--type",
234
+ data_type,
235
+ "--output",
236
+ output_file,
237
+ ]
238
+
239
+ if verbose:
240
+ logger.info(f"Running command: {' '.join(cmd)}")
241
+ logger.info("Downloading %s data for area: %s", data_type, bbox_str)
242
+
243
+ try:
244
+ # Run the command
245
+ result = subprocess.run(
246
+ cmd,
247
+ check=True,
248
+ stdout=subprocess.PIPE if not verbose else None,
249
+ stderr=subprocess.PIPE,
250
+ text=True,
251
+ )
252
+
253
+ # Check if the file was created
254
+ if os.path.exists(output_file):
255
+ file_size = os.path.getsize(output_file) / (1024 * 1024) # Size in MB
256
+ logger.info(
257
+ f"Successfully downloaded data to {output_file} ({file_size:.2f} MB)"
258
+ )
259
+
260
+ # Optionally show some stats about the downloaded data
261
+ if output_format == "geojson" and os.path.getsize(output_file) > 0:
262
+ try:
263
+ gdf = gpd.read_file(output_file)
264
+ logger.info(f"Downloaded {len(gdf)} features")
265
+
266
+ if len(gdf) > 0 and verbose:
267
+ # Show a sample of the attribute names
268
+ attrs = list(gdf.columns)
269
+ attrs.remove("geometry")
270
+ logger.info(f"Available attributes: {', '.join(attrs[:10])}...")
271
+ except Exception as e:
272
+ logger.warning(f"Could not read the GeoJSON file: {str(e)}")
273
+
274
+ return output_file
275
+ else:
276
+ logger.error(f"Command completed but file {output_file} was not created")
277
+ if result.stderr:
278
+ logger.error(f"Command error output: {result.stderr}")
279
+ return None
280
+
281
+ except subprocess.CalledProcessError as e:
282
+ logger.error(f"Error running overturemaps command: {str(e)}")
283
+ if e.stderr:
284
+ logger.error(f"Command error output: {e.stderr}")
285
+ raise RuntimeError(f"Failed to download Overture Maps data: {str(e)}")
286
+ except Exception as e:
287
+ logger.error(f"Unexpected error: {str(e)}")
288
+ raise
289
+
290
+
291
+ def convert_vector_format(
292
+ input_file: str,
293
+ output_format: str = "geojson",
294
+ filter_expression: Optional[str] = None,
295
+ ) -> str:
296
+ """Convert the downloaded data to a different format or filter it.
297
+
298
+ Args:
299
+ input_file: Path to the input file.
300
+ output_format: Format to convert to, one of "geojson", "parquet", "shapefile", "csv".
301
+ filter_expression: Optional GeoDataFrame query expression to filter the data.
302
+
303
+ Returns:
304
+ Path to the converted file.
305
+ """
306
+ try:
307
+ # Read the input file
308
+ logger.info(f"Reading {input_file}")
309
+ gdf = gpd.read_file(input_file)
310
+
311
+ # Apply filter if specified
312
+ if filter_expression:
313
+ logger.info(f"Filtering data using expression: {filter_expression}")
314
+ gdf = gdf.query(filter_expression)
315
+ logger.info(f"After filtering: {len(gdf)} features")
316
+
317
+ # Define output file path
318
+ base_path = os.path.splitext(input_file)[0]
319
+
320
+ if output_format == "geojson":
321
+ output_file = f"{base_path}.geojson"
322
+ logger.info(f"Converting to GeoJSON: {output_file}")
323
+ gdf.to_file(output_file, driver="GeoJSON")
324
+ elif output_format == "parquet":
325
+ output_file = f"{base_path}.parquet"
326
+ logger.info(f"Converting to Parquet: {output_file}")
327
+ gdf.to_parquet(output_file)
328
+ elif output_format == "shapefile":
329
+ output_file = f"{base_path}.shp"
330
+ logger.info(f"Converting to Shapefile: {output_file}")
331
+ gdf.to_file(output_file)
332
+ elif output_format == "csv":
333
+ output_file = f"{base_path}.csv"
334
+ logger.info(f"Converting to CSV: {output_file}")
335
+
336
+ # For CSV, we need to convert geometry to WKT
337
+ gdf["geometry_wkt"] = gdf.geometry.apply(lambda g: g.wkt)
338
+
339
+ # Save to CSV with geometry as WKT
340
+ gdf.drop(columns=["geometry"]).to_csv(output_file, index=False)
341
+ else:
342
+ raise ValueError(f"Unsupported output format: {output_format}")
343
+
344
+ return output_file
345
+
346
+ except Exception as e:
347
+ logger.error(f"Error converting data: {str(e)}")
348
+ raise
349
+
350
+
351
+ def extract_building_stats(geojson_file: str) -> Dict[str, Any]:
352
+ """Extract statistics from the building data.
353
+
354
+ Args:
355
+ geojson_file: Path to the GeoJSON file.
356
+
357
+ Returns:
358
+ Dictionary with statistics.
359
+ """
360
+ try:
361
+ # Read the GeoJSON file
362
+ gdf = gpd.read_file(geojson_file)
363
+
364
+ # Calculate statistics
365
+ bbox = gdf.total_bounds.tolist()
366
+ # Convert numpy values to Python native types
367
+ bbox = [float(x) for x in bbox]
368
+
369
+ stats = {
370
+ "total_buildings": int(len(gdf)),
371
+ "has_height": (
372
+ int(gdf["height"].notna().sum()) if "height" in gdf.columns else 0
373
+ ),
374
+ "has_name": (
375
+ int(gdf["names.common.value"].notna().sum())
376
+ if "names.common.value" in gdf.columns
377
+ else 0
378
+ ),
379
+ "bbox": bbox,
380
+ }
381
+
382
+ return stats
383
+
384
+ except Exception as e:
385
+ logger.error(f"Error extracting statistics: {str(e)}")
386
+ return {"error": str(e)}
geoai/geoai.py CHANGED
@@ -1 +1,3 @@
1
1
  """Main module."""
2
+
3
+ from .common import viz_raster, viz_image, plot_batch, calc_stats
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: geoai-py
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
7
- Project-URL: Homepage, https://github.com/giswqs/geoai
7
+ Project-URL: Homepage, https://github.com/opengeos/geoai
8
8
  Keywords: geoai
9
9
  Classifier: Intended Audience :: Developers
10
10
  Classifier: License :: OSI Approved :: MIT License
@@ -18,14 +18,21 @@ Requires-Python: >=3.9
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: albumentations
21
+ Requires-Dist: jupyter-server-proxy
22
+ Requires-Dist: leafmap
23
+ Requires-Dist: localtileserver
21
24
  Requires-Dist: scikit-learn
22
25
  Requires-Dist: segment-geospatial
23
26
  Requires-Dist: torch
27
+ Requires-Dist: torchgeo
24
28
  Requires-Dist: transformers
29
+ Provides-Extra: download
30
+ Requires-Dist: pystac_client; extra == "download"
31
+ Requires-Dist: planetary_computer; extra == "download"
32
+ Requires-Dist: tqdm; extra == "download"
33
+ Requires-Dist: overturemaps; extra == "download"
25
34
  Provides-Extra: all
26
- Requires-Dist: geoai[extra]; extra == "all"
27
- Provides-Extra: extra
28
- Requires-Dist: pandas; extra == "extra"
35
+ Requires-Dist: geoai[download]; extra == "all"
29
36
 
30
37
  # geoai
31
38
 
@@ -0,0 +1,11 @@
1
+ geoai/__init__.py,sha256=pHMpntXI_R7c_vTJ6VbyS8iAx1uVyDP5HwYauOonZ3s,143
2
+ geoai/common.py,sha256=6h6mtUBO428P3IZppWyCVo04Ohzc3VhmnH0tvVh479g,9675
3
+ geoai/download.py,sha256=wzmLvrbrX9tTAkLYZInFnW-Yextr3Appcg0DPhQWHYU,12738
4
+ geoai/geoai.py,sha256=TmR7x1uL51G5oAjw0AQWnC5VQtLWDygyFLrDIj46xNc,86
5
+ geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
6
+ geoai_py-0.1.7.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
7
+ geoai_py-0.1.7.dist-info/METADATA,sha256=I3hYR-8dPERNmz_FsLLUeVVFXsPzAHK09HNZMRxss0A,1886
8
+ geoai_py-0.1.7.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
9
+ geoai_py-0.1.7.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
10
+ geoai_py-0.1.7.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
11
+ geoai_py-0.1.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (75.8.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py2-none-any
5
5
  Tag: py3-none-any
@@ -1,10 +0,0 @@
1
- geoai/__init__.py,sha256=xkyHrnU3iebQB2V0bl1Xd9s8YKI5-UTS-tMbjgnJXNY,120
2
- geoai/common.py,sha256=Rw6d9qmZDu3dUGTyJto1Y97S7-QA-m2p-pbCNvMDrm4,184
3
- geoai/geoai.py,sha256=h0hwdogXGFqerm-5ZPeT-irPn91pCcQRjiHThXsRzEk,19
4
- geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
5
- geoai_py-0.1.5.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
6
- geoai_py-0.1.5.dist-info/METADATA,sha256=oQrREZxyg5_OgaVK4Pkn9txkaqUYcUtvigo13MfFNo0,1609
7
- geoai_py-0.1.5.dist-info/WHEEL,sha256=9Hm2OB-j1QcCUq9Jguht7ayGIIZBRTdOXD1qg9cCgPM,109
8
- geoai_py-0.1.5.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
9
- geoai_py-0.1.5.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
10
- geoai_py-0.1.5.dist-info/RECORD,,