geoai-py 0.1.5__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,4 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.1.5"
5
+ __version__ = "0.1.6"
6
+
7
+
8
+ from .geoai import *
geoai/common.py CHANGED
@@ -1,6 +1,279 @@
1
1
  """The common module contains common functions and classes used by the other modules."""
2
2
 
3
+ import os
4
+ from collections.abc import Iterable
5
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
6
+ import matplotlib.pyplot as plt
3
7
 
4
- def hello_world():
5
- """Prints "Hello World!" to the console."""
6
- print("Hello World!")
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import DataLoader
11
+ from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
12
+ from torchgeo.samplers import RandomGeoSampler, Units
13
+ from torchgeo.transforms import indices
14
+
15
+
16
+ def viz_raster(
17
+ source: str,
18
+ indexes: Optional[int] = None,
19
+ colormap: Optional[str] = None,
20
+ vmin: Optional[float] = None,
21
+ vmax: Optional[float] = None,
22
+ nodata: Optional[float] = None,
23
+ attribution: Optional[str] = None,
24
+ layer_name: Optional[str] = "Raster",
25
+ layer_index: Optional[int] = None,
26
+ zoom_to_layer: Optional[bool] = True,
27
+ visible: Optional[bool] = True,
28
+ opacity: Optional[float] = 1.0,
29
+ array_args: Optional[Dict] = {},
30
+ client_args: Optional[Dict] = {"cors_all": False},
31
+ basemap: Optional[str] = "OpenStreetMap",
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Visualize a raster using leafmap.
36
+
37
+ Args:
38
+ source (str): The source of the raster.
39
+ indexes (Optional[int], optional): The band indexes to visualize. Defaults to None.
40
+ colormap (Optional[str], optional): The colormap to apply. Defaults to None.
41
+ vmin (Optional[float], optional): The minimum value for colormap scaling. Defaults to None.
42
+ vmax (Optional[float], optional): The maximum value for colormap scaling. Defaults to None.
43
+ nodata (Optional[float], optional): The nodata value. Defaults to None.
44
+ attribution (Optional[str], optional): The attribution for the raster. Defaults to None.
45
+ layer_name (Optional[str], optional): The name of the layer. Defaults to "Raster".
46
+ layer_index (Optional[int], optional): The index of the layer. Defaults to None.
47
+ zoom_to_layer (Optional[bool], optional): Whether to zoom to the layer. Defaults to True.
48
+ visible (Optional[bool], optional): Whether the layer is visible. Defaults to True.
49
+ opacity (Optional[float], optional): The opacity of the layer. Defaults to 1.0.
50
+ array_args (Optional[Dict], optional): Additional arguments for array processing. Defaults to {}.
51
+ client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
52
+ basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
53
+ **kwargs (Any): Additional keyword arguments.
54
+
55
+ Returns:
56
+ leafmap.Map: The map object with the raster layer added.
57
+ """
58
+ import leafmap
59
+
60
+ m = leafmap.Map(basemap=basemap)
61
+
62
+ m.add_raster(
63
+ source=source,
64
+ indexes=indexes,
65
+ colormap=colormap,
66
+ vmin=vmin,
67
+ vmax=vmax,
68
+ nodata=nodata,
69
+ attribution=attribution,
70
+ layer_name=layer_name,
71
+ layer_index=layer_index,
72
+ zoom_to_layer=zoom_to_layer,
73
+ visible=visible,
74
+ opacity=opacity,
75
+ array_args=array_args,
76
+ client_args=client_args,
77
+ **kwargs,
78
+ )
79
+ return m
80
+
81
+
82
+ def viz_image(
83
+ image: Union[np.ndarray, torch.Tensor],
84
+ transpose: bool = False,
85
+ bdx: Optional[int] = None,
86
+ scale_factor: float = 1.0,
87
+ figsize: Tuple[int, int] = (10, 5),
88
+ axis_off: bool = True,
89
+ **kwargs: Any,
90
+ ) -> None:
91
+ """
92
+ Visualize an image using matplotlib.
93
+
94
+ Args:
95
+ image (Union[np.ndarray, torch.Tensor]): The image to visualize.
96
+ transpose (bool, optional): Whether to transpose the image. Defaults to False.
97
+ bdx (Optional[int], optional): The band index to visualize. Defaults to None.
98
+ scale_factor (float, optional): The scale factor to apply to the image. Defaults to 1.0.
99
+ figsize (Tuple[int, int], optional): The size of the figure. Defaults to (10, 5).
100
+ axis_off (bool, optional): Whether to turn off the axis. Defaults to True.
101
+ **kwargs (Any): Additional keyword arguments for plt.imshow().
102
+
103
+ Returns:
104
+ None
105
+ """
106
+
107
+ if isinstance(image, torch.Tensor):
108
+ image = image.cpu().numpy()
109
+
110
+ plt.figure(figsize=figsize)
111
+
112
+ if transpose:
113
+ image = image.transpose(1, 2, 0)
114
+
115
+ if bdx is not None:
116
+ image = image[:, :, bdx]
117
+
118
+ if len(image.shape) > 2 and image.shape[2] > 3:
119
+ image = image[:, :, 0:3]
120
+
121
+ if scale_factor != 1.0:
122
+ image = np.clip(image * scale_factor, 0, 1)
123
+
124
+ plt.imshow(image, **kwargs)
125
+ if axis_off:
126
+ plt.axis("off")
127
+ plt.show()
128
+ plt.close()
129
+
130
+
131
+ def plot_images(
132
+ images: Iterable[torch.Tensor],
133
+ axs: Iterable[plt.Axes],
134
+ chnls: List[int] = [2, 1, 0],
135
+ bright: float = 1.0,
136
+ ) -> None:
137
+ """
138
+ Plot a list of images.
139
+
140
+ Args:
141
+ images (Iterable[torch.Tensor]): The images to plot.
142
+ axs (Iterable[plt.Axes]): The axes to plot the images on.
143
+ chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
144
+ bright (float, optional): The brightness factor. Defaults to 1.0.
145
+
146
+ Returns:
147
+ None
148
+ """
149
+ for img, ax in zip(images, axs):
150
+ arr = torch.clamp(bright * img, min=0, max=1).numpy()
151
+ rgb = arr.transpose(1, 2, 0)[:, :, chnls]
152
+ ax.imshow(rgb)
153
+ ax.axis("off")
154
+
155
+
156
+ def plot_masks(
157
+ masks: Iterable[torch.Tensor], axs: Iterable[plt.Axes], cmap: str = "Blues"
158
+ ) -> None:
159
+ """
160
+ Plot a list of masks.
161
+
162
+ Args:
163
+ masks (Iterable[torch.Tensor]): The masks to plot.
164
+ axs (Iterable[plt.Axes]): The axes to plot the masks on.
165
+ cmap (str, optional): The colormap to use. Defaults to "Blues".
166
+
167
+ Returns:
168
+ None
169
+ """
170
+ for mask, ax in zip(masks, axs):
171
+ ax.imshow(mask.squeeze().numpy(), cmap=cmap)
172
+ ax.axis("off")
173
+
174
+
175
+ def plot_batch(
176
+ batch: Dict[str, Any],
177
+ bright: float = 1.0,
178
+ cols: int = 4,
179
+ width: int = 5,
180
+ chnls: List[int] = [2, 1, 0],
181
+ cmap: str = "Blues",
182
+ ) -> None:
183
+ """
184
+ Plot a batch of images and masks. This function is adapted from the plot_batch()
185
+ function in the torchgeo library at
186
+ https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html
187
+ Credit to the torchgeo developers for the original implementation.
188
+
189
+ Args:
190
+ batch (Dict[str, Any]): The batch containing images and masks.
191
+ bright (float, optional): The brightness factor. Defaults to 1.0.
192
+ cols (int, optional): The number of columns in the plot grid. Defaults to 4.
193
+ width (int, optional): The width of each plot. Defaults to 5.
194
+ chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
195
+ cmap (str, optional): The colormap to use for masks. Defaults to "Blues".
196
+
197
+ Returns:
198
+ None
199
+ """
200
+ # Get the samples and the number of items in the batch
201
+ samples = unbind_samples(batch.copy())
202
+
203
+ # if batch contains images and masks, the number of images will be doubled
204
+ n = 2 * len(samples) if ("image" in batch) and ("mask" in batch) else len(samples)
205
+
206
+ # calculate the number of rows in the grid
207
+ rows = n // cols + (1 if n % cols != 0 else 0)
208
+
209
+ # create a grid
210
+ _, axs = plt.subplots(rows, cols, figsize=(cols * width, rows * width))
211
+
212
+ if ("image" in batch) and ("mask" in batch):
213
+ # plot the images on the even axis
214
+ plot_images(
215
+ images=map(lambda x: x["image"], samples),
216
+ axs=axs.reshape(-1)[::2],
217
+ chnls=chnls,
218
+ bright=bright,
219
+ )
220
+
221
+ # plot the masks on the odd axis
222
+ plot_masks(masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1)[1::2])
223
+
224
+ else:
225
+ if "image" in batch:
226
+ plot_images(
227
+ images=map(lambda x: x["image"], samples),
228
+ axs=axs.reshape(-1),
229
+ chnls=chnls,
230
+ bright=bright,
231
+ )
232
+
233
+ elif "mask" in batch:
234
+ plot_masks(
235
+ masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1), cmap=cmap
236
+ )
237
+
238
+
239
+ def calc_stats(
240
+ dataset: RasterDataset, divide_by: float = 1.0
241
+ ) -> Tuple[np.ndarray, np.ndarray]:
242
+ """
243
+ Calculate the statistics (mean and std) for the entire dataset.
244
+
245
+ This function is adapted from the plot_batch() function in the torchgeo library at
246
+ https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html.
247
+ Credit to the torchgeo developers for the original implementation.
248
+
249
+ Warning: This is an approximation. The correct value should take into account the
250
+ mean for the whole dataset for computing individual stds.
251
+
252
+ Args:
253
+ dataset (RasterDataset): The dataset to calculate statistics for.
254
+ divide_by (float, optional): The value to divide the image data by. Defaults to 1.0.
255
+
256
+ Returns:
257
+ Tuple[np.ndarray, np.ndarray]: The mean and standard deviation for each band.
258
+ """
259
+ import rasterio as rio
260
+
261
+ # To avoid loading the entire dataset in memory, we will loop through each img
262
+ # The filenames will be retrieved from the dataset's rtree index
263
+ files = [
264
+ item.object
265
+ for item in dataset.index.intersection(dataset.index.bounds, objects=True)
266
+ ]
267
+
268
+ # Resetting statistics
269
+ accum_mean = 0
270
+ accum_std = 0
271
+
272
+ for file in files:
273
+ img = rio.open(file).read() / divide_by # type: ignore
274
+ accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
275
+ accum_std += img.reshape((img.shape[0], -1)).std(axis=1)
276
+
277
+ # at the end, we shall have 2 vectors with length n=chnls
278
+ # we will average them considering the number of images
279
+ return accum_mean / len(files), accum_std / len(files)
geoai/geoai.py CHANGED
@@ -1 +1,3 @@
1
1
  """Main module."""
2
+
3
+ from .common import viz_raster, viz_image, plot_batch, calc_stats
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: geoai-py
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
7
- Project-URL: Homepage, https://github.com/giswqs/geoai
7
+ Project-URL: Homepage, https://github.com/opengeos/geoai
8
8
  Keywords: geoai
9
9
  Classifier: Intended Audience :: Developers
10
10
  Classifier: License :: OSI Approved :: MIT License
@@ -18,14 +18,14 @@ Requires-Python: >=3.9
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: albumentations
21
+ Requires-Dist: jupyter-server-proxy
22
+ Requires-Dist: leafmap
23
+ Requires-Dist: localtileserver
21
24
  Requires-Dist: scikit-learn
22
25
  Requires-Dist: segment-geospatial
23
26
  Requires-Dist: torch
27
+ Requires-Dist: torchgeo
24
28
  Requires-Dist: transformers
25
- Provides-Extra: all
26
- Requires-Dist: geoai[extra]; extra == "all"
27
- Provides-Extra: extra
28
- Requires-Dist: pandas; extra == "extra"
29
29
 
30
30
  # geoai
31
31
 
@@ -0,0 +1,10 @@
1
+ geoai/__init__.py,sha256=FBqWQ-gHNvzP7beS6Vdp5upQnAVuYHa4CZLTvgpXfqA,143
2
+ geoai/common.py,sha256=6h6mtUBO428P3IZppWyCVo04Ohzc3VhmnH0tvVh479g,9675
3
+ geoai/geoai.py,sha256=TmR7x1uL51G5oAjw0AQWnC5VQtLWDygyFLrDIj46xNc,86
4
+ geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
5
+ geoai_py-0.1.6.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
6
+ geoai_py-0.1.6.dist-info/METADATA,sha256=ByO92YNy910hBcIhcHtm4_6A1fi7WEElrzldHHv-YmU,1599
7
+ geoai_py-0.1.6.dist-info/WHEEL,sha256=9Hm2OB-j1QcCUq9Jguht7ayGIIZBRTdOXD1qg9cCgPM,109
8
+ geoai_py-0.1.6.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
9
+ geoai_py-0.1.6.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
10
+ geoai_py-0.1.6.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- geoai/__init__.py,sha256=xkyHrnU3iebQB2V0bl1Xd9s8YKI5-UTS-tMbjgnJXNY,120
2
- geoai/common.py,sha256=Rw6d9qmZDu3dUGTyJto1Y97S7-QA-m2p-pbCNvMDrm4,184
3
- geoai/geoai.py,sha256=h0hwdogXGFqerm-5ZPeT-irPn91pCcQRjiHThXsRzEk,19
4
- geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
5
- geoai_py-0.1.5.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
6
- geoai_py-0.1.5.dist-info/METADATA,sha256=oQrREZxyg5_OgaVK4Pkn9txkaqUYcUtvigo13MfFNo0,1609
7
- geoai_py-0.1.5.dist-info/WHEEL,sha256=9Hm2OB-j1QcCUq9Jguht7ayGIIZBRTdOXD1qg9cCgPM,109
8
- geoai_py-0.1.5.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
9
- geoai_py-0.1.5.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
10
- geoai_py-0.1.5.dist-info/RECORD,,