sqlbench 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlbench/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """SQLBench - A multi-database SQL workbench."""
2
+
3
+ __version__ = "0.1.0"
sqlbench/__main__.py ADDED
@@ -0,0 +1,7 @@
1
+ """Entry point for sqlbench."""
2
+
3
+ from sqlbench.app import main
4
+
5
+
6
+ if __name__ == "__main__":
7
+ main()
sqlbench/adapters.py ADDED
@@ -0,0 +1,383 @@
1
+ """Database adapters for different database types."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+
6
+ class DBAdapter(ABC):
7
+ """Base class for database adapters."""
8
+
9
+ db_type = "base"
10
+ display_name = "Base"
11
+ default_port = None
12
+ requires_database = False
13
+ supports_spool = False
14
+
15
+ @abstractmethod
16
+ def connect(self, host, user, password, port=None, database=None):
17
+ """Connect to the database and return a connection object."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_version(self, conn):
22
+ """Get the database version string."""
23
+ pass
24
+
25
+ def add_pagination(self, sql, limit, offset=0):
26
+ """Add pagination to a SQL statement. Default uses LIMIT/OFFSET."""
27
+ sql_stripped = sql.strip()
28
+ while sql_stripped.endswith(';'):
29
+ sql_stripped = sql_stripped[:-1].strip()
30
+
31
+ if offset > 0:
32
+ return f"{sql_stripped} LIMIT {limit} OFFSET {offset}"
33
+ return f"{sql_stripped} LIMIT {limit}"
34
+
35
+ def get_count_sql(self, sql):
36
+ """Wrap SQL in COUNT(*) to get total rows."""
37
+ sql_stripped = sql.strip()
38
+ while sql_stripped.endswith(';'):
39
+ sql_stripped = sql_stripped[:-1].strip()
40
+ return f"SELECT COUNT(*) FROM ({sql_stripped}) AS count_query"
41
+
42
+ @abstractmethod
43
+ def get_columns_query(self, tables):
44
+ """Get SQL to retrieve column metadata for given tables."""
45
+ pass
46
+
47
+ @abstractmethod
48
+ def get_tables_query(self):
49
+ """Get SQL to retrieve list of tables."""
50
+ pass
51
+
52
+ def is_numeric_type(self, type_code):
53
+ """Check if a type_code represents a numeric type."""
54
+ from decimal import Decimal
55
+ # Default: check for Python numeric types (works for pyodbc)
56
+ return type_code in (int, float, Decimal)
57
+
58
+ def get_column_display_size(self, col_info):
59
+ """Extract best display size from cursor description tuple.
60
+
61
+ col_info format varies by driver but generally:
62
+ (name, type_code, display_size, internal_size, precision, scale, null_ok)
63
+ """
64
+ if not col_info or len(col_info) < 5:
65
+ return 10
66
+
67
+ display_size = col_info[2] or 0
68
+ internal_size = col_info[3] or 0
69
+ precision = col_info[4] or 0
70
+
71
+ return display_size or internal_size or precision or 10
72
+
73
+ def get_select_limit_query(self, table_ref, limit):
74
+ """Get a SELECT query with row limit for a table."""
75
+ return f"SELECT * FROM {table_ref} LIMIT {limit}"
76
+
77
+ def get_version_query(self):
78
+ """Get the SQL to retrieve database version."""
79
+ return "SELECT VERSION()"
80
+
81
+
82
+ class IBMiAdapter(DBAdapter):
83
+ """Adapter for IBM i (AS/400) via ODBC."""
84
+
85
+ db_type = "ibmi"
86
+ display_name = "IBM i"
87
+ default_port = None # ODBC handles this
88
+ requires_database = False
89
+ supports_spool = True
90
+
91
+ def connect(self, host, user, password, port=None, database=None):
92
+ import pyodbc
93
+ conn_str = (
94
+ f"DRIVER={{IBM i Access ODBC Driver}};"
95
+ f"SYSTEM={host};"
96
+ f"UID={user};"
97
+ f"PWD={password};"
98
+ )
99
+ return pyodbc.connect(conn_str)
100
+
101
+ def get_version(self, conn):
102
+ try:
103
+ cursor = conn.cursor()
104
+ cursor.execute("SELECT OS_VERSION, OS_RELEASE FROM SYSIBMADM.ENV_SYS_INFO")
105
+ row = cursor.fetchone()
106
+ cursor.close()
107
+ if row:
108
+ return f"{row[0]}.{row[1]}"
109
+ except Exception:
110
+ pass
111
+ return None
112
+
113
+ def add_pagination(self, sql, limit, offset=0):
114
+ """IBM i uses OFFSET/FETCH syntax."""
115
+ sql_stripped = sql.strip()
116
+ while sql_stripped.endswith(';'):
117
+ sql_stripped = sql_stripped[:-1].strip()
118
+
119
+ if offset > 0:
120
+ return f"{sql_stripped} OFFSET {offset} ROWS FETCH FIRST {limit} ROWS ONLY"
121
+ return f"{sql_stripped} FETCH FIRST {limit} ROWS ONLY"
122
+
123
+ def get_select_limit_query(self, table_ref, limit):
124
+ """Get a SELECT query with row limit for IBM i."""
125
+ return f"SELECT * FROM {table_ref} FETCH FIRST {limit} ROWS ONLY"
126
+
127
+ def get_version_query(self):
128
+ """Get the SQL to retrieve IBM i version."""
129
+ return "SELECT OS_VERSION || '.' || OS_RELEASE FROM SYSIBMADM.ENV_SYS_INFO"
130
+
131
+ def get_columns_query(self, tables):
132
+ if not tables:
133
+ return None
134
+
135
+ table_conditions = []
136
+ for table in tables:
137
+ if '.' in table:
138
+ schema, tbl = table.split('.', 1)
139
+ table_conditions.append(
140
+ f"(TABLE_SCHEMA = '{schema.upper()}' AND TABLE_NAME = '{tbl.upper()}')"
141
+ )
142
+ else:
143
+ table_conditions.append(f"TABLE_NAME = '{table.upper()}'")
144
+
145
+ where_clause = " OR ".join(table_conditions)
146
+ return f"""
147
+ SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, LENGTH, NUMERIC_SCALE
148
+ FROM QSYS2.SYSCOLUMNS
149
+ WHERE {where_clause}
150
+ ORDER BY TABLE_SCHEMA, TABLE_NAME, ORDINAL_POSITION
151
+ """
152
+
153
+ def get_tables_query(self):
154
+ """Get tables from IBM i - returns schema, table_name, table_type."""
155
+ return """
156
+ SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE
157
+ FROM QSYS2.SYSTABLES
158
+ WHERE TABLE_TYPE IN ('T', 'P', 'V')
159
+ ORDER BY TABLE_SCHEMA, TABLE_NAME
160
+ """
161
+
162
+
163
+ class MySQLAdapter(DBAdapter):
164
+ """Adapter for MySQL."""
165
+
166
+ db_type = "mysql"
167
+ display_name = "MySQL"
168
+ default_port = 3306
169
+ requires_database = True
170
+ supports_spool = False
171
+
172
+ def connect(self, host, user, password, port=None, database=None):
173
+ import mysql.connector
174
+ config = {
175
+ 'host': host,
176
+ 'user': user,
177
+ 'password': password,
178
+ 'database': database or '',
179
+ }
180
+ if port:
181
+ config['port'] = int(port)
182
+ return mysql.connector.connect(**config)
183
+
184
+ def get_version(self, conn):
185
+ try:
186
+ cursor = conn.cursor()
187
+ cursor.execute("SELECT VERSION()")
188
+ row = cursor.fetchone()
189
+ cursor.close()
190
+ if row:
191
+ return row[0]
192
+ except Exception:
193
+ pass
194
+ return None
195
+
196
+ def get_columns_query(self, tables):
197
+ if not tables:
198
+ return None
199
+
200
+ table_conditions = []
201
+ for table in tables:
202
+ if '.' in table:
203
+ schema, tbl = table.split('.', 1)
204
+ table_conditions.append(
205
+ f"(TABLE_SCHEMA = '{schema}' AND TABLE_NAME = '{tbl}')"
206
+ )
207
+ else:
208
+ table_conditions.append(f"TABLE_NAME = '{table}'")
209
+
210
+ where_clause = " OR ".join(table_conditions)
211
+ return f"""
212
+ SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE,
213
+ CHARACTER_MAXIMUM_LENGTH AS LENGTH, NUMERIC_SCALE
214
+ FROM INFORMATION_SCHEMA.COLUMNS
215
+ WHERE {where_clause}
216
+ ORDER BY TABLE_SCHEMA, TABLE_NAME, ORDINAL_POSITION
217
+ """
218
+
219
+ def get_tables_query(self):
220
+ """Get tables from MySQL - returns schema, table_name, table_type."""
221
+ return """
222
+ SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE
223
+ FROM INFORMATION_SCHEMA.TABLES
224
+ WHERE TABLE_SCHEMA NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
225
+ ORDER BY TABLE_SCHEMA, TABLE_NAME
226
+ """
227
+
228
+ def is_numeric_type(self, type_code):
229
+ """Check if a type_code represents a numeric type for MySQL."""
230
+ # mysql-connector-python uses FieldType constants
231
+ # Numeric field types: TINY, SHORT, LONG, FLOAT, DOUBLE, LONGLONG, INT24, DECIMAL, NEWDECIMAL
232
+ try:
233
+ from mysql.connector import FieldType
234
+ numeric_types = {
235
+ FieldType.TINY, FieldType.SHORT, FieldType.LONG,
236
+ FieldType.FLOAT, FieldType.DOUBLE, FieldType.LONGLONG,
237
+ FieldType.INT24, FieldType.DECIMAL, FieldType.NEWDECIMAL
238
+ }
239
+ return type_code in numeric_types
240
+ except (ImportError, AttributeError):
241
+ # Fallback to checking if it's a Python numeric type
242
+ from decimal import Decimal
243
+ return type_code in (int, float, Decimal)
244
+
245
+ def get_column_display_size(self, col_info):
246
+ """Extract display size for MySQL - use internal_size as it's more reliable."""
247
+ if not col_info or len(col_info) < 5:
248
+ return 10
249
+
250
+ # MySQL: (name, type_code, display_size, internal_size, precision, scale, null_ok, flags, charset)
251
+ # internal_size (index 3) is usually populated
252
+ internal_size = col_info[3] or 0
253
+ display_size = col_info[2] or 0
254
+ precision = col_info[4] or 0
255
+
256
+ # For MySQL, internal_size is often the best indicator
257
+ size = internal_size or display_size or precision
258
+ # Cap at reasonable max for display purposes
259
+ return min(size, 255) if size > 0 else 20
260
+
261
+
262
+ class PostgreSQLAdapter(DBAdapter):
263
+ """Adapter for PostgreSQL."""
264
+
265
+ db_type = "postgresql"
266
+ display_name = "PostgreSQL"
267
+ default_port = 5432
268
+ requires_database = True
269
+ supports_spool = False
270
+
271
+ def connect(self, host, user, password, port=None, database=None):
272
+ import psycopg2
273
+ return psycopg2.connect(
274
+ host=host,
275
+ user=user,
276
+ password=password,
277
+ dbname=database or 'postgres',
278
+ port=port or 5432
279
+ )
280
+
281
+ def get_version_query(self):
282
+ return "SELECT version()"
283
+
284
+ def get_version(self, conn):
285
+ try:
286
+ cursor = conn.cursor()
287
+ cursor.execute("SELECT version()")
288
+ row = cursor.fetchone()
289
+ cursor.close()
290
+ if row:
291
+ # Extract just version number from full string
292
+ version_str = row[0]
293
+ if 'PostgreSQL' in version_str:
294
+ parts = version_str.split()
295
+ for i, p in enumerate(parts):
296
+ if p == 'PostgreSQL' and i + 1 < len(parts):
297
+ return parts[i + 1].rstrip(',')
298
+ return version_str[:30]
299
+ except Exception:
300
+ pass
301
+ return None
302
+
303
+ def get_columns_query(self, tables):
304
+ if not tables:
305
+ return None
306
+
307
+ table_conditions = []
308
+ for table in tables:
309
+ if '.' in table:
310
+ schema, tbl = table.split('.', 1)
311
+ table_conditions.append(
312
+ f"(table_schema = '{schema}' AND table_name = '{tbl}')"
313
+ )
314
+ else:
315
+ table_conditions.append(f"table_name = '{table}'")
316
+
317
+ where_clause = " OR ".join(table_conditions)
318
+ return f"""
319
+ SELECT table_schema, table_name, column_name, data_type,
320
+ character_maximum_length AS length, numeric_scale
321
+ FROM information_schema.columns
322
+ WHERE {where_clause}
323
+ ORDER BY table_schema, table_name, ordinal_position
324
+ """
325
+
326
+ def get_tables_query(self):
327
+ """Get tables from PostgreSQL - returns schema, table_name, table_type."""
328
+ return """
329
+ SELECT table_schema, table_name, table_type
330
+ FROM information_schema.tables
331
+ WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
332
+ ORDER BY table_schema, table_name
333
+ """
334
+
335
+ def is_numeric_type(self, type_code):
336
+ """Check if a type_code represents a numeric type for PostgreSQL."""
337
+ # psycopg2 uses OIDs for type_code
338
+ # Common numeric OIDs: 20=int8, 21=int2, 23=int4, 700=float4, 701=float8, 1700=numeric
339
+ numeric_oids = {20, 21, 23, 26, 700, 701, 790, 1700}
340
+ if isinstance(type_code, int):
341
+ return type_code in numeric_oids
342
+ # Fallback
343
+ from decimal import Decimal
344
+ return type_code in (int, float, Decimal)
345
+
346
+ def get_column_display_size(self, col_info):
347
+ """Extract display size for PostgreSQL."""
348
+ if not col_info or len(col_info) < 5:
349
+ return 10
350
+
351
+ # psycopg2: (name, type_code, display_size, internal_size, precision, scale, null_ok)
352
+ # display_size is often None, internal_size is the storage size
353
+ display_size = col_info[2] or 0
354
+ internal_size = col_info[3] or 0
355
+ precision = col_info[4] or 0
356
+
357
+ # For text types, internal_size is -1 (unlimited)
358
+ if internal_size < 0:
359
+ return 50 # Default for text/varchar without limit
360
+
361
+ size = display_size or precision or internal_size
362
+ return min(size, 255) if size > 0 else 20
363
+
364
+
365
+ # Registry of available adapters
366
+ ADAPTERS = {
367
+ 'ibmi': IBMiAdapter,
368
+ 'mysql': MySQLAdapter,
369
+ 'postgresql': PostgreSQLAdapter,
370
+ }
371
+
372
+
373
+ def get_adapter(db_type):
374
+ """Get an adapter instance by type."""
375
+ adapter_class = ADAPTERS.get(db_type)
376
+ if adapter_class:
377
+ return adapter_class()
378
+ raise ValueError(f"Unknown database type: {db_type}")
379
+
380
+
381
+ def get_adapter_choices():
382
+ """Get list of (db_type, display_name) for UI."""
383
+ return [(key, cls.display_name) for key, cls in ADAPTERS.items()]