geoai-py 0.1.7__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/extract.py ADDED
@@ -0,0 +1,832 @@
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ from shapely.geometry import Polygon, box
6
+ import geopandas as gpd
7
+ from tqdm import tqdm
8
+
9
+ import cv2
10
+ from torchgeo.datasets import NonGeoDataset
11
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
12
+ import torchvision.transforms as T
13
+ import rasterio
14
+ from rasterio.windows import Window
15
+ from rasterio.features import shapes
16
+ from huggingface_hub import hf_hub_download
17
+
18
+
19
+ class BuildingFootprintDataset(NonGeoDataset):
20
+ """
21
+ A TorchGeo dataset for building footprint extraction.
22
+ Using NonGeoDataset to avoid spatial indexing issues.
23
+ """
24
+
25
+ def __init__(self, raster_path, chip_size=(512, 512), transforms=None):
26
+ """
27
+ Initialize the dataset.
28
+
29
+ Args:
30
+ raster_path: Path to the input raster file
31
+ chip_size: Size of image chips to extract (height, width)
32
+ transforms: Transforms to apply to the image
33
+ """
34
+ super().__init__()
35
+
36
+ # Initialize parameters
37
+ self.raster_path = raster_path
38
+ self.chip_size = chip_size
39
+ self.transforms = transforms
40
+
41
+ # Open raster and get metadata
42
+ with rasterio.open(self.raster_path) as src:
43
+ self.crs = src.crs
44
+ self.transform = src.transform
45
+ self.height = src.height
46
+ self.width = src.width
47
+ self.count = src.count
48
+
49
+ # Define the bounds of the dataset
50
+ west, south, east, north = src.bounds
51
+ self.bounds = (west, south, east, north)
52
+
53
+ # Define the ROI for the dataset
54
+ self.roi = box(*self.bounds)
55
+
56
+ # Calculate number of chips in each dimension
57
+ self.rows = self.height // self.chip_size[0]
58
+ self.cols = self.width // self.chip_size[1]
59
+
60
+ print(
61
+ f"Dataset initialized with {self.rows} rows and {self.cols} columns of chips"
62
+ )
63
+ if src.crs:
64
+ print(f"CRS: {src.crs}")
65
+
66
+ def __getitem__(self, idx):
67
+ """
68
+ Get an image chip from the dataset by index.
69
+
70
+ Args:
71
+ idx: Index of the chip
72
+
73
+ Returns:
74
+ Dict containing image tensor
75
+ """
76
+ # Convert flat index to grid position
77
+ row = idx // self.cols
78
+ col = idx % self.cols
79
+
80
+ # Calculate pixel coordinates
81
+ i = col * self.chip_size[1]
82
+ j = row * self.chip_size[0]
83
+
84
+ # Read window from raster
85
+ with rasterio.open(self.raster_path) as src:
86
+ # Make sure we don't read outside the image
87
+ width = min(self.chip_size[1], self.width - i)
88
+ height = min(self.chip_size[0], self.height - j)
89
+
90
+ window = Window(i, j, width, height)
91
+ image = src.read(window=window)
92
+
93
+ # Handle RGBA or multispectral images - keep only first 3 bands
94
+ if image.shape[0] > 3:
95
+ print(f"Image has {image.shape[0]} bands, using first 3 bands only")
96
+ image = image[:3]
97
+ elif image.shape[0] < 3:
98
+ # If image has fewer than 3 bands, duplicate the last band to make 3
99
+ print(f"Image has {image.shape[0]} bands, duplicating bands to make 3")
100
+ temp = np.zeros((3, image.shape[1], image.shape[2]), dtype=image.dtype)
101
+ for c in range(3):
102
+ temp[c] = image[min(c, image.shape[0] - 1)]
103
+ image = temp
104
+
105
+ # Handle partial windows at edges by padding
106
+ if (
107
+ image.shape[1] != self.chip_size[0]
108
+ or image.shape[2] != self.chip_size[1]
109
+ ):
110
+ temp = np.zeros(
111
+ (image.shape[0], self.chip_size[0], self.chip_size[1]),
112
+ dtype=image.dtype,
113
+ )
114
+ temp[:, : image.shape[1], : image.shape[2]] = image
115
+ image = temp
116
+
117
+ # Convert to format expected by model (C,H,W)
118
+ image = torch.from_numpy(image).float()
119
+
120
+ # Normalize to [0, 1]
121
+ if image.max() > 1:
122
+ image = image / 255.0
123
+
124
+ # Apply transforms if any
125
+ if self.transforms is not None:
126
+ image = self.transforms(image)
127
+
128
+ # Create geographic bounding box for the window
129
+ minx, miny = self.transform * (i, j + height)
130
+ maxx, maxy = self.transform * (i + width, j)
131
+ bbox = box(minx, miny, maxx, maxy)
132
+
133
+ return {
134
+ "image": image,
135
+ "bbox": bbox,
136
+ "coords": torch.tensor([i, j], dtype=torch.long), # Consistent format
137
+ "window_size": torch.tensor(
138
+ [width, height], dtype=torch.long
139
+ ), # Consistent format
140
+ }
141
+
142
+ def __len__(self):
143
+ """Return the number of samples in the dataset."""
144
+ return self.rows * self.cols
145
+
146
+
147
+ class BuildingFootprintExtractor:
148
+ """
149
+ Building footprint extraction using Mask R-CNN with TorchGeo.
150
+ """
151
+
152
+ def __init__(self, model_path=None, device=None):
153
+ """
154
+ Initialize the building footprint extractor.
155
+
156
+ Args:
157
+ model_path: Path to the .pth model file
158
+ device: Device to use for inference ('cuda:0', 'cpu', etc.)
159
+ """
160
+ # Set device
161
+ if device is None:
162
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
163
+ else:
164
+ self.device = torch.device(device)
165
+
166
+ # Default parameters for building detection - these can be overridden in process_raster
167
+ self.chip_size = (512, 512) # Size of image chips for processing
168
+ self.overlap = 0.25 # Default overlap between tiles
169
+ self.confidence_threshold = 0.5 # Default confidence threshold
170
+ self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
171
+ self.small_building_area = 100 # Minimum area in pixels to keep a building
172
+ self.mask_threshold = 0.5 # Threshold for mask binarization
173
+ self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
174
+
175
+ # Initialize model
176
+ self.model = self._initialize_model()
177
+
178
+ # Download model if needed
179
+ if model_path is None:
180
+ model_path = self._download_model_from_hf()
181
+
182
+ # Load model weights
183
+ self._load_weights(model_path)
184
+
185
+ # Set model to evaluation mode
186
+ self.model.eval()
187
+
188
+ def _download_model_from_hf(self):
189
+ """
190
+ Download the USA building footprints model from Hugging Face.
191
+
192
+ Returns:
193
+ Path to the downloaded model file
194
+ """
195
+ try:
196
+
197
+ print("Model path not specified, downloading from Hugging Face...")
198
+
199
+ # Define the repository ID and model filename
200
+ repo_id = "giswqs/geoai" # Update with your actual username/repo
201
+ filename = "usa_building_footprints.pth"
202
+
203
+ # Ensure cache directory exists
204
+ # cache_dir = os.path.join(
205
+ # os.path.expanduser("~"), ".cache", "building_footprints"
206
+ # )
207
+ # os.makedirs(cache_dir, exist_ok=True)
208
+
209
+ # Download the model
210
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
211
+ print(f"Model downloaded to: {model_path}")
212
+
213
+ return model_path
214
+
215
+ except Exception as e:
216
+ print(f"Error downloading model from Hugging Face: {e}")
217
+ print("Please specify a local model path or ensure internet connectivity.")
218
+ raise
219
+
220
+ def _initialize_model(self):
221
+ """Initialize Mask R-CNN model with ResNet50 backbone."""
222
+ # Standard image mean and std for pre-trained models
223
+ # Note: This would normally come from your config file
224
+ image_mean = [0.485, 0.456, 0.406]
225
+ image_std = [0.229, 0.224, 0.225]
226
+
227
+ # Create model with explicit normalization parameters
228
+ model = maskrcnn_resnet50_fpn(
229
+ weights=None,
230
+ progress=False,
231
+ num_classes=2, # Background + building
232
+ weights_backbone=None,
233
+ # These parameters ensure consistent normalization
234
+ image_mean=image_mean,
235
+ image_std=image_std,
236
+ )
237
+
238
+ model.to(self.device)
239
+ return model
240
+
241
+ def _load_weights(self, model_path):
242
+ """
243
+ Load weights from file with error handling for different formats.
244
+
245
+ Args:
246
+ model_path: Path to model weights
247
+ """
248
+ if not os.path.exists(model_path):
249
+ raise FileNotFoundError(f"Model file not found: {model_path}")
250
+
251
+ try:
252
+ state_dict = torch.load(model_path, map_location=self.device)
253
+
254
+ # Handle different state dict formats
255
+ if isinstance(state_dict, dict):
256
+ if "model" in state_dict:
257
+ state_dict = state_dict["model"]
258
+ elif "state_dict" in state_dict:
259
+ state_dict = state_dict["state_dict"]
260
+
261
+ # Try to load state dict
262
+ try:
263
+ self.model.load_state_dict(state_dict)
264
+ print("Model loaded successfully")
265
+ except Exception as e:
266
+ print(f"Error loading model: {e}")
267
+ print("Attempting to fix state_dict keys...")
268
+
269
+ # Try to fix state_dict keys (remove module prefix if needed)
270
+ new_state_dict = {}
271
+ for k, v in state_dict.items():
272
+ if k.startswith("module."):
273
+ new_state_dict[k[7:]] = v
274
+ else:
275
+ new_state_dict[k] = v
276
+
277
+ self.model.load_state_dict(new_state_dict)
278
+ print("Model loaded successfully after key fixing")
279
+
280
+ except Exception as e:
281
+ raise RuntimeError(f"Failed to load model: {e}")
282
+
283
+ def _mask_to_polygons(self, mask, **kwargs):
284
+ """
285
+ Convert binary mask to polygon contours using OpenCV.
286
+
287
+ Args:
288
+ mask: Binary mask as numpy array
289
+ **kwargs: Optional parameters:
290
+ simplify_tolerance: Tolerance for polygon simplification
291
+ mask_threshold: Threshold for mask binarization
292
+ small_building_area: Minimum area in pixels to keep a building
293
+
294
+ Returns:
295
+ List of polygons as lists of (x, y) coordinates
296
+ """
297
+
298
+ # Get parameters from kwargs or use instance defaults
299
+ simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
300
+ mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
301
+ small_building_area = kwargs.get(
302
+ "small_building_area", self.small_building_area
303
+ )
304
+
305
+ # Ensure binary mask
306
+ mask = (mask > mask_threshold).astype(np.uint8)
307
+
308
+ # Optional: apply morphological operations to improve mask quality
309
+ kernel = np.ones((3, 3), np.uint8)
310
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
311
+
312
+ # Find contours
313
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
314
+
315
+ # Convert to list of [x, y] coordinates
316
+ polygons = []
317
+ for contour in contours:
318
+ # Filter out too small contours
319
+ if contour.shape[0] < 3 or cv2.contourArea(contour) < small_building_area:
320
+ continue
321
+
322
+ # Simplify contour if it has many points
323
+ if contour.shape[0] > 50:
324
+ epsilon = simplify_tolerance * cv2.arcLength(contour, True)
325
+ contour = cv2.approxPolyDP(contour, epsilon, True)
326
+
327
+ # Convert to list of [x, y] coordinates
328
+ polygon = contour.reshape(-1, 2).tolist()
329
+ polygons.append(polygon)
330
+
331
+ return polygons
332
+
333
+ def _filter_overlapping_polygons(self, gdf, **kwargs):
334
+ """
335
+ Filter overlapping polygons using non-maximum suppression.
336
+
337
+ Args:
338
+ gdf: GeoDataFrame with polygons
339
+ **kwargs: Optional parameters:
340
+ nms_iou_threshold: IoU threshold for filtering
341
+
342
+ Returns:
343
+ Filtered GeoDataFrame
344
+ """
345
+ if len(gdf) <= 1:
346
+ return gdf
347
+
348
+ # Get parameters from kwargs or use instance defaults
349
+ iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
350
+
351
+ # Sort by confidence
352
+ gdf = gdf.sort_values("confidence", ascending=False)
353
+
354
+ # Fix any invalid geometries
355
+ gdf["geometry"] = gdf["geometry"].apply(
356
+ lambda geom: geom.buffer(0) if not geom.is_valid else geom
357
+ )
358
+
359
+ keep_indices = []
360
+ polygons = gdf.geometry.values
361
+
362
+ for i in range(len(polygons)):
363
+ if i in keep_indices:
364
+ continue
365
+
366
+ keep = True
367
+ for j in keep_indices:
368
+ # Skip invalid geometries
369
+ if not polygons[i].is_valid or not polygons[j].is_valid:
370
+ continue
371
+
372
+ # Calculate IoU
373
+ try:
374
+ intersection = polygons[i].intersection(polygons[j]).area
375
+ union = polygons[i].area + polygons[j].area - intersection
376
+ iou = intersection / union if union > 0 else 0
377
+
378
+ if iou > iou_threshold:
379
+ keep = False
380
+ break
381
+ except Exception:
382
+ # Skip on topology exceptions
383
+ continue
384
+
385
+ if keep:
386
+ keep_indices.append(i)
387
+
388
+ return gdf.iloc[keep_indices]
389
+
390
+ @torch.no_grad()
391
+ def process_raster(self, raster_path, output_path=None, batch_size=4, **kwargs):
392
+ """
393
+ Process a raster file to extract building footprints with customizable parameters.
394
+
395
+ Args:
396
+ raster_path: Path to input raster file
397
+ output_path: Path to output GeoJSON file (optional)
398
+ batch_size: Batch size for processing
399
+ **kwargs: Additional parameters:
400
+ confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
401
+ overlap: Overlap between adjacent tiles (0.0-1.0)
402
+ chip_size: Size of image chips for processing (height, width)
403
+ nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
404
+ mask_threshold: Threshold for mask binarization (0.0-1.0)
405
+ small_building_area: Minimum area in pixels to keep a building
406
+ simplify_tolerance: Tolerance for polygon simplification
407
+
408
+ Returns:
409
+ GeoDataFrame with building footprints
410
+ """
411
+ # Get parameters from kwargs or use instance defaults
412
+ confidence_threshold = kwargs.get(
413
+ "confidence_threshold", self.confidence_threshold
414
+ )
415
+ overlap = kwargs.get("overlap", self.overlap)
416
+ chip_size = kwargs.get("chip_size", self.chip_size)
417
+ nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
418
+ mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
419
+ small_building_area = kwargs.get(
420
+ "small_building_area", self.small_building_area
421
+ )
422
+ simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
423
+
424
+ # Print parameters being used
425
+ print(f"Processing with parameters:")
426
+ print(f"- Confidence threshold: {confidence_threshold}")
427
+ print(f"- Tile overlap: {overlap}")
428
+ print(f"- Chip size: {chip_size}")
429
+ print(f"- NMS IoU threshold: {nms_iou_threshold}")
430
+ print(f"- Mask threshold: {mask_threshold}")
431
+ print(f"- Min building area: {small_building_area}")
432
+ print(f"- Simplify tolerance: {simplify_tolerance}")
433
+
434
+ # Create dataset
435
+ dataset = BuildingFootprintDataset(raster_path=raster_path, chip_size=chip_size)
436
+
437
+ # Custom collate function to handle Shapely objects
438
+ def custom_collate(batch):
439
+ """
440
+ Custom collate function that handles Shapely geometries
441
+ by keeping them as Python objects rather than trying to collate them.
442
+ """
443
+ elem = batch[0]
444
+ if isinstance(elem, dict):
445
+ result = {}
446
+ for key in elem:
447
+ if key == "bbox":
448
+ # Don't collate shapely objects, keep as list
449
+ result[key] = [d[key] for d in batch]
450
+ else:
451
+ # For tensors and other collatable types
452
+ try:
453
+ result[key] = (
454
+ torch.utils.data._utils.collate.default_collate(
455
+ [d[key] for d in batch]
456
+ )
457
+ )
458
+ except TypeError:
459
+ # Fall back to list for non-collatable types
460
+ result[key] = [d[key] for d in batch]
461
+ return result
462
+ else:
463
+ # Default collate for non-dict types
464
+ return torch.utils.data._utils.collate.default_collate(batch)
465
+
466
+ # Create dataloader with simple indexing and custom collate
467
+ dataloader = torch.utils.data.DataLoader(
468
+ dataset,
469
+ batch_size=batch_size,
470
+ shuffle=False,
471
+ num_workers=0,
472
+ collate_fn=custom_collate,
473
+ )
474
+
475
+ # Process batches
476
+ all_polygons = []
477
+ all_scores = []
478
+
479
+ print(f"Processing raster with {len(dataloader)} batches")
480
+ for batch in tqdm(dataloader):
481
+ # Move images to device
482
+ images = batch["image"].to(self.device)
483
+ coords = batch["coords"] # (i, j) coordinates in pixels
484
+ bboxes = batch[
485
+ "bbox"
486
+ ] # Geographic bounding boxes - now a list, not a tensor
487
+
488
+ # Run inference
489
+ predictions = self.model(images)
490
+
491
+ # Process predictions
492
+ for idx, prediction in enumerate(predictions):
493
+ masks = prediction["masks"].cpu().numpy()
494
+ scores = prediction["scores"].cpu().numpy()
495
+ labels = prediction["labels"].cpu().numpy()
496
+
497
+ # Skip if no predictions
498
+ if len(scores) == 0:
499
+ continue
500
+
501
+ # Filter by confidence threshold
502
+ valid_indices = scores >= confidence_threshold
503
+ masks = masks[valid_indices]
504
+ scores = scores[valid_indices]
505
+ labels = labels[valid_indices]
506
+
507
+ # Skip if no valid predictions
508
+ if len(scores) == 0:
509
+ continue
510
+
511
+ # Get window coordinates
512
+ # The coords might be in different formats depending on batch handling
513
+ if isinstance(coords, list):
514
+ # If coords is a list of tuples
515
+ coord_item = coords[idx]
516
+ if isinstance(coord_item, tuple) and len(coord_item) == 2:
517
+ i, j = coord_item
518
+ elif isinstance(coord_item, torch.Tensor):
519
+ i, j = coord_item.cpu().numpy().tolist()
520
+ else:
521
+ print(f"Unexpected coords format: {type(coord_item)}")
522
+ continue
523
+ elif isinstance(coords, torch.Tensor):
524
+ # If coords is a tensor of shape [batch_size, 2]
525
+ i, j = coords[idx].cpu().numpy().tolist()
526
+ else:
527
+ print(f"Unexpected coords type: {type(coords)}")
528
+ continue
529
+
530
+ # Get window size
531
+ if isinstance(batch["window_size"], list):
532
+ window_item = batch["window_size"][idx]
533
+ if isinstance(window_item, tuple) and len(window_item) == 2:
534
+ window_width, window_height = window_item
535
+ elif isinstance(window_item, torch.Tensor):
536
+ window_width, window_height = window_item.cpu().numpy().tolist()
537
+ else:
538
+ print(f"Unexpected window_size format: {type(window_item)}")
539
+ continue
540
+ elif isinstance(batch["window_size"], torch.Tensor):
541
+ window_width, window_height = (
542
+ batch["window_size"][idx].cpu().numpy().tolist()
543
+ )
544
+ else:
545
+ print(f"Unexpected window_size type: {type(batch['window_size'])}")
546
+ continue
547
+
548
+ # Process masks to polygons
549
+ for mask_idx, mask in enumerate(masks):
550
+ # Get binary mask
551
+ binary_mask = mask[0] # Get binary mask
552
+
553
+ # Convert mask to polygon with custom parameters
554
+ contours = self._mask_to_polygons(
555
+ binary_mask,
556
+ simplify_tolerance=simplify_tolerance,
557
+ mask_threshold=mask_threshold,
558
+ small_building_area=small_building_area,
559
+ )
560
+
561
+ # Skip if no valid polygons
562
+ if not contours:
563
+ continue
564
+
565
+ # Transform polygons to geographic coordinates
566
+ with rasterio.open(raster_path) as src:
567
+ transform = src.transform
568
+
569
+ for contour in contours:
570
+ # Convert polygon to global coordinates
571
+ global_polygon = []
572
+ for x, y in contour:
573
+ # Adjust coordinates based on window position
574
+ gx, gy = transform * (i + x, j + y)
575
+ global_polygon.append((gx, gy))
576
+
577
+ # Create Shapely polygon
578
+ if len(global_polygon) >= 3:
579
+ try:
580
+ shapely_poly = Polygon(global_polygon)
581
+ if shapely_poly.is_valid and shapely_poly.area > 0:
582
+ all_polygons.append(shapely_poly)
583
+ all_scores.append(float(scores[mask_idx]))
584
+ except Exception as e:
585
+ print(f"Error creating polygon: {e}")
586
+
587
+ # Create GeoDataFrame
588
+ if not all_polygons:
589
+ print("No valid polygons found")
590
+ return None
591
+
592
+ gdf = gpd.GeoDataFrame(
593
+ {
594
+ "geometry": all_polygons,
595
+ "confidence": all_scores,
596
+ "class": 1, # Building class
597
+ },
598
+ crs=dataset.crs,
599
+ )
600
+
601
+ # Remove overlapping polygons with custom threshold
602
+ gdf = self._filter_overlapping_polygons(
603
+ gdf, nms_iou_threshold=nms_iou_threshold
604
+ )
605
+
606
+ # Save to file if requested
607
+ if output_path:
608
+ gdf.to_file(output_path, driver="GeoJSON")
609
+ print(f"Saved {len(gdf)} building footprints to {output_path}")
610
+
611
+ return gdf
612
+
613
+ def visualize_results(
614
+ self, raster_path, gdf=None, output_path=None, figsize=(12, 12)
615
+ ):
616
+ """
617
+ Visualize building detection results.
618
+
619
+ Args:
620
+ raster_path: Path to input raster
621
+ gdf: GeoDataFrame with building polygons (optional)
622
+ output_path: Path to save visualization (optional)
623
+ figsize: Figure size (width, height) in inches
624
+ """
625
+ # Check if raster file exists
626
+ if not os.path.exists(raster_path):
627
+ print(f"Error: Raster file '{raster_path}' not found.")
628
+ return
629
+
630
+ # Process raster if GeoDataFrame not provided
631
+ if gdf is None:
632
+ gdf = self.process_raster(raster_path)
633
+
634
+ if gdf is None or len(gdf) == 0:
635
+ print("No buildings to visualize")
636
+ return
637
+
638
+ # Read raster for visualization
639
+ with rasterio.open(raster_path) as src:
640
+ # Read the entire image or a subset if it's very large
641
+ if src.height > 2000 or src.width > 2000:
642
+ # Calculate scale factor to reduce size
643
+ scale = min(2000 / src.height, 2000 / src.width)
644
+ out_shape = (
645
+ int(src.count),
646
+ int(src.height * scale),
647
+ int(src.width * scale),
648
+ )
649
+
650
+ # Read and resample
651
+ image = src.read(
652
+ out_shape=out_shape, resampling=rasterio.enums.Resampling.bilinear
653
+ )
654
+ else:
655
+ image = src.read()
656
+
657
+ # Convert to RGB for display
658
+ if image.shape[0] > 3:
659
+ image = image[:3]
660
+ elif image.shape[0] == 1:
661
+ image = np.repeat(image, 3, axis=0)
662
+
663
+ # Normalize image for display
664
+ image = image.transpose(1, 2, 0) # CHW to HWC
665
+ image = image.astype(np.float32)
666
+
667
+ if image.max() > 10: # Likely 0-255 range
668
+ image = image / 255.0
669
+
670
+ image = np.clip(image, 0, 1)
671
+
672
+ # Get image bounds
673
+ bounds = src.bounds
674
+
675
+ # Create figure with appropriate aspect ratio
676
+ aspect_ratio = image.shape[1] / image.shape[0] # width / height
677
+ plt.figure(figsize=(figsize[0], figsize[0] / aspect_ratio))
678
+
679
+ # Create axis with the right projection if CRS is available
680
+ ax = plt.gca()
681
+
682
+ # Display image
683
+ ax.imshow(image)
684
+
685
+ # Convert GeoDataFrame to pixel coordinates for plotting
686
+ with rasterio.open(raster_path) as src:
687
+
688
+ def geo_to_pixel(x, y):
689
+ return ~src.transform * (x, y)
690
+
691
+ # Plot each building footprint
692
+ for _, row in gdf.iterrows():
693
+ # Convert polygon to pixel coordinates
694
+ geom = row.geometry
695
+ if geom.is_empty:
696
+ continue
697
+
698
+ try:
699
+ # Get polygon exterior coordinates
700
+ x, y = geom.exterior.xy
701
+
702
+ # Convert to pixel coordinates
703
+ pixel_coords = [geo_to_pixel(x[i], y[i]) for i in range(len(x))]
704
+ pixel_x = [coord[0] for coord in pixel_coords]
705
+ pixel_y = [coord[1] for coord in pixel_coords]
706
+
707
+ # Plot polygon
708
+ ax.plot(pixel_x, pixel_y, color="red", linewidth=1)
709
+ except Exception as e:
710
+ print(f"Error plotting polygon: {e}")
711
+
712
+ # Remove axes
713
+ ax.set_xticks([])
714
+ ax.set_yticks([])
715
+ ax.set_title(f"Building Footprints (Found: {len(gdf)})")
716
+
717
+ # Add colorbar for confidence if available
718
+ if "confidence" in gdf.columns:
719
+ # Create a colorbar legend
720
+ sm = plt.cm.ScalarMappable(
721
+ cmap=plt.cm.viridis,
722
+ norm=plt.Normalize(gdf.confidence.min(), gdf.confidence.max()),
723
+ )
724
+ sm.set_array([])
725
+ cbar = plt.colorbar(sm, ax=ax, orientation="vertical", shrink=0.7)
726
+ cbar.set_label("Confidence")
727
+
728
+ # Save if requested
729
+ if output_path:
730
+ plt.tight_layout()
731
+ plt.savefig(output_path, dpi=300, bbox_inches="tight")
732
+ print(f"Visualization saved to {output_path}")
733
+
734
+ plt.close()
735
+
736
+ # Create a simpler visualization focused just on a subset of buildings
737
+ # This helps when the raster is very large
738
+ plt.figure(figsize=figsize)
739
+ ax = plt.gca()
740
+
741
+ # Choose a subset of the image to show
742
+ with rasterio.open(raster_path) as src:
743
+ # Get a sample window based on the first few buildings
744
+ if len(gdf) > 0:
745
+ # Get centroid of first building
746
+ sample_geom = gdf.iloc[0].geometry
747
+ centroid = sample_geom.centroid
748
+
749
+ # Convert to pixel coordinates
750
+ center_x, center_y = ~src.transform * (centroid.x, centroid.y)
751
+
752
+ # Define a window around this building
753
+ window_size = 500 # pixels
754
+ window = rasterio.windows.Window(
755
+ max(0, int(center_x - window_size / 2)),
756
+ max(0, int(center_y - window_size / 2)),
757
+ min(window_size, src.width - int(center_x - window_size / 2)),
758
+ min(window_size, src.height - int(center_y - window_size / 2)),
759
+ )
760
+
761
+ # Read this window
762
+ sample_image = src.read(window=window)
763
+
764
+ # Convert to RGB for display
765
+ if sample_image.shape[0] > 3:
766
+ sample_image = sample_image[:3]
767
+ elif sample_image.shape[0] == 1:
768
+ sample_image = np.repeat(sample_image, 3, axis=0)
769
+
770
+ # Normalize image for display
771
+ sample_image = sample_image.transpose(1, 2, 0) # CHW to HWC
772
+ sample_image = sample_image.astype(np.float32)
773
+
774
+ if sample_image.max() > 10: # Likely 0-255 range
775
+ sample_image = sample_image / 255.0
776
+
777
+ sample_image = np.clip(sample_image, 0, 1)
778
+
779
+ # Get transform for this window
780
+ window_transform = src.window_transform(window)
781
+
782
+ # Display sample image
783
+ ax.imshow(sample_image)
784
+
785
+ # Filter buildings that intersect with this window
786
+ window_bounds = rasterio.windows.bounds(window, src.transform)
787
+ window_box = box(*window_bounds)
788
+ visible_gdf = gdf[gdf.intersects(window_box)]
789
+
790
+ # Plot building footprints in this view
791
+ for _, row in visible_gdf.iterrows():
792
+ try:
793
+ # Get polygon exterior coordinates
794
+ geom = row.geometry
795
+ if geom.is_empty:
796
+ continue
797
+
798
+ x, y = geom.exterior.xy
799
+
800
+ # Convert to pixel coordinates relative to window
801
+ pixel_coords = [
802
+ ~window_transform * (x[i], y[i]) for i in range(len(x))
803
+ ]
804
+ pixel_x = [coord[0] for coord in pixel_coords]
805
+ pixel_y = [coord[1] for coord in pixel_coords]
806
+
807
+ # Plot polygon
808
+ ax.plot(pixel_x, pixel_y, color="red", linewidth=1.5)
809
+ except Exception as e:
810
+ print(f"Error plotting polygon in sample view: {e}")
811
+
812
+ # Set title
813
+ ax.set_title(
814
+ f"Sample Area - Building Footprints (Showing: {len(visible_gdf)})"
815
+ )
816
+
817
+ # Remove axes
818
+ ax.set_xticks([])
819
+ ax.set_yticks([])
820
+
821
+ # Save if requested
822
+ if output_path:
823
+ sample_output = (
824
+ os.path.splitext(output_path)[0]
825
+ + "_sample"
826
+ + os.path.splitext(output_path)[1]
827
+ )
828
+ plt.tight_layout()
829
+ plt.savefig(sample_output, dpi=300, bbox_inches="tight")
830
+ print(f"Sample visualization saved to {sample_output}")
831
+
832
+ return True