geoai-py 0.3.4__py2.py3-none-any.whl → 0.3.6__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/extract.py +205 -104
- geoai/geoai.py +1 -0
- geoai/segment.py +305 -0
- geoai/utils.py +861 -18
- {geoai_py-0.3.4.dist-info → geoai_py-0.3.6.dist-info}/METADATA +5 -1
- geoai_py-0.3.6.dist-info/RECORD +13 -0
- {geoai_py-0.3.4.dist-info → geoai_py-0.3.6.dist-info}/WHEEL +1 -1
- geoai/preprocess.py +0 -3021
- geoai_py-0.3.4.dist-info/RECORD +0 -13
- {geoai_py-0.3.4.dist-info → geoai_py-0.3.6.dist-info}/LICENSE +0 -0
- {geoai_py-0.3.4.dist-info → geoai_py-0.3.6.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.3.4.dist-info → geoai_py-0.3.6.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/extract.py
CHANGED
|
@@ -1,21 +1,29 @@
|
|
|
1
|
+
"""This module provides a dataset class for object extraction from raster data"""
|
|
2
|
+
|
|
3
|
+
# Standard Library
|
|
1
4
|
import os
|
|
5
|
+
|
|
6
|
+
# Third-Party Libraries
|
|
7
|
+
import cv2
|
|
8
|
+
import geopandas as gpd
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
2
10
|
import numpy as np
|
|
11
|
+
import rasterio
|
|
12
|
+
import scipy.ndimage as ndimage
|
|
3
13
|
import torch
|
|
4
|
-
|
|
14
|
+
from huggingface_hub import hf_hub_download
|
|
15
|
+
from rasterio.windows import Window
|
|
5
16
|
from shapely.geometry import Polygon, box
|
|
6
|
-
import geopandas as gpd
|
|
7
17
|
from tqdm import tqdm
|
|
18
|
+
from torchvision.models.detection import (
|
|
19
|
+
maskrcnn_resnet50_fpn,
|
|
20
|
+
fasterrcnn_resnet50_fpn_v2,
|
|
21
|
+
)
|
|
8
22
|
|
|
9
|
-
|
|
10
|
-
from torchvision.models.detection import maskrcnn_resnet50_fpn
|
|
11
|
-
import torchvision.transforms as T
|
|
12
|
-
import rasterio
|
|
13
|
-
from rasterio.windows import Window
|
|
14
|
-
from rasterio.features import shapes
|
|
15
|
-
from huggingface_hub import hf_hub_download
|
|
16
|
-
import scipy.ndimage as ndimage
|
|
23
|
+
# Local Imports
|
|
17
24
|
from .utils import get_raster_stats
|
|
18
25
|
|
|
26
|
+
|
|
19
27
|
try:
|
|
20
28
|
from torchgeo.datasets import NonGeoDataset
|
|
21
29
|
except ImportError as e:
|
|
@@ -60,6 +68,7 @@ class CustomDataset(NonGeoDataset):
|
|
|
60
68
|
chip_size=(512, 512),
|
|
61
69
|
overlap=0.5,
|
|
62
70
|
transforms=None,
|
|
71
|
+
band_indexes=None,
|
|
63
72
|
verbose=False,
|
|
64
73
|
):
|
|
65
74
|
"""
|
|
@@ -70,6 +79,7 @@ class CustomDataset(NonGeoDataset):
|
|
|
70
79
|
chip_size: Size of image chips to extract (height, width). Default is (512, 512).
|
|
71
80
|
overlap: Amount of overlap between adjacent tiles (0.0-1.0). Default is 0.5 (50%).
|
|
72
81
|
transforms: Transforms to apply to the image. Default is None.
|
|
82
|
+
band_indexes: List of band indexes to use. Default is None (use all bands).
|
|
73
83
|
verbose: Whether to print detailed processing information. Default is False.
|
|
74
84
|
|
|
75
85
|
Raises:
|
|
@@ -82,6 +92,7 @@ class CustomDataset(NonGeoDataset):
|
|
|
82
92
|
self.chip_size = chip_size
|
|
83
93
|
self.overlap = overlap
|
|
84
94
|
self.transforms = transforms
|
|
95
|
+
self.band_indexes = band_indexes
|
|
85
96
|
self.verbose = verbose
|
|
86
97
|
self.warned_about_bands = False
|
|
87
98
|
|
|
@@ -191,7 +202,10 @@ class CustomDataset(NonGeoDataset):
|
|
|
191
202
|
if not self.warned_about_bands and self.verbose:
|
|
192
203
|
print(f"Image has {image.shape[0]} bands, using first 3 bands only")
|
|
193
204
|
self.warned_about_bands = True
|
|
194
|
-
|
|
205
|
+
if self.band_indexes is not None:
|
|
206
|
+
image = image[self.band_indexes]
|
|
207
|
+
else:
|
|
208
|
+
image = image[:3]
|
|
195
209
|
elif image.shape[0] < 3:
|
|
196
210
|
# If image has fewer than 3 bands, duplicate the last band to make 3
|
|
197
211
|
if not self.warned_about_bands and self.verbose:
|
|
@@ -594,7 +608,7 @@ class ObjectDetector:
|
|
|
594
608
|
|
|
595
609
|
Args:
|
|
596
610
|
mask_path: Path to the object masks GeoTIFF
|
|
597
|
-
output_path: Path to save the output GeoJSON (default: mask_path with .geojson extension)
|
|
611
|
+
output_path: Path to save the output GeoJSON or Parquet file (default: mask_path with .geojson extension)
|
|
598
612
|
simplify_tolerance: Tolerance for polygon simplification (default: self.simplify_tolerance)
|
|
599
613
|
mask_threshold: Threshold for mask binarization (default: self.mask_threshold)
|
|
600
614
|
min_object_area: Minimum area in pixels to keep an object (default: self.min_object_area)
|
|
@@ -779,7 +793,10 @@ class ObjectDetector:
|
|
|
779
793
|
|
|
780
794
|
# Save to file
|
|
781
795
|
if output_path:
|
|
782
|
-
|
|
796
|
+
if output_path.endswith(".parquet"):
|
|
797
|
+
gdf.to_parquet(output_path)
|
|
798
|
+
else:
|
|
799
|
+
gdf.to_file(output_path)
|
|
783
800
|
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
784
801
|
|
|
785
802
|
return gdf
|
|
@@ -792,6 +809,7 @@ class ObjectDetector:
|
|
|
792
809
|
batch_size=4,
|
|
793
810
|
filter_edges=True,
|
|
794
811
|
edge_buffer=20,
|
|
812
|
+
band_indexes=None,
|
|
795
813
|
**kwargs,
|
|
796
814
|
):
|
|
797
815
|
"""
|
|
@@ -799,10 +817,11 @@ class ObjectDetector:
|
|
|
799
817
|
|
|
800
818
|
Args:
|
|
801
819
|
raster_path: Path to input raster file
|
|
802
|
-
output_path: Path to output GeoJSON file (optional)
|
|
820
|
+
output_path: Path to output GeoJSON or Parquet file (optional)
|
|
803
821
|
batch_size: Batch size for processing
|
|
804
822
|
filter_edges: Whether to filter out objects at the edges of the image
|
|
805
823
|
edge_buffer: Size of edge buffer in pixels to filter out objects (if filter_edges=True)
|
|
824
|
+
band_indexes: List of band indexes to use (if None, use all bands)
|
|
806
825
|
**kwargs: Additional parameters:
|
|
807
826
|
confidence_threshold: Minimum confidence score to keep a detection (0.0-1.0)
|
|
808
827
|
overlap: Overlap between adjacent tiles (0.0-1.0)
|
|
@@ -843,7 +862,10 @@ class ObjectDetector:
|
|
|
843
862
|
|
|
844
863
|
# Create dataset
|
|
845
864
|
dataset = CustomDataset(
|
|
846
|
-
raster_path=raster_path,
|
|
865
|
+
raster_path=raster_path,
|
|
866
|
+
chip_size=chip_size,
|
|
867
|
+
overlap=overlap,
|
|
868
|
+
band_indexes=band_indexes,
|
|
847
869
|
)
|
|
848
870
|
self.raster_stats = dataset.raster_stats
|
|
849
871
|
|
|
@@ -1021,7 +1043,10 @@ class ObjectDetector:
|
|
|
1021
1043
|
|
|
1022
1044
|
# Save to file if requested
|
|
1023
1045
|
if output_path:
|
|
1024
|
-
|
|
1046
|
+
if output_path.endswith(".parquet"):
|
|
1047
|
+
gdf.to_parquet(output_path)
|
|
1048
|
+
else:
|
|
1049
|
+
gdf.to_file(output_path, driver="GeoJSON")
|
|
1025
1050
|
print(f"Saved {len(gdf)} objects to {output_path}")
|
|
1026
1051
|
|
|
1027
1052
|
return gdf
|
|
@@ -1881,108 +1906,36 @@ class ObjectDetector:
|
|
|
1881
1906
|
plt.savefig(sample_output, dpi=300, bbox_inches="tight")
|
|
1882
1907
|
print(f"Sample visualization saved to {sample_output}")
|
|
1883
1908
|
|
|
1884
|
-
|
|
1885
|
-
class BuildingFootprintExtractor(ObjectDetector):
|
|
1886
|
-
"""
|
|
1887
|
-
Building footprint extraction using a pre-trained Mask R-CNN model.
|
|
1888
|
-
|
|
1889
|
-
This class extends the
|
|
1890
|
-
`ObjectDetector` class with additional methods for building footprint extraction."
|
|
1891
|
-
"""
|
|
1892
|
-
|
|
1893
|
-
def __init__(
|
|
1894
|
-
self,
|
|
1895
|
-
model_path="building_footprints_usa.pth",
|
|
1896
|
-
repo_id=None,
|
|
1897
|
-
model=None,
|
|
1898
|
-
device=None,
|
|
1899
|
-
):
|
|
1900
|
-
"""
|
|
1901
|
-
Initialize the object extractor.
|
|
1902
|
-
|
|
1903
|
-
Args:
|
|
1904
|
-
model_path: Path to the .pth model file.
|
|
1905
|
-
repo_id: Repo ID for loading models from the Hub.
|
|
1906
|
-
model: Custom model to use for inference.
|
|
1907
|
-
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1908
|
-
"""
|
|
1909
|
-
super().__init__(
|
|
1910
|
-
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1911
|
-
)
|
|
1912
|
-
|
|
1913
|
-
def regularize_buildings(
|
|
1914
|
-
self,
|
|
1915
|
-
gdf,
|
|
1916
|
-
min_area=10,
|
|
1917
|
-
angle_threshold=15,
|
|
1918
|
-
orthogonality_threshold=0.3,
|
|
1919
|
-
rectangularity_threshold=0.7,
|
|
1920
|
-
):
|
|
1921
|
-
"""
|
|
1922
|
-
Regularize building footprints to enforce right angles and rectangular shapes.
|
|
1923
|
-
|
|
1924
|
-
Args:
|
|
1925
|
-
gdf: GeoDataFrame with building footprints
|
|
1926
|
-
min_area: Minimum area in square units to keep a building
|
|
1927
|
-
angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
|
|
1928
|
-
orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
|
|
1929
|
-
rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
|
|
1930
|
-
|
|
1931
|
-
Returns:
|
|
1932
|
-
GeoDataFrame with regularized building footprints
|
|
1933
|
-
"""
|
|
1934
|
-
return self.regularize_objects(
|
|
1935
|
-
gdf,
|
|
1936
|
-
min_area=min_area,
|
|
1937
|
-
angle_threshold=angle_threshold,
|
|
1938
|
-
orthogonality_threshold=orthogonality_threshold,
|
|
1939
|
-
rectangularity_threshold=rectangularity_threshold,
|
|
1940
|
-
)
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
class CarDetector(ObjectDetector):
|
|
1944
|
-
"""
|
|
1945
|
-
Car detection using a pre-trained Mask R-CNN model.
|
|
1946
|
-
|
|
1947
|
-
This class extends the `ObjectDetector` class with additional methods for car detection.
|
|
1948
|
-
"""
|
|
1949
|
-
|
|
1950
|
-
def __init__(
|
|
1951
|
-
self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
|
|
1952
|
-
):
|
|
1953
|
-
"""
|
|
1954
|
-
Initialize the object extractor.
|
|
1955
|
-
|
|
1956
|
-
Args:
|
|
1957
|
-
model_path: Path to the .pth model file.
|
|
1958
|
-
repo_id: Repo ID for loading models from the Hub.
|
|
1959
|
-
model: Custom model to use for inference.
|
|
1960
|
-
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
1961
|
-
"""
|
|
1962
|
-
super().__init__(
|
|
1963
|
-
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
1964
|
-
)
|
|
1965
|
-
|
|
1966
1909
|
def generate_masks(
|
|
1967
1910
|
self,
|
|
1968
1911
|
raster_path,
|
|
1969
1912
|
output_path=None,
|
|
1970
1913
|
confidence_threshold=None,
|
|
1971
1914
|
mask_threshold=None,
|
|
1915
|
+
min_object_area=10,
|
|
1916
|
+
max_object_area=float("inf"),
|
|
1972
1917
|
overlap=0.25,
|
|
1973
1918
|
batch_size=4,
|
|
1919
|
+
band_indexes=None,
|
|
1974
1920
|
verbose=False,
|
|
1921
|
+
**kwargs,
|
|
1975
1922
|
):
|
|
1976
1923
|
"""
|
|
1977
1924
|
Save masks with confidence values as a multi-band GeoTIFF.
|
|
1978
1925
|
|
|
1926
|
+
Objects with area smaller than min_object_area or larger than max_object_area
|
|
1927
|
+
will be filtered out.
|
|
1928
|
+
|
|
1979
1929
|
Args:
|
|
1980
1930
|
raster_path: Path to input raster
|
|
1981
1931
|
output_path: Path for output GeoTIFF
|
|
1982
1932
|
confidence_threshold: Minimum confidence score (0.0-1.0)
|
|
1983
1933
|
mask_threshold: Threshold for mask binarization (0.0-1.0)
|
|
1934
|
+
min_object_area: Minimum area (in pixels) for an object to be included
|
|
1935
|
+
max_object_area: Maximum area (in pixels) for an object to be included
|
|
1984
1936
|
overlap: Overlap between tiles (0.0-1.0)
|
|
1985
1937
|
batch_size: Batch size for processing
|
|
1938
|
+
band_indexes: List of band indexes to use (default: all bands)
|
|
1986
1939
|
verbose: Whether to print detailed processing information
|
|
1987
1940
|
|
|
1988
1941
|
Returns:
|
|
@@ -1994,6 +1947,8 @@ class CarDetector(ObjectDetector):
|
|
|
1994
1947
|
if mask_threshold is None:
|
|
1995
1948
|
mask_threshold = self.mask_threshold
|
|
1996
1949
|
|
|
1950
|
+
chip_size = kwargs.get("chip_size", self.chip_size)
|
|
1951
|
+
|
|
1997
1952
|
# Default output path
|
|
1998
1953
|
if output_path is None:
|
|
1999
1954
|
output_path = os.path.splitext(raster_path)[0] + "_masks_conf.tif"
|
|
@@ -2003,8 +1958,9 @@ class CarDetector(ObjectDetector):
|
|
|
2003
1958
|
# Create dataset with the specified overlap
|
|
2004
1959
|
dataset = CustomDataset(
|
|
2005
1960
|
raster_path=raster_path,
|
|
2006
|
-
chip_size=
|
|
1961
|
+
chip_size=chip_size,
|
|
2007
1962
|
overlap=overlap,
|
|
1963
|
+
band_indexes=band_indexes,
|
|
2008
1964
|
verbose=verbose,
|
|
2009
1965
|
)
|
|
2010
1966
|
|
|
@@ -2091,6 +2047,21 @@ class CarDetector(ObjectDetector):
|
|
|
2091
2047
|
for mask_idx, mask in enumerate(masks):
|
|
2092
2048
|
# Convert to binary mask
|
|
2093
2049
|
binary_mask = (mask[0] > mask_threshold).astype(np.uint8) * 255
|
|
2050
|
+
|
|
2051
|
+
# Check object area - calculate number of pixels in the mask
|
|
2052
|
+
object_area = np.sum(binary_mask > 0)
|
|
2053
|
+
|
|
2054
|
+
# Skip objects that don't meet area criteria
|
|
2055
|
+
if (
|
|
2056
|
+
object_area < min_object_area
|
|
2057
|
+
or object_area > max_object_area
|
|
2058
|
+
):
|
|
2059
|
+
if verbose:
|
|
2060
|
+
print(
|
|
2061
|
+
f"Filtering out object with area {object_area} pixels"
|
|
2062
|
+
)
|
|
2063
|
+
continue
|
|
2064
|
+
|
|
2094
2065
|
conf_value = int(scores[mask_idx] * 255) # Scale to 0-255
|
|
2095
2066
|
|
|
2096
2067
|
# Update the mask and confidence arrays
|
|
@@ -2135,13 +2106,24 @@ class CarDetector(ObjectDetector):
|
|
|
2135
2106
|
print(f"Masks with confidence values saved to {output_path}")
|
|
2136
2107
|
return output_path
|
|
2137
2108
|
|
|
2138
|
-
def vectorize_masks(
|
|
2109
|
+
def vectorize_masks(
|
|
2110
|
+
self,
|
|
2111
|
+
masks_path,
|
|
2112
|
+
output_path=None,
|
|
2113
|
+
confidence_threshold=0.5,
|
|
2114
|
+
min_object_area=100,
|
|
2115
|
+
max_object_size=None,
|
|
2116
|
+
**kwargs,
|
|
2117
|
+
):
|
|
2139
2118
|
"""
|
|
2140
2119
|
Convert masks with confidence to vector polygons.
|
|
2141
2120
|
|
|
2142
2121
|
Args:
|
|
2143
|
-
masks_path: Path to masks GeoTIFF with confidence band
|
|
2144
|
-
output_path: Path for output GeoJSON
|
|
2122
|
+
masks_path: Path to masks GeoTIFF with confidence band.
|
|
2123
|
+
output_path: Path for output GeoJSON.
|
|
2124
|
+
confidence_threshold: Minimum confidence score (0.0-1.0). Default: 0.5
|
|
2125
|
+
min_object_area: Minimum area in pixels to keep an object. Default: 100
|
|
2126
|
+
max_object_size: Maximum area in pixels to keep an object. Default: None
|
|
2145
2127
|
**kwargs: Additional parameters
|
|
2146
2128
|
|
|
2147
2129
|
Returns:
|
|
@@ -2182,6 +2164,10 @@ class CarDetector(ObjectDetector):
|
|
|
2182
2164
|
else:
|
|
2183
2165
|
confidence = 0.0
|
|
2184
2166
|
|
|
2167
|
+
# Skip if confidence is below threshold
|
|
2168
|
+
if confidence < confidence_threshold:
|
|
2169
|
+
continue
|
|
2170
|
+
|
|
2185
2171
|
# Find contours
|
|
2186
2172
|
contours, _ = cv2.findContours(
|
|
2187
2173
|
component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
|
@@ -2190,9 +2176,13 @@ class CarDetector(ObjectDetector):
|
|
|
2190
2176
|
for contour in contours:
|
|
2191
2177
|
# Filter by size
|
|
2192
2178
|
area = cv2.contourArea(contour)
|
|
2193
|
-
if area <
|
|
2179
|
+
if area < min_object_area:
|
|
2194
2180
|
continue
|
|
2195
2181
|
|
|
2182
|
+
if max_object_size is not None:
|
|
2183
|
+
if area > max_object_size:
|
|
2184
|
+
continue
|
|
2185
|
+
|
|
2196
2186
|
# Get minimum area rectangle
|
|
2197
2187
|
rect = cv2.minAreaRect(contour)
|
|
2198
2188
|
box_points = cv2.boxPoints(rect)
|
|
@@ -2224,7 +2214,7 @@ class CarDetector(ObjectDetector):
|
|
|
2224
2214
|
# Save to file if requested
|
|
2225
2215
|
if output_path:
|
|
2226
2216
|
gdf.to_file(output_path, driver="GeoJSON")
|
|
2227
|
-
print(f"Saved {len(gdf)}
|
|
2217
|
+
print(f"Saved {len(gdf)} objects with confidence to {output_path}")
|
|
2228
2218
|
|
|
2229
2219
|
return gdf
|
|
2230
2220
|
else:
|
|
@@ -2232,6 +2222,88 @@ class CarDetector(ObjectDetector):
|
|
|
2232
2222
|
return None
|
|
2233
2223
|
|
|
2234
2224
|
|
|
2225
|
+
class BuildingFootprintExtractor(ObjectDetector):
|
|
2226
|
+
"""
|
|
2227
|
+
Building footprint extraction using a pre-trained Mask R-CNN model.
|
|
2228
|
+
|
|
2229
|
+
This class extends the
|
|
2230
|
+
`ObjectDetector` class with additional methods for building footprint extraction."
|
|
2231
|
+
"""
|
|
2232
|
+
|
|
2233
|
+
def __init__(
|
|
2234
|
+
self,
|
|
2235
|
+
model_path="building_footprints_usa.pth",
|
|
2236
|
+
repo_id=None,
|
|
2237
|
+
model=None,
|
|
2238
|
+
device=None,
|
|
2239
|
+
):
|
|
2240
|
+
"""
|
|
2241
|
+
Initialize the object extractor.
|
|
2242
|
+
|
|
2243
|
+
Args:
|
|
2244
|
+
model_path: Path to the .pth model file.
|
|
2245
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
2246
|
+
model: Custom model to use for inference.
|
|
2247
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
2248
|
+
"""
|
|
2249
|
+
super().__init__(
|
|
2250
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2251
|
+
)
|
|
2252
|
+
|
|
2253
|
+
def regularize_buildings(
|
|
2254
|
+
self,
|
|
2255
|
+
gdf,
|
|
2256
|
+
min_area=10,
|
|
2257
|
+
angle_threshold=15,
|
|
2258
|
+
orthogonality_threshold=0.3,
|
|
2259
|
+
rectangularity_threshold=0.7,
|
|
2260
|
+
):
|
|
2261
|
+
"""
|
|
2262
|
+
Regularize building footprints to enforce right angles and rectangular shapes.
|
|
2263
|
+
|
|
2264
|
+
Args:
|
|
2265
|
+
gdf: GeoDataFrame with building footprints
|
|
2266
|
+
min_area: Minimum area in square units to keep a building
|
|
2267
|
+
angle_threshold: Maximum deviation from 90 degrees to consider an angle as orthogonal (degrees)
|
|
2268
|
+
orthogonality_threshold: Percentage of angles that must be orthogonal for a building to be regularized
|
|
2269
|
+
rectangularity_threshold: Minimum area ratio to building's oriented bounding box for rectangular simplification
|
|
2270
|
+
|
|
2271
|
+
Returns:
|
|
2272
|
+
GeoDataFrame with regularized building footprints
|
|
2273
|
+
"""
|
|
2274
|
+
return self.regularize_objects(
|
|
2275
|
+
gdf,
|
|
2276
|
+
min_area=min_area,
|
|
2277
|
+
angle_threshold=angle_threshold,
|
|
2278
|
+
orthogonality_threshold=orthogonality_threshold,
|
|
2279
|
+
rectangularity_threshold=rectangularity_threshold,
|
|
2280
|
+
)
|
|
2281
|
+
|
|
2282
|
+
|
|
2283
|
+
class CarDetector(ObjectDetector):
|
|
2284
|
+
"""
|
|
2285
|
+
Car detection using a pre-trained Mask R-CNN model.
|
|
2286
|
+
|
|
2287
|
+
This class extends the `ObjectDetector` class with additional methods for car detection.
|
|
2288
|
+
"""
|
|
2289
|
+
|
|
2290
|
+
def __init__(
|
|
2291
|
+
self, model_path="car_detection_usa.pth", repo_id=None, model=None, device=None
|
|
2292
|
+
):
|
|
2293
|
+
"""
|
|
2294
|
+
Initialize the object extractor.
|
|
2295
|
+
|
|
2296
|
+
Args:
|
|
2297
|
+
model_path: Path to the .pth model file.
|
|
2298
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
2299
|
+
model: Custom model to use for inference.
|
|
2300
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
2301
|
+
"""
|
|
2302
|
+
super().__init__(
|
|
2303
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2304
|
+
)
|
|
2305
|
+
|
|
2306
|
+
|
|
2235
2307
|
class ShipDetector(ObjectDetector):
|
|
2236
2308
|
"""
|
|
2237
2309
|
Ship detection using a pre-trained Mask R-CNN model.
|
|
@@ -2255,3 +2327,32 @@ class ShipDetector(ObjectDetector):
|
|
|
2255
2327
|
super().__init__(
|
|
2256
2328
|
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2257
2329
|
)
|
|
2330
|
+
|
|
2331
|
+
|
|
2332
|
+
class SolarPanelDetector(ObjectDetector):
|
|
2333
|
+
"""
|
|
2334
|
+
Solar panel detection using a pre-trained Mask R-CNN model.
|
|
2335
|
+
|
|
2336
|
+
This class extends the
|
|
2337
|
+
`ObjectDetector` class with additional methods for solar panel detection."
|
|
2338
|
+
"""
|
|
2339
|
+
|
|
2340
|
+
def __init__(
|
|
2341
|
+
self,
|
|
2342
|
+
model_path="solar_panel_detection.pth",
|
|
2343
|
+
repo_id=None,
|
|
2344
|
+
model=None,
|
|
2345
|
+
device=None,
|
|
2346
|
+
):
|
|
2347
|
+
"""
|
|
2348
|
+
Initialize the object extractor.
|
|
2349
|
+
|
|
2350
|
+
Args:
|
|
2351
|
+
model_path: Path to the .pth model file.
|
|
2352
|
+
repo_id: Repo ID for loading models from the Hub.
|
|
2353
|
+
model: Custom model to use for inference.
|
|
2354
|
+
device: Device to use for inference ('cuda:0', 'cpu', etc.).
|
|
2355
|
+
"""
|
|
2356
|
+
super().__init__(
|
|
2357
|
+
model_path=model_path, repo_id=repo_id, model=model, device=device
|
|
2358
|
+
)
|
geoai/geoai.py
CHANGED