geoai-py 0.15.0__py2.py3-none-any.whl → 0.17.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
3
+ import json
4
4
  import os
5
5
  import uuid
6
- from concurrent.futures import ThreadPoolExecutor
7
6
  from types import SimpleNamespace
8
- from typing import Any, Callable, Optional
7
+ from typing import Any, Callable, Dict, Optional
9
8
 
10
9
  import boto3
11
10
  import ipywidgets as widgets
@@ -19,14 +18,45 @@ from strands.models.anthropic import AnthropicModel
19
18
  from strands.models.ollama import OllamaModel as _OllamaModel
20
19
  from strands.models.openai import OpenAIModel
21
20
 
21
+ from .catalog_tools import CatalogTools
22
22
  from .map_tools import MapSession, MapTools
23
+ from .stac_tools import STACTools
23
24
 
24
- try:
25
- import nest_asyncio
26
25
 
27
- nest_asyncio.apply()
28
- except Exception:
29
- pass
26
+ class UICallbackHandler:
27
+ """Callback handler that updates UI status widget with agent progress.
28
+
29
+ This handler intercepts tool calls and progress events to provide
30
+ real-time feedback without overwhelming the user.
31
+ """
32
+
33
+ def __init__(self, status_widget=None):
34
+ """Initialize the callback handler.
35
+
36
+ Args:
37
+ status_widget: Optional ipywidgets.HTML widget to update with status.
38
+ """
39
+ self.status_widget = status_widget
40
+ self.current_tool = None
41
+
42
+ def __call__(self, **kwargs):
43
+ """Handle callback events from the agent.
44
+
45
+ Args:
46
+ **kwargs: Event data from the agent execution.
47
+ """
48
+ # Track when tools are being called
49
+ if "current_tool_use" in kwargs and kwargs["current_tool_use"].get("name"):
50
+ tool_name = kwargs["current_tool_use"]["name"]
51
+ self.current_tool = tool_name
52
+
53
+ # Update status widget if available
54
+ if self.status_widget is not None:
55
+ # Make tool names more user-friendly
56
+ friendly_name = tool_name.replace("_", " ").title()
57
+ self.status_widget.value = (
58
+ f"<span style='color:#0a7'>⚙️ {friendly_name}...</span>"
59
+ )
30
60
 
31
61
 
32
62
  class OllamaModel(_OllamaModel):
@@ -170,18 +200,6 @@ def create_bedrock_model(
170
200
  )
171
201
 
172
202
 
173
- def _ensure_loop() -> asyncio.AbstractEventLoop:
174
- try:
175
- loop = asyncio.get_event_loop()
176
- except RuntimeError:
177
- loop = asyncio.new_event_loop()
178
- asyncio.set_event_loop(loop)
179
- if loop.is_closed():
180
- loop = asyncio.new_event_loop()
181
- asyncio.set_event_loop(loop)
182
- return loop
183
-
184
-
185
203
  class GeoAgent(Agent):
186
204
  """Geospatial AI agent with interactive mapping capabilities."""
187
205
 
@@ -311,30 +329,14 @@ class GeoAgent(Agent):
311
329
  def ask(self, prompt: str) -> str:
312
330
  """Send a single-turn prompt to the agent.
313
331
 
314
- Runs entirely on the same thread/event loop as the Agent
315
- to avoid cross-loop asyncio object issues.
316
-
317
332
  Args:
318
333
  prompt: The text prompt to send to the agent.
319
334
 
320
335
  Returns:
321
336
  The agent's response as a string.
322
337
  """
323
- # Ensure there's an event loop bound to this thread (Jupyter-safe)
324
- loop = _ensure_loop()
325
-
326
- # Preserve existing conversation messages
327
- existing_messages = self.messages.copy()
328
-
329
- # Create a fresh model but keep conversation history
330
- self.model = self._model_factory()
331
-
332
- # Restore the conversation messages
333
- self.messages = existing_messages
334
-
335
- # Execute the prompt using the Agent's async API on this loop
336
- # Avoid Agent.__call__ since it spins a new thread+loop
337
- result = loop.run_until_complete(self.invoke_async(prompt))
338
+ # Use strands' built-in __call__ method which now supports multiple calls
339
+ result = self(prompt)
338
340
  return getattr(result, "final_text", str(result))
339
341
 
340
342
  def show_ui(self, *, height: int = 700) -> None:
@@ -349,7 +351,10 @@ class GeoAgent(Agent):
349
351
  m.create_container()
350
352
 
351
353
  map_panel = widgets.VBox(
352
- [widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"), m.container],
354
+ [
355
+ widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"),
356
+ m.floating_sidebar_widget,
357
+ ],
353
358
  layout=widgets.Layout(
354
359
  flex="2 1 0%",
355
360
  min_width="520px",
@@ -427,7 +432,6 @@ class GeoAgent(Agent):
427
432
  examples=examples,
428
433
  )
429
434
  self._pending = {"fut": None}
430
- self._executor = ThreadPoolExecutor(max_workers=1)
431
435
 
432
436
  def _esc(s: str) -> str:
433
437
  """Escape HTML characters in a string.
@@ -494,7 +498,14 @@ class GeoAgent(Agent):
494
498
  _lock(True)
495
499
  self._ui.status.value = "<span style='color:#0a7'>Running…</span>"
496
500
  try:
497
- out = self.ask(text) # fresh Agent/model per call, silent
501
+ # Create a callback handler that updates the status widget
502
+ callback_handler = UICallbackHandler(status_widget=self._ui.status)
503
+
504
+ # Temporarily set callback_handler for this call
505
+ old_callback = self.callback_handler
506
+ self.callback_handler = callback_handler
507
+
508
+ out = self.ask(text) # fresh Agent/model per call, with callback
498
509
  _append("assistant", out)
499
510
  self._ui.status.value = "<span style='color:#0a7'>Done.</span>"
500
511
  except Exception as e:
@@ -503,6 +514,8 @@ class GeoAgent(Agent):
503
514
  "<span style='color:#c00'>Finished with an issue.</span>"
504
515
  )
505
516
  finally:
517
+ # Restore old callback handler
518
+ self.callback_handler = old_callback
506
519
  self._ui.inp.value = ""
507
520
  _lock(False)
508
521
 
@@ -593,3 +606,874 @@ class GeoAgent(Agent):
593
606
  [map_panel, right], layout=widgets.Layout(width="100%", gap="8px")
594
607
  )
595
608
  display(root)
609
+
610
+
611
+ class STACAgent(Agent):
612
+ """AI agent for searching and interacting with STAC catalogs."""
613
+
614
+ def __init__(
615
+ self,
616
+ *,
617
+ model: str = "llama3.1",
618
+ system_prompt: str = "default",
619
+ endpoint: str = "https://planetarycomputer.microsoft.com/api/stac/v1",
620
+ model_args: dict = None,
621
+ map_instance: Optional[leafmap.Map] = None,
622
+ **kwargs: Any,
623
+ ) -> None:
624
+ """Initialize the STAC Agent.
625
+
626
+ Args:
627
+ model: Model identifier (default: "llama3.1").
628
+ system_prompt: System prompt for the agent (default: "default").
629
+ endpoint: STAC API endpoint URL. Defaults to Microsoft Planetary Computer.
630
+ model_args: Additional keyword arguments for the model.
631
+ map_instance: Optional leafmap.Map instance for visualization. If None, creates a new one.
632
+ **kwargs: Additional keyword arguments for the Agent.
633
+ """
634
+ self.stac_tools: STACTools = STACTools(endpoint=endpoint)
635
+ self.map_instance = map_instance if map_instance is not None else leafmap.Map()
636
+
637
+ if model_args is None:
638
+ model_args = {}
639
+
640
+ # --- save a model factory we can call each turn ---
641
+ if model == "llama3.1":
642
+ self._model_factory: Callable[[], OllamaModel] = (
643
+ lambda: create_ollama_model(
644
+ host="http://localhost:11434", model_id=model, **model_args
645
+ )
646
+ )
647
+ elif isinstance(model, str):
648
+ self._model_factory: Callable[[], BedrockModel] = (
649
+ lambda: create_bedrock_model(model_id=model, **model_args)
650
+ )
651
+ elif isinstance(model, OllamaModel):
652
+ # Extract configuration from existing OllamaModel and create new instances
653
+ model_id = model.config["model_id"]
654
+ host = model.host
655
+ client_args = model.client_args
656
+ self._model_factory: Callable[[], OllamaModel] = (
657
+ lambda: create_ollama_model(
658
+ host=host, model_id=model_id, client_args=client_args, **model_args
659
+ )
660
+ )
661
+ elif isinstance(model, OpenAIModel):
662
+ # Extract configuration from existing OpenAIModel and create new instances
663
+ model_id = model.config["model_id"]
664
+ client_args = model.client_args.copy()
665
+ self._model_factory: Callable[[], OpenAIModel] = (
666
+ lambda mid=model_id, client_args=client_args: create_openai_model(
667
+ model_id=mid, client_args=client_args, **model_args
668
+ )
669
+ )
670
+ elif isinstance(model, AnthropicModel):
671
+ # Extract configuration from existing AnthropicModel and create new instances
672
+ model_id = model.config["model_id"]
673
+ client_args = model.client_args.copy()
674
+ self._model_factory: Callable[[], AnthropicModel] = (
675
+ lambda mid=model_id, client_args=client_args: create_anthropic_model(
676
+ model_id=mid, client_args=client_args, **model_args
677
+ )
678
+ )
679
+ else:
680
+ raise ValueError(f"Invalid model: {model}")
681
+
682
+ # build initial model (first turn)
683
+ model = self._model_factory()
684
+
685
+ if system_prompt == "default":
686
+ system_prompt = """You are a STAC search agent. Follow these steps EXACTLY:
687
+
688
+ 1. Determine collection ID based on data type:
689
+ - "sentinel-2-l2a" for Sentinel-2 or optical satellite imagery
690
+ - "landsat-c2-l2" for Landsat
691
+ - "naip" for NAIP or aerial imagery (USA only)
692
+ - "sentinel-1-grd" for Sentinel-1 or SAR/radar
693
+ - "cop-dem-glo-30" for DEM, elevation, or terrain data
694
+ - "aster-l1t" for ASTER
695
+ For other data (e.g., MODIS, land cover): call list_collections(filter_keyword="<keyword>")
696
+
697
+ 2. If location mentioned:
698
+ - Call geocode_location("<name>") FIRST
699
+ - WAIT for the response
700
+ - Extract the "bbox" array from the JSON response
701
+ - This bbox is [west, south, east, north] format
702
+
703
+ 3. Call search_items():
704
+ - collection: REQUIRED
705
+ - bbox: Use the EXACT bbox array from geocode_location (REQUIRED if location mentioned)
706
+ - time_range: "YYYY-MM-DD/YYYY-MM-DD" format if dates mentioned
707
+ - query: Use for cloud cover filtering (see examples)
708
+ - max_items: 1
709
+
710
+ Cloud cover filtering:
711
+ - "<10% cloud": query={"eo:cloud_cover": {"lt": 10}}
712
+ - "<20% cloud": query={"eo:cloud_cover": {"lt": 20}}
713
+ - "<5% cloud": query={"eo:cloud_cover": {"lt": 5}}
714
+
715
+ Examples:
716
+ 1. "Find Landsat over Paris from June to July 2023"
717
+ geocode_location("Paris") → {"bbox": [2.224, 48.815, 2.469, 48.902], ...}
718
+ search_items(collection="landsat-c2-l2", bbox=[2.224, 48.815, 2.469, 48.902], time_range="2023-06-01/2023-07-31")
719
+
720
+ 2. "Find Landsat with <10% cloud cover over Paris"
721
+ geocode_location("Paris") → {"bbox": [2.224, 48.815, 2.469, 48.902], ...}
722
+ search_items(collection="landsat-c2-l2", bbox=[2.224, 48.815, 2.469, 48.902], query={"eo:cloud_cover": {"lt": 10}})
723
+
724
+ 4. Return first item as JSON:
725
+ {"id": "...", "collection": "...", "datetime": "...", "bbox": [...], "assets": [...]}
726
+
727
+ ERROR HANDLING:
728
+ - If no items found: {"error": "No items found"}
729
+ - If tool result too large: {"error": "Result too large, try narrower search"}
730
+ - If tool error: {"error": "Search failed: <error message>"}
731
+
732
+ CRITICAL: Return ONLY JSON. NO explanatory text, NO made-up data."""
733
+
734
+ super().__init__(
735
+ name="STAC Search Agent",
736
+ model=model,
737
+ tools=[
738
+ self.stac_tools.list_collections,
739
+ self.stac_tools.search_items,
740
+ self.stac_tools.get_item_info,
741
+ self.stac_tools.geocode_location,
742
+ self.stac_tools.get_common_collections,
743
+ ],
744
+ system_prompt=system_prompt,
745
+ callback_handler=None,
746
+ )
747
+
748
+ def _extract_search_items_payload(self, result: Any) -> Optional[Dict[str, Any]]:
749
+ """Return the parsed payload from the search_items tool, if available."""
750
+ # Try to get tool_results from the result object
751
+ tool_results = getattr(result, "tool_results", None)
752
+ if tool_results:
753
+ for tool_result in tool_results:
754
+ if getattr(tool_result, "tool_name", "") != "search_items":
755
+ continue
756
+
757
+ payload = getattr(tool_result, "result", None)
758
+ if payload is None:
759
+ continue
760
+
761
+ if isinstance(payload, str):
762
+ try:
763
+ payload = json.loads(payload)
764
+ except json.JSONDecodeError:
765
+ continue
766
+
767
+ if isinstance(payload, dict):
768
+ return payload
769
+
770
+ # Alternative: check messages for tool results
771
+ messages = getattr(self, "messages", [])
772
+ for msg in reversed(messages): # Check recent messages first
773
+ # Handle dict-style messages
774
+ if isinstance(msg, dict):
775
+ role = msg.get("role")
776
+ content = msg.get("content", [])
777
+
778
+ # Look for tool results in user messages (strands pattern)
779
+ if role == "user" and isinstance(content, list):
780
+ for item in content:
781
+ if isinstance(item, dict) and "toolResult" in item:
782
+ tool_result = item["toolResult"]
783
+ # Check if this is a search_items result
784
+ # We need to look at the preceding assistant message to identify the tool
785
+ if tool_result.get("status") == "success":
786
+ result_content = tool_result.get("content", [])
787
+ if isinstance(result_content, list) and result_content:
788
+ text_content = result_content[0].get("text", "")
789
+ try:
790
+ payload = json.loads(text_content)
791
+ # Return ANY search_items payload, even if items is empty
792
+ # This is identified by having "query" and "collection" fields
793
+ if (
794
+ "query" in payload
795
+ and "collection" in payload
796
+ and "items" in payload
797
+ ):
798
+ return payload
799
+ except json.JSONDecodeError:
800
+ continue
801
+
802
+ return None
803
+
804
+ def ask(self, prompt: str) -> str:
805
+ """Send a single-turn prompt to the agent.
806
+
807
+ Args:
808
+ prompt: The text prompt to send to the agent.
809
+
810
+ Returns:
811
+ The agent's response as a string (JSON format for search queries).
812
+ """
813
+ # Use strands' built-in __call__ method which now supports multiple calls
814
+ result = self(prompt)
815
+
816
+ search_payload = self._extract_search_items_payload(result)
817
+ if search_payload is not None:
818
+ if "error" in search_payload:
819
+ return json.dumps({"error": search_payload["error"]}, indent=2)
820
+
821
+ items = search_payload.get("items") or []
822
+ if items:
823
+ return json.dumps(items[0], indent=2)
824
+
825
+ return json.dumps({"error": "No items found"}, indent=2)
826
+
827
+ return getattr(result, "final_text", str(result))
828
+
829
+ def search_and_get_first_item(self, prompt: str) -> Optional[Dict[str, Any]]:
830
+ """Search for imagery and return the first item as a structured dict.
831
+
832
+ This method sends a search query to the agent, extracts the search results
833
+ directly from the tool calls, and returns the first item as a STACItemInfo-compatible
834
+ dictionary.
835
+
836
+ Note: This method uses LLM inference which adds ~5-10 seconds overhead.
837
+ For faster searches, use STACTools directly:
838
+ >>> from geoai.agents import STACTools
839
+ >>> tools = STACTools()
840
+ >>> result = tools.search_items(
841
+ ... collection="sentinel-2-l2a",
842
+ ... bbox=[-122.5, 37.7, -122.4, 37.8],
843
+ ... time_range="2024-08-01/2024-08-31"
844
+ ... )
845
+
846
+ Args:
847
+ prompt: Natural language search query (e.g., "Find Sentinel-2 imagery
848
+ over San Francisco in September 2024").
849
+
850
+ Returns:
851
+ Dictionary containing STACItemInfo fields (id, collection, datetime,
852
+ bbox, assets, properties), or None if no results found.
853
+
854
+ Example:
855
+ >>> agent = STACAgent()
856
+ >>> item = agent.search_and_get_first_item(
857
+ ... "Find Sentinel-2 imagery over Paris in summer 2023"
858
+ ... )
859
+ >>> print(item['id'])
860
+ >>> print(item['assets'][0]['key']) # or 'title'
861
+ """
862
+ # Use strands' built-in __call__ method which now supports multiple calls
863
+ result = self(prompt)
864
+
865
+ search_payload = self._extract_search_items_payload(result)
866
+ if search_payload is not None:
867
+ if "error" in search_payload:
868
+ print(f"Search error: {search_payload['error']}")
869
+ return None
870
+
871
+ items = search_payload.get("items") or []
872
+ if items:
873
+ return items[0]
874
+
875
+ print("No items found in search results")
876
+ return None
877
+
878
+ # Fallback: try to parse the final text response
879
+ response = getattr(result, "final_text", str(result))
880
+
881
+ try:
882
+ item_data = json.loads(response)
883
+
884
+ if "error" in item_data:
885
+ print(f"Search error: {item_data['error']}")
886
+ return None
887
+
888
+ if not all(k in item_data for k in ["id", "collection"]):
889
+ print("Response missing required fields (id, collection)")
890
+ return None
891
+
892
+ return item_data
893
+
894
+ except json.JSONDecodeError:
895
+ print("Could not extract item data from agent response")
896
+ return None
897
+
898
+ def _visualize_stac_item(self, item: Dict[str, Any]) -> None:
899
+ """Visualize a STAC item on the map.
900
+
901
+ Args:
902
+ item: STAC item dictionary with id, collection, assets, etc.
903
+ """
904
+ if not item or "id" not in item or "collection" not in item:
905
+ return
906
+
907
+ # Get the collection and item ID
908
+ collection = item.get("collection")
909
+ item_id = item.get("id")
910
+
911
+ kwargs = {}
912
+
913
+ # Determine which assets to display based on collection
914
+ assets = None
915
+ if collection == "sentinel-2-l2a":
916
+ assets = ["B04", "B03", "B02"] # True color RGB
917
+ elif collection == "landsat-c2-l2":
918
+ assets = ["red", "green", "blue"] # Landsat RGB
919
+ elif collection == "naip":
920
+ assets = ["image"] # NAIP 4-band imagery
921
+ elif "sentinel-1" in collection:
922
+ assets = ["vv"] # Sentinel-1 VV polarization
923
+ elif collection == "cop-dem-glo-30":
924
+ assets = ["data"]
925
+ kwargs["colormap_name"] = "terrain"
926
+ elif collection == "aster-l1t":
927
+ assets = ["VNIR"] # ASTER L1T imagery
928
+ elif collection == "3dep-lidar-hag":
929
+ assets = ["data"]
930
+ kwargs["colormap_name"] = "terrain"
931
+ else:
932
+ # Try to find common asset names
933
+ if "assets" in item:
934
+ asset_keys = [
935
+ a.get("key") for a in item["assets"] if isinstance(a, dict)
936
+ ]
937
+ # Look for visual or RGB assets
938
+ for possible in ["visual", "image", "data"]:
939
+ if possible in asset_keys:
940
+ assets = [possible]
941
+ break
942
+ # If still no assets, use first few assets
943
+ if not assets and asset_keys:
944
+ assets = asset_keys[:1]
945
+
946
+ if not assets:
947
+ return None
948
+ try:
949
+ # Add the STAC layer to the map
950
+ layer_name = f"{collection[:20]}_{item_id[:15]}"
951
+ self.map_instance.add_stac_layer(
952
+ collection=collection,
953
+ item=item_id,
954
+ assets=assets,
955
+ name=layer_name,
956
+ before_id=self.map_instance.first_symbol_layer_id,
957
+ **kwargs,
958
+ )
959
+ return assets # Return the assets that were visualized
960
+ except Exception as e:
961
+ print(f"Could not visualize item on map: {e}")
962
+ return None
963
+
964
+ def show_ui(self, *, height: int = 700) -> None:
965
+ """Display an interactive UI with map and chat interface for STAC searches.
966
+
967
+ Args:
968
+ height: Height of the UI in pixels (default: 700).
969
+ """
970
+ m = self.map_instance
971
+ if not hasattr(m, "container") or m.container is None:
972
+ m.create_container()
973
+
974
+ map_panel = widgets.VBox(
975
+ [
976
+ widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"),
977
+ m.floating_sidebar_widget,
978
+ ],
979
+ layout=widgets.Layout(
980
+ flex="2 1 0%",
981
+ min_width="520px",
982
+ border="1px solid #ddd",
983
+ padding="8px",
984
+ height=f"{height}px",
985
+ overflow="hidden",
986
+ ),
987
+ )
988
+
989
+ # ----- chat widgets -----
990
+ session_id = str(uuid.uuid4())[:8]
991
+ title = widgets.HTML(
992
+ f"<h3 style='margin:0'>STAC Search Agent</h3>"
993
+ f"<p style='margin:4px 0 8px;color:#666'>Session: {session_id}</p>"
994
+ )
995
+ log = widgets.HTML(
996
+ value="<div style='color:#777'>No messages yet. Try searching for satellite imagery!</div>",
997
+ layout=widgets.Layout(
998
+ border="1px solid #ddd",
999
+ padding="8px",
1000
+ height="520px",
1001
+ overflow_y="auto",
1002
+ ),
1003
+ )
1004
+ inp = widgets.Textarea(
1005
+ placeholder="Search for satellite/aerial imagery (e.g., 'Find Sentinel-2 imagery over Paris in summer 2024')",
1006
+ layout=widgets.Layout(width="99%", height="90px"),
1007
+ )
1008
+ btn_send = widgets.Button(
1009
+ description="Search",
1010
+ button_style="primary",
1011
+ icon="search",
1012
+ layout=widgets.Layout(width="120px"),
1013
+ )
1014
+ btn_stop = widgets.Button(
1015
+ description="Stop", icon="stop", layout=widgets.Layout(width="120px")
1016
+ )
1017
+ btn_clear = widgets.Button(
1018
+ description="Clear", icon="trash", layout=widgets.Layout(width="120px")
1019
+ )
1020
+ status = widgets.HTML("<span style='color:#666'>Ready to search.</span>")
1021
+
1022
+ examples = widgets.Dropdown(
1023
+ options=[
1024
+ ("— Example Searches —", ""),
1025
+ (
1026
+ "Sentinel-2 over Las Vegas",
1027
+ "Find Sentinel-2 imagery over Las Vegas in August 2025",
1028
+ ),
1029
+ (
1030
+ "Landsat over Paris",
1031
+ "Find Landsat imagery over Paris from June to July 2025",
1032
+ ),
1033
+ (
1034
+ "Landsat with <10% cloud cover",
1035
+ "Find Landsat imagery over Paris with <10% cloud cover in June 2025",
1036
+ ),
1037
+ ("NAIP for NYC", "Show me NAIP aerial imagery for New York City"),
1038
+ ("DEM for Seattle", "Show me DEM data for Seattle"),
1039
+ (
1040
+ "3DEP Lidar HAG",
1041
+ "Show me data over Austin from 3dep-lidar-hag collection",
1042
+ ),
1043
+ ("ASTER for Tokyo", "Show me ASTER imagery for Tokyo"),
1044
+ ],
1045
+ value="",
1046
+ layout=widgets.Layout(width="auto"),
1047
+ )
1048
+
1049
+ # --- state kept on self so it persists ---
1050
+ self._ui = SimpleNamespace(
1051
+ messages=[],
1052
+ map_panel=map_panel,
1053
+ title=title,
1054
+ log=log,
1055
+ inp=inp,
1056
+ btn_send=btn_send,
1057
+ btn_stop=btn_stop,
1058
+ btn_clear=btn_clear,
1059
+ status=status,
1060
+ examples=examples,
1061
+ )
1062
+ self._pending = {"fut": None}
1063
+
1064
+ def _esc(s: str) -> str:
1065
+ """Escape HTML characters in a string."""
1066
+ return (
1067
+ s.replace("&", "&amp;")
1068
+ .replace("<", "&lt;")
1069
+ .replace(">", "&gt;")
1070
+ .replace("\n", "<br/>")
1071
+ )
1072
+
1073
+ def _append(role: str, msg: str) -> None:
1074
+ """Append a message to the chat log."""
1075
+ self._ui.messages.append((role, msg))
1076
+ parts = []
1077
+ for r, mm in self._ui.messages:
1078
+ if r == "user":
1079
+ parts.append(
1080
+ f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#eef;'><b>You</b>: {_esc(mm)}</div>"
1081
+ )
1082
+ else:
1083
+ parts.append(
1084
+ f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#f7f7f7;'><b>Agent</b>: {_esc(mm)}</div>"
1085
+ )
1086
+ self._ui.log.value = (
1087
+ "<div>"
1088
+ + (
1089
+ "".join(parts)
1090
+ if parts
1091
+ else "<div style='color:#777'>No messages yet.</div>"
1092
+ )
1093
+ + "</div>"
1094
+ )
1095
+
1096
+ def _lock(lock: bool) -> None:
1097
+ """Lock or unlock UI controls."""
1098
+ self._ui.btn_send.disabled = lock
1099
+ self._ui.btn_stop.disabled = not lock
1100
+ self._ui.btn_clear.disabled = lock
1101
+ self._ui.inp.disabled = lock
1102
+ self._ui.examples.disabled = lock
1103
+
1104
+ def _on_send(_: Any = None) -> None:
1105
+ """Handle send button click or Enter key press."""
1106
+ text = self._ui.inp.value.strip()
1107
+ if not text:
1108
+ return
1109
+ _append("user", text)
1110
+ _lock(True)
1111
+ self._ui.status.value = "<span style='color:#0a7'>Searching…</span>"
1112
+ try:
1113
+ # Create a callback handler that updates the status widget
1114
+ callback_handler = UICallbackHandler(status_widget=self._ui.status)
1115
+
1116
+ # Temporarily set callback_handler for this call
1117
+ old_callback = self.callback_handler
1118
+ self.callback_handler = callback_handler
1119
+
1120
+ # Get the structured search result directly (will show progress via callback)
1121
+ item_data = self.search_and_get_first_item(text)
1122
+
1123
+ if item_data is not None:
1124
+ # Update status for visualization step
1125
+ self._ui.status.value = (
1126
+ "<span style='color:#0a7'>Adding to map...</span>"
1127
+ )
1128
+
1129
+ # Visualize on map
1130
+ visualized_assets = self._visualize_stac_item(item_data)
1131
+
1132
+ # Format response for display
1133
+ formatted_response = (
1134
+ f"Found item: {item_data['id']}\n"
1135
+ f"Collection: {item_data['collection']}\n"
1136
+ f"Date: {item_data.get('datetime', 'N/A')}\n"
1137
+ )
1138
+
1139
+ if visualized_assets:
1140
+ assets_str = ", ".join(visualized_assets)
1141
+ formatted_response += f"✓ Added to map (assets: {assets_str})"
1142
+ else:
1143
+ formatted_response += "✓ Added to map"
1144
+
1145
+ _append("assistant", formatted_response)
1146
+ else:
1147
+ _append(
1148
+ "assistant",
1149
+ "No items found. Try adjusting your search query or date range.",
1150
+ )
1151
+
1152
+ self._ui.status.value = "<span style='color:#0a7'>Done.</span>"
1153
+ except Exception as e:
1154
+ _append("assistant", f"[error] {type(e).__name__}: {e}")
1155
+ self._ui.status.value = (
1156
+ "<span style='color:#c00'>Finished with an issue.</span>"
1157
+ )
1158
+ finally:
1159
+ # Restore old callback handler
1160
+ self.callback_handler = old_callback
1161
+ self._ui.inp.value = ""
1162
+ _lock(False)
1163
+
1164
+ def _on_stop(_: Any = None) -> None:
1165
+ """Handle stop button click."""
1166
+ fut = self._pending.get("fut")
1167
+ if fut and not fut.done():
1168
+ self._pending["fut"] = None
1169
+ self._ui.status.value = (
1170
+ "<span style='color:#c00'>Stop requested.</span>"
1171
+ )
1172
+ _lock(False)
1173
+
1174
+ def _on_clear(_: Any = None) -> None:
1175
+ """Handle clear button click."""
1176
+ self._ui.messages.clear()
1177
+ self._ui.log.value = "<div style='color:#777'>No messages yet.</div>"
1178
+ self._ui.status.value = "<span style='color:#666'>Cleared.</span>"
1179
+
1180
+ def _on_example_change(change: dict[str, Any]) -> None:
1181
+ """Handle example dropdown selection change."""
1182
+ if change["name"] == "value" and change["new"]:
1183
+ self._ui.inp.value = change["new"]
1184
+ self._ui.examples.value = ""
1185
+ self._ui.inp.send({"method": "focus"})
1186
+
1187
+ # keep handler refs
1188
+ self._handlers = SimpleNamespace(
1189
+ on_send=_on_send,
1190
+ on_stop=_on_stop,
1191
+ on_clear=_on_clear,
1192
+ on_example_change=_on_example_change,
1193
+ )
1194
+
1195
+ # wire events
1196
+ self._ui.btn_send.on_click(self._handlers.on_send)
1197
+ self._ui.btn_stop.on_click(self._handlers.on_stop)
1198
+ self._ui.btn_clear.on_click(self._handlers.on_clear)
1199
+ self._ui.examples.observe(self._handlers.on_example_change, names="value")
1200
+
1201
+ # Ctrl+Enter on textarea
1202
+ self._keyev = Event(
1203
+ source=self._ui.inp, watched_events=["keyup"], prevent_default_action=False
1204
+ )
1205
+
1206
+ def _on_key(e: dict[str, Any]) -> None:
1207
+ """Handle keyboard events on the input textarea."""
1208
+ if (
1209
+ e.get("type") == "keyup"
1210
+ and e.get("key") == "Enter"
1211
+ and e.get("ctrlKey")
1212
+ ):
1213
+ if self._ui.inp.value.endswith("\n"):
1214
+ self._ui.inp.value = self._ui.inp.value[:-1]
1215
+ self._handlers.on_send()
1216
+
1217
+ # store callback too
1218
+ self._on_key_cb: Callable[[dict[str, Any]], None] = _on_key
1219
+ self._keyev.on_dom_event(self._on_key_cb)
1220
+
1221
+ buttons = widgets.HBox(
1222
+ [
1223
+ self._ui.btn_send,
1224
+ self._ui.btn_stop,
1225
+ self._ui.btn_clear,
1226
+ widgets.Box(
1227
+ [self._ui.examples], layout=widgets.Layout(margin="0 0 0 auto")
1228
+ ),
1229
+ ]
1230
+ )
1231
+ right = widgets.VBox(
1232
+ [
1233
+ self._ui.title,
1234
+ self._ui.log,
1235
+ self._ui.inp,
1236
+ buttons,
1237
+ self._ui.status,
1238
+ ],
1239
+ layout=widgets.Layout(flex="1 1 0%", min_width="360px"),
1240
+ )
1241
+ root = widgets.HBox(
1242
+ [map_panel, right], layout=widgets.Layout(width="100%", gap="8px")
1243
+ )
1244
+ display(root)
1245
+
1246
+
1247
+ class CatalogAgent(Agent):
1248
+ """AI agent for searching data catalogs with natural language queries."""
1249
+
1250
+ def __init__(
1251
+ self,
1252
+ *,
1253
+ model: str = "llama3.1",
1254
+ system_prompt: str = "default",
1255
+ catalog_url: Optional[str] = None,
1256
+ catalog_df: Optional[Any] = None,
1257
+ model_args: dict = None,
1258
+ **kwargs: Any,
1259
+ ) -> None:
1260
+ """Initialize the Catalog Agent.
1261
+
1262
+ Args:
1263
+ model: Model identifier (default: "llama3.1").
1264
+ system_prompt: System prompt for the agent (default: "default").
1265
+ catalog_url: URL to a catalog file (TSV, CSV, or JSON). Use JSON format for spatial search support.
1266
+ Example: "https://raw.githubusercontent.com/opengeos/Earth-Engine-Catalog/refs/heads/master/gee_catalog.json"
1267
+ catalog_df: Pre-loaded catalog as a pandas DataFrame.
1268
+ model_args: Additional keyword arguments for the model.
1269
+ **kwargs: Additional keyword arguments for the Agent.
1270
+ """
1271
+ self.catalog_tools: CatalogTools = CatalogTools(
1272
+ catalog_url=catalog_url, catalog_df=catalog_df
1273
+ )
1274
+
1275
+ if model_args is None:
1276
+ model_args = {}
1277
+
1278
+ # --- save a model factory we can call each turn ---
1279
+ if model == "llama3.1":
1280
+ self._model_factory: Callable[[], OllamaModel] = (
1281
+ lambda: create_ollama_model(
1282
+ host="http://localhost:11434", model_id=model, **model_args
1283
+ )
1284
+ )
1285
+ elif isinstance(model, str):
1286
+ self._model_factory: Callable[[], BedrockModel] = (
1287
+ lambda: create_bedrock_model(model_id=model, **model_args)
1288
+ )
1289
+ elif isinstance(model, OllamaModel):
1290
+ # Extract configuration from existing OllamaModel and create new instances
1291
+ model_id = model.config["model_id"]
1292
+ host = model.host
1293
+ client_args = model.client_args
1294
+ self._model_factory: Callable[[], OllamaModel] = (
1295
+ lambda: create_ollama_model(
1296
+ host=host, model_id=model_id, client_args=client_args, **model_args
1297
+ )
1298
+ )
1299
+ elif isinstance(model, OpenAIModel):
1300
+ # Extract configuration from existing OpenAIModel and create new instances
1301
+ model_id = model.config["model_id"]
1302
+ client_args = model.client_args.copy()
1303
+ self._model_factory: Callable[[], OpenAIModel] = (
1304
+ lambda mid=model_id, client_args=client_args: create_openai_model(
1305
+ model_id=mid, client_args=client_args, **model_args
1306
+ )
1307
+ )
1308
+ elif isinstance(model, AnthropicModel):
1309
+ # Extract configuration from existing AnthropicModel and create new instances
1310
+ model_id = model.config["model_id"]
1311
+ client_args = model.client_args.copy()
1312
+ self._model_factory: Callable[[], AnthropicModel] = (
1313
+ lambda mid=model_id, client_args=client_args: create_anthropic_model(
1314
+ model_id=mid, client_args=client_args, **model_args
1315
+ )
1316
+ )
1317
+ else:
1318
+ raise ValueError(f"Invalid model: {model}")
1319
+
1320
+ # build initial model (first turn)
1321
+ model = self._model_factory()
1322
+
1323
+ if system_prompt == "default":
1324
+ system_prompt = """You are a data catalog search agent. Your job is to help users find datasets from a data catalog.
1325
+
1326
+ IMPORTANT: Follow these steps EXACTLY:
1327
+
1328
+ 1. Understand the user's query:
1329
+ - What type of data are they looking for? (e.g., landcover, elevation, imagery)
1330
+ - Are they searching for a specific geographic region? (e.g., California, San Francisco, bounding box)
1331
+ - Are they filtering by time period? (e.g., "from 2020", "between 2015-2020", "recent data")
1332
+ - Are they filtering by provider? (e.g., NASA, USGS)
1333
+ - Are they filtering by dataset type? (e.g., image, image_collection, table)
1334
+
1335
+ 2. Use the appropriate tool:
1336
+ - search_by_region: PREFERRED for spatial queries - search datasets covering a geographic region
1337
+ * Use location parameter for place names (e.g., "California", "San Francisco")
1338
+ * Use bbox parameter for coordinates [west, south, east, north]
1339
+ * Can combine with keywords, dataset_type, provider, start_date, end_date filters
1340
+ - search_datasets: For keyword-only searches without spatial filter
1341
+ * Can filter by start_date and end_date for temporal queries
1342
+ - geocode_location: Convert location names to coordinates (called automatically by search_by_region)
1343
+ - get_dataset_info: Get details about a specific dataset by ID
1344
+ - list_dataset_types: Show available dataset types
1345
+ - list_providers: Show available data providers
1346
+ - get_catalog_stats: Get overall catalog statistics
1347
+
1348
+ 3. Search strategy:
1349
+ - SPATIAL QUERIES: If user mentions ANY location or region, IMMEDIATELY use search_by_region
1350
+ * Pass location names directly to the location parameter - DO NOT ask user for bbox coordinates
1351
+ * Examples of locations: California, San Francisco, New York, Paris, any city/state/country name
1352
+ * search_by_region will automatically geocode location names - you don't need to call geocode_location separately
1353
+ - TEMPORAL QUERIES: If user mentions ANY time period, ALWAYS add start_date/end_date parameters
1354
+ * "from 2022" or "since 2022" or "2022 onwards" → start_date="2022-01-01"
1355
+ * "until 2023" or "before 2023" → end_date="2023-12-31"
1356
+ * "between 2020 and 2023" → start_date="2020-01-01", end_date="2023-12-31"
1357
+ * "recent" or "latest" → start_date="2020-01-01"
1358
+ * Time indicators: from, since, after, before, until, between, onwards, recent, latest
1359
+ - KEYWORD QUERIES: If no location mentioned, use search_datasets
1360
+ - Extract key search terms from the user's query
1361
+ - Use keywords parameter for the main search terms
1362
+ - Use dataset_type parameter if user specifies type (image, table, etc.)
1363
+ - Use provider parameter if user specifies provider (NASA, USGS, etc.)
1364
+ - Default max_results is 10, but can be adjusted
1365
+
1366
+ CRITICAL RULES:
1367
+ 1. NEVER ask the user to provide bbox coordinates. If they mention a location name, pass it directly to search_by_region(location="name")
1368
+ 2. ALWAYS add start_date or end_date when user mentions ANY time period (from, since, onwards, recent, etc.)
1369
+ 3. Convert years to YYYY-MM-DD format: 2022 → "2022-01-01"
1370
+
1371
+ 4. Examples:
1372
+ - "Find landcover datasets covering California" → search_by_region(location="California", keywords="landcover")
1373
+ - "Show elevation data for San Francisco" → search_by_region(location="San Francisco", keywords="elevation")
1374
+ - "Find datasets in bbox [-122, 37, -121, 38]" → search_by_region(bbox=[-122, 37, -121, 38])
1375
+ - "Find landcover datasets from NASA" → search_datasets(keywords="landcover", provider="NASA")
1376
+ - "Show me elevation data" → search_datasets(keywords="elevation")
1377
+ - "What types of datasets are available?" → list_dataset_types()
1378
+ - "Find image collections about forests" → search_datasets(keywords="forest", dataset_type="image_collection")
1379
+ - "Find landcover data from 2020 onwards" → search_datasets(keywords="landcover", start_date="2020-01-01")
1380
+ - "Show California datasets between 2015 and 2020" → search_by_region(location="California", start_date="2015-01-01", end_date="2020-12-31")
1381
+ - "Find recent elevation data" → search_datasets(keywords="elevation", start_date="2020-01-01")
1382
+
1383
+ 5. Return results clearly:
1384
+ - Summarize the number of results found
1385
+ - List the top results with their EXACT IDs and titles FROM THE TOOL RESPONSE
1386
+ - Mention key information like provider, type, geographic coverage, and date range if available
1387
+ - For spatial searches, mention the search region
1388
+
1389
+ ERROR HANDLING:
1390
+ - If no results found: Suggest trying different keywords, broader region, or removing filters
1391
+ - If location not found: Suggest alternative spellings or try a broader region
1392
+ - If tool error: Explain the error and suggest alternatives
1393
+
1394
+ CRITICAL RULES - MUST FOLLOW:
1395
+ 1. NEVER make up or hallucinate dataset IDs, titles, or any other information
1396
+ 2. ONLY report datasets that appear in the actual tool response
1397
+ 3. Copy dataset IDs and titles EXACTLY as they appear in the tool response
1398
+ 4. If a field is null/None in the tool response, say "N/A" or omit it - DO NOT guess
1399
+ 5. DO NOT use your training data knowledge about Earth Engine datasets
1400
+ 6. DO NOT fill in missing information from your knowledge
1401
+ 7. If unsure, say "Information not available in results"
1402
+
1403
+ Example of CORRECT behavior:
1404
+ Tool returns: {"id": "AAFC/ACI", "title": "Canada AAFC Annual Crop Inventory"}
1405
+ Your response: "Found dataset: AAFC/ACI - Canada AAFC Annual Crop Inventory"
1406
+
1407
+ Example of INCORRECT behavior (DO NOT DO THIS):
1408
+ Tool returns: {"id": "AAFC/ACI", "title": "Canada AAFC Annual Crop Inventory"}
1409
+ Your response: "Found dataset: USGS/NED - USGS Elevation Data" ← WRONG! This ID wasn't in the tool response!"""
1410
+
1411
+ super().__init__(
1412
+ name="Catalog Search Agent",
1413
+ model=model,
1414
+ tools=[
1415
+ self.catalog_tools.search_datasets,
1416
+ self.catalog_tools.search_by_region,
1417
+ self.catalog_tools.get_dataset_info,
1418
+ self.catalog_tools.geocode_location,
1419
+ self.catalog_tools.list_dataset_types,
1420
+ self.catalog_tools.list_providers,
1421
+ self.catalog_tools.get_catalog_stats,
1422
+ ],
1423
+ system_prompt=system_prompt,
1424
+ callback_handler=None,
1425
+ )
1426
+
1427
+ def ask(self, prompt: str) -> str:
1428
+ """Send a single-turn prompt to the agent.
1429
+
1430
+ Args:
1431
+ prompt: The text prompt to send to the agent.
1432
+
1433
+ Returns:
1434
+ The agent's response as a string.
1435
+ """
1436
+ # Use strands' built-in __call__ method which now supports multiple calls
1437
+ result = self(prompt)
1438
+ return getattr(result, "final_text", str(result))
1439
+
1440
+ def search_datasets(
1441
+ self,
1442
+ keywords: Optional[str] = None,
1443
+ dataset_type: Optional[str] = None,
1444
+ provider: Optional[str] = None,
1445
+ max_results: int = 10,
1446
+ ) -> List[Dict[str, Any]]:
1447
+ """Search for datasets and return structured results.
1448
+
1449
+ This method directly uses the CatalogTools without LLM inference for faster searches.
1450
+
1451
+ Args:
1452
+ keywords: Keywords to search for.
1453
+ dataset_type: Filter by dataset type.
1454
+ provider: Filter by provider.
1455
+ max_results: Maximum number of results to return.
1456
+
1457
+ Returns:
1458
+ List of dataset dictionaries.
1459
+
1460
+ Example:
1461
+ >>> agent = CatalogAgent(catalog_url="...")
1462
+ >>> datasets = agent.search_datasets(keywords="landcover", provider="NASA")
1463
+ >>> for ds in datasets:
1464
+ ... print(ds['id'], ds['title'])
1465
+ """
1466
+ result_json = self.catalog_tools.search_datasets(
1467
+ keywords=keywords,
1468
+ dataset_type=dataset_type,
1469
+ provider=provider,
1470
+ max_results=max_results,
1471
+ )
1472
+
1473
+ result = json.loads(result_json)
1474
+
1475
+ if "error" in result:
1476
+ print(f"Search error: {result['error']}")
1477
+ return []
1478
+
1479
+ return result.get("datasets", [])