geoai-py 0.4.3__py2.py3-none-any.whl → 0.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/download.py CHANGED
@@ -207,98 +207,130 @@ def json_serializable(obj: Any) -> Any:
207
207
  return obj
208
208
 
209
209
 
210
+ def get_overture_latest_release(patch=True) -> str:
211
+ """
212
+ Retrieves the value of the 'latest' key from the Overture Maps release JSON file.
213
+
214
+ Args:
215
+ patch (bool): If True, returns the full version string (e.g., "2025-02-19.0").
216
+
217
+ Returns:
218
+ str: The value of the 'latest' key from the releases.json file.
219
+
220
+ Raises:
221
+ requests.RequestException: If there's an issue with the HTTP request.
222
+ KeyError: If the 'latest' key is not found in the JSON data.
223
+ json.JSONDecodeError: If the response cannot be parsed as JSON.
224
+ """
225
+ url = "https://labs.overturemaps.org/data/releases.json"
226
+
227
+ try:
228
+ response = requests.get(url)
229
+ response.raise_for_status() # Raise an exception for HTTP errors
230
+
231
+ data = response.json()
232
+ if patch:
233
+ latest_release = data.get("latest")
234
+ else:
235
+ latest_release = data.get("latest").split(".")[
236
+ 0
237
+ ] # Extract the version number
238
+
239
+ if latest_release is None:
240
+ raise KeyError("The 'latest' key was not found in the releases.json file")
241
+
242
+ return latest_release
243
+
244
+ except requests.RequestException as e:
245
+ print(f"Error making the request: {e}")
246
+ raise
247
+ except json.JSONDecodeError as e:
248
+ print(f"Error parsing JSON response: {e}")
249
+ raise
250
+ except KeyError as e:
251
+ print(f"Key error: {e}")
252
+ raise
253
+
254
+
255
+ def get_all_overture_types():
256
+ """Get a list of all available Overture Maps data types.
257
+
258
+ Returns:
259
+ list: List of available Overture Maps data types.
260
+ """
261
+ from overturemaps import core
262
+
263
+ return core.get_all_overture_types()
264
+
265
+
210
266
  def download_overture_buildings(
211
267
  bbox: Tuple[float, float, float, float],
212
- output_file: str,
213
- output_format: str = "geojson",
214
- data_type: str = "building",
215
- verbose: bool = True,
268
+ output: str,
269
+ overture_type: str = "building",
270
+ **kwargs: Any,
216
271
  ) -> str:
217
272
  """Download building data from Overture Maps for a given bounding box using the overturemaps CLI tool.
218
273
 
219
274
  Args:
220
275
  bbox: Bounding box in the format (min_lon, min_lat, max_lon, max_lat) in WGS84 coordinates.
221
- output_file: Path to save the output file.
222
- output_format: Format to save the output, one of "geojson", "geojsonseq", or "geoparquet".
223
- data_type: The Overture Maps data type to download (building, place, etc.).
224
- verbose: Whether to print verbose output.
276
+ output: Path to save the output file.
277
+ overture_type: The Overture Maps data type to download (building, place, etc.).
225
278
 
226
279
  Returns:
227
280
  Path to the output file.
228
281
  """
229
- # Create output directory if needed
230
- output_dir = os.path.dirname(output_file)
231
- if output_dir and not os.path.exists(output_dir):
232
- os.makedirs(output_dir, exist_ok=True)
233
-
234
- # Format the bounding box string for the command
235
- west, south, east, north = bbox
236
- bbox_str = f"{west},{south},{east},{north}"
237
-
238
- # Build the command
239
- cmd = [
240
- "overturemaps",
241
- "download",
242
- "--bbox",
243
- bbox_str,
244
- "-f",
245
- output_format,
246
- "--type",
247
- data_type,
248
- "--output",
249
- output_file,
250
- ]
251
-
252
- if verbose:
253
- logger.info(f"Running command: {' '.join(cmd)}")
254
- logger.info("Downloading %s data for area: %s", data_type, bbox_str)
282
+
283
+ return get_overture_data(
284
+ overture_type=overture_type, bbox=bbox, output=output, **kwargs
285
+ )
286
+
287
+
288
+ def get_overture_data(
289
+ overture_type: str,
290
+ bbox: Tuple[float, float, float, float] = None,
291
+ columns: List[str] = None,
292
+ output: str = None,
293
+ **kwargs: Any,
294
+ ) -> "gpd.GeoDataFrame":
295
+ """Fetches overture data and returns it as a GeoDataFrame.
296
+
297
+ Args:
298
+ overture_type (str): The type of overture data to fetch.It can be one of the following:
299
+ address|building|building_part|division|division_area|division_boundary|place|
300
+ segment|connector|infrastructure|land|land_cover|land_use|water
301
+ bbox (Tuple[float, float, float, float], optional): The bounding box to
302
+ filter the data. Defaults to None.
303
+ columns (List[str], optional): The columns to include in the output.
304
+ Defaults to None.
305
+ output (str, optional): The file path to save the output GeoDataFrame.
306
+ Defaults to None.
307
+
308
+ Returns:
309
+ gpd.GeoDataFrame: The fetched overture data as a GeoDataFrame.
310
+
311
+ Raises:
312
+ ImportError: If the overture package is not installed.
313
+ """
255
314
 
256
315
  try:
257
- # Run the command
258
- result = subprocess.run(
259
- cmd,
260
- check=True,
261
- stdout=subprocess.PIPE if not verbose else None,
262
- stderr=subprocess.PIPE,
263
- text=True,
264
- )
316
+ from overturemaps import core
317
+ except ImportError:
318
+ raise ImportError("The overturemaps package is required to use this function")
265
319
 
266
- # Check if the file was created
267
- if os.path.exists(output_file):
268
- file_size = os.path.getsize(output_file) / (1024 * 1024) # Size in MB
269
- logger.info(
270
- f"Successfully downloaded data to {output_file} ({file_size:.2f} MB)"
271
- )
320
+ gdf = core.geodataframe(overture_type, bbox=bbox)
321
+ if columns is not None:
322
+ gdf = gdf[columns]
272
323
 
273
- # Optionally show some stats about the downloaded data
274
- if output_format == "geojson" and os.path.getsize(output_file) > 0:
275
- try:
276
- gdf = gpd.read_file(output_file)
277
- logger.info(f"Downloaded {len(gdf)} features")
278
-
279
- if len(gdf) > 0 and verbose:
280
- # Show a sample of the attribute names
281
- attrs = list(gdf.columns)
282
- attrs.remove("geometry")
283
- logger.info(f"Available attributes: {', '.join(attrs[:10])}...")
284
- except Exception as e:
285
- logger.warning(f"Could not read the GeoJSON file: {str(e)}")
324
+ gdf.crs = "EPSG:4326"
286
325
 
287
- return output_file
288
- else:
289
- logger.error(f"Command completed but file {output_file} was not created")
290
- if result.stderr:
291
- logger.error(f"Command error output: {result.stderr}")
292
- return None
293
-
294
- except subprocess.CalledProcessError as e:
295
- logger.error(f"Error running overturemaps command: {str(e)}")
296
- if e.stderr:
297
- logger.error(f"Command error output: {e.stderr}")
298
- raise RuntimeError(f"Failed to download Overture Maps data: {str(e)}")
299
- except Exception as e:
300
- logger.error(f"Unexpected error: {str(e)}")
301
- raise
326
+ out_dir = os.path.dirname(os.path.abspath(output))
327
+ if not os.path.exists(out_dir):
328
+ os.makedirs(out_dir, exist_ok=True)
329
+
330
+ if output is not None:
331
+ gdf.to_file(output, **kwargs)
332
+
333
+ return gdf
302
334
 
303
335
 
304
336
  def convert_vector_format(
@@ -361,18 +393,23 @@ def convert_vector_format(
361
393
  raise
362
394
 
363
395
 
364
- def extract_building_stats(geojson_file: str) -> Dict[str, Any]:
396
+ def extract_building_stats(data: str) -> Dict[str, Any]:
365
397
  """Extract statistics from the building data.
366
398
 
367
399
  Args:
368
- geojson_file: Path to the GeoJSON file.
400
+ data: Path to the GeoJSON file or GeoDataFrame containing building data.
369
401
 
370
402
  Returns:
371
403
  Dictionary with statistics.
372
404
  """
373
405
  try:
374
406
  # Read the GeoJSON file
375
- gdf = gpd.read_file(geojson_file)
407
+
408
+ if isinstance(data, gpd.GeoDataFrame):
409
+ gdf = data
410
+ else:
411
+
412
+ gdf = gpd.read_file(data)
376
413
 
377
414
  # Calculate statistics
378
415
  bbox = gdf.total_bounds.tolist()
@@ -903,7 +940,7 @@ def pc_stac_download(
903
940
  from concurrent.futures import ThreadPoolExecutor, as_completed
904
941
 
905
942
  # Handle single item case
906
- if isinstance(items, pystac.Item):
943
+ if isinstance(items, pystac.Item) or isinstance(items, str):
907
944
  items = [items]
908
945
  elif not isinstance(items, list):
909
946
  raise TypeError("items must be a STAC Item or list of STAC Items")
@@ -973,6 +1010,8 @@ def pc_stac_download(
973
1010
 
974
1011
  for item in items:
975
1012
  item_assets = {}
1013
+ if isinstance(item, str):
1014
+ item = pystac.Item.from_file(item)
976
1015
  item_id = item.id
977
1016
  print(f"Processing STAC item: {item_id}")
978
1017
 
geoai/geoai.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Main module."""
2
2
 
3
3
  import leafmap
4
+ import leafmap.maplibregl as maplibregl
4
5
 
5
6
  from .download import (
6
7
  download_naip,
@@ -13,10 +14,11 @@ from .download import (
13
14
  read_pc_item_asset,
14
15
  view_pc_item,
15
16
  )
17
+ from .classify import train_classifier, classify_image, classify_images
16
18
  from .extract import *
17
19
  from .hf import *
18
20
  from .segment import *
19
- from .train import object_detection, train_MaskRCNN_model
21
+ from .train import object_detection, object_detection_batch, train_MaskRCNN_model
20
22
  from .utils import *
21
23
 
22
24
 
@@ -26,3 +28,93 @@ class Map(leafmap.Map):
26
28
  def __init__(self, *args, **kwargs):
27
29
  """Initialize the Map class."""
28
30
  super().__init__(*args, **kwargs)
31
+
32
+
33
+ class MapLibre(maplibregl.Map):
34
+ """A subclass of maplibregl.Map for GeoAI applications."""
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ """Initialize the MapLibre class."""
38
+ super().__init__(*args, **kwargs)
39
+
40
+
41
+ def create_vector_data(
42
+ m: Optional[Map] = None,
43
+ properties: Optional[Dict[str, List[Any]]] = None,
44
+ time_format: str = "%Y%m%dT%H%M%S",
45
+ column_widths: Optional[List[int]] = (9, 3),
46
+ map_height: str = "600px",
47
+ out_dir: Optional[str] = None,
48
+ filename_prefix: str = "",
49
+ file_ext: str = "geojson",
50
+ add_mapillary: bool = False,
51
+ style: str = "photo",
52
+ radius: float = 0.00005,
53
+ width: int = 300,
54
+ height: int = 420,
55
+ frame_border: int = 0,
56
+ **kwargs: Any,
57
+ ):
58
+ """Generates a widget-based interface for creating and managing vector data on a map.
59
+
60
+ This function creates an interactive widget interface that allows users to draw features
61
+ (points, lines, polygons) on a map, assign properties to these features, and export them
62
+ as GeoJSON files. The interface includes a map, a sidebar for property management, and
63
+ buttons for saving, exporting, and resetting the data.
64
+
65
+ Args:
66
+ m (Map, optional): An existing Map object. If not provided, a default map with
67
+ basemaps and drawing controls will be created. Defaults to None.
68
+ properties (Dict[str, List[Any]], optional): A dictionary where keys are property names
69
+ and values are lists of possible values for each property. These properties can be
70
+ assigned to the drawn features. Defaults to None.
71
+ time_format (str, optional): The format string for the timestamp used in the exported
72
+ filename. Defaults to "%Y%m%dT%H%M%S".
73
+ column_widths (Optional[List[int]], optional): A list of two integers specifying the
74
+ relative widths of the map and sidebar columns. Defaults to (9, 3).
75
+ map_height (str, optional): The height of the map widget. Defaults to "600px".
76
+ out_dir (str, optional): The directory where the exported GeoJSON files will be saved.
77
+ If not provided, the current working directory is used. Defaults to None.
78
+ filename_prefix (str, optional): A prefix to be added to the exported filename.
79
+ Defaults to "".
80
+ file_ext (str, optional): The file extension for the exported file. Defaults to "geojson".
81
+ add_mapillary (bool, optional): Whether to add a Mapillary image widget that displays the
82
+ nearest image to the clicked point on the map. Defaults to False.
83
+ style (str, optional): The style of the Mapillary image widget. Can be "classic", "photo",
84
+ or "split". Defaults to "photo".
85
+ radius (float, optional): The radius (in degrees) used to search for the nearest Mapillary
86
+ image. Defaults to 0.00005 degrees.
87
+ width (int, optional): The width of the Mapillary image widget. Defaults to 300.
88
+ height (int, optional): The height of the Mapillary image widget. Defaults to 420.
89
+ frame_border (int, optional): The width of the frame border for the Mapillary image widget.
90
+ Defaults to 0.
91
+ **kwargs (Any): Additional keyword arguments that may be passed to the function.
92
+
93
+ Returns:
94
+ widgets.VBox: A vertical box widget containing the map, sidebar, and control buttons.
95
+
96
+ Example:
97
+ >>> properties = {
98
+ ... "Type": ["Residential", "Commercial", "Industrial"],
99
+ ... "Area": [100, 200, 300],
100
+ ... }
101
+ >>> widget = create_vector_data(properties=properties)
102
+ >>> display(widget) # Display the widget in a Jupyter notebook
103
+ """
104
+ return maplibregl.create_vector_data(
105
+ m=m,
106
+ properties=properties,
107
+ time_format=time_format,
108
+ column_widths=column_widths,
109
+ map_height=map_height,
110
+ out_dir=out_dir,
111
+ filename_prefix=filename_prefix,
112
+ file_ext=file_ext,
113
+ add_mapillary=add_mapillary,
114
+ style=style,
115
+ radius=radius,
116
+ width=width,
117
+ height=height,
118
+ frame_border=frame_border,
119
+ **kwargs,
120
+ )
geoai/train.py CHANGED
@@ -1,3 +1,4 @@
1
+ import glob
1
2
  import math
2
3
  import os
3
4
  import random
@@ -20,6 +21,8 @@ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
20
21
  from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
21
22
  from tqdm import tqdm
22
23
 
24
+ from .utils import download_model_from_hf
25
+
23
26
 
24
27
  def get_instance_segmentation_model(num_classes=2, num_channels=3, pretrained=True):
25
28
  """
@@ -352,7 +355,9 @@ def collate_fn(batch):
352
355
  return tuple(zip(*batch))
353
356
 
354
357
 
355
- def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
358
+ def train_one_epoch(
359
+ model, optimizer, data_loader, device, epoch, print_freq=10, verbose=True
360
+ ):
356
361
  """
357
362
  Train the model for one epoch.
358
363
 
@@ -363,6 +368,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
363
368
  device (torch.device): Device to train on.
364
369
  epoch (int): Current epoch number.
365
370
  print_freq (int): How often to print progress.
371
+ verbose (bool): Whether to print detailed progress.
366
372
 
367
373
  Returns:
368
374
  float: Average loss for the epoch.
@@ -392,9 +398,10 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
392
398
  # Print progress
393
399
  if i % print_freq == 0:
394
400
  elapsed_time = time.time() - start_time
395
- print(
396
- f"Epoch: {epoch}, Batch: {i}/{len(data_loader)}, Loss: {losses.item():.4f}, Time: {elapsed_time:.2f}s"
397
- )
401
+ if verbose:
402
+ print(
403
+ f"Epoch: {epoch}, Batch: {i}/{len(data_loader)}, Loss: {losses.item():.4f}, Time: {elapsed_time:.2f}s"
404
+ )
398
405
  start_time = time.time()
399
406
 
400
407
  # Calculate average loss
@@ -579,6 +586,8 @@ def train_MaskRCNN_model(
579
586
  val_split=0.2,
580
587
  visualize=False,
581
588
  resume_training=False,
589
+ print_freq=10,
590
+ verbose=True,
582
591
  ):
583
592
  """Train and evaluate Mask R-CNN model for instance segmentation.
584
593
 
@@ -605,7 +614,8 @@ def train_MaskRCNN_model(
605
614
  Defaults to False.
606
615
  resume_training (bool): If True and pretrained_model_path is provided,
607
616
  will try to load optimizer and scheduler states as well. Defaults to False.
608
-
617
+ print_freq (int): Frequency of printing training progress. Defaults to 10.
618
+ verbose (bool): If True, prints detailed training progress. Defaults to True.
609
619
  Returns:
610
620
  None: Model weights are saved to output_dir.
611
621
 
@@ -756,7 +766,9 @@ def train_MaskRCNN_model(
756
766
  # Training loop
757
767
  for epoch in range(start_epoch, num_epochs):
758
768
  # Train one epoch
759
- train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
769
+ train_loss = train_one_epoch(
770
+ model, optimizer, train_loader, device, epoch, print_freq, verbose
771
+ )
760
772
 
761
773
  # Update learning rate
762
774
  lr_scheduler.step()
@@ -1107,6 +1119,13 @@ def object_detection(
1107
1119
  model = get_instance_segmentation_model(
1108
1120
  num_classes=2, num_channels=num_channels, pretrained=pretrained
1109
1121
  )
1122
+
1123
+ if not os.path.exists(model_path):
1124
+ try:
1125
+ model_path = download_model_from_hf(model_path)
1126
+ except Exception as e:
1127
+ raise FileNotFoundError(f"Model file not found: {model_path}")
1128
+
1110
1129
  model.load_state_dict(torch.load(model_path, map_location=device))
1111
1130
  model.to(device)
1112
1131
  model.eval()
@@ -1123,3 +1142,93 @@ def object_detection(
1123
1142
  device=device,
1124
1143
  **kwargs,
1125
1144
  )
1145
+
1146
+
1147
+ def object_detection_batch(
1148
+ input_paths,
1149
+ output_dir,
1150
+ model_path,
1151
+ filenames=None,
1152
+ window_size=512,
1153
+ overlap=256,
1154
+ confidence_threshold=0.5,
1155
+ batch_size=4,
1156
+ num_channels=3,
1157
+ pretrained=True,
1158
+ device=None,
1159
+ **kwargs,
1160
+ ):
1161
+ """
1162
+ Perform object detection on a GeoTIFF using a pre-trained Mask R-CNN model.
1163
+
1164
+ Args:
1165
+ input_paths (str or list): Path(s) to input GeoTIFF file(s). If a directory is provided,
1166
+ all .tif files in that directory will be processed.
1167
+ output_dir (str): Directory to save output mask GeoTIFF files.
1168
+ model_path (str): Path to trained model weights.
1169
+ filenames (list, optional): List of output filenames. If None, defaults to
1170
+ "<input_filename>_mask.tif" for each input file.
1171
+ If provided, must match the number of input files.
1172
+ window_size (int): Size of sliding window for inference.
1173
+ overlap (int): Overlap between adjacent windows.
1174
+ confidence_threshold (float): Confidence threshold for predictions (0-1).
1175
+ batch_size (int): Batch size for inference.
1176
+ num_channels (int): Number of channels in the input image and model.
1177
+ pretrained (bool): Whether to use pretrained backbone for model loading.
1178
+ device (torch.device, optional): Device to run inference on. If None, uses CUDA if available.
1179
+ **kwargs: Additional arguments passed to inference_on_geotiff.
1180
+
1181
+ Returns:
1182
+ None: Output mask is saved to output_path.
1183
+ """
1184
+ # Load your trained model
1185
+ if device is None:
1186
+ device = (
1187
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1188
+ )
1189
+ model = get_instance_segmentation_model(
1190
+ num_classes=2, num_channels=num_channels, pretrained=pretrained
1191
+ )
1192
+
1193
+ if not os.path.exists(output_dir):
1194
+ os.makedirs(output_dir, exist_ok=True)
1195
+
1196
+ if not os.path.exists(model_path):
1197
+ try:
1198
+ model_path = download_model_from_hf(model_path)
1199
+ except Exception as e:
1200
+ raise FileNotFoundError(f"Model file not found: {model_path}")
1201
+
1202
+ model.load_state_dict(torch.load(model_path, map_location=device))
1203
+ model.to(device)
1204
+ model.eval()
1205
+
1206
+ if isinstance(input_paths, str) and (not input_paths.endswith(".tif")):
1207
+ files = glob.glob(os.path.join(input_paths, "*.tif"))
1208
+ files.sort()
1209
+ elif isinstance(input_paths, str):
1210
+ files = [input_paths]
1211
+
1212
+ if filenames is None:
1213
+ filenames = [
1214
+ os.path.join(output_dir, os.path.basename(f).replace(".tif", "_mask.tif"))
1215
+ for f in files
1216
+ ]
1217
+ else:
1218
+ if len(filenames) != len(files):
1219
+ raise ValueError("Number of filenames must match number of input files.")
1220
+
1221
+ for index, file in enumerate(files):
1222
+ print(f"Processing file {index + 1}/{len(files)}: {file}")
1223
+ inference_on_geotiff(
1224
+ model=model,
1225
+ geotiff_path=file,
1226
+ output_path=filenames[index],
1227
+ window_size=window_size, # Adjust based on your model and memory
1228
+ overlap=overlap, # Overlap to avoid edge artifacts
1229
+ confidence_threshold=confidence_threshold,
1230
+ batch_size=batch_size, # Adjust based on your GPU memory
1231
+ num_channels=num_channels,
1232
+ device=device,
1233
+ **kwargs,
1234
+ )