sqlsaber-viz 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlsaber_viz/spec.py ADDED
@@ -0,0 +1,130 @@
1
+ """Pydantic models for visualization specs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Annotated, Literal
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class FieldEncoding(BaseModel):
11
+ field: str
12
+ type: Literal["category", "number", "time"] = "number"
13
+
14
+
15
+ class ChartOptions(BaseModel):
16
+ width: int | None = Field(default=None, ge=20, le=200)
17
+ height: int | None = Field(default=None, ge=10, le=100)
18
+ x_label: str | None = None
19
+ y_label: str | None = None
20
+ color: str | None = None
21
+ marker: str | None = None
22
+
23
+
24
+ class BarEncoding(BaseModel):
25
+ x: FieldEncoding
26
+ y: FieldEncoding
27
+ series: FieldEncoding | None = None
28
+
29
+
30
+ class BarChart(BaseModel):
31
+ type: Literal["bar"]
32
+ encoding: BarEncoding
33
+ orientation: Literal["vertical", "horizontal"] = "vertical"
34
+ mode: Literal["grouped", "stacked"] = "grouped"
35
+ options: ChartOptions = Field(default_factory=ChartOptions)
36
+
37
+
38
+ class LineEncoding(BaseModel):
39
+ x: FieldEncoding
40
+ y: FieldEncoding
41
+ series: FieldEncoding | None = None
42
+
43
+
44
+ class LineChart(BaseModel):
45
+ type: Literal["line"]
46
+ encoding: LineEncoding
47
+ options: ChartOptions = Field(default_factory=ChartOptions)
48
+
49
+
50
+ class ScatterEncoding(BaseModel):
51
+ x: FieldEncoding
52
+ y: FieldEncoding
53
+ series: FieldEncoding | None = None
54
+
55
+
56
+ class ScatterChart(BaseModel):
57
+ type: Literal["scatter"]
58
+ encoding: ScatterEncoding
59
+ options: ChartOptions = Field(default_factory=ChartOptions)
60
+
61
+
62
+ class BoxplotConfig(BaseModel):
63
+ label_field: str
64
+ value_field: str
65
+
66
+
67
+ class BoxplotChart(BaseModel):
68
+ type: Literal["boxplot"]
69
+ boxplot: BoxplotConfig
70
+ options: ChartOptions = Field(default_factory=ChartOptions)
71
+
72
+
73
+ class HistogramConfig(BaseModel):
74
+ field: str
75
+ bins: int = Field(default=20, ge=2, le=100)
76
+
77
+
78
+ class HistogramChart(BaseModel):
79
+ type: Literal["histogram"]
80
+ histogram: HistogramConfig
81
+ options: ChartOptions = Field(default_factory=ChartOptions)
82
+
83
+
84
+ ChartSpec = Annotated[
85
+ BarChart | LineChart | ScatterChart | BoxplotChart | HistogramChart,
86
+ Field(discriminator="type"),
87
+ ]
88
+
89
+
90
+ class SortItem(BaseModel):
91
+ field: str
92
+ dir: Literal["asc", "desc"] = "asc"
93
+
94
+
95
+ class SortTransform(BaseModel):
96
+ sort: list[SortItem]
97
+
98
+
99
+ class LimitTransform(BaseModel):
100
+ limit: int = Field(ge=1)
101
+
102
+
103
+ class FilterConfig(BaseModel):
104
+ field: str
105
+ op: Literal["==", "!=", ">", "<", ">=", "<="]
106
+ value: str | int | float | bool | None
107
+
108
+
109
+ class FilterTransform(BaseModel):
110
+ filter: FilterConfig
111
+
112
+
113
+ Transform = SortTransform | LimitTransform | FilterTransform
114
+
115
+
116
+ class DataSource(BaseModel):
117
+ file: str = Field(pattern=r"^result_[A-Za-z0-9._-]+\.json$")
118
+
119
+
120
+ class DataConfig(BaseModel):
121
+ source: DataSource
122
+
123
+
124
+ class VizSpec(BaseModel):
125
+ version: Literal["1"] = "1"
126
+ title: str | None = None
127
+ description: str | None = None
128
+ data: DataConfig
129
+ chart: ChartSpec
130
+ transform: list[Transform] = Field(default_factory=list)
@@ -0,0 +1,144 @@
1
+ """Internal agent for generating visualization specs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import Any
7
+
8
+ from sqlsaber.agents.provider_factory import ProviderFactory
9
+ from sqlsaber.config import providers
10
+ from sqlsaber.config.settings import Config
11
+
12
+ from .prompts import VIZ_SYSTEM_PROMPT
13
+ from .spec import VizSpec
14
+ from .templates import ChartType, list_chart_types, vizspec_template
15
+
16
+
17
+ class SpecAgent:
18
+ """Internal agent for generating visualization specs."""
19
+
20
+ def __init__(self, model_name: str | None = None, api_key: str | None = None):
21
+ self.config = Config()
22
+ self._model_name_override = model_name
23
+ self._api_key_override = api_key
24
+ self.agent = self._build_agent()
25
+
26
+ def _build_agent(self):
27
+ model_name = self._model_name_override or self.config.model.name
28
+ model_name_only = (
29
+ model_name.split(":", 1)[1] if ":" in model_name else model_name
30
+ )
31
+
32
+ if not (self._model_name_override and self._api_key_override):
33
+ self.config.auth.validate(model_name)
34
+
35
+ provider = providers.provider_from_model(model_name) or ""
36
+ api_key = self._api_key_override or self.config.auth.get_api_key(model_name)
37
+
38
+ factory = ProviderFactory()
39
+ agent = factory.create_agent(
40
+ provider=provider,
41
+ model_name=model_name_only,
42
+ full_model_str=model_name,
43
+ api_key=api_key,
44
+ thinking_enabled=False,
45
+ )
46
+
47
+ @agent.system_prompt
48
+ def viz_system_prompt() -> str:
49
+ return VIZ_SYSTEM_PROMPT
50
+
51
+ self._register_tools(agent)
52
+
53
+ return agent
54
+
55
+ def _register_tools(self, agent) -> None:
56
+ """Register visualization helper tools on the agent."""
57
+
58
+ @agent.tool_plain
59
+ def get_vizspec_template(chart_type: ChartType, file: str) -> dict:
60
+ """Get the complete VizSpec template for a chart type.
61
+
62
+ Call this FIRST to get the correct JSON structure, then fill in
63
+ the placeholder field names with actual column names from your data.
64
+
65
+ Args:
66
+ chart_type: One of "bar", "line", "scatter", "boxplot", "histogram"
67
+ file: The result file key (e.g., "result_abc123.json")
68
+
69
+ Returns:
70
+ A complete VizSpec template with placeholders for field names.
71
+ """
72
+ return vizspec_template(chart_type, file)
73
+
74
+ @agent.tool_plain
75
+ def get_available_chart_types() -> list[dict]:
76
+ """List available chart types with descriptions.
77
+
78
+ Call this if you're unsure which chart type to use for the data.
79
+
80
+ Returns:
81
+ List of chart types with descriptions and use cases.
82
+ """
83
+ return list_chart_types()
84
+
85
+ async def generate_spec(
86
+ self,
87
+ request: str,
88
+ columns: list[dict],
89
+ row_count: int,
90
+ file: str,
91
+ chart_type_hint: str | None = None,
92
+ ) -> VizSpec:
93
+ """Generate a VizSpec from user request and data summary."""
94
+
95
+ prompt = self._build_prompt(
96
+ request=request,
97
+ columns=columns,
98
+ row_count=row_count,
99
+ file=file,
100
+ chart_type_hint=chart_type_hint,
101
+ )
102
+
103
+ result = await self.agent.run(prompt)
104
+ output = str(result.output).strip()
105
+ parsed = _parse_json(output)
106
+ return VizSpec.model_validate(parsed)
107
+
108
+ def _build_prompt(
109
+ self,
110
+ request: str,
111
+ columns: list[dict],
112
+ row_count: int,
113
+ file: str,
114
+ chart_type_hint: str | None,
115
+ ) -> str:
116
+ columns_json = json.dumps(columns, ensure_ascii=False, indent=2)
117
+ hint_text = f"Chart type hint: {chart_type_hint}" if chart_type_hint else ""
118
+
119
+ return (
120
+ "## User Request\n"
121
+ f"{request.strip()}\n\n"
122
+ "## Data Summary\n"
123
+ f"Row count: {row_count}\n"
124
+ f"File: {file}\n"
125
+ f"Columns:\n{columns_json}\n\n"
126
+ f"{hint_text}\n\n"
127
+ "Use `get_vizspec_template` to get the correct spec structure, "
128
+ "then fill in the placeholders with actual column names.\n"
129
+ "Return ONLY the final JSON."
130
+ ).strip()
131
+
132
+
133
+ def _parse_json(text: str) -> dict[str, Any]:
134
+ try:
135
+ parsed = json.loads(text)
136
+ except json.JSONDecodeError:
137
+ start = text.find("{")
138
+ end = text.rfind("}")
139
+ if start == -1 or end == -1 or end <= start:
140
+ raise
141
+ parsed = json.loads(text[start : end + 1])
142
+ if not isinstance(parsed, dict):
143
+ raise json.JSONDecodeError("Expected JSON object", text, 0)
144
+ return parsed
@@ -0,0 +1,175 @@
1
+ """Template builders for visualization specs.
2
+
3
+ These functions generate minimal valid templates from Pydantic models,
4
+ ensuring they stay in sync with the schema definitions.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Literal
10
+
11
+ from .spec import (
12
+ BarChart,
13
+ BarEncoding,
14
+ BoxplotChart,
15
+ BoxplotConfig,
16
+ ChartSpec,
17
+ ChartOptions,
18
+ DataConfig,
19
+ DataSource,
20
+ FieldEncoding,
21
+ HistogramChart,
22
+ HistogramConfig,
23
+ LineChart,
24
+ LineEncoding,
25
+ ScatterChart,
26
+ ScatterEncoding,
27
+ VizSpec,
28
+ )
29
+
30
+ ChartType = Literal["bar", "line", "scatter", "boxplot", "histogram"]
31
+
32
+ # Placeholder values for template fields
33
+ _CATEGORY_PLACEHOLDER = "<category_column>"
34
+ _NUMBER_PLACEHOLDER = "<number_column>"
35
+ _TIME_PLACEHOLDER = "<time_column>"
36
+ _LABEL_PLACEHOLDER = "<label_column>"
37
+ _VALUE_PLACEHOLDER = "<value_column>"
38
+
39
+
40
+ def _build_bar_chart() -> BarChart:
41
+ return BarChart(
42
+ type="bar",
43
+ encoding=BarEncoding(
44
+ x=FieldEncoding(field=_CATEGORY_PLACEHOLDER, type="category"),
45
+ y=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
46
+ series=None,
47
+ ),
48
+ orientation="vertical",
49
+ mode="grouped",
50
+ options=ChartOptions(),
51
+ )
52
+
53
+
54
+ def _build_line_chart() -> LineChart:
55
+ return LineChart(
56
+ type="line",
57
+ encoding=LineEncoding(
58
+ x=FieldEncoding(field=_TIME_PLACEHOLDER, type="time"),
59
+ y=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
60
+ series=None,
61
+ ),
62
+ options=ChartOptions(),
63
+ )
64
+
65
+
66
+ def _build_scatter_chart() -> ScatterChart:
67
+ return ScatterChart(
68
+ type="scatter",
69
+ encoding=ScatterEncoding(
70
+ x=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
71
+ y=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
72
+ series=None,
73
+ ),
74
+ options=ChartOptions(),
75
+ )
76
+
77
+
78
+ def _build_boxplot_chart() -> BoxplotChart:
79
+ return BoxplotChart(
80
+ type="boxplot",
81
+ boxplot=BoxplotConfig(
82
+ label_field=_LABEL_PLACEHOLDER,
83
+ value_field=_VALUE_PLACEHOLDER,
84
+ ),
85
+ options=ChartOptions(),
86
+ )
87
+
88
+
89
+ def _build_histogram_chart() -> HistogramChart:
90
+ return HistogramChart(
91
+ type="histogram",
92
+ histogram=HistogramConfig(
93
+ field=_NUMBER_PLACEHOLDER,
94
+ bins=20,
95
+ ),
96
+ options=ChartOptions(),
97
+ )
98
+
99
+
100
+ _CHART_BUILDERS: dict[ChartType, callable] = {
101
+ "bar": _build_bar_chart,
102
+ "line": _build_line_chart,
103
+ "scatter": _build_scatter_chart,
104
+ "boxplot": _build_boxplot_chart,
105
+ "histogram": _build_histogram_chart,
106
+ }
107
+
108
+
109
+ def _build_chart(chart_type: ChartType) -> ChartSpec:
110
+ """Build a chart object for the given type."""
111
+ builder = _CHART_BUILDERS.get(chart_type)
112
+ if builder is None:
113
+ raise ValueError(f"Unknown chart type: {chart_type}")
114
+ return builder()
115
+
116
+
117
+ def chart_template(chart_type: ChartType) -> dict:
118
+ """Return a minimal valid chart template for the given chart type.
119
+
120
+ The template uses placeholder field names that the model should replace
121
+ with actual column names from the data.
122
+ """
123
+ return _build_chart(chart_type).model_dump(exclude_none=True)
124
+
125
+
126
+ def vizspec_template(chart_type: ChartType, file: str) -> dict:
127
+ """Return a complete VizSpec template with data source pre-filled.
128
+
129
+ The template includes the chart structure for the specified type
130
+ and has placeholders for field names.
131
+ """
132
+ spec = VizSpec(
133
+ version="1",
134
+ title=None,
135
+ description=None,
136
+ data=DataConfig(source=DataSource(file=file)),
137
+ chart=_build_chart(chart_type),
138
+ transform=[],
139
+ )
140
+
141
+ return spec.model_dump(exclude_none=True)
142
+
143
+
144
+ def list_chart_types() -> list[dict]:
145
+ """Return available chart types with descriptions.
146
+
147
+ Helps the model choose the appropriate chart type for the data.
148
+ """
149
+ return [
150
+ {
151
+ "type": "bar",
152
+ "description": "Compare categories. Use x for category, y for numeric value.",
153
+ "use_when": "Comparing values across categories (e.g., sales by region)",
154
+ },
155
+ {
156
+ "type": "line",
157
+ "description": "Show trends over time/sequence. Use x for time/sequence, y for value.",
158
+ "use_when": "Showing change over time (e.g., monthly revenue)",
159
+ },
160
+ {
161
+ "type": "scatter",
162
+ "description": "Show correlation between two numeric variables.",
163
+ "use_when": "Exploring relationship between two numbers (e.g., age vs income)",
164
+ },
165
+ {
166
+ "type": "boxplot",
167
+ "description": "Show distribution of values across groups.",
168
+ "use_when": "Comparing distributions (e.g., salary by department)",
169
+ },
170
+ {
171
+ "type": "histogram",
172
+ "description": "Show distribution of a single numeric variable.",
173
+ "use_when": "Understanding value distribution (e.g., age distribution)",
174
+ },
175
+ ]
sqlsaber_viz/tools.py ADDED
@@ -0,0 +1,234 @@
1
+ """Visualization tool implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import re
8
+ from html import escape
9
+
10
+ from pydantic import ValidationError
11
+ from pydantic_ai import RunContext
12
+ from rich.console import Console
13
+ from rich.text import Text
14
+
15
+ from sqlsaber.tools.base import Tool
16
+ from sqlsaber.utils.json_utils import json_dumps
17
+
18
+ from .data_loader import (
19
+ extract_data_summary,
20
+ find_tool_output_in_messages,
21
+ find_tool_output_payload,
22
+ )
23
+ from .renderers.plotext_renderer import PlotextRenderer
24
+ from .spec import BarChart, LimitTransform, SortItem, SortTransform, VizSpec
25
+ from .transforms import apply_transforms
26
+
27
+ TOOL_OUTPUT_FILE_PATTERN = re.compile(r"^result_[A-Za-z0-9._-]+\.json$")
28
+ SPEC_TIMEOUT_SECONDS = 300
29
+
30
+
31
+ class VizTool(Tool):
32
+ """Terminal visualization tool for SQL results."""
33
+
34
+ requires_ctx = True
35
+
36
+ def __init__(self):
37
+ super().__init__()
38
+ self._last_ctx: RunContext | None = None
39
+ self._last_rows: list[dict] | None = None
40
+ self._last_file: str | None = None
41
+ self._replay_messages: list | None = None
42
+
43
+ def set_replay_messages(self, messages: list) -> None:
44
+ """Set message history for replay scenarios (e.g., threads show)."""
45
+ self._replay_messages = messages
46
+
47
+ @property
48
+ def name(self) -> str:
49
+ return "viz"
50
+
51
+ def render_executing(self, console: Console, args: dict) -> bool:
52
+ """Suppress default JSON rendering during execution."""
53
+ return True
54
+
55
+ async def execute(
56
+ self,
57
+ ctx: RunContext,
58
+ request: str,
59
+ file: str,
60
+ chart_type: str | None = None,
61
+ ) -> str:
62
+ """Generate a visualization spec for SQL results.
63
+
64
+ Args:
65
+ request: Natural language description of the desired visualization.
66
+ file: Result file key from execute_sql (e.g., "result_abc123.json").
67
+ chart_type: Optional hint for chart type (bar, line, scatter, boxplot, histogram).
68
+
69
+ Returns:
70
+ JSON string containing the visualization spec.
71
+ """
72
+ self._last_ctx = ctx
73
+
74
+ if not file or not TOOL_OUTPUT_FILE_PATTERN.match(file):
75
+ return json_dumps({"error": "Invalid result file key format."})
76
+
77
+ tool_call_id = file.removeprefix("result_").removesuffix(".json")
78
+ payload = find_tool_output_payload(ctx, tool_call_id)
79
+ if payload is None:
80
+ return json_dumps({"error": "Tool output not found in message history."})
81
+
82
+ summary = extract_data_summary(payload)
83
+ columns = summary.get("columns", [])
84
+ row_count = summary.get("row_count", 0)
85
+ rows = summary.get("rows", [])
86
+
87
+ self._last_rows = rows
88
+ self._last_file = file
89
+
90
+ agent = _get_spec_agent_cls()()
91
+
92
+ try:
93
+ spec = await asyncio.wait_for(
94
+ agent.generate_spec(
95
+ request=request,
96
+ columns=columns,
97
+ row_count=row_count,
98
+ file=file,
99
+ chart_type_hint=chart_type,
100
+ ),
101
+ timeout=SPEC_TIMEOUT_SECONDS,
102
+ )
103
+ spec = self._ensure_bar_defaults(spec, row_count)
104
+ return json_dumps(spec.model_dump())
105
+ except asyncio.TimeoutError:
106
+ return json_dumps(
107
+ {
108
+ "error": "Spec generation timed out.",
109
+ "details": f"Timed out after {SPEC_TIMEOUT_SECONDS} seconds.",
110
+ }
111
+ )
112
+ except (ValidationError, json.JSONDecodeError, ValueError) as exc:
113
+ return json_dumps(
114
+ {
115
+ "error": "Failed to generate a valid visualization spec.",
116
+ "details": str(exc),
117
+ }
118
+ )
119
+
120
+ def render_result(self, console: Console, result: object) -> bool:
121
+ """Render the spec as a terminal chart using plotext."""
122
+ spec = self._parse_spec(result)
123
+ if spec is None:
124
+ return False
125
+
126
+ rows = self._resolve_rows(spec)
127
+ if rows is None:
128
+ if console.is_terminal:
129
+ console.print("[warning]No data available for visualization.[/warning]")
130
+ else:
131
+ console.print("*No data available for visualization.*\n")
132
+ return True
133
+
134
+ rows = apply_transforms(rows, spec.transform)
135
+
136
+ renderer = PlotextRenderer()
137
+ chart = renderer.render(spec, rows)
138
+ if console.is_terminal:
139
+ console.print(Text.from_ansi(chart))
140
+ else:
141
+ console.print(f"```\n{self._strip_ansi(chart)}\n```\n", markup=False)
142
+ return True
143
+
144
+ def render_result_html(self, result: object) -> str | None:
145
+ """Render the spec as an HTML chart."""
146
+ spec = self._parse_spec(result)
147
+ if spec is None:
148
+ return None
149
+
150
+ rows = self._resolve_rows(spec)
151
+ if rows is None:
152
+ return '<div class="viz-error">No data available for visualization.</div>'
153
+
154
+ rows = apply_transforms(rows, spec.transform)
155
+ from .renderers.plotext_renderer import PlotextRenderer
156
+
157
+ renderer = PlotextRenderer()
158
+ chart = renderer.render(spec, rows)
159
+ return f'<pre class="viz-chart">{escape(self._strip_ansi(chart))}</pre>'
160
+
161
+ def _parse_spec(self, result: object) -> VizSpec | None:
162
+ data = self._parse_result(result)
163
+ if not isinstance(data, dict):
164
+ return None
165
+ if "error" in data and data["error"]:
166
+ return None
167
+ try:
168
+ return VizSpec.model_validate(data)
169
+ except ValidationError:
170
+ return None
171
+
172
+ def _parse_result(self, result: object) -> object:
173
+ if isinstance(result, dict):
174
+ return result
175
+ if isinstance(result, str):
176
+ try:
177
+ return json.loads(result)
178
+ except json.JSONDecodeError:
179
+ return {"error": result}
180
+ return {"error": str(result)}
181
+
182
+ def _strip_ansi(self, text: str) -> str:
183
+ return re.sub(r"\x1b\[[0-9;]*m", "", text)
184
+
185
+ def _resolve_rows(self, spec: VizSpec) -> list[dict] | None:
186
+ if self._last_rows is not None and self._last_file == spec.data.source.file:
187
+ return self._last_rows
188
+
189
+ tool_call_id = spec.data.source.file.removeprefix("result_").removesuffix(
190
+ ".json"
191
+ )
192
+
193
+ payload: dict | None = None
194
+ if self._last_ctx is not None:
195
+ payload = find_tool_output_payload(self._last_ctx, tool_call_id)
196
+ elif self._replay_messages is not None:
197
+ payload = find_tool_output_in_messages(self._replay_messages, tool_call_id)
198
+
199
+ if payload is None:
200
+ return None
201
+ summary = extract_data_summary(payload)
202
+ rows = summary.get("rows")
203
+ if isinstance(rows, list):
204
+ return rows
205
+ return None
206
+
207
+ def _ensure_bar_defaults(self, spec: VizSpec, row_count: int) -> VizSpec:
208
+ if not isinstance(spec.chart, BarChart):
209
+ return spec
210
+
211
+ transforms = list(spec.transform)
212
+ has_limit = any(isinstance(t, LimitTransform) for t in transforms)
213
+ has_sort = any(isinstance(t, SortTransform) for t in transforms)
214
+
215
+ if not has_sort:
216
+ transforms.append(
217
+ SortTransform(
218
+ sort=[SortItem(field=spec.chart.encoding.y.field, dir="desc")]
219
+ )
220
+ )
221
+
222
+ if not has_limit and row_count > 20:
223
+ transforms.append(LimitTransform(limit=20))
224
+
225
+ if transforms != spec.transform:
226
+ return spec.model_copy(update={"transform": transforms})
227
+
228
+ return spec
229
+
230
+
231
+ def _get_spec_agent_cls():
232
+ from .spec_agent import SpecAgent
233
+
234
+ return SpecAgent