datahub-analytics-agent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- analytics_agent/__init__.py +0 -0
- analytics_agent/agent/__init__.py +0 -0
- analytics_agent/agent/analysis.py +149 -0
- analytics_agent/agent/chart_generator.py +70 -0
- analytics_agent/agent/chart_tool.py +103 -0
- analytics_agent/agent/compaction.py +57 -0
- analytics_agent/agent/compactor_registry.py +22 -0
- analytics_agent/agent/graph.py +121 -0
- analytics_agent/agent/history.py +159 -0
- analytics_agent/agent/llm.py +87 -0
- analytics_agent/agent/mock_llm.py +111 -0
- analytics_agent/agent/state.py +13 -0
- analytics_agent/agent/streaming.py +304 -0
- analytics_agent/api/__init__.py +135 -0
- analytics_agent/api/chat.py +439 -0
- analytics_agent/api/conversations.py +244 -0
- analytics_agent/api/oauth.py +741 -0
- analytics_agent/api/settings.py +1947 -0
- analytics_agent/config.py +236 -0
- analytics_agent/context/__init__.py +0 -0
- analytics_agent/context/base.py +26 -0
- analytics_agent/context/datahub.py +242 -0
- analytics_agent/context/mcp_platform.py +123 -0
- analytics_agent/context/native_datahub.py +58 -0
- analytics_agent/context/registry.py +84 -0
- analytics_agent/db/__init__.py +0 -0
- analytics_agent/db/alembic/env.py +49 -0
- analytics_agent/db/alembic/script.py.mako +25 -0
- analytics_agent/db/alembic/versions/001_init.py +47 -0
- analytics_agent/db/alembic/versions/002_settings_table.py +30 -0
- analytics_agent/db/alembic/versions/003_integrations.py +52 -0
- analytics_agent/db/alembic/versions/004_conversation_quality.py +28 -0
- analytics_agent/db/alembic/versions/005_context_platforms.py +36 -0
- analytics_agent/db/base.py +33 -0
- analytics_agent/db/models.py +137 -0
- analytics_agent/db/repository.py +294 -0
- analytics_agent/db/types.py +69 -0
- analytics_agent/engines/__init__.py +0 -0
- analytics_agent/engines/base.py +30 -0
- analytics_agent/engines/factory.py +95 -0
- analytics_agent/engines/mcp/__init__.py +0 -0
- analytics_agent/engines/mcp/engine.py +78 -0
- analytics_agent/engines/resolver.py +84 -0
- analytics_agent/engines/snowflake/__init__.py +0 -0
- analytics_agent/engines/snowflake/engine.py +304 -0
- analytics_agent/engines/sqlalchemy/__init__.py +0 -0
- analytics_agent/engines/sqlalchemy/engine.py +163 -0
- analytics_agent/main.py +536 -0
- analytics_agent/prompts/__init__.py +0 -0
- analytics_agent/prompts/chart.py +101 -0
- analytics_agent/prompts/system.py +33 -0
- analytics_agent/prompts/system_prompt.md +184 -0
- analytics_agent/skills/__init__.py +0 -0
- analytics_agent/skills/datahub_skills.py +409 -0
- analytics_agent/skills/improve-context/SKILL.md +73 -0
- analytics_agent/skills/loader.py +162 -0
- analytics_agent/skills/publish-analysis/SKILL.md +99 -0
- analytics_agent/skills/save-correction/SKILL.md +161 -0
- analytics_agent/skills/search-business-context/SKILL.md +109 -0
- analytics_agent/tracing.py +88 -0
- datahub_analytics_agent-0.1.0.dist-info/METADATA +328 -0
- datahub_analytics_agent-0.1.0.dist-info/RECORD +63 -0
- datahub_analytics_agent-0.1.0.dist-info/WHEEL +4 -0
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
|
|
7
|
+
import orjson
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
CONTEXT_TOOLS: frozenset[str] = frozenset(
|
|
12
|
+
{"search_documents", "grep_documents", "search", "get_entities", "search_business_context"}
|
|
13
|
+
)
|
|
14
|
+
_CONTEXT_TOOLS = CONTEXT_TOOLS # backward compat alias
|
|
15
|
+
|
|
16
|
+
_SCORE_LABELS = {1: "Poor", 2: "Poor", 3: "Fair", 4: "Good", 5: "Excellent"}
|
|
17
|
+
|
|
18
|
+
_ASSESSMENT_PROMPT = """\
|
|
19
|
+
You are assessing the **context quality** of a data assistant conversation.
|
|
20
|
+
|
|
21
|
+
Context quality measures how well the DataHub knowledge base (documentation, \
|
|
22
|
+
definitions, dataset descriptions) supported the agent's work.
|
|
23
|
+
|
|
24
|
+
Score 1–5:
|
|
25
|
+
5 Excellent — DataHub had rich, accurate documentation that fully covered the \
|
|
26
|
+
question. Agent applied the definition directly with no improvisation.
|
|
27
|
+
4 Good — Useful docs found; definition was mostly complete; agent made one minor \
|
|
28
|
+
stated assumption that didn't change the answer meaningfully.
|
|
29
|
+
3 Fair — Definition found but incomplete or ambiguous; agent had to fill gaps, \
|
|
30
|
+
deviate from the definition, or ask the user for clarification about what the \
|
|
31
|
+
context should have made clear.
|
|
32
|
+
2 Poor — Docs mostly missing or returned empty results; agent improvised \
|
|
33
|
+
substantially and the answer depended heavily on undocumented choices.
|
|
34
|
+
1 Very Poor — No useful context; agent expressed significant uncertainty, made \
|
|
35
|
+
conflicting assumptions, or produced an answer that contradicts available definitions.
|
|
36
|
+
|
|
37
|
+
Key signals that push the score DOWN:
|
|
38
|
+
- Agent says "the definition doesn't cover this" or "I'll interpret this as…"
|
|
39
|
+
- Agent switches columns, tables, or date anchors not mentioned in the definition
|
|
40
|
+
- Agent produces a result that varies based on an undocumented assumption
|
|
41
|
+
- Agent asks the user to clarify something the glossary/docs should have defined
|
|
42
|
+
|
|
43
|
+
--- CONTEXT TOOL CALLS AND RESULTS ---
|
|
44
|
+
{context_calls}
|
|
45
|
+
--- END CONTEXT ---
|
|
46
|
+
|
|
47
|
+
--- AGENT REASONING (what the agent said and concluded) ---
|
|
48
|
+
{agent_reasoning}
|
|
49
|
+
--- END REASONING ---
|
|
50
|
+
|
|
51
|
+
Respond with ONLY valid JSON, no explanation outside it:
|
|
52
|
+
{{"score": <1-5>, "label": "<Excellent|Good|Fair|Poor>", "reason": "<one sentence that names the specific gap or strength>"}}"""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def compute_context_quality(messages: list) -> QualityScore:
|
|
56
|
+
"""
|
|
57
|
+
LLM-assessed context quality score (1–5).
|
|
58
|
+
|
|
59
|
+
Extracts DataHub context tool calls + results and the agent's own reasoning
|
|
60
|
+
text, then asks a cheap model to judge whether the returned context was
|
|
61
|
+
actually useful and complete — penalising cases where the agent had to
|
|
62
|
+
improvise or deviate from the definition.
|
|
63
|
+
Returns Neutral (3) immediately when no context tool calls have occurred yet.
|
|
64
|
+
"""
|
|
65
|
+
context_calls: list[dict] = []
|
|
66
|
+
agent_text_chunks: list[str] = []
|
|
67
|
+
|
|
68
|
+
for msg in messages:
|
|
69
|
+
try:
|
|
70
|
+
payload = (
|
|
71
|
+
orjson.loads(msg.payload) if isinstance(msg.payload, (str, bytes)) else msg.payload
|
|
72
|
+
)
|
|
73
|
+
except Exception:
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
if msg.event_type == "TOOL_RESULT":
|
|
77
|
+
tool_name = payload.get("tool_name", "")
|
|
78
|
+
if tool_name not in _CONTEXT_TOOLS:
|
|
79
|
+
continue
|
|
80
|
+
result_raw = payload.get("result", "")
|
|
81
|
+
result_str = str(result_raw)[:800] + ("…" if len(str(result_raw)) > 800 else "")
|
|
82
|
+
context_calls.append(
|
|
83
|
+
{
|
|
84
|
+
"tool": tool_name,
|
|
85
|
+
"is_error": bool(payload.get("is_error", False)),
|
|
86
|
+
"result": result_str,
|
|
87
|
+
}
|
|
88
|
+
)
|
|
89
|
+
elif msg.event_type in ("TEXT", "COMPLETE"):
|
|
90
|
+
text = payload.get("text", "")
|
|
91
|
+
if text:
|
|
92
|
+
agent_text_chunks.append(text[:400])
|
|
93
|
+
|
|
94
|
+
if not context_calls:
|
|
95
|
+
return QualityScore(
|
|
96
|
+
score=3, label="Neutral", breakdown={"reason": "No context lookups yet"}
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
calls_text = "\n\n".join(
|
|
100
|
+
f"Tool: {c['tool']}\nError: {c['is_error']}\nResult: {c['result']}" for c in context_calls
|
|
101
|
+
)
|
|
102
|
+
# Deduplicate and cap agent reasoning (TEXT events stream token-by-token)
|
|
103
|
+
reasoning = " ".join(dict.fromkeys(agent_text_chunks))[:1200]
|
|
104
|
+
|
|
105
|
+
prompt = _ASSESSMENT_PROMPT.format(context_calls=calls_text, agent_reasoning=reasoning)
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
109
|
+
|
|
110
|
+
from analytics_agent.agent.llm import get_quality_llm
|
|
111
|
+
|
|
112
|
+
llm = get_quality_llm()
|
|
113
|
+
response = await llm.ainvoke(
|
|
114
|
+
[
|
|
115
|
+
SystemMessage(
|
|
116
|
+
content="You assess data assistant context quality. Reply only with the requested JSON."
|
|
117
|
+
),
|
|
118
|
+
HumanMessage(content=prompt),
|
|
119
|
+
]
|
|
120
|
+
)
|
|
121
|
+
raw = response.content
|
|
122
|
+
if isinstance(raw, list):
|
|
123
|
+
raw = next(
|
|
124
|
+
(b.get("text", "") for b in raw if isinstance(b, dict) and b.get("type") == "text"),
|
|
125
|
+
"",
|
|
126
|
+
)
|
|
127
|
+
raw = raw.strip()
|
|
128
|
+
# Strip markdown code fences if present
|
|
129
|
+
if raw.startswith("```"):
|
|
130
|
+
raw = raw.split("```")[1]
|
|
131
|
+
if raw.startswith("json"):
|
|
132
|
+
raw = raw[4:]
|
|
133
|
+
data = json.loads(raw.strip())
|
|
134
|
+
score = max(1, min(5, int(data.get("score", 3))))
|
|
135
|
+
label = data.get("label", _SCORE_LABELS[score])
|
|
136
|
+
reason = data.get("reason", "")
|
|
137
|
+
return QualityScore(score=score, label=label, breakdown={"reason": reason})
|
|
138
|
+
except Exception as exc:
|
|
139
|
+
logger.warning("Context quality LLM assessment failed: %s", exc)
|
|
140
|
+
return QualityScore(
|
|
141
|
+
score=3, label="Neutral", breakdown={"reason": "Assessment unavailable"}
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class QualityScore:
|
|
147
|
+
score: int # 1–5
|
|
148
|
+
label: str
|
|
149
|
+
breakdown: dict
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import orjson
|
|
6
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
7
|
+
|
|
8
|
+
from analytics_agent.agent.llm import get_chart_llm
|
|
9
|
+
from analytics_agent.agent.state import AgentState
|
|
10
|
+
from analytics_agent.prompts.chart import CHART_SYSTEM_PROMPT, build_chart_user_prompt
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def chart_node(state: AgentState) -> dict:
|
|
16
|
+
"""
|
|
17
|
+
Generate a Vega-Lite chart spec from the last SQL result and store it in state.
|
|
18
|
+
streaming.py reads it from state after the graph completes.
|
|
19
|
+
"""
|
|
20
|
+
from analytics_agent.agent.graph import get_last_sql_result
|
|
21
|
+
|
|
22
|
+
sql_result = get_last_sql_result(state)
|
|
23
|
+
if not sql_result or not sql_result.get("rows"):
|
|
24
|
+
return {}
|
|
25
|
+
|
|
26
|
+
llm = get_chart_llm()
|
|
27
|
+
|
|
28
|
+
user_prompt = build_chart_user_prompt(
|
|
29
|
+
question=state.get("user_question", ""),
|
|
30
|
+
sql=sql_result.get("sql", ""),
|
|
31
|
+
columns=sql_result.get("columns", []),
|
|
32
|
+
sample_rows=sql_result.get("rows", []),
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
response = await llm.ainvoke(
|
|
37
|
+
[SystemMessage(content=CHART_SYSTEM_PROMPT), HumanMessage(content=user_prompt)]
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
raw = response.content
|
|
41
|
+
if isinstance(raw, list):
|
|
42
|
+
# Anthropic returns list of content blocks
|
|
43
|
+
raw = next(
|
|
44
|
+
(b.get("text", "") for b in raw if isinstance(b, dict) and b.get("type") == "text"),
|
|
45
|
+
"",
|
|
46
|
+
)
|
|
47
|
+
if "```" in raw:
|
|
48
|
+
raw = raw.split("```")[1]
|
|
49
|
+
if raw.startswith("json"):
|
|
50
|
+
raw = raw[4:]
|
|
51
|
+
result = orjson.loads(raw.strip())
|
|
52
|
+
|
|
53
|
+
chart_schema = result.get("chart_schema", {})
|
|
54
|
+
chart_type = result.get("chart_type", "")
|
|
55
|
+
|
|
56
|
+
if chart_schema and chart_type:
|
|
57
|
+
chart_schema["data"] = {"values": sql_result.get("rows", [])}
|
|
58
|
+
|
|
59
|
+
# Store in state so streaming.py can emit it after graph completion
|
|
60
|
+
return {
|
|
61
|
+
"pending_chart": {
|
|
62
|
+
"vega_lite_spec": chart_schema,
|
|
63
|
+
"reasoning": result.get("reasoning", ""),
|
|
64
|
+
"chart_type": chart_type,
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
except Exception:
|
|
68
|
+
logger.exception("Chart generation failed (non-fatal)")
|
|
69
|
+
|
|
70
|
+
return {}
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
import orjson
|
|
7
|
+
from langchain_core.tools import tool
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
# Side-channel: keyed by chart_id so streaming.py can fetch the spec
|
|
12
|
+
# without the model ever seeing the full JSON.
|
|
13
|
+
_pending_charts: dict[str, dict] = {}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@tool
|
|
17
|
+
async def create_chart(
|
|
18
|
+
data: list[dict] | None = None,
|
|
19
|
+
question: str = "",
|
|
20
|
+
title: str = "",
|
|
21
|
+
color_scheme: str = "",
|
|
22
|
+
) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Generate a Vega-Lite chart from structured data. Call this when the user asks
|
|
25
|
+
for a chart, graph, or visualization. The chart renders automatically in the UI.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
data: list of dicts with consistent keys (e.g. [{"platform": "snowflake", "count": 2290}])
|
|
29
|
+
question: the user's question or description of what to visualize
|
|
30
|
+
title: optional chart title
|
|
31
|
+
color_scheme: optional color instruction e.g. "rainbow", "blue", "categorical", "green"
|
|
32
|
+
|
|
33
|
+
On follow-up requests to change chart colors or style, call this again with the
|
|
34
|
+
same data and the new color_scheme.
|
|
35
|
+
|
|
36
|
+
Example: create_chart(data=[...], question="datasets by platform", color_scheme="rainbow")
|
|
37
|
+
"""
|
|
38
|
+
from analytics_agent.agent.llm import get_chart_llm
|
|
39
|
+
from analytics_agent.prompts.chart import CHART_SYSTEM_PROMPT, build_chart_user_prompt
|
|
40
|
+
|
|
41
|
+
if not data:
|
|
42
|
+
return "No data provided — cannot create chart."
|
|
43
|
+
|
|
44
|
+
columns = list(data[0].keys()) if data else []
|
|
45
|
+
llm = get_chart_llm()
|
|
46
|
+
|
|
47
|
+
full_question = question or title
|
|
48
|
+
if color_scheme:
|
|
49
|
+
full_question = f"{full_question} (use {color_scheme} color scheme)"
|
|
50
|
+
|
|
51
|
+
user_prompt = build_chart_user_prompt(
|
|
52
|
+
question=full_question,
|
|
53
|
+
sql="",
|
|
54
|
+
columns=columns,
|
|
55
|
+
sample_rows=data[:50],
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
60
|
+
|
|
61
|
+
response = await llm.ainvoke(
|
|
62
|
+
[SystemMessage(content=CHART_SYSTEM_PROMPT), HumanMessage(content=user_prompt)]
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
raw = response.content
|
|
66
|
+
if isinstance(raw, list):
|
|
67
|
+
raw = next(
|
|
68
|
+
(b.get("text", "") for b in raw if isinstance(b, dict) and b.get("type") == "text"),
|
|
69
|
+
"",
|
|
70
|
+
)
|
|
71
|
+
if "```" in raw:
|
|
72
|
+
raw = raw.split("```")[1]
|
|
73
|
+
if raw.startswith("json"):
|
|
74
|
+
raw = raw[4:]
|
|
75
|
+
result = orjson.loads(raw.strip())
|
|
76
|
+
|
|
77
|
+
chart_schema = result.get("chart_schema", {})
|
|
78
|
+
chart_type = result.get("chart_type", "")
|
|
79
|
+
|
|
80
|
+
if chart_schema and chart_type:
|
|
81
|
+
chart_schema["data"] = {"values": data}
|
|
82
|
+
|
|
83
|
+
# Store spec in side-channel — return a short human-readable summary so
|
|
84
|
+
# the model retains context for follow-up requests (e.g. "change color")
|
|
85
|
+
chart_id = str(uuid.uuid4())
|
|
86
|
+
_pending_charts[chart_id] = {
|
|
87
|
+
"vega_lite_spec": chart_schema,
|
|
88
|
+
"reasoning": result.get("reasoning", ""),
|
|
89
|
+
"chart_type": chart_type,
|
|
90
|
+
}
|
|
91
|
+
color_note = f", color_scheme={color_scheme!r}" if color_scheme else ""
|
|
92
|
+
# Include the full data inline so the model can reuse it on follow-up requests
|
|
93
|
+
# (e.g. "redraw with different colors")
|
|
94
|
+
data_summary = orjson.dumps(data).decode()
|
|
95
|
+
return (
|
|
96
|
+
f"CHART_READY:{chart_id} "
|
|
97
|
+
f"({chart_type} chart, {len(data)} rows{color_note})\n"
|
|
98
|
+
f"data={data_summary}"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.exception("create_chart failed")
|
|
103
|
+
return f"Chart generation failed: {e}"
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pluggable chat history compaction.
|
|
3
|
+
|
|
4
|
+
OSS default: TurnWindowCompactor drops oldest turns by token budget.
|
|
5
|
+
DataHub Cloud (or other extensions) can register a SummarizingCompactor
|
|
6
|
+
via compactor_registry.register_compactor() at app startup.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Protocol, runtime_checkable
|
|
12
|
+
|
|
13
|
+
from langchain_core.messages import AIMessage, BaseMessage
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@runtime_checkable
|
|
17
|
+
class HistoryCompactor(Protocol):
|
|
18
|
+
def compact(
|
|
19
|
+
self,
|
|
20
|
+
turns: list[list[BaseMessage]],
|
|
21
|
+
max_tokens: int,
|
|
22
|
+
) -> list[list[BaseMessage]]:
|
|
23
|
+
"""Return a (possibly shorter) list of turns that fits within max_tokens.
|
|
24
|
+
|
|
25
|
+
Turns are in chronological order; always keep the most recent turn.
|
|
26
|
+
Never return an empty list when given a non-empty input.
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def estimate_tokens(msgs: list[BaseMessage]) -> int:
|
|
32
|
+
"""Fast character-based token estimate (~4 chars per token)."""
|
|
33
|
+
total = 0
|
|
34
|
+
for msg in msgs:
|
|
35
|
+
total += 100 # per-message overhead (role, metadata, framing)
|
|
36
|
+
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
37
|
+
total += len(content) // 4
|
|
38
|
+
if isinstance(msg, AIMessage):
|
|
39
|
+
for tc in msg.tool_calls or []:
|
|
40
|
+
total += len(str(tc.get("args", ""))) // 4
|
|
41
|
+
return total
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class TurnWindowCompactor:
|
|
45
|
+
"""Drop oldest turns until the flattened history fits within max_tokens."""
|
|
46
|
+
|
|
47
|
+
def compact(
|
|
48
|
+
self,
|
|
49
|
+
turns: list[list[BaseMessage]],
|
|
50
|
+
max_tokens: int,
|
|
51
|
+
) -> list[list[BaseMessage]]:
|
|
52
|
+
while len(turns) > 1:
|
|
53
|
+
flat = [msg for turn in turns for msg in turn]
|
|
54
|
+
if estimate_tokens(flat) <= max_tokens:
|
|
55
|
+
break
|
|
56
|
+
turns = turns[1:]
|
|
57
|
+
return turns
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module-level registry for the active HistoryCompactor.
|
|
3
|
+
|
|
4
|
+
DataHub Cloud (or any extension) can call register_compactor() at app startup
|
|
5
|
+
to swap in a more sophisticated strategy (e.g. LLM summarization) without
|
|
6
|
+
modifying core files.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from analytics_agent.agent.compaction import HistoryCompactor, TurnWindowCompactor
|
|
12
|
+
|
|
13
|
+
_compactor: HistoryCompactor = TurnWindowCompactor()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def register_compactor(c: HistoryCompactor) -> None:
|
|
17
|
+
global _compactor
|
|
18
|
+
_compactor = c
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_compactor() -> HistoryCompactor:
|
|
22
|
+
return _compactor
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import orjson
|
|
6
|
+
from langchain.agents import create_agent
|
|
7
|
+
from langchain_core.messages import ToolMessage
|
|
8
|
+
from langgraph.graph import END, START, StateGraph
|
|
9
|
+
|
|
10
|
+
from analytics_agent.agent.llm import get_llm
|
|
11
|
+
from analytics_agent.agent.state import AgentState
|
|
12
|
+
from analytics_agent.prompts.system import build_system_prompt
|
|
13
|
+
|
|
14
|
+
# Write-back skills are opt-in; only included when explicitly enabled by the user
|
|
15
|
+
_SKILL_TOOL_NAMES: frozenset[str] = frozenset({"publish_analysis", "save_correction"})
|
|
16
|
+
_MUTATION_TOOL_NAMES = _SKILL_TOOL_NAMES # alias used in filter below
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_last_sql_result(state: AgentState) -> dict | None:
|
|
20
|
+
"""Scan message history for the last execute_sql ToolMessage and parse its content."""
|
|
21
|
+
for msg in reversed(state["messages"]):
|
|
22
|
+
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "execute_sql":
|
|
23
|
+
try:
|
|
24
|
+
if isinstance(msg.content, str):
|
|
25
|
+
return orjson.loads(msg.content)
|
|
26
|
+
except Exception:
|
|
27
|
+
pass
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _route_after_agent(state: AgentState) -> Literal["chart", "__end__"]:
|
|
32
|
+
result = get_last_sql_result(state)
|
|
33
|
+
if result and result.get("rows"):
|
|
34
|
+
return "chart"
|
|
35
|
+
return "__end__"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def build_graph(
|
|
39
|
+
engine_name: str,
|
|
40
|
+
engine=None, # pre-resolved engine from resolver.py; if None falls back to registry
|
|
41
|
+
system_prompt_override: str | None = None,
|
|
42
|
+
disabled_tools: set[str] | None = None,
|
|
43
|
+
enabled_mutations: set[str] | None = None,
|
|
44
|
+
context_tools: list | None = None, # pre-built from DB context platforms at request time
|
|
45
|
+
engine_tools: list | None = None, # pre-built for MCP data sources (bypasses QueryEngine)
|
|
46
|
+
):
|
|
47
|
+
from analytics_agent.agent.chart_generator import chart_node
|
|
48
|
+
from analytics_agent.engines.factory import get_registry
|
|
49
|
+
|
|
50
|
+
disabled = disabled_tools or set()
|
|
51
|
+
llm = get_llm(streaming=True)
|
|
52
|
+
|
|
53
|
+
from analytics_agent.agent.chart_tool import create_chart
|
|
54
|
+
|
|
55
|
+
# Context platform tools — built dynamically from DB at request time.
|
|
56
|
+
# Falls back to env-var based build only when caller doesn't provide them.
|
|
57
|
+
if context_tools is not None:
|
|
58
|
+
datahub_tools = [t for t in context_tools if t.name not in disabled]
|
|
59
|
+
else:
|
|
60
|
+
from analytics_agent.context.datahub import build_datahub_tools
|
|
61
|
+
|
|
62
|
+
datahub_tools = [t for t in build_datahub_tools() if t.name not in disabled]
|
|
63
|
+
|
|
64
|
+
# Always-on skills (context search etc.) + opt-in write-back skills
|
|
65
|
+
from analytics_agent.skills.loader import build_always_on_skill_tools, build_skill_tools
|
|
66
|
+
|
|
67
|
+
skill_tools = build_always_on_skill_tools() + build_skill_tools(enabled_mutations or set())
|
|
68
|
+
|
|
69
|
+
# Engine tools — MCP data sources supply pre-built tools; native engines use QueryEngine
|
|
70
|
+
if engine_tools is not None:
|
|
71
|
+
engine_tools = [t for t in engine_tools if t.name not in disabled]
|
|
72
|
+
else:
|
|
73
|
+
if engine is None:
|
|
74
|
+
registry = get_registry()
|
|
75
|
+
engine = registry.get(engine_name)
|
|
76
|
+
if not engine:
|
|
77
|
+
raise ValueError(f"Engine '{engine_name}' not found.")
|
|
78
|
+
engine_tools = [t for t in engine.get_tools() if t.name not in disabled]
|
|
79
|
+
chart_tools = [] if "create_chart" in disabled else [create_chart]
|
|
80
|
+
all_tools = datahub_tools + skill_tools + engine_tools + chart_tools
|
|
81
|
+
|
|
82
|
+
if system_prompt_override:
|
|
83
|
+
from analytics_agent.skills.loader import (
|
|
84
|
+
get_improve_context_prompt_section,
|
|
85
|
+
get_search_business_context_section,
|
|
86
|
+
get_skill_system_prompt_section,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
system_prompt = system_prompt_override.format(engine_name=engine_name)
|
|
90
|
+
system_prompt += get_search_business_context_section()
|
|
91
|
+
system_prompt += get_improve_context_prompt_section()
|
|
92
|
+
if enabled_mutations:
|
|
93
|
+
system_prompt += get_skill_system_prompt_section(enabled_mutations)
|
|
94
|
+
else:
|
|
95
|
+
system_prompt = build_system_prompt(engine_name, enabled_skills=enabled_mutations)
|
|
96
|
+
|
|
97
|
+
# Enable per-tool error handling so validation errors (e.g. hallucinated
|
|
98
|
+
# arguments like filter= on get_entities) are returned as tool messages
|
|
99
|
+
# the agent can read and recover from, rather than crashing the loop.
|
|
100
|
+
for tool in all_tools:
|
|
101
|
+
tool.handle_tool_error = True
|
|
102
|
+
|
|
103
|
+
react_agent = create_agent(
|
|
104
|
+
model=llm,
|
|
105
|
+
tools=all_tools,
|
|
106
|
+
state_schema=AgentState,
|
|
107
|
+
system_prompt=system_prompt,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
graph = StateGraph(AgentState)
|
|
111
|
+
graph.add_node("agent", react_agent)
|
|
112
|
+
graph.add_node("chart", chart_node)
|
|
113
|
+
graph.add_edge(START, "agent")
|
|
114
|
+
graph.add_conditional_edges(
|
|
115
|
+
"agent",
|
|
116
|
+
_route_after_agent,
|
|
117
|
+
{"chart": "chart", "__end__": END},
|
|
118
|
+
)
|
|
119
|
+
graph.add_edge("chart", END)
|
|
120
|
+
|
|
121
|
+
return graph.compile()
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reconstruct LangChain-compatible message history from DB-persisted events.
|
|
3
|
+
|
|
4
|
+
Strategy:
|
|
5
|
+
- Group stored messages by user turn (each user TEXT starts a new turn).
|
|
6
|
+
- Each turn emits: HumanMessage → [tool call/result pairs] → AIMessage (final text).
|
|
7
|
+
- Tool calls and results are matched by sequence order within a turn.
|
|
8
|
+
- If a turn had no tool calls and no COMPLETE text, we fall back to TEXT chunks.
|
|
9
|
+
- Turns that have no assistant response at all (e.g. error turns) are skipped entirely
|
|
10
|
+
to avoid consecutive HumanMessages which LangGraph rejects.
|
|
11
|
+
- An optional HistoryCompactor drops the oldest turns to stay within the token budget.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
|
|
18
|
+
import orjson
|
|
19
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
|
20
|
+
|
|
21
|
+
from analytics_agent.agent.compaction import HistoryCompactor
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def build_history(
|
|
25
|
+
stored_messages: list,
|
|
26
|
+
current_user_text: str,
|
|
27
|
+
compactor: HistoryCompactor | None = None,
|
|
28
|
+
max_history_tokens: int = 120_000,
|
|
29
|
+
) -> list[BaseMessage]:
|
|
30
|
+
"""
|
|
31
|
+
Convert persisted message rows into a LangChain message list ending with
|
|
32
|
+
the current user turn.
|
|
33
|
+
|
|
34
|
+
If a compactor is provided, oldest turns are dropped to stay within
|
|
35
|
+
max_history_tokens before returning.
|
|
36
|
+
"""
|
|
37
|
+
# Split into turns at each user TEXT message
|
|
38
|
+
raw_turns: list[list] = []
|
|
39
|
+
current_turn: list = []
|
|
40
|
+
for msg in stored_messages:
|
|
41
|
+
payload = orjson.loads(msg.payload) if isinstance(msg.payload, str) else msg.payload
|
|
42
|
+
if msg.role == "user" and msg.event_type == "TEXT":
|
|
43
|
+
if current_turn:
|
|
44
|
+
raw_turns.append(current_turn)
|
|
45
|
+
current_turn = [("user", payload.get("text", ""), msg)]
|
|
46
|
+
else:
|
|
47
|
+
current_turn.append((msg.role, payload, msg))
|
|
48
|
+
if current_turn:
|
|
49
|
+
raw_turns.append(current_turn)
|
|
50
|
+
|
|
51
|
+
# Build LangChain messages per turn
|
|
52
|
+
lc_turns: list[list[BaseMessage]] = []
|
|
53
|
+
for turn in raw_turns:
|
|
54
|
+
role0, content0, _ = turn[0]
|
|
55
|
+
if role0 != "user":
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
tool_calls: list[dict] = []
|
|
59
|
+
tool_results: list[dict] = []
|
|
60
|
+
text_chunks: list[str] = []
|
|
61
|
+
final_text = ""
|
|
62
|
+
has_chart = False
|
|
63
|
+
|
|
64
|
+
for role, payload, msg in turn[1:]:
|
|
65
|
+
if role != "assistant":
|
|
66
|
+
continue
|
|
67
|
+
evt = msg.event_type
|
|
68
|
+
|
|
69
|
+
if evt == "TOOL_CALL":
|
|
70
|
+
tool_calls.append(
|
|
71
|
+
{
|
|
72
|
+
"id": msg.id,
|
|
73
|
+
"name": payload.get("tool_name", ""),
|
|
74
|
+
"input": payload.get("tool_input", {}),
|
|
75
|
+
}
|
|
76
|
+
)
|
|
77
|
+
elif evt in ("TOOL_RESULT", "SQL"):
|
|
78
|
+
idx = len(tool_results)
|
|
79
|
+
call_id = tool_calls[idx]["id"] if idx < len(tool_calls) else msg.id
|
|
80
|
+
tool_results.append(
|
|
81
|
+
{
|
|
82
|
+
"id": call_id,
|
|
83
|
+
"name": payload.get("tool_name", ""),
|
|
84
|
+
"result": payload.get("result", payload.get("sql", ""))[:4000],
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
elif evt == "TEXT":
|
|
88
|
+
chunk = payload.get("text", "")
|
|
89
|
+
if chunk:
|
|
90
|
+
text_chunks.append(chunk)
|
|
91
|
+
elif evt == "COMPLETE":
|
|
92
|
+
final_text = payload.get("text", "")
|
|
93
|
+
elif evt == "CHART":
|
|
94
|
+
has_chart = True
|
|
95
|
+
|
|
96
|
+
if not final_text:
|
|
97
|
+
assembled = "".join(text_chunks)
|
|
98
|
+
assembled = re.sub(
|
|
99
|
+
r"```(?:json)?\s*\{.*?\"chart_schema\".*?\}\s*```", "", assembled, flags=re.DOTALL
|
|
100
|
+
).strip()
|
|
101
|
+
final_text = assembled[:500] if assembled else ""
|
|
102
|
+
|
|
103
|
+
if not final_text and has_chart:
|
|
104
|
+
final_text = "[Chart rendered]"
|
|
105
|
+
|
|
106
|
+
has_any_assistant_content = tool_calls or final_text or has_chart
|
|
107
|
+
if not has_any_assistant_content:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
turn_msgs: list[BaseMessage] = []
|
|
111
|
+
turn_msgs.append(HumanMessage(content=content0))
|
|
112
|
+
|
|
113
|
+
if tool_calls:
|
|
114
|
+
lc_tool_calls = [
|
|
115
|
+
{
|
|
116
|
+
"id": tc["id"],
|
|
117
|
+
"name": tc["name"],
|
|
118
|
+
"args": tc["input"],
|
|
119
|
+
"type": "tool_call",
|
|
120
|
+
}
|
|
121
|
+
for tc in tool_calls
|
|
122
|
+
]
|
|
123
|
+
turn_msgs.append(AIMessage(content="", tool_calls=lc_tool_calls))
|
|
124
|
+
|
|
125
|
+
# Every tool_call must have a ToolMessage with its exact ID.
|
|
126
|
+
# Pad missing results; always use tc["id"] as tool_call_id so
|
|
127
|
+
# the IDs are guaranteed to match the AIMessage (avoids Anthropic
|
|
128
|
+
# "unexpected tool_use_id" errors from orphaned DB records).
|
|
129
|
+
for i, tc in enumerate(tool_calls):
|
|
130
|
+
if i < len(tool_results):
|
|
131
|
+
tr = tool_results[i]
|
|
132
|
+
turn_msgs.append(
|
|
133
|
+
ToolMessage(
|
|
134
|
+
content=str(tr["result"]),
|
|
135
|
+
tool_call_id=tc["id"],
|
|
136
|
+
name=tr["name"],
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
else:
|
|
140
|
+
turn_msgs.append(
|
|
141
|
+
ToolMessage(
|
|
142
|
+
content="[Tool did not return a result]",
|
|
143
|
+
tool_call_id=tc["id"],
|
|
144
|
+
name=tc["name"],
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if final_text or not tool_calls:
|
|
149
|
+
turn_msgs.append(AIMessage(content=final_text or "Done."))
|
|
150
|
+
|
|
151
|
+
lc_turns.append(turn_msgs)
|
|
152
|
+
|
|
153
|
+
# Drop oldest turns if needed
|
|
154
|
+
if compactor is not None and lc_turns:
|
|
155
|
+
lc_turns = compactor.compact(lc_turns, max_tokens=max_history_tokens)
|
|
156
|
+
|
|
157
|
+
result = [msg for turn in lc_turns for msg in turn]
|
|
158
|
+
result.append(HumanMessage(content=current_user_text))
|
|
159
|
+
return result
|