qcp-cli 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qcp/__init__.py +8 -0
- qcp/agent.py +165 -0
- qcp/cli.py +191 -0
- qcp/config.py +85 -0
- qcp/db.py +176 -0
- qcp/errors.py +61 -0
- qcp/llm.py +61 -0
- qcp/memory.py +94 -0
- qcp/models.py +125 -0
- qcp/output.py +119 -0
- qcp/tools.py +168 -0
- qcp_cli-0.1.5.dist-info/METADATA +207 -0
- qcp_cli-0.1.5.dist-info/RECORD +16 -0
- qcp_cli-0.1.5.dist-info/WHEEL +4 -0
- qcp_cli-0.1.5.dist-info/entry_points.txt +2 -0
- qcp_cli-0.1.5.dist-info/licenses/LICENSE +21 -0
qcp/llm.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""LangChain model construction and Gemini credential validation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
from langchain_core.language_models import BaseChatModel
|
|
9
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
10
|
+
|
|
11
|
+
from qcp import config as cfg
|
|
12
|
+
from qcp.errors import NoApiKeyConfiguredError
|
|
13
|
+
|
|
14
|
+
DEFAULT_GEMINI_MODEL = "gemini-2.5-flash"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_model() -> str:
|
|
18
|
+
"""Resolve the Gemini model from environment, config, or default."""
|
|
19
|
+
return str(os.environ.get("GEMINI_MODEL") or cfg.get("gemini_model") or DEFAULT_GEMINI_MODEL)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def require_api_key() -> str:
|
|
23
|
+
"""Return the configured Gemini key or raise an actionable error."""
|
|
24
|
+
api_key = cfg.get_gemini_api_key()
|
|
25
|
+
if not api_key:
|
|
26
|
+
raise NoApiKeyConfiguredError("gemini")
|
|
27
|
+
return api_key
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ChatModelFactory(ABC):
|
|
31
|
+
"""Contract for constructing the chat model used by QCP agents."""
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def create(self) -> BaseChatModel:
|
|
35
|
+
"""Create a configured LangChain chat model."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class GeminiChatModelFactory(ChatModelFactory):
|
|
39
|
+
"""Create Gemini 2.5 Flash models through LangChain."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, api_key: str, model: str | None = None) -> None:
|
|
42
|
+
"""Initialize explicit credentials and model selection."""
|
|
43
|
+
self._api_key = api_key
|
|
44
|
+
self._model = model or get_model()
|
|
45
|
+
|
|
46
|
+
def create(self) -> ChatGoogleGenerativeAI:
|
|
47
|
+
"""Create a deterministic Gemini chat model with tool support."""
|
|
48
|
+
return ChatGoogleGenerativeAI(
|
|
49
|
+
model=self._model,
|
|
50
|
+
google_api_key=self._api_key,
|
|
51
|
+
temperature=0.1,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def validate_api_key(api_key: str) -> tuple[bool, str]:
|
|
56
|
+
"""Validate a Gemini key through the same LangChain integration QCP uses."""
|
|
57
|
+
try:
|
|
58
|
+
GeminiChatModelFactory(api_key).create().invoke("Reply with the single word pong.")
|
|
59
|
+
except Exception as error:
|
|
60
|
+
return False, str(error)[:300]
|
|
61
|
+
return True, ""
|
qcp/memory.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Local, credential-free storage for database schema snapshots."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import stat
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from contextlib import suppress
|
|
10
|
+
from datetime import UTC, datetime, timedelta
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from pydantic import ValidationError
|
|
15
|
+
|
|
16
|
+
from qcp import config as cfg
|
|
17
|
+
from qcp.models import SchemaSnapshot
|
|
18
|
+
|
|
19
|
+
SCHEMA_CACHE_TTL = timedelta(hours=24)
|
|
20
|
+
SCHEMA_CACHE_VERSION = 2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SchemaMemoryStore(ABC):
|
|
24
|
+
"""Contract for persisting schema metadata across CLI runs."""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def recall(self, database_id: str) -> SchemaSnapshot | None:
|
|
28
|
+
"""Return a fresh snapshot for a database, if one exists."""
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def store(self, snapshot: SchemaSnapshot) -> None:
|
|
32
|
+
"""Persist a schema snapshot."""
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def invalidate(self, database_id: str) -> None:
|
|
36
|
+
"""Remove cached schema for a database."""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class JsonSchemaMemoryStore(SchemaMemoryStore):
|
|
40
|
+
"""Store isolated schema snapshots in ``~/.qcp/schema.json``."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, path: Path | None = None, ttl: timedelta = SCHEMA_CACHE_TTL) -> None:
|
|
43
|
+
"""Initialize the JSON store and its freshness policy."""
|
|
44
|
+
self._path = path
|
|
45
|
+
self._ttl = ttl
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def path(self) -> Path:
|
|
49
|
+
"""Return the current cache path, respecting test-time QCP_HOME overrides."""
|
|
50
|
+
return self._path or cfg.CONFIG_DIR / "schema.json"
|
|
51
|
+
|
|
52
|
+
def _load_all(self) -> dict[str, Any]:
|
|
53
|
+
if not self.path.exists():
|
|
54
|
+
return {}
|
|
55
|
+
try:
|
|
56
|
+
value = json.loads(self.path.read_text(encoding="utf-8"))
|
|
57
|
+
return value if isinstance(value, dict) else {}
|
|
58
|
+
except json.JSONDecodeError, OSError:
|
|
59
|
+
return {}
|
|
60
|
+
|
|
61
|
+
def recall(self, database_id: str) -> SchemaSnapshot | None:
|
|
62
|
+
"""Return a fresh, valid snapshot without exposing other databases."""
|
|
63
|
+
raw_snapshot = self._load_all().get(database_id)
|
|
64
|
+
if not isinstance(raw_snapshot, dict) or raw_snapshot.get("format_version") != SCHEMA_CACHE_VERSION:
|
|
65
|
+
return None
|
|
66
|
+
try:
|
|
67
|
+
snapshot = SchemaSnapshot.model_validate(raw_snapshot)
|
|
68
|
+
except ValidationError:
|
|
69
|
+
return None
|
|
70
|
+
captured_at = snapshot.captured_at
|
|
71
|
+
if captured_at.tzinfo is None:
|
|
72
|
+
captured_at = captured_at.replace(tzinfo=UTC)
|
|
73
|
+
if datetime.now(UTC) - captured_at > self._ttl:
|
|
74
|
+
return None
|
|
75
|
+
return snapshot
|
|
76
|
+
|
|
77
|
+
def store(self, snapshot: SchemaSnapshot) -> None:
|
|
78
|
+
"""Persist one validated snapshot with owner-only permissions."""
|
|
79
|
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
with suppress(OSError):
|
|
81
|
+
os.chmod(self.path.parent, stat.S_IRWXU)
|
|
82
|
+
snapshots = self._load_all()
|
|
83
|
+
snapshots[snapshot.database_id] = snapshot.model_dump(mode="json")
|
|
84
|
+
self.path.write_text(json.dumps(snapshots, indent=2), encoding="utf-8")
|
|
85
|
+
with suppress(OSError):
|
|
86
|
+
os.chmod(self.path, stat.S_IRUSR | stat.S_IWUSR)
|
|
87
|
+
|
|
88
|
+
def invalidate(self, database_id: str) -> None:
|
|
89
|
+
"""Remove one database's snapshot while preserving all others."""
|
|
90
|
+
snapshots = self._load_all()
|
|
91
|
+
if database_id not in snapshots:
|
|
92
|
+
return
|
|
93
|
+
del snapshots[database_id]
|
|
94
|
+
self.path.write_text(json.dumps(snapshots, indent=2), encoding="utf-8")
|
qcp/models.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""Typed domain models used by the QCP agent and its tools."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field, SecretStr, field_serializer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class QcpConfig(BaseModel):
|
|
12
|
+
"""Validated representation of the persisted QCP configuration."""
|
|
13
|
+
|
|
14
|
+
model_config = ConfigDict(extra="ignore")
|
|
15
|
+
|
|
16
|
+
database_url: str | None = None
|
|
17
|
+
gemini_api_key: SecretStr | None = None
|
|
18
|
+
provider: Literal["gemini"] = "gemini"
|
|
19
|
+
gemini_model: str = "gemini-2.5-flash"
|
|
20
|
+
|
|
21
|
+
@field_serializer("gemini_api_key", when_used="json")
|
|
22
|
+
def serialize_api_key(self, value: SecretStr | None) -> str | None:
|
|
23
|
+
"""Persist the actual API key while retaining redacted representations."""
|
|
24
|
+
return value.get_secret_value() if value is not None else None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SchemaColumn(BaseModel):
|
|
28
|
+
"""A PostgreSQL column exposed to the database agent."""
|
|
29
|
+
|
|
30
|
+
name: str
|
|
31
|
+
data_type: str
|
|
32
|
+
nullable: bool
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SchemaTable(BaseModel):
|
|
36
|
+
"""A PostgreSQL table and its columns."""
|
|
37
|
+
|
|
38
|
+
schema_name: str = "public"
|
|
39
|
+
name: str
|
|
40
|
+
columns: list[SchemaColumn]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SchemaSnapshot(BaseModel):
|
|
44
|
+
"""A schema snapshot persisted in local QCP memory."""
|
|
45
|
+
|
|
46
|
+
format_version: int = 2
|
|
47
|
+
database_id: str
|
|
48
|
+
captured_at: datetime
|
|
49
|
+
tables: list[SchemaTable]
|
|
50
|
+
|
|
51
|
+
def summary(self, max_tables: int = 50) -> str:
|
|
52
|
+
"""Return a compact schema representation for the language model."""
|
|
53
|
+
lines: list[str] = []
|
|
54
|
+
for table in self.tables[:max_tables]:
|
|
55
|
+
columns = ", ".join(f"{column.name} {column.data_type}" for column in table.columns)
|
|
56
|
+
lines.append(f"- {table.schema_name}.{table.name}({columns})")
|
|
57
|
+
if len(self.tables) > max_tables:
|
|
58
|
+
lines.append(f"... and {len(self.tables) - max_tables} more tables")
|
|
59
|
+
return "\n".join(lines) if lines else "(no tables found in 'public' schema)"
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class QueryResult(BaseModel):
|
|
63
|
+
"""The exact SQL and rows returned by the read-query tool."""
|
|
64
|
+
|
|
65
|
+
sql: str
|
|
66
|
+
columns: list[str] = Field(default_factory=list)
|
|
67
|
+
rows: list[list[Any]] = Field(default_factory=list)
|
|
68
|
+
truncated: bool = False
|
|
69
|
+
executed: bool = True
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class InsightContext(BaseModel):
|
|
73
|
+
"""Grounded facts supplied to the model for insight generation."""
|
|
74
|
+
|
|
75
|
+
facts: list[str]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class QueryNarrative(BaseModel):
|
|
79
|
+
"""Structured natural-language response from the query agent."""
|
|
80
|
+
|
|
81
|
+
answer: str = Field(description="A concise answer grounded only in the executed query result.")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class InsightsNarrative(BaseModel):
|
|
85
|
+
"""Structured natural-language response from the insights agent."""
|
|
86
|
+
|
|
87
|
+
insights: list[str] = Field(min_length=3, max_length=6)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class AgentQueryResponse(BaseModel):
|
|
91
|
+
"""Application-level response returned to the CLI query command."""
|
|
92
|
+
|
|
93
|
+
query_result: QueryResult
|
|
94
|
+
answer: str
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class AgentInsightsResponse(BaseModel):
|
|
98
|
+
"""Application-level response returned to the CLI insights command."""
|
|
99
|
+
|
|
100
|
+
insights: list[str]
|
|
101
|
+
query_result: QueryResult | None = None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class LookupSchemaInput(BaseModel):
|
|
105
|
+
"""Input for the schema lookup tool."""
|
|
106
|
+
|
|
107
|
+
force_refresh: bool = Field(default=False, description="Ignore cached state and query PostgreSQL again.")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class SchemaMemoryInput(BaseModel):
|
|
111
|
+
"""Input for the local schema-memory tool."""
|
|
112
|
+
|
|
113
|
+
operation: Literal["recall", "store"]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class ExecuteReadQueryInput(BaseModel):
|
|
117
|
+
"""Input for the read-only query execution tool."""
|
|
118
|
+
|
|
119
|
+
sql: str = Field(description="One PostgreSQL SELECT or WITH query without multiple statements.")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class AnalyzeInsightsInput(BaseModel):
|
|
123
|
+
"""Input for the grounded insights tool."""
|
|
124
|
+
|
|
125
|
+
focus: str | None = Field(default=None, description="Optional analytical focus supplied by the user.")
|
qcp/output.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Small terminal output helpers kept independent from the agent layer."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import shutil
|
|
6
|
+
import textwrap
|
|
7
|
+
from collections.abc import Sequence
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
DEFAULT_TERMINAL_WIDTH = 120
|
|
11
|
+
MIN_TERMINAL_WIDTH = 40
|
|
12
|
+
COLUMN_SEPARATOR = " "
|
|
13
|
+
EXPANDED_SEPARATOR = " | "
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def format_table(
|
|
17
|
+
columns: Sequence[str],
|
|
18
|
+
rows: Sequence[Sequence[Any]],
|
|
19
|
+
max_width: int | None = None,
|
|
20
|
+
) -> str:
|
|
21
|
+
"""Format query results for the current terminal width.
|
|
22
|
+
|
|
23
|
+
Compact results use a conventional horizontal table. Results that cannot
|
|
24
|
+
fit the terminal switch to an expanded record layout so long values remain
|
|
25
|
+
readable without truncation.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
columns: Column names in result order.
|
|
29
|
+
rows: Query result rows.
|
|
30
|
+
max_width: Optional deterministic width override, primarily for tests.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
A terminal-friendly representation including the result row count.
|
|
34
|
+
"""
|
|
35
|
+
if not columns:
|
|
36
|
+
return "(no output)"
|
|
37
|
+
if not rows:
|
|
38
|
+
return "(0 rows)"
|
|
39
|
+
|
|
40
|
+
terminal_width = max(
|
|
41
|
+
MIN_TERMINAL_WIDTH,
|
|
42
|
+
max_width or shutil.get_terminal_size(fallback=(DEFAULT_TERMINAL_WIDTH, 24)).columns,
|
|
43
|
+
)
|
|
44
|
+
str_rows = [[_stringify(value) for value in row] for row in rows]
|
|
45
|
+
widths = [len(c) for c in columns]
|
|
46
|
+
for row in str_rows:
|
|
47
|
+
for index, value in enumerate(row[: len(widths)]):
|
|
48
|
+
widths[index] = max(widths[index], len(value))
|
|
49
|
+
|
|
50
|
+
table_width = sum(widths) + len(COLUMN_SEPARATOR) * (len(widths) - 1)
|
|
51
|
+
contains_multiline_value = any("\n" in value for row in str_rows for value in row)
|
|
52
|
+
if table_width > terminal_width or contains_multiline_value:
|
|
53
|
+
return _format_expanded(columns, str_rows, terminal_width)
|
|
54
|
+
|
|
55
|
+
def fmt_row(vals: list[str]) -> str:
|
|
56
|
+
padded_values = [*vals[: len(widths)], *([""] * max(0, len(widths) - len(vals)))]
|
|
57
|
+
return COLUMN_SEPARATOR.join(value.ljust(widths[index]) for index, value in enumerate(padded_values))
|
|
58
|
+
|
|
59
|
+
separator = COLUMN_SEPARATOR.join("-" * width for width in widths)
|
|
60
|
+
lines = [fmt_row(list(columns)), separator]
|
|
61
|
+
lines.extend(fmt_row(row) for row in str_rows)
|
|
62
|
+
lines.append(f"\n{_row_count(len(rows))}")
|
|
63
|
+
return "\n".join(lines)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _stringify(value: Any) -> str:
|
|
67
|
+
"""Convert a result value to one safe display string."""
|
|
68
|
+
if value is None:
|
|
69
|
+
return ""
|
|
70
|
+
return str(value).replace("\r\n", "\n").replace("\r", "\n")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _format_expanded(columns: Sequence[str], rows: Sequence[Sequence[str]], max_width: int) -> str:
|
|
74
|
+
"""Render wide results as wrapped, vertically expanded records."""
|
|
75
|
+
label_width = min(max(len(column) for column in columns), max_width // 3)
|
|
76
|
+
value_width = max(1, max_width - label_width - len(EXPANDED_SEPARATOR))
|
|
77
|
+
lines: list[str] = []
|
|
78
|
+
|
|
79
|
+
for row_number, row in enumerate(rows, start=1):
|
|
80
|
+
if lines:
|
|
81
|
+
lines.append("")
|
|
82
|
+
lines.append(f"-[ RECORD {row_number} ]".ljust(max_width, "-"))
|
|
83
|
+
padded_row = [*row[: len(columns)], *([""] * max(0, len(columns) - len(row)))]
|
|
84
|
+
for column, value in zip(columns, padded_row, strict=True):
|
|
85
|
+
wrapped_column = _wrap_value(column, label_width)
|
|
86
|
+
wrapped_value = _wrap_value(value, value_width)
|
|
87
|
+
part_count = max(len(wrapped_column), len(wrapped_value))
|
|
88
|
+
for part_index in range(part_count):
|
|
89
|
+
label_part = wrapped_column[part_index] if part_index < len(wrapped_column) else ""
|
|
90
|
+
value_part = wrapped_value[part_index] if part_index < len(wrapped_value) else ""
|
|
91
|
+
lines.append(f"{label_part.ljust(label_width)}{EXPANDED_SEPARATOR}{value_part}")
|
|
92
|
+
|
|
93
|
+
lines.append(f"\n{_row_count(len(rows))}")
|
|
94
|
+
return "\n".join(lines)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _wrap_value(value: str, width: int) -> list[str]:
|
|
98
|
+
"""Wrap long and multiline values without dropping their content."""
|
|
99
|
+
if not value:
|
|
100
|
+
return [""]
|
|
101
|
+
wrapped_lines: list[str] = []
|
|
102
|
+
for logical_line in value.split("\n"):
|
|
103
|
+
wrapped_lines.extend(
|
|
104
|
+
textwrap.wrap(
|
|
105
|
+
logical_line,
|
|
106
|
+
width=width,
|
|
107
|
+
replace_whitespace=False,
|
|
108
|
+
drop_whitespace=False,
|
|
109
|
+
break_long_words=True,
|
|
110
|
+
break_on_hyphens=False,
|
|
111
|
+
)
|
|
112
|
+
or [""]
|
|
113
|
+
)
|
|
114
|
+
return wrapped_lines
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _row_count(row_count: int) -> str:
|
|
118
|
+
"""Return the conventional result-count footer."""
|
|
119
|
+
return f"({row_count} row{'s' if row_count != 1 else ''})"
|
qcp/tools.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
"""LangChain tools exposed to the QCP database agent."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any, NotRequired
|
|
5
|
+
|
|
6
|
+
from langchain.agents import AgentState
|
|
7
|
+
from langchain.messages import ToolMessage
|
|
8
|
+
from langchain.tools import ToolRuntime, tool
|
|
9
|
+
from langchain_core.tools import BaseTool
|
|
10
|
+
from langgraph.types import Command
|
|
11
|
+
|
|
12
|
+
from qcp.db import DatabaseClient, normalize_read_query
|
|
13
|
+
from qcp.errors import SchemaChangedError
|
|
14
|
+
from qcp.memory import SchemaMemoryStore
|
|
15
|
+
from qcp.models import (
|
|
16
|
+
AnalyzeInsightsInput,
|
|
17
|
+
ExecuteReadQueryInput,
|
|
18
|
+
InsightContext,
|
|
19
|
+
LookupSchemaInput,
|
|
20
|
+
QueryResult,
|
|
21
|
+
SchemaMemoryInput,
|
|
22
|
+
SchemaSnapshot,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class QcpAgentState(AgentState):
|
|
27
|
+
"""Agent state containing only validated, application-owned artifacts."""
|
|
28
|
+
|
|
29
|
+
schema_snapshot: NotRequired[dict[str, Any] | None]
|
|
30
|
+
query_result: NotRequired[dict[str, Any] | None]
|
|
31
|
+
insight_context: NotRequired[dict[str, Any] | None]
|
|
32
|
+
schema_retry_count: NotRequired[int]
|
|
33
|
+
query_execution_retry_count: NotRequired[int]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DatabaseToolkit:
|
|
37
|
+
"""Build dependency-injected LangChain tools for one CLI invocation."""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
database: DatabaseClient,
|
|
42
|
+
memory: SchemaMemoryStore,
|
|
43
|
+
*,
|
|
44
|
+
dry_run: bool = False,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Initialize database and memory dependencies for the tools."""
|
|
47
|
+
self._database = database
|
|
48
|
+
self._memory = memory
|
|
49
|
+
self._dry_run = dry_run
|
|
50
|
+
|
|
51
|
+
def build(self) -> list[BaseTool]:
|
|
52
|
+
"""Create the schema, memory, query, and insights tools."""
|
|
53
|
+
database = self._database
|
|
54
|
+
memory = self._memory
|
|
55
|
+
dry_run = self._dry_run
|
|
56
|
+
|
|
57
|
+
@tool("lookup_schema", args_schema=LookupSchemaInput)
|
|
58
|
+
def lookup_schema(force_refresh: bool, runtime: ToolRuntime[None, QcpAgentState]) -> Command:
|
|
59
|
+
"""Read the current public PostgreSQL schema when cache is absent or stale."""
|
|
60
|
+
del force_refresh
|
|
61
|
+
snapshot = database.lookup_schema()
|
|
62
|
+
return _state_command(
|
|
63
|
+
runtime,
|
|
64
|
+
content=f"Current PostgreSQL schema:\n{snapshot.summary()}",
|
|
65
|
+
schema_snapshot=snapshot.model_dump(mode="json"),
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@tool("schema_memory", args_schema=SchemaMemoryInput)
|
|
69
|
+
def schema_memory(operation: str, runtime: ToolRuntime[None, QcpAgentState]) -> Command:
|
|
70
|
+
"""Recall a fresh schema snapshot from local memory or store a looked-up snapshot."""
|
|
71
|
+
if operation == "recall":
|
|
72
|
+
snapshot = memory.recall(database.database_id)
|
|
73
|
+
if snapshot is None:
|
|
74
|
+
return _state_command(runtime, content="Schema memory is missing or older than 24 hours.")
|
|
75
|
+
return _state_command(
|
|
76
|
+
runtime,
|
|
77
|
+
content=f"Fresh schema recalled from local memory:\n{snapshot.summary()}",
|
|
78
|
+
schema_snapshot=snapshot.model_dump(mode="json"),
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
raw_snapshot = runtime.state.get("schema_snapshot")
|
|
82
|
+
if raw_snapshot is None:
|
|
83
|
+
return _state_command(runtime, content="No looked-up schema is available to store.")
|
|
84
|
+
snapshot = SchemaSnapshot.model_validate(raw_snapshot)
|
|
85
|
+
memory.store(snapshot)
|
|
86
|
+
return _state_command(runtime, content="Schema snapshot stored in local memory.")
|
|
87
|
+
|
|
88
|
+
@tool("execute_read_query", args_schema=ExecuteReadQueryInput)
|
|
89
|
+
def execute_read_query(sql: str, runtime: ToolRuntime[None, QcpAgentState]) -> Command:
|
|
90
|
+
"""Execute one PostgreSQL SELECT or WITH query in a read-only transaction."""
|
|
91
|
+
if runtime.state.get("schema_snapshot") is None:
|
|
92
|
+
return _state_command(runtime, content="Schema is required before query execution.")
|
|
93
|
+
|
|
94
|
+
if dry_run:
|
|
95
|
+
query_result = QueryResult(sql=normalize_read_query(sql), executed=False)
|
|
96
|
+
return _state_command(
|
|
97
|
+
runtime,
|
|
98
|
+
content=json.dumps(query_result.model_dump(mode="json")),
|
|
99
|
+
query_result=query_result.model_dump(mode="json"),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
query_result = database.execute_read_query(sql)
|
|
104
|
+
except SchemaChangedError:
|
|
105
|
+
retry_count = runtime.state.get("schema_retry_count", 0)
|
|
106
|
+
if retry_count >= 1:
|
|
107
|
+
raise
|
|
108
|
+
memory.invalidate(database.database_id)
|
|
109
|
+
return _state_command(
|
|
110
|
+
runtime,
|
|
111
|
+
content=(
|
|
112
|
+
"The cached schema is stale. Call lookup_schema with force_refresh=true, "
|
|
113
|
+
"store it with schema_memory, then retry this query once."
|
|
114
|
+
),
|
|
115
|
+
schema_snapshot=None,
|
|
116
|
+
query_result=None,
|
|
117
|
+
schema_retry_count=1,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
payload = query_result.model_dump(mode="json")
|
|
121
|
+
return _state_command(
|
|
122
|
+
runtime,
|
|
123
|
+
content=json.dumps(payload),
|
|
124
|
+
query_result=payload,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
@tool("analyze_insights", args_schema=AnalyzeInsightsInput)
|
|
128
|
+
def analyze_insights(focus: str | None, runtime: ToolRuntime[None, QcpAgentState]) -> Command:
|
|
129
|
+
"""Build grounded facts from schema and optional query results for insight generation."""
|
|
130
|
+
raw_snapshot = runtime.state.get("schema_snapshot")
|
|
131
|
+
if raw_snapshot is None:
|
|
132
|
+
return _state_command(runtime, content="Schema is required before analyzing insights.")
|
|
133
|
+
snapshot = SchemaSnapshot.model_validate(raw_snapshot)
|
|
134
|
+
facts = [
|
|
135
|
+
f"The database snapshot contains {len(snapshot.tables)} tables.",
|
|
136
|
+
"Available tables: " + ", ".join(f"{table.schema_name}.{table.name}" for table in snapshot.tables),
|
|
137
|
+
]
|
|
138
|
+
if focus:
|
|
139
|
+
facts.append(f"The user's requested analytical focus is: {focus}")
|
|
140
|
+
raw_result = runtime.state.get("query_result")
|
|
141
|
+
if raw_result is not None:
|
|
142
|
+
query_result = QueryResult.model_validate(raw_result)
|
|
143
|
+
facts.extend(
|
|
144
|
+
[
|
|
145
|
+
f"The executed query returned {len(query_result.rows)} rows.",
|
|
146
|
+
"Result columns: " + ", ".join(query_result.columns),
|
|
147
|
+
"Result sample: " + json.dumps(query_result.model_dump(mode="json")["rows"][:20]),
|
|
148
|
+
]
|
|
149
|
+
)
|
|
150
|
+
context = InsightContext(facts=facts)
|
|
151
|
+
payload = context.model_dump(mode="json")
|
|
152
|
+
return _state_command(
|
|
153
|
+
runtime,
|
|
154
|
+
content=json.dumps(payload),
|
|
155
|
+
insight_context=payload,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return [lookup_schema, schema_memory, execute_read_query, analyze_insights]
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _state_command(runtime: ToolRuntime, content: str, **updates: Any) -> Command:
|
|
162
|
+
"""Create a state update containing the required matching tool message."""
|
|
163
|
+
return Command(
|
|
164
|
+
update={
|
|
165
|
+
**updates,
|
|
166
|
+
"messages": [ToolMessage(content=content, tool_call_id=runtime.tool_call_id)],
|
|
167
|
+
}
|
|
168
|
+
)
|