rosetta-sql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rosetta/executor.py ADDED
@@ -0,0 +1,403 @@
1
+ """SQL execution engine for Rosetta."""
2
+
3
+ import logging
4
+ import re
5
+ import socket
6
+ import subprocess
7
+ import time
8
+ import traceback
9
+ from typing import List
10
+
11
+ from .models import DBMSConfig, Statement, StmtResult, StmtType
12
+
13
+ log = logging.getLogger("rosetta")
14
+
15
+ try:
16
+ import mysql.connector
17
+ except ImportError:
18
+ mysql_connector_available = False
19
+ else:
20
+ mysql_connector_available = True
21
+
22
+ try:
23
+ import pymysql
24
+ except ImportError:
25
+ pymysql_available = False
26
+ else:
27
+ pymysql_available = True
28
+
29
+
30
+ def format_cell(value) -> str:
31
+ """Format a single cell value for output."""
32
+ if value is None:
33
+ return "NULL"
34
+ if isinstance(value, bytes):
35
+ return value.decode("utf-8", errors="replace")
36
+ if isinstance(value, bool):
37
+ return "1" if value else "0"
38
+ return str(value)
39
+
40
+
41
+ def format_result(stmt: Statement, result: StmtResult,
42
+ dbms_config: DBMSConfig) -> List[str]:
43
+ """Format a statement result into lines matching MTR .result style."""
44
+ output: List[str] = []
45
+
46
+ if stmt.stmt_type == StmtType.ECHO:
47
+ output.append(stmt.text)
48
+ return output
49
+
50
+ # Prefix the first line of the SQL with the source line number so that
51
+ # duplicate SQL statements can be uniquely identified in reports.
52
+ sql = stmt.text
53
+ sql_lines = sql.split("\n")
54
+ for i, sql_line in enumerate(sql_lines):
55
+ if i == 0:
56
+ output.append(f"[L{stmt.line_no}] {sql_line}")
57
+ else:
58
+ output.append(sql_line)
59
+
60
+ if result.error:
61
+ if stmt.expected_error:
62
+ output.append(f"ERROR: {result.error}")
63
+ else:
64
+ output.append(f"ERROR (unexpected): {result.error}")
65
+ return output
66
+
67
+ if result.columns and result.rows is not None:
68
+ output.append("\t".join(result.columns))
69
+ for row in result.rows:
70
+ output.append("\t".join(format_cell(c) for c in row))
71
+
72
+ if result.warnings:
73
+ output.append("Warnings:")
74
+ for w in result.warnings:
75
+ output.append(w)
76
+
77
+ return output
78
+
79
+
80
+ class DBConnection:
81
+ """Wraps a MySQL-protocol database connection."""
82
+
83
+ def __init__(self, config: DBMSConfig, database: str):
84
+ self.config = config
85
+ self.database = database
86
+ self.conn = None
87
+ self.cursor = None
88
+ self._query_timeout = 0
89
+ self._skip_patterns = [re.compile(p, re.IGNORECASE)
90
+ for p in config.skip_patterns]
91
+
92
+ def connect(self, query_timeout: int = 0):
93
+ """Connect to the database.
94
+
95
+ Only establishes the connection and switches to the target database.
96
+ Does NOT drop or create the database — all DDL must be explicit
97
+ in setup/teardown SQL defined by the user.
98
+
99
+ Args:
100
+ query_timeout: Query timeout in seconds. If > 0, set max_execution_time.
101
+ Pass -1 to reuse the previously saved value (for reconnect).
102
+ """
103
+ if query_timeout >= 0:
104
+ self._query_timeout = query_timeout
105
+ qt = self._query_timeout
106
+ kwargs = dict(
107
+ host=self.config.host,
108
+ port=self.config.port,
109
+ user=self.config.user,
110
+ password=self.config.password,
111
+ connect_timeout=10, # Connection timeout in seconds
112
+ read_timeout=max(60, qt * 2) if qt > 0 else 60,
113
+ )
114
+
115
+ if self.config.driver == "mysql.connector":
116
+ if not mysql_connector_available:
117
+ raise ImportError(
118
+ "mysql-connector-python is not installed. "
119
+ "Install via: pip install mysql-connector-python"
120
+ )
121
+ # Enable LOCAL INFILE for LOAD DATA operations
122
+ kwargs["allow_local_infile"] = True
123
+ self.conn = mysql.connector.connect(**kwargs)
124
+ else:
125
+ if not pymysql_available:
126
+ raise ImportError(
127
+ "PyMySQL is not installed. "
128
+ "Install via: pip install pymysql"
129
+ )
130
+ # Enable LOCAL INFILE for LOAD DATA operations
131
+ kwargs["local_infile"] = True
132
+ self.conn = pymysql.connect(**kwargs)
133
+
134
+ # Enable autocommit - use method for pymysql, attribute for mysql.connector
135
+ if hasattr(self.conn, 'autocommit') and callable(self.conn.autocommit):
136
+ self.conn.autocommit(True) # pymysql style
137
+ else:
138
+ self.conn.autocommit = True # mysql.connector style
139
+ self.cursor = self.conn.cursor()
140
+
141
+ # Ensure the database exists, then switch to it
142
+ self.cursor.execute(
143
+ f"CREATE DATABASE IF NOT EXISTS `{self.database}`")
144
+ self.cursor.execute(f"USE `{self.database}`")
145
+
146
+ # Set query timeout at database level
147
+ if qt > 0:
148
+ timeout_ms = qt * 1000
149
+ # Try different timeout settings for various DBMS
150
+ for sql in [
151
+ f"SET SESSION max_execution_time = {timeout_ms}", # MySQL/TiDB
152
+ f"SET SESSION tidb_max_execution_time = {timeout_ms}", # TiDB specific
153
+ ]:
154
+ try:
155
+ self.cursor.execute(sql)
156
+ except Exception:
157
+ pass # Ignore if not supported
158
+
159
+ for sql in self.config.init_sql:
160
+ try:
161
+ self.cursor.execute(sql)
162
+ except Exception as e:
163
+ log.warning("[%s] init_sql failed: %s — %s",
164
+ self.config.name, sql, e)
165
+
166
+ def close(self):
167
+ if self.cursor:
168
+ try:
169
+ self.cursor.close()
170
+ except Exception:
171
+ pass
172
+ if self.conn:
173
+ try:
174
+ self.conn.close()
175
+ except Exception:
176
+ pass
177
+
178
+ def should_skip(self, sql: str) -> bool:
179
+ """Check if this SQL should be skipped for this DBMS."""
180
+ for pat in self._skip_patterns:
181
+ if pat.search(sql):
182
+ return True
183
+ return False
184
+
185
+ def _is_connection_lost(self, err: Exception) -> bool:
186
+ """Check if the error indicates a lost connection."""
187
+ err_str = str(err)
188
+ code = (getattr(err, 'args', (None,))[0]
189
+ if hasattr(err, 'args') else None)
190
+ if code in (0, 2006, 2013):
191
+ return True
192
+ if "Lost connection" in err_str or "gone away" in err_str:
193
+ return True
194
+ if "Connection refused" in err_str:
195
+ return True
196
+ # Connection object became None (e.g. after socket timeout)
197
+ if "NoneType" in err_str and ("settimeout" in err_str or "attribute" in err_str):
198
+ return True
199
+ if self.conn is None or self.cursor is None:
200
+ return True
201
+ return False
202
+
203
+ def reconnect(self):
204
+ """Attempt to reconnect after a lost connection."""
205
+ max_retries = 3
206
+ for attempt in range(max_retries):
207
+ try:
208
+ self.close()
209
+ time.sleep(2 ** attempt)
210
+ self.connect(query_timeout=-1) # reuse saved timeout
211
+ log.info("[%s] Reconnected successfully (attempt %d)",
212
+ self.config.name, attempt + 1)
213
+ return True
214
+ except Exception as e:
215
+ log.warning("[%s] Reconnect attempt %d failed: %s",
216
+ self.config.name, attempt + 1, e)
217
+ log.error("[%s] All reconnect attempts failed", self.config.name)
218
+ return False
219
+
220
+ def execute(self, sql: str, sort_result: bool = False) -> StmtResult:
221
+ """Execute a SQL statement and capture the result."""
222
+ result = StmtResult(stmt=Statement(StmtType.SQL, sql, 0))
223
+ try:
224
+ self.cursor.execute(sql)
225
+
226
+ if self.cursor.description:
227
+ result.columns = [desc[0]
228
+ for desc in self.cursor.description]
229
+ rows = self.cursor.fetchall()
230
+ if sort_result:
231
+ rows = sorted(rows,
232
+ key=lambda r: [str(c) for c in r])
233
+ result.rows = rows
234
+ else:
235
+ result.affected_rows = self.cursor.rowcount or 0
236
+
237
+ try:
238
+ self.cursor.execute("SHOW WARNINGS")
239
+ warns = self.cursor.fetchall()
240
+ if warns:
241
+ result.warnings = [
242
+ f"{w[0]}\t{w[1]}\t{w[2]}" for w in warns
243
+ ]
244
+ except Exception:
245
+ pass
246
+
247
+ except Exception as e:
248
+ result.error = str(e)
249
+ if self._is_connection_lost(e):
250
+ log.warning("[%s] Connection lost, attempting reconnect...",
251
+ self.config.name)
252
+ if self.reconnect():
253
+ try:
254
+ self.cursor.execute(f"USE `{self.database}`")
255
+ except Exception:
256
+ pass
257
+
258
+ return result
259
+
260
+
261
+ def check_port(host: str, port: int, timeout: float = 3.0) -> bool:
262
+ """Check if a TCP port is reachable."""
263
+ try:
264
+ with socket.create_connection((host, port), timeout=timeout):
265
+ return True
266
+ except (OSError, socket.timeout):
267
+ return False
268
+
269
+
270
+ def ensure_service(config: DBMSConfig) -> bool:
271
+ """Ensure the DBMS service is up; try restart if configured.
272
+
273
+ Returns True if reachable, False otherwise.
274
+ """
275
+ name = config.name
276
+ if check_port(config.host, config.port):
277
+ return True
278
+
279
+ log.warning("[%s] Port %s:%d is not reachable",
280
+ name, config.host, config.port)
281
+
282
+ if not config.restart_cmd:
283
+ log.error("[%s] No restart_cmd configured, cannot recover", name)
284
+ return False
285
+
286
+ log.info("[%s] Attempting restart via: %s", name, config.restart_cmd)
287
+ try:
288
+ subprocess.run(
289
+ config.restart_cmd, shell=True,
290
+ timeout=60, check=False,
291
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
292
+ )
293
+ except Exception as e:
294
+ log.error("[%s] restart_cmd failed: %s", name, e)
295
+ return False
296
+
297
+ for attempt in range(10):
298
+ time.sleep(3)
299
+ if check_port(config.host, config.port):
300
+ log.info("[%s] Service is back up after restart", name)
301
+ return True
302
+ log.info("[%s] Waiting for service... (%d/10)", name, attempt + 1)
303
+
304
+ log.error("[%s] Service did not come back after restart", name)
305
+ return False
306
+
307
+
308
+ def run_on_dbms(config: DBMSConfig, statements: List[Statement],
309
+ database: str,
310
+ should_skip_fn=None,
311
+ on_connect=None,
312
+ on_progress=None,
313
+ on_done=None) -> List[str]:
314
+ """Execute all statements on a single DBMS and return output lines.
315
+
316
+ Connects to the database and executes the given statements.
317
+ Does NOT automatically drop or recreate the database — all DDL must
318
+ be explicit in the statements themselves.
319
+
320
+ Args:
321
+ config: DBMS connection config.
322
+ statements: Parsed statements to execute.
323
+ database: Test database name.
324
+ should_skip_fn: Optional callable(stmt) -> bool for global skips.
325
+ on_connect: Optional callback(name, success, msg) called after connect.
326
+ on_progress: Optional callback(error: bool) called per statement.
327
+ on_done: Optional callback(name, executed, errors) called when done.
328
+
329
+ Returns:
330
+ List of output lines, or None if connection failed.
331
+ """
332
+ name = config.name
333
+
334
+ if not ensure_service(config):
335
+ log.error("[%s] Service unavailable, skipping", name)
336
+ if on_connect:
337
+ on_connect(name, False, "Service unavailable")
338
+ return None
339
+
340
+ db = DBConnection(config, database)
341
+ output_lines: List[str] = []
342
+
343
+ try:
344
+ db.connect()
345
+ log.debug("[%s] Connected, using database '%s'", name, database)
346
+ if on_connect:
347
+ on_connect(name, True, f"Connected ({config.host}:{config.port})")
348
+ except Exception as e:
349
+ log.error("[%s] Connection failed: %s", name, e)
350
+ if on_connect:
351
+ on_connect(name, False, str(e))
352
+ return None
353
+
354
+ total = len(statements)
355
+ executed = 0
356
+ errors = 0
357
+
358
+ try:
359
+ for i, stmt in enumerate(statements):
360
+ if stmt.stmt_type == StmtType.ECHO:
361
+ output_lines.extend(
362
+ format_result(stmt, StmtResult(stmt=stmt), config)
363
+ )
364
+ if on_progress:
365
+ on_progress(error=False)
366
+ continue
367
+
368
+ if should_skip_fn and should_skip_fn(stmt):
369
+ if on_progress:
370
+ on_progress(error=False)
371
+ continue
372
+
373
+ if db.should_skip(stmt.text):
374
+ if on_progress:
375
+ on_progress(error=False)
376
+ continue
377
+
378
+ result = db.execute(stmt.text, sort_result=stmt.sort_result)
379
+ result.stmt = stmt
380
+ output_lines.extend(format_result(stmt, result, config))
381
+ executed += 1
382
+
383
+ has_error = bool(result.error and not stmt.expected_error)
384
+ if has_error:
385
+ errors += 1
386
+ log.warning("[%s] Error at line %d: %s — %s",
387
+ name, stmt.line_no,
388
+ stmt.text[:80], result.error)
389
+
390
+ if on_progress:
391
+ on_progress(error=has_error)
392
+
393
+ except Exception as e:
394
+ log.error("[%s] Fatal error: %s", name, e)
395
+ log.error(traceback.format_exc())
396
+ output_lines.append(f"FATAL ERROR: {e}")
397
+ finally:
398
+ db.close()
399
+
400
+ log.debug("[%s] Done: %d executed, %d errors", name, executed, errors)
401
+ if on_done:
402
+ on_done(name, executed, errors)
403
+ return output_lines