rosetta-sql 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,83 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ 生成 bench_json_mv_select 测试所需的 CSV 数据文件。
4
+
5
+ 生成文件:
6
+ - bench_sel.csv: 100万行字符串tag数据 (user_id, json_array)
7
+ - bench_sel_int.csv: 100万行整数数据 (user_id, json_array)
8
+
9
+ 用法:
10
+ python generate_csv_data.py [--output-dir DIR] [--rows N]
11
+ """
12
+
13
+ import argparse
14
+ import csv
15
+ import json
16
+ import os
17
+ import random
18
+ from typing import List, Tuple
19
+
20
+
21
+ def generate_string_data(num_rows: int = 1_000_000) -> List[Tuple[int, str]]:
22
+ """生成字符串tag数据,数组大小在 [1, 20] 范围内随机变化。"""
23
+ rows = []
24
+ for _ in range(num_rows):
25
+ user_id = random.randint(1, 100000)
26
+ array_size = random.randint(1, 20)
27
+ tags = [f"tag_{random.randint(0, 999)}" for _ in range(array_size)]
28
+ json_data = json.dumps(tags)
29
+ rows.append((user_id, json_data))
30
+ return rows
31
+
32
+
33
+ def generate_int_data(num_rows: int = 1_000_000) -> List[Tuple[int, str]]:
34
+ """生成整数数据,数组大小在 [1, 20] 范围内随机变化 (值域0-9999)。"""
35
+ rows = []
36
+ for _ in range(num_rows):
37
+ user_id = random.randint(1, 100000)
38
+ array_size = random.randint(1, 20)
39
+ values = [random.randint(0, 9999) for _ in range(array_size)]
40
+ json_data = json.dumps(values)
41
+ rows.append((user_id, json_data))
42
+ return rows
43
+
44
+
45
+ def write_csv(filename: str, rows: List[Tuple], output_dir: str):
46
+ """写入CSV文件,无header,制表符分隔。JSON字段不加引号。"""
47
+ filepath = os.path.join(output_dir, filename)
48
+ with open(filepath, 'w', newline='', encoding='utf-8') as f:
49
+ for row in rows:
50
+ # 手动写入,不使用csv.writer避免自动引号
51
+ # 格式: user_id\tjson_data\n
52
+ f.write(f"{row[0]}\t{row[1]}\n")
53
+ print(f"Generated: {filepath} ({len(rows)} rows)")
54
+
55
+
56
+ def main():
57
+ parser = argparse.ArgumentParser(description='Generate CSV data files for benchmark')
58
+ parser.add_argument('--output-dir', '-o', default='.', help='Output directory (default: current dir)')
59
+ parser.add_argument('--rows', '-n', type=int, default=1_000_000, help='Number of rows per table (default: 1000000)')
60
+ args = parser.parse_args()
61
+
62
+ os.makedirs(args.output_dir, exist_ok=True)
63
+
64
+ # 设置随机种子以便可重现
65
+ random.seed(42)
66
+
67
+ print(f"Generating {args.rows:,} rows per table...")
68
+
69
+ # 生成字符串数据
70
+ print("Generating string data...")
71
+ str_rows = generate_string_data(args.rows)
72
+ write_csv('bench_sel.csv', str_rows, args.output_dir)
73
+
74
+ # 生成整数数据
75
+ print("Generating integer data...")
76
+ int_rows = generate_int_data(args.rows)
77
+ write_csv('bench_sel_int.csv', int_rows, args.output_dir)
78
+
79
+ print("Done!")
80
+
81
+
82
+ if __name__ == '__main__':
83
+ main()
@@ -0,0 +1,168 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ 将 CSV 数据导入到目标数据库。
4
+
5
+ 用法:
6
+ python import_data.py --config ../dbms_config.json --dbms tdsql
7
+ python import_data.py --config ../dbms_config.json --dbms all
8
+ """
9
+
10
+ import argparse
11
+ import json
12
+ import sys
13
+ import time
14
+
15
+ try:
16
+ import pymysql
17
+ except ImportError:
18
+ print("Error: pymysql is required. Install via: pip install pymysql")
19
+ sys.exit(1)
20
+
21
+
22
+ CREATE_TABLE_SQL = """
23
+ CREATE TABLE IF NOT EXISTS bench_sel (
24
+ id INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
25
+ user_id INT NOT NULL,
26
+ data JSON NOT NULL,
27
+ KEY idx_user (user_id),
28
+ KEY idx_mv_data ((CAST(data->'$' AS CHAR(64) ARRAY)))
29
+ )
30
+ """
31
+
32
+
33
+ def import_data(config: dict, database: str, csv_path: str, batch_size: int = 10000):
34
+ """Import data for a single DBMS."""
35
+ name = config['name']
36
+ print(f"\n[{name}] Connecting to {config['host']}:{config['port']}...")
37
+ print(f"[{name}] Target database: {database}")
38
+
39
+ try:
40
+ # Connect and create database
41
+ conn = pymysql.connect(
42
+ host=config['host'],
43
+ port=config['port'],
44
+ user=config['user'],
45
+ password=config['password'],
46
+ local_infile=True,
47
+ charset='utf8mb4',
48
+ )
49
+ cursor = conn.cursor()
50
+ cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{database}`")
51
+ cursor.execute(f"USE `{database}`")
52
+ conn.commit()
53
+
54
+ # Drop existing table and create new one with MVI index
55
+ print(f"[{name}] Dropping existing table...")
56
+ cursor.execute("DROP TABLE IF EXISTS bench_sel")
57
+ conn.commit()
58
+ print(f"[{name}] Creating table...")
59
+ cursor.execute(CREATE_TABLE_SQL)
60
+ conn.commit()
61
+
62
+ # Import CSV using batch INSERT
63
+ print(f"[{name}] Importing data...")
64
+ insert_sql = "INSERT INTO bench_sel (user_id, data) VALUES (%s, %s)"
65
+
66
+ total_rows = 0
67
+ batch = []
68
+ batch_times = [] # 记录每批次耗时
69
+ import_start_time = time.time()
70
+ batch_start_time = time.time()
71
+
72
+ with open(csv_path, 'r', encoding='utf-8') as f:
73
+ for line in f:
74
+ parts = line.rstrip('\n').split('\t', 1)
75
+ if len(parts) == 2:
76
+ batch.append((int(parts[0]), parts[1]))
77
+
78
+ if len(batch) >= batch_size:
79
+ cursor.executemany(insert_sql, batch)
80
+ conn.commit()
81
+ total_rows += len(batch)
82
+
83
+ # 记录本批次耗时
84
+ batch_elapsed = time.time() - batch_start_time
85
+ batch_times.append(batch_elapsed)
86
+
87
+ # 计算统计信息
88
+ avg_batch_time = sum(batch_times) / len(batch_times)
89
+ rows_per_sec = len(batch) / batch_elapsed if batch_elapsed > 0 else 0
90
+ elapsed_total = time.time() - import_start_time
91
+
92
+ print(f"[{name}] Imported {total_rows:,} rows | "
93
+ f"Batch: {batch_elapsed:.2f}s | "
94
+ f"Avg: {avg_batch_time:.2f}s/batch | "
95
+ f"Speed: {rows_per_sec:,.0f} rows/s", end='\r')
96
+ batch = []
97
+ batch_start_time = time.time()
98
+
99
+ # Insert remaining rows
100
+ if batch:
101
+ cursor.executemany(insert_sql, batch)
102
+ conn.commit()
103
+ total_rows += len(batch)
104
+ batch_elapsed = time.time() - batch_start_time
105
+ if batch_elapsed > 0:
106
+ batch_times.append(batch_elapsed)
107
+
108
+ import_total_time = time.time() - import_start_time
109
+ avg_rows_per_sec = total_rows / import_total_time if import_total_time > 0 else 0
110
+
111
+ print(f"\n[{name}] Import completed: {total_rows:,} rows in {import_total_time:.2f}s "
112
+ f"({avg_rows_per_sec:,.0f} rows/s)")
113
+ if batch_times:
114
+ print(f"[{name}] Average batch time: {sum(batch_times)/len(batch_times):.3f}s "
115
+ f"(min: {min(batch_times):.3f}s, max: {max(batch_times):.3f}s, batches: {len(batch_times)})")
116
+
117
+ cursor.close()
118
+ conn.close()
119
+ print(f"[{name}] Done!")
120
+ return True
121
+
122
+ except Exception as e:
123
+ error_msg = str(e)
124
+ # 静默处理连接拒绝错误
125
+ if "Connection refused" in error_msg or "111" in error_msg:
126
+ print(f"[{name}] Skipped: service unavailable")
127
+ return False
128
+ print(f"[{name}] ERROR: {e}")
129
+ import traceback
130
+ traceback.print_exc()
131
+ return False
132
+
133
+
134
+ def main():
135
+ parser = argparse.ArgumentParser(description='Import CSV data to database')
136
+ parser.add_argument('--config', '-c', required=True, help='Path to dbms_config.json')
137
+ parser.add_argument('--dbms', '-d', default='all', help='DBMS name to import (default: all)')
138
+ parser.add_argument('--database', '-D', default='cross_dbms_test_db', help='Target database name')
139
+ parser.add_argument('--csv', default='benchmark/bench_sel.csv', help='CSV file path')
140
+ parser.add_argument('--batch-size', '-b', type=int, default=10000, help='Batch size for INSERT')
141
+ args = parser.parse_args()
142
+
143
+ # Load config
144
+ with open(args.config, 'r') as f:
145
+ data = json.load(f)
146
+ configs = data['databases'] if isinstance(data, dict) else data
147
+
148
+ # Filter target DBMS
149
+ if args.dbms.lower() == 'all':
150
+ targets = configs
151
+ else:
152
+ targets = [c for c in configs if c['name'].lower() == args.dbms.lower()]
153
+ if not targets:
154
+ print(f"Error: DBMS '{args.dbms}' not found")
155
+ print(f"Available: {[c['name'] for c in configs]}")
156
+ sys.exit(1)
157
+
158
+ # Import to each target
159
+ success = 0
160
+ for cfg in targets:
161
+ if import_data(cfg, args.database, args.csv, args.batch_size):
162
+ success += 1
163
+
164
+ print(f"\nCompleted: {success}/{len(targets)} DBMS targets")
165
+
166
+
167
+ if __name__ == '__main__':
168
+ main()
rosetta/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """Rosetta — Cross-DBMS SQL behavioral consistency verification tool."""
2
+
3
+ __version__ = "1.0.0"
rosetta/__main__.py ADDED
@@ -0,0 +1,8 @@
1
+ """Allow running rosetta as: python -m rosetta"""
2
+
3
+ import sys
4
+
5
+ # Use new CLI module
6
+ from .cli.main import main
7
+
8
+ sys.exit(main())