sqlsaber-viz 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlsaber_viz/__init__.py +19 -0
- sqlsaber_viz/data_loader.py +143 -0
- sqlsaber_viz/prompts.py +31 -0
- sqlsaber_viz/renderers/__init__.py +6 -0
- sqlsaber_viz/renderers/base.py +13 -0
- sqlsaber_viz/renderers/html_renderer.py +17 -0
- sqlsaber_viz/renderers/plotext_renderer.py +385 -0
- sqlsaber_viz/spec.py +130 -0
- sqlsaber_viz/spec_agent.py +144 -0
- sqlsaber_viz/templates.py +175 -0
- sqlsaber_viz/tools.py +234 -0
- sqlsaber_viz/transforms.py +155 -0
- sqlsaber_viz-0.1.1.dist-info/METADATA +12 -0
- sqlsaber_viz-0.1.1.dist-info/RECORD +16 -0
- sqlsaber_viz-0.1.1.dist-info/WHEEL +4 -0
- sqlsaber_viz-0.1.1.dist-info/entry_points.txt +2 -0
sqlsaber_viz/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""SQLSaber visualization plugin."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from sqlsaber.tools.registry import ToolRegistry
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def register_tools(registry: "ToolRegistry | None" = None):
|
|
10
|
+
"""Register visualization tools.
|
|
11
|
+
|
|
12
|
+
Returns list of tool classes for sqlsaber to register.
|
|
13
|
+
"""
|
|
14
|
+
from .tools import VizTool
|
|
15
|
+
|
|
16
|
+
return [VizTool]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = ["register_tools"]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Helpers for loading SQL result payloads and extracting summaries."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from datetime import date, datetime, time
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from pydantic_ai import RunContext
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def find_tool_output_payload(ctx: "RunContext", tool_call_id: str) -> dict | None:
|
|
14
|
+
"""Find tool output from RunContext message history."""
|
|
15
|
+
return find_tool_output_in_messages(ctx.messages, tool_call_id)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def find_tool_output_in_messages(messages: list, tool_call_id: str) -> dict | None:
|
|
19
|
+
"""Find tool output from a list of ModelMessage objects."""
|
|
20
|
+
for message in reversed(messages):
|
|
21
|
+
for part in getattr(message, "parts", []):
|
|
22
|
+
if getattr(part, "part_kind", "") not in (
|
|
23
|
+
"tool-return",
|
|
24
|
+
"builtin-tool-return",
|
|
25
|
+
):
|
|
26
|
+
continue
|
|
27
|
+
if getattr(part, "tool_call_id", None) != tool_call_id:
|
|
28
|
+
continue
|
|
29
|
+
content = getattr(part, "content", None)
|
|
30
|
+
if isinstance(content, dict):
|
|
31
|
+
return content
|
|
32
|
+
if isinstance(content, str):
|
|
33
|
+
try:
|
|
34
|
+
parsed = json.loads(content)
|
|
35
|
+
except json.JSONDecodeError:
|
|
36
|
+
return {"result": content}
|
|
37
|
+
if isinstance(parsed, dict):
|
|
38
|
+
return parsed
|
|
39
|
+
return {"result": parsed}
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def extract_data_summary(payload: dict) -> dict:
|
|
44
|
+
"""Extract column info and samples from SQL result payload.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
{
|
|
48
|
+
"columns": [
|
|
49
|
+
{"name": "col1", "type": "string", "sample": ["a", "b", "c"]},
|
|
50
|
+
{"name": "col2", "type": "number", "sample": [1, 2, 3]},
|
|
51
|
+
],
|
|
52
|
+
"row_count": 150,
|
|
53
|
+
"rows": [...] # Full rows for rendering
|
|
54
|
+
}
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
results = payload.get("results")
|
|
58
|
+
rows = _coerce_rows(results) if isinstance(results, list) else []
|
|
59
|
+
row_count = payload.get("row_count")
|
|
60
|
+
if not isinstance(row_count, int):
|
|
61
|
+
row_count = len(rows)
|
|
62
|
+
|
|
63
|
+
columns = _extract_columns(rows)
|
|
64
|
+
return {"columns": columns, "row_count": row_count, "rows": rows}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def infer_column_type(values: list[object]) -> str:
|
|
68
|
+
"""Infer column type from sample values.
|
|
69
|
+
|
|
70
|
+
Returns: "number", "string", "time", "boolean", or "null"
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
cleaned = [value for value in values if value is not None]
|
|
74
|
+
if not cleaned:
|
|
75
|
+
return "null"
|
|
76
|
+
|
|
77
|
+
if all(isinstance(value, bool) for value in cleaned):
|
|
78
|
+
return "boolean"
|
|
79
|
+
|
|
80
|
+
if all(isinstance(value, (int, float)) for value in cleaned):
|
|
81
|
+
return "number"
|
|
82
|
+
|
|
83
|
+
if all(_is_time_value(value) for value in cleaned):
|
|
84
|
+
return "time"
|
|
85
|
+
|
|
86
|
+
return "string"
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _extract_columns(rows: list[dict[str, object]]) -> list[dict[str, object]]:
|
|
90
|
+
if not rows:
|
|
91
|
+
return []
|
|
92
|
+
|
|
93
|
+
# Use the union of keys from the first 50 rows to avoid missing sparse columns.
|
|
94
|
+
keys: list[str] = []
|
|
95
|
+
seen: set[str] = set()
|
|
96
|
+
for row in rows[:50]:
|
|
97
|
+
for key in row.keys():
|
|
98
|
+
if key not in seen:
|
|
99
|
+
seen.add(key)
|
|
100
|
+
keys.append(key)
|
|
101
|
+
|
|
102
|
+
columns: list[dict[str, object]] = []
|
|
103
|
+
for key in keys:
|
|
104
|
+
sample_values = [row.get(key) for row in rows[:20] if key in row]
|
|
105
|
+
column_type = infer_column_type(sample_values)
|
|
106
|
+
columns.append(
|
|
107
|
+
{
|
|
108
|
+
"name": key,
|
|
109
|
+
"type": column_type,
|
|
110
|
+
"sample": sample_values[:5],
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return columns
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _coerce_rows(rows: list[object]) -> list[dict[str, object]]:
|
|
118
|
+
coerced: list[dict[str, object]] = []
|
|
119
|
+
for row in rows:
|
|
120
|
+
if isinstance(row, dict):
|
|
121
|
+
coerced.append({str(key): value for key, value in row.items()})
|
|
122
|
+
else:
|
|
123
|
+
coerced.append({"value": row})
|
|
124
|
+
return coerced
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _is_time_value(value: object) -> bool:
|
|
128
|
+
if isinstance(value, (datetime, date, time)):
|
|
129
|
+
return True
|
|
130
|
+
if isinstance(value, str):
|
|
131
|
+
normalized = value
|
|
132
|
+
if value.endswith("Z"):
|
|
133
|
+
normalized = value[:-1] + "+00:00"
|
|
134
|
+
try:
|
|
135
|
+
datetime.fromisoformat(normalized)
|
|
136
|
+
return True
|
|
137
|
+
except ValueError:
|
|
138
|
+
try:
|
|
139
|
+
time.fromisoformat(normalized)
|
|
140
|
+
return True
|
|
141
|
+
except ValueError:
|
|
142
|
+
return False
|
|
143
|
+
return False
|
sqlsaber_viz/prompts.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Prompt definitions for viz spec generation."""
|
|
2
|
+
|
|
3
|
+
VIZ_SYSTEM_PROMPT = """You are a visualization spec generator. Given a user's request and data summary, generate a valid JSON visualization spec.
|
|
4
|
+
|
|
5
|
+
## Workflow
|
|
6
|
+
1. Decide the appropriate chart type based on the request and data
|
|
7
|
+
2. Call `get_vizspec_template` with the chart type and file to get the correct spec structure
|
|
8
|
+
3. Fill in the template with actual column names from the provided data summary
|
|
9
|
+
4. Return ONLY the final JSON spec (no explanations, no markdown code blocks)
|
|
10
|
+
|
|
11
|
+
## Chart Type Selection
|
|
12
|
+
- Comparing categories → bar
|
|
13
|
+
- Comparing categories across series → bar with encoding.series
|
|
14
|
+
- Trend over time → line
|
|
15
|
+
- Correlation between two numbers → scatter
|
|
16
|
+
- Distribution of one variable → histogram
|
|
17
|
+
- Distribution comparison across groups → boxplot
|
|
18
|
+
|
|
19
|
+
## Transform Operations (optional, add to "transform" array)
|
|
20
|
+
- {"sort": [{"field": "col", "dir": "desc"}]} - Sort data
|
|
21
|
+
- {"limit": 20} - Limit rows (recommended for bar charts with many categories)
|
|
22
|
+
- {"filter": {"field": "col", "op": "!=", "value": null}} - Filter rows
|
|
23
|
+
|
|
24
|
+
## Rules
|
|
25
|
+
1. ALWAYS call `get_vizspec_template` first to get the correct structure
|
|
26
|
+
2. Use ONLY columns that exist in the provided data summary
|
|
27
|
+
3. Match field types: category columns for x in bar charts, numeric columns for y
|
|
28
|
+
4. Add limit transform for bar charts to avoid overcrowding (10-20 bars max)
|
|
29
|
+
5. Sort bar charts by y value descending for better readability
|
|
30
|
+
6. Title should describe what the chart shows
|
|
31
|
+
"""
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Renderer protocol for visualization outputs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Protocol
|
|
6
|
+
|
|
7
|
+
from ..spec import VizSpec
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RendererProtocol(Protocol):
|
|
11
|
+
def render(self, spec: VizSpec, rows: list[dict]) -> str:
|
|
12
|
+
"""Render a visualization spec with data rows."""
|
|
13
|
+
...
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Placeholder HTML renderer for future web UI support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from ..spec import VizSpec
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class HtmlRenderer:
|
|
9
|
+
"""Render VizSpec to HTML.
|
|
10
|
+
|
|
11
|
+
Placeholder implementation; currently returns an empty string.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def render(self, spec: VizSpec, rows: list[dict]) -> str:
|
|
15
|
+
_ = spec
|
|
16
|
+
_ = rows
|
|
17
|
+
return ""
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""Plotext renderer for terminal charts."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from datetime import datetime, time
|
|
8
|
+
from typing import Iterable
|
|
9
|
+
|
|
10
|
+
from ..spec import (
|
|
11
|
+
BarChart,
|
|
12
|
+
BoxplotChart,
|
|
13
|
+
HistogramChart,
|
|
14
|
+
LineChart,
|
|
15
|
+
ScatterChart,
|
|
16
|
+
VizSpec,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PlotextRenderer:
|
|
21
|
+
"""Render VizSpec to terminal using plotext."""
|
|
22
|
+
|
|
23
|
+
_series_colors = [
|
|
24
|
+
"cyan+",
|
|
25
|
+
"yellow+",
|
|
26
|
+
"red+",
|
|
27
|
+
"green+",
|
|
28
|
+
"blue+",
|
|
29
|
+
"magenta+",
|
|
30
|
+
"white+",
|
|
31
|
+
]
|
|
32
|
+
_default_width = 80
|
|
33
|
+
_default_height = 25
|
|
34
|
+
|
|
35
|
+
def render(self, spec: VizSpec, rows: list[dict]) -> str:
|
|
36
|
+
"""Render spec with data to ASCII chart string.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
ASCII chart string from plt.build(), or error message if rendering fails.
|
|
40
|
+
"""
|
|
41
|
+
import plotext as plt
|
|
42
|
+
|
|
43
|
+
plt.clf()
|
|
44
|
+
plt.clear_figure()
|
|
45
|
+
|
|
46
|
+
chart = spec.chart
|
|
47
|
+
options = chart.options
|
|
48
|
+
|
|
49
|
+
width = options.width or self._default_width
|
|
50
|
+
height = options.height or self._default_height
|
|
51
|
+
plt.plot_size(width=width, height=height)
|
|
52
|
+
|
|
53
|
+
if spec.title:
|
|
54
|
+
plt.title(spec.title)
|
|
55
|
+
|
|
56
|
+
error_msg: str | None = None
|
|
57
|
+
try:
|
|
58
|
+
if isinstance(chart, BarChart):
|
|
59
|
+
error_msg = self._render_bar(chart, rows, plt)
|
|
60
|
+
elif isinstance(chart, LineChart):
|
|
61
|
+
error_msg = self._render_line(chart, rows, plt)
|
|
62
|
+
elif isinstance(chart, ScatterChart):
|
|
63
|
+
error_msg = self._render_scatter(chart, rows, plt)
|
|
64
|
+
elif isinstance(chart, BoxplotChart):
|
|
65
|
+
error_msg = self._render_boxplot(chart, rows, plt)
|
|
66
|
+
elif isinstance(chart, HistogramChart):
|
|
67
|
+
error_msg = self._render_histogram(chart, rows, plt)
|
|
68
|
+
else:
|
|
69
|
+
return f"[Unsupported chart type: {type(chart).__name__}]"
|
|
70
|
+
except Exception as e:
|
|
71
|
+
return f"[Chart rendering error: {e}]"
|
|
72
|
+
|
|
73
|
+
if error_msg:
|
|
74
|
+
return error_msg
|
|
75
|
+
|
|
76
|
+
if options.x_label:
|
|
77
|
+
plt.xlabel(options.x_label)
|
|
78
|
+
if options.y_label:
|
|
79
|
+
plt.ylabel(options.y_label)
|
|
80
|
+
|
|
81
|
+
return plt.build()
|
|
82
|
+
|
|
83
|
+
def _render_bar(self, chart: BarChart, rows: list[dict], plt) -> str | None:
|
|
84
|
+
x_field = chart.encoding.x.field
|
|
85
|
+
y_field = chart.encoding.y.field
|
|
86
|
+
series_field = chart.encoding.series.field if chart.encoding.series else None
|
|
87
|
+
|
|
88
|
+
orientation = "h" if chart.orientation == "horizontal" else "v"
|
|
89
|
+
|
|
90
|
+
if series_field:
|
|
91
|
+
categories, series_names, series_values = self._build_series_matrix(
|
|
92
|
+
rows, x_field, y_field, series_field
|
|
93
|
+
)
|
|
94
|
+
if not categories or not series_names:
|
|
95
|
+
return f"[No data: no valid values for '{x_field}' / '{y_field}']"
|
|
96
|
+
if chart.mode == "stacked":
|
|
97
|
+
plt.stacked_bar(
|
|
98
|
+
categories,
|
|
99
|
+
series_values,
|
|
100
|
+
labels=series_names,
|
|
101
|
+
orientation=orientation,
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
plt.multiple_bar(
|
|
105
|
+
categories,
|
|
106
|
+
series_values,
|
|
107
|
+
labels=series_names,
|
|
108
|
+
orientation=orientation,
|
|
109
|
+
)
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
# Aggregate by category (sum) for consistency with series path
|
|
113
|
+
aggregated: dict[str, float] = {}
|
|
114
|
+
for row in rows:
|
|
115
|
+
category = str(row.get(x_field, ""))
|
|
116
|
+
value = self._to_number(row.get(y_field))
|
|
117
|
+
if value is None:
|
|
118
|
+
continue
|
|
119
|
+
aggregated[category] = aggregated.get(category, 0.0) + value
|
|
120
|
+
|
|
121
|
+
if not aggregated:
|
|
122
|
+
return f"[No data: no valid numeric values for '{y_field}']"
|
|
123
|
+
|
|
124
|
+
categories = list(aggregated.keys())
|
|
125
|
+
values = list(aggregated.values())
|
|
126
|
+
|
|
127
|
+
color = self._safe_color(chart.options.color, "blue+")
|
|
128
|
+
plt.bar(categories, values, color=color, orientation=orientation)
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
def _render_line(self, chart: LineChart, rows: list[dict], plt) -> str | None:
|
|
132
|
+
x_field = chart.encoding.x.field
|
|
133
|
+
y_field = chart.encoding.y.field
|
|
134
|
+
series_field = chart.encoding.series.field if chart.encoding.series else None
|
|
135
|
+
|
|
136
|
+
marker = self._safe_marker(chart.options.marker, "braille")
|
|
137
|
+
|
|
138
|
+
if series_field:
|
|
139
|
+
series_map = self._group_series(rows, series_field)
|
|
140
|
+
any_plotted = False
|
|
141
|
+
for idx, (series_name, series_rows) in enumerate(series_map.items()):
|
|
142
|
+
x, y = self._extract_xy_sorted(series_rows, x_field, y_field)
|
|
143
|
+
if not x or not y:
|
|
144
|
+
continue
|
|
145
|
+
any_plotted = True
|
|
146
|
+
color = self._series_colors[idx % len(self._series_colors)]
|
|
147
|
+
plt.plot(x, y, color=color, marker=marker, label=series_name)
|
|
148
|
+
if not any_plotted:
|
|
149
|
+
return f"[No data: no valid values for '{x_field}' / '{y_field}']"
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
x, y = self._extract_xy_sorted(rows, x_field, y_field)
|
|
153
|
+
if not x or not y:
|
|
154
|
+
return f"[No data: no valid values for '{x_field}' / '{y_field}']"
|
|
155
|
+
color = self._safe_color(chart.options.color, "cyan+")
|
|
156
|
+
plt.plot(x, y, color=color, marker=marker)
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
def _render_scatter(self, chart: ScatterChart, rows: list[dict], plt) -> str | None:
|
|
160
|
+
x_field = chart.encoding.x.field
|
|
161
|
+
y_field = chart.encoding.y.field
|
|
162
|
+
series_field = chart.encoding.series.field if chart.encoding.series else None
|
|
163
|
+
|
|
164
|
+
marker = self._safe_marker(chart.options.marker, "dot")
|
|
165
|
+
|
|
166
|
+
if series_field:
|
|
167
|
+
series_map = self._group_series(rows, series_field)
|
|
168
|
+
any_plotted = False
|
|
169
|
+
for idx, (series_name, series_rows) in enumerate(series_map.items()):
|
|
170
|
+
x, y = self._extract_xy(series_rows, x_field, y_field)
|
|
171
|
+
if not x or not y:
|
|
172
|
+
continue
|
|
173
|
+
any_plotted = True
|
|
174
|
+
color = self._series_colors[idx % len(self._series_colors)]
|
|
175
|
+
plt.scatter(x, y, color=color, marker=marker, label=series_name)
|
|
176
|
+
if not any_plotted:
|
|
177
|
+
return f"[No data: no valid values for '{x_field}' / '{y_field}']"
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
x, y = self._extract_xy(rows, x_field, y_field)
|
|
181
|
+
if not x or not y:
|
|
182
|
+
return f"[No data: no valid values for '{x_field}' / '{y_field}']"
|
|
183
|
+
color = self._safe_color(chart.options.color, "red+")
|
|
184
|
+
plt.scatter(x, y, color=color, marker=marker)
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
def _render_boxplot(self, chart: BoxplotChart, rows: list[dict], plt) -> str | None:
|
|
188
|
+
label_field = chart.boxplot.label_field
|
|
189
|
+
value_field = chart.boxplot.value_field
|
|
190
|
+
|
|
191
|
+
groups: dict[str, list[float]] = {}
|
|
192
|
+
for row in rows:
|
|
193
|
+
label = str(row.get(label_field, ""))
|
|
194
|
+
value = self._to_number(row.get(value_field))
|
|
195
|
+
if value is None:
|
|
196
|
+
continue
|
|
197
|
+
groups.setdefault(label, []).append(value)
|
|
198
|
+
|
|
199
|
+
if not groups:
|
|
200
|
+
return f"[No data: no valid numeric values for '{value_field}']"
|
|
201
|
+
|
|
202
|
+
labels = list(groups.keys())
|
|
203
|
+
data = [groups[label] for label in labels]
|
|
204
|
+
|
|
205
|
+
plt.box(labels, data)
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
def _render_histogram(self, chart: HistogramChart, rows: list[dict], plt) -> str | None:
|
|
209
|
+
field = chart.histogram.field
|
|
210
|
+
bins = chart.histogram.bins
|
|
211
|
+
|
|
212
|
+
values: list[float] = []
|
|
213
|
+
for row in rows:
|
|
214
|
+
val = self._to_number(row.get(field))
|
|
215
|
+
if val is not None:
|
|
216
|
+
values.append(val)
|
|
217
|
+
|
|
218
|
+
if not values:
|
|
219
|
+
return f"[No data: no valid numeric values for '{field}']"
|
|
220
|
+
|
|
221
|
+
color = self._safe_color(chart.options.color, "green+")
|
|
222
|
+
|
|
223
|
+
plt.hist(values, bins=bins, color=color)
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
def _extract_xy(
|
|
227
|
+
self, rows: Iterable[dict], x_field: str, y_field: str
|
|
228
|
+
) -> tuple[list[float], list[float]]:
|
|
229
|
+
x: list[float] = []
|
|
230
|
+
y: list[float] = []
|
|
231
|
+
for row in rows:
|
|
232
|
+
x_val = self._to_number(row.get(x_field))
|
|
233
|
+
y_val = self._to_number(row.get(y_field))
|
|
234
|
+
if x_val is None or y_val is None:
|
|
235
|
+
continue
|
|
236
|
+
x.append(x_val)
|
|
237
|
+
y.append(y_val)
|
|
238
|
+
return x, y
|
|
239
|
+
|
|
240
|
+
def _extract_xy_sorted(
|
|
241
|
+
self, rows: Iterable[dict], x_field: str, y_field: str
|
|
242
|
+
) -> tuple[list[float], list[float]]:
|
|
243
|
+
"""Extract x/y pairs and sort by x for proper line chart rendering."""
|
|
244
|
+
pairs: list[tuple[float, float]] = []
|
|
245
|
+
for row in rows:
|
|
246
|
+
x_val = self._to_number(row.get(x_field))
|
|
247
|
+
y_val = self._to_number(row.get(y_field))
|
|
248
|
+
if x_val is None or y_val is None:
|
|
249
|
+
continue
|
|
250
|
+
pairs.append((x_val, y_val))
|
|
251
|
+
pairs.sort(key=lambda p: p[0])
|
|
252
|
+
x = [p[0] for p in pairs]
|
|
253
|
+
y = [p[1] for p in pairs]
|
|
254
|
+
return x, y
|
|
255
|
+
|
|
256
|
+
def _group_series(
|
|
257
|
+
self, rows: Iterable[dict], series_field: str
|
|
258
|
+
) -> dict[str, list[dict]]:
|
|
259
|
+
groups: dict[str, list[dict]] = defaultdict(list)
|
|
260
|
+
for row in rows:
|
|
261
|
+
key = str(row.get(series_field, ""))
|
|
262
|
+
groups[key].append(row)
|
|
263
|
+
return dict(groups)
|
|
264
|
+
|
|
265
|
+
def _build_series_matrix(
|
|
266
|
+
self,
|
|
267
|
+
rows: Iterable[dict],
|
|
268
|
+
x_field: str,
|
|
269
|
+
y_field: str,
|
|
270
|
+
series_field: str,
|
|
271
|
+
) -> tuple[list[str], list[str], list[list[float]]]:
|
|
272
|
+
categories: list[str] = []
|
|
273
|
+
series_names: list[str] = []
|
|
274
|
+
data: dict[str, dict[str, float]] = {}
|
|
275
|
+
|
|
276
|
+
for row in rows:
|
|
277
|
+
category = str(row.get(x_field, ""))
|
|
278
|
+
series_name = str(row.get(series_field, ""))
|
|
279
|
+
value = self._to_number(row.get(y_field))
|
|
280
|
+
if value is None:
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
if category not in categories:
|
|
284
|
+
categories.append(category)
|
|
285
|
+
if series_name not in series_names:
|
|
286
|
+
series_names.append(series_name)
|
|
287
|
+
|
|
288
|
+
data.setdefault(series_name, {})
|
|
289
|
+
data[series_name][category] = data[series_name].get(category, 0.0) + value
|
|
290
|
+
|
|
291
|
+
series_values: list[list[float]] = []
|
|
292
|
+
for series_name in series_names:
|
|
293
|
+
values = [
|
|
294
|
+
data.get(series_name, {}).get(category, 0.0) for category in categories
|
|
295
|
+
]
|
|
296
|
+
series_values.append(values)
|
|
297
|
+
|
|
298
|
+
return categories, series_names, series_values
|
|
299
|
+
|
|
300
|
+
def _to_number(self, value: object) -> float | None:
|
|
301
|
+
if value is None:
|
|
302
|
+
return None
|
|
303
|
+
if isinstance(value, bool):
|
|
304
|
+
return None
|
|
305
|
+
if isinstance(value, (int, float)):
|
|
306
|
+
return float(value)
|
|
307
|
+
if isinstance(value, datetime):
|
|
308
|
+
return value.timestamp()
|
|
309
|
+
if isinstance(value, time):
|
|
310
|
+
return self._time_to_seconds(value)
|
|
311
|
+
if isinstance(value, str):
|
|
312
|
+
try:
|
|
313
|
+
return float(value)
|
|
314
|
+
except ValueError:
|
|
315
|
+
# Handle Z suffix (e.g., "2024-01-01T00:00:00Z")
|
|
316
|
+
normalized = value
|
|
317
|
+
if value.endswith("Z"):
|
|
318
|
+
normalized = value[:-1] + "+00:00"
|
|
319
|
+
try:
|
|
320
|
+
return datetime.fromisoformat(normalized).timestamp()
|
|
321
|
+
except ValueError:
|
|
322
|
+
pass
|
|
323
|
+
try:
|
|
324
|
+
return self._time_to_seconds(time.fromisoformat(normalized))
|
|
325
|
+
except ValueError:
|
|
326
|
+
pass
|
|
327
|
+
# Try YYYY-MM format (e.g., "2023-06")
|
|
328
|
+
if re.match(r"^\d{4}-\d{2}$", value):
|
|
329
|
+
try:
|
|
330
|
+
return datetime.fromisoformat(f"{value}-01").timestamp()
|
|
331
|
+
except ValueError:
|
|
332
|
+
pass
|
|
333
|
+
return None
|
|
334
|
+
return None
|
|
335
|
+
|
|
336
|
+
def _time_to_seconds(self, value: time) -> float:
|
|
337
|
+
"""Convert time-only values to seconds since midnight."""
|
|
338
|
+
return (
|
|
339
|
+
value.hour * 3600
|
|
340
|
+
+ value.minute * 60
|
|
341
|
+
+ value.second
|
|
342
|
+
+ value.microsecond / 1_000_000
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def _safe_color(self, color: str | None, default: str) -> str:
|
|
346
|
+
"""Return validated color or default if invalid."""
|
|
347
|
+
if not color:
|
|
348
|
+
return default
|
|
349
|
+
# plotext accepts color names like "red+", "blue", etc.
|
|
350
|
+
# If an invalid color is used, plotext may throw; keep known-good defaults
|
|
351
|
+
valid_colors = {
|
|
352
|
+
"black",
|
|
353
|
+
"red",
|
|
354
|
+
"green",
|
|
355
|
+
"yellow",
|
|
356
|
+
"blue",
|
|
357
|
+
"magenta",
|
|
358
|
+
"cyan",
|
|
359
|
+
"white",
|
|
360
|
+
"black+",
|
|
361
|
+
"red+",
|
|
362
|
+
"green+",
|
|
363
|
+
"yellow+",
|
|
364
|
+
"blue+",
|
|
365
|
+
"magenta+",
|
|
366
|
+
"cyan+",
|
|
367
|
+
"white+",
|
|
368
|
+
}
|
|
369
|
+
return color if color in valid_colors else default
|
|
370
|
+
|
|
371
|
+
def _safe_marker(self, marker: str | None, default: str) -> str:
|
|
372
|
+
"""Return validated marker or default if invalid."""
|
|
373
|
+
if not marker:
|
|
374
|
+
return default
|
|
375
|
+
# plotext marker options
|
|
376
|
+
valid_markers = {
|
|
377
|
+
"sd",
|
|
378
|
+
"dot",
|
|
379
|
+
"hd",
|
|
380
|
+
"fhd",
|
|
381
|
+
"braille",
|
|
382
|
+
"heart",
|
|
383
|
+
"point",
|
|
384
|
+
}
|
|
385
|
+
return marker if marker in valid_markers else default
|