rosetta-sql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rosetta/buglist.py ADDED
@@ -0,0 +1,108 @@
1
+ """Bug-mark management for Rosetta.
2
+
3
+ A *bug entry* records the fingerprint of a diff that has been identified as a
4
+ genuine bug. Unlike whitelisted diffs, bug-marked diffs **still count toward
5
+ the failure rate** — the mark is purely informational so that users can track
6
+ known bugs across test runs.
7
+
8
+ The bug list is persisted as a single JSON file (``buglist.json``) in the
9
+ output directory. The fingerprint algorithm is identical to the whitelist
10
+ (MD5 over normalised SQL + output), so the same ``diff_fingerprint`` helper
11
+ is reused.
12
+ """
13
+
14
+ import json
15
+ import logging
16
+ import os
17
+ import time as _time
18
+ from typing import Dict, Optional
19
+
20
+ log = logging.getLogger("rosetta")
21
+
22
+ _BUGLIST_FILE = "buglist.json"
23
+
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Buglist store
27
+ # ---------------------------------------------------------------------------
28
+
29
+ class Buglist:
30
+ """In-memory bug list backed by a JSON file.
31
+
32
+ Structure of ``buglist.json``::
33
+
34
+ {
35
+ "<fingerprint>": {
36
+ "stmt": "SELECT ...",
37
+ "dbms_a": "tdsql",
38
+ "dbms_b": "mysql",
39
+ "block": 42,
40
+ "reason": "Known bug #123",
41
+ "added_at": "2026-03-10 18:00:00"
42
+ },
43
+ ...
44
+ }
45
+ """
46
+
47
+ def __init__(self, output_dir: str):
48
+ self._path = os.path.join(output_dir, _BUGLIST_FILE)
49
+ self._data: Dict[str, dict] = {}
50
+ self.load()
51
+
52
+ # -- persistence --------------------------------------------------------
53
+
54
+ def load(self):
55
+ if os.path.isfile(self._path):
56
+ try:
57
+ with open(self._path, "r", encoding="utf-8") as f:
58
+ self._data = json.load(f)
59
+ except (json.JSONDecodeError, OSError) as e:
60
+ log.warning("Failed to load buglist: %s", e)
61
+ self._data = {}
62
+ else:
63
+ self._data = {}
64
+
65
+ def save(self):
66
+ os.makedirs(os.path.dirname(self._path) or ".", exist_ok=True)
67
+ with open(self._path, "w", encoding="utf-8") as f:
68
+ json.dump(self._data, f, indent=2, ensure_ascii=False)
69
+
70
+ # -- query / mutate -----------------------------------------------------
71
+
72
+ @property
73
+ def entries(self) -> Dict[str, dict]:
74
+ return dict(self._data)
75
+
76
+ def __len__(self) -> int:
77
+ return len(self._data)
78
+
79
+ def contains(self, fingerprint: str) -> bool:
80
+ return fingerprint in self._data
81
+
82
+ def add(self, fingerprint: str, stmt: str, dbms_a: str, dbms_b: str,
83
+ block: int = 0, reason: str = "") -> dict:
84
+ """Add an entry and persist. Returns the stored dict."""
85
+ entry = {
86
+ "stmt": stmt[:300],
87
+ "dbms_a": dbms_a,
88
+ "dbms_b": dbms_b,
89
+ "block": block,
90
+ "reason": reason,
91
+ "added_at": _time.strftime("%Y-%m-%d %H:%M:%S"),
92
+ }
93
+ self._data[fingerprint] = entry
94
+ self.save()
95
+ return entry
96
+
97
+ def remove(self, fingerprint: str) -> bool:
98
+ """Remove an entry. Returns True if it existed."""
99
+ if fingerprint in self._data:
100
+ del self._data[fingerprint]
101
+ self.save()
102
+ return True
103
+ return False
104
+
105
+ def clear(self):
106
+ """Remove all entries."""
107
+ self._data.clear()
108
+ self.save()
@@ -0,0 +1,11 @@
1
+ """
2
+ Rosetta CLI - Modern command-line interface for AI Agents and humans.
3
+
4
+ Human-readable output by default; use -j/--json for JSON output.
5
+ """
6
+
7
+ from .main import main
8
+ from .result import CommandResult
9
+ from ..runner import _enter_interactive, parse_args
10
+
11
+ __all__ = ["main", "CommandResult", "_enter_interactive", "parse_args"]
@@ -0,0 +1,243 @@
1
+ """
2
+ Handler for the 'config' subcommand.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import TYPE_CHECKING
8
+
9
+ from .result import CommandResult
10
+
11
+ if TYPE_CHECKING:
12
+ from .output import OutputFormatter
13
+
14
+
15
+ def handle_config(args, output: "OutputFormatter") -> CommandResult:
16
+ """
17
+ Handle the 'config' subcommand.
18
+
19
+ Args:
20
+ args: Parsed command-line arguments
21
+ output: Output formatter
22
+
23
+ Returns:
24
+ CommandResult with config information
25
+ """
26
+ if args.action == "show":
27
+ return _handle_config_show(args, output)
28
+ elif args.action == "validate":
29
+ return _handle_config_validate(args, output)
30
+ elif args.action == "init":
31
+ return _handle_config_init(args, output)
32
+ else:
33
+ return CommandResult.failure(
34
+ f"Unknown config action: {args.action}",
35
+ )
36
+
37
+
38
+ def _handle_config_show(args, output: "OutputFormatter") -> CommandResult:
39
+ """
40
+ Show current configuration.
41
+
42
+ Args:
43
+ args: Parsed arguments
44
+ output: Output formatter
45
+
46
+ Returns:
47
+ CommandResult with config details
48
+ """
49
+ from ..config import load_config
50
+
51
+ if not os.path.isfile(args.config):
52
+ return CommandResult.failure(
53
+ f"Config file not found: {args.config}",
54
+ )
55
+
56
+ try:
57
+ configs = load_config(args.config)
58
+ except Exception as e:
59
+ return CommandResult.failure(f"Failed to load config: {str(e)}")
60
+
61
+ # Read raw JSON for display
62
+ with open(args.config, "r", encoding="utf-8") as f:
63
+ raw_config = json.load(f)
64
+
65
+ return CommandResult.success(
66
+ "config show",
67
+ {
68
+ "config_path": os.path.abspath(args.config),
69
+ "total_dbms": len(configs),
70
+ "enabled_dbms": sum(1 for c in configs if c.enabled),
71
+ "databases": [
72
+ {
73
+ "name": c.name,
74
+ "host": c.host,
75
+ "port": c.port,
76
+ "user": c.user,
77
+ "driver": c.driver,
78
+ "enabled": c.enabled,
79
+ "has_init_sql": bool(c.init_sql),
80
+ "skip_patterns_count": len(c.skip_patterns),
81
+ }
82
+ for c in configs
83
+ ],
84
+ "raw_config": raw_config,
85
+ },
86
+ )
87
+
88
+
89
+ def _handle_config_validate(args, output: "OutputFormatter") -> CommandResult:
90
+ """
91
+ Validate configuration file.
92
+
93
+ Args:
94
+ args: Parsed arguments
95
+ output: Output formatter
96
+
97
+ Returns:
98
+ CommandResult with validation results
99
+ """
100
+ import socket
101
+ from ..config import load_config
102
+ from ..executor import check_port
103
+
104
+ if not os.path.isfile(args.config):
105
+ return CommandResult.failure(
106
+ f"Config file not found: {args.config}",
107
+ )
108
+
109
+ errors = []
110
+ warnings = []
111
+
112
+ # Validate JSON structure
113
+ try:
114
+ with open(args.config, "r", encoding="utf-8") as f:
115
+ data = json.load(f)
116
+ except json.JSONDecodeError as e:
117
+ return CommandResult.failure(
118
+ f"Invalid JSON: {str(e)}",
119
+ )
120
+
121
+ # Check databases array
122
+ if "databases" not in data:
123
+ return CommandResult.failure(
124
+ "Missing 'databases' key in config",
125
+ )
126
+
127
+ if not isinstance(data["databases"], list):
128
+ return CommandResult.failure(
129
+ "'databases' must be an array",
130
+ )
131
+
132
+ if len(data["databases"]) == 0:
133
+ return CommandResult.failure(
134
+ "No databases configured",
135
+ )
136
+
137
+ # Validate each database config
138
+ for i, db in enumerate(data["databases"]):
139
+ prefix = f"databases[{i}]"
140
+
141
+ # Required fields
142
+ if "name" not in db:
143
+ errors.append(f"{prefix}: missing 'name' field")
144
+
145
+ # Optional fields with defaults
146
+ host = db.get("host", "127.0.0.1")
147
+ port = db.get("port", 3306)
148
+
149
+ # Validate types
150
+ if not isinstance(host, str):
151
+ errors.append(f"{prefix}.host: must be a string")
152
+
153
+ if not isinstance(port, int):
154
+ errors.append(f"{prefix}.port: must be an integer")
155
+
156
+ # Check if port is valid
157
+ if isinstance(port, int) and (port < 1 or port > 65535):
158
+ errors.append(f"{prefix}.port: must be between 1 and 65535")
159
+
160
+ # Try to load config
161
+ try:
162
+ configs = load_config(args.config)
163
+ except Exception as e:
164
+ errors.append(f"Failed to load config: {str(e)}")
165
+ configs = []
166
+
167
+ # Check connectivity for enabled databases
168
+ connectivity = []
169
+ for config in configs:
170
+ if not config.enabled:
171
+ continue
172
+
173
+ reachable = check_port(config.host, config.port, timeout=2)
174
+ connectivity.append({
175
+ "name": config.name,
176
+ "host": config.host,
177
+ "port": config.port,
178
+ "reachable": reachable,
179
+ })
180
+
181
+ if not reachable:
182
+ warnings.append(
183
+ f"{config.name} ({config.host}:{config.port}): not reachable"
184
+ )
185
+
186
+ if errors:
187
+ return CommandResult.failure(
188
+ "Config validation failed",
189
+ )
190
+
191
+ return CommandResult.success(
192
+ "config validate",
193
+ {
194
+ "config_path": os.path.abspath(args.config),
195
+ "valid": True,
196
+ "total_dbms": len(configs),
197
+ "enabled_dbms": sum(1 for c in configs if c.enabled),
198
+ "errors": errors,
199
+ "warnings": warnings,
200
+ "connectivity": connectivity,
201
+ },
202
+ )
203
+
204
+
205
+ def _handle_config_init(args, output: "OutputFormatter") -> CommandResult:
206
+ """
207
+ Generate sample configuration file.
208
+
209
+ Args:
210
+ args: Parsed arguments
211
+ output: Output formatter
212
+
213
+ Returns:
214
+ CommandResult with generated config path
215
+ """
216
+ from ..config import generate_sample_config
217
+
218
+ # Determine output path
219
+ output_path = args.output if args.output else "dbms_config.sample.json"
220
+
221
+ # Check if file already exists
222
+ if os.path.isfile(output_path):
223
+ return CommandResult.failure(
224
+ f"File already exists: {output_path}. Use --output to specify a different path",
225
+ command="config init",
226
+ )
227
+
228
+ # Generate sample config
229
+ try:
230
+ generate_sample_config(output_path)
231
+ except Exception as e:
232
+ return CommandResult.failure(
233
+ f"Failed to generate config: {str(e)}",
234
+ command="config init",
235
+ )
236
+
237
+ return CommandResult.success(
238
+ "config init",
239
+ {
240
+ "config_path": os.path.abspath(output_path),
241
+ "message": f"Sample config written to {output_path}",
242
+ },
243
+ )
rosetta/cli/exec.py ADDED
@@ -0,0 +1,219 @@
1
+ """
2
+ Handler for the 'exec' subcommand - execute SQL statements.
3
+ """
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from .result import CommandResult
8
+
9
+ if TYPE_CHECKING:
10
+ from .output import OutputFormatter
11
+
12
+
13
+ def handle_exec(args, output: "OutputFormatter") -> CommandResult:
14
+ """
15
+ Handle the 'exec' subcommand.
16
+
17
+ Args:
18
+ args: Parsed command-line arguments
19
+ output: Output formatter
20
+
21
+ Returns:
22
+ CommandResult with execution results
23
+ """
24
+ import os
25
+ import concurrent.futures
26
+ import time as _time
27
+ from ..config import load_config, filter_configs
28
+ from ..executor import DBConnection, check_port
29
+ from ..parser import TestFileParser
30
+
31
+ # Load config
32
+ if not os.path.isfile(args.config):
33
+ return CommandResult.failure(
34
+ f"Config file not found: {args.config}",
35
+ )
36
+
37
+ all_configs = load_config(args.config)
38
+ if not all_configs:
39
+ return CommandResult.failure(
40
+ f"No databases configured in {args.config}",
41
+ )
42
+
43
+ # Filter configs
44
+ if args.dbms:
45
+ try:
46
+ configs = filter_configs(all_configs, args.dbms)
47
+ except ValueError as e:
48
+ return CommandResult.failure(str(e))
49
+ else:
50
+ configs = [c for c in all_configs if c.enabled]
51
+
52
+ if not configs:
53
+ return CommandResult.failure("No databases selected")
54
+
55
+ # Get SQL statements
56
+ sql_text = None
57
+ if args.sql:
58
+ sql_text = args.sql
59
+ elif args.file:
60
+ if not os.path.isfile(args.file):
61
+ return CommandResult.failure(
62
+ f"SQL file not found: {args.file}",
63
+ )
64
+ with open(args.file, "r", encoding="utf-8") as f:
65
+ sql_text = f.read()
66
+ else:
67
+ return CommandResult.failure(
68
+ "Either --sql or --file is required",
69
+ )
70
+
71
+ # Parse SQL statements
72
+ try:
73
+ parsed = TestFileParser.parse_text(sql_text)
74
+ statements = [s.text for s in parsed]
75
+ except Exception as e:
76
+ return CommandResult.failure(f"Parse error: {str(e)}")
77
+
78
+ # Determine database (None means connect without selecting a database)
79
+ database = args.database if args.database else None
80
+
81
+ # Execute on each DBMS
82
+ def _exec_on_dbms(config):
83
+ """Execute all statements on one DBMS."""
84
+ result = {
85
+ "name": config.name,
86
+ "statements": [],
87
+ "error": None,
88
+ }
89
+
90
+ # Check port first
91
+ if not check_port(config.host, config.port):
92
+ result["error"] = f"Cannot reach {config.host}:{config.port}"
93
+ return result
94
+
95
+ # For exec without --database, connect directly without USE/CREATE
96
+ if database is None:
97
+ conn = None
98
+ cursor = None
99
+ try:
100
+ connect_kwargs = dict(
101
+ host=config.host,
102
+ port=config.port,
103
+ user=config.user,
104
+ password=config.password,
105
+ connect_timeout=10,
106
+ )
107
+ if config.driver == "mysql.connector":
108
+ import mysql.connector
109
+ connect_kwargs["allow_local_infile"] = True
110
+ conn = mysql.connector.connect(**connect_kwargs)
111
+ else:
112
+ import pymysql
113
+ connect_kwargs["local_infile"] = True
114
+ conn = pymysql.connect(**connect_kwargs)
115
+ conn.autocommit = True
116
+ cursor = conn.cursor()
117
+ except Exception as e:
118
+ result["error"] = f"Connection failed: {str(e)}"
119
+ return result
120
+
121
+ try:
122
+ for sql in statements:
123
+ stmt_result = _exec_stmt(cursor, sql)
124
+ result["statements"].append(stmt_result)
125
+ finally:
126
+ if cursor:
127
+ try:
128
+ cursor.close()
129
+ except Exception:
130
+ pass
131
+ if conn:
132
+ try:
133
+ conn.close()
134
+ except Exception:
135
+ pass
136
+ return result
137
+
138
+ # With explicit --database, use DBConnection (creates DB + USE)
139
+ db = DBConnection(config, database)
140
+ try:
141
+ db.connect()
142
+ except Exception as e:
143
+ result["error"] = f"Connection failed: {str(e)}"
144
+ return result
145
+
146
+ try:
147
+ for sql in statements:
148
+ stmt_result = _exec_stmt(db.cursor, sql)
149
+ result["statements"].append(stmt_result)
150
+ finally:
151
+ db.close()
152
+
153
+ return result
154
+
155
+ # Execute in parallel
156
+ results = {}
157
+ with concurrent.futures.ThreadPoolExecutor(
158
+ max_workers=len(configs)) as pool:
159
+ futures = {pool.submit(_exec_on_dbms, c): c for c in configs}
160
+ for fut in concurrent.futures.as_completed(futures):
161
+ r = fut.result()
162
+ results[r["name"]] = r
163
+
164
+ return CommandResult.success(
165
+ "exec",
166
+ {
167
+ "sql": sql_text[:500], # Truncate for JSON
168
+ "total_statements": len(statements),
169
+ "database": database,
170
+ "dbms_targets": [c.name for c in configs],
171
+ "results": results,
172
+ },
173
+ )
174
+
175
+
176
+ def _exec_stmt(cursor, sql: str) -> dict:
177
+ """Execute a single SQL statement and return the result dict."""
178
+ import time as _time
179
+ stmt_result = {
180
+ "sql": sql,
181
+ "columns": None,
182
+ "rows": None,
183
+ "error": None,
184
+ "affected_rows": 0,
185
+ "elapsed_ms": 0,
186
+ }
187
+ try:
188
+ t0 = _time.monotonic()
189
+ cursor.execute(sql)
190
+ if cursor.description:
191
+ stmt_result["columns"] = [
192
+ desc[0] for desc in cursor.description
193
+ ]
194
+ rows = cursor.fetchall()
195
+ stmt_result["rows"] = [
196
+ [_format_val(c) for c in row]
197
+ for row in rows
198
+ ]
199
+ stmt_result["row_count"] = len(rows)
200
+ else:
201
+ stmt_result["affected_rows"] = cursor.rowcount or 0
202
+ t1 = _time.monotonic()
203
+ stmt_result["elapsed_ms"] = round((t1 - t0) * 1000, 3)
204
+ except Exception as e:
205
+ t1 = _time.monotonic()
206
+ stmt_result["error"] = str(e)
207
+ stmt_result["elapsed_ms"] = round((t1 - t0) * 1000, 3)
208
+ return stmt_result
209
+
210
+
211
+ def _format_val(value) -> str:
212
+ """Format a cell value for JSON serialization."""
213
+ if value is None:
214
+ return "NULL"
215
+ if isinstance(value, bytes):
216
+ return value.decode("utf-8", errors="replace")
217
+ if isinstance(value, bool):
218
+ return "1" if value else "0"
219
+ return str(value)
@@ -0,0 +1,124 @@
1
+ """
2
+ Handler for the 'interactive' subcommand (and aliases 'repl', 'i').
3
+ """
4
+
5
+ import sys
6
+ from typing import TYPE_CHECKING
7
+
8
+ from .result import CommandResult
9
+
10
+ if TYPE_CHECKING:
11
+ from .output import OutputFormatter
12
+
13
+
14
+ def handle_interactive(args, output: "OutputFormatter") -> CommandResult:
15
+ """
16
+ Handle the 'interactive' subcommand.
17
+
18
+ Args:
19
+ args: Parsed command-line arguments
20
+ output: Output formatter
21
+
22
+ Returns:
23
+ CommandResult with session summary
24
+ """
25
+ import os
26
+ import logging
27
+ from ..config import load_config, filter_configs
28
+ from ..interactive import InteractiveSession, BenchInteractiveSession
29
+ from ..executor import ensure_service
30
+
31
+ # Load config
32
+ if not os.path.isfile(args.config):
33
+ return CommandResult.failure(
34
+ f"Config file not found: {args.config}",
35
+ )
36
+
37
+ all_configs = load_config(args.config)
38
+ if not all_configs:
39
+ return CommandResult.failure(
40
+ f"No databases configured in {args.config}",
41
+ )
42
+
43
+ # Filter configs
44
+ if args.dbms:
45
+ try:
46
+ configs = filter_configs(all_configs, args.dbms)
47
+ except ValueError as e:
48
+ return CommandResult.failure(str(e))
49
+ else:
50
+ # Auto-detect reachable DBMS
51
+ reachable_configs = []
52
+ for config in all_configs:
53
+ if ensure_service(config):
54
+ reachable_configs.append(config)
55
+
56
+ if not reachable_configs:
57
+ return CommandResult.failure(
58
+ "No reachable DBMS found. Check your dbms_config.json"
59
+ )
60
+
61
+ configs = reachable_configs
62
+
63
+ if not configs:
64
+ return CommandResult.failure("No databases selected")
65
+
66
+ # Start interactive session
67
+ # Note: For JSON output mode, we still launch interactive but inform user
68
+ if output.format == "json":
69
+ # In JSON mode, inform user that interactive mode is intended for human use
70
+ return CommandResult.success(
71
+ "interactive",
72
+ {
73
+ "message": "Interactive mode launched",
74
+ "note": "Interactive mode is designed for human users. Run without -j/--json for best experience.",
75
+ "dbms_targets": [c.name for c in configs],
76
+ "database": args.database,
77
+ "output_dir": os.path.abspath(args.output_dir),
78
+ "serve": args.serve,
79
+ "port": args.port,
80
+ },
81
+ )
82
+
83
+ # For human mode, actually launch the interactive session
84
+ try:
85
+ # Import the existing interactive logic from old CLI
86
+ from ..cli import _enter_interactive, parse_args
87
+
88
+ # Build args for legacy interactive mode
89
+ legacy_args = parse_args([
90
+ "-i",
91
+ "--config", args.config,
92
+ "--database", args.database,
93
+ "--output-dir", args.output_dir,
94
+ ])
95
+
96
+ # Use filtered configs (either user-specified or auto-detected reachable)
97
+ legacy_args.dbms = ",".join(c.name for c in configs)
98
+ if args.serve:
99
+ legacy_args.serve = args.serve
100
+ if args.port:
101
+ legacy_args.port = args.port
102
+
103
+ # Launch interactive session
104
+ exit_code = _enter_interactive(legacy_args)
105
+
106
+ return CommandResult.success(
107
+ "interactive",
108
+ {
109
+ "exit_code": exit_code,
110
+ "message": "Interactive session ended",
111
+ },
112
+ )
113
+
114
+ except KeyboardInterrupt:
115
+ return CommandResult.success(
116
+ "interactive",
117
+ {
118
+ "message": "Interactive session interrupted",
119
+ },
120
+ )
121
+ except Exception as e:
122
+ return CommandResult.failure(
123
+ f"Interactive session failed: {str(e)}",
124
+ )