geoai-py 0.12.0__py2.py3-none-any.whl → 0.13.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.12.0"
5
+ __version__ = "0.13.1"
6
6
 
7
7
 
8
8
  import os
@@ -99,5 +99,5 @@ def set_proj_lib_path(verbose=False):
99
99
  # if ("google.colab" not in sys.modules) and (sys.platform != "windows"):
100
100
  # set_proj_lib_path()
101
101
 
102
+ from .dinov3 import DINOv3GeoProcessor, analyze_image_patches, create_similarity_map
102
103
  from .geoai import *
103
- from .dinov3 import DINOv3GeoProcessor, create_similarity_map, analyze_image_patches
geoai/agents/__init__.py CHANGED
@@ -1,2 +1,8 @@
1
- from .geo_agents import GeoAgent
1
+ from .geo_agents import (
2
+ GeoAgent,
3
+ create_ollama_model,
4
+ create_anthropic_model,
5
+ create_openai_model,
6
+ create_bedrock_model,
7
+ )
2
8
  from .map_tools import MapTools
@@ -1,17 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import asyncio
4
+ import os
4
5
  import uuid
5
6
  from concurrent.futures import ThreadPoolExecutor
6
7
  from types import SimpleNamespace
7
8
  from typing import Any, Callable, Optional
8
9
 
10
+ import boto3
9
11
  import ipywidgets as widgets
10
12
  import leafmap.maplibregl as leafmap
13
+ from botocore.config import Config as BotocoreConfig
11
14
  from ipyevents import Event
12
15
  from IPython.display import display
13
16
  from strands import Agent
14
- from strands.models.ollama import OllamaModel
17
+ from strands.models import BedrockModel
18
+ from strands.models.anthropic import AnthropicModel
19
+ from strands.models.ollama import OllamaModel as _OllamaModel
20
+ from strands.models.openai import OpenAIModel
15
21
 
16
22
  from .map_tools import MapSession, MapTools
17
23
 
@@ -23,6 +29,147 @@ except Exception:
23
29
  pass
24
30
 
25
31
 
32
+ class OllamaModel(_OllamaModel):
33
+ """Fixed OllamaModel that ensures proper model_id handling."""
34
+
35
+ async def stream(self, *args, **kwargs):
36
+ """Override stream to ensure model_id is passed as string."""
37
+ # Patch the ollama client to handle model object correctly
38
+ import ollama
39
+
40
+ # Save original method if not already saved
41
+ if not hasattr(ollama.AsyncClient, "_original_chat"):
42
+ ollama.AsyncClient._original_chat = ollama.AsyncClient.chat
43
+
44
+ async def fixed_chat(self, **chat_kwargs):
45
+ # If model is an OllamaModel object, extract the model_id
46
+ if "model" in chat_kwargs and hasattr(chat_kwargs["model"], "config"):
47
+ chat_kwargs["model"] = chat_kwargs["model"].config["model_id"]
48
+ return await ollama.AsyncClient._original_chat(self, **chat_kwargs)
49
+
50
+ ollama.AsyncClient.chat = fixed_chat
51
+
52
+ # Call the original stream method
53
+ async for chunk in super().stream(*args, **kwargs):
54
+ yield chunk
55
+
56
+
57
+ def create_ollama_model(
58
+ host: str = "http://localhost:11434",
59
+ model_id: str = "llama3.1",
60
+ client_args: dict = None,
61
+ **kwargs: Any,
62
+ ) -> OllamaModel:
63
+ """Create an Ollama model.
64
+
65
+ Args:
66
+ host: Ollama host URL.
67
+ model_id: Ollama model ID.
68
+ client_args: Client arguments for the Ollama model.
69
+ **kwargs: Additional keyword arguments for the Ollama model.
70
+
71
+ Returns:
72
+ OllamaModel: An Ollama model.
73
+ """
74
+ if client_args is None:
75
+ client_args = {}
76
+ return OllamaModel(host=host, model_id=model_id, client_args=client_args, **kwargs)
77
+
78
+
79
+ def create_openai_model(
80
+ model_id: str = "gpt-4o-mini",
81
+ api_key: str = None,
82
+ client_args: dict = None,
83
+ **kwargs: Any,
84
+ ) -> OpenAIModel:
85
+ """Create an OpenAI model.
86
+
87
+ Args:
88
+ model_id: OpenAI model ID.
89
+ api_key: OpenAI API key.
90
+ client_args: Client arguments for the OpenAI model.
91
+ **kwargs: Additional keyword arguments for the OpenAI model.
92
+
93
+ Returns:
94
+ OpenAIModel: An OpenAI model.
95
+ """
96
+
97
+ if api_key is None:
98
+ try:
99
+ api_key = os.getenv("OPENAI_API_KEY", None)
100
+ if api_key is None:
101
+ raise ValueError("OPENAI_API_KEY is not set")
102
+ except Exception:
103
+ raise ValueError("OPENAI_API_KEY is not set")
104
+
105
+ if client_args is None:
106
+ client_args = kwargs.get("client_args", {})
107
+ if "api_key" not in client_args and api_key is not None:
108
+ client_args["api_key"] = api_key
109
+
110
+ return OpenAIModel(client_args=client_args, model_id=model_id, **kwargs)
111
+
112
+
113
+ def create_anthropic_model(
114
+ model_id: str = "claude-sonnet-4-20250514",
115
+ api_key: str = None,
116
+ client_args: dict = None,
117
+ **kwargs: Any,
118
+ ) -> AnthropicModel:
119
+ """Create an Anthropic model.
120
+
121
+ Args:
122
+ model_id: Anthropic model ID. Defaults to "claude-sonnet-4-20250514".
123
+ For a complete list of supported models,
124
+ see https://docs.claude.com/en/docs/about-claude/models/overview.
125
+ api_key: Anthropic API key.
126
+ client_args: Client arguments for the Anthropic model.
127
+ **kwargs: Additional keyword arguments for the Anthropic model.
128
+ """
129
+
130
+ if api_key is None:
131
+ try:
132
+ api_key = os.getenv("ANTHROPIC_API_KEY", None)
133
+ if api_key is None:
134
+ raise ValueError("ANTHROPIC_API_KEY is not set")
135
+ except Exception:
136
+ raise ValueError("ANTHROPIC_API_KEY is not set")
137
+
138
+ if client_args is None:
139
+ client_args = kwargs.get("client_args", {})
140
+ if "api_key" not in client_args and api_key is not None:
141
+ client_args["api_key"] = api_key
142
+
143
+ return AnthropicModel(client_args=client_args, model_id=model_id, **kwargs)
144
+
145
+
146
+ def create_bedrock_model(
147
+ model_id: str = "anthropic.claude-sonnet-4-20250514-v1:0",
148
+ region_name: str = None,
149
+ boto_session: Optional[boto3.Session] = None,
150
+ boto_client_config: Optional[BotocoreConfig] = None,
151
+ **kwargs: Any,
152
+ ) -> BedrockModel:
153
+ """Create a Bedrock model.
154
+
155
+ Args:
156
+ model_id: Bedrock model ID. Run the following command to get the model ID:
157
+ aws bedrock list-foundation-models | jq -r '.modelSummaries[].modelId'
158
+ region_name: Bedrock region name.
159
+ boto_session: Bedrock boto session.
160
+ boto_client_config: Bedrock boto client config.
161
+ **kwargs: Additional keyword arguments for the Bedrock model.
162
+ """
163
+
164
+ return BedrockModel(
165
+ model_id=model_id,
166
+ region_name=region_name,
167
+ boto_session=boto_session,
168
+ boto_client_config=boto_client_config,
169
+ **kwargs,
170
+ )
171
+
172
+
26
173
  def _ensure_loop() -> asyncio.AbstractEventLoop:
27
174
  try:
28
175
  loop = asyncio.get_event_loop()
@@ -39,28 +186,98 @@ class GeoAgent(Agent):
39
186
  """Geospatial AI agent with interactive mapping capabilities."""
40
187
 
41
188
  def __init__(
42
- self, *, model_id: str = "llama3.1", map_instance: Optional[leafmap.Map] = None
189
+ self,
190
+ *,
191
+ model: str = "llama3.1",
192
+ map_instance: Optional[leafmap.Map] = None,
193
+ system_prompt: str = "default",
194
+ model_args: dict = None,
195
+ **kwargs: Any,
43
196
  ) -> None:
44
197
  """Initialize the GeoAgent.
45
198
 
46
199
  Args:
47
- model_id: Ollama model identifier (default: "llama3.1").
200
+ model: Model identifier (default: "llama3.1").
48
201
  map_instance: Optional existing map instance.
202
+ model_args: Additional keyword arguments for the model.
203
+ **kwargs: Additional keyword arguments for the model.
49
204
  """
50
205
  self.session: MapSession = MapSession(map_instance)
51
206
  self.tools: MapTools = MapTools(self.session)
52
207
 
208
+ if model_args is None:
209
+ model_args = {}
210
+
53
211
  # --- save a model factory we can call each turn ---
54
- self._model_factory: Callable[[], OllamaModel] = lambda: OllamaModel(
55
- host="http://localhost:11434", model_id=model_id
56
- )
212
+ if model == "llama3.1":
213
+ self._model_factory: Callable[[], OllamaModel] = (
214
+ lambda: create_ollama_model(
215
+ host="http://localhost:11434", model_id=model, **model_args
216
+ )
217
+ )
218
+ elif isinstance(model, str):
219
+ self._model_factory: Callable[[], BedrockModel] = (
220
+ lambda: create_bedrock_model(model_id=model, **model_args)
221
+ )
222
+ elif isinstance(model, OllamaModel):
223
+ # Extract configuration from existing OllamaModel and create new instances
224
+ model_id = model.config["model_id"]
225
+ host = model.host
226
+ client_args = model.client_args
227
+ self._model_factory: Callable[[], OllamaModel] = (
228
+ lambda: create_ollama_model(
229
+ host=host, model_id=model_id, client_args=client_args, **model_args
230
+ )
231
+ )
232
+ elif isinstance(model, OpenAIModel):
233
+ # Extract configuration from existing OpenAIModel and create new instances
234
+ model_id = model.config["model_id"]
235
+ client_args = model.client_args.copy()
236
+ self._model_factory: Callable[[], OpenAIModel] = (
237
+ lambda mid=model_id, client_args=client_args: create_openai_model(
238
+ model_id=mid, client_args=client_args, **model_args
239
+ )
240
+ )
241
+ elif isinstance(model, AnthropicModel):
242
+ # Extract configuration from existing AnthropicModel and create new instances
243
+ model_id = model.config["model_id"]
244
+ client_args = model.client_args.copy()
245
+ self._model_factory: Callable[[], AnthropicModel] = (
246
+ lambda mid=model_id, client_args=client_args: create_anthropic_model(
247
+ model_id=mid, client_args=client_args, **model_args
248
+ )
249
+ )
250
+ else:
251
+ raise ValueError(f"Invalid model: {model}")
57
252
 
58
253
  # build initial model (first turn)
59
- ollama_model: OllamaModel = self._model_factory()
254
+ model = self._model_factory()
255
+
256
+ if system_prompt == "default":
257
+ system_prompt = """
258
+ You are a map control agent. Call tools with MINIMAL parameters only.
259
+
260
+ CRITICAL: Treat all kwargs parameters as optional parameters.
261
+ CRITICAL: NEVER include optional parameters unless user explicitly asks for them.
262
+
263
+ TOOL CALL RULES:
264
+ - zoom_to(zoom=N) - ONLY zoom parameter, OMIT options completely
265
+ - add_cog_layer(url='X') - NEVER include bands, nodata, opacity, etc.
266
+ - fly_to(longitude=N, latitude=N) - NEVER include zoom parameter
267
+ - add_basemap(name='X') - NEVER include any other parameters
268
+ - add_marker(lng_lat=[lon,lat]) - NEVER include popup or options
269
+
270
+ - remove_layer(name='X') - call get_layer_names() to get the layer name closest to
271
+ the name of the layer you want to remove before calling this tool
272
+
273
+ - add_overture_3d_buildings(kwargs={}) - kwargs parameter required by tool validation
274
+ FORBIDDEN: Optional parameters, string representations like '{}' or '[1,2,3]'
275
+ REQUIRED: Minimal tool calls with only what's absolutely necessary
276
+ """
60
277
 
61
278
  super().__init__(
62
279
  name="Leafmap Visualization Agent",
63
- model=ollama_model,
280
+ model=model,
64
281
  tools=[
65
282
  # Core navigation tools
66
283
  self.tools.fly_to,
@@ -87,36 +304,37 @@ class GeoAgent(Agent):
87
304
  self.tools.add_marker,
88
305
  self.tools.set_pitch,
89
306
  ],
90
- system_prompt="You are a map control agent. Call tools with MINIMAL parameters only.\n\n"
91
- + "CRITICAL: Treat all kwargs parameters as optional parameters.\n"
92
- + "CRITICAL: NEVER include optional parameters unless user explicitly asks for them.\n\n"
93
- + "TOOL CALL RULES:\n"
94
- + "- zoom_to(zoom=N) - ONLY zoom parameter, OMIT options completely\n"
95
- + "- add_cog_layer(url='X') - NEVER include bands, nodata, opacity, etc.\n"
96
- + "- fly_to(longitude=N, latitude=N) - NEVER include zoom parameter\n"
97
- + "- add_basemap(name='X') - NEVER include any other parameters\n"
98
- + "- add_marker(lng_lat=[lon,lat]) - NEVER include popup or options\n\n"
99
- + "- remove_layer(name='X') - call get_layer_names() to get the layer name closest to"
100
- + "the name of the layer you want to remove before calling this tool\n\n"
101
- + "- add_overture_3d_buildings(kwargs={}) - kwargs parameter required by tool validation\n"
102
- + "FORBIDDEN: Optional parameters, string representations like '{}' or '[1,2,3]'\n"
103
- + "REQUIRED: Minimal tool calls with only what's absolutely necessary",
307
+ system_prompt=system_prompt,
104
308
  callback_handler=None,
105
309
  )
106
310
 
107
311
  def ask(self, prompt: str) -> str:
108
312
  """Send a single-turn prompt to the agent.
109
313
 
314
+ Runs entirely on the same thread/event loop as the Agent
315
+ to avoid cross-loop asyncio object issues.
316
+
110
317
  Args:
111
318
  prompt: The text prompt to send to the agent.
112
319
 
113
320
  Returns:
114
321
  The agent's response as a string.
115
322
  """
116
- _ensure_loop()
323
+ # Ensure there's an event loop bound to this thread (Jupyter-safe)
324
+ loop = _ensure_loop()
325
+
326
+ # Preserve existing conversation messages
327
+ existing_messages = self.messages.copy()
328
+
329
+ # Create a fresh model but keep conversation history
117
330
  self.model = self._model_factory()
118
331
 
119
- result = self(prompt)
332
+ # Restore the conversation messages
333
+ self.messages = existing_messages
334
+
335
+ # Execute the prompt using the Agent's async API on this loop
336
+ # Avoid Agent.__call__ since it spins a new thread+loop
337
+ result = loop.run_until_complete(self.invoke_async(prompt))
120
338
  return getattr(result, "final_text", str(result))
121
339
 
122
340
  def show_ui(self, *, height: int = 700) -> None:
@@ -186,7 +404,7 @@ class GeoAgent(Agent):
186
404
  ),
187
405
  (
188
406
  "Add GeoJSON",
189
- "Add vector layer: https://github.com/opengeos/datasets/releases/download/us/us_states.geojson",
407
+ "Add GeoJSON layer: https://github.com/opengeos/datasets/releases/download/us/us_states.geojson",
190
408
  ),
191
409
  ("Remove layer", "Remove layer OpenTopoMap"),
192
410
  ("Save map", "Save the map as demo.html and return the path"),
geoai/agents/map_tools.py CHANGED
@@ -115,7 +115,7 @@ class MapTools:
115
115
  visible: bool = True,
116
116
  bands: Optional[List[int]] = None,
117
117
  nodata: Optional[Union[int, float]] = 0,
118
- titiler_endpoint: str = "https://giswqs-titiler-endpoint.hf.space",
118
+ titiler_endpoint: str = None,
119
119
  ) -> str:
120
120
  """Add a Cloud Optimized GeoTIFF (COG) layer to the map.
121
121
 
@@ -224,7 +224,6 @@ class MapTools:
224
224
 
225
225
  Args:
226
226
  pitch (float): The pitch value to set.
227
- **kwargs (Any): Additional keyword arguments to control the pitch.
228
227
 
229
228
  Returns:
230
229
  None
@@ -496,7 +495,7 @@ class MapTools:
496
495
  array_args=None,
497
496
  client_args={"cors_all": True},
498
497
  overwrite: bool = True,
499
- **kwargs,
498
+ **kwargs: Any,
500
499
  ):
501
500
  """Add a local raster dataset to the map.
502
501
  If you are using this function in JupyterHub on a remote server
@@ -523,12 +522,6 @@ class MapTools:
523
522
  the palette when plotting a single band. Defaults to None.
524
523
  nodata (float, optional): The value from the band to use to interpret
525
524
  as not valid data. Defaults to None.
526
- attribution (str, optional): Attribution for the source raster. This
527
- defaults to a message about it being a local file.. Defaults to None.
528
- layer_name (str, optional): The layer name to use. Defaults to 'Raster'.
529
- layer_index (int, optional): The index of the layer. Defaults to None.
530
- zoom_to_layer (bool, optional): Whether to zoom to the extent of the
531
- layer. Defaults to True.
532
525
  visible (bool, optional): Whether the layer is visible. Defaults to True.
533
526
  opacity (float, optional): The opacity of the layer. Defaults to 1.0.
534
527
  array_args (dict, optional): Additional arguments to pass to
@@ -570,8 +563,8 @@ class MapTools:
570
563
  remove_port: bool = True,
571
564
  preview: bool = False,
572
565
  overwrite: bool = False,
573
- **kwargs,
574
- ):
566
+ **kwargs: Any,
567
+ ) -> str:
575
568
  """Render the map to an HTML page.
576
569
 
577
570
  Args:
@@ -756,7 +749,6 @@ class MapTools:
756
749
  Adds a marker to the map.
757
750
 
758
751
  Args:
759
- marker (Marker, optional): A Marker object. Defaults to None.
760
752
  lng_lat (List[Union[float, float]]): A list of two floats
761
753
  representing the longitude and latitude of the marker.
762
754
  popup (Optional[str], optional): The text to display in a popup when
@@ -891,7 +883,6 @@ class MapTools:
891
883
  Args:
892
884
  zoom (float): The zoom level to zoom to.
893
885
  options (Dict[str, Any], optional): Additional options to control the zoom. Defaults to {}.
894
- **kwargs (Any): Additional keyword arguments to control the zoom.
895
886
 
896
887
  Returns:
897
888
  None
@@ -1059,7 +1050,7 @@ class MapTools:
1059
1050
  dpi: Optional[Union[str, float]] = "figure",
1060
1051
  transparent: Optional[bool] = False,
1061
1052
  position: str = "bottom-right",
1062
- **kwargs,
1053
+ **kwargs: Any,
1063
1054
  ) -> str:
1064
1055
  """
1065
1056
  Add a colorbar to the map.
@@ -1186,7 +1177,9 @@ class MapTools:
1186
1177
  return f"Video added: {layer_id}"
1187
1178
 
1188
1179
  @tool
1189
- def add_nlcd(self, years: list = [2023], add_legend: bool = True, **kwargs) -> None:
1180
+ def add_nlcd(
1181
+ self, years: list = [2023], add_legend: bool = True, **kwargs: Any
1182
+ ) -> None:
1190
1183
  """
1191
1184
  Adds National Land Cover Database (NLCD) data to the map.
1192
1185
 
geoai/dinov3.py CHANGED
@@ -8,22 +8,21 @@ import json
8
8
  import math
9
9
  import os
10
10
  import sys
11
- from typing import Tuple, Optional, Dict, List, Union
11
+ from typing import Dict, List, Optional, Tuple, Union
12
12
 
13
+ import matplotlib.patches as patches
14
+ import matplotlib.pyplot as plt
13
15
  import numpy as np
14
- from PIL import Image
16
+ import rasterio
15
17
  import torch
16
18
  import torch.nn.functional as F
17
19
  import torchvision.transforms as transforms
18
- import rasterio
19
- from rasterio.windows import Window
20
- from rasterio.io import DatasetReader
21
- import matplotlib.pyplot as plt
22
- import matplotlib.patches as patches
23
-
24
20
  from huggingface_hub import hf_hub_download
21
+ from PIL import Image
22
+ from rasterio.io import DatasetReader
23
+ from rasterio.windows import Window
25
24
 
26
- from .utils import get_device, coords_to_xy, dict_to_image, dict_to_rioxarray
25
+ from .utils import coords_to_xy, dict_to_image, dict_to_rioxarray, get_device
27
26
 
28
27
 
29
28
  class DINOv3GeoProcessor:
@@ -47,7 +46,6 @@ class DINOv3GeoProcessor:
47
46
  See https://github.com/facebookresearch/dinov3 for more details.
48
47
  weights_path: Path to model weights (optional)
49
48
  device: Torch device to use
50
- dinov3_location: Path to DINOv3 repository
51
49
  """
52
50
 
53
51
  dinov3_github_location = "facebookresearch/dinov3"
geoai/geoai.py CHANGED
@@ -8,7 +8,6 @@ logging.getLogger("maplibre").setLevel(logging.ERROR)
8
8
  import leafmap
9
9
  import leafmap.maplibregl as maplibregl
10
10
 
11
- from .change_detection import ChangeDetection
12
11
  from .classify import classify_image, classify_images, train_classifier
13
12
  from .download import (
14
13
  download_naip,
@@ -26,6 +25,7 @@ from .download import (
26
25
  )
27
26
  from .extract import *
28
27
  from .hf import *
28
+ from .map_widgets import DINOv3GUI
29
29
  from .segment import *
30
30
  from .train import (
31
31
  get_instance_segmentation_model,
@@ -41,7 +41,6 @@ from .train import (
41
41
  train_segmentation_model,
42
42
  )
43
43
  from .utils import *
44
- from .map_widgets import DINOv3GUI
45
44
 
46
45
 
47
46
  class LeafMap(leafmap.Map):
geoai/map_widgets.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Interactive widget for GeoAI."""
2
2
 
3
3
  import ipywidgets as widgets
4
+
4
5
  from .utils import dict_to_image, dict_to_rioxarray
5
6
 
6
7
 
@@ -28,14 +29,13 @@ class DINOv3GUI(widgets.VBox):
28
29
  colormap_options (list): The colormap options.
29
30
  raster_args (dict): The raster arguments.
30
31
 
31
- Returns:
32
- None
33
-
34
32
  Example:
35
33
  >>> processor = DINOv3GeoProcessor()
36
34
  >>> features, h_patches, w_patches = processor.extract_features(raster)
37
35
  >>> gui = DINOv3GUI(raster, processor, features, host_map=m)
38
36
  """
37
+ super().__init__()
38
+
39
39
  if raster_args is None:
40
40
  raster_args = {}
41
41
 
geoai/utils.py CHANGED
@@ -182,7 +182,6 @@ def view_image(
182
182
  image (Union[np.ndarray, torch.Tensor]): The image to visualize.
183
183
  transpose (bool, optional): Whether to transpose the image. Defaults to False.
184
184
  bdx (Optional[int], optional): The band index to visualize. Defaults to None.
185
- scale_factor (float, optional): The scale factor to apply to the image. Defaults to 1.0.
186
185
  figsize (Tuple[int, int], optional): The size of the figure. Defaults to (10, 5).
187
186
  axis_off (bool, optional): Whether to turn off the axis. Defaults to True.
188
187
  title (Optional[str], optional): The title of the plot. Defaults to None.
@@ -396,9 +395,10 @@ def dict_to_rioxarray(data_dict: Dict) -> xr.DataArray:
396
395
  xr.DataArray: The xarray DataArray.
397
396
  """
398
397
 
399
- from affine import Affine
400
398
  from collections import namedtuple
401
399
 
400
+ from affine import Affine
401
+
402
402
  BoundingBox = namedtuple("BoundingBox", ["minx", "maxx", "miny", "maxy"])
403
403
 
404
404
  # Extract components from the dictionary
@@ -710,7 +710,7 @@ def view_vector_interactive(
710
710
  tiles_args (dict, optional): Additional arguments for the localtileserver client.
711
711
  get_folium_tile_layer function. Defaults to None.
712
712
  **kwargs: Additional keyword arguments to pass to GeoDataFrame.explore() function.
713
- See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
713
+ See https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html
714
714
 
715
715
  Returns:
716
716
  folium.Map: The map object with the vector data added.
@@ -2508,7 +2508,7 @@ def batch_vector_to_raster(
2508
2508
  fill_value=0,
2509
2509
  dtype=np.uint8,
2510
2510
  nodata=None,
2511
- ):
2511
+ ) -> List[str]:
2512
2512
  """
2513
2513
  Batch convert vector data to multiple rasters based on different extents or reference rasters.
2514
2514
 
@@ -2527,7 +2527,7 @@ def batch_vector_to_raster(
2527
2527
  nodata (int): No data value for the output raster.
2528
2528
 
2529
2529
  Returns:
2530
- list: List of paths to the created raster files.
2530
+ List[str]: List of paths to the created raster files.
2531
2531
  """
2532
2532
  # Create output directory if it doesn't exist
2533
2533
  os.makedirs(output_dir, exist_ok=True)
@@ -3128,7 +3128,7 @@ def export_geotiff_tiles_batch(
3128
3128
  skip_empty_tiles=False,
3129
3129
  image_extensions=None,
3130
3130
  mask_extensions=None,
3131
- ):
3131
+ ) -> Dict[str, Any]:
3132
3132
  """
3133
3133
  Export georeferenced GeoTIFF tiles from folders of images and masks.
3134
3134
 
@@ -3156,7 +3156,7 @@ def export_geotiff_tiles_batch(
3156
3156
  mask_extensions (list): List of mask file extensions to process (default: common raster/vector formats)
3157
3157
 
3158
3158
  Returns:
3159
- dict: Dictionary containing batch processing statistics
3159
+ Dict[str, Any]: Dictionary containing batch processing statistics
3160
3160
 
3161
3161
  Raises:
3162
3162
  ValueError: If no images or masks found, or if counts don't match
@@ -3631,7 +3631,7 @@ def _process_image_mask_pair(
3631
3631
 
3632
3632
  def create_overview_image(
3633
3633
  src, tile_coordinates, output_path, tile_size, stride, geojson_path=None
3634
- ):
3634
+ ) -> str:
3635
3635
  """Create an overview image showing all tiles and their status, with optional GeoJSON export.
3636
3636
 
3637
3637
  Args:
@@ -3782,7 +3782,7 @@ def create_overview_image(
3782
3782
 
3783
3783
  def export_tiles_to_geojson(
3784
3784
  tile_coordinates, src, output_path, tile_size=None, stride=None
3785
- ):
3785
+ ) -> str:
3786
3786
  """
3787
3787
  Export tile rectangles directly to GeoJSON without creating an overview image.
3788
3788
 
@@ -4630,14 +4630,14 @@ def export_training_data(
4630
4630
 
4631
4631
 
4632
4632
  def masks_to_vector(
4633
- mask_path,
4634
- output_path=None,
4635
- simplify_tolerance=1.0,
4636
- mask_threshold=0.5,
4637
- min_object_area=100,
4638
- max_object_area=None,
4639
- nms_iou_threshold=0.5,
4640
- ):
4633
+ mask_path: str,
4634
+ output_path: Optional[str] = None,
4635
+ simplify_tolerance: float = 1.0,
4636
+ mask_threshold: float = 0.5,
4637
+ min_object_area: int = 100,
4638
+ max_object_area: Optional[int] = None,
4639
+ nms_iou_threshold: float = 0.5,
4640
+ ) -> Any:
4641
4641
  """
4642
4642
  Convert a building mask GeoTIFF to vector polygons and save as a vector dataset.
4643
4643
 
@@ -4651,7 +4651,7 @@ def masks_to_vector(
4651
4651
  nms_iou_threshold: IoU threshold for non-maximum suppression (default: self.nms_iou_threshold)
4652
4652
 
4653
4653
  Returns:
4654
- GeoDataFrame with building footprints
4654
+ Any: GeoDataFrame with building footprints
4655
4655
  """
4656
4656
  # Set default output path if not provided
4657
4657
  # if output_path is None:
@@ -5654,7 +5654,7 @@ def orthogonalize(
5654
5654
  min_segments=4,
5655
5655
  area_tolerance=0.7,
5656
5656
  detect_triangles=True,
5657
- ):
5657
+ ) -> Any:
5658
5658
  """
5659
5659
  Orthogonalizes object masks in a GeoTIFF file.
5660
5660
 
@@ -5678,7 +5678,7 @@ def orthogonalize(
5678
5678
  detect_triangles (bool, optional): If True, performs additional check to avoid creating triangular shapes.
5679
5679
 
5680
5680
  Returns:
5681
- geopandas.GeoDataFrame: A GeoDataFrame containing the orthogonalized features.
5681
+ Any: A GeoDataFrame containing the orthogonalized features.
5682
5682
  """
5683
5683
 
5684
5684
  from functools import partial
@@ -7085,8 +7085,8 @@ def regularize(
7085
7085
  num_cores: int = 1,
7086
7086
  include_metadata: bool = False,
7087
7087
  output_path: Optional[str] = None,
7088
- **kwargs,
7089
- ) -> gpd.GeoDataFrame:
7088
+ **kwargs: Any,
7089
+ ) -> Any:
7090
7090
  """Regularizes polygon geometries in a GeoDataFrame by aligning edges.
7091
7091
 
7092
7092
  Aligns edges to be parallel or perpendicular (optionally also 45 degrees)
@@ -7522,7 +7522,11 @@ def write_colormap(
7522
7522
 
7523
7523
 
7524
7524
  def plot_performance_metrics(
7525
- history_path: str, figsize: Tuple[int, int] = (15, 5), verbose: bool = True
7525
+ history_path: str,
7526
+ figsize: Tuple[int, int] = (15, 5),
7527
+ verbose: bool = True,
7528
+ save_path: Optional[str] = None,
7529
+ kwargs: Optional[Dict] = None,
7526
7530
  ) -> None:
7527
7531
  """Plot performance metrics from a history object.
7528
7532
 
@@ -7531,6 +7535,8 @@ def plot_performance_metrics(
7531
7535
  figsize: The figure size.
7532
7536
  verbose: Whether to print the best and final metrics.
7533
7537
  """
7538
+ if kwargs is None:
7539
+ kwargs = {}
7534
7540
  history = torch.load(history_path)
7535
7541
 
7536
7542
  # Handle different key naming conventions
@@ -7579,6 +7585,14 @@ def plot_performance_metrics(
7579
7585
  plt.grid(True)
7580
7586
 
7581
7587
  plt.tight_layout()
7588
+
7589
+ if save_path:
7590
+ if "dpi" not in kwargs:
7591
+ kwargs["dpi"] = 150
7592
+ if "bbox_inches" not in kwargs:
7593
+ kwargs["bbox_inches"] = "tight"
7594
+ plt.savefig(save_path, **kwargs)
7595
+
7582
7596
  plt.show()
7583
7597
 
7584
7598
  if verbose:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: geoai-py
3
- Version: 0.12.0
3
+ Version: 0.13.1
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -39,13 +39,13 @@ Requires-Dist: rioxarray
39
39
  Requires-Dist: scikit-image
40
40
  Requires-Dist: scikit-learn
41
41
  Requires-Dist: torch
42
- Requires-Dist: torchange
43
42
  Requires-Dist: torchgeo
44
43
  Requires-Dist: torchinfo
45
44
  Requires-Dist: tqdm
46
45
  Requires-Dist: transformers
47
46
  Provides-Extra: extra
48
47
  Requires-Dist: overturemaps; extra == "extra"
48
+ Requires-Dist: torchange; extra == "extra"
49
49
  Provides-Extra: agents
50
50
  Requires-Dist: strands-agents; extra == "agents"
51
51
  Requires-Dist: strands-agents-tools; extra == "agents"
@@ -0,0 +1,24 @@
1
+ geoai/__init__.py,sha256=HOXMIhkhHbKfAjjyW5KoS3iHaqq9-SUA6Vstr22G5f4,3851
2
+ geoai/change_detection.py,sha256=XkJjMEU1nD8uX3-nQy7NEmz8cukVeSaRxKJHlrv8xPM,59636
3
+ geoai/classify.py,sha256=0DcComVR6vKU4qWtH2oHVeXc7ZTcV0mFvdXRtlNmolo,35637
4
+ geoai/detectron2.py,sha256=dOOFM9M9-6PV8q2A4-mnIPrz7yTo-MpEvDiAW34nl0w,14610
5
+ geoai/dinov3.py,sha256=u4Lulihhvs4wTgi84RjRw8jWQpB8omQSl-dVVryNVus,40377
6
+ geoai/download.py,sha256=B0EwpQFndJknOKmwRfEEnnCJhplOAwcLyNzFuA6FjZs,47633
7
+ geoai/extract.py,sha256=595MBcSaFx-gQLIEv5g3oEM90QA5In4L59GPVgBOlQc,122092
8
+ geoai/geoai.py,sha256=ZnGhcTvXbhqpO98Bmt2c4q09VXEgawn0yF8dqxGrlRg,10066
9
+ geoai/hf.py,sha256=HbfJfpO6XnANKhmFOBvpwULiC65TeMlnLNtyQHHmlKA,17248
10
+ geoai/map_widgets.py,sha256=QLmkILsztNaRXRULHKOd7Glb7S0pEWXSK9-P8S5AuzQ,5856
11
+ geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
12
+ geoai/segment.py,sha256=yBGTxA-ti8lBpk7WVaBOp6yP23HkaulKJQk88acrmZ0,43788
13
+ geoai/segmentation.py,sha256=7yEzBSKCyHW1dNssoK0rdvhxi2IXsIQIFSga817KdI4,11535
14
+ geoai/train.py,sha256=r9eioaBpc2eg6hckkGVI3aGhQZffKas_UVRj-AWruu8,136049
15
+ geoai/utils.py,sha256=lpyhytBeDLiqWz31syeRvpbT5AUn3cOblKU57uDD9sU,301265
16
+ geoai/agents/__init__.py,sha256=NndUtQ5-i8Zuim8CJftCZYKbCvrkDXj9iLVtiBtc_qE,178
17
+ geoai/agents/geo_agents.py,sha256=4tLntKBL_FgTQsUVzReP9acbYotnfjMRc5BYwW9WEyE,21431
18
+ geoai/agents/map_tools.py,sha256=OK5uB0VUHjjUnc-DYRy2CQ__kyUIARSCPBucGabO0Xw,60669
19
+ geoai_py-0.13.1.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
20
+ geoai_py-0.13.1.dist-info/METADATA,sha256=sNcJv-QuPoSMPAahudoWE3Z0BnV2hxwYO5bRKNKoPaA,10345
21
+ geoai_py-0.13.1.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
22
+ geoai_py-0.13.1.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
23
+ geoai_py-0.13.1.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
24
+ geoai_py-0.13.1.dist-info/RECORD,,
@@ -1,24 +0,0 @@
1
- geoai/__init__.py,sha256=DKFI6qeNtIjJ48oO7YWDxyaZmzSlsQ7YeXPypF9mhKI,3851
2
- geoai/change_detection.py,sha256=XkJjMEU1nD8uX3-nQy7NEmz8cukVeSaRxKJHlrv8xPM,59636
3
- geoai/classify.py,sha256=0DcComVR6vKU4qWtH2oHVeXc7ZTcV0mFvdXRtlNmolo,35637
4
- geoai/detectron2.py,sha256=dOOFM9M9-6PV8q2A4-mnIPrz7yTo-MpEvDiAW34nl0w,14610
5
- geoai/dinov3.py,sha256=c1rNvCdGSRvI8Twj-4Eanunxjh411ctIXRo_GpVRahQ,40433
6
- geoai/download.py,sha256=B0EwpQFndJknOKmwRfEEnnCJhplOAwcLyNzFuA6FjZs,47633
7
- geoai/extract.py,sha256=595MBcSaFx-gQLIEv5g3oEM90QA5In4L59GPVgBOlQc,122092
8
- geoai/geoai.py,sha256=6WwUnYQtQmPAiE9AJtdb4Elztc8vcH39nFFgwJQBPVQ,10112
9
- geoai/hf.py,sha256=HbfJfpO6XnANKhmFOBvpwULiC65TeMlnLNtyQHHmlKA,17248
10
- geoai/map_widgets.py,sha256=8S0WCAeH8f1jswtBJHzV_lGaO92er8P58GxxotbKUng,5862
11
- geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
12
- geoai/segment.py,sha256=yBGTxA-ti8lBpk7WVaBOp6yP23HkaulKJQk88acrmZ0,43788
13
- geoai/segmentation.py,sha256=7yEzBSKCyHW1dNssoK0rdvhxi2IXsIQIFSga817KdI4,11535
14
- geoai/train.py,sha256=r9eioaBpc2eg6hckkGVI3aGhQZffKas_UVRj-AWruu8,136049
15
- geoai/utils.py,sha256=Vg8vNuIZ2PvmwPGAIZHN_1dHw4jC9ddjM_GmdkFN1KA,300899
16
- geoai/agents/__init__.py,sha256=MdJH4hA6dMV4mbXpq0wwhO7gMJdb5oDj8RJDJIRSr7A,65
17
- geoai/agents/geo_agents.py,sha256=Yv9WYqjfV5mv8PN1n2ZP6aPhuy0MWDtgJgaVxLzc24A,13771
18
- geoai/agents/map_tools.py,sha256=HrreesNpsj8jytwO3ZA1e2REYQ4nyas_5v23-xb5sdY,61354
19
- geoai_py-0.12.0.dist-info/licenses/LICENSE,sha256=vN2L5U7cZ6ZkOHFmc8WiGlsogWsZc5dllMeNxnKVOZg,1070
20
- geoai_py-0.12.0.dist-info/METADATA,sha256=NJ7ybA_ondbMpU3Jw58hDLOKD5maa_HMT04xnNY_cHI,10327
21
- geoai_py-0.12.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
22
- geoai_py-0.12.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
23
- geoai_py-0.12.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
24
- geoai_py-0.12.0.dist-info/RECORD,,