databao-agent 0.1.4.dev5__tar.gz → 0.1.4.dev8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/PKG-INFO +2 -1
  2. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/api.py +8 -3
  3. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/agent.py +1 -1
  4. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/executor.py +4 -4
  5. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/snowflake_adapter.py +7 -1
  6. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/react_tools.py +1 -1
  7. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/__init__.py +2 -1
  8. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/base.py +42 -40
  9. databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/claude_model_wrapper.py +313 -0
  10. databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/executor.py +115 -0
  11. databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/system_prompt.jinja +83 -0
  12. databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/utils.py +81 -0
  13. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/graph.py +1 -1
  14. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/frontend/text_frontend.py +22 -2
  15. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/executor.py +23 -12
  16. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/graph.py +18 -58
  17. databao_agent-0.1.4.dev8/databao/agent/executors/utils.py +69 -0
  18. databao_agent-0.1.4.dev8/databao/agent/visualizers/__init__.py +0 -0
  19. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/pyproject.toml +1 -0
  20. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/.gitignore +0 -0
  21. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/LICENSE.md +0 -0
  22. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/README.md +0 -0
  23. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/client/out/multimodal-html/index.html +0 -0
  24. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/client/out/multimodal-jupyter/index.js +0 -0
  25. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/client/out/multimodal-jupyter/style.css +0 -0
  26. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/__init__.py +0 -0
  27. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/caches/__init__.py +0 -0
  28. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/caches/disk_cache.py +0 -0
  29. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/caches/in_mem_cache.py +0 -0
  30. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/configs/__init__.py +0 -0
  31. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/configs/agent.py +0 -0
  32. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/configs/llm.py +0 -0
  33. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/__init__.py +0 -0
  34. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/cache.py +0 -0
  35. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/data_source.py +0 -0
  36. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/domain.py +0 -0
  37. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/env.py +0 -0
  38. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/opa.py +0 -0
  39. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/sources.py +0 -0
  40. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/thread.py +0 -0
  41. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/visualizer.py +0 -0
  42. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/__init__.py +0 -0
  43. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/bigquery_adapter.py +0 -0
  44. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/database_adapter.py +0 -0
  45. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/database_connection.py +0 -0
  46. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/databases.py +0 -0
  47. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/duckdb_adapter.py +0 -0
  48. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/mysql_adapter.py +0 -0
  49. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/postgresql_adapter.py +0 -0
  50. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/sqlite_adapter.py +0 -0
  51. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/utils.py +0 -0
  52. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/dbt/__init__.py +0 -0
  53. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/dbt/dbt.py +0 -0
  54. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/__init__.py +0 -0
  55. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/schema_inspection.py +0 -0
  56. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/types.py +0 -0
  57. {databao_agent-0.1.4.dev5/databao/agent/executors/frontend → databao_agent-0.1.4.dev8/databao/agent/executors/claude_code}/__init__.py +0 -0
  58. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/__init__.py +0 -0
  59. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/config.py +0 -0
  60. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/dbt_runner.py +0 -0
  61. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/errors.py +0 -0
  62. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/executor.py +0 -0
  63. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/query_runner.py +0 -0
  64. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/system_prompt.jinja +0 -0
  65. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/task_instruction.jinja +0 -0
  66. {databao_agent-0.1.4.dev5/databao/agent/executors/react_duckdb → databao_agent-0.1.4.dev8/databao/agent/executors/frontend}/__init__.py +0 -0
  67. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/frontend/messages.py +0 -0
  68. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/history_cleaning.py +0 -0
  69. /databao_agent-0.1.4.dev5/databao/agent/executors/tools.py → /databao_agent-0.1.4.dev8/databao/agent/executors/langchain_tools.py +0 -0
  70. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/__init__.py +0 -0
  71. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/system_prompt.jinja +0 -0
  72. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/llm.py +0 -0
  73. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/prompt.py +0 -0
  74. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/query_expansion.py +0 -0
  75. {databao_agent-0.1.4.dev5/databao/agent/integrations → databao_agent-0.1.4.dev8/databao/agent/executors/react_duckdb}/__init__.py +0 -0
  76. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/react_duckdb/executor.py +0 -0
  77. {databao_agent-0.1.4.dev5/databao/agent/visualizers → databao_agent-0.1.4.dev8/databao/agent/integrations}/__init__.py +0 -0
  78. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/__init__.py +0 -0
  79. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/databao_context.py +0 -0
  80. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/databao_context_engine.py +0 -0
  81. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/databao_context_project_manager.py +0 -0
  82. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/__init__.py +0 -0
  83. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/adapter.py +0 -0
  84. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/config.py +0 -0
  85. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/connection.py +0 -0
  86. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/manager.py +0 -0
  87. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/oauth.py +0 -0
  88. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/__init__.py +0 -0
  89. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/html_viewer.py +0 -0
  90. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/jupyter_widget.py +0 -0
  91. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/utils.py +0 -0
  92. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/py.typed +0 -0
  93. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/visualizers/dumb.py +0 -0
  94. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/visualizers/vega_chat.py +0 -0
  95. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/visualizers/vega_vis_tool.py +0 -0
  96. {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/hatch_build.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: databao-agent
3
- Version: 0.1.4.dev5
3
+ Version: 0.1.4.dev8
4
4
  Summary: databao-agent: NL queries for data
5
5
  Project-URL: Homepage, https://databao.app/
6
6
  Project-URL: Source, https://github.com/JetBrains/databao-agent
@@ -9,6 +9,7 @@ License-File: LICENSE.md
9
9
  Classifier: Operating System :: OS Independent
10
10
  Classifier: Programming Language :: Python :: 3
11
11
  Requires-Python: <3.15,>=3.11
12
+ Requires-Dist: claude-agent-sdk>=0.1.48
12
13
  Requires-Dist: databao-context-engine[mysql,postgresql,snowflake]~=0.6.0
13
14
  Requires-Dist: dbt-core~=1.9.0
14
15
  Requires-Dist: dbt-duckdb>=1.10.0
@@ -8,10 +8,13 @@ from databao.agent.configs.agent import DEFAULT_AGENT_CONFIG, AgentConfig
8
8
  from databao.agent.configs.llm import LLMConfig, LLMConfigDirectory
9
9
  from databao.agent.core import Agent, Cache, Executor, Visualizer
10
10
  from databao.agent.core.domain import Domain, _DCEProjectDomain, _InMemoryDomain
11
- from databao.agent.executors import ReactDuckDBExecutor
11
+ from databao.agent.executors import (
12
+ ClaudeCodeExecutor,
13
+ DbtProjectExecutor,
14
+ LighthouseExecutor,
15
+ ReactDuckDBExecutor,
16
+ )
12
17
  from databao.agent.executors.dbt.config import DbtConfig
13
- from databao.agent.executors.dbt.executor import DbtProjectExecutor
14
- from databao.agent.executors.lighthouse.executor import LighthouseExecutor
15
18
  from databao.agent.visualizers.vega_chat import VegaChatVisualizer
16
19
 
17
20
 
@@ -49,6 +52,8 @@ def agent(
49
52
  data_executor = DbtProjectExecutor(dbt_config=dbt_config, writer=writer)
50
53
  case "react_duckdb":
51
54
  data_executor = ReactDuckDBExecutor(writer=writer)
55
+ case "claude":
56
+ data_executor = ClaudeCodeExecutor(writer=writer)
52
57
  case _:
53
58
  raise ValueError(f"Invalid executor type: {executor_type}")
54
59
 
@@ -55,7 +55,7 @@ class Agent:
55
55
  self.__llm_config = llm
56
56
  self.__agent_config = agent_config
57
57
 
58
- self.__executor = data_executor
58
+ self.__executor: Executor = data_executor
59
59
  self.__visualizer = visualizer
60
60
  self.__cache = cache
61
61
  self.__mcp: McpManager = McpManager()
@@ -4,7 +4,6 @@ import re
4
4
  from abc import ABC, abstractmethod
5
5
  from typing import TYPE_CHECKING, Any, ClassVar, Literal, TextIO
6
6
 
7
- from langchain_core.tools import BaseTool
8
7
  from pandas import DataFrame
9
8
  from pydantic import BaseModel, ConfigDict
10
9
 
@@ -149,8 +148,8 @@ class Executor(ABC):
149
148
  """
150
149
 
151
150
  @abstractmethod
152
- def register_tools(self, tools: list[BaseTool]) -> None:
153
- """Register additional LangChain tools to be available during execution."""
151
+ def register_tools(self, tools: list[Any]) -> None:
152
+ """Register additional tools to be available during execution."""
154
153
 
155
154
  @abstractmethod
156
155
  def drop_last_opa_group(self, cache: "Cache", n: int = 1) -> None:
@@ -184,7 +183,8 @@ class Executor(ABC):
184
183
  """
185
184
  pass
186
185
 
187
- def prepare_for_execution(self, domain: "Domain") -> None:
186
+ @staticmethod
187
+ def prepare_for_execution(domain: "Domain") -> None:
188
188
  if domain.supports_context and not domain.is_context_built():
189
189
  logger.warning(
190
190
  "Context has not been built yet. Building it now — this may take a while. "
@@ -11,6 +11,7 @@ from databao_context_engine import (
11
11
  SnowflakeSSOAuth,
12
12
  )
13
13
  from databao_context_engine.pluginlib.build_plugin import AbstractConfigFile
14
+ from snowflake.connector.network import SNOWFLAKE_HOST_SUFFIX
14
15
  from sqlalchemy import Connection, Engine, make_url
15
16
 
16
17
  from databao.agent.databases.database_adapter import DatabaseAdapter
@@ -92,8 +93,13 @@ class SnowflakeAdapter(DatabaseAdapter):
92
93
  if "dbname" in content:
93
94
  content[DATABASE_KEY] = content.pop("dbname")
94
95
 
96
+ host: str | None = content.pop("host", None)
97
+ account: str = content.get(ACCOUNT_KEY, "")
98
+ if host and host.endswith(SNOWFLAKE_HOST_SUFFIX):
99
+ account = host[: -len(SNOWFLAKE_HOST_SUFFIX)]
100
+
95
101
  return SnowflakeConnectionProperties(
96
- account=content.get(ACCOUNT_KEY),
102
+ account=account,
97
103
  warehouse=content.get(WAREHOUSE_KEY),
98
104
  database=content.get(DATABASE_KEY),
99
105
  user=content.get(USER_KEY),
@@ -12,7 +12,7 @@ from pydantic import BaseModel
12
12
 
13
13
  from databao.agent.core import Domain
14
14
  from databao.agent.duckdb.schema_inspection import inspect_duckdb_schema, summarize_duckdb_schema
15
- from databao.agent.executors.tools import make_search_context_tool
15
+ from databao.agent.executors.langchain_tools import make_search_context_tool
16
16
 
17
17
  _LOGGER = logging.getLogger(__name__)
18
18
 
@@ -1,5 +1,6 @@
1
+ from databao.agent.executors.claude_code.executor import ClaudeCodeExecutor
1
2
  from databao.agent.executors.dbt.executor import DbtProjectExecutor
2
3
  from databao.agent.executors.lighthouse.executor import LighthouseExecutor
3
4
  from databao.agent.executors.react_duckdb.executor import ReactDuckDBExecutor
4
5
 
5
- __all__ = ["DbtProjectExecutor", "LighthouseExecutor", "ReactDuckDBExecutor"]
6
+ __all__ = ["ClaudeCodeExecutor", "DbtProjectExecutor", "LighthouseExecutor", "ReactDuckDBExecutor"]
@@ -1,9 +1,9 @@
1
+ import logging
1
2
  from abc import ABC, abstractmethod
2
3
  from collections.abc import Callable
3
4
  from typing import Any, TextIO
4
5
 
5
6
  import duckdb
6
- from databao_context_engine import DuckDBConnectionConfig
7
7
  from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
8
8
  from langchain_core.runnables import RunnableConfig
9
9
  from langchain_core.tools import BaseTool
@@ -20,11 +20,13 @@ from databao.agent.databases import register_db_in_duckdb
20
20
  from databao.agent.executors.frontend.text_frontend import TextStreamFrontend
21
21
  from databao.agent.executors.history_cleaning import clean_tool_history
22
22
 
23
+ logger = logging.getLogger(__name__)
23
24
 
24
- class GraphExecutor(Executor, ABC):
25
+
26
+ class DuckDBExecutor(Executor, ABC):
25
27
  """
26
- Base class for LangGraph executors that execute with a DuckDB connection and LLM configuration.
27
- Provides common functionality for graph caching, message handling, and OPA processing.
28
+ Base class for executors that execute with a DuckDB connection and LLM configuration.
29
+ Provides common functionality for message handling and OPA processing.
28
30
  """
29
31
 
30
32
  def __init__(self, writer: TextIO | None = None) -> None:
@@ -34,16 +36,11 @@ class GraphExecutor(Executor, ABC):
34
36
  writer: Optional TextIO for streaming output. If provided, streaming
35
37
  output will be written to this writer instead of stdout.
36
38
  """
37
- self._graph_recursion_limit = 50
38
39
  self._writer = writer
39
40
  self._duckdb_connection: duckdb.DuckDBPyConnection = duckdb.connect(":memory:")
40
41
  self._registered_dbs: dict[str, DBDataSource] = {}
41
42
  self._registered_dfs: dict[str, DFDataSource] = {}
42
43
  self._registered_dbts: dict[str, DBTDataSource] = {}
43
- self._extra_tools: dict[str, BaseTool] = {}
44
- self._compiled_graph: CompiledStateGraph[Any] | None = None
45
- self._compiled_tools_version: int = 0
46
- self._compiled_at_version: int = -1
47
44
 
48
45
  def _init_sources_from_domain(self, domain: Domain, *, register_in_duckdb: bool = True) -> None:
49
46
  """Sync sources from the domain into the executor's registered dictionaries.
@@ -61,7 +58,11 @@ class GraphExecutor(Executor, ABC):
61
58
  for name, db_source in sources.dbs.items():
62
59
  if name not in self._registered_dbs:
63
60
  if register_in_duckdb:
64
- register_db_in_duckdb(self._duckdb_connection, db_source.config, name)
61
+ try:
62
+ register_db_in_duckdb(self._duckdb_connection, db_source.config, name)
63
+ except Exception as exc:
64
+ logger.warning("Datasource '%s' is not available and will be skipped: %s", name, exc)
65
+ continue
65
66
  self._registered_dbs[name] = db_source
66
67
 
67
68
  for name, df_source in sources.dfs.items():
@@ -74,25 +75,26 @@ class GraphExecutor(Executor, ABC):
74
75
  if name not in self._registered_dbts:
75
76
  self._registered_dbts[name] = dbt_source
76
77
 
77
- def prepare_for_execution(self, domain: "Domain") -> None:
78
- if not domain.supports_context or domain.is_context_built():
79
- return
80
78
 
81
- # DuckDB does not allow two connections to hold a file open simultaneously.
82
- # Temporarily detach file-based DuckDB sources so the context engine can attach them.
83
- duckdb_file_sources = {
84
- name: source
85
- for name, source in self._registered_dbs.items()
86
- if isinstance(source.config, DuckDBConnectionConfig)
87
- }
88
- for name in duckdb_file_sources:
89
- self._duckdb_connection.execute(f'DETACH "{name}"')
90
-
91
- try:
92
- super().prepare_for_execution(domain)
93
- finally:
94
- for name, source in duckdb_file_sources.items():
95
- register_db_in_duckdb(self._duckdb_connection, source.config, name)
79
+ class GraphExecutor(DuckDBExecutor, ABC):
80
+ """
81
+ Base class for LangGraph executors that execute with a DuckDB connection and LLM configuration.
82
+ Provides common functionality for graph caching, message handling, and OPA processing.
83
+ """
84
+
85
+ def __init__(self, writer: TextIO | None = None) -> None:
86
+ """Initialize agent with graph caching infrastructure.
87
+
88
+ Args:
89
+ writer: Optional TextIO for streaming output. If provided, streaming
90
+ output will be written to this writer instead of stdout.
91
+ """
92
+ super().__init__(writer)
93
+ self._extra_tools: dict[str, BaseTool] = {}
94
+ self._graph_recursion_limit = 50
95
+ self._compiled_graph: CompiledStateGraph[Any] | None = None
96
+ self._compiled_tools_version: int = 0
97
+ self._compiled_at_version: int = -1
96
98
 
97
99
  def register_tools(self, tools: list[BaseTool]) -> None:
98
100
  """Register additional LangChain tools and invalidate the cached compiled graph."""
@@ -100,6 +102,18 @@ class GraphExecutor(Executor, ABC):
100
102
  self._extra_tools[t.name] = t
101
103
  self._compiled_tools_version += 1
102
104
 
105
+ def drop_last_opa_group(self, cache: Cache, n: int = 1) -> None:
106
+ """Drop last n groups of operations from the message history."""
107
+ messages = cache.get("state", default={}).get("messages", [])
108
+ human_messages = [m for m in messages if isinstance(m, HumanMessage)]
109
+ if len(human_messages) < n:
110
+ raise ValueError(f"Cannot drop last {n} operations - only {len(human_messages)} operations found.")
111
+ c = 0
112
+ while c < n:
113
+ m = messages.pop()
114
+ if isinstance(m, HumanMessage):
115
+ c += 1
116
+
103
117
  @abstractmethod
104
118
  def _compile_graph(
105
119
  self,
@@ -120,18 +134,6 @@ class GraphExecutor(Executor, ABC):
120
134
  self._compiled_at_version = self._compiled_tools_version
121
135
  return self._compiled_graph
122
136
 
123
- def drop_last_opa_group(self, cache: Cache, n: int = 1) -> None:
124
- """Drop last n groups of operations from the message history."""
125
- messages = cache.get("state", default={}).get("messages", [])
126
- human_messages = [m for m in messages if isinstance(m, HumanMessage)]
127
- if len(human_messages) < n:
128
- raise ValueError(f"Cannot drop last {n} operations - only {len(human_messages)} operations found.")
129
- c = 0
130
- while c < n:
131
- m = messages.pop()
132
- if isinstance(m, HumanMessage):
133
- c += 1
134
-
135
137
  def _process_opas(self, opas: list[Opa], cache: Cache) -> list[Any]:
136
138
  """
137
139
  Process a single opa and convert it to a message, appending to message history.
@@ -0,0 +1,313 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import queue
5
+ import threading
6
+ from collections.abc import Generator
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Any, TextIO
10
+
11
+ import pandas as pd
12
+ from _duckdb import DuckDBPyConnection
13
+ from claude_agent_sdk import (
14
+ ClaudeAgentOptions,
15
+ ClaudeSDKClient,
16
+ SdkMcpTool,
17
+ create_sdk_mcp_server,
18
+ tool,
19
+ )
20
+ from claude_agent_sdk.types import McpSdkServerConfig, ResultMessage, SystemPromptPreset
21
+ from claude_agent_sdk.types import Message as ClaudeMessage
22
+ from claude_agent_sdk.types import SystemMessage as ClaudeSystemMessage
23
+ from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
24
+ from mcp.types import ToolAnnotations
25
+
26
+ from databao.agent.configs.llm import LLMConfig
27
+ from databao.agent.core.executor import ExecutionResult
28
+ from databao.agent.executors.claude_code.utils import cast_claude_message_to_langchain_message
29
+ from databao.agent.executors.frontend.messages import get_tool_call
30
+ from databao.agent.executors.frontend.text_frontend import TextStreamFrontend
31
+ from databao.agent.executors.lighthouse.graph import RUN_SQL_QUERY_TOOL_DESCRIPTION
32
+ from databao.agent.executors.utils import run_sql_query
33
+
34
+ _LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ @dataclass
38
+ class QueryResult:
39
+ sql: str
40
+ df: pd.DataFrame
41
+
42
+
43
+ class ClaudeModelWrapper:
44
+ DISPLAY_ROW_LIMIT = 12
45
+ """Max number of rows to return in SQL tool calls."""
46
+
47
+ DISPLAY_CELL_CHAR_LIMIT = 1024
48
+ """Max number of characters a dataframe cell can have before it is trimmed."""
49
+
50
+ __runtime_mcp_server: McpSdkServerConfig | None = None
51
+
52
+ def __init__(
53
+ self,
54
+ *,
55
+ config: LLMConfig,
56
+ connection: DuckDBPyConnection,
57
+ system_prompt: str,
58
+ append_system_prompt: bool = False,
59
+ session_id: str | None = None,
60
+ limit_max_rows: int | None = None,
61
+ max_turns: int | None = 100,
62
+ ):
63
+ self._duckdb_connection = connection
64
+ self._limit_max_rows = limit_max_rows
65
+ self.config = config
66
+ self.sdk_mcp_tools = self._build_tools()
67
+ self._tool_server_name = Path(__file__).stem + "_mcp_server"
68
+ self.mcp_tool_names = [self._get_full_tool_name(t.name) for t in self.sdk_mcp_tools]
69
+
70
+ self.options = ClaudeAgentOptions(
71
+ max_turns=max_turns,
72
+ cwd=".",
73
+ allowed_tools=self.mcp_tool_names,
74
+ model=self.config.name,
75
+ mcp_servers={self._tool_server_name: self._build_tool_server()},
76
+ permission_mode="acceptEdits",
77
+ resume=session_id,
78
+ system_prompt=system_prompt
79
+ if not append_system_prompt
80
+ else SystemPromptPreset(
81
+ type="preset", # Append to Claude's internal system prompt
82
+ preset="claude_code",
83
+ append=system_prompt,
84
+ ),
85
+ )
86
+ self.client = ClaudeSDKClient(options=self.options)
87
+ self._query_cache: dict[int, QueryResult] = {}
88
+ self._ready_event: threading.Event
89
+ self._exit_event: asyncio.Event
90
+ self._visualization_prompt: str | None = None
91
+
92
+ def __enter__(self) -> "ClaudeModelWrapper":
93
+ self._loop = asyncio.new_event_loop()
94
+ self._thread = threading.Thread(target=self._loop.run_forever, daemon=True, name=f"{self._tool_server_name}")
95
+ self._thread.start()
96
+
97
+ self._ready_event = threading.Event()
98
+
99
+ async def _lifecycle() -> None:
100
+ self._exit_event = asyncio.Event()
101
+ async with self.client:
102
+ self._ready_event.set()
103
+ await self._exit_event.wait()
104
+
105
+ self._lifecycle_task = asyncio.run_coroutine_threadsafe(_lifecycle(), self._loop)
106
+ self._ready_event.wait()
107
+ return self
108
+
109
+ def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
110
+ self._loop.call_soon_threadsafe(self._exit_event.set)
111
+ self._lifecycle_task.result()
112
+ self._loop.call_soon_threadsafe(self._loop.stop)
113
+ self._thread.join()
114
+
115
+ def _get_full_tool_name(self, tool_name: str) -> str:
116
+ return f"mcp__{self._tool_server_name}__{tool_name}"
117
+
118
+ def _build_tools(self) -> list[SdkMcpTool[Any]]:
119
+ # Set read only hints to enable parallel tool execution
120
+ # (see https://platform.claude.com/docs/en/agent-sdk/agent-loop#parallel-tool-execution)
121
+
122
+ tools = []
123
+
124
+ @tool(
125
+ "run_sql_query",
126
+ RUN_SQL_QUERY_TOOL_DESCRIPTION,
127
+ {"sql": str},
128
+ annotations=ToolAnnotations(readOnlyHint=True),
129
+ )
130
+ async def _run_sql_query(args: dict[str, Any]) -> dict[str, Any]:
131
+ result = run_sql_query(
132
+ args.get("sql", ""),
133
+ con=self._duckdb_connection,
134
+ sql_row_limit=self._limit_max_rows,
135
+ display_row_limit=self.DISPLAY_ROW_LIMIT,
136
+ display_cell_char_limit=self.DISPLAY_CELL_CHAR_LIMIT,
137
+ )
138
+ if "error" in result:
139
+ return {"content": [{"type": "text", "text": json.dumps(result, default=str)}]}
140
+
141
+ result_for_llm: dict[str, Any] = {"csv": result.get("csv", "")}
142
+
143
+ if (sql := result.get("sql")) and (df := result.get("df")) is not None:
144
+ query_id = len(self._query_cache) + 1
145
+ self._query_cache[query_id] = QueryResult(sql=sql, df=df)
146
+ result_for_llm["query_id"] = query_id
147
+
148
+ return {"content": [{"type": "text", "text": json.dumps(result_for_llm, default=str)}]}
149
+
150
+ tools.append(_run_sql_query)
151
+
152
+ @tool(
153
+ "submit_query_id",
154
+ """\
155
+ This tool call must be the last tool to be called by the model.
156
+ It will provide to the user the generated sql and the output thereof resulting from the query with
157
+ the respective query id. You will find the query ids of the error-free queries in the outputs of
158
+ the run_sql_query tool in the `query_id` key. The `query_id` itself need not be the one of the last
159
+ generated query, it rather needs to reference the query which most closely matches the
160
+ user's question.
161
+
162
+ Args:
163
+ query_id: The ID of the query to submit.""",
164
+ {"query_id": int, "visualization_prompt": str},
165
+ annotations=ToolAnnotations(readOnlyHint=True),
166
+ )
167
+ async def submit_query_id(args: dict[str, Any]) -> dict[str, Any]:
168
+ query_id: int | None = args.get("query_id")
169
+ self._visualization_prompt = args.get("visualization_prompt")
170
+
171
+ if query_id not in self._query_cache:
172
+ return {"content": [{"type": "text", "text": json.dumps({"error": f"Query id {query_id} not found"})}]}
173
+ return {"content": [{"type": "text", "text": json.dumps({"query_id": query_id})}]}
174
+
175
+ tools.append(submit_query_id)
176
+
177
+ return tools
178
+
179
+ def _build_tool_server(self) -> McpSdkServerConfig:
180
+ tools = self._build_tools()
181
+ if self.__runtime_mcp_server is None:
182
+ self.__runtime_mcp_server = create_sdk_mcp_server(
183
+ name=self._tool_server_name,
184
+ version="1.0.0",
185
+ tools=tools,
186
+ )
187
+ return self.__runtime_mcp_server
188
+
189
+ def _check_mcp_tool_availability(self, first_message: ClaudeMessage) -> None:
190
+ """
191
+ Each conversation begins with an initial init system message. This SystemMessage
192
+ carries the information about the tools available to claude. To prevent
193
+ the system from running with the mcp tools being silently not available, we
194
+ explicitly look for them and raise and error if any of them is missing.
195
+ """
196
+ if not isinstance(first_message, ClaudeSystemMessage):
197
+ raise TypeError(
198
+ f"The first message should be a system message, got {type(first_message)}. "
199
+ "Check if you are actually calling this function on the first message of the conversation."
200
+ )
201
+
202
+ if missing_tools := set(self.mcp_tool_names).difference(first_message.data["tools"]):
203
+ raise ValueError(
204
+ f"The following mcp tools are not available: {missing_tools}. "
205
+ "Check the connection to the mcp servers by running /mcp in the claude console."
206
+ )
207
+
208
+ def _get_tool_query_id_results(self, message: ToolMessage) -> QueryResult | None:
209
+ try:
210
+ payload = json.loads(message.text)
211
+ except json.JSONDecodeError as e:
212
+ _LOGGER.warning("Failed to parse tool call payload: %s", message.text, exc_info=e)
213
+ payload = {}
214
+ query_id = payload.get("query_id")
215
+ if query_id is not None:
216
+ return self._query_cache.get(query_id)
217
+ return None
218
+
219
+ def solve(self, prompt: str) -> Generator[ClaudeMessage, None, None]:
220
+ _LOGGER.info(f"Querying {prompt}")
221
+
222
+ _sentinel = object()
223
+ q: queue.Queue[Any] = queue.Queue()
224
+
225
+ async def _produce() -> None:
226
+ await self.client.query(prompt=prompt)
227
+ messages = self.client.receive_response()
228
+ async for message in messages:
229
+ q.put(message)
230
+ q.put(_sentinel)
231
+
232
+ asyncio.run_coroutine_threadsafe(_produce(), self._loop)
233
+
234
+ first_message = q.get()
235
+ self._check_mcp_tool_availability(first_message)
236
+ yield first_message
237
+ _LOGGER.info(first_message)
238
+
239
+ n_messages = 1
240
+ while (message := q.get()) is not _sentinel:
241
+ _LOGGER.info(message)
242
+ n_messages += 1
243
+ yield message
244
+
245
+ _LOGGER.info(f"End of conversation. Got {n_messages} messages.\n\n")
246
+
247
+ def ask(
248
+ self,
249
+ prompt: str,
250
+ *,
251
+ stream: bool = False,
252
+ writer: TextIO | None = None,
253
+ ) -> tuple[ExecutionResult, str | None]:
254
+ """
255
+ Iterate through the messages from claude, cast them into BaseMessage
256
+ object so that they are compatible with the Experiment class and pack
257
+ them into a SolverResult object.
258
+ """
259
+ session_id: str | None = None
260
+ max_init_query_id = max(self._query_cache) if self._query_cache else 0
261
+ message_log: list[BaseMessage] = []
262
+ submitted_query_result: QueryResult | None = None
263
+ frontend = TextStreamFrontend({"messages": message_log}, writer=writer)
264
+ for message in self.solve(prompt):
265
+ if isinstance(message, ClaudeSystemMessage) and session_id is None:
266
+ # Child subagents have their own system messages, but we want the parent one only
267
+ session_id = message.data.get("session_id", "default")
268
+
269
+ # Skip the final text-only ResultMessage, as the previous AssistantMessage already contains the text
270
+ # of this message.
271
+ if isinstance(message, ResultMessage):
272
+ continue
273
+
274
+ lc_message = cast_claude_message_to_langchain_message(message)
275
+
276
+ if isinstance(lc_message, ToolMessage):
277
+ tool_call = get_tool_call(message_log, lc_message)
278
+ if tool_call is not None:
279
+ if tool_call["name"] == self._get_full_tool_name("run_sql_query"): # noqa: SIM102
280
+ if query_result := self._get_tool_query_id_results(lc_message):
281
+ lc_message.artifact = {
282
+ "sql": query_result.sql,
283
+ "df": query_result.df,
284
+ } # To show when streaming
285
+ if tool_call["name"] == self._get_full_tool_name("submit_query_id"): # noqa: SIM102
286
+ if query_result := self._get_tool_query_id_results(lc_message):
287
+ submitted_query_result = query_result
288
+
289
+ message_log.append(lc_message)
290
+
291
+ if stream:
292
+ if isinstance(lc_message, AIMessage):
293
+ frontend.write_full_ai_message(lc_message)
294
+ frontend.write_stream_chunk("values", {"messages": message_log})
295
+
296
+ if stream:
297
+ frontend.end()
298
+
299
+ if submitted_query_result is None:
300
+ # Fallback to the last executed query if no query was submitted
301
+ max_query_id = max(self._query_cache) if self._query_cache else 0
302
+ if max_query_id > max_init_query_id:
303
+ submitted_query_result = self._query_cache[max_query_id]
304
+
305
+ return ExecutionResult(
306
+ text=message_log[-1].text if message_log else "",
307
+ meta={
308
+ "visualization_prompt": self._visualization_prompt,
309
+ ExecutionResult.META_MESSAGES_KEY: message_log,
310
+ },
311
+ code=submitted_query_result.sql if submitted_query_result else "",
312
+ df=submitted_query_result.df if submitted_query_result else None,
313
+ ), session_id