geoai-py 0.2.1__py2.py3-none-any.whl → 0.2.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.2.1"
5
+ __version__ = "0.2.3"
6
6
 
7
7
 
8
8
  import os
geoai/common.py CHANGED
@@ -4,7 +4,7 @@ import os
4
4
  from collections.abc import Iterable
5
5
  from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
6
6
  import matplotlib.pyplot as plt
7
-
7
+ import geopandas as gpd
8
8
  import leafmap
9
9
  import torch
10
10
  import numpy as np
@@ -17,7 +17,7 @@ from torchgeo.samplers import RandomGeoSampler, Units
17
17
  from torchgeo.transforms import indices
18
18
 
19
19
 
20
- def viz_raster(
20
+ def view_raster(
21
21
  source: str,
22
22
  indexes: Optional[int] = None,
23
23
  colormap: Optional[str] = None,
@@ -85,7 +85,7 @@ def viz_raster(
85
85
  return m
86
86
 
87
87
 
88
- def viz_image(
88
+ def view_image(
89
89
  image: Union[np.ndarray, torch.Tensor],
90
90
  transpose: bool = False,
91
91
  bdx: Optional[int] = None,
@@ -114,6 +114,8 @@ def viz_image(
114
114
 
115
115
  if isinstance(image, torch.Tensor):
116
116
  image = image.cpu().numpy()
117
+ elif isinstance(image, str):
118
+ image = rio.open(image).read().transpose(1, 2, 0)
117
119
 
118
120
  plt.figure(figsize=figsize)
119
121
 
@@ -434,3 +436,188 @@ def dict_to_image(
434
436
  else:
435
437
  image = leafmap.array_to_image(da, **kwargs)
436
438
  return image
439
+
440
+
441
+ def view_vector(
442
+ vector_data,
443
+ column=None,
444
+ cmap="viridis",
445
+ figsize=(10, 10),
446
+ title=None,
447
+ legend=True,
448
+ basemap=False,
449
+ alpha=0.7,
450
+ edge_color="black",
451
+ classification="quantiles",
452
+ n_classes=5,
453
+ highlight_index=None,
454
+ highlight_color="red",
455
+ scheme=None,
456
+ save_path=None,
457
+ dpi=300,
458
+ ):
459
+ """
460
+ Visualize vector datasets with options for styling, classification, basemaps and more.
461
+
462
+ This function visualizes GeoDataFrame objects with customizable symbology.
463
+ It supports different vector types (points, lines, polygons), attribute-based
464
+ classification, and background basemaps.
465
+
466
+ Args:
467
+ vector_data (geopandas.GeoDataFrame): The vector dataset to visualize.
468
+ column (str, optional): Column to use for choropleth mapping. If None,
469
+ a single color will be used. Defaults to None.
470
+ cmap (str or matplotlib.colors.Colormap, optional): Colormap to use for
471
+ choropleth mapping. Defaults to "viridis".
472
+ figsize (tuple, optional): Figure size as (width, height) in inches.
473
+ Defaults to (10, 10).
474
+ title (str, optional): Title for the plot. Defaults to None.
475
+ legend (bool, optional): Whether to display a legend. Defaults to True.
476
+ basemap (bool, optional): Whether to add a web basemap. Requires contextily.
477
+ Defaults to False.
478
+ alpha (float, optional): Transparency of the vector features, between 0-1.
479
+ Defaults to 0.7.
480
+ edge_color (str, optional): Color for feature edges. Defaults to "black".
481
+ classification (str, optional): Classification method for choropleth maps.
482
+ Options: "quantiles", "equal_interval", "natural_breaks".
483
+ Defaults to "quantiles".
484
+ n_classes (int, optional): Number of classes for choropleth maps.
485
+ Defaults to 5.
486
+ highlight_index (list, optional): List of indices to highlight.
487
+ Defaults to None.
488
+ highlight_color (str, optional): Color to use for highlighted features.
489
+ Defaults to "red".
490
+ scheme (str, optional): MapClassify classification scheme. Overrides
491
+ classification parameter if provided. Defaults to None.
492
+ save_path (str, optional): Path to save the figure. If None, the figure
493
+ is not saved. Defaults to None.
494
+ dpi (int, optional): DPI for saved figure. Defaults to 300.
495
+
496
+ Returns:
497
+ matplotlib.axes.Axes: The Axes object containing the plot.
498
+
499
+ Examples:
500
+ >>> import geopandas as gpd
501
+ >>> cities = gpd.read_file("cities.shp")
502
+ >>> view_vector(cities, "population", cmap="Reds", basemap=True)
503
+
504
+ >>> roads = gpd.read_file("roads.shp")
505
+ >>> view_vector(roads, "type", basemap=True, figsize=(12, 8))
506
+ """
507
+ import contextily as ctx
508
+
509
+ if isinstance(vector_data, str):
510
+ vector_data = gpd.read_file(vector_data)
511
+
512
+ # Check if input is a GeoDataFrame
513
+ if not isinstance(vector_data, gpd.GeoDataFrame):
514
+ raise TypeError("Input data must be a GeoDataFrame")
515
+
516
+ # Make a copy to avoid changing the original data
517
+ gdf = vector_data.copy()
518
+
519
+ # Set up figure and axis
520
+ fig, ax = plt.subplots(figsize=figsize)
521
+
522
+ # Determine geometry type
523
+ geom_type = gdf.geometry.iloc[0].geom_type
524
+
525
+ # Plotting parameters
526
+ plot_kwargs = {"alpha": alpha, "ax": ax}
527
+
528
+ # Set up keyword arguments based on geometry type
529
+ if "Point" in geom_type:
530
+ plot_kwargs["markersize"] = 50
531
+ plot_kwargs["edgecolor"] = edge_color
532
+ elif "Line" in geom_type:
533
+ plot_kwargs["linewidth"] = 1
534
+ elif "Polygon" in geom_type:
535
+ plot_kwargs["edgecolor"] = edge_color
536
+
537
+ # Classification options
538
+ if column is not None:
539
+ if scheme is not None:
540
+ # Use mapclassify scheme if provided
541
+ plot_kwargs["scheme"] = scheme
542
+ else:
543
+ # Use classification parameter
544
+ if classification == "quantiles":
545
+ plot_kwargs["scheme"] = "quantiles"
546
+ elif classification == "equal_interval":
547
+ plot_kwargs["scheme"] = "equal_interval"
548
+ elif classification == "natural_breaks":
549
+ plot_kwargs["scheme"] = "fisher_jenks"
550
+
551
+ plot_kwargs["k"] = n_classes
552
+ plot_kwargs["cmap"] = cmap
553
+ plot_kwargs["column"] = column
554
+ plot_kwargs["legend"] = legend
555
+
556
+ # Plot the main data
557
+ gdf.plot(**plot_kwargs)
558
+
559
+ # Highlight specific features if requested
560
+ if highlight_index is not None:
561
+ gdf.iloc[highlight_index].plot(
562
+ ax=ax, color=highlight_color, edgecolor="black", linewidth=2, zorder=5
563
+ )
564
+
565
+ # Add basemap if requested
566
+ if basemap:
567
+ try:
568
+ ctx.add_basemap(ax, crs=gdf.crs, source=ctx.providers.OpenStreetMap.Mapnik)
569
+ except Exception as e:
570
+ print(f"Could not add basemap: {e}")
571
+
572
+ # Set title if provided
573
+ if title:
574
+ ax.set_title(title, fontsize=14)
575
+
576
+ # Remove axes if not needed
577
+ ax.set_axis_off()
578
+
579
+ # Adjust layout
580
+ plt.tight_layout()
581
+
582
+ # Save figure if a path is provided
583
+ if save_path:
584
+ plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
585
+
586
+ return ax
587
+
588
+
589
+ def view_vector_interactive(
590
+ vector_data,
591
+ **kwargs,
592
+ ):
593
+ """
594
+ Visualize vector datasets with options for styling, classification, basemaps and more.
595
+
596
+ This function visualizes GeoDataFrame objects with customizable symbology.
597
+ It supports different vector types (points, lines, polygons), attribute-based
598
+ classification, and background basemaps.
599
+
600
+ Args:
601
+ vector_data (geopandas.GeoDataFrame): The vector dataset to visualize.
602
+ **kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
603
+ See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
604
+
605
+ Returns:
606
+ folium.Map: The map object with the vector data added.
607
+
608
+ Examples:
609
+ >>> import geopandas as gpd
610
+ >>> cities = gpd.read_file("cities.shp")
611
+ >>> view_vector_interactive(cities)
612
+
613
+ >>> roads = gpd.read_file("roads.shp")
614
+ >>> view_vector_interactive(roads, figsize=(12, 8))
615
+ """
616
+ if isinstance(vector_data, str):
617
+ vector_data = gpd.read_file(vector_data)
618
+
619
+ # Check if input is a GeoDataFrame
620
+ if not isinstance(vector_data, gpd.GeoDataFrame):
621
+ raise TypeError("Input data must be a GeoDataFrame")
622
+
623
+ return vector_data.explore(**kwargs)
geoai/extract.py CHANGED
@@ -198,7 +198,7 @@ class BuildingFootprintExtractor:
198
198
 
199
199
  # Define the repository ID and model filename
200
200
  repo_id = "giswqs/geoai" # Update with your actual username/repo
201
- filename = "usa_building_footprints.pth"
201
+ filename = "building_footprints_usa.pth"
202
202
 
203
203
  # Ensure cache directory exists
204
204
  # cache_dir = os.path.join(
@@ -718,7 +718,7 @@ class BuildingFootprintExtractor:
718
718
  if "confidence" in gdf.columns:
719
719
  # Create a colorbar legend
720
720
  sm = plt.cm.ScalarMappable(
721
- cmap=plt.cm.viridis,
721
+ cmap=plt.get_cmap("viridis"),
722
722
  norm=plt.Normalize(gdf.confidence.min(), gdf.confidence.max()),
723
723
  )
724
724
  sm.set_array([])
geoai/geoai.py CHANGED
@@ -1,3 +1,5 @@
1
1
  """Main module."""
2
2
 
3
- from .common import viz_raster, viz_image, plot_batch, calc_stats
3
+ from .common import *
4
+ from .preprocess import *
5
+ from .extract import *