rosetta-sql 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/generate_csv_data.py +83 -0
- benchmark/import_data.py +168 -0
- rosetta/__init__.py +3 -0
- rosetta/__main__.py +8 -0
- rosetta/benchmark.py +1678 -0
- rosetta/buglist.py +108 -0
- rosetta/cli/__init__.py +11 -0
- rosetta/cli/config_cmd.py +243 -0
- rosetta/cli/exec.py +219 -0
- rosetta/cli/interactive_cmd.py +124 -0
- rosetta/cli/list_cmd.py +215 -0
- rosetta/cli/main.py +617 -0
- rosetta/cli/output.py +545 -0
- rosetta/cli/result.py +61 -0
- rosetta/cli/result_cmd.py +247 -0
- rosetta/cli/run.py +625 -0
- rosetta/cli/status.py +161 -0
- rosetta/comparator.py +205 -0
- rosetta/config.py +139 -0
- rosetta/executor.py +403 -0
- rosetta/flamegraph.py +630 -0
- rosetta/interactive.py +1790 -0
- rosetta/models.py +197 -0
- rosetta/parser.py +308 -0
- rosetta/reporter/__init__.py +1 -0
- rosetta/reporter/bench_html.py +1457 -0
- rosetta/reporter/bench_text.py +162 -0
- rosetta/reporter/history.py +1686 -0
- rosetta/reporter/html.py +644 -0
- rosetta/reporter/text.py +110 -0
- rosetta/runner.py +3089 -0
- rosetta/ui.py +736 -0
- rosetta/whitelist.py +161 -0
- rosetta_sql-1.0.0.dist-info/LICENSE +21 -0
- rosetta_sql-1.0.0.dist-info/METADATA +379 -0
- rosetta_sql-1.0.0.dist-info/RECORD +42 -0
- rosetta_sql-1.0.0.dist-info/WHEEL +5 -0
- rosetta_sql-1.0.0.dist-info/entry_points.txt +2 -0
- rosetta_sql-1.0.0.dist-info/top_level.txt +4 -0
- skills/rosetta/scripts/install_rosetta.py +469 -0
- skills/rosetta/scripts/rosetta_wrapper.py +377 -0
- tests/test_cli.py +749 -0
rosetta/buglist.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Bug-mark management for Rosetta.
|
|
2
|
+
|
|
3
|
+
A *bug entry* records the fingerprint of a diff that has been identified as a
|
|
4
|
+
genuine bug. Unlike whitelisted diffs, bug-marked diffs **still count toward
|
|
5
|
+
the failure rate** — the mark is purely informational so that users can track
|
|
6
|
+
known bugs across test runs.
|
|
7
|
+
|
|
8
|
+
The bug list is persisted as a single JSON file (``buglist.json``) in the
|
|
9
|
+
output directory. The fingerprint algorithm is identical to the whitelist
|
|
10
|
+
(MD5 over normalised SQL + output), so the same ``diff_fingerprint`` helper
|
|
11
|
+
is reused.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import time as _time
|
|
18
|
+
from typing import Dict, Optional
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger("rosetta")
|
|
21
|
+
|
|
22
|
+
_BUGLIST_FILE = "buglist.json"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# ---------------------------------------------------------------------------
|
|
26
|
+
# Buglist store
|
|
27
|
+
# ---------------------------------------------------------------------------
|
|
28
|
+
|
|
29
|
+
class Buglist:
|
|
30
|
+
"""In-memory bug list backed by a JSON file.
|
|
31
|
+
|
|
32
|
+
Structure of ``buglist.json``::
|
|
33
|
+
|
|
34
|
+
{
|
|
35
|
+
"<fingerprint>": {
|
|
36
|
+
"stmt": "SELECT ...",
|
|
37
|
+
"dbms_a": "tdsql",
|
|
38
|
+
"dbms_b": "mysql",
|
|
39
|
+
"block": 42,
|
|
40
|
+
"reason": "Known bug #123",
|
|
41
|
+
"added_at": "2026-03-10 18:00:00"
|
|
42
|
+
},
|
|
43
|
+
...
|
|
44
|
+
}
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, output_dir: str):
|
|
48
|
+
self._path = os.path.join(output_dir, _BUGLIST_FILE)
|
|
49
|
+
self._data: Dict[str, dict] = {}
|
|
50
|
+
self.load()
|
|
51
|
+
|
|
52
|
+
# -- persistence --------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
def load(self):
|
|
55
|
+
if os.path.isfile(self._path):
|
|
56
|
+
try:
|
|
57
|
+
with open(self._path, "r", encoding="utf-8") as f:
|
|
58
|
+
self._data = json.load(f)
|
|
59
|
+
except (json.JSONDecodeError, OSError) as e:
|
|
60
|
+
log.warning("Failed to load buglist: %s", e)
|
|
61
|
+
self._data = {}
|
|
62
|
+
else:
|
|
63
|
+
self._data = {}
|
|
64
|
+
|
|
65
|
+
def save(self):
|
|
66
|
+
os.makedirs(os.path.dirname(self._path) or ".", exist_ok=True)
|
|
67
|
+
with open(self._path, "w", encoding="utf-8") as f:
|
|
68
|
+
json.dump(self._data, f, indent=2, ensure_ascii=False)
|
|
69
|
+
|
|
70
|
+
# -- query / mutate -----------------------------------------------------
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def entries(self) -> Dict[str, dict]:
|
|
74
|
+
return dict(self._data)
|
|
75
|
+
|
|
76
|
+
def __len__(self) -> int:
|
|
77
|
+
return len(self._data)
|
|
78
|
+
|
|
79
|
+
def contains(self, fingerprint: str) -> bool:
|
|
80
|
+
return fingerprint in self._data
|
|
81
|
+
|
|
82
|
+
def add(self, fingerprint: str, stmt: str, dbms_a: str, dbms_b: str,
|
|
83
|
+
block: int = 0, reason: str = "") -> dict:
|
|
84
|
+
"""Add an entry and persist. Returns the stored dict."""
|
|
85
|
+
entry = {
|
|
86
|
+
"stmt": stmt[:300],
|
|
87
|
+
"dbms_a": dbms_a,
|
|
88
|
+
"dbms_b": dbms_b,
|
|
89
|
+
"block": block,
|
|
90
|
+
"reason": reason,
|
|
91
|
+
"added_at": _time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
92
|
+
}
|
|
93
|
+
self._data[fingerprint] = entry
|
|
94
|
+
self.save()
|
|
95
|
+
return entry
|
|
96
|
+
|
|
97
|
+
def remove(self, fingerprint: str) -> bool:
|
|
98
|
+
"""Remove an entry. Returns True if it existed."""
|
|
99
|
+
if fingerprint in self._data:
|
|
100
|
+
del self._data[fingerprint]
|
|
101
|
+
self.save()
|
|
102
|
+
return True
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
def clear(self):
|
|
106
|
+
"""Remove all entries."""
|
|
107
|
+
self._data.clear()
|
|
108
|
+
self.save()
|
rosetta/cli/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rosetta CLI - Modern command-line interface for AI Agents and humans.
|
|
3
|
+
|
|
4
|
+
Human-readable output by default; use -j/--json for JSON output.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .main import main
|
|
8
|
+
from .result import CommandResult
|
|
9
|
+
from ..runner import _enter_interactive, parse_args
|
|
10
|
+
|
|
11
|
+
__all__ = ["main", "CommandResult", "_enter_interactive", "parse_args"]
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Handler for the 'config' subcommand.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from .result import CommandResult
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .output import OutputFormatter
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def handle_config(args, output: "OutputFormatter") -> CommandResult:
|
|
16
|
+
"""
|
|
17
|
+
Handle the 'config' subcommand.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
args: Parsed command-line arguments
|
|
21
|
+
output: Output formatter
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
CommandResult with config information
|
|
25
|
+
"""
|
|
26
|
+
if args.action == "show":
|
|
27
|
+
return _handle_config_show(args, output)
|
|
28
|
+
elif args.action == "validate":
|
|
29
|
+
return _handle_config_validate(args, output)
|
|
30
|
+
elif args.action == "init":
|
|
31
|
+
return _handle_config_init(args, output)
|
|
32
|
+
else:
|
|
33
|
+
return CommandResult.failure(
|
|
34
|
+
f"Unknown config action: {args.action}",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _handle_config_show(args, output: "OutputFormatter") -> CommandResult:
|
|
39
|
+
"""
|
|
40
|
+
Show current configuration.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
args: Parsed arguments
|
|
44
|
+
output: Output formatter
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
CommandResult with config details
|
|
48
|
+
"""
|
|
49
|
+
from ..config import load_config
|
|
50
|
+
|
|
51
|
+
if not os.path.isfile(args.config):
|
|
52
|
+
return CommandResult.failure(
|
|
53
|
+
f"Config file not found: {args.config}",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
configs = load_config(args.config)
|
|
58
|
+
except Exception as e:
|
|
59
|
+
return CommandResult.failure(f"Failed to load config: {str(e)}")
|
|
60
|
+
|
|
61
|
+
# Read raw JSON for display
|
|
62
|
+
with open(args.config, "r", encoding="utf-8") as f:
|
|
63
|
+
raw_config = json.load(f)
|
|
64
|
+
|
|
65
|
+
return CommandResult.success(
|
|
66
|
+
"config show",
|
|
67
|
+
{
|
|
68
|
+
"config_path": os.path.abspath(args.config),
|
|
69
|
+
"total_dbms": len(configs),
|
|
70
|
+
"enabled_dbms": sum(1 for c in configs if c.enabled),
|
|
71
|
+
"databases": [
|
|
72
|
+
{
|
|
73
|
+
"name": c.name,
|
|
74
|
+
"host": c.host,
|
|
75
|
+
"port": c.port,
|
|
76
|
+
"user": c.user,
|
|
77
|
+
"driver": c.driver,
|
|
78
|
+
"enabled": c.enabled,
|
|
79
|
+
"has_init_sql": bool(c.init_sql),
|
|
80
|
+
"skip_patterns_count": len(c.skip_patterns),
|
|
81
|
+
}
|
|
82
|
+
for c in configs
|
|
83
|
+
],
|
|
84
|
+
"raw_config": raw_config,
|
|
85
|
+
},
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _handle_config_validate(args, output: "OutputFormatter") -> CommandResult:
|
|
90
|
+
"""
|
|
91
|
+
Validate configuration file.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
args: Parsed arguments
|
|
95
|
+
output: Output formatter
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
CommandResult with validation results
|
|
99
|
+
"""
|
|
100
|
+
import socket
|
|
101
|
+
from ..config import load_config
|
|
102
|
+
from ..executor import check_port
|
|
103
|
+
|
|
104
|
+
if not os.path.isfile(args.config):
|
|
105
|
+
return CommandResult.failure(
|
|
106
|
+
f"Config file not found: {args.config}",
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
errors = []
|
|
110
|
+
warnings = []
|
|
111
|
+
|
|
112
|
+
# Validate JSON structure
|
|
113
|
+
try:
|
|
114
|
+
with open(args.config, "r", encoding="utf-8") as f:
|
|
115
|
+
data = json.load(f)
|
|
116
|
+
except json.JSONDecodeError as e:
|
|
117
|
+
return CommandResult.failure(
|
|
118
|
+
f"Invalid JSON: {str(e)}",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# Check databases array
|
|
122
|
+
if "databases" not in data:
|
|
123
|
+
return CommandResult.failure(
|
|
124
|
+
"Missing 'databases' key in config",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if not isinstance(data["databases"], list):
|
|
128
|
+
return CommandResult.failure(
|
|
129
|
+
"'databases' must be an array",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if len(data["databases"]) == 0:
|
|
133
|
+
return CommandResult.failure(
|
|
134
|
+
"No databases configured",
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# Validate each database config
|
|
138
|
+
for i, db in enumerate(data["databases"]):
|
|
139
|
+
prefix = f"databases[{i}]"
|
|
140
|
+
|
|
141
|
+
# Required fields
|
|
142
|
+
if "name" not in db:
|
|
143
|
+
errors.append(f"{prefix}: missing 'name' field")
|
|
144
|
+
|
|
145
|
+
# Optional fields with defaults
|
|
146
|
+
host = db.get("host", "127.0.0.1")
|
|
147
|
+
port = db.get("port", 3306)
|
|
148
|
+
|
|
149
|
+
# Validate types
|
|
150
|
+
if not isinstance(host, str):
|
|
151
|
+
errors.append(f"{prefix}.host: must be a string")
|
|
152
|
+
|
|
153
|
+
if not isinstance(port, int):
|
|
154
|
+
errors.append(f"{prefix}.port: must be an integer")
|
|
155
|
+
|
|
156
|
+
# Check if port is valid
|
|
157
|
+
if isinstance(port, int) and (port < 1 or port > 65535):
|
|
158
|
+
errors.append(f"{prefix}.port: must be between 1 and 65535")
|
|
159
|
+
|
|
160
|
+
# Try to load config
|
|
161
|
+
try:
|
|
162
|
+
configs = load_config(args.config)
|
|
163
|
+
except Exception as e:
|
|
164
|
+
errors.append(f"Failed to load config: {str(e)}")
|
|
165
|
+
configs = []
|
|
166
|
+
|
|
167
|
+
# Check connectivity for enabled databases
|
|
168
|
+
connectivity = []
|
|
169
|
+
for config in configs:
|
|
170
|
+
if not config.enabled:
|
|
171
|
+
continue
|
|
172
|
+
|
|
173
|
+
reachable = check_port(config.host, config.port, timeout=2)
|
|
174
|
+
connectivity.append({
|
|
175
|
+
"name": config.name,
|
|
176
|
+
"host": config.host,
|
|
177
|
+
"port": config.port,
|
|
178
|
+
"reachable": reachable,
|
|
179
|
+
})
|
|
180
|
+
|
|
181
|
+
if not reachable:
|
|
182
|
+
warnings.append(
|
|
183
|
+
f"{config.name} ({config.host}:{config.port}): not reachable"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
if errors:
|
|
187
|
+
return CommandResult.failure(
|
|
188
|
+
"Config validation failed",
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return CommandResult.success(
|
|
192
|
+
"config validate",
|
|
193
|
+
{
|
|
194
|
+
"config_path": os.path.abspath(args.config),
|
|
195
|
+
"valid": True,
|
|
196
|
+
"total_dbms": len(configs),
|
|
197
|
+
"enabled_dbms": sum(1 for c in configs if c.enabled),
|
|
198
|
+
"errors": errors,
|
|
199
|
+
"warnings": warnings,
|
|
200
|
+
"connectivity": connectivity,
|
|
201
|
+
},
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _handle_config_init(args, output: "OutputFormatter") -> CommandResult:
|
|
206
|
+
"""
|
|
207
|
+
Generate sample configuration file.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
args: Parsed arguments
|
|
211
|
+
output: Output formatter
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
CommandResult with generated config path
|
|
215
|
+
"""
|
|
216
|
+
from ..config import generate_sample_config
|
|
217
|
+
|
|
218
|
+
# Determine output path
|
|
219
|
+
output_path = args.output if args.output else "dbms_config.sample.json"
|
|
220
|
+
|
|
221
|
+
# Check if file already exists
|
|
222
|
+
if os.path.isfile(output_path):
|
|
223
|
+
return CommandResult.failure(
|
|
224
|
+
f"File already exists: {output_path}. Use --output to specify a different path",
|
|
225
|
+
command="config init",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Generate sample config
|
|
229
|
+
try:
|
|
230
|
+
generate_sample_config(output_path)
|
|
231
|
+
except Exception as e:
|
|
232
|
+
return CommandResult.failure(
|
|
233
|
+
f"Failed to generate config: {str(e)}",
|
|
234
|
+
command="config init",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return CommandResult.success(
|
|
238
|
+
"config init",
|
|
239
|
+
{
|
|
240
|
+
"config_path": os.path.abspath(output_path),
|
|
241
|
+
"message": f"Sample config written to {output_path}",
|
|
242
|
+
},
|
|
243
|
+
)
|
rosetta/cli/exec.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Handler for the 'exec' subcommand - execute SQL statements.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from .result import CommandResult
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from .output import OutputFormatter
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def handle_exec(args, output: "OutputFormatter") -> CommandResult:
|
|
14
|
+
"""
|
|
15
|
+
Handle the 'exec' subcommand.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
args: Parsed command-line arguments
|
|
19
|
+
output: Output formatter
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
CommandResult with execution results
|
|
23
|
+
"""
|
|
24
|
+
import os
|
|
25
|
+
import concurrent.futures
|
|
26
|
+
import time as _time
|
|
27
|
+
from ..config import load_config, filter_configs
|
|
28
|
+
from ..executor import DBConnection, check_port
|
|
29
|
+
from ..parser import TestFileParser
|
|
30
|
+
|
|
31
|
+
# Load config
|
|
32
|
+
if not os.path.isfile(args.config):
|
|
33
|
+
return CommandResult.failure(
|
|
34
|
+
f"Config file not found: {args.config}",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
all_configs = load_config(args.config)
|
|
38
|
+
if not all_configs:
|
|
39
|
+
return CommandResult.failure(
|
|
40
|
+
f"No databases configured in {args.config}",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Filter configs
|
|
44
|
+
if args.dbms:
|
|
45
|
+
try:
|
|
46
|
+
configs = filter_configs(all_configs, args.dbms)
|
|
47
|
+
except ValueError as e:
|
|
48
|
+
return CommandResult.failure(str(e))
|
|
49
|
+
else:
|
|
50
|
+
configs = [c for c in all_configs if c.enabled]
|
|
51
|
+
|
|
52
|
+
if not configs:
|
|
53
|
+
return CommandResult.failure("No databases selected")
|
|
54
|
+
|
|
55
|
+
# Get SQL statements
|
|
56
|
+
sql_text = None
|
|
57
|
+
if args.sql:
|
|
58
|
+
sql_text = args.sql
|
|
59
|
+
elif args.file:
|
|
60
|
+
if not os.path.isfile(args.file):
|
|
61
|
+
return CommandResult.failure(
|
|
62
|
+
f"SQL file not found: {args.file}",
|
|
63
|
+
)
|
|
64
|
+
with open(args.file, "r", encoding="utf-8") as f:
|
|
65
|
+
sql_text = f.read()
|
|
66
|
+
else:
|
|
67
|
+
return CommandResult.failure(
|
|
68
|
+
"Either --sql or --file is required",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Parse SQL statements
|
|
72
|
+
try:
|
|
73
|
+
parsed = TestFileParser.parse_text(sql_text)
|
|
74
|
+
statements = [s.text for s in parsed]
|
|
75
|
+
except Exception as e:
|
|
76
|
+
return CommandResult.failure(f"Parse error: {str(e)}")
|
|
77
|
+
|
|
78
|
+
# Determine database (None means connect without selecting a database)
|
|
79
|
+
database = args.database if args.database else None
|
|
80
|
+
|
|
81
|
+
# Execute on each DBMS
|
|
82
|
+
def _exec_on_dbms(config):
|
|
83
|
+
"""Execute all statements on one DBMS."""
|
|
84
|
+
result = {
|
|
85
|
+
"name": config.name,
|
|
86
|
+
"statements": [],
|
|
87
|
+
"error": None,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
# Check port first
|
|
91
|
+
if not check_port(config.host, config.port):
|
|
92
|
+
result["error"] = f"Cannot reach {config.host}:{config.port}"
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
# For exec without --database, connect directly without USE/CREATE
|
|
96
|
+
if database is None:
|
|
97
|
+
conn = None
|
|
98
|
+
cursor = None
|
|
99
|
+
try:
|
|
100
|
+
connect_kwargs = dict(
|
|
101
|
+
host=config.host,
|
|
102
|
+
port=config.port,
|
|
103
|
+
user=config.user,
|
|
104
|
+
password=config.password,
|
|
105
|
+
connect_timeout=10,
|
|
106
|
+
)
|
|
107
|
+
if config.driver == "mysql.connector":
|
|
108
|
+
import mysql.connector
|
|
109
|
+
connect_kwargs["allow_local_infile"] = True
|
|
110
|
+
conn = mysql.connector.connect(**connect_kwargs)
|
|
111
|
+
else:
|
|
112
|
+
import pymysql
|
|
113
|
+
connect_kwargs["local_infile"] = True
|
|
114
|
+
conn = pymysql.connect(**connect_kwargs)
|
|
115
|
+
conn.autocommit = True
|
|
116
|
+
cursor = conn.cursor()
|
|
117
|
+
except Exception as e:
|
|
118
|
+
result["error"] = f"Connection failed: {str(e)}"
|
|
119
|
+
return result
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
for sql in statements:
|
|
123
|
+
stmt_result = _exec_stmt(cursor, sql)
|
|
124
|
+
result["statements"].append(stmt_result)
|
|
125
|
+
finally:
|
|
126
|
+
if cursor:
|
|
127
|
+
try:
|
|
128
|
+
cursor.close()
|
|
129
|
+
except Exception:
|
|
130
|
+
pass
|
|
131
|
+
if conn:
|
|
132
|
+
try:
|
|
133
|
+
conn.close()
|
|
134
|
+
except Exception:
|
|
135
|
+
pass
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
# With explicit --database, use DBConnection (creates DB + USE)
|
|
139
|
+
db = DBConnection(config, database)
|
|
140
|
+
try:
|
|
141
|
+
db.connect()
|
|
142
|
+
except Exception as e:
|
|
143
|
+
result["error"] = f"Connection failed: {str(e)}"
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
for sql in statements:
|
|
148
|
+
stmt_result = _exec_stmt(db.cursor, sql)
|
|
149
|
+
result["statements"].append(stmt_result)
|
|
150
|
+
finally:
|
|
151
|
+
db.close()
|
|
152
|
+
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
# Execute in parallel
|
|
156
|
+
results = {}
|
|
157
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
158
|
+
max_workers=len(configs)) as pool:
|
|
159
|
+
futures = {pool.submit(_exec_on_dbms, c): c for c in configs}
|
|
160
|
+
for fut in concurrent.futures.as_completed(futures):
|
|
161
|
+
r = fut.result()
|
|
162
|
+
results[r["name"]] = r
|
|
163
|
+
|
|
164
|
+
return CommandResult.success(
|
|
165
|
+
"exec",
|
|
166
|
+
{
|
|
167
|
+
"sql": sql_text[:500], # Truncate for JSON
|
|
168
|
+
"total_statements": len(statements),
|
|
169
|
+
"database": database,
|
|
170
|
+
"dbms_targets": [c.name for c in configs],
|
|
171
|
+
"results": results,
|
|
172
|
+
},
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _exec_stmt(cursor, sql: str) -> dict:
|
|
177
|
+
"""Execute a single SQL statement and return the result dict."""
|
|
178
|
+
import time as _time
|
|
179
|
+
stmt_result = {
|
|
180
|
+
"sql": sql,
|
|
181
|
+
"columns": None,
|
|
182
|
+
"rows": None,
|
|
183
|
+
"error": None,
|
|
184
|
+
"affected_rows": 0,
|
|
185
|
+
"elapsed_ms": 0,
|
|
186
|
+
}
|
|
187
|
+
try:
|
|
188
|
+
t0 = _time.monotonic()
|
|
189
|
+
cursor.execute(sql)
|
|
190
|
+
if cursor.description:
|
|
191
|
+
stmt_result["columns"] = [
|
|
192
|
+
desc[0] for desc in cursor.description
|
|
193
|
+
]
|
|
194
|
+
rows = cursor.fetchall()
|
|
195
|
+
stmt_result["rows"] = [
|
|
196
|
+
[_format_val(c) for c in row]
|
|
197
|
+
for row in rows
|
|
198
|
+
]
|
|
199
|
+
stmt_result["row_count"] = len(rows)
|
|
200
|
+
else:
|
|
201
|
+
stmt_result["affected_rows"] = cursor.rowcount or 0
|
|
202
|
+
t1 = _time.monotonic()
|
|
203
|
+
stmt_result["elapsed_ms"] = round((t1 - t0) * 1000, 3)
|
|
204
|
+
except Exception as e:
|
|
205
|
+
t1 = _time.monotonic()
|
|
206
|
+
stmt_result["error"] = str(e)
|
|
207
|
+
stmt_result["elapsed_ms"] = round((t1 - t0) * 1000, 3)
|
|
208
|
+
return stmt_result
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _format_val(value) -> str:
|
|
212
|
+
"""Format a cell value for JSON serialization."""
|
|
213
|
+
if value is None:
|
|
214
|
+
return "NULL"
|
|
215
|
+
if isinstance(value, bytes):
|
|
216
|
+
return value.decode("utf-8", errors="replace")
|
|
217
|
+
if isinstance(value, bool):
|
|
218
|
+
return "1" if value else "0"
|
|
219
|
+
return str(value)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Handler for the 'interactive' subcommand (and aliases 'repl', 'i').
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from .result import CommandResult
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from .output import OutputFormatter
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def handle_interactive(args, output: "OutputFormatter") -> CommandResult:
|
|
15
|
+
"""
|
|
16
|
+
Handle the 'interactive' subcommand.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
args: Parsed command-line arguments
|
|
20
|
+
output: Output formatter
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
CommandResult with session summary
|
|
24
|
+
"""
|
|
25
|
+
import os
|
|
26
|
+
import logging
|
|
27
|
+
from ..config import load_config, filter_configs
|
|
28
|
+
from ..interactive import InteractiveSession, BenchInteractiveSession
|
|
29
|
+
from ..executor import ensure_service
|
|
30
|
+
|
|
31
|
+
# Load config
|
|
32
|
+
if not os.path.isfile(args.config):
|
|
33
|
+
return CommandResult.failure(
|
|
34
|
+
f"Config file not found: {args.config}",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
all_configs = load_config(args.config)
|
|
38
|
+
if not all_configs:
|
|
39
|
+
return CommandResult.failure(
|
|
40
|
+
f"No databases configured in {args.config}",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Filter configs
|
|
44
|
+
if args.dbms:
|
|
45
|
+
try:
|
|
46
|
+
configs = filter_configs(all_configs, args.dbms)
|
|
47
|
+
except ValueError as e:
|
|
48
|
+
return CommandResult.failure(str(e))
|
|
49
|
+
else:
|
|
50
|
+
# Auto-detect reachable DBMS
|
|
51
|
+
reachable_configs = []
|
|
52
|
+
for config in all_configs:
|
|
53
|
+
if ensure_service(config):
|
|
54
|
+
reachable_configs.append(config)
|
|
55
|
+
|
|
56
|
+
if not reachable_configs:
|
|
57
|
+
return CommandResult.failure(
|
|
58
|
+
"No reachable DBMS found. Check your dbms_config.json"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
configs = reachable_configs
|
|
62
|
+
|
|
63
|
+
if not configs:
|
|
64
|
+
return CommandResult.failure("No databases selected")
|
|
65
|
+
|
|
66
|
+
# Start interactive session
|
|
67
|
+
# Note: For JSON output mode, we still launch interactive but inform user
|
|
68
|
+
if output.format == "json":
|
|
69
|
+
# In JSON mode, inform user that interactive mode is intended for human use
|
|
70
|
+
return CommandResult.success(
|
|
71
|
+
"interactive",
|
|
72
|
+
{
|
|
73
|
+
"message": "Interactive mode launched",
|
|
74
|
+
"note": "Interactive mode is designed for human users. Run without -j/--json for best experience.",
|
|
75
|
+
"dbms_targets": [c.name for c in configs],
|
|
76
|
+
"database": args.database,
|
|
77
|
+
"output_dir": os.path.abspath(args.output_dir),
|
|
78
|
+
"serve": args.serve,
|
|
79
|
+
"port": args.port,
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# For human mode, actually launch the interactive session
|
|
84
|
+
try:
|
|
85
|
+
# Import the existing interactive logic from old CLI
|
|
86
|
+
from ..cli import _enter_interactive, parse_args
|
|
87
|
+
|
|
88
|
+
# Build args for legacy interactive mode
|
|
89
|
+
legacy_args = parse_args([
|
|
90
|
+
"-i",
|
|
91
|
+
"--config", args.config,
|
|
92
|
+
"--database", args.database,
|
|
93
|
+
"--output-dir", args.output_dir,
|
|
94
|
+
])
|
|
95
|
+
|
|
96
|
+
# Use filtered configs (either user-specified or auto-detected reachable)
|
|
97
|
+
legacy_args.dbms = ",".join(c.name for c in configs)
|
|
98
|
+
if args.serve:
|
|
99
|
+
legacy_args.serve = args.serve
|
|
100
|
+
if args.port:
|
|
101
|
+
legacy_args.port = args.port
|
|
102
|
+
|
|
103
|
+
# Launch interactive session
|
|
104
|
+
exit_code = _enter_interactive(legacy_args)
|
|
105
|
+
|
|
106
|
+
return CommandResult.success(
|
|
107
|
+
"interactive",
|
|
108
|
+
{
|
|
109
|
+
"exit_code": exit_code,
|
|
110
|
+
"message": "Interactive session ended",
|
|
111
|
+
},
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
except KeyboardInterrupt:
|
|
115
|
+
return CommandResult.success(
|
|
116
|
+
"interactive",
|
|
117
|
+
{
|
|
118
|
+
"message": "Interactive session interrupted",
|
|
119
|
+
},
|
|
120
|
+
)
|
|
121
|
+
except Exception as e:
|
|
122
|
+
return CommandResult.failure(
|
|
123
|
+
f"Interactive session failed: {str(e)}",
|
|
124
|
+
)
|