kontra 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kontra/__init__.py +1871 -0
- kontra/api/__init__.py +22 -0
- kontra/api/compare.py +340 -0
- kontra/api/decorators.py +153 -0
- kontra/api/results.py +2121 -0
- kontra/api/rules.py +681 -0
- kontra/cli/__init__.py +0 -0
- kontra/cli/commands/__init__.py +1 -0
- kontra/cli/commands/config.py +153 -0
- kontra/cli/commands/diff.py +450 -0
- kontra/cli/commands/history.py +196 -0
- kontra/cli/commands/profile.py +289 -0
- kontra/cli/commands/validate.py +468 -0
- kontra/cli/constants.py +6 -0
- kontra/cli/main.py +48 -0
- kontra/cli/renderers.py +304 -0
- kontra/cli/utils.py +28 -0
- kontra/config/__init__.py +34 -0
- kontra/config/loader.py +127 -0
- kontra/config/models.py +49 -0
- kontra/config/settings.py +797 -0
- kontra/connectors/__init__.py +0 -0
- kontra/connectors/db_utils.py +251 -0
- kontra/connectors/detection.py +323 -0
- kontra/connectors/handle.py +368 -0
- kontra/connectors/postgres.py +127 -0
- kontra/connectors/sqlserver.py +226 -0
- kontra/engine/__init__.py +0 -0
- kontra/engine/backends/duckdb_session.py +227 -0
- kontra/engine/backends/duckdb_utils.py +18 -0
- kontra/engine/backends/polars_backend.py +47 -0
- kontra/engine/engine.py +1205 -0
- kontra/engine/executors/__init__.py +15 -0
- kontra/engine/executors/base.py +50 -0
- kontra/engine/executors/database_base.py +528 -0
- kontra/engine/executors/duckdb_sql.py +607 -0
- kontra/engine/executors/postgres_sql.py +162 -0
- kontra/engine/executors/registry.py +69 -0
- kontra/engine/executors/sqlserver_sql.py +163 -0
- kontra/engine/materializers/__init__.py +14 -0
- kontra/engine/materializers/base.py +42 -0
- kontra/engine/materializers/duckdb.py +110 -0
- kontra/engine/materializers/factory.py +22 -0
- kontra/engine/materializers/polars_connector.py +131 -0
- kontra/engine/materializers/postgres.py +157 -0
- kontra/engine/materializers/registry.py +138 -0
- kontra/engine/materializers/sqlserver.py +160 -0
- kontra/engine/result.py +15 -0
- kontra/engine/sql_utils.py +611 -0
- kontra/engine/sql_validator.py +609 -0
- kontra/engine/stats.py +194 -0
- kontra/engine/types.py +138 -0
- kontra/errors.py +533 -0
- kontra/logging.py +85 -0
- kontra/preplan/__init__.py +5 -0
- kontra/preplan/planner.py +253 -0
- kontra/preplan/postgres.py +179 -0
- kontra/preplan/sqlserver.py +191 -0
- kontra/preplan/types.py +24 -0
- kontra/probes/__init__.py +20 -0
- kontra/probes/compare.py +400 -0
- kontra/probes/relationship.py +283 -0
- kontra/reporters/__init__.py +0 -0
- kontra/reporters/json_reporter.py +190 -0
- kontra/reporters/rich_reporter.py +11 -0
- kontra/rules/__init__.py +35 -0
- kontra/rules/base.py +186 -0
- kontra/rules/builtin/__init__.py +40 -0
- kontra/rules/builtin/allowed_values.py +156 -0
- kontra/rules/builtin/compare.py +188 -0
- kontra/rules/builtin/conditional_not_null.py +213 -0
- kontra/rules/builtin/conditional_range.py +310 -0
- kontra/rules/builtin/contains.py +138 -0
- kontra/rules/builtin/custom_sql_check.py +182 -0
- kontra/rules/builtin/disallowed_values.py +140 -0
- kontra/rules/builtin/dtype.py +203 -0
- kontra/rules/builtin/ends_with.py +129 -0
- kontra/rules/builtin/freshness.py +240 -0
- kontra/rules/builtin/length.py +193 -0
- kontra/rules/builtin/max_rows.py +35 -0
- kontra/rules/builtin/min_rows.py +46 -0
- kontra/rules/builtin/not_null.py +121 -0
- kontra/rules/builtin/range.py +222 -0
- kontra/rules/builtin/regex.py +143 -0
- kontra/rules/builtin/starts_with.py +129 -0
- kontra/rules/builtin/unique.py +124 -0
- kontra/rules/condition_parser.py +203 -0
- kontra/rules/execution_plan.py +455 -0
- kontra/rules/factory.py +103 -0
- kontra/rules/predicates.py +25 -0
- kontra/rules/registry.py +24 -0
- kontra/rules/static_predicates.py +120 -0
- kontra/scout/__init__.py +9 -0
- kontra/scout/backends/__init__.py +17 -0
- kontra/scout/backends/base.py +111 -0
- kontra/scout/backends/duckdb_backend.py +359 -0
- kontra/scout/backends/postgres_backend.py +519 -0
- kontra/scout/backends/sqlserver_backend.py +577 -0
- kontra/scout/dtype_mapping.py +150 -0
- kontra/scout/patterns.py +69 -0
- kontra/scout/profiler.py +801 -0
- kontra/scout/reporters/__init__.py +39 -0
- kontra/scout/reporters/json_reporter.py +165 -0
- kontra/scout/reporters/markdown_reporter.py +152 -0
- kontra/scout/reporters/rich_reporter.py +144 -0
- kontra/scout/store.py +208 -0
- kontra/scout/suggest.py +200 -0
- kontra/scout/types.py +652 -0
- kontra/state/__init__.py +29 -0
- kontra/state/backends/__init__.py +79 -0
- kontra/state/backends/base.py +348 -0
- kontra/state/backends/local.py +480 -0
- kontra/state/backends/postgres.py +1010 -0
- kontra/state/backends/s3.py +543 -0
- kontra/state/backends/sqlserver.py +969 -0
- kontra/state/fingerprint.py +166 -0
- kontra/state/types.py +1061 -0
- kontra/version.py +1 -0
- kontra-0.5.2.dist-info/METADATA +122 -0
- kontra-0.5.2.dist-info/RECORD +124 -0
- kontra-0.5.2.dist-info/WHEEL +5 -0
- kontra-0.5.2.dist-info/entry_points.txt +2 -0
- kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
- kontra-0.5.2.dist-info/top_level.txt +1 -0
|
File without changes
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
# src/kontra/connectors/db_utils.py
|
|
2
|
+
"""
|
|
3
|
+
Shared utilities for database connectors.
|
|
4
|
+
|
|
5
|
+
This module provides common functionality for resolving connection parameters
|
|
6
|
+
from URIs and environment variables, reducing duplication between
|
|
7
|
+
postgres.py and sqlserver.py.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
14
|
+
from urllib.parse import urlparse, unquote
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class DbConnectionConfig:
|
|
20
|
+
"""Configuration for resolving database connection parameters."""
|
|
21
|
+
|
|
22
|
+
# Defaults
|
|
23
|
+
default_host: str
|
|
24
|
+
default_port: int
|
|
25
|
+
default_user: str
|
|
26
|
+
default_schema: str
|
|
27
|
+
|
|
28
|
+
# Environment variable names
|
|
29
|
+
env_host: str
|
|
30
|
+
env_port: str
|
|
31
|
+
env_user: str
|
|
32
|
+
env_password: str
|
|
33
|
+
env_database: str
|
|
34
|
+
env_url: Optional[str] # e.g., DATABASE_URL, SQLSERVER_URL
|
|
35
|
+
|
|
36
|
+
# Error message context
|
|
37
|
+
db_name: str # e.g., "PostgreSQL", "SQL Server"
|
|
38
|
+
uri_example: str # e.g., "postgres://user:pass@host:5432/database/schema.table"
|
|
39
|
+
env_example: str # e.g., "PGDATABASE"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class ResolvedConnectionParams:
|
|
44
|
+
"""
|
|
45
|
+
Generic resolved connection parameters.
|
|
46
|
+
|
|
47
|
+
Dialect-specific connectors convert this to their own dataclass.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
host: str
|
|
51
|
+
port: int
|
|
52
|
+
user: str
|
|
53
|
+
password: Optional[str]
|
|
54
|
+
database: Optional[str]
|
|
55
|
+
schema: str
|
|
56
|
+
table: Optional[str]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def resolve_connection_params(
|
|
60
|
+
uri: str,
|
|
61
|
+
config: DbConnectionConfig,
|
|
62
|
+
) -> ResolvedConnectionParams:
|
|
63
|
+
"""
|
|
64
|
+
Resolve database connection parameters from URI + environment.
|
|
65
|
+
|
|
66
|
+
Three-layer resolution with later layers overriding earlier:
|
|
67
|
+
1. Environment variables (PGXXX, MSSQL_XXX, etc.)
|
|
68
|
+
2. URL environment variable (DATABASE_URL, SQLSERVER_URL)
|
|
69
|
+
3. Explicit URI values (highest priority)
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
uri: The connection URI
|
|
73
|
+
config: Dialect-specific configuration
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
ResolvedConnectionParams with all values resolved
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
ValueError: If required parameters (database, table) cannot be resolved
|
|
80
|
+
"""
|
|
81
|
+
parsed = urlparse(uri)
|
|
82
|
+
|
|
83
|
+
# Start with defaults
|
|
84
|
+
host = config.default_host
|
|
85
|
+
port = config.default_port
|
|
86
|
+
user = config.default_user
|
|
87
|
+
password: Optional[str] = None
|
|
88
|
+
database: Optional[str] = None
|
|
89
|
+
schema = config.default_schema
|
|
90
|
+
table: Optional[str] = None
|
|
91
|
+
|
|
92
|
+
# Layer 1: Standard environment variables
|
|
93
|
+
host, port, user, password, database = _apply_env_vars(
|
|
94
|
+
host, port, user, password, database, config
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Layer 2: URL environment variable (if configured)
|
|
98
|
+
if config.env_url:
|
|
99
|
+
host, port, user, password, database = _apply_url_env_var(
|
|
100
|
+
host, port, user, password, database, config.env_url
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Layer 3: Explicit URI values (highest priority)
|
|
104
|
+
host, port, user, password = _apply_uri_connection(
|
|
105
|
+
host, port, user, password, parsed
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Extract database and schema.table from path
|
|
109
|
+
database, schema, table = _parse_uri_path(
|
|
110
|
+
parsed.path, database, config.default_schema
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Validate required fields
|
|
114
|
+
_validate_required_fields(database, table, config)
|
|
115
|
+
|
|
116
|
+
return ResolvedConnectionParams(
|
|
117
|
+
host=host,
|
|
118
|
+
port=port,
|
|
119
|
+
user=user,
|
|
120
|
+
password=password,
|
|
121
|
+
database=database,
|
|
122
|
+
schema=schema,
|
|
123
|
+
table=table,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _apply_env_vars(
|
|
128
|
+
host: str,
|
|
129
|
+
port: int,
|
|
130
|
+
user: str,
|
|
131
|
+
password: Optional[str],
|
|
132
|
+
database: Optional[str],
|
|
133
|
+
config: DbConnectionConfig,
|
|
134
|
+
) -> Tuple[str, int, str, Optional[str], Optional[str]]:
|
|
135
|
+
"""Apply environment variables (Layer 1)."""
|
|
136
|
+
if os.getenv(config.env_host):
|
|
137
|
+
host = os.getenv(config.env_host, host)
|
|
138
|
+
if os.getenv(config.env_port):
|
|
139
|
+
try:
|
|
140
|
+
port = int(os.getenv(config.env_port, str(port)))
|
|
141
|
+
except ValueError:
|
|
142
|
+
pass
|
|
143
|
+
if os.getenv(config.env_user):
|
|
144
|
+
user = os.getenv(config.env_user, user)
|
|
145
|
+
if os.getenv(config.env_password):
|
|
146
|
+
password = os.getenv(config.env_password)
|
|
147
|
+
if os.getenv(config.env_database):
|
|
148
|
+
database = os.getenv(config.env_database)
|
|
149
|
+
|
|
150
|
+
return host, port, user, password, database
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _apply_url_env_var(
|
|
154
|
+
host: str,
|
|
155
|
+
port: int,
|
|
156
|
+
user: str,
|
|
157
|
+
password: Optional[str],
|
|
158
|
+
database: Optional[str],
|
|
159
|
+
env_url_name: str,
|
|
160
|
+
) -> Tuple[str, int, str, Optional[str], Optional[str]]:
|
|
161
|
+
"""Apply URL environment variable like DATABASE_URL (Layer 2)."""
|
|
162
|
+
url_value = os.getenv(env_url_name)
|
|
163
|
+
if not url_value:
|
|
164
|
+
return host, port, user, password, database
|
|
165
|
+
|
|
166
|
+
db_parsed = urlparse(url_value)
|
|
167
|
+
if db_parsed.hostname:
|
|
168
|
+
host = db_parsed.hostname
|
|
169
|
+
if db_parsed.port:
|
|
170
|
+
port = db_parsed.port
|
|
171
|
+
if db_parsed.username:
|
|
172
|
+
user = unquote(db_parsed.username)
|
|
173
|
+
if db_parsed.password:
|
|
174
|
+
password = unquote(db_parsed.password)
|
|
175
|
+
if db_parsed.path and db_parsed.path != "/":
|
|
176
|
+
database = db_parsed.path.strip("/").split("/")[0]
|
|
177
|
+
|
|
178
|
+
return host, port, user, password, database
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _apply_uri_connection(
|
|
182
|
+
host: str,
|
|
183
|
+
port: int,
|
|
184
|
+
user: str,
|
|
185
|
+
password: Optional[str],
|
|
186
|
+
parsed,
|
|
187
|
+
) -> Tuple[str, int, str, Optional[str]]:
|
|
188
|
+
"""Apply explicit URI connection values (Layer 3)."""
|
|
189
|
+
if parsed.hostname:
|
|
190
|
+
host = parsed.hostname
|
|
191
|
+
if parsed.port:
|
|
192
|
+
port = parsed.port
|
|
193
|
+
if parsed.username:
|
|
194
|
+
user = unquote(parsed.username)
|
|
195
|
+
if parsed.password:
|
|
196
|
+
password = unquote(parsed.password)
|
|
197
|
+
|
|
198
|
+
return host, port, user, password
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _parse_uri_path(
|
|
202
|
+
path: str,
|
|
203
|
+
current_database: Optional[str],
|
|
204
|
+
default_schema: str,
|
|
205
|
+
) -> Tuple[Optional[str], str, Optional[str]]:
|
|
206
|
+
"""
|
|
207
|
+
Parse database, schema, and table from URI path.
|
|
208
|
+
|
|
209
|
+
Format: /database/schema.table or /database/table (uses default schema)
|
|
210
|
+
"""
|
|
211
|
+
database = current_database
|
|
212
|
+
schema = default_schema
|
|
213
|
+
table: Optional[str] = None
|
|
214
|
+
|
|
215
|
+
path_parts = [p for p in path.strip("/").split("/") if p]
|
|
216
|
+
|
|
217
|
+
if len(path_parts) >= 1:
|
|
218
|
+
database = path_parts[0]
|
|
219
|
+
|
|
220
|
+
if len(path_parts) >= 2:
|
|
221
|
+
schema_table = path_parts[1]
|
|
222
|
+
if "." in schema_table:
|
|
223
|
+
schema, table = schema_table.split(".", 1)
|
|
224
|
+
else:
|
|
225
|
+
schema = default_schema
|
|
226
|
+
table = schema_table
|
|
227
|
+
|
|
228
|
+
return database, schema, table
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _validate_required_fields(
|
|
232
|
+
database: Optional[str],
|
|
233
|
+
table: Optional[str],
|
|
234
|
+
config: DbConnectionConfig,
|
|
235
|
+
) -> None:
|
|
236
|
+
"""Validate that required fields are present."""
|
|
237
|
+
if not database:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"{config.db_name} database name is required.\n\n"
|
|
240
|
+
f"Set {config.env_database} environment variable or use full URI:\n"
|
|
241
|
+
f" {config.uri_example}"
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if not table:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f"{config.db_name} table name is required.\n\n"
|
|
247
|
+
f"Specify schema.table in URI:\n"
|
|
248
|
+
f" {config.uri_example}\n"
|
|
249
|
+
f" {config.uri_example.split('/')[0]}///{config.default_schema}.users "
|
|
250
|
+
f"(with {config.env_database} set)"
|
|
251
|
+
)
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
# src/kontra/connectors/detection.py
|
|
2
|
+
"""
|
|
3
|
+
Connection type detection for BYOC (Bring Your Own Connection) pattern.
|
|
4
|
+
|
|
5
|
+
Detects database dialect from connection objects so Kontra can use the
|
|
6
|
+
correct SQL executor and materializer.
|
|
7
|
+
|
|
8
|
+
Supported connection types:
|
|
9
|
+
- psycopg / psycopg2 / psycopg3 → PostgreSQL
|
|
10
|
+
- pg8000 → PostgreSQL
|
|
11
|
+
- pyodbc → SQL Server (or detected via getinfo)
|
|
12
|
+
- pymssql → SQL Server
|
|
13
|
+
- SQLAlchemy engine/connection → detected from dialect
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Any, Tuple, Optional
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Dialect constants
|
|
22
|
+
POSTGRESQL = "postgresql"
|
|
23
|
+
SQLSERVER = "sqlserver"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def detect_connection_dialect(conn: Any) -> str:
|
|
27
|
+
"""
|
|
28
|
+
Detect database dialect from a connection object.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
conn: A database connection object (psycopg, pyodbc, SQLAlchemy, etc.)
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Dialect string: "postgresql" or "sqlserver"
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
ValueError: If connection type cannot be detected
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
>>> import psycopg2
|
|
41
|
+
>>> conn = psycopg2.connect(...)
|
|
42
|
+
>>> detect_connection_dialect(conn)
|
|
43
|
+
'postgresql'
|
|
44
|
+
|
|
45
|
+
>>> import pyodbc
|
|
46
|
+
>>> conn = pyodbc.connect("DRIVER={ODBC Driver 17};SERVER=...")
|
|
47
|
+
>>> detect_connection_dialect(conn)
|
|
48
|
+
'sqlserver'
|
|
49
|
+
"""
|
|
50
|
+
module = type(conn).__module__
|
|
51
|
+
class_name = type(conn).__name__
|
|
52
|
+
|
|
53
|
+
# PostgreSQL drivers
|
|
54
|
+
if module.startswith("psycopg"):
|
|
55
|
+
return POSTGRESQL
|
|
56
|
+
if module.startswith("pg8000"):
|
|
57
|
+
return POSTGRESQL
|
|
58
|
+
if "postgres" in module.lower():
|
|
59
|
+
return POSTGRESQL
|
|
60
|
+
|
|
61
|
+
# SQL Server drivers
|
|
62
|
+
if module.startswith("pymssql"):
|
|
63
|
+
return SQLSERVER
|
|
64
|
+
|
|
65
|
+
# pyodbc - generic ODBC, need to inspect
|
|
66
|
+
if module == "pyodbc":
|
|
67
|
+
return _detect_pyodbc_dialect(conn)
|
|
68
|
+
|
|
69
|
+
# SQLAlchemy
|
|
70
|
+
if module.startswith("sqlalchemy"):
|
|
71
|
+
return _detect_sqlalchemy_dialect(conn)
|
|
72
|
+
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Unknown connection type: {module}.{class_name}\n\n"
|
|
75
|
+
"Supported connection types:\n"
|
|
76
|
+
" - psycopg / psycopg2 / psycopg3 (PostgreSQL)\n"
|
|
77
|
+
" - pg8000 (PostgreSQL)\n"
|
|
78
|
+
" - pyodbc (SQL Server, PostgreSQL via ODBC)\n"
|
|
79
|
+
" - pymssql (SQL Server)\n"
|
|
80
|
+
" - SQLAlchemy engine or connection"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _detect_pyodbc_dialect(conn: Any) -> str:
|
|
85
|
+
"""
|
|
86
|
+
Detect dialect from a pyodbc connection.
|
|
87
|
+
|
|
88
|
+
pyodbc is generic - it can connect to SQL Server, PostgreSQL, MySQL, etc.
|
|
89
|
+
We use getinfo(SQL_DBMS_NAME) to detect the actual database.
|
|
90
|
+
"""
|
|
91
|
+
try:
|
|
92
|
+
import pyodbc
|
|
93
|
+
dbms_name = conn.getinfo(pyodbc.SQL_DBMS_NAME).lower()
|
|
94
|
+
|
|
95
|
+
if "sql server" in dbms_name or "microsoft" in dbms_name:
|
|
96
|
+
return SQLSERVER
|
|
97
|
+
if "postgres" in dbms_name:
|
|
98
|
+
return POSTGRESQL
|
|
99
|
+
|
|
100
|
+
# Default for pyodbc (most common use case)
|
|
101
|
+
return SQLSERVER
|
|
102
|
+
|
|
103
|
+
except Exception:
|
|
104
|
+
# If getinfo fails, assume SQL Server (most common pyodbc use)
|
|
105
|
+
return SQLSERVER
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _detect_sqlalchemy_dialect(conn: Any) -> str:
|
|
109
|
+
"""
|
|
110
|
+
Detect dialect from a SQLAlchemy engine or connection.
|
|
111
|
+
|
|
112
|
+
SQLAlchemy connections/engines have a dialect attribute that tells us
|
|
113
|
+
the database type.
|
|
114
|
+
"""
|
|
115
|
+
dialect_name = None
|
|
116
|
+
|
|
117
|
+
# SQLAlchemy Engine
|
|
118
|
+
if hasattr(conn, "dialect"):
|
|
119
|
+
dialect_name = conn.dialect.name
|
|
120
|
+
|
|
121
|
+
# SQLAlchemy Connection (has engine attribute)
|
|
122
|
+
elif hasattr(conn, "engine") and hasattr(conn.engine, "dialect"):
|
|
123
|
+
dialect_name = conn.engine.dialect.name
|
|
124
|
+
|
|
125
|
+
if dialect_name:
|
|
126
|
+
dialect_lower = dialect_name.lower()
|
|
127
|
+
if "postgres" in dialect_lower:
|
|
128
|
+
return POSTGRESQL
|
|
129
|
+
if "mssql" in dialect_lower or "sqlserver" in dialect_lower:
|
|
130
|
+
return SQLSERVER
|
|
131
|
+
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Unsupported SQLAlchemy dialect: {dialect_name}\n\n"
|
|
134
|
+
"Supported dialects: postgresql, mssql"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
raise ValueError(
|
|
138
|
+
"Could not detect dialect from SQLAlchemy connection.\n"
|
|
139
|
+
"Make sure you're passing an Engine or Connection object."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def is_cursor_object(obj: Any) -> bool:
|
|
144
|
+
"""
|
|
145
|
+
Check if an object appears to be a database cursor (not a connection).
|
|
146
|
+
|
|
147
|
+
Cursors are returned by connection.cursor() and have execute/fetch methods
|
|
148
|
+
but NOT a cursor() method themselves.
|
|
149
|
+
|
|
150
|
+
This helps catch a common mistake: passing cursor instead of connection.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
obj: Any Python object
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
True if the object appears to be a database cursor
|
|
157
|
+
"""
|
|
158
|
+
if obj is None:
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
class_name = type(obj).__name__.lower()
|
|
162
|
+
|
|
163
|
+
# Explicit cursor class names
|
|
164
|
+
if "cursor" in class_name:
|
|
165
|
+
return True
|
|
166
|
+
|
|
167
|
+
# Has execute/fetchone but NOT cursor() method = likely a cursor
|
|
168
|
+
has_execute = hasattr(obj, "execute") and callable(getattr(obj, "execute", None))
|
|
169
|
+
has_fetch = hasattr(obj, "fetchone") and callable(getattr(obj, "fetchone", None))
|
|
170
|
+
has_cursor_method = hasattr(obj, "cursor") and callable(getattr(obj, "cursor", None))
|
|
171
|
+
|
|
172
|
+
if has_execute and has_fetch and not has_cursor_method:
|
|
173
|
+
return True
|
|
174
|
+
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def is_database_connection(obj: Any) -> bool:
|
|
179
|
+
"""
|
|
180
|
+
Check if an object appears to be a database connection.
|
|
181
|
+
|
|
182
|
+
This is a heuristic check - we look for common connection attributes
|
|
183
|
+
and module names.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
obj: Any Python object
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
True if the object appears to be a database connection
|
|
190
|
+
"""
|
|
191
|
+
if obj is None:
|
|
192
|
+
return False
|
|
193
|
+
|
|
194
|
+
# First check it's not a cursor
|
|
195
|
+
if is_cursor_object(obj):
|
|
196
|
+
return False
|
|
197
|
+
|
|
198
|
+
module = type(obj).__module__
|
|
199
|
+
|
|
200
|
+
# Known database driver modules
|
|
201
|
+
known_modules = (
|
|
202
|
+
"psycopg",
|
|
203
|
+
"psycopg2",
|
|
204
|
+
"pg8000",
|
|
205
|
+
"pyodbc",
|
|
206
|
+
"pymssql",
|
|
207
|
+
"sqlalchemy",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
for known in known_modules:
|
|
211
|
+
if module.startswith(known):
|
|
212
|
+
return True
|
|
213
|
+
|
|
214
|
+
# Check for common connection attributes
|
|
215
|
+
if hasattr(obj, "cursor") and callable(getattr(obj, "cursor", None)):
|
|
216
|
+
return True
|
|
217
|
+
|
|
218
|
+
# SQLAlchemy engine
|
|
219
|
+
if hasattr(obj, "dialect") and hasattr(obj, "connect"):
|
|
220
|
+
return True
|
|
221
|
+
|
|
222
|
+
return False
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def is_sqlalchemy_object(obj: Any) -> bool:
|
|
226
|
+
"""
|
|
227
|
+
Check if an object is a SQLAlchemy Engine or Connection.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
obj: Any Python object
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
True if the object is a SQLAlchemy engine or connection
|
|
234
|
+
"""
|
|
235
|
+
module = type(obj).__module__
|
|
236
|
+
return module.startswith("sqlalchemy")
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def unwrap_sqlalchemy_connection(obj: Any) -> Any:
|
|
240
|
+
"""
|
|
241
|
+
Extract the raw DBAPI connection from a SQLAlchemy Engine or Connection.
|
|
242
|
+
|
|
243
|
+
SQLAlchemy objects don't have a .cursor() method, but we can get the
|
|
244
|
+
underlying DBAPI connection which does.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
obj: A SQLAlchemy Engine or Connection object
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
The raw DBAPI connection object
|
|
251
|
+
|
|
252
|
+
Raises:
|
|
253
|
+
ValueError: If the object cannot be unwrapped
|
|
254
|
+
"""
|
|
255
|
+
# If it's not SQLAlchemy, return as-is
|
|
256
|
+
if not is_sqlalchemy_object(obj):
|
|
257
|
+
return obj
|
|
258
|
+
|
|
259
|
+
# SQLAlchemy 2.x Engine - use raw_connection()
|
|
260
|
+
if hasattr(obj, "raw_connection"):
|
|
261
|
+
return obj.raw_connection()
|
|
262
|
+
|
|
263
|
+
# SQLAlchemy 2.x Connection - get underlying connection
|
|
264
|
+
if hasattr(obj, "connection"):
|
|
265
|
+
dbapi_conn = obj.connection
|
|
266
|
+
# In SQLAlchemy 2.x, this might be another wrapper
|
|
267
|
+
if hasattr(dbapi_conn, "dbapi_connection"):
|
|
268
|
+
return dbapi_conn.dbapi_connection
|
|
269
|
+
return dbapi_conn
|
|
270
|
+
|
|
271
|
+
# SQLAlchemy 1.x Engine
|
|
272
|
+
if hasattr(obj, "connect"):
|
|
273
|
+
sa_conn = obj.connect()
|
|
274
|
+
if hasattr(sa_conn, "connection"):
|
|
275
|
+
return sa_conn.connection
|
|
276
|
+
return sa_conn
|
|
277
|
+
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"Cannot extract DBAPI connection from SQLAlchemy object: {type(obj).__name__}\n"
|
|
280
|
+
"Try passing engine.raw_connection() instead."
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def parse_table_reference(table: str) -> Tuple[Optional[str], Optional[str], str]:
|
|
285
|
+
"""
|
|
286
|
+
Parse a table reference into (database, schema, table) components.
|
|
287
|
+
|
|
288
|
+
Formats:
|
|
289
|
+
- "table" → (None, None, "table")
|
|
290
|
+
- "schema.table" → (None, "schema", "table")
|
|
291
|
+
- "database.schema.table" → ("database", "schema", "table")
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
table: Table reference string
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Tuple of (database, schema, table_name)
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
ValueError: If table reference has too many parts
|
|
301
|
+
"""
|
|
302
|
+
parts = table.split(".")
|
|
303
|
+
|
|
304
|
+
if len(parts) == 1:
|
|
305
|
+
return None, None, parts[0]
|
|
306
|
+
elif len(parts) == 2:
|
|
307
|
+
return None, parts[0], parts[1]
|
|
308
|
+
elif len(parts) == 3:
|
|
309
|
+
return parts[0], parts[1], parts[2]
|
|
310
|
+
else:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Invalid table reference: {table}\n\n"
|
|
313
|
+
"Expected format: table, schema.table, or database.schema.table"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def get_default_schema(dialect: str) -> str:
|
|
318
|
+
"""Get the default schema for a dialect."""
|
|
319
|
+
if dialect == POSTGRESQL:
|
|
320
|
+
return "public"
|
|
321
|
+
elif dialect == SQLSERVER:
|
|
322
|
+
return "dbo"
|
|
323
|
+
return "public"
|