geoai-py 0.1.0__py2.py3-none-any.whl → 0.1.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +5 -2
- geoai/common.py +277 -5
- geoai/geoai.py +2 -0
- geoai/segmentation.py +349 -0
- {geoai_py-0.1.0.dist-info → geoai_py-0.1.6.dist-info}/METADATA +17 -10
- geoai_py-0.1.6.dist-info/RECORD +10 -0
- {geoai_py-0.1.0.dist-info → geoai_py-0.1.6.dist-info}/WHEEL +1 -1
- geoai_py-0.1.6.dist-info/entry_points.txt +2 -0
- geoai_py-0.1.0.dist-info/RECORD +0 -9
- geoai_py-0.1.0.dist-info/dependency_links.txt +0 -1
- {geoai_py-0.1.0.dist-info → geoai_py-0.1.6.dist-info}/LICENSE +0 -0
- {geoai_py-0.1.0.dist-info → geoai_py-0.1.6.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/common.py
CHANGED
|
@@ -1,7 +1,279 @@
|
|
|
1
|
-
"""The common module contains common functions and classes used by the other modules.
|
|
2
|
-
"""
|
|
1
|
+
"""The common module contains common functions and classes used by the other modules."""
|
|
3
2
|
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
|
|
12
|
+
from torchgeo.samplers import RandomGeoSampler, Units
|
|
13
|
+
from torchgeo.transforms import indices
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def viz_raster(
|
|
17
|
+
source: str,
|
|
18
|
+
indexes: Optional[int] = None,
|
|
19
|
+
colormap: Optional[str] = None,
|
|
20
|
+
vmin: Optional[float] = None,
|
|
21
|
+
vmax: Optional[float] = None,
|
|
22
|
+
nodata: Optional[float] = None,
|
|
23
|
+
attribution: Optional[str] = None,
|
|
24
|
+
layer_name: Optional[str] = "Raster",
|
|
25
|
+
layer_index: Optional[int] = None,
|
|
26
|
+
zoom_to_layer: Optional[bool] = True,
|
|
27
|
+
visible: Optional[bool] = True,
|
|
28
|
+
opacity: Optional[float] = 1.0,
|
|
29
|
+
array_args: Optional[Dict] = {},
|
|
30
|
+
client_args: Optional[Dict] = {"cors_all": False},
|
|
31
|
+
basemap: Optional[str] = "OpenStreetMap",
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Visualize a raster using leafmap.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
source (str): The source of the raster.
|
|
39
|
+
indexes (Optional[int], optional): The band indexes to visualize. Defaults to None.
|
|
40
|
+
colormap (Optional[str], optional): The colormap to apply. Defaults to None.
|
|
41
|
+
vmin (Optional[float], optional): The minimum value for colormap scaling. Defaults to None.
|
|
42
|
+
vmax (Optional[float], optional): The maximum value for colormap scaling. Defaults to None.
|
|
43
|
+
nodata (Optional[float], optional): The nodata value. Defaults to None.
|
|
44
|
+
attribution (Optional[str], optional): The attribution for the raster. Defaults to None.
|
|
45
|
+
layer_name (Optional[str], optional): The name of the layer. Defaults to "Raster".
|
|
46
|
+
layer_index (Optional[int], optional): The index of the layer. Defaults to None.
|
|
47
|
+
zoom_to_layer (Optional[bool], optional): Whether to zoom to the layer. Defaults to True.
|
|
48
|
+
visible (Optional[bool], optional): Whether the layer is visible. Defaults to True.
|
|
49
|
+
opacity (Optional[float], optional): The opacity of the layer. Defaults to 1.0.
|
|
50
|
+
array_args (Optional[Dict], optional): Additional arguments for array processing. Defaults to {}.
|
|
51
|
+
client_args (Optional[Dict], optional): Additional arguments for the client. Defaults to {"cors_all": False}.
|
|
52
|
+
basemap (Optional[str], optional): The basemap to use. Defaults to "OpenStreetMap".
|
|
53
|
+
**kwargs (Any): Additional keyword arguments.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
leafmap.Map: The map object with the raster layer added.
|
|
57
|
+
"""
|
|
58
|
+
import leafmap
|
|
59
|
+
|
|
60
|
+
m = leafmap.Map(basemap=basemap)
|
|
61
|
+
|
|
62
|
+
m.add_raster(
|
|
63
|
+
source=source,
|
|
64
|
+
indexes=indexes,
|
|
65
|
+
colormap=colormap,
|
|
66
|
+
vmin=vmin,
|
|
67
|
+
vmax=vmax,
|
|
68
|
+
nodata=nodata,
|
|
69
|
+
attribution=attribution,
|
|
70
|
+
layer_name=layer_name,
|
|
71
|
+
layer_index=layer_index,
|
|
72
|
+
zoom_to_layer=zoom_to_layer,
|
|
73
|
+
visible=visible,
|
|
74
|
+
opacity=opacity,
|
|
75
|
+
array_args=array_args,
|
|
76
|
+
client_args=client_args,
|
|
77
|
+
**kwargs,
|
|
78
|
+
)
|
|
79
|
+
return m
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def viz_image(
|
|
83
|
+
image: Union[np.ndarray, torch.Tensor],
|
|
84
|
+
transpose: bool = False,
|
|
85
|
+
bdx: Optional[int] = None,
|
|
86
|
+
scale_factor: float = 1.0,
|
|
87
|
+
figsize: Tuple[int, int] = (10, 5),
|
|
88
|
+
axis_off: bool = True,
|
|
89
|
+
**kwargs: Any,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Visualize an image using matplotlib.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
image (Union[np.ndarray, torch.Tensor]): The image to visualize.
|
|
96
|
+
transpose (bool, optional): Whether to transpose the image. Defaults to False.
|
|
97
|
+
bdx (Optional[int], optional): The band index to visualize. Defaults to None.
|
|
98
|
+
scale_factor (float, optional): The scale factor to apply to the image. Defaults to 1.0.
|
|
99
|
+
figsize (Tuple[int, int], optional): The size of the figure. Defaults to (10, 5).
|
|
100
|
+
axis_off (bool, optional): Whether to turn off the axis. Defaults to True.
|
|
101
|
+
**kwargs (Any): Additional keyword arguments for plt.imshow().
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
None
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
if isinstance(image, torch.Tensor):
|
|
108
|
+
image = image.cpu().numpy()
|
|
109
|
+
|
|
110
|
+
plt.figure(figsize=figsize)
|
|
111
|
+
|
|
112
|
+
if transpose:
|
|
113
|
+
image = image.transpose(1, 2, 0)
|
|
114
|
+
|
|
115
|
+
if bdx is not None:
|
|
116
|
+
image = image[:, :, bdx]
|
|
117
|
+
|
|
118
|
+
if len(image.shape) > 2 and image.shape[2] > 3:
|
|
119
|
+
image = image[:, :, 0:3]
|
|
120
|
+
|
|
121
|
+
if scale_factor != 1.0:
|
|
122
|
+
image = np.clip(image * scale_factor, 0, 1)
|
|
123
|
+
|
|
124
|
+
plt.imshow(image, **kwargs)
|
|
125
|
+
if axis_off:
|
|
126
|
+
plt.axis("off")
|
|
127
|
+
plt.show()
|
|
128
|
+
plt.close()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def plot_images(
|
|
132
|
+
images: Iterable[torch.Tensor],
|
|
133
|
+
axs: Iterable[plt.Axes],
|
|
134
|
+
chnls: List[int] = [2, 1, 0],
|
|
135
|
+
bright: float = 1.0,
|
|
136
|
+
) -> None:
|
|
137
|
+
"""
|
|
138
|
+
Plot a list of images.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
images (Iterable[torch.Tensor]): The images to plot.
|
|
142
|
+
axs (Iterable[plt.Axes]): The axes to plot the images on.
|
|
143
|
+
chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
|
|
144
|
+
bright (float, optional): The brightness factor. Defaults to 1.0.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
None
|
|
148
|
+
"""
|
|
149
|
+
for img, ax in zip(images, axs):
|
|
150
|
+
arr = torch.clamp(bright * img, min=0, max=1).numpy()
|
|
151
|
+
rgb = arr.transpose(1, 2, 0)[:, :, chnls]
|
|
152
|
+
ax.imshow(rgb)
|
|
153
|
+
ax.axis("off")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def plot_masks(
|
|
157
|
+
masks: Iterable[torch.Tensor], axs: Iterable[plt.Axes], cmap: str = "Blues"
|
|
158
|
+
) -> None:
|
|
159
|
+
"""
|
|
160
|
+
Plot a list of masks.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
masks (Iterable[torch.Tensor]): The masks to plot.
|
|
164
|
+
axs (Iterable[plt.Axes]): The axes to plot the masks on.
|
|
165
|
+
cmap (str, optional): The colormap to use. Defaults to "Blues".
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
None
|
|
6
169
|
"""
|
|
7
|
-
|
|
170
|
+
for mask, ax in zip(masks, axs):
|
|
171
|
+
ax.imshow(mask.squeeze().numpy(), cmap=cmap)
|
|
172
|
+
ax.axis("off")
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def plot_batch(
|
|
176
|
+
batch: Dict[str, Any],
|
|
177
|
+
bright: float = 1.0,
|
|
178
|
+
cols: int = 4,
|
|
179
|
+
width: int = 5,
|
|
180
|
+
chnls: List[int] = [2, 1, 0],
|
|
181
|
+
cmap: str = "Blues",
|
|
182
|
+
) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Plot a batch of images and masks. This function is adapted from the plot_batch()
|
|
185
|
+
function in the torchgeo library at
|
|
186
|
+
https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html
|
|
187
|
+
Credit to the torchgeo developers for the original implementation.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
batch (Dict[str, Any]): The batch containing images and masks.
|
|
191
|
+
bright (float, optional): The brightness factor. Defaults to 1.0.
|
|
192
|
+
cols (int, optional): The number of columns in the plot grid. Defaults to 4.
|
|
193
|
+
width (int, optional): The width of each plot. Defaults to 5.
|
|
194
|
+
chnls (List[int], optional): The channels to use for RGB. Defaults to [2, 1, 0].
|
|
195
|
+
cmap (str, optional): The colormap to use for masks. Defaults to "Blues".
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
None
|
|
199
|
+
"""
|
|
200
|
+
# Get the samples and the number of items in the batch
|
|
201
|
+
samples = unbind_samples(batch.copy())
|
|
202
|
+
|
|
203
|
+
# if batch contains images and masks, the number of images will be doubled
|
|
204
|
+
n = 2 * len(samples) if ("image" in batch) and ("mask" in batch) else len(samples)
|
|
205
|
+
|
|
206
|
+
# calculate the number of rows in the grid
|
|
207
|
+
rows = n // cols + (1 if n % cols != 0 else 0)
|
|
208
|
+
|
|
209
|
+
# create a grid
|
|
210
|
+
_, axs = plt.subplots(rows, cols, figsize=(cols * width, rows * width))
|
|
211
|
+
|
|
212
|
+
if ("image" in batch) and ("mask" in batch):
|
|
213
|
+
# plot the images on the even axis
|
|
214
|
+
plot_images(
|
|
215
|
+
images=map(lambda x: x["image"], samples),
|
|
216
|
+
axs=axs.reshape(-1)[::2],
|
|
217
|
+
chnls=chnls,
|
|
218
|
+
bright=bright,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# plot the masks on the odd axis
|
|
222
|
+
plot_masks(masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1)[1::2])
|
|
223
|
+
|
|
224
|
+
else:
|
|
225
|
+
if "image" in batch:
|
|
226
|
+
plot_images(
|
|
227
|
+
images=map(lambda x: x["image"], samples),
|
|
228
|
+
axs=axs.reshape(-1),
|
|
229
|
+
chnls=chnls,
|
|
230
|
+
bright=bright,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
elif "mask" in batch:
|
|
234
|
+
plot_masks(
|
|
235
|
+
masks=map(lambda x: x["mask"], samples), axs=axs.reshape(-1), cmap=cmap
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def calc_stats(
|
|
240
|
+
dataset: RasterDataset, divide_by: float = 1.0
|
|
241
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
242
|
+
"""
|
|
243
|
+
Calculate the statistics (mean and std) for the entire dataset.
|
|
244
|
+
|
|
245
|
+
This function is adapted from the plot_batch() function in the torchgeo library at
|
|
246
|
+
https://torchgeo.readthedocs.io/en/stable/tutorials/earth_surface_water.html.
|
|
247
|
+
Credit to the torchgeo developers for the original implementation.
|
|
248
|
+
|
|
249
|
+
Warning: This is an approximation. The correct value should take into account the
|
|
250
|
+
mean for the whole dataset for computing individual stds.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
dataset (RasterDataset): The dataset to calculate statistics for.
|
|
254
|
+
divide_by (float, optional): The value to divide the image data by. Defaults to 1.0.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Tuple[np.ndarray, np.ndarray]: The mean and standard deviation for each band.
|
|
258
|
+
"""
|
|
259
|
+
import rasterio as rio
|
|
260
|
+
|
|
261
|
+
# To avoid loading the entire dataset in memory, we will loop through each img
|
|
262
|
+
# The filenames will be retrieved from the dataset's rtree index
|
|
263
|
+
files = [
|
|
264
|
+
item.object
|
|
265
|
+
for item in dataset.index.intersection(dataset.index.bounds, objects=True)
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
# Resetting statistics
|
|
269
|
+
accum_mean = 0
|
|
270
|
+
accum_std = 0
|
|
271
|
+
|
|
272
|
+
for file in files:
|
|
273
|
+
img = rio.open(file).read() / divide_by # type: ignore
|
|
274
|
+
accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
|
|
275
|
+
accum_std += img.reshape((img.shape[0], -1)).std(axis=1)
|
|
276
|
+
|
|
277
|
+
# at the end, we shall have 2 vectors with length n=chnls
|
|
278
|
+
# we will average them considering the number of images
|
|
279
|
+
return accum_mean / len(files), accum_std / len(files)
|
geoai/geoai.py
CHANGED
geoai/segmentation.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image
|
|
4
|
+
import torch
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
from torch.utils.data import Dataset, Subset
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from sklearn.model_selection import train_test_split
|
|
9
|
+
import albumentations as A
|
|
10
|
+
from albumentations.pytorch import ToTensorV2
|
|
11
|
+
from transformers import (
|
|
12
|
+
Trainer,
|
|
13
|
+
TrainingArguments,
|
|
14
|
+
SegformerForSemanticSegmentation,
|
|
15
|
+
DefaultDataCollator,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CustomDataset(Dataset):
|
|
20
|
+
"""Custom Dataset for loading images and masks."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
images_dir: str,
|
|
25
|
+
masks_dir: str,
|
|
26
|
+
transform: A.Compose = None,
|
|
27
|
+
target_size: tuple = (256, 256),
|
|
28
|
+
num_classes: int = 2,
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
images_dir (str): Directory containing images.
|
|
33
|
+
masks_dir (str): Directory containing masks.
|
|
34
|
+
transform (A.Compose, optional): Transformations to be applied on the images and masks.
|
|
35
|
+
target_size (tuple, optional): Target size for resizing images and masks.
|
|
36
|
+
num_classes (int, optional): Number of classes in the masks.
|
|
37
|
+
"""
|
|
38
|
+
self.images_dir = images_dir
|
|
39
|
+
self.masks_dir = masks_dir
|
|
40
|
+
self.transform = transform
|
|
41
|
+
self.target_size = target_size
|
|
42
|
+
self.num_classes = num_classes
|
|
43
|
+
self.images = sorted(os.listdir(images_dir))
|
|
44
|
+
self.masks = sorted(os.listdir(masks_dir))
|
|
45
|
+
|
|
46
|
+
def __len__(self) -> int:
|
|
47
|
+
"""Returns the total number of samples."""
|
|
48
|
+
return len(self.images)
|
|
49
|
+
|
|
50
|
+
def __getitem__(self, idx: int) -> dict:
|
|
51
|
+
"""
|
|
52
|
+
Args:
|
|
53
|
+
idx (int): Index of the sample to fetch.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
dict: A dictionary with 'pixel_values' and 'labels'.
|
|
57
|
+
"""
|
|
58
|
+
img_path = os.path.join(self.images_dir, self.images[idx])
|
|
59
|
+
mask_path = os.path.join(self.masks_dir, self.masks[idx])
|
|
60
|
+
image = Image.open(img_path).convert("RGB")
|
|
61
|
+
mask = Image.open(mask_path).convert("L")
|
|
62
|
+
|
|
63
|
+
image = image.resize(self.target_size)
|
|
64
|
+
mask = mask.resize(self.target_size)
|
|
65
|
+
|
|
66
|
+
image = np.array(image)
|
|
67
|
+
mask = np.array(mask)
|
|
68
|
+
|
|
69
|
+
mask = (mask > 127).astype(np.uint8)
|
|
70
|
+
|
|
71
|
+
if self.transform:
|
|
72
|
+
transformed = self.transform(image=image, mask=mask)
|
|
73
|
+
image = transformed["image"]
|
|
74
|
+
mask = transformed["mask"]
|
|
75
|
+
|
|
76
|
+
assert (
|
|
77
|
+
mask.max() < self.num_classes
|
|
78
|
+
), f"Mask values should be less than {self.num_classes}, but found {mask.max()}"
|
|
79
|
+
assert (
|
|
80
|
+
mask.min() >= 0
|
|
81
|
+
), f"Mask values should be greater than or equal to 0, but found {mask.min()}"
|
|
82
|
+
|
|
83
|
+
mask = mask.clone().detach().long()
|
|
84
|
+
|
|
85
|
+
return {"pixel_values": image, "labels": mask}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_transform() -> A.Compose:
|
|
89
|
+
"""
|
|
90
|
+
Returns:
|
|
91
|
+
A.Compose: A composition of image transformations.
|
|
92
|
+
"""
|
|
93
|
+
return A.Compose(
|
|
94
|
+
[
|
|
95
|
+
A.Resize(256, 256),
|
|
96
|
+
A.HorizontalFlip(p=0.5),
|
|
97
|
+
A.VerticalFlip(p=0.5),
|
|
98
|
+
A.RandomRotate90(p=0.5),
|
|
99
|
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
100
|
+
ToTensorV2(),
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def prepare_datasets(
|
|
106
|
+
images_dir: str,
|
|
107
|
+
masks_dir: str,
|
|
108
|
+
transform: A.Compose,
|
|
109
|
+
test_size: float = 0.2,
|
|
110
|
+
random_state: int = 42,
|
|
111
|
+
) -> tuple:
|
|
112
|
+
"""
|
|
113
|
+
Args:
|
|
114
|
+
images_dir (str): Directory containing images.
|
|
115
|
+
masks_dir (str): Directory containing masks.
|
|
116
|
+
transform (A.Compose): Transformations to be applied.
|
|
117
|
+
test_size (float, optional): Proportion of the dataset to include in the validation split.
|
|
118
|
+
random_state (int, optional): Random seed for shuffling the dataset.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
tuple: Training and validation datasets.
|
|
122
|
+
"""
|
|
123
|
+
dataset = CustomDataset(images_dir, masks_dir, transform)
|
|
124
|
+
train_indices, val_indices = train_test_split(
|
|
125
|
+
list(range(len(dataset))), test_size=test_size, random_state=random_state
|
|
126
|
+
)
|
|
127
|
+
train_dataset = Subset(dataset, train_indices)
|
|
128
|
+
val_dataset = Subset(dataset, val_indices)
|
|
129
|
+
return train_dataset, val_dataset
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def train_model(
|
|
133
|
+
train_dataset: Dataset,
|
|
134
|
+
val_dataset: Dataset,
|
|
135
|
+
pretrained_model: str = "nvidia/segformer-b0-finetuned-ade-512-512",
|
|
136
|
+
model_save_path: str = "./model",
|
|
137
|
+
output_dir: str = "./results",
|
|
138
|
+
num_epochs: int = 10,
|
|
139
|
+
batch_size: int = 8,
|
|
140
|
+
learning_rate: float = 5e-5,
|
|
141
|
+
) -> str:
|
|
142
|
+
"""
|
|
143
|
+
Trains the model and saves the fine-tuned model to the specified path.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
train_dataset (Dataset): Training dataset.
|
|
147
|
+
val_dataset (Dataset): Validation dataset.
|
|
148
|
+
pretrained_model (str, optional): Pretrained model to fine-tune.
|
|
149
|
+
model_save_path (str): Path to save the fine-tuned model. Defaults to './model'.
|
|
150
|
+
output_dir (str, optional): Directory to save training outputs.
|
|
151
|
+
num_epochs (int, optional): Number of training epochs.
|
|
152
|
+
batch_size (int, optional): Batch size for training and evaluation.
|
|
153
|
+
learning_rate (float, optional): Learning rate for training.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
str: Path to the saved fine-tuned model.
|
|
157
|
+
"""
|
|
158
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
159
|
+
model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model).to(
|
|
160
|
+
device
|
|
161
|
+
)
|
|
162
|
+
data_collator = DefaultDataCollator(return_tensors="pt")
|
|
163
|
+
|
|
164
|
+
training_args = TrainingArguments(
|
|
165
|
+
output_dir=output_dir,
|
|
166
|
+
num_train_epochs=num_epochs,
|
|
167
|
+
per_device_train_batch_size=batch_size,
|
|
168
|
+
per_device_eval_batch_size=batch_size,
|
|
169
|
+
eval_strategy="epoch",
|
|
170
|
+
save_strategy="epoch",
|
|
171
|
+
logging_dir="./logs",
|
|
172
|
+
learning_rate=learning_rate,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
trainer = Trainer(
|
|
176
|
+
model=model,
|
|
177
|
+
args=training_args,
|
|
178
|
+
data_collator=data_collator,
|
|
179
|
+
train_dataset=train_dataset,
|
|
180
|
+
eval_dataset=val_dataset,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
trainer.train()
|
|
184
|
+
model.save_pretrained(model_save_path)
|
|
185
|
+
print(f"Model saved to {model_save_path}")
|
|
186
|
+
return model_save_path
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def load_model(
|
|
190
|
+
model_path: str, device: torch.device
|
|
191
|
+
) -> SegformerForSemanticSegmentation:
|
|
192
|
+
"""
|
|
193
|
+
Loads the fine-tuned model from the specified path.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
model_path (str): Path to the model.
|
|
197
|
+
device (torch.device): Device to load the model on.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
SegformerForSemanticSegmentation: Loaded model.
|
|
201
|
+
"""
|
|
202
|
+
model = SegformerForSemanticSegmentation.from_pretrained(model_path)
|
|
203
|
+
model.to(device)
|
|
204
|
+
model.eval()
|
|
205
|
+
return model
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def preprocess_image(image_path: str, target_size: tuple = (256, 256)) -> torch.Tensor:
|
|
209
|
+
"""
|
|
210
|
+
Preprocesses the input image for prediction.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
image_path (str): Path to the input image.
|
|
214
|
+
target_size (tuple, optional): Target size for resizing the image.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
torch.Tensor: Preprocessed image tensor.
|
|
218
|
+
"""
|
|
219
|
+
image = Image.open(image_path).convert("RGB")
|
|
220
|
+
transform = A.Compose(
|
|
221
|
+
[
|
|
222
|
+
A.Resize(target_size[0], target_size[1]),
|
|
223
|
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
224
|
+
ToTensorV2(),
|
|
225
|
+
]
|
|
226
|
+
)
|
|
227
|
+
image = np.array(image)
|
|
228
|
+
transformed = transform(image=image)
|
|
229
|
+
return transformed["image"].unsqueeze(0)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def predict_image(
|
|
233
|
+
model: SegformerForSemanticSegmentation,
|
|
234
|
+
image_tensor: torch.Tensor,
|
|
235
|
+
original_size: tuple,
|
|
236
|
+
device: torch.device,
|
|
237
|
+
) -> np.ndarray:
|
|
238
|
+
"""
|
|
239
|
+
Predicts the segmentation mask for the input image.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
model (SegformerForSemanticSegmentation): Fine-tuned model.
|
|
243
|
+
image_tensor (torch.Tensor): Preprocessed image tensor.
|
|
244
|
+
original_size (tuple): Original size of the image (width, height).
|
|
245
|
+
device (torch.device): Device to perform inference on.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
np.ndarray: Predicted segmentation mask.
|
|
249
|
+
"""
|
|
250
|
+
with torch.no_grad():
|
|
251
|
+
image_tensor = image_tensor.to(device)
|
|
252
|
+
outputs = model(pixel_values=image_tensor)
|
|
253
|
+
logits = outputs.logits
|
|
254
|
+
upsampled_logits = F.interpolate(
|
|
255
|
+
logits, size=original_size[::-1], mode="bilinear", align_corners=False
|
|
256
|
+
)
|
|
257
|
+
predictions = torch.argmax(upsampled_logits, dim=1).cpu().numpy()
|
|
258
|
+
return predictions[0]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def segment_image(
|
|
262
|
+
image_path: str,
|
|
263
|
+
model_path: str,
|
|
264
|
+
target_size: tuple = (256, 256),
|
|
265
|
+
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
|
266
|
+
) -> np.ndarray:
|
|
267
|
+
"""
|
|
268
|
+
Segments the input image using the fine-tuned model.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
image_path (str): Path to the input image.
|
|
272
|
+
model_path (str): Path to the fine-tuned model.
|
|
273
|
+
target_size (tuple, optional): Target size for resizing the image.
|
|
274
|
+
device (torch.device, optional): Device to perform inference on.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
np.ndarray: Predicted segmentation mask.
|
|
278
|
+
"""
|
|
279
|
+
model = load_model(model_path, device)
|
|
280
|
+
image = Image.open(image_path).convert("RGB")
|
|
281
|
+
original_size = image.size
|
|
282
|
+
image_tensor = preprocess_image(image_path, target_size)
|
|
283
|
+
predictions = predict_image(model, image_tensor, original_size, device)
|
|
284
|
+
return predictions
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def visualize_predictions(
|
|
288
|
+
image_path: str,
|
|
289
|
+
segmented_mask: np.ndarray,
|
|
290
|
+
target_size: tuple = (256, 256),
|
|
291
|
+
reference_image_path: str = None,
|
|
292
|
+
) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Visualizes the original image, segmented mask, and optionally the reference image.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
image_path (str): Path to the original image.
|
|
298
|
+
segmented_mask (np.ndarray): Predicted segmentation mask.
|
|
299
|
+
target_size (tuple, optional): Target size for resizing images.
|
|
300
|
+
reference_image_path (str, optional): Path to the reference image.
|
|
301
|
+
"""
|
|
302
|
+
original_image = Image.open(image_path).convert("RGB")
|
|
303
|
+
original_image = original_image.resize(target_size)
|
|
304
|
+
segmented_image = Image.fromarray((segmented_mask * 255).astype(np.uint8))
|
|
305
|
+
|
|
306
|
+
if reference_image_path:
|
|
307
|
+
reference_image = Image.open(reference_image_path).convert("RGB")
|
|
308
|
+
reference_image = reference_image.resize(target_size)
|
|
309
|
+
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
|
310
|
+
axes[1].imshow(reference_image)
|
|
311
|
+
axes[1].set_title("Reference Image")
|
|
312
|
+
axes[1].axis("off")
|
|
313
|
+
else:
|
|
314
|
+
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
|
|
315
|
+
|
|
316
|
+
axes[0].imshow(original_image)
|
|
317
|
+
axes[0].set_title("Original Image")
|
|
318
|
+
axes[0].axis("off")
|
|
319
|
+
|
|
320
|
+
if reference_image_path:
|
|
321
|
+
axes[2].imshow(segmented_image, cmap="gray")
|
|
322
|
+
axes[2].set_title("Segmented Image")
|
|
323
|
+
axes[2].axis("off")
|
|
324
|
+
else:
|
|
325
|
+
axes[1].imshow(segmented_image, cmap="gray")
|
|
326
|
+
axes[1].set_title("Segmented Image")
|
|
327
|
+
axes[1].axis("off")
|
|
328
|
+
|
|
329
|
+
plt.tight_layout()
|
|
330
|
+
plt.show()
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# Example usage
|
|
334
|
+
if __name__ == "__main__":
|
|
335
|
+
images_dir = "../datasets/Water-Bodies-Dataset/Images"
|
|
336
|
+
masks_dir = "../datasets/Water-Bodies-Dataset/Masks"
|
|
337
|
+
transform = get_transform()
|
|
338
|
+
train_dataset, val_dataset = prepare_datasets(images_dir, masks_dir, transform)
|
|
339
|
+
|
|
340
|
+
model_save_path = "./fine_tuned_model"
|
|
341
|
+
train_model(train_dataset, val_dataset, model_save_path)
|
|
342
|
+
|
|
343
|
+
image_path = "../datasets/Water-Bodies-Dataset/Images/water_body_44.jpg"
|
|
344
|
+
reference_image_path = image_path.replace("Images", "Masks")
|
|
345
|
+
segmented_mask = segment_image(image_path, model_save_path)
|
|
346
|
+
|
|
347
|
+
visualize_predictions(
|
|
348
|
+
image_path, segmented_mask, reference_image_path=reference_image_path
|
|
349
|
+
)
|
|
@@ -1,29 +1,35 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: geoai-py
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
License: MIT license
|
|
5
|
+
Author-email: Qiusheng Wu <giswqs@gmail.com>
|
|
6
|
+
License: MIT License
|
|
7
|
+
Project-URL: Homepage, https://github.com/opengeos/geoai
|
|
9
8
|
Keywords: geoai
|
|
10
9
|
Classifier: Intended Audience :: Developers
|
|
11
10
|
Classifier: License :: OSI Approved :: MIT License
|
|
12
11
|
Classifier: Natural Language :: English
|
|
13
|
-
Classifier: Programming Language :: Python :: 3
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
15
12
|
Classifier: Programming Language :: Python :: 3.9
|
|
16
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
17
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
-
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
|
+
Requires-Python: >=3.9
|
|
19
18
|
Description-Content-Type: text/markdown
|
|
20
19
|
License-File: LICENSE
|
|
20
|
+
Requires-Dist: albumentations
|
|
21
|
+
Requires-Dist: jupyter-server-proxy
|
|
22
|
+
Requires-Dist: leafmap
|
|
23
|
+
Requires-Dist: localtileserver
|
|
24
|
+
Requires-Dist: scikit-learn
|
|
21
25
|
Requires-Dist: segment-geospatial
|
|
26
|
+
Requires-Dist: torch
|
|
27
|
+
Requires-Dist: torchgeo
|
|
28
|
+
Requires-Dist: transformers
|
|
22
29
|
|
|
23
30
|
# geoai
|
|
24
31
|
|
|
25
32
|
[](https://pypi.python.org/pypi/geoai-py)
|
|
26
|
-
|
|
27
33
|
[](https://anaconda.org/conda-forge/geoai)
|
|
28
34
|
|
|
29
35
|
**A Python package for using Artificial Intelligence (AI) with geospatial data**
|
|
@@ -33,5 +39,6 @@ Requires-Dist: segment-geospatial
|
|
|
33
39
|
|
|
34
40
|
## Features
|
|
35
41
|
|
|
42
|
+
- Visualizing geospatial data, including vector, raster, and LiDAR data
|
|
36
43
|
- Segmenting remote sensing imagery with the Segment Anything Model
|
|
37
44
|
- Classifying remote sensing imagery with deep learning models
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
geoai/__init__.py,sha256=FBqWQ-gHNvzP7beS6Vdp5upQnAVuYHa4CZLTvgpXfqA,143
|
|
2
|
+
geoai/common.py,sha256=6h6mtUBO428P3IZppWyCVo04Ohzc3VhmnH0tvVh479g,9675
|
|
3
|
+
geoai/geoai.py,sha256=TmR7x1uL51G5oAjw0AQWnC5VQtLWDygyFLrDIj46xNc,86
|
|
4
|
+
geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
|
|
5
|
+
geoai_py-0.1.6.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
|
6
|
+
geoai_py-0.1.6.dist-info/METADATA,sha256=ByO92YNy910hBcIhcHtm4_6A1fi7WEElrzldHHv-YmU,1599
|
|
7
|
+
geoai_py-0.1.6.dist-info/WHEEL,sha256=9Hm2OB-j1QcCUq9Jguht7ayGIIZBRTdOXD1qg9cCgPM,109
|
|
8
|
+
geoai_py-0.1.6.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
|
9
|
+
geoai_py-0.1.6.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
|
10
|
+
geoai_py-0.1.6.dist-info/RECORD,,
|
geoai_py-0.1.0.dist-info/RECORD
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
geoai/__init__.py,sha256=GX7AiSmG5X4zGHQ3PijrWbmSLhPtyB5Gw1BX0gBVZQE,120
|
|
2
|
-
geoai/common.py,sha256=Si4ZvbkkaVGyKrOa7gmsVNWS2Abtji-V5nEsJhrIM3M,188
|
|
3
|
-
geoai/geoai.py,sha256=h0hwdogXGFqerm-5ZPeT-irPn91pCcQRjiHThXsRzEk,19
|
|
4
|
-
geoai_py-0.1.0.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
|
5
|
-
geoai_py-0.1.0.dist-info/METADATA,sha256=LSE7RxfY84Sutf-Rlrb9640WiIY3ZGSOiLVjqGVL89Y,1295
|
|
6
|
-
geoai_py-0.1.0.dist-info/WHEEL,sha256=iYlv5fX357PQyRT2o6tw1bN-YcKFFHKqB_LwHO5wP-g,110
|
|
7
|
-
geoai_py-0.1.0.dist-info/dependency_links.txt,sha256=FfEs5pM2sDUo8djMXZayv60m1an89YQIFGXl8UUNdmE,19
|
|
8
|
-
geoai_py-0.1.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
|
9
|
-
geoai_py-0.1.0.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
segment-geospatial
|
|
File without changes
|
|
File without changes
|