geoai-py 0.3.0__py2.py3-none-any.whl → 0.3.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/extract.py +260 -143
- geoai/preprocess.py +7 -1
- geoai/utils.py +138 -3
- {geoai_py-0.3.0.dist-info → geoai_py-0.3.2.dist-info}/METADATA +1 -2
- geoai_py-0.3.2.dist-info/RECORD +13 -0
- geoai_py-0.3.0.dist-info/RECORD +0 -13
- {geoai_py-0.3.0.dist-info → geoai_py-0.3.2.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.0.dist-info → geoai_py-0.3.2.dist-info}/WHEEL +0 -0
- {geoai_py-0.3.0.dist-info → geoai_py-0.3.2.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.0.dist-info → geoai_py-0.3.2.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/extract.py
CHANGED
|
@@ -7,7 +7,6 @@ import geopandas as gpd
|
|
|
7
7
|
from tqdm import tqdm
|
|
8
8
|
|
|
9
9
|
import cv2
|
|
10
|
-
from torchgeo.datasets import NonGeoDataset
|
|
11
10
|
from torchvision.models.detection import maskrcnn_resnet50_fpn
|
|
12
11
|
import torchvision.transforms as T
|
|
13
12
|
import rasterio
|
|
@@ -16,10 +15,17 @@ from rasterio.features import shapes
|
|
|
16
15
|
from huggingface_hub import hf_hub_download
|
|
17
16
|
from .preprocess import get_raster_stats
|
|
18
17
|
|
|
18
|
+
try:
|
|
19
|
+
from torchgeo.datasets import NonGeoDataset
|
|
20
|
+
except ImportError as e:
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"Your torchgeo version is too old. Please upgrade to the latest version using 'pip install -U torchgeo'."
|
|
23
|
+
)
|
|
19
24
|
|
|
20
|
-
|
|
25
|
+
|
|
26
|
+
class CustomDataset(NonGeoDataset):
|
|
21
27
|
"""
|
|
22
|
-
A TorchGeo dataset for
|
|
28
|
+
A TorchGeo dataset for object extraction.
|
|
23
29
|
Using NonGeoDataset to avoid spatial indexing issues.
|
|
24
30
|
"""
|
|
25
31
|
|
|
@@ -164,18 +170,20 @@ class BuildingFootprintDataset(NonGeoDataset):
|
|
|
164
170
|
return self.rows * self.cols
|
|
165
171
|
|
|
166
172
|
|
|
167
|
-
class
|
|
173
|
+
class ObjectDetector:
|
|
168
174
|
"""
|
|
169
|
-
|
|
175
|
+
Object extraction using Mask R-CNN with TorchGeo.
|
|
170
176
|
"""
|
|
171
177
|
|
|
172
|
-
def __init__(self, model_path=None, device=None):
|
|
178
|
+
def __init__(self, model_path=None, repo_id=None, model=None, device=None):
|
|
173
179
|
"""
|
|
174
|
-
Initialize the
|
|
180
|
+
Initialize the object extractor.
|
|
175
181
|
|
|
176
182
|
Args:
|
|
177
|
-
model_path: Path to the .pth model file
|
|
178
|
-
|
|
183
|
+
model_path: Path to the .pth model file.
|
|
184
|
+
repo_id: Hugging Face repository ID for model download.
|
|
185
|
+
model: Pre-initialized model object (optional).
|
|
186
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
179
187
|
"""
|
|
180
188
|
# Set device
|
|
181
189
|
if device is None:
|
|
@@ -183,31 +191,35 @@ class BuildingFootprintExtractor:
|
|
|
183
191
|
else:
|
|
184
192
|
self.device = torch.device(device)
|
|
185
193
|
|
|
186
|
-
# Default parameters for
|
|
194
|
+
# Default parameters for object detection - these can be overridden in process_raster
|
|
187
195
|
self.chip_size = (512, 512) # Size of image chips for processing
|
|
188
196
|
self.overlap = 0.25 # Default overlap between tiles
|
|
189
197
|
self.confidence_threshold = 0.5 # Default confidence threshold
|
|
190
198
|
self.nms_iou_threshold = 0.5 # IoU threshold for non-maximum suppression
|
|
191
|
-
self.
|
|
199
|
+
self.small_object_area = 100 # Minimum area in pixels to keep an object
|
|
192
200
|
self.mask_threshold = 0.5 # Threshold for mask binarization
|
|
193
201
|
self.simplify_tolerance = 1.0 # Tolerance for polygon simplification
|
|
194
202
|
|
|
195
203
|
# Initialize model
|
|
196
|
-
self.model = self.
|
|
204
|
+
self.model = self.initialize_model(model)
|
|
197
205
|
|
|
198
206
|
# Download model if needed
|
|
199
|
-
if model_path is None:
|
|
200
|
-
model_path = self.
|
|
207
|
+
if model_path is None or (not os.path.exists(model_path)):
|
|
208
|
+
model_path = self.download_model_from_hf(model_path, repo_id)
|
|
201
209
|
|
|
202
210
|
# Load model weights
|
|
203
|
-
self.
|
|
211
|
+
self.load_weights(model_path)
|
|
204
212
|
|
|
205
213
|
# Set model to evaluation mode
|
|
206
214
|
self.model.eval()
|
|
207
215
|
|
|
208
|
-
def
|
|
216
|
+
def download_model_from_hf(self, model_path=None, repo_id=None):
|
|
209
217
|
"""
|
|
210
|
-
Download the
|
|
218
|
+
Download the object detection model from Hugging Face.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
model_path: Path to the model file.
|
|
222
|
+
repo_id: Hugging Face repository ID.
|
|
211
223
|
|
|
212
224
|
Returns:
|
|
213
225
|
Path to the downloaded model file
|
|
@@ -217,17 +229,14 @@ class BuildingFootprintExtractor:
|
|
|
217
229
|
print("Model path not specified, downloading from Hugging Face...")
|
|
218
230
|
|
|
219
231
|
# Define the repository ID and model filename
|
|
220
|
-
repo_id
|
|
221
|
-
|
|
232
|
+
if repo_id is None:
|
|
233
|
+
repo_id = "giswqs/geoai"
|
|
222
234
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
# os.path.expanduser("~"), ".cache", "building_footprints"
|
|
226
|
-
# )
|
|
227
|
-
# os.makedirs(cache_dir, exist_ok=True)
|
|
235
|
+
if model_path is None:
|
|
236
|
+
model_path = "building_footprints_usa.pth"
|
|
228
237
|
|
|
229
238
|
# Download the model
|
|
230
|
-
model_path = hf_hub_download(repo_id=repo_id, filename=
|
|
239
|
+
model_path = hf_hub_download(repo_id=repo_id, filename=model_path)
|
|
231
240
|
print(f"Model downloaded to: {model_path}")
|
|
232
241
|
|
|
233
242
|
return model_path
|
|
@@ -237,28 +246,36 @@ class BuildingFootprintExtractor:
|
|
|
237
246
|
print("Please specify a local model path or ensure internet connectivity.")
|
|
238
247
|
raise
|
|
239
248
|
|
|
240
|
-
def
|
|
241
|
-
"""Initialize
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
249
|
+
def initialize_model(self, model):
|
|
250
|
+
"""Initialize a deep learning model for object detection.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
model (torch.nn.Module): A pre-initialized model object.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
torch.nn.Module: A deep learning model for object detection.
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
if model is None: # Initialize Mask R-CNN model with ResNet50 backbone.
|
|
260
|
+
# Standard image mean and std for pre-trained models
|
|
261
|
+
image_mean = [0.485, 0.456, 0.406]
|
|
262
|
+
image_std = [0.229, 0.224, 0.225]
|
|
263
|
+
|
|
264
|
+
# Create model with explicit normalization parameters
|
|
265
|
+
model = maskrcnn_resnet50_fpn(
|
|
266
|
+
weights=None,
|
|
267
|
+
progress=False,
|
|
268
|
+
num_classes=2, # Background + object
|
|
269
|
+
weights_backbone=None,
|
|
270
|
+
# These parameters ensure consistent normalization
|
|
271
|
+
image_mean=image_mean,
|
|
272
|
+
image_std=image_std,
|
|
273
|
+
)
|
|
257
274
|
|
|
258
275
|
model.to(self.device)
|
|
259
276
|
return model
|
|
260
277
|
|
|
261
|
-
def
|
|
278
|
+
def load_weights(self, model_path):
|
|
262
279
|
"""
|
|
263
280
|
Load weights from file with error handling for different formats.
|
|
264
281
|
|
|
@@ -300,7 +317,7 @@ class BuildingFootprintExtractor:
|
|
|
300
317
|
except Exception as e:
|
|
301
318
|
raise RuntimeError(f"Failed to load model: {e}")
|
|
302
319
|
|
|
303
|
-
def
|
|
320
|
+
def mask_to_polygons(self, mask, **kwargs):
|
|
304
321
|
"""
|
|
305
322
|
Convert binary mask to polygon contours using OpenCV.
|
|
306
323
|
|
|
@@ -309,7 +326,7 @@ class BuildingFootprintExtractor:
|
|
|
309
326
|
**kwargs: Optional parameters:
|
|
310
327
|
simplify_tolerance: Tolerance for polygon simplification
|
|
311
328
|
mask_threshold: Threshold for mask binarization
|
|
312
|
-
|
|
329
|
+
small_object_area: Minimum area in pixels to keep an object
|
|
313
330
|
|
|
314
331
|
Returns:
|
|
315
332
|
List of polygons as lists of (x, y) coordinates
|
|
@@ -318,9 +335,7 @@ class BuildingFootprintExtractor:
|
|
|
318
335
|
# Get parameters from kwargs or use instance defaults
|
|
319
336
|
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
320
337
|
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
321
|
-
|
|
322
|
-
"small_building_area", self.small_building_area
|
|
323
|
-
)
|
|
338
|
+
small_object_area = kwargs.get("small_object_area", self.small_object_area)
|
|
324
339
|
|
|
325
340
|
# Ensure binary mask
|
|
326
341
|
mask = (mask > mask_threshold).astype(np.uint8)
|
|
@@ -336,7 +351,7 @@ class BuildingFootprintExtractor:
|
|
|
336
351
|
polygons = []
|
|
337
352
|
for contour in contours:
|
|
338
353
|
# Filter out too small contours
|
|
339
|
-
if contour.shape[0] < 3 or cv2.contourArea(contour) <
|
|
354
|
+
if contour.shape[0] < 3 or cv2.contourArea(contour) < small_object_area:
|
|
340
355
|
continue
|
|
341
356
|
|
|
342
357
|
# Simplify contour if it has many points
|
|
@@ -350,7 +365,7 @@ class BuildingFootprintExtractor:
|
|
|
350
365
|
|
|
351
366
|
return polygons
|
|
352
367
|
|
|
353
|
-
def
|
|
368
|
+
def filter_overlapping_polygons(self, gdf, **kwargs):
|
|
354
369
|
"""
|
|
355
370
|
Filter overlapping polygons using non-maximum suppression.
|
|
356
371
|
|
|
@@ -407,26 +422,26 @@ class BuildingFootprintExtractor:
|
|
|
407
422
|
|
|
408
423
|
return gdf.iloc[keep_indices]
|
|
409
424
|
|
|
410
|
-
def
|
|
425
|
+
def filter_edge_objects(self, gdf, raster_path, edge_buffer=10):
|
|
411
426
|
"""
|
|
412
|
-
Filter out
|
|
427
|
+
Filter out object detections that fall in padding/edge areas of the image.
|
|
413
428
|
|
|
414
429
|
Args:
|
|
415
|
-
gdf: GeoDataFrame with
|
|
430
|
+
gdf: GeoDataFrame with object detections
|
|
416
431
|
raster_path: Path to the original raster file
|
|
417
432
|
edge_buffer: Buffer in pixels to consider as edge region
|
|
418
433
|
|
|
419
434
|
Returns:
|
|
420
|
-
GeoDataFrame with filtered
|
|
435
|
+
GeoDataFrame with filtered objects
|
|
421
436
|
"""
|
|
422
437
|
import rasterio
|
|
423
438
|
from shapely.geometry import box
|
|
424
439
|
|
|
425
|
-
# If no
|
|
440
|
+
# If no objects detected, return empty GeoDataFrame
|
|
426
441
|
if gdf is None or len(gdf) == 0:
|
|
427
442
|
return gdf
|
|
428
443
|
|
|
429
|
-
print(f"
|
|
444
|
+
print(f"Objects before filtering: {len(gdf)}")
|
|
430
445
|
|
|
431
446
|
with rasterio.open(raster_path) as src:
|
|
432
447
|
# Get raster bounds
|
|
@@ -455,18 +470,18 @@ class BuildingFootprintExtractor:
|
|
|
455
470
|
else:
|
|
456
471
|
inner_box = box(*inner_bounds)
|
|
457
472
|
|
|
458
|
-
# Filter out
|
|
473
|
+
# Filter out objects that intersect with the edge of the image
|
|
459
474
|
filtered_gdf = gdf[gdf.intersects(inner_box)]
|
|
460
475
|
|
|
461
|
-
# Additional check for
|
|
462
|
-
|
|
476
|
+
# Additional check for objects that have >50% of their area outside the valid region
|
|
477
|
+
valid_objects = []
|
|
463
478
|
for idx, row in filtered_gdf.iterrows():
|
|
464
479
|
if row.geometry.intersection(inner_box).area >= 0.5 * row.geometry.area:
|
|
465
|
-
|
|
480
|
+
valid_objects.append(idx)
|
|
466
481
|
|
|
467
|
-
filtered_gdf = filtered_gdf.loc[
|
|
482
|
+
filtered_gdf = filtered_gdf.loc[valid_objects]
|
|
468
483
|
|
|
469
|
-
print(f"
|
|
484
|
+
print(f"Objects after filtering: {len(filtered_gdf)}")
|
|
470
485
|
|
|
471
486
|
return filtered_gdf
|
|
472
487
|
|
|
@@ -476,28 +491,28 @@ class BuildingFootprintExtractor:
|
|
|
476
491
|
output_path=None,
|
|
477
492
|
simplify_tolerance=None,
|
|
478
493
|
mask_threshold=None,
|
|
479
|
-
|
|
494
|
+
small_object_area=None,
|
|
480
495
|
nms_iou_threshold=None,
|
|
481
496
|
regularize=True,
|
|
482
497
|
angle_threshold=15,
|
|
483
498
|
rectangularity_threshold=0.7,
|
|
484
499
|
):
|
|
485
500
|
"""
|
|
486
|
-
Convert
|
|
501
|
+
Convert an object mask GeoTIFF to vector polygons and save as GeoJSON.
|
|
487
502
|
|
|
488
503
|
Args:
|
|
489
|
-
mask_path: Path to the
|
|
504
|
+
mask_path: Path to the object masks GeoTIFF
|
|
490
505
|
output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
|
|
491
506
|
simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
|
|
492
507
|
mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
|
|
493
|
-
|
|
508
|
+
small_object_area: Minimum area in pixels to keep an object (default: self.small_object_area)
|
|
494
509
|
nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
|
|
495
|
-
regularize: Whether to regularize
|
|
510
|
+
regularize: Whether to regularize objects to right angles (default: True)
|
|
496
511
|
angle_threshold: Maximum deviation from 90 degrees for regularization (default: 15)
|
|
497
512
|
rectangularity_threshold: Threshold for rectangle simplification (default: 0.7)
|
|
498
513
|
|
|
499
514
|
Returns:
|
|
500
|
-
GeoDataFrame with
|
|
515
|
+
GeoDataFrame with objects
|
|
501
516
|
"""
|
|
502
517
|
# Use class defaults if parameters not provided
|
|
503
518
|
simplify_tolerance = (
|
|
@@ -508,10 +523,10 @@ class BuildingFootprintExtractor:
|
|
|
508
523
|
mask_threshold = (
|
|
509
524
|
mask_threshold if mask_threshold is not None else self.mask_threshold
|
|
510
525
|
)
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
if
|
|
514
|
-
else self.
|
|
526
|
+
small_object_area = (
|
|
527
|
+
small_object_area
|
|
528
|
+
if small_object_area is not None
|
|
529
|
+
else self.small_object_area
|
|
515
530
|
)
|
|
516
531
|
nms_iou_threshold = (
|
|
517
532
|
nms_iou_threshold
|
|
@@ -525,10 +540,10 @@ class BuildingFootprintExtractor:
|
|
|
525
540
|
|
|
526
541
|
print(f"Converting mask to GeoJSON with parameters:")
|
|
527
542
|
print(f"- Mask threshold: {mask_threshold}")
|
|
528
|
-
print(f"- Min
|
|
543
|
+
print(f"- Min object area: {small_object_area}")
|
|
529
544
|
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
530
545
|
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
531
|
-
print(f"- Regularize
|
|
546
|
+
print(f"- Regularize objects: {regularize}")
|
|
532
547
|
if regularize:
|
|
533
548
|
print(f"- Angle threshold: {angle_threshold}° from 90°")
|
|
534
549
|
print(f"- Rectangularity threshold: {rectangularity_threshold*100}%")
|
|
@@ -558,7 +573,7 @@ class BuildingFootprintExtractor:
|
|
|
558
573
|
)
|
|
559
574
|
|
|
560
575
|
print(
|
|
561
|
-
f"Found {num_labels-1} potential
|
|
576
|
+
f"Found {num_labels-1} potential objects"
|
|
562
577
|
) # Subtract 1 for background
|
|
563
578
|
|
|
564
579
|
# Create list to store polygons and confidence values
|
|
@@ -567,19 +582,19 @@ class BuildingFootprintExtractor:
|
|
|
567
582
|
|
|
568
583
|
# Process each component (skip the first one which is background)
|
|
569
584
|
for i in tqdm(range(1, num_labels)):
|
|
570
|
-
# Extract this
|
|
585
|
+
# Extract this object
|
|
571
586
|
area = stats[i, cv2.CC_STAT_AREA]
|
|
572
587
|
|
|
573
588
|
# Skip if too small
|
|
574
|
-
if area <
|
|
589
|
+
if area < small_object_area:
|
|
575
590
|
continue
|
|
576
591
|
|
|
577
|
-
# Create a mask for this
|
|
578
|
-
|
|
592
|
+
# Create a mask for this object
|
|
593
|
+
object_mask = (labels == i).astype(np.uint8)
|
|
579
594
|
|
|
580
595
|
# Find contours
|
|
581
596
|
contours, _ = cv2.findContours(
|
|
582
|
-
|
|
597
|
+
object_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
583
598
|
)
|
|
584
599
|
|
|
585
600
|
# Process each contour
|
|
@@ -627,17 +642,17 @@ class BuildingFootprintExtractor:
|
|
|
627
642
|
{
|
|
628
643
|
"geometry": all_polygons,
|
|
629
644
|
"confidence": all_confidences,
|
|
630
|
-
"class": 1, #
|
|
645
|
+
"class": 1, # Object class
|
|
631
646
|
},
|
|
632
647
|
crs=crs,
|
|
633
648
|
)
|
|
634
649
|
|
|
635
650
|
# Apply non-maximum suppression to remove overlapping polygons
|
|
636
|
-
gdf = self.
|
|
651
|
+
gdf = self.filter_overlapping_polygons(
|
|
637
652
|
gdf, nms_iou_threshold=nms_iou_threshold
|
|
638
653
|
)
|
|
639
654
|
|
|
640
|
-
print(f"
|
|
655
|
+
print(f"Object count after NMS filtering: {len(gdf)}")
|
|
641
656
|
|
|
642
657
|
# Apply regularization if requested
|
|
643
658
|
if regularize and len(gdf) > 0:
|
|
@@ -655,8 +670,8 @@ class BuildingFootprintExtractor:
|
|
|
655
670
|
# Use 10 pixels as minimum area in geographic units
|
|
656
671
|
min_geo_area = 10 * avg_pixel_area
|
|
657
672
|
|
|
658
|
-
# Regularize
|
|
659
|
-
gdf = self.
|
|
673
|
+
# Regularize objects
|
|
674
|
+
gdf = self.regularize_objects(
|
|
660
675
|
gdf,
|
|
661
676
|
min_area=min_geo_area,
|
|
662
677
|
angle_threshold=angle_threshold,
|
|
@@ -666,7 +681,7 @@ class BuildingFootprintExtractor:
|
|
|
666
681
|
# Save to file
|
|
667
682
|
if output_path:
|
|
668
683
|
gdf.to_file(output_path)
|
|
669
|
-
print(f"Saved {len(gdf)}
|
|
684
|
+
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
670
685
|
|
|
671
686
|
return gdf
|
|
672
687
|
|
|
@@ -681,25 +696,25 @@ class BuildingFootprintExtractor:
|
|
|
681
696
|
**kwargs,
|
|
682
697
|
):
|
|
683
698
|
"""
|
|
684
|
-
Process a raster file to extract
|
|
699
|
+
Process a raster file to extract objects with customizable parameters.
|
|
685
700
|
|
|
686
701
|
Args:
|
|
687
702
|
raster_path: Path to input raster file
|
|
688
703
|
output_path: Path to output GeoJSON file (optional)
|
|
689
704
|
batch_size: Batch size for processing
|
|
690
|
-
filter_edges: Whether to filter out
|
|
691
|
-
edge_buffer: Size of edge buffer in pixels to filter out
|
|
705
|
+
filter_edges: Whether to filter out objects at the edges of the image
|
|
706
|
+
edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
|
|
692
707
|
**kwargs: Additional parameters:
|
|
693
708
|
confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
|
|
694
709
|
overlap: Overlap between adjacent tiles (0.0-1.0)
|
|
695
710
|
chip_size: Size of image chips for processing (height, width)
|
|
696
711
|
nms_iou_threshold: IoU threshold for non-maximum suppression (0.0-1.0)
|
|
697
712
|
mask_threshold: Threshold for mask binarization (0.0-1.0)
|
|
698
|
-
|
|
713
|
+
small_object_area: Minimum area in pixels to keep an object
|
|
699
714
|
simplify_tolerance: Tolerance for polygon simplification
|
|
700
715
|
|
|
701
716
|
Returns:
|
|
702
|
-
GeoDataFrame with
|
|
717
|
+
GeoDataFrame with objects
|
|
703
718
|
"""
|
|
704
719
|
# Get parameters from kwargs or use instance defaults
|
|
705
720
|
confidence_threshold = kwargs.get(
|
|
@@ -709,9 +724,7 @@ class BuildingFootprintExtractor:
|
|
|
709
724
|
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
710
725
|
nms_iou_threshold = kwargs.get("nms_iou_threshold", self.nms_iou_threshold)
|
|
711
726
|
mask_threshold = kwargs.get("mask_threshold", self.mask_threshold)
|
|
712
|
-
|
|
713
|
-
"small_building_area", self.small_building_area
|
|
714
|
-
)
|
|
727
|
+
small_object_area = kwargs.get("small_object_area", self.small_object_area)
|
|
715
728
|
simplify_tolerance = kwargs.get("simplify_tolerance", self.simplify_tolerance)
|
|
716
729
|
|
|
717
730
|
# Print parameters being used
|
|
@@ -721,14 +734,14 @@ class BuildingFootprintExtractor:
|
|
|
721
734
|
print(f"- Chip size: {chip_size}")
|
|
722
735
|
print(f"- NMS IoU threshold: {nms_iou_threshold}")
|
|
723
736
|
print(f"- Mask threshold: {mask_threshold}")
|
|
724
|
-
print(f"- Min
|
|
737
|
+
print(f"- Min object area: {small_object_area}")
|
|
725
738
|
print(f"- Simplify tolerance: {simplify_tolerance}")
|
|
726
|
-
print(f"- Filter edge
|
|
739
|
+
print(f"- Filter edge objects: {filter_edges}")
|
|
727
740
|
if filter_edges:
|
|
728
741
|
print(f"- Edge buffer size: {edge_buffer} pixels")
|
|
729
742
|
|
|
730
743
|
# Create dataset
|
|
731
|
-
dataset =
|
|
744
|
+
dataset = CustomDataset(raster_path=raster_path, chip_size=chip_size)
|
|
732
745
|
self.raster_stats = dataset.raster_stats
|
|
733
746
|
|
|
734
747
|
# Custom collate function to handle Shapely objects
|
|
@@ -848,11 +861,11 @@ class BuildingFootprintExtractor:
|
|
|
848
861
|
binary_mask = mask[0] # Get binary mask
|
|
849
862
|
|
|
850
863
|
# Convert mask to polygon with custom parameters
|
|
851
|
-
contours = self.
|
|
864
|
+
contours = self.mask_to_polygons(
|
|
852
865
|
binary_mask,
|
|
853
866
|
simplify_tolerance=simplify_tolerance,
|
|
854
867
|
mask_threshold=mask_threshold,
|
|
855
|
-
|
|
868
|
+
small_object_area=small_object_area,
|
|
856
869
|
)
|
|
857
870
|
|
|
858
871
|
# Skip if no valid polygons
|
|
@@ -890,24 +903,22 @@ class BuildingFootprintExtractor:
|
|
|
890
903
|
{
|
|
891
904
|
"geometry": all_polygons,
|
|
892
905
|
"confidence": all_scores,
|
|
893
|
-
"class": 1, #
|
|
906
|
+
"class": 1, # Object class
|
|
894
907
|
},
|
|
895
908
|
crs=dataset.crs,
|
|
896
909
|
)
|
|
897
910
|
|
|
898
911
|
# Remove overlapping polygons with custom threshold
|
|
899
|
-
gdf = self.
|
|
900
|
-
gdf, nms_iou_threshold=nms_iou_threshold
|
|
901
|
-
)
|
|
912
|
+
gdf = self.filter_overlapping_polygons(gdf, nms_iou_threshold=nms_iou_threshold)
|
|
902
913
|
|
|
903
|
-
# Filter edge
|
|
914
|
+
# Filter edge objects if requested
|
|
904
915
|
if filter_edges:
|
|
905
|
-
gdf = self.
|
|
916
|
+
gdf = self.filter_edge_objects(gdf, raster_path, edge_buffer=edge_buffer)
|
|
906
917
|
|
|
907
918
|
# Save to file if requested
|
|
908
919
|
if output_path:
|
|
909
920
|
gdf.to_file(output_path, driver="GeoJSON")
|
|
910
|
-
print(f"Saved {len(gdf)}
|
|
921
|
+
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
911
922
|
|
|
912
923
|
return gdf
|
|
913
924
|
|
|
@@ -915,7 +926,7 @@ class BuildingFootprintExtractor:
|
|
|
915
926
|
self, raster_path, output_path=None, batch_size=4, verbose=False, **kwargs
|
|
916
927
|
):
|
|
917
928
|
"""
|
|
918
|
-
Process a raster file to extract
|
|
929
|
+
Process a raster file to extract object masks and save as GeoTIFF.
|
|
919
930
|
|
|
920
931
|
Args:
|
|
921
932
|
raster_path: Path to input raster file
|
|
@@ -949,7 +960,7 @@ class BuildingFootprintExtractor:
|
|
|
949
960
|
print(f"- Mask threshold: {mask_threshold}")
|
|
950
961
|
|
|
951
962
|
# Create dataset
|
|
952
|
-
dataset =
|
|
963
|
+
dataset = CustomDataset(
|
|
953
964
|
raster_path=raster_path, chip_size=chip_size, verbose=verbose
|
|
954
965
|
)
|
|
955
966
|
|
|
@@ -966,7 +977,7 @@ class BuildingFootprintExtractor:
|
|
|
966
977
|
output_profile = src.profile.copy()
|
|
967
978
|
output_profile.update(
|
|
968
979
|
dtype=rasterio.uint8,
|
|
969
|
-
count=1, # Single band for
|
|
980
|
+
count=1, # Single band for object mask
|
|
970
981
|
compress="lzw",
|
|
971
982
|
nodata=0,
|
|
972
983
|
)
|
|
@@ -1138,10 +1149,10 @@ class BuildingFootprintExtractor:
|
|
|
1138
1149
|
# Write the final mask to the output file
|
|
1139
1150
|
dst.write(mask_array, 1)
|
|
1140
1151
|
|
|
1141
|
-
print(f"
|
|
1152
|
+
print(f"Object masks saved to {output_path}")
|
|
1142
1153
|
return output_path
|
|
1143
1154
|
|
|
1144
|
-
def
|
|
1155
|
+
def regularize_objects(
|
|
1145
1156
|
self,
|
|
1146
1157
|
gdf,
|
|
1147
1158
|
min_area=10,
|
|
@@ -1150,17 +1161,17 @@ class BuildingFootprintExtractor:
|
|
|
1150
1161
|
rectangularity_threshold=0.7,
|
|
1151
1162
|
):
|
|
1152
1163
|
"""
|
|
1153
|
-
Regularize
|
|
1164
|
+
Regularize objects to enforce right angles and rectangular shapes.
|
|
1154
1165
|
|
|
1155
1166
|
Args:
|
|
1156
|
-
gdf: GeoDataFrame with
|
|
1157
|
-
min_area: Minimum area in square units to keep
|
|
1167
|
+
gdf: GeoDataFrame with objects
|
|
1168
|
+
min_area: Minimum area in square units to keep an object
|
|
1158
1169
|
angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
|
|
1159
|
-
orthogonality_threshold: Percentage of angles that must be orthogonal for
|
|
1160
|
-
rectangularity_threshold: Minimum area ratio to
|
|
1170
|
+
orthogonality_threshold: Percentage of angles that must be orthogonal for an object to be regularized
|
|
1171
|
+
rectangularity_threshold: Minimum area ratio to Object's oriented bounding box for rectangular simplification
|
|
1161
1172
|
|
|
1162
1173
|
Returns:
|
|
1163
|
-
GeoDataFrame with regularized
|
|
1174
|
+
GeoDataFrame with regularized objects
|
|
1164
1175
|
"""
|
|
1165
1176
|
import numpy as np
|
|
1166
1177
|
from shapely.geometry import Polygon, MultiPolygon, box
|
|
@@ -1275,10 +1286,10 @@ class BuildingFootprintExtractor:
|
|
|
1275
1286
|
return rect
|
|
1276
1287
|
|
|
1277
1288
|
if gdf is None or len(gdf) == 0:
|
|
1278
|
-
print("No
|
|
1289
|
+
print("No Objects to regularize")
|
|
1279
1290
|
return gdf
|
|
1280
1291
|
|
|
1281
|
-
print(f"Regularizing {len(gdf)}
|
|
1292
|
+
print(f"Regularizing {len(gdf)} objects...")
|
|
1282
1293
|
print(f"- Angle threshold: {angle_threshold}° from 90°")
|
|
1283
1294
|
print(f"- Min orthogonality: {orthogonality_threshold*100}% of angles")
|
|
1284
1295
|
print(
|
|
@@ -1289,11 +1300,11 @@ class BuildingFootprintExtractor:
|
|
|
1289
1300
|
result_gdf = gdf.copy()
|
|
1290
1301
|
|
|
1291
1302
|
# Track statistics
|
|
1292
|
-
|
|
1303
|
+
total_objects = len(gdf)
|
|
1293
1304
|
regularized_count = 0
|
|
1294
1305
|
rectangularized_count = 0
|
|
1295
1306
|
|
|
1296
|
-
# Process each
|
|
1307
|
+
# Process each Object
|
|
1297
1308
|
for idx, row in tqdm(gdf.iterrows(), total=len(gdf)):
|
|
1298
1309
|
geom = row.geometry
|
|
1299
1310
|
|
|
@@ -1308,7 +1319,7 @@ class BuildingFootprintExtractor:
|
|
|
1308
1319
|
continue
|
|
1309
1320
|
geom = list(geom.geoms)[np.argmax(areas)]
|
|
1310
1321
|
|
|
1311
|
-
# Filter out tiny
|
|
1322
|
+
# Filter out tiny Objects
|
|
1312
1323
|
if geom.area < min_area:
|
|
1313
1324
|
continue
|
|
1314
1325
|
|
|
@@ -1325,33 +1336,33 @@ class BuildingFootprintExtractor:
|
|
|
1325
1336
|
|
|
1326
1337
|
# Decide how to regularize
|
|
1327
1338
|
if rectangularity >= rectangularity_threshold:
|
|
1328
|
-
#
|
|
1339
|
+
# Object is already quite rectangular, simplify to a rectangle
|
|
1329
1340
|
result_gdf.at[idx, "geometry"] = oriented_box
|
|
1330
1341
|
result_gdf.at[idx, "regularized"] = "rectangle"
|
|
1331
1342
|
rectangularized_count += 1
|
|
1332
1343
|
elif orthogonality >= orthogonality_threshold:
|
|
1333
|
-
#
|
|
1344
|
+
# Object has many orthogonal angles but isn't rectangular
|
|
1334
1345
|
# Could implement more sophisticated regularization here
|
|
1335
1346
|
# For now, we'll still use the oriented rectangle
|
|
1336
1347
|
result_gdf.at[idx, "geometry"] = oriented_box
|
|
1337
1348
|
result_gdf.at[idx, "regularized"] = "orthogonal"
|
|
1338
1349
|
regularized_count += 1
|
|
1339
1350
|
else:
|
|
1340
|
-
#
|
|
1351
|
+
# Object doesn't have clear orthogonal structure
|
|
1341
1352
|
# Keep original but flag as unmodified
|
|
1342
1353
|
result_gdf.at[idx, "regularized"] = "original"
|
|
1343
1354
|
|
|
1344
1355
|
# Report statistics
|
|
1345
1356
|
print(f"Regularization completed:")
|
|
1346
|
-
print(f"- Total
|
|
1357
|
+
print(f"- Total objects: {total_objects}")
|
|
1347
1358
|
print(
|
|
1348
|
-
f"- Rectangular
|
|
1359
|
+
f"- Rectangular objects: {rectangularized_count} ({rectangularized_count/total_objects*100:.1f}%)"
|
|
1349
1360
|
)
|
|
1350
1361
|
print(
|
|
1351
|
-
f"- Other regularized
|
|
1362
|
+
f"- Other regularized objects: {regularized_count} ({regularized_count/total_objects*100:.1f}%)"
|
|
1352
1363
|
)
|
|
1353
1364
|
print(
|
|
1354
|
-
f"- Unmodified
|
|
1365
|
+
f"- Unmodified objects: {total_objects-rectangularized_count-regularized_count} ({(total_objects-rectangularized_count-regularized_count)/total_objects*100:.1f}%)"
|
|
1355
1366
|
)
|
|
1356
1367
|
|
|
1357
1368
|
return result_gdf
|
|
@@ -1360,14 +1371,14 @@ class BuildingFootprintExtractor:
|
|
|
1360
1371
|
self, raster_path, gdf=None, output_path=None, figsize=(12, 12)
|
|
1361
1372
|
):
|
|
1362
1373
|
"""
|
|
1363
|
-
Visualize
|
|
1374
|
+
Visualize object detection results with proper coordinate transformation.
|
|
1364
1375
|
|
|
1365
|
-
This function displays
|
|
1376
|
+
This function displays objects on top of the raster image,
|
|
1366
1377
|
ensuring proper alignment between the GeoDataFrame polygons and the image.
|
|
1367
1378
|
|
|
1368
1379
|
Args:
|
|
1369
1380
|
raster_path: Path to input raster
|
|
1370
|
-
gdf: GeoDataFrame with
|
|
1381
|
+
gdf: GeoDataFrame with object polygons (optional)
|
|
1371
1382
|
output_path: Path to save visualization (optional)
|
|
1372
1383
|
figsize: Figure size (width, height) in inches
|
|
1373
1384
|
|
|
@@ -1384,7 +1395,7 @@ class BuildingFootprintExtractor:
|
|
|
1384
1395
|
gdf = self.process_raster(raster_path)
|
|
1385
1396
|
|
|
1386
1397
|
if gdf is None or len(gdf) == 0:
|
|
1387
|
-
print("No
|
|
1398
|
+
print("No objects to visualize")
|
|
1388
1399
|
return False
|
|
1389
1400
|
|
|
1390
1401
|
# Check if confidence column exists in the GeoDataFrame
|
|
@@ -1525,7 +1536,7 @@ class BuildingFootprintExtractor:
|
|
|
1525
1536
|
print(f"Unsupported geometry type: {geometry.geom_type}")
|
|
1526
1537
|
return None
|
|
1527
1538
|
|
|
1528
|
-
# Plot each
|
|
1539
|
+
# Plot each object
|
|
1529
1540
|
for idx, row in gdf.iterrows():
|
|
1530
1541
|
try:
|
|
1531
1542
|
# Convert polygon to pixel coordinates
|
|
@@ -1587,7 +1598,7 @@ class BuildingFootprintExtractor:
|
|
|
1587
1598
|
# Remove axes
|
|
1588
1599
|
ax.set_xticks([])
|
|
1589
1600
|
ax.set_yticks([])
|
|
1590
|
-
ax.set_title(f"
|
|
1601
|
+
ax.set_title(f"objects (Found: {len(gdf)})")
|
|
1591
1602
|
|
|
1592
1603
|
# Save if requested
|
|
1593
1604
|
if output_path:
|
|
@@ -1597,21 +1608,21 @@ class BuildingFootprintExtractor:
|
|
|
1597
1608
|
|
|
1598
1609
|
plt.close()
|
|
1599
1610
|
|
|
1600
|
-
# Create a simpler visualization focused just on a subset of
|
|
1611
|
+
# Create a simpler visualization focused just on a subset of objects
|
|
1601
1612
|
if len(gdf) > 0:
|
|
1602
1613
|
plt.figure(figsize=figsize)
|
|
1603
1614
|
ax = plt.gca()
|
|
1604
1615
|
|
|
1605
1616
|
# Choose a subset of the image to show
|
|
1606
1617
|
with rasterio.open(raster_path) as src:
|
|
1607
|
-
# Get centroid of first
|
|
1618
|
+
# Get centroid of first object
|
|
1608
1619
|
sample_geom = gdf.iloc[0].geometry
|
|
1609
1620
|
centroid = sample_geom.centroid
|
|
1610
1621
|
|
|
1611
1622
|
# Convert to pixel coordinates
|
|
1612
1623
|
center_x, center_y = ~src.transform * (centroid.x, centroid.y)
|
|
1613
1624
|
|
|
1614
|
-
# Define a window around this
|
|
1625
|
+
# Define a window around this object
|
|
1615
1626
|
window_size = 500 # pixels
|
|
1616
1627
|
window = rasterio.windows.Window(
|
|
1617
1628
|
max(0, int(center_x - window_size / 2)),
|
|
@@ -1648,7 +1659,7 @@ class BuildingFootprintExtractor:
|
|
|
1648
1659
|
window_bounds = rasterio.windows.bounds(window, src.transform)
|
|
1649
1660
|
window_box = box(*window_bounds)
|
|
1650
1661
|
|
|
1651
|
-
# Filter
|
|
1662
|
+
# Filter objects that intersect with this window
|
|
1652
1663
|
visible_gdf = gdf[gdf.intersects(window_box)]
|
|
1653
1664
|
|
|
1654
1665
|
# Set up colors for sample view if confidence data exists
|
|
@@ -1670,7 +1681,7 @@ class BuildingFootprintExtractor:
|
|
|
1670
1681
|
except Exception as e:
|
|
1671
1682
|
print(f"Error setting up sample confidence visualization: {e}")
|
|
1672
1683
|
|
|
1673
|
-
# Plot
|
|
1684
|
+
# Plot objects in sample view
|
|
1674
1685
|
for idx, row in visible_gdf.iterrows():
|
|
1675
1686
|
try:
|
|
1676
1687
|
# Get window-relative pixel coordinates
|
|
@@ -1745,9 +1756,7 @@ class BuildingFootprintExtractor:
|
|
|
1745
1756
|
print(f"Error plotting polygon in sample view: {e}")
|
|
1746
1757
|
|
|
1747
1758
|
# Set title
|
|
1748
|
-
ax.set_title(
|
|
1749
|
-
f"Sample Area - Building Footprints (Showing: {len(visible_gdf)})"
|
|
1750
|
-
)
|
|
1759
|
+
ax.set_title(f"Sample Area - objects (Showing: {len(visible_gdf)})")
|
|
1751
1760
|
|
|
1752
1761
|
# Remove axes
|
|
1753
1762
|
ax.set_xticks([])
|
|
@@ -1763,3 +1772,111 @@ class BuildingFootprintExtractor:
|
|
|
1763
1772
|
plt.tight_layout()
|
|
1764
1773
|
plt.savefig(sample_output, dpi=300, bbox_inches="tight")
|
|
1765
1774
|
print(f"Sample visualization saved to {sample_output}")
|
|
1775
|
+
|
|
1776
|
+
|
|
1777
|
+
class BuildingFootprintExtractor(ObjectDetector):
|
|
1778
|
+
"""
|
|
1779
|
+
Building footprint extraction using a pre-trained Mask R-CNN model.
|
|
1780
|
+
|
|
1781
|
+
This class extends the
|
|
1782
|
+
`ObjectDetector` class with additional methods for building footprint extraction."
|
|
1783
|
+
"""
|
|
1784
|
+
|
|
1785
|
+
def __init__(
|
|
1786
|
+
self,
|
|
1787
|
+
model_path="building_footprints_usa.pth",
|
|
1788
|
+
repo_id=None,
|
|
1789
|
+
model=None,
|
|
1790
|
+
device=None,
|
|
1791
|
+
):
|
|
1792
|
+
"""
|
|
1793
|
+
Initialize the object extractor.
|
|
1794
|
+
|
|
1795
|
+
Args:
|
|
1796
|
+
model_path: Path to the .pth model file.
|
|
1797
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
1798
|
+
model: Custom model to use for inference.
|
|
1799
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1800
|
+
"""
|
|
1801
|
+
super().__init__(
|
|
1802
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1803
|
+
)
|
|
1804
|
+
|
|
1805
|
+
def regularize_buildings(
|
|
1806
|
+
self,
|
|
1807
|
+
gdf,
|
|
1808
|
+
min_area=10,
|
|
1809
|
+
angle_threshold=15,
|
|
1810
|
+
orthogonality_threshold=0.3,
|
|
1811
|
+
rectangularity_threshold=0.7,
|
|
1812
|
+
):
|
|
1813
|
+
"""
|
|
1814
|
+
Regularize building footprints to enforce right angles and rectangular shapes.
|
|
1815
|
+
|
|
1816
|
+
Args:
|
|
1817
|
+
gdf: GeoDataFrame with building footprints
|
|
1818
|
+
min_area: Minimum area in square units to keep a building
|
|
1819
|
+
angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
|
|
1820
|
+
orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
|
|
1821
|
+
rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
|
|
1822
|
+
|
|
1823
|
+
Returns:
|
|
1824
|
+
GeoDataFrame with regularized building footprints
|
|
1825
|
+
"""
|
|
1826
|
+
return self.regularize_objects(
|
|
1827
|
+
gdf,
|
|
1828
|
+
min_area=min_area,
|
|
1829
|
+
angle_threshold=angle_threshold,
|
|
1830
|
+
orthogonality_threshold=orthogonality_threshold,
|
|
1831
|
+
rectangularity_threshold=rectangularity_threshold,
|
|
1832
|
+
)
|
|
1833
|
+
|
|
1834
|
+
|
|
1835
|
+
class CarDetector(ObjectDetector):
|
|
1836
|
+
"""
|
|
1837
|
+
Car detection using a pre-trained Mask R-CNN model.
|
|
1838
|
+
|
|
1839
|
+
This class extends the
|
|
1840
|
+
`ObjectDetector` class with additional methods for car detection."
|
|
1841
|
+
"""
|
|
1842
|
+
|
|
1843
|
+
def __init__(
|
|
1844
|
+
self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
|
|
1845
|
+
):
|
|
1846
|
+
"""
|
|
1847
|
+
Initialize the object extractor.
|
|
1848
|
+
|
|
1849
|
+
Args:
|
|
1850
|
+
model_path: Path to the .pth model file.
|
|
1851
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
1852
|
+
model: Custom model to use for inference.
|
|
1853
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1854
|
+
"""
|
|
1855
|
+
super().__init__(
|
|
1856
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1857
|
+
)
|
|
1858
|
+
|
|
1859
|
+
|
|
1860
|
+
class ShipDetector(ObjectDetector):
|
|
1861
|
+
"""
|
|
1862
|
+
Ship detection using a pre-trained Mask R-CNN model.
|
|
1863
|
+
|
|
1864
|
+
This class extends the
|
|
1865
|
+
`ObjectDetector` class with additional methods for ship detection."
|
|
1866
|
+
"""
|
|
1867
|
+
|
|
1868
|
+
def __init__(
|
|
1869
|
+
self, model_path="ship_detection.pth", repo_id=None, model=None, device=None
|
|
1870
|
+
):
|
|
1871
|
+
"""
|
|
1872
|
+
Initialize the object extractor.
|
|
1873
|
+
|
|
1874
|
+
Args:
|
|
1875
|
+
model_path: Path to the .pth model file.
|
|
1876
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
1877
|
+
model: Custom model to use for inference.
|
|
1878
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1879
|
+
"""
|
|
1880
|
+
super().__init__(
|
|
1881
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1882
|
+
)
|
geoai/preprocess.py
CHANGED
|
@@ -18,10 +18,16 @@ import matplotlib.pyplot as plt
|
|
|
18
18
|
from tqdm import tqdm
|
|
19
19
|
from torchvision.transforms import RandomRotation
|
|
20
20
|
from shapely.affinity import rotate
|
|
21
|
-
import torchgeo
|
|
22
21
|
import torch
|
|
23
22
|
import cv2
|
|
24
23
|
|
|
24
|
+
try:
|
|
25
|
+
import torchgeo
|
|
26
|
+
except ImportError as e:
|
|
27
|
+
raise ImportError(
|
|
28
|
+
"Your torchgeo version is too old. Please upgrade to the latest version using 'pip install -U torchgeo'."
|
|
29
|
+
)
|
|
30
|
+
|
|
25
31
|
|
|
26
32
|
def download_file(url, output_path=None, overwrite=False):
|
|
27
33
|
"""
|
geoai/utils.py
CHANGED
|
@@ -12,9 +12,15 @@ import xarray as xr
|
|
|
12
12
|
import rioxarray
|
|
13
13
|
import rasterio as rio
|
|
14
14
|
from torch.utils.data import DataLoader
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
from torchgeo.
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
|
|
18
|
+
from torchgeo.samplers import RandomGeoSampler, Units
|
|
19
|
+
from torchgeo.transforms import indices
|
|
20
|
+
except ImportError as e:
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"Your torchgeo version is too old. Please upgrade to the latest version using 'pip install -U torchgeo'."
|
|
23
|
+
)
|
|
18
24
|
|
|
19
25
|
|
|
20
26
|
def view_raster(
|
|
@@ -1039,3 +1045,132 @@ def adaptive_regularization(
|
|
|
1039
1045
|
return gpd.GeoDataFrame(geometry=results, crs=building_polygons.crs)
|
|
1040
1046
|
else:
|
|
1041
1047
|
return results
|
|
1048
|
+
|
|
1049
|
+
|
|
1050
|
+
def install_package(package):
|
|
1051
|
+
"""Install a Python package.
|
|
1052
|
+
|
|
1053
|
+
Args:
|
|
1054
|
+
package (str | list): The package name or a GitHub URL or a list of package names or GitHub URLs.
|
|
1055
|
+
"""
|
|
1056
|
+
import subprocess
|
|
1057
|
+
|
|
1058
|
+
if isinstance(package, str):
|
|
1059
|
+
packages = [package]
|
|
1060
|
+
elif isinstance(package, list):
|
|
1061
|
+
packages = package
|
|
1062
|
+
else:
|
|
1063
|
+
raise ValueError("The package argument must be a string or a list of strings.")
|
|
1064
|
+
|
|
1065
|
+
for package in packages:
|
|
1066
|
+
if package.startswith("https"):
|
|
1067
|
+
package = f"git+{package}"
|
|
1068
|
+
|
|
1069
|
+
# Execute pip install command and show output in real-time
|
|
1070
|
+
command = f"pip install {package}"
|
|
1071
|
+
process = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
|
|
1072
|
+
|
|
1073
|
+
# Print output in real-time
|
|
1074
|
+
while True:
|
|
1075
|
+
output = process.stdout.readline()
|
|
1076
|
+
if output == b"" and process.poll() is not None:
|
|
1077
|
+
break
|
|
1078
|
+
if output:
|
|
1079
|
+
print(output.decode("utf-8").strip())
|
|
1080
|
+
|
|
1081
|
+
# Wait for process to complete
|
|
1082
|
+
process.wait()
|
|
1083
|
+
|
|
1084
|
+
|
|
1085
|
+
def create_split_map(
|
|
1086
|
+
left_layer: Optional[str] = "TERRAIN",
|
|
1087
|
+
right_layer: Optional[str] = "OpenTopoMap",
|
|
1088
|
+
left_args: Optional[dict] = None,
|
|
1089
|
+
right_args: Optional[dict] = None,
|
|
1090
|
+
left_array_args: Optional[dict] = None,
|
|
1091
|
+
right_array_args: Optional[dict] = None,
|
|
1092
|
+
zoom_control: Optional[bool] = True,
|
|
1093
|
+
fullscreen_control: Optional[bool] = True,
|
|
1094
|
+
layer_control: Optional[bool] = True,
|
|
1095
|
+
add_close_button: Optional[bool] = False,
|
|
1096
|
+
left_label: Optional[str] = None,
|
|
1097
|
+
right_label: Optional[str] = None,
|
|
1098
|
+
left_position: Optional[str] = "bottomleft",
|
|
1099
|
+
right_position: Optional[str] = "bottomright",
|
|
1100
|
+
widget_layout: Optional[dict] = None,
|
|
1101
|
+
draggable: Optional[bool] = True,
|
|
1102
|
+
center: Optional[List[float]] = [20, 0],
|
|
1103
|
+
zoom: Optional[int] = 2,
|
|
1104
|
+
height: Optional[int] = "600px",
|
|
1105
|
+
basemap: Optional[str] = None,
|
|
1106
|
+
basemap_args: Optional[dict] = None,
|
|
1107
|
+
m=None,
|
|
1108
|
+
**kwargs,
|
|
1109
|
+
) -> None:
|
|
1110
|
+
"""Adds split map.
|
|
1111
|
+
|
|
1112
|
+
Args:
|
|
1113
|
+
left_layer (str, optional): The left tile layer. Can be a local file path, HTTP URL, or a basemap name. Defaults to 'TERRAIN'.
|
|
1114
|
+
right_layer (str, optional): The right tile layer. Can be a local file path, HTTP URL, or a basemap name. Defaults to 'OpenTopoMap'.
|
|
1115
|
+
left_args (dict, optional): The arguments for the left tile layer. Defaults to {}.
|
|
1116
|
+
right_args (dict, optional): The arguments for the right tile layer. Defaults to {}.
|
|
1117
|
+
left_array_args (dict, optional): The arguments for array_to_image for the left layer. Defaults to {}.
|
|
1118
|
+
right_array_args (dict, optional): The arguments for array_to_image for the right layer. Defaults to {}.
|
|
1119
|
+
zoom_control (bool, optional): Whether to add zoom control. Defaults to True.
|
|
1120
|
+
fullscreen_control (bool, optional): Whether to add fullscreen control. Defaults to True.
|
|
1121
|
+
layer_control (bool, optional): Whether to add layer control. Defaults to True.
|
|
1122
|
+
add_close_button (bool, optional): Whether to add a close button. Defaults to False.
|
|
1123
|
+
left_label (str, optional): The label for the left layer. Defaults to None.
|
|
1124
|
+
right_label (str, optional): The label for the right layer. Defaults to None.
|
|
1125
|
+
left_position (str, optional): The position for the left label. Defaults to "bottomleft".
|
|
1126
|
+
right_position (str, optional): The position for the right label. Defaults to "bottomright".
|
|
1127
|
+
widget_layout (dict, optional): The layout for the widget. Defaults to None.
|
|
1128
|
+
draggable (bool, optional): Whether the split map is draggable. Defaults to True.
|
|
1129
|
+
"""
|
|
1130
|
+
|
|
1131
|
+
if left_args is None:
|
|
1132
|
+
left_args = {}
|
|
1133
|
+
|
|
1134
|
+
if right_args is None:
|
|
1135
|
+
right_args = {}
|
|
1136
|
+
|
|
1137
|
+
if left_array_args is None:
|
|
1138
|
+
left_array_args = {}
|
|
1139
|
+
|
|
1140
|
+
if right_array_args is None:
|
|
1141
|
+
right_array_args = {}
|
|
1142
|
+
|
|
1143
|
+
if basemap_args is None:
|
|
1144
|
+
basemap_args = {}
|
|
1145
|
+
|
|
1146
|
+
if m is None:
|
|
1147
|
+
m = leafmap.Map(center=center, zoom=zoom, height=height, **kwargs)
|
|
1148
|
+
m.clear_layers()
|
|
1149
|
+
if isinstance(basemap, str):
|
|
1150
|
+
if basemap.endswith(".tif"):
|
|
1151
|
+
if basemap.startswith("http"):
|
|
1152
|
+
m.add_cog_layer(basemap, name="Basemap", **basemap_args)
|
|
1153
|
+
else:
|
|
1154
|
+
m.add_raster(basemap, name="Basemap", **basemap_args)
|
|
1155
|
+
else:
|
|
1156
|
+
m.add_basemap(basemap)
|
|
1157
|
+
m.split_map(
|
|
1158
|
+
left_layer=left_layer,
|
|
1159
|
+
right_layer=right_layer,
|
|
1160
|
+
left_args=left_args,
|
|
1161
|
+
right_args=right_args,
|
|
1162
|
+
left_array_args=left_array_args,
|
|
1163
|
+
right_array_args=right_array_args,
|
|
1164
|
+
zoom_control=zoom_control,
|
|
1165
|
+
fullscreen_control=fullscreen_control,
|
|
1166
|
+
layer_control=layer_control,
|
|
1167
|
+
add_close_button=add_close_button,
|
|
1168
|
+
left_label=left_label,
|
|
1169
|
+
right_label=right_label,
|
|
1170
|
+
left_position=left_position,
|
|
1171
|
+
right_position=right_position,
|
|
1172
|
+
widget_layout=widget_layout,
|
|
1173
|
+
draggable=draggable,
|
|
1174
|
+
)
|
|
1175
|
+
|
|
1176
|
+
return m
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: geoai-py
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
|
|
5
5
|
Author-email: Qiusheng Wu <giswqs@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -31,7 +31,6 @@ Requires-Dist: pystac-client
|
|
|
31
31
|
Requires-Dist: rasterio
|
|
32
32
|
Requires-Dist: rioxarray
|
|
33
33
|
Requires-Dist: scikit-learn
|
|
34
|
-
Requires-Dist: segment-geospatial
|
|
35
34
|
Requires-Dist: torch
|
|
36
35
|
Requires-Dist: torchgeo
|
|
37
36
|
Requires-Dist: tqdm
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
geoai/__init__.py,sha256=XvH0-HD7pv6gUBC8VJzToLqq7uCcWp2Orhp9sTR7Dd4,923
|
|
2
|
+
geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
|
|
3
|
+
geoai/extract.py,sha256=DQEvZXXmSxwoAUhJsXActM1IbFOg1ljg5fArqM5JdSk,75240
|
|
4
|
+
geoai/geoai.py,sha256=wNwKIqwOT10tU4uiWTcNp5Gd598rRFMANIfJsGdOWKM,90
|
|
5
|
+
geoai/preprocess.py,sha256=zQynNxQ_nxDkCEQU-h4G1SrgqxV1c5EREMV3JeS0cC0,118701
|
|
6
|
+
geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
|
|
7
|
+
geoai/utils.py,sha256=DCHfL2G5jH6RWDJUyJYORMa5uTaQp56Kj8sQXlG-0ck,42678
|
|
8
|
+
geoai_py-0.3.2.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
|
9
|
+
geoai_py-0.3.2.dist-info/METADATA,sha256=-j8ngaO_hQnWueYEoZsLbbIDN2cSj-uWpbBmD7m-2mU,5720
|
|
10
|
+
geoai_py-0.3.2.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
|
|
11
|
+
geoai_py-0.3.2.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
|
12
|
+
geoai_py-0.3.2.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
|
13
|
+
geoai_py-0.3.2.dist-info/RECORD,,
|
geoai_py-0.3.0.dist-info/RECORD
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
geoai/__init__.py,sha256=rJod2PDa1AiRHE8ugVp0Bfiky7ZWBhqbh2kZ45WiggA,923
|
|
2
|
-
geoai/download.py,sha256=4GiDmLrp2wKslgfm507WeZrwOdYcMekgQXxWGbl5cBw,13094
|
|
3
|
-
geoai/extract.py,sha256=9oLbrSg_aHcimpYxfk0jLOIHQWVULRsdiAGUsPLC-qk,71708
|
|
4
|
-
geoai/geoai.py,sha256=wNwKIqwOT10tU4uiWTcNp5Gd598rRFMANIfJsGdOWKM,90
|
|
5
|
-
geoai/preprocess.py,sha256=teV-W7ykXnoru0Y_d0V9ANdO6jMyETeGbqr1_8H-Yh0,118523
|
|
6
|
-
geoai/segmentation.py,sha256=Vcymnhwl_xikt4v9x8CYJq_vId9R1gB7-YzLfwg-F9M,11372
|
|
7
|
-
geoai/utils.py,sha256=3vXFDdFqZeg4kgeNt6-Hp28RfNoQcDOH7BjrlJ6L0UE,37521
|
|
8
|
-
geoai_py-0.3.0.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
|
9
|
-
geoai_py-0.3.0.dist-info/METADATA,sha256=L62RHKj0Yqno8LDYVrL50YyMfO1ybRYs2NI15WHiJMQ,5754
|
|
10
|
-
geoai_py-0.3.0.dist-info/WHEEL,sha256=rF4EZyR2XVS6irmOHQIJx2SUqXLZKRMUrjsg8UwN-XQ,109
|
|
11
|
-
geoai_py-0.3.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
|
12
|
-
geoai_py-0.3.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
|
13
|
-
geoai_py-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|