geoai-py 0.9.0__py2.py3-none-any.whl → 0.9.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.9.0"
5
+ __version__ = "0.9.2"
6
6
 
7
7
 
8
8
  import os
geoai/change_detection.py CHANGED
@@ -1,13 +1,16 @@
1
1
  """Change detection module for remote sensing imagery using torchange."""
2
2
 
3
3
  import os
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import cv2
7
+ import matplotlib.pyplot as plt
4
8
  import numpy as np
5
9
  import rasterio
6
10
  from rasterio.windows import from_bounds
7
11
  from skimage.transform import resize
8
- import matplotlib.pyplot as plt
9
- import cv2
10
12
  from torchange.models.segment_any_change import AnyChange, show_change_masks
13
+
11
14
  from .utils import download_file
12
15
 
13
16
 
@@ -47,15 +50,15 @@ class ChangeDetection:
47
50
 
48
51
  def set_hyperparameters(
49
52
  self,
50
- change_confidence_threshold=155,
51
- auto_threshold=False,
52
- use_normalized_feature=True,
53
- area_thresh=0.8,
54
- match_hist=False,
55
- object_sim_thresh=60,
56
- bitemporal_match=True,
57
- **kwargs,
58
- ):
53
+ change_confidence_threshold: int = 155,
54
+ auto_threshold: bool = False,
55
+ use_normalized_feature: bool = True,
56
+ area_thresh: float = 0.8,
57
+ match_hist: bool = False,
58
+ object_sim_thresh: int = 60,
59
+ bitemporal_match: bool = True,
60
+ **kwargs: Any,
61
+ ) -> None:
59
62
  """
60
63
  Set hyperparameters for the change detection model.
61
64
 
@@ -83,16 +86,16 @@ class ChangeDetection:
83
86
 
84
87
  def set_mask_generator_params(
85
88
  self,
86
- points_per_side=32,
89
+ points_per_side: int = 32,
87
90
  points_per_batch: int = 64,
88
91
  pred_iou_thresh: float = 0.5,
89
92
  stability_score_thresh: float = 0.95,
90
93
  stability_score_offset: float = 1.0,
91
94
  box_nms_thresh: float = 0.7,
92
- point_grids=None,
95
+ point_grids: Optional[List] = None,
93
96
  min_mask_region_area: int = 0,
94
- **kwargs,
95
- ):
97
+ **kwargs: Any,
98
+ ) -> None:
96
99
  """
97
100
  Set mask generator parameters.
98
101
 
@@ -203,17 +206,17 @@ class ChangeDetection:
203
206
 
204
207
  def detect_changes(
205
208
  self,
206
- image1_path,
207
- image2_path,
208
- output_path=None,
209
- target_size=1024,
210
- return_results=True,
211
- export_probability=False,
212
- probability_output_path=None,
213
- export_instance_masks=False,
214
- instance_masks_output_path=None,
215
- return_detailed_results=False,
216
- ):
209
+ image1_path: str,
210
+ image2_path: str,
211
+ output_path: Optional[str] = None,
212
+ target_size: int = 1024,
213
+ return_results: bool = True,
214
+ export_probability: bool = False,
215
+ probability_output_path: Optional[str] = None,
216
+ export_instance_masks: bool = False,
217
+ instance_masks_output_path: Optional[str] = None,
218
+ return_detailed_results: bool = False,
219
+ ) -> Union[Tuple[Any, np.ndarray, np.ndarray], Dict[str, Any], None]:
217
220
  """
218
221
  Detect changes between two GeoTIFF images with instance segmentation.
219
222
 
@@ -530,7 +533,9 @@ class ChangeDetection:
530
533
  ) as dst:
531
534
  dst.write(prob_final.astype(np.float32), 1)
532
535
 
533
- def visualize_changes(self, image1_path, image2_path, figsize=(15, 5)):
536
+ def visualize_changes(
537
+ self, image1_path: str, image2_path: str, figsize: Tuple[int, int] = (15, 5)
538
+ ) -> plt.Figure:
534
539
  """
535
540
  Visualize change detection results.
536
541
 
@@ -1516,14 +1521,16 @@ Areas:
1516
1521
  return results
1517
1522
 
1518
1523
 
1519
- def download_checkpoint(model_type="vit_h", checkpoint_dir=None):
1524
+ def download_checkpoint(
1525
+ model_type: str = "vit_h", checkpoint_dir: Optional[str] = None
1526
+ ) -> str:
1520
1527
  """Download the SAM model checkpoint.
1521
1528
 
1522
1529
  Args:
1523
1530
  model_type (str, optional): The model type. Can be one of ['vit_h', 'vit_l', 'vit_b'].
1524
1531
  Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
1525
1532
  checkpoint_dir (str, optional): The checkpoint_dir directory. Defaults to None,
1526
- "~/.cache/torch/hub/checkpoints".
1533
+ which uses "~/.cache/torch/hub/checkpoints".
1527
1534
  """
1528
1535
 
1529
1536
  model_types = {
geoai/classify.py CHANGED
@@ -1,50 +1,51 @@
1
1
  """The module for training semantic segmentation models for classifying remote sensing imagery."""
2
2
 
3
3
  import os
4
+ from typing import Any, Dict, List, Optional, Union
4
5
 
5
6
  import numpy as np
6
7
 
7
8
 
8
9
  def train_classifier(
9
- image_root,
10
- label_root,
11
- output_dir="output",
12
- in_channels=4,
13
- num_classes=14,
14
- epochs=20,
15
- img_size=256,
16
- batch_size=8,
17
- sample_size=500,
18
- model="unet",
19
- backbone="resnet50",
20
- weights=True,
21
- num_filters=3,
22
- loss="ce",
23
- class_weights=None,
24
- ignore_index=None,
25
- lr=0.001,
26
- patience=10,
27
- freeze_backbone=False,
28
- freeze_decoder=False,
29
- transforms=None,
30
- use_augmentation=False,
31
- seed=42,
32
- train_val_test_split=(0.6, 0.2, 0.2),
33
- accelerator="auto",
34
- devices="auto",
35
- logger=None,
36
- callbacks=None,
37
- log_every_n_steps=10,
38
- use_distributed_sampler=False,
39
- monitor_metric="val_loss",
40
- mode="min",
41
- save_top_k=1,
42
- save_last=True,
43
- checkpoint_filename="best_model",
44
- checkpoint_path=None,
45
- every_n_epochs=1,
46
- **kwargs,
47
- ):
10
+ image_root: str,
11
+ label_root: str,
12
+ output_dir: str = "output",
13
+ in_channels: int = 4,
14
+ num_classes: int = 14,
15
+ epochs: int = 20,
16
+ img_size: int = 256,
17
+ batch_size: int = 8,
18
+ sample_size: int = 500,
19
+ model: str = "unet",
20
+ backbone: str = "resnet50",
21
+ weights: bool = True,
22
+ num_filters: int = 3,
23
+ loss: str = "ce",
24
+ class_weights: Optional[List[float]] = None,
25
+ ignore_index: Optional[int] = None,
26
+ lr: float = 0.001,
27
+ patience: int = 10,
28
+ freeze_backbone: bool = False,
29
+ freeze_decoder: bool = False,
30
+ transforms: Optional[Any] = None,
31
+ use_augmentation: bool = False,
32
+ seed: int = 42,
33
+ train_val_test_split: tuple = (0.6, 0.2, 0.2),
34
+ accelerator: str = "auto",
35
+ devices: str = "auto",
36
+ logger: Optional[Any] = None,
37
+ callbacks: Optional[List[Any]] = None,
38
+ log_every_n_steps: int = 10,
39
+ use_distributed_sampler: bool = False,
40
+ monitor_metric: str = "val_loss",
41
+ mode: str = "min",
42
+ save_top_k: int = 1,
43
+ save_last: bool = True,
44
+ checkpoint_filename: str = "best_model",
45
+ checkpoint_path: Optional[str] = None,
46
+ every_n_epochs: int = 1,
47
+ **kwargs: Any,
48
+ ) -> Any:
48
49
  """Train a semantic segmentation model on geospatial imagery.
49
50
 
50
51
  This function sets up datasets, model, trainer, and executes the training process
@@ -584,15 +585,15 @@ def _classify_image(
584
585
 
585
586
 
586
587
  def classify_image(
587
- image_path,
588
- model_path,
589
- output_path=None,
590
- chip_size=1024,
591
- overlap=256,
592
- batch_size=4,
593
- colormap=None,
594
- **kwargs,
595
- ):
588
+ image_path: str,
589
+ model_path: str,
590
+ output_path: Optional[str] = None,
591
+ chip_size: int = 1024,
592
+ overlap: int = 256,
593
+ batch_size: int = 4,
594
+ colormap: Optional[Dict] = None,
595
+ **kwargs: Any,
596
+ ) -> str:
596
597
  """
597
598
  Classify a geospatial image using a trained semantic segmentation model.
598
599
 
@@ -826,15 +827,15 @@ def classify_image(
826
827
 
827
828
 
828
829
  def classify_images(
829
- image_paths,
830
- model_path,
831
- output_dir=None,
832
- chip_size=1024,
833
- batch_size=4,
834
- colormap=None,
835
- file_extension=".tif",
836
- **kwargs,
837
- ):
830
+ image_paths: Union[str, List[str]],
831
+ model_path: str,
832
+ output_dir: Optional[str] = None,
833
+ chip_size: int = 1024,
834
+ batch_size: int = 4,
835
+ colormap: Optional[Dict] = None,
836
+ file_extension: str = ".tif",
837
+ **kwargs: Any,
838
+ ) -> List[str]:
838
839
  """
839
840
  Classify multiple geospatial images using a trained semantic segmentation model.
840
841