geoai-py 0.8.3__py2.py3-none-any.whl → 0.9.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/change_detection.py +1568 -0
- geoai/classify.py +58 -57
- geoai/detectron2.py +466 -0
- geoai/download.py +74 -68
- geoai/extract.py +186 -141
- geoai/geoai.py +13 -11
- geoai/hf.py +14 -12
- geoai/segment.py +44 -39
- geoai/segmentation.py +10 -9
- geoai/train.py +372 -241
- geoai/utils.py +198 -123
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/METADATA +5 -1
- geoai_py-0.9.1.dist-info/RECORD +19 -0
- geoai_py-0.8.3.dist-info/RECORD +0 -17
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/WHEEL +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.8.3.dist-info → geoai_py-0.9.1.dist-info}/top_level.txt +0 -0
geoai/extract.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
# Standard Library
|
4
4
|
import os
|
5
5
|
import time
|
6
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
6
7
|
|
7
8
|
# Third-Party Libraries
|
8
9
|
import cv2
|
@@ -64,13 +65,13 @@ class CustomDataset(NonGeoDataset):
|
|
64
65
|
|
65
66
|
def __init__(
|
66
67
|
self,
|
67
|
-
raster_path,
|
68
|
-
chip_size=(512, 512),
|
69
|
-
overlap=0.5,
|
70
|
-
transforms=None,
|
71
|
-
band_indexes=None,
|
72
|
-
verbose=False,
|
73
|
-
):
|
68
|
+
raster_path: str,
|
69
|
+
chip_size: Tuple[int, int] = (512, 512),
|
70
|
+
overlap: float = 0.5,
|
71
|
+
transforms: Optional[Any] = None,
|
72
|
+
band_indexes: Optional[List[int]] = None,
|
73
|
+
verbose: bool = False,
|
74
|
+
) -> None:
|
74
75
|
"""
|
75
76
|
Initialize the dataset with overlapping tiles.
|
76
77
|
|
@@ -163,7 +164,7 @@ class CustomDataset(NonGeoDataset):
|
|
163
164
|
# Get raster stats
|
164
165
|
self.raster_stats = get_raster_stats(raster_path, divide_by=255)
|
165
166
|
|
166
|
-
def __getitem__(self, idx):
|
167
|
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
167
168
|
"""
|
168
169
|
Get an image chip from the dataset by index.
|
169
170
|
|
@@ -255,7 +256,7 @@ class CustomDataset(NonGeoDataset):
|
|
255
256
|
), # Consistent format
|
256
257
|
}
|
257
258
|
|
258
|
-
def __len__(self):
|
259
|
+
def __len__(self) -> int:
|
259
260
|
"""
|
260
261
|
Return the number of samples in the dataset.
|
261
262
|
|
@@ -271,8 +272,13 @@ class ObjectDetector:
|
|
271
272
|
"""
|
272
273
|
|
273
274
|
def __init__(
|
274
|
-
self,
|
275
|
-
|
275
|
+
self,
|
276
|
+
model_path: Optional[str] = None,
|
277
|
+
repo_id: Optional[str] = None,
|
278
|
+
model: Optional[Any] = None,
|
279
|
+
num_classes: int = 2,
|
280
|
+
device: Optional[str] = None,
|
281
|
+
) -> None:
|
276
282
|
"""
|
277
283
|
Initialize the object extractor.
|
278
284
|
|
@@ -312,7 +318,9 @@ class ObjectDetector:
|
|
312
318
|
# Set model to evaluation mode
|
313
319
|
self.model.eval()
|
314
320
|
|
315
|
-
def download_model_from_hf(
|
321
|
+
def download_model_from_hf(
|
322
|
+
self, model_path: Optional[str] = None, repo_id: Optional[str] = None
|
323
|
+
) -> str:
|
316
324
|
"""
|
317
325
|
Download the object detection model from Hugging Face.
|
318
326
|
|
@@ -345,7 +353,7 @@ class ObjectDetector:
|
|
345
353
|
print("Please specify a local model path or ensure internet connectivity.")
|
346
354
|
raise
|
347
355
|
|
348
|
-
def initialize_model(self, model, num_classes=2):
|
356
|
+
def initialize_model(self, model: Optional[Any], num_classes: int = 2) -> Any:
|
349
357
|
"""Initialize a deep learning model for object detection.
|
350
358
|
|
351
359
|
Args:
|
@@ -375,7 +383,7 @@ class ObjectDetector:
|
|
375
383
|
model.to(self.device)
|
376
384
|
return model
|
377
385
|
|
378
|
-
def load_weights(self, model_path):
|
386
|
+
def load_weights(self, model_path: str) -> None:
|
379
387
|
"""
|
380
388
|
Load weights from file with error handling for different formats.
|
381
389
|
|
@@ -417,7 +425,7 @@ class ObjectDetector:
|
|
417
425
|
except Exception as e:
|
418
426
|
raise RuntimeError(f"Failed to load model: {e}")
|
419
427
|
|
420
|
-
def mask_to_polygons(self, mask, **kwargs):
|
428
|
+
def mask_to_polygons(self, mask: np.ndarray, **kwargs: Any) -> List[Polygon]:
|
421
429
|
"""
|
422
430
|
Convert binary mask to polygon contours using OpenCV.
|
423
431
|
|
@@ -474,7 +482,9 @@ class ObjectDetector:
|
|
474
482
|
|
475
483
|
return polygons
|
476
484
|
|
477
|
-
def filter_overlapping_polygons(
|
485
|
+
def filter_overlapping_polygons(
|
486
|
+
self, gdf: gpd.GeoDataFrame, **kwargs: Any
|
487
|
+
) -> gpd.GeoDataFrame:
|
478
488
|
"""
|
479
489
|
Filter overlapping polygons using non-maximum suppression.
|
480
490
|
|
@@ -531,7 +541,9 @@ class ObjectDetector:
|
|
531
541
|
|
532
542
|
return gdf.iloc[keep_indices]
|
533
543
|
|
534
|
-
def filter_edge_objects(
|
544
|
+
def filter_edge_objects(
|
545
|
+
self, gdf: gpd.GeoDataFrame, raster_path: str, edge_buffer: int = 10
|
546
|
+
) -> gpd.GeoDataFrame:
|
535
547
|
"""
|
536
548
|
Filter out object detections that fall in padding/edge areas of the image.
|
537
549
|
|
@@ -596,17 +608,17 @@ class ObjectDetector:
|
|
596
608
|
|
597
609
|
def masks_to_vector(
|
598
610
|
self,
|
599
|
-
mask_path,
|
600
|
-
output_path=None,
|
601
|
-
simplify_tolerance=None,
|
602
|
-
mask_threshold=None,
|
603
|
-
min_object_area=None,
|
604
|
-
max_object_area=None,
|
605
|
-
nms_iou_threshold=None,
|
606
|
-
regularize=True,
|
607
|
-
angle_threshold=15,
|
608
|
-
rectangularity_threshold=0.7,
|
609
|
-
):
|
611
|
+
mask_path: str,
|
612
|
+
output_path: Optional[str] = None,
|
613
|
+
simplify_tolerance: Optional[float] = None,
|
614
|
+
mask_threshold: Optional[float] = None,
|
615
|
+
min_object_area: Optional[int] = None,
|
616
|
+
max_object_area: Optional[int] = None,
|
617
|
+
nms_iou_threshold: Optional[float] = None,
|
618
|
+
regularize: bool = True,
|
619
|
+
angle_threshold: int = 15,
|
620
|
+
rectangularity_threshold: float = 0.7,
|
621
|
+
) -> gpd.GeoDataFrame:
|
610
622
|
"""
|
611
623
|
Convert an object mask GeoTIFF to vector polygons and save as GeoJSON.
|
612
624
|
|
@@ -808,14 +820,14 @@ class ObjectDetector:
|
|
808
820
|
@torch.no_grad()
|
809
821
|
def process_raster(
|
810
822
|
self,
|
811
|
-
raster_path,
|
812
|
-
output_path=None,
|
813
|
-
batch_size=4,
|
814
|
-
filter_edges=True,
|
815
|
-
edge_buffer=20,
|
816
|
-
band_indexes=None,
|
817
|
-
**kwargs,
|
818
|
-
):
|
823
|
+
raster_path: str,
|
824
|
+
output_path: Optional[str] = None,
|
825
|
+
batch_size: int = 4,
|
826
|
+
filter_edges: bool = True,
|
827
|
+
edge_buffer: int = 20,
|
828
|
+
band_indexes: Optional[List[int]] = None,
|
829
|
+
**kwargs: Any,
|
830
|
+
) -> "gpd.GeoDataFrame":
|
819
831
|
"""
|
820
832
|
Process a raster file to extract objects with customizable parameters.
|
821
833
|
|
@@ -874,7 +886,7 @@ class ObjectDetector:
|
|
874
886
|
self.raster_stats = dataset.raster_stats
|
875
887
|
|
876
888
|
# Custom collate function to handle Shapely objects
|
877
|
-
def custom_collate(batch):
|
889
|
+
def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
878
890
|
"""
|
879
891
|
Custom collate function that handles Shapely geometries
|
880
892
|
by keeping them as Python objects rather than trying to collate them.
|
@@ -1056,8 +1068,13 @@ class ObjectDetector:
|
|
1056
1068
|
return gdf
|
1057
1069
|
|
1058
1070
|
def save_masks_as_geotiff(
|
1059
|
-
self,
|
1060
|
-
|
1071
|
+
self,
|
1072
|
+
raster_path: str,
|
1073
|
+
output_path: Optional[str] = None,
|
1074
|
+
batch_size: int = 4,
|
1075
|
+
verbose: bool = False,
|
1076
|
+
**kwargs: Any,
|
1077
|
+
) -> str:
|
1061
1078
|
"""
|
1062
1079
|
Process a raster file to extract object masks and save as GeoTIFF.
|
1063
1080
|
|
@@ -1125,7 +1142,7 @@ class ObjectDetector:
|
|
1125
1142
|
mask_array = np.zeros((src.height, src.width), dtype=np.uint8)
|
1126
1143
|
|
1127
1144
|
# Custom collate function to handle Shapely objects
|
1128
|
-
def custom_collate(batch):
|
1145
|
+
def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
1129
1146
|
"""Custom collate function for DataLoader"""
|
1130
1147
|
elem = batch[0]
|
1131
1148
|
if isinstance(elem, dict):
|
@@ -1291,12 +1308,12 @@ class ObjectDetector:
|
|
1291
1308
|
|
1292
1309
|
def regularize_objects(
|
1293
1310
|
self,
|
1294
|
-
gdf,
|
1295
|
-
min_area=10,
|
1296
|
-
angle_threshold=15,
|
1297
|
-
orthogonality_threshold=0.3,
|
1298
|
-
rectangularity_threshold=0.7,
|
1299
|
-
):
|
1311
|
+
gdf: gpd.GeoDataFrame,
|
1312
|
+
min_area: int = 10,
|
1313
|
+
angle_threshold: int = 15,
|
1314
|
+
orthogonality_threshold: float = 0.3,
|
1315
|
+
rectangularity_threshold: float = 0.7,
|
1316
|
+
) -> gpd.GeoDataFrame:
|
1300
1317
|
"""
|
1301
1318
|
Regularize objects to enforce right angles and rectangular shapes.
|
1302
1319
|
|
@@ -1319,7 +1336,9 @@ class ObjectDetector:
|
|
1319
1336
|
from shapely.geometry import MultiPolygon, Polygon, box
|
1320
1337
|
from tqdm import tqdm
|
1321
1338
|
|
1322
|
-
def get_angle(
|
1339
|
+
def get_angle(
|
1340
|
+
p1: Tuple[float, float], p2: Tuple[float, float], p3: Tuple[float, float]
|
1341
|
+
) -> float:
|
1323
1342
|
"""Calculate angle between three points in degrees (0-180)"""
|
1324
1343
|
a = np.array(p1)
|
1325
1344
|
b = np.array(p2)
|
@@ -1335,11 +1354,11 @@ class ObjectDetector:
|
|
1335
1354
|
|
1336
1355
|
return angle
|
1337
1356
|
|
1338
|
-
def is_orthogonal(angle, threshold=angle_threshold):
|
1357
|
+
def is_orthogonal(angle: float, threshold: int = angle_threshold) -> bool:
|
1339
1358
|
"""Check if angle is close to 90 degrees"""
|
1340
1359
|
return abs(angle - 90) <= threshold
|
1341
1360
|
|
1342
|
-
def calculate_dominant_direction(polygon):
|
1361
|
+
def calculate_dominant_direction(polygon: Polygon) -> float:
|
1343
1362
|
"""Find the dominant direction of a polygon using PCA"""
|
1344
1363
|
# Extract coordinates
|
1345
1364
|
coords = np.array(polygon.exterior.coords)
|
@@ -1368,7 +1387,7 @@ class ObjectDetector:
|
|
1368
1387
|
|
1369
1388
|
return angle_deg
|
1370
1389
|
|
1371
|
-
def create_oriented_envelope(polygon, angle_deg):
|
1390
|
+
def create_oriented_envelope(polygon: Polygon, angle_deg: float) -> Polygon:
|
1372
1391
|
"""Create an oriented minimum area rectangle for the polygon"""
|
1373
1392
|
# Create a rotated rectangle using OpenCV method (more robust than Shapely methods)
|
1374
1393
|
coords = np.array(polygon.exterior.coords)[:-1].astype(
|
@@ -1384,13 +1403,13 @@ class ObjectDetector:
|
|
1384
1403
|
|
1385
1404
|
return oriented_box
|
1386
1405
|
|
1387
|
-
def get_rectangularity(polygon, oriented_box):
|
1406
|
+
def get_rectangularity(polygon: Polygon, oriented_box: Polygon) -> float:
|
1388
1407
|
"""Calculate the rectangularity (area ratio to its oriented bounding box)"""
|
1389
1408
|
if oriented_box.area == 0:
|
1390
1409
|
return 0
|
1391
1410
|
return polygon.area / oriented_box.area
|
1392
1411
|
|
1393
|
-
def check_orthogonality(polygon):
|
1412
|
+
def check_orthogonality(polygon: Polygon) -> float:
|
1394
1413
|
"""Check what percentage of angles in the polygon are orthogonal"""
|
1395
1414
|
coords = list(polygon.exterior.coords)
|
1396
1415
|
if len(coords) <= 4: # Triangle or point
|
@@ -1413,7 +1432,7 @@ class ObjectDetector:
|
|
1413
1432
|
|
1414
1433
|
return orthogonal_count / total_angles
|
1415
1434
|
|
1416
|
-
def simplify_to_rectangle(polygon):
|
1435
|
+
def simplify_to_rectangle(polygon: Polygon) -> Polygon:
|
1417
1436
|
"""Simplify a polygon to a rectangle using its oriented bounding box"""
|
1418
1437
|
# Get dominant direction
|
1419
1438
|
angle = calculate_dominant_direction(polygon)
|
@@ -1506,8 +1525,12 @@ class ObjectDetector:
|
|
1506
1525
|
return result_gdf
|
1507
1526
|
|
1508
1527
|
def visualize_results(
|
1509
|
-
self,
|
1510
|
-
|
1528
|
+
self,
|
1529
|
+
raster_path: str,
|
1530
|
+
gdf: Optional[gpd.GeoDataFrame] = None,
|
1531
|
+
output_path: Optional[str] = None,
|
1532
|
+
figsize: Tuple[int, int] = (12, 12),
|
1533
|
+
) -> bool:
|
1511
1534
|
"""
|
1512
1535
|
Visualize object detection results with proper coordinate transformation.
|
1513
1536
|
|
@@ -1653,7 +1676,9 @@ class ObjectDetector:
|
|
1653
1676
|
has_confidence = False
|
1654
1677
|
|
1655
1678
|
# Function to convert coordinates
|
1656
|
-
def geo_to_pixel(
|
1679
|
+
def geo_to_pixel(
|
1680
|
+
geometry: Any, transform: Any
|
1681
|
+
) -> Optional[Tuple[List[float], List[float]]]:
|
1657
1682
|
"""Convert geometry to pixel coordinates using the provided transform."""
|
1658
1683
|
if geometry.is_empty:
|
1659
1684
|
return None
|
@@ -1913,18 +1938,18 @@ class ObjectDetector:
|
|
1913
1938
|
|
1914
1939
|
def generate_masks(
|
1915
1940
|
self,
|
1916
|
-
raster_path,
|
1917
|
-
output_path=None,
|
1918
|
-
confidence_threshold=None,
|
1919
|
-
mask_threshold=None,
|
1920
|
-
min_object_area=10,
|
1921
|
-
max_object_area=float("inf"),
|
1922
|
-
overlap=0.25,
|
1923
|
-
batch_size=4,
|
1924
|
-
band_indexes=None,
|
1925
|
-
verbose=False,
|
1926
|
-
**kwargs,
|
1927
|
-
):
|
1941
|
+
raster_path: str,
|
1942
|
+
output_path: Optional[str] = None,
|
1943
|
+
confidence_threshold: Optional[float] = None,
|
1944
|
+
mask_threshold: Optional[float] = None,
|
1945
|
+
min_object_area: int = 10,
|
1946
|
+
max_object_area: float = float("inf"),
|
1947
|
+
overlap: float = 0.25,
|
1948
|
+
batch_size: int = 4,
|
1949
|
+
band_indexes: Optional[List[int]] = None,
|
1950
|
+
verbose: bool = False,
|
1951
|
+
**kwargs: Any,
|
1952
|
+
) -> str:
|
1928
1953
|
"""
|
1929
1954
|
Save masks with confidence values as a multi-band GeoTIFF.
|
1930
1955
|
|
@@ -1983,7 +2008,7 @@ class ObjectDetector:
|
|
1983
2008
|
conf_array = np.zeros((src.height, src.width), dtype=np.uint8)
|
1984
2009
|
|
1985
2010
|
# Define custom collate function to handle Shapely objects
|
1986
|
-
def custom_collate(batch):
|
2011
|
+
def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
1987
2012
|
"""
|
1988
2013
|
Custom collate function that handles Shapely geometries
|
1989
2014
|
by keeping them as Python objects rather than trying to collate them.
|
@@ -2113,14 +2138,14 @@ class ObjectDetector:
|
|
2113
2138
|
|
2114
2139
|
def vectorize_masks(
|
2115
2140
|
self,
|
2116
|
-
masks_path,
|
2117
|
-
output_path=None,
|
2118
|
-
confidence_threshold=0.5,
|
2119
|
-
min_object_area=100,
|
2120
|
-
max_object_area=None,
|
2121
|
-
n_workers=None,
|
2122
|
-
**kwargs,
|
2123
|
-
):
|
2141
|
+
masks_path: str,
|
2142
|
+
output_path: Optional[str] = None,
|
2143
|
+
confidence_threshold: float = 0.5,
|
2144
|
+
min_object_area: int = 100,
|
2145
|
+
max_object_area: Optional[int] = None,
|
2146
|
+
n_workers: Optional[int] = None,
|
2147
|
+
**kwargs: Any,
|
2148
|
+
) -> gpd.GeoDataFrame:
|
2124
2149
|
"""
|
2125
2150
|
Convert masks with confidence to vector polygons.
|
2126
2151
|
|
@@ -2142,13 +2167,13 @@ class ObjectDetector:
|
|
2142
2167
|
"""
|
2143
2168
|
|
2144
2169
|
def _process_single_component(
|
2145
|
-
component_mask,
|
2146
|
-
conf_data,
|
2147
|
-
transform,
|
2148
|
-
confidence_threshold,
|
2149
|
-
min_object_area,
|
2150
|
-
max_object_area,
|
2151
|
-
):
|
2170
|
+
component_mask: np.ndarray,
|
2171
|
+
conf_data: np.ndarray,
|
2172
|
+
transform: Any,
|
2173
|
+
confidence_threshold: float,
|
2174
|
+
min_object_area: int,
|
2175
|
+
max_object_area: Optional[int],
|
2176
|
+
) -> Optional[Dict[str, Any]]:
|
2152
2177
|
# Get confidence value
|
2153
2178
|
conf_region = conf_data[component_mask > 0]
|
2154
2179
|
if len(conf_region) > 0:
|
@@ -2195,7 +2220,9 @@ class ObjectDetector:
|
|
2195
2220
|
import concurrent.futures
|
2196
2221
|
from functools import partial
|
2197
2222
|
|
2198
|
-
def process_component(
|
2223
|
+
def process_component(
|
2224
|
+
args: Tuple[int, np.ndarray, np.ndarray, Any, float, int, Optional[int]],
|
2225
|
+
) -> Optional[Dict[str, Any]]:
|
2199
2226
|
"""
|
2200
2227
|
Helper function to process a single component
|
2201
2228
|
"""
|
@@ -2346,11 +2373,11 @@ class BuildingFootprintExtractor(ObjectDetector):
|
|
2346
2373
|
|
2347
2374
|
def __init__(
|
2348
2375
|
self,
|
2349
|
-
model_path="building_footprints_usa.pth",
|
2350
|
-
repo_id=None,
|
2351
|
-
model=None,
|
2352
|
-
device=None,
|
2353
|
-
):
|
2376
|
+
model_path: str = "building_footprints_usa.pth",
|
2377
|
+
repo_id: Optional[str] = None,
|
2378
|
+
model: Optional[Any] = None,
|
2379
|
+
device: Optional[str] = None,
|
2380
|
+
) -> None:
|
2354
2381
|
"""
|
2355
2382
|
Initialize the object extractor.
|
2356
2383
|
|
@@ -2366,12 +2393,12 @@ class BuildingFootprintExtractor(ObjectDetector):
|
|
2366
2393
|
|
2367
2394
|
def regularize_buildings(
|
2368
2395
|
self,
|
2369
|
-
gdf,
|
2370
|
-
min_area=10,
|
2371
|
-
angle_threshold=15,
|
2372
|
-
orthogonality_threshold=0.3,
|
2373
|
-
rectangularity_threshold=0.7,
|
2374
|
-
):
|
2396
|
+
gdf: gpd.GeoDataFrame,
|
2397
|
+
min_area: int = 10,
|
2398
|
+
angle_threshold: int = 15,
|
2399
|
+
orthogonality_threshold: float = 0.3,
|
2400
|
+
rectangularity_threshold: float = 0.7,
|
2401
|
+
) -> gpd.GeoDataFrame:
|
2375
2402
|
"""
|
2376
2403
|
Regularize building footprints to enforce right angles and rectangular shapes.
|
2377
2404
|
|
@@ -2402,8 +2429,12 @@ class CarDetector(ObjectDetector):
|
|
2402
2429
|
"""
|
2403
2430
|
|
2404
2431
|
def __init__(
|
2405
|
-
self,
|
2406
|
-
|
2432
|
+
self,
|
2433
|
+
model_path: str = "car_detection_usa.pth",
|
2434
|
+
repo_id: Optional[str] = None,
|
2435
|
+
model: Optional[Any] = None,
|
2436
|
+
device: Optional[str] = None,
|
2437
|
+
) -> None:
|
2407
2438
|
"""
|
2408
2439
|
Initialize the object extractor.
|
2409
2440
|
|
@@ -2427,8 +2458,12 @@ class ShipDetector(ObjectDetector):
|
|
2427
2458
|
"""
|
2428
2459
|
|
2429
2460
|
def __init__(
|
2430
|
-
self,
|
2431
|
-
|
2461
|
+
self,
|
2462
|
+
model_path: str = "ship_detection.pth",
|
2463
|
+
repo_id: Optional[str] = None,
|
2464
|
+
model: Optional[Any] = None,
|
2465
|
+
device: Optional[str] = None,
|
2466
|
+
) -> None:
|
2432
2467
|
"""
|
2433
2468
|
Initialize the object extractor.
|
2434
2469
|
|
@@ -2453,11 +2488,11 @@ class SolarPanelDetector(ObjectDetector):
|
|
2453
2488
|
|
2454
2489
|
def __init__(
|
2455
2490
|
self,
|
2456
|
-
model_path="solar_panel_detection.pth",
|
2457
|
-
repo_id=None,
|
2458
|
-
model=None,
|
2459
|
-
device=None,
|
2460
|
-
):
|
2491
|
+
model_path: str = "solar_panel_detection.pth",
|
2492
|
+
repo_id: Optional[str] = None,
|
2493
|
+
model: Optional[Any] = None,
|
2494
|
+
device: Optional[str] = None,
|
2495
|
+
) -> None:
|
2461
2496
|
"""
|
2462
2497
|
Initialize the object extractor.
|
2463
2498
|
|
@@ -2481,12 +2516,12 @@ class ParkingSplotDetector(ObjectDetector):
|
|
2481
2516
|
|
2482
2517
|
def __init__(
|
2483
2518
|
self,
|
2484
|
-
model_path="parking_spot_detection.pth",
|
2485
|
-
repo_id=None,
|
2486
|
-
model=None,
|
2487
|
-
num_classes=3,
|
2488
|
-
device=None,
|
2489
|
-
):
|
2519
|
+
model_path: str = "parking_spot_detection.pth",
|
2520
|
+
repo_id: Optional[str] = None,
|
2521
|
+
model: Optional[Any] = None,
|
2522
|
+
num_classes: int = 3,
|
2523
|
+
device: Optional[str] = None,
|
2524
|
+
) -> None:
|
2490
2525
|
"""
|
2491
2526
|
Initialize the object extractor.
|
2492
2527
|
|
@@ -2521,13 +2556,13 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2521
2556
|
|
2522
2557
|
def __init__(
|
2523
2558
|
self,
|
2524
|
-
model_path="field_boundary_detector.pth",
|
2525
|
-
repo_id=None,
|
2526
|
-
model=None,
|
2527
|
-
device=None,
|
2528
|
-
band_selection=None,
|
2529
|
-
use_ndvi=False,
|
2530
|
-
):
|
2559
|
+
model_path: str = "field_boundary_detector.pth",
|
2560
|
+
repo_id: Optional[str] = None,
|
2561
|
+
model: Optional[Any] = None,
|
2562
|
+
device: Optional[str] = None,
|
2563
|
+
band_selection: Optional[List[int]] = None,
|
2564
|
+
use_ndvi: bool = False,
|
2565
|
+
) -> None:
|
2531
2566
|
"""
|
2532
2567
|
Initialize the field boundary delineator.
|
2533
2568
|
|
@@ -2603,7 +2638,7 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2603
2638
|
self.min_object_area = 1000 # Minimum area in pixels for field detection
|
2604
2639
|
self.simplify_tolerance = 2.0 # Higher tolerance for field boundaries
|
2605
2640
|
|
2606
|
-
def initialize_sentinel2_model(self, model=None):
|
2641
|
+
def initialize_sentinel2_model(self, model: Optional[Any] = None) -> Any:
|
2607
2642
|
"""
|
2608
2643
|
Initialize a Mask R-CNN model with a modified first layer to accept Sentinel-2 data.
|
2609
2644
|
|
@@ -2656,7 +2691,12 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2656
2691
|
model.to(self.device)
|
2657
2692
|
return model
|
2658
2693
|
|
2659
|
-
def preprocess_sentinel_bands(
|
2694
|
+
def preprocess_sentinel_bands(
|
2695
|
+
self,
|
2696
|
+
image_data: np.ndarray,
|
2697
|
+
band_selection: Optional[List[int]] = None,
|
2698
|
+
use_ndvi: Optional[bool] = None,
|
2699
|
+
) -> torch.Tensor:
|
2660
2700
|
"""
|
2661
2701
|
Preprocess Sentinel-2 band data for model input.
|
2662
2702
|
|
@@ -2718,7 +2758,12 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2718
2758
|
|
2719
2759
|
return image_tensor
|
2720
2760
|
|
2721
|
-
def update_band_stats(
|
2761
|
+
def update_band_stats(
|
2762
|
+
self,
|
2763
|
+
raster_path: str,
|
2764
|
+
band_selection: Optional[List[int]] = None,
|
2765
|
+
sample_size: int = 1000,
|
2766
|
+
) -> Dict[str, List[float]]:
|
2722
2767
|
"""
|
2723
2768
|
Update band statistics from the input Sentinel-2 raster.
|
2724
2769
|
|
@@ -2782,15 +2827,15 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2782
2827
|
|
2783
2828
|
def process_sentinel_raster(
|
2784
2829
|
self,
|
2785
|
-
raster_path,
|
2786
|
-
output_path=None,
|
2787
|
-
batch_size=4,
|
2788
|
-
band_selection=None,
|
2789
|
-
use_ndvi=None,
|
2790
|
-
filter_edges=True,
|
2791
|
-
edge_buffer=20,
|
2792
|
-
**kwargs,
|
2793
|
-
):
|
2830
|
+
raster_path: str,
|
2831
|
+
output_path: Optional[str] = None,
|
2832
|
+
batch_size: int = 4,
|
2833
|
+
band_selection: Optional[List[int]] = None,
|
2834
|
+
use_ndvi: Optional[bool] = None,
|
2835
|
+
filter_edges: bool = True,
|
2836
|
+
edge_buffer: int = 20,
|
2837
|
+
**kwargs: Any,
|
2838
|
+
) -> gpd.GeoDataFrame:
|
2794
2839
|
"""
|
2795
2840
|
Process a Sentinel-2 raster to extract field boundaries.
|
2796
2841
|
|
@@ -2840,14 +2885,14 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2840
2885
|
class Sentinel2Dataset(torch.utils.data.Dataset):
|
2841
2886
|
def __init__(
|
2842
2887
|
self,
|
2843
|
-
raster_path,
|
2844
|
-
chip_size,
|
2845
|
-
stride_x,
|
2846
|
-
stride_y,
|
2847
|
-
band_selection,
|
2848
|
-
use_ndvi,
|
2849
|
-
field_delineator,
|
2850
|
-
):
|
2888
|
+
raster_path: str,
|
2889
|
+
chip_size: Tuple[int, int],
|
2890
|
+
stride_x: int,
|
2891
|
+
stride_y: int,
|
2892
|
+
band_selection: List[int],
|
2893
|
+
use_ndvi: bool,
|
2894
|
+
field_delineator: Any,
|
2895
|
+
) -> None:
|
2851
2896
|
self.raster_path = raster_path
|
2852
2897
|
self.chip_size = chip_size
|
2853
2898
|
self.stride_x = stride_x
|
@@ -2901,10 +2946,10 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2901
2946
|
print(f"Image dimensions: {self.width} x {self.height} pixels")
|
2902
2947
|
print(f"Chip size: {self.chip_size[1]} x {self.chip_size[0]} pixels")
|
2903
2948
|
|
2904
|
-
def __len__(self):
|
2949
|
+
def __len__(self) -> int:
|
2905
2950
|
return self.rows * self.cols
|
2906
2951
|
|
2907
|
-
def __getitem__(self, idx):
|
2952
|
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
2908
2953
|
# Convert flat index to grid position
|
2909
2954
|
row = idx // self.cols
|
2910
2955
|
col = idx % self.cols
|
@@ -2971,7 +3016,7 @@ class AgricultureFieldDelineator(ObjectDetector):
|
|
2971
3016
|
)
|
2972
3017
|
|
2973
3018
|
# Define custom collate function
|
2974
|
-
def custom_collate(batch):
|
3019
|
+
def custom_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
2975
3020
|
elem = batch[0]
|
2976
3021
|
if isinstance(elem, dict):
|
2977
3022
|
result = {}
|