geoai-py 0.3.6__py2.py3-none-any.whl → 0.4.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +76 -14
- geoai/download.py +9 -8
- geoai/extract.py +65 -24
- geoai/geoai.py +3 -1
- geoai/hf.py +447 -0
- geoai/segment.py +4 -3
- geoai/segmentation.py +8 -7
- geoai/train.py +1039 -0
- geoai/utils.py +32 -28
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/METADATA +3 -8
- geoai_py-0.4.1.dist-info/RECORD +15 -0
- geoai_py-0.3.6.dist-info/RECORD +0 -13
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.1.dist-info}/top_level.txt +0 -0
geoai/hf.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
"""This module contains utility functions for working with Hugging Face models."""
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
import os
|
|
5
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import rasterio
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
from transformers import AutoConfig, AutoModelForMaskedImageModeling, pipeline
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_model_config(model_id):
|
|
16
|
+
"""
|
|
17
|
+
Get the model configuration for a Hugging Face model.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_id (str): The Hugging Face model ID.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
transformers.configuration_utils.PretrainedConfig: The model configuration.
|
|
24
|
+
"""
|
|
25
|
+
return AutoConfig.from_pretrained(model_id)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_model_input_channels(model_id):
|
|
29
|
+
"""
|
|
30
|
+
Check the number of input channels supported by a Hugging Face model.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_id (str): The Hugging Face model ID.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
int: The number of input channels the model accepts.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If unable to determine the number of input channels.
|
|
40
|
+
"""
|
|
41
|
+
# Load the model configuration
|
|
42
|
+
config = AutoConfig.from_pretrained(model_id)
|
|
43
|
+
|
|
44
|
+
# For Mask2Former models
|
|
45
|
+
if hasattr(config, "backbone_config"):
|
|
46
|
+
if hasattr(config.backbone_config, "num_channels"):
|
|
47
|
+
return config.backbone_config.num_channels
|
|
48
|
+
|
|
49
|
+
# Try to load the model and inspect its architecture
|
|
50
|
+
try:
|
|
51
|
+
model = AutoModelForMaskedImageModeling.from_pretrained(model_id)
|
|
52
|
+
|
|
53
|
+
# For Swin Transformer-based models like Mask2Former
|
|
54
|
+
if hasattr(model, "backbone") and hasattr(model.backbone, "embeddings"):
|
|
55
|
+
if hasattr(model.backbone.embeddings, "patch_embeddings"):
|
|
56
|
+
# Swin models typically have patch embeddings that indicate channel count
|
|
57
|
+
return model.backbone.embeddings.patch_embeddings.in_channels
|
|
58
|
+
except Exception as e:
|
|
59
|
+
print(f"Couldn't inspect model architecture: {e}")
|
|
60
|
+
|
|
61
|
+
# Default for most vision models
|
|
62
|
+
return 3
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def image_segmentation(
|
|
66
|
+
tif_path,
|
|
67
|
+
output_path,
|
|
68
|
+
labels_to_extract=None,
|
|
69
|
+
dtype="uint8",
|
|
70
|
+
model_name=None,
|
|
71
|
+
segmenter_args=None,
|
|
72
|
+
**kwargs,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Segments an image with a Hugging Face segmentation model and saves the results
|
|
76
|
+
as a single georeferenced image where each class has a unique integer value.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
tif_path (str): Path to the input georeferenced TIF file.
|
|
80
|
+
output_path (str): Path where the output georeferenced segmentation will be saved.
|
|
81
|
+
labels_to_extract (list, optional): List of labels to extract. If None, extracts all labels.
|
|
82
|
+
dtype (str, optional): Data type to use for the output mask. Defaults to "uint8".
|
|
83
|
+
model_name (str, optional): Name of the Hugging Face model to use for segmentation,
|
|
84
|
+
such as "facebook/mask2former-swin-large-cityscapes-semantic". Defaults to None.
|
|
85
|
+
See https://huggingface.co/models?pipeline_tag=image-segmentation&sort=trending for options.
|
|
86
|
+
segmenter_args (dict, optional): Additional arguments to pass to the segmenter.
|
|
87
|
+
Defaults to None.
|
|
88
|
+
**kwargs: Additional keyword arguments to pass to the segmentation pipeline
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
tuple: (Path to saved image, dictionary mapping label names to their assigned values,
|
|
92
|
+
dictionary mapping label names to confidence scores)
|
|
93
|
+
"""
|
|
94
|
+
# Load the original georeferenced image to extract metadata
|
|
95
|
+
with rasterio.open(tif_path) as src:
|
|
96
|
+
# Save the metadata for later use
|
|
97
|
+
meta = src.meta.copy()
|
|
98
|
+
# Get the dimensions
|
|
99
|
+
height = src.height
|
|
100
|
+
width = src.width
|
|
101
|
+
# Get the transform and CRS for georeferencing
|
|
102
|
+
# transform = src.transform
|
|
103
|
+
# crs = src.crs
|
|
104
|
+
|
|
105
|
+
# Initialize the segmentation pipeline
|
|
106
|
+
if model_name is None:
|
|
107
|
+
model_name = "facebook/mask2former-swin-large-cityscapes-semantic"
|
|
108
|
+
|
|
109
|
+
kwargs["task"] = "image-segmentation"
|
|
110
|
+
|
|
111
|
+
segmenter = pipeline(model=model_name, **kwargs)
|
|
112
|
+
|
|
113
|
+
# Run the segmentation on the GeoTIFF
|
|
114
|
+
if segmenter_args is None:
|
|
115
|
+
segmenter_args = {}
|
|
116
|
+
|
|
117
|
+
segments = segmenter(tif_path, **segmenter_args)
|
|
118
|
+
|
|
119
|
+
# If no specific labels are requested, extract all available ones
|
|
120
|
+
if labels_to_extract is None:
|
|
121
|
+
labels_to_extract = [segment["label"] for segment in segments]
|
|
122
|
+
|
|
123
|
+
# Create an empty mask to hold all the labels
|
|
124
|
+
# Using uint8 for up to 255 classes, switch to uint16 for more
|
|
125
|
+
combined_mask = np.zeros((height, width), dtype=np.uint8)
|
|
126
|
+
|
|
127
|
+
# Create a dictionary to map labels to values and store scores
|
|
128
|
+
label_to_value = {}
|
|
129
|
+
label_to_score = {}
|
|
130
|
+
|
|
131
|
+
# Process each segment we want to keep
|
|
132
|
+
for i, segment in enumerate(
|
|
133
|
+
[s for s in segments if s["label"] in labels_to_extract]
|
|
134
|
+
):
|
|
135
|
+
# Assign a unique value to each label (starting from 1)
|
|
136
|
+
value = i + 1
|
|
137
|
+
label = segment["label"]
|
|
138
|
+
score = segment["score"]
|
|
139
|
+
|
|
140
|
+
label_to_value[label] = value
|
|
141
|
+
label_to_score[label] = score
|
|
142
|
+
|
|
143
|
+
# Convert PIL image to numpy array
|
|
144
|
+
mask = np.array(segment["mask"])
|
|
145
|
+
|
|
146
|
+
# Apply a threshold if it's a probability mask (not binary)
|
|
147
|
+
if mask.dtype == float:
|
|
148
|
+
mask = (mask > 0.5).astype(np.uint8)
|
|
149
|
+
|
|
150
|
+
# Resize if needed to match original dimensions
|
|
151
|
+
if mask.shape != (height, width):
|
|
152
|
+
mask_img = Image.fromarray(mask)
|
|
153
|
+
mask_img = mask_img.resize((width, height))
|
|
154
|
+
mask = np.array(mask_img)
|
|
155
|
+
|
|
156
|
+
# Add this class to the combined mask
|
|
157
|
+
# Only overwrite if the pixel isn't already assigned to another class
|
|
158
|
+
# This handles overlapping segments by giving priority to earlier segments
|
|
159
|
+
combined_mask = np.where(
|
|
160
|
+
(mask > 0) & (combined_mask == 0), value, combined_mask
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Update metadata for the output raster
|
|
164
|
+
meta.update(
|
|
165
|
+
{
|
|
166
|
+
"count": 1, # One band for the mask
|
|
167
|
+
"dtype": dtype, # Use uint8 for up to 255 classes
|
|
168
|
+
"nodata": 0, # 0 represents no class
|
|
169
|
+
}
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Save the mask as a new georeferenced GeoTIFF
|
|
173
|
+
with rasterio.open(output_path, "w", **meta) as dst:
|
|
174
|
+
dst.write(combined_mask[np.newaxis, :, :]) # Add channel dimension
|
|
175
|
+
|
|
176
|
+
# Create a CSV colormap file with scores included
|
|
177
|
+
csv_path = os.path.splitext(output_path)[0] + "_colormap.csv"
|
|
178
|
+
with open(csv_path, "w", newline="") as csvfile:
|
|
179
|
+
fieldnames = ["ClassValue", "ClassName", "ConfidenceScore"]
|
|
180
|
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
181
|
+
|
|
182
|
+
writer.writeheader()
|
|
183
|
+
for label, value in label_to_value.items():
|
|
184
|
+
writer.writerow(
|
|
185
|
+
{
|
|
186
|
+
"ClassValue": value,
|
|
187
|
+
"ClassName": label,
|
|
188
|
+
"ConfidenceScore": f"{label_to_score[label]:.4f}",
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return output_path, label_to_value, label_to_score
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def mask_generation(
|
|
196
|
+
input_path: str,
|
|
197
|
+
output_mask_path: str,
|
|
198
|
+
output_csv_path: str,
|
|
199
|
+
model: str = "facebook/sam-vit-base",
|
|
200
|
+
confidence_threshold: float = 0.5,
|
|
201
|
+
points_per_side: int = 32,
|
|
202
|
+
crop_size: Optional[int] = None,
|
|
203
|
+
batch_size: int = 1,
|
|
204
|
+
band_indices: Optional[List[int]] = None,
|
|
205
|
+
min_object_size: int = 0,
|
|
206
|
+
generator_kwargs: Optional[Dict] = None,
|
|
207
|
+
**kwargs,
|
|
208
|
+
) -> Tuple[str, str]:
|
|
209
|
+
"""
|
|
210
|
+
Process a GeoTIFF using SAM mask generation and save results as a GeoTIFF and CSV.
|
|
211
|
+
|
|
212
|
+
The function reads a GeoTIFF image, applies the SAM mask generator from the
|
|
213
|
+
Hugging Face transformers pipeline, rasterizes the resulting masks to create
|
|
214
|
+
a labeled mask GeoTIFF, and saves mask scores and geometries to a CSV file.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
input_path: Path to the input GeoTIFF image.
|
|
218
|
+
output_mask_path: Path where the output mask GeoTIFF will be saved.
|
|
219
|
+
output_csv_path: Path where the mask scores CSV will be saved.
|
|
220
|
+
model: HuggingFace model checkpoint for the SAM model.
|
|
221
|
+
confidence_threshold: Minimum confidence score for masks to be included.
|
|
222
|
+
points_per_side: Number of points to sample along each side of the image.
|
|
223
|
+
crop_size: Size of image crops for processing. If None, process the full image.
|
|
224
|
+
band_indices: List of band indices to use. If None, use all bands.
|
|
225
|
+
batch_size: Batch size for inference.
|
|
226
|
+
min_object_size: Minimum size in pixels for objects to be included. Smaller masks will be filtered out.
|
|
227
|
+
generator_kwargs: Additional keyword arguments to pass to the mask generator.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Tuple containing the paths to the saved mask GeoTIFF and CSV file.
|
|
231
|
+
|
|
232
|
+
Raises:
|
|
233
|
+
ValueError: If the input file cannot be opened or processed.
|
|
234
|
+
RuntimeError: If mask generation fails.
|
|
235
|
+
"""
|
|
236
|
+
# Set up the mask generator
|
|
237
|
+
print("Setting up mask generator...")
|
|
238
|
+
mask_generator = pipeline(model=model, task="mask-generation", **kwargs)
|
|
239
|
+
|
|
240
|
+
# Open the GeoTIFF file
|
|
241
|
+
try:
|
|
242
|
+
print(f"Reading input GeoTIFF: {input_path}")
|
|
243
|
+
with rasterio.open(input_path) as src:
|
|
244
|
+
# Read metadata
|
|
245
|
+
profile = src.profile
|
|
246
|
+
# transform = src.transform
|
|
247
|
+
# crs = src.crs
|
|
248
|
+
|
|
249
|
+
# Read the image data
|
|
250
|
+
if band_indices is not None:
|
|
251
|
+
print(f"Using specified bands: {band_indices}")
|
|
252
|
+
image_data = np.stack([src.read(i + 1) for i in band_indices])
|
|
253
|
+
else:
|
|
254
|
+
print("Using all bands")
|
|
255
|
+
image_data = src.read()
|
|
256
|
+
|
|
257
|
+
# Handle image with more than 3 bands (convert to RGB for visualization)
|
|
258
|
+
if image_data.shape[0] > 3:
|
|
259
|
+
print(
|
|
260
|
+
f"Converting {image_data.shape[0]} bands to RGB (using first 3 bands)"
|
|
261
|
+
)
|
|
262
|
+
# Select first three bands or perform other band combination
|
|
263
|
+
image_data = image_data[:3]
|
|
264
|
+
elif image_data.shape[0] == 1:
|
|
265
|
+
print("Duplicating single band to create 3-band image")
|
|
266
|
+
# Duplicate single band to create a 3-band image
|
|
267
|
+
image_data = np.vstack([image_data] * 3)
|
|
268
|
+
|
|
269
|
+
# Transpose to HWC format for the model
|
|
270
|
+
image_data = np.transpose(image_data, (1, 2, 0))
|
|
271
|
+
|
|
272
|
+
# Normalize the image if needed
|
|
273
|
+
if image_data.dtype != np.uint8:
|
|
274
|
+
print(f"Normalizing image from {image_data.dtype} to uint8")
|
|
275
|
+
image_data = (image_data / image_data.max() * 255).astype(np.uint8)
|
|
276
|
+
except Exception as e:
|
|
277
|
+
raise ValueError(f"Failed to open or process input GeoTIFF: {e}")
|
|
278
|
+
|
|
279
|
+
# Process the image with the mask generator
|
|
280
|
+
try:
|
|
281
|
+
# Convert numpy array to PIL Image for the pipeline
|
|
282
|
+
# Ensure the array is in the right format (HWC and uint8)
|
|
283
|
+
if image_data.dtype != np.uint8:
|
|
284
|
+
image_data = (image_data / image_data.max() * 255).astype(np.uint8)
|
|
285
|
+
|
|
286
|
+
# Create a PIL Image from the numpy array
|
|
287
|
+
print("Converting to PIL Image for mask generation")
|
|
288
|
+
pil_image = Image.fromarray(image_data)
|
|
289
|
+
|
|
290
|
+
# Use the SAM pipeline for mask generation
|
|
291
|
+
if generator_kwargs is None:
|
|
292
|
+
generator_kwargs = {}
|
|
293
|
+
|
|
294
|
+
print("Running mask generation...")
|
|
295
|
+
mask_results = mask_generator(
|
|
296
|
+
pil_image,
|
|
297
|
+
points_per_side=points_per_side,
|
|
298
|
+
crop_n_points_downscale_factor=1 if crop_size is None else 2,
|
|
299
|
+
point_grids=None,
|
|
300
|
+
pred_iou_thresh=confidence_threshold,
|
|
301
|
+
stability_score_thresh=confidence_threshold,
|
|
302
|
+
crops_n_layers=0 if crop_size is None else 1,
|
|
303
|
+
crop_overlap_ratio=0.5,
|
|
304
|
+
batch_size=batch_size,
|
|
305
|
+
**generator_kwargs,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
print(
|
|
309
|
+
f"Number of initial masks: {len(mask_results['masks']) if isinstance(mask_results, dict) and 'masks' in mask_results else len(mask_results)}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
except Exception as e:
|
|
313
|
+
raise RuntimeError(f"Mask generation failed: {e}")
|
|
314
|
+
|
|
315
|
+
# Create a mask raster with unique IDs for each mask
|
|
316
|
+
mask_raster = np.zeros((image_data.shape[0], image_data.shape[1]), dtype=np.uint32)
|
|
317
|
+
mask_records = []
|
|
318
|
+
|
|
319
|
+
# Process each mask based on the structure of mask_results
|
|
320
|
+
if (
|
|
321
|
+
isinstance(mask_results, dict)
|
|
322
|
+
and "masks" in mask_results
|
|
323
|
+
and "scores" in mask_results
|
|
324
|
+
):
|
|
325
|
+
# Handle dictionary with 'masks' and 'scores' lists
|
|
326
|
+
print("Processing masks...")
|
|
327
|
+
total_masks = len(mask_results["masks"])
|
|
328
|
+
|
|
329
|
+
# Create progress bar
|
|
330
|
+
for i, (mask_data, score) in enumerate(
|
|
331
|
+
tqdm(
|
|
332
|
+
zip(mask_results["masks"], mask_results["scores"]),
|
|
333
|
+
total=total_masks,
|
|
334
|
+
desc="Processing masks",
|
|
335
|
+
)
|
|
336
|
+
):
|
|
337
|
+
mask_id = i + 1 # Start IDs at 1
|
|
338
|
+
|
|
339
|
+
# Convert to numpy if not already
|
|
340
|
+
if not isinstance(mask_data, np.ndarray):
|
|
341
|
+
# Try to convert from tensor or other format if needed
|
|
342
|
+
try:
|
|
343
|
+
mask_data = np.array(mask_data)
|
|
344
|
+
except:
|
|
345
|
+
print(f"Could not convert mask at index {i} to numpy array")
|
|
346
|
+
continue
|
|
347
|
+
|
|
348
|
+
mask_binary = mask_data.astype(bool)
|
|
349
|
+
area_pixels = np.sum(mask_binary)
|
|
350
|
+
|
|
351
|
+
# Skip if mask is smaller than the minimum size
|
|
352
|
+
if area_pixels < min_object_size:
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
# Add the mask to the raster with a unique ID
|
|
356
|
+
mask_raster[mask_binary] = mask_id
|
|
357
|
+
|
|
358
|
+
# Create a record for the CSV - without geometry calculation
|
|
359
|
+
mask_records.append(
|
|
360
|
+
{"mask_id": mask_id, "score": float(score), "area_pixels": area_pixels}
|
|
361
|
+
)
|
|
362
|
+
elif isinstance(mask_results, list):
|
|
363
|
+
# Handle list of dictionaries format (SAM original format)
|
|
364
|
+
print("Processing masks...")
|
|
365
|
+
total_masks = len(mask_results)
|
|
366
|
+
|
|
367
|
+
# Create progress bar
|
|
368
|
+
for i, mask_result in enumerate(tqdm(mask_results, desc="Processing masks")):
|
|
369
|
+
mask_id = i + 1 # Start IDs at 1
|
|
370
|
+
|
|
371
|
+
# Try different possible key names for masks and scores
|
|
372
|
+
mask_data = None
|
|
373
|
+
score = None
|
|
374
|
+
|
|
375
|
+
if isinstance(mask_result, dict):
|
|
376
|
+
# Try to find mask data
|
|
377
|
+
if "segmentation" in mask_result:
|
|
378
|
+
mask_data = mask_result["segmentation"]
|
|
379
|
+
elif "mask" in mask_result:
|
|
380
|
+
mask_data = mask_result["mask"]
|
|
381
|
+
|
|
382
|
+
# Try to find score
|
|
383
|
+
if "score" in mask_result:
|
|
384
|
+
score = mask_result["score"]
|
|
385
|
+
elif "predicted_iou" in mask_result:
|
|
386
|
+
score = mask_result["predicted_iou"]
|
|
387
|
+
elif "stability_score" in mask_result:
|
|
388
|
+
score = mask_result["stability_score"]
|
|
389
|
+
else:
|
|
390
|
+
score = 1.0 # Default score if none found
|
|
391
|
+
else:
|
|
392
|
+
# If mask_result is not a dict, it might be the mask directly
|
|
393
|
+
try:
|
|
394
|
+
mask_data = np.array(mask_result)
|
|
395
|
+
score = 1.0 # Default score
|
|
396
|
+
except:
|
|
397
|
+
print(f"Could not process mask at index {i}")
|
|
398
|
+
continue
|
|
399
|
+
|
|
400
|
+
if mask_data is not None:
|
|
401
|
+
# Convert to numpy if not already
|
|
402
|
+
if not isinstance(mask_data, np.ndarray):
|
|
403
|
+
try:
|
|
404
|
+
mask_data = np.array(mask_data)
|
|
405
|
+
except:
|
|
406
|
+
print(f"Could not convert mask at index {i} to numpy array")
|
|
407
|
+
continue
|
|
408
|
+
|
|
409
|
+
mask_binary = mask_data.astype(bool)
|
|
410
|
+
area_pixels = np.sum(mask_binary)
|
|
411
|
+
|
|
412
|
+
# Skip if mask is smaller than the minimum size
|
|
413
|
+
if area_pixels < min_object_size:
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
# Add the mask to the raster with a unique ID
|
|
417
|
+
mask_raster[mask_binary] = mask_id
|
|
418
|
+
|
|
419
|
+
# Create a record for the CSV - without geometry calculation
|
|
420
|
+
mask_records.append(
|
|
421
|
+
{
|
|
422
|
+
"mask_id": mask_id,
|
|
423
|
+
"score": float(score),
|
|
424
|
+
"area_pixels": area_pixels,
|
|
425
|
+
}
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
# If we couldn't figure out the format, raise an error
|
|
429
|
+
raise ValueError(f"Unexpected format for mask_results: {type(mask_results)}")
|
|
430
|
+
|
|
431
|
+
print(f"Number of final masks (after size filtering): {len(mask_records)}")
|
|
432
|
+
|
|
433
|
+
# Save the mask raster as a GeoTIFF
|
|
434
|
+
print(f"Saving mask GeoTIFF to {output_mask_path}")
|
|
435
|
+
output_profile = profile.copy()
|
|
436
|
+
output_profile.update(dtype=rasterio.uint32, count=1, compress="lzw", nodata=0)
|
|
437
|
+
|
|
438
|
+
with rasterio.open(output_mask_path, "w", **output_profile) as dst:
|
|
439
|
+
dst.write(mask_raster.astype(rasterio.uint32), 1)
|
|
440
|
+
|
|
441
|
+
# Save the mask data as a CSV
|
|
442
|
+
print(f"Saving mask metadata to {output_csv_path}")
|
|
443
|
+
mask_df = pd.DataFrame(mask_records)
|
|
444
|
+
mask_df.to_csv(output_csv_path, index=False)
|
|
445
|
+
|
|
446
|
+
print("Processing complete!")
|
|
447
|
+
return output_mask_path, output_csv_path
|
geoai/segment.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
"""This module provides functionality for segmenting high-resolution satellite imagery using vision-language models."""
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
+
|
|
4
5
|
import numpy as np
|
|
6
|
+
import rasterio
|
|
5
7
|
import torch
|
|
6
|
-
from tqdm import tqdm
|
|
7
8
|
from PIL import Image
|
|
8
|
-
import rasterio
|
|
9
9
|
from rasterio.windows import Window
|
|
10
|
-
from
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class CLIPSegmentation:
|
geoai/segmentation.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import os
|
|
2
|
+
|
|
3
|
+
import albumentations as A
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
2
5
|
import numpy as np
|
|
3
|
-
from PIL import Image
|
|
4
6
|
import torch
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
from torch.utils.data import Dataset, Subset
|
|
7
7
|
import torch.nn.functional as F
|
|
8
|
-
from sklearn.model_selection import train_test_split
|
|
9
|
-
import albumentations as A
|
|
10
8
|
from albumentations.pytorch import ToTensorV2
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from sklearn.model_selection import train_test_split
|
|
11
|
+
from torch.utils.data import Dataset, Subset
|
|
11
12
|
from transformers import (
|
|
13
|
+
DefaultDataCollator,
|
|
14
|
+
SegformerForSemanticSegmentation,
|
|
12
15
|
Trainer,
|
|
13
16
|
TrainingArguments,
|
|
14
|
-
SegformerForSemanticSegmentation,
|
|
15
|
-
DefaultDataCollator,
|
|
16
17
|
)
|
|
17
18
|
|
|
18
19
|
|