sql-query-mcp 0.4.0__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sql_query_mcp-0.4.0/sql_query_mcp.egg-info → sql_query_mcp-0.4.1}/PKG-INFO +1 -1
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/pyproject.toml +1 -1
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/__init__.py +1 -1
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/adapters/hive.py +18 -9
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/adapters/mysql.py +38 -13
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/adapters/postgres.py +5 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/importer.py +27 -4
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1/sql_query_mcp.egg-info}/PKG-INFO +1 -1
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_importer.py +110 -3
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_validator.py +149 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/LICENSE +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/README.md +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/setup.cfg +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/__main__.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/adapters/__init__.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/app.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/async_queries.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/audit.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/config.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/errors.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/executor.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/exporter.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/introspection.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/namespace.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/registry.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/release_metadata.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp/validator.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp.egg-info/SOURCES.txt +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp.egg-info/dependency_links.txt +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp.egg-info/entry_points.txt +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp.egg-info/requires.txt +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/sql_query_mcp.egg-info/top_level.txt +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_app.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_async_queries.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_audit.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_config.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_executor.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_exporter.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_metadata.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_namespace.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_registry.py +0 -0
- {sql_query_mcp-0.4.0 → sql_query_mcp-0.4.1}/tests/test_release_metadata.py +0 -0
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from contextlib import contextmanager
|
|
6
|
-
from typing import Iterator, List
|
|
6
|
+
from typing import Any, Iterator, List
|
|
7
7
|
from urllib.parse import parse_qs, unquote, urlparse
|
|
8
8
|
|
|
9
9
|
try:
|
|
@@ -52,6 +52,10 @@ class HiveAdapter:
|
|
|
52
52
|
def column_names(self, description) -> List[str]:
|
|
53
53
|
return [column[0] for column in (description or [])]
|
|
54
54
|
|
|
55
|
+
def normalize_identifier(self, value: str) -> str:
|
|
56
|
+
# Hive table and column identifiers are case-insensitive.
|
|
57
|
+
return value.casefold()
|
|
58
|
+
|
|
55
59
|
def normalize_rows(self, rows, columns: List[str]) -> List[dict]:
|
|
56
60
|
return [dict(zip(columns, row)) for row in rows]
|
|
57
61
|
|
|
@@ -80,7 +84,7 @@ class HiveAdapter:
|
|
|
80
84
|
columns = []
|
|
81
85
|
in_partitions = False
|
|
82
86
|
for row in rows:
|
|
83
|
-
name = self.
|
|
87
|
+
name = self._describe_value(row, "col_name", 0)
|
|
84
88
|
if not name:
|
|
85
89
|
continue
|
|
86
90
|
if str(name).startswith("# Partition Information"):
|
|
@@ -88,9 +92,8 @@ class HiveAdapter:
|
|
|
88
92
|
continue
|
|
89
93
|
if str(name).startswith("#"):
|
|
90
94
|
continue
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
comment = values[2] if len(values) > 2 else None
|
|
95
|
+
data_type = self._describe_value(row, "data_type", 1)
|
|
96
|
+
comment = self._describe_value(row, "comment", 2)
|
|
94
97
|
columns.append(
|
|
95
98
|
{
|
|
96
99
|
"column_name": name,
|
|
@@ -141,7 +144,13 @@ class HiveAdapter:
|
|
|
141
144
|
return next(iter(row.values()))
|
|
142
145
|
return row[0]
|
|
143
146
|
|
|
144
|
-
def
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
147
|
+
def _describe_value(self, row, key: str, index: int) -> Any:
|
|
148
|
+
# Hive table and column identifiers are case-insensitive. DESCRIBE may
|
|
149
|
+
# return tuples or dict rows, so dict key lookup follows Hive semantics.
|
|
150
|
+
if not isinstance(row, dict):
|
|
151
|
+
return row[index] if len(row) > index else None
|
|
152
|
+
lowered_key = key.lower()
|
|
153
|
+
for existing_key, value in row.items():
|
|
154
|
+
if existing_key.lower() == lowered_key:
|
|
155
|
+
return value
|
|
156
|
+
return None
|
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
-
from typing import Iterator, List
|
|
7
|
+
from typing import Any, Iterator, List
|
|
8
8
|
from urllib.parse import parse_qs, unquote, urlparse
|
|
9
9
|
|
|
10
10
|
try:
|
|
@@ -57,7 +57,7 @@ class MySQLAdapter:
|
|
|
57
57
|
ORDER BY schema_name
|
|
58
58
|
"""
|
|
59
59
|
)
|
|
60
|
-
return [row
|
|
60
|
+
return [_row_value(row, "database_name") for row in cur.fetchall()]
|
|
61
61
|
|
|
62
62
|
def list_tables(self, conn: object, database: str):
|
|
63
63
|
with conn.cursor() as cur:
|
|
@@ -70,7 +70,14 @@ class MySQLAdapter:
|
|
|
70
70
|
""",
|
|
71
71
|
(database,),
|
|
72
72
|
)
|
|
73
|
-
return
|
|
73
|
+
return [
|
|
74
|
+
{
|
|
75
|
+
"database_name": _row_value(row, "database_name"),
|
|
76
|
+
"table_name": _row_value(row, "table_name"),
|
|
77
|
+
"table_type": _row_value(row, "table_type"),
|
|
78
|
+
}
|
|
79
|
+
for row in cur.fetchall()
|
|
80
|
+
]
|
|
74
81
|
|
|
75
82
|
def describe_table(self, conn: object, database: str, table_name: str):
|
|
76
83
|
with conn.cursor() as cur:
|
|
@@ -101,13 +108,13 @@ class MySQLAdapter:
|
|
|
101
108
|
return {
|
|
102
109
|
"columns": [
|
|
103
110
|
{
|
|
104
|
-
"column_name": row
|
|
105
|
-
"data_type": row
|
|
111
|
+
"column_name": _row_value(row, "column_name"),
|
|
112
|
+
"data_type": _row_value(row, "column_type"),
|
|
106
113
|
"udt_name": None,
|
|
107
|
-
"nullable": row
|
|
108
|
-
"default": row
|
|
109
|
-
"primary_key": row
|
|
110
|
-
"extra": row
|
|
114
|
+
"nullable": _row_value(row, "is_nullable") == "YES",
|
|
115
|
+
"default": _row_value(row, "column_default"),
|
|
116
|
+
"primary_key": _row_value(row, "column_key") == "PRI",
|
|
117
|
+
"extra": _row_value(row, "extra"),
|
|
111
118
|
}
|
|
112
119
|
for row in columns
|
|
113
120
|
],
|
|
@@ -136,7 +143,7 @@ class MySQLAdapter:
|
|
|
136
143
|
def extract_plan(self, rows):
|
|
137
144
|
if not rows:
|
|
138
145
|
return []
|
|
139
|
-
plan = rows[0]
|
|
146
|
+
plan = _row_value(rows[0], "EXPLAIN")
|
|
140
147
|
if isinstance(plan, str):
|
|
141
148
|
try:
|
|
142
149
|
return json.loads(plan)
|
|
@@ -147,6 +154,11 @@ class MySQLAdapter:
|
|
|
147
154
|
def column_names(self, description) -> List[str]:
|
|
148
155
|
return [column[0] for column in (description or [])]
|
|
149
156
|
|
|
157
|
+
def normalize_identifier(self, value: str) -> str:
|
|
158
|
+
# MySQL column names, index names, and column aliases are
|
|
159
|
+
# case-insensitive on every platform.
|
|
160
|
+
return value.casefold()
|
|
161
|
+
|
|
150
162
|
def _parse_dsn(self, dsn: str) -> dict:
|
|
151
163
|
parsed = urlparse(dsn)
|
|
152
164
|
if parsed.scheme not in {"mysql", "mysql+pymysql"}:
|
|
@@ -169,16 +181,29 @@ class MySQLAdapter:
|
|
|
169
181
|
def _normalize_indexes(self, rows: List[dict]) -> List[dict]:
|
|
170
182
|
grouped = {}
|
|
171
183
|
for row in rows:
|
|
172
|
-
index_name = row
|
|
184
|
+
index_name = _row_value(row, "index_name")
|
|
173
185
|
item = grouped.setdefault(
|
|
174
186
|
index_name,
|
|
175
187
|
{
|
|
176
188
|
"index_name": index_name,
|
|
177
189
|
"columns": [],
|
|
178
|
-
"unique": row
|
|
190
|
+
"unique": _row_value(row, "non_unique") == 0,
|
|
179
191
|
"primary_key": index_name == "PRIMARY",
|
|
180
192
|
"definition": None,
|
|
181
193
|
},
|
|
182
194
|
)
|
|
183
|
-
item["columns"].append(row
|
|
195
|
+
item["columns"].append(_row_value(row, "column_name"))
|
|
184
196
|
return [grouped[name] for name in sorted(grouped)]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _row_value(row: dict, key: str) -> Any:
|
|
200
|
+
# MySQL column names, index names, and column aliases are case-insensitive,
|
|
201
|
+
# and drivers may expose information_schema labels as COLUMN_NAME or
|
|
202
|
+
# column_name. Keep this normalization local to the MySQL adapter.
|
|
203
|
+
if key in row:
|
|
204
|
+
return row[key]
|
|
205
|
+
lowered_key = key.lower()
|
|
206
|
+
for existing_key, value in row.items():
|
|
207
|
+
if existing_key.lower() == lowered_key:
|
|
208
|
+
return value
|
|
209
|
+
raise KeyError(key)
|
|
@@ -188,6 +188,11 @@ class PostgresAdapter:
|
|
|
188
188
|
def column_names(self, description) -> List[str]:
|
|
189
189
|
return [column.name for column in (description or [])]
|
|
190
190
|
|
|
191
|
+
def normalize_identifier(self, value: str) -> str:
|
|
192
|
+
# PostgreSQL quoted identifiers are case-sensitive, and this adapter
|
|
193
|
+
# quotes import columns with sql.Identifier, so header matching is exact.
|
|
194
|
+
return value
|
|
195
|
+
|
|
191
196
|
def _get_pool(self, connection_id: str, dsn: str) -> ConnectionPool:
|
|
192
197
|
if ConnectionPool is None or dict_row is None:
|
|
193
198
|
raise ConfigurationError("缺少 psycopg / psycopg-pool 依赖,请先安装项目依赖。")
|
|
@@ -68,7 +68,8 @@ class TableFileImporter:
|
|
|
68
68
|
f"未找到表 {namespace.value}.{table_name},或当前用户没有访问权限"
|
|
69
69
|
)
|
|
70
70
|
table_columns = [item["column_name"] for item in description["columns"]]
|
|
71
|
-
|
|
71
|
+
normalize_identifier = getattr(adapter, "normalize_identifier", _exact_identifier)
|
|
72
|
+
_validate_headers(headers, table_columns, normalize_identifier)
|
|
72
73
|
query = adapter.build_insert_query(namespace.value, table_name, headers)
|
|
73
74
|
_execute_insert(conn, config.engine, query, rows)
|
|
74
75
|
|
|
@@ -170,20 +171,42 @@ def _normalize_row(row: Sequence[object], expected_length: int) -> Tuple[object,
|
|
|
170
171
|
return tuple(None if value == "" else value for value in row)
|
|
171
172
|
|
|
172
173
|
|
|
173
|
-
def _validate_headers(
|
|
174
|
+
def _validate_headers(
|
|
175
|
+
headers: Sequence[str],
|
|
176
|
+
table_columns: Sequence[str],
|
|
177
|
+
normalize_identifier=None,
|
|
178
|
+
) -> None:
|
|
179
|
+
if normalize_identifier is None:
|
|
180
|
+
normalize_identifier = _exact_identifier
|
|
174
181
|
if not headers:
|
|
175
182
|
raise QueryExecutionError("文件表头不能为空。")
|
|
176
183
|
empty_headers = [index + 1 for index, header in enumerate(headers) if not header]
|
|
177
184
|
if empty_headers:
|
|
178
185
|
raise QueryExecutionError(f"文件表头存在空字段,位置: {empty_headers}")
|
|
179
|
-
|
|
186
|
+
normalized_headers = [normalize_identifier(header) for header in headers]
|
|
187
|
+
duplicates = sorted(
|
|
188
|
+
{
|
|
189
|
+
header
|
|
190
|
+
for header, normalized in zip(headers, normalized_headers)
|
|
191
|
+
if normalized_headers.count(normalized) > 1
|
|
192
|
+
}
|
|
193
|
+
)
|
|
180
194
|
if duplicates:
|
|
181
195
|
raise QueryExecutionError(f"文件表头存在重复字段: {duplicates}")
|
|
182
|
-
|
|
196
|
+
normalized_table_columns = {normalize_identifier(column) for column in table_columns}
|
|
197
|
+
unknown = sorted(
|
|
198
|
+
header
|
|
199
|
+
for header, normalized in zip(headers, normalized_headers)
|
|
200
|
+
if normalized not in normalized_table_columns
|
|
201
|
+
)
|
|
183
202
|
if unknown:
|
|
184
203
|
raise QueryExecutionError(f"文件表头包含目标表不存在的字段: {unknown}")
|
|
185
204
|
|
|
186
205
|
|
|
206
|
+
def _exact_identifier(value: str) -> str:
|
|
207
|
+
return value
|
|
208
|
+
|
|
209
|
+
|
|
187
210
|
def _execute_insert(conn: Any, engine: str, query: object, rows: List[Tuple[object, ...]]) -> None:
|
|
188
211
|
if engine == "postgres" and hasattr(conn, "transaction"):
|
|
189
212
|
with conn.transaction():
|
|
@@ -108,6 +108,14 @@ class _AdapterStub:
|
|
|
108
108
|
def build_insert_query(self, namespace: str, table_name: str, columns):
|
|
109
109
|
return f"insert {namespace}.{table_name} ({','.join(columns)})"
|
|
110
110
|
|
|
111
|
+
def normalize_identifier(self, value: str) -> str:
|
|
112
|
+
return value
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class _CaseInsensitiveAdapterStub(_AdapterStub):
|
|
116
|
+
def normalize_identifier(self, value: str) -> str:
|
|
117
|
+
return value.casefold()
|
|
118
|
+
|
|
111
119
|
|
|
112
120
|
class _RegistryStub:
|
|
113
121
|
def __init__(self, config: ConnectionConfig, adapter: object, conn: object) -> None:
|
|
@@ -160,6 +168,56 @@ class TableFileImporterTestCase(unittest.TestCase):
|
|
|
160
168
|
self.assertEqual(2, records[0]["row_count"])
|
|
161
169
|
self.assertEqual(".csv", records[0]["extra"]["file_extension"])
|
|
162
170
|
|
|
171
|
+
def test_mysql_import_csv_accepts_header_case_by_adapter_semantics(self) -> None:
|
|
172
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
173
|
+
csv_path = _write_csv(Path(temp_dir) / "users.csv", [["ID", "NAME"], ["1", "Alice"]])
|
|
174
|
+
conn = _ConnectionStub()
|
|
175
|
+
importer = _build_importer(
|
|
176
|
+
Path(temp_dir) / "audit.jsonl",
|
|
177
|
+
conn,
|
|
178
|
+
adapter=_CaseInsensitiveAdapterStub(),
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
result = importer.import_table_file(
|
|
182
|
+
"crm_mysql_prod_main_rw",
|
|
183
|
+
"users",
|
|
184
|
+
str(csv_path),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self.assertEqual(1, result["inserted_row_count"])
|
|
188
|
+
self.assertEqual(
|
|
189
|
+
[("insert crm.users (ID,NAME)", [("1", "Alice")])],
|
|
190
|
+
conn.cursor_stub.executed_many,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def test_mysql_import_csv_rejects_case_only_duplicate_header(self) -> None:
|
|
194
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
195
|
+
csv_path = _write_csv(Path(temp_dir) / "users.csv", [["id", "ID"], ["1", "2"]])
|
|
196
|
+
conn = _ConnectionStub()
|
|
197
|
+
importer = _build_importer(
|
|
198
|
+
Path(temp_dir) / "audit.jsonl",
|
|
199
|
+
conn,
|
|
200
|
+
adapter=_CaseInsensitiveAdapterStub(),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
with self.assertRaises(QueryExecutionError) as caught:
|
|
204
|
+
importer.import_table_file("crm_mysql_prod_main_rw", "users", str(csv_path))
|
|
205
|
+
|
|
206
|
+
self.assertIn("重复字段", str(caught.exception))
|
|
207
|
+
self.assertEqual([], conn.cursor_stub.executed_many)
|
|
208
|
+
|
|
209
|
+
def test_postgres_import_csv_rejects_header_with_different_case(self) -> None:
|
|
210
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
211
|
+
csv_path = _write_csv(Path(temp_dir) / "users.csv", [["ID"], ["1"]])
|
|
212
|
+
conn = _ConnectionStub()
|
|
213
|
+
importer = _build_postgres_importer(Path(temp_dir) / "audit.jsonl", conn)
|
|
214
|
+
|
|
215
|
+
with self.assertRaises(QueryExecutionError) as caught:
|
|
216
|
+
importer.import_table_file("crm_postgres_prod_main_rw", "users", str(csv_path))
|
|
217
|
+
|
|
218
|
+
self.assertIn("不存在的字段", str(caught.exception))
|
|
219
|
+
self.assertEqual([], conn.cursor_stub.executed_many)
|
|
220
|
+
|
|
163
221
|
def test_hive_import_csv_uses_existing_import_tool_path(self) -> None:
|
|
164
222
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
165
223
|
csv_path = _write_csv(
|
|
@@ -183,6 +241,28 @@ class TableFileImporterTestCase(unittest.TestCase):
|
|
|
183
241
|
conn.cursor_stub.executed,
|
|
184
242
|
)
|
|
185
243
|
|
|
244
|
+
def test_hive_import_csv_accepts_header_case_by_adapter_semantics(self) -> None:
|
|
245
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
246
|
+
csv_path = _write_csv(Path(temp_dir) / "users.csv", [["NAME"], ["Alice"]])
|
|
247
|
+
conn = _HiveConnectionStub()
|
|
248
|
+
importer = _build_hive_importer(
|
|
249
|
+
Path(temp_dir) / "audit.jsonl",
|
|
250
|
+
conn,
|
|
251
|
+
adapter=_CaseInsensitiveAdapterStub(),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
result = importer.import_table_file(
|
|
255
|
+
"warehouse_hive_prod_main_rw",
|
|
256
|
+
"users",
|
|
257
|
+
str(csv_path),
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
self.assertEqual(1, result["inserted_row_count"])
|
|
261
|
+
self.assertEqual(
|
|
262
|
+
[("insert analytics.users (NAME)", ("Alice",))],
|
|
263
|
+
conn.cursor_stub.executed,
|
|
264
|
+
)
|
|
265
|
+
|
|
186
266
|
def test_hive_import_csv_executes_each_row_without_result_set_error(self) -> None:
|
|
187
267
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
188
268
|
csv_path = _write_csv(
|
|
@@ -384,7 +464,11 @@ class TableFileImporterTestCase(unittest.TestCase):
|
|
|
384
464
|
self.assertEqual(0, conn.begin_calls)
|
|
385
465
|
|
|
386
466
|
|
|
387
|
-
def _build_importer(
|
|
467
|
+
def _build_importer(
|
|
468
|
+
log_path: Path,
|
|
469
|
+
conn: _ConnectionStub,
|
|
470
|
+
adapter: object | None = None,
|
|
471
|
+
) -> TableFileImporter:
|
|
388
472
|
config = ConnectionConfig(
|
|
389
473
|
connection_id="crm_mysql_prod_main_rw",
|
|
390
474
|
engine="mysql",
|
|
@@ -396,6 +480,25 @@ def _build_importer(log_path: Path, conn: _ConnectionStub) -> TableFileImporter:
|
|
|
396
480
|
enabled=True,
|
|
397
481
|
default_database="crm",
|
|
398
482
|
)
|
|
483
|
+
return TableFileImporter(
|
|
484
|
+
registry=_RegistryStub(config, adapter or _AdapterStub(), conn),
|
|
485
|
+
settings=ServerSettings(audit_log_path=log_path),
|
|
486
|
+
audit_logger=AuditLogger(log_path),
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _build_postgres_importer(log_path: Path, conn: object) -> TableFileImporter:
|
|
491
|
+
config = ConnectionConfig(
|
|
492
|
+
connection_id="crm_postgres_prod_main_rw",
|
|
493
|
+
engine="postgres",
|
|
494
|
+
label="CRM PostgreSQL",
|
|
495
|
+
env="prod",
|
|
496
|
+
tenant="main",
|
|
497
|
+
role="rw",
|
|
498
|
+
dsn_env="PG_CONN",
|
|
499
|
+
enabled=True,
|
|
500
|
+
default_schema="public",
|
|
501
|
+
)
|
|
399
502
|
return TableFileImporter(
|
|
400
503
|
registry=_RegistryStub(config, _AdapterStub(), conn),
|
|
401
504
|
settings=ServerSettings(audit_log_path=log_path),
|
|
@@ -403,7 +506,11 @@ def _build_importer(log_path: Path, conn: _ConnectionStub) -> TableFileImporter:
|
|
|
403
506
|
)
|
|
404
507
|
|
|
405
508
|
|
|
406
|
-
def _build_hive_importer(
|
|
509
|
+
def _build_hive_importer(
|
|
510
|
+
log_path: Path,
|
|
511
|
+
conn: object,
|
|
512
|
+
adapter: object | None = None,
|
|
513
|
+
) -> TableFileImporter:
|
|
407
514
|
config = ConnectionConfig(
|
|
408
515
|
connection_id="warehouse_hive_prod_main_rw",
|
|
409
516
|
engine="hive",
|
|
@@ -416,7 +523,7 @@ def _build_hive_importer(log_path: Path, conn: object) -> TableFileImporter:
|
|
|
416
523
|
default_database="analytics",
|
|
417
524
|
)
|
|
418
525
|
return TableFileImporter(
|
|
419
|
-
registry=_RegistryStub(config, _AdapterStub(), conn),
|
|
526
|
+
registry=_RegistryStub(config, adapter or _AdapterStub(), conn),
|
|
420
527
|
settings=ServerSettings(audit_log_path=log_path),
|
|
421
528
|
audit_logger=AuditLogger(log_path),
|
|
422
529
|
)
|
|
@@ -40,6 +40,32 @@ class _HiveConnectionStub:
|
|
|
40
40
|
return self.cursor_stub
|
|
41
41
|
|
|
42
42
|
|
|
43
|
+
class _MySQLCursorStub:
|
|
44
|
+
def __init__(self, result_sets) -> None:
|
|
45
|
+
self._result_sets = list(result_sets)
|
|
46
|
+
self.executed = []
|
|
47
|
+
|
|
48
|
+
def __enter__(self):
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, exc_type, exc, tb) -> None:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
def execute(self, sql: str, params=None) -> None:
|
|
55
|
+
self.executed.append((sql, params))
|
|
56
|
+
|
|
57
|
+
def fetchall(self):
|
|
58
|
+
return self._result_sets.pop(0)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class _MySQLConnectionStub:
|
|
62
|
+
def __init__(self, result_sets) -> None:
|
|
63
|
+
self.cursor_stub = _MySQLCursorStub(result_sets)
|
|
64
|
+
|
|
65
|
+
def cursor(self) -> _MySQLCursorStub:
|
|
66
|
+
return self.cursor_stub
|
|
67
|
+
|
|
68
|
+
|
|
43
69
|
class ValidatorTestCase(unittest.TestCase):
|
|
44
70
|
def test_accepts_plain_select(self) -> None:
|
|
45
71
|
self.assertEqual("SELECT 1", validate_select_sql("SELECT 1;", "postgres"))
|
|
@@ -253,12 +279,135 @@ class ValidatorTestCase(unittest.TestCase):
|
|
|
253
279
|
)
|
|
254
280
|
self.assertEqual([], description["indexes"])
|
|
255
281
|
|
|
282
|
+
def test_hive_describe_table_reads_dict_keys_case_insensitively(self) -> None:
|
|
283
|
+
conn = _HiveConnectionStub(
|
|
284
|
+
[
|
|
285
|
+
{
|
|
286
|
+
"COMMENT": "customer identifier",
|
|
287
|
+
"DATA_TYPE": "int",
|
|
288
|
+
"COL_NAME": "id",
|
|
289
|
+
}
|
|
290
|
+
]
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
description = HiveAdapter().describe_table(conn, "analytics", "orders")
|
|
294
|
+
|
|
295
|
+
self.assertEqual(
|
|
296
|
+
[
|
|
297
|
+
{
|
|
298
|
+
"column_name": "id",
|
|
299
|
+
"data_type": "int",
|
|
300
|
+
"udt_name": None,
|
|
301
|
+
"nullable": True,
|
|
302
|
+
"default": None,
|
|
303
|
+
"primary_key": False,
|
|
304
|
+
"extra": "customer identifier",
|
|
305
|
+
"partition_key": False,
|
|
306
|
+
}
|
|
307
|
+
],
|
|
308
|
+
description["columns"],
|
|
309
|
+
)
|
|
310
|
+
|
|
256
311
|
def test_mysql_explain_plan_is_parsed_to_structured_json(self) -> None:
|
|
257
312
|
plan = MySQLAdapter().extract_plan(
|
|
258
313
|
[{"EXPLAIN": json.dumps({"query_block": {"select_id": 1}})}]
|
|
259
314
|
)
|
|
260
315
|
self.assertEqual({"query_block": {"select_id": 1}}, plan)
|
|
261
316
|
|
|
317
|
+
def test_mysql_explain_plan_reads_label_case_insensitively(self) -> None:
|
|
318
|
+
plan = MySQLAdapter().extract_plan(
|
|
319
|
+
[{"explain": json.dumps({"query_block": {"select_id": 1}})}]
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
self.assertEqual({"query_block": {"select_id": 1}}, plan)
|
|
323
|
+
|
|
324
|
+
def test_mysql_list_databases_reads_row_keys_case_insensitively(self) -> None:
|
|
325
|
+
conn = _MySQLConnectionStub([[{"DATABASE_NAME": "crm"}]])
|
|
326
|
+
|
|
327
|
+
databases = MySQLAdapter().list_databases(conn)
|
|
328
|
+
|
|
329
|
+
self.assertEqual(["crm"], databases)
|
|
330
|
+
|
|
331
|
+
def test_mysql_list_tables_returns_normalized_rows(self) -> None:
|
|
332
|
+
conn = _MySQLConnectionStub(
|
|
333
|
+
[
|
|
334
|
+
[
|
|
335
|
+
{
|
|
336
|
+
"DATABASE_NAME": "crm",
|
|
337
|
+
"TABLE_NAME": "orders",
|
|
338
|
+
"TABLE_TYPE": "BASE TABLE",
|
|
339
|
+
}
|
|
340
|
+
]
|
|
341
|
+
]
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
tables = MySQLAdapter().list_tables(conn, "crm")
|
|
345
|
+
|
|
346
|
+
self.assertEqual(
|
|
347
|
+
[
|
|
348
|
+
{
|
|
349
|
+
"database_name": "crm",
|
|
350
|
+
"table_name": "orders",
|
|
351
|
+
"table_type": "BASE TABLE",
|
|
352
|
+
}
|
|
353
|
+
],
|
|
354
|
+
tables,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
def test_mysql_describe_table_reads_metadata_keys_case_insensitively(self) -> None:
|
|
358
|
+
conn = _MySQLConnectionStub(
|
|
359
|
+
[
|
|
360
|
+
[
|
|
361
|
+
{
|
|
362
|
+
"COLUMN_NAME": "id",
|
|
363
|
+
"COLUMN_TYPE": "bigint",
|
|
364
|
+
"IS_NULLABLE": "NO",
|
|
365
|
+
"COLUMN_DEFAULT": None,
|
|
366
|
+
"EXTRA": "auto_increment",
|
|
367
|
+
"COLUMN_KEY": "PRI",
|
|
368
|
+
"ORDINAL_POSITION": 1,
|
|
369
|
+
}
|
|
370
|
+
],
|
|
371
|
+
[
|
|
372
|
+
{
|
|
373
|
+
"INDEX_NAME": "PRIMARY",
|
|
374
|
+
"NON_UNIQUE": 0,
|
|
375
|
+
"SEQ_IN_INDEX": 1,
|
|
376
|
+
"COLUMN_NAME": "id",
|
|
377
|
+
}
|
|
378
|
+
],
|
|
379
|
+
]
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
description = MySQLAdapter().describe_table(conn, "crm", "orders")
|
|
383
|
+
|
|
384
|
+
self.assertEqual(
|
|
385
|
+
[
|
|
386
|
+
{
|
|
387
|
+
"column_name": "id",
|
|
388
|
+
"data_type": "bigint",
|
|
389
|
+
"udt_name": None,
|
|
390
|
+
"nullable": False,
|
|
391
|
+
"default": None,
|
|
392
|
+
"primary_key": True,
|
|
393
|
+
"extra": "auto_increment",
|
|
394
|
+
}
|
|
395
|
+
],
|
|
396
|
+
description["columns"],
|
|
397
|
+
)
|
|
398
|
+
self.assertEqual(
|
|
399
|
+
[
|
|
400
|
+
{
|
|
401
|
+
"index_name": "PRIMARY",
|
|
402
|
+
"columns": ["id"],
|
|
403
|
+
"unique": True,
|
|
404
|
+
"primary_key": True,
|
|
405
|
+
"definition": None,
|
|
406
|
+
}
|
|
407
|
+
],
|
|
408
|
+
description["indexes"],
|
|
409
|
+
)
|
|
410
|
+
|
|
262
411
|
def test_postgres_build_insert_query_quotes_identifiers(self) -> None:
|
|
263
412
|
query = PostgresAdapter().build_insert_query(
|
|
264
413
|
"public", "orders", ["order", "status"]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|