rosetta-sql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rosetta/benchmark.py ADDED
@@ -0,0 +1,1678 @@
1
+ """Benchmark engine for Rosetta cross-DBMS performance comparison."""
2
+
3
+ import concurrent.futures
4
+ import json
5
+ import logging
6
+ import math
7
+ import random
8
+ import re
9
+ import string
10
+ import threading
11
+ import time as _time
12
+ from pathlib import Path
13
+ from typing import Callable, Dict, List, Optional
14
+
15
+ from .executor import DBConnection, ensure_service
16
+ from .models import (
17
+ BenchmarkConfig, BenchmarkResult, BenchQuery, BenchWorkload,
18
+ DBMSBenchResult, DBMSConfig, QueryLatencyStats, TestCase, WorkloadMode,
19
+ )
20
+
21
+ log = logging.getLogger("rosetta")
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Template variable engine
26
+ # ---------------------------------------------------------------------------
27
+
28
+ class TemplateEngine:
29
+ """Render template variables in SQL strings.
30
+
31
+ Supported placeholders:
32
+ {{rand_int(min,max)}} — random integer in [min, max]
33
+ {{rand_str(len)}} — random alphanumeric string of given length
34
+ {{rand_choice(a,b,c,...)}} — random pick from comma-separated values
35
+ {{seq_int()}} — monotonically increasing integer (per engine)
36
+ """
37
+
38
+ # Match {{func_name(args)}} or {{func_name()}}
39
+ _PATTERN = re.compile(r"\{\{\s*(\w+)\s*\(([^)]*)\)\s*\}\}")
40
+
41
+ def __init__(self, seed: Optional[int] = None):
42
+ self._rng = random.Random(seed)
43
+ self._seq_counter = 0
44
+ self._seq_lock = threading.Lock()
45
+
46
+ def render(self, sql: str) -> str:
47
+ """Replace all template placeholders in *sql* with concrete values."""
48
+ return self._PATTERN.sub(self._replace_match, sql)
49
+
50
+ def _replace_match(self, match: re.Match) -> str:
51
+ func_name = match.group(1).lower()
52
+ raw_args = match.group(2).strip()
53
+
54
+ handler = {
55
+ "rand_int": self._rand_int,
56
+ "rand_str": self._rand_str,
57
+ "rand_choice": self._rand_choice,
58
+ "seq_int": self._seq_int,
59
+ }.get(func_name)
60
+
61
+ if handler is None:
62
+ log.warning("Unknown template function: %s", func_name)
63
+ return match.group(0) # leave unchanged
64
+
65
+ return handler(raw_args)
66
+
67
+ # -- handler implementations ---------------------------------------------
68
+
69
+ def _rand_int(self, args: str) -> str:
70
+ parts = [p.strip() for p in args.split(",")]
71
+ if len(parts) != 2:
72
+ log.warning("rand_int expects 2 args, got: %s", args)
73
+ return "0"
74
+ try:
75
+ lo, hi = int(parts[0]), int(parts[1])
76
+ except ValueError:
77
+ log.warning("rand_int args must be integers: %s", args)
78
+ return "0"
79
+ return str(self._rng.randint(lo, hi))
80
+
81
+ def _rand_str(self, args: str) -> str:
82
+ args = args.strip()
83
+ try:
84
+ length = int(args)
85
+ except ValueError:
86
+ log.warning("rand_str expects an integer length: %s", args)
87
+ length = 8
88
+ chars = string.ascii_letters + string.digits
89
+ return "".join(self._rng.choice(chars) for _ in range(length))
90
+
91
+ def _rand_choice(self, args: str) -> str:
92
+ choices = [c.strip() for c in args.split(",") if c.strip()]
93
+ if not choices:
94
+ log.warning("rand_choice received empty choices")
95
+ return ""
96
+ return self._rng.choice(choices)
97
+
98
+ def _seq_int(self, _args: str) -> str:
99
+ with self._seq_lock:
100
+ self._seq_counter += 1
101
+ return str(self._seq_counter)
102
+
103
+ def reset_seq(self):
104
+ """Reset the sequential counter (useful between runs)."""
105
+ with self._seq_lock:
106
+ self._seq_counter = 0
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Test case generator
111
+ # ---------------------------------------------------------------------------
112
+
113
+ class TestCaseGenerator:
114
+ """Pre-generate a fixed set of test cases for fair comparison.
115
+
116
+ All DBMS instances will execute the exact same sequence of SQL statements,
117
+ ensuring fair and reproducible performance comparison.
118
+ """
119
+
120
+ def __init__(self, seed: int = 42):
121
+ self.seed = seed
122
+ self.engine = TemplateEngine(seed=seed)
123
+
124
+ def generate_serial_cases(
125
+ self,
126
+ workload: BenchWorkload,
127
+ iterations: int,
128
+ warmup: int = 0,
129
+ ) -> List[TestCase]:
130
+ """Generate test cases for serial benchmark mode.
131
+
132
+ Args:
133
+ workload: The workload definition
134
+ iterations: Number of iterations per query
135
+ warmup: Number of warmup iterations (not included in test cases)
136
+
137
+ Returns:
138
+ List of TestCase objects, one for each query execution
139
+ """
140
+ cases = []
141
+
142
+ for query in workload.queries:
143
+ for i in range(iterations):
144
+ # Render the SQL template with concrete values
145
+ rendered_sql = self.engine.render(query.sql)
146
+ cleanup = self.engine.render(query.cleanup_sql) if query.cleanup_sql else ""
147
+ cases.append(TestCase(
148
+ query_name=query.name,
149
+ sql=rendered_sql,
150
+ original_sql=query.sql,
151
+ cleanup_sql=cleanup,
152
+ ))
153
+
154
+ return cases
155
+
156
+ def generate_concurrent_cases(
157
+ self,
158
+ workload: BenchWorkload,
159
+ total_queries: int,
160
+ ) -> List[TestCase]:
161
+ """Generate test cases for concurrent benchmark mode.
162
+
163
+ Uses weighted random selection to choose queries, ensuring all DBMS
164
+ execute the same sequence.
165
+
166
+ Args:
167
+ workload: The workload definition
168
+ total_queries: Total number of queries to generate
169
+
170
+ Returns:
171
+ List of TestCase objects
172
+ """
173
+ # Build weighted pool for selection
174
+ weighted_pool = []
175
+ for query in workload.queries:
176
+ weighted_pool.extend([query] * query.weight)
177
+
178
+ if not weighted_pool:
179
+ return []
180
+
181
+ # Use a separate RNG for query selection (deterministic)
182
+ rng = random.Random(self.seed)
183
+
184
+ cases = []
185
+ for i in range(total_queries):
186
+ query = rng.choice(weighted_pool)
187
+ rendered_sql = self.engine.render(query.sql)
188
+ cleanup = self.engine.render(query.cleanup_sql) if query.cleanup_sql else ""
189
+ cases.append(TestCase(
190
+ query_name=query.name,
191
+ sql=rendered_sql,
192
+ original_sql=query.sql,
193
+ cleanup_sql=cleanup,
194
+ ))
195
+
196
+ return cases
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Benchmark loader
201
+ # ---------------------------------------------------------------------------
202
+
203
+ class BenchmarkLoader:
204
+ """Load benchmark workload definitions from various sources.
205
+
206
+ Supported sources:
207
+ 1. Built-in template name — e.g. "oltp_read_write"
208
+ 2. Plain .sql file — each non-empty/non-comment line is a query
209
+ 3. JSON definition file — full control (setup, queries, teardown)
210
+ """
211
+
212
+ @staticmethod
213
+ def from_builtin(template_name: str) -> BenchWorkload:
214
+ """Load a built-in benchmark template by name."""
215
+ template_name = template_name.lower()
216
+ if template_name not in BUILTIN_TEMPLATES:
217
+ available = ", ".join(sorted(BUILTIN_TEMPLATES.keys()))
218
+ raise ValueError(
219
+ f"Unknown built-in template: '{template_name}'. "
220
+ f"Available: {available}"
221
+ )
222
+ return BUILTIN_TEMPLATES[template_name]()
223
+
224
+ @staticmethod
225
+ def from_sql_file(path: str) -> BenchWorkload:
226
+ """Load a benchmark workload from a plain .sql file.
227
+
228
+ Each non-empty, non-comment line becomes a query.
229
+ Multi-line statements are NOT supported in plain mode;
230
+ use JSON definition for complex workloads.
231
+ """
232
+ filepath = Path(path)
233
+ if not filepath.exists():
234
+ raise FileNotFoundError(f"SQL file not found: {path}")
235
+
236
+ text = filepath.read_text(encoding="utf-8")
237
+ queries: List[BenchQuery] = []
238
+ idx = 0
239
+
240
+ for line in text.splitlines():
241
+ stripped = line.strip()
242
+ # skip empty lines and comments
243
+ if not stripped or stripped.startswith("--") or stripped.startswith("#"):
244
+ continue
245
+ # remove trailing semicolon for consistency
246
+ if stripped.endswith(";"):
247
+ stripped = stripped[:-1].rstrip()
248
+ if not stripped:
249
+ continue
250
+ idx += 1
251
+ # use a readable default name
252
+ name = f"query_{idx}"
253
+ queries.append(BenchQuery(name=name, sql=stripped, weight=1))
254
+
255
+ if not queries:
256
+ raise ValueError(f"No valid queries found in: {path}")
257
+
258
+ return BenchWorkload(
259
+ name=filepath.stem,
260
+ queries=queries,
261
+ )
262
+
263
+ @staticmethod
264
+ def from_json_file(path: str) -> BenchWorkload:
265
+ """Load a benchmark workload from a JSON definition file.
266
+
267
+ Expected JSON schema:
268
+ {
269
+ "name": "my_workload", // optional
270
+ "setup": ["CREATE TABLE ..."], // optional
271
+ "queries": [
272
+ {
273
+ "name": "point_select", // optional, auto-generated
274
+ "sql": "SELECT ...", // required
275
+ "weight": 5 // optional, default 1
276
+ }
277
+ ],
278
+ "teardown": ["DROP TABLE ..."] // optional
279
+ }
280
+ """
281
+ filepath = Path(path)
282
+ if not filepath.exists():
283
+ raise FileNotFoundError(f"JSON file not found: {path}")
284
+
285
+ data = json.loads(filepath.read_text(encoding="utf-8"))
286
+
287
+ if not isinstance(data, dict):
288
+ raise ValueError("JSON benchmark file must be a JSON object")
289
+
290
+ raw_queries = data.get("queries", [])
291
+ if not raw_queries:
292
+ raise ValueError("JSON benchmark file must contain 'queries'")
293
+
294
+ queries: List[BenchQuery] = []
295
+ for i, q in enumerate(raw_queries):
296
+ if isinstance(q, str):
297
+ # shorthand: just a SQL string
298
+ queries.append(BenchQuery(
299
+ name=f"query_{i + 1}",
300
+ sql=q.rstrip(";").strip(),
301
+ weight=1,
302
+ ))
303
+ elif isinstance(q, dict):
304
+ sql = q.get("sql", "").rstrip(";").strip()
305
+ if not sql:
306
+ raise ValueError(
307
+ f"Query at index {i} is missing 'sql' field"
308
+ )
309
+ cleanup = q.get("cleanup_sql", "").rstrip(";").strip()
310
+ queries.append(BenchQuery(
311
+ name=q.get("name", f"query_{i + 1}"),
312
+ sql=sql,
313
+ weight=max(1, int(q.get("weight", 1))),
314
+ description=q.get("description", ""),
315
+ cleanup_sql=cleanup,
316
+ ))
317
+ else:
318
+ raise ValueError(
319
+ f"Query at index {i}: expected string or object, "
320
+ f"got {type(q).__name__}"
321
+ )
322
+
323
+ setup = data.get("setup", [])
324
+ if isinstance(setup, str):
325
+ setup = [setup]
326
+ # Filter out comment lines and empty strings
327
+ setup = [
328
+ s for s in setup
329
+ if s and not s.strip().startswith("--")
330
+ and not s.strip().startswith("===")
331
+ and not s.strip().startswith("#")
332
+ ]
333
+
334
+ teardown = data.get("teardown", [])
335
+ if isinstance(teardown, str):
336
+ teardown = [teardown]
337
+ # Filter out comment lines and empty strings
338
+ teardown = [
339
+ s for s in teardown
340
+ if s and not s.strip().startswith("--")
341
+ and not s.strip().startswith("===")
342
+ and not s.strip().startswith("#")
343
+ ]
344
+
345
+ return BenchWorkload(
346
+ name=data.get("name", filepath.stem),
347
+ setup=setup,
348
+ queries=queries,
349
+ teardown=teardown,
350
+ )
351
+
352
+ @staticmethod
353
+ def from_file(path: str) -> BenchWorkload:
354
+ """Auto-detect file type and load accordingly."""
355
+ lower = path.lower()
356
+ if lower.endswith(".json"):
357
+ return BenchmarkLoader.from_json_file(path)
358
+ elif lower.endswith(".sql"):
359
+ return BenchmarkLoader.from_sql_file(path)
360
+ else:
361
+ raise ValueError(
362
+ f"Unsupported benchmark file extension: {path}. "
363
+ "Use .json or .sql"
364
+ )
365
+
366
+ @staticmethod
367
+ def list_builtin_templates() -> List[str]:
368
+ """Return names of all available built-in templates."""
369
+ return sorted(BUILTIN_TEMPLATES.keys())
370
+
371
+ @staticmethod
372
+ def filter_queries(
373
+ workload: BenchWorkload, names: List[str]
374
+ ) -> BenchWorkload:
375
+ """Return a new workload containing only queries whose names match.
376
+
377
+ Matching is case-insensitive and supports substring matching.
378
+ """
379
+ if not names:
380
+ return workload
381
+
382
+ lower_names = [n.lower() for n in names]
383
+ filtered = [
384
+ q for q in workload.queries
385
+ if any(n in q.name.lower() for n in lower_names)
386
+ ]
387
+
388
+ if not filtered:
389
+ available = ", ".join(q.name for q in workload.queries)
390
+ raise ValueError(
391
+ f"No queries match filter {names}. "
392
+ f"Available queries: {available}"
393
+ )
394
+
395
+ return BenchWorkload(
396
+ name=workload.name,
397
+ setup=workload.setup,
398
+ queries=filtered,
399
+ teardown=workload.teardown,
400
+ )
401
+
402
+
403
+ # ---------------------------------------------------------------------------
404
+ # Built-in benchmark templates
405
+ # ---------------------------------------------------------------------------
406
+
407
+ def _template_oltp_read_write() -> BenchWorkload:
408
+ """OLTP Read-Write mixed workload.
409
+
410
+ Requires a pre-created table. Setup creates 'bench_accounts' with
411
+ 10 000 rows; teardown drops it.
412
+ """
413
+ return BenchWorkload(
414
+ name="oltp_read_write",
415
+ setup=[
416
+ "CREATE TABLE IF NOT EXISTS bench_accounts ("
417
+ " id INT PRIMARY KEY AUTO_INCREMENT,"
418
+ " name VARCHAR(100) NOT NULL,"
419
+ " balance DECIMAL(15,2) NOT NULL DEFAULT 0.00,"
420
+ " created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
421
+ ")",
422
+ # Seed 10 000 rows via a quick INSERT-SELECT trick.
423
+ # We insert in small batches to avoid overlong SQL.
424
+ "INSERT INTO bench_accounts (name, balance) "
425
+ "SELECT CONCAT('user_', seq), ROUND(RAND() * 10000, 2) "
426
+ "FROM (SELECT @rownum := @rownum + 1 AS seq "
427
+ " FROM information_schema.columns a, "
428
+ " information_schema.columns b, "
429
+ " (SELECT @rownum := 0) r "
430
+ " LIMIT 10000) t",
431
+ ],
432
+ queries=[
433
+ BenchQuery(
434
+ name="point_select",
435
+ sql="SELECT * FROM bench_accounts WHERE id = {{rand_int(1,10000)}}",
436
+ weight=5,
437
+ ),
438
+ BenchQuery(
439
+ name="range_select",
440
+ sql="SELECT * FROM bench_accounts WHERE id BETWEEN "
441
+ "{{rand_int(1,5000)}} AND {{rand_int(5001,10000)}} "
442
+ "ORDER BY id LIMIT 100",
443
+ weight=3,
444
+ ),
445
+ BenchQuery(
446
+ name="update_balance",
447
+ sql="UPDATE bench_accounts SET balance = balance + "
448
+ "{{rand_int(1,100)}} WHERE id = {{rand_int(1,10000)}}",
449
+ weight=3,
450
+ ),
451
+ BenchQuery(
452
+ name="insert_row",
453
+ sql="INSERT INTO bench_accounts (name, balance) VALUES "
454
+ "('{{rand_str(10)}}', {{rand_int(100,9999)}})",
455
+ weight=2,
456
+ ),
457
+ BenchQuery(
458
+ name="delete_row",
459
+ sql="DELETE FROM bench_accounts ORDER BY RAND() LIMIT 1",
460
+ weight=1,
461
+ ),
462
+ BenchQuery(
463
+ name="aggregate_sum",
464
+ sql="SELECT COUNT(*), SUM(balance), AVG(balance) "
465
+ "FROM bench_accounts",
466
+ weight=1,
467
+ ),
468
+ ],
469
+ teardown=[
470
+ "DROP TABLE IF EXISTS bench_accounts",
471
+ ],
472
+ )
473
+
474
+
475
+ def _template_oltp_read_only() -> BenchWorkload:
476
+ """OLTP Read-Only workload — no writes after setup."""
477
+ return BenchWorkload(
478
+ name="oltp_read_only",
479
+ setup=[
480
+ "CREATE TABLE IF NOT EXISTS bench_accounts ("
481
+ " id INT PRIMARY KEY AUTO_INCREMENT,"
482
+ " name VARCHAR(100) NOT NULL,"
483
+ " balance DECIMAL(15,2) NOT NULL DEFAULT 0.00,"
484
+ " created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
485
+ ")",
486
+ "INSERT INTO bench_accounts (name, balance) "
487
+ "SELECT CONCAT('user_', seq), ROUND(RAND() * 10000, 2) "
488
+ "FROM (SELECT @rownum := @rownum + 1 AS seq "
489
+ " FROM information_schema.columns a, "
490
+ " information_schema.columns b, "
491
+ " (SELECT @rownum := 0) r "
492
+ " LIMIT 10000) t",
493
+ ],
494
+ queries=[
495
+ BenchQuery(
496
+ name="point_select",
497
+ sql="SELECT * FROM bench_accounts WHERE id = {{rand_int(1,10000)}}",
498
+ weight=5,
499
+ ),
500
+ BenchQuery(
501
+ name="range_select",
502
+ sql="SELECT * FROM bench_accounts WHERE id BETWEEN "
503
+ "{{rand_int(1,5000)}} AND {{rand_int(5001,10000)}} "
504
+ "ORDER BY id LIMIT 100",
505
+ weight=3,
506
+ ),
507
+ BenchQuery(
508
+ name="aggregate_sum",
509
+ sql="SELECT COUNT(*), SUM(balance), AVG(balance) "
510
+ "FROM bench_accounts",
511
+ weight=2,
512
+ ),
513
+ BenchQuery(
514
+ name="order_by_limit",
515
+ sql="SELECT * FROM bench_accounts ORDER BY balance DESC "
516
+ "LIMIT {{rand_int(10,50)}}",
517
+ weight=2,
518
+ ),
519
+ BenchQuery(
520
+ name="like_search",
521
+ sql="SELECT * FROM bench_accounts WHERE name LIKE "
522
+ "'user_{{rand_int(1,999)}}%' LIMIT 20",
523
+ weight=1,
524
+ ),
525
+ ],
526
+ teardown=[
527
+ "DROP TABLE IF EXISTS bench_accounts",
528
+ ],
529
+ )
530
+
531
+
532
+ def _template_oltp_write_only() -> BenchWorkload:
533
+ """OLTP Write-Only workload — inserts, updates, deletes."""
534
+ return BenchWorkload(
535
+ name="oltp_write_only",
536
+ setup=[
537
+ "CREATE TABLE IF NOT EXISTS bench_accounts ("
538
+ " id INT PRIMARY KEY AUTO_INCREMENT,"
539
+ " name VARCHAR(100) NOT NULL,"
540
+ " balance DECIMAL(15,2) NOT NULL DEFAULT 0.00,"
541
+ " created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
542
+ ")",
543
+ "INSERT INTO bench_accounts (name, balance) "
544
+ "SELECT CONCAT('user_', seq), ROUND(RAND() * 10000, 2) "
545
+ "FROM (SELECT @rownum := @rownum + 1 AS seq "
546
+ " FROM information_schema.columns a, "
547
+ " information_schema.columns b, "
548
+ " (SELECT @rownum := 0) r "
549
+ " LIMIT 10000) t",
550
+ ],
551
+ queries=[
552
+ BenchQuery(
553
+ name="insert_row",
554
+ sql="INSERT INTO bench_accounts (name, balance) VALUES "
555
+ "('{{rand_str(10)}}', {{rand_int(100,9999)}})",
556
+ weight=4,
557
+ ),
558
+ BenchQuery(
559
+ name="update_balance",
560
+ sql="UPDATE bench_accounts SET balance = balance + "
561
+ "{{rand_int(1,100)}} WHERE id = {{rand_int(1,10000)}}",
562
+ weight=4,
563
+ ),
564
+ BenchQuery(
565
+ name="update_name",
566
+ sql="UPDATE bench_accounts SET name = '{{rand_str(10)}}' "
567
+ "WHERE id = {{rand_int(1,10000)}}",
568
+ weight=2,
569
+ ),
570
+ BenchQuery(
571
+ name="delete_row",
572
+ sql="DELETE FROM bench_accounts WHERE id = {{rand_int(1,10000)}}",
573
+ weight=2,
574
+ ),
575
+ BenchQuery(
576
+ name="replace_row",
577
+ sql="REPLACE INTO bench_accounts (id, name, balance) VALUES "
578
+ "({{rand_int(1,10000)}}, '{{rand_str(8)}}', {{rand_int(100,9999)}})",
579
+ weight=1,
580
+ ),
581
+ ],
582
+ teardown=[
583
+ "DROP TABLE IF EXISTS bench_accounts",
584
+ ],
585
+ )
586
+
587
+
588
+ # Registry of built-in templates (name → factory function)
589
+ BUILTIN_TEMPLATES: Dict[str, callable] = {
590
+ "oltp_read_write": _template_oltp_read_write,
591
+ "oltp_read_only": _template_oltp_read_only,
592
+ "oltp_write_only": _template_oltp_write_only,
593
+ }
594
+
595
+
596
+ # ---------------------------------------------------------------------------
597
+ # Statistics computation (pure Python, no numpy)
598
+ # ---------------------------------------------------------------------------
599
+
600
+ def _percentile(sorted_data: List[float], pct: float) -> float:
601
+ """Compute the *pct*-th percentile from pre-sorted data."""
602
+ if not sorted_data:
603
+ return 0.0
604
+ k = (len(sorted_data) - 1) * (pct / 100.0)
605
+ f = math.floor(k)
606
+ c = math.ceil(k)
607
+ if f == c:
608
+ return sorted_data[int(k)]
609
+ d0 = sorted_data[int(f)] * (c - k)
610
+ d1 = sorted_data[int(c)] * (k - f)
611
+ return d0 + d1
612
+
613
+
614
+ def compute_stats(
615
+ latencies: List[float], total_errors: int, elapsed_s: float,
616
+ query_name: str, sql_template: str = "",
617
+ ) -> QueryLatencyStats:
618
+ """Compute latency statistics from a list of execution times (ms)."""
619
+ stats = QueryLatencyStats(query_name=query_name, sql_template=sql_template)
620
+ stats.total_executions = len(latencies) + total_errors
621
+ stats.total_errors = total_errors
622
+
623
+ if not latencies:
624
+ return stats
625
+
626
+ sorted_lat = sorted(latencies)
627
+ stats.latencies_ms = latencies # keep raw data for reports
628
+ stats.min_ms = sorted_lat[0]
629
+ stats.max_ms = sorted_lat[-1]
630
+ stats.avg_ms = sum(sorted_lat) / len(sorted_lat)
631
+ stats.p50_ms = _percentile(sorted_lat, 50)
632
+ stats.p95_ms = _percentile(sorted_lat, 95)
633
+ stats.p99_ms = _percentile(sorted_lat, 99)
634
+
635
+ if elapsed_s > 0:
636
+ stats.qps = len(latencies) / elapsed_s
637
+
638
+ return stats
639
+
640
+
641
+ # ---------------------------------------------------------------------------
642
+ # Base benchmark runner
643
+ # ---------------------------------------------------------------------------
644
+
645
+ class BaseBenchmarkRunner:
646
+ """Base class for benchmark runners with common setup/teardown logic."""
647
+
648
+ def __init__(
649
+ self, config: DBMSConfig, workload: BenchWorkload,
650
+ bench_cfg: BenchmarkConfig, test_cases: List[TestCase],
651
+ database: str = "rosetta_bench",
652
+ on_progress: Optional[Callable] = None,
653
+ on_profile_start: Optional[Callable] = None,
654
+ on_profile_done: Optional[Callable] = None,
655
+ on_run_start: Optional[Callable] = None,
656
+ ):
657
+ self.config = config
658
+ self.workload = workload
659
+ self.bench_cfg = bench_cfg
660
+ self.test_cases = test_cases
661
+ self.database = database
662
+ self.on_progress = on_progress
663
+ self.on_profile_start = on_profile_start
664
+ self.on_profile_done = on_profile_done
665
+ self.on_run_start = on_run_start
666
+ self._mysqld_pid: Optional[int] = None
667
+
668
+ def _resolve_mysqld_pid(self) -> Optional[int]:
669
+ """Resolve the mysqld PID for perf profiling (cached)."""
670
+ if self._mysqld_pid is not None:
671
+ return self._mysqld_pid
672
+ from .flamegraph import find_mysqld_pid
673
+ pid = find_mysqld_pid(port=self.config.port)
674
+ if pid:
675
+ self._mysqld_pid = pid
676
+ log.info("[%s] Resolved mysqld PID: %d (port %d)",
677
+ self.config.name, pid, self.config.port)
678
+ else:
679
+ log.warning("[%s] Could not find mysqld PID for port %d",
680
+ self.config.name, self.config.port)
681
+ return self._mysqld_pid
682
+
683
+ def _check_profiling_support(self) -> bool:
684
+ """Check if profiling is supported for this DBMS."""
685
+ profiling = self.bench_cfg.profile
686
+ if profiling and self.config.name.lower() != "tdsql":
687
+ log.info("[%s] Profiling skipped (only tdsql is profiled)",
688
+ self.config.name)
689
+ return False
690
+ if profiling:
691
+ from .flamegraph import check_perf_available
692
+ ok, msg = check_perf_available()
693
+ if not ok:
694
+ log.warning("[%s] Profiling disabled: %s",
695
+ self.config.name, msg)
696
+ return False
697
+ mysqld_pid = self._resolve_mysqld_pid()
698
+ if not mysqld_pid:
699
+ log.warning("[%s] Profiling disabled: mysqld PID not found",
700
+ self.config.name)
701
+ return False
702
+ return profiling
703
+
704
+ def _run_setup(self, db: DBConnection, result: DBMSBenchResult):
705
+ """Run setup phase: execute setup SQL and count table rows."""
706
+ for sql in self.workload.setup:
707
+ try:
708
+ db.cursor.execute(sql)
709
+ except Exception as e:
710
+ log.warning("[%s] Setup failed: %s — %s",
711
+ self.config.name, sql[:80], e)
712
+
713
+ # Query total rows and schema after setup
714
+ try:
715
+ db.cursor.execute(
716
+ "SELECT TABLE_NAME FROM information_schema.TABLES "
717
+ f"WHERE TABLE_SCHEMA = '{self.database}' "
718
+ "AND TABLE_TYPE = 'BASE TABLE'")
719
+ tables = [r[0] for r in db.cursor.fetchall()]
720
+ total = 0
721
+ detail = {}
722
+ schema = {}
723
+ for tbl in tables:
724
+ # Get row count
725
+ try:
726
+ db.cursor.execute(f"SELECT COUNT(*) FROM `{tbl}`")
727
+ cnt = db.cursor.fetchone()
728
+ if cnt and cnt[0]:
729
+ row_count = int(cnt[0])
730
+ total += row_count
731
+ detail[tbl] = row_count
732
+ except Exception:
733
+ pass
734
+ # Get CREATE TABLE statement
735
+ try:
736
+ db.cursor.execute(f"SHOW CREATE TABLE `{tbl}`")
737
+ row = db.cursor.fetchone()
738
+ if row and len(row) >= 2:
739
+ schema[tbl] = row[1] # Second column is the CREATE TABLE stmt
740
+ except Exception as e:
741
+ log.debug("[%s] Could not get schema for %s: %s",
742
+ self.config.name, tbl, e)
743
+ result.table_rows = total
744
+ result.table_rows_detail = detail
745
+ result.table_schema = schema
746
+ except Exception as e:
747
+ log.debug("[%s] Could not query table rows: %s",
748
+ self.config.name, e)
749
+
750
+ def _run_teardown(self, db: DBConnection):
751
+ """Run teardown phase: execute user-defined teardown SQL only."""
752
+ if self.bench_cfg.skip_teardown:
753
+ log.info("[%s] Skipping teardown (--skip-teardown)", self.config.name)
754
+ return
755
+ for sql in self.workload.teardown:
756
+ try:
757
+ db.cursor.execute(sql)
758
+ except Exception as e:
759
+ log.warning("[%s] Teardown failed: %s — %s",
760
+ self.config.name, sql[:80], e)
761
+
762
+
763
+ # ---------------------------------------------------------------------------
764
+ # Serial benchmark runner
765
+ # ---------------------------------------------------------------------------
766
+
767
+ class SerialBenchmarkRunner(BaseBenchmarkRunner):
768
+ """Execute each query N times sequentially, one after another."""
769
+
770
+ def run(self, skip_setup: bool = False) -> DBMSBenchResult:
771
+ """Run the serial benchmark and return results.
772
+
773
+ Args:
774
+ skip_setup: If True, skip setup phase (already done in separate pass).
775
+ """
776
+ result = DBMSBenchResult(dbms_name=self.config.name)
777
+
778
+ if not ensure_service(self.config):
779
+ log.error("[%s] Service unavailable, skipping benchmark",
780
+ self.config.name)
781
+ return result
782
+
783
+ db = DBConnection(self.config, self.database)
784
+ try:
785
+ db.connect()
786
+ except Exception as e:
787
+ log.error("[%s] Connection failed: %s", self.config.name, e)
788
+ return result
789
+
790
+ profiling = self._check_profiling_support()
791
+ if profiling:
792
+ from .flamegraph import PerfProfiler
793
+
794
+ try:
795
+ # Setup phase (skip if already done)
796
+ if not skip_setup:
797
+ self._run_setup(db, result)
798
+
799
+ overall_start = None # set at first test case start
800
+
801
+ # Group test cases by query name for statistics
802
+ from collections import defaultdict
803
+ query_cases = defaultdict(list)
804
+ for tc in self.test_cases:
805
+ query_cases[tc.query_name].append(tc)
806
+
807
+ for query in self.workload.queries:
808
+ latencies: List[float] = []
809
+ errors = 0
810
+ error_logs: List[Dict] = [] # [{sql, error}]
811
+
812
+ # --- Phase 1: Warmup (use original template for warmup, not test cases) ---
813
+ # Warmup doesn't affect fairness, so we can render fresh SQL
814
+ warmup_engine = TemplateEngine(seed=self.bench_cfg.seed)
815
+ for i in range(self.bench_cfg.warmup):
816
+ rendered_sql = warmup_engine.render(query.sql)
817
+ try:
818
+ db.cursor.execute(rendered_sql)
819
+ if db.cursor.description:
820
+ db.cursor.fetchall()
821
+ except Exception as e:
822
+ log.debug("[%s] Warmup error: %s — %s",
823
+ self.config.name, query.name, e)
824
+ if db._is_connection_lost(e):
825
+ if db.reconnect():
826
+ try:
827
+ db.cursor.execute(
828
+ f"USE `{self.database}`")
829
+ except Exception:
830
+ pass
831
+ # Still run cleanup to restore state
832
+ if query.cleanup_sql:
833
+ try:
834
+ cleanup = warmup_engine.render(query.cleanup_sql)
835
+ db.cursor.execute(cleanup)
836
+ if db.cursor.description:
837
+ db.cursor.fetchall()
838
+ except Exception:
839
+ pass
840
+ continue
841
+
842
+ # Run cleanup SQL after warmup iteration
843
+ if query.cleanup_sql:
844
+ try:
845
+ cleanup = warmup_engine.render(query.cleanup_sql)
846
+ db.cursor.execute(cleanup)
847
+ if db.cursor.description:
848
+ db.cursor.fetchall()
849
+ except Exception:
850
+ pass
851
+
852
+ if self.on_progress:
853
+ self.on_progress(
854
+ query.name, i + 1,
855
+ self.bench_cfg.warmup,
856
+ is_warmup=True,
857
+ )
858
+
859
+ # --- Phase 1.5: Capture EXPLAIN plan (once, after warmup) ---
860
+ explain_text = ""
861
+ explain_tree_text = ""
862
+ try:
863
+ rendered_sql = warmup_engine.render(query.sql)
864
+ db.cursor.execute("EXPLAIN " + rendered_sql)
865
+ if db.cursor.description:
866
+ cols = [desc[0] for desc in db.cursor.description]
867
+ rows = db.cursor.fetchall()
868
+ # Format as aligned text table
869
+ col_widths = [len(c) for c in cols]
870
+ for row in rows:
871
+ for j, cell in enumerate(row):
872
+ col_widths[j] = max(
873
+ col_widths[j], len(str(cell)))
874
+ header = " | ".join(
875
+ c.ljust(col_widths[j])
876
+ for j, c in enumerate(cols))
877
+ sep = "-+-".join(
878
+ "-" * col_widths[j]
879
+ for j in range(len(cols)))
880
+ lines = [header, sep]
881
+ for row in rows:
882
+ lines.append(" | ".join(
883
+ str(cell).ljust(col_widths[j])
884
+ for j, cell in enumerate(row)))
885
+ explain_text = "\n".join(lines)
886
+ except Exception as e:
887
+ log.debug("[%s] EXPLAIN failed for %s: %s",
888
+ self.config.name, query.name, e)
889
+
890
+ # --- Phase 1.6: Capture EXPLAIN FORMAT=TREE (tdsql only) ---
891
+ if self.config.name.lower() == "tdsql":
892
+ try:
893
+ rendered_sql = warmup_engine.render(query.sql)
894
+ db.cursor.execute(
895
+ "EXPLAIN FORMAT=TREE " + rendered_sql)
896
+ tree_rows = db.cursor.fetchall()
897
+ if tree_rows:
898
+ explain_tree_text = "\n".join(
899
+ str(row[0]) if row else ""
900
+ for row in tree_rows)
901
+ except Exception as e:
902
+ log.debug("[%s] EXPLAIN FORMAT=TREE failed for %s: %s",
903
+ self.config.name, query.name, e)
904
+
905
+ # --- Phase 2: Execute pre-generated test cases ---
906
+ # Begin timing: only actual SQL execution counts
907
+ q_start = _time.monotonic()
908
+ if overall_start is None:
909
+ overall_start = q_start
910
+
911
+ profiler = None
912
+ if profiling:
913
+ profiler = PerfProfiler(
914
+ mysqld_pid=self._mysqld_pid,
915
+ perf_freq=self.bench_cfg.perf_freq,
916
+ )
917
+ profiler.start()
918
+
919
+ # Execute all test cases for this query
920
+ test_cases_for_query = query_cases.get(query.name, [])
921
+ for i, tc in enumerate(test_cases_for_query):
922
+ # Use pre-rendered SQL from test case
923
+ sql = tc.sql
924
+
925
+ t0 = _time.monotonic()
926
+ try:
927
+ db.cursor.execute(sql)
928
+ # Consume result set to measure full round-trip
929
+ if db.cursor.description:
930
+ db.cursor.fetchall()
931
+ except Exception as e:
932
+ errors += 1
933
+ err_msg = str(e)
934
+ log.debug("[%s] Query error: %s — %s",
935
+ self.config.name, query.name, e)
936
+ # Collect error details (limit to avoid memory bloat)
937
+ if len(error_logs) < 50:
938
+ error_logs.append({
939
+ "sql": sql[:500],
940
+ "error": err_msg[:500],
941
+ })
942
+ # Try reconnect on connection loss
943
+ if db._is_connection_lost(e):
944
+ if db.reconnect():
945
+ try:
946
+ db.cursor.execute(
947
+ f"USE `{self.database}`")
948
+ except Exception:
949
+ pass
950
+ # Still run cleanup to restore state for next iteration
951
+ if tc.cleanup_sql:
952
+ try:
953
+ db.cursor.execute(tc.cleanup_sql)
954
+ if db.cursor.description:
955
+ db.cursor.fetchall()
956
+ except Exception:
957
+ pass
958
+ continue
959
+ t1 = _time.monotonic()
960
+
961
+ latencies.append((t1 - t0) * 1000.0) # ms
962
+
963
+ # Run cleanup SQL to restore state (not timed)
964
+ if tc.cleanup_sql:
965
+ try:
966
+ db.cursor.execute(tc.cleanup_sql)
967
+ if db.cursor.description:
968
+ db.cursor.fetchall()
969
+ except Exception as ce:
970
+ log.debug("[%s] Cleanup error: %s — %s",
971
+ self.config.name, query.name, ce)
972
+
973
+ if self.on_progress:
974
+ self.on_progress(
975
+ query.name, i + 1,
976
+ len(test_cases_for_query),
977
+ is_warmup=False,
978
+ )
979
+
980
+ # --- Phase 3: Stop perf immediately after iterations ---
981
+ q_elapsed = _time.monotonic() - q_start
982
+
983
+ fg_svg = ""
984
+ if profiler is not None:
985
+ # Update status to indicate perf processing (can be slow)
986
+ if self.on_profile_start:
987
+ self.on_profile_start(query.name)
988
+ fg_data = profiler.stop(query_name=query.name)
989
+ # Only show flamegraph if total duration exceeds threshold
990
+ if fg_data.svg_content:
991
+ min_ms = self.bench_cfg.flamegraph_min_ms
992
+ if min_ms > 0 and q_elapsed * 1000 < min_ms:
993
+ log.debug("[%s] Skipping flamegraph for %s: total duration %.0fms < %dms threshold",
994
+ self.config.name, query.name, q_elapsed * 1000, min_ms)
995
+ else:
996
+ fg_svg = fg_data.svg_content
997
+ elif fg_data.error:
998
+ log.warning("[%s] Flame graph for %s: %s",
999
+ self.config.name, query.name,
1000
+ fg_data.error)
1001
+ profiler.cleanup()
1002
+ if self.on_profile_done:
1003
+ self.on_profile_done(
1004
+ query.name, fg_data.sample_count)
1005
+ stats = compute_stats(
1006
+ latencies, errors, q_elapsed, query.name,
1007
+ sql_template=query.sql)
1008
+ stats.flamegraph_svg = fg_svg
1009
+ stats.explain_plan = explain_text
1010
+ stats.explain_tree = explain_tree_text
1011
+ stats.error_logs = error_logs
1012
+ result.query_stats.append(stats)
1013
+ result.total_queries += len(latencies) + errors
1014
+ result.total_errors += errors
1015
+
1016
+ result.total_duration_s = (
1017
+ (_time.monotonic() - overall_start)
1018
+ if overall_start is not None else 0.0
1019
+ )
1020
+ if result.total_duration_s > 0:
1021
+ result.overall_qps = (
1022
+ result.total_queries / result.total_duration_s
1023
+ )
1024
+
1025
+ finally:
1026
+ # Teardown phase
1027
+ self._run_teardown(db)
1028
+ db.close()
1029
+
1030
+ return result
1031
+
1032
+
1033
+ # ---------------------------------------------------------------------------
1034
+ # Concurrent benchmark runner
1035
+ # ---------------------------------------------------------------------------
1036
+
1037
+ class ConcurrentBenchmarkRunner(BaseBenchmarkRunner):
1038
+ """Multi-threaded stress test with duration-based execution.
1039
+
1040
+ Callbacks:
1041
+ on_progress: Called after each query execution.
1042
+ on_run_start: Called when steady-state execution begins (after setup/ramp-up).
1043
+
1044
+ Outlier Detection:
1045
+ Queries exceeding outlier_threshold_ms are logged but still counted.
1046
+ This helps identify long-running queries without distorting statistics.
1047
+ """
1048
+
1049
+ def run(self, skip_setup: bool = False) -> DBMSBenchResult:
1050
+ """Run the concurrent benchmark and return results.
1051
+
1052
+ In concurrent mode, workers execute queries continuously for the
1053
+ specified duration. Each worker loops through
1054
+ the pre-generated test cases repeatedly until time expires.
1055
+
1056
+ Args:
1057
+ skip_setup: If True, skip setup phase (already done in separate pass).
1058
+ """
1059
+ result = DBMSBenchResult(dbms_name=self.config.name)
1060
+
1061
+ if not ensure_service(self.config):
1062
+ log.error("[%s] Service unavailable, skipping benchmark",
1063
+ self.config.name)
1064
+ return result
1065
+
1066
+ # Setup phase (single connection) - skip if already done
1067
+ if not skip_setup:
1068
+ setup_db = DBConnection(self.config, self.database)
1069
+ try:
1070
+ setup_db.connect()
1071
+ self._run_setup(setup_db, result)
1072
+ except Exception as e:
1073
+ log.error("[%s] Connection failed for setup: %s",
1074
+ self.config.name, e)
1075
+ return result
1076
+ finally:
1077
+ setup_db.close()
1078
+
1079
+ if not self.test_cases:
1080
+ log.error("[%s] No test cases generated", self.config.name)
1081
+ return result
1082
+
1083
+ # Capture EXPLAIN plans before the concurrent run (single connection)
1084
+ # Use pre-generated test cases for EXPLAIN
1085
+ explain_plans: Dict[str, str] = {}
1086
+ explain_tree_plans: Dict[str, str] = {}
1087
+ try:
1088
+ explain_db = DBConnection(self.config, self.database)
1089
+ explain_db.connect()
1090
+
1091
+ # Group test cases by query name
1092
+ from collections import defaultdict
1093
+ query_cases_map = defaultdict(list)
1094
+ for tc in self.test_cases:
1095
+ query_cases_map[tc.query_name].append(tc)
1096
+
1097
+ for query in self.workload.queries:
1098
+ # Get first test case for this query
1099
+ cases_for_query = query_cases_map.get(query.name, [])
1100
+ if not cases_for_query:
1101
+ continue
1102
+
1103
+ # Use the first pre-rendered SQL for EXPLAIN
1104
+ sample_sql = cases_for_query[0].sql
1105
+
1106
+ try:
1107
+ explain_db.cursor.execute("EXPLAIN " + sample_sql)
1108
+ if explain_db.cursor.description:
1109
+ cols = [desc[0]
1110
+ for desc in explain_db.cursor.description]
1111
+ rows = explain_db.cursor.fetchall()
1112
+ col_widths = [len(c) for c in cols]
1113
+ for row in rows:
1114
+ for j, cell in enumerate(row):
1115
+ col_widths[j] = max(
1116
+ col_widths[j], len(str(cell)))
1117
+ header = " | ".join(
1118
+ c.ljust(col_widths[j])
1119
+ for j, c in enumerate(cols))
1120
+ sep = "-+-".join(
1121
+ "-" * col_widths[j]
1122
+ for j in range(len(cols)))
1123
+ lines = [header, sep]
1124
+ for row in rows:
1125
+ lines.append(" | ".join(
1126
+ str(cell).ljust(col_widths[j])
1127
+ for j, cell in enumerate(row)))
1128
+ explain_plans[query.name] = "\n".join(lines)
1129
+ except Exception as e:
1130
+ log.debug("[%s] EXPLAIN failed for %s: %s",
1131
+ self.config.name, query.name, e)
1132
+
1133
+ # EXPLAIN FORMAT=TREE (tdsql only)
1134
+ if self.config.name.lower() == "tdsql":
1135
+ try:
1136
+ explain_db.cursor.execute(
1137
+ "EXPLAIN FORMAT=TREE " + sample_sql)
1138
+ tree_rows = explain_db.cursor.fetchall()
1139
+ if tree_rows:
1140
+ explain_tree_plans[query.name] = "\n".join(
1141
+ str(row[0]) if row else ""
1142
+ for row in tree_rows)
1143
+ except Exception as e:
1144
+ log.debug(
1145
+ "[%s] EXPLAIN FORMAT=TREE failed for %s: %s",
1146
+ self.config.name, query.name, e)
1147
+ explain_db.close()
1148
+ except Exception as e:
1149
+ log.debug("[%s] EXPLAIN connection failed: %s",
1150
+ self.config.name, e)
1151
+
1152
+ # Determine run duration
1153
+ duration = self.bench_cfg.duration
1154
+ if duration <= 0:
1155
+ duration = 30.0 # default 30s
1156
+
1157
+ concurrency = max(1, self.bench_cfg.concurrency)
1158
+ ramp_up = self.bench_cfg.ramp_up
1159
+
1160
+ # Per-query latencies collected across all threads
1161
+ latency_lock = threading.Lock()
1162
+ per_query_latencies: Dict[str, List[float]] = {
1163
+ q.name: [] for q in self.workload.queries
1164
+ }
1165
+ per_query_errors: Dict[str, int] = {
1166
+ q.name: 0 for q in self.workload.queries
1167
+ }
1168
+ per_query_error_logs: Dict[str, List[Dict]] = {
1169
+ q.name: [] for q in self.workload.queries
1170
+ }
1171
+ stop_event = threading.Event()
1172
+ total_executed = [0] # mutable counter
1173
+ active_connections: List[DBConnection] = [] # Track for forced close
1174
+ conn_lock = threading.Lock()
1175
+
1176
+ # Query timeout configuration
1177
+ query_timeout = self.bench_cfg.query_timeout
1178
+ outlier_threshold_ms = query_timeout * 1000 if query_timeout > 0 else 0
1179
+ outliers_logged = set() # Avoid spamming logs for same query
1180
+
1181
+ # Profiling setup — in concurrent mode, capture a single mixed
1182
+ # flame graph for the entire run duration.
1183
+ profiling = self._check_profiling_support()
1184
+ profiler = None
1185
+ if profiling:
1186
+ from .flamegraph import PerfProfiler
1187
+ if self.on_profile_start:
1188
+ self.on_profile_start("concurrent_mix")
1189
+ profiler = PerfProfiler(
1190
+ mysqld_pid=self._mysqld_pid,
1191
+ perf_freq=self.bench_cfg.perf_freq,
1192
+ )
1193
+
1194
+ def worker(thread_id: int, start_delay: float):
1195
+ """Worker thread that executes queries continuously until stop_event.
1196
+
1197
+ Each worker loops through the pre-generated test cases repeatedly,
1198
+ cycling back to the start when reaching the end. This ensures
1199
+ time-based execution similar to the duration-based mode.
1200
+ """
1201
+ if start_delay > 0:
1202
+ _time.sleep(start_delay)
1203
+
1204
+ db = DBConnection(self.config, self.database)
1205
+ try:
1206
+ db.connect(query_timeout=query_timeout)
1207
+ except Exception as e:
1208
+ log.warning("[%s] Worker %d connect failed: %s",
1209
+ self.config.name, thread_id, e)
1210
+ return
1211
+
1212
+ # Track connection for forced close
1213
+ with conn_lock:
1214
+ active_connections.append(db)
1215
+
1216
+ try:
1217
+ # Local index for cycling through test cases
1218
+ local_idx = 0
1219
+ n_cases = len(self.test_cases)
1220
+
1221
+ while not stop_event.is_set():
1222
+ # Cycle through test cases repeatedly
1223
+ tc = self.test_cases[local_idx % n_cases]
1224
+ local_idx += 1
1225
+
1226
+ # Execute pre-rendered SQL
1227
+ t0 = _time.monotonic()
1228
+ try:
1229
+ db.cursor.execute(tc.sql)
1230
+ if db.cursor.description:
1231
+ db.cursor.fetchall()
1232
+ except Exception as e:
1233
+ with latency_lock:
1234
+ per_query_errors[tc.query_name] += 1
1235
+ # Collect error details (limit per query)
1236
+ logs = per_query_error_logs[tc.query_name]
1237
+ if len(logs) < 50:
1238
+ logs.append({
1239
+ "sql": tc.sql[:500],
1240
+ "error": str(e)[:500],
1241
+ })
1242
+ if db._is_connection_lost(e):
1243
+ if not db.reconnect():
1244
+ break
1245
+ try:
1246
+ db.cursor.execute(
1247
+ f"USE `{self.database}`")
1248
+ except Exception:
1249
+ pass
1250
+ # Still run cleanup to restore state for next iteration
1251
+ if tc.cleanup_sql:
1252
+ try:
1253
+ db.cursor.execute(tc.cleanup_sql)
1254
+ if db.cursor.description:
1255
+ db.cursor.fetchall()
1256
+ except Exception:
1257
+ pass
1258
+ continue
1259
+ t1 = _time.monotonic()
1260
+
1261
+ lat_ms = (t1 - t0) * 1000.0
1262
+
1263
+ # Run cleanup SQL to restore state (not timed)
1264
+ if tc.cleanup_sql:
1265
+ try:
1266
+ db.cursor.execute(tc.cleanup_sql)
1267
+ if db.cursor.description:
1268
+ db.cursor.fetchall()
1269
+ except Exception:
1270
+ pass
1271
+
1272
+ # Log outlier queries (exceeding threshold)
1273
+ if outlier_threshold_ms > 0 and lat_ms > outlier_threshold_ms:
1274
+ outlier_key = (tc.query_name, int(lat_ms / 1000))
1275
+ if outlier_key not in outliers_logged:
1276
+ outliers_logged.add(outlier_key)
1277
+ log.warning(
1278
+ "[%s] Slow query detected: %s took %.0fms (>%ds threshold)",
1279
+ self.config.name, tc.query_name, lat_ms, query_timeout
1280
+ )
1281
+
1282
+ with latency_lock:
1283
+ per_query_latencies[tc.query_name].append(lat_ms)
1284
+ total_executed[0] += 1
1285
+
1286
+ if self.on_progress:
1287
+ self.on_progress(
1288
+ tc.query_name, total_executed[0], 0,
1289
+ is_warmup=False,
1290
+ )
1291
+ finally:
1292
+ # Remove from active connections
1293
+ with conn_lock:
1294
+ try:
1295
+ active_connections.remove(db)
1296
+ except ValueError:
1297
+ pass
1298
+ db.close()
1299
+
1300
+ # Launch threads with ramp-up
1301
+ exec_start = None # set after ramp-up, before steady-state timing
1302
+ with concurrent.futures.ThreadPoolExecutor(
1303
+ max_workers=concurrency) as pool:
1304
+ futures = []
1305
+ for i in range(concurrency):
1306
+ delay = (ramp_up / concurrency) * i if ramp_up > 0 else 0
1307
+ futures.append(pool.submit(worker, i, delay))
1308
+
1309
+ # Wait for ramp-up to complete before starting profiler
1310
+ # so it only captures steady-state load
1311
+ if ramp_up > 0:
1312
+ _time.sleep(ramp_up)
1313
+
1314
+ # Notify that steady-state execution is starting
1315
+ if self.on_run_start:
1316
+ self.on_run_start()
1317
+
1318
+ # Begin timing: only steady-state execution counts
1319
+ exec_start = _time.monotonic()
1320
+
1321
+ # Start profiler after ramp-up, when all workers are active
1322
+ if profiler is not None:
1323
+ profiler.start()
1324
+
1325
+ # Wait for duration to expire
1326
+ _time.sleep(duration)
1327
+
1328
+ # Signal all workers to stop
1329
+ stop_event.set()
1330
+
1331
+ # Force close any lingering connections to interrupt slow queries
1332
+ # Wait a brief moment for workers to finish gracefully
1333
+ _time.sleep(0.5)
1334
+ with conn_lock:
1335
+ for conn in list(active_connections):
1336
+ try:
1337
+ conn.close()
1338
+ except Exception:
1339
+ pass
1340
+
1341
+ # Wait for all threads to finish
1342
+ for f in futures:
1343
+ try:
1344
+ f.result(timeout=5) # Reduced timeout since connections are closed
1345
+ except Exception as e:
1346
+ log.debug("[%s] Worker cleanup: %s",
1347
+ self.config.name, e)
1348
+
1349
+ overall_elapsed = _time.monotonic() - exec_start
1350
+
1351
+ # Stop profiler immediately after all workers finish
1352
+ concurrent_fg_svg = ""
1353
+ if profiler is not None:
1354
+ fg_data = profiler.stop(query_name="concurrent_mix")
1355
+ if fg_data.svg_content:
1356
+ concurrent_fg_svg = fg_data.svg_content
1357
+ elif fg_data.error:
1358
+ log.warning("[%s] Flame graph: %s",
1359
+ self.config.name, fg_data.error)
1360
+ profiler.cleanup()
1361
+ if self.on_profile_done:
1362
+ self.on_profile_done("concurrent_mix", fg_data.sample_count)
1363
+
1364
+ # Compute stats per query
1365
+ for query in self.workload.queries:
1366
+ lats = per_query_latencies[query.name]
1367
+ errs = per_query_errors[query.name]
1368
+ stats = compute_stats(lats, errs, overall_elapsed, query.name,
1369
+ sql_template=query.sql)
1370
+ # In concurrent mode, all queries share the same flame graph
1371
+ stats.flamegraph_svg = concurrent_fg_svg
1372
+ stats.explain_plan = explain_plans.get(query.name, "")
1373
+ stats.explain_tree = explain_tree_plans.get(query.name, "")
1374
+ stats.error_logs = per_query_error_logs.get(query.name, [])
1375
+ result.query_stats.append(stats)
1376
+ result.total_queries += len(lats) + errs
1377
+ result.total_errors += errs
1378
+
1379
+ result.total_duration_s = overall_elapsed
1380
+ if overall_elapsed > 0:
1381
+ result.overall_qps = result.total_queries / overall_elapsed
1382
+
1383
+ # Teardown phase (single connection)
1384
+ teardown_db = DBConnection(self.config, self.database)
1385
+ try:
1386
+ teardown_db.connect()
1387
+ self._run_teardown(teardown_db)
1388
+ except Exception:
1389
+ pass
1390
+ finally:
1391
+ teardown_db.close()
1392
+
1393
+ return result
1394
+
1395
+
1396
+ # ---------------------------------------------------------------------------
1397
+ # Top-level benchmark runner
1398
+ # ---------------------------------------------------------------------------
1399
+
1400
+ def run_benchmark(
1401
+ configs: List[DBMSConfig],
1402
+ workload: BenchWorkload,
1403
+ bench_cfg: BenchmarkConfig,
1404
+ database: str = "rosetta_bench",
1405
+ on_progress: Optional[Callable] = None,
1406
+ on_dbms_start: Optional[Callable] = None,
1407
+ on_dbms_done: Optional[Callable] = None,
1408
+ on_profile_start: Optional[Callable] = None,
1409
+ on_profile_done: Optional[Callable] = None,
1410
+ on_run_start: Optional[Callable] = None,
1411
+ on_setup_start: Optional[Callable] = None,
1412
+ on_setup_done: Optional[Callable] = None,
1413
+ parallel_dbms: bool = False,
1414
+ ) -> BenchmarkResult:
1415
+ """Run benchmark on all DBMS targets and return aggregated results.
1416
+
1417
+ The execution is split into two phases for fairness:
1418
+ 1. Setup phase: All DBMS targets run setup SQL in parallel
1419
+ 2. Query phase: After all setups complete, run queries in parallel
1420
+
1421
+ This ensures all DBMS start the query phase at the same time with
1422
+ identical data states.
1423
+
1424
+ Args:
1425
+ configs: List of DBMS connection configs.
1426
+ workload: The workload definition.
1427
+ bench_cfg: Benchmark runtime configuration.
1428
+ database: Database name for the benchmark.
1429
+ on_progress: Optional callback(dbms_name, query_name, iteration, total).
1430
+ on_dbms_start: Optional callback(dbms_name).
1431
+ on_dbms_done: Optional callback(dbms_name, dbms_result).
1432
+ on_profile_start: Optional callback(dbms_name, query_name).
1433
+ on_profile_done: Optional callback(dbms_name, query_name, sample_count).
1434
+ on_run_start: Optional callback() - called when query phase begins.
1435
+ on_setup_start: Optional callback(dbms_name) - called when setup starts.
1436
+ on_setup_done: Optional callback(dbms_name, success) - called when setup finishes.
1437
+ parallel_dbms: If True, run benchmarks on all DBMS targets in
1438
+ parallel (each DBMS gets its own thread and TemplateEngine).
1439
+
1440
+ Returns:
1441
+ BenchmarkResult with results from all DBMS instances.
1442
+ """
1443
+ result = BenchmarkResult(
1444
+ workload_name=workload.name,
1445
+ mode=bench_cfg.mode,
1446
+ config=bench_cfg,
1447
+ timestamp=_time.strftime("%Y-%m-%d %H:%M:%S"),
1448
+ setup_sql=list(workload.setup),
1449
+ teardown_sql=list(workload.teardown),
1450
+ queries_sql=[
1451
+ {"name": q.name, "sql": q.sql, "weight": q.weight,
1452
+ "description": q.description, "cleanup_sql": q.cleanup_sql}
1453
+ for q in workload.queries
1454
+ ],
1455
+ )
1456
+
1457
+ # Apply query filter
1458
+ if bench_cfg.filter_queries:
1459
+ workload = BenchmarkLoader.filter_queries(
1460
+ workload, bench_cfg.filter_queries)
1461
+
1462
+ # Pre-generate test cases for fair comparison
1463
+ # All DBMS instances will execute the exact same sequence of SQL statements
1464
+ log.info("Pre-generating test cases with seed=%d", bench_cfg.seed)
1465
+ generator = TestCaseGenerator(seed=bench_cfg.seed)
1466
+
1467
+ if bench_cfg.mode == WorkloadMode.SERIAL:
1468
+ test_cases = generator.generate_serial_cases(
1469
+ workload, bench_cfg.iterations, bench_cfg.warmup)
1470
+ log.info("Generated %d test cases for serial mode", len(test_cases))
1471
+ else:
1472
+ # For concurrent mode, generate a pool of test cases that workers
1473
+ # will cycle through repeatedly during the duration.
1474
+ # Use a reasonable pool size (~100 per query) for variety.
1475
+ pool_size = max(100, len(workload.queries) * 20)
1476
+ test_cases = generator.generate_concurrent_cases(workload, pool_size)
1477
+ log.info("Generated %d test cases for concurrent mode (will cycle)", len(test_cases))
1478
+
1479
+ # Track setup results per DBMS
1480
+ setup_results: Dict[str, bool] = {}
1481
+ setup_table_rows: Dict[str, int] = {}
1482
+ setup_table_rows_detail: Dict[str, Dict[str, int]] = {}
1483
+ setup_table_schema: Dict[str, Dict[str, str]] = {} # {dbms_name: {table_name: CREATE TABLE stmt}}
1484
+
1485
+ def _run_setup(config: DBMSConfig) -> bool:
1486
+ """Run setup phase on a single DBMS. Returns True on success."""
1487
+ if on_setup_start:
1488
+ on_setup_start(config.name)
1489
+
1490
+ if not ensure_service(config):
1491
+ log.error("[%s] Service unavailable for setup", config.name)
1492
+ if on_setup_done:
1493
+ on_setup_done(config.name, False)
1494
+ return False
1495
+
1496
+ db = DBConnection(config, database)
1497
+ try:
1498
+ db.connect()
1499
+ if bench_cfg.skip_setup:
1500
+ log.info("[%s] Skipping setup (--skip-setup), reusing existing tables", config.name)
1501
+ else:
1502
+ # Run setup SQL
1503
+ for sql in workload.setup:
1504
+ try:
1505
+ db.cursor.execute(sql)
1506
+ except Exception as e:
1507
+ log.warning("[%s] Setup failed: %s — %s",
1508
+ config.name, sql[:80], e)
1509
+
1510
+ # Query total rows and schema after setup
1511
+ try:
1512
+ db.cursor.execute(
1513
+ "SELECT TABLE_NAME FROM information_schema.TABLES "
1514
+ f"WHERE TABLE_SCHEMA = '{database}' "
1515
+ "AND TABLE_TYPE = 'BASE TABLE'")
1516
+ tables = [r[0] for r in db.cursor.fetchall()]
1517
+ total = 0
1518
+ detail = {}
1519
+ schema = {}
1520
+ for tbl in tables:
1521
+ # Get row count
1522
+ try:
1523
+ db.cursor.execute(f"SELECT COUNT(*) FROM `{tbl}`")
1524
+ cnt = db.cursor.fetchone()
1525
+ if cnt and cnt[0]:
1526
+ row_count = int(cnt[0])
1527
+ total += row_count
1528
+ detail[tbl] = row_count
1529
+ except Exception:
1530
+ pass
1531
+ # Get CREATE TABLE statement
1532
+ try:
1533
+ db.cursor.execute(f"SHOW CREATE TABLE `{tbl}`")
1534
+ row = db.cursor.fetchone()
1535
+ if row and len(row) >= 2:
1536
+ schema[tbl] = row[1] # Second column is the CREATE TABLE stmt
1537
+ except Exception:
1538
+ pass
1539
+ setup_table_rows[config.name] = total
1540
+ setup_table_rows_detail[config.name] = detail
1541
+ setup_table_schema[config.name] = schema
1542
+ except Exception as e:
1543
+ log.debug("[%s] Could not query table rows: %s",
1544
+ config.name, e)
1545
+
1546
+ log.info("[%s] Setup completed", config.name)
1547
+ if on_setup_done:
1548
+ on_setup_done(config.name, True)
1549
+ return True
1550
+
1551
+ except Exception as e:
1552
+ log.error("[%s] Setup connection failed: %s", config.name, e)
1553
+ if on_setup_done:
1554
+ on_setup_done(config.name, False)
1555
+ return False
1556
+ finally:
1557
+ db.close()
1558
+
1559
+ # --- Phase 1: Run setup on all DBMS targets ---
1560
+ log.info("Starting setup phase for %d DBMS targets", len(configs))
1561
+
1562
+ if parallel_dbms and len(configs) > 1:
1563
+ # Parallel setup
1564
+ with concurrent.futures.ThreadPoolExecutor(
1565
+ max_workers=len(configs)) as pool:
1566
+ futures = {pool.submit(_run_setup, c): c for c in configs}
1567
+ for fut in concurrent.futures.as_completed(futures):
1568
+ config = futures[fut]
1569
+ try:
1570
+ setup_results[config.name] = fut.result()
1571
+ except Exception as e:
1572
+ log.error("[%s] Setup failed: %s", config.name, e)
1573
+ setup_results[config.name] = False
1574
+ else:
1575
+ # Sequential setup
1576
+ for config in configs:
1577
+ setup_results[config.name] = _run_setup(config)
1578
+
1579
+ # Check if any setup succeeded
1580
+ successful_configs = [c for c in configs if setup_results.get(c.name, False)]
1581
+ if not successful_configs:
1582
+ log.error("All DBMS setups failed, aborting benchmark")
1583
+ return result
1584
+
1585
+ # --- Phase 2: Run query phase on all DBMS targets ---
1586
+ log.info("All setups complete, starting query phase")
1587
+
1588
+ # Brief pause to ensure "setup完毕" status is visible to user
1589
+ _time.sleep(1.0)
1590
+
1591
+ # Notify that query phase is starting (for UI timing reset)
1592
+ if on_run_start:
1593
+ on_run_start()
1594
+
1595
+ def _run_query_phase(config: DBMSConfig) -> DBMSBenchResult:
1596
+ """Run query phase on a single DBMS target (setup already done)."""
1597
+ if on_dbms_start:
1598
+ on_dbms_start(config.name)
1599
+
1600
+ def _progress_cb(qname, it, total, is_warmup=False,
1601
+ _dbms=config.name):
1602
+ if on_progress:
1603
+ on_progress(_dbms, qname, it, total, is_warmup)
1604
+
1605
+ def _profile_start_cb(qname, _dbms=config.name):
1606
+ if on_profile_start:
1607
+ on_profile_start(_dbms, qname)
1608
+
1609
+ def _profile_done_cb(qname, samples, _dbms=config.name):
1610
+ if on_profile_done:
1611
+ on_profile_done(_dbms, qname, samples)
1612
+
1613
+ if bench_cfg.mode == WorkloadMode.SERIAL:
1614
+ runner = SerialBenchmarkRunner(
1615
+ config, workload, bench_cfg, test_cases, database,
1616
+ on_progress=_progress_cb,
1617
+ on_profile_start=_profile_start_cb,
1618
+ on_profile_done=_profile_done_cb,
1619
+ )
1620
+ else:
1621
+ runner = ConcurrentBenchmarkRunner(
1622
+ config, workload, bench_cfg, test_cases, database,
1623
+ on_progress=_progress_cb,
1624
+ on_profile_start=_profile_start_cb,
1625
+ on_profile_done=_profile_done_cb,
1626
+ on_run_start=None, # Already called above
1627
+ )
1628
+
1629
+ # Skip setup since it's already done
1630
+ dbms_result = runner.run(skip_setup=True)
1631
+
1632
+ # Add table_rows and schema from setup phase
1633
+ if config.name in setup_table_rows:
1634
+ dbms_result.table_rows = setup_table_rows[config.name]
1635
+ if config.name in setup_table_rows_detail:
1636
+ dbms_result.table_rows_detail = setup_table_rows_detail[config.name]
1637
+ if config.name in setup_table_schema:
1638
+ dbms_result.table_schema = setup_table_schema[config.name]
1639
+
1640
+ if on_dbms_done:
1641
+ on_dbms_done(config.name, dbms_result)
1642
+
1643
+ return dbms_result
1644
+
1645
+ if parallel_dbms and len(successful_configs) > 1:
1646
+ # Run all DBMS targets in parallel
1647
+ with concurrent.futures.ThreadPoolExecutor(
1648
+ max_workers=len(successful_configs)) as pool:
1649
+ futures = {
1650
+ pool.submit(_run_query_phase, c): c for c in successful_configs
1651
+ }
1652
+ for fut in concurrent.futures.as_completed(futures):
1653
+ try:
1654
+ dbms_result = fut.result()
1655
+ result.dbms_results.append(dbms_result)
1656
+ except Exception as e:
1657
+ config = futures[fut]
1658
+ log.error("[%s] Benchmark failed: %s", config.name, e)
1659
+ else:
1660
+ # Sequential execution
1661
+ for config in successful_configs:
1662
+ dbms_result = _run_query_phase(config)
1663
+ result.dbms_results.append(dbms_result)
1664
+
1665
+ # Ensure results are in the same order as configs for consistent reports
1666
+ name_order = {c.name: i for i, c in enumerate(configs)}
1667
+ result.dbms_results.sort(
1668
+ key=lambda r: name_order.get(r.dbms_name, 999))
1669
+
1670
+ # Populate table_rows and schema from first DBMS that reported a value
1671
+ for dr in result.dbms_results:
1672
+ if dr.table_rows > 0:
1673
+ result.table_rows = dr.table_rows
1674
+ result.table_rows_detail = dr.table_rows_detail
1675
+ result.table_schema = dr.table_schema
1676
+ break
1677
+
1678
+ return result