rosetta-sql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rosetta/cli/output.py ADDED
@@ -0,0 +1,545 @@
1
+ """
2
+ Output formatter for CLI commands.
3
+
4
+ Provides JSON output by default (AI Agent friendly) and human-readable output
5
+ as an option.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING
9
+
10
+ if TYPE_CHECKING:
11
+ from .result import CommandResult
12
+
13
+
14
+ class OutputFormatter:
15
+ """
16
+ Format command results for output.
17
+
18
+ Supports two output formats:
19
+ - json: Machine-readable JSON (default, AI Agent friendly)
20
+ - human: Human-readable format with colors and tables
21
+ """
22
+
23
+ def __init__(self, format: str = "json"):
24
+ """
25
+ Initialize output formatter.
26
+
27
+ Args:
28
+ format: Output format, either "json" or "human"
29
+ """
30
+ self.format = format
31
+
32
+ def print(self, result: "CommandResult") -> None:
33
+ """
34
+ Print the command result.
35
+
36
+ Args:
37
+ result: CommandResult to print
38
+ """
39
+ if self.format == "json":
40
+ self._print_json(result)
41
+ else:
42
+ self._print_human(result)
43
+
44
+ def _print_json(self, result: "CommandResult") -> None:
45
+ """Print result as JSON."""
46
+ print(result.to_json())
47
+
48
+ def _print_human(self, result: "CommandResult") -> None:
49
+ """Print result in human-readable format."""
50
+ try:
51
+ from rich.console import Console
52
+ from rich.table import Table
53
+
54
+ console = Console()
55
+
56
+ if result.ok:
57
+ console.print(f"[green]✓[/green] {result.command}")
58
+ if result.data:
59
+ self._print_data_human(console, result.data)
60
+ else:
61
+ # Show command if meaningful, otherwise just show error
62
+ if result.command and result.command != "unknown":
63
+ console.print(f"[red]✗[/red] {result.command}")
64
+ if result.error:
65
+ console.print(f"[red]Error:[/red] {result.error}")
66
+ else:
67
+ console.print(f"[red]✗[/red] {result.error}")
68
+ except ImportError:
69
+ # Fallback to plain text if rich is not available
70
+ if result.ok:
71
+ print(f"✓ {result.command}")
72
+ if result.data:
73
+ print(result.data)
74
+ else:
75
+ if result.command and result.command != "unknown":
76
+ print(f"✗ {result.command}")
77
+ if result.error:
78
+ print(f"Error: {result.error}")
79
+ else:
80
+ print(f"✗ {result.error}")
81
+
82
+ def _print_data_human(self, console, data: dict) -> None:
83
+ """
84
+ Print data in human-readable format with smart formatting.
85
+
86
+ Args:
87
+ console: Rich console instance
88
+ data: Data dictionary to print
89
+ """
90
+ from rich.table import Table
91
+ from rich.panel import Panel
92
+
93
+ # Detect command type and format accordingly
94
+ if "run_id" in data and "report_files" in data:
95
+ # result show (must be checked before generic "dbms" list match)
96
+ self._print_result_show(console, data)
97
+ elif "dbms" in data and isinstance(data.get("dbms"), list):
98
+ # status dbms or list dbms
99
+ if "connected" in data: # status dbms
100
+ self._print_dbms_status(console, data)
101
+ else: # list dbms
102
+ self._print_dbms_list(console, data)
103
+ elif "templates" in data and isinstance(data.get("templates"), list):
104
+ # list templates
105
+ self._print_templates(console, data)
106
+ elif "runs" in data and isinstance(data.get("runs"), list):
107
+ # result list / history
108
+ self._print_history(console, data)
109
+ elif "dbms_results" in data and isinstance(data.get("dbms_results"), list):
110
+ # run bench result
111
+ self._print_bench_result(console, data)
112
+ elif "comparisons" in data and isinstance(data.get("comparisons"), dict):
113
+ # run mtr result
114
+ self._print_mtr_result(console, data)
115
+ elif "results" in data and isinstance(data.get("results"), dict):
116
+ # exec result
117
+ self._print_exec_result(console, data)
118
+ elif "databases" in data and isinstance(data.get("databases"), list):
119
+ # config show
120
+ self._print_config_show(console, data)
121
+ elif all(not isinstance(v, (dict, list)) for v in data.values()):
122
+ # Simple key-value data
123
+ table = Table(show_header=False)
124
+ table.add_column("Key", style="cyan")
125
+ table.add_column("Value")
126
+ for k, v in data.items():
127
+ table.add_row(str(k), str(v))
128
+ console.print(table)
129
+ else:
130
+ # Fallback: print as formatted dict
131
+ import json
132
+ console.print(Panel(
133
+ json.dumps(data, indent=2, ensure_ascii=False),
134
+ title="Result Data"
135
+ ))
136
+
137
+ def _print_dbms_status(self, console, data: dict) -> None:
138
+ """Print DBMS connection status."""
139
+ from rich.table import Table
140
+
141
+ # Print DBMS table with summary in title
142
+ dbms_list = data.get("dbms", [])
143
+ if dbms_list:
144
+ total = data.get('total', 0)
145
+ connected = data.get('connected', 0)
146
+ disconnected = data.get('disconnected', 0)
147
+ table = Table(title=f"Total: {total} Connected: {connected} Disconnected: {disconnected}")
148
+ table.add_column("Name", style="cyan", no_wrap=True)
149
+ table.add_column("Host", no_wrap=True)
150
+ table.add_column("Port", justify="right")
151
+ table.add_column("Driver", no_wrap=True)
152
+ table.add_column("Status", justify="center")
153
+ table.add_column("Version")
154
+ table.add_column("Latency", justify="right")
155
+
156
+ for db in dbms_list:
157
+ # Status with color
158
+ if db.get("connected"):
159
+ status = "[green]✓ Connected[/green]"
160
+ version = db.get("version", "")
161
+ elif db.get("port_reachable"):
162
+ status = "[red]✗ Auth Failed[/red]"
163
+ version = ""
164
+ else:
165
+ status = "[red]✗ Unreachable[/red]"
166
+ version = ""
167
+
168
+ # Latency
169
+ latency = db.get("latency_ms")
170
+ latency_str = f"{latency:.2f}ms" if latency else "-"
171
+
172
+ table.add_row(
173
+ db.get("name", ""),
174
+ db.get("host", ""),
175
+ str(db.get("port", "")),
176
+ db.get("driver", ""),
177
+ status,
178
+ version,
179
+ latency_str
180
+ )
181
+
182
+ console.print(table)
183
+
184
+ def _print_dbms_list(self, console, data: dict) -> None:
185
+ """Print configured DBMS list."""
186
+ from rich.table import Table
187
+
188
+ total = data.get('total', 0)
189
+ dbms_list = data.get("dbms", [])
190
+
191
+ if dbms_list:
192
+ table = Table(title=f"Total: {total} Configured DBMS")
193
+ table.add_column("Name", style="cyan", no_wrap=True)
194
+ table.add_column("Host", no_wrap=True)
195
+ table.add_column("Port", justify="right")
196
+ table.add_column("Driver", no_wrap=True)
197
+ table.add_column("Version")
198
+ table.add_column("Enabled", justify="center")
199
+
200
+ for db in dbms_list:
201
+ enabled = "[green]✓[/green]" if db.get("enabled") else "[red]✗[/red]"
202
+ version = db.get("version", "") if db.get("enabled") else ""
203
+ # Truncate version if too long
204
+ if version and len(version) > 20:
205
+ version = version[:17] + "..."
206
+ table.add_row(
207
+ db.get("name", ""),
208
+ db.get("host", ""),
209
+ str(db.get("port", "")),
210
+ db.get("driver", ""),
211
+ version,
212
+ enabled
213
+ )
214
+
215
+ console.print(table)
216
+
217
+ def _print_config_show(self, console, data: dict) -> None:
218
+ """Print configuration details."""
219
+ from rich.table import Table
220
+
221
+ console.print(f"[cyan]Config Path:[/cyan] {data.get('config_path', '')}")
222
+ console.print(f"[cyan]Total DBMS:[/cyan] {data.get('total_dbms', 0)}")
223
+ console.print(f"[cyan]Enabled DBMS:[/cyan] {data.get('enabled_dbms', 0)}")
224
+ console.print()
225
+
226
+ databases = data.get("databases", [])
227
+ if databases:
228
+ table = Table(title="Database Configurations")
229
+ table.add_column("Name", style="cyan", no_wrap=True)
230
+ table.add_column("Host", no_wrap=True)
231
+ table.add_column("Port", justify="right")
232
+ table.add_column("User", no_wrap=True)
233
+ table.add_column("Enabled", justify="center")
234
+ table.add_column("Init SQL", justify="center")
235
+ table.add_column("Skip Patterns", justify="right")
236
+
237
+ for db in databases:
238
+ enabled = "[green]✓[/green]" if db.get("enabled") else "[red]✗[/red]"
239
+ has_init = "[green]✓[/green]" if db.get("has_init_sql") else "-"
240
+
241
+ table.add_row(
242
+ db.get("name", ""),
243
+ db.get("host", ""),
244
+ str(db.get("port", "")),
245
+ db.get("user", ""),
246
+ enabled,
247
+ has_init,
248
+ str(db.get("skip_patterns_count", 0))
249
+ )
250
+
251
+ console.print(table)
252
+
253
+ def _print_templates(self, console, data: dict) -> None:
254
+ """Print benchmark templates list."""
255
+ from rich.table import Table
256
+
257
+ console.print(f"[cyan]Total Templates:[/cyan] {data.get('total', 0)}")
258
+ console.print()
259
+
260
+ templates = data.get("templates", [])
261
+ if templates:
262
+ table = Table()
263
+ table.add_column("Name", style="cyan", no_wrap=True)
264
+ table.add_column("Description")
265
+
266
+ for tmpl in templates:
267
+ table.add_row(
268
+ tmpl.get("name", ""),
269
+ tmpl.get("description", "")
270
+ )
271
+
272
+ console.print(table)
273
+
274
+ def _print_history(self, console, data: dict) -> None:
275
+ """Print execution history (result list) with pagination."""
276
+ from rich.table import Table
277
+
278
+ total = data.get("total", 0)
279
+ page = data.get("page", 1)
280
+ total_pages = data.get("total_pages", 1)
281
+ per_page = data.get("per_page", 20)
282
+
283
+ runs = data.get("runs", [])
284
+ if not runs:
285
+ console.print("[dim]No runs found.[/dim]")
286
+ return
287
+
288
+ title = f"History (page {page}/{total_pages}, {total} total)"
289
+ table = Table(
290
+ title=title,
291
+ show_header=True,
292
+ header_style="bold cyan",
293
+ border_style="dim",
294
+ pad_edge=True,
295
+ )
296
+ table.add_column("#", style="dim", justify="right", no_wrap=True)
297
+ table.add_column("Run ID", style="cyan")
298
+ table.add_column("Type", no_wrap=True)
299
+ table.add_column("DBMS")
300
+ table.add_column("Timestamp", no_wrap=True)
301
+
302
+ for run in runs:
303
+ rtype = run.get("type", "")
304
+ if rtype == "bench":
305
+ type_badge = "[orange1]bench[/orange1]"
306
+ elif rtype == "mtr":
307
+ type_badge = "[green]mtr[/green]"
308
+ else:
309
+ type_badge = rtype
310
+
311
+ table.add_row(
312
+ str(run.get("idx", "")),
313
+ run.get("id", ""),
314
+ type_badge,
315
+ run.get("dbms", ""),
316
+ run.get("timestamp", ""),
317
+ )
318
+
319
+ console.print(table)
320
+ if total_pages > 1:
321
+ hints = []
322
+ if page < total_pages:
323
+ hints.append(f"-p {page + 1}")
324
+ if page > 1:
325
+ hints.append(f"-p {page - 1}")
326
+ console.print(
327
+ f"[dim]Page {page}/{total_pages}. "
328
+ f"Use {' / '.join(hints)} to navigate.[/dim]"
329
+ )
330
+
331
+ def _print_result_show(self, console, data: dict) -> None:
332
+ """Print result show details."""
333
+ import os
334
+ from rich.table import Table
335
+ from rich.panel import Panel
336
+
337
+ run_path = data.get('path', '')
338
+ abs_path = os.path.abspath(run_path) if run_path else ''
339
+
340
+ # Header info
341
+ info_lines = []
342
+ info_lines.append(f"[bold]Run ID[/bold] {data.get('run_id', '')}")
343
+ info_lines.append(f"[bold]Type[/bold] {data.get('type', '')}")
344
+ info_lines.append(f"[bold]Workload[/bold] {data.get('workload', '')}")
345
+ info_lines.append(f"[bold]Timestamp[/bold] {data.get('timestamp', '')}")
346
+ dbms_list = data.get("dbms", [])
347
+ if dbms_list:
348
+ info_lines.append(f"[bold]DBMS[/bold] {', '.join(dbms_list)}")
349
+ if data.get("mode"):
350
+ info_lines.append(f"[bold]Mode[/bold] {data.get('mode', '')}")
351
+ info_lines.append(f"[bold]Path[/bold] {abs_path}")
352
+
353
+ console.print(Panel(
354
+ "\n".join(info_lines),
355
+ title="[bold cyan]Run Details[/bold cyan]",
356
+ title_align="left",
357
+ border_style="cyan",
358
+ padding=(0, 1),
359
+ ))
360
+
361
+ # Bench summary
362
+ bench_summary = data.get("bench_summary", [])
363
+ if bench_summary:
364
+ console.print()
365
+ table = Table(
366
+ title="[bold]Performance Summary[/bold]",
367
+ title_style="",
368
+ show_header=True, header_style="bold cyan",
369
+ border_style="dim", pad_edge=True,
370
+ )
371
+ table.add_column("DBMS", style="bold", no_wrap=True)
372
+ table.add_column("QPS", justify="right")
373
+ table.add_column("Duration", justify="right")
374
+ table.add_column("Queries", justify="right")
375
+ table.add_column("Errors", justify="right")
376
+
377
+ for s in bench_summary:
378
+ errors_str = str(s.get("errors", 0))
379
+ if s.get("errors", 0) > 0:
380
+ errors_str = f"[red]{errors_str}[/red]"
381
+ table.add_row(
382
+ s.get("dbms", ""),
383
+ f"{s.get('qps', 0):.2f}",
384
+ f"{s.get('duration_s', 0):.2f}s",
385
+ str(s.get("queries", 0)),
386
+ errors_str,
387
+ )
388
+ console.print(table)
389
+
390
+ # Report files (already absolute paths from data)
391
+ report_files = data.get("report_files", [])
392
+ if report_files:
393
+ console.print()
394
+ console.print("[bold]Reports:[/bold]")
395
+ for f in report_files:
396
+ console.print(f" [dim]•[/dim] {f}")
397
+
398
+ def _print_bench_result(self, console, data: dict) -> None:
399
+ """Print benchmark result summary."""
400
+ from rich.table import Table
401
+
402
+ console.print(f"[cyan]Workload:[/cyan] {data.get('workload', 'unknown')}")
403
+ console.print(f"[cyan]Mode:[/cyan] {data.get('mode', 'unknown')}")
404
+ console.print()
405
+
406
+ dbms_results = data.get("dbms_results", [])
407
+ if dbms_results:
408
+ table = Table(title="Benchmark Results")
409
+ table.add_column("DBMS", style="cyan", no_wrap=True)
410
+ table.add_column("QPS", justify="right", no_wrap=True)
411
+ table.add_column("Duration", justify="right", no_wrap=True)
412
+ table.add_column("Queries", justify="right")
413
+ table.add_column("Errors", justify="right")
414
+
415
+ for dr in dbms_results:
416
+ table.add_row(
417
+ dr.get("dbms_name", ""),
418
+ f"{dr.get('overall_qps', 0):.2f}",
419
+ f"{dr.get('total_duration_s', 0):.2f}s",
420
+ str(dr.get("total_queries", 0)),
421
+ str(dr.get("total_errors", 0))
422
+ )
423
+
424
+ console.print(table)
425
+
426
+ console.print()
427
+ console.print(f"[dim]Report directory:[/dim] {data.get('report_directory', '')}")
428
+
429
+ def _print_mtr_result(self, console, data: dict) -> None:
430
+ """Print MTR test result summary."""
431
+ from rich.table import Table
432
+
433
+ dbms_targets = ', '.join(data.get('dbms_targets', []))
434
+
435
+ console.print(f"[cyan]Test File:[/cyan] {data.get('test_file', 'unknown')}")
436
+
437
+ comparisons = data.get("comparisons", {})
438
+ if comparisons:
439
+ table = Table(title=f"DBMS Targets: {dbms_targets}")
440
+ table.add_column("Comparison", style="cyan", no_wrap=True)
441
+ table.add_column("Matched", justify="right")
442
+ table.add_column("Mismatched", justify="right")
443
+ table.add_column("Pass Rate", justify="right", no_wrap=True)
444
+
445
+ for key, cmp in comparisons.items():
446
+ table.add_row(
447
+ key,
448
+ str(cmp.get("matched", 0)),
449
+ str(cmp.get("mismatched", 0)),
450
+ f"{cmp.get('pass_rate', 0):.1f}%"
451
+ )
452
+
453
+ console.print(table)
454
+
455
+ if data.get("failed_connections"):
456
+ console.print()
457
+ console.print(f"[red]Failed Connections:[/red] {', '.join(data['failed_connections'])}")
458
+
459
+ console.print()
460
+ console.print(f"[dim]Report directory:[/dim] {data.get('report_directory', '')}")
461
+
462
+ def _print_exec_result(self, console, data: dict) -> None:
463
+ """Print SQL execution result — one column per DBMS."""
464
+ from rich.table import Table
465
+
466
+ results = data.get("results", {})
467
+ dbms_names = list(results.keys())
468
+
469
+ # Print connection-level errors first
470
+ has_conn_err = False
471
+ for name in dbms_names:
472
+ r = results[name]
473
+ if r.get("error"):
474
+ console.print(f"[red]✗ {name}:[/red] {r['error']}")
475
+ has_conn_err = True
476
+ if has_conn_err:
477
+ console.print()
478
+
479
+ ok_dbms = [n for n in dbms_names if not results[n].get("error")]
480
+ if not ok_dbms:
481
+ return
482
+
483
+ n_stmts = max(len(results[n].get("statements", [])) for n in ok_dbms)
484
+ if n_stmts == 0:
485
+ return
486
+
487
+ # Build table: # | SQL | dbms1 (time) | dbms2 (time) | ...
488
+ table = Table(
489
+ show_header=True,
490
+ header_style="bold cyan",
491
+ border_style="dim",
492
+ pad_edge=True,
493
+ expand=True,
494
+ )
495
+ table.add_column("#", style="bold cyan", no_wrap=True, justify="right", width=3)
496
+ table.add_column("SQL", style="dim", no_wrap=True, max_width=40)
497
+ for name in ok_dbms:
498
+ table.add_column(name, no_wrap=False)
499
+
500
+ for si in range(n_stmts):
501
+ # Get SQL text
502
+ sql_text = ""
503
+ for n in ok_dbms:
504
+ stmts = results[n].get("statements", [])
505
+ if si < len(stmts):
506
+ sql_text = stmts[si].get("sql", "")
507
+ break
508
+ sql_display = sql_text if len(sql_text) <= 40 else sql_text[:37] + "..."
509
+
510
+ # Collect result string for each DBMS
511
+ cells = []
512
+ for name in ok_dbms:
513
+ stmts = results[name].get("statements", [])
514
+ sd = stmts[si] if si < len(stmts) else {}
515
+ elapsed = f"{sd.get('elapsed_ms', 0):.2f}ms"
516
+
517
+ if sd.get("error"):
518
+ cells.append(f"[red]ERROR: {sd['error']}[/red]\n[dim]{elapsed}[/dim]")
519
+ elif sd.get("columns"):
520
+ rows = sd.get("rows", [])
521
+ cols = sd["columns"]
522
+ if rows:
523
+ lines = []
524
+ for row in rows[:5]:
525
+ if len(cols) == 1:
526
+ lines.append(str(row[0]))
527
+ else:
528
+ lines.append(", ".join(
529
+ f"{cols[ci]}={row[ci]}" for ci in range(len(cols))
530
+ ))
531
+ if len(rows) > 5:
532
+ lines.append(f"[dim]... +{len(rows) - 5} rows[/dim]")
533
+ lines.append(f"[dim]{elapsed}[/dim]")
534
+ cells.append("\n".join(lines))
535
+ else:
536
+ cells.append(f"[dim]Empty {elapsed}[/dim]")
537
+ else:
538
+ affected = sd.get("affected_rows", 0)
539
+ cells.append(f"[dim]OK, {affected} rows {elapsed}[/dim]")
540
+
541
+ if si > 0:
542
+ table.add_section()
543
+ table.add_row(str(si + 1), sql_display, *cells)
544
+
545
+ console.print(table)
rosetta/cli/result.py ADDED
@@ -0,0 +1,61 @@
1
+ """
2
+ Unified command result structure for all CLI commands.
3
+ """
4
+
5
+ from dataclasses import dataclass, asdict
6
+ from datetime import datetime
7
+ from typing import Any, Dict, Optional
8
+ import json
9
+
10
+
11
+ @dataclass
12
+ class CommandResult:
13
+ """
14
+ Unified result structure for all CLI commands.
15
+
16
+ This ensures consistent output format for AI Agent consumption.
17
+
18
+ Attributes:
19
+ ok: Whether the command succeeded
20
+ command: The command that was executed (e.g., "run mtr")
21
+ timestamp: ISO format timestamp of execution
22
+ data: Command-specific result data
23
+ error: Error message if command failed
24
+ """
25
+ ok: bool
26
+ command: str
27
+ timestamp: str
28
+ data: Optional[Dict[str, Any]] = None
29
+ error: Optional[str] = None
30
+
31
+ @classmethod
32
+ def success(cls, command: str, data: Optional[Dict[str, Any]] = None) -> "CommandResult":
33
+ """Create a successful result."""
34
+ return cls(
35
+ ok=True,
36
+ command=command,
37
+ timestamp=datetime.now().isoformat(),
38
+ data=data
39
+ )
40
+
41
+ @classmethod
42
+ def failure(cls, error: str, command: str = "unknown") -> "CommandResult":
43
+ """Create a failed result."""
44
+ return cls(
45
+ ok=False,
46
+ command=command,
47
+ timestamp=datetime.now().isoformat(),
48
+ error=error
49
+ )
50
+
51
+ def to_dict(self) -> Dict[str, Any]:
52
+ """Convert to dictionary."""
53
+ return asdict(self)
54
+
55
+ def to_json(self) -> str:
56
+ """Convert to JSON string."""
57
+ return json.dumps(self.to_dict(), indent=2, ensure_ascii=False)
58
+
59
+ def exit_code(self) -> int:
60
+ """Return appropriate exit code."""
61
+ return 0 if self.ok else 1