databao-agent 0.1.4.dev5__tar.gz → 0.1.4.dev8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/PKG-INFO +2 -1
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/api.py +8 -3
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/agent.py +1 -1
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/executor.py +4 -4
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/snowflake_adapter.py +7 -1
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/react_tools.py +1 -1
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/__init__.py +2 -1
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/base.py +42 -40
- databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/claude_model_wrapper.py +313 -0
- databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/executor.py +115 -0
- databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/system_prompt.jinja +83 -0
- databao_agent-0.1.4.dev8/databao/agent/executors/claude_code/utils.py +81 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/graph.py +1 -1
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/frontend/text_frontend.py +22 -2
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/executor.py +23 -12
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/graph.py +18 -58
- databao_agent-0.1.4.dev8/databao/agent/executors/utils.py +69 -0
- databao_agent-0.1.4.dev8/databao/agent/visualizers/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/pyproject.toml +1 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/.gitignore +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/LICENSE.md +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/README.md +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/client/out/multimodal-html/index.html +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/client/out/multimodal-jupyter/index.js +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/client/out/multimodal-jupyter/style.css +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/caches/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/caches/disk_cache.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/caches/in_mem_cache.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/configs/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/configs/agent.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/configs/llm.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/cache.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/data_source.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/domain.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/env.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/opa.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/sources.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/thread.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/core/visualizer.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/bigquery_adapter.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/database_adapter.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/database_connection.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/databases.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/duckdb_adapter.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/mysql_adapter.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/postgresql_adapter.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/sqlite_adapter.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/utils.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/dbt/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/dbt/dbt.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/schema_inspection.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/duckdb/types.py +0 -0
- {databao_agent-0.1.4.dev5/databao/agent/executors/frontend → databao_agent-0.1.4.dev8/databao/agent/executors/claude_code}/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/config.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/dbt_runner.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/errors.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/executor.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/query_runner.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/system_prompt.jinja +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/dbt/task_instruction.jinja +0 -0
- {databao_agent-0.1.4.dev5/databao/agent/executors/react_duckdb → databao_agent-0.1.4.dev8/databao/agent/executors/frontend}/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/frontend/messages.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/history_cleaning.py +0 -0
- /databao_agent-0.1.4.dev5/databao/agent/executors/tools.py → /databao_agent-0.1.4.dev8/databao/agent/executors/langchain_tools.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/lighthouse/system_prompt.jinja +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/llm.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/prompt.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/query_expansion.py +0 -0
- {databao_agent-0.1.4.dev5/databao/agent/integrations → databao_agent-0.1.4.dev8/databao/agent/executors/react_duckdb}/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/executors/react_duckdb/executor.py +0 -0
- {databao_agent-0.1.4.dev5/databao/agent/visualizers → databao_agent-0.1.4.dev8/databao/agent/integrations}/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/databao_context.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/databao_context_engine.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/integrations/dce/databao_context_project_manager.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/adapter.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/config.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/connection.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/manager.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/mcp/oauth.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/__init__.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/html_viewer.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/jupyter_widget.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/multimodal/utils.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/py.typed +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/visualizers/dumb.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/visualizers/vega_chat.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/visualizers/vega_vis_tool.py +0 -0
- {databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/hatch_build.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: databao-agent
|
|
3
|
-
Version: 0.1.4.
|
|
3
|
+
Version: 0.1.4.dev8
|
|
4
4
|
Summary: databao-agent: NL queries for data
|
|
5
5
|
Project-URL: Homepage, https://databao.app/
|
|
6
6
|
Project-URL: Source, https://github.com/JetBrains/databao-agent
|
|
@@ -9,6 +9,7 @@ License-File: LICENSE.md
|
|
|
9
9
|
Classifier: Operating System :: OS Independent
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
11
|
Requires-Python: <3.15,>=3.11
|
|
12
|
+
Requires-Dist: claude-agent-sdk>=0.1.48
|
|
12
13
|
Requires-Dist: databao-context-engine[mysql,postgresql,snowflake]~=0.6.0
|
|
13
14
|
Requires-Dist: dbt-core~=1.9.0
|
|
14
15
|
Requires-Dist: dbt-duckdb>=1.10.0
|
|
@@ -8,10 +8,13 @@ from databao.agent.configs.agent import DEFAULT_AGENT_CONFIG, AgentConfig
|
|
|
8
8
|
from databao.agent.configs.llm import LLMConfig, LLMConfigDirectory
|
|
9
9
|
from databao.agent.core import Agent, Cache, Executor, Visualizer
|
|
10
10
|
from databao.agent.core.domain import Domain, _DCEProjectDomain, _InMemoryDomain
|
|
11
|
-
from databao.agent.executors import
|
|
11
|
+
from databao.agent.executors import (
|
|
12
|
+
ClaudeCodeExecutor,
|
|
13
|
+
DbtProjectExecutor,
|
|
14
|
+
LighthouseExecutor,
|
|
15
|
+
ReactDuckDBExecutor,
|
|
16
|
+
)
|
|
12
17
|
from databao.agent.executors.dbt.config import DbtConfig
|
|
13
|
-
from databao.agent.executors.dbt.executor import DbtProjectExecutor
|
|
14
|
-
from databao.agent.executors.lighthouse.executor import LighthouseExecutor
|
|
15
18
|
from databao.agent.visualizers.vega_chat import VegaChatVisualizer
|
|
16
19
|
|
|
17
20
|
|
|
@@ -49,6 +52,8 @@ def agent(
|
|
|
49
52
|
data_executor = DbtProjectExecutor(dbt_config=dbt_config, writer=writer)
|
|
50
53
|
case "react_duckdb":
|
|
51
54
|
data_executor = ReactDuckDBExecutor(writer=writer)
|
|
55
|
+
case "claude":
|
|
56
|
+
data_executor = ClaudeCodeExecutor(writer=writer)
|
|
52
57
|
case _:
|
|
53
58
|
raise ValueError(f"Invalid executor type: {executor_type}")
|
|
54
59
|
|
|
@@ -4,7 +4,6 @@ import re
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TextIO
|
|
6
6
|
|
|
7
|
-
from langchain_core.tools import BaseTool
|
|
8
7
|
from pandas import DataFrame
|
|
9
8
|
from pydantic import BaseModel, ConfigDict
|
|
10
9
|
|
|
@@ -149,8 +148,8 @@ class Executor(ABC):
|
|
|
149
148
|
"""
|
|
150
149
|
|
|
151
150
|
@abstractmethod
|
|
152
|
-
def register_tools(self, tools: list[
|
|
153
|
-
"""Register additional
|
|
151
|
+
def register_tools(self, tools: list[Any]) -> None:
|
|
152
|
+
"""Register additional tools to be available during execution."""
|
|
154
153
|
|
|
155
154
|
@abstractmethod
|
|
156
155
|
def drop_last_opa_group(self, cache: "Cache", n: int = 1) -> None:
|
|
@@ -184,7 +183,8 @@ class Executor(ABC):
|
|
|
184
183
|
"""
|
|
185
184
|
pass
|
|
186
185
|
|
|
187
|
-
|
|
186
|
+
@staticmethod
|
|
187
|
+
def prepare_for_execution(domain: "Domain") -> None:
|
|
188
188
|
if domain.supports_context and not domain.is_context_built():
|
|
189
189
|
logger.warning(
|
|
190
190
|
"Context has not been built yet. Building it now — this may take a while. "
|
{databao_agent-0.1.4.dev5 → databao_agent-0.1.4.dev8}/databao/agent/databases/snowflake_adapter.py
RENAMED
|
@@ -11,6 +11,7 @@ from databao_context_engine import (
|
|
|
11
11
|
SnowflakeSSOAuth,
|
|
12
12
|
)
|
|
13
13
|
from databao_context_engine.pluginlib.build_plugin import AbstractConfigFile
|
|
14
|
+
from snowflake.connector.network import SNOWFLAKE_HOST_SUFFIX
|
|
14
15
|
from sqlalchemy import Connection, Engine, make_url
|
|
15
16
|
|
|
16
17
|
from databao.agent.databases.database_adapter import DatabaseAdapter
|
|
@@ -92,8 +93,13 @@ class SnowflakeAdapter(DatabaseAdapter):
|
|
|
92
93
|
if "dbname" in content:
|
|
93
94
|
content[DATABASE_KEY] = content.pop("dbname")
|
|
94
95
|
|
|
96
|
+
host: str | None = content.pop("host", None)
|
|
97
|
+
account: str = content.get(ACCOUNT_KEY, "")
|
|
98
|
+
if host and host.endswith(SNOWFLAKE_HOST_SUFFIX):
|
|
99
|
+
account = host[: -len(SNOWFLAKE_HOST_SUFFIX)]
|
|
100
|
+
|
|
95
101
|
return SnowflakeConnectionProperties(
|
|
96
|
-
account=
|
|
102
|
+
account=account,
|
|
97
103
|
warehouse=content.get(WAREHOUSE_KEY),
|
|
98
104
|
database=content.get(DATABASE_KEY),
|
|
99
105
|
user=content.get(USER_KEY),
|
|
@@ -12,7 +12,7 @@ from pydantic import BaseModel
|
|
|
12
12
|
|
|
13
13
|
from databao.agent.core import Domain
|
|
14
14
|
from databao.agent.duckdb.schema_inspection import inspect_duckdb_schema, summarize_duckdb_schema
|
|
15
|
-
from databao.agent.executors.
|
|
15
|
+
from databao.agent.executors.langchain_tools import make_search_context_tool
|
|
16
16
|
|
|
17
17
|
_LOGGER = logging.getLogger(__name__)
|
|
18
18
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
|
+
from databao.agent.executors.claude_code.executor import ClaudeCodeExecutor
|
|
1
2
|
from databao.agent.executors.dbt.executor import DbtProjectExecutor
|
|
2
3
|
from databao.agent.executors.lighthouse.executor import LighthouseExecutor
|
|
3
4
|
from databao.agent.executors.react_duckdb.executor import ReactDuckDBExecutor
|
|
4
5
|
|
|
5
|
-
__all__ = ["DbtProjectExecutor", "LighthouseExecutor", "ReactDuckDBExecutor"]
|
|
6
|
+
__all__ = ["ClaudeCodeExecutor", "DbtProjectExecutor", "LighthouseExecutor", "ReactDuckDBExecutor"]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from abc import ABC, abstractmethod
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import Any, TextIO
|
|
4
5
|
|
|
5
6
|
import duckdb
|
|
6
|
-
from databao_context_engine import DuckDBConnectionConfig
|
|
7
7
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
|
8
8
|
from langchain_core.runnables import RunnableConfig
|
|
9
9
|
from langchain_core.tools import BaseTool
|
|
@@ -20,11 +20,13 @@ from databao.agent.databases import register_db_in_duckdb
|
|
|
20
20
|
from databao.agent.executors.frontend.text_frontend import TextStreamFrontend
|
|
21
21
|
from databao.agent.executors.history_cleaning import clean_tool_history
|
|
22
22
|
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
23
24
|
|
|
24
|
-
|
|
25
|
+
|
|
26
|
+
class DuckDBExecutor(Executor, ABC):
|
|
25
27
|
"""
|
|
26
|
-
Base class for
|
|
27
|
-
Provides common functionality for
|
|
28
|
+
Base class for executors that execute with a DuckDB connection and LLM configuration.
|
|
29
|
+
Provides common functionality for message handling and OPA processing.
|
|
28
30
|
"""
|
|
29
31
|
|
|
30
32
|
def __init__(self, writer: TextIO | None = None) -> None:
|
|
@@ -34,16 +36,11 @@ class GraphExecutor(Executor, ABC):
|
|
|
34
36
|
writer: Optional TextIO for streaming output. If provided, streaming
|
|
35
37
|
output will be written to this writer instead of stdout.
|
|
36
38
|
"""
|
|
37
|
-
self._graph_recursion_limit = 50
|
|
38
39
|
self._writer = writer
|
|
39
40
|
self._duckdb_connection: duckdb.DuckDBPyConnection = duckdb.connect(":memory:")
|
|
40
41
|
self._registered_dbs: dict[str, DBDataSource] = {}
|
|
41
42
|
self._registered_dfs: dict[str, DFDataSource] = {}
|
|
42
43
|
self._registered_dbts: dict[str, DBTDataSource] = {}
|
|
43
|
-
self._extra_tools: dict[str, BaseTool] = {}
|
|
44
|
-
self._compiled_graph: CompiledStateGraph[Any] | None = None
|
|
45
|
-
self._compiled_tools_version: int = 0
|
|
46
|
-
self._compiled_at_version: int = -1
|
|
47
44
|
|
|
48
45
|
def _init_sources_from_domain(self, domain: Domain, *, register_in_duckdb: bool = True) -> None:
|
|
49
46
|
"""Sync sources from the domain into the executor's registered dictionaries.
|
|
@@ -61,7 +58,11 @@ class GraphExecutor(Executor, ABC):
|
|
|
61
58
|
for name, db_source in sources.dbs.items():
|
|
62
59
|
if name not in self._registered_dbs:
|
|
63
60
|
if register_in_duckdb:
|
|
64
|
-
|
|
61
|
+
try:
|
|
62
|
+
register_db_in_duckdb(self._duckdb_connection, db_source.config, name)
|
|
63
|
+
except Exception as exc:
|
|
64
|
+
logger.warning("Datasource '%s' is not available and will be skipped: %s", name, exc)
|
|
65
|
+
continue
|
|
65
66
|
self._registered_dbs[name] = db_source
|
|
66
67
|
|
|
67
68
|
for name, df_source in sources.dfs.items():
|
|
@@ -74,25 +75,26 @@ class GraphExecutor(Executor, ABC):
|
|
|
74
75
|
if name not in self._registered_dbts:
|
|
75
76
|
self._registered_dbts[name] = dbt_source
|
|
76
77
|
|
|
77
|
-
def prepare_for_execution(self, domain: "Domain") -> None:
|
|
78
|
-
if not domain.supports_context or domain.is_context_built():
|
|
79
|
-
return
|
|
80
78
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
79
|
+
class GraphExecutor(DuckDBExecutor, ABC):
|
|
80
|
+
"""
|
|
81
|
+
Base class for LangGraph executors that execute with a DuckDB connection and LLM configuration.
|
|
82
|
+
Provides common functionality for graph caching, message handling, and OPA processing.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, writer: TextIO | None = None) -> None:
|
|
86
|
+
"""Initialize agent with graph caching infrastructure.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
writer: Optional TextIO for streaming output. If provided, streaming
|
|
90
|
+
output will be written to this writer instead of stdout.
|
|
91
|
+
"""
|
|
92
|
+
super().__init__(writer)
|
|
93
|
+
self._extra_tools: dict[str, BaseTool] = {}
|
|
94
|
+
self._graph_recursion_limit = 50
|
|
95
|
+
self._compiled_graph: CompiledStateGraph[Any] | None = None
|
|
96
|
+
self._compiled_tools_version: int = 0
|
|
97
|
+
self._compiled_at_version: int = -1
|
|
96
98
|
|
|
97
99
|
def register_tools(self, tools: list[BaseTool]) -> None:
|
|
98
100
|
"""Register additional LangChain tools and invalidate the cached compiled graph."""
|
|
@@ -100,6 +102,18 @@ class GraphExecutor(Executor, ABC):
|
|
|
100
102
|
self._extra_tools[t.name] = t
|
|
101
103
|
self._compiled_tools_version += 1
|
|
102
104
|
|
|
105
|
+
def drop_last_opa_group(self, cache: Cache, n: int = 1) -> None:
|
|
106
|
+
"""Drop last n groups of operations from the message history."""
|
|
107
|
+
messages = cache.get("state", default={}).get("messages", [])
|
|
108
|
+
human_messages = [m for m in messages if isinstance(m, HumanMessage)]
|
|
109
|
+
if len(human_messages) < n:
|
|
110
|
+
raise ValueError(f"Cannot drop last {n} operations - only {len(human_messages)} operations found.")
|
|
111
|
+
c = 0
|
|
112
|
+
while c < n:
|
|
113
|
+
m = messages.pop()
|
|
114
|
+
if isinstance(m, HumanMessage):
|
|
115
|
+
c += 1
|
|
116
|
+
|
|
103
117
|
@abstractmethod
|
|
104
118
|
def _compile_graph(
|
|
105
119
|
self,
|
|
@@ -120,18 +134,6 @@ class GraphExecutor(Executor, ABC):
|
|
|
120
134
|
self._compiled_at_version = self._compiled_tools_version
|
|
121
135
|
return self._compiled_graph
|
|
122
136
|
|
|
123
|
-
def drop_last_opa_group(self, cache: Cache, n: int = 1) -> None:
|
|
124
|
-
"""Drop last n groups of operations from the message history."""
|
|
125
|
-
messages = cache.get("state", default={}).get("messages", [])
|
|
126
|
-
human_messages = [m for m in messages if isinstance(m, HumanMessage)]
|
|
127
|
-
if len(human_messages) < n:
|
|
128
|
-
raise ValueError(f"Cannot drop last {n} operations - only {len(human_messages)} operations found.")
|
|
129
|
-
c = 0
|
|
130
|
-
while c < n:
|
|
131
|
-
m = messages.pop()
|
|
132
|
-
if isinstance(m, HumanMessage):
|
|
133
|
-
c += 1
|
|
134
|
-
|
|
135
137
|
def _process_opas(self, opas: list[Opa], cache: Cache) -> list[Any]:
|
|
136
138
|
"""
|
|
137
139
|
Process a single opa and convert it to a message, appending to message history.
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import queue
|
|
5
|
+
import threading
|
|
6
|
+
from collections.abc import Generator
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, TextIO
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from _duckdb import DuckDBPyConnection
|
|
13
|
+
from claude_agent_sdk import (
|
|
14
|
+
ClaudeAgentOptions,
|
|
15
|
+
ClaudeSDKClient,
|
|
16
|
+
SdkMcpTool,
|
|
17
|
+
create_sdk_mcp_server,
|
|
18
|
+
tool,
|
|
19
|
+
)
|
|
20
|
+
from claude_agent_sdk.types import McpSdkServerConfig, ResultMessage, SystemPromptPreset
|
|
21
|
+
from claude_agent_sdk.types import Message as ClaudeMessage
|
|
22
|
+
from claude_agent_sdk.types import SystemMessage as ClaudeSystemMessage
|
|
23
|
+
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
24
|
+
from mcp.types import ToolAnnotations
|
|
25
|
+
|
|
26
|
+
from databao.agent.configs.llm import LLMConfig
|
|
27
|
+
from databao.agent.core.executor import ExecutionResult
|
|
28
|
+
from databao.agent.executors.claude_code.utils import cast_claude_message_to_langchain_message
|
|
29
|
+
from databao.agent.executors.frontend.messages import get_tool_call
|
|
30
|
+
from databao.agent.executors.frontend.text_frontend import TextStreamFrontend
|
|
31
|
+
from databao.agent.executors.lighthouse.graph import RUN_SQL_QUERY_TOOL_DESCRIPTION
|
|
32
|
+
from databao.agent.executors.utils import run_sql_query
|
|
33
|
+
|
|
34
|
+
_LOGGER = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class QueryResult:
|
|
39
|
+
sql: str
|
|
40
|
+
df: pd.DataFrame
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ClaudeModelWrapper:
|
|
44
|
+
DISPLAY_ROW_LIMIT = 12
|
|
45
|
+
"""Max number of rows to return in SQL tool calls."""
|
|
46
|
+
|
|
47
|
+
DISPLAY_CELL_CHAR_LIMIT = 1024
|
|
48
|
+
"""Max number of characters a dataframe cell can have before it is trimmed."""
|
|
49
|
+
|
|
50
|
+
__runtime_mcp_server: McpSdkServerConfig | None = None
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
*,
|
|
55
|
+
config: LLMConfig,
|
|
56
|
+
connection: DuckDBPyConnection,
|
|
57
|
+
system_prompt: str,
|
|
58
|
+
append_system_prompt: bool = False,
|
|
59
|
+
session_id: str | None = None,
|
|
60
|
+
limit_max_rows: int | None = None,
|
|
61
|
+
max_turns: int | None = 100,
|
|
62
|
+
):
|
|
63
|
+
self._duckdb_connection = connection
|
|
64
|
+
self._limit_max_rows = limit_max_rows
|
|
65
|
+
self.config = config
|
|
66
|
+
self.sdk_mcp_tools = self._build_tools()
|
|
67
|
+
self._tool_server_name = Path(__file__).stem + "_mcp_server"
|
|
68
|
+
self.mcp_tool_names = [self._get_full_tool_name(t.name) for t in self.sdk_mcp_tools]
|
|
69
|
+
|
|
70
|
+
self.options = ClaudeAgentOptions(
|
|
71
|
+
max_turns=max_turns,
|
|
72
|
+
cwd=".",
|
|
73
|
+
allowed_tools=self.mcp_tool_names,
|
|
74
|
+
model=self.config.name,
|
|
75
|
+
mcp_servers={self._tool_server_name: self._build_tool_server()},
|
|
76
|
+
permission_mode="acceptEdits",
|
|
77
|
+
resume=session_id,
|
|
78
|
+
system_prompt=system_prompt
|
|
79
|
+
if not append_system_prompt
|
|
80
|
+
else SystemPromptPreset(
|
|
81
|
+
type="preset", # Append to Claude's internal system prompt
|
|
82
|
+
preset="claude_code",
|
|
83
|
+
append=system_prompt,
|
|
84
|
+
),
|
|
85
|
+
)
|
|
86
|
+
self.client = ClaudeSDKClient(options=self.options)
|
|
87
|
+
self._query_cache: dict[int, QueryResult] = {}
|
|
88
|
+
self._ready_event: threading.Event
|
|
89
|
+
self._exit_event: asyncio.Event
|
|
90
|
+
self._visualization_prompt: str | None = None
|
|
91
|
+
|
|
92
|
+
def __enter__(self) -> "ClaudeModelWrapper":
|
|
93
|
+
self._loop = asyncio.new_event_loop()
|
|
94
|
+
self._thread = threading.Thread(target=self._loop.run_forever, daemon=True, name=f"{self._tool_server_name}")
|
|
95
|
+
self._thread.start()
|
|
96
|
+
|
|
97
|
+
self._ready_event = threading.Event()
|
|
98
|
+
|
|
99
|
+
async def _lifecycle() -> None:
|
|
100
|
+
self._exit_event = asyncio.Event()
|
|
101
|
+
async with self.client:
|
|
102
|
+
self._ready_event.set()
|
|
103
|
+
await self._exit_event.wait()
|
|
104
|
+
|
|
105
|
+
self._lifecycle_task = asyncio.run_coroutine_threadsafe(_lifecycle(), self._loop)
|
|
106
|
+
self._ready_event.wait()
|
|
107
|
+
return self
|
|
108
|
+
|
|
109
|
+
def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
|
|
110
|
+
self._loop.call_soon_threadsafe(self._exit_event.set)
|
|
111
|
+
self._lifecycle_task.result()
|
|
112
|
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
113
|
+
self._thread.join()
|
|
114
|
+
|
|
115
|
+
def _get_full_tool_name(self, tool_name: str) -> str:
|
|
116
|
+
return f"mcp__{self._tool_server_name}__{tool_name}"
|
|
117
|
+
|
|
118
|
+
def _build_tools(self) -> list[SdkMcpTool[Any]]:
|
|
119
|
+
# Set read only hints to enable parallel tool execution
|
|
120
|
+
# (see https://platform.claude.com/docs/en/agent-sdk/agent-loop#parallel-tool-execution)
|
|
121
|
+
|
|
122
|
+
tools = []
|
|
123
|
+
|
|
124
|
+
@tool(
|
|
125
|
+
"run_sql_query",
|
|
126
|
+
RUN_SQL_QUERY_TOOL_DESCRIPTION,
|
|
127
|
+
{"sql": str},
|
|
128
|
+
annotations=ToolAnnotations(readOnlyHint=True),
|
|
129
|
+
)
|
|
130
|
+
async def _run_sql_query(args: dict[str, Any]) -> dict[str, Any]:
|
|
131
|
+
result = run_sql_query(
|
|
132
|
+
args.get("sql", ""),
|
|
133
|
+
con=self._duckdb_connection,
|
|
134
|
+
sql_row_limit=self._limit_max_rows,
|
|
135
|
+
display_row_limit=self.DISPLAY_ROW_LIMIT,
|
|
136
|
+
display_cell_char_limit=self.DISPLAY_CELL_CHAR_LIMIT,
|
|
137
|
+
)
|
|
138
|
+
if "error" in result:
|
|
139
|
+
return {"content": [{"type": "text", "text": json.dumps(result, default=str)}]}
|
|
140
|
+
|
|
141
|
+
result_for_llm: dict[str, Any] = {"csv": result.get("csv", "")}
|
|
142
|
+
|
|
143
|
+
if (sql := result.get("sql")) and (df := result.get("df")) is not None:
|
|
144
|
+
query_id = len(self._query_cache) + 1
|
|
145
|
+
self._query_cache[query_id] = QueryResult(sql=sql, df=df)
|
|
146
|
+
result_for_llm["query_id"] = query_id
|
|
147
|
+
|
|
148
|
+
return {"content": [{"type": "text", "text": json.dumps(result_for_llm, default=str)}]}
|
|
149
|
+
|
|
150
|
+
tools.append(_run_sql_query)
|
|
151
|
+
|
|
152
|
+
@tool(
|
|
153
|
+
"submit_query_id",
|
|
154
|
+
"""\
|
|
155
|
+
This tool call must be the last tool to be called by the model.
|
|
156
|
+
It will provide to the user the generated sql and the output thereof resulting from the query with
|
|
157
|
+
the respective query id. You will find the query ids of the error-free queries in the outputs of
|
|
158
|
+
the run_sql_query tool in the `query_id` key. The `query_id` itself need not be the one of the last
|
|
159
|
+
generated query, it rather needs to reference the query which most closely matches the
|
|
160
|
+
user's question.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
query_id: The ID of the query to submit.""",
|
|
164
|
+
{"query_id": int, "visualization_prompt": str},
|
|
165
|
+
annotations=ToolAnnotations(readOnlyHint=True),
|
|
166
|
+
)
|
|
167
|
+
async def submit_query_id(args: dict[str, Any]) -> dict[str, Any]:
|
|
168
|
+
query_id: int | None = args.get("query_id")
|
|
169
|
+
self._visualization_prompt = args.get("visualization_prompt")
|
|
170
|
+
|
|
171
|
+
if query_id not in self._query_cache:
|
|
172
|
+
return {"content": [{"type": "text", "text": json.dumps({"error": f"Query id {query_id} not found"})}]}
|
|
173
|
+
return {"content": [{"type": "text", "text": json.dumps({"query_id": query_id})}]}
|
|
174
|
+
|
|
175
|
+
tools.append(submit_query_id)
|
|
176
|
+
|
|
177
|
+
return tools
|
|
178
|
+
|
|
179
|
+
def _build_tool_server(self) -> McpSdkServerConfig:
|
|
180
|
+
tools = self._build_tools()
|
|
181
|
+
if self.__runtime_mcp_server is None:
|
|
182
|
+
self.__runtime_mcp_server = create_sdk_mcp_server(
|
|
183
|
+
name=self._tool_server_name,
|
|
184
|
+
version="1.0.0",
|
|
185
|
+
tools=tools,
|
|
186
|
+
)
|
|
187
|
+
return self.__runtime_mcp_server
|
|
188
|
+
|
|
189
|
+
def _check_mcp_tool_availability(self, first_message: ClaudeMessage) -> None:
|
|
190
|
+
"""
|
|
191
|
+
Each conversation begins with an initial init system message. This SystemMessage
|
|
192
|
+
carries the information about the tools available to claude. To prevent
|
|
193
|
+
the system from running with the mcp tools being silently not available, we
|
|
194
|
+
explicitly look for them and raise and error if any of them is missing.
|
|
195
|
+
"""
|
|
196
|
+
if not isinstance(first_message, ClaudeSystemMessage):
|
|
197
|
+
raise TypeError(
|
|
198
|
+
f"The first message should be a system message, got {type(first_message)}. "
|
|
199
|
+
"Check if you are actually calling this function on the first message of the conversation."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if missing_tools := set(self.mcp_tool_names).difference(first_message.data["tools"]):
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"The following mcp tools are not available: {missing_tools}. "
|
|
205
|
+
"Check the connection to the mcp servers by running /mcp in the claude console."
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def _get_tool_query_id_results(self, message: ToolMessage) -> QueryResult | None:
|
|
209
|
+
try:
|
|
210
|
+
payload = json.loads(message.text)
|
|
211
|
+
except json.JSONDecodeError as e:
|
|
212
|
+
_LOGGER.warning("Failed to parse tool call payload: %s", message.text, exc_info=e)
|
|
213
|
+
payload = {}
|
|
214
|
+
query_id = payload.get("query_id")
|
|
215
|
+
if query_id is not None:
|
|
216
|
+
return self._query_cache.get(query_id)
|
|
217
|
+
return None
|
|
218
|
+
|
|
219
|
+
def solve(self, prompt: str) -> Generator[ClaudeMessage, None, None]:
|
|
220
|
+
_LOGGER.info(f"Querying {prompt}")
|
|
221
|
+
|
|
222
|
+
_sentinel = object()
|
|
223
|
+
q: queue.Queue[Any] = queue.Queue()
|
|
224
|
+
|
|
225
|
+
async def _produce() -> None:
|
|
226
|
+
await self.client.query(prompt=prompt)
|
|
227
|
+
messages = self.client.receive_response()
|
|
228
|
+
async for message in messages:
|
|
229
|
+
q.put(message)
|
|
230
|
+
q.put(_sentinel)
|
|
231
|
+
|
|
232
|
+
asyncio.run_coroutine_threadsafe(_produce(), self._loop)
|
|
233
|
+
|
|
234
|
+
first_message = q.get()
|
|
235
|
+
self._check_mcp_tool_availability(first_message)
|
|
236
|
+
yield first_message
|
|
237
|
+
_LOGGER.info(first_message)
|
|
238
|
+
|
|
239
|
+
n_messages = 1
|
|
240
|
+
while (message := q.get()) is not _sentinel:
|
|
241
|
+
_LOGGER.info(message)
|
|
242
|
+
n_messages += 1
|
|
243
|
+
yield message
|
|
244
|
+
|
|
245
|
+
_LOGGER.info(f"End of conversation. Got {n_messages} messages.\n\n")
|
|
246
|
+
|
|
247
|
+
def ask(
|
|
248
|
+
self,
|
|
249
|
+
prompt: str,
|
|
250
|
+
*,
|
|
251
|
+
stream: bool = False,
|
|
252
|
+
writer: TextIO | None = None,
|
|
253
|
+
) -> tuple[ExecutionResult, str | None]:
|
|
254
|
+
"""
|
|
255
|
+
Iterate through the messages from claude, cast them into BaseMessage
|
|
256
|
+
object so that they are compatible with the Experiment class and pack
|
|
257
|
+
them into a SolverResult object.
|
|
258
|
+
"""
|
|
259
|
+
session_id: str | None = None
|
|
260
|
+
max_init_query_id = max(self._query_cache) if self._query_cache else 0
|
|
261
|
+
message_log: list[BaseMessage] = []
|
|
262
|
+
submitted_query_result: QueryResult | None = None
|
|
263
|
+
frontend = TextStreamFrontend({"messages": message_log}, writer=writer)
|
|
264
|
+
for message in self.solve(prompt):
|
|
265
|
+
if isinstance(message, ClaudeSystemMessage) and session_id is None:
|
|
266
|
+
# Child subagents have their own system messages, but we want the parent one only
|
|
267
|
+
session_id = message.data.get("session_id", "default")
|
|
268
|
+
|
|
269
|
+
# Skip the final text-only ResultMessage, as the previous AssistantMessage already contains the text
|
|
270
|
+
# of this message.
|
|
271
|
+
if isinstance(message, ResultMessage):
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
lc_message = cast_claude_message_to_langchain_message(message)
|
|
275
|
+
|
|
276
|
+
if isinstance(lc_message, ToolMessage):
|
|
277
|
+
tool_call = get_tool_call(message_log, lc_message)
|
|
278
|
+
if tool_call is not None:
|
|
279
|
+
if tool_call["name"] == self._get_full_tool_name("run_sql_query"): # noqa: SIM102
|
|
280
|
+
if query_result := self._get_tool_query_id_results(lc_message):
|
|
281
|
+
lc_message.artifact = {
|
|
282
|
+
"sql": query_result.sql,
|
|
283
|
+
"df": query_result.df,
|
|
284
|
+
} # To show when streaming
|
|
285
|
+
if tool_call["name"] == self._get_full_tool_name("submit_query_id"): # noqa: SIM102
|
|
286
|
+
if query_result := self._get_tool_query_id_results(lc_message):
|
|
287
|
+
submitted_query_result = query_result
|
|
288
|
+
|
|
289
|
+
message_log.append(lc_message)
|
|
290
|
+
|
|
291
|
+
if stream:
|
|
292
|
+
if isinstance(lc_message, AIMessage):
|
|
293
|
+
frontend.write_full_ai_message(lc_message)
|
|
294
|
+
frontend.write_stream_chunk("values", {"messages": message_log})
|
|
295
|
+
|
|
296
|
+
if stream:
|
|
297
|
+
frontend.end()
|
|
298
|
+
|
|
299
|
+
if submitted_query_result is None:
|
|
300
|
+
# Fallback to the last executed query if no query was submitted
|
|
301
|
+
max_query_id = max(self._query_cache) if self._query_cache else 0
|
|
302
|
+
if max_query_id > max_init_query_id:
|
|
303
|
+
submitted_query_result = self._query_cache[max_query_id]
|
|
304
|
+
|
|
305
|
+
return ExecutionResult(
|
|
306
|
+
text=message_log[-1].text if message_log else "",
|
|
307
|
+
meta={
|
|
308
|
+
"visualization_prompt": self._visualization_prompt,
|
|
309
|
+
ExecutionResult.META_MESSAGES_KEY: message_log,
|
|
310
|
+
},
|
|
311
|
+
code=submitted_query_result.sql if submitted_query_result else "",
|
|
312
|
+
df=submitted_query_result.df if submitted_query_result else None,
|
|
313
|
+
), session_id
|