rosetta-sql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rosetta/ui.py ADDED
@@ -0,0 +1,736 @@
1
+ """Rich terminal UI for Rosetta."""
2
+
3
+ import logging
4
+ import threading
5
+ import time
6
+ from typing import Dict, List, Optional
7
+
8
+ from rich.console import Console, Group
9
+ from rich.panel import Panel
10
+ from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
11
+ SpinnerColumn, TextColumn, TimeElapsedColumn,
12
+ TimeRemainingColumn)
13
+ from rich.rule import Rule
14
+ from rich.table import Table
15
+ from rich.text import Text
16
+
17
+ from .models import CompareResult
18
+
19
+ # The real stderr console for final output and live progress.
20
+ console = Console(stderr=True)
21
+
22
+ _log = logging.getLogger("rosetta")
23
+ _BOX = "cyan"
24
+
25
+ # Collects rich renderables for final Panel output.
26
+ _renderables: List = []
27
+
28
+
29
+ def _add(renderable):
30
+ """Add a renderable to the output buffer."""
31
+ _renderables.append(renderable)
32
+
33
+
34
+ def flush_all(title: str = "Rosetta"):
35
+ """Flush all buffered renderables as a single Panel."""
36
+ if not _renderables:
37
+ return
38
+
39
+ group = Group(*_renderables)
40
+ console.print(Panel(
41
+ group,
42
+ title=f"[bold]{title}[/bold]",
43
+ border_style=_BOX,
44
+ expand=True,
45
+ padding=(0, 1),
46
+ ))
47
+ _renderables.clear()
48
+
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Banner
52
+ # ---------------------------------------------------------------------------
53
+
54
+ BANNER_TEXT = (
55
+ "[bold cyan]"
56
+ r" ____ _ _" "\n"
57
+ r" | _ \ ___ ___ ___| |_| |_ __ _" "\n"
58
+ r" | |_) / _ \/ __|/ _ \ __| __/ _` |" "\n"
59
+ r" | _ < (_) \__ \ __/ |_| || (_| |" "\n"
60
+ r" |_| \_\___/|___/\___|\__|\__\__,_|"
61
+ "[/bold cyan]\n"
62
+ "[dim]Cross-DBMS SQL Behavioral Consistency Verification[/dim]\n"
63
+ )
64
+
65
+
66
+ def print_banner():
67
+ """Buffer the Rosetta banner."""
68
+ _add(Text.from_markup(BANNER_TEXT))
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Phase headers
73
+ # ---------------------------------------------------------------------------
74
+
75
+ def print_phase(title: str, detail: str = ""):
76
+ """Buffer a phase header."""
77
+ text = f"[bold white]{title}[/bold white]"
78
+ if detail:
79
+ text += f" [dim]{detail}[/dim]"
80
+ _add(Text(""))
81
+ _add(Rule(Text.from_markup(text), style=_BOX))
82
+
83
+
84
+ # ---------------------------------------------------------------------------
85
+ # Info messages
86
+ # ---------------------------------------------------------------------------
87
+
88
+ def print_info(msg: str, highlight: str = ""):
89
+ """Buffer an informational line."""
90
+ if highlight:
91
+ _add(Text.from_markup(f" [cyan]>[/cyan] {msg} [bold]{highlight}[/bold]"))
92
+ else:
93
+ _add(Text.from_markup(f" [cyan]>[/cyan] {msg}"))
94
+
95
+
96
+ def print_success(msg: str):
97
+ """Buffer a success message."""
98
+ _add(Text.from_markup(f" [green]✓[/green] {msg}"))
99
+
100
+
101
+ def print_warning(msg: str):
102
+ """Buffer a warning message."""
103
+ _add(Text.from_markup(f" [yellow]âš [/yellow] {msg}"))
104
+
105
+
106
+ def print_error(msg: str):
107
+ """Buffer an error message."""
108
+ _add(Text.from_markup(f" [red]✗[/red] {msg}"))
109
+
110
+
111
+ # ---------------------------------------------------------------------------
112
+ # Execution progress bar
113
+ # ---------------------------------------------------------------------------
114
+
115
+ class ExecutionProgress:
116
+ """Context manager for a DBMS execution progress bar.
117
+
118
+ Multiple instances share a single rich Progress bar so that parallel
119
+ DBMS executions are displayed simultaneously. The shared Progress is
120
+ created on the first ``__enter__`` and stopped when the last instance
121
+ exits.
122
+ """
123
+
124
+ _lock = threading.Lock()
125
+ _shared_progress: Optional[Progress] = None
126
+ _ref_count = 0
127
+
128
+ def __init__(self, dbms_name: str, total: int):
129
+ self.dbms_name = dbms_name
130
+ self.total = total
131
+ self._task_id = None
132
+ self._errors = 0
133
+ self._executed = 0
134
+ self._elapsed = 0.0
135
+ self._start_time = 0.0
136
+
137
+ # -- shared Progress lifecycle ------------------------------------------
138
+
139
+ @classmethod
140
+ def _acquire(cls) -> Progress:
141
+ with cls._lock:
142
+ if cls._shared_progress is None:
143
+ cls._shared_progress = Progress(
144
+ SpinnerColumn(),
145
+ TextColumn("[bold blue]{task.fields[dbms]}[/bold blue]"),
146
+ BarColumn(bar_width=40),
147
+ MofNCompleteColumn(),
148
+ TextColumn("[dim]|[/dim]"),
149
+ TimeElapsedColumn(),
150
+ TextColumn("[dim]|[/dim]"),
151
+ TimeRemainingColumn(),
152
+ TextColumn("{task.fields[status]}"),
153
+ console=console,
154
+ transient=True,
155
+ )
156
+ cls._shared_progress.start()
157
+ cls._ref_count += 1
158
+ return cls._shared_progress
159
+
160
+ @classmethod
161
+ def _release(cls):
162
+ with cls._lock:
163
+ cls._ref_count -= 1
164
+ if cls._ref_count <= 0:
165
+ if cls._shared_progress is not None:
166
+ cls._shared_progress.stop()
167
+ cls._shared_progress = None
168
+ cls._ref_count = 0
169
+ # Ensure cursor is visible after progress bar stops
170
+ try:
171
+ console.show_cursor()
172
+ except Exception:
173
+ pass
174
+
175
+ # -- context manager ----------------------------------------------------
176
+
177
+ def __enter__(self):
178
+ self._start_time = time.monotonic()
179
+ progress = self._acquire()
180
+ self._task_id = progress.add_task(
181
+ "exec", total=self.total,
182
+ dbms=self.dbms_name, status="",
183
+ )
184
+ return self
185
+
186
+ def __exit__(self, *args):
187
+ self._elapsed = time.monotonic() - self._start_time
188
+ self._release()
189
+
190
+ def advance(self, error: bool = False):
191
+ """Advance progress by 1."""
192
+ if error:
193
+ self._errors += 1
194
+ self._executed += 1
195
+ status = (f"[red]{self._errors} err[/red]"
196
+ if self._errors else "[green]ok[/green]")
197
+ prog = self.__class__._shared_progress
198
+ if prog is not None:
199
+ prog.update(self._task_id, advance=1, status=status)
200
+
201
+ def set_status(self, text: str):
202
+ """Set a custom status text."""
203
+ prog = self.__class__._shared_progress
204
+ if prog is not None:
205
+ prog.update(self._task_id, status=text)
206
+
207
+ def write_summary_to_buffer(self):
208
+ """Write a static one-line summary into the buffer (call after exit)."""
209
+ elapsed = f"{self._elapsed:.1f}s"
210
+ if self._errors:
211
+ status = f"[yellow]{self._executed} done, {self._errors} err[/yellow]"
212
+ else:
213
+ status = f"[green]{self._executed} done[/green]"
214
+ _add(Text.from_markup(
215
+ f" [bold blue]{self.dbms_name}[/bold blue] "
216
+ f"{self._executed}/{self.total} "
217
+ f"[dim]{elapsed}[/dim] {status}"
218
+ ))
219
+
220
+
221
+ # ---------------------------------------------------------------------------
222
+ # Summary table
223
+ # ---------------------------------------------------------------------------
224
+
225
+ def print_summary(comparisons: Dict[str, CompareResult],
226
+ failed_connections: set = None):
227
+ """Buffer a rich summary table of comparison results."""
228
+ _add(Text(""))
229
+ _add(Rule(Text.from_markup("[bold white]Summary[/bold white]"), style=_BOX))
230
+ _add(Text(""))
231
+
232
+ # Detect whether any comparison has whitelisted diffs
233
+ has_wl = any(cmp.whitelisted > 0 for cmp in comparisons.values())
234
+ # Detect whether any comparison has bug-marked diffs
235
+ has_bug = any(cmp.bug_marked > 0 for cmp in comparisons.values())
236
+
237
+ table = Table(
238
+ header_style="bold",
239
+ show_lines=False,
240
+ padding=(0, 1),
241
+ expand=True,
242
+ show_edge=False,
243
+ )
244
+
245
+ table.add_column("Comparison", style="white", ratio=3)
246
+ table.add_column("Status", justify="center", min_width=6)
247
+ table.add_column("Match", justify="right", style="green")
248
+ table.add_column("Mismatch", justify="right")
249
+ if has_wl:
250
+ table.add_column("Whitelist", justify="right")
251
+ if has_bug:
252
+ table.add_column("Bug", justify="right")
253
+ table.add_column("Skip", justify="right", style="dim")
254
+ table.add_column("Total", justify="right")
255
+ table.add_column("Rate", justify="right", min_width=14)
256
+
257
+ if failed_connections:
258
+ for name in failed_connections:
259
+ cols = [name, "[yellow]SKIP[/yellow]",
260
+ "-", "-"]
261
+ if has_wl:
262
+ cols.append("-")
263
+ if has_bug:
264
+ cols.append("-")
265
+ cols += ["-", "-", "[dim]conn failed[/dim]"]
266
+ table.add_row(*cols)
267
+
268
+ all_pass = True
269
+ for key, cmp in comparisons.items():
270
+ effective_mismatch = cmp.effective_mismatched
271
+ is_pass = effective_mismatch <= 0
272
+ if not is_pass:
273
+ all_pass = False
274
+
275
+ status = ("[bold green]PASS[/bold green]" if is_pass
276
+ else "[bold red]FAIL[/bold red]")
277
+ mismatch_style = "red bold" if effective_mismatch > 0 else "dim"
278
+ rate = cmp.pass_rate
279
+ rate_color = ("green" if rate >= 100
280
+ else "yellow" if rate >= 90
281
+ else "red")
282
+
283
+ bar_len = 8
284
+ filled = int(rate / 100 * bar_len)
285
+ bar = (f"[{rate_color}]{'â–ˆ' * filled}{'â–‘' * (bar_len - filled)}"
286
+ f"[/{rate_color}] {rate:.1f}%")
287
+
288
+ cols = [
289
+ key, status,
290
+ str(cmp.matched),
291
+ Text(str(effective_mismatch if effective_mismatch > 0 else 0),
292
+ style=mismatch_style),
293
+ ]
294
+ if has_wl:
295
+ wl_text = (Text(str(cmp.whitelisted), style="yellow")
296
+ if cmp.whitelisted > 0
297
+ else Text("0", style="dim"))
298
+ cols.append(wl_text)
299
+ if has_bug:
300
+ bug_text = (Text(str(cmp.bug_marked), style="red")
301
+ if cmp.bug_marked > 0
302
+ else Text("0", style="dim"))
303
+ cols.append(bug_text)
304
+ cols += [
305
+ str(cmp.skipped),
306
+ str(cmp.total_stmts),
307
+ bar,
308
+ ]
309
+ table.add_row(*cols)
310
+
311
+ _add(table)
312
+
313
+ # Overall verdict
314
+ _add(Text(""))
315
+ if all_pass and not (failed_connections):
316
+ _add(Text.from_markup("[bold green] ★ OVERALL: ALL PASSED[/bold green]"))
317
+ elif all_pass:
318
+ _add(Text.from_markup(
319
+ "[bold yellow] ★ OVERALL: ALL COMPARED PASSED[/bold yellow]"
320
+ " [dim](some connections failed)[/dim]"))
321
+ else:
322
+ _add(Text.from_markup(
323
+ "[bold red] ★ OVERALL: DIFFERENCES FOUND[/bold red]"))
324
+
325
+ return all_pass
326
+
327
+
328
+ # ---------------------------------------------------------------------------
329
+ # Report output
330
+ # ---------------------------------------------------------------------------
331
+
332
+ def print_report_file(path: str, label: str = ""):
333
+ """Buffer a generated report file path."""
334
+ label_text = f"[dim]{label}[/dim] " if label else ""
335
+ _add(Text.from_markup(f" [green]✓[/green] {label_text}[bold]{path}[/bold]"))
336
+
337
+
338
+ # ---------------------------------------------------------------------------
339
+ # HTTP server panel
340
+ # ---------------------------------------------------------------------------
341
+
342
+ def print_server_info(url: str, directory: str, *,
343
+ history_url: str = ""):
344
+ """Print the HTTP server panel (standalone, after main panel)."""
345
+ lines = [
346
+ f"[bold cyan]URL:[/bold cyan] {url}",
347
+ f"[dim]Dir:[/dim] {directory}",
348
+ ]
349
+ if history_url:
350
+ lines.append(f"[dim]History:[/dim] {history_url}")
351
+ lines.append("")
352
+ lines.append("Press [bold]Ctrl+C[/bold] to stop")
353
+ console.print()
354
+ console.print(Panel("\n".join(lines),
355
+ title="[bold]HTML Report Server[/bold]",
356
+ border_style=_BOX, expand=False))
357
+
358
+
359
+ # ---------------------------------------------------------------------------
360
+ # Benchmark progress & summary
361
+ # ---------------------------------------------------------------------------
362
+
363
+ class BenchProgress:
364
+ """Context manager for benchmark execution progress.
365
+
366
+ Shows a live progress bar per DBMS during benchmark execution.
367
+ Reuses the same shared Progress approach as ExecutionProgress.
368
+
369
+ For SERIAL mode: shows iteration count (N/M)
370
+ For CONCURRENT mode: shows time progress (20s/30s)
371
+ """
372
+
373
+ _lock = threading.Lock()
374
+ _shared_progress: Optional[Progress] = None
375
+ _ref_count = 0
376
+
377
+ def __init__(self, dbms_name: str, total_queries: int, iterations: int,
378
+ is_concurrent: bool = False, duration: float = 0.0):
379
+ self.dbms_name = dbms_name
380
+ self.is_concurrent = is_concurrent
381
+ if is_concurrent and duration > 0:
382
+ self.total = int(duration) # seconds for time-based progress
383
+ else:
384
+ self.total = total_queries * iterations
385
+ self.duration = duration
386
+ self._task_id = None
387
+ self._completed = 0
388
+ self._start_time = 0.0
389
+ self._elapsed = 0.0
390
+
391
+ @classmethod
392
+ def _acquire(cls, is_concurrent: bool = False) -> Progress:
393
+ with cls._lock:
394
+ if cls._shared_progress is None:
395
+ if is_concurrent:
396
+ # Time-based progress for concurrent mode
397
+ cls._shared_progress = Progress(
398
+ SpinnerColumn(),
399
+ TextColumn(
400
+ "[bold blue]{task.fields[dbms]}[/bold blue]"),
401
+ BarColumn(bar_width=40),
402
+ TextColumn("[cyan]{task.fields[elapsed_s]}s[/cyan]"),
403
+ TextColumn("[dim]/[/dim]"),
404
+ TextColumn("[cyan]{task.fields[total_s]}s[/cyan]"),
405
+ TextColumn("[dim]|[/dim]"),
406
+ TextColumn("{task.fields[status]}"),
407
+ console=console,
408
+ transient=True,
409
+ )
410
+ else:
411
+ # Iteration-based progress for serial mode
412
+ cls._shared_progress = Progress(
413
+ SpinnerColumn(),
414
+ TextColumn(
415
+ "[bold blue]{task.fields[dbms]}[/bold blue]"),
416
+ BarColumn(bar_width=40),
417
+ MofNCompleteColumn(),
418
+ TextColumn("[dim]|[/dim]"),
419
+ TimeElapsedColumn(),
420
+ TextColumn("[dim]|[/dim]"),
421
+ TextColumn("{task.fields[status]}"),
422
+ console=console,
423
+ transient=True,
424
+ )
425
+ cls._shared_progress.start()
426
+ # Disable stdin echo and line buffering to prevent user input
427
+ # from interfering with progress display
428
+ try:
429
+ import sys
430
+ import termios
431
+ import tty
432
+ if sys.stdin.isatty():
433
+ fd = sys.stdin.fileno()
434
+ old_settings = termios.tcgetattr(fd)
435
+ new_settings = termios.tcgetattr(fd)
436
+ # Disable echo and canonical mode
437
+ new_settings[3] = new_settings[3] & ~termios.ECHO & ~termios.ICANON
438
+ termios.tcsetattr(fd, termios.TCSANOW, new_settings)
439
+ # Store old settings for restoration
440
+ cls._stdin_old_settings = old_settings
441
+ except Exception:
442
+ # Ignore if termios is not available (e.g., Windows)
443
+ pass
444
+ cls._ref_count += 1
445
+ return cls._shared_progress
446
+
447
+ @classmethod
448
+ def _release(cls):
449
+ with cls._lock:
450
+ cls._ref_count -= 1
451
+ if cls._ref_count <= 0:
452
+ if cls._shared_progress is not None:
453
+ cls._shared_progress.stop()
454
+ cls._shared_progress = None
455
+ cls._ref_count = 0
456
+ # Restore stdin settings
457
+ if hasattr(cls, '_stdin_old_settings'):
458
+ try:
459
+ import sys
460
+ import termios
461
+ if sys.stdin.isatty():
462
+ fd = sys.stdin.fileno()
463
+ termios.tcsetattr(fd, termios.TCSANOW, cls._stdin_old_settings)
464
+ delattr(cls, '_stdin_old_settings')
465
+ except Exception:
466
+ pass
467
+ # Ensure cursor is visible after progress bar stops
468
+ try:
469
+ console.show_cursor()
470
+ except Exception:
471
+ pass
472
+
473
+ def __enter__(self):
474
+ self._start_time = time.monotonic()
475
+ progress = self._acquire(is_concurrent=self.is_concurrent)
476
+ if self.is_concurrent and self.duration > 0:
477
+ self._task_id = progress.add_task(
478
+ "bench", total=self.total,
479
+ dbms=self.dbms_name, status="[dim]setup...[/dim]",
480
+ elapsed_s=0, total_s=int(self.duration),
481
+ )
482
+ else:
483
+ self._task_id = progress.add_task(
484
+ "bench", total=self.total,
485
+ dbms=self.dbms_name, status="[dim]warmup[/dim]",
486
+ )
487
+ return self
488
+
489
+ def reset_timer(self):
490
+ """Reset the start time for concurrent mode (call after setup)."""
491
+ self._start_time = time.monotonic()
492
+
493
+ def __exit__(self, *args):
494
+ self._elapsed = time.monotonic() - self._start_time
495
+ self._release()
496
+
497
+ def advance(self, query_name: str = "", iteration: int = 0,
498
+ total: int = 0, is_warmup: bool = False):
499
+ """Advance overall progress by 1 (for serial mode).
500
+
501
+ The progress bar tracks the overall test case count (warmup + iterations
502
+ across all queries). The status text shows which query is running and
503
+ its per-query iteration count.
504
+
505
+ Args:
506
+ query_name: Current query name
507
+ iteration: Current iteration for this query (1-indexed)
508
+ total: Total iterations for this query
509
+ is_warmup: Whether this is a warmup iteration
510
+ """
511
+ self._completed += 1
512
+ if is_warmup:
513
+ status = "[dim]warmup[/dim]"
514
+ else:
515
+ status = f"{query_name}"
516
+ prog = self.__class__._shared_progress
517
+ if prog is not None:
518
+ prog.update(self._task_id, advance=1, status=status)
519
+
520
+ def update_time(self, status: str = ""):
521
+ """Update progress based on elapsed time (for concurrent mode).
522
+
523
+ Args:
524
+ status: Optional status text to display. If empty, keeps current status.
525
+ """
526
+ elapsed = time.monotonic() - self._start_time
527
+ elapsed_int = int(elapsed)
528
+ prog = self.__class__._shared_progress
529
+ if prog is not None:
530
+ # Don't overwrite existing status with empty string
531
+ if status:
532
+ prog.update(
533
+ self._task_id,
534
+ completed=elapsed_int,
535
+ elapsed_s=elapsed_int,
536
+ status=status,
537
+ )
538
+ else:
539
+ prog.update(
540
+ self._task_id,
541
+ completed=elapsed_int,
542
+ elapsed_s=elapsed_int,
543
+ )
544
+
545
+ def set_status(self, text: str):
546
+ """Set custom status.
547
+
548
+ Args:
549
+ text: Status text to display.
550
+ """
551
+ prog = self.__class__._shared_progress
552
+ if prog is not None:
553
+ if self.is_concurrent and self.duration > 0:
554
+ # During setup phase, keep elapsed at 0
555
+ # (don't count setup time toward benchmark duration)
556
+ is_setup = "setup" in text.lower()
557
+ if is_setup:
558
+ prog.update(
559
+ self._task_id,
560
+ status=text,
561
+ elapsed_s=0,
562
+ completed=0,
563
+ )
564
+ else:
565
+ elapsed = time.monotonic() - self._start_time
566
+ elapsed_int = int(elapsed)
567
+ prog.update(
568
+ self._task_id,
569
+ status=text,
570
+ elapsed_s=elapsed_int,
571
+ completed=elapsed_int,
572
+ )
573
+ else:
574
+ prog.update(self._task_id, status=text)
575
+
576
+ def write_summary_to_buffer(self):
577
+ """Write a one-line summary into the buffer."""
578
+ elapsed = f"{self._elapsed:.1f}s"
579
+ _add(Text.from_markup(
580
+ f" [bold blue]{self.dbms_name}[/bold blue] "
581
+ f"{self._completed} queries "
582
+ f"[dim]{elapsed}[/dim] [green]done[/green]"
583
+ ))
584
+
585
+
586
+ def print_bench_summary(result):
587
+ """Buffer a rich benchmark summary table.
588
+
589
+ Args:
590
+ result: BenchmarkResult instance.
591
+ """
592
+ from .models import BenchmarkResult, WorkloadMode # avoid circular at module level
593
+
594
+ _add(Text(""))
595
+ _add(Rule(Text.from_markup(
596
+ "[bold white]Benchmark Summary[/bold white]"), style=_BOX))
597
+ _add(Text(""))
598
+
599
+ # Config info
600
+ cfg = result.config
601
+ mode_str = result.mode.name
602
+
603
+ # Build config details based on mode
604
+ if result.mode == WorkloadMode.CONCURRENT:
605
+ config_parts = [
606
+ f"Mode: [cyan]{mode_str}[/cyan]",
607
+ f"Concurrency: [cyan]{cfg.concurrency}[/cyan]",
608
+ f"Duration: [cyan]{cfg.duration}s[/cyan]",
609
+ ]
610
+ if cfg.ramp_up > 0:
611
+ config_parts.append(f"Ramp-up: [cyan]{cfg.ramp_up}s[/cyan]")
612
+ if cfg.warmup > 0:
613
+ config_parts.append(f"Warmup: [cyan]{cfg.warmup}[/cyan]")
614
+ else:
615
+ config_parts = [
616
+ f"Mode: [cyan]{mode_str}[/cyan]",
617
+ f"Iterations: [cyan]{cfg.iterations}[/cyan]",
618
+ f"Warmup: [cyan]{cfg.warmup}[/cyan]",
619
+ ]
620
+
621
+ _add(Text.from_markup(
622
+ f" Workload: [bold]{result.workload_name}[/bold] "
623
+ + " ".join(config_parts) +
624
+ f" Timestamp: [dim]{result.timestamp}[/dim]"
625
+ ))
626
+
627
+ # Show profiling status (always visible)
628
+ if getattr(cfg, 'profile', False):
629
+ # Count flame graphs collected
630
+ fg_count = sum(
631
+ 1 for dr in result.dbms_results
632
+ for qs in dr.query_stats
633
+ if qs.flamegraph_svg
634
+ )
635
+ _add(Text.from_markup(
636
+ f" Profiling: [bold red]🔥 ON[/bold red] "
637
+ f"[dim]{fg_count} flame graph(s) captured[/dim]"
638
+ ))
639
+ else:
640
+ _add(Text.from_markup(
641
+ f" Profiling: [dim]OFF[/dim]"
642
+ ))
643
+
644
+ _add(Text(""))
645
+
646
+ # Per-DBMS summary table
647
+ table = Table(
648
+ header_style="bold",
649
+ show_lines=False,
650
+ padding=(0, 1),
651
+ expand=True,
652
+ show_edge=False,
653
+ )
654
+
655
+ table.add_column("DBMS", style="bold blue", ratio=2)
656
+ table.add_column("Queries", justify="right")
657
+ table.add_column("Errors", justify="right")
658
+ table.add_column("Duration", justify="right")
659
+ table.add_column("QPS", justify="right", style="green")
660
+
661
+ for dr in result.dbms_results:
662
+ err_style = "red bold" if dr.total_errors > 0 else "dim"
663
+ table.add_row(
664
+ dr.dbms_name,
665
+ str(dr.total_queries),
666
+ Text(str(dr.total_errors), style=err_style),
667
+ f"{dr.total_duration_s:.2f}s",
668
+ f"{dr.overall_qps:.1f}",
669
+ )
670
+
671
+ _add(table)
672
+ _add(Text(""))
673
+
674
+ # Per-query comparison (if multiple DBMS)
675
+ if len(result.dbms_results) >= 2:
676
+ _add(Rule(Text.from_markup(
677
+ "[bold white]Per-Query Comparison[/bold white]"), style=_BOX))
678
+ _add(Text(""))
679
+
680
+ # Collect all query names
681
+ all_queries = []
682
+ for dr in result.dbms_results:
683
+ for qs in dr.query_stats:
684
+ if qs.query_name not in all_queries:
685
+ all_queries.append(qs.query_name)
686
+
687
+ for qname in all_queries:
688
+ qtable = Table(
689
+ title=f"[cyan]{qname}[/cyan]",
690
+ header_style="bold",
691
+ show_lines=False,
692
+ padding=(0, 1),
693
+ expand=True,
694
+ show_edge=False,
695
+ )
696
+ qtable.add_column("DBMS", style="blue", ratio=2)
697
+ qtable.add_column("Avg(ms)", justify="right")
698
+ qtable.add_column("P50", justify="right")
699
+ qtable.add_column("P95", justify="right")
700
+ qtable.add_column("P99", justify="right")
701
+ qtable.add_column("QPS", justify="right", style="green")
702
+
703
+ for dr in result.dbms_results:
704
+ qs = next(
705
+ (s for s in dr.query_stats
706
+ if s.query_name == qname), None)
707
+ if qs:
708
+ qtable.add_row(
709
+ dr.dbms_name,
710
+ f"{qs.avg_ms:.2f}",
711
+ f"{qs.p50_ms:.2f}",
712
+ f"{qs.p95_ms:.2f}",
713
+ f"{qs.p99_ms:.2f}",
714
+ f"{qs.qps:.1f}",
715
+ )
716
+
717
+ _add(qtable)
718
+ _add(Text(""))
719
+
720
+
721
+ # ---------------------------------------------------------------------------
722
+ # Logging handler that uses rich
723
+ # ---------------------------------------------------------------------------
724
+
725
+ class RichLogHandler(logging.Handler):
726
+ """Redirect log records to rich console with minimal formatting."""
727
+
728
+ def emit(self, record):
729
+ try:
730
+ msg = self.format(record)
731
+ if record.levelno >= logging.ERROR:
732
+ print_error(msg)
733
+ elif record.levelno >= logging.WARNING:
734
+ print_warning(msg)
735
+ except Exception:
736
+ self.handleError(record)