geoai-py 0.3.5__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/download.py +9 -8
- geoai/extract.py +158 -38
- geoai/geoai.py +4 -1
- geoai/hf.py +447 -0
- geoai/segment.py +306 -0
- geoai/segmentation.py +8 -7
- geoai/train.py +1039 -0
- geoai/utils.py +863 -25
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/METADATA +5 -1
- geoai_py-0.4.0.dist-info/RECORD +15 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/WHEEL +1 -1
- geoai/preprocess.py +0 -3021
- geoai_py-0.3.5.dist-info/RECORD +0 -13
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.5.dist-info → geoai_py-0.4.0.dist-info}/top_level.txt +0 -0
geoai/segment.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
"""This module provides functionality for segmenting high-resolution satellite imagery using vision-language models."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import rasterio
|
|
7
|
+
import torch
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from rasterio.windows import Window
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CLIPSegmentation:
|
|
15
|
+
"""
|
|
16
|
+
A class for segmenting high-resolution satellite imagery using text prompts with CLIP-based models.
|
|
17
|
+
|
|
18
|
+
This segmenter utilizes the CLIP-Seg model to perform semantic segmentation based on text prompts.
|
|
19
|
+
It can process large GeoTIFF files by tiling them and handles proper georeferencing in the output.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
|
|
23
|
+
device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
|
|
24
|
+
tile_size (int): Size of tiles to process the image in chunks. Defaults to 352.
|
|
25
|
+
overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 16.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
processor (CLIPSegProcessor): The processor for the CLIP-Seg model.
|
|
29
|
+
model (CLIPSegForImageSegmentation): The CLIP-Seg model for segmentation.
|
|
30
|
+
device (str): The device being used ('cuda' or 'cpu').
|
|
31
|
+
tile_size (int): Size of tiles for processing.
|
|
32
|
+
overlap (int): Overlap between tiles.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model_name="CIDAS/clipseg-rd64-refined",
|
|
38
|
+
device=None,
|
|
39
|
+
tile_size=512,
|
|
40
|
+
overlap=32,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the ImageSegmenter with the specified model and settings.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
|
|
47
|
+
device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
|
|
48
|
+
tile_size (int): Size of tiles to process the image in chunks. Defaults to 512.
|
|
49
|
+
overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 32.
|
|
50
|
+
"""
|
|
51
|
+
self.tile_size = tile_size
|
|
52
|
+
self.overlap = overlap
|
|
53
|
+
|
|
54
|
+
# Set device
|
|
55
|
+
if device is None:
|
|
56
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
57
|
+
else:
|
|
58
|
+
self.device = device
|
|
59
|
+
|
|
60
|
+
# Load model and processor
|
|
61
|
+
self.processor = CLIPSegProcessor.from_pretrained(model_name)
|
|
62
|
+
self.model = CLIPSegForImageSegmentation.from_pretrained(model_name).to(
|
|
63
|
+
self.device
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
print(f"Model loaded on {self.device}")
|
|
67
|
+
|
|
68
|
+
def segment_image(
|
|
69
|
+
self, input_path, output_path, text_prompt, threshold=0.5, smoothing_sigma=1.0
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Segment a GeoTIFF image using the provided text prompt.
|
|
73
|
+
|
|
74
|
+
The function processes the image in tiles and saves the result as a GeoTIFF with two bands:
|
|
75
|
+
- Band 1: Binary segmentation mask (0 or 1)
|
|
76
|
+
- Band 2: Probability scores (0.0 to 1.0)
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
input_path (str): Path to the input GeoTIFF file.
|
|
80
|
+
output_path (str): Path where the output GeoTIFF will be saved.
|
|
81
|
+
text_prompt (str): Text description of what to segment (e.g., "water", "buildings").
|
|
82
|
+
threshold (float): Threshold for binary segmentation (0.0 to 1.0). Defaults to 0.5.
|
|
83
|
+
smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
str: Path to the saved output file.
|
|
87
|
+
"""
|
|
88
|
+
# Open the input GeoTIFF
|
|
89
|
+
with rasterio.open(input_path) as src:
|
|
90
|
+
# Get metadata
|
|
91
|
+
meta = src.meta
|
|
92
|
+
height = src.height
|
|
93
|
+
width = src.width
|
|
94
|
+
|
|
95
|
+
# Create output metadata
|
|
96
|
+
out_meta = meta.copy()
|
|
97
|
+
out_meta.update({"count": 2, "dtype": "float32", "nodata": None})
|
|
98
|
+
|
|
99
|
+
# Create arrays for results
|
|
100
|
+
segmentation = np.zeros((height, width), dtype=np.float32)
|
|
101
|
+
probabilities = np.zeros((height, width), dtype=np.float32)
|
|
102
|
+
|
|
103
|
+
# Calculate effective tile size (accounting for overlap)
|
|
104
|
+
effective_tile_size = self.tile_size - 2 * self.overlap
|
|
105
|
+
|
|
106
|
+
# Calculate number of tiles
|
|
107
|
+
n_tiles_x = max(1, int(np.ceil(width / effective_tile_size)))
|
|
108
|
+
n_tiles_y = max(1, int(np.ceil(height / effective_tile_size)))
|
|
109
|
+
total_tiles = n_tiles_x * n_tiles_y
|
|
110
|
+
|
|
111
|
+
# Process tiles with tqdm progress bar
|
|
112
|
+
with tqdm(total=total_tiles, desc="Processing tiles") as pbar:
|
|
113
|
+
# Iterate through tiles
|
|
114
|
+
for y in range(n_tiles_y):
|
|
115
|
+
for x in range(n_tiles_x):
|
|
116
|
+
# Calculate tile coordinates with overlap
|
|
117
|
+
x_start = max(0, x * effective_tile_size - self.overlap)
|
|
118
|
+
y_start = max(0, y * effective_tile_size - self.overlap)
|
|
119
|
+
x_end = min(width, (x + 1) * effective_tile_size + self.overlap)
|
|
120
|
+
y_end = min(
|
|
121
|
+
height, (y + 1) * effective_tile_size + self.overlap
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
tile_width = x_end - x_start
|
|
125
|
+
tile_height = y_end - y_start
|
|
126
|
+
|
|
127
|
+
# Read the tile
|
|
128
|
+
window = Window(x_start, y_start, tile_width, tile_height)
|
|
129
|
+
tile_data = src.read(window=window)
|
|
130
|
+
|
|
131
|
+
# Process the tile
|
|
132
|
+
try:
|
|
133
|
+
# Convert to RGB if necessary (handling different satellite bands)
|
|
134
|
+
if tile_data.shape[0] > 3:
|
|
135
|
+
# Use first three bands for RGB representation
|
|
136
|
+
rgb_tile = tile_data[:3].transpose(1, 2, 0)
|
|
137
|
+
# Normalize data to 0-255 range if needed
|
|
138
|
+
if rgb_tile.max() > 0:
|
|
139
|
+
rgb_tile = (
|
|
140
|
+
(rgb_tile - rgb_tile.min())
|
|
141
|
+
/ (rgb_tile.max() - rgb_tile.min())
|
|
142
|
+
* 255
|
|
143
|
+
).astype(np.uint8)
|
|
144
|
+
elif tile_data.shape[0] == 1:
|
|
145
|
+
# Create RGB from grayscale
|
|
146
|
+
rgb_tile = np.repeat(
|
|
147
|
+
tile_data[0][:, :, np.newaxis], 3, axis=2
|
|
148
|
+
)
|
|
149
|
+
# Normalize if needed
|
|
150
|
+
if rgb_tile.max() > 0:
|
|
151
|
+
rgb_tile = (
|
|
152
|
+
(rgb_tile - rgb_tile.min())
|
|
153
|
+
/ (rgb_tile.max() - rgb_tile.min())
|
|
154
|
+
* 255
|
|
155
|
+
).astype(np.uint8)
|
|
156
|
+
else:
|
|
157
|
+
# Already 3-channel, assume RGB
|
|
158
|
+
rgb_tile = tile_data.transpose(1, 2, 0)
|
|
159
|
+
# Normalize if needed
|
|
160
|
+
if rgb_tile.max() > 0:
|
|
161
|
+
rgb_tile = (
|
|
162
|
+
(rgb_tile - rgb_tile.min())
|
|
163
|
+
/ (rgb_tile.max() - rgb_tile.min())
|
|
164
|
+
* 255
|
|
165
|
+
).astype(np.uint8)
|
|
166
|
+
|
|
167
|
+
# Convert to PIL Image
|
|
168
|
+
pil_image = Image.fromarray(rgb_tile)
|
|
169
|
+
|
|
170
|
+
# Resize if needed to match model's requirements
|
|
171
|
+
if (
|
|
172
|
+
pil_image.width > self.tile_size
|
|
173
|
+
or pil_image.height > self.tile_size
|
|
174
|
+
):
|
|
175
|
+
# Keep aspect ratio
|
|
176
|
+
pil_image.thumbnail(
|
|
177
|
+
(self.tile_size, self.tile_size), Image.LANCZOS
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Process with CLIP-Seg
|
|
181
|
+
inputs = self.processor(
|
|
182
|
+
text=text_prompt, images=pil_image, return_tensors="pt"
|
|
183
|
+
).to(self.device)
|
|
184
|
+
|
|
185
|
+
# Forward pass
|
|
186
|
+
with torch.no_grad():
|
|
187
|
+
outputs = self.model(**inputs)
|
|
188
|
+
|
|
189
|
+
# Get logits and resize to original tile size
|
|
190
|
+
logits = outputs.logits[0]
|
|
191
|
+
|
|
192
|
+
# Convert logits to probabilities with sigmoid
|
|
193
|
+
probs = torch.sigmoid(logits).cpu().numpy()
|
|
194
|
+
|
|
195
|
+
# Resize back to original tile size if needed
|
|
196
|
+
if probs.shape != (tile_height, tile_width):
|
|
197
|
+
# Use bicubic interpolation for smoother results
|
|
198
|
+
probs_resized = np.array(
|
|
199
|
+
Image.fromarray(probs).resize(
|
|
200
|
+
(tile_width, tile_height), Image.BICUBIC
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
probs_resized = probs
|
|
205
|
+
|
|
206
|
+
# Apply gaussian blur to reduce blockiness
|
|
207
|
+
try:
|
|
208
|
+
from scipy.ndimage import gaussian_filter
|
|
209
|
+
|
|
210
|
+
probs_resized = gaussian_filter(
|
|
211
|
+
probs_resized, sigma=smoothing_sigma
|
|
212
|
+
)
|
|
213
|
+
except ImportError:
|
|
214
|
+
pass # Continue without smoothing if scipy is not available
|
|
215
|
+
|
|
216
|
+
# Store results in the full arrays
|
|
217
|
+
# Only store the non-overlapping part (except at edges)
|
|
218
|
+
valid_x_start = self.overlap if x > 0 else 0
|
|
219
|
+
valid_y_start = self.overlap if y > 0 else 0
|
|
220
|
+
valid_x_end = (
|
|
221
|
+
tile_width - self.overlap
|
|
222
|
+
if x < n_tiles_x - 1
|
|
223
|
+
else tile_width
|
|
224
|
+
)
|
|
225
|
+
valid_y_end = (
|
|
226
|
+
tile_height - self.overlap
|
|
227
|
+
if y < n_tiles_y - 1
|
|
228
|
+
else tile_height
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
dest_x_start = x_start + valid_x_start
|
|
232
|
+
dest_y_start = y_start + valid_y_start
|
|
233
|
+
dest_x_end = x_start + valid_x_end
|
|
234
|
+
dest_y_end = y_start + valid_y_end
|
|
235
|
+
|
|
236
|
+
# Store probabilities
|
|
237
|
+
probabilities[
|
|
238
|
+
dest_y_start:dest_y_end, dest_x_start:dest_x_end
|
|
239
|
+
] = probs_resized[
|
|
240
|
+
valid_y_start:valid_y_end, valid_x_start:valid_x_end
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
print(f"Error processing tile at ({x}, {y}): {str(e)}")
|
|
245
|
+
# Continue with next tile
|
|
246
|
+
|
|
247
|
+
# Update progress bar
|
|
248
|
+
pbar.update(1)
|
|
249
|
+
|
|
250
|
+
# Create binary segmentation from probabilities
|
|
251
|
+
segmentation = (probabilities >= threshold).astype(np.float32)
|
|
252
|
+
|
|
253
|
+
# Write the output GeoTIFF
|
|
254
|
+
with rasterio.open(output_path, "w", **out_meta) as dst:
|
|
255
|
+
dst.write(segmentation, 1)
|
|
256
|
+
dst.write(probabilities, 2)
|
|
257
|
+
|
|
258
|
+
# Add descriptions to bands
|
|
259
|
+
dst.set_band_description(1, "Binary Segmentation")
|
|
260
|
+
dst.set_band_description(2, "Probability Scores")
|
|
261
|
+
|
|
262
|
+
print(f"Segmentation saved to {output_path}")
|
|
263
|
+
return output_path
|
|
264
|
+
|
|
265
|
+
def segment_image_batch(
|
|
266
|
+
self,
|
|
267
|
+
input_paths,
|
|
268
|
+
output_dir,
|
|
269
|
+
text_prompt,
|
|
270
|
+
threshold=0.5,
|
|
271
|
+
smoothing_sigma=1.0,
|
|
272
|
+
suffix="_segmented",
|
|
273
|
+
):
|
|
274
|
+
"""
|
|
275
|
+
Segment multiple GeoTIFF images using the provided text prompt.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
input_paths (list): List of paths to input GeoTIFF files.
|
|
279
|
+
output_dir (str): Directory where output GeoTIFFs will be saved.
|
|
280
|
+
text_prompt (str): Text description of what to segment.
|
|
281
|
+
threshold (float): Threshold for binary segmentation. Defaults to 0.5.
|
|
282
|
+
smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
|
|
283
|
+
suffix (str): Suffix to add to output filenames. Defaults to "_segmented".
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
list: Paths to all saved output files.
|
|
287
|
+
"""
|
|
288
|
+
# Create output directory if it doesn't exist
|
|
289
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
290
|
+
|
|
291
|
+
output_paths = []
|
|
292
|
+
|
|
293
|
+
# Process each input file
|
|
294
|
+
for input_path in tqdm(input_paths, desc="Processing files"):
|
|
295
|
+
# Generate output path
|
|
296
|
+
filename = os.path.basename(input_path)
|
|
297
|
+
base_name, ext = os.path.splitext(filename)
|
|
298
|
+
output_path = os.path.join(output_dir, f"{base_name}{suffix}{ext}")
|
|
299
|
+
|
|
300
|
+
# Segment the image
|
|
301
|
+
result_path = self.segment_image(
|
|
302
|
+
input_path, output_path, text_prompt, threshold, smoothing_sigma
|
|
303
|
+
)
|
|
304
|
+
output_paths.append(result_path)
|
|
305
|
+
|
|
306
|
+
return output_paths
|
geoai/segmentation.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
3
|
+
import albumentations as A
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
2
5
|
import numpy as np
|
|
3
|
-
from PIL import Image
|
|
4
6
|
import torch
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
from torch.utils.data import Dataset, Subset
|
|
7
7
|
import torch.nn.functional as F
|
|
8
|
-
from sklearn.model_selection import train_test_split
|
|
9
|
-
import albumentations as A
|
|
10
8
|
from albumentations.pytorch import ToTensorV2
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from sklearn.model_selection import train_test_split
|
|
11
|
+
from torch.utils.data import Dataset, Subset
|
|
11
12
|
from transformers import (
|
|
13
|
+
DefaultDataCollator,
|
|
14
|
+
SegformerForSemanticSegmentation,
|
|
12
15
|
Trainer,
|
|
13
16
|
TrainingArguments,
|
|
14
|
-
SegformerForSemanticSegmentation,
|
|
15
|
-
DefaultDataCollator,
|
|
16
17
|
)
|
|
17
18
|
|
|
18
19
|
|