geoai-py 0.15.0__py2.py3-none-any.whl → 0.17.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +1 -1
- geoai/agents/__init__.py +4 -0
- geoai/agents/catalog_models.py +51 -0
- geoai/agents/catalog_tools.py +907 -0
- geoai/agents/geo_agents.py +925 -41
- geoai/agents/stac_models.py +67 -0
- geoai/agents/stac_tools.py +435 -0
- geoai/change_detection.py +16 -6
- geoai/download.py +5 -1
- geoai/geoai.py +3 -0
- geoai/train.py +573 -31
- geoai/utils.py +752 -208
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/METADATA +2 -1
- geoai_py-0.17.0.dist-info/RECORD +30 -0
- geoai_py-0.15.0.dist-info/RECORD +0 -26
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/WHEEL +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.15.0.dist-info → geoai_py-0.17.0.dist-info}/top_level.txt +0 -0
geoai/agents/geo_agents.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import
|
3
|
+
import json
|
4
4
|
import os
|
5
5
|
import uuid
|
6
|
-
from concurrent.futures import ThreadPoolExecutor
|
7
6
|
from types import SimpleNamespace
|
8
|
-
from typing import Any, Callable, Optional
|
7
|
+
from typing import Any, Callable, Dict, Optional
|
9
8
|
|
10
9
|
import boto3
|
11
10
|
import ipywidgets as widgets
|
@@ -19,14 +18,45 @@ from strands.models.anthropic import AnthropicModel
|
|
19
18
|
from strands.models.ollama import OllamaModel as _OllamaModel
|
20
19
|
from strands.models.openai import OpenAIModel
|
21
20
|
|
21
|
+
from .catalog_tools import CatalogTools
|
22
22
|
from .map_tools import MapSession, MapTools
|
23
|
+
from .stac_tools import STACTools
|
23
24
|
|
24
|
-
try:
|
25
|
-
import nest_asyncio
|
26
25
|
|
27
|
-
|
28
|
-
|
29
|
-
|
26
|
+
class UICallbackHandler:
|
27
|
+
"""Callback handler that updates UI status widget with agent progress.
|
28
|
+
|
29
|
+
This handler intercepts tool calls and progress events to provide
|
30
|
+
real-time feedback without overwhelming the user.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, status_widget=None):
|
34
|
+
"""Initialize the callback handler.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
status_widget: Optional ipywidgets.HTML widget to update with status.
|
38
|
+
"""
|
39
|
+
self.status_widget = status_widget
|
40
|
+
self.current_tool = None
|
41
|
+
|
42
|
+
def __call__(self, **kwargs):
|
43
|
+
"""Handle callback events from the agent.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
**kwargs: Event data from the agent execution.
|
47
|
+
"""
|
48
|
+
# Track when tools are being called
|
49
|
+
if "current_tool_use" in kwargs and kwargs["current_tool_use"].get("name"):
|
50
|
+
tool_name = kwargs["current_tool_use"]["name"]
|
51
|
+
self.current_tool = tool_name
|
52
|
+
|
53
|
+
# Update status widget if available
|
54
|
+
if self.status_widget is not None:
|
55
|
+
# Make tool names more user-friendly
|
56
|
+
friendly_name = tool_name.replace("_", " ").title()
|
57
|
+
self.status_widget.value = (
|
58
|
+
f"<span style='color:#0a7'>⚙️ {friendly_name}...</span>"
|
59
|
+
)
|
30
60
|
|
31
61
|
|
32
62
|
class OllamaModel(_OllamaModel):
|
@@ -170,18 +200,6 @@ def create_bedrock_model(
|
|
170
200
|
)
|
171
201
|
|
172
202
|
|
173
|
-
def _ensure_loop() -> asyncio.AbstractEventLoop:
|
174
|
-
try:
|
175
|
-
loop = asyncio.get_event_loop()
|
176
|
-
except RuntimeError:
|
177
|
-
loop = asyncio.new_event_loop()
|
178
|
-
asyncio.set_event_loop(loop)
|
179
|
-
if loop.is_closed():
|
180
|
-
loop = asyncio.new_event_loop()
|
181
|
-
asyncio.set_event_loop(loop)
|
182
|
-
return loop
|
183
|
-
|
184
|
-
|
185
203
|
class GeoAgent(Agent):
|
186
204
|
"""Geospatial AI agent with interactive mapping capabilities."""
|
187
205
|
|
@@ -311,30 +329,14 @@ class GeoAgent(Agent):
|
|
311
329
|
def ask(self, prompt: str) -> str:
|
312
330
|
"""Send a single-turn prompt to the agent.
|
313
331
|
|
314
|
-
Runs entirely on the same thread/event loop as the Agent
|
315
|
-
to avoid cross-loop asyncio object issues.
|
316
|
-
|
317
332
|
Args:
|
318
333
|
prompt: The text prompt to send to the agent.
|
319
334
|
|
320
335
|
Returns:
|
321
336
|
The agent's response as a string.
|
322
337
|
"""
|
323
|
-
#
|
324
|
-
|
325
|
-
|
326
|
-
# Preserve existing conversation messages
|
327
|
-
existing_messages = self.messages.copy()
|
328
|
-
|
329
|
-
# Create a fresh model but keep conversation history
|
330
|
-
self.model = self._model_factory()
|
331
|
-
|
332
|
-
# Restore the conversation messages
|
333
|
-
self.messages = existing_messages
|
334
|
-
|
335
|
-
# Execute the prompt using the Agent's async API on this loop
|
336
|
-
# Avoid Agent.__call__ since it spins a new thread+loop
|
337
|
-
result = loop.run_until_complete(self.invoke_async(prompt))
|
338
|
+
# Use strands' built-in __call__ method which now supports multiple calls
|
339
|
+
result = self(prompt)
|
338
340
|
return getattr(result, "final_text", str(result))
|
339
341
|
|
340
342
|
def show_ui(self, *, height: int = 700) -> None:
|
@@ -349,7 +351,10 @@ class GeoAgent(Agent):
|
|
349
351
|
m.create_container()
|
350
352
|
|
351
353
|
map_panel = widgets.VBox(
|
352
|
-
[
|
354
|
+
[
|
355
|
+
widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"),
|
356
|
+
m.floating_sidebar_widget,
|
357
|
+
],
|
353
358
|
layout=widgets.Layout(
|
354
359
|
flex="2 1 0%",
|
355
360
|
min_width="520px",
|
@@ -427,7 +432,6 @@ class GeoAgent(Agent):
|
|
427
432
|
examples=examples,
|
428
433
|
)
|
429
434
|
self._pending = {"fut": None}
|
430
|
-
self._executor = ThreadPoolExecutor(max_workers=1)
|
431
435
|
|
432
436
|
def _esc(s: str) -> str:
|
433
437
|
"""Escape HTML characters in a string.
|
@@ -494,7 +498,14 @@ class GeoAgent(Agent):
|
|
494
498
|
_lock(True)
|
495
499
|
self._ui.status.value = "<span style='color:#0a7'>Running…</span>"
|
496
500
|
try:
|
497
|
-
|
501
|
+
# Create a callback handler that updates the status widget
|
502
|
+
callback_handler = UICallbackHandler(status_widget=self._ui.status)
|
503
|
+
|
504
|
+
# Temporarily set callback_handler for this call
|
505
|
+
old_callback = self.callback_handler
|
506
|
+
self.callback_handler = callback_handler
|
507
|
+
|
508
|
+
out = self.ask(text) # fresh Agent/model per call, with callback
|
498
509
|
_append("assistant", out)
|
499
510
|
self._ui.status.value = "<span style='color:#0a7'>Done.</span>"
|
500
511
|
except Exception as e:
|
@@ -503,6 +514,8 @@ class GeoAgent(Agent):
|
|
503
514
|
"<span style='color:#c00'>Finished with an issue.</span>"
|
504
515
|
)
|
505
516
|
finally:
|
517
|
+
# Restore old callback handler
|
518
|
+
self.callback_handler = old_callback
|
506
519
|
self._ui.inp.value = ""
|
507
520
|
_lock(False)
|
508
521
|
|
@@ -593,3 +606,874 @@ class GeoAgent(Agent):
|
|
593
606
|
[map_panel, right], layout=widgets.Layout(width="100%", gap="8px")
|
594
607
|
)
|
595
608
|
display(root)
|
609
|
+
|
610
|
+
|
611
|
+
class STACAgent(Agent):
|
612
|
+
"""AI agent for searching and interacting with STAC catalogs."""
|
613
|
+
|
614
|
+
def __init__(
|
615
|
+
self,
|
616
|
+
*,
|
617
|
+
model: str = "llama3.1",
|
618
|
+
system_prompt: str = "default",
|
619
|
+
endpoint: str = "https://planetarycomputer.microsoft.com/api/stac/v1",
|
620
|
+
model_args: dict = None,
|
621
|
+
map_instance: Optional[leafmap.Map] = None,
|
622
|
+
**kwargs: Any,
|
623
|
+
) -> None:
|
624
|
+
"""Initialize the STAC Agent.
|
625
|
+
|
626
|
+
Args:
|
627
|
+
model: Model identifier (default: "llama3.1").
|
628
|
+
system_prompt: System prompt for the agent (default: "default").
|
629
|
+
endpoint: STAC API endpoint URL. Defaults to Microsoft Planetary Computer.
|
630
|
+
model_args: Additional keyword arguments for the model.
|
631
|
+
map_instance: Optional leafmap.Map instance for visualization. If None, creates a new one.
|
632
|
+
**kwargs: Additional keyword arguments for the Agent.
|
633
|
+
"""
|
634
|
+
self.stac_tools: STACTools = STACTools(endpoint=endpoint)
|
635
|
+
self.map_instance = map_instance if map_instance is not None else leafmap.Map()
|
636
|
+
|
637
|
+
if model_args is None:
|
638
|
+
model_args = {}
|
639
|
+
|
640
|
+
# --- save a model factory we can call each turn ---
|
641
|
+
if model == "llama3.1":
|
642
|
+
self._model_factory: Callable[[], OllamaModel] = (
|
643
|
+
lambda: create_ollama_model(
|
644
|
+
host="http://localhost:11434", model_id=model, **model_args
|
645
|
+
)
|
646
|
+
)
|
647
|
+
elif isinstance(model, str):
|
648
|
+
self._model_factory: Callable[[], BedrockModel] = (
|
649
|
+
lambda: create_bedrock_model(model_id=model, **model_args)
|
650
|
+
)
|
651
|
+
elif isinstance(model, OllamaModel):
|
652
|
+
# Extract configuration from existing OllamaModel and create new instances
|
653
|
+
model_id = model.config["model_id"]
|
654
|
+
host = model.host
|
655
|
+
client_args = model.client_args
|
656
|
+
self._model_factory: Callable[[], OllamaModel] = (
|
657
|
+
lambda: create_ollama_model(
|
658
|
+
host=host, model_id=model_id, client_args=client_args, **model_args
|
659
|
+
)
|
660
|
+
)
|
661
|
+
elif isinstance(model, OpenAIModel):
|
662
|
+
# Extract configuration from existing OpenAIModel and create new instances
|
663
|
+
model_id = model.config["model_id"]
|
664
|
+
client_args = model.client_args.copy()
|
665
|
+
self._model_factory: Callable[[], OpenAIModel] = (
|
666
|
+
lambda mid=model_id, client_args=client_args: create_openai_model(
|
667
|
+
model_id=mid, client_args=client_args, **model_args
|
668
|
+
)
|
669
|
+
)
|
670
|
+
elif isinstance(model, AnthropicModel):
|
671
|
+
# Extract configuration from existing AnthropicModel and create new instances
|
672
|
+
model_id = model.config["model_id"]
|
673
|
+
client_args = model.client_args.copy()
|
674
|
+
self._model_factory: Callable[[], AnthropicModel] = (
|
675
|
+
lambda mid=model_id, client_args=client_args: create_anthropic_model(
|
676
|
+
model_id=mid, client_args=client_args, **model_args
|
677
|
+
)
|
678
|
+
)
|
679
|
+
else:
|
680
|
+
raise ValueError(f"Invalid model: {model}")
|
681
|
+
|
682
|
+
# build initial model (first turn)
|
683
|
+
model = self._model_factory()
|
684
|
+
|
685
|
+
if system_prompt == "default":
|
686
|
+
system_prompt = """You are a STAC search agent. Follow these steps EXACTLY:
|
687
|
+
|
688
|
+
1. Determine collection ID based on data type:
|
689
|
+
- "sentinel-2-l2a" for Sentinel-2 or optical satellite imagery
|
690
|
+
- "landsat-c2-l2" for Landsat
|
691
|
+
- "naip" for NAIP or aerial imagery (USA only)
|
692
|
+
- "sentinel-1-grd" for Sentinel-1 or SAR/radar
|
693
|
+
- "cop-dem-glo-30" for DEM, elevation, or terrain data
|
694
|
+
- "aster-l1t" for ASTER
|
695
|
+
For other data (e.g., MODIS, land cover): call list_collections(filter_keyword="<keyword>")
|
696
|
+
|
697
|
+
2. If location mentioned:
|
698
|
+
- Call geocode_location("<name>") FIRST
|
699
|
+
- WAIT for the response
|
700
|
+
- Extract the "bbox" array from the JSON response
|
701
|
+
- This bbox is [west, south, east, north] format
|
702
|
+
|
703
|
+
3. Call search_items():
|
704
|
+
- collection: REQUIRED
|
705
|
+
- bbox: Use the EXACT bbox array from geocode_location (REQUIRED if location mentioned)
|
706
|
+
- time_range: "YYYY-MM-DD/YYYY-MM-DD" format if dates mentioned
|
707
|
+
- query: Use for cloud cover filtering (see examples)
|
708
|
+
- max_items: 1
|
709
|
+
|
710
|
+
Cloud cover filtering:
|
711
|
+
- "<10% cloud": query={"eo:cloud_cover": {"lt": 10}}
|
712
|
+
- "<20% cloud": query={"eo:cloud_cover": {"lt": 20}}
|
713
|
+
- "<5% cloud": query={"eo:cloud_cover": {"lt": 5}}
|
714
|
+
|
715
|
+
Examples:
|
716
|
+
1. "Find Landsat over Paris from June to July 2023"
|
717
|
+
geocode_location("Paris") → {"bbox": [2.224, 48.815, 2.469, 48.902], ...}
|
718
|
+
search_items(collection="landsat-c2-l2", bbox=[2.224, 48.815, 2.469, 48.902], time_range="2023-06-01/2023-07-31")
|
719
|
+
|
720
|
+
2. "Find Landsat with <10% cloud cover over Paris"
|
721
|
+
geocode_location("Paris") → {"bbox": [2.224, 48.815, 2.469, 48.902], ...}
|
722
|
+
search_items(collection="landsat-c2-l2", bbox=[2.224, 48.815, 2.469, 48.902], query={"eo:cloud_cover": {"lt": 10}})
|
723
|
+
|
724
|
+
4. Return first item as JSON:
|
725
|
+
{"id": "...", "collection": "...", "datetime": "...", "bbox": [...], "assets": [...]}
|
726
|
+
|
727
|
+
ERROR HANDLING:
|
728
|
+
- If no items found: {"error": "No items found"}
|
729
|
+
- If tool result too large: {"error": "Result too large, try narrower search"}
|
730
|
+
- If tool error: {"error": "Search failed: <error message>"}
|
731
|
+
|
732
|
+
CRITICAL: Return ONLY JSON. NO explanatory text, NO made-up data."""
|
733
|
+
|
734
|
+
super().__init__(
|
735
|
+
name="STAC Search Agent",
|
736
|
+
model=model,
|
737
|
+
tools=[
|
738
|
+
self.stac_tools.list_collections,
|
739
|
+
self.stac_tools.search_items,
|
740
|
+
self.stac_tools.get_item_info,
|
741
|
+
self.stac_tools.geocode_location,
|
742
|
+
self.stac_tools.get_common_collections,
|
743
|
+
],
|
744
|
+
system_prompt=system_prompt,
|
745
|
+
callback_handler=None,
|
746
|
+
)
|
747
|
+
|
748
|
+
def _extract_search_items_payload(self, result: Any) -> Optional[Dict[str, Any]]:
|
749
|
+
"""Return the parsed payload from the search_items tool, if available."""
|
750
|
+
# Try to get tool_results from the result object
|
751
|
+
tool_results = getattr(result, "tool_results", None)
|
752
|
+
if tool_results:
|
753
|
+
for tool_result in tool_results:
|
754
|
+
if getattr(tool_result, "tool_name", "") != "search_items":
|
755
|
+
continue
|
756
|
+
|
757
|
+
payload = getattr(tool_result, "result", None)
|
758
|
+
if payload is None:
|
759
|
+
continue
|
760
|
+
|
761
|
+
if isinstance(payload, str):
|
762
|
+
try:
|
763
|
+
payload = json.loads(payload)
|
764
|
+
except json.JSONDecodeError:
|
765
|
+
continue
|
766
|
+
|
767
|
+
if isinstance(payload, dict):
|
768
|
+
return payload
|
769
|
+
|
770
|
+
# Alternative: check messages for tool results
|
771
|
+
messages = getattr(self, "messages", [])
|
772
|
+
for msg in reversed(messages): # Check recent messages first
|
773
|
+
# Handle dict-style messages
|
774
|
+
if isinstance(msg, dict):
|
775
|
+
role = msg.get("role")
|
776
|
+
content = msg.get("content", [])
|
777
|
+
|
778
|
+
# Look for tool results in user messages (strands pattern)
|
779
|
+
if role == "user" and isinstance(content, list):
|
780
|
+
for item in content:
|
781
|
+
if isinstance(item, dict) and "toolResult" in item:
|
782
|
+
tool_result = item["toolResult"]
|
783
|
+
# Check if this is a search_items result
|
784
|
+
# We need to look at the preceding assistant message to identify the tool
|
785
|
+
if tool_result.get("status") == "success":
|
786
|
+
result_content = tool_result.get("content", [])
|
787
|
+
if isinstance(result_content, list) and result_content:
|
788
|
+
text_content = result_content[0].get("text", "")
|
789
|
+
try:
|
790
|
+
payload = json.loads(text_content)
|
791
|
+
# Return ANY search_items payload, even if items is empty
|
792
|
+
# This is identified by having "query" and "collection" fields
|
793
|
+
if (
|
794
|
+
"query" in payload
|
795
|
+
and "collection" in payload
|
796
|
+
and "items" in payload
|
797
|
+
):
|
798
|
+
return payload
|
799
|
+
except json.JSONDecodeError:
|
800
|
+
continue
|
801
|
+
|
802
|
+
return None
|
803
|
+
|
804
|
+
def ask(self, prompt: str) -> str:
|
805
|
+
"""Send a single-turn prompt to the agent.
|
806
|
+
|
807
|
+
Args:
|
808
|
+
prompt: The text prompt to send to the agent.
|
809
|
+
|
810
|
+
Returns:
|
811
|
+
The agent's response as a string (JSON format for search queries).
|
812
|
+
"""
|
813
|
+
# Use strands' built-in __call__ method which now supports multiple calls
|
814
|
+
result = self(prompt)
|
815
|
+
|
816
|
+
search_payload = self._extract_search_items_payload(result)
|
817
|
+
if search_payload is not None:
|
818
|
+
if "error" in search_payload:
|
819
|
+
return json.dumps({"error": search_payload["error"]}, indent=2)
|
820
|
+
|
821
|
+
items = search_payload.get("items") or []
|
822
|
+
if items:
|
823
|
+
return json.dumps(items[0], indent=2)
|
824
|
+
|
825
|
+
return json.dumps({"error": "No items found"}, indent=2)
|
826
|
+
|
827
|
+
return getattr(result, "final_text", str(result))
|
828
|
+
|
829
|
+
def search_and_get_first_item(self, prompt: str) -> Optional[Dict[str, Any]]:
|
830
|
+
"""Search for imagery and return the first item as a structured dict.
|
831
|
+
|
832
|
+
This method sends a search query to the agent, extracts the search results
|
833
|
+
directly from the tool calls, and returns the first item as a STACItemInfo-compatible
|
834
|
+
dictionary.
|
835
|
+
|
836
|
+
Note: This method uses LLM inference which adds ~5-10 seconds overhead.
|
837
|
+
For faster searches, use STACTools directly:
|
838
|
+
>>> from geoai.agents import STACTools
|
839
|
+
>>> tools = STACTools()
|
840
|
+
>>> result = tools.search_items(
|
841
|
+
... collection="sentinel-2-l2a",
|
842
|
+
... bbox=[-122.5, 37.7, -122.4, 37.8],
|
843
|
+
... time_range="2024-08-01/2024-08-31"
|
844
|
+
... )
|
845
|
+
|
846
|
+
Args:
|
847
|
+
prompt: Natural language search query (e.g., "Find Sentinel-2 imagery
|
848
|
+
over San Francisco in September 2024").
|
849
|
+
|
850
|
+
Returns:
|
851
|
+
Dictionary containing STACItemInfo fields (id, collection, datetime,
|
852
|
+
bbox, assets, properties), or None if no results found.
|
853
|
+
|
854
|
+
Example:
|
855
|
+
>>> agent = STACAgent()
|
856
|
+
>>> item = agent.search_and_get_first_item(
|
857
|
+
... "Find Sentinel-2 imagery over Paris in summer 2023"
|
858
|
+
... )
|
859
|
+
>>> print(item['id'])
|
860
|
+
>>> print(item['assets'][0]['key']) # or 'title'
|
861
|
+
"""
|
862
|
+
# Use strands' built-in __call__ method which now supports multiple calls
|
863
|
+
result = self(prompt)
|
864
|
+
|
865
|
+
search_payload = self._extract_search_items_payload(result)
|
866
|
+
if search_payload is not None:
|
867
|
+
if "error" in search_payload:
|
868
|
+
print(f"Search error: {search_payload['error']}")
|
869
|
+
return None
|
870
|
+
|
871
|
+
items = search_payload.get("items") or []
|
872
|
+
if items:
|
873
|
+
return items[0]
|
874
|
+
|
875
|
+
print("No items found in search results")
|
876
|
+
return None
|
877
|
+
|
878
|
+
# Fallback: try to parse the final text response
|
879
|
+
response = getattr(result, "final_text", str(result))
|
880
|
+
|
881
|
+
try:
|
882
|
+
item_data = json.loads(response)
|
883
|
+
|
884
|
+
if "error" in item_data:
|
885
|
+
print(f"Search error: {item_data['error']}")
|
886
|
+
return None
|
887
|
+
|
888
|
+
if not all(k in item_data for k in ["id", "collection"]):
|
889
|
+
print("Response missing required fields (id, collection)")
|
890
|
+
return None
|
891
|
+
|
892
|
+
return item_data
|
893
|
+
|
894
|
+
except json.JSONDecodeError:
|
895
|
+
print("Could not extract item data from agent response")
|
896
|
+
return None
|
897
|
+
|
898
|
+
def _visualize_stac_item(self, item: Dict[str, Any]) -> None:
|
899
|
+
"""Visualize a STAC item on the map.
|
900
|
+
|
901
|
+
Args:
|
902
|
+
item: STAC item dictionary with id, collection, assets, etc.
|
903
|
+
"""
|
904
|
+
if not item or "id" not in item or "collection" not in item:
|
905
|
+
return
|
906
|
+
|
907
|
+
# Get the collection and item ID
|
908
|
+
collection = item.get("collection")
|
909
|
+
item_id = item.get("id")
|
910
|
+
|
911
|
+
kwargs = {}
|
912
|
+
|
913
|
+
# Determine which assets to display based on collection
|
914
|
+
assets = None
|
915
|
+
if collection == "sentinel-2-l2a":
|
916
|
+
assets = ["B04", "B03", "B02"] # True color RGB
|
917
|
+
elif collection == "landsat-c2-l2":
|
918
|
+
assets = ["red", "green", "blue"] # Landsat RGB
|
919
|
+
elif collection == "naip":
|
920
|
+
assets = ["image"] # NAIP 4-band imagery
|
921
|
+
elif "sentinel-1" in collection:
|
922
|
+
assets = ["vv"] # Sentinel-1 VV polarization
|
923
|
+
elif collection == "cop-dem-glo-30":
|
924
|
+
assets = ["data"]
|
925
|
+
kwargs["colormap_name"] = "terrain"
|
926
|
+
elif collection == "aster-l1t":
|
927
|
+
assets = ["VNIR"] # ASTER L1T imagery
|
928
|
+
elif collection == "3dep-lidar-hag":
|
929
|
+
assets = ["data"]
|
930
|
+
kwargs["colormap_name"] = "terrain"
|
931
|
+
else:
|
932
|
+
# Try to find common asset names
|
933
|
+
if "assets" in item:
|
934
|
+
asset_keys = [
|
935
|
+
a.get("key") for a in item["assets"] if isinstance(a, dict)
|
936
|
+
]
|
937
|
+
# Look for visual or RGB assets
|
938
|
+
for possible in ["visual", "image", "data"]:
|
939
|
+
if possible in asset_keys:
|
940
|
+
assets = [possible]
|
941
|
+
break
|
942
|
+
# If still no assets, use first few assets
|
943
|
+
if not assets and asset_keys:
|
944
|
+
assets = asset_keys[:1]
|
945
|
+
|
946
|
+
if not assets:
|
947
|
+
return None
|
948
|
+
try:
|
949
|
+
# Add the STAC layer to the map
|
950
|
+
layer_name = f"{collection[:20]}_{item_id[:15]}"
|
951
|
+
self.map_instance.add_stac_layer(
|
952
|
+
collection=collection,
|
953
|
+
item=item_id,
|
954
|
+
assets=assets,
|
955
|
+
name=layer_name,
|
956
|
+
before_id=self.map_instance.first_symbol_layer_id,
|
957
|
+
**kwargs,
|
958
|
+
)
|
959
|
+
return assets # Return the assets that were visualized
|
960
|
+
except Exception as e:
|
961
|
+
print(f"Could not visualize item on map: {e}")
|
962
|
+
return None
|
963
|
+
|
964
|
+
def show_ui(self, *, height: int = 700) -> None:
|
965
|
+
"""Display an interactive UI with map and chat interface for STAC searches.
|
966
|
+
|
967
|
+
Args:
|
968
|
+
height: Height of the UI in pixels (default: 700).
|
969
|
+
"""
|
970
|
+
m = self.map_instance
|
971
|
+
if not hasattr(m, "container") or m.container is None:
|
972
|
+
m.create_container()
|
973
|
+
|
974
|
+
map_panel = widgets.VBox(
|
975
|
+
[
|
976
|
+
widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"),
|
977
|
+
m.floating_sidebar_widget,
|
978
|
+
],
|
979
|
+
layout=widgets.Layout(
|
980
|
+
flex="2 1 0%",
|
981
|
+
min_width="520px",
|
982
|
+
border="1px solid #ddd",
|
983
|
+
padding="8px",
|
984
|
+
height=f"{height}px",
|
985
|
+
overflow="hidden",
|
986
|
+
),
|
987
|
+
)
|
988
|
+
|
989
|
+
# ----- chat widgets -----
|
990
|
+
session_id = str(uuid.uuid4())[:8]
|
991
|
+
title = widgets.HTML(
|
992
|
+
f"<h3 style='margin:0'>STAC Search Agent</h3>"
|
993
|
+
f"<p style='margin:4px 0 8px;color:#666'>Session: {session_id}</p>"
|
994
|
+
)
|
995
|
+
log = widgets.HTML(
|
996
|
+
value="<div style='color:#777'>No messages yet. Try searching for satellite imagery!</div>",
|
997
|
+
layout=widgets.Layout(
|
998
|
+
border="1px solid #ddd",
|
999
|
+
padding="8px",
|
1000
|
+
height="520px",
|
1001
|
+
overflow_y="auto",
|
1002
|
+
),
|
1003
|
+
)
|
1004
|
+
inp = widgets.Textarea(
|
1005
|
+
placeholder="Search for satellite/aerial imagery (e.g., 'Find Sentinel-2 imagery over Paris in summer 2024')",
|
1006
|
+
layout=widgets.Layout(width="99%", height="90px"),
|
1007
|
+
)
|
1008
|
+
btn_send = widgets.Button(
|
1009
|
+
description="Search",
|
1010
|
+
button_style="primary",
|
1011
|
+
icon="search",
|
1012
|
+
layout=widgets.Layout(width="120px"),
|
1013
|
+
)
|
1014
|
+
btn_stop = widgets.Button(
|
1015
|
+
description="Stop", icon="stop", layout=widgets.Layout(width="120px")
|
1016
|
+
)
|
1017
|
+
btn_clear = widgets.Button(
|
1018
|
+
description="Clear", icon="trash", layout=widgets.Layout(width="120px")
|
1019
|
+
)
|
1020
|
+
status = widgets.HTML("<span style='color:#666'>Ready to search.</span>")
|
1021
|
+
|
1022
|
+
examples = widgets.Dropdown(
|
1023
|
+
options=[
|
1024
|
+
("— Example Searches —", ""),
|
1025
|
+
(
|
1026
|
+
"Sentinel-2 over Las Vegas",
|
1027
|
+
"Find Sentinel-2 imagery over Las Vegas in August 2025",
|
1028
|
+
),
|
1029
|
+
(
|
1030
|
+
"Landsat over Paris",
|
1031
|
+
"Find Landsat imagery over Paris from June to July 2025",
|
1032
|
+
),
|
1033
|
+
(
|
1034
|
+
"Landsat with <10% cloud cover",
|
1035
|
+
"Find Landsat imagery over Paris with <10% cloud cover in June 2025",
|
1036
|
+
),
|
1037
|
+
("NAIP for NYC", "Show me NAIP aerial imagery for New York City"),
|
1038
|
+
("DEM for Seattle", "Show me DEM data for Seattle"),
|
1039
|
+
(
|
1040
|
+
"3DEP Lidar HAG",
|
1041
|
+
"Show me data over Austin from 3dep-lidar-hag collection",
|
1042
|
+
),
|
1043
|
+
("ASTER for Tokyo", "Show me ASTER imagery for Tokyo"),
|
1044
|
+
],
|
1045
|
+
value="",
|
1046
|
+
layout=widgets.Layout(width="auto"),
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
# --- state kept on self so it persists ---
|
1050
|
+
self._ui = SimpleNamespace(
|
1051
|
+
messages=[],
|
1052
|
+
map_panel=map_panel,
|
1053
|
+
title=title,
|
1054
|
+
log=log,
|
1055
|
+
inp=inp,
|
1056
|
+
btn_send=btn_send,
|
1057
|
+
btn_stop=btn_stop,
|
1058
|
+
btn_clear=btn_clear,
|
1059
|
+
status=status,
|
1060
|
+
examples=examples,
|
1061
|
+
)
|
1062
|
+
self._pending = {"fut": None}
|
1063
|
+
|
1064
|
+
def _esc(s: str) -> str:
|
1065
|
+
"""Escape HTML characters in a string."""
|
1066
|
+
return (
|
1067
|
+
s.replace("&", "&")
|
1068
|
+
.replace("<", "<")
|
1069
|
+
.replace(">", ">")
|
1070
|
+
.replace("\n", "<br/>")
|
1071
|
+
)
|
1072
|
+
|
1073
|
+
def _append(role: str, msg: str) -> None:
|
1074
|
+
"""Append a message to the chat log."""
|
1075
|
+
self._ui.messages.append((role, msg))
|
1076
|
+
parts = []
|
1077
|
+
for r, mm in self._ui.messages:
|
1078
|
+
if r == "user":
|
1079
|
+
parts.append(
|
1080
|
+
f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#eef;'><b>You</b>: {_esc(mm)}</div>"
|
1081
|
+
)
|
1082
|
+
else:
|
1083
|
+
parts.append(
|
1084
|
+
f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#f7f7f7;'><b>Agent</b>: {_esc(mm)}</div>"
|
1085
|
+
)
|
1086
|
+
self._ui.log.value = (
|
1087
|
+
"<div>"
|
1088
|
+
+ (
|
1089
|
+
"".join(parts)
|
1090
|
+
if parts
|
1091
|
+
else "<div style='color:#777'>No messages yet.</div>"
|
1092
|
+
)
|
1093
|
+
+ "</div>"
|
1094
|
+
)
|
1095
|
+
|
1096
|
+
def _lock(lock: bool) -> None:
|
1097
|
+
"""Lock or unlock UI controls."""
|
1098
|
+
self._ui.btn_send.disabled = lock
|
1099
|
+
self._ui.btn_stop.disabled = not lock
|
1100
|
+
self._ui.btn_clear.disabled = lock
|
1101
|
+
self._ui.inp.disabled = lock
|
1102
|
+
self._ui.examples.disabled = lock
|
1103
|
+
|
1104
|
+
def _on_send(_: Any = None) -> None:
|
1105
|
+
"""Handle send button click or Enter key press."""
|
1106
|
+
text = self._ui.inp.value.strip()
|
1107
|
+
if not text:
|
1108
|
+
return
|
1109
|
+
_append("user", text)
|
1110
|
+
_lock(True)
|
1111
|
+
self._ui.status.value = "<span style='color:#0a7'>Searching…</span>"
|
1112
|
+
try:
|
1113
|
+
# Create a callback handler that updates the status widget
|
1114
|
+
callback_handler = UICallbackHandler(status_widget=self._ui.status)
|
1115
|
+
|
1116
|
+
# Temporarily set callback_handler for this call
|
1117
|
+
old_callback = self.callback_handler
|
1118
|
+
self.callback_handler = callback_handler
|
1119
|
+
|
1120
|
+
# Get the structured search result directly (will show progress via callback)
|
1121
|
+
item_data = self.search_and_get_first_item(text)
|
1122
|
+
|
1123
|
+
if item_data is not None:
|
1124
|
+
# Update status for visualization step
|
1125
|
+
self._ui.status.value = (
|
1126
|
+
"<span style='color:#0a7'>Adding to map...</span>"
|
1127
|
+
)
|
1128
|
+
|
1129
|
+
# Visualize on map
|
1130
|
+
visualized_assets = self._visualize_stac_item(item_data)
|
1131
|
+
|
1132
|
+
# Format response for display
|
1133
|
+
formatted_response = (
|
1134
|
+
f"Found item: {item_data['id']}\n"
|
1135
|
+
f"Collection: {item_data['collection']}\n"
|
1136
|
+
f"Date: {item_data.get('datetime', 'N/A')}\n"
|
1137
|
+
)
|
1138
|
+
|
1139
|
+
if visualized_assets:
|
1140
|
+
assets_str = ", ".join(visualized_assets)
|
1141
|
+
formatted_response += f"✓ Added to map (assets: {assets_str})"
|
1142
|
+
else:
|
1143
|
+
formatted_response += "✓ Added to map"
|
1144
|
+
|
1145
|
+
_append("assistant", formatted_response)
|
1146
|
+
else:
|
1147
|
+
_append(
|
1148
|
+
"assistant",
|
1149
|
+
"No items found. Try adjusting your search query or date range.",
|
1150
|
+
)
|
1151
|
+
|
1152
|
+
self._ui.status.value = "<span style='color:#0a7'>Done.</span>"
|
1153
|
+
except Exception as e:
|
1154
|
+
_append("assistant", f"[error] {type(e).__name__}: {e}")
|
1155
|
+
self._ui.status.value = (
|
1156
|
+
"<span style='color:#c00'>Finished with an issue.</span>"
|
1157
|
+
)
|
1158
|
+
finally:
|
1159
|
+
# Restore old callback handler
|
1160
|
+
self.callback_handler = old_callback
|
1161
|
+
self._ui.inp.value = ""
|
1162
|
+
_lock(False)
|
1163
|
+
|
1164
|
+
def _on_stop(_: Any = None) -> None:
|
1165
|
+
"""Handle stop button click."""
|
1166
|
+
fut = self._pending.get("fut")
|
1167
|
+
if fut and not fut.done():
|
1168
|
+
self._pending["fut"] = None
|
1169
|
+
self._ui.status.value = (
|
1170
|
+
"<span style='color:#c00'>Stop requested.</span>"
|
1171
|
+
)
|
1172
|
+
_lock(False)
|
1173
|
+
|
1174
|
+
def _on_clear(_: Any = None) -> None:
|
1175
|
+
"""Handle clear button click."""
|
1176
|
+
self._ui.messages.clear()
|
1177
|
+
self._ui.log.value = "<div style='color:#777'>No messages yet.</div>"
|
1178
|
+
self._ui.status.value = "<span style='color:#666'>Cleared.</span>"
|
1179
|
+
|
1180
|
+
def _on_example_change(change: dict[str, Any]) -> None:
|
1181
|
+
"""Handle example dropdown selection change."""
|
1182
|
+
if change["name"] == "value" and change["new"]:
|
1183
|
+
self._ui.inp.value = change["new"]
|
1184
|
+
self._ui.examples.value = ""
|
1185
|
+
self._ui.inp.send({"method": "focus"})
|
1186
|
+
|
1187
|
+
# keep handler refs
|
1188
|
+
self._handlers = SimpleNamespace(
|
1189
|
+
on_send=_on_send,
|
1190
|
+
on_stop=_on_stop,
|
1191
|
+
on_clear=_on_clear,
|
1192
|
+
on_example_change=_on_example_change,
|
1193
|
+
)
|
1194
|
+
|
1195
|
+
# wire events
|
1196
|
+
self._ui.btn_send.on_click(self._handlers.on_send)
|
1197
|
+
self._ui.btn_stop.on_click(self._handlers.on_stop)
|
1198
|
+
self._ui.btn_clear.on_click(self._handlers.on_clear)
|
1199
|
+
self._ui.examples.observe(self._handlers.on_example_change, names="value")
|
1200
|
+
|
1201
|
+
# Ctrl+Enter on textarea
|
1202
|
+
self._keyev = Event(
|
1203
|
+
source=self._ui.inp, watched_events=["keyup"], prevent_default_action=False
|
1204
|
+
)
|
1205
|
+
|
1206
|
+
def _on_key(e: dict[str, Any]) -> None:
|
1207
|
+
"""Handle keyboard events on the input textarea."""
|
1208
|
+
if (
|
1209
|
+
e.get("type") == "keyup"
|
1210
|
+
and e.get("key") == "Enter"
|
1211
|
+
and e.get("ctrlKey")
|
1212
|
+
):
|
1213
|
+
if self._ui.inp.value.endswith("\n"):
|
1214
|
+
self._ui.inp.value = self._ui.inp.value[:-1]
|
1215
|
+
self._handlers.on_send()
|
1216
|
+
|
1217
|
+
# store callback too
|
1218
|
+
self._on_key_cb: Callable[[dict[str, Any]], None] = _on_key
|
1219
|
+
self._keyev.on_dom_event(self._on_key_cb)
|
1220
|
+
|
1221
|
+
buttons = widgets.HBox(
|
1222
|
+
[
|
1223
|
+
self._ui.btn_send,
|
1224
|
+
self._ui.btn_stop,
|
1225
|
+
self._ui.btn_clear,
|
1226
|
+
widgets.Box(
|
1227
|
+
[self._ui.examples], layout=widgets.Layout(margin="0 0 0 auto")
|
1228
|
+
),
|
1229
|
+
]
|
1230
|
+
)
|
1231
|
+
right = widgets.VBox(
|
1232
|
+
[
|
1233
|
+
self._ui.title,
|
1234
|
+
self._ui.log,
|
1235
|
+
self._ui.inp,
|
1236
|
+
buttons,
|
1237
|
+
self._ui.status,
|
1238
|
+
],
|
1239
|
+
layout=widgets.Layout(flex="1 1 0%", min_width="360px"),
|
1240
|
+
)
|
1241
|
+
root = widgets.HBox(
|
1242
|
+
[map_panel, right], layout=widgets.Layout(width="100%", gap="8px")
|
1243
|
+
)
|
1244
|
+
display(root)
|
1245
|
+
|
1246
|
+
|
1247
|
+
class CatalogAgent(Agent):
|
1248
|
+
"""AI agent for searching data catalogs with natural language queries."""
|
1249
|
+
|
1250
|
+
def __init__(
|
1251
|
+
self,
|
1252
|
+
*,
|
1253
|
+
model: str = "llama3.1",
|
1254
|
+
system_prompt: str = "default",
|
1255
|
+
catalog_url: Optional[str] = None,
|
1256
|
+
catalog_df: Optional[Any] = None,
|
1257
|
+
model_args: dict = None,
|
1258
|
+
**kwargs: Any,
|
1259
|
+
) -> None:
|
1260
|
+
"""Initialize the Catalog Agent.
|
1261
|
+
|
1262
|
+
Args:
|
1263
|
+
model: Model identifier (default: "llama3.1").
|
1264
|
+
system_prompt: System prompt for the agent (default: "default").
|
1265
|
+
catalog_url: URL to a catalog file (TSV, CSV, or JSON). Use JSON format for spatial search support.
|
1266
|
+
Example: "https://raw.githubusercontent.com/opengeos/Earth-Engine-Catalog/refs/heads/master/gee_catalog.json"
|
1267
|
+
catalog_df: Pre-loaded catalog as a pandas DataFrame.
|
1268
|
+
model_args: Additional keyword arguments for the model.
|
1269
|
+
**kwargs: Additional keyword arguments for the Agent.
|
1270
|
+
"""
|
1271
|
+
self.catalog_tools: CatalogTools = CatalogTools(
|
1272
|
+
catalog_url=catalog_url, catalog_df=catalog_df
|
1273
|
+
)
|
1274
|
+
|
1275
|
+
if model_args is None:
|
1276
|
+
model_args = {}
|
1277
|
+
|
1278
|
+
# --- save a model factory we can call each turn ---
|
1279
|
+
if model == "llama3.1":
|
1280
|
+
self._model_factory: Callable[[], OllamaModel] = (
|
1281
|
+
lambda: create_ollama_model(
|
1282
|
+
host="http://localhost:11434", model_id=model, **model_args
|
1283
|
+
)
|
1284
|
+
)
|
1285
|
+
elif isinstance(model, str):
|
1286
|
+
self._model_factory: Callable[[], BedrockModel] = (
|
1287
|
+
lambda: create_bedrock_model(model_id=model, **model_args)
|
1288
|
+
)
|
1289
|
+
elif isinstance(model, OllamaModel):
|
1290
|
+
# Extract configuration from existing OllamaModel and create new instances
|
1291
|
+
model_id = model.config["model_id"]
|
1292
|
+
host = model.host
|
1293
|
+
client_args = model.client_args
|
1294
|
+
self._model_factory: Callable[[], OllamaModel] = (
|
1295
|
+
lambda: create_ollama_model(
|
1296
|
+
host=host, model_id=model_id, client_args=client_args, **model_args
|
1297
|
+
)
|
1298
|
+
)
|
1299
|
+
elif isinstance(model, OpenAIModel):
|
1300
|
+
# Extract configuration from existing OpenAIModel and create new instances
|
1301
|
+
model_id = model.config["model_id"]
|
1302
|
+
client_args = model.client_args.copy()
|
1303
|
+
self._model_factory: Callable[[], OpenAIModel] = (
|
1304
|
+
lambda mid=model_id, client_args=client_args: create_openai_model(
|
1305
|
+
model_id=mid, client_args=client_args, **model_args
|
1306
|
+
)
|
1307
|
+
)
|
1308
|
+
elif isinstance(model, AnthropicModel):
|
1309
|
+
# Extract configuration from existing AnthropicModel and create new instances
|
1310
|
+
model_id = model.config["model_id"]
|
1311
|
+
client_args = model.client_args.copy()
|
1312
|
+
self._model_factory: Callable[[], AnthropicModel] = (
|
1313
|
+
lambda mid=model_id, client_args=client_args: create_anthropic_model(
|
1314
|
+
model_id=mid, client_args=client_args, **model_args
|
1315
|
+
)
|
1316
|
+
)
|
1317
|
+
else:
|
1318
|
+
raise ValueError(f"Invalid model: {model}")
|
1319
|
+
|
1320
|
+
# build initial model (first turn)
|
1321
|
+
model = self._model_factory()
|
1322
|
+
|
1323
|
+
if system_prompt == "default":
|
1324
|
+
system_prompt = """You are a data catalog search agent. Your job is to help users find datasets from a data catalog.
|
1325
|
+
|
1326
|
+
IMPORTANT: Follow these steps EXACTLY:
|
1327
|
+
|
1328
|
+
1. Understand the user's query:
|
1329
|
+
- What type of data are they looking for? (e.g., landcover, elevation, imagery)
|
1330
|
+
- Are they searching for a specific geographic region? (e.g., California, San Francisco, bounding box)
|
1331
|
+
- Are they filtering by time period? (e.g., "from 2020", "between 2015-2020", "recent data")
|
1332
|
+
- Are they filtering by provider? (e.g., NASA, USGS)
|
1333
|
+
- Are they filtering by dataset type? (e.g., image, image_collection, table)
|
1334
|
+
|
1335
|
+
2. Use the appropriate tool:
|
1336
|
+
- search_by_region: PREFERRED for spatial queries - search datasets covering a geographic region
|
1337
|
+
* Use location parameter for place names (e.g., "California", "San Francisco")
|
1338
|
+
* Use bbox parameter for coordinates [west, south, east, north]
|
1339
|
+
* Can combine with keywords, dataset_type, provider, start_date, end_date filters
|
1340
|
+
- search_datasets: For keyword-only searches without spatial filter
|
1341
|
+
* Can filter by start_date and end_date for temporal queries
|
1342
|
+
- geocode_location: Convert location names to coordinates (called automatically by search_by_region)
|
1343
|
+
- get_dataset_info: Get details about a specific dataset by ID
|
1344
|
+
- list_dataset_types: Show available dataset types
|
1345
|
+
- list_providers: Show available data providers
|
1346
|
+
- get_catalog_stats: Get overall catalog statistics
|
1347
|
+
|
1348
|
+
3. Search strategy:
|
1349
|
+
- SPATIAL QUERIES: If user mentions ANY location or region, IMMEDIATELY use search_by_region
|
1350
|
+
* Pass location names directly to the location parameter - DO NOT ask user for bbox coordinates
|
1351
|
+
* Examples of locations: California, San Francisco, New York, Paris, any city/state/country name
|
1352
|
+
* search_by_region will automatically geocode location names - you don't need to call geocode_location separately
|
1353
|
+
- TEMPORAL QUERIES: If user mentions ANY time period, ALWAYS add start_date/end_date parameters
|
1354
|
+
* "from 2022" or "since 2022" or "2022 onwards" → start_date="2022-01-01"
|
1355
|
+
* "until 2023" or "before 2023" → end_date="2023-12-31"
|
1356
|
+
* "between 2020 and 2023" → start_date="2020-01-01", end_date="2023-12-31"
|
1357
|
+
* "recent" or "latest" → start_date="2020-01-01"
|
1358
|
+
* Time indicators: from, since, after, before, until, between, onwards, recent, latest
|
1359
|
+
- KEYWORD QUERIES: If no location mentioned, use search_datasets
|
1360
|
+
- Extract key search terms from the user's query
|
1361
|
+
- Use keywords parameter for the main search terms
|
1362
|
+
- Use dataset_type parameter if user specifies type (image, table, etc.)
|
1363
|
+
- Use provider parameter if user specifies provider (NASA, USGS, etc.)
|
1364
|
+
- Default max_results is 10, but can be adjusted
|
1365
|
+
|
1366
|
+
CRITICAL RULES:
|
1367
|
+
1. NEVER ask the user to provide bbox coordinates. If they mention a location name, pass it directly to search_by_region(location="name")
|
1368
|
+
2. ALWAYS add start_date or end_date when user mentions ANY time period (from, since, onwards, recent, etc.)
|
1369
|
+
3. Convert years to YYYY-MM-DD format: 2022 → "2022-01-01"
|
1370
|
+
|
1371
|
+
4. Examples:
|
1372
|
+
- "Find landcover datasets covering California" → search_by_region(location="California", keywords="landcover")
|
1373
|
+
- "Show elevation data for San Francisco" → search_by_region(location="San Francisco", keywords="elevation")
|
1374
|
+
- "Find datasets in bbox [-122, 37, -121, 38]" → search_by_region(bbox=[-122, 37, -121, 38])
|
1375
|
+
- "Find landcover datasets from NASA" → search_datasets(keywords="landcover", provider="NASA")
|
1376
|
+
- "Show me elevation data" → search_datasets(keywords="elevation")
|
1377
|
+
- "What types of datasets are available?" → list_dataset_types()
|
1378
|
+
- "Find image collections about forests" → search_datasets(keywords="forest", dataset_type="image_collection")
|
1379
|
+
- "Find landcover data from 2020 onwards" → search_datasets(keywords="landcover", start_date="2020-01-01")
|
1380
|
+
- "Show California datasets between 2015 and 2020" → search_by_region(location="California", start_date="2015-01-01", end_date="2020-12-31")
|
1381
|
+
- "Find recent elevation data" → search_datasets(keywords="elevation", start_date="2020-01-01")
|
1382
|
+
|
1383
|
+
5. Return results clearly:
|
1384
|
+
- Summarize the number of results found
|
1385
|
+
- List the top results with their EXACT IDs and titles FROM THE TOOL RESPONSE
|
1386
|
+
- Mention key information like provider, type, geographic coverage, and date range if available
|
1387
|
+
- For spatial searches, mention the search region
|
1388
|
+
|
1389
|
+
ERROR HANDLING:
|
1390
|
+
- If no results found: Suggest trying different keywords, broader region, or removing filters
|
1391
|
+
- If location not found: Suggest alternative spellings or try a broader region
|
1392
|
+
- If tool error: Explain the error and suggest alternatives
|
1393
|
+
|
1394
|
+
CRITICAL RULES - MUST FOLLOW:
|
1395
|
+
1. NEVER make up or hallucinate dataset IDs, titles, or any other information
|
1396
|
+
2. ONLY report datasets that appear in the actual tool response
|
1397
|
+
3. Copy dataset IDs and titles EXACTLY as they appear in the tool response
|
1398
|
+
4. If a field is null/None in the tool response, say "N/A" or omit it - DO NOT guess
|
1399
|
+
5. DO NOT use your training data knowledge about Earth Engine datasets
|
1400
|
+
6. DO NOT fill in missing information from your knowledge
|
1401
|
+
7. If unsure, say "Information not available in results"
|
1402
|
+
|
1403
|
+
Example of CORRECT behavior:
|
1404
|
+
Tool returns: {"id": "AAFC/ACI", "title": "Canada AAFC Annual Crop Inventory"}
|
1405
|
+
Your response: "Found dataset: AAFC/ACI - Canada AAFC Annual Crop Inventory"
|
1406
|
+
|
1407
|
+
Example of INCORRECT behavior (DO NOT DO THIS):
|
1408
|
+
Tool returns: {"id": "AAFC/ACI", "title": "Canada AAFC Annual Crop Inventory"}
|
1409
|
+
Your response: "Found dataset: USGS/NED - USGS Elevation Data" ← WRONG! This ID wasn't in the tool response!"""
|
1410
|
+
|
1411
|
+
super().__init__(
|
1412
|
+
name="Catalog Search Agent",
|
1413
|
+
model=model,
|
1414
|
+
tools=[
|
1415
|
+
self.catalog_tools.search_datasets,
|
1416
|
+
self.catalog_tools.search_by_region,
|
1417
|
+
self.catalog_tools.get_dataset_info,
|
1418
|
+
self.catalog_tools.geocode_location,
|
1419
|
+
self.catalog_tools.list_dataset_types,
|
1420
|
+
self.catalog_tools.list_providers,
|
1421
|
+
self.catalog_tools.get_catalog_stats,
|
1422
|
+
],
|
1423
|
+
system_prompt=system_prompt,
|
1424
|
+
callback_handler=None,
|
1425
|
+
)
|
1426
|
+
|
1427
|
+
def ask(self, prompt: str) -> str:
|
1428
|
+
"""Send a single-turn prompt to the agent.
|
1429
|
+
|
1430
|
+
Args:
|
1431
|
+
prompt: The text prompt to send to the agent.
|
1432
|
+
|
1433
|
+
Returns:
|
1434
|
+
The agent's response as a string.
|
1435
|
+
"""
|
1436
|
+
# Use strands' built-in __call__ method which now supports multiple calls
|
1437
|
+
result = self(prompt)
|
1438
|
+
return getattr(result, "final_text", str(result))
|
1439
|
+
|
1440
|
+
def search_datasets(
|
1441
|
+
self,
|
1442
|
+
keywords: Optional[str] = None,
|
1443
|
+
dataset_type: Optional[str] = None,
|
1444
|
+
provider: Optional[str] = None,
|
1445
|
+
max_results: int = 10,
|
1446
|
+
) -> List[Dict[str, Any]]:
|
1447
|
+
"""Search for datasets and return structured results.
|
1448
|
+
|
1449
|
+
This method directly uses the CatalogTools without LLM inference for faster searches.
|
1450
|
+
|
1451
|
+
Args:
|
1452
|
+
keywords: Keywords to search for.
|
1453
|
+
dataset_type: Filter by dataset type.
|
1454
|
+
provider: Filter by provider.
|
1455
|
+
max_results: Maximum number of results to return.
|
1456
|
+
|
1457
|
+
Returns:
|
1458
|
+
List of dataset dictionaries.
|
1459
|
+
|
1460
|
+
Example:
|
1461
|
+
>>> agent = CatalogAgent(catalog_url="...")
|
1462
|
+
>>> datasets = agent.search_datasets(keywords="landcover", provider="NASA")
|
1463
|
+
>>> for ds in datasets:
|
1464
|
+
... print(ds['id'], ds['title'])
|
1465
|
+
"""
|
1466
|
+
result_json = self.catalog_tools.search_datasets(
|
1467
|
+
keywords=keywords,
|
1468
|
+
dataset_type=dataset_type,
|
1469
|
+
provider=provider,
|
1470
|
+
max_results=max_results,
|
1471
|
+
)
|
1472
|
+
|
1473
|
+
result = json.loads(result_json)
|
1474
|
+
|
1475
|
+
if "error" in result:
|
1476
|
+
print(f"Search error: {result['error']}")
|
1477
|
+
return []
|
1478
|
+
|
1479
|
+
return result.get("datasets", [])
|