geoai-py 0.2.2__py2.py3-none-any.whl → 0.3.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: geoai-py
3
- Version: 0.2.2
3
+ Version: 0.3.0
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -18,11 +18,13 @@ Requires-Python: >=3.9
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: albumentations
21
+ Requires-Dist: contextily
21
22
  Requires-Dist: geopandas
22
23
  Requires-Dist: huggingface_hub
23
24
  Requires-Dist: jupyter-server-proxy
24
25
  Requires-Dist: leafmap
25
26
  Requires-Dist: localtileserver
27
+ Requires-Dist: mapclassify
26
28
  Requires-Dist: overturemaps
27
29
  Requires-Dist: planetary-computer
28
30
  Requires-Dist: pystac-client
@@ -0,0 +1,13 @@
1
+ geoai/__init__.py,sha256=rJod2PDa1AiRHE8ugVp0Bfiky7ZWBhqbh2kZ45WiggA,923
2
+ geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
3
+ geoai/extract.py,sha256=9oLbrSg_aHcimpYxfk0jLOIHQWVULRsdiAGUsPLC-qk,71708
4
+ geoai/geoai.py,sha256=wNwKIqwOT10tU4uiWTcNp5Gd598rRFMANIfJsGdOWKM,90
5
+ geoai/preprocess.py,sha256=teV-W7ykXnoru0Y_d0V9ANdO6jMyETeGbqr1_8H-Yh0,118523
6
+ geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
7
+ geoai/utils.py,sha256=3vXFDdFqZeg4kgeNt6-Hp28RfNoQcDOH7BjrlJ6L0UE,37521
8
+ geoai_py-0.3.0.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
9
+ geoai_py-0.3.0.dist-info/METADATA,sha256=L62RHKj0Yqno8LDYVrL50YyMfO1ybRYs2NI15WHiJMQ,5754
10
+ geoai_py-0.3.0.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
11
+ geoai_py-0.3.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
12
+ geoai_py-0.3.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
13
+ geoai_py-0.3.0.dist-info/RECORD,,
geoai/common.py DELETED
@@ -1,438 +0,0 @@
1
- """The common module contains common functions and classes used by the other modules."""
2
-
3
- import os
4
- from collections.abc import Iterable
5
- from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
6
- import matplotlib.pyplot as plt
7
-
8
- import leafmap
9
- import torch
10
- import numpy as np
11
- import xarray as xr
12
- import rioxarray
13
- import rasterio as rio
14
- from torch.utils.data import DataLoader
15
- from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
16
- from torchgeo.samplers import RandomGeoSampler, Units
17
- from torchgeo.transforms import indices
18
-
19
-
20
- def viz_raster(
21
- source: str,
22
- indexes: Optional[int] = None,
23
- colormap: Optional[str] = None,
24
- vmin: Optional[float] = None,
25
- vmax: Optional[float] = None,
26
- nodata: Optional[float] = None,
27
- attribution: Optional[str] = None,
28
- layer_name: Optional[str] = "Raster",
29
- layer_index: Optional[int] = None,
30
- zoom_to_layer: Optional[bool] = True,
31
- visible: Optional[bool] = True,
32
- opacity: Optional[float] = 1.0,
33
- array_args: Optional[Dict] = {},
34
- client_args: Optional[Dict] = {"cors_all": False},
35
- basemap: Optional[str] = "OpenStreetMap",
36
- **kwargs,
37
- ):
38
- """
39
- Visualize a raster using leafmap.
40
-
41
- Args:
42
- source (str): The source of the raster.
43
- indexes (Optional[int], optional): The band indexes to visualize. Defaults to None.
44
- colormap (Optional[str], optional): The colormap to apply. Defaults to None.
45
- vmin (Optional[float], optional): The minimum value for colormap scaling. Defaults to None.
46
- vmax (Optional[float], optional): The maximum value for colormap scaling. Defaults to None.
47
- nodata (Optional[float], optional): The nodata value. Defaults to None.
48
- attribution (Optional[str], optional): The attribution for the raster. Defaults to None.
49
- layer_name (Optional[str], optional): The name of the layer. Defaults to "Raster".
50
- layer_index (Optional[int], optional): The index of the layer. Defaults to None.
51
- zoom_to_layer (Optional[bool], optional): Whether to zoom to the layer. Defaults to True.
52
- visible (Optional[bool], optional): Whether the layer is visible. Defaults to True.
53
- opacity (Optional[float], optional): The opacity of the layer. Defaults to 1.0.
54
- array_args (Optional[Dict], optional): Additional arguments for array processing. Defaults to {}.
55
- client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
56
- basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
57
- **kwargs (Any): Additional keyword arguments.
58
-
59
- Returns:
60
- leafmap.Map: The map object with the raster layer added.
61
- """
62
-
63
- m = leafmap.Map(basemap=basemap)
64
-
65
- if isinstance(source, dict):
66
- source = dict_to_image(source)
67
-
68
- m.add_raster(
69
- source=source,
70
- indexes=indexes,
71
- colormap=colormap,
72
- vmin=vmin,
73
- vmax=vmax,
74
- nodata=nodata,
75
- attribution=attribution,
76
- layer_name=layer_name,
77
- layer_index=layer_index,
78
- zoom_to_layer=zoom_to_layer,
79
- visible=visible,
80
- opacity=opacity,
81
- array_args=array_args,
82
- client_args=client_args,
83
- **kwargs,
84
- )
85
- return m
86
-
87
-
88
- def viz_image(
89
- image: Union[np.ndarray, torch.Tensor],
90
- transpose: bool = False,
91
- bdx: Optional[int] = None,
92
- scale_factor: float = 1.0,
93
- figsize: Tuple[int, int] = (10, 5),
94
- axis_off: bool = True,
95
- title: Optional[str] = None,
96
- **kwargs: Any,
97
- ) -> None:
98
- """
99
- Visualize an image using matplotlib.
100
-
101
- Args:
102
- image (Union[np.ndarray, torch.Tensor]): The image to visualize.
103
- transpose (bool, optional): Whether to transpose the image. Defaults to False.
104
- bdx (Optional[int], optional): The band index to visualize. Defaults to None.
105
- scale_factor (float, optional): The scale factor to apply to the image. Defaults to 1.0.
106
- figsize (Tuple[int, int], optional): The size of the figure. Defaults to (10, 5).
107
- axis_off (bool, optional): Whether to turn off the axis. Defaults to True.
108
- title (Optional[str], optional): The title of the plot. Defaults to None.
109
- **kwargs (Any): Additional keyword arguments for plt.imshow().
110
-
111
- Returns:
112
- None
113
- """
114
-
115
- if isinstance(image, torch.Tensor):
116
- image = image.cpu().numpy()
117
- elif isinstance(image, str):
118
- image = rio.open(image).read().transpose(1, 2, 0)
119
-
120
- plt.figure(figsize=figsize)
121
-
122
- if transpose:
123
- image = image.transpose(1, 2, 0)
124
-
125
- if bdx is not None:
126
- image = image[:, :, bdx]
127
-
128
- if len(image.shape) > 2 and image.shape[2] > 3:
129
- image = image[:, :, 0:3]
130
-
131
- if scale_factor != 1.0:
132
- image = np.clip(image * scale_factor, 0, 1)
133
-
134
- plt.imshow(image, **kwargs)
135
- if axis_off:
136
- plt.axis("off")
137
- if title is not None:
138
- plt.title(title)
139
- plt.show()
140
- plt.close()
141
-
142
-
143
- def plot_images(
144
- images: Iterable[torch.Tensor],
145
- axs: Iterable[plt.Axes],
146
- chnls: List[int] = [2, 1, 0],
147
- bright: float = 1.0,
148
- ) -> None:
149
- """
150
- Plot a list of images.
151
-
152
- Args:
153
- images (Iterable[torch.Tensor]): The images to plot.
154
- axs (Iterable[plt.Axes]): The axes to plot the images on.
155
- chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
156
- bright (float, optional): The brightness factor. Defaults to 1.0.
157
-
158
- Returns:
159
- None
160
- """
161
- for img, ax in zip(images, axs):
162
- arr = torch.clamp(bright * img, min=0, max=1).numpy()
163
- rgb = arr.transpose(1, 2, 0)[:, :, chnls]
164
- ax.imshow(rgb)
165
- ax.axis("off")
166
-
167
-
168
- def plot_masks(
169
- masks: Iterable[torch.Tensor], axs: Iterable[plt.Axes], cmap: str = "Blues"
170
- ) -> None:
171
- """
172
- Plot a list of masks.
173
-
174
- Args:
175
- masks (Iterable[torch.Tensor]): The masks to plot.
176
- axs (Iterable[plt.Axes]): The axes to plot the masks on.
177
- cmap (str, optional): The colormap to use. Defaults to "Blues".
178
-
179
- Returns:
180
- None
181
- """
182
- for mask, ax in zip(masks, axs):
183
- ax.imshow(mask.squeeze().numpy(), cmap=cmap)
184
- ax.axis("off")
185
-
186
-
187
- def plot_batch(
188
- batch: Dict[str, Any],
189
- bright: float = 1.0,
190
- cols: int = 4,
191
- width: int = 5,
192
- chnls: List[int] = [2, 1, 0],
193
- cmap: str = "Blues",
194
- ) -> None:
195
- """
196
- Plot a batch of images and masks. This function is adapted from the plot_batch()
197
- function in the torchgeo library at
198
- https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html
199
- Credit to the torchgeo developers for the original implementation.
200
-
201
- Args:
202
- batch (Dict[str, Any]): The batch containing images and masks.
203
- bright (float, optional): The brightness factor. Defaults to 1.0.
204
- cols (int, optional): The number of columns in the plot grid. Defaults to 4.
205
- width (int, optional): The width of each plot. Defaults to 5.
206
- chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
207
- cmap (str, optional): The colormap to use for masks. Defaults to "Blues".
208
-
209
- Returns:
210
- None
211
- """
212
- # Get the samples and the number of items in the batch
213
- samples = unbind_samples(batch.copy())
214
-
215
- # if batch contains images and masks, the number of images will be doubled
216
- n = 2 * len(samples) if ("image" in batch) and ("mask" in batch) else len(samples)
217
-
218
- # calculate the number of rows in the grid
219
- rows = n // cols + (1 if n % cols != 0 else 0)
220
-
221
- # create a grid
222
- _, axs = plt.subplots(rows, cols, figsize=(cols * width, rows * width))
223
-
224
- if ("image" in batch) and ("mask" in batch):
225
- # plot the images on the even axis
226
- plot_images(
227
- images=map(lambda x: x["image"], samples),
228
- axs=axs.reshape(-1)[::2],
229
- chnls=chnls,
230
- bright=bright,
231
- )
232
-
233
- # plot the masks on the odd axis
234
- plot_masks(masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1)[1::2])
235
-
236
- else:
237
- if "image" in batch:
238
- plot_images(
239
- images=map(lambda x: x["image"], samples),
240
- axs=axs.reshape(-1),
241
- chnls=chnls,
242
- bright=bright,
243
- )
244
-
245
- elif "mask" in batch:
246
- plot_masks(
247
- masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1), cmap=cmap
248
- )
249
-
250
-
251
- def calc_stats(
252
- dataset: RasterDataset, divide_by: float = 1.0
253
- ) -> Tuple[np.ndarray, np.ndarray]:
254
- """
255
- Calculate the statistics (mean and std) for the entire dataset.
256
-
257
- This function is adapted from the plot_batch() function in the torchgeo library at
258
- https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html.
259
- Credit to the torchgeo developers for the original implementation.
260
-
261
- Warning: This is an approximation. The correct value should take into account the
262
- mean for the whole dataset for computing individual stds.
263
-
264
- Args:
265
- dataset (RasterDataset): The dataset to calculate statistics for.
266
- divide_by (float, optional): The value to divide the image data by. Defaults to 1.0.
267
-
268
- Returns:
269
- Tuple[np.ndarray, np.ndarray]: The mean and standard deviation for each band.
270
- """
271
- import rasterio as rio
272
-
273
- # To avoid loading the entire dataset in memory, we will loop through each img
274
- # The filenames will be retrieved from the dataset's rtree index
275
- files = [
276
- item.object
277
- for item in dataset.index.intersection(dataset.index.bounds, objects=True)
278
- ]
279
-
280
- # Resetting statistics
281
- accum_mean = 0
282
- accum_std = 0
283
-
284
- for file in files:
285
- img = rio.open(file).read() / divide_by # type: ignore
286
- accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
287
- accum_std += img.reshape((img.shape[0], -1)).std(axis=1)
288
-
289
- # at the end, we shall have 2 vectors with length n=chnls
290
- # we will average them considering the number of images
291
- return accum_mean / len(files), accum_std / len(files)
292
-
293
-
294
- def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
295
- """Convert a dictionary to a xarray DataArray. The dictionary should contain the
296
- following keys: "crs", "bounds", and "image". It can be generated from a TorchGeo
297
- dataset sampler.
298
-
299
- Args:
300
- data_dict (Dict): The dictionary containing the data.
301
-
302
- Returns:
303
- xr.DataArray: The xarray DataArray.
304
- """
305
-
306
- from affine import Affine
307
-
308
- # Extract components from the dictionary
309
- crs = data_dict["crs"]
310
- bounds = data_dict["bounds"]
311
- image_tensor = data_dict["image"]
312
-
313
- # Convert tensor to numpy array if needed
314
- if hasattr(image_tensor, "numpy"):
315
- # For PyTorch tensors
316
- image_array = image_tensor.numpy()
317
- else:
318
- # If it's already a numpy array or similar
319
- image_array = np.array(image_tensor)
320
-
321
- # Calculate pixel resolution
322
- width = image_array.shape[2] # Width is the size of the last dimension
323
- height = image_array.shape[1] # Height is the size of the middle dimension
324
-
325
- res_x = (bounds.maxx - bounds.minx) / width
326
- res_y = (bounds.maxy - bounds.miny) / height
327
-
328
- # Create the transform matrix
329
- transform = Affine(res_x, 0.0, bounds.minx, 0.0, -res_y, bounds.maxy)
330
-
331
- # Create dimensions
332
- x_coords = np.linspace(bounds.minx + res_x / 2, bounds.maxx - res_x / 2, width)
333
- y_coords = np.linspace(bounds.maxy - res_y / 2, bounds.miny + res_y / 2, height)
334
-
335
- # If time dimension exists in the bounds
336
- if hasattr(bounds, "mint") and hasattr(bounds, "maxt"):
337
- # Create a single time value or range if needed
338
- t_coords = [
339
- bounds.mint
340
- ] # Or np.linspace(bounds.mint, bounds.maxt, num_time_steps)
341
-
342
- # Create DataArray with time dimension
343
- dims = (
344
- ("band", "y", "x")
345
- if image_array.shape[0] <= 10
346
- else ("time", "band", "y", "x")
347
- )
348
-
349
- if dims[0] == "band":
350
- # For multi-band single time
351
- da = xr.DataArray(
352
- image_array,
353
- dims=dims,
354
- coords={
355
- "band": np.arange(1, image_array.shape[0] + 1),
356
- "y": y_coords,
357
- "x": x_coords,
358
- },
359
- )
360
- else:
361
- # For multi-time multi-band
362
- da = xr.DataArray(
363
- image_array,
364
- dims=dims,
365
- coords={
366
- "time": t_coords,
367
- "band": np.arange(1, image_array.shape[1] + 1),
368
- "y": y_coords,
369
- "x": x_coords,
370
- },
371
- )
372
- else:
373
- # Create DataArray without time dimension
374
- da = xr.DataArray(
375
- image_array,
376
- dims=("band", "y", "x"),
377
- coords={
378
- "band": np.arange(1, image_array.shape[0] + 1),
379
- "y": y_coords,
380
- "x": x_coords,
381
- },
382
- )
383
-
384
- # Set spatial attributes
385
- da.rio.write_crs(crs, inplace=True)
386
- da.rio.write_transform(transform, inplace=True)
387
-
388
- return da
389
-
390
-
391
- def dict_to_image(
392
- data_dict: Dict[str, Any], output: Optional[str] = None, **kwargs
393
- ) -> rio.DatasetReader:
394
- """Convert a dictionary containing spatial data to a rasterio dataset or save it to
395
- a file. The dictionary should contain the following keys: "crs", "bounds", and "image".
396
- It can be generated from a TorchGeo dataset sampler.
397
-
398
- This function transforms a dictionary with CRS, bounding box, and image data
399
- into a rasterio DatasetReader using leafmap's array_to_image utility after
400
- first converting to a rioxarray DataArray.
401
-
402
- Args:
403
- data_dict: A dictionary containing:
404
- - 'crs': A pyproj CRS object
405
- - 'bounds': A BoundingBox object with minx, maxx, miny, maxy attributes
406
- and optionally mint, maxt for temporal bounds
407
- - 'image': A tensor or array-like object with image data
408
- output: Optional path to save the image to a file. If not provided, the image
409
- will be returned as a rasterio DatasetReader object.
410
- **kwargs: Additional keyword arguments to pass to leafmap.array_to_image.
411
- Common options include:
412
- - colormap: str, name of the colormap (e.g., 'viridis', 'terrain')
413
- - vmin: float, minimum value for colormap scaling
414
- - vmax: float, maximum value for colormap scaling
415
-
416
- Returns:
417
- A rasterio DatasetReader object that can be used for visualization or
418
- further processing.
419
-
420
- Examples:
421
- >>> image = dict_to_image(
422
- ... {'crs': CRS.from_epsg(26911), 'bounds': bbox, 'image': tensor},
423
- ... colormap='terrain'
424
- ... )
425
- >>> fig, ax = plt.subplots(figsize=(10, 10))
426
- >>> show(image, ax=ax)
427
- """
428
- da = dict_to_rioxarray(data_dict)
429
-
430
- if output is not None:
431
- out_dir = os.path.abspath(os.path.dirname(output))
432
- if not os.path.exists(out_dir):
433
- os.makedirs(out_dir, exist_ok=True)
434
- da.rio.to_raster(output)
435
- return output
436
- else:
437
- image = leafmap.array_to_image(da, **kwargs)
438
- return image
@@ -1,13 +0,0 @@
1
- geoai/__init__.py,sha256=yEbFyHPNijxgK-75tatrRELZ9TUdZVYo2uPlxCeBFDA,923
2
- geoai/common.py,sha256=NdfkQKMPHkwr0B5sDpH5Q_7Nt2AmYt9Gw-KE88NsQ5s,15222
3
- geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
4
- geoai/extract.py,sha256=Fh29d5Fj60YiqhMs62lzkd9T_ONTp2UZ4j98We769sg,31563
5
- geoai/geoai.py,sha256=BCEtHil0P5cettJdMIhblg1pRaV-vHNQFaYmBrtYP3g,68
6
- geoai/preprocess.py,sha256=pYtf3-eZY76SKd17MvEZ1qNUvblYW5kzQLvZ-ZM4Wwg,106833
7
- geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
8
- geoai_py-0.2.2.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
9
- geoai_py-0.2.2.dist-info/METADATA,sha256=baREpHpvCvfktqiMSWNI-FGOVme8NAj0UkaJhS6Bkm4,5701
10
- geoai_py-0.2.2.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
11
- geoai_py-0.2.2.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
12
- geoai_py-0.2.2.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
13
- geoai_py-0.2.2.dist-info/RECORD,,