geoai-py 0.11.1__py2.py3-none-any.whl → 0.13.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.11.1"
5
+ __version__ = "0.13.0"
6
6
 
7
7
 
8
8
  import os
@@ -99,5 +99,5 @@ def set_proj_lib_path(verbose=False):
99
99
  # if ("google.colab" not in sys.modules) and (sys.platform != "windows"):
100
100
  # set_proj_lib_path()
101
101
 
102
+ from .dinov3 import DINOv3GeoProcessor, analyze_image_patches, create_similarity_map
102
103
  from .geoai import *
103
- from .dinov3 import DINOv3GeoProcessor, create_similarity_map, analyze_image_patches
@@ -0,0 +1,8 @@
1
+ from .geo_agents import (
2
+ GeoAgent,
3
+ create_ollama_model,
4
+ create_anthropic_model,
5
+ create_openai_model,
6
+ create_bedrock_model,
7
+ )
8
+ from .map_tools import MapTools
@@ -0,0 +1,580 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import os
5
+ import uuid
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from types import SimpleNamespace
8
+ from typing import Any, Callable, Optional
9
+
10
+ import boto3
11
+ import ipywidgets as widgets
12
+ import leafmap.maplibregl as leafmap
13
+ from botocore.config import Config as BotocoreConfig
14
+ from ipyevents import Event
15
+ from IPython.display import display
16
+ from strands import Agent
17
+ from strands.models import BedrockModel
18
+ from strands.models.anthropic import AnthropicModel
19
+ from strands.models.ollama import OllamaModel as _OllamaModel
20
+ from strands.models.openai import OpenAIModel
21
+
22
+ from .map_tools import MapSession, MapTools
23
+
24
+ try:
25
+ import nest_asyncio
26
+
27
+ nest_asyncio.apply()
28
+ except Exception:
29
+ pass
30
+
31
+
32
+ class OllamaModel(_OllamaModel):
33
+ """Fixed OllamaModel that ensures proper model_id handling."""
34
+
35
+ async def stream(self, *args, **kwargs):
36
+ """Override stream to ensure model_id is passed as string."""
37
+ # Patch the ollama client to handle model object correctly
38
+ import ollama
39
+
40
+ # Save original method if not already saved
41
+ if not hasattr(ollama.AsyncClient, "_original_chat"):
42
+ ollama.AsyncClient._original_chat = ollama.AsyncClient.chat
43
+
44
+ async def fixed_chat(self, **chat_kwargs):
45
+ # If model is an OllamaModel object, extract the model_id
46
+ if "model" in chat_kwargs and hasattr(chat_kwargs["model"], "config"):
47
+ chat_kwargs["model"] = chat_kwargs["model"].config["model_id"]
48
+ return await ollama.AsyncClient._original_chat(self, **chat_kwargs)
49
+
50
+ ollama.AsyncClient.chat = fixed_chat
51
+
52
+ # Call the original stream method
53
+ async for chunk in super().stream(*args, **kwargs):
54
+ yield chunk
55
+
56
+
57
+ def create_ollama_model(
58
+ host: str = "http://localhost:11434",
59
+ model_id: str = "llama3.1",
60
+ client_args: dict = None,
61
+ **kwargs: Any,
62
+ ) -> OllamaModel:
63
+ """Create an Ollama model.
64
+
65
+ Args:
66
+ host: Ollama host URL.
67
+ model_id: Ollama model ID.
68
+ client_args: Client arguments for the Ollama model.
69
+ **kwargs: Additional keyword arguments for the Ollama model.
70
+
71
+ Returns:
72
+ OllamaModel: An Ollama model.
73
+ """
74
+ if client_args is None:
75
+ client_args = {}
76
+ return OllamaModel(host=host, model_id=model_id, client_args=client_args, **kwargs)
77
+
78
+
79
+ def create_openai_model(
80
+ model_id: str = "gpt-4o-mini",
81
+ api_key: str = None,
82
+ client_args: dict = None,
83
+ **kwargs: Any,
84
+ ) -> OpenAIModel:
85
+ """Create an OpenAI model.
86
+
87
+ Args:
88
+ model_id: OpenAI model ID.
89
+ api_key: OpenAI API key.
90
+ client_args: Client arguments for the OpenAI model.
91
+ **kwargs: Additional keyword arguments for the OpenAI model.
92
+
93
+ Returns:
94
+ OpenAIModel: An OpenAI model.
95
+ """
96
+
97
+ if api_key is None:
98
+ try:
99
+ api_key = os.getenv("OPENAI_API_KEY", None)
100
+ if api_key is None:
101
+ raise ValueError("OPENAI_API_KEY is not set")
102
+ except Exception:
103
+ raise ValueError("OPENAI_API_KEY is not set")
104
+
105
+ if client_args is None:
106
+ client_args = kwargs.get("client_args", {})
107
+ if "api_key" not in client_args and api_key is not None:
108
+ client_args["api_key"] = api_key
109
+
110
+ return OpenAIModel(client_args=client_args, model_id=model_id, **kwargs)
111
+
112
+
113
+ def create_anthropic_model(
114
+ model_id: str = "claude-sonnet-4-20250514",
115
+ api_key: str = None,
116
+ client_args: dict = None,
117
+ **kwargs: Any,
118
+ ) -> AnthropicModel:
119
+ """Create an Anthropic model.
120
+
121
+ Args:
122
+ model_id: Anthropic model ID. Defaults to "claude-sonnet-4-20250514".
123
+ For a complete list of supported models,
124
+ see https://docs.claude.com/en/docs/about-claude/models/overview.
125
+ api_key: Anthropic API key.
126
+ client_args: Client arguments for the Anthropic model.
127
+ **kwargs: Additional keyword arguments for the Anthropic model.
128
+ """
129
+
130
+ if api_key is None:
131
+ try:
132
+ api_key = os.getenv("ANTHROPIC_API_KEY", None)
133
+ if api_key is None:
134
+ raise ValueError("ANTHROPIC_API_KEY is not set")
135
+ except Exception:
136
+ raise ValueError("ANTHROPIC_API_KEY is not set")
137
+
138
+ if client_args is None:
139
+ client_args = kwargs.get("client_args", {})
140
+ if "api_key" not in client_args and api_key is not None:
141
+ client_args["api_key"] = api_key
142
+
143
+ return AnthropicModel(client_args=client_args, model_id=model_id, **kwargs)
144
+
145
+
146
+ def create_bedrock_model(
147
+ model_id: str = "anthropic.claude-sonnet-4-20250514-v1:0",
148
+ region_name: str = None,
149
+ boto_session: Optional[boto3.Session] = None,
150
+ boto_client_config: Optional[BotocoreConfig] = None,
151
+ **kwargs: Any,
152
+ ) -> BedrockModel:
153
+ """Create a Bedrock model.
154
+
155
+ Args:
156
+ model_id: Bedrock model ID. Run the following command to get the model ID:
157
+ aws bedrock list-foundation-models | jq -r '.modelSummaries[].modelId'
158
+ region_name: Bedrock region name.
159
+ boto_session: Bedrock boto session.
160
+ boto_client_config: Bedrock boto client config.
161
+ **kwargs: Additional keyword arguments for the Bedrock model.
162
+ """
163
+
164
+ return BedrockModel(
165
+ model_id=model_id,
166
+ region_name=region_name,
167
+ boto_session=boto_session,
168
+ boto_client_config=boto_client_config,
169
+ **kwargs,
170
+ )
171
+
172
+
173
+ def _ensure_loop() -> asyncio.AbstractEventLoop:
174
+ try:
175
+ loop = asyncio.get_event_loop()
176
+ except RuntimeError:
177
+ loop = asyncio.new_event_loop()
178
+ asyncio.set_event_loop(loop)
179
+ if loop.is_closed():
180
+ loop = asyncio.new_event_loop()
181
+ asyncio.set_event_loop(loop)
182
+ return loop
183
+
184
+
185
+ class GeoAgent(Agent):
186
+ """Geospatial AI agent with interactive mapping capabilities."""
187
+
188
+ def __init__(
189
+ self,
190
+ *,
191
+ model: str = "llama3.1",
192
+ map_instance: Optional[leafmap.Map] = None,
193
+ **kwargs: Any,
194
+ ) -> None:
195
+ """Initialize the GeoAgent.
196
+
197
+ Args:
198
+ model: Model identifier (default: "llama3.1").
199
+ map_instance: Optional existing map instance.
200
+ **kwargs: Additional keyword arguments for the model.
201
+ """
202
+ self.session: MapSession = MapSession(map_instance)
203
+ self.tools: MapTools = MapTools(self.session)
204
+
205
+ # --- save a model factory we can call each turn ---
206
+ if model == "llama3.1":
207
+ self._model_factory: Callable[[], OllamaModel] = (
208
+ lambda: create_ollama_model(
209
+ host="http://localhost:11434", model_id=model, **kwargs
210
+ )
211
+ )
212
+ elif isinstance(model, str):
213
+ self._model_factory: Callable[[], BedrockModel] = (
214
+ lambda: create_bedrock_model(model_id=model, **kwargs)
215
+ )
216
+ elif isinstance(model, OllamaModel):
217
+ # Extract configuration from existing OllamaModel and create new instances
218
+ model_id = model.config["model_id"]
219
+ host = model.host
220
+ client_args = model.client_args
221
+ self._model_factory: Callable[[], OllamaModel] = (
222
+ lambda: create_ollama_model(
223
+ host=host, model_id=model_id, client_args=client_args, **kwargs
224
+ )
225
+ )
226
+ elif isinstance(model, OpenAIModel):
227
+ # Extract configuration from existing OpenAIModel and create new instances
228
+ model_id = model.config["model_id"]
229
+ client_args = model.client_args.copy()
230
+ self._model_factory: Callable[[], OpenAIModel] = (
231
+ lambda mid=model_id, client_args=client_args: create_openai_model(
232
+ model_id=mid, client_args=client_args, **kwargs
233
+ )
234
+ )
235
+ elif isinstance(model, AnthropicModel):
236
+ # Extract configuration from existing AnthropicModel and create new instances
237
+ model_id = model.config["model_id"]
238
+ client_args = model.client_args.copy()
239
+ self._model_factory: Callable[[], AnthropicModel] = (
240
+ lambda mid=model_id, client_args=client_args: create_anthropic_model(
241
+ model_id=mid, client_args=client_args, **kwargs
242
+ )
243
+ )
244
+ else:
245
+ raise ValueError(f"Invalid model: {model}")
246
+
247
+ # build initial model (first turn)
248
+ model = self._model_factory()
249
+
250
+ super().__init__(
251
+ name="Leafmap Visualization Agent",
252
+ model=model,
253
+ tools=[
254
+ # Core navigation tools
255
+ self.tools.fly_to,
256
+ self.tools.create_map,
257
+ self.tools.zoom_to,
258
+ self.tools.jump_to,
259
+ # Essential layer tools
260
+ self.tools.add_basemap,
261
+ self.tools.add_vector,
262
+ self.tools.add_raster,
263
+ self.tools.add_cog_layer,
264
+ self.tools.remove_layer,
265
+ self.tools.get_layer_names,
266
+ self.tools.set_terrain,
267
+ self.tools.remove_terrain,
268
+ self.tools.add_overture_3d_buildings,
269
+ self.tools.set_paint_property,
270
+ self.tools.set_layout_property,
271
+ self.tools.set_color,
272
+ self.tools.set_opacity,
273
+ self.tools.set_visibility,
274
+ # self.tools.save_map,
275
+ # Basic interaction tools
276
+ self.tools.add_marker,
277
+ self.tools.set_pitch,
278
+ ],
279
+ system_prompt="You are a map control agent. Call tools with MINIMAL parameters only.\n\n"
280
+ + "CRITICAL: Treat all kwargs parameters as optional parameters.\n"
281
+ + "CRITICAL: NEVER include optional parameters unless user explicitly asks for them.\n\n"
282
+ + "TOOL CALL RULES:\n"
283
+ + "- zoom_to(zoom=N) - ONLY zoom parameter, OMIT options completely\n"
284
+ + "- add_cog_layer(url='X') - NEVER include bands, nodata, opacity, etc.\n"
285
+ + "- fly_to(longitude=N, latitude=N) - NEVER include zoom parameter\n"
286
+ + "- add_basemap(name='X') - NEVER include any other parameters\n"
287
+ + "- add_marker(lng_lat=[lon,lat]) - NEVER include popup or options\n\n"
288
+ + "- remove_layer(name='X') - call get_layer_names() to get the layer name closest to"
289
+ + "the name of the layer you want to remove before calling this tool\n\n"
290
+ + "- add_overture_3d_buildings(kwargs={}) - kwargs parameter required by tool validation\n"
291
+ + "FORBIDDEN: Optional parameters, string representations like '{}' or '[1,2,3]'\n"
292
+ + "REQUIRED: Minimal tool calls with only what's absolutely necessary",
293
+ callback_handler=None,
294
+ )
295
+
296
+ def ask(self, prompt: str) -> str:
297
+ """Send a single-turn prompt to the agent.
298
+
299
+ Runs entirely on the same thread/event loop as the Agent
300
+ to avoid cross-loop asyncio object issues.
301
+
302
+ Args:
303
+ prompt: The text prompt to send to the agent.
304
+
305
+ Returns:
306
+ The agent's response as a string.
307
+ """
308
+ # Ensure there's an event loop bound to this thread (Jupyter-safe)
309
+ loop = _ensure_loop()
310
+
311
+ # Preserve existing conversation messages
312
+ existing_messages = self.messages.copy()
313
+
314
+ # Create a fresh model but keep conversation history
315
+ self.model = self._model_factory()
316
+
317
+ # Restore the conversation messages
318
+ self.messages = existing_messages
319
+
320
+ # Execute the prompt using the Agent's async API on this loop
321
+ # Avoid Agent.__call__ since it spins a new thread+loop
322
+ result = loop.run_until_complete(self.invoke_async(prompt))
323
+ return getattr(result, "final_text", str(result))
324
+
325
+ def show_ui(self, *, height: int = 700) -> None:
326
+ """Display an interactive UI with map and chat interface.
327
+
328
+ Args:
329
+ height: Height of the UI in pixels (default: 700).
330
+ """
331
+
332
+ m = self.tools.session.m
333
+ if not hasattr(m, "container") or m.container is None:
334
+ m.create_container()
335
+
336
+ map_panel = widgets.VBox(
337
+ [widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"), m.container],
338
+ layout=widgets.Layout(
339
+ flex="2 1 0%",
340
+ min_width="520px",
341
+ border="1px solid #ddd",
342
+ padding="8px",
343
+ height=f"{height}px",
344
+ overflow="hidden",
345
+ ),
346
+ )
347
+
348
+ # ----- chat widgets -----
349
+ session_id = str(uuid.uuid4())[:8]
350
+ title = widgets.HTML(
351
+ f"<h3 style='margin:0'>Chatbot</h3>"
352
+ f"<p style='margin:4px 0 8px;color:#666'>Session: {session_id}</p>"
353
+ )
354
+ log = widgets.HTML(
355
+ value="<div style='color:#777'>No messages yet.</div>",
356
+ layout=widgets.Layout(
357
+ border="1px solid #ddd",
358
+ padding="8px",
359
+ height="520px",
360
+ overflow_y="auto",
361
+ ),
362
+ )
363
+ inp = widgets.Textarea(
364
+ placeholder="Ask to add/remove/list layers, set basemap, save the map, etc.",
365
+ layout=widgets.Layout(width="99%", height="90px"),
366
+ )
367
+ btn_send = widgets.Button(
368
+ description="Send",
369
+ button_style="primary",
370
+ icon="paper-plane",
371
+ layout=widgets.Layout(width="120px"),
372
+ )
373
+ btn_stop = widgets.Button(
374
+ description="Stop", icon="stop", layout=widgets.Layout(width="120px")
375
+ )
376
+ btn_clear = widgets.Button(
377
+ description="Clear", icon="trash", layout=widgets.Layout(width="120px")
378
+ )
379
+ status = widgets.HTML("<span style='color:#666'>Ready.</span>")
380
+
381
+ examples = widgets.Dropdown(
382
+ options=[
383
+ ("— Examples —", ""),
384
+ ("Fly to", "Fly to Chicago"),
385
+ ("Add basemap", "Add basemap OpenTopoMap"),
386
+ (
387
+ "Add COG layer",
388
+ "Add COG layer https://huggingface.co/datasets/giswqs/geospatial/resolve/main/naip_rgb_train.tif",
389
+ ),
390
+ (
391
+ "Add GeoJSON",
392
+ "Add vector layer: https://github.com/opengeos/datasets/releases/download/us/us_states.geojson",
393
+ ),
394
+ ("Remove layer", "Remove layer OpenTopoMap"),
395
+ ("Save map", "Save the map as demo.html and return the path"),
396
+ ],
397
+ value="",
398
+ layout=widgets.Layout(width="auto"),
399
+ )
400
+
401
+ # --- state kept on self so it persists ---
402
+ self._ui = SimpleNamespace(
403
+ messages=[],
404
+ map_panel=map_panel,
405
+ title=title,
406
+ log=log,
407
+ inp=inp,
408
+ btn_send=btn_send,
409
+ btn_stop=btn_stop,
410
+ btn_clear=btn_clear,
411
+ status=status,
412
+ examples=examples,
413
+ )
414
+ self._pending = {"fut": None}
415
+ self._executor = ThreadPoolExecutor(max_workers=1)
416
+
417
+ def _esc(s: str) -> str:
418
+ """Escape HTML characters in a string.
419
+
420
+ Args:
421
+ s: Input string to escape.
422
+
423
+ Returns:
424
+ HTML-escaped string.
425
+ """
426
+ return (
427
+ s.replace("&", "&amp;")
428
+ .replace("<", "&lt;")
429
+ .replace(">", "&gt;")
430
+ .replace("\n", "<br/>")
431
+ )
432
+
433
+ def _append(role: str, msg: str) -> None:
434
+ """Append a message to the chat log.
435
+
436
+ Args:
437
+ role: Role of the message sender ("user" or "assistant").
438
+ msg: Message content.
439
+ """
440
+ self._ui.messages.append((role, msg))
441
+ parts = []
442
+ for r, mm in self._ui.messages:
443
+ if r == "user":
444
+ parts.append(
445
+ f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#eef;'><b>You</b>: {_esc(mm)}</div>"
446
+ )
447
+ else:
448
+ parts.append(
449
+ f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#f7f7f7;'><b>Agent</b>: {_esc(mm)}</div>"
450
+ )
451
+ self._ui.log.value = (
452
+ "<div>"
453
+ + (
454
+ "".join(parts)
455
+ if parts
456
+ else "<div style='color:#777'>No messages yet.</div>"
457
+ )
458
+ + "</div>"
459
+ )
460
+
461
+ def _lock(lock: bool) -> None:
462
+ """Lock or unlock UI controls.
463
+
464
+ Args:
465
+ lock: True to lock controls, False to unlock.
466
+ """
467
+ self._ui.btn_send.disabled = lock
468
+ self._ui.btn_stop.disabled = not lock
469
+ self._ui.btn_clear.disabled = lock
470
+ self._ui.inp.disabled = lock
471
+ self._ui.examples.disabled = lock
472
+
473
+ def _on_send(_: Any = None) -> None:
474
+ """Handle send button click or Enter key press."""
475
+ text = self._ui.inp.value.strip()
476
+ if not text:
477
+ return
478
+ _append("user", text)
479
+ _lock(True)
480
+ self._ui.status.value = "<span style='color:#0a7'>Running…</span>"
481
+ try:
482
+ out = self.ask(text) # fresh Agent/model per call, silent
483
+ _append("assistant", out)
484
+ self._ui.status.value = "<span style='color:#0a7'>Done.</span>"
485
+ except Exception as e:
486
+ _append("assistant", f"[error] {type(e).__name__}: {e}")
487
+ self._ui.status.value = (
488
+ "<span style='color:#c00'>Finished with an issue.</span>"
489
+ )
490
+ finally:
491
+ self._ui.inp.value = ""
492
+ _lock(False)
493
+
494
+ def _on_stop(_: Any = None) -> None:
495
+ """Handle stop button click."""
496
+ fut = self._pending.get("fut")
497
+ if fut and not fut.done():
498
+ self._pending["fut"] = None
499
+ self._ui.status.value = "<span style='color:#c00'>Stop requested. If it finishes, result will be ignored.</span>"
500
+ _lock(False)
501
+
502
+ def _on_clear(_: Any = None) -> None:
503
+ """Handle clear button click."""
504
+ self._ui.messages.clear()
505
+ self._ui.log.value = "<div style='color:#777'>No messages yet.</div>"
506
+ self._ui.status.value = "<span style='color:#666'>Cleared.</span>"
507
+
508
+ def _on_example_change(change: dict[str, Any]) -> None:
509
+ """Handle example dropdown selection change.
510
+
511
+ Args:
512
+ change: Change event dictionary from the dropdown widget.
513
+ """
514
+ if change["name"] == "value" and change["new"]:
515
+ self._ui.inp.value = change["new"]
516
+ self._ui.examples.value = ""
517
+ self._ui.inp.send({"method": "focus"})
518
+
519
+ # keep handler refs
520
+ self._handlers = SimpleNamespace(
521
+ on_send=_on_send,
522
+ on_stop=_on_stop,
523
+ on_clear=_on_clear,
524
+ on_example_change=_on_example_change,
525
+ )
526
+
527
+ # wire events
528
+ self._ui.btn_send.on_click(self._handlers.on_send)
529
+ self._ui.btn_stop.on_click(self._handlers.on_stop)
530
+ self._ui.btn_clear.on_click(self._handlers.on_clear)
531
+ self._ui.examples.observe(self._handlers.on_example_change, names="value")
532
+
533
+ # Ctrl+Enter on textarea (keyup only; do not block defaults)
534
+ self._keyev = Event(
535
+ source=self._ui.inp, watched_events=["keyup"], prevent_default_action=False
536
+ )
537
+
538
+ def _on_key(e: dict[str, Any]) -> None:
539
+ """Handle keyboard events on the input textarea.
540
+
541
+ Args:
542
+ e: Keyboard event dictionary.
543
+ """
544
+ if (
545
+ e.get("type") == "keyup"
546
+ and e.get("key") == "Enter"
547
+ and e.get("ctrlKey")
548
+ ):
549
+ if self._ui.inp.value.endswith("\n"):
550
+ self._ui.inp.value = self._ui.inp.value[:-1]
551
+ self._handlers.on_send()
552
+
553
+ # store callback too
554
+ self._on_key_cb: Callable[[dict[str, Any]], None] = _on_key
555
+ self._keyev.on_dom_event(self._on_key_cb)
556
+
557
+ buttons = widgets.HBox(
558
+ [
559
+ self._ui.btn_send,
560
+ self._ui.btn_stop,
561
+ self._ui.btn_clear,
562
+ widgets.Box(
563
+ [self._ui.examples], layout=widgets.Layout(margin="0 0 0 auto")
564
+ ),
565
+ ]
566
+ )
567
+ right = widgets.VBox(
568
+ [
569
+ self._ui.title if hasattr(self._ui, "title") else title,
570
+ self._ui.log,
571
+ self._ui.inp,
572
+ buttons,
573
+ self._ui.status,
574
+ ],
575
+ layout=widgets.Layout(flex="1 1 0%", min_width="360px"),
576
+ )
577
+ root = widgets.HBox(
578
+ [map_panel, right], layout=widgets.Layout(width="100%", gap="8px")
579
+ )
580
+ display(root)