rosetta-sql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rosetta/interactive.py ADDED
@@ -0,0 +1,1790 @@
1
+ """Interactive terminal session for Rosetta.
2
+
3
+ Allows users to repeatedly submit MTR test paths and execute them without
4
+ restarting the program. Base parameters (config, dbms, baseline, etc.) are
5
+ fixed at launch; only the test file path changes between iterations.
6
+ """
7
+
8
+ import glob
9
+ import http.server
10
+ import json
11
+ import logging
12
+ import os
13
+ import socket
14
+ import subprocess
15
+ import threading
16
+ import time as _time
17
+ from pathlib import Path
18
+ from typing import Dict, List, Optional
19
+
20
+ from prompt_toolkit import PromptSession
21
+ from prompt_toolkit.completion import Completer, Completion
22
+ from prompt_toolkit.formatted_text import HTML
23
+ from prompt_toolkit.history import InMemoryHistory, FileHistory
24
+ from prompt_toolkit.styles import Style
25
+
26
+ from .config import DEFAULT_TEST_DB
27
+ from .models import DBMSConfig
28
+ from .reporter.history import generate_index_html
29
+ from .ui import (console, flush_all, print_error, print_info,
30
+ print_summary, print_warning)
31
+
32
+ log = logging.getLogger("rosetta")
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Path auto-completion
37
+ # ---------------------------------------------------------------------------
38
+
39
+ class TestFileCompleter(Completer):
40
+ """Auto-complete .test file paths and directories."""
41
+
42
+ def get_completions(self, document, complete_event):
43
+ text = document.text_before_cursor.strip()
44
+ if not text:
45
+ text = "./"
46
+ expanded = os.path.expanduser(text)
47
+ if os.path.isdir(expanded):
48
+ if not expanded.endswith("/"):
49
+ expanded += "/"
50
+ pattern = expanded + "*"
51
+
52
+ for path in sorted(glob.glob(pattern)):
53
+ if os.path.isdir(path):
54
+ yield Completion(path + "/", start_position=-len(text),
55
+ display=os.path.basename(path) + "/",
56
+ display_meta="dir")
57
+ elif path.endswith(".test"):
58
+ yield Completion(path, start_position=-len(text),
59
+ display=os.path.basename(path),
60
+ display_meta="test")
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Prompt style
65
+ # ---------------------------------------------------------------------------
66
+
67
+ _PROMPT_STYLE = Style.from_dict({
68
+ "prompt": "bold cyan",
69
+ "path": "bold white",
70
+ "placeholder": "dim #888888",
71
+ })
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # HTTP server management
76
+ # ---------------------------------------------------------------------------
77
+
78
+ class _SilentHTTPServer(http.server.HTTPServer):
79
+ """HTTPServer that silently handles connection errors."""
80
+
81
+ def handle_error(self, request, client_address):
82
+ """Silently ignore connection reset/broken pipe errors."""
83
+ # These are normal when clients disconnect abruptly
84
+ pass
85
+
86
+
87
+ class _APIHandler(http.server.SimpleHTTPRequestHandler):
88
+ """HTTP handler with whitelist/buglist API endpoints and suppressed logging."""
89
+
90
+ # Class-level reference set by ReportServer before creating instances.
91
+ _whitelist = None # type: ignore
92
+ _buglist = None # type: ignore
93
+ _configs: List[DBMSConfig] = []
94
+ _all_configs: List[DBMSConfig] = []
95
+ _database: str = ""
96
+
97
+ def log_message(self, format, *args): # noqa: A002
98
+ pass # Suppress all request logs
99
+
100
+ def end_headers(self): # noqa: N802
101
+ # Disable caching for all responses
102
+ self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
103
+ self.send_header("Pragma", "no-cache")
104
+ self.send_header("Expires", "0")
105
+ super().end_headers()
106
+
107
+ # -- GET routing (redirect / → /index.html, serve API) -----------------
108
+
109
+ def do_GET(self): # noqa: N802
110
+ if self.path == "/":
111
+ self.send_response(302)
112
+ self.send_header("Location", "/index.html")
113
+ self.end_headers()
114
+ return
115
+ if self.path == "/api/dbms":
116
+ self._handle_dbms_list()
117
+ return
118
+ super().do_GET()
119
+
120
+ # -- CORS ---------------------------------------------------------------
121
+
122
+ def _send_cors_headers(self):
123
+ self.send_header("Access-Control-Allow-Origin", "*")
124
+ self.send_header("Access-Control-Allow-Methods",
125
+ "GET, POST, OPTIONS")
126
+ self.send_header("Access-Control-Allow-Headers", "Content-Type")
127
+
128
+ def do_OPTIONS(self): # noqa: N802
129
+ self.send_response(200)
130
+ self._send_cors_headers()
131
+ self.end_headers()
132
+
133
+ # -- API routing --------------------------------------------------------
134
+
135
+ def do_POST(self): # noqa: N802
136
+ if self.path.startswith("/api/whitelist/"):
137
+ self._handle_whitelist_api()
138
+ elif self.path.startswith("/api/buglist/"):
139
+ self._handle_buglist_api()
140
+ elif self.path == "/api/execute":
141
+ self._handle_execute_api()
142
+ elif self.path == "/api/execute/stream":
143
+ self._handle_execute_stream_api()
144
+ elif self.path == "/api/runs/delete":
145
+ self._handle_runs_delete_api()
146
+ else:
147
+ self.send_error(404)
148
+
149
+ def _read_json(self) -> dict:
150
+ length = int(self.headers.get("Content-Length", 0))
151
+ body = self.rfile.read(length) if length else b"{}"
152
+ return json.loads(body)
153
+
154
+ def _respond_json(self, data: dict, status: int = 200):
155
+ payload = json.dumps(data, ensure_ascii=False).encode("utf-8")
156
+ self.send_response(status)
157
+ self.send_header("Content-Type", "application/json; charset=utf-8")
158
+ self._send_cors_headers()
159
+ self.send_header("Content-Length", str(len(payload)))
160
+ self.end_headers()
161
+ self.wfile.write(payload)
162
+
163
+ def _handle_whitelist_api(self):
164
+ action = self.path.split("/api/whitelist/", 1)[-1].strip("/")
165
+ wl = self._whitelist
166
+ if wl is None:
167
+ self._respond_json({"ok": False, "error": "whitelist not loaded"},
168
+ 500)
169
+ return
170
+ try:
171
+ body = self._read_json()
172
+ except Exception:
173
+ body = {}
174
+
175
+ if action == "add":
176
+ fp = body.get("fingerprint", "")
177
+ if not fp:
178
+ self._respond_json({"ok": False,
179
+ "error": "fingerprint required"}, 400)
180
+ return
181
+ entry = wl.add(
182
+ fingerprint=fp,
183
+ stmt=body.get("stmt", ""),
184
+ dbms_a=body.get("dbms_a", ""),
185
+ dbms_b=body.get("dbms_b", ""),
186
+ block=body.get("block", 0),
187
+ reason=body.get("reason", ""),
188
+ )
189
+ self._respond_json({"ok": True, "entry": entry})
190
+
191
+ elif action == "remove":
192
+ fp = body.get("fingerprint", "")
193
+ removed = wl.remove(fp) if fp else False
194
+ self._respond_json({"ok": removed})
195
+
196
+ elif action == "clear":
197
+ wl.clear()
198
+ self._respond_json({"ok": True})
199
+
200
+ elif action == "list":
201
+ self._respond_json({"ok": True, "entries": wl.entries})
202
+
203
+ else:
204
+ self._respond_json({"ok": False, "error": "unknown action"}, 404)
205
+
206
+ def _handle_buglist_api(self):
207
+ action = self.path.split("/api/buglist/", 1)[-1].strip("/")
208
+ bl = self._buglist
209
+ if bl is None:
210
+ self._respond_json({"ok": False, "error": "buglist not loaded"},
211
+ 500)
212
+ return
213
+ try:
214
+ body = self._read_json()
215
+ except Exception:
216
+ body = {}
217
+
218
+ if action == "add":
219
+ fp = body.get("fingerprint", "")
220
+ if not fp:
221
+ self._respond_json({"ok": False,
222
+ "error": "fingerprint required"}, 400)
223
+ return
224
+ entry = bl.add(
225
+ fingerprint=fp,
226
+ stmt=body.get("stmt", ""),
227
+ dbms_a=body.get("dbms_a", ""),
228
+ dbms_b=body.get("dbms_b", ""),
229
+ block=body.get("block", 0),
230
+ reason=body.get("reason", ""),
231
+ )
232
+ self._respond_json({"ok": True, "entry": entry})
233
+
234
+ elif action == "remove":
235
+ fp = body.get("fingerprint", "")
236
+ removed = bl.remove(fp) if fp else False
237
+ self._respond_json({"ok": removed})
238
+
239
+ elif action == "clear":
240
+ bl.clear()
241
+ self._respond_json({"ok": True})
242
+
243
+ elif action == "list":
244
+ self._respond_json({"ok": True, "entries": bl.entries})
245
+
246
+ else:
247
+ self._respond_json({"ok": False, "error": "unknown action"}, 404)
248
+
249
+ # -- Runs delete API ----------------------------------------------------
250
+
251
+ def _handle_runs_delete_api(self):
252
+ """POST /api/runs/delete — delete a run directory.
253
+
254
+ Request body: {"dir_name": "test_name_20250101_120000"}
255
+ Response: {"ok": true} or {"ok": false, "error": "..."}
256
+ """
257
+ import shutil
258
+
259
+ try:
260
+ body = self._read_json()
261
+ except Exception:
262
+ self._respond_json({"ok": False, "error": "invalid JSON"}, 400)
263
+ return
264
+
265
+ dir_name = body.get("dir_name", "")
266
+ if not dir_name:
267
+ self._respond_json({"ok": False, "error": "dir_name required"}, 400)
268
+ return
269
+
270
+ # Security: prevent path traversal
271
+ if ".." in dir_name or "/" in dir_name or "\\" in dir_name:
272
+ self._respond_json({"ok": False, "error": "invalid dir_name"}, 400)
273
+ return
274
+
275
+ # Get the serving directory (output_dir)
276
+ # The handler is created with directory= output_dir
277
+ target_dir = os.path.join(self.directory, dir_name)
278
+
279
+ if not os.path.isdir(target_dir):
280
+ self._respond_json({"ok": False, "error": "directory not found"}, 404)
281
+ return
282
+
283
+ try:
284
+ shutil.rmtree(target_dir)
285
+ log.info("Deleted run directory: %s", target_dir)
286
+ # Regenerate index.html after deletion
287
+ from .reporter.history import generate_index_html
288
+ generate_index_html(self.directory)
289
+ self._respond_json({"ok": True})
290
+ except Exception as e:
291
+ log.error("Failed to delete directory %s: %s", target_dir, e)
292
+ self._respond_json({"ok": False, "error": str(e)}, 500)
293
+
294
+ # -- Playground API -----------------------------------------------------
295
+
296
+ def _handle_dbms_list(self):
297
+ """GET /api/dbms — return all DBMS from config with active flags."""
298
+ active_names = {c.name for c in self._configs}
299
+ dbms_list = [{"name": c.name, "host": c.host, "port": c.port,
300
+ "active": c.name in active_names}
301
+ for c in self._all_configs]
302
+ self._respond_json({
303
+ "ok": True,
304
+ "database": self._database,
305
+ "dbms": dbms_list,
306
+ })
307
+
308
+ def _handle_execute_api(self):
309
+ """POST /api/execute — execute SQL on selected DBMS targets.
310
+
311
+ Request body: {"sql": "...", "dbms": ["tdsql", "mysql"]}
312
+ Response: {"ok": true, "results": {"tdsql": {...}, "mysql": {...}}}
313
+ """
314
+ import concurrent.futures
315
+
316
+ from .executor import DBConnection, check_port
317
+
318
+ try:
319
+ body = self._read_json()
320
+ except Exception:
321
+ self._respond_json({"ok": False, "error": "invalid JSON"}, 400)
322
+ return
323
+
324
+ sql_text = body.get("sql", "").strip()
325
+ if not sql_text:
326
+ self._respond_json({"ok": False, "error": "sql is required"}, 400)
327
+ return
328
+
329
+ requested_dbms = body.get("dbms", [])
330
+ configs_map = {c.name: c for c in self._all_configs}
331
+
332
+ if not requested_dbms:
333
+ requested_dbms = list(configs_map.keys())
334
+
335
+ targets = []
336
+ for name in requested_dbms:
337
+ if name in configs_map:
338
+ targets.append(configs_map[name])
339
+
340
+ if not targets:
341
+ self._respond_json(
342
+ {"ok": False, "error": "no valid DBMS targets"}, 400)
343
+ return
344
+
345
+ database = self._database
346
+
347
+ # Reuse the full MTR parser to extract SQL statements,
348
+ # filtering out all MTR directives (--echo, --error, etc.)
349
+ from .parser import TestFileParser
350
+ parsed = TestFileParser.parse_text(sql_text)
351
+ stmts = [s.text for s in parsed]
352
+
353
+ def _exec_on_dbms(config):
354
+ """Execute all statements on one DBMS, return result dict."""
355
+ result = {
356
+ "name": config.name,
357
+ "statements": [],
358
+ "error": None,
359
+ }
360
+
361
+ if not check_port(config.host, config.port):
362
+ result["error"] = (f"Cannot reach {config.host}:"
363
+ f"{config.port}")
364
+ return result
365
+
366
+ db = DBConnection(config, database)
367
+ try:
368
+ db.connect()
369
+ except Exception as e:
370
+ result["error"] = f"Connection failed: {e}"
371
+ return result
372
+
373
+ try:
374
+ for sql in stmts:
375
+ stmt_result = {"sql": sql, "columns": None,
376
+ "rows": None, "error": None,
377
+ "affected_rows": 0,
378
+ "elapsed_ms": 0}
379
+ try:
380
+ t0 = _time.monotonic()
381
+ db.cursor.execute(sql)
382
+ if db.cursor.description:
383
+ stmt_result["columns"] = [
384
+ desc[0]
385
+ for desc in db.cursor.description
386
+ ]
387
+ rows = db.cursor.fetchall()
388
+ # Convert to serializable format
389
+ stmt_result["rows"] = [
390
+ [_format_val(c) for c in row]
391
+ for row in rows
392
+ ]
393
+ else:
394
+ stmt_result["affected_rows"] = (
395
+ db.cursor.rowcount or 0)
396
+ t1 = _time.monotonic()
397
+ stmt_result["elapsed_ms"] = round(
398
+ (t1 - t0) * 1000, 3)
399
+ except Exception as e:
400
+ t1 = _time.monotonic()
401
+ stmt_result["error"] = str(e)
402
+ # Extract error code if available (e.g., MySQL error code)
403
+ # Most DB-API exceptions have error code in args[0]
404
+ error_code = None
405
+ if hasattr(e, 'args') and e.args and isinstance(e.args[0], int):
406
+ error_code = e.args[0]
407
+ elif hasattr(e, 'errno'):
408
+ error_code = getattr(e, 'errno')
409
+ stmt_result["error_code"] = error_code
410
+ stmt_result["elapsed_ms"] = round(
411
+ (t1 - t0) * 1000, 3)
412
+
413
+ result["statements"].append(stmt_result)
414
+ finally:
415
+ db.close()
416
+
417
+ return result
418
+
419
+ # Execute in parallel across all DBMS targets
420
+ results = {}
421
+ with concurrent.futures.ThreadPoolExecutor(
422
+ max_workers=len(targets)) as pool:
423
+ futures = {pool.submit(_exec_on_dbms, c): c for c in targets}
424
+ for fut in concurrent.futures.as_completed(futures):
425
+ r = fut.result()
426
+ results[r["name"]] = r
427
+
428
+ self._respond_json({"ok": True, "results": results})
429
+
430
+ def _handle_execute_stream_api(self):
431
+ """POST /api/execute/stream — execute SQL on selected DBMS targets
432
+ with Server-Sent Events progress updates.
433
+
434
+ Request body: {"sql": "...", "dbms": ["tdsql", "mysql"]}
435
+ SSE events:
436
+ - event: progress data: {"name": "...", "index": N, "total": N, "result": {...}}
437
+ - event: done data: {"ok": true}
438
+ - event: error data: {"error": "..."}
439
+ """
440
+ import concurrent.futures
441
+
442
+ from .executor import DBConnection, check_port
443
+
444
+ try:
445
+ body = self._read_json()
446
+ except Exception:
447
+ self._respond_json({"ok": False, "error": "invalid JSON"}, 400)
448
+ return
449
+
450
+ sql_text = body.get("sql", "").strip()
451
+ if not sql_text:
452
+ self._respond_json({"ok": False, "error": "sql is required"}, 400)
453
+ return
454
+
455
+ requested_dbms = body.get("dbms", [])
456
+ configs_map = {c.name: c for c in self._all_configs}
457
+
458
+ if not requested_dbms:
459
+ requested_dbms = list(configs_map.keys())
460
+
461
+ targets = []
462
+ for name in requested_dbms:
463
+ if name in configs_map:
464
+ targets.append(configs_map[name])
465
+
466
+ if not targets:
467
+ self._respond_json(
468
+ {"ok": False, "error": "no valid DBMS targets"}, 400)
469
+ return
470
+
471
+ database = self._database
472
+
473
+ from .parser import TestFileParser
474
+ parsed = TestFileParser.parse_text(sql_text)
475
+ stmts = [s.text for s in parsed]
476
+ total = len(targets)
477
+
478
+ # Set up SSE response headers
479
+ self.send_response(200)
480
+ self.send_header("Content-Type", "text/event-stream; charset=utf-8")
481
+ self.send_header("Cache-Control", "no-cache")
482
+ self.send_header("Connection", "keep-alive")
483
+ self._send_cors_headers()
484
+ self.end_headers()
485
+
486
+ sse_lock = threading.Lock()
487
+
488
+ def _send_sse(event: str, data: dict):
489
+ """Send a single SSE event to the client (thread-safe)."""
490
+ with sse_lock:
491
+ try:
492
+ payload = json.dumps(data, ensure_ascii=False)
493
+ self.wfile.write(
494
+ f"event: {event}\ndata: {payload}\n\n".encode("utf-8"))
495
+ self.wfile.flush()
496
+ except Exception:
497
+ pass
498
+
499
+ def _exec_on_dbms(config, index):
500
+ """Execute all statements on one DBMS, return result dict."""
501
+ result = {
502
+ "name": config.name,
503
+ "statements": [],
504
+ "error": None,
505
+ }
506
+
507
+ if not check_port(config.host, config.port):
508
+ result["error"] = (f"Cannot reach {config.host}:"
509
+ f"{config.port}")
510
+ _send_sse("progress", {
511
+ "name": config.name,
512
+ "index": index,
513
+ "total": total,
514
+ "result": result,
515
+ })
516
+ return result
517
+
518
+ db = DBConnection(config, database)
519
+ try:
520
+ db.connect()
521
+ except Exception as e:
522
+ result["error"] = f"Connection failed: {e}"
523
+ _send_sse("progress", {
524
+ "name": config.name,
525
+ "index": index,
526
+ "total": total,
527
+ "result": result,
528
+ })
529
+ return result
530
+
531
+ try:
532
+ for sql in stmts:
533
+ stmt_result = {"sql": sql, "columns": None,
534
+ "rows": None, "error": None,
535
+ "affected_rows": 0,
536
+ "elapsed_ms": 0}
537
+ try:
538
+ t0 = _time.monotonic()
539
+ db.cursor.execute(sql)
540
+ if db.cursor.description:
541
+ stmt_result["columns"] = [
542
+ desc[0]
543
+ for desc in db.cursor.description
544
+ ]
545
+ rows = db.cursor.fetchall()
546
+ stmt_result["rows"] = [
547
+ [_format_val(c) for c in row]
548
+ for row in rows
549
+ ]
550
+ else:
551
+ stmt_result["affected_rows"] = (
552
+ db.cursor.rowcount or 0)
553
+ t1 = _time.monotonic()
554
+ stmt_result["elapsed_ms"] = round(
555
+ (t1 - t0) * 1000, 3)
556
+ except Exception as e:
557
+ t1 = _time.monotonic()
558
+ stmt_result["error"] = str(e)
559
+ error_code = None
560
+ if hasattr(e, 'args') and e.args and isinstance(e.args[0], int):
561
+ error_code = e.args[0]
562
+ elif hasattr(e, 'errno'):
563
+ error_code = getattr(e, 'errno')
564
+ stmt_result["error_code"] = error_code
565
+ stmt_result["elapsed_ms"] = round(
566
+ (t1 - t0) * 1000, 3)
567
+
568
+ result["statements"].append(stmt_result)
569
+ finally:
570
+ db.close()
571
+
572
+ _send_sse("progress", {
573
+ "name": config.name,
574
+ "index": index,
575
+ "total": total,
576
+ "result": result,
577
+ })
578
+ return result
579
+
580
+ # Execute in parallel across all DBMS targets
581
+ try:
582
+ with concurrent.futures.ThreadPoolExecutor(
583
+ max_workers=len(targets)) as pool:
584
+ futures = {}
585
+ for i, c in enumerate(targets):
586
+ futures[pool.submit(_exec_on_dbms, c, i + 1)] = c
587
+ for fut in concurrent.futures.as_completed(futures):
588
+ fut.result() # propagate exceptions if any
589
+
590
+ _send_sse("done", {"ok": True})
591
+ except Exception as e:
592
+ _send_sse("error", {"error": str(e)})
593
+
594
+
595
+ def _format_val(value) -> str:
596
+ """Format a cell value for JSON serialisation."""
597
+ if value is None:
598
+ return "NULL"
599
+ if isinstance(value, bytes):
600
+ return value.decode("utf-8", errors="replace")
601
+ if isinstance(value, bool):
602
+ return "1" if value else "0"
603
+ return str(value)
604
+
605
+
606
+ class ReportServer:
607
+ """Manages a background HTTP server for viewing HTML reports."""
608
+
609
+ def __init__(self, directory: str, port: int = 0, whitelist=None,
610
+ buglist=None, configs: Optional[List[DBMSConfig]] = None,
611
+ all_configs: Optional[List[DBMSConfig]] = None,
612
+ database: str = ""):
613
+ self.directory = os.path.abspath(directory)
614
+ self.port = port
615
+ self.whitelist = whitelist
616
+ self.buglist = buglist
617
+ self.configs = configs or []
618
+ self.all_configs = all_configs or self.configs
619
+ self.database = database
620
+ self._server: Optional[http.server.HTTPServer] = None
621
+ self._thread: Optional[threading.Thread] = None
622
+
623
+ @property
624
+ def running(self) -> bool:
625
+ return self._thread is not None and self._thread.is_alive()
626
+
627
+ @property
628
+ def base_url(self) -> str:
629
+ return f"http://localhost:{self.port}"
630
+
631
+ def start(self) -> str:
632
+ """Start the server and return the base URL."""
633
+ if self.running:
634
+ return self.base_url
635
+ if self.port == 0:
636
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
637
+ s.bind(("", 0))
638
+ self.port = s.getsockname()[1]
639
+ os.makedirs(self.directory, exist_ok=True)
640
+ # Pre-generate index/whitelist/buglist pages so / redirects work
641
+ from .reporter.history import (generate_buglist_html,
642
+ generate_index_html,
643
+ generate_playground_html,
644
+ generate_whitelist_html)
645
+ generate_index_html(self.directory)
646
+ generate_whitelist_html(self.directory)
647
+ generate_buglist_html(self.directory)
648
+ generate_playground_html(self.directory)
649
+ directory = self.directory
650
+ wl = self.whitelist
651
+ bl = self.buglist
652
+ # Inject references into handler class
653
+ _APIHandler._whitelist = wl
654
+ _APIHandler._buglist = bl
655
+ _APIHandler._configs = self.configs
656
+ _APIHandler._all_configs = self.all_configs
657
+ _APIHandler._database = self.database
658
+ handler = lambda *a, **kw: _APIHandler(
659
+ *a, directory=directory, **kw)
660
+ self._server = _SilentHTTPServer(
661
+ ("0.0.0.0", self.port), handler)
662
+ self._thread = threading.Thread(target=self._server.serve_forever,
663
+ daemon=True)
664
+ self._thread.start()
665
+ return self.base_url
666
+
667
+ def stop(self):
668
+ if self._server:
669
+ t = threading.Thread(target=self._server.shutdown, daemon=True)
670
+ t.start()
671
+ t.join(timeout=3)
672
+ # Close the listening socket so the port is released immediately.
673
+ # shutdown() only stops serve_forever(); without server_close()
674
+ # the socket stays open and the port remains occupied.
675
+ try:
676
+ self._server.server_close()
677
+ except Exception:
678
+ pass
679
+ self._server = None
680
+ self._thread = None
681
+
682
+
683
+ # ---------------------------------------------------------------------------
684
+ # Interactive session
685
+ # ---------------------------------------------------------------------------
686
+
687
+ class InteractiveSession:
688
+ """Interactive REPL that accepts repeated test file submissions."""
689
+
690
+ COMMANDS = {
691
+ "help": "Show available commands",
692
+ "status": "Show current configuration",
693
+ "history": "Show executed tests in this session",
694
+ "server": "Show report server URL",
695
+ "open": "Open latest HTML report in IDE",
696
+ "clear": "Clear the screen",
697
+ "back": "Back to mode selection (also: b)",
698
+ "quit": "Exit (also: exit, q)",
699
+ }
700
+
701
+ def __init__(self, configs: List[DBMSConfig], output_dir: str,
702
+ database: str = DEFAULT_TEST_DB,
703
+ baseline: Optional[str] = None,
704
+ skip_explain: bool = False,
705
+ skip_analyze: bool = False,
706
+ skip_show_create: bool = False,
707
+ output_format: str = "all",
708
+ serve: bool = False, port: int = 19527,
709
+ all_configs: Optional[List[DBMSConfig]] = None):
710
+ self.configs = configs
711
+ self.all_configs = all_configs or configs
712
+ self.output_dir = os.path.abspath(output_dir)
713
+ self.database = database
714
+ self.baseline = baseline
715
+ self.skip_explain = skip_explain
716
+ self.skip_analyze = skip_analyze
717
+ self.skip_show_create = skip_show_create
718
+ self.output_format = output_format
719
+ self.serve = serve
720
+ self.port = port
721
+ self._run_history: List[Dict] = []
722
+ self._report_server: Optional[ReportServer] = None
723
+ # Whitelist — shared across all runs in this session
724
+ from .whitelist import Whitelist
725
+ self._whitelist = Whitelist(self.output_dir)
726
+ # Buglist — shared across all runs in this session
727
+ from .buglist import Buglist
728
+ self._buglist = Buglist(self.output_dir)
729
+
730
+ # -- server helpers -----------------------------------------------------
731
+
732
+ def _ensure_server(self) -> Optional[ReportServer]:
733
+ if not self.serve:
734
+ return None
735
+ if self._report_server and self._report_server.running:
736
+ return self._report_server
737
+ # Stop previous server if it exists but is no longer running
738
+ if self._report_server:
739
+ self._report_server.stop()
740
+ self._report_server = ReportServer(self.output_dir, self.port,
741
+ whitelist=self._whitelist,
742
+ buglist=self._buglist,
743
+ configs=self.configs,
744
+ all_configs=self.all_configs,
745
+ database=self.database)
746
+ try:
747
+ self._report_server.start()
748
+ return self._report_server
749
+ except OSError as e:
750
+ console.print(f" [red]✗[/red] Server failed: {e}")
751
+ return None
752
+
753
+ def _open_in_ide(self, url: str):
754
+ try:
755
+ subprocess.Popen(["code", "--open-url", url],
756
+ stdout=subprocess.DEVNULL,
757
+ stderr=subprocess.DEVNULL)
758
+ except FileNotFoundError:
759
+ pass
760
+
761
+ # -- test execution -----------------------------------------------------
762
+
763
+ def _run_test(self, test_file: str) -> bool:
764
+ from .runner import RosettaRunner
765
+ from .reporter.history import generate_buglist_html, generate_whitelist_html
766
+
767
+ if not os.path.isfile(test_file):
768
+ print_error(f"Test file not found: {test_file}")
769
+ flush_all()
770
+ return False
771
+
772
+ run_stamp = _time.strftime("%Y%m%d_%H%M%S")
773
+ test_name = Path(test_file).stem
774
+ run_dir = os.path.join(self.output_dir, f"{test_name}_{run_stamp}")
775
+
776
+ # Reload whitelist and buglist to pick up any changes from the web UI
777
+ self._whitelist.load()
778
+ self._buglist.load()
779
+
780
+ print_info("DBMS targets:",
781
+ ", ".join(c.name for c in self.configs))
782
+
783
+ runner = RosettaRunner(
784
+ test_file=test_file, configs=self.configs,
785
+ output_dir=run_dir, database=self.database,
786
+ baseline=self.baseline, skip_explain=self.skip_explain,
787
+ skip_analyze=self.skip_analyze,
788
+ skip_show_create=self.skip_show_create,
789
+ output_format=self.output_format,
790
+ whitelist=self._whitelist,
791
+ buglist=self._buglist)
792
+
793
+ comparisons = runner.run()
794
+
795
+ if not comparisons:
796
+ flush_all()
797
+ self._run_history.append({
798
+ "test": test_file, "time": _time.strftime("%H:%M:%S"),
799
+ "status": "FAIL", "run_dir": run_dir})
800
+ return False
801
+
802
+ # Update 'latest' symlink
803
+ latest_link = os.path.join(self.output_dir, "latest")
804
+ try:
805
+ if os.path.islink(latest_link):
806
+ os.remove(latest_link)
807
+ os.symlink(os.path.basename(run_dir), latest_link)
808
+ except OSError:
809
+ pass
810
+
811
+ generate_index_html(self.output_dir)
812
+ generate_whitelist_html(self.output_dir)
813
+ generate_buglist_html(self.output_dir)
814
+
815
+ # Print whitelist summary
816
+ wl_count = sum(cmp.whitelisted for cmp in comparisons.values())
817
+ if wl_count:
818
+ console.print(
819
+ f" [yellow]⚡ {wl_count} diff(s) matched whitelist"
820
+ f"[/yellow]")
821
+
822
+ # Print bug summary
823
+ bug_count = sum(cmp.bug_marked for cmp in comparisons.values())
824
+ if bug_count:
825
+ console.print(
826
+ f" [red]🐛 {bug_count} diff(s) marked as bug"
827
+ f"[/red]")
828
+
829
+ all_pass = print_summary(comparisons, runner.failed_connections)
830
+ flush_all()
831
+
832
+ passed = all_pass and not runner.failed_connections
833
+ self._run_history.append({
834
+ "test": test_file, "time": _time.strftime("%H:%M:%S"),
835
+ "status": "PASS" if passed else "FAIL", "run_dir": run_dir})
836
+
837
+ # Open in browser
838
+ srv = self._ensure_server()
839
+ if srv:
840
+ html_file = f"{test_name}.html"
841
+ html_path = os.path.join(run_dir, html_file)
842
+ if os.path.isfile(html_path):
843
+ url = (f"{srv.base_url}"
844
+ f"/{os.path.basename(run_dir)}/{html_file}")
845
+ console.print(
846
+ f"\n [cyan]📊 Report:[/cyan] "
847
+ f"[bold link={url}]{url}[/bold link]\n")
848
+ self._open_in_ide(url)
849
+
850
+ return passed
851
+
852
+ # -- command handlers ---------------------------------------------------
853
+
854
+ def _cmd_help(self):
855
+ console.print("\n [bold cyan]Available commands:[/bold cyan]")
856
+ for cmd, desc in self.COMMANDS.items():
857
+ console.print(f" [bold]{cmd:10s}[/bold] {desc}")
858
+ console.print(
859
+ "\n Or enter a [bold].test[/bold] file path to execute.\n")
860
+
861
+ def _cmd_status(self):
862
+ console.print(f"\n [cyan]Config:[/cyan]")
863
+ console.print(
864
+ f" DBMS: "
865
+ f"[bold]{', '.join(c.name for c in self.configs)}[/bold]")
866
+ console.print(f" Baseline: [bold]{self.baseline or 'none'}[/bold]")
867
+ console.print(f" Database: [bold]{self.database}[/bold]")
868
+ console.print(f" Output: [bold]{self.output_dir}[/bold]")
869
+ console.print(f" Format: [bold]{self.output_format}[/bold]")
870
+ console.print(f" Runs: [bold]{len(self._run_history)}[/bold]")
871
+ if self._report_server and self._report_server.running:
872
+ console.print(
873
+ f" Server: "
874
+ f"[bold green]{self._report_server.base_url}[/bold green]")
875
+ console.print()
876
+
877
+ def _cmd_history(self):
878
+ if not self._run_history:
879
+ console.print("\n [dim]No tests executed yet.[/dim]\n")
880
+ return
881
+ console.print(f"\n [bold cyan]Session history "
882
+ f"({len(self._run_history)} runs):[/bold cyan]")
883
+ for i, entry in enumerate(self._run_history, 1):
884
+ status_style = ("green" if entry["status"] == "PASS"
885
+ else "red")
886
+ console.print(
887
+ f" {i:3d}. [{status_style}]{entry['status']:4s}"
888
+ f"[/{status_style}] "
889
+ f"[dim]{entry['time']}[/dim] {entry['test']}")
890
+ console.print()
891
+
892
+ def _cmd_server(self):
893
+ srv = self._ensure_server()
894
+ if srv and srv.running:
895
+ idx_url = f"{srv.base_url}/index.html"
896
+ console.print(
897
+ f"\n [green]●[/green] Server running: "
898
+ f"[bold link={idx_url}]{idx_url}[/bold link]\n")
899
+ else:
900
+ console.print("\n [dim]Server not running "
901
+ "(use --serve to enable).[/dim]\n")
902
+
903
+ def _cmd_open(self):
904
+ latest = os.path.join(self.output_dir, "latest")
905
+ if not os.path.islink(latest):
906
+ console.print("\n [dim]No results yet.[/dim]\n")
907
+ return
908
+ real_dir = os.path.realpath(latest)
909
+ htmls = [f for f in os.listdir(real_dir) if f.endswith(".html")]
910
+ if not htmls:
911
+ console.print("\n [dim]No HTML report found.[/dim]\n")
912
+ return
913
+ srv = self._ensure_server()
914
+ if not srv:
915
+ console.print("\n [dim]Server not available.[/dim]\n")
916
+ return
917
+ url = (f"{srv.base_url}"
918
+ f"/{os.path.basename(real_dir)}/{htmls[0]}")
919
+ console.print(f"\n Opening: [bold]{url}[/bold]\n")
920
+ self._open_in_ide(url)
921
+
922
+ # -- main loop ----------------------------------------------------------
923
+
924
+ def run(self):
925
+ """Start the interactive REPL.
926
+
927
+ Returns ``"back"`` if the user typed ``back``/``b``,
928
+ ``"quit"`` otherwise (including EOF / KeyboardInterrupt).
929
+ """
930
+ os.makedirs(self.output_dir, exist_ok=True)
931
+ session: PromptSession = PromptSession(
932
+ history=FileHistory(os.path.join(self.output_dir, ".rosetta_history")),
933
+ completer=TestFileCompleter(),
934
+ style=_PROMPT_STYLE,
935
+ complete_while_typing=True,
936
+ multiline=False,
937
+ )
938
+
939
+ _placeholder = HTML('<placeholder>Type a path, \'help\', \'back\', or \'quit\'</placeholder>')
940
+ # ║ + 55 chars content + ║
941
+ # ╚ + 55×═ + ╝
942
+ border = "═" * 55
943
+ title = "Rosetta Interactive Mode"
944
+ hint = "Enter .test file paths to execute, or 'help'"
945
+ # Center-pad content to 55 visible characters
946
+ title_line = f" {title} ".center(55)
947
+ hint_line = f" {hint} ".center(55)
948
+ console.print(f" [bold cyan]╔{border}╗[/bold cyan]")
949
+ console.print(f" [bold cyan]║[/bold cyan]"
950
+ f"[bold white]{title_line}[/bold white]"
951
+ f"[bold cyan]║[/bold cyan]")
952
+ console.print(f" [bold cyan]║[/bold cyan]"
953
+ f"[dim]{hint_line}[/dim]"
954
+ f"[bold cyan]║[/bold cyan]")
955
+ console.print(f" [bold cyan]╚{border}╝[/bold cyan]")
956
+
957
+ # Show status
958
+ console.print(
959
+ f" [dim]DBMS:[/dim] "
960
+ f"[bold]{', '.join(c.name for c in self.configs)}[/bold] "
961
+ f"[dim]Baseline:[/dim] "
962
+ f"[bold]{self.baseline or 'auto'}[/bold] "
963
+ f"[dim]Database:[/dim] [bold]{self.database}[/bold]")
964
+
965
+ # Start server early if requested
966
+ srv = self._ensure_server()
967
+ if srv and srv.running:
968
+ console.print(
969
+ f" [dim]Server:[/dim] "
970
+ f"[bold green]{srv.base_url}[/bold green]")
971
+ console.print()
972
+
973
+ run_count = 0
974
+ exit_reason = "quit"
975
+
976
+ while True:
977
+ try:
978
+ prompt_msg = HTML(
979
+ '<prompt>rosetta</prompt> <path>▶</path> ')
980
+ user_input = session.prompt(
981
+ prompt_msg, placeholder=_placeholder).strip()
982
+ except (EOFError, KeyboardInterrupt):
983
+ break
984
+
985
+ if not user_input:
986
+ continue
987
+
988
+ cmd = user_input.lower()
989
+
990
+ # Back to mode selection
991
+ if cmd in ("back", "b"):
992
+ exit_reason = "back"
993
+ break
994
+
995
+ # Exit commands
996
+ if cmd in ("quit", "exit", "q"):
997
+ break
998
+
999
+ # Built-in commands
1000
+ if cmd == "help":
1001
+ self._cmd_help()
1002
+ continue
1003
+ if cmd == "status":
1004
+ self._cmd_status()
1005
+ continue
1006
+ if cmd == "history":
1007
+ self._cmd_history()
1008
+ continue
1009
+ if cmd == "server":
1010
+ self._cmd_server()
1011
+ continue
1012
+ if cmd == "open":
1013
+ self._cmd_open()
1014
+ continue
1015
+ if cmd == "clear":
1016
+ console.clear()
1017
+ continue
1018
+
1019
+ # Treat as file path
1020
+ test_path = os.path.expanduser(user_input)
1021
+ if not os.path.isabs(test_path):
1022
+ test_path = os.path.abspath(test_path)
1023
+
1024
+ run_count += 1
1025
+ console.print()
1026
+ console.rule(
1027
+ f"[bold cyan] Run #{run_count}: "
1028
+ f"{os.path.basename(test_path)} [/bold cyan]")
1029
+ console.print()
1030
+
1031
+ self._run_test(test_path)
1032
+
1033
+ console.print(
1034
+ " [dim]Ready for next test. "
1035
+ "Type a path, 'help', 'back', or 'quit'.[/dim]\n")
1036
+
1037
+ # Cleanup
1038
+ if exit_reason == "back":
1039
+ if self._report_server:
1040
+ self._report_server.stop()
1041
+ else:
1042
+ console.print()
1043
+ if self._run_history:
1044
+ console.print(
1045
+ f" [dim]Session complete: "
1046
+ f"{len(self._run_history)} test(s) executed.[/dim]")
1047
+ if self._report_server:
1048
+ self._report_server.stop()
1049
+ console.print(" [dim]Report server stopped.[/dim]")
1050
+ console.print(" [bold cyan]Goodbye! 👋[/bold cyan]\n")
1051
+
1052
+ return exit_reason
1053
+
1054
+
1055
+ # ---------------------------------------------------------------------------
1056
+ # Benchmark file auto-completion
1057
+ # ---------------------------------------------------------------------------
1058
+
1059
+ class BenchFileCompleter(Completer):
1060
+ """Auto-complete .json / .sql benchmark file paths and directories."""
1061
+
1062
+ def get_completions(self, document, complete_event):
1063
+ text = document.text_before_cursor.strip()
1064
+ if not text:
1065
+ text = "./"
1066
+ expanded = os.path.expanduser(text)
1067
+ if os.path.isdir(expanded):
1068
+ if not expanded.endswith("/"):
1069
+ expanded += "/"
1070
+ pattern = expanded + "*"
1071
+
1072
+ for path in sorted(glob.glob(pattern)):
1073
+ if os.path.isdir(path):
1074
+ yield Completion(path + "/", start_position=-len(text),
1075
+ display=os.path.basename(path) + "/",
1076
+ display_meta="dir")
1077
+ elif path.endswith(".json") or path.endswith(".sql"):
1078
+ yield Completion(path, start_position=-len(text),
1079
+ display=os.path.basename(path),
1080
+ display_meta="bench")
1081
+
1082
+
1083
+ # ---------------------------------------------------------------------------
1084
+ # Benchmark interactive session
1085
+ # ---------------------------------------------------------------------------
1086
+
1087
+ class BenchInteractiveSession:
1088
+ """Interactive REPL for benchmark mode.
1089
+
1090
+ Base parameters (config, dbms, iterations, warmup, concurrency, etc.)
1091
+ are fixed at launch; only the bench file path changes between runs.
1092
+ """
1093
+
1094
+ COMMANDS = {
1095
+ "help": "Show available commands",
1096
+ "status": "Show current configuration",
1097
+ "history": "Show executed benchmarks in this session",
1098
+ "server": "Show report server URL",
1099
+ "open": "Open latest HTML report in IDE",
1100
+ "clear": "Clear the screen",
1101
+ "back": "Back to parameter selection (also: b)",
1102
+ "quit": "Exit (also: exit, q)",
1103
+ }
1104
+
1105
+ def __init__(self, configs: List[DBMSConfig], output_dir: str,
1106
+ database: str = DEFAULT_TEST_DB,
1107
+ iterations: int = 100,
1108
+ warmup: int = 5,
1109
+ concurrency: int = 0,
1110
+ duration: float = 30.0,
1111
+ ramp_up: float = 0.0,
1112
+ bench_filter: Optional[str] = None,
1113
+ repeat: int = 1,
1114
+ parallel_dbms: bool = True,
1115
+ output_format: str = "all",
1116
+ serve: bool = False,
1117
+ port: int = 19527,
1118
+ profile: bool = False,
1119
+ perf_freq: int = 99,
1120
+ query_timeout: int = 5,
1121
+ flamegraph_min_ms: int = 1000,
1122
+ bench_mode: str = "serial"):
1123
+ self.configs = configs
1124
+ self.output_dir = os.path.abspath(output_dir)
1125
+ self.database = database
1126
+ self.iterations = iterations
1127
+ self.warmup = warmup
1128
+ self.concurrency = concurrency
1129
+ self.duration = duration
1130
+ self.ramp_up = ramp_up
1131
+ self.bench_filter = bench_filter
1132
+ self.repeat = max(1, repeat)
1133
+ self.parallel_dbms = parallel_dbms
1134
+ self.output_format = output_format
1135
+ self.serve = serve
1136
+ self.port = port
1137
+ self.profile = profile
1138
+ self.perf_freq = perf_freq
1139
+ self.query_timeout = query_timeout
1140
+ self.flamegraph_min_ms = flamegraph_min_ms
1141
+ self.bench_mode = bench_mode
1142
+ self._run_history: List[Dict] = []
1143
+ self._report_server: Optional[ReportServer] = None
1144
+
1145
+ # -- server helpers -----------------------------------------------------
1146
+
1147
+ def _ensure_server(self) -> Optional[ReportServer]:
1148
+ if not self.serve:
1149
+ return None
1150
+ if self._report_server and self._report_server.running:
1151
+ return self._report_server
1152
+ # Stop previous server if it exists but is no longer running
1153
+ if self._report_server:
1154
+ self._report_server.stop()
1155
+ self._report_server = ReportServer(self.output_dir, self.port)
1156
+ try:
1157
+ self._report_server.start()
1158
+ return self._report_server
1159
+ except OSError as e:
1160
+ console.print(f" [red]✗[/red] Server failed: {e}")
1161
+ return None
1162
+
1163
+ def _open_in_ide(self, url: str):
1164
+ try:
1165
+ subprocess.Popen(["code", "--open-url", url],
1166
+ stdout=subprocess.DEVNULL,
1167
+ stderr=subprocess.DEVNULL)
1168
+ except FileNotFoundError:
1169
+ pass
1170
+
1171
+ # -- bench execution ----------------------------------------------------
1172
+
1173
+ def _run_bench(self, bench_file: str) -> bool:
1174
+ """Execute one benchmark run (possibly with --repeat rounds)."""
1175
+ import threading
1176
+ import time as _time
1177
+
1178
+ from .benchmark import BenchmarkLoader, run_benchmark, BenchWorkload
1179
+ from .models import BenchmarkConfig, WorkloadMode
1180
+ from .reporter.bench_text import write_bench_text_report
1181
+ from .reporter.bench_html import write_bench_html_report
1182
+ from .reporter.history import generate_index_html
1183
+ from .ui import (BenchProgress, flush_all, print_bench_summary,
1184
+ print_error, print_info, print_phase,
1185
+ print_report_file)
1186
+
1187
+ # Determine mode
1188
+ if self.concurrency > 0:
1189
+ mode = WorkloadMode.CONCURRENT
1190
+ else:
1191
+ mode = WorkloadMode.SERIAL
1192
+
1193
+ json_extra_config = {} # Extra config from JSON file
1194
+
1195
+ # Load workload
1196
+ if not os.path.isfile(bench_file):
1197
+ print_error(f"Bench file not found: {bench_file}")
1198
+ flush_all()
1199
+ return False
1200
+
1201
+ try:
1202
+ workload = BenchmarkLoader.from_file(bench_file)
1203
+ except (FileNotFoundError, ValueError) as e:
1204
+ print_error(str(e))
1205
+ flush_all()
1206
+ return False
1207
+
1208
+ # Read extra config from JSON file (database, skip_setup, skip_teardown)
1209
+ json_extra_config = {}
1210
+ if bench_file.endswith('.json'):
1211
+ import json as _json
1212
+ try:
1213
+ with open(bench_file, 'r') as f:
1214
+ json_data = _json.load(f)
1215
+ json_extra_config = {
1216
+ 'database': json_data.get('database'),
1217
+ 'skip_setup': json_data.get('skip_setup'),
1218
+ 'skip_teardown': json_data.get('skip_teardown'),
1219
+ }
1220
+ except Exception:
1221
+ pass
1222
+
1223
+ # Determine skip_setup/skip_teardown: instance attr overrides JSON
1224
+ json_skip_setup = json_extra_config.get('skip_setup')
1225
+ json_skip_teardown = json_extra_config.get('skip_teardown')
1226
+ inst_skip_setup = getattr(self, 'skip_setup', False)
1227
+ inst_skip_teardown = getattr(self, 'skip_teardown', False)
1228
+ final_skip_setup = inst_skip_setup if inst_skip_setup else (json_skip_setup if json_skip_setup is not None else False)
1229
+ final_skip_teardown = inst_skip_teardown if inst_skip_teardown else (json_skip_teardown if json_skip_teardown is not None else False)
1230
+
1231
+ filter_queries = []
1232
+ if self.bench_filter:
1233
+ filter_queries = [
1234
+ n.strip() for n in self.bench_filter.split(",")
1235
+ if n.strip()
1236
+ ]
1237
+
1238
+ bench_cfg = BenchmarkConfig(
1239
+ mode=mode,
1240
+ iterations=self.iterations,
1241
+ warmup=self.warmup,
1242
+ concurrency=self.concurrency if self.concurrency > 0 else 1,
1243
+ duration=self.duration,
1244
+ ramp_up=self.ramp_up,
1245
+ filter_queries=filter_queries,
1246
+ profile=self.profile,
1247
+ perf_freq=self.perf_freq,
1248
+ query_timeout=self.query_timeout,
1249
+ flamegraph_min_ms=self.flamegraph_min_ms,
1250
+ skip_setup=final_skip_setup,
1251
+ skip_teardown=final_skip_teardown,
1252
+ )
1253
+
1254
+ # Apply filter
1255
+ display_workload = workload
1256
+ if filter_queries:
1257
+ try:
1258
+ display_workload = BenchmarkLoader.filter_queries(
1259
+ workload, filter_queries)
1260
+ except ValueError as e:
1261
+ print_error(str(e))
1262
+ flush_all()
1263
+ return False
1264
+
1265
+ # Display plan
1266
+ print_phase("Benchmark", workload.name)
1267
+ print_info("Mode:", mode.name)
1268
+ print_info("DBMS targets:",
1269
+ ", ".join(c.name for c in self.configs))
1270
+ if self.parallel_dbms and len(self.configs) > 1:
1271
+ print_info("DBMS execution:",
1272
+ "[bold green]parallel[/bold green]")
1273
+ elif not self.parallel_dbms and len(self.configs) > 1:
1274
+ print_info("DBMS execution:", "sequential")
1275
+
1276
+ if mode == WorkloadMode.SERIAL:
1277
+ print_info("Queries:",
1278
+ ", ".join(q.name for q in display_workload.queries))
1279
+ print_info("Iterations:",
1280
+ f"{bench_cfg.iterations} "
1281
+ f"Warmup: {bench_cfg.warmup}")
1282
+ else:
1283
+ print_info("Queries:",
1284
+ ", ".join(q.name for q in display_workload.queries))
1285
+ print_info("Concurrency:",
1286
+ f"{bench_cfg.concurrency} "
1287
+ f"Duration: {bench_cfg.duration}s")
1288
+ if filter_queries:
1289
+ print_info("Filter:", ", ".join(filter_queries))
1290
+ if self.repeat > 1:
1291
+ print_info("Repeat:", f"{self.repeat} rounds")
1292
+
1293
+ fmt = self.output_format
1294
+ output_dir = self.output_dir
1295
+ configs = self.configs
1296
+
1297
+ def _run_one_round(round_num: int):
1298
+ """Execute a single benchmark round."""
1299
+ if self.repeat > 1:
1300
+ console.print(
1301
+ f"\n[bold cyan]{'━' * 60}[/bold cyan]")
1302
+ console.print(
1303
+ f"[bold cyan] Round {round_num}/"
1304
+ f"{self.repeat}[/bold cyan]")
1305
+ console.print(
1306
+ f"[bold cyan]{'━' * 60}[/bold cyan]\n")
1307
+
1308
+ run_stamp = _time.strftime("%Y%m%d_%H%M%S")
1309
+ run_dir = os.path.join(
1310
+ output_dir,
1311
+ f"bench_{workload.name}_{run_stamp}")
1312
+ os.makedirs(run_dir, exist_ok=True)
1313
+
1314
+ print_phase("Execute")
1315
+
1316
+ # Progress tracking
1317
+ progress_bars: Dict[str, BenchProgress] = {}
1318
+ _progress_lock = threading.Lock()
1319
+
1320
+ n_queries = len(display_workload.queries)
1321
+ # CONCURRENT mode uses time-based progress
1322
+ is_time_based = (mode == WorkloadMode.CONCURRENT)
1323
+ if mode == WorkloadMode.CONCURRENT:
1324
+ duration = bench_cfg.duration if bench_cfg.duration > 0 else 30.0
1325
+ per_query = 100 # placeholder, not used for time-based
1326
+ else:
1327
+ duration = 0.0
1328
+ per_query = bench_cfg.iterations + bench_cfg.warmup
1329
+
1330
+ # Create progress bars upfront (they will show "setup..." initially)
1331
+ if self.parallel_dbms and len(configs) > 1:
1332
+ for c in configs:
1333
+ bp = BenchProgress(
1334
+ c.name, n_queries, per_query,
1335
+ is_concurrent=is_time_based, duration=duration)
1336
+ bp.__enter__()
1337
+ bp.set_status("[yellow]正在setup...[/yellow]")
1338
+ progress_bars[c.name] = bp
1339
+
1340
+ def on_setup_start(dbms_name):
1341
+ with _progress_lock:
1342
+ if dbms_name not in progress_bars:
1343
+ bp = BenchProgress(
1344
+ dbms_name, n_queries, per_query,
1345
+ is_concurrent=is_time_based, duration=duration)
1346
+ bp.__enter__()
1347
+ bp.set_status("[yellow]正在setup...[/yellow]")
1348
+ progress_bars[dbms_name] = bp
1349
+
1350
+ def on_setup_done(dbms_name, success):
1351
+ bp = progress_bars.get(dbms_name)
1352
+ if bp:
1353
+ if success:
1354
+ bp.set_status("[green]setup完毕[/green]")
1355
+ else:
1356
+ bp.set_status("[red]setup失败 — 跳过该DBMS[/red]")
1357
+ # Close progress bar for failed DBMS
1358
+ bp.__exit__(None, None, None)
1359
+ bp.write_summary_to_buffer()
1360
+
1361
+ def on_dbms_start(dbms_name):
1362
+ with _progress_lock:
1363
+ if dbms_name not in progress_bars:
1364
+ bp = BenchProgress(
1365
+ dbms_name, n_queries, per_query,
1366
+ is_concurrent=is_time_based, duration=duration)
1367
+ bp.__enter__()
1368
+ progress_bars[dbms_name] = bp
1369
+
1370
+ def on_run_start():
1371
+ # Reset timers when query phase begins (all setups complete)
1372
+ # Keep "setup完毕" status visible until queries actually start
1373
+ with _progress_lock:
1374
+ for bp in progress_bars.values():
1375
+ bp.reset_timer()
1376
+ # Record start time for timer thread
1377
+ timer_start_time[0] = _time.monotonic()
1378
+ # Signal timer thread to start updating
1379
+ query_phase_started.set()
1380
+
1381
+ def on_progress(dbms_name, query_name, iteration,
1382
+ total, is_warmup=False):
1383
+ bp = progress_bars.get(dbms_name)
1384
+ if bp:
1385
+ if is_time_based:
1386
+ # In time-based mode (CONCURRENT), update time progress
1387
+ bp.update_time(status=f"[cyan]{query_name}[/cyan]")
1388
+ else:
1389
+ # In serial mode, show per-query iteration progress
1390
+ bp.advance(query_name=query_name,
1391
+ iteration=iteration,
1392
+ total=total,
1393
+ is_warmup=is_warmup)
1394
+
1395
+ def on_dbms_done(dbms_name, dbms_result):
1396
+ bp = progress_bars.get(dbms_name)
1397
+ if bp:
1398
+ bp.set_status(
1399
+ f"[green]{dbms_result.total_queries} queries, "
1400
+ f"{dbms_result.overall_qps:.1f} QPS[/green]")
1401
+ bp.__exit__(None, None, None)
1402
+ bp.write_summary_to_buffer()
1403
+
1404
+ def on_profile_start(dbms_name, query_name):
1405
+ bp = progress_bars.get(dbms_name)
1406
+ if bp:
1407
+ bp.set_status(
1408
+ f"[red]🔥 profiling {query_name}[/red]")
1409
+
1410
+ def on_profile_done(dbms_name, query_name, sample_count):
1411
+ bp = progress_bars.get(dbms_name)
1412
+ if bp:
1413
+ bp.set_status(
1414
+ f"[dim]🔥 {query_name}: "
1415
+ f"{sample_count} samples[/dim]")
1416
+
1417
+ # For time-based mode (CONCURRENT), timer thread updates progress
1418
+ timer_stop_event = None
1419
+ timer_thread = None
1420
+ query_phase_started = threading.Event()
1421
+ timer_start_time = [None] # Will be set in on_run_start
1422
+
1423
+ if is_time_based:
1424
+ timer_stop_event = threading.Event()
1425
+
1426
+ def _timer_update():
1427
+ # Wait until query phase starts (all setups complete)
1428
+ query_phase_started.wait()
1429
+ while not timer_stop_event.is_set():
1430
+ # Check if we've exceeded the duration - stop updating progress
1431
+ # (actual benchmark may take longer due to cleanup)
1432
+ if timer_start_time[0] is not None:
1433
+ elapsed = _time.monotonic() - timer_start_time[0]
1434
+ if elapsed >= duration:
1435
+ break
1436
+ for bp in list(progress_bars.values()):
1437
+ bp.update_time(status="")
1438
+ _time.sleep(0.5)
1439
+
1440
+ timer_thread = threading.Thread(target=_timer_update, daemon=True)
1441
+ timer_thread.start()
1442
+
1443
+ try:
1444
+ # Determine database: JSON config overrides instance default
1445
+ json_database = json_extra_config.get('database')
1446
+ final_database = json_database if json_database else self.database
1447
+
1448
+ # Prepare callbacks for progress tracking
1449
+ callbacks = {
1450
+ 'on_progress': on_progress,
1451
+ 'on_dbms_start': on_dbms_start,
1452
+ 'on_dbms_done': on_dbms_done,
1453
+ 'on_profile_start': on_profile_start if bench_cfg.profile else None,
1454
+ 'on_profile_done': on_profile_done if bench_cfg.profile else None,
1455
+ 'on_run_start': on_run_start,
1456
+ 'on_setup_start': on_setup_start,
1457
+ 'on_setup_done': on_setup_done,
1458
+ }
1459
+
1460
+ # Use shared core function for benchmark execution
1461
+ from .runner import run_benchmark_with_progress
1462
+ run_dir, result = run_benchmark_with_progress(
1463
+ configs=configs,
1464
+ workload=workload,
1465
+ bench_cfg=bench_cfg,
1466
+ database=final_database,
1467
+ output_dir=output_dir,
1468
+ output_format=fmt,
1469
+ parallel_dbms=self.parallel_dbms,
1470
+ json_extra_config=json_extra_config,
1471
+ callbacks=callbacks,
1472
+ bench_file=bench_file,
1473
+ )
1474
+ finally:
1475
+ # Stop timer thread
1476
+ if timer_stop_event is not None:
1477
+ timer_stop_event.set()
1478
+ if timer_thread is not None:
1479
+ timer_thread.join(timeout=1.0)
1480
+
1481
+ # Reports - already generated by run_benchmark_with_progress
1482
+ print_phase("Reports")
1483
+
1484
+ if fmt in ("text", "all"):
1485
+ text_path = os.path.join(run_dir, f"bench_{workload.name}.report.txt")
1486
+ print_report_file(text_path, label="text")
1487
+
1488
+ if fmt in ("html", "all"):
1489
+ html_path = os.path.join(run_dir, f"bench_{workload.name}.html")
1490
+ print_report_file(html_path, label="html")
1491
+
1492
+ # JSON - already saved by run_benchmark_with_progress
1493
+ json_path = os.path.join(run_dir, "bench_result.json")
1494
+ print_report_file(json_path, label="json")
1495
+
1496
+ # Latest symlink and history index - already updated by run_benchmark_with_progress
1497
+
1498
+ print_bench_summary(result)
1499
+ flush_all()
1500
+
1501
+ return run_dir
1502
+
1503
+ # Main loop for repeat rounds
1504
+ last_run_dir = None
1505
+ for rnd in range(1, self.repeat + 1):
1506
+ try:
1507
+ last_run_dir = _run_one_round(rnd)
1508
+ except KeyboardInterrupt:
1509
+ console.print(
1510
+ f"\n[yellow]Interrupted at round {rnd}/"
1511
+ f"{self.repeat}. Stopping.[/yellow]")
1512
+ flush_all()
1513
+ break
1514
+ if rnd < self.repeat:
1515
+ _time.sleep(1)
1516
+
1517
+ if self.repeat > 1:
1518
+ console.print(
1519
+ f"\n[bold green]All {self.repeat} rounds "
1520
+ f"completed.[/bold green]")
1521
+ flush_all()
1522
+
1523
+ success = last_run_dir is not None
1524
+ self._run_history.append({
1525
+ "bench_file": bench_file,
1526
+ "workload": workload.name,
1527
+ "time": _time.strftime("%H:%M:%S"),
1528
+ "status": "OK" if success else "FAIL",
1529
+ "run_dir": last_run_dir or "",
1530
+ })
1531
+
1532
+ # Open in browser via server
1533
+ srv = self._ensure_server()
1534
+ if (srv and last_run_dir
1535
+ and fmt in ("html", "all")):
1536
+ html_file = f"bench_{workload.name}.html"
1537
+ html_path = os.path.join(last_run_dir, html_file)
1538
+ if os.path.isfile(html_path):
1539
+ url = (f"{srv.base_url}"
1540
+ f"/{os.path.basename(last_run_dir)}"
1541
+ f"/{html_file}")
1542
+ console.print(
1543
+ f"\n [cyan]📊 Report:[/cyan] "
1544
+ f"[bold link={url}]{url}[/bold link]\n")
1545
+ self._open_in_ide(url)
1546
+
1547
+ return success
1548
+
1549
+ # -- command handlers ---------------------------------------------------
1550
+
1551
+ def _cmd_help(self):
1552
+ console.print("\n [bold cyan]Available commands:[/bold cyan]")
1553
+ for cmd, desc in self.COMMANDS.items():
1554
+ console.print(f" [bold]{cmd:10s}[/bold] {desc}")
1555
+ console.print(
1556
+ "\n Or enter a [bold].json / .sql[/bold] bench file path"
1557
+ " to execute.\n")
1558
+
1559
+ def _cmd_status(self):
1560
+ console.print(f"\n [cyan]Config:[/cyan]")
1561
+ console.print(
1562
+ f" DBMS: "
1563
+ f"[bold]{', '.join(c.name for c in self.configs)}[/bold]")
1564
+ console.print(f" Database: [bold]{self.database}[/bold]")
1565
+ console.print(f" Iterations: [bold]{self.iterations}[/bold]")
1566
+ console.print(f" Warmup: [bold]{self.warmup}[/bold]")
1567
+ if self.concurrency > 0:
1568
+ console.print(
1569
+ f" Concurrency: [bold]{self.concurrency}[/bold]")
1570
+ console.print(
1571
+ f" Duration: [bold]{self.duration}s[/bold]")
1572
+ console.print(f" Repeat: [bold]{self.repeat}[/bold]")
1573
+ console.print(f" Output: [bold]{self.output_dir}[/bold]")
1574
+ console.print(f" Format: [bold]{self.output_format}[/bold]")
1575
+ console.print(
1576
+ f" Runs: [bold]{len(self._run_history)}[/bold]")
1577
+ if self._report_server and self._report_server.running:
1578
+ console.print(
1579
+ f" Server: "
1580
+ f"[bold green]{self._report_server.base_url}"
1581
+ f"[/bold green]")
1582
+ console.print()
1583
+
1584
+ def _cmd_history(self):
1585
+ if not self._run_history:
1586
+ console.print("\n [dim]No benchmarks executed yet.[/dim]\n")
1587
+ return
1588
+ console.print(
1589
+ f"\n [bold cyan]Session history "
1590
+ f"({len(self._run_history)} runs):[/bold cyan]")
1591
+ for i, entry in enumerate(self._run_history, 1):
1592
+ status_style = ("green" if entry["status"] == "OK"
1593
+ else "red")
1594
+ console.print(
1595
+ f" {i:3d}. [{status_style}]{entry['status']:4s}"
1596
+ f"[/{status_style}] "
1597
+ f"[dim]{entry['time']}[/dim] "
1598
+ f"{entry['bench_file']} "
1599
+ f"[dim]({entry['workload']})[/dim]")
1600
+ console.print()
1601
+
1602
+ def _cmd_server(self):
1603
+ srv = self._ensure_server()
1604
+ if srv and srv.running:
1605
+ idx_url = f"{srv.base_url}/index.html"
1606
+ console.print(
1607
+ f"\n [green]●[/green] Server running: "
1608
+ f"[bold link={idx_url}]{idx_url}[/bold link]\n")
1609
+ else:
1610
+ console.print("\n [dim]Server not running "
1611
+ "(use --serve to enable).[/dim]\n")
1612
+
1613
+ def _cmd_open(self):
1614
+ latest = os.path.join(self.output_dir, "latest")
1615
+ if not os.path.islink(latest):
1616
+ console.print("\n [dim]No results yet.[/dim]\n")
1617
+ return
1618
+ real_dir = os.path.realpath(latest)
1619
+ htmls = [f for f in os.listdir(real_dir) if f.endswith(".html")]
1620
+ if not htmls:
1621
+ console.print("\n [dim]No HTML report found.[/dim]\n")
1622
+ return
1623
+ srv = self._ensure_server()
1624
+ if not srv:
1625
+ console.print("\n [dim]Server not available.[/dim]\n")
1626
+ return
1627
+ url = (f"{srv.base_url}"
1628
+ f"/{os.path.basename(real_dir)}/{htmls[0]}")
1629
+ console.print(f"\n Opening: [bold]{url}[/bold]\n")
1630
+ self._open_in_ide(url)
1631
+
1632
+ # -- main loop ----------------------------------------------------------
1633
+
1634
+ def run(self):
1635
+ """Start the interactive benchmark REPL.
1636
+
1637
+ Returns ``"back"`` if the user typed ``back``/``b``,
1638
+ ``"quit"`` otherwise (including EOF / KeyboardInterrupt).
1639
+ """
1640
+ os.makedirs(self.output_dir, exist_ok=True)
1641
+ session: PromptSession = PromptSession(
1642
+ history=FileHistory(os.path.join(self.output_dir, ".rosetta_bench_history")),
1643
+ completer=BenchFileCompleter(),
1644
+ style=_PROMPT_STYLE,
1645
+ complete_while_typing=True,
1646
+ multiline=False,
1647
+ )
1648
+
1649
+ _placeholder = HTML('<placeholder>Type a path, \'help\', \'back\', or \'quit\'</placeholder>')
1650
+
1651
+ # Welcome banner
1652
+ border = "═" * 55
1653
+ title = "Rosetta Benchmark Interactive Mode"
1654
+ hint = "Enter bench file (.json/.sql) to execute, or 'help'"
1655
+ title_line = f" {title} ".center(55)
1656
+ hint_line = f" {hint} ".center(55)
1657
+ console.print(f" [bold cyan]╔{border}╗[/bold cyan]")
1658
+ console.print(f" [bold cyan]║[/bold cyan]"
1659
+ f"[bold white]{title_line}[/bold white]"
1660
+ f"[bold cyan]║[/bold cyan]")
1661
+ console.print(f" [bold cyan]║[/bold cyan]"
1662
+ f"[dim]{hint_line}[/dim]"
1663
+ f"[bold cyan]║[/bold cyan]")
1664
+ console.print(f" [bold cyan]╚{border}╝[/bold cyan]")
1665
+
1666
+ # Show config
1667
+ if self.concurrency > 0:
1668
+ mode_str = "CONCURRENT"
1669
+ config_parts = [
1670
+ f"[dim]Mode:[/dim] [bold]{mode_str}[/bold]",
1671
+ f"[dim]Concurrency:[/dim] [bold]{self.concurrency}[/bold]",
1672
+ ]
1673
+ if self.duration > 0:
1674
+ config_parts.append(
1675
+ f"[dim]Duration:[/dim] [bold]{self.duration}s[/bold]")
1676
+ if self.ramp_up > 0:
1677
+ config_parts.append(
1678
+ f"[dim]Ramp-up:[/dim] [bold]{self.ramp_up}s[/bold]")
1679
+ if self.warmup > 0:
1680
+ config_parts.append(
1681
+ f"[dim]Warmup:[/dim] [bold]{self.warmup}[/bold]")
1682
+ else:
1683
+ mode_str = "SERIAL"
1684
+ config_parts = [
1685
+ f"[dim]Mode:[/dim] [bold]{mode_str}[/bold]",
1686
+ f"[dim]Iterations:[/dim] [bold]{self.iterations}[/bold]",
1687
+ f"[dim]Warmup:[/dim] [bold]{self.warmup}[/bold]",
1688
+ ]
1689
+ console.print(
1690
+ f" [dim]DBMS:[/dim] "
1691
+ f"[bold]{', '.join(c.name for c in self.configs)}[/bold] "
1692
+ + " ".join(config_parts))
1693
+ if self.repeat > 1:
1694
+ console.print(
1695
+ f" [dim]Repeat:[/dim] [bold]{self.repeat}[/bold] "
1696
+ f"[dim]Database:[/dim] [bold]{self.database}[/bold]")
1697
+ else:
1698
+ console.print(
1699
+ f" [dim]Database:[/dim] [bold]{self.database}[/bold]")
1700
+
1701
+ # Start server early if requested
1702
+ srv = self._ensure_server()
1703
+ if srv and srv.running:
1704
+ console.print(
1705
+ f" [dim]Server:[/dim] "
1706
+ f"[bold green]{srv.base_url}[/bold green]")
1707
+ console.print()
1708
+
1709
+ run_count = 0
1710
+ exit_reason = "quit"
1711
+
1712
+ while True:
1713
+ try:
1714
+ prompt_msg = HTML(
1715
+ '<prompt>rosetta</prompt> <path>▶</path> ')
1716
+ user_input = session.prompt(
1717
+ prompt_msg, placeholder=_placeholder).strip()
1718
+ except (EOFError, KeyboardInterrupt):
1719
+ break
1720
+
1721
+ if not user_input:
1722
+ continue
1723
+
1724
+ cmd = user_input.lower()
1725
+
1726
+ # Back to parameter selection
1727
+ if cmd in ("back", "b"):
1728
+ exit_reason = "back"
1729
+ break
1730
+
1731
+ # Exit
1732
+ if cmd in ("quit", "exit", "q"):
1733
+ break
1734
+
1735
+ # Built-in commands
1736
+ if cmd == "help":
1737
+ self._cmd_help()
1738
+ continue
1739
+ if cmd == "status":
1740
+ self._cmd_status()
1741
+ continue
1742
+ if cmd == "history":
1743
+ self._cmd_history()
1744
+ continue
1745
+ if cmd == "server":
1746
+ self._cmd_server()
1747
+ continue
1748
+ if cmd == "open":
1749
+ self._cmd_open()
1750
+ continue
1751
+ if cmd == "clear":
1752
+ console.clear()
1753
+ continue
1754
+
1755
+ # Treat as bench file path
1756
+ bench_path = os.path.expanduser(user_input)
1757
+ if not os.path.isabs(bench_path):
1758
+ bench_path = os.path.abspath(bench_path)
1759
+
1760
+ run_count += 1
1761
+ console.print()
1762
+ console.rule(
1763
+ f"[bold cyan] Bench #{run_count}: "
1764
+ f"{os.path.basename(bench_path)} [/bold cyan]")
1765
+ console.print()
1766
+
1767
+ self._run_bench(bench_path)
1768
+
1769
+ console.print(
1770
+ " [dim]Ready for next benchmark. "
1771
+ "Type a path, 'help', 'back', or 'quit'.[/dim]\n")
1772
+
1773
+ # Cleanup
1774
+ if exit_reason == "back":
1775
+ # Silent cleanup — caller will clear the screen
1776
+ if self._report_server:
1777
+ self._report_server.stop()
1778
+ else:
1779
+ console.print()
1780
+ if self._run_history:
1781
+ console.print(
1782
+ f" [dim]Session complete: "
1783
+ f"{len(self._run_history)} benchmark(s) "
1784
+ f"executed.[/dim]")
1785
+ if self._report_server:
1786
+ self._report_server.stop()
1787
+ console.print(" [dim]Report server stopped.[/dim]")
1788
+ console.print(" [bold cyan]Goodbye! 👋[/bold cyan]\n")
1789
+
1790
+ return exit_reason