geoai-py 0.3.6__py2.py3-none-any.whl → 0.4.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/download.py +9 -8
- geoai/extract.py +65 -24
- geoai/geoai.py +3 -1
- geoai/hf.py +447 -0
- geoai/segment.py +4 -3
- geoai/segmentation.py +8 -7
- geoai/train.py +1039 -0
- geoai/utils.py +12 -15
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.0.dist-info}/METADATA +1 -1
- geoai_py-0.4.0.dist-info/RECORD +15 -0
- geoai_py-0.3.6.dist-info/RECORD +0 -13
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.0.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.6.dist-info → geoai_py-0.4.0.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/download.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
"""This module provides functions to download data, including NAIP imagery and building data from Overture Maps."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
|
-
|
|
5
|
-
import
|
|
6
|
-
|
|
5
|
+
import subprocess
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import geopandas as gpd
|
|
7
9
|
import matplotlib.pyplot as plt
|
|
8
|
-
|
|
10
|
+
import numpy as np
|
|
9
11
|
import planetary_computer as pc
|
|
10
|
-
import
|
|
12
|
+
import requests
|
|
13
|
+
import rioxarray
|
|
14
|
+
from pystac_client import Client
|
|
11
15
|
from shapely.geometry import box
|
|
12
16
|
from tqdm import tqdm
|
|
13
|
-
import requests
|
|
14
|
-
import subprocess
|
|
15
|
-
import logging
|
|
16
17
|
|
|
17
18
|
# Configure logging
|
|
18
19
|
logging.basicConfig(
|
geoai/extract.py
CHANGED
|
@@ -14,16 +14,15 @@ import torch
|
|
|
14
14
|
from huggingface_hub import hf_hub_download
|
|
15
15
|
from rasterio.windows import Window
|
|
16
16
|
from shapely.geometry import Polygon, box
|
|
17
|
-
from tqdm import tqdm
|
|
18
17
|
from torchvision.models.detection import (
|
|
19
|
-
maskrcnn_resnet50_fpn,
|
|
20
18
|
fasterrcnn_resnet50_fpn_v2,
|
|
19
|
+
maskrcnn_resnet50_fpn,
|
|
21
20
|
)
|
|
21
|
+
from tqdm import tqdm
|
|
22
22
|
|
|
23
23
|
# Local Imports
|
|
24
24
|
from .utils import get_raster_stats
|
|
25
25
|
|
|
26
|
-
|
|
27
26
|
try:
|
|
28
27
|
from torchgeo.datasets import NonGeoDataset
|
|
29
28
|
except ImportError as e:
|
|
@@ -270,7 +269,9 @@ class ObjectDetector:
|
|
|
270
269
|
Object extraction using Mask R-CNN with TorchGeo.
|
|
271
270
|
"""
|
|
272
271
|
|
|
273
|
-
def __init__(
|
|
272
|
+
def __init__(
|
|
273
|
+
self, model_path=None, repo_id=None, model=None, num_classes=2, device=None
|
|
274
|
+
):
|
|
274
275
|
"""
|
|
275
276
|
Initialize the object extractor.
|
|
276
277
|
|
|
@@ -278,6 +279,7 @@ class ObjectDetector:
|
|
|
278
279
|
model_path: Path to the .pth model file.
|
|
279
280
|
repo_id: Hugging Face repository ID for model download.
|
|
280
281
|
model: Pre-initialized model object (optional).
|
|
282
|
+
num_classes: Number of classes for detection (default: 2).
|
|
281
283
|
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
282
284
|
"""
|
|
283
285
|
# Set device
|
|
@@ -297,7 +299,7 @@ class ObjectDetector:
|
|
|
297
299
|
self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
|
|
298
300
|
|
|
299
301
|
# Initialize model
|
|
300
|
-
self.model = self.initialize_model(model)
|
|
302
|
+
self.model = self.initialize_model(model, num_classes=num_classes)
|
|
301
303
|
|
|
302
304
|
# Download model if needed
|
|
303
305
|
if model_path is None or (not os.path.exists(model_path)):
|
|
@@ -342,11 +344,12 @@ class ObjectDetector:
|
|
|
342
344
|
print("Please specify a local model path or ensure internet connectivity.")
|
|
343
345
|
raise
|
|
344
346
|
|
|
345
|
-
def initialize_model(self, model):
|
|
347
|
+
def initialize_model(self, model, num_classes=2):
|
|
346
348
|
"""Initialize a deep learning model for object detection.
|
|
347
349
|
|
|
348
350
|
Args:
|
|
349
351
|
model (torch.nn.Module): A pre-initialized model object.
|
|
352
|
+
num_classes (int): Number of classes for detection.
|
|
350
353
|
|
|
351
354
|
Returns:
|
|
352
355
|
torch.nn.Module: A deep learning model for object detection.
|
|
@@ -361,7 +364,7 @@ class ObjectDetector:
|
|
|
361
364
|
model = maskrcnn_resnet50_fpn(
|
|
362
365
|
weights=None,
|
|
363
366
|
progress=False,
|
|
364
|
-
num_classes=
|
|
367
|
+
num_classes=num_classes, # Background + object
|
|
365
368
|
weights_backbone=None,
|
|
366
369
|
# These parameters ensure consistent normalization
|
|
367
370
|
image_mean=image_mean,
|
|
@@ -1306,13 +1309,14 @@ class ObjectDetector:
|
|
|
1306
1309
|
Returns:
|
|
1307
1310
|
GeoDataFrame with regularized objects
|
|
1308
1311
|
"""
|
|
1312
|
+
import math
|
|
1313
|
+
|
|
1314
|
+
import cv2
|
|
1315
|
+
import geopandas as gpd
|
|
1309
1316
|
import numpy as np
|
|
1310
|
-
from shapely.geometry import Polygon, MultiPolygon, box
|
|
1311
1317
|
from shapely.affinity import rotate, translate
|
|
1312
|
-
import
|
|
1313
|
-
import math
|
|
1318
|
+
from shapely.geometry import MultiPolygon, Polygon, box
|
|
1314
1319
|
from tqdm import tqdm
|
|
1315
|
-
import cv2
|
|
1316
1320
|
|
|
1317
1321
|
def get_angle(p1, p2, p3):
|
|
1318
1322
|
"""Calculate angle between three points in degrees (0-180)"""
|
|
@@ -2112,7 +2116,7 @@ class ObjectDetector:
|
|
|
2112
2116
|
output_path=None,
|
|
2113
2117
|
confidence_threshold=0.5,
|
|
2114
2118
|
min_object_area=100,
|
|
2115
|
-
|
|
2119
|
+
max_object_area=None,
|
|
2116
2120
|
**kwargs,
|
|
2117
2121
|
):
|
|
2118
2122
|
"""
|
|
@@ -2123,7 +2127,7 @@ class ObjectDetector:
|
|
|
2123
2127
|
output_path: Path for output GeoJSON.
|
|
2124
2128
|
confidence_threshold: Minimum confidence score (0.0-1.0). Default: 0.5
|
|
2125
2129
|
min_object_area: Minimum area in pixels to keep an object. Default: 100
|
|
2126
|
-
|
|
2130
|
+
max_object_area: Maximum area in pixels to keep an object. Default: None
|
|
2127
2131
|
**kwargs: Additional parameters
|
|
2128
2132
|
|
|
2129
2133
|
Returns:
|
|
@@ -2147,8 +2151,9 @@ class ObjectDetector:
|
|
|
2147
2151
|
print(f"Found {num_features} connected components")
|
|
2148
2152
|
|
|
2149
2153
|
# Process each component
|
|
2150
|
-
|
|
2151
|
-
|
|
2154
|
+
polygons = []
|
|
2155
|
+
confidences = []
|
|
2156
|
+
pixels = []
|
|
2152
2157
|
|
|
2153
2158
|
# Add progress bar
|
|
2154
2159
|
for label in tqdm(range(1, num_features + 1), desc="Processing components"):
|
|
@@ -2179,8 +2184,8 @@ class ObjectDetector:
|
|
|
2179
2184
|
if area < min_object_area:
|
|
2180
2185
|
continue
|
|
2181
2186
|
|
|
2182
|
-
if
|
|
2183
|
-
if area >
|
|
2187
|
+
if max_object_area is not None:
|
|
2188
|
+
if area > max_object_area:
|
|
2184
2189
|
continue
|
|
2185
2190
|
|
|
2186
2191
|
# Get minimum area rectangle
|
|
@@ -2197,16 +2202,18 @@ class ObjectDetector:
|
|
|
2197
2202
|
poly = Polygon(geo_points)
|
|
2198
2203
|
|
|
2199
2204
|
# Add to lists
|
|
2200
|
-
|
|
2201
|
-
|
|
2205
|
+
polygons.append(poly)
|
|
2206
|
+
confidences.append(confidence)
|
|
2207
|
+
pixels.append(area)
|
|
2202
2208
|
|
|
2203
2209
|
# Create GeoDataFrame
|
|
2204
|
-
if
|
|
2210
|
+
if polygons:
|
|
2205
2211
|
gdf = gpd.GeoDataFrame(
|
|
2206
2212
|
{
|
|
2207
|
-
"geometry":
|
|
2208
|
-
"confidence":
|
|
2209
|
-
"class": [1] * len(
|
|
2213
|
+
"geometry": polygons,
|
|
2214
|
+
"confidence": confidences,
|
|
2215
|
+
"class": [1] * len(polygons),
|
|
2216
|
+
"pixels": pixels,
|
|
2210
2217
|
},
|
|
2211
2218
|
crs=crs,
|
|
2212
2219
|
)
|
|
@@ -2218,7 +2225,7 @@ class ObjectDetector:
|
|
|
2218
2225
|
|
|
2219
2226
|
return gdf
|
|
2220
2227
|
else:
|
|
2221
|
-
print("No valid
|
|
2228
|
+
print("No valid polygons found")
|
|
2222
2229
|
return None
|
|
2223
2230
|
|
|
2224
2231
|
|
|
@@ -2356,3 +2363,37 @@ class SolarPanelDetector(ObjectDetector):
|
|
|
2356
2363
|
super().__init__(
|
|
2357
2364
|
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2358
2365
|
)
|
|
2366
|
+
|
|
2367
|
+
|
|
2368
|
+
class ParkingSplotDetector(ObjectDetector):
|
|
2369
|
+
"""
|
|
2370
|
+
Car detection using a pre-trained Mask R-CNN model.
|
|
2371
|
+
|
|
2372
|
+
This class extends the `ObjectDetector` class with additional methods for car detection.
|
|
2373
|
+
"""
|
|
2374
|
+
|
|
2375
|
+
def __init__(
|
|
2376
|
+
self,
|
|
2377
|
+
model_path="parking_spot_detection.pth",
|
|
2378
|
+
repo_id=None,
|
|
2379
|
+
model=None,
|
|
2380
|
+
num_classes=3,
|
|
2381
|
+
device=None,
|
|
2382
|
+
):
|
|
2383
|
+
"""
|
|
2384
|
+
Initialize the object extractor.
|
|
2385
|
+
|
|
2386
|
+
Args:
|
|
2387
|
+
model_path: Path to the .pth model file.
|
|
2388
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
2389
|
+
model: Custom model to use for inference.
|
|
2390
|
+
num_classes: Number of classes for the model. Default: 3
|
|
2391
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
2392
|
+
"""
|
|
2393
|
+
super().__init__(
|
|
2394
|
+
model_path=model_path,
|
|
2395
|
+
repo_id=repo_id,
|
|
2396
|
+
model=model,
|
|
2397
|
+
num_classes=num_classes,
|
|
2398
|
+
device=device,
|
|
2399
|
+
)
|
geoai/geoai.py
CHANGED
geoai/hf.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
"""This module contains utility functions for working with Hugging Face models."""
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
import os
|
|
5
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import rasterio
|
|
10
|
+
from PIL import Image
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
from transformers import AutoConfig, AutoModelForMaskedImageModeling, pipeline
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_model_config(model_id):
|
|
16
|
+
"""
|
|
17
|
+
Get the model configuration for a Hugging Face model.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_id (str): The Hugging Face model ID.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
transformers.configuration_utils.PretrainedConfig: The model configuration.
|
|
24
|
+
"""
|
|
25
|
+
return AutoConfig.from_pretrained(model_id)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_model_input_channels(model_id):
|
|
29
|
+
"""
|
|
30
|
+
Check the number of input channels supported by a Hugging Face model.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
model_id (str): The Hugging Face model ID.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
int: The number of input channels the model accepts.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If unable to determine the number of input channels.
|
|
40
|
+
"""
|
|
41
|
+
# Load the model configuration
|
|
42
|
+
config = AutoConfig.from_pretrained(model_id)
|
|
43
|
+
|
|
44
|
+
# For Mask2Former models
|
|
45
|
+
if hasattr(config, "backbone_config"):
|
|
46
|
+
if hasattr(config.backbone_config, "num_channels"):
|
|
47
|
+
return config.backbone_config.num_channels
|
|
48
|
+
|
|
49
|
+
# Try to load the model and inspect its architecture
|
|
50
|
+
try:
|
|
51
|
+
model = AutoModelForMaskedImageModeling.from_pretrained(model_id)
|
|
52
|
+
|
|
53
|
+
# For Swin Transformer-based models like Mask2Former
|
|
54
|
+
if hasattr(model, "backbone") and hasattr(model.backbone, "embeddings"):
|
|
55
|
+
if hasattr(model.backbone.embeddings, "patch_embeddings"):
|
|
56
|
+
# Swin models typically have patch embeddings that indicate channel count
|
|
57
|
+
return model.backbone.embeddings.patch_embeddings.in_channels
|
|
58
|
+
except Exception as e:
|
|
59
|
+
print(f"Couldn't inspect model architecture: {e}")
|
|
60
|
+
|
|
61
|
+
# Default for most vision models
|
|
62
|
+
return 3
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def image_segmentation(
|
|
66
|
+
tif_path,
|
|
67
|
+
output_path,
|
|
68
|
+
labels_to_extract=None,
|
|
69
|
+
dtype="uint8",
|
|
70
|
+
model_name=None,
|
|
71
|
+
segmenter_args=None,
|
|
72
|
+
**kwargs,
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Segments an image with a Hugging Face segmentation model and saves the results
|
|
76
|
+
as a single georeferenced image where each class has a unique integer value.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
tif_path (str): Path to the input georeferenced TIF file.
|
|
80
|
+
output_path (str): Path where the output georeferenced segmentation will be saved.
|
|
81
|
+
labels_to_extract (list, optional): List of labels to extract. If None, extracts all labels.
|
|
82
|
+
dtype (str, optional): Data type to use for the output mask. Defaults to "uint8".
|
|
83
|
+
model_name (str, optional): Name of the Hugging Face model to use for segmentation,
|
|
84
|
+
such as "facebook/mask2former-swin-large-cityscapes-semantic". Defaults to None.
|
|
85
|
+
See https://huggingface.co/models?pipeline_tag=image-segmentation&sort=trending for options.
|
|
86
|
+
segmenter_args (dict, optional): Additional arguments to pass to the segmenter.
|
|
87
|
+
Defaults to None.
|
|
88
|
+
**kwargs: Additional keyword arguments to pass to the segmentation pipeline
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
tuple: (Path to saved image, dictionary mapping label names to their assigned values,
|
|
92
|
+
dictionary mapping label names to confidence scores)
|
|
93
|
+
"""
|
|
94
|
+
# Load the original georeferenced image to extract metadata
|
|
95
|
+
with rasterio.open(tif_path) as src:
|
|
96
|
+
# Save the metadata for later use
|
|
97
|
+
meta = src.meta.copy()
|
|
98
|
+
# Get the dimensions
|
|
99
|
+
height = src.height
|
|
100
|
+
width = src.width
|
|
101
|
+
# Get the transform and CRS for georeferencing
|
|
102
|
+
# transform = src.transform
|
|
103
|
+
# crs = src.crs
|
|
104
|
+
|
|
105
|
+
# Initialize the segmentation pipeline
|
|
106
|
+
if model_name is None:
|
|
107
|
+
model_name = "facebook/mask2former-swin-large-cityscapes-semantic"
|
|
108
|
+
|
|
109
|
+
kwargs["task"] = "image-segmentation"
|
|
110
|
+
|
|
111
|
+
segmenter = pipeline(model=model_name, **kwargs)
|
|
112
|
+
|
|
113
|
+
# Run the segmentation on the GeoTIFF
|
|
114
|
+
if segmenter_args is None:
|
|
115
|
+
segmenter_args = {}
|
|
116
|
+
|
|
117
|
+
segments = segmenter(tif_path, **segmenter_args)
|
|
118
|
+
|
|
119
|
+
# If no specific labels are requested, extract all available ones
|
|
120
|
+
if labels_to_extract is None:
|
|
121
|
+
labels_to_extract = [segment["label"] for segment in segments]
|
|
122
|
+
|
|
123
|
+
# Create an empty mask to hold all the labels
|
|
124
|
+
# Using uint8 for up to 255 classes, switch to uint16 for more
|
|
125
|
+
combined_mask = np.zeros((height, width), dtype=np.uint8)
|
|
126
|
+
|
|
127
|
+
# Create a dictionary to map labels to values and store scores
|
|
128
|
+
label_to_value = {}
|
|
129
|
+
label_to_score = {}
|
|
130
|
+
|
|
131
|
+
# Process each segment we want to keep
|
|
132
|
+
for i, segment in enumerate(
|
|
133
|
+
[s for s in segments if s["label"] in labels_to_extract]
|
|
134
|
+
):
|
|
135
|
+
# Assign a unique value to each label (starting from 1)
|
|
136
|
+
value = i + 1
|
|
137
|
+
label = segment["label"]
|
|
138
|
+
score = segment["score"]
|
|
139
|
+
|
|
140
|
+
label_to_value[label] = value
|
|
141
|
+
label_to_score[label] = score
|
|
142
|
+
|
|
143
|
+
# Convert PIL image to numpy array
|
|
144
|
+
mask = np.array(segment["mask"])
|
|
145
|
+
|
|
146
|
+
# Apply a threshold if it's a probability mask (not binary)
|
|
147
|
+
if mask.dtype == float:
|
|
148
|
+
mask = (mask > 0.5).astype(np.uint8)
|
|
149
|
+
|
|
150
|
+
# Resize if needed to match original dimensions
|
|
151
|
+
if mask.shape != (height, width):
|
|
152
|
+
mask_img = Image.fromarray(mask)
|
|
153
|
+
mask_img = mask_img.resize((width, height))
|
|
154
|
+
mask = np.array(mask_img)
|
|
155
|
+
|
|
156
|
+
# Add this class to the combined mask
|
|
157
|
+
# Only overwrite if the pixel isn't already assigned to another class
|
|
158
|
+
# This handles overlapping segments by giving priority to earlier segments
|
|
159
|
+
combined_mask = np.where(
|
|
160
|
+
(mask > 0) & (combined_mask == 0), value, combined_mask
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Update metadata for the output raster
|
|
164
|
+
meta.update(
|
|
165
|
+
{
|
|
166
|
+
"count": 1, # One band for the mask
|
|
167
|
+
"dtype": dtype, # Use uint8 for up to 255 classes
|
|
168
|
+
"nodata": 0, # 0 represents no class
|
|
169
|
+
}
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Save the mask as a new georeferenced GeoTIFF
|
|
173
|
+
with rasterio.open(output_path, "w", **meta) as dst:
|
|
174
|
+
dst.write(combined_mask[np.newaxis, :, :]) # Add channel dimension
|
|
175
|
+
|
|
176
|
+
# Create a CSV colormap file with scores included
|
|
177
|
+
csv_path = os.path.splitext(output_path)[0] + "_colormap.csv"
|
|
178
|
+
with open(csv_path, "w", newline="") as csvfile:
|
|
179
|
+
fieldnames = ["ClassValue", "ClassName", "ConfidenceScore"]
|
|
180
|
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
|
181
|
+
|
|
182
|
+
writer.writeheader()
|
|
183
|
+
for label, value in label_to_value.items():
|
|
184
|
+
writer.writerow(
|
|
185
|
+
{
|
|
186
|
+
"ClassValue": value,
|
|
187
|
+
"ClassName": label,
|
|
188
|
+
"ConfidenceScore": f"{label_to_score[label]:.4f}",
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return output_path, label_to_value, label_to_score
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def mask_generation(
|
|
196
|
+
input_path: str,
|
|
197
|
+
output_mask_path: str,
|
|
198
|
+
output_csv_path: str,
|
|
199
|
+
model: str = "facebook/sam-vit-base",
|
|
200
|
+
confidence_threshold: float = 0.5,
|
|
201
|
+
points_per_side: int = 32,
|
|
202
|
+
crop_size: Optional[int] = None,
|
|
203
|
+
batch_size: int = 1,
|
|
204
|
+
band_indices: Optional[List[int]] = None,
|
|
205
|
+
min_object_size: int = 0,
|
|
206
|
+
generator_kwargs: Optional[Dict] = None,
|
|
207
|
+
**kwargs,
|
|
208
|
+
) -> Tuple[str, str]:
|
|
209
|
+
"""
|
|
210
|
+
Process a GeoTIFF using SAM mask generation and save results as a GeoTIFF and CSV.
|
|
211
|
+
|
|
212
|
+
The function reads a GeoTIFF image, applies the SAM mask generator from the
|
|
213
|
+
Hugging Face transformers pipeline, rasterizes the resulting masks to create
|
|
214
|
+
a labeled mask GeoTIFF, and saves mask scores and geometries to a CSV file.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
input_path: Path to the input GeoTIFF image.
|
|
218
|
+
output_mask_path: Path where the output mask GeoTIFF will be saved.
|
|
219
|
+
output_csv_path: Path where the mask scores CSV will be saved.
|
|
220
|
+
model: HuggingFace model checkpoint for the SAM model.
|
|
221
|
+
confidence_threshold: Minimum confidence score for masks to be included.
|
|
222
|
+
points_per_side: Number of points to sample along each side of the image.
|
|
223
|
+
crop_size: Size of image crops for processing. If None, process the full image.
|
|
224
|
+
band_indices: List of band indices to use. If None, use all bands.
|
|
225
|
+
batch_size: Batch size for inference.
|
|
226
|
+
min_object_size: Minimum size in pixels for objects to be included. Smaller masks will be filtered out.
|
|
227
|
+
generator_kwargs: Additional keyword arguments to pass to the mask generator.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Tuple containing the paths to the saved mask GeoTIFF and CSV file.
|
|
231
|
+
|
|
232
|
+
Raises:
|
|
233
|
+
ValueError: If the input file cannot be opened or processed.
|
|
234
|
+
RuntimeError: If mask generation fails.
|
|
235
|
+
"""
|
|
236
|
+
# Set up the mask generator
|
|
237
|
+
print("Setting up mask generator...")
|
|
238
|
+
mask_generator = pipeline(model=model, task="mask-generation", **kwargs)
|
|
239
|
+
|
|
240
|
+
# Open the GeoTIFF file
|
|
241
|
+
try:
|
|
242
|
+
print(f"Reading input GeoTIFF: {input_path}")
|
|
243
|
+
with rasterio.open(input_path) as src:
|
|
244
|
+
# Read metadata
|
|
245
|
+
profile = src.profile
|
|
246
|
+
# transform = src.transform
|
|
247
|
+
# crs = src.crs
|
|
248
|
+
|
|
249
|
+
# Read the image data
|
|
250
|
+
if band_indices is not None:
|
|
251
|
+
print(f"Using specified bands: {band_indices}")
|
|
252
|
+
image_data = np.stack([src.read(i + 1) for i in band_indices])
|
|
253
|
+
else:
|
|
254
|
+
print("Using all bands")
|
|
255
|
+
image_data = src.read()
|
|
256
|
+
|
|
257
|
+
# Handle image with more than 3 bands (convert to RGB for visualization)
|
|
258
|
+
if image_data.shape[0] > 3:
|
|
259
|
+
print(
|
|
260
|
+
f"Converting {image_data.shape[0]} bands to RGB (using first 3 bands)"
|
|
261
|
+
)
|
|
262
|
+
# Select first three bands or perform other band combination
|
|
263
|
+
image_data = image_data[:3]
|
|
264
|
+
elif image_data.shape[0] == 1:
|
|
265
|
+
print("Duplicating single band to create 3-band image")
|
|
266
|
+
# Duplicate single band to create a 3-band image
|
|
267
|
+
image_data = np.vstack([image_data] * 3)
|
|
268
|
+
|
|
269
|
+
# Transpose to HWC format for the model
|
|
270
|
+
image_data = np.transpose(image_data, (1, 2, 0))
|
|
271
|
+
|
|
272
|
+
# Normalize the image if needed
|
|
273
|
+
if image_data.dtype != np.uint8:
|
|
274
|
+
print(f"Normalizing image from {image_data.dtype} to uint8")
|
|
275
|
+
image_data = (image_data / image_data.max() * 255).astype(np.uint8)
|
|
276
|
+
except Exception as e:
|
|
277
|
+
raise ValueError(f"Failed to open or process input GeoTIFF: {e}")
|
|
278
|
+
|
|
279
|
+
# Process the image with the mask generator
|
|
280
|
+
try:
|
|
281
|
+
# Convert numpy array to PIL Image for the pipeline
|
|
282
|
+
# Ensure the array is in the right format (HWC and uint8)
|
|
283
|
+
if image_data.dtype != np.uint8:
|
|
284
|
+
image_data = (image_data / image_data.max() * 255).astype(np.uint8)
|
|
285
|
+
|
|
286
|
+
# Create a PIL Image from the numpy array
|
|
287
|
+
print("Converting to PIL Image for mask generation")
|
|
288
|
+
pil_image = Image.fromarray(image_data)
|
|
289
|
+
|
|
290
|
+
# Use the SAM pipeline for mask generation
|
|
291
|
+
if generator_kwargs is None:
|
|
292
|
+
generator_kwargs = {}
|
|
293
|
+
|
|
294
|
+
print("Running mask generation...")
|
|
295
|
+
mask_results = mask_generator(
|
|
296
|
+
pil_image,
|
|
297
|
+
points_per_side=points_per_side,
|
|
298
|
+
crop_n_points_downscale_factor=1 if crop_size is None else 2,
|
|
299
|
+
point_grids=None,
|
|
300
|
+
pred_iou_thresh=confidence_threshold,
|
|
301
|
+
stability_score_thresh=confidence_threshold,
|
|
302
|
+
crops_n_layers=0 if crop_size is None else 1,
|
|
303
|
+
crop_overlap_ratio=0.5,
|
|
304
|
+
batch_size=batch_size,
|
|
305
|
+
**generator_kwargs,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
print(
|
|
309
|
+
f"Number of initial masks: {len(mask_results['masks']) if isinstance(mask_results, dict) and 'masks' in mask_results else len(mask_results)}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
except Exception as e:
|
|
313
|
+
raise RuntimeError(f"Mask generation failed: {e}")
|
|
314
|
+
|
|
315
|
+
# Create a mask raster with unique IDs for each mask
|
|
316
|
+
mask_raster = np.zeros((image_data.shape[0], image_data.shape[1]), dtype=np.uint32)
|
|
317
|
+
mask_records = []
|
|
318
|
+
|
|
319
|
+
# Process each mask based on the structure of mask_results
|
|
320
|
+
if (
|
|
321
|
+
isinstance(mask_results, dict)
|
|
322
|
+
and "masks" in mask_results
|
|
323
|
+
and "scores" in mask_results
|
|
324
|
+
):
|
|
325
|
+
# Handle dictionary with 'masks' and 'scores' lists
|
|
326
|
+
print("Processing masks...")
|
|
327
|
+
total_masks = len(mask_results["masks"])
|
|
328
|
+
|
|
329
|
+
# Create progress bar
|
|
330
|
+
for i, (mask_data, score) in enumerate(
|
|
331
|
+
tqdm(
|
|
332
|
+
zip(mask_results["masks"], mask_results["scores"]),
|
|
333
|
+
total=total_masks,
|
|
334
|
+
desc="Processing masks",
|
|
335
|
+
)
|
|
336
|
+
):
|
|
337
|
+
mask_id = i + 1 # Start IDs at 1
|
|
338
|
+
|
|
339
|
+
# Convert to numpy if not already
|
|
340
|
+
if not isinstance(mask_data, np.ndarray):
|
|
341
|
+
# Try to convert from tensor or other format if needed
|
|
342
|
+
try:
|
|
343
|
+
mask_data = np.array(mask_data)
|
|
344
|
+
except:
|
|
345
|
+
print(f"Could not convert mask at index {i} to numpy array")
|
|
346
|
+
continue
|
|
347
|
+
|
|
348
|
+
mask_binary = mask_data.astype(bool)
|
|
349
|
+
area_pixels = np.sum(mask_binary)
|
|
350
|
+
|
|
351
|
+
# Skip if mask is smaller than the minimum size
|
|
352
|
+
if area_pixels < min_object_size:
|
|
353
|
+
continue
|
|
354
|
+
|
|
355
|
+
# Add the mask to the raster with a unique ID
|
|
356
|
+
mask_raster[mask_binary] = mask_id
|
|
357
|
+
|
|
358
|
+
# Create a record for the CSV - without geometry calculation
|
|
359
|
+
mask_records.append(
|
|
360
|
+
{"mask_id": mask_id, "score": float(score), "area_pixels": area_pixels}
|
|
361
|
+
)
|
|
362
|
+
elif isinstance(mask_results, list):
|
|
363
|
+
# Handle list of dictionaries format (SAM original format)
|
|
364
|
+
print("Processing masks...")
|
|
365
|
+
total_masks = len(mask_results)
|
|
366
|
+
|
|
367
|
+
# Create progress bar
|
|
368
|
+
for i, mask_result in enumerate(tqdm(mask_results, desc="Processing masks")):
|
|
369
|
+
mask_id = i + 1 # Start IDs at 1
|
|
370
|
+
|
|
371
|
+
# Try different possible key names for masks and scores
|
|
372
|
+
mask_data = None
|
|
373
|
+
score = None
|
|
374
|
+
|
|
375
|
+
if isinstance(mask_result, dict):
|
|
376
|
+
# Try to find mask data
|
|
377
|
+
if "segmentation" in mask_result:
|
|
378
|
+
mask_data = mask_result["segmentation"]
|
|
379
|
+
elif "mask" in mask_result:
|
|
380
|
+
mask_data = mask_result["mask"]
|
|
381
|
+
|
|
382
|
+
# Try to find score
|
|
383
|
+
if "score" in mask_result:
|
|
384
|
+
score = mask_result["score"]
|
|
385
|
+
elif "predicted_iou" in mask_result:
|
|
386
|
+
score = mask_result["predicted_iou"]
|
|
387
|
+
elif "stability_score" in mask_result:
|
|
388
|
+
score = mask_result["stability_score"]
|
|
389
|
+
else:
|
|
390
|
+
score = 1.0 # Default score if none found
|
|
391
|
+
else:
|
|
392
|
+
# If mask_result is not a dict, it might be the mask directly
|
|
393
|
+
try:
|
|
394
|
+
mask_data = np.array(mask_result)
|
|
395
|
+
score = 1.0 # Default score
|
|
396
|
+
except:
|
|
397
|
+
print(f"Could not process mask at index {i}")
|
|
398
|
+
continue
|
|
399
|
+
|
|
400
|
+
if mask_data is not None:
|
|
401
|
+
# Convert to numpy if not already
|
|
402
|
+
if not isinstance(mask_data, np.ndarray):
|
|
403
|
+
try:
|
|
404
|
+
mask_data = np.array(mask_data)
|
|
405
|
+
except:
|
|
406
|
+
print(f"Could not convert mask at index {i} to numpy array")
|
|
407
|
+
continue
|
|
408
|
+
|
|
409
|
+
mask_binary = mask_data.astype(bool)
|
|
410
|
+
area_pixels = np.sum(mask_binary)
|
|
411
|
+
|
|
412
|
+
# Skip if mask is smaller than the minimum size
|
|
413
|
+
if area_pixels < min_object_size:
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
# Add the mask to the raster with a unique ID
|
|
417
|
+
mask_raster[mask_binary] = mask_id
|
|
418
|
+
|
|
419
|
+
# Create a record for the CSV - without geometry calculation
|
|
420
|
+
mask_records.append(
|
|
421
|
+
{
|
|
422
|
+
"mask_id": mask_id,
|
|
423
|
+
"score": float(score),
|
|
424
|
+
"area_pixels": area_pixels,
|
|
425
|
+
}
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
# If we couldn't figure out the format, raise an error
|
|
429
|
+
raise ValueError(f"Unexpected format for mask_results: {type(mask_results)}")
|
|
430
|
+
|
|
431
|
+
print(f"Number of final masks (after size filtering): {len(mask_records)}")
|
|
432
|
+
|
|
433
|
+
# Save the mask raster as a GeoTIFF
|
|
434
|
+
print(f"Saving mask GeoTIFF to {output_mask_path}")
|
|
435
|
+
output_profile = profile.copy()
|
|
436
|
+
output_profile.update(dtype=rasterio.uint32, count=1, compress="lzw", nodata=0)
|
|
437
|
+
|
|
438
|
+
with rasterio.open(output_mask_path, "w", **output_profile) as dst:
|
|
439
|
+
dst.write(mask_raster.astype(rasterio.uint32), 1)
|
|
440
|
+
|
|
441
|
+
# Save the mask data as a CSV
|
|
442
|
+
print(f"Saving mask metadata to {output_csv_path}")
|
|
443
|
+
mask_df = pd.DataFrame(mask_records)
|
|
444
|
+
mask_df.to_csv(output_csv_path, index=False)
|
|
445
|
+
|
|
446
|
+
print("Processing complete!")
|
|
447
|
+
return output_mask_path, output_csv_path
|