geoai-py 0.3.6__py2.py3-none-any.whl → 0.4.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/hf.py ADDED
@@ -0,0 +1,447 @@
1
+ """This module contains utility functions for working with Hugging Face models."""
2
+
3
+ import csv
4
+ import os
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import rasterio
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+ from transformers import AutoConfig, AutoModelForMaskedImageModeling, pipeline
13
+
14
+
15
+ def get_model_config(model_id):
16
+ """
17
+ Get the model configuration for a Hugging Face model.
18
+
19
+ Args:
20
+ model_id (str): The Hugging Face model ID.
21
+
22
+ Returns:
23
+ transformers.configuration_utils.PretrainedConfig: The model configuration.
24
+ """
25
+ return AutoConfig.from_pretrained(model_id)
26
+
27
+
28
+ def get_model_input_channels(model_id):
29
+ """
30
+ Check the number of input channels supported by a Hugging Face model.
31
+
32
+ Args:
33
+ model_id (str): The Hugging Face model ID.
34
+
35
+ Returns:
36
+ int: The number of input channels the model accepts.
37
+
38
+ Raises:
39
+ ValueError: If unable to determine the number of input channels.
40
+ """
41
+ # Load the model configuration
42
+ config = AutoConfig.from_pretrained(model_id)
43
+
44
+ # For Mask2Former models
45
+ if hasattr(config, "backbone_config"):
46
+ if hasattr(config.backbone_config, "num_channels"):
47
+ return config.backbone_config.num_channels
48
+
49
+ # Try to load the model and inspect its architecture
50
+ try:
51
+ model = AutoModelForMaskedImageModeling.from_pretrained(model_id)
52
+
53
+ # For Swin Transformer-based models like Mask2Former
54
+ if hasattr(model, "backbone") and hasattr(model.backbone, "embeddings"):
55
+ if hasattr(model.backbone.embeddings, "patch_embeddings"):
56
+ # Swin models typically have patch embeddings that indicate channel count
57
+ return model.backbone.embeddings.patch_embeddings.in_channels
58
+ except Exception as e:
59
+ print(f"Couldn't inspect model architecture: {e}")
60
+
61
+ # Default for most vision models
62
+ return 3
63
+
64
+
65
+ def image_segmentation(
66
+ tif_path,
67
+ output_path,
68
+ labels_to_extract=None,
69
+ dtype="uint8",
70
+ model_name=None,
71
+ segmenter_args=None,
72
+ **kwargs,
73
+ ):
74
+ """
75
+ Segments an image with a Hugging Face segmentation model and saves the results
76
+ as a single georeferenced image where each class has a unique integer value.
77
+
78
+ Args:
79
+ tif_path (str): Path to the input georeferenced TIF file.
80
+ output_path (str): Path where the output georeferenced segmentation will be saved.
81
+ labels_to_extract (list, optional): List of labels to extract. If None, extracts all labels.
82
+ dtype (str, optional): Data type to use for the output mask. Defaults to "uint8".
83
+ model_name (str, optional): Name of the Hugging Face model to use for segmentation,
84
+ such as "facebook/mask2former-swin-large-cityscapes-semantic". Defaults to None.
85
+ See https://huggingface.co/models?pipeline_tag=image-segmentation&sort=trending for options.
86
+ segmenter_args (dict, optional): Additional arguments to pass to the segmenter.
87
+ Defaults to None.
88
+ **kwargs: Additional keyword arguments to pass to the segmentation pipeline
89
+
90
+ Returns:
91
+ tuple: (Path to saved image, dictionary mapping label names to their assigned values,
92
+ dictionary mapping label names to confidence scores)
93
+ """
94
+ # Load the original georeferenced image to extract metadata
95
+ with rasterio.open(tif_path) as src:
96
+ # Save the metadata for later use
97
+ meta = src.meta.copy()
98
+ # Get the dimensions
99
+ height = src.height
100
+ width = src.width
101
+ # Get the transform and CRS for georeferencing
102
+ # transform = src.transform
103
+ # crs = src.crs
104
+
105
+ # Initialize the segmentation pipeline
106
+ if model_name is None:
107
+ model_name = "facebook/mask2former-swin-large-cityscapes-semantic"
108
+
109
+ kwargs["task"] = "image-segmentation"
110
+
111
+ segmenter = pipeline(model=model_name, **kwargs)
112
+
113
+ # Run the segmentation on the GeoTIFF
114
+ if segmenter_args is None:
115
+ segmenter_args = {}
116
+
117
+ segments = segmenter(tif_path, **segmenter_args)
118
+
119
+ # If no specific labels are requested, extract all available ones
120
+ if labels_to_extract is None:
121
+ labels_to_extract = [segment["label"] for segment in segments]
122
+
123
+ # Create an empty mask to hold all the labels
124
+ # Using uint8 for up to 255 classes, switch to uint16 for more
125
+ combined_mask = np.zeros((height, width), dtype=np.uint8)
126
+
127
+ # Create a dictionary to map labels to values and store scores
128
+ label_to_value = {}
129
+ label_to_score = {}
130
+
131
+ # Process each segment we want to keep
132
+ for i, segment in enumerate(
133
+ [s for s in segments if s["label"] in labels_to_extract]
134
+ ):
135
+ # Assign a unique value to each label (starting from 1)
136
+ value = i + 1
137
+ label = segment["label"]
138
+ score = segment["score"]
139
+
140
+ label_to_value[label] = value
141
+ label_to_score[label] = score
142
+
143
+ # Convert PIL image to numpy array
144
+ mask = np.array(segment["mask"])
145
+
146
+ # Apply a threshold if it's a probability mask (not binary)
147
+ if mask.dtype == float:
148
+ mask = (mask > 0.5).astype(np.uint8)
149
+
150
+ # Resize if needed to match original dimensions
151
+ if mask.shape != (height, width):
152
+ mask_img = Image.fromarray(mask)
153
+ mask_img = mask_img.resize((width, height))
154
+ mask = np.array(mask_img)
155
+
156
+ # Add this class to the combined mask
157
+ # Only overwrite if the pixel isn't already assigned to another class
158
+ # This handles overlapping segments by giving priority to earlier segments
159
+ combined_mask = np.where(
160
+ (mask > 0) & (combined_mask == 0), value, combined_mask
161
+ )
162
+
163
+ # Update metadata for the output raster
164
+ meta.update(
165
+ {
166
+ "count": 1, # One band for the mask
167
+ "dtype": dtype, # Use uint8 for up to 255 classes
168
+ "nodata": 0, # 0 represents no class
169
+ }
170
+ )
171
+
172
+ # Save the mask as a new georeferenced GeoTIFF
173
+ with rasterio.open(output_path, "w", **meta) as dst:
174
+ dst.write(combined_mask[np.newaxis, :, :]) # Add channel dimension
175
+
176
+ # Create a CSV colormap file with scores included
177
+ csv_path = os.path.splitext(output_path)[0] + "_colormap.csv"
178
+ with open(csv_path, "w", newline="") as csvfile:
179
+ fieldnames = ["ClassValue", "ClassName", "ConfidenceScore"]
180
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
181
+
182
+ writer.writeheader()
183
+ for label, value in label_to_value.items():
184
+ writer.writerow(
185
+ {
186
+ "ClassValue": value,
187
+ "ClassName": label,
188
+ "ConfidenceScore": f"{label_to_score[label]:.4f}",
189
+ }
190
+ )
191
+
192
+ return output_path, label_to_value, label_to_score
193
+
194
+
195
+ def mask_generation(
196
+ input_path: str,
197
+ output_mask_path: str,
198
+ output_csv_path: str,
199
+ model: str = "facebook/sam-vit-base",
200
+ confidence_threshold: float = 0.5,
201
+ points_per_side: int = 32,
202
+ crop_size: Optional[int] = None,
203
+ batch_size: int = 1,
204
+ band_indices: Optional[List[int]] = None,
205
+ min_object_size: int = 0,
206
+ generator_kwargs: Optional[Dict] = None,
207
+ **kwargs,
208
+ ) -> Tuple[str, str]:
209
+ """
210
+ Process a GeoTIFF using SAM mask generation and save results as a GeoTIFF and CSV.
211
+
212
+ The function reads a GeoTIFF image, applies the SAM mask generator from the
213
+ Hugging Face transformers pipeline, rasterizes the resulting masks to create
214
+ a labeled mask GeoTIFF, and saves mask scores and geometries to a CSV file.
215
+
216
+ Args:
217
+ input_path: Path to the input GeoTIFF image.
218
+ output_mask_path: Path where the output mask GeoTIFF will be saved.
219
+ output_csv_path: Path where the mask scores CSV will be saved.
220
+ model: HuggingFace model checkpoint for the SAM model.
221
+ confidence_threshold: Minimum confidence score for masks to be included.
222
+ points_per_side: Number of points to sample along each side of the image.
223
+ crop_size: Size of image crops for processing. If None, process the full image.
224
+ band_indices: List of band indices to use. If None, use all bands.
225
+ batch_size: Batch size for inference.
226
+ min_object_size: Minimum size in pixels for objects to be included. Smaller masks will be filtered out.
227
+ generator_kwargs: Additional keyword arguments to pass to the mask generator.
228
+
229
+ Returns:
230
+ Tuple containing the paths to the saved mask GeoTIFF and CSV file.
231
+
232
+ Raises:
233
+ ValueError: If the input file cannot be opened or processed.
234
+ RuntimeError: If mask generation fails.
235
+ """
236
+ # Set up the mask generator
237
+ print("Setting up mask generator...")
238
+ mask_generator = pipeline(model=model, task="mask-generation", **kwargs)
239
+
240
+ # Open the GeoTIFF file
241
+ try:
242
+ print(f"Reading input GeoTIFF: {input_path}")
243
+ with rasterio.open(input_path) as src:
244
+ # Read metadata
245
+ profile = src.profile
246
+ # transform = src.transform
247
+ # crs = src.crs
248
+
249
+ # Read the image data
250
+ if band_indices is not None:
251
+ print(f"Using specified bands: {band_indices}")
252
+ image_data = np.stack([src.read(i + 1) for i in band_indices])
253
+ else:
254
+ print("Using all bands")
255
+ image_data = src.read()
256
+
257
+ # Handle image with more than 3 bands (convert to RGB for visualization)
258
+ if image_data.shape[0] > 3:
259
+ print(
260
+ f"Converting {image_data.shape[0]} bands to RGB (using first 3 bands)"
261
+ )
262
+ # Select first three bands or perform other band combination
263
+ image_data = image_data[:3]
264
+ elif image_data.shape[0] == 1:
265
+ print("Duplicating single band to create 3-band image")
266
+ # Duplicate single band to create a 3-band image
267
+ image_data = np.vstack([image_data] * 3)
268
+
269
+ # Transpose to HWC format for the model
270
+ image_data = np.transpose(image_data, (1, 2, 0))
271
+
272
+ # Normalize the image if needed
273
+ if image_data.dtype != np.uint8:
274
+ print(f"Normalizing image from {image_data.dtype} to uint8")
275
+ image_data = (image_data / image_data.max() * 255).astype(np.uint8)
276
+ except Exception as e:
277
+ raise ValueError(f"Failed to open or process input GeoTIFF: {e}")
278
+
279
+ # Process the image with the mask generator
280
+ try:
281
+ # Convert numpy array to PIL Image for the pipeline
282
+ # Ensure the array is in the right format (HWC and uint8)
283
+ if image_data.dtype != np.uint8:
284
+ image_data = (image_data / image_data.max() * 255).astype(np.uint8)
285
+
286
+ # Create a PIL Image from the numpy array
287
+ print("Converting to PIL Image for mask generation")
288
+ pil_image = Image.fromarray(image_data)
289
+
290
+ # Use the SAM pipeline for mask generation
291
+ if generator_kwargs is None:
292
+ generator_kwargs = {}
293
+
294
+ print("Running mask generation...")
295
+ mask_results = mask_generator(
296
+ pil_image,
297
+ points_per_side=points_per_side,
298
+ crop_n_points_downscale_factor=1 if crop_size is None else 2,
299
+ point_grids=None,
300
+ pred_iou_thresh=confidence_threshold,
301
+ stability_score_thresh=confidence_threshold,
302
+ crops_n_layers=0 if crop_size is None else 1,
303
+ crop_overlap_ratio=0.5,
304
+ batch_size=batch_size,
305
+ **generator_kwargs,
306
+ )
307
+
308
+ print(
309
+ f"Number of initial masks: {len(mask_results['masks']) if isinstance(mask_results, dict) and 'masks' in mask_results else len(mask_results)}"
310
+ )
311
+
312
+ except Exception as e:
313
+ raise RuntimeError(f"Mask generation failed: {e}")
314
+
315
+ # Create a mask raster with unique IDs for each mask
316
+ mask_raster = np.zeros((image_data.shape[0], image_data.shape[1]), dtype=np.uint32)
317
+ mask_records = []
318
+
319
+ # Process each mask based on the structure of mask_results
320
+ if (
321
+ isinstance(mask_results, dict)
322
+ and "masks" in mask_results
323
+ and "scores" in mask_results
324
+ ):
325
+ # Handle dictionary with 'masks' and 'scores' lists
326
+ print("Processing masks...")
327
+ total_masks = len(mask_results["masks"])
328
+
329
+ # Create progress bar
330
+ for i, (mask_data, score) in enumerate(
331
+ tqdm(
332
+ zip(mask_results["masks"], mask_results["scores"]),
333
+ total=total_masks,
334
+ desc="Processing masks",
335
+ )
336
+ ):
337
+ mask_id = i + 1 # Start IDs at 1
338
+
339
+ # Convert to numpy if not already
340
+ if not isinstance(mask_data, np.ndarray):
341
+ # Try to convert from tensor or other format if needed
342
+ try:
343
+ mask_data = np.array(mask_data)
344
+ except:
345
+ print(f"Could not convert mask at index {i} to numpy array")
346
+ continue
347
+
348
+ mask_binary = mask_data.astype(bool)
349
+ area_pixels = np.sum(mask_binary)
350
+
351
+ # Skip if mask is smaller than the minimum size
352
+ if area_pixels < min_object_size:
353
+ continue
354
+
355
+ # Add the mask to the raster with a unique ID
356
+ mask_raster[mask_binary] = mask_id
357
+
358
+ # Create a record for the CSV - without geometry calculation
359
+ mask_records.append(
360
+ {"mask_id": mask_id, "score": float(score), "area_pixels": area_pixels}
361
+ )
362
+ elif isinstance(mask_results, list):
363
+ # Handle list of dictionaries format (SAM original format)
364
+ print("Processing masks...")
365
+ total_masks = len(mask_results)
366
+
367
+ # Create progress bar
368
+ for i, mask_result in enumerate(tqdm(mask_results, desc="Processing masks")):
369
+ mask_id = i + 1 # Start IDs at 1
370
+
371
+ # Try different possible key names for masks and scores
372
+ mask_data = None
373
+ score = None
374
+
375
+ if isinstance(mask_result, dict):
376
+ # Try to find mask data
377
+ if "segmentation" in mask_result:
378
+ mask_data = mask_result["segmentation"]
379
+ elif "mask" in mask_result:
380
+ mask_data = mask_result["mask"]
381
+
382
+ # Try to find score
383
+ if "score" in mask_result:
384
+ score = mask_result["score"]
385
+ elif "predicted_iou" in mask_result:
386
+ score = mask_result["predicted_iou"]
387
+ elif "stability_score" in mask_result:
388
+ score = mask_result["stability_score"]
389
+ else:
390
+ score = 1.0 # Default score if none found
391
+ else:
392
+ # If mask_result is not a dict, it might be the mask directly
393
+ try:
394
+ mask_data = np.array(mask_result)
395
+ score = 1.0 # Default score
396
+ except:
397
+ print(f"Could not process mask at index {i}")
398
+ continue
399
+
400
+ if mask_data is not None:
401
+ # Convert to numpy if not already
402
+ if not isinstance(mask_data, np.ndarray):
403
+ try:
404
+ mask_data = np.array(mask_data)
405
+ except:
406
+ print(f"Could not convert mask at index {i} to numpy array")
407
+ continue
408
+
409
+ mask_binary = mask_data.astype(bool)
410
+ area_pixels = np.sum(mask_binary)
411
+
412
+ # Skip if mask is smaller than the minimum size
413
+ if area_pixels < min_object_size:
414
+ continue
415
+
416
+ # Add the mask to the raster with a unique ID
417
+ mask_raster[mask_binary] = mask_id
418
+
419
+ # Create a record for the CSV - without geometry calculation
420
+ mask_records.append(
421
+ {
422
+ "mask_id": mask_id,
423
+ "score": float(score),
424
+ "area_pixels": area_pixels,
425
+ }
426
+ )
427
+ else:
428
+ # If we couldn't figure out the format, raise an error
429
+ raise ValueError(f"Unexpected format for mask_results: {type(mask_results)}")
430
+
431
+ print(f"Number of final masks (after size filtering): {len(mask_records)}")
432
+
433
+ # Save the mask raster as a GeoTIFF
434
+ print(f"Saving mask GeoTIFF to {output_mask_path}")
435
+ output_profile = profile.copy()
436
+ output_profile.update(dtype=rasterio.uint32, count=1, compress="lzw", nodata=0)
437
+
438
+ with rasterio.open(output_mask_path, "w", **output_profile) as dst:
439
+ dst.write(mask_raster.astype(rasterio.uint32), 1)
440
+
441
+ # Save the mask data as a CSV
442
+ print(f"Saving mask metadata to {output_csv_path}")
443
+ mask_df = pd.DataFrame(mask_records)
444
+ mask_df.to_csv(output_csv_path, index=False)
445
+
446
+ print("Processing complete!")
447
+ return output_mask_path, output_csv_path
geoai/segment.py CHANGED
@@ -1,13 +1,14 @@
1
1
  """This module provides functionality for segmenting high-resolution satellite imagery using vision-language models."""
2
2
 
3
3
  import os
4
+
4
5
  import numpy as np
6
+ import rasterio
5
7
  import torch
6
- from tqdm import tqdm
7
8
  from PIL import Image
8
- import rasterio
9
9
  from rasterio.windows import Window
10
- from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
10
+ from tqdm import tqdm
11
+ from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
11
12
 
12
13
 
13
14
  class CLIPSegmentation:
geoai/segmentation.py CHANGED
@@ -1,18 +1,19 @@
1
1
  import os
2
+
3
+ import albumentations as A
4
+ import matplotlib.pyplot as plt
2
5
  import numpy as np
3
- from PIL import Image
4
6
  import torch
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import Dataset, Subset
7
7
  import torch.nn.functional as F
8
- from sklearn.model_selection import train_test_split
9
- import albumentations as A
10
8
  from albumentations.pytorch import ToTensorV2
9
+ from PIL import Image
10
+ from sklearn.model_selection import train_test_split
11
+ from torch.utils.data import Dataset, Subset
11
12
  from transformers import (
13
+ DefaultDataCollator,
14
+ SegformerForSemanticSegmentation,
12
15
  Trainer,
13
16
  TrainingArguments,
14
- SegformerForSemanticSegmentation,
15
- DefaultDataCollator,
16
17
  )
17
18
 
18
19