geoai-py 0.15.0__py2.py3-none-any.whl → 0.18.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
3
+ import json
4
4
  import os
5
5
  import uuid
6
- from concurrent.futures import ThreadPoolExecutor
7
6
  from types import SimpleNamespace
8
- from typing import Any, Callable, Optional
7
+ from typing import Any, Callable, Dict, Optional
9
8
 
10
9
  import boto3
11
10
  import ipywidgets as widgets
@@ -19,14 +18,47 @@ from strands.models.anthropic import AnthropicModel
19
18
  from strands.models.ollama import OllamaModel as _OllamaModel
20
19
  from strands.models.openai import OpenAIModel
21
20
 
21
+ from .catalog_tools import CatalogTools
22
22
  from .map_tools import MapSession, MapTools
23
+ from .stac_tools import STACTools
23
24
 
24
- try:
25
- import nest_asyncio
26
25
 
27
- nest_asyncio.apply()
28
- except Exception:
29
- pass
26
+ class UICallbackHandler:
27
+ """Callback handler that updates UI status widget with agent progress.
28
+
29
+ This handler intercepts tool calls and progress events to provide
30
+ real-time feedback without overwhelming the user.
31
+ """
32
+
33
+ def __init__(self, status_widget=None):
34
+ """Initialize the callback handler.
35
+
36
+ Args:
37
+ status_widget: Optional ipywidgets.HTML widget to update with status.
38
+ """
39
+ self.status_widget = status_widget
40
+ self.current_tool = None
41
+
42
+ def __call__(self, **kwargs):
43
+ """Handle callback events from the agent.
44
+
45
+ Args:
46
+ **kwargs: Event data from the agent execution.
47
+ """
48
+ # Track when tools are being called
49
+ if "current_tool_use" in kwargs and kwargs["current_tool_use"].get("name"):
50
+ tool_name = kwargs["current_tool_use"]["name"]
51
+ self.current_tool = tool_name
52
+
53
+ # Update status widget if available
54
+ if self.status_widget is not None:
55
+ # Make tool names more user-friendly
56
+ friendly_name = tool_name.replace("_", " ").title()
57
+ self.status_widget.value = (
58
+ f"<span style='color:#0a7'>"
59
+ f"<i class='fas fa-spinner fa-spin' style='font-size:1.2em'></i> "
60
+ f"{friendly_name}...</span>"
61
+ )
30
62
 
31
63
 
32
64
  class OllamaModel(_OllamaModel):
@@ -170,18 +202,6 @@ def create_bedrock_model(
170
202
  )
171
203
 
172
204
 
173
- def _ensure_loop() -> asyncio.AbstractEventLoop:
174
- try:
175
- loop = asyncio.get_event_loop()
176
- except RuntimeError:
177
- loop = asyncio.new_event_loop()
178
- asyncio.set_event_loop(loop)
179
- if loop.is_closed():
180
- loop = asyncio.new_event_loop()
181
- asyncio.set_event_loop(loop)
182
- return loop
183
-
184
-
185
205
  class GeoAgent(Agent):
186
206
  """Geospatial AI agent with interactive mapping capabilities."""
187
207
 
@@ -311,30 +331,14 @@ class GeoAgent(Agent):
311
331
  def ask(self, prompt: str) -> str:
312
332
  """Send a single-turn prompt to the agent.
313
333
 
314
- Runs entirely on the same thread/event loop as the Agent
315
- to avoid cross-loop asyncio object issues.
316
-
317
334
  Args:
318
335
  prompt: The text prompt to send to the agent.
319
336
 
320
337
  Returns:
321
338
  The agent's response as a string.
322
339
  """
323
- # Ensure there's an event loop bound to this thread (Jupyter-safe)
324
- loop = _ensure_loop()
325
-
326
- # Preserve existing conversation messages
327
- existing_messages = self.messages.copy()
328
-
329
- # Create a fresh model but keep conversation history
330
- self.model = self._model_factory()
331
-
332
- # Restore the conversation messages
333
- self.messages = existing_messages
334
-
335
- # Execute the prompt using the Agent's async API on this loop
336
- # Avoid Agent.__call__ since it spins a new thread+loop
337
- result = loop.run_until_complete(self.invoke_async(prompt))
340
+ # Use strands' built-in __call__ method which now supports multiple calls
341
+ result = self(prompt)
338
342
  return getattr(result, "final_text", str(result))
339
343
 
340
344
  def show_ui(self, *, height: int = 700) -> None:
@@ -349,7 +353,10 @@ class GeoAgent(Agent):
349
353
  m.create_container()
350
354
 
351
355
  map_panel = widgets.VBox(
352
- [widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"), m.container],
356
+ [
357
+ widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"),
358
+ m.floating_sidebar_widget,
359
+ ],
353
360
  layout=widgets.Layout(
354
361
  flex="2 1 0%",
355
362
  min_width="520px",
@@ -391,7 +398,10 @@ class GeoAgent(Agent):
391
398
  btn_clear = widgets.Button(
392
399
  description="Clear", icon="trash", layout=widgets.Layout(width="120px")
393
400
  )
394
- status = widgets.HTML("<span style='color:#666'>Ready.</span>")
401
+ status = widgets.HTML(
402
+ "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css'>"
403
+ "<span style='color:#666'>Ready.</span>"
404
+ )
395
405
 
396
406
  examples = widgets.Dropdown(
397
407
  options=[
@@ -427,7 +437,6 @@ class GeoAgent(Agent):
427
437
  examples=examples,
428
438
  )
429
439
  self._pending = {"fut": None}
430
- self._executor = ThreadPoolExecutor(max_workers=1)
431
440
 
432
441
  def _esc(s: str) -> str:
433
442
  """Escape HTML characters in a string.
@@ -494,7 +503,14 @@ class GeoAgent(Agent):
494
503
  _lock(True)
495
504
  self._ui.status.value = "<span style='color:#0a7'>Running…</span>"
496
505
  try:
497
- out = self.ask(text) # fresh Agent/model per call, silent
506
+ # Create a callback handler that updates the status widget
507
+ callback_handler = UICallbackHandler(status_widget=self._ui.status)
508
+
509
+ # Temporarily set callback_handler for this call
510
+ old_callback = self.callback_handler
511
+ self.callback_handler = callback_handler
512
+
513
+ out = self.ask(text) # fresh Agent/model per call, with callback
498
514
  _append("assistant", out)
499
515
  self._ui.status.value = "<span style='color:#0a7'>Done.</span>"
500
516
  except Exception as e:
@@ -503,6 +519,8 @@ class GeoAgent(Agent):
503
519
  "<span style='color:#c00'>Finished with an issue.</span>"
504
520
  )
505
521
  finally:
522
+ # Restore old callback handler
523
+ self.callback_handler = old_callback
506
524
  self._ui.inp.value = ""
507
525
  _lock(False)
508
526
 
@@ -593,3 +611,877 @@ class GeoAgent(Agent):
593
611
  [map_panel, right], layout=widgets.Layout(width="100%", gap="8px")
594
612
  )
595
613
  display(root)
614
+
615
+
616
+ class STACAgent(Agent):
617
+ """AI agent for searching and interacting with STAC catalogs."""
618
+
619
+ def __init__(
620
+ self,
621
+ *,
622
+ model: str = "llama3.1",
623
+ system_prompt: str = "default",
624
+ endpoint: str = "https://planetarycomputer.microsoft.com/api/stac/v1",
625
+ model_args: dict = None,
626
+ map_instance: Optional[leafmap.Map] = None,
627
+ **kwargs: Any,
628
+ ) -> None:
629
+ """Initialize the STAC Agent.
630
+
631
+ Args:
632
+ model: Model identifier (default: "llama3.1").
633
+ system_prompt: System prompt for the agent (default: "default").
634
+ endpoint: STAC API endpoint URL. Defaults to Microsoft Planetary Computer.
635
+ model_args: Additional keyword arguments for the model.
636
+ map_instance: Optional leafmap.Map instance for visualization. If None, creates a new one.
637
+ **kwargs: Additional keyword arguments for the Agent.
638
+ """
639
+ self.stac_tools: STACTools = STACTools(endpoint=endpoint)
640
+ self.map_instance = map_instance if map_instance is not None else leafmap.Map()
641
+
642
+ if model_args is None:
643
+ model_args = {}
644
+
645
+ # --- save a model factory we can call each turn ---
646
+ if model == "llama3.1":
647
+ self._model_factory: Callable[[], OllamaModel] = (
648
+ lambda: create_ollama_model(
649
+ host="http://localhost:11434", model_id=model, **model_args
650
+ )
651
+ )
652
+ elif isinstance(model, str):
653
+ self._model_factory: Callable[[], BedrockModel] = (
654
+ lambda: create_bedrock_model(model_id=model, **model_args)
655
+ )
656
+ elif isinstance(model, OllamaModel):
657
+ # Extract configuration from existing OllamaModel and create new instances
658
+ model_id = model.config["model_id"]
659
+ host = model.host
660
+ client_args = model.client_args
661
+ self._model_factory: Callable[[], OllamaModel] = (
662
+ lambda: create_ollama_model(
663
+ host=host, model_id=model_id, client_args=client_args, **model_args
664
+ )
665
+ )
666
+ elif isinstance(model, OpenAIModel):
667
+ # Extract configuration from existing OpenAIModel and create new instances
668
+ model_id = model.config["model_id"]
669
+ client_args = model.client_args.copy()
670
+ self._model_factory: Callable[[], OpenAIModel] = (
671
+ lambda mid=model_id, client_args=client_args: create_openai_model(
672
+ model_id=mid, client_args=client_args, **model_args
673
+ )
674
+ )
675
+ elif isinstance(model, AnthropicModel):
676
+ # Extract configuration from existing AnthropicModel and create new instances
677
+ model_id = model.config["model_id"]
678
+ client_args = model.client_args.copy()
679
+ self._model_factory: Callable[[], AnthropicModel] = (
680
+ lambda mid=model_id, client_args=client_args: create_anthropic_model(
681
+ model_id=mid, client_args=client_args, **model_args
682
+ )
683
+ )
684
+ else:
685
+ raise ValueError(f"Invalid model: {model}")
686
+
687
+ # build initial model (first turn)
688
+ model = self._model_factory()
689
+
690
+ if system_prompt == "default":
691
+ system_prompt = """You are a STAC search agent. Follow these steps EXACTLY:
692
+
693
+ 1. Determine collection ID based on data type:
694
+ - "sentinel-2-l2a" for Sentinel-2 or optical satellite imagery
695
+ - "landsat-c2-l2" for Landsat
696
+ - "naip" for NAIP or aerial imagery (USA only)
697
+ - "sentinel-1-grd" for Sentinel-1 or SAR/radar
698
+ - "cop-dem-glo-30" for DEM, elevation, or terrain data
699
+ - "aster-l1t" for ASTER
700
+ For other data (e.g., MODIS, land cover): call list_collections(filter_keyword="<keyword>")
701
+
702
+ 2. If location mentioned:
703
+ - Call geocode_location("<name>") FIRST
704
+ - WAIT for the response
705
+ - Extract the "bbox" array from the JSON response
706
+ - This bbox is [west, south, east, north] format
707
+
708
+ 3. Call search_items():
709
+ - collection: REQUIRED
710
+ - bbox: Use the EXACT bbox array from geocode_location (REQUIRED if location mentioned)
711
+ - time_range: "YYYY-MM-DD/YYYY-MM-DD" format if dates mentioned
712
+ - query: Use for cloud cover filtering (see examples)
713
+ - max_items: 1
714
+
715
+ Cloud cover filtering:
716
+ - "<10% cloud": query={"eo:cloud_cover": {"lt": 10}}
717
+ - "<20% cloud": query={"eo:cloud_cover": {"lt": 20}}
718
+ - "<5% cloud": query={"eo:cloud_cover": {"lt": 5}}
719
+
720
+ Examples:
721
+ 1. "Find Landsat over Paris from June to July 2023"
722
+ geocode_location("Paris") → {"bbox": [2.224, 48.815, 2.469, 48.902], ...}
723
+ search_items(collection="landsat-c2-l2", bbox=[2.224, 48.815, 2.469, 48.902], time_range="2023-06-01/2023-07-31")
724
+
725
+ 2. "Find Landsat with <10% cloud cover over Paris"
726
+ geocode_location("Paris") → {"bbox": [2.224, 48.815, 2.469, 48.902], ...}
727
+ search_items(collection="landsat-c2-l2", bbox=[2.224, 48.815, 2.469, 48.902], query={"eo:cloud_cover": {"lt": 10}})
728
+
729
+ 4. Return first item as JSON:
730
+ {"id": "...", "collection": "...", "datetime": "...", "bbox": [...], "assets": [...]}
731
+
732
+ ERROR HANDLING:
733
+ - If no items found: {"error": "No items found"}
734
+ - If tool result too large: {"error": "Result too large, try narrower search"}
735
+ - If tool error: {"error": "Search failed: <error message>"}
736
+
737
+ CRITICAL: Return ONLY JSON. NO explanatory text, NO made-up data."""
738
+
739
+ super().__init__(
740
+ name="STAC Search Agent",
741
+ model=model,
742
+ tools=[
743
+ self.stac_tools.list_collections,
744
+ self.stac_tools.search_items,
745
+ self.stac_tools.get_item_info,
746
+ self.stac_tools.geocode_location,
747
+ self.stac_tools.get_common_collections,
748
+ ],
749
+ system_prompt=system_prompt,
750
+ callback_handler=None,
751
+ )
752
+
753
+ def _extract_search_items_payload(self, result: Any) -> Optional[Dict[str, Any]]:
754
+ """Return the parsed payload from the search_items tool, if available."""
755
+ # Try to get tool_results from the result object
756
+ tool_results = getattr(result, "tool_results", None)
757
+ if tool_results:
758
+ for tool_result in tool_results:
759
+ if getattr(tool_result, "tool_name", "") != "search_items":
760
+ continue
761
+
762
+ payload = getattr(tool_result, "result", None)
763
+ if payload is None:
764
+ continue
765
+
766
+ if isinstance(payload, str):
767
+ try:
768
+ payload = json.loads(payload)
769
+ except json.JSONDecodeError:
770
+ continue
771
+
772
+ if isinstance(payload, dict):
773
+ return payload
774
+
775
+ # Alternative: check messages for tool results
776
+ messages = getattr(self, "messages", [])
777
+ for msg in reversed(messages): # Check recent messages first
778
+ # Handle dict-style messages
779
+ if isinstance(msg, dict):
780
+ role = msg.get("role")
781
+ content = msg.get("content", [])
782
+
783
+ # Look for tool results in user messages (strands pattern)
784
+ if role == "user" and isinstance(content, list):
785
+ for item in content:
786
+ if isinstance(item, dict) and "toolResult" in item:
787
+ tool_result = item["toolResult"]
788
+ # Check if this is a search_items result
789
+ # We need to look at the preceding assistant message to identify the tool
790
+ if tool_result.get("status") == "success":
791
+ result_content = tool_result.get("content", [])
792
+ if isinstance(result_content, list) and result_content:
793
+ text_content = result_content[0].get("text", "")
794
+ try:
795
+ payload = json.loads(text_content)
796
+ # Return ANY search_items payload, even if items is empty
797
+ # This is identified by having "query" and "collection" fields
798
+ if (
799
+ "query" in payload
800
+ and "collection" in payload
801
+ and "items" in payload
802
+ ):
803
+ return payload
804
+ except json.JSONDecodeError:
805
+ continue
806
+
807
+ return None
808
+
809
+ def ask(self, prompt: str) -> str:
810
+ """Send a single-turn prompt to the agent.
811
+
812
+ Args:
813
+ prompt: The text prompt to send to the agent.
814
+
815
+ Returns:
816
+ The agent's response as a string (JSON format for search queries).
817
+ """
818
+ # Use strands' built-in __call__ method which now supports multiple calls
819
+ result = self(prompt)
820
+
821
+ search_payload = self._extract_search_items_payload(result)
822
+ if search_payload is not None:
823
+ if "error" in search_payload:
824
+ return json.dumps({"error": search_payload["error"]}, indent=2)
825
+
826
+ items = search_payload.get("items") or []
827
+ if items:
828
+ return json.dumps(items[0], indent=2)
829
+
830
+ return json.dumps({"error": "No items found"}, indent=2)
831
+
832
+ return getattr(result, "final_text", str(result))
833
+
834
+ def search_and_get_first_item(self, prompt: str) -> Optional[Dict[str, Any]]:
835
+ """Search for imagery and return the first item as a structured dict.
836
+
837
+ This method sends a search query to the agent, extracts the search results
838
+ directly from the tool calls, and returns the first item as a STACItemInfo-compatible
839
+ dictionary.
840
+
841
+ Note: This method uses LLM inference which adds ~5-10 seconds overhead.
842
+ For faster searches, use STACTools directly:
843
+ >>> from geoai.agents import STACTools
844
+ >>> tools = STACTools()
845
+ >>> result = tools.search_items(
846
+ ... collection="sentinel-2-l2a",
847
+ ... bbox=[-122.5, 37.7, -122.4, 37.8],
848
+ ... time_range="2024-08-01/2024-08-31"
849
+ ... )
850
+
851
+ Args:
852
+ prompt: Natural language search query (e.g., "Find Sentinel-2 imagery
853
+ over San Francisco in September 2024").
854
+
855
+ Returns:
856
+ Dictionary containing STACItemInfo fields (id, collection, datetime,
857
+ bbox, assets, properties), or None if no results found.
858
+
859
+ Example:
860
+ >>> agent = STACAgent()
861
+ >>> item = agent.search_and_get_first_item(
862
+ ... "Find Sentinel-2 imagery over Paris in summer 2023"
863
+ ... )
864
+ >>> print(item['id'])
865
+ >>> print(item['assets'][0]['key']) # or 'title'
866
+ """
867
+ # Use strands' built-in __call__ method which now supports multiple calls
868
+ result = self(prompt)
869
+
870
+ search_payload = self._extract_search_items_payload(result)
871
+ if search_payload is not None:
872
+ if "error" in search_payload:
873
+ print(f"Search error: {search_payload['error']}")
874
+ return None
875
+
876
+ items = search_payload.get("items") or []
877
+ if items:
878
+ return items[0]
879
+
880
+ print("No items found in search results")
881
+ return None
882
+
883
+ # Fallback: try to parse the final text response
884
+ response = getattr(result, "final_text", str(result))
885
+
886
+ try:
887
+ item_data = json.loads(response)
888
+
889
+ if "error" in item_data:
890
+ print(f"Search error: {item_data['error']}")
891
+ return None
892
+
893
+ if not all(k in item_data for k in ["id", "collection"]):
894
+ print("Response missing required fields (id, collection)")
895
+ return None
896
+
897
+ return item_data
898
+
899
+ except json.JSONDecodeError:
900
+ print("Could not extract item data from agent response")
901
+ return None
902
+
903
+ def _visualize_stac_item(self, item: Dict[str, Any]) -> None:
904
+ """Visualize a STAC item on the map.
905
+
906
+ Args:
907
+ item: STAC item dictionary with id, collection, assets, etc.
908
+ """
909
+ if not item or "id" not in item or "collection" not in item:
910
+ return
911
+
912
+ # Get the collection and item ID
913
+ collection = item.get("collection")
914
+ item_id = item.get("id")
915
+
916
+ kwargs = {}
917
+
918
+ # Determine which assets to display based on collection
919
+ assets = None
920
+ if collection == "sentinel-2-l2a":
921
+ assets = ["B04", "B03", "B02"] # True color RGB
922
+ elif collection == "landsat-c2-l2":
923
+ assets = ["red", "green", "blue"] # Landsat RGB
924
+ elif collection == "naip":
925
+ assets = ["image"] # NAIP 4-band imagery
926
+ elif "sentinel-1" in collection:
927
+ assets = ["vv"] # Sentinel-1 VV polarization
928
+ elif collection == "cop-dem-glo-30":
929
+ assets = ["data"]
930
+ kwargs["colormap_name"] = "terrain"
931
+ elif collection == "aster-l1t":
932
+ assets = ["VNIR"] # ASTER L1T imagery
933
+ elif collection == "3dep-lidar-hag":
934
+ assets = ["data"]
935
+ kwargs["colormap_name"] = "terrain"
936
+ else:
937
+ # Try to find common asset names
938
+ if "assets" in item:
939
+ asset_keys = [
940
+ a.get("key") for a in item["assets"] if isinstance(a, dict)
941
+ ]
942
+ # Look for visual or RGB assets
943
+ for possible in ["visual", "image", "data"]:
944
+ if possible in asset_keys:
945
+ assets = [possible]
946
+ break
947
+ # If still no assets, use first few assets
948
+ if not assets and asset_keys:
949
+ assets = asset_keys[:1]
950
+
951
+ if not assets:
952
+ return None
953
+ try:
954
+ # Add the STAC layer to the map
955
+ layer_name = f"{collection[:20]}_{item_id[:15]}"
956
+ self.map_instance.add_stac_layer(
957
+ collection=collection,
958
+ item=item_id,
959
+ assets=assets,
960
+ name=layer_name,
961
+ before_id=self.map_instance.first_symbol_layer_id,
962
+ **kwargs,
963
+ )
964
+ return assets # Return the assets that were visualized
965
+ except Exception as e:
966
+ print(f"Could not visualize item on map: {e}")
967
+ return None
968
+
969
+ def show_ui(self, *, height: int = 700) -> None:
970
+ """Display an interactive UI with map and chat interface for STAC searches.
971
+
972
+ Args:
973
+ height: Height of the UI in pixels (default: 700).
974
+ """
975
+ m = self.map_instance
976
+ if not hasattr(m, "container") or m.container is None:
977
+ m.create_container()
978
+
979
+ map_panel = widgets.VBox(
980
+ [
981
+ widgets.HTML("<h3 style='margin:0 0 8px 0'>Map</h3>"),
982
+ m.floating_sidebar_widget,
983
+ ],
984
+ layout=widgets.Layout(
985
+ flex="2 1 0%",
986
+ min_width="520px",
987
+ border="1px solid #ddd",
988
+ padding="8px",
989
+ height=f"{height}px",
990
+ overflow="hidden",
991
+ ),
992
+ )
993
+
994
+ # ----- chat widgets -----
995
+ session_id = str(uuid.uuid4())[:8]
996
+ title = widgets.HTML(
997
+ f"<h3 style='margin:0'>STAC Search Agent</h3>"
998
+ f"<p style='margin:4px 0 8px;color:#666'>Session: {session_id}</p>"
999
+ )
1000
+ log = widgets.HTML(
1001
+ value="<div style='color:#777'>No messages yet. Try searching for satellite imagery!</div>",
1002
+ layout=widgets.Layout(
1003
+ border="1px solid #ddd",
1004
+ padding="8px",
1005
+ height="520px",
1006
+ overflow_y="auto",
1007
+ ),
1008
+ )
1009
+ inp = widgets.Textarea(
1010
+ placeholder="Search for satellite/aerial imagery (e.g., 'Find Sentinel-2 imagery over Paris in summer 2024')",
1011
+ layout=widgets.Layout(width="99%", height="90px"),
1012
+ )
1013
+ btn_send = widgets.Button(
1014
+ description="Search",
1015
+ button_style="primary",
1016
+ icon="search",
1017
+ layout=widgets.Layout(width="120px"),
1018
+ )
1019
+ btn_stop = widgets.Button(
1020
+ description="Stop", icon="stop", layout=widgets.Layout(width="120px")
1021
+ )
1022
+ btn_clear = widgets.Button(
1023
+ description="Clear", icon="trash", layout=widgets.Layout(width="120px")
1024
+ )
1025
+ status = widgets.HTML(
1026
+ "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css'>"
1027
+ "<span style='color:#666'>Ready to search.</span>"
1028
+ )
1029
+
1030
+ examples = widgets.Dropdown(
1031
+ options=[
1032
+ ("— Example Searches —", ""),
1033
+ (
1034
+ "Sentinel-2 over Las Vegas",
1035
+ "Find Sentinel-2 imagery over Las Vegas in August 2025",
1036
+ ),
1037
+ (
1038
+ "Landsat over Paris",
1039
+ "Find Landsat imagery over Paris from June to July 2025",
1040
+ ),
1041
+ (
1042
+ "Landsat with <10% cloud cover",
1043
+ "Find Landsat imagery over Paris with <10% cloud cover in June 2025",
1044
+ ),
1045
+ ("NAIP for NYC", "Show me NAIP aerial imagery for New York City"),
1046
+ ("DEM for Seattle", "Show me DEM data for Seattle"),
1047
+ (
1048
+ "3DEP Lidar HAG",
1049
+ "Show me data over Austin from 3dep-lidar-hag collection",
1050
+ ),
1051
+ ("ASTER for Tokyo", "Show me ASTER imagery for Tokyo"),
1052
+ ],
1053
+ value="",
1054
+ layout=widgets.Layout(width="auto"),
1055
+ )
1056
+
1057
+ # --- state kept on self so it persists ---
1058
+ self._ui = SimpleNamespace(
1059
+ messages=[],
1060
+ map_panel=map_panel,
1061
+ title=title,
1062
+ log=log,
1063
+ inp=inp,
1064
+ btn_send=btn_send,
1065
+ btn_stop=btn_stop,
1066
+ btn_clear=btn_clear,
1067
+ status=status,
1068
+ examples=examples,
1069
+ )
1070
+ self._pending = {"fut": None}
1071
+
1072
+ def _esc(s: str) -> str:
1073
+ """Escape HTML characters in a string."""
1074
+ return (
1075
+ s.replace("&", "&amp;")
1076
+ .replace("<", "&lt;")
1077
+ .replace(">", "&gt;")
1078
+ .replace("\n", "<br/>")
1079
+ )
1080
+
1081
+ def _append(role: str, msg: str) -> None:
1082
+ """Append a message to the chat log."""
1083
+ self._ui.messages.append((role, msg))
1084
+ parts = []
1085
+ for r, mm in self._ui.messages:
1086
+ if r == "user":
1087
+ parts.append(
1088
+ f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#eef;'><b>You</b>: {_esc(mm)}</div>"
1089
+ )
1090
+ else:
1091
+ parts.append(
1092
+ f"<div style='margin:6px 0;padding:6px 8px;border-radius:8px;background:#f7f7f7;'><b>Agent</b>: {_esc(mm)}</div>"
1093
+ )
1094
+ self._ui.log.value = (
1095
+ "<div>"
1096
+ + (
1097
+ "".join(parts)
1098
+ if parts
1099
+ else "<div style='color:#777'>No messages yet.</div>"
1100
+ )
1101
+ + "</div>"
1102
+ )
1103
+
1104
+ def _lock(lock: bool) -> None:
1105
+ """Lock or unlock UI controls."""
1106
+ self._ui.btn_send.disabled = lock
1107
+ self._ui.btn_stop.disabled = not lock
1108
+ self._ui.btn_clear.disabled = lock
1109
+ self._ui.inp.disabled = lock
1110
+ self._ui.examples.disabled = lock
1111
+
1112
+ def _on_send(_: Any = None) -> None:
1113
+ """Handle send button click or Enter key press."""
1114
+ text = self._ui.inp.value.strip()
1115
+ if not text:
1116
+ return
1117
+ _append("user", text)
1118
+ _lock(True)
1119
+ self._ui.status.value = "<span style='color:#0a7'>Searching…</span>"
1120
+ try:
1121
+ # Create a callback handler that updates the status widget
1122
+ callback_handler = UICallbackHandler(status_widget=self._ui.status)
1123
+
1124
+ # Temporarily set callback_handler for this call
1125
+ old_callback = self.callback_handler
1126
+ self.callback_handler = callback_handler
1127
+
1128
+ # Get the structured search result directly (will show progress via callback)
1129
+ item_data = self.search_and_get_first_item(text)
1130
+
1131
+ if item_data is not None:
1132
+ # Update status for visualization step
1133
+ self._ui.status.value = (
1134
+ "<span style='color:#0a7'>Adding to map...</span>"
1135
+ )
1136
+
1137
+ # Visualize on map
1138
+ visualized_assets = self._visualize_stac_item(item_data)
1139
+
1140
+ # Format response for display
1141
+ formatted_response = (
1142
+ f"Found item: {item_data['id']}\n"
1143
+ f"Collection: {item_data['collection']}\n"
1144
+ f"Date: {item_data.get('datetime', 'N/A')}\n"
1145
+ )
1146
+
1147
+ if visualized_assets:
1148
+ assets_str = ", ".join(visualized_assets)
1149
+ formatted_response += f"✓ Added to map (assets: {assets_str})"
1150
+ else:
1151
+ formatted_response += "✓ Added to map"
1152
+
1153
+ _append("assistant", formatted_response)
1154
+ else:
1155
+ _append(
1156
+ "assistant",
1157
+ "No items found. Try adjusting your search query or date range.",
1158
+ )
1159
+
1160
+ self._ui.status.value = "<span style='color:#0a7'>Done.</span>"
1161
+ except Exception as e:
1162
+ _append("assistant", f"[error] {type(e).__name__}: {e}")
1163
+ self._ui.status.value = (
1164
+ "<span style='color:#c00'>Finished with an issue.</span>"
1165
+ )
1166
+ finally:
1167
+ # Restore old callback handler
1168
+ self.callback_handler = old_callback
1169
+ self._ui.inp.value = ""
1170
+ _lock(False)
1171
+
1172
+ def _on_stop(_: Any = None) -> None:
1173
+ """Handle stop button click."""
1174
+ fut = self._pending.get("fut")
1175
+ if fut and not fut.done():
1176
+ self._pending["fut"] = None
1177
+ self._ui.status.value = (
1178
+ "<span style='color:#c00'>Stop requested.</span>"
1179
+ )
1180
+ _lock(False)
1181
+
1182
+ def _on_clear(_: Any = None) -> None:
1183
+ """Handle clear button click."""
1184
+ self._ui.messages.clear()
1185
+ self._ui.log.value = "<div style='color:#777'>No messages yet.</div>"
1186
+ self._ui.status.value = "<span style='color:#666'>Cleared.</span>"
1187
+
1188
+ def _on_example_change(change: dict[str, Any]) -> None:
1189
+ """Handle example dropdown selection change."""
1190
+ if change["name"] == "value" and change["new"]:
1191
+ self._ui.inp.value = change["new"]
1192
+ self._ui.examples.value = ""
1193
+ self._ui.inp.send({"method": "focus"})
1194
+
1195
+ # keep handler refs
1196
+ self._handlers = SimpleNamespace(
1197
+ on_send=_on_send,
1198
+ on_stop=_on_stop,
1199
+ on_clear=_on_clear,
1200
+ on_example_change=_on_example_change,
1201
+ )
1202
+
1203
+ # wire events
1204
+ self._ui.btn_send.on_click(self._handlers.on_send)
1205
+ self._ui.btn_stop.on_click(self._handlers.on_stop)
1206
+ self._ui.btn_clear.on_click(self._handlers.on_clear)
1207
+ self._ui.examples.observe(self._handlers.on_example_change, names="value")
1208
+
1209
+ # Ctrl+Enter on textarea
1210
+ self._keyev = Event(
1211
+ source=self._ui.inp, watched_events=["keyup"], prevent_default_action=False
1212
+ )
1213
+
1214
+ def _on_key(e: dict[str, Any]) -> None:
1215
+ """Handle keyboard events on the input textarea."""
1216
+ if (
1217
+ e.get("type") == "keyup"
1218
+ and e.get("key") == "Enter"
1219
+ and e.get("ctrlKey")
1220
+ ):
1221
+ if self._ui.inp.value.endswith("\n"):
1222
+ self._ui.inp.value = self._ui.inp.value[:-1]
1223
+ self._handlers.on_send()
1224
+
1225
+ # store callback too
1226
+ self._on_key_cb: Callable[[dict[str, Any]], None] = _on_key
1227
+ self._keyev.on_dom_event(self._on_key_cb)
1228
+
1229
+ buttons = widgets.HBox(
1230
+ [
1231
+ self._ui.btn_send,
1232
+ self._ui.btn_stop,
1233
+ self._ui.btn_clear,
1234
+ widgets.Box(
1235
+ [self._ui.examples], layout=widgets.Layout(margin="0 0 0 auto")
1236
+ ),
1237
+ ]
1238
+ )
1239
+ right = widgets.VBox(
1240
+ [
1241
+ self._ui.title,
1242
+ self._ui.log,
1243
+ self._ui.inp,
1244
+ buttons,
1245
+ self._ui.status,
1246
+ ],
1247
+ layout=widgets.Layout(flex="1 1 0%", min_width="360px"),
1248
+ )
1249
+ root = widgets.HBox(
1250
+ [map_panel, right], layout=widgets.Layout(width="100%", gap="8px")
1251
+ )
1252
+ display(root)
1253
+
1254
+
1255
+ class CatalogAgent(Agent):
1256
+ """AI agent for searching data catalogs with natural language queries."""
1257
+
1258
+ def __init__(
1259
+ self,
1260
+ *,
1261
+ model: str = "llama3.1",
1262
+ system_prompt: str = "default",
1263
+ catalog_url: Optional[str] = None,
1264
+ catalog_df: Optional[Any] = None,
1265
+ model_args: dict = None,
1266
+ **kwargs: Any,
1267
+ ) -> None:
1268
+ """Initialize the Catalog Agent.
1269
+
1270
+ Args:
1271
+ model: Model identifier (default: "llama3.1").
1272
+ system_prompt: System prompt for the agent (default: "default").
1273
+ catalog_url: URL to a catalog file (TSV, CSV, or JSON). Use JSON format for spatial search support.
1274
+ Example: "https://raw.githubusercontent.com/opengeos/Earth-Engine-Catalog/refs/heads/master/gee_catalog.json"
1275
+ catalog_df: Pre-loaded catalog as a pandas DataFrame.
1276
+ model_args: Additional keyword arguments for the model.
1277
+ **kwargs: Additional keyword arguments for the Agent.
1278
+ """
1279
+ self.catalog_tools: CatalogTools = CatalogTools(
1280
+ catalog_url=catalog_url, catalog_df=catalog_df
1281
+ )
1282
+
1283
+ if model_args is None:
1284
+ model_args = {}
1285
+
1286
+ # --- save a model factory we can call each turn ---
1287
+ if model == "llama3.1":
1288
+ self._model_factory: Callable[[], OllamaModel] = (
1289
+ lambda: create_ollama_model(
1290
+ host="http://localhost:11434", model_id=model, **model_args
1291
+ )
1292
+ )
1293
+ elif isinstance(model, str):
1294
+ self._model_factory: Callable[[], BedrockModel] = (
1295
+ lambda: create_bedrock_model(model_id=model, **model_args)
1296
+ )
1297
+ elif isinstance(model, OllamaModel):
1298
+ # Extract configuration from existing OllamaModel and create new instances
1299
+ model_id = model.config["model_id"]
1300
+ host = model.host
1301
+ client_args = model.client_args
1302
+ self._model_factory: Callable[[], OllamaModel] = (
1303
+ lambda: create_ollama_model(
1304
+ host=host, model_id=model_id, client_args=client_args, **model_args
1305
+ )
1306
+ )
1307
+ elif isinstance(model, OpenAIModel):
1308
+ # Extract configuration from existing OpenAIModel and create new instances
1309
+ model_id = model.config["model_id"]
1310
+ client_args = model.client_args.copy()
1311
+ self._model_factory: Callable[[], OpenAIModel] = (
1312
+ lambda mid=model_id, client_args=client_args: create_openai_model(
1313
+ model_id=mid, client_args=client_args, **model_args
1314
+ )
1315
+ )
1316
+ elif isinstance(model, AnthropicModel):
1317
+ # Extract configuration from existing AnthropicModel and create new instances
1318
+ model_id = model.config["model_id"]
1319
+ client_args = model.client_args.copy()
1320
+ self._model_factory: Callable[[], AnthropicModel] = (
1321
+ lambda mid=model_id, client_args=client_args: create_anthropic_model(
1322
+ model_id=mid, client_args=client_args, **model_args
1323
+ )
1324
+ )
1325
+ else:
1326
+ raise ValueError(f"Invalid model: {model}")
1327
+
1328
+ # build initial model (first turn)
1329
+ model = self._model_factory()
1330
+
1331
+ if system_prompt == "default":
1332
+ system_prompt = """You are a data catalog search agent. Your job is to help users find datasets from a data catalog.
1333
+
1334
+ IMPORTANT: Follow these steps EXACTLY:
1335
+
1336
+ 1. Understand the user's query:
1337
+ - What type of data are they looking for? (e.g., landcover, elevation, imagery)
1338
+ - Are they searching for a specific geographic region? (e.g., California, San Francisco, bounding box)
1339
+ - Are they filtering by time period? (e.g., "from 2020", "between 2015-2020", "recent data")
1340
+ - Are they filtering by provider? (e.g., NASA, USGS)
1341
+ - Are they filtering by dataset type? (e.g., image, image_collection, table)
1342
+
1343
+ 2. Use the appropriate tool:
1344
+ - search_by_region: PREFERRED for spatial queries - search datasets covering a geographic region
1345
+ * Use location parameter for place names (e.g., "California", "San Francisco")
1346
+ * Use bbox parameter for coordinates [west, south, east, north]
1347
+ * Can combine with keywords, dataset_type, provider, start_date, end_date filters
1348
+ - search_datasets: For keyword-only searches without spatial filter
1349
+ * Can filter by start_date and end_date for temporal queries
1350
+ - geocode_location: Convert location names to coordinates (called automatically by search_by_region)
1351
+ - get_dataset_info: Get details about a specific dataset by ID
1352
+ - list_dataset_types: Show available dataset types
1353
+ - list_providers: Show available data providers
1354
+ - get_catalog_stats: Get overall catalog statistics
1355
+
1356
+ 3. Search strategy:
1357
+ - SPATIAL QUERIES: If user mentions ANY location or region, IMMEDIATELY use search_by_region
1358
+ * Pass location names directly to the location parameter - DO NOT ask user for bbox coordinates
1359
+ * Examples of locations: California, San Francisco, New York, Paris, any city/state/country name
1360
+ * search_by_region will automatically geocode location names - you don't need to call geocode_location separately
1361
+ - TEMPORAL QUERIES: If user mentions ANY time period, ALWAYS add start_date/end_date parameters
1362
+ * "from 2022" or "since 2022" or "2022 onwards" → start_date="2022-01-01"
1363
+ * "until 2023" or "before 2023" → end_date="2023-12-31"
1364
+ * "between 2020 and 2023" → start_date="2020-01-01", end_date="2023-12-31"
1365
+ * "recent" or "latest" → start_date="2020-01-01"
1366
+ * Time indicators: from, since, after, before, until, between, onwards, recent, latest
1367
+ - KEYWORD QUERIES: If no location mentioned, use search_datasets
1368
+ - Extract key search terms from the user's query
1369
+ - Use keywords parameter for the main search terms
1370
+ - Use dataset_type parameter if user specifies type (image, table, etc.)
1371
+ - Use provider parameter if user specifies provider (NASA, USGS, etc.)
1372
+ - Default max_results is 10, but can be adjusted
1373
+
1374
+ CRITICAL RULES:
1375
+ 1. NEVER ask the user to provide bbox coordinates. If they mention a location name, pass it directly to search_by_region(location="name")
1376
+ 2. ALWAYS add start_date or end_date when user mentions ANY time period (from, since, onwards, recent, etc.)
1377
+ 3. Convert years to YYYY-MM-DD format: 2022 → "2022-01-01"
1378
+
1379
+ 4. Examples:
1380
+ - "Find landcover datasets covering California" → search_by_region(location="California", keywords="landcover")
1381
+ - "Show elevation data for San Francisco" → search_by_region(location="San Francisco", keywords="elevation")
1382
+ - "Find datasets in bbox [-122, 37, -121, 38]" → search_by_region(bbox=[-122, 37, -121, 38])
1383
+ - "Find landcover datasets from NASA" → search_datasets(keywords="landcover", provider="NASA")
1384
+ - "Show me elevation data" → search_datasets(keywords="elevation")
1385
+ - "What types of datasets are available?" → list_dataset_types()
1386
+ - "Find image collections about forests" → search_datasets(keywords="forest", dataset_type="image_collection")
1387
+ - "Find landcover data from 2020 onwards" → search_datasets(keywords="landcover", start_date="2020-01-01")
1388
+ - "Show California datasets between 2015 and 2020" → search_by_region(location="California", start_date="2015-01-01", end_date="2020-12-31")
1389
+ - "Find recent elevation data" → search_datasets(keywords="elevation", start_date="2020-01-01")
1390
+
1391
+ 5. Return results clearly:
1392
+ - Summarize the number of results found
1393
+ - List the top results with their EXACT IDs and titles FROM THE TOOL RESPONSE
1394
+ - Mention key information like provider, type, geographic coverage, and date range if available
1395
+ - For spatial searches, mention the search region
1396
+
1397
+ ERROR HANDLING:
1398
+ - If no results found: Suggest trying different keywords, broader region, or removing filters
1399
+ - If location not found: Suggest alternative spellings or try a broader region
1400
+ - If tool error: Explain the error and suggest alternatives
1401
+
1402
+ CRITICAL RULES - MUST FOLLOW:
1403
+ 1. NEVER make up or hallucinate dataset IDs, titles, or any other information
1404
+ 2. ONLY report datasets that appear in the actual tool response
1405
+ 3. Copy dataset IDs and titles EXACTLY as they appear in the tool response
1406
+ 4. If a field is null/None in the tool response, say "N/A" or omit it - DO NOT guess
1407
+ 5. DO NOT use your training data knowledge about Earth Engine datasets
1408
+ 6. DO NOT fill in missing information from your knowledge
1409
+ 7. If unsure, say "Information not available in results"
1410
+
1411
+ Example of CORRECT behavior:
1412
+ Tool returns: {"id": "AAFC/ACI", "title": "Canada AAFC Annual Crop Inventory"}
1413
+ Your response: "Found dataset: AAFC/ACI - Canada AAFC Annual Crop Inventory"
1414
+
1415
+ Example of INCORRECT behavior (DO NOT DO THIS):
1416
+ Tool returns: {"id": "AAFC/ACI", "title": "Canada AAFC Annual Crop Inventory"}
1417
+ Your response: "Found dataset: USGS/NED - USGS Elevation Data" ← WRONG! This ID wasn't in the tool response!"""
1418
+
1419
+ super().__init__(
1420
+ name="Catalog Search Agent",
1421
+ model=model,
1422
+ tools=[
1423
+ self.catalog_tools.search_datasets,
1424
+ self.catalog_tools.search_by_region,
1425
+ self.catalog_tools.get_dataset_info,
1426
+ self.catalog_tools.geocode_location,
1427
+ self.catalog_tools.list_dataset_types,
1428
+ self.catalog_tools.list_providers,
1429
+ self.catalog_tools.get_catalog_stats,
1430
+ ],
1431
+ system_prompt=system_prompt,
1432
+ callback_handler=None,
1433
+ )
1434
+
1435
+ def ask(self, prompt: str) -> str:
1436
+ """Send a single-turn prompt to the agent.
1437
+
1438
+ Args:
1439
+ prompt: The text prompt to send to the agent.
1440
+
1441
+ Returns:
1442
+ The agent's response as a string.
1443
+ """
1444
+ # Use strands' built-in __call__ method which now supports multiple calls
1445
+ result = self(prompt)
1446
+ return getattr(result, "final_text", str(result))
1447
+
1448
+ def search_datasets(
1449
+ self,
1450
+ keywords: Optional[str] = None,
1451
+ dataset_type: Optional[str] = None,
1452
+ provider: Optional[str] = None,
1453
+ max_results: int = 10,
1454
+ ) -> List[Dict[str, Any]]:
1455
+ """Search for datasets and return structured results.
1456
+
1457
+ This method directly uses the CatalogTools without LLM inference for faster searches.
1458
+
1459
+ Args:
1460
+ keywords: Keywords to search for.
1461
+ dataset_type: Filter by dataset type.
1462
+ provider: Filter by provider.
1463
+ max_results: Maximum number of results to return.
1464
+
1465
+ Returns:
1466
+ List of dataset dictionaries.
1467
+
1468
+ Example:
1469
+ >>> agent = CatalogAgent(catalog_url="...")
1470
+ >>> datasets = agent.search_datasets(keywords="landcover", provider="NASA")
1471
+ >>> for ds in datasets:
1472
+ ... print(ds['id'], ds['title'])
1473
+ """
1474
+ result_json = self.catalog_tools.search_datasets(
1475
+ keywords=keywords,
1476
+ dataset_type=dataset_type,
1477
+ provider=provider,
1478
+ max_results=max_results,
1479
+ )
1480
+
1481
+ result = json.loads(result_json)
1482
+
1483
+ if "error" in result:
1484
+ print(f"Search error: {result['error']}")
1485
+ return []
1486
+
1487
+ return result.get("datasets", [])