rosetta-sql 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/generate_csv_data.py +83 -0
- benchmark/import_data.py +168 -0
- rosetta/__init__.py +3 -0
- rosetta/__main__.py +8 -0
- rosetta/benchmark.py +1678 -0
- rosetta/buglist.py +108 -0
- rosetta/cli/__init__.py +11 -0
- rosetta/cli/config_cmd.py +243 -0
- rosetta/cli/exec.py +219 -0
- rosetta/cli/interactive_cmd.py +124 -0
- rosetta/cli/list_cmd.py +215 -0
- rosetta/cli/main.py +617 -0
- rosetta/cli/output.py +545 -0
- rosetta/cli/result.py +61 -0
- rosetta/cli/result_cmd.py +247 -0
- rosetta/cli/run.py +625 -0
- rosetta/cli/status.py +161 -0
- rosetta/comparator.py +205 -0
- rosetta/config.py +139 -0
- rosetta/executor.py +403 -0
- rosetta/flamegraph.py +630 -0
- rosetta/interactive.py +1790 -0
- rosetta/models.py +197 -0
- rosetta/parser.py +308 -0
- rosetta/reporter/__init__.py +1 -0
- rosetta/reporter/bench_html.py +1457 -0
- rosetta/reporter/bench_text.py +162 -0
- rosetta/reporter/history.py +1686 -0
- rosetta/reporter/html.py +644 -0
- rosetta/reporter/text.py +110 -0
- rosetta/runner.py +3089 -0
- rosetta/ui.py +736 -0
- rosetta/whitelist.py +161 -0
- rosetta_sql-1.0.0.dist-info/LICENSE +21 -0
- rosetta_sql-1.0.0.dist-info/METADATA +379 -0
- rosetta_sql-1.0.0.dist-info/RECORD +42 -0
- rosetta_sql-1.0.0.dist-info/WHEEL +5 -0
- rosetta_sql-1.0.0.dist-info/entry_points.txt +2 -0
- rosetta_sql-1.0.0.dist-info/top_level.txt +4 -0
- skills/rosetta/scripts/install_rosetta.py +469 -0
- skills/rosetta/scripts/rosetta_wrapper.py +377 -0
- tests/test_cli.py +749 -0
rosetta/ui.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
1
|
+
"""Rich terminal UI for Rosetta."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from rich.console import Console, Group
|
|
9
|
+
from rich.panel import Panel
|
|
10
|
+
from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
|
|
11
|
+
SpinnerColumn, TextColumn, TimeElapsedColumn,
|
|
12
|
+
TimeRemainingColumn)
|
|
13
|
+
from rich.rule import Rule
|
|
14
|
+
from rich.table import Table
|
|
15
|
+
from rich.text import Text
|
|
16
|
+
|
|
17
|
+
from .models import CompareResult
|
|
18
|
+
|
|
19
|
+
# The real stderr console for final output and live progress.
|
|
20
|
+
console = Console(stderr=True)
|
|
21
|
+
|
|
22
|
+
_log = logging.getLogger("rosetta")
|
|
23
|
+
_BOX = "cyan"
|
|
24
|
+
|
|
25
|
+
# Collects rich renderables for final Panel output.
|
|
26
|
+
_renderables: List = []
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _add(renderable):
|
|
30
|
+
"""Add a renderable to the output buffer."""
|
|
31
|
+
_renderables.append(renderable)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def flush_all(title: str = "Rosetta"):
|
|
35
|
+
"""Flush all buffered renderables as a single Panel."""
|
|
36
|
+
if not _renderables:
|
|
37
|
+
return
|
|
38
|
+
|
|
39
|
+
group = Group(*_renderables)
|
|
40
|
+
console.print(Panel(
|
|
41
|
+
group,
|
|
42
|
+
title=f"[bold]{title}[/bold]",
|
|
43
|
+
border_style=_BOX,
|
|
44
|
+
expand=True,
|
|
45
|
+
padding=(0, 1),
|
|
46
|
+
))
|
|
47
|
+
_renderables.clear()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# ---------------------------------------------------------------------------
|
|
51
|
+
# Banner
|
|
52
|
+
# ---------------------------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
BANNER_TEXT = (
|
|
55
|
+
"[bold cyan]"
|
|
56
|
+
r" ____ _ _" "\n"
|
|
57
|
+
r" | _ \ ___ ___ ___| |_| |_ __ _" "\n"
|
|
58
|
+
r" | |_) / _ \/ __|/ _ \ __| __/ _` |" "\n"
|
|
59
|
+
r" | _ < (_) \__ \ __/ |_| || (_| |" "\n"
|
|
60
|
+
r" |_| \_\___/|___/\___|\__|\__\__,_|"
|
|
61
|
+
"[/bold cyan]\n"
|
|
62
|
+
"[dim]Cross-DBMS SQL Behavioral Consistency Verification[/dim]\n"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def print_banner():
|
|
67
|
+
"""Buffer the Rosetta banner."""
|
|
68
|
+
_add(Text.from_markup(BANNER_TEXT))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ---------------------------------------------------------------------------
|
|
72
|
+
# Phase headers
|
|
73
|
+
# ---------------------------------------------------------------------------
|
|
74
|
+
|
|
75
|
+
def print_phase(title: str, detail: str = ""):
|
|
76
|
+
"""Buffer a phase header."""
|
|
77
|
+
text = f"[bold white]{title}[/bold white]"
|
|
78
|
+
if detail:
|
|
79
|
+
text += f" [dim]{detail}[/dim]"
|
|
80
|
+
_add(Text(""))
|
|
81
|
+
_add(Rule(Text.from_markup(text), style=_BOX))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ---------------------------------------------------------------------------
|
|
85
|
+
# Info messages
|
|
86
|
+
# ---------------------------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
def print_info(msg: str, highlight: str = ""):
|
|
89
|
+
"""Buffer an informational line."""
|
|
90
|
+
if highlight:
|
|
91
|
+
_add(Text.from_markup(f" [cyan]>[/cyan] {msg} [bold]{highlight}[/bold]"))
|
|
92
|
+
else:
|
|
93
|
+
_add(Text.from_markup(f" [cyan]>[/cyan] {msg}"))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def print_success(msg: str):
|
|
97
|
+
"""Buffer a success message."""
|
|
98
|
+
_add(Text.from_markup(f" [green]✓[/green] {msg}"))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def print_warning(msg: str):
|
|
102
|
+
"""Buffer a warning message."""
|
|
103
|
+
_add(Text.from_markup(f" [yellow]âš [/yellow] {msg}"))
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def print_error(msg: str):
|
|
107
|
+
"""Buffer an error message."""
|
|
108
|
+
_add(Text.from_markup(f" [red]✗[/red] {msg}"))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# ---------------------------------------------------------------------------
|
|
112
|
+
# Execution progress bar
|
|
113
|
+
# ---------------------------------------------------------------------------
|
|
114
|
+
|
|
115
|
+
class ExecutionProgress:
|
|
116
|
+
"""Context manager for a DBMS execution progress bar.
|
|
117
|
+
|
|
118
|
+
Multiple instances share a single rich Progress bar so that parallel
|
|
119
|
+
DBMS executions are displayed simultaneously. The shared Progress is
|
|
120
|
+
created on the first ``__enter__`` and stopped when the last instance
|
|
121
|
+
exits.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
_lock = threading.Lock()
|
|
125
|
+
_shared_progress: Optional[Progress] = None
|
|
126
|
+
_ref_count = 0
|
|
127
|
+
|
|
128
|
+
def __init__(self, dbms_name: str, total: int):
|
|
129
|
+
self.dbms_name = dbms_name
|
|
130
|
+
self.total = total
|
|
131
|
+
self._task_id = None
|
|
132
|
+
self._errors = 0
|
|
133
|
+
self._executed = 0
|
|
134
|
+
self._elapsed = 0.0
|
|
135
|
+
self._start_time = 0.0
|
|
136
|
+
|
|
137
|
+
# -- shared Progress lifecycle ------------------------------------------
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def _acquire(cls) -> Progress:
|
|
141
|
+
with cls._lock:
|
|
142
|
+
if cls._shared_progress is None:
|
|
143
|
+
cls._shared_progress = Progress(
|
|
144
|
+
SpinnerColumn(),
|
|
145
|
+
TextColumn("[bold blue]{task.fields[dbms]}[/bold blue]"),
|
|
146
|
+
BarColumn(bar_width=40),
|
|
147
|
+
MofNCompleteColumn(),
|
|
148
|
+
TextColumn("[dim]|[/dim]"),
|
|
149
|
+
TimeElapsedColumn(),
|
|
150
|
+
TextColumn("[dim]|[/dim]"),
|
|
151
|
+
TimeRemainingColumn(),
|
|
152
|
+
TextColumn("{task.fields[status]}"),
|
|
153
|
+
console=console,
|
|
154
|
+
transient=True,
|
|
155
|
+
)
|
|
156
|
+
cls._shared_progress.start()
|
|
157
|
+
cls._ref_count += 1
|
|
158
|
+
return cls._shared_progress
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def _release(cls):
|
|
162
|
+
with cls._lock:
|
|
163
|
+
cls._ref_count -= 1
|
|
164
|
+
if cls._ref_count <= 0:
|
|
165
|
+
if cls._shared_progress is not None:
|
|
166
|
+
cls._shared_progress.stop()
|
|
167
|
+
cls._shared_progress = None
|
|
168
|
+
cls._ref_count = 0
|
|
169
|
+
# Ensure cursor is visible after progress bar stops
|
|
170
|
+
try:
|
|
171
|
+
console.show_cursor()
|
|
172
|
+
except Exception:
|
|
173
|
+
pass
|
|
174
|
+
|
|
175
|
+
# -- context manager ----------------------------------------------------
|
|
176
|
+
|
|
177
|
+
def __enter__(self):
|
|
178
|
+
self._start_time = time.monotonic()
|
|
179
|
+
progress = self._acquire()
|
|
180
|
+
self._task_id = progress.add_task(
|
|
181
|
+
"exec", total=self.total,
|
|
182
|
+
dbms=self.dbms_name, status="",
|
|
183
|
+
)
|
|
184
|
+
return self
|
|
185
|
+
|
|
186
|
+
def __exit__(self, *args):
|
|
187
|
+
self._elapsed = time.monotonic() - self._start_time
|
|
188
|
+
self._release()
|
|
189
|
+
|
|
190
|
+
def advance(self, error: bool = False):
|
|
191
|
+
"""Advance progress by 1."""
|
|
192
|
+
if error:
|
|
193
|
+
self._errors += 1
|
|
194
|
+
self._executed += 1
|
|
195
|
+
status = (f"[red]{self._errors} err[/red]"
|
|
196
|
+
if self._errors else "[green]ok[/green]")
|
|
197
|
+
prog = self.__class__._shared_progress
|
|
198
|
+
if prog is not None:
|
|
199
|
+
prog.update(self._task_id, advance=1, status=status)
|
|
200
|
+
|
|
201
|
+
def set_status(self, text: str):
|
|
202
|
+
"""Set a custom status text."""
|
|
203
|
+
prog = self.__class__._shared_progress
|
|
204
|
+
if prog is not None:
|
|
205
|
+
prog.update(self._task_id, status=text)
|
|
206
|
+
|
|
207
|
+
def write_summary_to_buffer(self):
|
|
208
|
+
"""Write a static one-line summary into the buffer (call after exit)."""
|
|
209
|
+
elapsed = f"{self._elapsed:.1f}s"
|
|
210
|
+
if self._errors:
|
|
211
|
+
status = f"[yellow]{self._executed} done, {self._errors} err[/yellow]"
|
|
212
|
+
else:
|
|
213
|
+
status = f"[green]{self._executed} done[/green]"
|
|
214
|
+
_add(Text.from_markup(
|
|
215
|
+
f" [bold blue]{self.dbms_name}[/bold blue] "
|
|
216
|
+
f"{self._executed}/{self.total} "
|
|
217
|
+
f"[dim]{elapsed}[/dim] {status}"
|
|
218
|
+
))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# ---------------------------------------------------------------------------
|
|
222
|
+
# Summary table
|
|
223
|
+
# ---------------------------------------------------------------------------
|
|
224
|
+
|
|
225
|
+
def print_summary(comparisons: Dict[str, CompareResult],
|
|
226
|
+
failed_connections: set = None):
|
|
227
|
+
"""Buffer a rich summary table of comparison results."""
|
|
228
|
+
_add(Text(""))
|
|
229
|
+
_add(Rule(Text.from_markup("[bold white]Summary[/bold white]"), style=_BOX))
|
|
230
|
+
_add(Text(""))
|
|
231
|
+
|
|
232
|
+
# Detect whether any comparison has whitelisted diffs
|
|
233
|
+
has_wl = any(cmp.whitelisted > 0 for cmp in comparisons.values())
|
|
234
|
+
# Detect whether any comparison has bug-marked diffs
|
|
235
|
+
has_bug = any(cmp.bug_marked > 0 for cmp in comparisons.values())
|
|
236
|
+
|
|
237
|
+
table = Table(
|
|
238
|
+
header_style="bold",
|
|
239
|
+
show_lines=False,
|
|
240
|
+
padding=(0, 1),
|
|
241
|
+
expand=True,
|
|
242
|
+
show_edge=False,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
table.add_column("Comparison", style="white", ratio=3)
|
|
246
|
+
table.add_column("Status", justify="center", min_width=6)
|
|
247
|
+
table.add_column("Match", justify="right", style="green")
|
|
248
|
+
table.add_column("Mismatch", justify="right")
|
|
249
|
+
if has_wl:
|
|
250
|
+
table.add_column("Whitelist", justify="right")
|
|
251
|
+
if has_bug:
|
|
252
|
+
table.add_column("Bug", justify="right")
|
|
253
|
+
table.add_column("Skip", justify="right", style="dim")
|
|
254
|
+
table.add_column("Total", justify="right")
|
|
255
|
+
table.add_column("Rate", justify="right", min_width=14)
|
|
256
|
+
|
|
257
|
+
if failed_connections:
|
|
258
|
+
for name in failed_connections:
|
|
259
|
+
cols = [name, "[yellow]SKIP[/yellow]",
|
|
260
|
+
"-", "-"]
|
|
261
|
+
if has_wl:
|
|
262
|
+
cols.append("-")
|
|
263
|
+
if has_bug:
|
|
264
|
+
cols.append("-")
|
|
265
|
+
cols += ["-", "-", "[dim]conn failed[/dim]"]
|
|
266
|
+
table.add_row(*cols)
|
|
267
|
+
|
|
268
|
+
all_pass = True
|
|
269
|
+
for key, cmp in comparisons.items():
|
|
270
|
+
effective_mismatch = cmp.effective_mismatched
|
|
271
|
+
is_pass = effective_mismatch <= 0
|
|
272
|
+
if not is_pass:
|
|
273
|
+
all_pass = False
|
|
274
|
+
|
|
275
|
+
status = ("[bold green]PASS[/bold green]" if is_pass
|
|
276
|
+
else "[bold red]FAIL[/bold red]")
|
|
277
|
+
mismatch_style = "red bold" if effective_mismatch > 0 else "dim"
|
|
278
|
+
rate = cmp.pass_rate
|
|
279
|
+
rate_color = ("green" if rate >= 100
|
|
280
|
+
else "yellow" if rate >= 90
|
|
281
|
+
else "red")
|
|
282
|
+
|
|
283
|
+
bar_len = 8
|
|
284
|
+
filled = int(rate / 100 * bar_len)
|
|
285
|
+
bar = (f"[{rate_color}]{'â–ˆ' * filled}{'â–‘' * (bar_len - filled)}"
|
|
286
|
+
f"[/{rate_color}] {rate:.1f}%")
|
|
287
|
+
|
|
288
|
+
cols = [
|
|
289
|
+
key, status,
|
|
290
|
+
str(cmp.matched),
|
|
291
|
+
Text(str(effective_mismatch if effective_mismatch > 0 else 0),
|
|
292
|
+
style=mismatch_style),
|
|
293
|
+
]
|
|
294
|
+
if has_wl:
|
|
295
|
+
wl_text = (Text(str(cmp.whitelisted), style="yellow")
|
|
296
|
+
if cmp.whitelisted > 0
|
|
297
|
+
else Text("0", style="dim"))
|
|
298
|
+
cols.append(wl_text)
|
|
299
|
+
if has_bug:
|
|
300
|
+
bug_text = (Text(str(cmp.bug_marked), style="red")
|
|
301
|
+
if cmp.bug_marked > 0
|
|
302
|
+
else Text("0", style="dim"))
|
|
303
|
+
cols.append(bug_text)
|
|
304
|
+
cols += [
|
|
305
|
+
str(cmp.skipped),
|
|
306
|
+
str(cmp.total_stmts),
|
|
307
|
+
bar,
|
|
308
|
+
]
|
|
309
|
+
table.add_row(*cols)
|
|
310
|
+
|
|
311
|
+
_add(table)
|
|
312
|
+
|
|
313
|
+
# Overall verdict
|
|
314
|
+
_add(Text(""))
|
|
315
|
+
if all_pass and not (failed_connections):
|
|
316
|
+
_add(Text.from_markup("[bold green] ★ OVERALL: ALL PASSED[/bold green]"))
|
|
317
|
+
elif all_pass:
|
|
318
|
+
_add(Text.from_markup(
|
|
319
|
+
"[bold yellow] ★ OVERALL: ALL COMPARED PASSED[/bold yellow]"
|
|
320
|
+
" [dim](some connections failed)[/dim]"))
|
|
321
|
+
else:
|
|
322
|
+
_add(Text.from_markup(
|
|
323
|
+
"[bold red] ★ OVERALL: DIFFERENCES FOUND[/bold red]"))
|
|
324
|
+
|
|
325
|
+
return all_pass
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
# ---------------------------------------------------------------------------
|
|
329
|
+
# Report output
|
|
330
|
+
# ---------------------------------------------------------------------------
|
|
331
|
+
|
|
332
|
+
def print_report_file(path: str, label: str = ""):
|
|
333
|
+
"""Buffer a generated report file path."""
|
|
334
|
+
label_text = f"[dim]{label}[/dim] " if label else ""
|
|
335
|
+
_add(Text.from_markup(f" [green]✓[/green] {label_text}[bold]{path}[/bold]"))
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
# ---------------------------------------------------------------------------
|
|
339
|
+
# HTTP server panel
|
|
340
|
+
# ---------------------------------------------------------------------------
|
|
341
|
+
|
|
342
|
+
def print_server_info(url: str, directory: str, *,
|
|
343
|
+
history_url: str = ""):
|
|
344
|
+
"""Print the HTTP server panel (standalone, after main panel)."""
|
|
345
|
+
lines = [
|
|
346
|
+
f"[bold cyan]URL:[/bold cyan] {url}",
|
|
347
|
+
f"[dim]Dir:[/dim] {directory}",
|
|
348
|
+
]
|
|
349
|
+
if history_url:
|
|
350
|
+
lines.append(f"[dim]History:[/dim] {history_url}")
|
|
351
|
+
lines.append("")
|
|
352
|
+
lines.append("Press [bold]Ctrl+C[/bold] to stop")
|
|
353
|
+
console.print()
|
|
354
|
+
console.print(Panel("\n".join(lines),
|
|
355
|
+
title="[bold]HTML Report Server[/bold]",
|
|
356
|
+
border_style=_BOX, expand=False))
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
# ---------------------------------------------------------------------------
|
|
360
|
+
# Benchmark progress & summary
|
|
361
|
+
# ---------------------------------------------------------------------------
|
|
362
|
+
|
|
363
|
+
class BenchProgress:
|
|
364
|
+
"""Context manager for benchmark execution progress.
|
|
365
|
+
|
|
366
|
+
Shows a live progress bar per DBMS during benchmark execution.
|
|
367
|
+
Reuses the same shared Progress approach as ExecutionProgress.
|
|
368
|
+
|
|
369
|
+
For SERIAL mode: shows iteration count (N/M)
|
|
370
|
+
For CONCURRENT mode: shows time progress (20s/30s)
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
_lock = threading.Lock()
|
|
374
|
+
_shared_progress: Optional[Progress] = None
|
|
375
|
+
_ref_count = 0
|
|
376
|
+
|
|
377
|
+
def __init__(self, dbms_name: str, total_queries: int, iterations: int,
|
|
378
|
+
is_concurrent: bool = False, duration: float = 0.0):
|
|
379
|
+
self.dbms_name = dbms_name
|
|
380
|
+
self.is_concurrent = is_concurrent
|
|
381
|
+
if is_concurrent and duration > 0:
|
|
382
|
+
self.total = int(duration) # seconds for time-based progress
|
|
383
|
+
else:
|
|
384
|
+
self.total = total_queries * iterations
|
|
385
|
+
self.duration = duration
|
|
386
|
+
self._task_id = None
|
|
387
|
+
self._completed = 0
|
|
388
|
+
self._start_time = 0.0
|
|
389
|
+
self._elapsed = 0.0
|
|
390
|
+
|
|
391
|
+
@classmethod
|
|
392
|
+
def _acquire(cls, is_concurrent: bool = False) -> Progress:
|
|
393
|
+
with cls._lock:
|
|
394
|
+
if cls._shared_progress is None:
|
|
395
|
+
if is_concurrent:
|
|
396
|
+
# Time-based progress for concurrent mode
|
|
397
|
+
cls._shared_progress = Progress(
|
|
398
|
+
SpinnerColumn(),
|
|
399
|
+
TextColumn(
|
|
400
|
+
"[bold blue]{task.fields[dbms]}[/bold blue]"),
|
|
401
|
+
BarColumn(bar_width=40),
|
|
402
|
+
TextColumn("[cyan]{task.fields[elapsed_s]}s[/cyan]"),
|
|
403
|
+
TextColumn("[dim]/[/dim]"),
|
|
404
|
+
TextColumn("[cyan]{task.fields[total_s]}s[/cyan]"),
|
|
405
|
+
TextColumn("[dim]|[/dim]"),
|
|
406
|
+
TextColumn("{task.fields[status]}"),
|
|
407
|
+
console=console,
|
|
408
|
+
transient=True,
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
# Iteration-based progress for serial mode
|
|
412
|
+
cls._shared_progress = Progress(
|
|
413
|
+
SpinnerColumn(),
|
|
414
|
+
TextColumn(
|
|
415
|
+
"[bold blue]{task.fields[dbms]}[/bold blue]"),
|
|
416
|
+
BarColumn(bar_width=40),
|
|
417
|
+
MofNCompleteColumn(),
|
|
418
|
+
TextColumn("[dim]|[/dim]"),
|
|
419
|
+
TimeElapsedColumn(),
|
|
420
|
+
TextColumn("[dim]|[/dim]"),
|
|
421
|
+
TextColumn("{task.fields[status]}"),
|
|
422
|
+
console=console,
|
|
423
|
+
transient=True,
|
|
424
|
+
)
|
|
425
|
+
cls._shared_progress.start()
|
|
426
|
+
# Disable stdin echo and line buffering to prevent user input
|
|
427
|
+
# from interfering with progress display
|
|
428
|
+
try:
|
|
429
|
+
import sys
|
|
430
|
+
import termios
|
|
431
|
+
import tty
|
|
432
|
+
if sys.stdin.isatty():
|
|
433
|
+
fd = sys.stdin.fileno()
|
|
434
|
+
old_settings = termios.tcgetattr(fd)
|
|
435
|
+
new_settings = termios.tcgetattr(fd)
|
|
436
|
+
# Disable echo and canonical mode
|
|
437
|
+
new_settings[3] = new_settings[3] & ~termios.ECHO & ~termios.ICANON
|
|
438
|
+
termios.tcsetattr(fd, termios.TCSANOW, new_settings)
|
|
439
|
+
# Store old settings for restoration
|
|
440
|
+
cls._stdin_old_settings = old_settings
|
|
441
|
+
except Exception:
|
|
442
|
+
# Ignore if termios is not available (e.g., Windows)
|
|
443
|
+
pass
|
|
444
|
+
cls._ref_count += 1
|
|
445
|
+
return cls._shared_progress
|
|
446
|
+
|
|
447
|
+
@classmethod
|
|
448
|
+
def _release(cls):
|
|
449
|
+
with cls._lock:
|
|
450
|
+
cls._ref_count -= 1
|
|
451
|
+
if cls._ref_count <= 0:
|
|
452
|
+
if cls._shared_progress is not None:
|
|
453
|
+
cls._shared_progress.stop()
|
|
454
|
+
cls._shared_progress = None
|
|
455
|
+
cls._ref_count = 0
|
|
456
|
+
# Restore stdin settings
|
|
457
|
+
if hasattr(cls, '_stdin_old_settings'):
|
|
458
|
+
try:
|
|
459
|
+
import sys
|
|
460
|
+
import termios
|
|
461
|
+
if sys.stdin.isatty():
|
|
462
|
+
fd = sys.stdin.fileno()
|
|
463
|
+
termios.tcsetattr(fd, termios.TCSANOW, cls._stdin_old_settings)
|
|
464
|
+
delattr(cls, '_stdin_old_settings')
|
|
465
|
+
except Exception:
|
|
466
|
+
pass
|
|
467
|
+
# Ensure cursor is visible after progress bar stops
|
|
468
|
+
try:
|
|
469
|
+
console.show_cursor()
|
|
470
|
+
except Exception:
|
|
471
|
+
pass
|
|
472
|
+
|
|
473
|
+
def __enter__(self):
|
|
474
|
+
self._start_time = time.monotonic()
|
|
475
|
+
progress = self._acquire(is_concurrent=self.is_concurrent)
|
|
476
|
+
if self.is_concurrent and self.duration > 0:
|
|
477
|
+
self._task_id = progress.add_task(
|
|
478
|
+
"bench", total=self.total,
|
|
479
|
+
dbms=self.dbms_name, status="[dim]setup...[/dim]",
|
|
480
|
+
elapsed_s=0, total_s=int(self.duration),
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
self._task_id = progress.add_task(
|
|
484
|
+
"bench", total=self.total,
|
|
485
|
+
dbms=self.dbms_name, status="[dim]warmup[/dim]",
|
|
486
|
+
)
|
|
487
|
+
return self
|
|
488
|
+
|
|
489
|
+
def reset_timer(self):
|
|
490
|
+
"""Reset the start time for concurrent mode (call after setup)."""
|
|
491
|
+
self._start_time = time.monotonic()
|
|
492
|
+
|
|
493
|
+
def __exit__(self, *args):
|
|
494
|
+
self._elapsed = time.monotonic() - self._start_time
|
|
495
|
+
self._release()
|
|
496
|
+
|
|
497
|
+
def advance(self, query_name: str = "", iteration: int = 0,
|
|
498
|
+
total: int = 0, is_warmup: bool = False):
|
|
499
|
+
"""Advance overall progress by 1 (for serial mode).
|
|
500
|
+
|
|
501
|
+
The progress bar tracks the overall test case count (warmup + iterations
|
|
502
|
+
across all queries). The status text shows which query is running and
|
|
503
|
+
its per-query iteration count.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
query_name: Current query name
|
|
507
|
+
iteration: Current iteration for this query (1-indexed)
|
|
508
|
+
total: Total iterations for this query
|
|
509
|
+
is_warmup: Whether this is a warmup iteration
|
|
510
|
+
"""
|
|
511
|
+
self._completed += 1
|
|
512
|
+
if is_warmup:
|
|
513
|
+
status = "[dim]warmup[/dim]"
|
|
514
|
+
else:
|
|
515
|
+
status = f"{query_name}"
|
|
516
|
+
prog = self.__class__._shared_progress
|
|
517
|
+
if prog is not None:
|
|
518
|
+
prog.update(self._task_id, advance=1, status=status)
|
|
519
|
+
|
|
520
|
+
def update_time(self, status: str = ""):
|
|
521
|
+
"""Update progress based on elapsed time (for concurrent mode).
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
status: Optional status text to display. If empty, keeps current status.
|
|
525
|
+
"""
|
|
526
|
+
elapsed = time.monotonic() - self._start_time
|
|
527
|
+
elapsed_int = int(elapsed)
|
|
528
|
+
prog = self.__class__._shared_progress
|
|
529
|
+
if prog is not None:
|
|
530
|
+
# Don't overwrite existing status with empty string
|
|
531
|
+
if status:
|
|
532
|
+
prog.update(
|
|
533
|
+
self._task_id,
|
|
534
|
+
completed=elapsed_int,
|
|
535
|
+
elapsed_s=elapsed_int,
|
|
536
|
+
status=status,
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
prog.update(
|
|
540
|
+
self._task_id,
|
|
541
|
+
completed=elapsed_int,
|
|
542
|
+
elapsed_s=elapsed_int,
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
def set_status(self, text: str):
|
|
546
|
+
"""Set custom status.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
text: Status text to display.
|
|
550
|
+
"""
|
|
551
|
+
prog = self.__class__._shared_progress
|
|
552
|
+
if prog is not None:
|
|
553
|
+
if self.is_concurrent and self.duration > 0:
|
|
554
|
+
# During setup phase, keep elapsed at 0
|
|
555
|
+
# (don't count setup time toward benchmark duration)
|
|
556
|
+
is_setup = "setup" in text.lower()
|
|
557
|
+
if is_setup:
|
|
558
|
+
prog.update(
|
|
559
|
+
self._task_id,
|
|
560
|
+
status=text,
|
|
561
|
+
elapsed_s=0,
|
|
562
|
+
completed=0,
|
|
563
|
+
)
|
|
564
|
+
else:
|
|
565
|
+
elapsed = time.monotonic() - self._start_time
|
|
566
|
+
elapsed_int = int(elapsed)
|
|
567
|
+
prog.update(
|
|
568
|
+
self._task_id,
|
|
569
|
+
status=text,
|
|
570
|
+
elapsed_s=elapsed_int,
|
|
571
|
+
completed=elapsed_int,
|
|
572
|
+
)
|
|
573
|
+
else:
|
|
574
|
+
prog.update(self._task_id, status=text)
|
|
575
|
+
|
|
576
|
+
def write_summary_to_buffer(self):
|
|
577
|
+
"""Write a one-line summary into the buffer."""
|
|
578
|
+
elapsed = f"{self._elapsed:.1f}s"
|
|
579
|
+
_add(Text.from_markup(
|
|
580
|
+
f" [bold blue]{self.dbms_name}[/bold blue] "
|
|
581
|
+
f"{self._completed} queries "
|
|
582
|
+
f"[dim]{elapsed}[/dim] [green]done[/green]"
|
|
583
|
+
))
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def print_bench_summary(result):
|
|
587
|
+
"""Buffer a rich benchmark summary table.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
result: BenchmarkResult instance.
|
|
591
|
+
"""
|
|
592
|
+
from .models import BenchmarkResult, WorkloadMode # avoid circular at module level
|
|
593
|
+
|
|
594
|
+
_add(Text(""))
|
|
595
|
+
_add(Rule(Text.from_markup(
|
|
596
|
+
"[bold white]Benchmark Summary[/bold white]"), style=_BOX))
|
|
597
|
+
_add(Text(""))
|
|
598
|
+
|
|
599
|
+
# Config info
|
|
600
|
+
cfg = result.config
|
|
601
|
+
mode_str = result.mode.name
|
|
602
|
+
|
|
603
|
+
# Build config details based on mode
|
|
604
|
+
if result.mode == WorkloadMode.CONCURRENT:
|
|
605
|
+
config_parts = [
|
|
606
|
+
f"Mode: [cyan]{mode_str}[/cyan]",
|
|
607
|
+
f"Concurrency: [cyan]{cfg.concurrency}[/cyan]",
|
|
608
|
+
f"Duration: [cyan]{cfg.duration}s[/cyan]",
|
|
609
|
+
]
|
|
610
|
+
if cfg.ramp_up > 0:
|
|
611
|
+
config_parts.append(f"Ramp-up: [cyan]{cfg.ramp_up}s[/cyan]")
|
|
612
|
+
if cfg.warmup > 0:
|
|
613
|
+
config_parts.append(f"Warmup: [cyan]{cfg.warmup}[/cyan]")
|
|
614
|
+
else:
|
|
615
|
+
config_parts = [
|
|
616
|
+
f"Mode: [cyan]{mode_str}[/cyan]",
|
|
617
|
+
f"Iterations: [cyan]{cfg.iterations}[/cyan]",
|
|
618
|
+
f"Warmup: [cyan]{cfg.warmup}[/cyan]",
|
|
619
|
+
]
|
|
620
|
+
|
|
621
|
+
_add(Text.from_markup(
|
|
622
|
+
f" Workload: [bold]{result.workload_name}[/bold] "
|
|
623
|
+
+ " ".join(config_parts) +
|
|
624
|
+
f" Timestamp: [dim]{result.timestamp}[/dim]"
|
|
625
|
+
))
|
|
626
|
+
|
|
627
|
+
# Show profiling status (always visible)
|
|
628
|
+
if getattr(cfg, 'profile', False):
|
|
629
|
+
# Count flame graphs collected
|
|
630
|
+
fg_count = sum(
|
|
631
|
+
1 for dr in result.dbms_results
|
|
632
|
+
for qs in dr.query_stats
|
|
633
|
+
if qs.flamegraph_svg
|
|
634
|
+
)
|
|
635
|
+
_add(Text.from_markup(
|
|
636
|
+
f" Profiling: [bold red]🔥 ON[/bold red] "
|
|
637
|
+
f"[dim]{fg_count} flame graph(s) captured[/dim]"
|
|
638
|
+
))
|
|
639
|
+
else:
|
|
640
|
+
_add(Text.from_markup(
|
|
641
|
+
f" Profiling: [dim]OFF[/dim]"
|
|
642
|
+
))
|
|
643
|
+
|
|
644
|
+
_add(Text(""))
|
|
645
|
+
|
|
646
|
+
# Per-DBMS summary table
|
|
647
|
+
table = Table(
|
|
648
|
+
header_style="bold",
|
|
649
|
+
show_lines=False,
|
|
650
|
+
padding=(0, 1),
|
|
651
|
+
expand=True,
|
|
652
|
+
show_edge=False,
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
table.add_column("DBMS", style="bold blue", ratio=2)
|
|
656
|
+
table.add_column("Queries", justify="right")
|
|
657
|
+
table.add_column("Errors", justify="right")
|
|
658
|
+
table.add_column("Duration", justify="right")
|
|
659
|
+
table.add_column("QPS", justify="right", style="green")
|
|
660
|
+
|
|
661
|
+
for dr in result.dbms_results:
|
|
662
|
+
err_style = "red bold" if dr.total_errors > 0 else "dim"
|
|
663
|
+
table.add_row(
|
|
664
|
+
dr.dbms_name,
|
|
665
|
+
str(dr.total_queries),
|
|
666
|
+
Text(str(dr.total_errors), style=err_style),
|
|
667
|
+
f"{dr.total_duration_s:.2f}s",
|
|
668
|
+
f"{dr.overall_qps:.1f}",
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
_add(table)
|
|
672
|
+
_add(Text(""))
|
|
673
|
+
|
|
674
|
+
# Per-query comparison (if multiple DBMS)
|
|
675
|
+
if len(result.dbms_results) >= 2:
|
|
676
|
+
_add(Rule(Text.from_markup(
|
|
677
|
+
"[bold white]Per-Query Comparison[/bold white]"), style=_BOX))
|
|
678
|
+
_add(Text(""))
|
|
679
|
+
|
|
680
|
+
# Collect all query names
|
|
681
|
+
all_queries = []
|
|
682
|
+
for dr in result.dbms_results:
|
|
683
|
+
for qs in dr.query_stats:
|
|
684
|
+
if qs.query_name not in all_queries:
|
|
685
|
+
all_queries.append(qs.query_name)
|
|
686
|
+
|
|
687
|
+
for qname in all_queries:
|
|
688
|
+
qtable = Table(
|
|
689
|
+
title=f"[cyan]{qname}[/cyan]",
|
|
690
|
+
header_style="bold",
|
|
691
|
+
show_lines=False,
|
|
692
|
+
padding=(0, 1),
|
|
693
|
+
expand=True,
|
|
694
|
+
show_edge=False,
|
|
695
|
+
)
|
|
696
|
+
qtable.add_column("DBMS", style="blue", ratio=2)
|
|
697
|
+
qtable.add_column("Avg(ms)", justify="right")
|
|
698
|
+
qtable.add_column("P50", justify="right")
|
|
699
|
+
qtable.add_column("P95", justify="right")
|
|
700
|
+
qtable.add_column("P99", justify="right")
|
|
701
|
+
qtable.add_column("QPS", justify="right", style="green")
|
|
702
|
+
|
|
703
|
+
for dr in result.dbms_results:
|
|
704
|
+
qs = next(
|
|
705
|
+
(s for s in dr.query_stats
|
|
706
|
+
if s.query_name == qname), None)
|
|
707
|
+
if qs:
|
|
708
|
+
qtable.add_row(
|
|
709
|
+
dr.dbms_name,
|
|
710
|
+
f"{qs.avg_ms:.2f}",
|
|
711
|
+
f"{qs.p50_ms:.2f}",
|
|
712
|
+
f"{qs.p95_ms:.2f}",
|
|
713
|
+
f"{qs.p99_ms:.2f}",
|
|
714
|
+
f"{qs.qps:.1f}",
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
_add(qtable)
|
|
718
|
+
_add(Text(""))
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
# ---------------------------------------------------------------------------
|
|
722
|
+
# Logging handler that uses rich
|
|
723
|
+
# ---------------------------------------------------------------------------
|
|
724
|
+
|
|
725
|
+
class RichLogHandler(logging.Handler):
|
|
726
|
+
"""Redirect log records to rich console with minimal formatting."""
|
|
727
|
+
|
|
728
|
+
def emit(self, record):
|
|
729
|
+
try:
|
|
730
|
+
msg = self.format(record)
|
|
731
|
+
if record.levelno >= logging.ERROR:
|
|
732
|
+
print_error(msg)
|
|
733
|
+
elif record.levelno >= logging.WARNING:
|
|
734
|
+
print_warning(msg)
|
|
735
|
+
except Exception:
|
|
736
|
+
self.handleError(record)
|