rosetta-sql 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/generate_csv_data.py +83 -0
- benchmark/import_data.py +168 -0
- rosetta/__init__.py +3 -0
- rosetta/__main__.py +8 -0
- rosetta/benchmark.py +1678 -0
- rosetta/buglist.py +108 -0
- rosetta/cli/__init__.py +11 -0
- rosetta/cli/config_cmd.py +243 -0
- rosetta/cli/exec.py +219 -0
- rosetta/cli/interactive_cmd.py +124 -0
- rosetta/cli/list_cmd.py +215 -0
- rosetta/cli/main.py +617 -0
- rosetta/cli/output.py +545 -0
- rosetta/cli/result.py +61 -0
- rosetta/cli/result_cmd.py +247 -0
- rosetta/cli/run.py +625 -0
- rosetta/cli/status.py +161 -0
- rosetta/comparator.py +205 -0
- rosetta/config.py +139 -0
- rosetta/executor.py +403 -0
- rosetta/flamegraph.py +630 -0
- rosetta/interactive.py +1790 -0
- rosetta/models.py +197 -0
- rosetta/parser.py +308 -0
- rosetta/reporter/__init__.py +1 -0
- rosetta/reporter/bench_html.py +1457 -0
- rosetta/reporter/bench_text.py +162 -0
- rosetta/reporter/history.py +1686 -0
- rosetta/reporter/html.py +644 -0
- rosetta/reporter/text.py +110 -0
- rosetta/runner.py +3089 -0
- rosetta/ui.py +736 -0
- rosetta/whitelist.py +161 -0
- rosetta_sql-1.0.0.dist-info/LICENSE +21 -0
- rosetta_sql-1.0.0.dist-info/METADATA +379 -0
- rosetta_sql-1.0.0.dist-info/RECORD +42 -0
- rosetta_sql-1.0.0.dist-info/WHEEL +5 -0
- rosetta_sql-1.0.0.dist-info/entry_points.txt +2 -0
- rosetta_sql-1.0.0.dist-info/top_level.txt +4 -0
- skills/rosetta/scripts/install_rosetta.py +469 -0
- skills/rosetta/scripts/rosetta_wrapper.py +377 -0
- tests/test_cli.py +749 -0
rosetta/executor.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""SQL execution engine for Rosetta."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
import socket
|
|
6
|
+
import subprocess
|
|
7
|
+
import time
|
|
8
|
+
import traceback
|
|
9
|
+
from typing import List
|
|
10
|
+
|
|
11
|
+
from .models import DBMSConfig, Statement, StmtResult, StmtType
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger("rosetta")
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import mysql.connector
|
|
17
|
+
except ImportError:
|
|
18
|
+
mysql_connector_available = False
|
|
19
|
+
else:
|
|
20
|
+
mysql_connector_available = True
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import pymysql
|
|
24
|
+
except ImportError:
|
|
25
|
+
pymysql_available = False
|
|
26
|
+
else:
|
|
27
|
+
pymysql_available = True
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def format_cell(value) -> str:
|
|
31
|
+
"""Format a single cell value for output."""
|
|
32
|
+
if value is None:
|
|
33
|
+
return "NULL"
|
|
34
|
+
if isinstance(value, bytes):
|
|
35
|
+
return value.decode("utf-8", errors="replace")
|
|
36
|
+
if isinstance(value, bool):
|
|
37
|
+
return "1" if value else "0"
|
|
38
|
+
return str(value)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def format_result(stmt: Statement, result: StmtResult,
|
|
42
|
+
dbms_config: DBMSConfig) -> List[str]:
|
|
43
|
+
"""Format a statement result into lines matching MTR .result style."""
|
|
44
|
+
output: List[str] = []
|
|
45
|
+
|
|
46
|
+
if stmt.stmt_type == StmtType.ECHO:
|
|
47
|
+
output.append(stmt.text)
|
|
48
|
+
return output
|
|
49
|
+
|
|
50
|
+
# Prefix the first line of the SQL with the source line number so that
|
|
51
|
+
# duplicate SQL statements can be uniquely identified in reports.
|
|
52
|
+
sql = stmt.text
|
|
53
|
+
sql_lines = sql.split("\n")
|
|
54
|
+
for i, sql_line in enumerate(sql_lines):
|
|
55
|
+
if i == 0:
|
|
56
|
+
output.append(f"[L{stmt.line_no}] {sql_line}")
|
|
57
|
+
else:
|
|
58
|
+
output.append(sql_line)
|
|
59
|
+
|
|
60
|
+
if result.error:
|
|
61
|
+
if stmt.expected_error:
|
|
62
|
+
output.append(f"ERROR: {result.error}")
|
|
63
|
+
else:
|
|
64
|
+
output.append(f"ERROR (unexpected): {result.error}")
|
|
65
|
+
return output
|
|
66
|
+
|
|
67
|
+
if result.columns and result.rows is not None:
|
|
68
|
+
output.append("\t".join(result.columns))
|
|
69
|
+
for row in result.rows:
|
|
70
|
+
output.append("\t".join(format_cell(c) for c in row))
|
|
71
|
+
|
|
72
|
+
if result.warnings:
|
|
73
|
+
output.append("Warnings:")
|
|
74
|
+
for w in result.warnings:
|
|
75
|
+
output.append(w)
|
|
76
|
+
|
|
77
|
+
return output
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class DBConnection:
|
|
81
|
+
"""Wraps a MySQL-protocol database connection."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, config: DBMSConfig, database: str):
|
|
84
|
+
self.config = config
|
|
85
|
+
self.database = database
|
|
86
|
+
self.conn = None
|
|
87
|
+
self.cursor = None
|
|
88
|
+
self._query_timeout = 0
|
|
89
|
+
self._skip_patterns = [re.compile(p, re.IGNORECASE)
|
|
90
|
+
for p in config.skip_patterns]
|
|
91
|
+
|
|
92
|
+
def connect(self, query_timeout: int = 0):
|
|
93
|
+
"""Connect to the database.
|
|
94
|
+
|
|
95
|
+
Only establishes the connection and switches to the target database.
|
|
96
|
+
Does NOT drop or create the database — all DDL must be explicit
|
|
97
|
+
in setup/teardown SQL defined by the user.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
query_timeout: Query timeout in seconds. If > 0, set max_execution_time.
|
|
101
|
+
Pass -1 to reuse the previously saved value (for reconnect).
|
|
102
|
+
"""
|
|
103
|
+
if query_timeout >= 0:
|
|
104
|
+
self._query_timeout = query_timeout
|
|
105
|
+
qt = self._query_timeout
|
|
106
|
+
kwargs = dict(
|
|
107
|
+
host=self.config.host,
|
|
108
|
+
port=self.config.port,
|
|
109
|
+
user=self.config.user,
|
|
110
|
+
password=self.config.password,
|
|
111
|
+
connect_timeout=10, # Connection timeout in seconds
|
|
112
|
+
read_timeout=max(60, qt * 2) if qt > 0 else 60,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if self.config.driver == "mysql.connector":
|
|
116
|
+
if not mysql_connector_available:
|
|
117
|
+
raise ImportError(
|
|
118
|
+
"mysql-connector-python is not installed. "
|
|
119
|
+
"Install via: pip install mysql-connector-python"
|
|
120
|
+
)
|
|
121
|
+
# Enable LOCAL INFILE for LOAD DATA operations
|
|
122
|
+
kwargs["allow_local_infile"] = True
|
|
123
|
+
self.conn = mysql.connector.connect(**kwargs)
|
|
124
|
+
else:
|
|
125
|
+
if not pymysql_available:
|
|
126
|
+
raise ImportError(
|
|
127
|
+
"PyMySQL is not installed. "
|
|
128
|
+
"Install via: pip install pymysql"
|
|
129
|
+
)
|
|
130
|
+
# Enable LOCAL INFILE for LOAD DATA operations
|
|
131
|
+
kwargs["local_infile"] = True
|
|
132
|
+
self.conn = pymysql.connect(**kwargs)
|
|
133
|
+
|
|
134
|
+
# Enable autocommit - use method for pymysql, attribute for mysql.connector
|
|
135
|
+
if hasattr(self.conn, 'autocommit') and callable(self.conn.autocommit):
|
|
136
|
+
self.conn.autocommit(True) # pymysql style
|
|
137
|
+
else:
|
|
138
|
+
self.conn.autocommit = True # mysql.connector style
|
|
139
|
+
self.cursor = self.conn.cursor()
|
|
140
|
+
|
|
141
|
+
# Ensure the database exists, then switch to it
|
|
142
|
+
self.cursor.execute(
|
|
143
|
+
f"CREATE DATABASE IF NOT EXISTS `{self.database}`")
|
|
144
|
+
self.cursor.execute(f"USE `{self.database}`")
|
|
145
|
+
|
|
146
|
+
# Set query timeout at database level
|
|
147
|
+
if qt > 0:
|
|
148
|
+
timeout_ms = qt * 1000
|
|
149
|
+
# Try different timeout settings for various DBMS
|
|
150
|
+
for sql in [
|
|
151
|
+
f"SET SESSION max_execution_time = {timeout_ms}", # MySQL/TiDB
|
|
152
|
+
f"SET SESSION tidb_max_execution_time = {timeout_ms}", # TiDB specific
|
|
153
|
+
]:
|
|
154
|
+
try:
|
|
155
|
+
self.cursor.execute(sql)
|
|
156
|
+
except Exception:
|
|
157
|
+
pass # Ignore if not supported
|
|
158
|
+
|
|
159
|
+
for sql in self.config.init_sql:
|
|
160
|
+
try:
|
|
161
|
+
self.cursor.execute(sql)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
log.warning("[%s] init_sql failed: %s — %s",
|
|
164
|
+
self.config.name, sql, e)
|
|
165
|
+
|
|
166
|
+
def close(self):
|
|
167
|
+
if self.cursor:
|
|
168
|
+
try:
|
|
169
|
+
self.cursor.close()
|
|
170
|
+
except Exception:
|
|
171
|
+
pass
|
|
172
|
+
if self.conn:
|
|
173
|
+
try:
|
|
174
|
+
self.conn.close()
|
|
175
|
+
except Exception:
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
def should_skip(self, sql: str) -> bool:
|
|
179
|
+
"""Check if this SQL should be skipped for this DBMS."""
|
|
180
|
+
for pat in self._skip_patterns:
|
|
181
|
+
if pat.search(sql):
|
|
182
|
+
return True
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
def _is_connection_lost(self, err: Exception) -> bool:
|
|
186
|
+
"""Check if the error indicates a lost connection."""
|
|
187
|
+
err_str = str(err)
|
|
188
|
+
code = (getattr(err, 'args', (None,))[0]
|
|
189
|
+
if hasattr(err, 'args') else None)
|
|
190
|
+
if code in (0, 2006, 2013):
|
|
191
|
+
return True
|
|
192
|
+
if "Lost connection" in err_str or "gone away" in err_str:
|
|
193
|
+
return True
|
|
194
|
+
if "Connection refused" in err_str:
|
|
195
|
+
return True
|
|
196
|
+
# Connection object became None (e.g. after socket timeout)
|
|
197
|
+
if "NoneType" in err_str and ("settimeout" in err_str or "attribute" in err_str):
|
|
198
|
+
return True
|
|
199
|
+
if self.conn is None or self.cursor is None:
|
|
200
|
+
return True
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
def reconnect(self):
|
|
204
|
+
"""Attempt to reconnect after a lost connection."""
|
|
205
|
+
max_retries = 3
|
|
206
|
+
for attempt in range(max_retries):
|
|
207
|
+
try:
|
|
208
|
+
self.close()
|
|
209
|
+
time.sleep(2 ** attempt)
|
|
210
|
+
self.connect(query_timeout=-1) # reuse saved timeout
|
|
211
|
+
log.info("[%s] Reconnected successfully (attempt %d)",
|
|
212
|
+
self.config.name, attempt + 1)
|
|
213
|
+
return True
|
|
214
|
+
except Exception as e:
|
|
215
|
+
log.warning("[%s] Reconnect attempt %d failed: %s",
|
|
216
|
+
self.config.name, attempt + 1, e)
|
|
217
|
+
log.error("[%s] All reconnect attempts failed", self.config.name)
|
|
218
|
+
return False
|
|
219
|
+
|
|
220
|
+
def execute(self, sql: str, sort_result: bool = False) -> StmtResult:
|
|
221
|
+
"""Execute a SQL statement and capture the result."""
|
|
222
|
+
result = StmtResult(stmt=Statement(StmtType.SQL, sql, 0))
|
|
223
|
+
try:
|
|
224
|
+
self.cursor.execute(sql)
|
|
225
|
+
|
|
226
|
+
if self.cursor.description:
|
|
227
|
+
result.columns = [desc[0]
|
|
228
|
+
for desc in self.cursor.description]
|
|
229
|
+
rows = self.cursor.fetchall()
|
|
230
|
+
if sort_result:
|
|
231
|
+
rows = sorted(rows,
|
|
232
|
+
key=lambda r: [str(c) for c in r])
|
|
233
|
+
result.rows = rows
|
|
234
|
+
else:
|
|
235
|
+
result.affected_rows = self.cursor.rowcount or 0
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
self.cursor.execute("SHOW WARNINGS")
|
|
239
|
+
warns = self.cursor.fetchall()
|
|
240
|
+
if warns:
|
|
241
|
+
result.warnings = [
|
|
242
|
+
f"{w[0]}\t{w[1]}\t{w[2]}" for w in warns
|
|
243
|
+
]
|
|
244
|
+
except Exception:
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
except Exception as e:
|
|
248
|
+
result.error = str(e)
|
|
249
|
+
if self._is_connection_lost(e):
|
|
250
|
+
log.warning("[%s] Connection lost, attempting reconnect...",
|
|
251
|
+
self.config.name)
|
|
252
|
+
if self.reconnect():
|
|
253
|
+
try:
|
|
254
|
+
self.cursor.execute(f"USE `{self.database}`")
|
|
255
|
+
except Exception:
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
return result
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def check_port(host: str, port: int, timeout: float = 3.0) -> bool:
|
|
262
|
+
"""Check if a TCP port is reachable."""
|
|
263
|
+
try:
|
|
264
|
+
with socket.create_connection((host, port), timeout=timeout):
|
|
265
|
+
return True
|
|
266
|
+
except (OSError, socket.timeout):
|
|
267
|
+
return False
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def ensure_service(config: DBMSConfig) -> bool:
|
|
271
|
+
"""Ensure the DBMS service is up; try restart if configured.
|
|
272
|
+
|
|
273
|
+
Returns True if reachable, False otherwise.
|
|
274
|
+
"""
|
|
275
|
+
name = config.name
|
|
276
|
+
if check_port(config.host, config.port):
|
|
277
|
+
return True
|
|
278
|
+
|
|
279
|
+
log.warning("[%s] Port %s:%d is not reachable",
|
|
280
|
+
name, config.host, config.port)
|
|
281
|
+
|
|
282
|
+
if not config.restart_cmd:
|
|
283
|
+
log.error("[%s] No restart_cmd configured, cannot recover", name)
|
|
284
|
+
return False
|
|
285
|
+
|
|
286
|
+
log.info("[%s] Attempting restart via: %s", name, config.restart_cmd)
|
|
287
|
+
try:
|
|
288
|
+
subprocess.run(
|
|
289
|
+
config.restart_cmd, shell=True,
|
|
290
|
+
timeout=60, check=False,
|
|
291
|
+
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
|
292
|
+
)
|
|
293
|
+
except Exception as e:
|
|
294
|
+
log.error("[%s] restart_cmd failed: %s", name, e)
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
for attempt in range(10):
|
|
298
|
+
time.sleep(3)
|
|
299
|
+
if check_port(config.host, config.port):
|
|
300
|
+
log.info("[%s] Service is back up after restart", name)
|
|
301
|
+
return True
|
|
302
|
+
log.info("[%s] Waiting for service... (%d/10)", name, attempt + 1)
|
|
303
|
+
|
|
304
|
+
log.error("[%s] Service did not come back after restart", name)
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def run_on_dbms(config: DBMSConfig, statements: List[Statement],
|
|
309
|
+
database: str,
|
|
310
|
+
should_skip_fn=None,
|
|
311
|
+
on_connect=None,
|
|
312
|
+
on_progress=None,
|
|
313
|
+
on_done=None) -> List[str]:
|
|
314
|
+
"""Execute all statements on a single DBMS and return output lines.
|
|
315
|
+
|
|
316
|
+
Connects to the database and executes the given statements.
|
|
317
|
+
Does NOT automatically drop or recreate the database — all DDL must
|
|
318
|
+
be explicit in the statements themselves.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
config: DBMS connection config.
|
|
322
|
+
statements: Parsed statements to execute.
|
|
323
|
+
database: Test database name.
|
|
324
|
+
should_skip_fn: Optional callable(stmt) -> bool for global skips.
|
|
325
|
+
on_connect: Optional callback(name, success, msg) called after connect.
|
|
326
|
+
on_progress: Optional callback(error: bool) called per statement.
|
|
327
|
+
on_done: Optional callback(name, executed, errors) called when done.
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
List of output lines, or None if connection failed.
|
|
331
|
+
"""
|
|
332
|
+
name = config.name
|
|
333
|
+
|
|
334
|
+
if not ensure_service(config):
|
|
335
|
+
log.error("[%s] Service unavailable, skipping", name)
|
|
336
|
+
if on_connect:
|
|
337
|
+
on_connect(name, False, "Service unavailable")
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
db = DBConnection(config, database)
|
|
341
|
+
output_lines: List[str] = []
|
|
342
|
+
|
|
343
|
+
try:
|
|
344
|
+
db.connect()
|
|
345
|
+
log.debug("[%s] Connected, using database '%s'", name, database)
|
|
346
|
+
if on_connect:
|
|
347
|
+
on_connect(name, True, f"Connected ({config.host}:{config.port})")
|
|
348
|
+
except Exception as e:
|
|
349
|
+
log.error("[%s] Connection failed: %s", name, e)
|
|
350
|
+
if on_connect:
|
|
351
|
+
on_connect(name, False, str(e))
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
total = len(statements)
|
|
355
|
+
executed = 0
|
|
356
|
+
errors = 0
|
|
357
|
+
|
|
358
|
+
try:
|
|
359
|
+
for i, stmt in enumerate(statements):
|
|
360
|
+
if stmt.stmt_type == StmtType.ECHO:
|
|
361
|
+
output_lines.extend(
|
|
362
|
+
format_result(stmt, StmtResult(stmt=stmt), config)
|
|
363
|
+
)
|
|
364
|
+
if on_progress:
|
|
365
|
+
on_progress(error=False)
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
if should_skip_fn and should_skip_fn(stmt):
|
|
369
|
+
if on_progress:
|
|
370
|
+
on_progress(error=False)
|
|
371
|
+
continue
|
|
372
|
+
|
|
373
|
+
if db.should_skip(stmt.text):
|
|
374
|
+
if on_progress:
|
|
375
|
+
on_progress(error=False)
|
|
376
|
+
continue
|
|
377
|
+
|
|
378
|
+
result = db.execute(stmt.text, sort_result=stmt.sort_result)
|
|
379
|
+
result.stmt = stmt
|
|
380
|
+
output_lines.extend(format_result(stmt, result, config))
|
|
381
|
+
executed += 1
|
|
382
|
+
|
|
383
|
+
has_error = bool(result.error and not stmt.expected_error)
|
|
384
|
+
if has_error:
|
|
385
|
+
errors += 1
|
|
386
|
+
log.warning("[%s] Error at line %d: %s — %s",
|
|
387
|
+
name, stmt.line_no,
|
|
388
|
+
stmt.text[:80], result.error)
|
|
389
|
+
|
|
390
|
+
if on_progress:
|
|
391
|
+
on_progress(error=has_error)
|
|
392
|
+
|
|
393
|
+
except Exception as e:
|
|
394
|
+
log.error("[%s] Fatal error: %s", name, e)
|
|
395
|
+
log.error(traceback.format_exc())
|
|
396
|
+
output_lines.append(f"FATAL ERROR: {e}")
|
|
397
|
+
finally:
|
|
398
|
+
db.close()
|
|
399
|
+
|
|
400
|
+
log.debug("[%s] Done: %d executed, %d errors", name, executed, errors)
|
|
401
|
+
if on_done:
|
|
402
|
+
on_done(name, executed, errors)
|
|
403
|
+
return output_lines
|