geoai-py 0.1.7__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +28 -1
- geoai/common.py +158 -1
- geoai/download.py +9 -0
- geoai/extract.py +832 -0
- geoai/preprocess.py +2008 -0
- geoai_py-0.2.1.dist-info/METADATA +136 -0
- geoai_py-0.2.1.dist-info/RECORD +13 -0
- geoai_py-0.1.7.dist-info/METADATA +0 -51
- geoai_py-0.1.7.dist-info/RECORD +0 -11
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/LICENSE +0 -0
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.1.7.dist-info → geoai_py-0.2.1.dist-info}/top_level.txt +0 -0
geoai/extract.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
from shapely.geometry import Polygon, box
|
|
6
|
+
import geopandas as gpd
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
import cv2
|
|
10
|
+
from torchgeo.datasets import NonGeoDataset
|
|
11
|
+
from torchvision.models.detection import maskrcnn_resnet50_fpn
|
|
12
|
+
import torchvision.transforms as T
|
|
13
|
+
import rasterio
|
|
14
|
+
from rasterio.windows import Window
|
|
15
|
+
from rasterio.features import shapes
|
|
16
|
+
from huggingface_hub import hf_hub_download
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BuildingFootprintDataset(NonGeoDataset):
|
|
20
|
+
"""
|
|
21
|
+
A TorchGeo dataset for building footprint extraction.
|
|
22
|
+
Using NonGeoDataset to avoid spatial indexing issues.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, raster_path, chip_size=(512, 512), transforms=None):
|
|
26
|
+
"""
|
|
27
|
+
Initialize the dataset.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
raster_path: Path to the input raster file
|
|
31
|
+
chip_size: Size of image chips to extract (height, width)
|
|
32
|
+
transforms: Transforms to apply to the image
|
|
33
|
+
"""
|
|
34
|
+
super().__init__()
|
|
35
|
+
|
|
36
|
+
# Initialize parameters
|
|
37
|
+
self.raster_path = raster_path
|
|
38
|
+
self.chip_size = chip_size
|
|
39
|
+
self.transforms = transforms
|
|
40
|
+
|
|
41
|
+
# Open raster and get metadata
|
|
42
|
+
with rasterio.open(self.raster_path) as src:
|
|
43
|
+
self.crs = src.crs
|
|
44
|
+
self.transform = src.transform
|
|
45
|
+
self.height = src.height
|
|
46
|
+
self.width = src.width
|
|
47
|
+
self.count = src.count
|
|
48
|
+
|
|
49
|
+
# Define the bounds of the dataset
|
|
50
|
+
west, south, east, north = src.bounds
|
|
51
|
+
self.bounds = (west, south, east, north)
|
|
52
|
+
|
|
53
|
+
# Define the ROI for the dataset
|
|
54
|
+
self.roi = box(*self.bounds)
|
|
55
|
+
|
|
56
|
+
# Calculate number of chips in each dimension
|
|
57
|
+
self.rows = self.height // self.chip_size[0]
|
|
58
|
+
self.cols = self.width // self.chip_size[1]
|
|
59
|
+
|
|
60
|
+
print(
|
|
61
|
+
f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
|
|
62
|
+
)
|
|
63
|
+
if src.crs:
|
|
64
|
+
print(f"CRS: {src.crs}")
|
|
65
|
+
|
|
66
|
+
def __getitem__(self, idx):
|
|
67
|
+
"""
|
|
68
|
+
Get an image chip from the dataset by index.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
idx: Index of the chip
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dict containing image tensor
|
|
75
|
+
"""
|
|
76
|
+
# Convert flat index to grid position
|
|
77
|
+
row = idx // self.cols
|
|
78
|
+
col = idx % self.cols
|
|
79
|
+
|
|
80
|
+
# Calculate pixel coordinates
|
|
81
|
+
i = col * self.chip_size[1]
|
|
82
|
+
j = row * self.chip_size[0]
|
|
83
|
+
|
|
84
|
+
# Read window from raster
|
|
85
|
+
with rasterio.open(self.raster_path) as src:
|
|
86
|
+
# Make sure we don't read outside the image
|
|
87
|
+
width = min(self.chip_size[1], self.width - i)
|
|
88
|
+
height = min(self.chip_size[0], self.height - j)
|
|
89
|
+
|
|
90
|
+
window = Window(i, j, width, height)
|
|
91
|
+
image = src.read(window=window)
|
|
92
|
+
|
|
93
|
+
# Handle RGBA or multispectral images - keep only first 3 bands
|
|
94
|
+
if image.shape[0] > 3:
|
|
95
|
+
print(f"Image has {image.shape[0]} bands, using first 3 bands only")
|
|
96
|
+
image = image[:3]
|
|
97
|
+
elif image.shape[0] < 3:
|
|
98
|
+
# If image has fewer than 3 bands, duplicate the last band to make 3
|
|
99
|
+
print(f"Image has {image.shape[0]} bands, duplicating bands to make 3")
|
|
100
|
+
temp = np.zeros((3, image.shape[1], image.shape[2]), dtype=image.dtype)
|
|
101
|
+
for c in range(3):
|
|
102
|
+
temp[c] = image[min(c, image.shape[0] - 1)]
|
|
103
|
+
image = temp
|
|
104
|
+
|
|
105
|
+
# Handle partial windows at edges by padding
|
|
106
|
+
if (
|
|
107
|
+
image.shape[1] != self.chip_size[0]
|
|
108
|
+
or image.shape[2] != self.chip_size[1]
|
|
109
|
+
):
|
|
110
|
+
temp = np.zeros(
|
|
111
|
+
(image.shape[0], self.chip_size[0], self.chip_size[1]),
|
|
112
|
+
dtype=image.dtype,
|
|
113
|
+
)
|
|
114
|
+
temp[:, : image.shape[1], : image.shape[2]] = image
|
|
115
|
+
image = temp
|
|
116
|
+
|
|
117
|
+
# Convert to format expected by model (C,H,W)
|
|
118
|
+
image = torch.from_numpy(image).float()
|
|
119
|
+
|
|
120
|
+
# Normalize to [0, 1]
|
|
121
|
+
if image.max() > 1:
|
|
122
|
+
image = image / 255.0
|
|
123
|
+
|
|
124
|
+
# Apply transforms if any
|
|
125
|
+
if self.transforms is not None:
|
|
126
|
+
image = self.transforms(image)
|
|
127
|
+
|
|
128
|
+
# Create geographic bounding box for the window
|
|
129
|
+
minx, miny = self.transform * (i, j + height)
|
|
130
|
+
maxx, maxy = self.transform * (i + width, j)
|
|
131
|
+
bbox = box(minx, miny, maxx, maxy)
|
|
132
|
+
|
|
133
|
+
return {
|
|
134
|
+
"image": image,
|
|
135
|
+
"bbox": bbox,
|
|
136
|
+
"coords": torch.tensor([i, j], dtype=torch.long), # Consistent format
|
|
137
|
+
"window_size": torch.tensor(
|
|
138
|
+
[width, height], dtype=torch.long
|
|
139
|
+
), # Consistent format
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
def __len__(self):
|
|
143
|
+
"""Return the number of samples in the dataset."""
|
|
144
|
+
return self.rows * self.cols
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class BuildingFootprintExtractor:
|
|
148
|
+
"""
|
|
149
|
+
Building footprint extraction using Mask R-CNN with TorchGeo.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
def __init__(self, model_path=None, device=None):
|
|
153
|
+
"""
|
|
154
|
+
Initialize the building footprint extractor.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
model_path: Path to the .pth model file
|
|
158
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.)
|
|
159
|
+
"""
|
|
160
|
+
# Set device
|
|
161
|
+
if device is None:
|
|
162
|
+
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
163
|
+
else:
|
|
164
|
+
self.device = torch.device(device)
|
|
165
|
+
|
|
166
|
+
# Default parameters for building detection - these can be overridden in process_raster
|
|
167
|
+
self.chip_size = (512, 512) # Size of image chips for processing
|
|
168
|
+
self.overlap = 0.25 # Default overlap between tiles
|
|
169
|
+
self.confidence_threshold = 0.5 # Default confidence threshold
|
|
170
|
+
self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
|
|
171
|
+
self.small_building_area = 100 # Minimum area in pixels to keep a building
|
|
172
|
+
self.mask_threshold = 0.5 # Threshold for mask binarization
|
|
173
|
+
self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
|
|
174
|
+
|
|
175
|
+
# Initialize model
|
|
176
|
+
self.model = self._initialize_model()
|
|
177
|
+
|
|
178
|
+
# Download model if needed
|
|
179
|
+
if model_path is None:
|
|
180
|
+
model_path = self._download_model_from_hf()
|
|
181
|
+
|
|
182
|
+
# Load model weights
|
|
183
|
+
self._load_weights(model_path)
|
|
184
|
+
|
|
185
|
+
# Set model to evaluation mode
|
|
186
|
+
self.model.eval()
|
|
187
|
+
|
|
188
|
+
def _download_model_from_hf(self):
|
|
189
|
+
"""
|
|
190
|
+
Download the USA building footprints model from Hugging Face.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Path to the downloaded model file
|
|
194
|
+
"""
|
|
195
|
+
try:
|
|
196
|
+
|
|
197
|
+
print("Model path not specified, downloading from Hugging Face...")
|
|
198
|
+
|
|
199
|
+
# Define the repository ID and model filename
|
|
200
|
+
repo_id = "giswqs/geoai" # Update with your actual username/repo
|
|
201
|
+
filename = "usa_building_footprints.pth"
|
|
202
|
+
|
|
203
|
+
# Ensure cache directory exists
|
|
204
|
+
# cache_dir = os.path.join(
|
|
205
|
+
# os.path.expanduser("~"), ".cache", "building_footprints"
|
|
206
|
+
# )
|
|
207
|
+
# os.makedirs(cache_dir, exist_ok=True)
|
|
208
|
+
|
|
209
|
+
# Download the model
|
|
210
|
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
211
|
+
print(f"Model downloaded to: {model_path}")
|
|
212
|
+
|
|
213
|
+
return model_path
|
|
214
|
+
|
|
215
|
+
except Exception as e:
|
|
216
|
+
print(f"Error downloading model from Hugging Face: {e}")
|
|
217
|
+
print("Please specify a local model path or ensure internet connectivity.")
|
|
218
|
+
raise
|
|
219
|
+
|
|
220
|
+
def _initialize_model(self):
|
|
221
|
+
"""Initialize Mask R-CNN model with ResNet50 backbone."""
|
|
222
|
+
# Standard image mean and std for pre-trained models
|
|
223
|
+
# Note: This would normally come from your config file
|
|
224
|
+
image_mean = [0.485, 0.456, 0.406]
|
|
225
|
+
image_std = [0.229, 0.224, 0.225]
|
|
226
|
+
|
|
227
|
+
# Create model with explicit normalization parameters
|
|
228
|
+
model = maskrcnn_resnet50_fpn(
|
|
229
|
+
weights=None,
|
|
230
|
+
progress=False,
|
|
231
|
+
num_classes=2, # Background + building
|
|
232
|
+
weights_backbone=None,
|
|
233
|
+
# These parameters ensure consistent normalization
|
|
234
|
+
image_mean=image_mean,
|
|
235
|
+
image_std=image_std,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
model.to(self.device)
|
|
239
|
+
return model
|
|
240
|
+
|
|
241
|
+
def _load_weights(self, model_path):
|
|
242
|
+
"""
|
|
243
|
+
Load weights from file with error handling for different formats.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
model_path: Path to model weights
|
|
247
|
+
"""
|
|
248
|
+
if not os.path.exists(model_path):
|
|
249
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
state_dict = torch.load(model_path, map_location=self.device)
|
|
253
|
+
|
|
254
|
+
# Handle different state dict formats
|
|
255
|
+
if isinstance(state_dict, dict):
|
|
256
|
+
if "model" in state_dict:
|
|
257
|
+
state_dict = state_dict["model"]
|
|
258
|
+
elif "state_dict" in state_dict:
|
|
259
|
+
state_dict = state_dict["state_dict"]
|
|
260
|
+
|
|
261
|
+
# Try to load state dict
|
|
262
|
+
try:
|
|
263
|
+
self.model.load_state_dict(state_dict)
|
|
264
|
+
print("Model loaded successfully")
|
|
265
|
+
except Exception as e:
|
|
266
|
+
print(f"Error loading model: {e}")
|
|
267
|
+
print("Attempting to fix state_dict keys...")
|
|
268
|
+
|
|
269
|
+
# Try to fix state_dict keys (remove module prefix if needed)
|
|
270
|
+
new_state_dict = {}
|
|
271
|
+
for k, v in state_dict.items():
|
|
272
|
+
if k.startswith("module."):
|
|
273
|
+
new_state_dict[k[7:]] = v
|
|
274
|
+
else:
|
|
275
|
+
new_state_dict[k] = v
|
|
276
|
+
|
|
277
|
+
self.model.load_state_dict(new_state_dict)
|
|
278
|
+
print("Model loaded successfully after key fixing")
|
|
279
|
+
|
|
280
|
+
except Exception as e:
|
|
281
|
+
raise RuntimeError(f"Failed to load model: {e}")
|
|
282
|
+
|
|
283
|
+
def _mask_to_polygons(self, mask, **kwargs):
|
|
284
|
+
"""
|
|
285
|
+
Convert binary mask to polygon contours using OpenCV.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
mask: Binary mask as numpy array
|
|
289
|
+
**kwargs: Optional parameters:
|
|
290
|
+
simplify_tolerance: Tolerance for polygon simplification
|
|
291
|
+
mask_threshold: Threshold for mask binarization
|
|
292
|
+
small_building_area: Minimum area in pixels to keep a building
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
List of polygons as lists of (x, y) coordinates
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
# Get parameters from kwargs or use instance defaults
|
|
299
|
+
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
300
|
+
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
301
|
+
small_building_area = kwargs.get(
|
|
302
|
+
"small_building_area", self.small_building_area
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Ensure binary mask
|
|
306
|
+
mask = (mask > mask_threshold).astype(np.uint8)
|
|
307
|
+
|
|
308
|
+
# Optional: apply morphological operations to improve mask quality
|
|
309
|
+
kernel = np.ones((3, 3), np.uint8)
|
|
310
|
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
|
311
|
+
|
|
312
|
+
# Find contours
|
|
313
|
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
314
|
+
|
|
315
|
+
# Convert to list of [x, y] coordinates
|
|
316
|
+
polygons = []
|
|
317
|
+
for contour in contours:
|
|
318
|
+
# Filter out too small contours
|
|
319
|
+
if contour.shape[0] < 3 or cv2.contourArea(contour) < small_building_area:
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
# Simplify contour if it has many points
|
|
323
|
+
if contour.shape[0] > 50:
|
|
324
|
+
epsilon = simplify_tolerance * cv2.arcLength(contour, True)
|
|
325
|
+
contour = cv2.approxPolyDP(contour, epsilon, True)
|
|
326
|
+
|
|
327
|
+
# Convert to list of [x, y] coordinates
|
|
328
|
+
polygon = contour.reshape(-1, 2).tolist()
|
|
329
|
+
polygons.append(polygon)
|
|
330
|
+
|
|
331
|
+
return polygons
|
|
332
|
+
|
|
333
|
+
def _filter_overlapping_polygons(self, gdf, **kwargs):
|
|
334
|
+
"""
|
|
335
|
+
Filter overlapping polygons using non-maximum suppression.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
gdf: GeoDataFrame with polygons
|
|
339
|
+
**kwargs: Optional parameters:
|
|
340
|
+
nms_iou_threshold: IoU threshold for filtering
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
Filtered GeoDataFrame
|
|
344
|
+
"""
|
|
345
|
+
if len(gdf) <= 1:
|
|
346
|
+
return gdf
|
|
347
|
+
|
|
348
|
+
# Get parameters from kwargs or use instance defaults
|
|
349
|
+
iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
|
|
350
|
+
|
|
351
|
+
# Sort by confidence
|
|
352
|
+
gdf = gdf.sort_values("confidence", ascending=False)
|
|
353
|
+
|
|
354
|
+
# Fix any invalid geometries
|
|
355
|
+
gdf["geometry"] = gdf["geometry"].apply(
|
|
356
|
+
lambda geom: geom.buffer(0) if not geom.is_valid else geom
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
keep_indices = []
|
|
360
|
+
polygons = gdf.geometry.values
|
|
361
|
+
|
|
362
|
+
for i in range(len(polygons)):
|
|
363
|
+
if i in keep_indices:
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
keep = True
|
|
367
|
+
for j in keep_indices:
|
|
368
|
+
# Skip invalid geometries
|
|
369
|
+
if not polygons[i].is_valid or not polygons[j].is_valid:
|
|
370
|
+
continue
|
|
371
|
+
|
|
372
|
+
# Calculate IoU
|
|
373
|
+
try:
|
|
374
|
+
intersection = polygons[i].intersection(polygons[j]).area
|
|
375
|
+
union = polygons[i].area + polygons[j].area - intersection
|
|
376
|
+
iou = intersection / union if union > 0 else 0
|
|
377
|
+
|
|
378
|
+
if iou > iou_threshold:
|
|
379
|
+
keep = False
|
|
380
|
+
break
|
|
381
|
+
except Exception:
|
|
382
|
+
# Skip on topology exceptions
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
if keep:
|
|
386
|
+
keep_indices.append(i)
|
|
387
|
+
|
|
388
|
+
return gdf.iloc[keep_indices]
|
|
389
|
+
|
|
390
|
+
@torch.no_grad()
|
|
391
|
+
def process_raster(self, raster_path, output_path=None, batch_size=4, **kwargs):
|
|
392
|
+
"""
|
|
393
|
+
Process a raster file to extract building footprints with customizable parameters.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
raster_path: Path to input raster file
|
|
397
|
+
output_path: Path to output GeoJSON file (optional)
|
|
398
|
+
batch_size: Batch size for processing
|
|
399
|
+
**kwargs: Additional parameters:
|
|
400
|
+
confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
|
|
401
|
+
overlap: Overlap between adjacent tiles (0.0-1.0)
|
|
402
|
+
chip_size: Size of image chips for processing (height, width)
|
|
403
|
+
nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
|
|
404
|
+
mask_threshold: Threshold for mask binarization (0.0-1.0)
|
|
405
|
+
small_building_area: Minimum area in pixels to keep a building
|
|
406
|
+
simplify_tolerance: Tolerance for polygon simplification
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
GeoDataFrame with building footprints
|
|
410
|
+
"""
|
|
411
|
+
# Get parameters from kwargs or use instance defaults
|
|
412
|
+
confidence_threshold = kwargs.get(
|
|
413
|
+
"confidence_threshold", self.confidence_threshold
|
|
414
|
+
)
|
|
415
|
+
overlap = kwargs.get("overlap", self.overlap)
|
|
416
|
+
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
417
|
+
nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
|
|
418
|
+
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
419
|
+
small_building_area = kwargs.get(
|
|
420
|
+
"small_building_area", self.small_building_area
|
|
421
|
+
)
|
|
422
|
+
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
423
|
+
|
|
424
|
+
# Print parameters being used
|
|
425
|
+
print(f"Processing with parameters:")
|
|
426
|
+
print(f"- Confidence threshold: {confidence_threshold}")
|
|
427
|
+
print(f"- Tile overlap: {overlap}")
|
|
428
|
+
print(f"- Chip size: {chip_size}")
|
|
429
|
+
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
430
|
+
print(f"- Mask threshold: {mask_threshold}")
|
|
431
|
+
print(f"- Min building area: {small_building_area}")
|
|
432
|
+
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
433
|
+
|
|
434
|
+
# Create dataset
|
|
435
|
+
dataset = BuildingFootprintDataset(raster_path=raster_path, chip_size=chip_size)
|
|
436
|
+
|
|
437
|
+
# Custom collate function to handle Shapely objects
|
|
438
|
+
def custom_collate(batch):
|
|
439
|
+
"""
|
|
440
|
+
Custom collate function that handles Shapely geometries
|
|
441
|
+
by keeping them as Python objects rather than trying to collate them.
|
|
442
|
+
"""
|
|
443
|
+
elem = batch[0]
|
|
444
|
+
if isinstance(elem, dict):
|
|
445
|
+
result = {}
|
|
446
|
+
for key in elem:
|
|
447
|
+
if key == "bbox":
|
|
448
|
+
# Don't collate shapely objects, keep as list
|
|
449
|
+
result[key] = [d[key] for d in batch]
|
|
450
|
+
else:
|
|
451
|
+
# For tensors and other collatable types
|
|
452
|
+
try:
|
|
453
|
+
result[key] = (
|
|
454
|
+
torch.utils.data._utils.collate.default_collate(
|
|
455
|
+
[d[key] for d in batch]
|
|
456
|
+
)
|
|
457
|
+
)
|
|
458
|
+
except TypeError:
|
|
459
|
+
# Fall back to list for non-collatable types
|
|
460
|
+
result[key] = [d[key] for d in batch]
|
|
461
|
+
return result
|
|
462
|
+
else:
|
|
463
|
+
# Default collate for non-dict types
|
|
464
|
+
return torch.utils.data._utils.collate.default_collate(batch)
|
|
465
|
+
|
|
466
|
+
# Create dataloader with simple indexing and custom collate
|
|
467
|
+
dataloader = torch.utils.data.DataLoader(
|
|
468
|
+
dataset,
|
|
469
|
+
batch_size=batch_size,
|
|
470
|
+
shuffle=False,
|
|
471
|
+
num_workers=0,
|
|
472
|
+
collate_fn=custom_collate,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
# Process batches
|
|
476
|
+
all_polygons = []
|
|
477
|
+
all_scores = []
|
|
478
|
+
|
|
479
|
+
print(f"Processing raster with {len(dataloader)} batches")
|
|
480
|
+
for batch in tqdm(dataloader):
|
|
481
|
+
# Move images to device
|
|
482
|
+
images = batch["image"].to(self.device)
|
|
483
|
+
coords = batch["coords"] # (i, j) coordinates in pixels
|
|
484
|
+
bboxes = batch[
|
|
485
|
+
"bbox"
|
|
486
|
+
] # Geographic bounding boxes - now a list, not a tensor
|
|
487
|
+
|
|
488
|
+
# Run inference
|
|
489
|
+
predictions = self.model(images)
|
|
490
|
+
|
|
491
|
+
# Process predictions
|
|
492
|
+
for idx, prediction in enumerate(predictions):
|
|
493
|
+
masks = prediction["masks"].cpu().numpy()
|
|
494
|
+
scores = prediction["scores"].cpu().numpy()
|
|
495
|
+
labels = prediction["labels"].cpu().numpy()
|
|
496
|
+
|
|
497
|
+
# Skip if no predictions
|
|
498
|
+
if len(scores) == 0:
|
|
499
|
+
continue
|
|
500
|
+
|
|
501
|
+
# Filter by confidence threshold
|
|
502
|
+
valid_indices = scores >= confidence_threshold
|
|
503
|
+
masks = masks[valid_indices]
|
|
504
|
+
scores = scores[valid_indices]
|
|
505
|
+
labels = labels[valid_indices]
|
|
506
|
+
|
|
507
|
+
# Skip if no valid predictions
|
|
508
|
+
if len(scores) == 0:
|
|
509
|
+
continue
|
|
510
|
+
|
|
511
|
+
# Get window coordinates
|
|
512
|
+
# The coords might be in different formats depending on batch handling
|
|
513
|
+
if isinstance(coords, list):
|
|
514
|
+
# If coords is a list of tuples
|
|
515
|
+
coord_item = coords[idx]
|
|
516
|
+
if isinstance(coord_item, tuple) and len(coord_item) == 2:
|
|
517
|
+
i, j = coord_item
|
|
518
|
+
elif isinstance(coord_item, torch.Tensor):
|
|
519
|
+
i, j = coord_item.cpu().numpy().tolist()
|
|
520
|
+
else:
|
|
521
|
+
print(f"Unexpected coords format: {type(coord_item)}")
|
|
522
|
+
continue
|
|
523
|
+
elif isinstance(coords, torch.Tensor):
|
|
524
|
+
# If coords is a tensor of shape [batch_size, 2]
|
|
525
|
+
i, j = coords[idx].cpu().numpy().tolist()
|
|
526
|
+
else:
|
|
527
|
+
print(f"Unexpected coords type: {type(coords)}")
|
|
528
|
+
continue
|
|
529
|
+
|
|
530
|
+
# Get window size
|
|
531
|
+
if isinstance(batch["window_size"], list):
|
|
532
|
+
window_item = batch["window_size"][idx]
|
|
533
|
+
if isinstance(window_item, tuple) and len(window_item) == 2:
|
|
534
|
+
window_width, window_height = window_item
|
|
535
|
+
elif isinstance(window_item, torch.Tensor):
|
|
536
|
+
window_width, window_height = window_item.cpu().numpy().tolist()
|
|
537
|
+
else:
|
|
538
|
+
print(f"Unexpected window_size format: {type(window_item)}")
|
|
539
|
+
continue
|
|
540
|
+
elif isinstance(batch["window_size"], torch.Tensor):
|
|
541
|
+
window_width, window_height = (
|
|
542
|
+
batch["window_size"][idx].cpu().numpy().tolist()
|
|
543
|
+
)
|
|
544
|
+
else:
|
|
545
|
+
print(f"Unexpected window_size type: {type(batch['window_size'])}")
|
|
546
|
+
continue
|
|
547
|
+
|
|
548
|
+
# Process masks to polygons
|
|
549
|
+
for mask_idx, mask in enumerate(masks):
|
|
550
|
+
# Get binary mask
|
|
551
|
+
binary_mask = mask[0] # Get binary mask
|
|
552
|
+
|
|
553
|
+
# Convert mask to polygon with custom parameters
|
|
554
|
+
contours = self._mask_to_polygons(
|
|
555
|
+
binary_mask,
|
|
556
|
+
simplify_tolerance=simplify_tolerance,
|
|
557
|
+
mask_threshold=mask_threshold,
|
|
558
|
+
small_building_area=small_building_area,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
# Skip if no valid polygons
|
|
562
|
+
if not contours:
|
|
563
|
+
continue
|
|
564
|
+
|
|
565
|
+
# Transform polygons to geographic coordinates
|
|
566
|
+
with rasterio.open(raster_path) as src:
|
|
567
|
+
transform = src.transform
|
|
568
|
+
|
|
569
|
+
for contour in contours:
|
|
570
|
+
# Convert polygon to global coordinates
|
|
571
|
+
global_polygon = []
|
|
572
|
+
for x, y in contour:
|
|
573
|
+
# Adjust coordinates based on window position
|
|
574
|
+
gx, gy = transform * (i + x, j + y)
|
|
575
|
+
global_polygon.append((gx, gy))
|
|
576
|
+
|
|
577
|
+
# Create Shapely polygon
|
|
578
|
+
if len(global_polygon) >= 3:
|
|
579
|
+
try:
|
|
580
|
+
shapely_poly = Polygon(global_polygon)
|
|
581
|
+
if shapely_poly.is_valid and shapely_poly.area > 0:
|
|
582
|
+
all_polygons.append(shapely_poly)
|
|
583
|
+
all_scores.append(float(scores[mask_idx]))
|
|
584
|
+
except Exception as e:
|
|
585
|
+
print(f"Error creating polygon: {e}")
|
|
586
|
+
|
|
587
|
+
# Create GeoDataFrame
|
|
588
|
+
if not all_polygons:
|
|
589
|
+
print("No valid polygons found")
|
|
590
|
+
return None
|
|
591
|
+
|
|
592
|
+
gdf = gpd.GeoDataFrame(
|
|
593
|
+
{
|
|
594
|
+
"geometry": all_polygons,
|
|
595
|
+
"confidence": all_scores,
|
|
596
|
+
"class": 1, # Building class
|
|
597
|
+
},
|
|
598
|
+
crs=dataset.crs,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
# Remove overlapping polygons with custom threshold
|
|
602
|
+
gdf = self._filter_overlapping_polygons(
|
|
603
|
+
gdf, nms_iou_threshold=nms_iou_threshold
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
# Save to file if requested
|
|
607
|
+
if output_path:
|
|
608
|
+
gdf.to_file(output_path, driver="GeoJSON")
|
|
609
|
+
print(f"Saved {len(gdf)} building footprints to {output_path}")
|
|
610
|
+
|
|
611
|
+
return gdf
|
|
612
|
+
|
|
613
|
+
def visualize_results(
|
|
614
|
+
self, raster_path, gdf=None, output_path=None, figsize=(12, 12)
|
|
615
|
+
):
|
|
616
|
+
"""
|
|
617
|
+
Visualize building detection results.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
raster_path: Path to input raster
|
|
621
|
+
gdf: GeoDataFrame with building polygons (optional)
|
|
622
|
+
output_path: Path to save visualization (optional)
|
|
623
|
+
figsize: Figure size (width, height) in inches
|
|
624
|
+
"""
|
|
625
|
+
# Check if raster file exists
|
|
626
|
+
if not os.path.exists(raster_path):
|
|
627
|
+
print(f"Error: Raster file '{raster_path}' not found.")
|
|
628
|
+
return
|
|
629
|
+
|
|
630
|
+
# Process raster if GeoDataFrame not provided
|
|
631
|
+
if gdf is None:
|
|
632
|
+
gdf = self.process_raster(raster_path)
|
|
633
|
+
|
|
634
|
+
if gdf is None or len(gdf) == 0:
|
|
635
|
+
print("No buildings to visualize")
|
|
636
|
+
return
|
|
637
|
+
|
|
638
|
+
# Read raster for visualization
|
|
639
|
+
with rasterio.open(raster_path) as src:
|
|
640
|
+
# Read the entire image or a subset if it's very large
|
|
641
|
+
if src.height > 2000 or src.width > 2000:
|
|
642
|
+
# Calculate scale factor to reduce size
|
|
643
|
+
scale = min(2000 / src.height, 2000 / src.width)
|
|
644
|
+
out_shape = (
|
|
645
|
+
int(src.count),
|
|
646
|
+
int(src.height * scale),
|
|
647
|
+
int(src.width * scale),
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
# Read and resample
|
|
651
|
+
image = src.read(
|
|
652
|
+
out_shape=out_shape, resampling=rasterio.enums.Resampling.bilinear
|
|
653
|
+
)
|
|
654
|
+
else:
|
|
655
|
+
image = src.read()
|
|
656
|
+
|
|
657
|
+
# Convert to RGB for display
|
|
658
|
+
if image.shape[0] > 3:
|
|
659
|
+
image = image[:3]
|
|
660
|
+
elif image.shape[0] == 1:
|
|
661
|
+
image = np.repeat(image, 3, axis=0)
|
|
662
|
+
|
|
663
|
+
# Normalize image for display
|
|
664
|
+
image = image.transpose(1, 2, 0) # CHW to HWC
|
|
665
|
+
image = image.astype(np.float32)
|
|
666
|
+
|
|
667
|
+
if image.max() > 10: # Likely 0-255 range
|
|
668
|
+
image = image / 255.0
|
|
669
|
+
|
|
670
|
+
image = np.clip(image, 0, 1)
|
|
671
|
+
|
|
672
|
+
# Get image bounds
|
|
673
|
+
bounds = src.bounds
|
|
674
|
+
|
|
675
|
+
# Create figure with appropriate aspect ratio
|
|
676
|
+
aspect_ratio = image.shape[1] / image.shape[0] # width / height
|
|
677
|
+
plt.figure(figsize=(figsize[0], figsize[0] / aspect_ratio))
|
|
678
|
+
|
|
679
|
+
# Create axis with the right projection if CRS is available
|
|
680
|
+
ax = plt.gca()
|
|
681
|
+
|
|
682
|
+
# Display image
|
|
683
|
+
ax.imshow(image)
|
|
684
|
+
|
|
685
|
+
# Convert GeoDataFrame to pixel coordinates for plotting
|
|
686
|
+
with rasterio.open(raster_path) as src:
|
|
687
|
+
|
|
688
|
+
def geo_to_pixel(x, y):
|
|
689
|
+
return ~src.transform * (x, y)
|
|
690
|
+
|
|
691
|
+
# Plot each building footprint
|
|
692
|
+
for _, row in gdf.iterrows():
|
|
693
|
+
# Convert polygon to pixel coordinates
|
|
694
|
+
geom = row.geometry
|
|
695
|
+
if geom.is_empty:
|
|
696
|
+
continue
|
|
697
|
+
|
|
698
|
+
try:
|
|
699
|
+
# Get polygon exterior coordinates
|
|
700
|
+
x, y = geom.exterior.xy
|
|
701
|
+
|
|
702
|
+
# Convert to pixel coordinates
|
|
703
|
+
pixel_coords = [geo_to_pixel(x[i], y[i]) for i in range(len(x))]
|
|
704
|
+
pixel_x = [coord[0] for coord in pixel_coords]
|
|
705
|
+
pixel_y = [coord[1] for coord in pixel_coords]
|
|
706
|
+
|
|
707
|
+
# Plot polygon
|
|
708
|
+
ax.plot(pixel_x, pixel_y, color="red", linewidth=1)
|
|
709
|
+
except Exception as e:
|
|
710
|
+
print(f"Error plotting polygon: {e}")
|
|
711
|
+
|
|
712
|
+
# Remove axes
|
|
713
|
+
ax.set_xticks([])
|
|
714
|
+
ax.set_yticks([])
|
|
715
|
+
ax.set_title(f"Building Footprints (Found: {len(gdf)})")
|
|
716
|
+
|
|
717
|
+
# Add colorbar for confidence if available
|
|
718
|
+
if "confidence" in gdf.columns:
|
|
719
|
+
# Create a colorbar legend
|
|
720
|
+
sm = plt.cm.ScalarMappable(
|
|
721
|
+
cmap=plt.cm.viridis,
|
|
722
|
+
norm=plt.Normalize(gdf.confidence.min(), gdf.confidence.max()),
|
|
723
|
+
)
|
|
724
|
+
sm.set_array([])
|
|
725
|
+
cbar = plt.colorbar(sm, ax=ax, orientation="vertical", shrink=0.7)
|
|
726
|
+
cbar.set_label("Confidence")
|
|
727
|
+
|
|
728
|
+
# Save if requested
|
|
729
|
+
if output_path:
|
|
730
|
+
plt.tight_layout()
|
|
731
|
+
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
|
732
|
+
print(f"Visualization saved to {output_path}")
|
|
733
|
+
|
|
734
|
+
plt.close()
|
|
735
|
+
|
|
736
|
+
# Create a simpler visualization focused just on a subset of buildings
|
|
737
|
+
# This helps when the raster is very large
|
|
738
|
+
plt.figure(figsize=figsize)
|
|
739
|
+
ax = plt.gca()
|
|
740
|
+
|
|
741
|
+
# Choose a subset of the image to show
|
|
742
|
+
with rasterio.open(raster_path) as src:
|
|
743
|
+
# Get a sample window based on the first few buildings
|
|
744
|
+
if len(gdf) > 0:
|
|
745
|
+
# Get centroid of first building
|
|
746
|
+
sample_geom = gdf.iloc[0].geometry
|
|
747
|
+
centroid = sample_geom.centroid
|
|
748
|
+
|
|
749
|
+
# Convert to pixel coordinates
|
|
750
|
+
center_x, center_y = ~src.transform * (centroid.x, centroid.y)
|
|
751
|
+
|
|
752
|
+
# Define a window around this building
|
|
753
|
+
window_size = 500 # pixels
|
|
754
|
+
window = rasterio.windows.Window(
|
|
755
|
+
max(0, int(center_x - window_size / 2)),
|
|
756
|
+
max(0, int(center_y - window_size / 2)),
|
|
757
|
+
min(window_size, src.width - int(center_x - window_size / 2)),
|
|
758
|
+
min(window_size, src.height - int(center_y - window_size / 2)),
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# Read this window
|
|
762
|
+
sample_image = src.read(window=window)
|
|
763
|
+
|
|
764
|
+
# Convert to RGB for display
|
|
765
|
+
if sample_image.shape[0] > 3:
|
|
766
|
+
sample_image = sample_image[:3]
|
|
767
|
+
elif sample_image.shape[0] == 1:
|
|
768
|
+
sample_image = np.repeat(sample_image, 3, axis=0)
|
|
769
|
+
|
|
770
|
+
# Normalize image for display
|
|
771
|
+
sample_image = sample_image.transpose(1, 2, 0) # CHW to HWC
|
|
772
|
+
sample_image = sample_image.astype(np.float32)
|
|
773
|
+
|
|
774
|
+
if sample_image.max() > 10: # Likely 0-255 range
|
|
775
|
+
sample_image = sample_image / 255.0
|
|
776
|
+
|
|
777
|
+
sample_image = np.clip(sample_image, 0, 1)
|
|
778
|
+
|
|
779
|
+
# Get transform for this window
|
|
780
|
+
window_transform = src.window_transform(window)
|
|
781
|
+
|
|
782
|
+
# Display sample image
|
|
783
|
+
ax.imshow(sample_image)
|
|
784
|
+
|
|
785
|
+
# Filter buildings that intersect with this window
|
|
786
|
+
window_bounds = rasterio.windows.bounds(window, src.transform)
|
|
787
|
+
window_box = box(*window_bounds)
|
|
788
|
+
visible_gdf = gdf[gdf.intersects(window_box)]
|
|
789
|
+
|
|
790
|
+
# Plot building footprints in this view
|
|
791
|
+
for _, row in visible_gdf.iterrows():
|
|
792
|
+
try:
|
|
793
|
+
# Get polygon exterior coordinates
|
|
794
|
+
geom = row.geometry
|
|
795
|
+
if geom.is_empty:
|
|
796
|
+
continue
|
|
797
|
+
|
|
798
|
+
x, y = geom.exterior.xy
|
|
799
|
+
|
|
800
|
+
# Convert to pixel coordinates relative to window
|
|
801
|
+
pixel_coords = [
|
|
802
|
+
~window_transform * (x[i], y[i]) for i in range(len(x))
|
|
803
|
+
]
|
|
804
|
+
pixel_x = [coord[0] for coord in pixel_coords]
|
|
805
|
+
pixel_y = [coord[1] for coord in pixel_coords]
|
|
806
|
+
|
|
807
|
+
# Plot polygon
|
|
808
|
+
ax.plot(pixel_x, pixel_y, color="red", linewidth=1.5)
|
|
809
|
+
except Exception as e:
|
|
810
|
+
print(f"Error plotting polygon in sample view: {e}")
|
|
811
|
+
|
|
812
|
+
# Set title
|
|
813
|
+
ax.set_title(
|
|
814
|
+
f"Sample Area - Building Footprints (Showing: {len(visible_gdf)})"
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
# Remove axes
|
|
818
|
+
ax.set_xticks([])
|
|
819
|
+
ax.set_yticks([])
|
|
820
|
+
|
|
821
|
+
# Save if requested
|
|
822
|
+
if output_path:
|
|
823
|
+
sample_output = (
|
|
824
|
+
os.path.splitext(output_path)[0]
|
|
825
|
+
+ "_sample"
|
|
826
|
+
+ os.path.splitext(output_path)[1]
|
|
827
|
+
)
|
|
828
|
+
plt.tight_layout()
|
|
829
|
+
plt.savefig(sample_output, dpi=300, bbox_inches="tight")
|
|
830
|
+
print(f"Sample visualization saved to {sample_output}")
|
|
831
|
+
|
|
832
|
+
return True
|