geoai-py 0.4.3__py2.py3-none-any.whl → 0.5.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +6 -1
- geoai/classify.py +933 -0
- geoai/download.py +119 -80
- geoai/geoai.py +93 -1
- geoai/train.py +115 -6
- geoai/utils.py +196 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/METADATA +9 -1
- geoai_py-0.5.1.dist-info/RECORD +16 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/WHEEL +1 -1
- geoai_py-0.4.3.dist-info/RECORD +0 -15
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.4.3.dist-info → geoai_py-0.5.1.dist-info}/top_level.txt +0 -0
geoai/download.py
CHANGED
|
@@ -207,98 +207,130 @@ def json_serializable(obj: Any) -> Any:
|
|
|
207
207
|
return obj
|
|
208
208
|
|
|
209
209
|
|
|
210
|
+
def get_overture_latest_release(patch=True) -> str:
|
|
211
|
+
"""
|
|
212
|
+
Retrieves the value of the 'latest' key from the Overture Maps release JSON file.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
patch (bool): If True, returns the full version string (e.g., "2025-02-19.0").
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
str: The value of the 'latest' key from the releases.json file.
|
|
219
|
+
|
|
220
|
+
Raises:
|
|
221
|
+
requests.RequestException: If there's an issue with the HTTP request.
|
|
222
|
+
KeyError: If the 'latest' key is not found in the JSON data.
|
|
223
|
+
json.JSONDecodeError: If the response cannot be parsed as JSON.
|
|
224
|
+
"""
|
|
225
|
+
url = "https://labs.overturemaps.org/data/releases.json"
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
response = requests.get(url)
|
|
229
|
+
response.raise_for_status() # Raise an exception for HTTP errors
|
|
230
|
+
|
|
231
|
+
data = response.json()
|
|
232
|
+
if patch:
|
|
233
|
+
latest_release = data.get("latest")
|
|
234
|
+
else:
|
|
235
|
+
latest_release = data.get("latest").split(".")[
|
|
236
|
+
0
|
|
237
|
+
] # Extract the version number
|
|
238
|
+
|
|
239
|
+
if latest_release is None:
|
|
240
|
+
raise KeyError("The 'latest' key was not found in the releases.json file")
|
|
241
|
+
|
|
242
|
+
return latest_release
|
|
243
|
+
|
|
244
|
+
except requests.RequestException as e:
|
|
245
|
+
print(f"Error making the request: {e}")
|
|
246
|
+
raise
|
|
247
|
+
except json.JSONDecodeError as e:
|
|
248
|
+
print(f"Error parsing JSON response: {e}")
|
|
249
|
+
raise
|
|
250
|
+
except KeyError as e:
|
|
251
|
+
print(f"Key error: {e}")
|
|
252
|
+
raise
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def get_all_overture_types():
|
|
256
|
+
"""Get a list of all available Overture Maps data types.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
list: List of available Overture Maps data types.
|
|
260
|
+
"""
|
|
261
|
+
from overturemaps import core
|
|
262
|
+
|
|
263
|
+
return core.get_all_overture_types()
|
|
264
|
+
|
|
265
|
+
|
|
210
266
|
def download_overture_buildings(
|
|
211
267
|
bbox: Tuple[float, float, float, float],
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
verbose: bool = True,
|
|
268
|
+
output: str,
|
|
269
|
+
overture_type: str = "building",
|
|
270
|
+
**kwargs: Any,
|
|
216
271
|
) -> str:
|
|
217
272
|
"""Download building data from Overture Maps for a given bounding box using the overturemaps CLI tool.
|
|
218
273
|
|
|
219
274
|
Args:
|
|
220
275
|
bbox: Bounding box in the format (min_lon, min_lat, max_lon, max_lat) in WGS84 coordinates.
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
data_type: The Overture Maps data type to download (building, place, etc.).
|
|
224
|
-
verbose: Whether to print verbose output.
|
|
276
|
+
output: Path to save the output file.
|
|
277
|
+
overture_type: The Overture Maps data type to download (building, place, etc.).
|
|
225
278
|
|
|
226
279
|
Returns:
|
|
227
280
|
Path to the output file.
|
|
228
281
|
"""
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
282
|
+
|
|
283
|
+
return get_overture_data(
|
|
284
|
+
overture_type=overture_type, bbox=bbox, output=output, **kwargs
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def get_overture_data(
|
|
289
|
+
overture_type: str,
|
|
290
|
+
bbox: Tuple[float, float, float, float] = None,
|
|
291
|
+
columns: List[str] = None,
|
|
292
|
+
output: str = None,
|
|
293
|
+
**kwargs: Any,
|
|
294
|
+
) -> "gpd.GeoDataFrame":
|
|
295
|
+
"""Fetches overture data and returns it as a GeoDataFrame.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
overture_type (str): The type of overture data to fetch.It can be one of the following:
|
|
299
|
+
address|building|building_part|division|division_area|division_boundary|place|
|
|
300
|
+
segment|connector|infrastructure|land|land_cover|land_use|water
|
|
301
|
+
bbox (Tuple[float, float, float, float], optional): The bounding box to
|
|
302
|
+
filter the data. Defaults to None.
|
|
303
|
+
columns (List[str], optional): The columns to include in the output.
|
|
304
|
+
Defaults to None.
|
|
305
|
+
output (str, optional): The file path to save the output GeoDataFrame.
|
|
306
|
+
Defaults to None.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
gpd.GeoDataFrame: The fetched overture data as a GeoDataFrame.
|
|
310
|
+
|
|
311
|
+
Raises:
|
|
312
|
+
ImportError: If the overture package is not installed.
|
|
313
|
+
"""
|
|
255
314
|
|
|
256
315
|
try:
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
check=True,
|
|
261
|
-
stdout=subprocess.PIPE if not verbose else None,
|
|
262
|
-
stderr=subprocess.PIPE,
|
|
263
|
-
text=True,
|
|
264
|
-
)
|
|
316
|
+
from overturemaps import core
|
|
317
|
+
except ImportError:
|
|
318
|
+
raise ImportError("The overturemaps package is required to use this function")
|
|
265
319
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
logger.info(
|
|
270
|
-
f"Successfully downloaded data to {output_file} ({file_size:.2f} MB)"
|
|
271
|
-
)
|
|
320
|
+
gdf = core.geodataframe(overture_type, bbox=bbox)
|
|
321
|
+
if columns is not None:
|
|
322
|
+
gdf = gdf[columns]
|
|
272
323
|
|
|
273
|
-
|
|
274
|
-
if output_format == "geojson" and os.path.getsize(output_file) > 0:
|
|
275
|
-
try:
|
|
276
|
-
gdf = gpd.read_file(output_file)
|
|
277
|
-
logger.info(f"Downloaded {len(gdf)} features")
|
|
278
|
-
|
|
279
|
-
if len(gdf) > 0 and verbose:
|
|
280
|
-
# Show a sample of the attribute names
|
|
281
|
-
attrs = list(gdf.columns)
|
|
282
|
-
attrs.remove("geometry")
|
|
283
|
-
logger.info(f"Available attributes: {', '.join(attrs[:10])}...")
|
|
284
|
-
except Exception as e:
|
|
285
|
-
logger.warning(f"Could not read the GeoJSON file: {str(e)}")
|
|
324
|
+
gdf.crs = "EPSG:4326"
|
|
286
325
|
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
logger.error(f"Error running overturemaps command: {str(e)}")
|
|
296
|
-
if e.stderr:
|
|
297
|
-
logger.error(f"Command error output: {e.stderr}")
|
|
298
|
-
raise RuntimeError(f"Failed to download Overture Maps data: {str(e)}")
|
|
299
|
-
except Exception as e:
|
|
300
|
-
logger.error(f"Unexpected error: {str(e)}")
|
|
301
|
-
raise
|
|
326
|
+
out_dir = os.path.dirname(os.path.abspath(output))
|
|
327
|
+
if not os.path.exists(out_dir):
|
|
328
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
329
|
+
|
|
330
|
+
if output is not None:
|
|
331
|
+
gdf.to_file(output, **kwargs)
|
|
332
|
+
|
|
333
|
+
return gdf
|
|
302
334
|
|
|
303
335
|
|
|
304
336
|
def convert_vector_format(
|
|
@@ -361,18 +393,23 @@ def convert_vector_format(
|
|
|
361
393
|
raise
|
|
362
394
|
|
|
363
395
|
|
|
364
|
-
def extract_building_stats(
|
|
396
|
+
def extract_building_stats(data: str) -> Dict[str, Any]:
|
|
365
397
|
"""Extract statistics from the building data.
|
|
366
398
|
|
|
367
399
|
Args:
|
|
368
|
-
|
|
400
|
+
data: Path to the GeoJSON file or GeoDataFrame containing building data.
|
|
369
401
|
|
|
370
402
|
Returns:
|
|
371
403
|
Dictionary with statistics.
|
|
372
404
|
"""
|
|
373
405
|
try:
|
|
374
406
|
# Read the GeoJSON file
|
|
375
|
-
|
|
407
|
+
|
|
408
|
+
if isinstance(data, gpd.GeoDataFrame):
|
|
409
|
+
gdf = data
|
|
410
|
+
else:
|
|
411
|
+
|
|
412
|
+
gdf = gpd.read_file(data)
|
|
376
413
|
|
|
377
414
|
# Calculate statistics
|
|
378
415
|
bbox = gdf.total_bounds.tolist()
|
|
@@ -903,7 +940,7 @@ def pc_stac_download(
|
|
|
903
940
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
904
941
|
|
|
905
942
|
# Handle single item case
|
|
906
|
-
if isinstance(items, pystac.Item):
|
|
943
|
+
if isinstance(items, pystac.Item) or isinstance(items, str):
|
|
907
944
|
items = [items]
|
|
908
945
|
elif not isinstance(items, list):
|
|
909
946
|
raise TypeError("items must be a STAC Item or list of STAC Items")
|
|
@@ -973,6 +1010,8 @@ def pc_stac_download(
|
|
|
973
1010
|
|
|
974
1011
|
for item in items:
|
|
975
1012
|
item_assets = {}
|
|
1013
|
+
if isinstance(item, str):
|
|
1014
|
+
item = pystac.Item.from_file(item)
|
|
976
1015
|
item_id = item.id
|
|
977
1016
|
print(f"Processing STAC item: {item_id}")
|
|
978
1017
|
|
geoai/geoai.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Main module."""
|
|
2
2
|
|
|
3
3
|
import leafmap
|
|
4
|
+
import leafmap.maplibregl as maplibregl
|
|
4
5
|
|
|
5
6
|
from .download import (
|
|
6
7
|
download_naip,
|
|
@@ -13,10 +14,11 @@ from .download import (
|
|
|
13
14
|
read_pc_item_asset,
|
|
14
15
|
view_pc_item,
|
|
15
16
|
)
|
|
17
|
+
from .classify import train_classifier, classify_image, classify_images
|
|
16
18
|
from .extract import *
|
|
17
19
|
from .hf import *
|
|
18
20
|
from .segment import *
|
|
19
|
-
from .train import object_detection, train_MaskRCNN_model
|
|
21
|
+
from .train import object_detection, object_detection_batch, train_MaskRCNN_model
|
|
20
22
|
from .utils import *
|
|
21
23
|
|
|
22
24
|
|
|
@@ -26,3 +28,93 @@ class Map(leafmap.Map):
|
|
|
26
28
|
def __init__(self, *args, **kwargs):
|
|
27
29
|
"""Initialize the Map class."""
|
|
28
30
|
super().__init__(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MapLibre(maplibregl.Map):
|
|
34
|
+
"""A subclass of maplibregl.Map for GeoAI applications."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, *args, **kwargs):
|
|
37
|
+
"""Initialize the MapLibre class."""
|
|
38
|
+
super().__init__(*args, **kwargs)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def create_vector_data(
|
|
42
|
+
m: Optional[Map] = None,
|
|
43
|
+
properties: Optional[Dict[str, List[Any]]] = None,
|
|
44
|
+
time_format: str = "%Y%m%dT%H%M%S",
|
|
45
|
+
column_widths: Optional[List[int]] = (9, 3),
|
|
46
|
+
map_height: str = "600px",
|
|
47
|
+
out_dir: Optional[str] = None,
|
|
48
|
+
filename_prefix: str = "",
|
|
49
|
+
file_ext: str = "geojson",
|
|
50
|
+
add_mapillary: bool = False,
|
|
51
|
+
style: str = "photo",
|
|
52
|
+
radius: float = 0.00005,
|
|
53
|
+
width: int = 300,
|
|
54
|
+
height: int = 420,
|
|
55
|
+
frame_border: int = 0,
|
|
56
|
+
**kwargs: Any,
|
|
57
|
+
):
|
|
58
|
+
"""Generates a widget-based interface for creating and managing vector data on a map.
|
|
59
|
+
|
|
60
|
+
This function creates an interactive widget interface that allows users to draw features
|
|
61
|
+
(points, lines, polygons) on a map, assign properties to these features, and export them
|
|
62
|
+
as GeoJSON files. The interface includes a map, a sidebar for property management, and
|
|
63
|
+
buttons for saving, exporting, and resetting the data.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
m (Map, optional): An existing Map object. If not provided, a default map with
|
|
67
|
+
basemaps and drawing controls will be created. Defaults to None.
|
|
68
|
+
properties (Dict[str, List[Any]], optional): A dictionary where keys are property names
|
|
69
|
+
and values are lists of possible values for each property. These properties can be
|
|
70
|
+
assigned to the drawn features. Defaults to None.
|
|
71
|
+
time_format (str, optional): The format string for the timestamp used in the exported
|
|
72
|
+
filename. Defaults to "%Y%m%dT%H%M%S".
|
|
73
|
+
column_widths (Optional[List[int]], optional): A list of two integers specifying the
|
|
74
|
+
relative widths of the map and sidebar columns. Defaults to (9, 3).
|
|
75
|
+
map_height (str, optional): The height of the map widget. Defaults to "600px".
|
|
76
|
+
out_dir (str, optional): The directory where the exported GeoJSON files will be saved.
|
|
77
|
+
If not provided, the current working directory is used. Defaults to None.
|
|
78
|
+
filename_prefix (str, optional): A prefix to be added to the exported filename.
|
|
79
|
+
Defaults to "".
|
|
80
|
+
file_ext (str, optional): The file extension for the exported file. Defaults to "geojson".
|
|
81
|
+
add_mapillary (bool, optional): Whether to add a Mapillary image widget that displays the
|
|
82
|
+
nearest image to the clicked point on the map. Defaults to False.
|
|
83
|
+
style (str, optional): The style of the Mapillary image widget. Can be "classic", "photo",
|
|
84
|
+
or "split". Defaults to "photo".
|
|
85
|
+
radius (float, optional): The radius (in degrees) used to search for the nearest Mapillary
|
|
86
|
+
image. Defaults to 0.00005 degrees.
|
|
87
|
+
width (int, optional): The width of the Mapillary image widget. Defaults to 300.
|
|
88
|
+
height (int, optional): The height of the Mapillary image widget. Defaults to 420.
|
|
89
|
+
frame_border (int, optional): The width of the frame border for the Mapillary image widget.
|
|
90
|
+
Defaults to 0.
|
|
91
|
+
**kwargs (Any): Additional keyword arguments that may be passed to the function.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
widgets.VBox: A vertical box widget containing the map, sidebar, and control buttons.
|
|
95
|
+
|
|
96
|
+
Example:
|
|
97
|
+
>>> properties = {
|
|
98
|
+
... "Type": ["Residential", "Commercial", "Industrial"],
|
|
99
|
+
... "Area": [100, 200, 300],
|
|
100
|
+
... }
|
|
101
|
+
>>> widget = create_vector_data(properties=properties)
|
|
102
|
+
>>> display(widget) # Display the widget in a Jupyter notebook
|
|
103
|
+
"""
|
|
104
|
+
return maplibregl.create_vector_data(
|
|
105
|
+
m=m,
|
|
106
|
+
properties=properties,
|
|
107
|
+
time_format=time_format,
|
|
108
|
+
column_widths=column_widths,
|
|
109
|
+
map_height=map_height,
|
|
110
|
+
out_dir=out_dir,
|
|
111
|
+
filename_prefix=filename_prefix,
|
|
112
|
+
file_ext=file_ext,
|
|
113
|
+
add_mapillary=add_mapillary,
|
|
114
|
+
style=style,
|
|
115
|
+
radius=radius,
|
|
116
|
+
width=width,
|
|
117
|
+
height=height,
|
|
118
|
+
frame_border=frame_border,
|
|
119
|
+
**kwargs,
|
|
120
|
+
)
|
geoai/train.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import glob
|
|
1
2
|
import math
|
|
2
3
|
import os
|
|
3
4
|
import random
|
|
@@ -20,6 +21,8 @@ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
|
20
21
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
|
21
22
|
from tqdm import tqdm
|
|
22
23
|
|
|
24
|
+
from .utils import download_model_from_hf
|
|
25
|
+
|
|
23
26
|
|
|
24
27
|
def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
|
|
25
28
|
"""
|
|
@@ -352,7 +355,9 @@ def collate_fn(batch):
|
|
|
352
355
|
return tuple(zip(*batch))
|
|
353
356
|
|
|
354
357
|
|
|
355
|
-
def train_one_epoch(
|
|
358
|
+
def train_one_epoch(
|
|
359
|
+
model, optimizer, data_loader, device, epoch, print_freq=10, verbose=True
|
|
360
|
+
):
|
|
356
361
|
"""
|
|
357
362
|
Train the model for one epoch.
|
|
358
363
|
|
|
@@ -363,6 +368,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
|
|
|
363
368
|
device (torch.device): Device to train on.
|
|
364
369
|
epoch (int): Current epoch number.
|
|
365
370
|
print_freq (int): How often to print progress.
|
|
371
|
+
verbose (bool): Whether to print detailed progress.
|
|
366
372
|
|
|
367
373
|
Returns:
|
|
368
374
|
float: Average loss for the epoch.
|
|
@@ -392,9 +398,10 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
|
|
|
392
398
|
# Print progress
|
|
393
399
|
if i % print_freq == 0:
|
|
394
400
|
elapsed_time = time.time() - start_time
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
401
|
+
if verbose:
|
|
402
|
+
print(
|
|
403
|
+
f"Epoch: {epoch}, Batch: {i}/{len(data_loader)}, Loss: {losses.item():.4f}, Time: {elapsed_time:.2f}s"
|
|
404
|
+
)
|
|
398
405
|
start_time = time.time()
|
|
399
406
|
|
|
400
407
|
# Calculate average loss
|
|
@@ -579,6 +586,8 @@ def train_MaskRCNN_model(
|
|
|
579
586
|
val_split=0.2,
|
|
580
587
|
visualize=False,
|
|
581
588
|
resume_training=False,
|
|
589
|
+
print_freq=10,
|
|
590
|
+
verbose=True,
|
|
582
591
|
):
|
|
583
592
|
"""Train and evaluate Mask R-CNN model for instance segmentation.
|
|
584
593
|
|
|
@@ -605,7 +614,8 @@ def train_MaskRCNN_model(
|
|
|
605
614
|
Defaults to False.
|
|
606
615
|
resume_training (bool): If True and pretrained_model_path is provided,
|
|
607
616
|
will try to load optimizer and scheduler states as well. Defaults to False.
|
|
608
|
-
|
|
617
|
+
print_freq (int): Frequency of printing training progress. Defaults to 10.
|
|
618
|
+
verbose (bool): If True, prints detailed training progress. Defaults to True.
|
|
609
619
|
Returns:
|
|
610
620
|
None: Model weights are saved to output_dir.
|
|
611
621
|
|
|
@@ -756,7 +766,9 @@ def train_MaskRCNN_model(
|
|
|
756
766
|
# Training loop
|
|
757
767
|
for epoch in range(start_epoch, num_epochs):
|
|
758
768
|
# Train one epoch
|
|
759
|
-
train_loss = train_one_epoch(
|
|
769
|
+
train_loss = train_one_epoch(
|
|
770
|
+
model, optimizer, train_loader, device, epoch, print_freq, verbose
|
|
771
|
+
)
|
|
760
772
|
|
|
761
773
|
# Update learning rate
|
|
762
774
|
lr_scheduler.step()
|
|
@@ -1107,6 +1119,13 @@ def object_detection(
|
|
|
1107
1119
|
model = get_instance_segmentation_model(
|
|
1108
1120
|
num_classes=2, num_channels=num_channels, pretrained=pretrained
|
|
1109
1121
|
)
|
|
1122
|
+
|
|
1123
|
+
if not os.path.exists(model_path):
|
|
1124
|
+
try:
|
|
1125
|
+
model_path = download_model_from_hf(model_path)
|
|
1126
|
+
except Exception as e:
|
|
1127
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
1128
|
+
|
|
1110
1129
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
1111
1130
|
model.to(device)
|
|
1112
1131
|
model.eval()
|
|
@@ -1123,3 +1142,93 @@ def object_detection(
|
|
|
1123
1142
|
device=device,
|
|
1124
1143
|
**kwargs,
|
|
1125
1144
|
)
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
def object_detection_batch(
|
|
1148
|
+
input_paths,
|
|
1149
|
+
output_dir,
|
|
1150
|
+
model_path,
|
|
1151
|
+
filenames=None,
|
|
1152
|
+
window_size=512,
|
|
1153
|
+
overlap=256,
|
|
1154
|
+
confidence_threshold=0.5,
|
|
1155
|
+
batch_size=4,
|
|
1156
|
+
num_channels=3,
|
|
1157
|
+
pretrained=True,
|
|
1158
|
+
device=None,
|
|
1159
|
+
**kwargs,
|
|
1160
|
+
):
|
|
1161
|
+
"""
|
|
1162
|
+
Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
|
|
1163
|
+
|
|
1164
|
+
Args:
|
|
1165
|
+
input_paths (str or list): Path(s) to input GeoTIFF file(s). If a directory is provided,
|
|
1166
|
+
all .tif files in that directory will be processed.
|
|
1167
|
+
output_dir (str): Directory to save output mask GeoTIFF files.
|
|
1168
|
+
model_path (str): Path to trained model weights.
|
|
1169
|
+
filenames (list, optional): List of output filenames. If None, defaults to
|
|
1170
|
+
"<input_filename>_mask.tif" for each input file.
|
|
1171
|
+
If provided, must match the number of input files.
|
|
1172
|
+
window_size (int): Size of sliding window for inference.
|
|
1173
|
+
overlap (int): Overlap between adjacent windows.
|
|
1174
|
+
confidence_threshold (float): Confidence threshold for predictions (0-1).
|
|
1175
|
+
batch_size (int): Batch size for inference.
|
|
1176
|
+
num_channels (int): Number of channels in the input image and model.
|
|
1177
|
+
pretrained (bool): Whether to use pretrained backbone for model loading.
|
|
1178
|
+
device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
|
|
1179
|
+
**kwargs: Additional arguments passed to inference_on_geotiff.
|
|
1180
|
+
|
|
1181
|
+
Returns:
|
|
1182
|
+
None: Output mask is saved to output_path.
|
|
1183
|
+
"""
|
|
1184
|
+
# Load your trained model
|
|
1185
|
+
if device is None:
|
|
1186
|
+
device = (
|
|
1187
|
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
1188
|
+
)
|
|
1189
|
+
model = get_instance_segmentation_model(
|
|
1190
|
+
num_classes=2, num_channels=num_channels, pretrained=pretrained
|
|
1191
|
+
)
|
|
1192
|
+
|
|
1193
|
+
if not os.path.exists(output_dir):
|
|
1194
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1195
|
+
|
|
1196
|
+
if not os.path.exists(model_path):
|
|
1197
|
+
try:
|
|
1198
|
+
model_path = download_model_from_hf(model_path)
|
|
1199
|
+
except Exception as e:
|
|
1200
|
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
|
1201
|
+
|
|
1202
|
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
1203
|
+
model.to(device)
|
|
1204
|
+
model.eval()
|
|
1205
|
+
|
|
1206
|
+
if isinstance(input_paths, str) and (not input_paths.endswith(".tif")):
|
|
1207
|
+
files = glob.glob(os.path.join(input_paths, "*.tif"))
|
|
1208
|
+
files.sort()
|
|
1209
|
+
elif isinstance(input_paths, str):
|
|
1210
|
+
files = [input_paths]
|
|
1211
|
+
|
|
1212
|
+
if filenames is None:
|
|
1213
|
+
filenames = [
|
|
1214
|
+
os.path.join(output_dir, os.path.basename(f).replace(".tif", "_mask.tif"))
|
|
1215
|
+
for f in files
|
|
1216
|
+
]
|
|
1217
|
+
else:
|
|
1218
|
+
if len(filenames) != len(files):
|
|
1219
|
+
raise ValueError("Number of filenames must match number of input files.")
|
|
1220
|
+
|
|
1221
|
+
for index, file in enumerate(files):
|
|
1222
|
+
print(f"Processing file {index + 1}/{len(files)}: {file}")
|
|
1223
|
+
inference_on_geotiff(
|
|
1224
|
+
model=model,
|
|
1225
|
+
geotiff_path=file,
|
|
1226
|
+
output_path=filenames[index],
|
|
1227
|
+
window_size=window_size, # Adjust based on your model and memory
|
|
1228
|
+
overlap=overlap, # Overlap to avoid edge artifacts
|
|
1229
|
+
confidence_threshold=confidence_threshold,
|
|
1230
|
+
batch_size=batch_size, # Adjust based on your GPU memory
|
|
1231
|
+
num_channels=num_channels,
|
|
1232
|
+
device=device,
|
|
1233
|
+
**kwargs,
|
|
1234
|
+
)
|