geoai-py 0.3.4__py2.py3-none-any.whl → 0.3.6__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/segment.py ADDED
@@ -0,0 +1,305 @@
1
+ """This module provides functionality for segmenting high-resolution satellite imagery using vision-language models."""
2
+
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ from tqdm import tqdm
7
+ from PIL import Image
8
+ import rasterio
9
+ from rasterio.windows import Window
10
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
11
+
12
+
13
+ class CLIPSegmentation:
14
+ """
15
+ A class for segmenting high-resolution satellite imagery using text prompts with CLIP-based models.
16
+
17
+ This segmenter utilizes the CLIP-Seg model to perform semantic segmentation based on text prompts.
18
+ It can process large GeoTIFF files by tiling them and handles proper georeferencing in the output.
19
+
20
+ Args:
21
+ model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
22
+ device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
23
+ tile_size (int): Size of tiles to process the image in chunks. Defaults to 352.
24
+ overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 16.
25
+
26
+ Attributes:
27
+ processor (CLIPSegProcessor): The processor for the CLIP-Seg model.
28
+ model (CLIPSegForImageSegmentation): The CLIP-Seg model for segmentation.
29
+ device (str): The device being used ('cuda' or 'cpu').
30
+ tile_size (int): Size of tiles for processing.
31
+ overlap (int): Overlap between tiles.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model_name="CIDAS/clipseg-rd64-refined",
37
+ device=None,
38
+ tile_size=512,
39
+ overlap=32,
40
+ ):
41
+ """
42
+ Initialize the ImageSegmenter with the specified model and settings.
43
+
44
+ Args:
45
+ model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
46
+ device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
47
+ tile_size (int): Size of tiles to process the image in chunks. Defaults to 512.
48
+ overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 32.
49
+ """
50
+ self.tile_size = tile_size
51
+ self.overlap = overlap
52
+
53
+ # Set device
54
+ if device is None:
55
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
56
+ else:
57
+ self.device = device
58
+
59
+ # Load model and processor
60
+ self.processor = CLIPSegProcessor.from_pretrained(model_name)
61
+ self.model = CLIPSegForImageSegmentation.from_pretrained(model_name).to(
62
+ self.device
63
+ )
64
+
65
+ print(f"Model loaded on {self.device}")
66
+
67
+ def segment_image(
68
+ self, input_path, output_path, text_prompt, threshold=0.5, smoothing_sigma=1.0
69
+ ):
70
+ """
71
+ Segment a GeoTIFF image using the provided text prompt.
72
+
73
+ The function processes the image in tiles and saves the result as a GeoTIFF with two bands:
74
+ - Band 1: Binary segmentation mask (0 or 1)
75
+ - Band 2: Probability scores (0.0 to 1.0)
76
+
77
+ Args:
78
+ input_path (str): Path to the input GeoTIFF file.
79
+ output_path (str): Path where the output GeoTIFF will be saved.
80
+ text_prompt (str): Text description of what to segment (e.g., "water", "buildings").
81
+ threshold (float): Threshold for binary segmentation (0.0 to 1.0). Defaults to 0.5.
82
+ smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
83
+
84
+ Returns:
85
+ str: Path to the saved output file.
86
+ """
87
+ # Open the input GeoTIFF
88
+ with rasterio.open(input_path) as src:
89
+ # Get metadata
90
+ meta = src.meta
91
+ height = src.height
92
+ width = src.width
93
+
94
+ # Create output metadata
95
+ out_meta = meta.copy()
96
+ out_meta.update({"count": 2, "dtype": "float32", "nodata": None})
97
+
98
+ # Create arrays for results
99
+ segmentation = np.zeros((height, width), dtype=np.float32)
100
+ probabilities = np.zeros((height, width), dtype=np.float32)
101
+
102
+ # Calculate effective tile size (accounting for overlap)
103
+ effective_tile_size = self.tile_size - 2 * self.overlap
104
+
105
+ # Calculate number of tiles
106
+ n_tiles_x = max(1, int(np.ceil(width / effective_tile_size)))
107
+ n_tiles_y = max(1, int(np.ceil(height / effective_tile_size)))
108
+ total_tiles = n_tiles_x * n_tiles_y
109
+
110
+ # Process tiles with tqdm progress bar
111
+ with tqdm(total=total_tiles, desc="Processing tiles") as pbar:
112
+ # Iterate through tiles
113
+ for y in range(n_tiles_y):
114
+ for x in range(n_tiles_x):
115
+ # Calculate tile coordinates with overlap
116
+ x_start = max(0, x * effective_tile_size - self.overlap)
117
+ y_start = max(0, y * effective_tile_size - self.overlap)
118
+ x_end = min(width, (x + 1) * effective_tile_size + self.overlap)
119
+ y_end = min(
120
+ height, (y + 1) * effective_tile_size + self.overlap
121
+ )
122
+
123
+ tile_width = x_end - x_start
124
+ tile_height = y_end - y_start
125
+
126
+ # Read the tile
127
+ window = Window(x_start, y_start, tile_width, tile_height)
128
+ tile_data = src.read(window=window)
129
+
130
+ # Process the tile
131
+ try:
132
+ # Convert to RGB if necessary (handling different satellite bands)
133
+ if tile_data.shape[0] > 3:
134
+ # Use first three bands for RGB representation
135
+ rgb_tile = tile_data[:3].transpose(1, 2, 0)
136
+ # Normalize data to 0-255 range if needed
137
+ if rgb_tile.max() > 0:
138
+ rgb_tile = (
139
+ (rgb_tile - rgb_tile.min())
140
+ / (rgb_tile.max() - rgb_tile.min())
141
+ * 255
142
+ ).astype(np.uint8)
143
+ elif tile_data.shape[0] == 1:
144
+ # Create RGB from grayscale
145
+ rgb_tile = np.repeat(
146
+ tile_data[0][:, :, np.newaxis], 3, axis=2
147
+ )
148
+ # Normalize if needed
149
+ if rgb_tile.max() > 0:
150
+ rgb_tile = (
151
+ (rgb_tile - rgb_tile.min())
152
+ / (rgb_tile.max() - rgb_tile.min())
153
+ * 255
154
+ ).astype(np.uint8)
155
+ else:
156
+ # Already 3-channel, assume RGB
157
+ rgb_tile = tile_data.transpose(1, 2, 0)
158
+ # Normalize if needed
159
+ if rgb_tile.max() > 0:
160
+ rgb_tile = (
161
+ (rgb_tile - rgb_tile.min())
162
+ / (rgb_tile.max() - rgb_tile.min())
163
+ * 255
164
+ ).astype(np.uint8)
165
+
166
+ # Convert to PIL Image
167
+ pil_image = Image.fromarray(rgb_tile)
168
+
169
+ # Resize if needed to match model's requirements
170
+ if (
171
+ pil_image.width > self.tile_size
172
+ or pil_image.height > self.tile_size
173
+ ):
174
+ # Keep aspect ratio
175
+ pil_image.thumbnail(
176
+ (self.tile_size, self.tile_size), Image.LANCZOS
177
+ )
178
+
179
+ # Process with CLIP-Seg
180
+ inputs = self.processor(
181
+ text=text_prompt, images=pil_image, return_tensors="pt"
182
+ ).to(self.device)
183
+
184
+ # Forward pass
185
+ with torch.no_grad():
186
+ outputs = self.model(**inputs)
187
+
188
+ # Get logits and resize to original tile size
189
+ logits = outputs.logits[0]
190
+
191
+ # Convert logits to probabilities with sigmoid
192
+ probs = torch.sigmoid(logits).cpu().numpy()
193
+
194
+ # Resize back to original tile size if needed
195
+ if probs.shape != (tile_height, tile_width):
196
+ # Use bicubic interpolation for smoother results
197
+ probs_resized = np.array(
198
+ Image.fromarray(probs).resize(
199
+ (tile_width, tile_height), Image.BICUBIC
200
+ )
201
+ )
202
+ else:
203
+ probs_resized = probs
204
+
205
+ # Apply gaussian blur to reduce blockiness
206
+ try:
207
+ from scipy.ndimage import gaussian_filter
208
+
209
+ probs_resized = gaussian_filter(
210
+ probs_resized, sigma=smoothing_sigma
211
+ )
212
+ except ImportError:
213
+ pass # Continue without smoothing if scipy is not available
214
+
215
+ # Store results in the full arrays
216
+ # Only store the non-overlapping part (except at edges)
217
+ valid_x_start = self.overlap if x > 0 else 0
218
+ valid_y_start = self.overlap if y > 0 else 0
219
+ valid_x_end = (
220
+ tile_width - self.overlap
221
+ if x < n_tiles_x - 1
222
+ else tile_width
223
+ )
224
+ valid_y_end = (
225
+ tile_height - self.overlap
226
+ if y < n_tiles_y - 1
227
+ else tile_height
228
+ )
229
+
230
+ dest_x_start = x_start + valid_x_start
231
+ dest_y_start = y_start + valid_y_start
232
+ dest_x_end = x_start + valid_x_end
233
+ dest_y_end = y_start + valid_y_end
234
+
235
+ # Store probabilities
236
+ probabilities[
237
+ dest_y_start:dest_y_end, dest_x_start:dest_x_end
238
+ ] = probs_resized[
239
+ valid_y_start:valid_y_end, valid_x_start:valid_x_end
240
+ ]
241
+
242
+ except Exception as e:
243
+ print(f"Error processing tile at ({x}, {y}): {str(e)}")
244
+ # Continue with next tile
245
+
246
+ # Update progress bar
247
+ pbar.update(1)
248
+
249
+ # Create binary segmentation from probabilities
250
+ segmentation = (probabilities >= threshold).astype(np.float32)
251
+
252
+ # Write the output GeoTIFF
253
+ with rasterio.open(output_path, "w", **out_meta) as dst:
254
+ dst.write(segmentation, 1)
255
+ dst.write(probabilities, 2)
256
+
257
+ # Add descriptions to bands
258
+ dst.set_band_description(1, "Binary Segmentation")
259
+ dst.set_band_description(2, "Probability Scores")
260
+
261
+ print(f"Segmentation saved to {output_path}")
262
+ return output_path
263
+
264
+ def segment_image_batch(
265
+ self,
266
+ input_paths,
267
+ output_dir,
268
+ text_prompt,
269
+ threshold=0.5,
270
+ smoothing_sigma=1.0,
271
+ suffix="_segmented",
272
+ ):
273
+ """
274
+ Segment multiple GeoTIFF images using the provided text prompt.
275
+
276
+ Args:
277
+ input_paths (list): List of paths to input GeoTIFF files.
278
+ output_dir (str): Directory where output GeoTIFFs will be saved.
279
+ text_prompt (str): Text description of what to segment.
280
+ threshold (float): Threshold for binary segmentation. Defaults to 0.5.
281
+ smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
282
+ suffix (str): Suffix to add to output filenames. Defaults to "_segmented".
283
+
284
+ Returns:
285
+ list: Paths to all saved output files.
286
+ """
287
+ # Create output directory if it doesn't exist
288
+ os.makedirs(output_dir, exist_ok=True)
289
+
290
+ output_paths = []
291
+
292
+ # Process each input file
293
+ for input_path in tqdm(input_paths, desc="Processing files"):
294
+ # Generate output path
295
+ filename = os.path.basename(input_path)
296
+ base_name, ext = os.path.splitext(filename)
297
+ output_path = os.path.join(output_dir, f"{base_name}{suffix}{ext}")
298
+
299
+ # Segment the image
300
+ result_path = self.segment_image(
301
+ input_path, output_path, text_prompt, threshold, smoothing_sigma
302
+ )
303
+ output_paths.append(result_path)
304
+
305
+ return output_paths