sqlsaber-viz 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,19 @@
1
+ """SQLSaber visualization plugin."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ if TYPE_CHECKING:
6
+ from sqlsaber.tools.registry import ToolRegistry
7
+
8
+
9
+ def register_tools(registry: "ToolRegistry | None" = None):
10
+ """Register visualization tools.
11
+
12
+ Returns list of tool classes for sqlsaber to register.
13
+ """
14
+ from .tools import VizTool
15
+
16
+ return [VizTool]
17
+
18
+
19
+ __all__ = ["register_tools"]
@@ -0,0 +1,143 @@
1
+ """Helpers for loading SQL result payloads and extracting summaries."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from datetime import date, datetime, time
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from pydantic_ai import RunContext
11
+
12
+
13
+ def find_tool_output_payload(ctx: "RunContext", tool_call_id: str) -> dict | None:
14
+ """Find tool output from RunContext message history."""
15
+ return find_tool_output_in_messages(ctx.messages, tool_call_id)
16
+
17
+
18
+ def find_tool_output_in_messages(messages: list, tool_call_id: str) -> dict | None:
19
+ """Find tool output from a list of ModelMessage objects."""
20
+ for message in reversed(messages):
21
+ for part in getattr(message, "parts", []):
22
+ if getattr(part, "part_kind", "") not in (
23
+ "tool-return",
24
+ "builtin-tool-return",
25
+ ):
26
+ continue
27
+ if getattr(part, "tool_call_id", None) != tool_call_id:
28
+ continue
29
+ content = getattr(part, "content", None)
30
+ if isinstance(content, dict):
31
+ return content
32
+ if isinstance(content, str):
33
+ try:
34
+ parsed = json.loads(content)
35
+ except json.JSONDecodeError:
36
+ return {"result": content}
37
+ if isinstance(parsed, dict):
38
+ return parsed
39
+ return {"result": parsed}
40
+ return None
41
+
42
+
43
+ def extract_data_summary(payload: dict) -> dict:
44
+ """Extract column info and samples from SQL result payload.
45
+
46
+ Returns:
47
+ {
48
+ "columns": [
49
+ {"name": "col1", "type": "string", "sample": ["a", "b", "c"]},
50
+ {"name": "col2", "type": "number", "sample": [1, 2, 3]},
51
+ ],
52
+ "row_count": 150,
53
+ "rows": [...] # Full rows for rendering
54
+ }
55
+ """
56
+
57
+ results = payload.get("results")
58
+ rows = _coerce_rows(results) if isinstance(results, list) else []
59
+ row_count = payload.get("row_count")
60
+ if not isinstance(row_count, int):
61
+ row_count = len(rows)
62
+
63
+ columns = _extract_columns(rows)
64
+ return {"columns": columns, "row_count": row_count, "rows": rows}
65
+
66
+
67
+ def infer_column_type(values: list[object]) -> str:
68
+ """Infer column type from sample values.
69
+
70
+ Returns: "number", "string", "time", "boolean", or "null"
71
+ """
72
+
73
+ cleaned = [value for value in values if value is not None]
74
+ if not cleaned:
75
+ return "null"
76
+
77
+ if all(isinstance(value, bool) for value in cleaned):
78
+ return "boolean"
79
+
80
+ if all(isinstance(value, (int, float)) for value in cleaned):
81
+ return "number"
82
+
83
+ if all(_is_time_value(value) for value in cleaned):
84
+ return "time"
85
+
86
+ return "string"
87
+
88
+
89
+ def _extract_columns(rows: list[dict[str, object]]) -> list[dict[str, object]]:
90
+ if not rows:
91
+ return []
92
+
93
+ # Use the union of keys from the first 50 rows to avoid missing sparse columns.
94
+ keys: list[str] = []
95
+ seen: set[str] = set()
96
+ for row in rows[:50]:
97
+ for key in row.keys():
98
+ if key not in seen:
99
+ seen.add(key)
100
+ keys.append(key)
101
+
102
+ columns: list[dict[str, object]] = []
103
+ for key in keys:
104
+ sample_values = [row.get(key) for row in rows[:20] if key in row]
105
+ column_type = infer_column_type(sample_values)
106
+ columns.append(
107
+ {
108
+ "name": key,
109
+ "type": column_type,
110
+ "sample": sample_values[:5],
111
+ }
112
+ )
113
+
114
+ return columns
115
+
116
+
117
+ def _coerce_rows(rows: list[object]) -> list[dict[str, object]]:
118
+ coerced: list[dict[str, object]] = []
119
+ for row in rows:
120
+ if isinstance(row, dict):
121
+ coerced.append({str(key): value for key, value in row.items()})
122
+ else:
123
+ coerced.append({"value": row})
124
+ return coerced
125
+
126
+
127
+ def _is_time_value(value: object) -> bool:
128
+ if isinstance(value, (datetime, date, time)):
129
+ return True
130
+ if isinstance(value, str):
131
+ normalized = value
132
+ if value.endswith("Z"):
133
+ normalized = value[:-1] + "+00:00"
134
+ try:
135
+ datetime.fromisoformat(normalized)
136
+ return True
137
+ except ValueError:
138
+ try:
139
+ time.fromisoformat(normalized)
140
+ return True
141
+ except ValueError:
142
+ return False
143
+ return False
@@ -0,0 +1,31 @@
1
+ """Prompt definitions for viz spec generation."""
2
+
3
+ VIZ_SYSTEM_PROMPT = """You are a visualization spec generator. Given a user's request and data summary, generate a valid JSON visualization spec.
4
+
5
+ ## Workflow
6
+ 1. Decide the appropriate chart type based on the request and data
7
+ 2. Call `get_vizspec_template` with the chart type and file to get the correct spec structure
8
+ 3. Fill in the template with actual column names from the provided data summary
9
+ 4. Return ONLY the final JSON spec (no explanations, no markdown code blocks)
10
+
11
+ ## Chart Type Selection
12
+ - Comparing categories → bar
13
+ - Comparing categories across series → bar with encoding.series
14
+ - Trend over time → line
15
+ - Correlation between two numbers → scatter
16
+ - Distribution of one variable → histogram
17
+ - Distribution comparison across groups → boxplot
18
+
19
+ ## Transform Operations (optional, add to "transform" array)
20
+ - {"sort": [{"field": "col", "dir": "desc"}]} - Sort data
21
+ - {"limit": 20} - Limit rows (recommended for bar charts with many categories)
22
+ - {"filter": {"field": "col", "op": "!=", "value": null}} - Filter rows
23
+
24
+ ## Rules
25
+ 1. ALWAYS call `get_vizspec_template` first to get the correct structure
26
+ 2. Use ONLY columns that exist in the provided data summary
27
+ 3. Match field types: category columns for x in bar charts, numeric columns for y
28
+ 4. Add limit transform for bar charts to avoid overcrowding (10-20 bars max)
29
+ 5. Sort bar charts by y value descending for better readability
30
+ 6. Title should describe what the chart shows
31
+ """
@@ -0,0 +1,6 @@
1
+ """Renderer exports for SQLSaber viz."""
2
+
3
+ from .base import RendererProtocol
4
+ from .plotext_renderer import PlotextRenderer
5
+
6
+ __all__ = ["RendererProtocol", "PlotextRenderer"]
@@ -0,0 +1,13 @@
1
+ """Renderer protocol for visualization outputs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ from ..spec import VizSpec
8
+
9
+
10
+ class RendererProtocol(Protocol):
11
+ def render(self, spec: VizSpec, rows: list[dict]) -> str:
12
+ """Render a visualization spec with data rows."""
13
+ ...
@@ -0,0 +1,17 @@
1
+ """Placeholder HTML renderer for future web UI support."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ..spec import VizSpec
6
+
7
+
8
+ class HtmlRenderer:
9
+ """Render VizSpec to HTML.
10
+
11
+ Placeholder implementation; currently returns an empty string.
12
+ """
13
+
14
+ def render(self, spec: VizSpec, rows: list[dict]) -> str:
15
+ _ = spec
16
+ _ = rows
17
+ return ""
@@ -0,0 +1,385 @@
1
+ """Plotext renderer for terminal charts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from collections import defaultdict
7
+ from datetime import datetime, time
8
+ from typing import Iterable
9
+
10
+ from ..spec import (
11
+ BarChart,
12
+ BoxplotChart,
13
+ HistogramChart,
14
+ LineChart,
15
+ ScatterChart,
16
+ VizSpec,
17
+ )
18
+
19
+
20
+ class PlotextRenderer:
21
+ """Render VizSpec to terminal using plotext."""
22
+
23
+ _series_colors = [
24
+ "cyan+",
25
+ "yellow+",
26
+ "red+",
27
+ "green+",
28
+ "blue+",
29
+ "magenta+",
30
+ "white+",
31
+ ]
32
+ _default_width = 80
33
+ _default_height = 25
34
+
35
+ def render(self, spec: VizSpec, rows: list[dict]) -> str:
36
+ """Render spec with data to ASCII chart string.
37
+
38
+ Returns:
39
+ ASCII chart string from plt.build(), or error message if rendering fails.
40
+ """
41
+ import plotext as plt
42
+
43
+ plt.clf()
44
+ plt.clear_figure()
45
+
46
+ chart = spec.chart
47
+ options = chart.options
48
+
49
+ width = options.width or self._default_width
50
+ height = options.height or self._default_height
51
+ plt.plot_size(width=width, height=height)
52
+
53
+ if spec.title:
54
+ plt.title(spec.title)
55
+
56
+ error_msg: str | None = None
57
+ try:
58
+ if isinstance(chart, BarChart):
59
+ error_msg = self._render_bar(chart, rows, plt)
60
+ elif isinstance(chart, LineChart):
61
+ error_msg = self._render_line(chart, rows, plt)
62
+ elif isinstance(chart, ScatterChart):
63
+ error_msg = self._render_scatter(chart, rows, plt)
64
+ elif isinstance(chart, BoxplotChart):
65
+ error_msg = self._render_boxplot(chart, rows, plt)
66
+ elif isinstance(chart, HistogramChart):
67
+ error_msg = self._render_histogram(chart, rows, plt)
68
+ else:
69
+ return f"[Unsupported chart type: {type(chart).__name__}]"
70
+ except Exception as e:
71
+ return f"[Chart rendering error: {e}]"
72
+
73
+ if error_msg:
74
+ return error_msg
75
+
76
+ if options.x_label:
77
+ plt.xlabel(options.x_label)
78
+ if options.y_label:
79
+ plt.ylabel(options.y_label)
80
+
81
+ return plt.build()
82
+
83
+ def _render_bar(self, chart: BarChart, rows: list[dict], plt) -> str | None:
84
+ x_field = chart.encoding.x.field
85
+ y_field = chart.encoding.y.field
86
+ series_field = chart.encoding.series.field if chart.encoding.series else None
87
+
88
+ orientation = "h" if chart.orientation == "horizontal" else "v"
89
+
90
+ if series_field:
91
+ categories, series_names, series_values = self._build_series_matrix(
92
+ rows, x_field, y_field, series_field
93
+ )
94
+ if not categories or not series_names:
95
+ return f"[No data: no valid values for '{x_field}' / '{y_field}']"
96
+ if chart.mode == "stacked":
97
+ plt.stacked_bar(
98
+ categories,
99
+ series_values,
100
+ labels=series_names,
101
+ orientation=orientation,
102
+ )
103
+ else:
104
+ plt.multiple_bar(
105
+ categories,
106
+ series_values,
107
+ labels=series_names,
108
+ orientation=orientation,
109
+ )
110
+ return None
111
+
112
+ # Aggregate by category (sum) for consistency with series path
113
+ aggregated: dict[str, float] = {}
114
+ for row in rows:
115
+ category = str(row.get(x_field, ""))
116
+ value = self._to_number(row.get(y_field))
117
+ if value is None:
118
+ continue
119
+ aggregated[category] = aggregated.get(category, 0.0) + value
120
+
121
+ if not aggregated:
122
+ return f"[No data: no valid numeric values for '{y_field}']"
123
+
124
+ categories = list(aggregated.keys())
125
+ values = list(aggregated.values())
126
+
127
+ color = self._safe_color(chart.options.color, "blue+")
128
+ plt.bar(categories, values, color=color, orientation=orientation)
129
+ return None
130
+
131
+ def _render_line(self, chart: LineChart, rows: list[dict], plt) -> str | None:
132
+ x_field = chart.encoding.x.field
133
+ y_field = chart.encoding.y.field
134
+ series_field = chart.encoding.series.field if chart.encoding.series else None
135
+
136
+ marker = self._safe_marker(chart.options.marker, "braille")
137
+
138
+ if series_field:
139
+ series_map = self._group_series(rows, series_field)
140
+ any_plotted = False
141
+ for idx, (series_name, series_rows) in enumerate(series_map.items()):
142
+ x, y = self._extract_xy_sorted(series_rows, x_field, y_field)
143
+ if not x or not y:
144
+ continue
145
+ any_plotted = True
146
+ color = self._series_colors[idx % len(self._series_colors)]
147
+ plt.plot(x, y, color=color, marker=marker, label=series_name)
148
+ if not any_plotted:
149
+ return f"[No data: no valid values for '{x_field}' / '{y_field}']"
150
+ return None
151
+
152
+ x, y = self._extract_xy_sorted(rows, x_field, y_field)
153
+ if not x or not y:
154
+ return f"[No data: no valid values for '{x_field}' / '{y_field}']"
155
+ color = self._safe_color(chart.options.color, "cyan+")
156
+ plt.plot(x, y, color=color, marker=marker)
157
+ return None
158
+
159
+ def _render_scatter(self, chart: ScatterChart, rows: list[dict], plt) -> str | None:
160
+ x_field = chart.encoding.x.field
161
+ y_field = chart.encoding.y.field
162
+ series_field = chart.encoding.series.field if chart.encoding.series else None
163
+
164
+ marker = self._safe_marker(chart.options.marker, "dot")
165
+
166
+ if series_field:
167
+ series_map = self._group_series(rows, series_field)
168
+ any_plotted = False
169
+ for idx, (series_name, series_rows) in enumerate(series_map.items()):
170
+ x, y = self._extract_xy(series_rows, x_field, y_field)
171
+ if not x or not y:
172
+ continue
173
+ any_plotted = True
174
+ color = self._series_colors[idx % len(self._series_colors)]
175
+ plt.scatter(x, y, color=color, marker=marker, label=series_name)
176
+ if not any_plotted:
177
+ return f"[No data: no valid values for '{x_field}' / '{y_field}']"
178
+ return None
179
+
180
+ x, y = self._extract_xy(rows, x_field, y_field)
181
+ if not x or not y:
182
+ return f"[No data: no valid values for '{x_field}' / '{y_field}']"
183
+ color = self._safe_color(chart.options.color, "red+")
184
+ plt.scatter(x, y, color=color, marker=marker)
185
+ return None
186
+
187
+ def _render_boxplot(self, chart: BoxplotChart, rows: list[dict], plt) -> str | None:
188
+ label_field = chart.boxplot.label_field
189
+ value_field = chart.boxplot.value_field
190
+
191
+ groups: dict[str, list[float]] = {}
192
+ for row in rows:
193
+ label = str(row.get(label_field, ""))
194
+ value = self._to_number(row.get(value_field))
195
+ if value is None:
196
+ continue
197
+ groups.setdefault(label, []).append(value)
198
+
199
+ if not groups:
200
+ return f"[No data: no valid numeric values for '{value_field}']"
201
+
202
+ labels = list(groups.keys())
203
+ data = [groups[label] for label in labels]
204
+
205
+ plt.box(labels, data)
206
+ return None
207
+
208
+ def _render_histogram(self, chart: HistogramChart, rows: list[dict], plt) -> str | None:
209
+ field = chart.histogram.field
210
+ bins = chart.histogram.bins
211
+
212
+ values: list[float] = []
213
+ for row in rows:
214
+ val = self._to_number(row.get(field))
215
+ if val is not None:
216
+ values.append(val)
217
+
218
+ if not values:
219
+ return f"[No data: no valid numeric values for '{field}']"
220
+
221
+ color = self._safe_color(chart.options.color, "green+")
222
+
223
+ plt.hist(values, bins=bins, color=color)
224
+ return None
225
+
226
+ def _extract_xy(
227
+ self, rows: Iterable[dict], x_field: str, y_field: str
228
+ ) -> tuple[list[float], list[float]]:
229
+ x: list[float] = []
230
+ y: list[float] = []
231
+ for row in rows:
232
+ x_val = self._to_number(row.get(x_field))
233
+ y_val = self._to_number(row.get(y_field))
234
+ if x_val is None or y_val is None:
235
+ continue
236
+ x.append(x_val)
237
+ y.append(y_val)
238
+ return x, y
239
+
240
+ def _extract_xy_sorted(
241
+ self, rows: Iterable[dict], x_field: str, y_field: str
242
+ ) -> tuple[list[float], list[float]]:
243
+ """Extract x/y pairs and sort by x for proper line chart rendering."""
244
+ pairs: list[tuple[float, float]] = []
245
+ for row in rows:
246
+ x_val = self._to_number(row.get(x_field))
247
+ y_val = self._to_number(row.get(y_field))
248
+ if x_val is None or y_val is None:
249
+ continue
250
+ pairs.append((x_val, y_val))
251
+ pairs.sort(key=lambda p: p[0])
252
+ x = [p[0] for p in pairs]
253
+ y = [p[1] for p in pairs]
254
+ return x, y
255
+
256
+ def _group_series(
257
+ self, rows: Iterable[dict], series_field: str
258
+ ) -> dict[str, list[dict]]:
259
+ groups: dict[str, list[dict]] = defaultdict(list)
260
+ for row in rows:
261
+ key = str(row.get(series_field, ""))
262
+ groups[key].append(row)
263
+ return dict(groups)
264
+
265
+ def _build_series_matrix(
266
+ self,
267
+ rows: Iterable[dict],
268
+ x_field: str,
269
+ y_field: str,
270
+ series_field: str,
271
+ ) -> tuple[list[str], list[str], list[list[float]]]:
272
+ categories: list[str] = []
273
+ series_names: list[str] = []
274
+ data: dict[str, dict[str, float]] = {}
275
+
276
+ for row in rows:
277
+ category = str(row.get(x_field, ""))
278
+ series_name = str(row.get(series_field, ""))
279
+ value = self._to_number(row.get(y_field))
280
+ if value is None:
281
+ continue
282
+
283
+ if category not in categories:
284
+ categories.append(category)
285
+ if series_name not in series_names:
286
+ series_names.append(series_name)
287
+
288
+ data.setdefault(series_name, {})
289
+ data[series_name][category] = data[series_name].get(category, 0.0) + value
290
+
291
+ series_values: list[list[float]] = []
292
+ for series_name in series_names:
293
+ values = [
294
+ data.get(series_name, {}).get(category, 0.0) for category in categories
295
+ ]
296
+ series_values.append(values)
297
+
298
+ return categories, series_names, series_values
299
+
300
+ def _to_number(self, value: object) -> float | None:
301
+ if value is None:
302
+ return None
303
+ if isinstance(value, bool):
304
+ return None
305
+ if isinstance(value, (int, float)):
306
+ return float(value)
307
+ if isinstance(value, datetime):
308
+ return value.timestamp()
309
+ if isinstance(value, time):
310
+ return self._time_to_seconds(value)
311
+ if isinstance(value, str):
312
+ try:
313
+ return float(value)
314
+ except ValueError:
315
+ # Handle Z suffix (e.g., "2024-01-01T00:00:00Z")
316
+ normalized = value
317
+ if value.endswith("Z"):
318
+ normalized = value[:-1] + "+00:00"
319
+ try:
320
+ return datetime.fromisoformat(normalized).timestamp()
321
+ except ValueError:
322
+ pass
323
+ try:
324
+ return self._time_to_seconds(time.fromisoformat(normalized))
325
+ except ValueError:
326
+ pass
327
+ # Try YYYY-MM format (e.g., "2023-06")
328
+ if re.match(r"^\d{4}-\d{2}$", value):
329
+ try:
330
+ return datetime.fromisoformat(f"{value}-01").timestamp()
331
+ except ValueError:
332
+ pass
333
+ return None
334
+ return None
335
+
336
+ def _time_to_seconds(self, value: time) -> float:
337
+ """Convert time-only values to seconds since midnight."""
338
+ return (
339
+ value.hour * 3600
340
+ + value.minute * 60
341
+ + value.second
342
+ + value.microsecond / 1_000_000
343
+ )
344
+
345
+ def _safe_color(self, color: str | None, default: str) -> str:
346
+ """Return validated color or default if invalid."""
347
+ if not color:
348
+ return default
349
+ # plotext accepts color names like "red+", "blue", etc.
350
+ # If an invalid color is used, plotext may throw; keep known-good defaults
351
+ valid_colors = {
352
+ "black",
353
+ "red",
354
+ "green",
355
+ "yellow",
356
+ "blue",
357
+ "magenta",
358
+ "cyan",
359
+ "white",
360
+ "black+",
361
+ "red+",
362
+ "green+",
363
+ "yellow+",
364
+ "blue+",
365
+ "magenta+",
366
+ "cyan+",
367
+ "white+",
368
+ }
369
+ return color if color in valid_colors else default
370
+
371
+ def _safe_marker(self, marker: str | None, default: str) -> str:
372
+ """Return validated marker or default if invalid."""
373
+ if not marker:
374
+ return default
375
+ # plotext marker options
376
+ valid_markers = {
377
+ "sd",
378
+ "dot",
379
+ "hd",
380
+ "fhd",
381
+ "braille",
382
+ "heart",
383
+ "point",
384
+ }
385
+ return marker if marker in valid_markers else default