rosetta-sql 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/generate_csv_data.py +83 -0
- benchmark/import_data.py +168 -0
- rosetta/__init__.py +3 -0
- rosetta/__main__.py +8 -0
- rosetta/benchmark.py +1678 -0
- rosetta/buglist.py +108 -0
- rosetta/cli/__init__.py +11 -0
- rosetta/cli/config_cmd.py +243 -0
- rosetta/cli/exec.py +219 -0
- rosetta/cli/interactive_cmd.py +124 -0
- rosetta/cli/list_cmd.py +215 -0
- rosetta/cli/main.py +617 -0
- rosetta/cli/output.py +545 -0
- rosetta/cli/result.py +61 -0
- rosetta/cli/result_cmd.py +247 -0
- rosetta/cli/run.py +625 -0
- rosetta/cli/status.py +161 -0
- rosetta/comparator.py +205 -0
- rosetta/config.py +139 -0
- rosetta/executor.py +403 -0
- rosetta/flamegraph.py +630 -0
- rosetta/interactive.py +1790 -0
- rosetta/models.py +197 -0
- rosetta/parser.py +308 -0
- rosetta/reporter/__init__.py +1 -0
- rosetta/reporter/bench_html.py +1457 -0
- rosetta/reporter/bench_text.py +162 -0
- rosetta/reporter/history.py +1686 -0
- rosetta/reporter/html.py +644 -0
- rosetta/reporter/text.py +110 -0
- rosetta/runner.py +3089 -0
- rosetta/ui.py +736 -0
- rosetta/whitelist.py +161 -0
- rosetta_sql-1.0.0.dist-info/LICENSE +21 -0
- rosetta_sql-1.0.0.dist-info/METADATA +379 -0
- rosetta_sql-1.0.0.dist-info/RECORD +42 -0
- rosetta_sql-1.0.0.dist-info/WHEEL +5 -0
- rosetta_sql-1.0.0.dist-info/entry_points.txt +2 -0
- rosetta_sql-1.0.0.dist-info/top_level.txt +4 -0
- skills/rosetta/scripts/install_rosetta.py +469 -0
- skills/rosetta/scripts/rosetta_wrapper.py +377 -0
- tests/test_cli.py +749 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
生成 bench_json_mv_select 测试所需的 CSV 数据文件。
|
|
4
|
+
|
|
5
|
+
生成文件:
|
|
6
|
+
- bench_sel.csv: 100万行字符串tag数据 (user_id, json_array)
|
|
7
|
+
- bench_sel_int.csv: 100万行整数数据 (user_id, json_array)
|
|
8
|
+
|
|
9
|
+
用法:
|
|
10
|
+
python generate_csv_data.py [--output-dir DIR] [--rows N]
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import csv
|
|
15
|
+
import json
|
|
16
|
+
import os
|
|
17
|
+
import random
|
|
18
|
+
from typing import List, Tuple
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def generate_string_data(num_rows: int = 1_000_000) -> List[Tuple[int, str]]:
|
|
22
|
+
"""生成字符串tag数据,数组大小在 [1, 20] 范围内随机变化。"""
|
|
23
|
+
rows = []
|
|
24
|
+
for _ in range(num_rows):
|
|
25
|
+
user_id = random.randint(1, 100000)
|
|
26
|
+
array_size = random.randint(1, 20)
|
|
27
|
+
tags = [f"tag_{random.randint(0, 999)}" for _ in range(array_size)]
|
|
28
|
+
json_data = json.dumps(tags)
|
|
29
|
+
rows.append((user_id, json_data))
|
|
30
|
+
return rows
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def generate_int_data(num_rows: int = 1_000_000) -> List[Tuple[int, str]]:
|
|
34
|
+
"""生成整数数据,数组大小在 [1, 20] 范围内随机变化 (值域0-9999)。"""
|
|
35
|
+
rows = []
|
|
36
|
+
for _ in range(num_rows):
|
|
37
|
+
user_id = random.randint(1, 100000)
|
|
38
|
+
array_size = random.randint(1, 20)
|
|
39
|
+
values = [random.randint(0, 9999) for _ in range(array_size)]
|
|
40
|
+
json_data = json.dumps(values)
|
|
41
|
+
rows.append((user_id, json_data))
|
|
42
|
+
return rows
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def write_csv(filename: str, rows: List[Tuple], output_dir: str):
|
|
46
|
+
"""写入CSV文件,无header,制表符分隔。JSON字段不加引号。"""
|
|
47
|
+
filepath = os.path.join(output_dir, filename)
|
|
48
|
+
with open(filepath, 'w', newline='', encoding='utf-8') as f:
|
|
49
|
+
for row in rows:
|
|
50
|
+
# 手动写入,不使用csv.writer避免自动引号
|
|
51
|
+
# 格式: user_id\tjson_data\n
|
|
52
|
+
f.write(f"{row[0]}\t{row[1]}\n")
|
|
53
|
+
print(f"Generated: {filepath} ({len(rows)} rows)")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def main():
|
|
57
|
+
parser = argparse.ArgumentParser(description='Generate CSV data files for benchmark')
|
|
58
|
+
parser.add_argument('--output-dir', '-o', default='.', help='Output directory (default: current dir)')
|
|
59
|
+
parser.add_argument('--rows', '-n', type=int, default=1_000_000, help='Number of rows per table (default: 1000000)')
|
|
60
|
+
args = parser.parse_args()
|
|
61
|
+
|
|
62
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
# 设置随机种子以便可重现
|
|
65
|
+
random.seed(42)
|
|
66
|
+
|
|
67
|
+
print(f"Generating {args.rows:,} rows per table...")
|
|
68
|
+
|
|
69
|
+
# 生成字符串数据
|
|
70
|
+
print("Generating string data...")
|
|
71
|
+
str_rows = generate_string_data(args.rows)
|
|
72
|
+
write_csv('bench_sel.csv', str_rows, args.output_dir)
|
|
73
|
+
|
|
74
|
+
# 生成整数数据
|
|
75
|
+
print("Generating integer data...")
|
|
76
|
+
int_rows = generate_int_data(args.rows)
|
|
77
|
+
write_csv('bench_sel_int.csv', int_rows, args.output_dir)
|
|
78
|
+
|
|
79
|
+
print("Done!")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
if __name__ == '__main__':
|
|
83
|
+
main()
|
benchmark/import_data.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
将 CSV 数据导入到目标数据库。
|
|
4
|
+
|
|
5
|
+
用法:
|
|
6
|
+
python import_data.py --config ../dbms_config.json --dbms tdsql
|
|
7
|
+
python import_data.py --config ../dbms_config.json --dbms all
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import argparse
|
|
11
|
+
import json
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import pymysql
|
|
17
|
+
except ImportError:
|
|
18
|
+
print("Error: pymysql is required. Install via: pip install pymysql")
|
|
19
|
+
sys.exit(1)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
CREATE_TABLE_SQL = """
|
|
23
|
+
CREATE TABLE IF NOT EXISTS bench_sel (
|
|
24
|
+
id INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
|
|
25
|
+
user_id INT NOT NULL,
|
|
26
|
+
data JSON NOT NULL,
|
|
27
|
+
KEY idx_user (user_id),
|
|
28
|
+
KEY idx_mv_data ((CAST(data->'$' AS CHAR(64) ARRAY)))
|
|
29
|
+
)
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def import_data(config: dict, database: str, csv_path: str, batch_size: int = 10000):
|
|
34
|
+
"""Import data for a single DBMS."""
|
|
35
|
+
name = config['name']
|
|
36
|
+
print(f"\n[{name}] Connecting to {config['host']}:{config['port']}...")
|
|
37
|
+
print(f"[{name}] Target database: {database}")
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
# Connect and create database
|
|
41
|
+
conn = pymysql.connect(
|
|
42
|
+
host=config['host'],
|
|
43
|
+
port=config['port'],
|
|
44
|
+
user=config['user'],
|
|
45
|
+
password=config['password'],
|
|
46
|
+
local_infile=True,
|
|
47
|
+
charset='utf8mb4',
|
|
48
|
+
)
|
|
49
|
+
cursor = conn.cursor()
|
|
50
|
+
cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{database}`")
|
|
51
|
+
cursor.execute(f"USE `{database}`")
|
|
52
|
+
conn.commit()
|
|
53
|
+
|
|
54
|
+
# Drop existing table and create new one with MVI index
|
|
55
|
+
print(f"[{name}] Dropping existing table...")
|
|
56
|
+
cursor.execute("DROP TABLE IF EXISTS bench_sel")
|
|
57
|
+
conn.commit()
|
|
58
|
+
print(f"[{name}] Creating table...")
|
|
59
|
+
cursor.execute(CREATE_TABLE_SQL)
|
|
60
|
+
conn.commit()
|
|
61
|
+
|
|
62
|
+
# Import CSV using batch INSERT
|
|
63
|
+
print(f"[{name}] Importing data...")
|
|
64
|
+
insert_sql = "INSERT INTO bench_sel (user_id, data) VALUES (%s, %s)"
|
|
65
|
+
|
|
66
|
+
total_rows = 0
|
|
67
|
+
batch = []
|
|
68
|
+
batch_times = [] # 记录每批次耗时
|
|
69
|
+
import_start_time = time.time()
|
|
70
|
+
batch_start_time = time.time()
|
|
71
|
+
|
|
72
|
+
with open(csv_path, 'r', encoding='utf-8') as f:
|
|
73
|
+
for line in f:
|
|
74
|
+
parts = line.rstrip('\n').split('\t', 1)
|
|
75
|
+
if len(parts) == 2:
|
|
76
|
+
batch.append((int(parts[0]), parts[1]))
|
|
77
|
+
|
|
78
|
+
if len(batch) >= batch_size:
|
|
79
|
+
cursor.executemany(insert_sql, batch)
|
|
80
|
+
conn.commit()
|
|
81
|
+
total_rows += len(batch)
|
|
82
|
+
|
|
83
|
+
# 记录本批次耗时
|
|
84
|
+
batch_elapsed = time.time() - batch_start_time
|
|
85
|
+
batch_times.append(batch_elapsed)
|
|
86
|
+
|
|
87
|
+
# 计算统计信息
|
|
88
|
+
avg_batch_time = sum(batch_times) / len(batch_times)
|
|
89
|
+
rows_per_sec = len(batch) / batch_elapsed if batch_elapsed > 0 else 0
|
|
90
|
+
elapsed_total = time.time() - import_start_time
|
|
91
|
+
|
|
92
|
+
print(f"[{name}] Imported {total_rows:,} rows | "
|
|
93
|
+
f"Batch: {batch_elapsed:.2f}s | "
|
|
94
|
+
f"Avg: {avg_batch_time:.2f}s/batch | "
|
|
95
|
+
f"Speed: {rows_per_sec:,.0f} rows/s", end='\r')
|
|
96
|
+
batch = []
|
|
97
|
+
batch_start_time = time.time()
|
|
98
|
+
|
|
99
|
+
# Insert remaining rows
|
|
100
|
+
if batch:
|
|
101
|
+
cursor.executemany(insert_sql, batch)
|
|
102
|
+
conn.commit()
|
|
103
|
+
total_rows += len(batch)
|
|
104
|
+
batch_elapsed = time.time() - batch_start_time
|
|
105
|
+
if batch_elapsed > 0:
|
|
106
|
+
batch_times.append(batch_elapsed)
|
|
107
|
+
|
|
108
|
+
import_total_time = time.time() - import_start_time
|
|
109
|
+
avg_rows_per_sec = total_rows / import_total_time if import_total_time > 0 else 0
|
|
110
|
+
|
|
111
|
+
print(f"\n[{name}] Import completed: {total_rows:,} rows in {import_total_time:.2f}s "
|
|
112
|
+
f"({avg_rows_per_sec:,.0f} rows/s)")
|
|
113
|
+
if batch_times:
|
|
114
|
+
print(f"[{name}] Average batch time: {sum(batch_times)/len(batch_times):.3f}s "
|
|
115
|
+
f"(min: {min(batch_times):.3f}s, max: {max(batch_times):.3f}s, batches: {len(batch_times)})")
|
|
116
|
+
|
|
117
|
+
cursor.close()
|
|
118
|
+
conn.close()
|
|
119
|
+
print(f"[{name}] Done!")
|
|
120
|
+
return True
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
error_msg = str(e)
|
|
124
|
+
# 静默处理连接拒绝错误
|
|
125
|
+
if "Connection refused" in error_msg or "111" in error_msg:
|
|
126
|
+
print(f"[{name}] Skipped: service unavailable")
|
|
127
|
+
return False
|
|
128
|
+
print(f"[{name}] ERROR: {e}")
|
|
129
|
+
import traceback
|
|
130
|
+
traceback.print_exc()
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def main():
|
|
135
|
+
parser = argparse.ArgumentParser(description='Import CSV data to database')
|
|
136
|
+
parser.add_argument('--config', '-c', required=True, help='Path to dbms_config.json')
|
|
137
|
+
parser.add_argument('--dbms', '-d', default='all', help='DBMS name to import (default: all)')
|
|
138
|
+
parser.add_argument('--database', '-D', default='cross_dbms_test_db', help='Target database name')
|
|
139
|
+
parser.add_argument('--csv', default='benchmark/bench_sel.csv', help='CSV file path')
|
|
140
|
+
parser.add_argument('--batch-size', '-b', type=int, default=10000, help='Batch size for INSERT')
|
|
141
|
+
args = parser.parse_args()
|
|
142
|
+
|
|
143
|
+
# Load config
|
|
144
|
+
with open(args.config, 'r') as f:
|
|
145
|
+
data = json.load(f)
|
|
146
|
+
configs = data['databases'] if isinstance(data, dict) else data
|
|
147
|
+
|
|
148
|
+
# Filter target DBMS
|
|
149
|
+
if args.dbms.lower() == 'all':
|
|
150
|
+
targets = configs
|
|
151
|
+
else:
|
|
152
|
+
targets = [c for c in configs if c['name'].lower() == args.dbms.lower()]
|
|
153
|
+
if not targets:
|
|
154
|
+
print(f"Error: DBMS '{args.dbms}' not found")
|
|
155
|
+
print(f"Available: {[c['name'] for c in configs]}")
|
|
156
|
+
sys.exit(1)
|
|
157
|
+
|
|
158
|
+
# Import to each target
|
|
159
|
+
success = 0
|
|
160
|
+
for cfg in targets:
|
|
161
|
+
if import_data(cfg, args.database, args.csv, args.batch_size):
|
|
162
|
+
success += 1
|
|
163
|
+
|
|
164
|
+
print(f"\nCompleted: {success}/{len(targets)} DBMS targets")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
if __name__ == '__main__':
|
|
168
|
+
main()
|
rosetta/__init__.py
ADDED