geoai-py 0.4.1__py2.py3-none-any.whl → 0.4.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/download.py +751 -3
- geoai/extract.py +671 -46
- geoai/geoai.py +22 -1
- geoai/train.py +98 -12
- geoai/utils.py +240 -8
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/METADATA +6 -6
- geoai_py-0.4.3.dist-info/RECORD +15 -0
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/WHEEL +1 -1
- geoai_py-0.4.1.dist-info/RECORD +0 -15
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info/licenses}/LICENSE +0 -0
- {geoai_py-0.4.1.dist-info → geoai_py-0.4.3.dist-info}/top_level.txt +0 -0
geoai/geoai.py
CHANGED
|
@@ -1,7 +1,28 @@
|
|
|
1
1
|
"""Main module."""
|
|
2
2
|
|
|
3
|
+
import leafmap
|
|
4
|
+
|
|
5
|
+
from .download import (
|
|
6
|
+
download_naip,
|
|
7
|
+
download_overture_buildings,
|
|
8
|
+
download_pc_stac_item,
|
|
9
|
+
pc_collection_list,
|
|
10
|
+
pc_item_asset_list,
|
|
11
|
+
pc_stac_search,
|
|
12
|
+
pc_stac_download,
|
|
13
|
+
read_pc_item_asset,
|
|
14
|
+
view_pc_item,
|
|
15
|
+
)
|
|
3
16
|
from .extract import *
|
|
4
17
|
from .hf import *
|
|
5
18
|
from .segment import *
|
|
19
|
+
from .train import object_detection, train_MaskRCNN_model
|
|
6
20
|
from .utils import *
|
|
7
|
-
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Map(leafmap.Map):
|
|
24
|
+
"""A subclass of leafmap.Map for GeoAI applications."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, *args, **kwargs):
|
|
27
|
+
"""Initialize the Map class."""
|
|
28
|
+
super().__init__(*args, **kwargs)
|
geoai/train.py
CHANGED
|
@@ -571,33 +571,51 @@ def train_MaskRCNN_model(
|
|
|
571
571
|
output_dir,
|
|
572
572
|
num_channels=3,
|
|
573
573
|
pretrained=True,
|
|
574
|
+
pretrained_model_path=None,
|
|
574
575
|
batch_size=4,
|
|
575
576
|
num_epochs=10,
|
|
576
577
|
learning_rate=0.005,
|
|
577
578
|
seed=42,
|
|
578
579
|
val_split=0.2,
|
|
579
580
|
visualize=False,
|
|
581
|
+
resume_training=False,
|
|
580
582
|
):
|
|
581
|
-
"""
|
|
582
|
-
|
|
583
|
+
"""Train and evaluate Mask R-CNN model for instance segmentation.
|
|
584
|
+
|
|
585
|
+
This function trains a Mask R-CNN model for instance segmentation using the
|
|
586
|
+
provided dataset. It supports loading a pretrained model to either initialize
|
|
587
|
+
the backbone or to continue training from a specific checkpoint.
|
|
583
588
|
|
|
584
589
|
Args:
|
|
585
590
|
images_dir (str): Directory containing image GeoTIFF files.
|
|
586
591
|
labels_dir (str): Directory containing label GeoTIFF files.
|
|
587
592
|
output_dir (str): Directory to save model checkpoints and results.
|
|
588
593
|
num_channels (int, optional): Number of input channels. If None, auto-detected.
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
594
|
+
Defaults to 3.
|
|
595
|
+
pretrained (bool): Whether to use pretrained backbone. This is ignored if
|
|
596
|
+
pretrained_model_path is provided. Defaults to True.
|
|
597
|
+
pretrained_model_path (str, optional): Path to a .pth file to load as a
|
|
598
|
+
pretrained model for continued training. Defaults to None.
|
|
599
|
+
batch_size (int): Batch size for training. Defaults to 4.
|
|
600
|
+
num_epochs (int): Number of training epochs. Defaults to 10.
|
|
601
|
+
learning_rate (float): Initial learning rate. Defaults to 0.005.
|
|
602
|
+
seed (int): Random seed for reproducibility. Defaults to 42.
|
|
603
|
+
val_split (float): Fraction of data to use for validation (0-1). Defaults to 0.2.
|
|
595
604
|
visualize (bool): Whether to generate visualizations of model predictions.
|
|
605
|
+
Defaults to False.
|
|
606
|
+
resume_training (bool): If True and pretrained_model_path is provided,
|
|
607
|
+
will try to load optimizer and scheduler states as well. Defaults to False.
|
|
596
608
|
|
|
597
609
|
Returns:
|
|
598
610
|
None: Model weights are saved to output_dir.
|
|
611
|
+
|
|
612
|
+
Raises:
|
|
613
|
+
FileNotFoundError: If pretrained_model_path is provided but file doesn't exist.
|
|
614
|
+
RuntimeError: If there's an issue loading the pretrained model.
|
|
599
615
|
"""
|
|
600
616
|
|
|
617
|
+
import datetime
|
|
618
|
+
|
|
601
619
|
# Set random seeds for reproducibility
|
|
602
620
|
torch.manual_seed(seed)
|
|
603
621
|
np.random.seed(seed)
|
|
@@ -694,9 +712,49 @@ def train_MaskRCNN_model(
|
|
|
694
712
|
# Set up learning rate scheduler
|
|
695
713
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)
|
|
696
714
|
|
|
697
|
-
#
|
|
715
|
+
# Initialize training variables
|
|
716
|
+
start_epoch = 0
|
|
698
717
|
best_iou = 0
|
|
699
|
-
|
|
718
|
+
|
|
719
|
+
# Load pretrained model if provided
|
|
720
|
+
if pretrained_model_path:
|
|
721
|
+
if not os.path.exists(pretrained_model_path):
|
|
722
|
+
raise FileNotFoundError(
|
|
723
|
+
f"Pretrained model file not found: {pretrained_model_path}"
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
print(f"Loading pretrained model from: {pretrained_model_path}")
|
|
727
|
+
try:
|
|
728
|
+
# Check if it's a full checkpoint or just model weights
|
|
729
|
+
checkpoint = torch.load(pretrained_model_path, map_location=device)
|
|
730
|
+
|
|
731
|
+
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
|
|
732
|
+
# It's a checkpoint with extra information
|
|
733
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
734
|
+
|
|
735
|
+
if resume_training:
|
|
736
|
+
# Resume from checkpoint
|
|
737
|
+
start_epoch = checkpoint.get("epoch", 0) + 1
|
|
738
|
+
best_iou = checkpoint.get("best_iou", 0)
|
|
739
|
+
|
|
740
|
+
if "optimizer_state_dict" in checkpoint:
|
|
741
|
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
742
|
+
|
|
743
|
+
if "scheduler_state_dict" in checkpoint:
|
|
744
|
+
lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
745
|
+
|
|
746
|
+
print(f"Resuming training from epoch {start_epoch}")
|
|
747
|
+
print(f"Previous best IoU: {best_iou:.4f}")
|
|
748
|
+
else:
|
|
749
|
+
# Assume it's just the model weights
|
|
750
|
+
model.load_state_dict(checkpoint)
|
|
751
|
+
|
|
752
|
+
print("Pretrained model loaded successfully")
|
|
753
|
+
except Exception as e:
|
|
754
|
+
raise RuntimeError(f"Failed to load pretrained model: {str(e)}")
|
|
755
|
+
|
|
756
|
+
# Training loop
|
|
757
|
+
for epoch in range(start_epoch, num_epochs):
|
|
700
758
|
# Train one epoch
|
|
701
759
|
train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
|
|
702
760
|
|
|
@@ -718,7 +776,7 @@ def train_MaskRCNN_model(
|
|
|
718
776
|
torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth"))
|
|
719
777
|
|
|
720
778
|
# Save checkpoint every 10 epochs
|
|
721
|
-
if (epoch + 1) % 10 == 0:
|
|
779
|
+
if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
|
|
722
780
|
torch.save(
|
|
723
781
|
{
|
|
724
782
|
"epoch": epoch,
|
|
@@ -733,6 +791,18 @@ def train_MaskRCNN_model(
|
|
|
733
791
|
# Save final model
|
|
734
792
|
torch.save(model.state_dict(), os.path.join(output_dir, "final_model.pth"))
|
|
735
793
|
|
|
794
|
+
# Save full checkpoint of final state
|
|
795
|
+
torch.save(
|
|
796
|
+
{
|
|
797
|
+
"epoch": num_epochs - 1,
|
|
798
|
+
"model_state_dict": model.state_dict(),
|
|
799
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
800
|
+
"scheduler_state_dict": lr_scheduler.state_dict(),
|
|
801
|
+
"best_iou": best_iou,
|
|
802
|
+
},
|
|
803
|
+
os.path.join(output_dir, "final_checkpoint.pth"),
|
|
804
|
+
)
|
|
805
|
+
|
|
736
806
|
# Load best model for evaluation and visualization
|
|
737
807
|
model.load_state_dict(torch.load(os.path.join(output_dir, "best_model.pth")))
|
|
738
808
|
|
|
@@ -752,7 +822,23 @@ def train_MaskRCNN_model(
|
|
|
752
822
|
num_samples=5,
|
|
753
823
|
output_dir=os.path.join(output_dir, "visualizations"),
|
|
754
824
|
)
|
|
755
|
-
|
|
825
|
+
|
|
826
|
+
# Save training summary
|
|
827
|
+
with open(os.path.join(output_dir, "training_summary.txt"), "w") as f:
|
|
828
|
+
f.write(
|
|
829
|
+
f"Training completed on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
|
|
830
|
+
)
|
|
831
|
+
f.write(f"Total epochs: {num_epochs}\n")
|
|
832
|
+
f.write(f"Best validation IoU: {best_iou:.4f}\n")
|
|
833
|
+
f.write(f"Final validation IoU: {final_metrics['IoU']:.4f}\n")
|
|
834
|
+
f.write(f"Final validation loss: {final_metrics['loss']:.4f}\n")
|
|
835
|
+
|
|
836
|
+
if pretrained_model_path:
|
|
837
|
+
f.write(f"Started from pretrained model: {pretrained_model_path}\n")
|
|
838
|
+
if resume_training:
|
|
839
|
+
f.write(f"Resumed training from epoch {start_epoch}\n")
|
|
840
|
+
|
|
841
|
+
print(f"Training complete! Trained model saved to {output_dir}")
|
|
756
842
|
|
|
757
843
|
|
|
758
844
|
def inference_on_geotiff(
|
geoai/utils.py
CHANGED
|
@@ -6,6 +6,7 @@ import math
|
|
|
6
6
|
import os
|
|
7
7
|
import warnings
|
|
8
8
|
import xml.etree.ElementTree as ET
|
|
9
|
+
from collections import OrderedDict
|
|
9
10
|
from collections.abc import Iterable
|
|
10
11
|
from pathlib import Path
|
|
11
12
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
@@ -29,6 +30,7 @@ from rasterio.plot import show
|
|
|
29
30
|
from rasterio.windows import Window
|
|
30
31
|
from shapely.affinity import rotate
|
|
31
32
|
from shapely.geometry import MultiPolygon, Polygon, box, mapping, shape
|
|
33
|
+
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50
|
|
32
34
|
from torchvision.transforms import RandomRotation
|
|
33
35
|
from tqdm import tqdm
|
|
34
36
|
|
|
@@ -1986,6 +1988,7 @@ def raster_to_vector(
|
|
|
1986
1988
|
simplify_tolerance=None,
|
|
1987
1989
|
class_values=None,
|
|
1988
1990
|
attribute_name="class",
|
|
1991
|
+
unique_attribute_value=False,
|
|
1989
1992
|
output_format="geojson",
|
|
1990
1993
|
plot_result=False,
|
|
1991
1994
|
):
|
|
@@ -2000,6 +2003,7 @@ def raster_to_vector(
|
|
|
2000
2003
|
simplify_tolerance (float): Tolerance for geometry simplification. None for no simplification.
|
|
2001
2004
|
class_values (list): Specific pixel values to vectorize. If None, all values > threshold are vectorized.
|
|
2002
2005
|
attribute_name (str): Name of the attribute field for the class values.
|
|
2006
|
+
unique_attribute_value (bool): Whether to generate unique values for each shape within a class.
|
|
2003
2007
|
output_format (str): Format for output file - 'geojson', 'shapefile', 'gpkg'.
|
|
2004
2008
|
plot_result (bool): Whether to plot the resulting polygons overlaid on the raster.
|
|
2005
2009
|
|
|
@@ -2030,7 +2034,7 @@ def raster_to_vector(
|
|
|
2030
2034
|
# Process each class value
|
|
2031
2035
|
for class_val in class_values:
|
|
2032
2036
|
mask = masks[class_val]
|
|
2033
|
-
|
|
2037
|
+
shape_count = 1
|
|
2034
2038
|
# Vectorize the mask
|
|
2035
2039
|
for geom, value in features.shapes(
|
|
2036
2040
|
mask.astype(np.uint8), mask=mask, transform=transform
|
|
@@ -2047,7 +2051,14 @@ def raster_to_vector(
|
|
|
2047
2051
|
geom = geom.simplify(simplify_tolerance)
|
|
2048
2052
|
|
|
2049
2053
|
# Add to features list with class value
|
|
2050
|
-
|
|
2054
|
+
if unique_attribute_value:
|
|
2055
|
+
all_features.append(
|
|
2056
|
+
{"geometry": geom, attribute_name: class_val * shape_count}
|
|
2057
|
+
)
|
|
2058
|
+
else:
|
|
2059
|
+
all_features.append({"geometry": geom, attribute_name: class_val})
|
|
2060
|
+
|
|
2061
|
+
shape_count += 1
|
|
2051
2062
|
|
|
2052
2063
|
# Create GeoDataFrame
|
|
2053
2064
|
if all_features:
|
|
@@ -4482,8 +4493,8 @@ def region_groups(
|
|
|
4482
4493
|
"area_bbox",
|
|
4483
4494
|
"area_convex",
|
|
4484
4495
|
"area_filled",
|
|
4485
|
-
"
|
|
4486
|
-
"
|
|
4496
|
+
"axis_major_length",
|
|
4497
|
+
"axis_minor_length",
|
|
4487
4498
|
"eccentricity",
|
|
4488
4499
|
"diameter_areagth",
|
|
4489
4500
|
"extent",
|
|
@@ -4580,7 +4591,7 @@ def region_groups(
|
|
|
4580
4591
|
)
|
|
4581
4592
|
|
|
4582
4593
|
df = pd.DataFrame(props)
|
|
4583
|
-
df["elongation"] = df["
|
|
4594
|
+
df["elongation"] = df["axis_major_length"] / df["axis_minor_length"]
|
|
4584
4595
|
|
|
4585
4596
|
dtype = "uint8"
|
|
4586
4597
|
if num_labels > 255 and num_labels <= 65535:
|
|
@@ -4597,9 +4608,29 @@ def region_groups(
|
|
|
4597
4608
|
da.values = label_image
|
|
4598
4609
|
if out_image is not None:
|
|
4599
4610
|
da.rio.to_raster(out_image, dtype=dtype)
|
|
4600
|
-
|
|
4601
|
-
|
|
4602
|
-
|
|
4611
|
+
|
|
4612
|
+
if out_vector is not None:
|
|
4613
|
+
tmp_raster = None
|
|
4614
|
+
tmp_vector = None
|
|
4615
|
+
try:
|
|
4616
|
+
if out_image is None:
|
|
4617
|
+
tmp_raster = temp_file_path(".tif")
|
|
4618
|
+
da.rio.to_raster(tmp_raster, dtype=dtype)
|
|
4619
|
+
tmp_vector = temp_file_path(".gpkg")
|
|
4620
|
+
raster_to_vector(
|
|
4621
|
+
tmp_raster,
|
|
4622
|
+
tmp_vector,
|
|
4623
|
+
attribute_name="value",
|
|
4624
|
+
unique_attribute_value=True,
|
|
4625
|
+
)
|
|
4626
|
+
else:
|
|
4627
|
+
tmp_vector = temp_file_path(".gpkg")
|
|
4628
|
+
raster_to_vector(
|
|
4629
|
+
out_image,
|
|
4630
|
+
tmp_vector,
|
|
4631
|
+
attribute_name="value",
|
|
4632
|
+
unique_attribute_value=True,
|
|
4633
|
+
)
|
|
4603
4634
|
gdf = gpd.read_file(tmp_vector)
|
|
4604
4635
|
gdf["label"] = gdf["value"].astype(int)
|
|
4605
4636
|
gdf.drop(columns=["value"], inplace=True)
|
|
@@ -4607,6 +4638,15 @@ def region_groups(
|
|
|
4607
4638
|
gdf2.to_file(out_vector)
|
|
4608
4639
|
gdf2.sort_values("label", inplace=True)
|
|
4609
4640
|
df = gdf2
|
|
4641
|
+
finally:
|
|
4642
|
+
try:
|
|
4643
|
+
if tmp_raster is not None and os.path.exists(tmp_raster):
|
|
4644
|
+
os.remove(tmp_raster)
|
|
4645
|
+
if tmp_vector is not None and os.path.exists(tmp_vector):
|
|
4646
|
+
os.remove(tmp_vector)
|
|
4647
|
+
except Exception as e:
|
|
4648
|
+
print(f"Warning: Failed to delete temporary files: {str(e)}")
|
|
4649
|
+
|
|
4610
4650
|
return da, df
|
|
4611
4651
|
|
|
4612
4652
|
|
|
@@ -5821,3 +5861,195 @@ def orthogonalize(
|
|
|
5821
5861
|
print("Done!")
|
|
5822
5862
|
|
|
5823
5863
|
return gdf
|
|
5864
|
+
|
|
5865
|
+
|
|
5866
|
+
def inspect_pth_file(pth_path):
|
|
5867
|
+
"""
|
|
5868
|
+
Inspect a PyTorch .pth model file to determine its architecture.
|
|
5869
|
+
|
|
5870
|
+
Args:
|
|
5871
|
+
pth_path: Path to the .pth file to inspect
|
|
5872
|
+
|
|
5873
|
+
Returns:
|
|
5874
|
+
Information about the model architecture
|
|
5875
|
+
"""
|
|
5876
|
+
# Check if file exists
|
|
5877
|
+
if not os.path.exists(pth_path):
|
|
5878
|
+
print(f"Error: File {pth_path} not found")
|
|
5879
|
+
return
|
|
5880
|
+
|
|
5881
|
+
# Load the checkpoint
|
|
5882
|
+
try:
|
|
5883
|
+
checkpoint = torch.load(pth_path, map_location=torch.device("cpu"))
|
|
5884
|
+
print(f"\n{'='*50}")
|
|
5885
|
+
print(f"Inspecting model file: {pth_path}")
|
|
5886
|
+
print(f"{'='*50}\n")
|
|
5887
|
+
|
|
5888
|
+
# Check if it's a state_dict or a complete model
|
|
5889
|
+
if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict):
|
|
5890
|
+
if "state_dict" in checkpoint:
|
|
5891
|
+
print("Found 'state_dict' key in the checkpoint.")
|
|
5892
|
+
state_dict = checkpoint["state_dict"]
|
|
5893
|
+
elif "model_state_dict" in checkpoint:
|
|
5894
|
+
print("Found 'model_state_dict' key in the checkpoint.")
|
|
5895
|
+
state_dict = checkpoint["model_state_dict"]
|
|
5896
|
+
else:
|
|
5897
|
+
print("Assuming file contains a direct state_dict.")
|
|
5898
|
+
state_dict = checkpoint
|
|
5899
|
+
|
|
5900
|
+
# Print the keys in the checkpoint
|
|
5901
|
+
print("\nCheckpoint contains the following keys:")
|
|
5902
|
+
for key in checkpoint.keys():
|
|
5903
|
+
if isinstance(checkpoint[key], dict):
|
|
5904
|
+
print(f"- {key} (dictionary with {len(checkpoint[key])} items)")
|
|
5905
|
+
elif isinstance(checkpoint[key], (torch.Tensor, list, tuple)):
|
|
5906
|
+
print(
|
|
5907
|
+
f"- {key} (shape/size: {len(checkpoint[key]) if isinstance(checkpoint[key], (list, tuple)) else checkpoint[key].shape})"
|
|
5908
|
+
)
|
|
5909
|
+
else:
|
|
5910
|
+
print(f"- {key} ({type(checkpoint[key]).__name__})")
|
|
5911
|
+
|
|
5912
|
+
# Try to infer the model architecture from the state_dict keys
|
|
5913
|
+
print("\nAnalyzing model architecture from state_dict...")
|
|
5914
|
+
|
|
5915
|
+
# Extract layer keys for analysis
|
|
5916
|
+
layer_keys = list(state_dict.keys())
|
|
5917
|
+
|
|
5918
|
+
# Print the first few layer keys to understand naming pattern
|
|
5919
|
+
print("\nFirst 10 layer names in state_dict:")
|
|
5920
|
+
for i, key in enumerate(layer_keys[:10]):
|
|
5921
|
+
shape = state_dict[key].shape
|
|
5922
|
+
print(f"- {key} (shape: {shape})")
|
|
5923
|
+
|
|
5924
|
+
# Look for architecture indicators in the keys
|
|
5925
|
+
architecture_indicators = {
|
|
5926
|
+
"conv": 0,
|
|
5927
|
+
"bn": 0,
|
|
5928
|
+
"layer": 0,
|
|
5929
|
+
"fc": 0,
|
|
5930
|
+
"backbone": 0,
|
|
5931
|
+
"encoder": 0,
|
|
5932
|
+
"decoder": 0,
|
|
5933
|
+
"unet": 0,
|
|
5934
|
+
"resnet": 0,
|
|
5935
|
+
"classifier": 0,
|
|
5936
|
+
"deeplab": 0,
|
|
5937
|
+
"fcn": 0,
|
|
5938
|
+
}
|
|
5939
|
+
|
|
5940
|
+
for key in layer_keys:
|
|
5941
|
+
for indicator in architecture_indicators:
|
|
5942
|
+
if indicator in key.lower():
|
|
5943
|
+
architecture_indicators[indicator] += 1
|
|
5944
|
+
|
|
5945
|
+
print("\nArchitecture indicators found in layer names:")
|
|
5946
|
+
for indicator, count in architecture_indicators.items():
|
|
5947
|
+
if count > 0:
|
|
5948
|
+
print(f"- '{indicator}' appears {count} times")
|
|
5949
|
+
|
|
5950
|
+
# Count total parameters
|
|
5951
|
+
total_params = sum(p.numel() for p in state_dict.values())
|
|
5952
|
+
print(f"\nTotal parameters: {total_params:,}")
|
|
5953
|
+
|
|
5954
|
+
# Try to load the model with different architectures
|
|
5955
|
+
print("\nAttempting to match with common architectures...")
|
|
5956
|
+
|
|
5957
|
+
# Try to identify if it's a segmentation model
|
|
5958
|
+
if any("out" in k or "classifier" in k for k in layer_keys):
|
|
5959
|
+
print("Model appears to be a segmentation model.")
|
|
5960
|
+
|
|
5961
|
+
# Check if it might be a UNet
|
|
5962
|
+
if (
|
|
5963
|
+
architecture_indicators["encoder"] > 0
|
|
5964
|
+
and architecture_indicators["decoder"] > 0
|
|
5965
|
+
):
|
|
5966
|
+
print(
|
|
5967
|
+
"Architecture seems to be a UNet-based model with encoder-decoder structure."
|
|
5968
|
+
)
|
|
5969
|
+
# Check for FCN or DeepLab indicators
|
|
5970
|
+
elif architecture_indicators["fcn"] > 0:
|
|
5971
|
+
print(
|
|
5972
|
+
"Architecture seems to be FCN-based (Fully Convolutional Network)."
|
|
5973
|
+
)
|
|
5974
|
+
elif architecture_indicators["deeplab"] > 0:
|
|
5975
|
+
print("Architecture seems to be DeepLab-based.")
|
|
5976
|
+
elif architecture_indicators["backbone"] > 0:
|
|
5977
|
+
print(
|
|
5978
|
+
"Model has a backbone architecture, likely a modern segmentation model."
|
|
5979
|
+
)
|
|
5980
|
+
|
|
5981
|
+
# Try to infer output classes from the final layer
|
|
5982
|
+
output_layer_keys = [
|
|
5983
|
+
k for k in layer_keys if "classifier" in k or k.endswith(".out.weight")
|
|
5984
|
+
]
|
|
5985
|
+
if output_layer_keys:
|
|
5986
|
+
output_shape = state_dict[output_layer_keys[0]].shape
|
|
5987
|
+
if len(output_shape) >= 2:
|
|
5988
|
+
num_classes = output_shape[0]
|
|
5989
|
+
print(f"\nModel likely has {num_classes} output classes.")
|
|
5990
|
+
|
|
5991
|
+
print("\nSUMMARY:")
|
|
5992
|
+
print("The model appears to be", end=" ")
|
|
5993
|
+
if architecture_indicators["unet"] > 0:
|
|
5994
|
+
print("a UNet architecture.", end=" ")
|
|
5995
|
+
elif architecture_indicators["fcn"] > 0:
|
|
5996
|
+
print("an FCN architecture.", end=" ")
|
|
5997
|
+
elif architecture_indicators["deeplab"] > 0:
|
|
5998
|
+
print("a DeepLab architecture.", end=" ")
|
|
5999
|
+
elif architecture_indicators["resnet"] > 0:
|
|
6000
|
+
print("ResNet-based.", end=" ")
|
|
6001
|
+
else:
|
|
6002
|
+
print("a custom architecture.", end=" ")
|
|
6003
|
+
|
|
6004
|
+
# Try to load with common models
|
|
6005
|
+
try_common_architectures(state_dict)
|
|
6006
|
+
|
|
6007
|
+
else:
|
|
6008
|
+
print(
|
|
6009
|
+
"The file contains an entire model object rather than just a state dictionary."
|
|
6010
|
+
)
|
|
6011
|
+
# If it's a complete model, we can directly examine its architecture
|
|
6012
|
+
print(checkpoint)
|
|
6013
|
+
|
|
6014
|
+
except Exception as e:
|
|
6015
|
+
print(f"Error loading the model file: {str(e)}")
|
|
6016
|
+
|
|
6017
|
+
|
|
6018
|
+
def try_common_architectures(state_dict):
|
|
6019
|
+
"""
|
|
6020
|
+
Try to load the state_dict into common architectures to see which one fits.
|
|
6021
|
+
|
|
6022
|
+
Args:
|
|
6023
|
+
state_dict: The model's state dictionary
|
|
6024
|
+
"""
|
|
6025
|
+
import torchinfo
|
|
6026
|
+
|
|
6027
|
+
# Test models and their initializations
|
|
6028
|
+
models_to_try = {
|
|
6029
|
+
"FCN-ResNet50": lambda: fcn_resnet50(num_classes=9),
|
|
6030
|
+
"DeepLabV3-ResNet50": lambda: deeplabv3_resnet50(num_classes=9),
|
|
6031
|
+
}
|
|
6032
|
+
|
|
6033
|
+
print("\nTrying to load state_dict into common architectures:")
|
|
6034
|
+
|
|
6035
|
+
for name, model_fn in models_to_try.items():
|
|
6036
|
+
try:
|
|
6037
|
+
model = model_fn()
|
|
6038
|
+
# Sometimes state_dict keys have 'model.' prefix
|
|
6039
|
+
if all(k.startswith("model.") for k in state_dict.keys()):
|
|
6040
|
+
cleaned_state_dict = {k[6:]: v for k, v in state_dict.items()}
|
|
6041
|
+
model.load_state_dict(cleaned_state_dict, strict=False)
|
|
6042
|
+
else:
|
|
6043
|
+
model.load_state_dict(state_dict, strict=False)
|
|
6044
|
+
|
|
6045
|
+
print(
|
|
6046
|
+
f"- {name}: Successfully loaded (may have missing or unexpected keys)"
|
|
6047
|
+
)
|
|
6048
|
+
|
|
6049
|
+
# Generate model summary
|
|
6050
|
+
print(f"\nSummary of {name} architecture:")
|
|
6051
|
+
summary = torchinfo.summary(model, input_size=(1, 3, 224, 224), verbose=0)
|
|
6052
|
+
print(summary)
|
|
6053
|
+
|
|
6054
|
+
except Exception as e:
|
|
6055
|
+
print(f"- {name}: Failed to load - {str(e)}")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: geoai-py
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.3
|
|
4
4
|
Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
|
|
5
5
|
Author-email: Qiusheng Wu <giswqs@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -34,10 +34,12 @@ Requires-Dist: scikit-image
|
|
|
34
34
|
Requires-Dist: scikit-learn
|
|
35
35
|
Requires-Dist: torch
|
|
36
36
|
Requires-Dist: torchgeo
|
|
37
|
+
Requires-Dist: torchinfo
|
|
37
38
|
Requires-Dist: tqdm
|
|
38
39
|
Requires-Dist: transformers
|
|
39
40
|
Provides-Extra: extra
|
|
40
41
|
Requires-Dist: overturemaps; extra == "extra"
|
|
42
|
+
Dynamic: license-file
|
|
41
43
|
|
|
42
44
|
# GeoAI: Artificial Intelligence for Geospatial Data
|
|
43
45
|
|
|
@@ -49,6 +51,8 @@ Requires-Dist: overturemaps; extra == "extra"
|
|
|
49
51
|
[](https://opensource.org/licenses/MIT)
|
|
50
52
|
[](https://bit.ly/GeoAI-Tutorials)
|
|
51
53
|
|
|
54
|
+
[](https://github.com/opengeos/geoai/blob/master/docs/assets/logo.png)
|
|
55
|
+
|
|
52
56
|
**A powerful Python package for integrating Artificial Intelligence with geospatial data analysis and visualization**
|
|
53
57
|
|
|
54
58
|
GeoAI bridges the gap between AI and geospatial analysis, providing tools for processing, analyzing, and visualizing geospatial data using advanced machine learning techniques. Whether you're working with satellite imagery, LiDAR point clouds, or vector data, GeoAI offers intuitive interfaces to apply cutting-edge AI models.
|
|
@@ -139,7 +143,3 @@ We welcome contributions of all kinds! See our [contributing guide](https://geoa
|
|
|
139
143
|
## 📄 License
|
|
140
144
|
|
|
141
145
|
GeoAI is free and open source software, licensed under the MIT License.
|
|
142
|
-
|
|
143
|
-
## 💖 Acknowledgment
|
|
144
|
-
|
|
145
|
-
Some of the pre-trained models used in the geoai package are adapted from the [ArcGIS Living Atlas](https://livingatlas.arcgis.com/en/browse/?q=dlpk#d=2&q=dlpk). Credits to Esri for making these models available.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
geoai/__init__.py,sha256=b12E2HztHEPaaKMVKpk6GPjC7ElUMxOW9pYKl4VZkmE,3592
|
|
2
|
+
geoai/download.py,sha256=BvCEpcBwZlEOWieixsdyvQDiE0CXRjU7oayLmy5_Dgs,40110
|
|
3
|
+
geoai/extract.py,sha256=GocJufMmrwEWxNBL1J91EXXHL8AKcO8m_lmtUF5AKPw,119102
|
|
4
|
+
geoai/geoai.py,sha256=BqKdWzNruDdGqwqoyTaJzUq4lKGj-RDBZlSO3t3-GxQ,626
|
|
5
|
+
geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
|
|
6
|
+
geoai/segment.py,sha256=g3YW17ftr--CKq6VB32TJEPY8owGQ7uQ0sg_tUT2ooE,13681
|
|
7
|
+
geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
|
|
8
|
+
geoai/train.py,sha256=-l2j1leTxDnFDLaBslu1q6CobXjm3LEdiQwUWOU8P6M,40088
|
|
9
|
+
geoai/utils.py,sha256=Wg9jbMBKUZSGUmU8Vkp6v19QcDNg5KmcyZxuHqJvgnc,233016
|
|
10
|
+
geoai_py-0.4.3.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
|
11
|
+
geoai_py-0.4.3.dist-info/METADATA,sha256=geDmJ-1zHImsOdcj4gypgq8JqSy8MznnxAnICwh0EbA,6049
|
|
12
|
+
geoai_py-0.4.3.dist-info/WHEEL,sha256=aoLN90hLOL0c0qxXMxWYUM3HA3WmFGZQqEJHX1V_OJE,109
|
|
13
|
+
geoai_py-0.4.3.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
|
14
|
+
geoai_py-0.4.3.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
|
15
|
+
geoai_py-0.4.3.dist-info/RECORD,,
|
geoai_py-0.4.1.dist-info/RECORD
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
geoai/__init__.py,sha256=lZ5LYzlwjX-TuvVtvuk0TC0le80P83lfUUOXBY0MPoU,3592
|
|
2
|
-
geoai/download.py,sha256=eqMecJqvqyrIVFViNA7pW8a5EIhqYJzRILmxQoFHG2k,13095
|
|
3
|
-
geoai/extract.py,sha256=CCXjUcGC4ZOKOKKjvElp8VFmTz46b0ATvGitbOPgTwE,95506
|
|
4
|
-
geoai/geoai.py,sha256=L1jkozDcjqJXvqT6i8oW04Ix9x4cc2-LNYi9_564ABQ,163
|
|
5
|
-
geoai/hf.py,sha256=mLKGxEAS5eHkxZLwuLpYc1o7e3-7QIXdBv-QUY-RkFk,17072
|
|
6
|
-
geoai/segment.py,sha256=g3YW17ftr--CKq6VB32TJEPY8owGQ7uQ0sg_tUT2ooE,13681
|
|
7
|
-
geoai/segmentation.py,sha256=AtPzCvguHAEeuyXafa4bzMFATvltEYcah1B8ZMfkM_s,11373
|
|
8
|
-
geoai/train.py,sha256=VaeFzIkVUNTdre8ImgUNhmbpA42qijSXaajLpmBF_Ic,36248
|
|
9
|
-
geoai/utils.py,sha256=oBNhk73_Owv-pmaRyBKkq0HinnqnMgP3U5CUAQx6ln0,223700
|
|
10
|
-
geoai_py-0.4.1.dist-info/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
|
|
11
|
-
geoai_py-0.4.1.dist-info/METADATA,sha256=0JZwrVh1EtR3mJIu8cumnafbCnElW70iBa77hSDwXHk,6078
|
|
12
|
-
geoai_py-0.4.1.dist-info/WHEEL,sha256=SrDKpSbFN1G94qcmBqS9nyHcDMp9cUS9OC06hC0G3G0,109
|
|
13
|
-
geoai_py-0.4.1.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
|
|
14
|
-
geoai_py-0.4.1.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
|
|
15
|
-
geoai_py-0.4.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|