structured2graph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +47 -0
- core/__init__.py +23 -0
- core/hygm/__init__.py +74 -0
- core/hygm/hygm.py +2351 -0
- core/hygm/models/__init__.py +82 -0
- core/hygm/models/graph_models.py +667 -0
- core/hygm/models/llm_models.py +229 -0
- core/hygm/models/operations.py +176 -0
- core/hygm/models/sources.py +68 -0
- core/hygm/models/user_operations.py +139 -0
- core/hygm/strategies/__init__.py +17 -0
- core/hygm/strategies/base.py +36 -0
- core/hygm/strategies/deterministic.py +262 -0
- core/hygm/strategies/llm.py +904 -0
- core/hygm/validation/__init__.py +38 -0
- core/hygm/validation/base.py +194 -0
- core/hygm/validation/graph_schema_validator.py +687 -0
- core/hygm/validation/memgraph_data_validator.py +991 -0
- core/migration_agent.py +1369 -0
- core/schema/spec.json +155 -0
- core/utils/meta_graph.py +108 -0
- database/__init__.py +36 -0
- database/adapters/__init__.py +11 -0
- database/adapters/memgraph.py +318 -0
- database/adapters/mysql.py +311 -0
- database/adapters/postgresql.py +335 -0
- database/analyzer.py +396 -0
- database/factory.py +219 -0
- database/models.py +209 -0
- main.py +518 -0
- query_generation/__init__.py +20 -0
- query_generation/cypher_generator.py +129 -0
- query_generation/schema_utilities.py +88 -0
- structured2graph-0.1.1.dist-info/METADATA +197 -0
- structured2graph-0.1.1.dist-info/RECORD +41 -0
- structured2graph-0.1.1.dist-info/WHEEL +4 -0
- structured2graph-0.1.1.dist-info/entry_points.txt +2 -0
- structured2graph-0.1.1.dist-info/licenses/LICENSE +21 -0
- utils/__init__.py +57 -0
- utils/config.py +235 -0
- utils/environment.py +404 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MySQL-specific database analyzer implementation.
|
|
3
|
+
|
|
4
|
+
This module provides MySQL-specific implementation of the DatabaseAnalyzer interface.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import mysql.connector
|
|
8
|
+
from typing import Dict, List, Any, Optional
|
|
9
|
+
import logging
|
|
10
|
+
from ..analyzer import (
|
|
11
|
+
DatabaseAnalyzer,
|
|
12
|
+
ColumnInfo,
|
|
13
|
+
ForeignKeyInfo,
|
|
14
|
+
TableInfo,
|
|
15
|
+
TableType,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MySQLAnalyzer(DatabaseAnalyzer):
|
|
22
|
+
"""MySQL-specific implementation of DatabaseAnalyzer."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self, host: str, user: str, password: str, database: str, port: int = 3306
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize MySQL analyzer.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
host: MySQL server hostname
|
|
32
|
+
user: MySQL username
|
|
33
|
+
password: MySQL password
|
|
34
|
+
database: Database name
|
|
35
|
+
port: MySQL port (default: 3306)
|
|
36
|
+
"""
|
|
37
|
+
connection_config = {
|
|
38
|
+
"host": host,
|
|
39
|
+
"user": user,
|
|
40
|
+
"password": password,
|
|
41
|
+
"database": database,
|
|
42
|
+
"port": port,
|
|
43
|
+
}
|
|
44
|
+
super().__init__(connection_config)
|
|
45
|
+
|
|
46
|
+
def _get_database_type(self) -> str:
|
|
47
|
+
"""Return the database type."""
|
|
48
|
+
return "mysql"
|
|
49
|
+
|
|
50
|
+
def connect(self) -> bool:
|
|
51
|
+
"""Establish connection to MySQL database."""
|
|
52
|
+
try:
|
|
53
|
+
self.connection = mysql.connector.connect(**self.connection_config)
|
|
54
|
+
logger.info("Successfully connected to MySQL database")
|
|
55
|
+
return True
|
|
56
|
+
except mysql.connector.Error as e:
|
|
57
|
+
logger.error(f"Error connecting to MySQL: {e}")
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
def disconnect(self) -> None:
|
|
61
|
+
"""Close MySQL connection."""
|
|
62
|
+
if self.connection and self.connection.is_connected():
|
|
63
|
+
self.connection.close()
|
|
64
|
+
logger.info("MySQL connection closed")
|
|
65
|
+
|
|
66
|
+
def get_tables(self) -> List[str]:
|
|
67
|
+
"""Get list of all tables in the database."""
|
|
68
|
+
if not self.connection:
|
|
69
|
+
raise ConnectionError("Not connected to database")
|
|
70
|
+
|
|
71
|
+
cursor = self.connection.cursor()
|
|
72
|
+
cursor.execute("SHOW TABLES")
|
|
73
|
+
tables = [table[0] for table in cursor.fetchall()]
|
|
74
|
+
cursor.close()
|
|
75
|
+
return tables
|
|
76
|
+
|
|
77
|
+
def get_table_schema(self, table_name: str) -> List[ColumnInfo]:
|
|
78
|
+
"""Get schema information for a specific table."""
|
|
79
|
+
if not self.connection:
|
|
80
|
+
raise ConnectionError("Not connected to database")
|
|
81
|
+
|
|
82
|
+
cursor = self.connection.cursor()
|
|
83
|
+
cursor.execute(f"DESCRIBE {table_name}")
|
|
84
|
+
|
|
85
|
+
columns = []
|
|
86
|
+
for row in cursor.fetchall():
|
|
87
|
+
field_name = row[0]
|
|
88
|
+
data_type = row[1]
|
|
89
|
+
is_nullable = row[2] == "YES"
|
|
90
|
+
key_type = row[3]
|
|
91
|
+
default_value = row[4]
|
|
92
|
+
extra = row[5]
|
|
93
|
+
|
|
94
|
+
# Determine if it's a primary key
|
|
95
|
+
is_primary_key = key_type == "PRI"
|
|
96
|
+
|
|
97
|
+
# Determine if it's a foreign key (will be checked separately)
|
|
98
|
+
is_foreign_key = False
|
|
99
|
+
|
|
100
|
+
# Check for auto increment
|
|
101
|
+
auto_increment = "auto_increment" in extra.lower()
|
|
102
|
+
|
|
103
|
+
# Parse data type for length, precision, scale
|
|
104
|
+
max_length = None
|
|
105
|
+
precision = None
|
|
106
|
+
scale = None
|
|
107
|
+
|
|
108
|
+
if "(" in data_type:
|
|
109
|
+
type_part = data_type.split("(")[0].lower()
|
|
110
|
+
params_part = data_type.split("(")[1].rstrip(")")
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
if "," in params_part:
|
|
114
|
+
# Decimal type with precision and scale
|
|
115
|
+
precision, scale = map(int, params_part.split(","))
|
|
116
|
+
elif type_part in ("varchar", "char", "varbinary", "binary", "bit"):
|
|
117
|
+
# Types that have numeric length parameters
|
|
118
|
+
max_length = int(params_part)
|
|
119
|
+
elif type_part in ("decimal", "numeric", "float", "double"):
|
|
120
|
+
# Numeric types that might have precision
|
|
121
|
+
if params_part.isdigit():
|
|
122
|
+
precision = int(params_part)
|
|
123
|
+
# For enum, set, and other types, we don't parse the
|
|
124
|
+
# parameters as integers
|
|
125
|
+
# They will be handled as part of the type definition
|
|
126
|
+
except (ValueError, TypeError) as e:
|
|
127
|
+
# If we can't parse the parameters as integers, it's
|
|
128
|
+
# likely an enum, set, etc.
|
|
129
|
+
# Keep the full type definition including parameters
|
|
130
|
+
logger.debug(
|
|
131
|
+
"Could not parse type parameters for " "%s: %s", data_type, e
|
|
132
|
+
)
|
|
133
|
+
# Don't modify data_type in this case, keep it as is
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
# Only update data_type if we successfully parsed parameters
|
|
137
|
+
length_types = ("varchar", "char", "varbinary", "binary", "bit")
|
|
138
|
+
numeric_types = ("decimal", "numeric", "float", "double")
|
|
139
|
+
if type_part in length_types or type_part in numeric_types:
|
|
140
|
+
data_type = type_part
|
|
141
|
+
|
|
142
|
+
columns.append(
|
|
143
|
+
ColumnInfo(
|
|
144
|
+
name=field_name,
|
|
145
|
+
data_type=data_type,
|
|
146
|
+
is_nullable=is_nullable,
|
|
147
|
+
is_primary_key=is_primary_key,
|
|
148
|
+
is_foreign_key=is_foreign_key,
|
|
149
|
+
default_value=default_value,
|
|
150
|
+
auto_increment=auto_increment,
|
|
151
|
+
max_length=max_length,
|
|
152
|
+
precision=precision,
|
|
153
|
+
scale=scale,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
cursor.close()
|
|
158
|
+
|
|
159
|
+
# Now check for foreign keys and update the columns
|
|
160
|
+
foreign_keys = self.get_foreign_keys(table_name)
|
|
161
|
+
fk_column_names = {fk.column_name for fk in foreign_keys}
|
|
162
|
+
|
|
163
|
+
for column in columns:
|
|
164
|
+
if column.name in fk_column_names:
|
|
165
|
+
column.is_foreign_key = True
|
|
166
|
+
|
|
167
|
+
return columns
|
|
168
|
+
|
|
169
|
+
def get_foreign_keys(self, table_name: str) -> List[ForeignKeyInfo]:
|
|
170
|
+
"""Get foreign key relationships for a table."""
|
|
171
|
+
if not self.connection:
|
|
172
|
+
raise ConnectionError("Not connected to database")
|
|
173
|
+
|
|
174
|
+
cursor = self.connection.cursor()
|
|
175
|
+
query = """
|
|
176
|
+
SELECT
|
|
177
|
+
COLUMN_NAME,
|
|
178
|
+
REFERENCED_TABLE_NAME,
|
|
179
|
+
REFERENCED_COLUMN_NAME,
|
|
180
|
+
CONSTRAINT_NAME
|
|
181
|
+
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
|
182
|
+
WHERE TABLE_SCHEMA = %s
|
|
183
|
+
AND TABLE_NAME = %s
|
|
184
|
+
AND REFERENCED_TABLE_NAME IS NOT NULL
|
|
185
|
+
"""
|
|
186
|
+
cursor.execute(query, (self.connection_config["database"], table_name))
|
|
187
|
+
|
|
188
|
+
foreign_keys = []
|
|
189
|
+
for row in cursor.fetchall():
|
|
190
|
+
foreign_keys.append(
|
|
191
|
+
ForeignKeyInfo(
|
|
192
|
+
column_name=row[0],
|
|
193
|
+
referenced_table=row[1],
|
|
194
|
+
referenced_column=row[2],
|
|
195
|
+
constraint_name=row[3],
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
cursor.close()
|
|
199
|
+
return foreign_keys
|
|
200
|
+
|
|
201
|
+
def get_table_data(
|
|
202
|
+
self, table_name: str, limit: Optional[int] = None
|
|
203
|
+
) -> List[Dict[str, Any]]:
|
|
204
|
+
"""Get data from a specific table."""
|
|
205
|
+
if not self.connection:
|
|
206
|
+
raise ConnectionError("Not connected to database")
|
|
207
|
+
|
|
208
|
+
cursor = self.connection.cursor(dictionary=True)
|
|
209
|
+
query = f"SELECT * FROM {table_name}"
|
|
210
|
+
if limit:
|
|
211
|
+
query += f" LIMIT {limit}"
|
|
212
|
+
|
|
213
|
+
cursor.execute(query)
|
|
214
|
+
data = cursor.fetchall()
|
|
215
|
+
cursor.close()
|
|
216
|
+
return data
|
|
217
|
+
|
|
218
|
+
def get_table_row_count(self, table_name: str) -> int:
|
|
219
|
+
"""Get the number of rows in a table."""
|
|
220
|
+
if not self.connection:
|
|
221
|
+
raise ConnectionError("Not connected to database")
|
|
222
|
+
|
|
223
|
+
cursor = self.connection.cursor()
|
|
224
|
+
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
|
|
225
|
+
count = cursor.fetchone()[0]
|
|
226
|
+
cursor.close()
|
|
227
|
+
return count
|
|
228
|
+
|
|
229
|
+
def is_view(self, table_name: str) -> bool:
|
|
230
|
+
"""Check if a table is actually a view."""
|
|
231
|
+
if not self.connection:
|
|
232
|
+
raise ConnectionError("Not connected to database")
|
|
233
|
+
|
|
234
|
+
cursor = self.connection.cursor()
|
|
235
|
+
query = """
|
|
236
|
+
SELECT TABLE_TYPE
|
|
237
|
+
FROM INFORMATION_SCHEMA.TABLES
|
|
238
|
+
WHERE TABLE_SCHEMA = %s
|
|
239
|
+
AND TABLE_NAME = %s
|
|
240
|
+
"""
|
|
241
|
+
cursor.execute(query, (self.connection_config["database"], table_name))
|
|
242
|
+
result = cursor.fetchone()
|
|
243
|
+
cursor.close()
|
|
244
|
+
|
|
245
|
+
if result:
|
|
246
|
+
return result[0] == "VIEW"
|
|
247
|
+
return False
|
|
248
|
+
|
|
249
|
+
def get_tables_excluding_views(self) -> List[str]:
|
|
250
|
+
"""Get list of all tables in the database, excluding views."""
|
|
251
|
+
if not self.connection:
|
|
252
|
+
raise ConnectionError("Not connected to database")
|
|
253
|
+
|
|
254
|
+
cursor = self.connection.cursor()
|
|
255
|
+
query = """
|
|
256
|
+
SELECT TABLE_NAME
|
|
257
|
+
FROM INFORMATION_SCHEMA.TABLES
|
|
258
|
+
WHERE TABLE_SCHEMA = %s
|
|
259
|
+
AND TABLE_TYPE = 'BASE TABLE'
|
|
260
|
+
"""
|
|
261
|
+
cursor.execute(query, (self.connection_config["database"],))
|
|
262
|
+
tables = [table[0] for table in cursor.fetchall()]
|
|
263
|
+
cursor.close()
|
|
264
|
+
return tables
|
|
265
|
+
|
|
266
|
+
def get_indexes(self, table_name: str) -> List[Dict[str, Any]]:
|
|
267
|
+
"""
|
|
268
|
+
Get index information for a table.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
table_name: Name of the table
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
List of index information dictionaries
|
|
275
|
+
"""
|
|
276
|
+
if not self.connection:
|
|
277
|
+
raise ConnectionError("Not connected to database")
|
|
278
|
+
|
|
279
|
+
cursor = self.connection.cursor()
|
|
280
|
+
query = """
|
|
281
|
+
SELECT
|
|
282
|
+
INDEX_NAME,
|
|
283
|
+
COLUMN_NAME,
|
|
284
|
+
NON_UNIQUE,
|
|
285
|
+
INDEX_TYPE
|
|
286
|
+
FROM INFORMATION_SCHEMA.STATISTICS
|
|
287
|
+
WHERE TABLE_SCHEMA = %s
|
|
288
|
+
AND TABLE_NAME = %s
|
|
289
|
+
ORDER BY INDEX_NAME, SEQ_IN_INDEX
|
|
290
|
+
"""
|
|
291
|
+
cursor.execute(query, (self.connection_config["database"], table_name))
|
|
292
|
+
|
|
293
|
+
indexes = {}
|
|
294
|
+
for row in cursor.fetchall():
|
|
295
|
+
index_name = row[0]
|
|
296
|
+
column_name = row[1]
|
|
297
|
+
is_unique = row[2] == 0
|
|
298
|
+
index_type = row[3]
|
|
299
|
+
|
|
300
|
+
if index_name not in indexes:
|
|
301
|
+
indexes[index_name] = {
|
|
302
|
+
"name": index_name,
|
|
303
|
+
"columns": [],
|
|
304
|
+
"is_unique": is_unique,
|
|
305
|
+
"type": index_type,
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
indexes[index_name]["columns"].append(column_name)
|
|
309
|
+
|
|
310
|
+
cursor.close()
|
|
311
|
+
return list(indexes.values())
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""PostgreSQL-specific database analyzer implementation."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import psycopg2 # type: ignore[import-not-found]
|
|
8
|
+
import psycopg2.extras # type: ignore[import-not-found]
|
|
9
|
+
from psycopg2 import sql # type: ignore[import-not-found]
|
|
10
|
+
except ImportError as import_error: # pragma: no cover - optional dependency
|
|
11
|
+
psycopg2 = None # type: ignore[assignment]
|
|
12
|
+
sql = None # type: ignore[assignment]
|
|
13
|
+
_PSYCOPG2_IMPORT_ERROR = import_error
|
|
14
|
+
else:
|
|
15
|
+
_PSYCOPG2_IMPORT_ERROR = None
|
|
16
|
+
|
|
17
|
+
from ..analyzer import DatabaseAnalyzer
|
|
18
|
+
from ..models import ColumnInfo, ForeignKeyInfo
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PostgreSQLAnalyzer(DatabaseAnalyzer):
|
|
24
|
+
"""PostgreSQL-specific implementation of DatabaseAnalyzer."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
host: str,
|
|
29
|
+
user: str,
|
|
30
|
+
password: str,
|
|
31
|
+
database: str,
|
|
32
|
+
port: int = 5432,
|
|
33
|
+
schema: str = "public",
|
|
34
|
+
):
|
|
35
|
+
connection_config = {
|
|
36
|
+
"host": host,
|
|
37
|
+
"user": user,
|
|
38
|
+
"password": password,
|
|
39
|
+
"database": database,
|
|
40
|
+
"port": port,
|
|
41
|
+
"schema": schema,
|
|
42
|
+
}
|
|
43
|
+
self._schema = schema
|
|
44
|
+
super().__init__(connection_config)
|
|
45
|
+
|
|
46
|
+
def _get_database_type(self) -> str:
|
|
47
|
+
return "postgresql"
|
|
48
|
+
|
|
49
|
+
def connect(self) -> bool:
|
|
50
|
+
if psycopg2 is None:
|
|
51
|
+
raise ImportError(
|
|
52
|
+
"psycopg2 is required for PostgreSQL support"
|
|
53
|
+
) from _PSYCOPG2_IMPORT_ERROR
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
connect_config = {
|
|
57
|
+
key: value
|
|
58
|
+
for key, value in self.connection_config.items()
|
|
59
|
+
if key != "schema"
|
|
60
|
+
}
|
|
61
|
+
self.connection = psycopg2.connect(**connect_config)
|
|
62
|
+
logger.info("Successfully connected to PostgreSQL database")
|
|
63
|
+
return True
|
|
64
|
+
except psycopg2.Error as exc:
|
|
65
|
+
logger.error("Error connecting to PostgreSQL: %s", exc)
|
|
66
|
+
self.connection = None
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
def disconnect(self) -> None:
|
|
70
|
+
if self.connection:
|
|
71
|
+
self.connection.close()
|
|
72
|
+
self.connection = None
|
|
73
|
+
logger.info("PostgreSQL connection closed")
|
|
74
|
+
|
|
75
|
+
def get_tables(self) -> List[str]:
|
|
76
|
+
connection = self._require_connection()
|
|
77
|
+
schema = self._schema_name()
|
|
78
|
+
query = """
|
|
79
|
+
SELECT table_name
|
|
80
|
+
FROM information_schema.tables
|
|
81
|
+
WHERE table_schema = %s
|
|
82
|
+
AND table_type IN ('BASE TABLE', 'VIEW')
|
|
83
|
+
ORDER BY table_name
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
cursor = connection.cursor()
|
|
87
|
+
cursor.execute(query, (schema,))
|
|
88
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
89
|
+
cursor.close()
|
|
90
|
+
return tables
|
|
91
|
+
|
|
92
|
+
def get_table_schema(self, table_name: str) -> List[ColumnInfo]:
|
|
93
|
+
connection = self._require_connection()
|
|
94
|
+
schema = self._schema_name()
|
|
95
|
+
|
|
96
|
+
column_query = """
|
|
97
|
+
SELECT
|
|
98
|
+
column_name,
|
|
99
|
+
data_type,
|
|
100
|
+
is_nullable,
|
|
101
|
+
column_default,
|
|
102
|
+
character_maximum_length,
|
|
103
|
+
numeric_precision,
|
|
104
|
+
numeric_scale
|
|
105
|
+
FROM information_schema.columns
|
|
106
|
+
WHERE table_schema = %s
|
|
107
|
+
AND table_name = %s
|
|
108
|
+
ORDER BY ordinal_position
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
cursor = connection.cursor()
|
|
112
|
+
cursor.execute(column_query, (schema, table_name))
|
|
113
|
+
column_rows = cursor.fetchall()
|
|
114
|
+
cursor.close()
|
|
115
|
+
|
|
116
|
+
primary_keys = set(self._get_primary_key_columns(table_name))
|
|
117
|
+
foreign_keys = self.get_foreign_keys(table_name)
|
|
118
|
+
fk_column_names = {fk.column_name for fk in foreign_keys}
|
|
119
|
+
|
|
120
|
+
columns: List[ColumnInfo] = []
|
|
121
|
+
for (
|
|
122
|
+
column_name,
|
|
123
|
+
data_type,
|
|
124
|
+
is_nullable,
|
|
125
|
+
column_default,
|
|
126
|
+
char_max_length,
|
|
127
|
+
numeric_precision,
|
|
128
|
+
numeric_scale,
|
|
129
|
+
) in column_rows:
|
|
130
|
+
auto_increment = False
|
|
131
|
+
if isinstance(column_default, str):
|
|
132
|
+
auto_increment = column_default.lower().startswith("nextval(")
|
|
133
|
+
|
|
134
|
+
max_length = int(char_max_length) if char_max_length is not None else None
|
|
135
|
+
precision = (
|
|
136
|
+
int(numeric_precision) if numeric_precision is not None else None
|
|
137
|
+
)
|
|
138
|
+
scale = int(numeric_scale) if numeric_scale is not None else None
|
|
139
|
+
|
|
140
|
+
columns.append(
|
|
141
|
+
ColumnInfo(
|
|
142
|
+
name=column_name,
|
|
143
|
+
data_type=data_type,
|
|
144
|
+
is_nullable=is_nullable == "YES",
|
|
145
|
+
is_primary_key=column_name in primary_keys,
|
|
146
|
+
is_foreign_key=column_name in fk_column_names,
|
|
147
|
+
default_value=column_default,
|
|
148
|
+
auto_increment=auto_increment,
|
|
149
|
+
max_length=max_length,
|
|
150
|
+
precision=precision,
|
|
151
|
+
scale=scale,
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return columns
|
|
156
|
+
|
|
157
|
+
def get_foreign_keys(self, table_name: str) -> List[ForeignKeyInfo]:
|
|
158
|
+
connection = self._require_connection()
|
|
159
|
+
schema = self._schema_name()
|
|
160
|
+
query = """
|
|
161
|
+
SELECT
|
|
162
|
+
kcu.column_name,
|
|
163
|
+
ccu.table_name AS referenced_table,
|
|
164
|
+
ccu.column_name AS referenced_column,
|
|
165
|
+
tc.constraint_name
|
|
166
|
+
FROM information_schema.table_constraints AS tc
|
|
167
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
168
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
169
|
+
AND tc.table_schema = kcu.table_schema
|
|
170
|
+
JOIN information_schema.constraint_column_usage AS ccu
|
|
171
|
+
ON ccu.constraint_name = tc.constraint_name
|
|
172
|
+
AND ccu.table_schema = tc.table_schema
|
|
173
|
+
WHERE tc.table_schema = %s
|
|
174
|
+
AND tc.table_name = %s
|
|
175
|
+
AND tc.constraint_type = 'FOREIGN KEY'
|
|
176
|
+
ORDER BY tc.constraint_name, kcu.ordinal_position
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
cursor = connection.cursor()
|
|
180
|
+
cursor.execute(query, (schema, table_name))
|
|
181
|
+
foreign_keys = [
|
|
182
|
+
ForeignKeyInfo(
|
|
183
|
+
column_name=row[0],
|
|
184
|
+
referenced_table=row[1],
|
|
185
|
+
referenced_column=row[2],
|
|
186
|
+
constraint_name=row[3],
|
|
187
|
+
)
|
|
188
|
+
for row in cursor.fetchall()
|
|
189
|
+
]
|
|
190
|
+
cursor.close()
|
|
191
|
+
return foreign_keys
|
|
192
|
+
|
|
193
|
+
def get_table_data(
|
|
194
|
+
self, table_name: str, limit: Optional[int] = None
|
|
195
|
+
) -> List[Dict[str, Any]]:
|
|
196
|
+
connection = self._require_connection()
|
|
197
|
+
schema = self._schema_name()
|
|
198
|
+
if sql is None or psycopg2 is None: # pragma: no cover - import guard
|
|
199
|
+
raise ImportError("psycopg2 is required for PostgreSQL support")
|
|
200
|
+
|
|
201
|
+
query = sql.SQL("SELECT * FROM {}.{}").format(
|
|
202
|
+
sql.Identifier(schema), sql.Identifier(table_name)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
params: Optional[Tuple[int, ...]] = None
|
|
206
|
+
if limit is not None:
|
|
207
|
+
query = query + sql.SQL(" LIMIT %s")
|
|
208
|
+
params = (limit,)
|
|
209
|
+
|
|
210
|
+
cursor = connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
|
211
|
+
cursor.execute(query, params)
|
|
212
|
+
rows = cursor.fetchall()
|
|
213
|
+
cursor.close()
|
|
214
|
+
return [dict(row) for row in rows]
|
|
215
|
+
|
|
216
|
+
def get_table_row_count(self, table_name: str) -> int:
|
|
217
|
+
connection = self._require_connection()
|
|
218
|
+
schema = self._schema_name()
|
|
219
|
+
if sql is None: # pragma: no cover - import guard
|
|
220
|
+
raise ImportError("psycopg2 is required for PostgreSQL support")
|
|
221
|
+
|
|
222
|
+
query = sql.SQL("SELECT COUNT(*) FROM {}.{}").format(
|
|
223
|
+
sql.Identifier(schema), sql.Identifier(table_name)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
cursor = connection.cursor()
|
|
227
|
+
cursor.execute(query)
|
|
228
|
+
count = cursor.fetchone()[0]
|
|
229
|
+
cursor.close()
|
|
230
|
+
return int(count)
|
|
231
|
+
|
|
232
|
+
def is_view(self, table_name: str) -> bool:
|
|
233
|
+
connection = self._require_connection()
|
|
234
|
+
schema = self._schema_name()
|
|
235
|
+
query = """
|
|
236
|
+
SELECT table_type
|
|
237
|
+
FROM information_schema.tables
|
|
238
|
+
WHERE table_schema = %s
|
|
239
|
+
AND table_name = %s
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
cursor = connection.cursor()
|
|
243
|
+
cursor.execute(query, (schema, table_name))
|
|
244
|
+
result = cursor.fetchone()
|
|
245
|
+
cursor.close()
|
|
246
|
+
if result:
|
|
247
|
+
return result[0] == "VIEW"
|
|
248
|
+
return False
|
|
249
|
+
|
|
250
|
+
def get_indexes(self, table_name: str) -> List[Dict[str, Any]]:
|
|
251
|
+
connection = self._require_connection()
|
|
252
|
+
schema = self._schema_name()
|
|
253
|
+
query = """
|
|
254
|
+
SELECT indexname, indexdef
|
|
255
|
+
FROM pg_indexes
|
|
256
|
+
WHERE schemaname = %s
|
|
257
|
+
AND tablename = %s
|
|
258
|
+
ORDER BY indexname
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
cursor = connection.cursor()
|
|
262
|
+
cursor.execute(query, (schema, table_name))
|
|
263
|
+
|
|
264
|
+
indexes: List[Dict[str, Any]] = []
|
|
265
|
+
for index_name, index_def in cursor.fetchall():
|
|
266
|
+
index_info: Dict[str, Any] = {
|
|
267
|
+
"name": index_name,
|
|
268
|
+
"columns": self._parse_index_columns(index_def),
|
|
269
|
+
"is_unique": index_def.upper().startswith("CREATE UNIQUE"),
|
|
270
|
+
"type": self._parse_index_type(index_def),
|
|
271
|
+
"definition": index_def,
|
|
272
|
+
}
|
|
273
|
+
indexes.append(index_info)
|
|
274
|
+
|
|
275
|
+
cursor.close()
|
|
276
|
+
return indexes
|
|
277
|
+
|
|
278
|
+
def _schema_name(self) -> str:
|
|
279
|
+
return self.connection_config.get("schema", self._schema or "public")
|
|
280
|
+
|
|
281
|
+
def _require_connection(self) -> Any:
|
|
282
|
+
if self.connection is None:
|
|
283
|
+
raise ConnectionError("Not connected to database")
|
|
284
|
+
return self.connection
|
|
285
|
+
|
|
286
|
+
def _get_primary_key_columns(self, table_name: str) -> List[str]:
|
|
287
|
+
connection = self._require_connection()
|
|
288
|
+
schema = self._schema_name()
|
|
289
|
+
query = """
|
|
290
|
+
SELECT kcu.column_name
|
|
291
|
+
FROM information_schema.table_constraints AS tc
|
|
292
|
+
JOIN information_schema.key_column_usage AS kcu
|
|
293
|
+
ON tc.constraint_name = kcu.constraint_name
|
|
294
|
+
AND tc.table_schema = kcu.table_schema
|
|
295
|
+
WHERE tc.table_schema = %s
|
|
296
|
+
AND tc.table_name = %s
|
|
297
|
+
AND tc.constraint_type = 'PRIMARY KEY'
|
|
298
|
+
ORDER BY kcu.ordinal_position
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
cursor = connection.cursor()
|
|
302
|
+
cursor.execute(query, (schema, table_name))
|
|
303
|
+
primary_keys = [row[0] for row in cursor.fetchall()]
|
|
304
|
+
cursor.close()
|
|
305
|
+
return primary_keys
|
|
306
|
+
|
|
307
|
+
def _parse_index_columns(self, index_def: str) -> List[str]:
|
|
308
|
+
if "(" not in index_def or ")" not in index_def:
|
|
309
|
+
return []
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
columns_part = index_def.split("(", 1)[1].rsplit(")", 1)[0]
|
|
313
|
+
except (IndexError, ValueError):
|
|
314
|
+
return []
|
|
315
|
+
|
|
316
|
+
columns = []
|
|
317
|
+
for raw_column in columns_part.split(","):
|
|
318
|
+
column = raw_column.strip().strip('"')
|
|
319
|
+
if column:
|
|
320
|
+
columns.append(column)
|
|
321
|
+
return columns
|
|
322
|
+
|
|
323
|
+
def _parse_index_type(self, index_def: str) -> Optional[str]:
|
|
324
|
+
marker = " USING "
|
|
325
|
+
upper_def = index_def.upper()
|
|
326
|
+
if marker not in upper_def:
|
|
327
|
+
return None
|
|
328
|
+
|
|
329
|
+
try:
|
|
330
|
+
start_index = upper_def.index(marker) + len(marker)
|
|
331
|
+
postfix = index_def[start_index:]
|
|
332
|
+
index_type = postfix.split(" ", 1)[0].strip()
|
|
333
|
+
return index_type.lower() if index_type else None
|
|
334
|
+
except ValueError:
|
|
335
|
+
return None
|