datahub-analytics-agent 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. analytics_agent/__init__.py +0 -0
  2. analytics_agent/agent/__init__.py +0 -0
  3. analytics_agent/agent/analysis.py +149 -0
  4. analytics_agent/agent/chart_generator.py +70 -0
  5. analytics_agent/agent/chart_tool.py +103 -0
  6. analytics_agent/agent/compaction.py +57 -0
  7. analytics_agent/agent/compactor_registry.py +22 -0
  8. analytics_agent/agent/graph.py +121 -0
  9. analytics_agent/agent/history.py +159 -0
  10. analytics_agent/agent/llm.py +87 -0
  11. analytics_agent/agent/mock_llm.py +111 -0
  12. analytics_agent/agent/state.py +13 -0
  13. analytics_agent/agent/streaming.py +304 -0
  14. analytics_agent/api/__init__.py +135 -0
  15. analytics_agent/api/chat.py +439 -0
  16. analytics_agent/api/conversations.py +244 -0
  17. analytics_agent/api/oauth.py +741 -0
  18. analytics_agent/api/settings.py +1947 -0
  19. analytics_agent/config.py +236 -0
  20. analytics_agent/context/__init__.py +0 -0
  21. analytics_agent/context/base.py +26 -0
  22. analytics_agent/context/datahub.py +242 -0
  23. analytics_agent/context/mcp_platform.py +123 -0
  24. analytics_agent/context/native_datahub.py +58 -0
  25. analytics_agent/context/registry.py +84 -0
  26. analytics_agent/db/__init__.py +0 -0
  27. analytics_agent/db/alembic/env.py +49 -0
  28. analytics_agent/db/alembic/script.py.mako +25 -0
  29. analytics_agent/db/alembic/versions/001_init.py +47 -0
  30. analytics_agent/db/alembic/versions/002_settings_table.py +30 -0
  31. analytics_agent/db/alembic/versions/003_integrations.py +52 -0
  32. analytics_agent/db/alembic/versions/004_conversation_quality.py +28 -0
  33. analytics_agent/db/alembic/versions/005_context_platforms.py +36 -0
  34. analytics_agent/db/base.py +33 -0
  35. analytics_agent/db/models.py +137 -0
  36. analytics_agent/db/repository.py +294 -0
  37. analytics_agent/db/types.py +69 -0
  38. analytics_agent/engines/__init__.py +0 -0
  39. analytics_agent/engines/base.py +30 -0
  40. analytics_agent/engines/factory.py +95 -0
  41. analytics_agent/engines/mcp/__init__.py +0 -0
  42. analytics_agent/engines/mcp/engine.py +78 -0
  43. analytics_agent/engines/resolver.py +84 -0
  44. analytics_agent/engines/snowflake/__init__.py +0 -0
  45. analytics_agent/engines/snowflake/engine.py +304 -0
  46. analytics_agent/engines/sqlalchemy/__init__.py +0 -0
  47. analytics_agent/engines/sqlalchemy/engine.py +163 -0
  48. analytics_agent/main.py +536 -0
  49. analytics_agent/prompts/__init__.py +0 -0
  50. analytics_agent/prompts/chart.py +101 -0
  51. analytics_agent/prompts/system.py +33 -0
  52. analytics_agent/prompts/system_prompt.md +184 -0
  53. analytics_agent/skills/__init__.py +0 -0
  54. analytics_agent/skills/datahub_skills.py +409 -0
  55. analytics_agent/skills/improve-context/SKILL.md +73 -0
  56. analytics_agent/skills/loader.py +162 -0
  57. analytics_agent/skills/publish-analysis/SKILL.md +99 -0
  58. analytics_agent/skills/save-correction/SKILL.md +161 -0
  59. analytics_agent/skills/search-business-context/SKILL.md +109 -0
  60. analytics_agent/tracing.py +88 -0
  61. datahub_analytics_agent-0.1.0.dist-info/METADATA +328 -0
  62. datahub_analytics_agent-0.1.0.dist-info/RECORD +63 -0
  63. datahub_analytics_agent-0.1.0.dist-info/WHEEL +4 -0
File without changes
File without changes
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ from dataclasses import dataclass
6
+
7
+ import orjson
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ CONTEXT_TOOLS: frozenset[str] = frozenset(
12
+ {"search_documents", "grep_documents", "search", "get_entities", "search_business_context"}
13
+ )
14
+ _CONTEXT_TOOLS = CONTEXT_TOOLS # backward compat alias
15
+
16
+ _SCORE_LABELS = {1: "Poor", 2: "Poor", 3: "Fair", 4: "Good", 5: "Excellent"}
17
+
18
+ _ASSESSMENT_PROMPT = """\
19
+ You are assessing the **context quality** of a data assistant conversation.
20
+
21
+ Context quality measures how well the DataHub knowledge base (documentation, \
22
+ definitions, dataset descriptions) supported the agent's work.
23
+
24
+ Score 1–5:
25
+ 5 Excellent — DataHub had rich, accurate documentation that fully covered the \
26
+ question. Agent applied the definition directly with no improvisation.
27
+ 4 Good — Useful docs found; definition was mostly complete; agent made one minor \
28
+ stated assumption that didn't change the answer meaningfully.
29
+ 3 Fair — Definition found but incomplete or ambiguous; agent had to fill gaps, \
30
+ deviate from the definition, or ask the user for clarification about what the \
31
+ context should have made clear.
32
+ 2 Poor — Docs mostly missing or returned empty results; agent improvised \
33
+ substantially and the answer depended heavily on undocumented choices.
34
+ 1 Very Poor — No useful context; agent expressed significant uncertainty, made \
35
+ conflicting assumptions, or produced an answer that contradicts available definitions.
36
+
37
+ Key signals that push the score DOWN:
38
+ - Agent says "the definition doesn't cover this" or "I'll interpret this as…"
39
+ - Agent switches columns, tables, or date anchors not mentioned in the definition
40
+ - Agent produces a result that varies based on an undocumented assumption
41
+ - Agent asks the user to clarify something the glossary/docs should have defined
42
+
43
+ --- CONTEXT TOOL CALLS AND RESULTS ---
44
+ {context_calls}
45
+ --- END CONTEXT ---
46
+
47
+ --- AGENT REASONING (what the agent said and concluded) ---
48
+ {agent_reasoning}
49
+ --- END REASONING ---
50
+
51
+ Respond with ONLY valid JSON, no explanation outside it:
52
+ {{"score": <1-5>, "label": "<Excellent|Good|Fair|Poor>", "reason": "<one sentence that names the specific gap or strength>"}}"""
53
+
54
+
55
+ async def compute_context_quality(messages: list) -> QualityScore:
56
+ """
57
+ LLM-assessed context quality score (1–5).
58
+
59
+ Extracts DataHub context tool calls + results and the agent's own reasoning
60
+ text, then asks a cheap model to judge whether the returned context was
61
+ actually useful and complete — penalising cases where the agent had to
62
+ improvise or deviate from the definition.
63
+ Returns Neutral (3) immediately when no context tool calls have occurred yet.
64
+ """
65
+ context_calls: list[dict] = []
66
+ agent_text_chunks: list[str] = []
67
+
68
+ for msg in messages:
69
+ try:
70
+ payload = (
71
+ orjson.loads(msg.payload) if isinstance(msg.payload, (str, bytes)) else msg.payload
72
+ )
73
+ except Exception:
74
+ continue
75
+
76
+ if msg.event_type == "TOOL_RESULT":
77
+ tool_name = payload.get("tool_name", "")
78
+ if tool_name not in _CONTEXT_TOOLS:
79
+ continue
80
+ result_raw = payload.get("result", "")
81
+ result_str = str(result_raw)[:800] + ("…" if len(str(result_raw)) > 800 else "")
82
+ context_calls.append(
83
+ {
84
+ "tool": tool_name,
85
+ "is_error": bool(payload.get("is_error", False)),
86
+ "result": result_str,
87
+ }
88
+ )
89
+ elif msg.event_type in ("TEXT", "COMPLETE"):
90
+ text = payload.get("text", "")
91
+ if text:
92
+ agent_text_chunks.append(text[:400])
93
+
94
+ if not context_calls:
95
+ return QualityScore(
96
+ score=3, label="Neutral", breakdown={"reason": "No context lookups yet"}
97
+ )
98
+
99
+ calls_text = "\n\n".join(
100
+ f"Tool: {c['tool']}\nError: {c['is_error']}\nResult: {c['result']}" for c in context_calls
101
+ )
102
+ # Deduplicate and cap agent reasoning (TEXT events stream token-by-token)
103
+ reasoning = " ".join(dict.fromkeys(agent_text_chunks))[:1200]
104
+
105
+ prompt = _ASSESSMENT_PROMPT.format(context_calls=calls_text, agent_reasoning=reasoning)
106
+
107
+ try:
108
+ from langchain_core.messages import HumanMessage, SystemMessage
109
+
110
+ from analytics_agent.agent.llm import get_quality_llm
111
+
112
+ llm = get_quality_llm()
113
+ response = await llm.ainvoke(
114
+ [
115
+ SystemMessage(
116
+ content="You assess data assistant context quality. Reply only with the requested JSON."
117
+ ),
118
+ HumanMessage(content=prompt),
119
+ ]
120
+ )
121
+ raw = response.content
122
+ if isinstance(raw, list):
123
+ raw = next(
124
+ (b.get("text", "") for b in raw if isinstance(b, dict) and b.get("type") == "text"),
125
+ "",
126
+ )
127
+ raw = raw.strip()
128
+ # Strip markdown code fences if present
129
+ if raw.startswith("```"):
130
+ raw = raw.split("```")[1]
131
+ if raw.startswith("json"):
132
+ raw = raw[4:]
133
+ data = json.loads(raw.strip())
134
+ score = max(1, min(5, int(data.get("score", 3))))
135
+ label = data.get("label", _SCORE_LABELS[score])
136
+ reason = data.get("reason", "")
137
+ return QualityScore(score=score, label=label, breakdown={"reason": reason})
138
+ except Exception as exc:
139
+ logger.warning("Context quality LLM assessment failed: %s", exc)
140
+ return QualityScore(
141
+ score=3, label="Neutral", breakdown={"reason": "Assessment unavailable"}
142
+ )
143
+
144
+
145
+ @dataclass
146
+ class QualityScore:
147
+ score: int # 1–5
148
+ label: str
149
+ breakdown: dict
@@ -0,0 +1,70 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ import orjson
6
+ from langchain_core.messages import HumanMessage, SystemMessage
7
+
8
+ from analytics_agent.agent.llm import get_chart_llm
9
+ from analytics_agent.agent.state import AgentState
10
+ from analytics_agent.prompts.chart import CHART_SYSTEM_PROMPT, build_chart_user_prompt
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ async def chart_node(state: AgentState) -> dict:
16
+ """
17
+ Generate a Vega-Lite chart spec from the last SQL result and store it in state.
18
+ streaming.py reads it from state after the graph completes.
19
+ """
20
+ from analytics_agent.agent.graph import get_last_sql_result
21
+
22
+ sql_result = get_last_sql_result(state)
23
+ if not sql_result or not sql_result.get("rows"):
24
+ return {}
25
+
26
+ llm = get_chart_llm()
27
+
28
+ user_prompt = build_chart_user_prompt(
29
+ question=state.get("user_question", ""),
30
+ sql=sql_result.get("sql", ""),
31
+ columns=sql_result.get("columns", []),
32
+ sample_rows=sql_result.get("rows", []),
33
+ )
34
+
35
+ try:
36
+ response = await llm.ainvoke(
37
+ [SystemMessage(content=CHART_SYSTEM_PROMPT), HumanMessage(content=user_prompt)]
38
+ )
39
+
40
+ raw = response.content
41
+ if isinstance(raw, list):
42
+ # Anthropic returns list of content blocks
43
+ raw = next(
44
+ (b.get("text", "") for b in raw if isinstance(b, dict) and b.get("type") == "text"),
45
+ "",
46
+ )
47
+ if "```" in raw:
48
+ raw = raw.split("```")[1]
49
+ if raw.startswith("json"):
50
+ raw = raw[4:]
51
+ result = orjson.loads(raw.strip())
52
+
53
+ chart_schema = result.get("chart_schema", {})
54
+ chart_type = result.get("chart_type", "")
55
+
56
+ if chart_schema and chart_type:
57
+ chart_schema["data"] = {"values": sql_result.get("rows", [])}
58
+
59
+ # Store in state so streaming.py can emit it after graph completion
60
+ return {
61
+ "pending_chart": {
62
+ "vega_lite_spec": chart_schema,
63
+ "reasoning": result.get("reasoning", ""),
64
+ "chart_type": chart_type,
65
+ }
66
+ }
67
+ except Exception:
68
+ logger.exception("Chart generation failed (non-fatal)")
69
+
70
+ return {}
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import uuid
5
+
6
+ import orjson
7
+ from langchain_core.tools import tool
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Side-channel: keyed by chart_id so streaming.py can fetch the spec
12
+ # without the model ever seeing the full JSON.
13
+ _pending_charts: dict[str, dict] = {}
14
+
15
+
16
+ @tool
17
+ async def create_chart(
18
+ data: list[dict] | None = None,
19
+ question: str = "",
20
+ title: str = "",
21
+ color_scheme: str = "",
22
+ ) -> str:
23
+ """
24
+ Generate a Vega-Lite chart from structured data. Call this when the user asks
25
+ for a chart, graph, or visualization. The chart renders automatically in the UI.
26
+
27
+ Args:
28
+ data: list of dicts with consistent keys (e.g. [{"platform": "snowflake", "count": 2290}])
29
+ question: the user's question or description of what to visualize
30
+ title: optional chart title
31
+ color_scheme: optional color instruction e.g. "rainbow", "blue", "categorical", "green"
32
+
33
+ On follow-up requests to change chart colors or style, call this again with the
34
+ same data and the new color_scheme.
35
+
36
+ Example: create_chart(data=[...], question="datasets by platform", color_scheme="rainbow")
37
+ """
38
+ from analytics_agent.agent.llm import get_chart_llm
39
+ from analytics_agent.prompts.chart import CHART_SYSTEM_PROMPT, build_chart_user_prompt
40
+
41
+ if not data:
42
+ return "No data provided — cannot create chart."
43
+
44
+ columns = list(data[0].keys()) if data else []
45
+ llm = get_chart_llm()
46
+
47
+ full_question = question or title
48
+ if color_scheme:
49
+ full_question = f"{full_question} (use {color_scheme} color scheme)"
50
+
51
+ user_prompt = build_chart_user_prompt(
52
+ question=full_question,
53
+ sql="",
54
+ columns=columns,
55
+ sample_rows=data[:50],
56
+ )
57
+
58
+ try:
59
+ from langchain_core.messages import HumanMessage, SystemMessage
60
+
61
+ response = await llm.ainvoke(
62
+ [SystemMessage(content=CHART_SYSTEM_PROMPT), HumanMessage(content=user_prompt)]
63
+ )
64
+
65
+ raw = response.content
66
+ if isinstance(raw, list):
67
+ raw = next(
68
+ (b.get("text", "") for b in raw if isinstance(b, dict) and b.get("type") == "text"),
69
+ "",
70
+ )
71
+ if "```" in raw:
72
+ raw = raw.split("```")[1]
73
+ if raw.startswith("json"):
74
+ raw = raw[4:]
75
+ result = orjson.loads(raw.strip())
76
+
77
+ chart_schema = result.get("chart_schema", {})
78
+ chart_type = result.get("chart_type", "")
79
+
80
+ if chart_schema and chart_type:
81
+ chart_schema["data"] = {"values": data}
82
+
83
+ # Store spec in side-channel — return a short human-readable summary so
84
+ # the model retains context for follow-up requests (e.g. "change color")
85
+ chart_id = str(uuid.uuid4())
86
+ _pending_charts[chart_id] = {
87
+ "vega_lite_spec": chart_schema,
88
+ "reasoning": result.get("reasoning", ""),
89
+ "chart_type": chart_type,
90
+ }
91
+ color_note = f", color_scheme={color_scheme!r}" if color_scheme else ""
92
+ # Include the full data inline so the model can reuse it on follow-up requests
93
+ # (e.g. "redraw with different colors")
94
+ data_summary = orjson.dumps(data).decode()
95
+ return (
96
+ f"CHART_READY:{chart_id} "
97
+ f"({chart_type} chart, {len(data)} rows{color_note})\n"
98
+ f"data={data_summary}"
99
+ )
100
+
101
+ except Exception as e:
102
+ logger.exception("create_chart failed")
103
+ return f"Chart generation failed: {e}"
@@ -0,0 +1,57 @@
1
+ """
2
+ Pluggable chat history compaction.
3
+
4
+ OSS default: TurnWindowCompactor drops oldest turns by token budget.
5
+ DataHub Cloud (or other extensions) can register a SummarizingCompactor
6
+ via compactor_registry.register_compactor() at app startup.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Protocol, runtime_checkable
12
+
13
+ from langchain_core.messages import AIMessage, BaseMessage
14
+
15
+
16
+ @runtime_checkable
17
+ class HistoryCompactor(Protocol):
18
+ def compact(
19
+ self,
20
+ turns: list[list[BaseMessage]],
21
+ max_tokens: int,
22
+ ) -> list[list[BaseMessage]]:
23
+ """Return a (possibly shorter) list of turns that fits within max_tokens.
24
+
25
+ Turns are in chronological order; always keep the most recent turn.
26
+ Never return an empty list when given a non-empty input.
27
+ """
28
+ ...
29
+
30
+
31
+ def estimate_tokens(msgs: list[BaseMessage]) -> int:
32
+ """Fast character-based token estimate (~4 chars per token)."""
33
+ total = 0
34
+ for msg in msgs:
35
+ total += 100 # per-message overhead (role, metadata, framing)
36
+ content = msg.content if isinstance(msg.content, str) else str(msg.content)
37
+ total += len(content) // 4
38
+ if isinstance(msg, AIMessage):
39
+ for tc in msg.tool_calls or []:
40
+ total += len(str(tc.get("args", ""))) // 4
41
+ return total
42
+
43
+
44
+ class TurnWindowCompactor:
45
+ """Drop oldest turns until the flattened history fits within max_tokens."""
46
+
47
+ def compact(
48
+ self,
49
+ turns: list[list[BaseMessage]],
50
+ max_tokens: int,
51
+ ) -> list[list[BaseMessage]]:
52
+ while len(turns) > 1:
53
+ flat = [msg for turn in turns for msg in turn]
54
+ if estimate_tokens(flat) <= max_tokens:
55
+ break
56
+ turns = turns[1:]
57
+ return turns
@@ -0,0 +1,22 @@
1
+ """
2
+ Module-level registry for the active HistoryCompactor.
3
+
4
+ DataHub Cloud (or any extension) can call register_compactor() at app startup
5
+ to swap in a more sophisticated strategy (e.g. LLM summarization) without
6
+ modifying core files.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from analytics_agent.agent.compaction import HistoryCompactor, TurnWindowCompactor
12
+
13
+ _compactor: HistoryCompactor = TurnWindowCompactor()
14
+
15
+
16
+ def register_compactor(c: HistoryCompactor) -> None:
17
+ global _compactor
18
+ _compactor = c
19
+
20
+
21
+ def get_compactor() -> HistoryCompactor:
22
+ return _compactor
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ import orjson
6
+ from langchain.agents import create_agent
7
+ from langchain_core.messages import ToolMessage
8
+ from langgraph.graph import END, START, StateGraph
9
+
10
+ from analytics_agent.agent.llm import get_llm
11
+ from analytics_agent.agent.state import AgentState
12
+ from analytics_agent.prompts.system import build_system_prompt
13
+
14
+ # Write-back skills are opt-in; only included when explicitly enabled by the user
15
+ _SKILL_TOOL_NAMES: frozenset[str] = frozenset({"publish_analysis", "save_correction"})
16
+ _MUTATION_TOOL_NAMES = _SKILL_TOOL_NAMES # alias used in filter below
17
+
18
+
19
+ def get_last_sql_result(state: AgentState) -> dict | None:
20
+ """Scan message history for the last execute_sql ToolMessage and parse its content."""
21
+ for msg in reversed(state["messages"]):
22
+ if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "execute_sql":
23
+ try:
24
+ if isinstance(msg.content, str):
25
+ return orjson.loads(msg.content)
26
+ except Exception:
27
+ pass
28
+ return None
29
+
30
+
31
+ def _route_after_agent(state: AgentState) -> Literal["chart", "__end__"]:
32
+ result = get_last_sql_result(state)
33
+ if result and result.get("rows"):
34
+ return "chart"
35
+ return "__end__"
36
+
37
+
38
+ def build_graph(
39
+ engine_name: str,
40
+ engine=None, # pre-resolved engine from resolver.py; if None falls back to registry
41
+ system_prompt_override: str | None = None,
42
+ disabled_tools: set[str] | None = None,
43
+ enabled_mutations: set[str] | None = None,
44
+ context_tools: list | None = None, # pre-built from DB context platforms at request time
45
+ engine_tools: list | None = None, # pre-built for MCP data sources (bypasses QueryEngine)
46
+ ):
47
+ from analytics_agent.agent.chart_generator import chart_node
48
+ from analytics_agent.engines.factory import get_registry
49
+
50
+ disabled = disabled_tools or set()
51
+ llm = get_llm(streaming=True)
52
+
53
+ from analytics_agent.agent.chart_tool import create_chart
54
+
55
+ # Context platform tools — built dynamically from DB at request time.
56
+ # Falls back to env-var based build only when caller doesn't provide them.
57
+ if context_tools is not None:
58
+ datahub_tools = [t for t in context_tools if t.name not in disabled]
59
+ else:
60
+ from analytics_agent.context.datahub import build_datahub_tools
61
+
62
+ datahub_tools = [t for t in build_datahub_tools() if t.name not in disabled]
63
+
64
+ # Always-on skills (context search etc.) + opt-in write-back skills
65
+ from analytics_agent.skills.loader import build_always_on_skill_tools, build_skill_tools
66
+
67
+ skill_tools = build_always_on_skill_tools() + build_skill_tools(enabled_mutations or set())
68
+
69
+ # Engine tools — MCP data sources supply pre-built tools; native engines use QueryEngine
70
+ if engine_tools is not None:
71
+ engine_tools = [t for t in engine_tools if t.name not in disabled]
72
+ else:
73
+ if engine is None:
74
+ registry = get_registry()
75
+ engine = registry.get(engine_name)
76
+ if not engine:
77
+ raise ValueError(f"Engine '{engine_name}' not found.")
78
+ engine_tools = [t for t in engine.get_tools() if t.name not in disabled]
79
+ chart_tools = [] if "create_chart" in disabled else [create_chart]
80
+ all_tools = datahub_tools + skill_tools + engine_tools + chart_tools
81
+
82
+ if system_prompt_override:
83
+ from analytics_agent.skills.loader import (
84
+ get_improve_context_prompt_section,
85
+ get_search_business_context_section,
86
+ get_skill_system_prompt_section,
87
+ )
88
+
89
+ system_prompt = system_prompt_override.format(engine_name=engine_name)
90
+ system_prompt += get_search_business_context_section()
91
+ system_prompt += get_improve_context_prompt_section()
92
+ if enabled_mutations:
93
+ system_prompt += get_skill_system_prompt_section(enabled_mutations)
94
+ else:
95
+ system_prompt = build_system_prompt(engine_name, enabled_skills=enabled_mutations)
96
+
97
+ # Enable per-tool error handling so validation errors (e.g. hallucinated
98
+ # arguments like filter= on get_entities) are returned as tool messages
99
+ # the agent can read and recover from, rather than crashing the loop.
100
+ for tool in all_tools:
101
+ tool.handle_tool_error = True
102
+
103
+ react_agent = create_agent(
104
+ model=llm,
105
+ tools=all_tools,
106
+ state_schema=AgentState,
107
+ system_prompt=system_prompt,
108
+ )
109
+
110
+ graph = StateGraph(AgentState)
111
+ graph.add_node("agent", react_agent)
112
+ graph.add_node("chart", chart_node)
113
+ graph.add_edge(START, "agent")
114
+ graph.add_conditional_edges(
115
+ "agent",
116
+ _route_after_agent,
117
+ {"chart": "chart", "__end__": END},
118
+ )
119
+ graph.add_edge("chart", END)
120
+
121
+ return graph.compile()
@@ -0,0 +1,159 @@
1
+ """
2
+ Reconstruct LangChain-compatible message history from DB-persisted events.
3
+
4
+ Strategy:
5
+ - Group stored messages by user turn (each user TEXT starts a new turn).
6
+ - Each turn emits: HumanMessage → [tool call/result pairs] → AIMessage (final text).
7
+ - Tool calls and results are matched by sequence order within a turn.
8
+ - If a turn had no tool calls and no COMPLETE text, we fall back to TEXT chunks.
9
+ - Turns that have no assistant response at all (e.g. error turns) are skipped entirely
10
+ to avoid consecutive HumanMessages which LangGraph rejects.
11
+ - An optional HistoryCompactor drops the oldest turns to stay within the token budget.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import re
17
+
18
+ import orjson
19
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
20
+
21
+ from analytics_agent.agent.compaction import HistoryCompactor
22
+
23
+
24
+ def build_history(
25
+ stored_messages: list,
26
+ current_user_text: str,
27
+ compactor: HistoryCompactor | None = None,
28
+ max_history_tokens: int = 120_000,
29
+ ) -> list[BaseMessage]:
30
+ """
31
+ Convert persisted message rows into a LangChain message list ending with
32
+ the current user turn.
33
+
34
+ If a compactor is provided, oldest turns are dropped to stay within
35
+ max_history_tokens before returning.
36
+ """
37
+ # Split into turns at each user TEXT message
38
+ raw_turns: list[list] = []
39
+ current_turn: list = []
40
+ for msg in stored_messages:
41
+ payload = orjson.loads(msg.payload) if isinstance(msg.payload, str) else msg.payload
42
+ if msg.role == "user" and msg.event_type == "TEXT":
43
+ if current_turn:
44
+ raw_turns.append(current_turn)
45
+ current_turn = [("user", payload.get("text", ""), msg)]
46
+ else:
47
+ current_turn.append((msg.role, payload, msg))
48
+ if current_turn:
49
+ raw_turns.append(current_turn)
50
+
51
+ # Build LangChain messages per turn
52
+ lc_turns: list[list[BaseMessage]] = []
53
+ for turn in raw_turns:
54
+ role0, content0, _ = turn[0]
55
+ if role0 != "user":
56
+ continue
57
+
58
+ tool_calls: list[dict] = []
59
+ tool_results: list[dict] = []
60
+ text_chunks: list[str] = []
61
+ final_text = ""
62
+ has_chart = False
63
+
64
+ for role, payload, msg in turn[1:]:
65
+ if role != "assistant":
66
+ continue
67
+ evt = msg.event_type
68
+
69
+ if evt == "TOOL_CALL":
70
+ tool_calls.append(
71
+ {
72
+ "id": msg.id,
73
+ "name": payload.get("tool_name", ""),
74
+ "input": payload.get("tool_input", {}),
75
+ }
76
+ )
77
+ elif evt in ("TOOL_RESULT", "SQL"):
78
+ idx = len(tool_results)
79
+ call_id = tool_calls[idx]["id"] if idx < len(tool_calls) else msg.id
80
+ tool_results.append(
81
+ {
82
+ "id": call_id,
83
+ "name": payload.get("tool_name", ""),
84
+ "result": payload.get("result", payload.get("sql", ""))[:4000],
85
+ }
86
+ )
87
+ elif evt == "TEXT":
88
+ chunk = payload.get("text", "")
89
+ if chunk:
90
+ text_chunks.append(chunk)
91
+ elif evt == "COMPLETE":
92
+ final_text = payload.get("text", "")
93
+ elif evt == "CHART":
94
+ has_chart = True
95
+
96
+ if not final_text:
97
+ assembled = "".join(text_chunks)
98
+ assembled = re.sub(
99
+ r"```(?:json)?\s*\{.*?\"chart_schema\".*?\}\s*```", "", assembled, flags=re.DOTALL
100
+ ).strip()
101
+ final_text = assembled[:500] if assembled else ""
102
+
103
+ if not final_text and has_chart:
104
+ final_text = "[Chart rendered]"
105
+
106
+ has_any_assistant_content = tool_calls or final_text or has_chart
107
+ if not has_any_assistant_content:
108
+ continue
109
+
110
+ turn_msgs: list[BaseMessage] = []
111
+ turn_msgs.append(HumanMessage(content=content0))
112
+
113
+ if tool_calls:
114
+ lc_tool_calls = [
115
+ {
116
+ "id": tc["id"],
117
+ "name": tc["name"],
118
+ "args": tc["input"],
119
+ "type": "tool_call",
120
+ }
121
+ for tc in tool_calls
122
+ ]
123
+ turn_msgs.append(AIMessage(content="", tool_calls=lc_tool_calls))
124
+
125
+ # Every tool_call must have a ToolMessage with its exact ID.
126
+ # Pad missing results; always use tc["id"] as tool_call_id so
127
+ # the IDs are guaranteed to match the AIMessage (avoids Anthropic
128
+ # "unexpected tool_use_id" errors from orphaned DB records).
129
+ for i, tc in enumerate(tool_calls):
130
+ if i < len(tool_results):
131
+ tr = tool_results[i]
132
+ turn_msgs.append(
133
+ ToolMessage(
134
+ content=str(tr["result"]),
135
+ tool_call_id=tc["id"],
136
+ name=tr["name"],
137
+ )
138
+ )
139
+ else:
140
+ turn_msgs.append(
141
+ ToolMessage(
142
+ content="[Tool did not return a result]",
143
+ tool_call_id=tc["id"],
144
+ name=tc["name"],
145
+ )
146
+ )
147
+
148
+ if final_text or not tool_calls:
149
+ turn_msgs.append(AIMessage(content=final_text or "Done."))
150
+
151
+ lc_turns.append(turn_msgs)
152
+
153
+ # Drop oldest turns if needed
154
+ if compactor is not None and lc_turns:
155
+ lc_turns = compactor.compact(lc_turns, max_tokens=max_history_tokens)
156
+
157
+ result = [msg for turn in lc_turns for msg in turn]
158
+ result.append(HumanMessage(content=current_user_text))
159
+ return result