sqlsaber-viz 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlsaber_viz/__init__.py +19 -0
- sqlsaber_viz/data_loader.py +143 -0
- sqlsaber_viz/prompts.py +31 -0
- sqlsaber_viz/renderers/__init__.py +6 -0
- sqlsaber_viz/renderers/base.py +13 -0
- sqlsaber_viz/renderers/html_renderer.py +17 -0
- sqlsaber_viz/renderers/plotext_renderer.py +385 -0
- sqlsaber_viz/spec.py +130 -0
- sqlsaber_viz/spec_agent.py +144 -0
- sqlsaber_viz/templates.py +175 -0
- sqlsaber_viz/tools.py +234 -0
- sqlsaber_viz/transforms.py +155 -0
- sqlsaber_viz-0.1.1.dist-info/METADATA +12 -0
- sqlsaber_viz-0.1.1.dist-info/RECORD +16 -0
- sqlsaber_viz-0.1.1.dist-info/WHEEL +4 -0
- sqlsaber_viz-0.1.1.dist-info/entry_points.txt +2 -0
sqlsaber_viz/spec.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Pydantic models for visualization specs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Annotated, Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FieldEncoding(BaseModel):
|
|
11
|
+
field: str
|
|
12
|
+
type: Literal["category", "number", "time"] = "number"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ChartOptions(BaseModel):
|
|
16
|
+
width: int | None = Field(default=None, ge=20, le=200)
|
|
17
|
+
height: int | None = Field(default=None, ge=10, le=100)
|
|
18
|
+
x_label: str | None = None
|
|
19
|
+
y_label: str | None = None
|
|
20
|
+
color: str | None = None
|
|
21
|
+
marker: str | None = None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BarEncoding(BaseModel):
|
|
25
|
+
x: FieldEncoding
|
|
26
|
+
y: FieldEncoding
|
|
27
|
+
series: FieldEncoding | None = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BarChart(BaseModel):
|
|
31
|
+
type: Literal["bar"]
|
|
32
|
+
encoding: BarEncoding
|
|
33
|
+
orientation: Literal["vertical", "horizontal"] = "vertical"
|
|
34
|
+
mode: Literal["grouped", "stacked"] = "grouped"
|
|
35
|
+
options: ChartOptions = Field(default_factory=ChartOptions)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class LineEncoding(BaseModel):
|
|
39
|
+
x: FieldEncoding
|
|
40
|
+
y: FieldEncoding
|
|
41
|
+
series: FieldEncoding | None = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LineChart(BaseModel):
|
|
45
|
+
type: Literal["line"]
|
|
46
|
+
encoding: LineEncoding
|
|
47
|
+
options: ChartOptions = Field(default_factory=ChartOptions)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ScatterEncoding(BaseModel):
|
|
51
|
+
x: FieldEncoding
|
|
52
|
+
y: FieldEncoding
|
|
53
|
+
series: FieldEncoding | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ScatterChart(BaseModel):
|
|
57
|
+
type: Literal["scatter"]
|
|
58
|
+
encoding: ScatterEncoding
|
|
59
|
+
options: ChartOptions = Field(default_factory=ChartOptions)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class BoxplotConfig(BaseModel):
|
|
63
|
+
label_field: str
|
|
64
|
+
value_field: str
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class BoxplotChart(BaseModel):
|
|
68
|
+
type: Literal["boxplot"]
|
|
69
|
+
boxplot: BoxplotConfig
|
|
70
|
+
options: ChartOptions = Field(default_factory=ChartOptions)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class HistogramConfig(BaseModel):
|
|
74
|
+
field: str
|
|
75
|
+
bins: int = Field(default=20, ge=2, le=100)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class HistogramChart(BaseModel):
|
|
79
|
+
type: Literal["histogram"]
|
|
80
|
+
histogram: HistogramConfig
|
|
81
|
+
options: ChartOptions = Field(default_factory=ChartOptions)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
ChartSpec = Annotated[
|
|
85
|
+
BarChart | LineChart | ScatterChart | BoxplotChart | HistogramChart,
|
|
86
|
+
Field(discriminator="type"),
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class SortItem(BaseModel):
|
|
91
|
+
field: str
|
|
92
|
+
dir: Literal["asc", "desc"] = "asc"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SortTransform(BaseModel):
|
|
96
|
+
sort: list[SortItem]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class LimitTransform(BaseModel):
|
|
100
|
+
limit: int = Field(ge=1)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class FilterConfig(BaseModel):
|
|
104
|
+
field: str
|
|
105
|
+
op: Literal["==", "!=", ">", "<", ">=", "<="]
|
|
106
|
+
value: str | int | float | bool | None
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class FilterTransform(BaseModel):
|
|
110
|
+
filter: FilterConfig
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
Transform = SortTransform | LimitTransform | FilterTransform
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class DataSource(BaseModel):
|
|
117
|
+
file: str = Field(pattern=r"^result_[A-Za-z0-9._-]+\.json$")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class DataConfig(BaseModel):
|
|
121
|
+
source: DataSource
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class VizSpec(BaseModel):
|
|
125
|
+
version: Literal["1"] = "1"
|
|
126
|
+
title: str | None = None
|
|
127
|
+
description: str | None = None
|
|
128
|
+
data: DataConfig
|
|
129
|
+
chart: ChartSpec
|
|
130
|
+
transform: list[Transform] = Field(default_factory=list)
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""Internal agent for generating visualization specs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from sqlsaber.agents.provider_factory import ProviderFactory
|
|
9
|
+
from sqlsaber.config import providers
|
|
10
|
+
from sqlsaber.config.settings import Config
|
|
11
|
+
|
|
12
|
+
from .prompts import VIZ_SYSTEM_PROMPT
|
|
13
|
+
from .spec import VizSpec
|
|
14
|
+
from .templates import ChartType, list_chart_types, vizspec_template
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SpecAgent:
|
|
18
|
+
"""Internal agent for generating visualization specs."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, model_name: str | None = None, api_key: str | None = None):
|
|
21
|
+
self.config = Config()
|
|
22
|
+
self._model_name_override = model_name
|
|
23
|
+
self._api_key_override = api_key
|
|
24
|
+
self.agent = self._build_agent()
|
|
25
|
+
|
|
26
|
+
def _build_agent(self):
|
|
27
|
+
model_name = self._model_name_override or self.config.model.name
|
|
28
|
+
model_name_only = (
|
|
29
|
+
model_name.split(":", 1)[1] if ":" in model_name else model_name
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
if not (self._model_name_override and self._api_key_override):
|
|
33
|
+
self.config.auth.validate(model_name)
|
|
34
|
+
|
|
35
|
+
provider = providers.provider_from_model(model_name) or ""
|
|
36
|
+
api_key = self._api_key_override or self.config.auth.get_api_key(model_name)
|
|
37
|
+
|
|
38
|
+
factory = ProviderFactory()
|
|
39
|
+
agent = factory.create_agent(
|
|
40
|
+
provider=provider,
|
|
41
|
+
model_name=model_name_only,
|
|
42
|
+
full_model_str=model_name,
|
|
43
|
+
api_key=api_key,
|
|
44
|
+
thinking_enabled=False,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@agent.system_prompt
|
|
48
|
+
def viz_system_prompt() -> str:
|
|
49
|
+
return VIZ_SYSTEM_PROMPT
|
|
50
|
+
|
|
51
|
+
self._register_tools(agent)
|
|
52
|
+
|
|
53
|
+
return agent
|
|
54
|
+
|
|
55
|
+
def _register_tools(self, agent) -> None:
|
|
56
|
+
"""Register visualization helper tools on the agent."""
|
|
57
|
+
|
|
58
|
+
@agent.tool_plain
|
|
59
|
+
def get_vizspec_template(chart_type: ChartType, file: str) -> dict:
|
|
60
|
+
"""Get the complete VizSpec template for a chart type.
|
|
61
|
+
|
|
62
|
+
Call this FIRST to get the correct JSON structure, then fill in
|
|
63
|
+
the placeholder field names with actual column names from your data.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
chart_type: One of "bar", "line", "scatter", "boxplot", "histogram"
|
|
67
|
+
file: The result file key (e.g., "result_abc123.json")
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A complete VizSpec template with placeholders for field names.
|
|
71
|
+
"""
|
|
72
|
+
return vizspec_template(chart_type, file)
|
|
73
|
+
|
|
74
|
+
@agent.tool_plain
|
|
75
|
+
def get_available_chart_types() -> list[dict]:
|
|
76
|
+
"""List available chart types with descriptions.
|
|
77
|
+
|
|
78
|
+
Call this if you're unsure which chart type to use for the data.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
List of chart types with descriptions and use cases.
|
|
82
|
+
"""
|
|
83
|
+
return list_chart_types()
|
|
84
|
+
|
|
85
|
+
async def generate_spec(
|
|
86
|
+
self,
|
|
87
|
+
request: str,
|
|
88
|
+
columns: list[dict],
|
|
89
|
+
row_count: int,
|
|
90
|
+
file: str,
|
|
91
|
+
chart_type_hint: str | None = None,
|
|
92
|
+
) -> VizSpec:
|
|
93
|
+
"""Generate a VizSpec from user request and data summary."""
|
|
94
|
+
|
|
95
|
+
prompt = self._build_prompt(
|
|
96
|
+
request=request,
|
|
97
|
+
columns=columns,
|
|
98
|
+
row_count=row_count,
|
|
99
|
+
file=file,
|
|
100
|
+
chart_type_hint=chart_type_hint,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
result = await self.agent.run(prompt)
|
|
104
|
+
output = str(result.output).strip()
|
|
105
|
+
parsed = _parse_json(output)
|
|
106
|
+
return VizSpec.model_validate(parsed)
|
|
107
|
+
|
|
108
|
+
def _build_prompt(
|
|
109
|
+
self,
|
|
110
|
+
request: str,
|
|
111
|
+
columns: list[dict],
|
|
112
|
+
row_count: int,
|
|
113
|
+
file: str,
|
|
114
|
+
chart_type_hint: str | None,
|
|
115
|
+
) -> str:
|
|
116
|
+
columns_json = json.dumps(columns, ensure_ascii=False, indent=2)
|
|
117
|
+
hint_text = f"Chart type hint: {chart_type_hint}" if chart_type_hint else ""
|
|
118
|
+
|
|
119
|
+
return (
|
|
120
|
+
"## User Request\n"
|
|
121
|
+
f"{request.strip()}\n\n"
|
|
122
|
+
"## Data Summary\n"
|
|
123
|
+
f"Row count: {row_count}\n"
|
|
124
|
+
f"File: {file}\n"
|
|
125
|
+
f"Columns:\n{columns_json}\n\n"
|
|
126
|
+
f"{hint_text}\n\n"
|
|
127
|
+
"Use `get_vizspec_template` to get the correct spec structure, "
|
|
128
|
+
"then fill in the placeholders with actual column names.\n"
|
|
129
|
+
"Return ONLY the final JSON."
|
|
130
|
+
).strip()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _parse_json(text: str) -> dict[str, Any]:
|
|
134
|
+
try:
|
|
135
|
+
parsed = json.loads(text)
|
|
136
|
+
except json.JSONDecodeError:
|
|
137
|
+
start = text.find("{")
|
|
138
|
+
end = text.rfind("}")
|
|
139
|
+
if start == -1 or end == -1 or end <= start:
|
|
140
|
+
raise
|
|
141
|
+
parsed = json.loads(text[start : end + 1])
|
|
142
|
+
if not isinstance(parsed, dict):
|
|
143
|
+
raise json.JSONDecodeError("Expected JSON object", text, 0)
|
|
144
|
+
return parsed
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Template builders for visualization specs.
|
|
2
|
+
|
|
3
|
+
These functions generate minimal valid templates from Pydantic models,
|
|
4
|
+
ensuring they stay in sync with the schema definitions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
11
|
+
from .spec import (
|
|
12
|
+
BarChart,
|
|
13
|
+
BarEncoding,
|
|
14
|
+
BoxplotChart,
|
|
15
|
+
BoxplotConfig,
|
|
16
|
+
ChartSpec,
|
|
17
|
+
ChartOptions,
|
|
18
|
+
DataConfig,
|
|
19
|
+
DataSource,
|
|
20
|
+
FieldEncoding,
|
|
21
|
+
HistogramChart,
|
|
22
|
+
HistogramConfig,
|
|
23
|
+
LineChart,
|
|
24
|
+
LineEncoding,
|
|
25
|
+
ScatterChart,
|
|
26
|
+
ScatterEncoding,
|
|
27
|
+
VizSpec,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
ChartType = Literal["bar", "line", "scatter", "boxplot", "histogram"]
|
|
31
|
+
|
|
32
|
+
# Placeholder values for template fields
|
|
33
|
+
_CATEGORY_PLACEHOLDER = "<category_column>"
|
|
34
|
+
_NUMBER_PLACEHOLDER = "<number_column>"
|
|
35
|
+
_TIME_PLACEHOLDER = "<time_column>"
|
|
36
|
+
_LABEL_PLACEHOLDER = "<label_column>"
|
|
37
|
+
_VALUE_PLACEHOLDER = "<value_column>"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _build_bar_chart() -> BarChart:
|
|
41
|
+
return BarChart(
|
|
42
|
+
type="bar",
|
|
43
|
+
encoding=BarEncoding(
|
|
44
|
+
x=FieldEncoding(field=_CATEGORY_PLACEHOLDER, type="category"),
|
|
45
|
+
y=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
|
|
46
|
+
series=None,
|
|
47
|
+
),
|
|
48
|
+
orientation="vertical",
|
|
49
|
+
mode="grouped",
|
|
50
|
+
options=ChartOptions(),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _build_line_chart() -> LineChart:
|
|
55
|
+
return LineChart(
|
|
56
|
+
type="line",
|
|
57
|
+
encoding=LineEncoding(
|
|
58
|
+
x=FieldEncoding(field=_TIME_PLACEHOLDER, type="time"),
|
|
59
|
+
y=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
|
|
60
|
+
series=None,
|
|
61
|
+
),
|
|
62
|
+
options=ChartOptions(),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _build_scatter_chart() -> ScatterChart:
|
|
67
|
+
return ScatterChart(
|
|
68
|
+
type="scatter",
|
|
69
|
+
encoding=ScatterEncoding(
|
|
70
|
+
x=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
|
|
71
|
+
y=FieldEncoding(field=_NUMBER_PLACEHOLDER, type="number"),
|
|
72
|
+
series=None,
|
|
73
|
+
),
|
|
74
|
+
options=ChartOptions(),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _build_boxplot_chart() -> BoxplotChart:
|
|
79
|
+
return BoxplotChart(
|
|
80
|
+
type="boxplot",
|
|
81
|
+
boxplot=BoxplotConfig(
|
|
82
|
+
label_field=_LABEL_PLACEHOLDER,
|
|
83
|
+
value_field=_VALUE_PLACEHOLDER,
|
|
84
|
+
),
|
|
85
|
+
options=ChartOptions(),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _build_histogram_chart() -> HistogramChart:
|
|
90
|
+
return HistogramChart(
|
|
91
|
+
type="histogram",
|
|
92
|
+
histogram=HistogramConfig(
|
|
93
|
+
field=_NUMBER_PLACEHOLDER,
|
|
94
|
+
bins=20,
|
|
95
|
+
),
|
|
96
|
+
options=ChartOptions(),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
_CHART_BUILDERS: dict[ChartType, callable] = {
|
|
101
|
+
"bar": _build_bar_chart,
|
|
102
|
+
"line": _build_line_chart,
|
|
103
|
+
"scatter": _build_scatter_chart,
|
|
104
|
+
"boxplot": _build_boxplot_chart,
|
|
105
|
+
"histogram": _build_histogram_chart,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _build_chart(chart_type: ChartType) -> ChartSpec:
|
|
110
|
+
"""Build a chart object for the given type."""
|
|
111
|
+
builder = _CHART_BUILDERS.get(chart_type)
|
|
112
|
+
if builder is None:
|
|
113
|
+
raise ValueError(f"Unknown chart type: {chart_type}")
|
|
114
|
+
return builder()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def chart_template(chart_type: ChartType) -> dict:
|
|
118
|
+
"""Return a minimal valid chart template for the given chart type.
|
|
119
|
+
|
|
120
|
+
The template uses placeholder field names that the model should replace
|
|
121
|
+
with actual column names from the data.
|
|
122
|
+
"""
|
|
123
|
+
return _build_chart(chart_type).model_dump(exclude_none=True)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def vizspec_template(chart_type: ChartType, file: str) -> dict:
|
|
127
|
+
"""Return a complete VizSpec template with data source pre-filled.
|
|
128
|
+
|
|
129
|
+
The template includes the chart structure for the specified type
|
|
130
|
+
and has placeholders for field names.
|
|
131
|
+
"""
|
|
132
|
+
spec = VizSpec(
|
|
133
|
+
version="1",
|
|
134
|
+
title=None,
|
|
135
|
+
description=None,
|
|
136
|
+
data=DataConfig(source=DataSource(file=file)),
|
|
137
|
+
chart=_build_chart(chart_type),
|
|
138
|
+
transform=[],
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return spec.model_dump(exclude_none=True)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def list_chart_types() -> list[dict]:
|
|
145
|
+
"""Return available chart types with descriptions.
|
|
146
|
+
|
|
147
|
+
Helps the model choose the appropriate chart type for the data.
|
|
148
|
+
"""
|
|
149
|
+
return [
|
|
150
|
+
{
|
|
151
|
+
"type": "bar",
|
|
152
|
+
"description": "Compare categories. Use x for category, y for numeric value.",
|
|
153
|
+
"use_when": "Comparing values across categories (e.g., sales by region)",
|
|
154
|
+
},
|
|
155
|
+
{
|
|
156
|
+
"type": "line",
|
|
157
|
+
"description": "Show trends over time/sequence. Use x for time/sequence, y for value.",
|
|
158
|
+
"use_when": "Showing change over time (e.g., monthly revenue)",
|
|
159
|
+
},
|
|
160
|
+
{
|
|
161
|
+
"type": "scatter",
|
|
162
|
+
"description": "Show correlation between two numeric variables.",
|
|
163
|
+
"use_when": "Exploring relationship between two numbers (e.g., age vs income)",
|
|
164
|
+
},
|
|
165
|
+
{
|
|
166
|
+
"type": "boxplot",
|
|
167
|
+
"description": "Show distribution of values across groups.",
|
|
168
|
+
"use_when": "Comparing distributions (e.g., salary by department)",
|
|
169
|
+
},
|
|
170
|
+
{
|
|
171
|
+
"type": "histogram",
|
|
172
|
+
"description": "Show distribution of a single numeric variable.",
|
|
173
|
+
"use_when": "Understanding value distribution (e.g., age distribution)",
|
|
174
|
+
},
|
|
175
|
+
]
|
sqlsaber_viz/tools.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Visualization tool implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import re
|
|
8
|
+
from html import escape
|
|
9
|
+
|
|
10
|
+
from pydantic import ValidationError
|
|
11
|
+
from pydantic_ai import RunContext
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
from rich.text import Text
|
|
14
|
+
|
|
15
|
+
from sqlsaber.tools.base import Tool
|
|
16
|
+
from sqlsaber.utils.json_utils import json_dumps
|
|
17
|
+
|
|
18
|
+
from .data_loader import (
|
|
19
|
+
extract_data_summary,
|
|
20
|
+
find_tool_output_in_messages,
|
|
21
|
+
find_tool_output_payload,
|
|
22
|
+
)
|
|
23
|
+
from .renderers.plotext_renderer import PlotextRenderer
|
|
24
|
+
from .spec import BarChart, LimitTransform, SortItem, SortTransform, VizSpec
|
|
25
|
+
from .transforms import apply_transforms
|
|
26
|
+
|
|
27
|
+
TOOL_OUTPUT_FILE_PATTERN = re.compile(r"^result_[A-Za-z0-9._-]+\.json$")
|
|
28
|
+
SPEC_TIMEOUT_SECONDS = 300
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class VizTool(Tool):
|
|
32
|
+
"""Terminal visualization tool for SQL results."""
|
|
33
|
+
|
|
34
|
+
requires_ctx = True
|
|
35
|
+
|
|
36
|
+
def __init__(self):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self._last_ctx: RunContext | None = None
|
|
39
|
+
self._last_rows: list[dict] | None = None
|
|
40
|
+
self._last_file: str | None = None
|
|
41
|
+
self._replay_messages: list | None = None
|
|
42
|
+
|
|
43
|
+
def set_replay_messages(self, messages: list) -> None:
|
|
44
|
+
"""Set message history for replay scenarios (e.g., threads show)."""
|
|
45
|
+
self._replay_messages = messages
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def name(self) -> str:
|
|
49
|
+
return "viz"
|
|
50
|
+
|
|
51
|
+
def render_executing(self, console: Console, args: dict) -> bool:
|
|
52
|
+
"""Suppress default JSON rendering during execution."""
|
|
53
|
+
return True
|
|
54
|
+
|
|
55
|
+
async def execute(
|
|
56
|
+
self,
|
|
57
|
+
ctx: RunContext,
|
|
58
|
+
request: str,
|
|
59
|
+
file: str,
|
|
60
|
+
chart_type: str | None = None,
|
|
61
|
+
) -> str:
|
|
62
|
+
"""Generate a visualization spec for SQL results.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
request: Natural language description of the desired visualization.
|
|
66
|
+
file: Result file key from execute_sql (e.g., "result_abc123.json").
|
|
67
|
+
chart_type: Optional hint for chart type (bar, line, scatter, boxplot, histogram).
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
JSON string containing the visualization spec.
|
|
71
|
+
"""
|
|
72
|
+
self._last_ctx = ctx
|
|
73
|
+
|
|
74
|
+
if not file or not TOOL_OUTPUT_FILE_PATTERN.match(file):
|
|
75
|
+
return json_dumps({"error": "Invalid result file key format."})
|
|
76
|
+
|
|
77
|
+
tool_call_id = file.removeprefix("result_").removesuffix(".json")
|
|
78
|
+
payload = find_tool_output_payload(ctx, tool_call_id)
|
|
79
|
+
if payload is None:
|
|
80
|
+
return json_dumps({"error": "Tool output not found in message history."})
|
|
81
|
+
|
|
82
|
+
summary = extract_data_summary(payload)
|
|
83
|
+
columns = summary.get("columns", [])
|
|
84
|
+
row_count = summary.get("row_count", 0)
|
|
85
|
+
rows = summary.get("rows", [])
|
|
86
|
+
|
|
87
|
+
self._last_rows = rows
|
|
88
|
+
self._last_file = file
|
|
89
|
+
|
|
90
|
+
agent = _get_spec_agent_cls()()
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
spec = await asyncio.wait_for(
|
|
94
|
+
agent.generate_spec(
|
|
95
|
+
request=request,
|
|
96
|
+
columns=columns,
|
|
97
|
+
row_count=row_count,
|
|
98
|
+
file=file,
|
|
99
|
+
chart_type_hint=chart_type,
|
|
100
|
+
),
|
|
101
|
+
timeout=SPEC_TIMEOUT_SECONDS,
|
|
102
|
+
)
|
|
103
|
+
spec = self._ensure_bar_defaults(spec, row_count)
|
|
104
|
+
return json_dumps(spec.model_dump())
|
|
105
|
+
except asyncio.TimeoutError:
|
|
106
|
+
return json_dumps(
|
|
107
|
+
{
|
|
108
|
+
"error": "Spec generation timed out.",
|
|
109
|
+
"details": f"Timed out after {SPEC_TIMEOUT_SECONDS} seconds.",
|
|
110
|
+
}
|
|
111
|
+
)
|
|
112
|
+
except (ValidationError, json.JSONDecodeError, ValueError) as exc:
|
|
113
|
+
return json_dumps(
|
|
114
|
+
{
|
|
115
|
+
"error": "Failed to generate a valid visualization spec.",
|
|
116
|
+
"details": str(exc),
|
|
117
|
+
}
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def render_result(self, console: Console, result: object) -> bool:
|
|
121
|
+
"""Render the spec as a terminal chart using plotext."""
|
|
122
|
+
spec = self._parse_spec(result)
|
|
123
|
+
if spec is None:
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
rows = self._resolve_rows(spec)
|
|
127
|
+
if rows is None:
|
|
128
|
+
if console.is_terminal:
|
|
129
|
+
console.print("[warning]No data available for visualization.[/warning]")
|
|
130
|
+
else:
|
|
131
|
+
console.print("*No data available for visualization.*\n")
|
|
132
|
+
return True
|
|
133
|
+
|
|
134
|
+
rows = apply_transforms(rows, spec.transform)
|
|
135
|
+
|
|
136
|
+
renderer = PlotextRenderer()
|
|
137
|
+
chart = renderer.render(spec, rows)
|
|
138
|
+
if console.is_terminal:
|
|
139
|
+
console.print(Text.from_ansi(chart))
|
|
140
|
+
else:
|
|
141
|
+
console.print(f"```\n{self._strip_ansi(chart)}\n```\n", markup=False)
|
|
142
|
+
return True
|
|
143
|
+
|
|
144
|
+
def render_result_html(self, result: object) -> str | None:
|
|
145
|
+
"""Render the spec as an HTML chart."""
|
|
146
|
+
spec = self._parse_spec(result)
|
|
147
|
+
if spec is None:
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
rows = self._resolve_rows(spec)
|
|
151
|
+
if rows is None:
|
|
152
|
+
return '<div class="viz-error">No data available for visualization.</div>'
|
|
153
|
+
|
|
154
|
+
rows = apply_transforms(rows, spec.transform)
|
|
155
|
+
from .renderers.plotext_renderer import PlotextRenderer
|
|
156
|
+
|
|
157
|
+
renderer = PlotextRenderer()
|
|
158
|
+
chart = renderer.render(spec, rows)
|
|
159
|
+
return f'<pre class="viz-chart">{escape(self._strip_ansi(chart))}</pre>'
|
|
160
|
+
|
|
161
|
+
def _parse_spec(self, result: object) -> VizSpec | None:
|
|
162
|
+
data = self._parse_result(result)
|
|
163
|
+
if not isinstance(data, dict):
|
|
164
|
+
return None
|
|
165
|
+
if "error" in data and data["error"]:
|
|
166
|
+
return None
|
|
167
|
+
try:
|
|
168
|
+
return VizSpec.model_validate(data)
|
|
169
|
+
except ValidationError:
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
def _parse_result(self, result: object) -> object:
|
|
173
|
+
if isinstance(result, dict):
|
|
174
|
+
return result
|
|
175
|
+
if isinstance(result, str):
|
|
176
|
+
try:
|
|
177
|
+
return json.loads(result)
|
|
178
|
+
except json.JSONDecodeError:
|
|
179
|
+
return {"error": result}
|
|
180
|
+
return {"error": str(result)}
|
|
181
|
+
|
|
182
|
+
def _strip_ansi(self, text: str) -> str:
|
|
183
|
+
return re.sub(r"\x1b\[[0-9;]*m", "", text)
|
|
184
|
+
|
|
185
|
+
def _resolve_rows(self, spec: VizSpec) -> list[dict] | None:
|
|
186
|
+
if self._last_rows is not None and self._last_file == spec.data.source.file:
|
|
187
|
+
return self._last_rows
|
|
188
|
+
|
|
189
|
+
tool_call_id = spec.data.source.file.removeprefix("result_").removesuffix(
|
|
190
|
+
".json"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
payload: dict | None = None
|
|
194
|
+
if self._last_ctx is not None:
|
|
195
|
+
payload = find_tool_output_payload(self._last_ctx, tool_call_id)
|
|
196
|
+
elif self._replay_messages is not None:
|
|
197
|
+
payload = find_tool_output_in_messages(self._replay_messages, tool_call_id)
|
|
198
|
+
|
|
199
|
+
if payload is None:
|
|
200
|
+
return None
|
|
201
|
+
summary = extract_data_summary(payload)
|
|
202
|
+
rows = summary.get("rows")
|
|
203
|
+
if isinstance(rows, list):
|
|
204
|
+
return rows
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
def _ensure_bar_defaults(self, spec: VizSpec, row_count: int) -> VizSpec:
|
|
208
|
+
if not isinstance(spec.chart, BarChart):
|
|
209
|
+
return spec
|
|
210
|
+
|
|
211
|
+
transforms = list(spec.transform)
|
|
212
|
+
has_limit = any(isinstance(t, LimitTransform) for t in transforms)
|
|
213
|
+
has_sort = any(isinstance(t, SortTransform) for t in transforms)
|
|
214
|
+
|
|
215
|
+
if not has_sort:
|
|
216
|
+
transforms.append(
|
|
217
|
+
SortTransform(
|
|
218
|
+
sort=[SortItem(field=spec.chart.encoding.y.field, dir="desc")]
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if not has_limit and row_count > 20:
|
|
223
|
+
transforms.append(LimitTransform(limit=20))
|
|
224
|
+
|
|
225
|
+
if transforms != spec.transform:
|
|
226
|
+
return spec.model_copy(update={"transform": transforms})
|
|
227
|
+
|
|
228
|
+
return spec
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _get_spec_agent_cls():
|
|
232
|
+
from .spec_agent import SpecAgent
|
|
233
|
+
|
|
234
|
+
return SpecAgent
|