geoai-py 0.2.1__py2.py3-none-any.whl → 0.2.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/common.py +190 -3
- geoai/extract.py +2 -2
- geoai/geoai.py +3 -1
- geoai/preprocess.py +872 -116
- {geoai_py-0.2.1.dist-info → geoai_py-0.2.3.dist-info}/METADATA +5 -1
- geoai_py-0.2.3.dist-info/RECORD +13 -0
- geoai_py-0.2.1.dist-info/RECORD +0 -13
- {geoai_py-0.2.1.dist-info → geoai_py-0.2.3.dist-info}/LICENSE +0 -0
- {geoai_py-0.2.1.dist-info → geoai_py-0.2.3.dist-info}/WHEEL +0 -0
- {geoai_py-0.2.1.dist-info → geoai_py-0.2.3.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.2.1.dist-info → geoai_py-0.2.3.dist-info}/top_level.txt +0 -0
geoai/__init__.py
CHANGED
geoai/common.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
|
4
4
|
from collections.abc import Iterable
|
|
5
5
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Callable
|
|
6
6
|
import matplotlib.pyplot as plt
|
|
7
|
-
|
|
7
|
+
import geopandas as gpd
|
|
8
8
|
import leafmap
|
|
9
9
|
import torch
|
|
10
10
|
import numpy as np
|
|
@@ -17,7 +17,7 @@ from torchgeo.samplers import RandomGeoSampler, Units
|
|
|
17
17
|
from torchgeo.transforms import indices
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
def
|
|
20
|
+
def view_raster(
|
|
21
21
|
source: str,
|
|
22
22
|
indexes: Optional[int] = None,
|
|
23
23
|
colormap: Optional[str] = None,
|
|
@@ -85,7 +85,7 @@ def viz_raster(
|
|
|
85
85
|
return m
|
|
86
86
|
|
|
87
87
|
|
|
88
|
-
def
|
|
88
|
+
def view_image(
|
|
89
89
|
image: Union[np.ndarray, torch.Tensor],
|
|
90
90
|
transpose: bool = False,
|
|
91
91
|
bdx: Optional[int] = None,
|
|
@@ -114,6 +114,8 @@ def viz_image(
|
|
|
114
114
|
|
|
115
115
|
if isinstance(image, torch.Tensor):
|
|
116
116
|
image = image.cpu().numpy()
|
|
117
|
+
elif isinstance(image, str):
|
|
118
|
+
image = rio.open(image).read().transpose(1, 2, 0)
|
|
117
119
|
|
|
118
120
|
plt.figure(figsize=figsize)
|
|
119
121
|
|
|
@@ -434,3 +436,188 @@ def dict_to_image(
|
|
|
434
436
|
else:
|
|
435
437
|
image = leafmap.array_to_image(da, **kwargs)
|
|
436
438
|
return image
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def view_vector(
|
|
442
|
+
vector_data,
|
|
443
|
+
column=None,
|
|
444
|
+
cmap="viridis",
|
|
445
|
+
figsize=(10, 10),
|
|
446
|
+
title=None,
|
|
447
|
+
legend=True,
|
|
448
|
+
basemap=False,
|
|
449
|
+
alpha=0.7,
|
|
450
|
+
edge_color="black",
|
|
451
|
+
classification="quantiles",
|
|
452
|
+
n_classes=5,
|
|
453
|
+
highlight_index=None,
|
|
454
|
+
highlight_color="red",
|
|
455
|
+
scheme=None,
|
|
456
|
+
save_path=None,
|
|
457
|
+
dpi=300,
|
|
458
|
+
):
|
|
459
|
+
"""
|
|
460
|
+
Visualize vector datasets with options for styling, classification, basemaps and more.
|
|
461
|
+
|
|
462
|
+
This function visualizes GeoDataFrame objects with customizable symbology.
|
|
463
|
+
It supports different vector types (points, lines, polygons), attribute-based
|
|
464
|
+
classification, and background basemaps.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
vector_data (geopandas.GeoDataFrame): The vector dataset to visualize.
|
|
468
|
+
column (str, optional): Column to use for choropleth mapping. If None,
|
|
469
|
+
a single color will be used. Defaults to None.
|
|
470
|
+
cmap (str or matplotlib.colors.Colormap, optional): Colormap to use for
|
|
471
|
+
choropleth mapping. Defaults to "viridis".
|
|
472
|
+
figsize (tuple, optional): Figure size as (width, height) in inches.
|
|
473
|
+
Defaults to (10, 10).
|
|
474
|
+
title (str, optional): Title for the plot. Defaults to None.
|
|
475
|
+
legend (bool, optional): Whether to display a legend. Defaults to True.
|
|
476
|
+
basemap (bool, optional): Whether to add a web basemap. Requires contextily.
|
|
477
|
+
Defaults to False.
|
|
478
|
+
alpha (float, optional): Transparency of the vector features, between 0-1.
|
|
479
|
+
Defaults to 0.7.
|
|
480
|
+
edge_color (str, optional): Color for feature edges. Defaults to "black".
|
|
481
|
+
classification (str, optional): Classification method for choropleth maps.
|
|
482
|
+
Options: "quantiles", "equal_interval", "natural_breaks".
|
|
483
|
+
Defaults to "quantiles".
|
|
484
|
+
n_classes (int, optional): Number of classes for choropleth maps.
|
|
485
|
+
Defaults to 5.
|
|
486
|
+
highlight_index (list, optional): List of indices to highlight.
|
|
487
|
+
Defaults to None.
|
|
488
|
+
highlight_color (str, optional): Color to use for highlighted features.
|
|
489
|
+
Defaults to "red".
|
|
490
|
+
scheme (str, optional): MapClassify classification scheme. Overrides
|
|
491
|
+
classification parameter if provided. Defaults to None.
|
|
492
|
+
save_path (str, optional): Path to save the figure. If None, the figure
|
|
493
|
+
is not saved. Defaults to None.
|
|
494
|
+
dpi (int, optional): DPI for saved figure. Defaults to 300.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
matplotlib.axes.Axes: The Axes object containing the plot.
|
|
498
|
+
|
|
499
|
+
Examples:
|
|
500
|
+
>>> import geopandas as gpd
|
|
501
|
+
>>> cities = gpd.read_file("cities.shp")
|
|
502
|
+
>>> view_vector(cities, "population", cmap="Reds", basemap=True)
|
|
503
|
+
|
|
504
|
+
>>> roads = gpd.read_file("roads.shp")
|
|
505
|
+
>>> view_vector(roads, "type", basemap=True, figsize=(12, 8))
|
|
506
|
+
"""
|
|
507
|
+
import contextily as ctx
|
|
508
|
+
|
|
509
|
+
if isinstance(vector_data, str):
|
|
510
|
+
vector_data = gpd.read_file(vector_data)
|
|
511
|
+
|
|
512
|
+
# Check if input is a GeoDataFrame
|
|
513
|
+
if not isinstance(vector_data, gpd.GeoDataFrame):
|
|
514
|
+
raise TypeError("Input data must be a GeoDataFrame")
|
|
515
|
+
|
|
516
|
+
# Make a copy to avoid changing the original data
|
|
517
|
+
gdf = vector_data.copy()
|
|
518
|
+
|
|
519
|
+
# Set up figure and axis
|
|
520
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
521
|
+
|
|
522
|
+
# Determine geometry type
|
|
523
|
+
geom_type = gdf.geometry.iloc[0].geom_type
|
|
524
|
+
|
|
525
|
+
# Plotting parameters
|
|
526
|
+
plot_kwargs = {"alpha": alpha, "ax": ax}
|
|
527
|
+
|
|
528
|
+
# Set up keyword arguments based on geometry type
|
|
529
|
+
if "Point" in geom_type:
|
|
530
|
+
plot_kwargs["markersize"] = 50
|
|
531
|
+
plot_kwargs["edgecolor"] = edge_color
|
|
532
|
+
elif "Line" in geom_type:
|
|
533
|
+
plot_kwargs["linewidth"] = 1
|
|
534
|
+
elif "Polygon" in geom_type:
|
|
535
|
+
plot_kwargs["edgecolor"] = edge_color
|
|
536
|
+
|
|
537
|
+
# Classification options
|
|
538
|
+
if column is not None:
|
|
539
|
+
if scheme is not None:
|
|
540
|
+
# Use mapclassify scheme if provided
|
|
541
|
+
plot_kwargs["scheme"] = scheme
|
|
542
|
+
else:
|
|
543
|
+
# Use classification parameter
|
|
544
|
+
if classification == "quantiles":
|
|
545
|
+
plot_kwargs["scheme"] = "quantiles"
|
|
546
|
+
elif classification == "equal_interval":
|
|
547
|
+
plot_kwargs["scheme"] = "equal_interval"
|
|
548
|
+
elif classification == "natural_breaks":
|
|
549
|
+
plot_kwargs["scheme"] = "fisher_jenks"
|
|
550
|
+
|
|
551
|
+
plot_kwargs["k"] = n_classes
|
|
552
|
+
plot_kwargs["cmap"] = cmap
|
|
553
|
+
plot_kwargs["column"] = column
|
|
554
|
+
plot_kwargs["legend"] = legend
|
|
555
|
+
|
|
556
|
+
# Plot the main data
|
|
557
|
+
gdf.plot(**plot_kwargs)
|
|
558
|
+
|
|
559
|
+
# Highlight specific features if requested
|
|
560
|
+
if highlight_index is not None:
|
|
561
|
+
gdf.iloc[highlight_index].plot(
|
|
562
|
+
ax=ax, color=highlight_color, edgecolor="black", linewidth=2, zorder=5
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Add basemap if requested
|
|
566
|
+
if basemap:
|
|
567
|
+
try:
|
|
568
|
+
ctx.add_basemap(ax, crs=gdf.crs, source=ctx.providers.OpenStreetMap.Mapnik)
|
|
569
|
+
except Exception as e:
|
|
570
|
+
print(f"Could not add basemap: {e}")
|
|
571
|
+
|
|
572
|
+
# Set title if provided
|
|
573
|
+
if title:
|
|
574
|
+
ax.set_title(title, fontsize=14)
|
|
575
|
+
|
|
576
|
+
# Remove axes if not needed
|
|
577
|
+
ax.set_axis_off()
|
|
578
|
+
|
|
579
|
+
# Adjust layout
|
|
580
|
+
plt.tight_layout()
|
|
581
|
+
|
|
582
|
+
# Save figure if a path is provided
|
|
583
|
+
if save_path:
|
|
584
|
+
plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
|
585
|
+
|
|
586
|
+
return ax
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def view_vector_interactive(
|
|
590
|
+
vector_data,
|
|
591
|
+
**kwargs,
|
|
592
|
+
):
|
|
593
|
+
"""
|
|
594
|
+
Visualize vector datasets with options for styling, classification, basemaps and more.
|
|
595
|
+
|
|
596
|
+
This function visualizes GeoDataFrame objects with customizable symbology.
|
|
597
|
+
It supports different vector types (points, lines, polygons), attribute-based
|
|
598
|
+
classification, and background basemaps.
|
|
599
|
+
|
|
600
|
+
Args:
|
|
601
|
+
vector_data (geopandas.GeoDataFrame): The vector dataset to visualize.
|
|
602
|
+
**kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
|
|
603
|
+
See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
folium.Map: The map object with the vector data added.
|
|
607
|
+
|
|
608
|
+
Examples:
|
|
609
|
+
>>> import geopandas as gpd
|
|
610
|
+
>>> cities = gpd.read_file("cities.shp")
|
|
611
|
+
>>> view_vector_interactive(cities)
|
|
612
|
+
|
|
613
|
+
>>> roads = gpd.read_file("roads.shp")
|
|
614
|
+
>>> view_vector_interactive(roads, figsize=(12, 8))
|
|
615
|
+
"""
|
|
616
|
+
if isinstance(vector_data, str):
|
|
617
|
+
vector_data = gpd.read_file(vector_data)
|
|
618
|
+
|
|
619
|
+
# Check if input is a GeoDataFrame
|
|
620
|
+
if not isinstance(vector_data, gpd.GeoDataFrame):
|
|
621
|
+
raise TypeError("Input data must be a GeoDataFrame")
|
|
622
|
+
|
|
623
|
+
return vector_data.explore(**kwargs)
|
geoai/extract.py
CHANGED
|
@@ -198,7 +198,7 @@ class BuildingFootprintExtractor:
|
|
|
198
198
|
|
|
199
199
|
# Define the repository ID and model filename
|
|
200
200
|
repo_id = "giswqs/geoai" # Update with your actual username/repo
|
|
201
|
-
filename = "
|
|
201
|
+
filename = "building_footprints_usa.pth"
|
|
202
202
|
|
|
203
203
|
# Ensure cache directory exists
|
|
204
204
|
# cache_dir = os.path.join(
|
|
@@ -718,7 +718,7 @@ class BuildingFootprintExtractor:
|
|
|
718
718
|
if "confidence" in gdf.columns:
|
|
719
719
|
# Create a colorbar legend
|
|
720
720
|
sm = plt.cm.ScalarMappable(
|
|
721
|
-
cmap=plt.
|
|
721
|
+
cmap=plt.get_cmap("viridis"),
|
|
722
722
|
norm=plt.Normalize(gdf.confidence.min(), gdf.confidence.max()),
|
|
723
723
|
)
|
|
724
724
|
sm.set_array([])
|
geoai/geoai.py
CHANGED