geoai-py 0.3.5__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/segment.py ADDED
@@ -0,0 +1,306 @@
1
+ """This module provides functionality for segmenting high-resolution satellite imagery using vision-language models."""
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ import rasterio
7
+ import torch
8
+ from PIL import Image
9
+ from rasterio.windows import Window
10
+ from tqdm import tqdm
11
+ from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
12
+
13
+
14
+ class CLIPSegmentation:
15
+ """
16
+ A class for segmenting high-resolution satellite imagery using text prompts with CLIP-based models.
17
+
18
+ This segmenter utilizes the CLIP-Seg model to perform semantic segmentation based on text prompts.
19
+ It can process large GeoTIFF files by tiling them and handles proper georeferencing in the output.
20
+
21
+ Args:
22
+ model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
23
+ device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
24
+ tile_size (int): Size of tiles to process the image in chunks. Defaults to 352.
25
+ overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 16.
26
+
27
+ Attributes:
28
+ processor (CLIPSegProcessor): The processor for the CLIP-Seg model.
29
+ model (CLIPSegForImageSegmentation): The CLIP-Seg model for segmentation.
30
+ device (str): The device being used ('cuda' or 'cpu').
31
+ tile_size (int): Size of tiles for processing.
32
+ overlap (int): Overlap between tiles.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ model_name="CIDAS/clipseg-rd64-refined",
38
+ device=None,
39
+ tile_size=512,
40
+ overlap=32,
41
+ ):
42
+ """
43
+ Initialize the ImageSegmenter with the specified model and settings.
44
+
45
+ Args:
46
+ model_name (str): Name of the CLIP-Seg model to use. Defaults to "CIDAS/clipseg-rd64-refined".
47
+ device (str): Device to run the model on ('cuda', 'cpu'). If None, will use CUDA if available.
48
+ tile_size (int): Size of tiles to process the image in chunks. Defaults to 512.
49
+ overlap (int): Overlap between tiles to avoid edge artifacts. Defaults to 32.
50
+ """
51
+ self.tile_size = tile_size
52
+ self.overlap = overlap
53
+
54
+ # Set device
55
+ if device is None:
56
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ else:
58
+ self.device = device
59
+
60
+ # Load model and processor
61
+ self.processor = CLIPSegProcessor.from_pretrained(model_name)
62
+ self.model = CLIPSegForImageSegmentation.from_pretrained(model_name).to(
63
+ self.device
64
+ )
65
+
66
+ print(f"Model loaded on {self.device}")
67
+
68
+ def segment_image(
69
+ self, input_path, output_path, text_prompt, threshold=0.5, smoothing_sigma=1.0
70
+ ):
71
+ """
72
+ Segment a GeoTIFF image using the provided text prompt.
73
+
74
+ The function processes the image in tiles and saves the result as a GeoTIFF with two bands:
75
+ - Band 1: Binary segmentation mask (0 or 1)
76
+ - Band 2: Probability scores (0.0 to 1.0)
77
+
78
+ Args:
79
+ input_path (str): Path to the input GeoTIFF file.
80
+ output_path (str): Path where the output GeoTIFF will be saved.
81
+ text_prompt (str): Text description of what to segment (e.g., "water", "buildings").
82
+ threshold (float): Threshold for binary segmentation (0.0 to 1.0). Defaults to 0.5.
83
+ smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
84
+
85
+ Returns:
86
+ str: Path to the saved output file.
87
+ """
88
+ # Open the input GeoTIFF
89
+ with rasterio.open(input_path) as src:
90
+ # Get metadata
91
+ meta = src.meta
92
+ height = src.height
93
+ width = src.width
94
+
95
+ # Create output metadata
96
+ out_meta = meta.copy()
97
+ out_meta.update({"count": 2, "dtype": "float32", "nodata": None})
98
+
99
+ # Create arrays for results
100
+ segmentation = np.zeros((height, width), dtype=np.float32)
101
+ probabilities = np.zeros((height, width), dtype=np.float32)
102
+
103
+ # Calculate effective tile size (accounting for overlap)
104
+ effective_tile_size = self.tile_size - 2 * self.overlap
105
+
106
+ # Calculate number of tiles
107
+ n_tiles_x = max(1, int(np.ceil(width / effective_tile_size)))
108
+ n_tiles_y = max(1, int(np.ceil(height / effective_tile_size)))
109
+ total_tiles = n_tiles_x * n_tiles_y
110
+
111
+ # Process tiles with tqdm progress bar
112
+ with tqdm(total=total_tiles, desc="Processing tiles") as pbar:
113
+ # Iterate through tiles
114
+ for y in range(n_tiles_y):
115
+ for x in range(n_tiles_x):
116
+ # Calculate tile coordinates with overlap
117
+ x_start = max(0, x * effective_tile_size - self.overlap)
118
+ y_start = max(0, y * effective_tile_size - self.overlap)
119
+ x_end = min(width, (x + 1) * effective_tile_size + self.overlap)
120
+ y_end = min(
121
+ height, (y + 1) * effective_tile_size + self.overlap
122
+ )
123
+
124
+ tile_width = x_end - x_start
125
+ tile_height = y_end - y_start
126
+
127
+ # Read the tile
128
+ window = Window(x_start, y_start, tile_width, tile_height)
129
+ tile_data = src.read(window=window)
130
+
131
+ # Process the tile
132
+ try:
133
+ # Convert to RGB if necessary (handling different satellite bands)
134
+ if tile_data.shape[0] > 3:
135
+ # Use first three bands for RGB representation
136
+ rgb_tile = tile_data[:3].transpose(1, 2, 0)
137
+ # Normalize data to 0-255 range if needed
138
+ if rgb_tile.max() > 0:
139
+ rgb_tile = (
140
+ (rgb_tile - rgb_tile.min())
141
+ / (rgb_tile.max() - rgb_tile.min())
142
+ * 255
143
+ ).astype(np.uint8)
144
+ elif tile_data.shape[0] == 1:
145
+ # Create RGB from grayscale
146
+ rgb_tile = np.repeat(
147
+ tile_data[0][:, :, np.newaxis], 3, axis=2
148
+ )
149
+ # Normalize if needed
150
+ if rgb_tile.max() > 0:
151
+ rgb_tile = (
152
+ (rgb_tile - rgb_tile.min())
153
+ / (rgb_tile.max() - rgb_tile.min())
154
+ * 255
155
+ ).astype(np.uint8)
156
+ else:
157
+ # Already 3-channel, assume RGB
158
+ rgb_tile = tile_data.transpose(1, 2, 0)
159
+ # Normalize if needed
160
+ if rgb_tile.max() > 0:
161
+ rgb_tile = (
162
+ (rgb_tile - rgb_tile.min())
163
+ / (rgb_tile.max() - rgb_tile.min())
164
+ * 255
165
+ ).astype(np.uint8)
166
+
167
+ # Convert to PIL Image
168
+ pil_image = Image.fromarray(rgb_tile)
169
+
170
+ # Resize if needed to match model's requirements
171
+ if (
172
+ pil_image.width > self.tile_size
173
+ or pil_image.height > self.tile_size
174
+ ):
175
+ # Keep aspect ratio
176
+ pil_image.thumbnail(
177
+ (self.tile_size, self.tile_size), Image.LANCZOS
178
+ )
179
+
180
+ # Process with CLIP-Seg
181
+ inputs = self.processor(
182
+ text=text_prompt, images=pil_image, return_tensors="pt"
183
+ ).to(self.device)
184
+
185
+ # Forward pass
186
+ with torch.no_grad():
187
+ outputs = self.model(**inputs)
188
+
189
+ # Get logits and resize to original tile size
190
+ logits = outputs.logits[0]
191
+
192
+ # Convert logits to probabilities with sigmoid
193
+ probs = torch.sigmoid(logits).cpu().numpy()
194
+
195
+ # Resize back to original tile size if needed
196
+ if probs.shape != (tile_height, tile_width):
197
+ # Use bicubic interpolation for smoother results
198
+ probs_resized = np.array(
199
+ Image.fromarray(probs).resize(
200
+ (tile_width, tile_height), Image.BICUBIC
201
+ )
202
+ )
203
+ else:
204
+ probs_resized = probs
205
+
206
+ # Apply gaussian blur to reduce blockiness
207
+ try:
208
+ from scipy.ndimage import gaussian_filter
209
+
210
+ probs_resized = gaussian_filter(
211
+ probs_resized, sigma=smoothing_sigma
212
+ )
213
+ except ImportError:
214
+ pass # Continue without smoothing if scipy is not available
215
+
216
+ # Store results in the full arrays
217
+ # Only store the non-overlapping part (except at edges)
218
+ valid_x_start = self.overlap if x > 0 else 0
219
+ valid_y_start = self.overlap if y > 0 else 0
220
+ valid_x_end = (
221
+ tile_width - self.overlap
222
+ if x < n_tiles_x - 1
223
+ else tile_width
224
+ )
225
+ valid_y_end = (
226
+ tile_height - self.overlap
227
+ if y < n_tiles_y - 1
228
+ else tile_height
229
+ )
230
+
231
+ dest_x_start = x_start + valid_x_start
232
+ dest_y_start = y_start + valid_y_start
233
+ dest_x_end = x_start + valid_x_end
234
+ dest_y_end = y_start + valid_y_end
235
+
236
+ # Store probabilities
237
+ probabilities[
238
+ dest_y_start:dest_y_end, dest_x_start:dest_x_end
239
+ ] = probs_resized[
240
+ valid_y_start:valid_y_end, valid_x_start:valid_x_end
241
+ ]
242
+
243
+ except Exception as e:
244
+ print(f"Error processing tile at ({x}, {y}): {str(e)}")
245
+ # Continue with next tile
246
+
247
+ # Update progress bar
248
+ pbar.update(1)
249
+
250
+ # Create binary segmentation from probabilities
251
+ segmentation = (probabilities >= threshold).astype(np.float32)
252
+
253
+ # Write the output GeoTIFF
254
+ with rasterio.open(output_path, "w", **out_meta) as dst:
255
+ dst.write(segmentation, 1)
256
+ dst.write(probabilities, 2)
257
+
258
+ # Add descriptions to bands
259
+ dst.set_band_description(1, "Binary Segmentation")
260
+ dst.set_band_description(2, "Probability Scores")
261
+
262
+ print(f"Segmentation saved to {output_path}")
263
+ return output_path
264
+
265
+ def segment_image_batch(
266
+ self,
267
+ input_paths,
268
+ output_dir,
269
+ text_prompt,
270
+ threshold=0.5,
271
+ smoothing_sigma=1.0,
272
+ suffix="_segmented",
273
+ ):
274
+ """
275
+ Segment multiple GeoTIFF images using the provided text prompt.
276
+
277
+ Args:
278
+ input_paths (list): List of paths to input GeoTIFF files.
279
+ output_dir (str): Directory where output GeoTIFFs will be saved.
280
+ text_prompt (str): Text description of what to segment.
281
+ threshold (float): Threshold for binary segmentation. Defaults to 0.5.
282
+ smoothing_sigma (float): Sigma value for Gaussian smoothing to reduce blockiness. Defaults to 1.0.
283
+ suffix (str): Suffix to add to output filenames. Defaults to "_segmented".
284
+
285
+ Returns:
286
+ list: Paths to all saved output files.
287
+ """
288
+ # Create output directory if it doesn't exist
289
+ os.makedirs(output_dir, exist_ok=True)
290
+
291
+ output_paths = []
292
+
293
+ # Process each input file
294
+ for input_path in tqdm(input_paths, desc="Processing files"):
295
+ # Generate output path
296
+ filename = os.path.basename(input_path)
297
+ base_name, ext = os.path.splitext(filename)
298
+ output_path = os.path.join(output_dir, f"{base_name}{suffix}{ext}")
299
+
300
+ # Segment the image
301
+ result_path = self.segment_image(
302
+ input_path, output_path, text_prompt, threshold, smoothing_sigma
303
+ )
304
+ output_paths.append(result_path)
305
+
306
+ return output_paths
geoai/segmentation.py CHANGED
@@ -1,18 +1,19 @@
1
1
  import os
2
+
3
+ import albumentations as A
4
+ import matplotlib.pyplot as plt
2
5
  import numpy as np
3
- from PIL import Image
4
6
  import torch
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import Dataset, Subset
7
7
  import torch.nn.functional as F
8
- from sklearn.model_selection import train_test_split
9
- import albumentations as A
10
8
  from albumentations.pytorch import ToTensorV2
9
+ from PIL import Image
10
+ from sklearn.model_selection import train_test_split
11
+ from torch.utils.data import Dataset, Subset
11
12
  from transformers import (
13
+ DefaultDataCollator,
14
+ SegformerForSemanticSegmentation,
12
15
  Trainer,
13
16
  TrainingArguments,
14
- SegformerForSemanticSegmentation,
15
- DefaultDataCollator,
16
17
  )
17
18
 
18
19