sqlspec 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlspec/__init__.py +104 -0
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +14 -0
- sqlspec/_serialization.py +312 -0
- sqlspec/_typing.py +784 -0
- sqlspec/adapters/__init__.py +0 -0
- sqlspec/adapters/adbc/__init__.py +5 -0
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +880 -0
- sqlspec/adapters/adbc/config.py +436 -0
- sqlspec/adapters/adbc/data_dictionary.py +537 -0
- sqlspec/adapters/adbc/driver.py +841 -0
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +153 -0
- sqlspec/adapters/aiosqlite/__init__.py +29 -0
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +536 -0
- sqlspec/adapters/aiosqlite/config.py +310 -0
- sqlspec/adapters/aiosqlite/data_dictionary.py +260 -0
- sqlspec/adapters/aiosqlite/driver.py +463 -0
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +500 -0
- sqlspec/adapters/asyncmy/__init__.py +25 -0
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +503 -0
- sqlspec/adapters/asyncmy/config.py +246 -0
- sqlspec/adapters/asyncmy/data_dictionary.py +241 -0
- sqlspec/adapters/asyncmy/driver.py +632 -0
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +23 -0
- sqlspec/adapters/asyncpg/_type_handlers.py +76 -0
- sqlspec/adapters/asyncpg/_types.py +23 -0
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +460 -0
- sqlspec/adapters/asyncpg/config.py +464 -0
- sqlspec/adapters/asyncpg/data_dictionary.py +321 -0
- sqlspec/adapters/asyncpg/driver.py +720 -0
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/__init__.py +18 -0
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +585 -0
- sqlspec/adapters/bigquery/config.py +298 -0
- sqlspec/adapters/bigquery/data_dictionary.py +256 -0
- sqlspec/adapters/bigquery/driver.py +1073 -0
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +125 -0
- sqlspec/adapters/duckdb/__init__.py +24 -0
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +563 -0
- sqlspec/adapters/duckdb/config.py +396 -0
- sqlspec/adapters/duckdb/data_dictionary.py +264 -0
- sqlspec/adapters/duckdb/driver.py +604 -0
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +273 -0
- sqlspec/adapters/duckdb/type_converter.py +133 -0
- sqlspec/adapters/oracledb/__init__.py +32 -0
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +39 -0
- sqlspec/adapters/oracledb/_uuid_handlers.py +130 -0
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1632 -0
- sqlspec/adapters/oracledb/config.py +469 -0
- sqlspec/adapters/oracledb/data_dictionary.py +717 -0
- sqlspec/adapters/oracledb/driver.py +1493 -0
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +765 -0
- sqlspec/adapters/oracledb/migrations.py +532 -0
- sqlspec/adapters/oracledb/type_converter.py +207 -0
- sqlspec/adapters/psqlpy/__init__.py +16 -0
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +12 -0
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +483 -0
- sqlspec/adapters/psqlpy/config.py +271 -0
- sqlspec/adapters/psqlpy/data_dictionary.py +179 -0
- sqlspec/adapters/psqlpy/driver.py +892 -0
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +102 -0
- sqlspec/adapters/psycopg/__init__.py +32 -0
- sqlspec/adapters/psycopg/_type_handlers.py +90 -0
- sqlspec/adapters/psycopg/_types.py +18 -0
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +962 -0
- sqlspec/adapters/psycopg/config.py +487 -0
- sqlspec/adapters/psycopg/data_dictionary.py +630 -0
- sqlspec/adapters/psycopg/driver.py +1336 -0
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/spanner/__init__.py +38 -0
- sqlspec/adapters/spanner/_type_handlers.py +186 -0
- sqlspec/adapters/spanner/_types.py +12 -0
- sqlspec/adapters/spanner/adk/__init__.py +5 -0
- sqlspec/adapters/spanner/adk/store.py +435 -0
- sqlspec/adapters/spanner/config.py +241 -0
- sqlspec/adapters/spanner/data_dictionary.py +95 -0
- sqlspec/adapters/spanner/dialect/__init__.py +6 -0
- sqlspec/adapters/spanner/dialect/_spangres.py +52 -0
- sqlspec/adapters/spanner/dialect/_spanner.py +123 -0
- sqlspec/adapters/spanner/driver.py +366 -0
- sqlspec/adapters/spanner/litestar/__init__.py +5 -0
- sqlspec/adapters/spanner/litestar/store.py +266 -0
- sqlspec/adapters/spanner/type_converter.py +46 -0
- sqlspec/adapters/sqlite/__init__.py +18 -0
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +582 -0
- sqlspec/adapters/sqlite/config.py +221 -0
- sqlspec/adapters/sqlite/data_dictionary.py +256 -0
- sqlspec/adapters/sqlite/driver.py +527 -0
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +140 -0
- sqlspec/base.py +811 -0
- sqlspec/builder/__init__.py +146 -0
- sqlspec/builder/_base.py +900 -0
- sqlspec/builder/_column.py +517 -0
- sqlspec/builder/_ddl.py +1642 -0
- sqlspec/builder/_delete.py +84 -0
- sqlspec/builder/_dml.py +381 -0
- sqlspec/builder/_expression_wrappers.py +46 -0
- sqlspec/builder/_factory.py +1537 -0
- sqlspec/builder/_insert.py +315 -0
- sqlspec/builder/_join.py +375 -0
- sqlspec/builder/_merge.py +848 -0
- sqlspec/builder/_parsing_utils.py +297 -0
- sqlspec/builder/_select.py +1615 -0
- sqlspec/builder/_update.py +161 -0
- sqlspec/builder/_vector_expressions.py +259 -0
- sqlspec/cli.py +764 -0
- sqlspec/config.py +1540 -0
- sqlspec/core/__init__.py +305 -0
- sqlspec/core/cache.py +785 -0
- sqlspec/core/compiler.py +603 -0
- sqlspec/core/filters.py +872 -0
- sqlspec/core/hashing.py +274 -0
- sqlspec/core/metrics.py +83 -0
- sqlspec/core/parameters/__init__.py +64 -0
- sqlspec/core/parameters/_alignment.py +266 -0
- sqlspec/core/parameters/_converter.py +413 -0
- sqlspec/core/parameters/_processor.py +341 -0
- sqlspec/core/parameters/_registry.py +201 -0
- sqlspec/core/parameters/_transformers.py +226 -0
- sqlspec/core/parameters/_types.py +430 -0
- sqlspec/core/parameters/_validator.py +123 -0
- sqlspec/core/pipeline.py +187 -0
- sqlspec/core/result.py +1124 -0
- sqlspec/core/splitter.py +940 -0
- sqlspec/core/stack.py +163 -0
- sqlspec/core/statement.py +835 -0
- sqlspec/core/type_conversion.py +235 -0
- sqlspec/driver/__init__.py +36 -0
- sqlspec/driver/_async.py +1027 -0
- sqlspec/driver/_common.py +1236 -0
- sqlspec/driver/_sync.py +1025 -0
- sqlspec/driver/mixins/__init__.py +7 -0
- sqlspec/driver/mixins/_result_tools.py +61 -0
- sqlspec/driver/mixins/_sql_translator.py +122 -0
- sqlspec/driver/mixins/_storage.py +311 -0
- sqlspec/exceptions.py +321 -0
- sqlspec/extensions/__init__.py +0 -0
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +471 -0
- sqlspec/extensions/fastapi/__init__.py +19 -0
- sqlspec/extensions/fastapi/extension.py +341 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +72 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +402 -0
- sqlspec/extensions/litestar/__init__.py +23 -0
- sqlspec/extensions/litestar/_utils.py +52 -0
- sqlspec/extensions/litestar/cli.py +92 -0
- sqlspec/extensions/litestar/config.py +90 -0
- sqlspec/extensions/litestar/handlers.py +316 -0
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +638 -0
- sqlspec/extensions/litestar/providers.py +454 -0
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/extensions/otel/__init__.py +58 -0
- sqlspec/extensions/prometheus/__init__.py +107 -0
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +26 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +257 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/loader.py +716 -0
- sqlspec/migrations/__init__.py +36 -0
- sqlspec/migrations/base.py +728 -0
- sqlspec/migrations/commands.py +1140 -0
- sqlspec/migrations/context.py +142 -0
- sqlspec/migrations/fix.py +203 -0
- sqlspec/migrations/loaders.py +450 -0
- sqlspec/migrations/runner.py +1024 -0
- sqlspec/migrations/templates.py +234 -0
- sqlspec/migrations/tracker.py +403 -0
- sqlspec/migrations/utils.py +256 -0
- sqlspec/migrations/validation.py +203 -0
- sqlspec/observability/__init__.py +22 -0
- sqlspec/observability/_config.py +228 -0
- sqlspec/observability/_diagnostics.py +67 -0
- sqlspec/observability/_dispatcher.py +151 -0
- sqlspec/observability/_observer.py +180 -0
- sqlspec/observability/_runtime.py +381 -0
- sqlspec/observability/_spans.py +158 -0
- sqlspec/protocols.py +530 -0
- sqlspec/py.typed +0 -0
- sqlspec/storage/__init__.py +46 -0
- sqlspec/storage/_utils.py +104 -0
- sqlspec/storage/backends/__init__.py +1 -0
- sqlspec/storage/backends/base.py +163 -0
- sqlspec/storage/backends/fsspec.py +398 -0
- sqlspec/storage/backends/local.py +377 -0
- sqlspec/storage/backends/obstore.py +580 -0
- sqlspec/storage/errors.py +104 -0
- sqlspec/storage/pipeline.py +604 -0
- sqlspec/storage/registry.py +289 -0
- sqlspec/typing.py +219 -0
- sqlspec/utils/__init__.py +31 -0
- sqlspec/utils/arrow_helpers.py +95 -0
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/correlation.py +132 -0
- sqlspec/utils/data_transformation.py +114 -0
- sqlspec/utils/dependencies.py +79 -0
- sqlspec/utils/deprecation.py +113 -0
- sqlspec/utils/fixtures.py +250 -0
- sqlspec/utils/logging.py +172 -0
- sqlspec/utils/module_loader.py +273 -0
- sqlspec/utils/portal.py +325 -0
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +396 -0
- sqlspec/utils/singleton.py +41 -0
- sqlspec/utils/sync_tools.py +277 -0
- sqlspec/utils/text.py +108 -0
- sqlspec/utils/type_converters.py +99 -0
- sqlspec/utils/type_guards.py +1324 -0
- sqlspec/utils/version.py +444 -0
- sqlspec-0.32.0.dist-info/METADATA +202 -0
- sqlspec-0.32.0.dist-info/RECORD +262 -0
- sqlspec-0.32.0.dist-info/WHEEL +4 -0
- sqlspec-0.32.0.dist-info/entry_points.txt +2 -0
- sqlspec-0.32.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,892 @@
|
|
|
1
|
+
"""Psqlpy driver implementation for PostgreSQL connectivity.
|
|
2
|
+
|
|
3
|
+
Provides parameter style conversion, type coercion, error handling,
|
|
4
|
+
and transaction management.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import datetime
|
|
8
|
+
import decimal
|
|
9
|
+
import inspect
|
|
10
|
+
import io
|
|
11
|
+
import re
|
|
12
|
+
import uuid
|
|
13
|
+
from typing import TYPE_CHECKING, Any, Final, cast
|
|
14
|
+
|
|
15
|
+
import psqlpy.exceptions
|
|
16
|
+
from psqlpy.extra_types import JSONB
|
|
17
|
+
|
|
18
|
+
from sqlspec.adapters.psqlpy.data_dictionary import PsqlpyAsyncDataDictionary
|
|
19
|
+
from sqlspec.adapters.psqlpy.type_converter import PostgreSQLTypeConverter
|
|
20
|
+
from sqlspec.core import (
|
|
21
|
+
SQL,
|
|
22
|
+
DriverParameterProfile,
|
|
23
|
+
ParameterStyle,
|
|
24
|
+
ParameterStyleConfig,
|
|
25
|
+
StatementConfig,
|
|
26
|
+
build_statement_config_from_profile,
|
|
27
|
+
get_cache_config,
|
|
28
|
+
register_driver_profile,
|
|
29
|
+
)
|
|
30
|
+
from sqlspec.driver import AsyncDriverAdapterBase
|
|
31
|
+
from sqlspec.exceptions import (
|
|
32
|
+
CheckViolationError,
|
|
33
|
+
DatabaseConnectionError,
|
|
34
|
+
DataError,
|
|
35
|
+
ForeignKeyViolationError,
|
|
36
|
+
IntegrityError,
|
|
37
|
+
NotNullViolationError,
|
|
38
|
+
OperationalError,
|
|
39
|
+
SQLParsingError,
|
|
40
|
+
SQLSpecError,
|
|
41
|
+
TransactionError,
|
|
42
|
+
UniqueViolationError,
|
|
43
|
+
)
|
|
44
|
+
from sqlspec.typing import Empty
|
|
45
|
+
from sqlspec.utils.logging import get_logger
|
|
46
|
+
from sqlspec.utils.serializers import to_json
|
|
47
|
+
from sqlspec.utils.type_converters import build_nested_decimal_normalizer
|
|
48
|
+
|
|
49
|
+
if TYPE_CHECKING:
|
|
50
|
+
from collections.abc import Callable
|
|
51
|
+
from contextlib import AbstractAsyncContextManager
|
|
52
|
+
|
|
53
|
+
from sqlspec.adapters.psqlpy._types import PsqlpyConnection
|
|
54
|
+
from sqlspec.core import ArrowResult, SQLResult
|
|
55
|
+
from sqlspec.driver import ExecutionResult
|
|
56
|
+
from sqlspec.driver._async import AsyncDataDictionaryBase
|
|
57
|
+
from sqlspec.storage import (
|
|
58
|
+
AsyncStoragePipeline,
|
|
59
|
+
StorageBridgeJob,
|
|
60
|
+
StorageDestination,
|
|
61
|
+
StorageFormat,
|
|
62
|
+
StorageTelemetry,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
__all__ = (
|
|
66
|
+
"PsqlpyCursor",
|
|
67
|
+
"PsqlpyDriver",
|
|
68
|
+
"PsqlpyExceptionHandler",
|
|
69
|
+
"build_psqlpy_statement_config",
|
|
70
|
+
"psqlpy_statement_config",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
logger = get_logger("adapters.psqlpy")
|
|
74
|
+
|
|
75
|
+
_type_converter = PostgreSQLTypeConverter()
|
|
76
|
+
|
|
77
|
+
PSQLPY_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE)
|
|
78
|
+
|
|
79
|
+
_JSON_CASTS: Final[frozenset[str]] = frozenset({"JSON", "JSONB"})
|
|
80
|
+
_TIMESTAMP_CASTS: Final[frozenset[str]] = frozenset({
|
|
81
|
+
"TIMESTAMP",
|
|
82
|
+
"TIMESTAMPTZ",
|
|
83
|
+
"TIMESTAMP WITH TIME ZONE",
|
|
84
|
+
"TIMESTAMP WITHOUT TIME ZONE",
|
|
85
|
+
})
|
|
86
|
+
_UUID_CASTS: Final[frozenset[str]] = frozenset({"UUID"})
|
|
87
|
+
_DECIMAL_NORMALIZER = build_nested_decimal_normalizer(mode="float")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class PsqlpyCursor:
|
|
91
|
+
"""Context manager for psqlpy cursor management."""
|
|
92
|
+
|
|
93
|
+
__slots__ = ("_in_use", "connection")
|
|
94
|
+
|
|
95
|
+
def __init__(self, connection: "PsqlpyConnection") -> None:
|
|
96
|
+
self.connection = connection
|
|
97
|
+
self._in_use = False
|
|
98
|
+
|
|
99
|
+
async def __aenter__(self) -> "PsqlpyConnection":
|
|
100
|
+
"""Enter cursor context.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Psqlpy connection object
|
|
104
|
+
"""
|
|
105
|
+
self._in_use = True
|
|
106
|
+
return self.connection
|
|
107
|
+
|
|
108
|
+
async def __aexit__(self, *_: Any) -> None:
|
|
109
|
+
"""Exit cursor context.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
exc_type: Exception type
|
|
113
|
+
exc_val: Exception value
|
|
114
|
+
exc_tb: Exception traceback
|
|
115
|
+
"""
|
|
116
|
+
self._in_use = False
|
|
117
|
+
|
|
118
|
+
def is_in_use(self) -> bool:
|
|
119
|
+
"""Check if cursor is currently in use.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
True if cursor is in use, False otherwise
|
|
123
|
+
"""
|
|
124
|
+
return self._in_use
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class PsqlpyExceptionHandler:
|
|
128
|
+
"""Async context manager for handling psqlpy database exceptions.
|
|
129
|
+
|
|
130
|
+
Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
|
|
131
|
+
for better error handling in application code.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
__slots__ = ()
|
|
135
|
+
|
|
136
|
+
async def __aenter__(self) -> None:
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
140
|
+
if exc_type is None:
|
|
141
|
+
return
|
|
142
|
+
if issubclass(exc_type, (psqlpy.exceptions.DatabaseError, psqlpy.exceptions.Error)):
|
|
143
|
+
self._map_postgres_exception(exc_val)
|
|
144
|
+
|
|
145
|
+
def _map_postgres_exception(self, e: Any) -> None:
|
|
146
|
+
"""Map PostgreSQL exception to SQLSpec exception.
|
|
147
|
+
|
|
148
|
+
psqlpy does not expose SQLSTATE codes directly, so we use message-based
|
|
149
|
+
detection to map exceptions.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
e: psqlpy exception instance
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
Specific SQLSpec exception based on error message patterns
|
|
156
|
+
"""
|
|
157
|
+
error_msg = str(e).lower()
|
|
158
|
+
|
|
159
|
+
if "unique" in error_msg or "duplicate key" in error_msg:
|
|
160
|
+
self._raise_unique_violation(e, None)
|
|
161
|
+
elif "foreign key" in error_msg or "violates foreign key" in error_msg:
|
|
162
|
+
self._raise_foreign_key_violation(e, None)
|
|
163
|
+
elif "not null" in error_msg or ("null value" in error_msg and "violates not-null" in error_msg):
|
|
164
|
+
self._raise_not_null_violation(e, None)
|
|
165
|
+
elif "check constraint" in error_msg or "violates check constraint" in error_msg:
|
|
166
|
+
self._raise_check_violation(e, None)
|
|
167
|
+
elif "constraint" in error_msg:
|
|
168
|
+
self._raise_integrity_error(e, None)
|
|
169
|
+
elif "syntax error" in error_msg or "parse" in error_msg:
|
|
170
|
+
self._raise_parsing_error(e, None)
|
|
171
|
+
elif "connection" in error_msg or "could not connect" in error_msg:
|
|
172
|
+
self._raise_connection_error(e, None)
|
|
173
|
+
elif "deadlock" in error_msg or "serialization failure" in error_msg:
|
|
174
|
+
self._raise_transaction_error(e, None)
|
|
175
|
+
else:
|
|
176
|
+
self._raise_generic_error(e, None)
|
|
177
|
+
|
|
178
|
+
def _raise_unique_violation(self, e: Any, code: "str | None") -> None:
|
|
179
|
+
msg = f"PostgreSQL unique constraint violation: {e}"
|
|
180
|
+
raise UniqueViolationError(msg) from e
|
|
181
|
+
|
|
182
|
+
def _raise_foreign_key_violation(self, e: Any, code: "str | None") -> None:
|
|
183
|
+
msg = f"PostgreSQL foreign key constraint violation: {e}"
|
|
184
|
+
raise ForeignKeyViolationError(msg) from e
|
|
185
|
+
|
|
186
|
+
def _raise_not_null_violation(self, e: Any, code: "str | None") -> None:
|
|
187
|
+
msg = f"PostgreSQL not-null constraint violation: {e}"
|
|
188
|
+
raise NotNullViolationError(msg) from e
|
|
189
|
+
|
|
190
|
+
def _raise_check_violation(self, e: Any, code: "str | None") -> None:
|
|
191
|
+
msg = f"PostgreSQL check constraint violation: {e}"
|
|
192
|
+
raise CheckViolationError(msg) from e
|
|
193
|
+
|
|
194
|
+
def _raise_integrity_error(self, e: Any, code: "str | None") -> None:
|
|
195
|
+
msg = f"PostgreSQL integrity constraint violation: {e}"
|
|
196
|
+
raise IntegrityError(msg) from e
|
|
197
|
+
|
|
198
|
+
def _raise_parsing_error(self, e: Any, code: "str | None") -> None:
|
|
199
|
+
msg = f"PostgreSQL SQL syntax error: {e}"
|
|
200
|
+
raise SQLParsingError(msg) from e
|
|
201
|
+
|
|
202
|
+
def _raise_connection_error(self, e: Any, code: "str | None") -> None:
|
|
203
|
+
msg = f"PostgreSQL connection error: {e}"
|
|
204
|
+
raise DatabaseConnectionError(msg) from e
|
|
205
|
+
|
|
206
|
+
def _raise_transaction_error(self, e: Any, code: "str | None") -> None:
|
|
207
|
+
msg = f"PostgreSQL transaction error: {e}"
|
|
208
|
+
raise TransactionError(msg) from e
|
|
209
|
+
|
|
210
|
+
def _raise_data_error(self, e: Any, code: "str | None") -> None:
|
|
211
|
+
msg = f"PostgreSQL data error: {e}"
|
|
212
|
+
raise DataError(msg) from e
|
|
213
|
+
|
|
214
|
+
def _raise_operational_error(self, e: Any, code: "str | None") -> None:
|
|
215
|
+
msg = f"PostgreSQL operational error: {e}"
|
|
216
|
+
raise OperationalError(msg) from e
|
|
217
|
+
|
|
218
|
+
def _raise_generic_error(self, e: Any, code: "str | None") -> None:
|
|
219
|
+
msg = f"PostgreSQL database error: {e}"
|
|
220
|
+
raise SQLSpecError(msg) from e
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class PsqlpyDriver(AsyncDriverAdapterBase):
|
|
224
|
+
"""PostgreSQL driver implementation using psqlpy.
|
|
225
|
+
|
|
226
|
+
Provides parameter style conversion, type coercion, error handling,
|
|
227
|
+
and transaction management.
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
__slots__ = ("_data_dictionary",)
|
|
231
|
+
dialect = "postgres"
|
|
232
|
+
|
|
233
|
+
def __init__(
|
|
234
|
+
self,
|
|
235
|
+
connection: "PsqlpyConnection",
|
|
236
|
+
statement_config: "StatementConfig | None" = None,
|
|
237
|
+
driver_features: "dict[str, Any] | None" = None,
|
|
238
|
+
) -> None:
|
|
239
|
+
if statement_config is None:
|
|
240
|
+
cache_config = get_cache_config()
|
|
241
|
+
statement_config = psqlpy_statement_config.replace(enable_caching=cache_config.compiled_cache_enabled)
|
|
242
|
+
|
|
243
|
+
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
|
|
244
|
+
self._data_dictionary: AsyncDataDictionaryBase | None = None
|
|
245
|
+
|
|
246
|
+
def prepare_driver_parameters(
|
|
247
|
+
self,
|
|
248
|
+
parameters: Any,
|
|
249
|
+
statement_config: "StatementConfig",
|
|
250
|
+
is_many: bool = False,
|
|
251
|
+
prepared_statement: Any | None = None,
|
|
252
|
+
) -> Any:
|
|
253
|
+
"""Prepare parameters with cast-aware type coercion for psqlpy.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
parameters: Parameters in any format
|
|
257
|
+
statement_config: Statement configuration
|
|
258
|
+
is_many: Whether this is for execute_many operation
|
|
259
|
+
prepared_statement: Prepared statement containing the original SQL statement
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Parameters with cast-aware type coercion applied
|
|
263
|
+
"""
|
|
264
|
+
enable_cast_detection = self.driver_features.get("enable_cast_detection", True)
|
|
265
|
+
|
|
266
|
+
if enable_cast_detection and prepared_statement and self.dialect in {"postgres", "postgresql"} and not is_many:
|
|
267
|
+
parameter_casts = self._get_parameter_casts(prepared_statement)
|
|
268
|
+
prepared = self._prepare_parameters_with_casts(parameters, parameter_casts, statement_config)
|
|
269
|
+
else:
|
|
270
|
+
prepared = super().prepare_driver_parameters(parameters, statement_config, is_many, prepared_statement)
|
|
271
|
+
|
|
272
|
+
if not is_many and isinstance(prepared, list):
|
|
273
|
+
prepared = tuple(prepared)
|
|
274
|
+
|
|
275
|
+
if not is_many and isinstance(prepared, tuple):
|
|
276
|
+
return tuple(_normalize_scalar_parameter(item) for item in prepared)
|
|
277
|
+
|
|
278
|
+
return prepared
|
|
279
|
+
|
|
280
|
+
def _get_parameter_casts(self, statement: SQL) -> "dict[int, str]":
|
|
281
|
+
"""Get parameter cast metadata from compiled statement.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
statement: SQL statement with compiled metadata
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
Dict mapping parameter positions to cast types
|
|
288
|
+
"""
|
|
289
|
+
processed_state = statement.get_processed_state()
|
|
290
|
+
if processed_state is not Empty:
|
|
291
|
+
return processed_state.parameter_casts or {}
|
|
292
|
+
return {}
|
|
293
|
+
|
|
294
|
+
def _prepare_parameters_with_casts(
|
|
295
|
+
self, parameters: Any, parameter_casts: "dict[int, str]", statement_config: "StatementConfig"
|
|
296
|
+
) -> Any:
|
|
297
|
+
"""Prepare parameters with cast-aware type coercion.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
parameters: Parameter values (list, tuple, or scalar)
|
|
301
|
+
parameter_casts: Mapping of parameter positions to cast types
|
|
302
|
+
statement_config: Statement configuration for type coercion
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Parameters with cast-aware type coercion applied
|
|
306
|
+
"""
|
|
307
|
+
if isinstance(parameters, (list, tuple)):
|
|
308
|
+
result: list[Any] = []
|
|
309
|
+
serializer = statement_config.parameter_config.json_serializer or to_json
|
|
310
|
+
type_map = statement_config.parameter_config.type_coercion_map
|
|
311
|
+
for idx, param in enumerate(parameters, start=1):
|
|
312
|
+
cast_type = parameter_casts.get(idx, "")
|
|
313
|
+
prepared_value = param
|
|
314
|
+
if type_map:
|
|
315
|
+
for type_check, converter in type_map.items():
|
|
316
|
+
if isinstance(prepared_value, type_check):
|
|
317
|
+
prepared_value = converter(prepared_value)
|
|
318
|
+
break
|
|
319
|
+
if cast_type:
|
|
320
|
+
prepared_value = _coerce_parameter_for_cast(prepared_value, cast_type, serializer)
|
|
321
|
+
result.append(prepared_value)
|
|
322
|
+
return tuple(result) if isinstance(parameters, tuple) else result
|
|
323
|
+
return parameters
|
|
324
|
+
|
|
325
|
+
def with_cursor(self, connection: "PsqlpyConnection") -> "PsqlpyCursor":
|
|
326
|
+
"""Create context manager for psqlpy cursor.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
connection: Psqlpy connection object
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
PsqlpyCursor context manager
|
|
333
|
+
"""
|
|
334
|
+
return PsqlpyCursor(connection)
|
|
335
|
+
|
|
336
|
+
def handle_database_exceptions(self) -> "AbstractAsyncContextManager[None]":
|
|
337
|
+
"""Handle database-specific exceptions.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Exception handler context manager
|
|
341
|
+
"""
|
|
342
|
+
return PsqlpyExceptionHandler()
|
|
343
|
+
|
|
344
|
+
async def _try_special_handling(self, cursor: "PsqlpyConnection", statement: SQL) -> "SQLResult | None":
|
|
345
|
+
"""Hook for psqlpy-specific special operations.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
cursor: Psqlpy connection object
|
|
349
|
+
statement: SQL statement to analyze
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
SQLResult if special handling applied, None otherwise
|
|
353
|
+
"""
|
|
354
|
+
_ = (cursor, statement)
|
|
355
|
+
return None
|
|
356
|
+
|
|
357
|
+
async def _execute_script(self, cursor: "PsqlpyConnection", statement: SQL) -> "ExecutionResult":
|
|
358
|
+
"""Execute SQL script with statement splitting.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
cursor: Psqlpy connection object
|
|
362
|
+
statement: SQL statement with script content
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
ExecutionResult with script execution metadata
|
|
366
|
+
|
|
367
|
+
Notes:
|
|
368
|
+
Uses execute() with empty parameters for each statement instead of execute_batch().
|
|
369
|
+
execute_batch() uses simple query protocol which can break subsequent queries
|
|
370
|
+
that rely on extended protocol (e.g., information_schema queries with name type).
|
|
371
|
+
"""
|
|
372
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
373
|
+
statement_config = statement.statement_config
|
|
374
|
+
statements = self.split_script_statements(sql, statement_config, strip_trailing_semicolon=True)
|
|
375
|
+
|
|
376
|
+
successful_count = 0
|
|
377
|
+
last_result = None
|
|
378
|
+
|
|
379
|
+
for stmt in statements:
|
|
380
|
+
last_result = await cursor.execute(stmt, prepared_parameters or [])
|
|
381
|
+
successful_count += 1
|
|
382
|
+
|
|
383
|
+
return self.create_execution_result(
|
|
384
|
+
last_result, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
async def _execute_many(self, cursor: "PsqlpyConnection", statement: SQL) -> "ExecutionResult":
|
|
388
|
+
"""Execute SQL with multiple parameter sets.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
cursor: Psqlpy connection object
|
|
392
|
+
statement: SQL statement with multiple parameter sets
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
ExecutionResult with batch execution metadata
|
|
396
|
+
"""
|
|
397
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
398
|
+
|
|
399
|
+
if not prepared_parameters:
|
|
400
|
+
return self.create_execution_result(cursor, rowcount_override=0, is_many_result=True)
|
|
401
|
+
|
|
402
|
+
driver_parameters = self.prepare_driver_parameters(
|
|
403
|
+
prepared_parameters, self.statement_config, is_many=True, prepared_statement=statement
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
operation_type = statement.operation_type
|
|
407
|
+
should_coerce = operation_type != "SELECT"
|
|
408
|
+
|
|
409
|
+
formatted_parameters = []
|
|
410
|
+
for param_set in driver_parameters:
|
|
411
|
+
values = list(param_set) if isinstance(param_set, (list, tuple)) else [param_set]
|
|
412
|
+
|
|
413
|
+
if should_coerce:
|
|
414
|
+
values = list(_coerce_numeric_for_write(values))
|
|
415
|
+
|
|
416
|
+
formatted_parameters.append(values)
|
|
417
|
+
|
|
418
|
+
await cursor.execute_many(sql, formatted_parameters)
|
|
419
|
+
|
|
420
|
+
rows_affected = len(formatted_parameters)
|
|
421
|
+
|
|
422
|
+
return self.create_execution_result(cursor, rowcount_override=rows_affected, is_many_result=True)
|
|
423
|
+
|
|
424
|
+
async def _execute_statement(self, cursor: "PsqlpyConnection", statement: SQL) -> "ExecutionResult":
|
|
425
|
+
"""Execute single SQL statement.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
cursor: Psqlpy connection object
|
|
429
|
+
statement: SQL statement to execute
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
ExecutionResult with execution metadata
|
|
433
|
+
"""
|
|
434
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
435
|
+
|
|
436
|
+
driver_parameters = prepared_parameters
|
|
437
|
+
operation_type = statement.operation_type
|
|
438
|
+
should_coerce = operation_type != "SELECT"
|
|
439
|
+
effective_parameters = _coerce_numeric_for_write(driver_parameters) if should_coerce else driver_parameters
|
|
440
|
+
|
|
441
|
+
if statement.returns_rows():
|
|
442
|
+
query_result = await cursor.fetch(sql, effective_parameters or [])
|
|
443
|
+
dict_rows: list[dict[str, Any]] = query_result.result() if query_result else []
|
|
444
|
+
|
|
445
|
+
return self.create_execution_result(
|
|
446
|
+
cursor,
|
|
447
|
+
selected_data=dict_rows,
|
|
448
|
+
column_names=list(dict_rows[0].keys()) if dict_rows else [],
|
|
449
|
+
data_row_count=len(dict_rows),
|
|
450
|
+
is_select_result=True,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
result = await cursor.execute(sql, effective_parameters or [])
|
|
454
|
+
rows_affected = self._extract_rows_affected(result)
|
|
455
|
+
|
|
456
|
+
return self.create_execution_result(cursor, rowcount_override=rows_affected)
|
|
457
|
+
|
|
458
|
+
def _extract_rows_affected(self, result: Any) -> int:
|
|
459
|
+
"""Extract rows affected from psqlpy result.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
result: Psqlpy execution result object
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
Number of rows affected, -1 if unable to determine
|
|
466
|
+
"""
|
|
467
|
+
try:
|
|
468
|
+
if hasattr(result, "tag") and result.tag:
|
|
469
|
+
return self._parse_command_tag(result.tag)
|
|
470
|
+
if hasattr(result, "status") and result.status:
|
|
471
|
+
return self._parse_command_tag(result.status)
|
|
472
|
+
if isinstance(result, str):
|
|
473
|
+
return self._parse_command_tag(result)
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.debug("Failed to parse psqlpy command tag: %s", e)
|
|
476
|
+
return -1
|
|
477
|
+
|
|
478
|
+
def _parse_command_tag(self, tag: str) -> int:
|
|
479
|
+
"""Parse PostgreSQL command tag to extract rows affected.
|
|
480
|
+
|
|
481
|
+
Args:
|
|
482
|
+
tag: PostgreSQL command tag string
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
Number of rows affected, -1 if unable to parse
|
|
486
|
+
"""
|
|
487
|
+
if not tag:
|
|
488
|
+
return -1
|
|
489
|
+
|
|
490
|
+
match = PSQLPY_STATUS_REGEX.match(tag.strip())
|
|
491
|
+
if match:
|
|
492
|
+
command = match.group(1).upper()
|
|
493
|
+
if command == "INSERT" and match.group(3):
|
|
494
|
+
return int(match.group(3))
|
|
495
|
+
if command in {"UPDATE", "DELETE"} and match.group(3):
|
|
496
|
+
return int(match.group(3))
|
|
497
|
+
return -1
|
|
498
|
+
|
|
499
|
+
async def select_to_storage(
|
|
500
|
+
self,
|
|
501
|
+
statement: "SQL | str",
|
|
502
|
+
destination: "StorageDestination",
|
|
503
|
+
/,
|
|
504
|
+
*parameters: Any,
|
|
505
|
+
statement_config: "StatementConfig | None" = None,
|
|
506
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
507
|
+
format_hint: "StorageFormat | None" = None,
|
|
508
|
+
telemetry: "StorageTelemetry | None" = None,
|
|
509
|
+
**kwargs: Any,
|
|
510
|
+
) -> "StorageBridgeJob":
|
|
511
|
+
"""Execute a query and stream Arrow results to a storage backend."""
|
|
512
|
+
|
|
513
|
+
self._require_capability("arrow_export_enabled")
|
|
514
|
+
arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
515
|
+
async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline())
|
|
516
|
+
telemetry_payload = await self._write_result_to_storage_async(
|
|
517
|
+
arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline
|
|
518
|
+
)
|
|
519
|
+
self._attach_partition_telemetry(telemetry_payload, partitioner)
|
|
520
|
+
return self._create_storage_job(telemetry_payload, telemetry)
|
|
521
|
+
|
|
522
|
+
async def load_from_arrow(
|
|
523
|
+
self,
|
|
524
|
+
table: str,
|
|
525
|
+
source: "ArrowResult | Any",
|
|
526
|
+
*,
|
|
527
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
528
|
+
overwrite: bool = False,
|
|
529
|
+
telemetry: "StorageTelemetry | None" = None,
|
|
530
|
+
) -> "StorageBridgeJob":
|
|
531
|
+
"""Load Arrow-formatted data into PostgreSQL via psqlpy binary COPY."""
|
|
532
|
+
|
|
533
|
+
self._require_capability("arrow_import_enabled")
|
|
534
|
+
arrow_table = self._coerce_arrow_table(source)
|
|
535
|
+
if overwrite:
|
|
536
|
+
await self._truncate_table_async(table)
|
|
537
|
+
|
|
538
|
+
columns, records = self._arrow_table_to_rows(arrow_table)
|
|
539
|
+
if records:
|
|
540
|
+
schema_name, table_name = _split_schema_and_table(table)
|
|
541
|
+
async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor:
|
|
542
|
+
copy_kwargs: dict[str, Any] = {"columns": columns}
|
|
543
|
+
if schema_name:
|
|
544
|
+
copy_kwargs["schema_name"] = schema_name
|
|
545
|
+
try:
|
|
546
|
+
copy_payload = _encode_records_for_binary_copy(records)
|
|
547
|
+
copy_operation = cursor.binary_copy_to_table(copy_payload, table_name, **copy_kwargs)
|
|
548
|
+
if inspect.isawaitable(copy_operation):
|
|
549
|
+
await copy_operation
|
|
550
|
+
except (TypeError, psqlpy.exceptions.DatabaseError) as exc:
|
|
551
|
+
logger.debug("Binary COPY not available for psqlpy; falling back to INSERT statements: %s", exc)
|
|
552
|
+
insert_sql = _build_psqlpy_insert_statement(table, columns)
|
|
553
|
+
formatted_records = _coerce_records_for_execute_many(records)
|
|
554
|
+
insert_operation = cursor.execute_many(insert_sql, formatted_records)
|
|
555
|
+
if inspect.isawaitable(insert_operation):
|
|
556
|
+
await insert_operation
|
|
557
|
+
|
|
558
|
+
telemetry_payload = self._build_ingest_telemetry(arrow_table)
|
|
559
|
+
telemetry_payload["destination"] = table
|
|
560
|
+
self._attach_partition_telemetry(telemetry_payload, partitioner)
|
|
561
|
+
return self._create_storage_job(telemetry_payload, telemetry)
|
|
562
|
+
|
|
563
|
+
async def load_from_storage(
|
|
564
|
+
self,
|
|
565
|
+
table: str,
|
|
566
|
+
source: "StorageDestination",
|
|
567
|
+
*,
|
|
568
|
+
file_format: "StorageFormat",
|
|
569
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
570
|
+
overwrite: bool = False,
|
|
571
|
+
) -> "StorageBridgeJob":
|
|
572
|
+
"""Load staged artifacts from storage using the storage bridge pipeline."""
|
|
573
|
+
|
|
574
|
+
arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format)
|
|
575
|
+
return await self.load_from_arrow(
|
|
576
|
+
table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
async def begin(self) -> None:
|
|
580
|
+
"""Begin a database transaction."""
|
|
581
|
+
try:
|
|
582
|
+
await self.connection.execute("BEGIN")
|
|
583
|
+
except psqlpy.exceptions.DatabaseError as e:
|
|
584
|
+
msg = f"Failed to begin psqlpy transaction: {e}"
|
|
585
|
+
raise SQLSpecError(msg) from e
|
|
586
|
+
|
|
587
|
+
async def rollback(self) -> None:
|
|
588
|
+
"""Rollback the current transaction."""
|
|
589
|
+
try:
|
|
590
|
+
await self.connection.execute("ROLLBACK")
|
|
591
|
+
except psqlpy.exceptions.DatabaseError as e:
|
|
592
|
+
msg = f"Failed to rollback psqlpy transaction: {e}"
|
|
593
|
+
raise SQLSpecError(msg) from e
|
|
594
|
+
|
|
595
|
+
async def commit(self) -> None:
|
|
596
|
+
"""Commit the current transaction."""
|
|
597
|
+
try:
|
|
598
|
+
await self.connection.execute("COMMIT")
|
|
599
|
+
except psqlpy.exceptions.DatabaseError as e:
|
|
600
|
+
msg = f"Failed to commit psqlpy transaction: {e}"
|
|
601
|
+
raise SQLSpecError(msg) from e
|
|
602
|
+
|
|
603
|
+
async def _truncate_table_async(self, table: str) -> None:
|
|
604
|
+
qualified = _format_table_identifier(table)
|
|
605
|
+
async with self.handle_database_exceptions(), self.with_cursor(self.connection) as cursor:
|
|
606
|
+
await cursor.execute(f"TRUNCATE TABLE {qualified}")
|
|
607
|
+
|
|
608
|
+
@property
|
|
609
|
+
def data_dictionary(self) -> "AsyncDataDictionaryBase":
|
|
610
|
+
"""Get the data dictionary for this driver.
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
Data dictionary instance for metadata queries
|
|
614
|
+
"""
|
|
615
|
+
if self._data_dictionary is None:
|
|
616
|
+
self._data_dictionary = PsqlpyAsyncDataDictionary()
|
|
617
|
+
return self._data_dictionary
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def _coerce_json_parameter(value: Any, cast_type: str, serializer: "Callable[[Any], str]") -> Any:
|
|
621
|
+
"""Serialize JSON parameters according to the detected cast type.
|
|
622
|
+
|
|
623
|
+
Args:
|
|
624
|
+
value: Parameter value supplied by the caller.
|
|
625
|
+
cast_type: Uppercase cast identifier detected in SQL.
|
|
626
|
+
serializer: JSON serialization callable from statement config.
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
Serialized parameter suitable for driver execution.
|
|
630
|
+
|
|
631
|
+
Raises:
|
|
632
|
+
SQLSpecError: If serialization fails for JSON payloads.
|
|
633
|
+
"""
|
|
634
|
+
|
|
635
|
+
if value is None:
|
|
636
|
+
return None
|
|
637
|
+
if cast_type == "JSONB":
|
|
638
|
+
if isinstance(value, JSONB):
|
|
639
|
+
return value
|
|
640
|
+
if isinstance(value, dict):
|
|
641
|
+
return JSONB(value)
|
|
642
|
+
if isinstance(value, (list, tuple)):
|
|
643
|
+
return JSONB(list(value))
|
|
644
|
+
if isinstance(value, tuple):
|
|
645
|
+
return list(value)
|
|
646
|
+
if isinstance(value, (dict, list, str, JSONB)):
|
|
647
|
+
return value
|
|
648
|
+
try:
|
|
649
|
+
serialized_value = serializer(value)
|
|
650
|
+
except Exception as error:
|
|
651
|
+
msg = "Failed to serialize JSON parameter for psqlpy."
|
|
652
|
+
raise SQLSpecError(msg) from error
|
|
653
|
+
return serialized_value
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
def _coerce_uuid_parameter(value: Any) -> Any:
|
|
657
|
+
"""Convert UUID-compatible parameters to ``uuid.UUID`` instances.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
value: Parameter value supplied by the caller.
|
|
661
|
+
|
|
662
|
+
Returns:
|
|
663
|
+
``uuid.UUID`` instance when input is coercible, otherwise original value.
|
|
664
|
+
|
|
665
|
+
Raises:
|
|
666
|
+
SQLSpecError: If the value cannot be converted to ``uuid.UUID``.
|
|
667
|
+
"""
|
|
668
|
+
|
|
669
|
+
if isinstance(value, uuid.UUID):
|
|
670
|
+
return value
|
|
671
|
+
if isinstance(value, str):
|
|
672
|
+
try:
|
|
673
|
+
return uuid.UUID(value)
|
|
674
|
+
except ValueError as error:
|
|
675
|
+
msg = "Invalid UUID parameter for psqlpy."
|
|
676
|
+
raise SQLSpecError(msg) from error
|
|
677
|
+
return value
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def _coerce_timestamp_parameter(value: Any) -> Any:
|
|
681
|
+
"""Convert ISO-formatted timestamp strings to ``datetime.datetime``.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
value: Parameter value supplied by the caller.
|
|
685
|
+
|
|
686
|
+
Returns:
|
|
687
|
+
``datetime.datetime`` instance when conversion succeeds, otherwise original value.
|
|
688
|
+
|
|
689
|
+
Raises:
|
|
690
|
+
SQLSpecError: If the value cannot be parsed as an ISO timestamp.
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
if isinstance(value, datetime.datetime):
|
|
694
|
+
return value
|
|
695
|
+
if isinstance(value, str):
|
|
696
|
+
normalized_value = value[:-1] + "+00:00" if value.endswith("Z") else value
|
|
697
|
+
try:
|
|
698
|
+
return datetime.datetime.fromisoformat(normalized_value)
|
|
699
|
+
except ValueError as error:
|
|
700
|
+
msg = "Invalid ISO timestamp parameter for psqlpy."
|
|
701
|
+
raise SQLSpecError(msg) from error
|
|
702
|
+
return value
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def _coerce_parameter_for_cast(value: Any, cast_type: str, serializer: "Callable[[Any], str]") -> Any:
|
|
706
|
+
"""Apply cast-aware coercion for psqlpy parameters.
|
|
707
|
+
|
|
708
|
+
Args:
|
|
709
|
+
value: Parameter value supplied by the caller.
|
|
710
|
+
cast_type: Uppercase cast identifier detected in SQL.
|
|
711
|
+
serializer: JSON serialization callable from statement config.
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
Coerced value appropriate for the specified cast, or the original value.
|
|
715
|
+
"""
|
|
716
|
+
|
|
717
|
+
upper_cast = cast_type.upper()
|
|
718
|
+
if upper_cast in _JSON_CASTS:
|
|
719
|
+
return _coerce_json_parameter(value, upper_cast, serializer)
|
|
720
|
+
if upper_cast in _UUID_CASTS:
|
|
721
|
+
return _coerce_uuid_parameter(value)
|
|
722
|
+
if upper_cast in _TIMESTAMP_CASTS:
|
|
723
|
+
return _coerce_timestamp_parameter(value)
|
|
724
|
+
return value
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def _prepare_dict_parameter(value: "dict[str, Any]") -> dict[str, Any]:
|
|
728
|
+
normalized = _DECIMAL_NORMALIZER(value)
|
|
729
|
+
return normalized if isinstance(normalized, dict) else value
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
def _prepare_list_parameter(value: "list[Any]") -> list[Any]:
|
|
733
|
+
return [_DECIMAL_NORMALIZER(item) for item in value]
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
def _prepare_tuple_parameter(value: "tuple[Any, ...]") -> tuple[Any, ...]:
|
|
737
|
+
return tuple(_DECIMAL_NORMALIZER(item) for item in value)
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def _normalize_scalar_parameter(value: Any) -> Any:
|
|
741
|
+
return value
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def _coerce_numeric_for_write(value: Any) -> Any:
|
|
745
|
+
if isinstance(value, float):
|
|
746
|
+
return decimal.Decimal(str(value))
|
|
747
|
+
if isinstance(value, decimal.Decimal):
|
|
748
|
+
return value
|
|
749
|
+
if isinstance(value, list):
|
|
750
|
+
return [_coerce_numeric_for_write(item) for item in value]
|
|
751
|
+
if isinstance(value, tuple):
|
|
752
|
+
coerced = [_coerce_numeric_for_write(item) for item in value]
|
|
753
|
+
return tuple(coerced)
|
|
754
|
+
if isinstance(value, dict):
|
|
755
|
+
return {key: _coerce_numeric_for_write(item) for key, item in value.items()}
|
|
756
|
+
return value
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def _escape_copy_text(value: str) -> str:
|
|
760
|
+
return value.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n").replace("\r", "\\r")
|
|
761
|
+
|
|
762
|
+
|
|
763
|
+
def _format_copy_value(value: Any) -> str:
|
|
764
|
+
if value is None:
|
|
765
|
+
return r"\N"
|
|
766
|
+
if isinstance(value, bool):
|
|
767
|
+
return "t" if value else "f"
|
|
768
|
+
if isinstance(value, (datetime.date, datetime.datetime, datetime.time)):
|
|
769
|
+
return value.isoformat()
|
|
770
|
+
if isinstance(value, (list, tuple, dict)):
|
|
771
|
+
return to_json(value)
|
|
772
|
+
if isinstance(value, (bytes, bytearray)):
|
|
773
|
+
return value.decode("utf-8")
|
|
774
|
+
return str(_coerce_numeric_for_write(value))
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def _encode_records_for_binary_copy(records: "list[tuple[Any, ...]]") -> bytes:
|
|
778
|
+
"""Encode row tuples into a bytes payload compatible with binary_copy_to_table.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
records: Sequence of row tuples extracted from the Arrow table.
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
UTF-8 encoded bytes buffer representing the COPY payload.
|
|
785
|
+
"""
|
|
786
|
+
|
|
787
|
+
buffer = io.StringIO()
|
|
788
|
+
for record in records:
|
|
789
|
+
encoded_columns = [_escape_copy_text(_format_copy_value(value)) for value in record]
|
|
790
|
+
buffer.write("\t".join(encoded_columns))
|
|
791
|
+
buffer.write("\n")
|
|
792
|
+
return buffer.getvalue().encode("utf-8")
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def _split_schema_and_table(identifier: str) -> "tuple[str | None, str]":
|
|
796
|
+
cleaned = identifier.strip()
|
|
797
|
+
if not cleaned:
|
|
798
|
+
msg = "Table name must not be empty"
|
|
799
|
+
raise SQLSpecError(msg)
|
|
800
|
+
if "." not in cleaned:
|
|
801
|
+
return None, cleaned.strip('"')
|
|
802
|
+
parts = [part for part in cleaned.split(".") if part]
|
|
803
|
+
if len(parts) == 1:
|
|
804
|
+
return None, parts[0].strip('"')
|
|
805
|
+
schema_name = ".".join(parts[:-1]).strip('"')
|
|
806
|
+
table_name = parts[-1].strip('"')
|
|
807
|
+
if not table_name:
|
|
808
|
+
msg = "Table name must not be empty"
|
|
809
|
+
raise SQLSpecError(msg)
|
|
810
|
+
return schema_name or None, table_name
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
def _quote_identifier(identifier: str) -> str:
|
|
814
|
+
normalized = identifier.replace('"', '""')
|
|
815
|
+
return f'"{normalized}"'
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def _format_table_identifier(identifier: str) -> str:
|
|
819
|
+
schema_name, table_name = _split_schema_and_table(identifier)
|
|
820
|
+
if schema_name:
|
|
821
|
+
return f"{_quote_identifier(schema_name)}.{_quote_identifier(table_name)}"
|
|
822
|
+
return _quote_identifier(table_name)
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def _build_psqlpy_insert_statement(table: str, columns: "list[str]") -> str:
|
|
826
|
+
column_clause = ", ".join(_quote_identifier(column) for column in columns)
|
|
827
|
+
placeholders = ", ".join(f"${index}" for index in range(1, len(columns) + 1))
|
|
828
|
+
return f"INSERT INTO {_format_table_identifier(table)} ({column_clause}) VALUES ({placeholders})"
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
def _coerce_records_for_execute_many(records: "list[tuple[Any, ...]]") -> "list[list[Any]]":
|
|
832
|
+
formatted_records: list[list[Any]] = []
|
|
833
|
+
for record in records:
|
|
834
|
+
coerced = _coerce_numeric_for_write(record)
|
|
835
|
+
if isinstance(coerced, tuple):
|
|
836
|
+
formatted_records.append(list(coerced))
|
|
837
|
+
elif isinstance(coerced, list):
|
|
838
|
+
formatted_records.append(coerced)
|
|
839
|
+
else:
|
|
840
|
+
formatted_records.append([coerced])
|
|
841
|
+
return formatted_records
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
def _build_psqlpy_profile() -> DriverParameterProfile:
|
|
845
|
+
"""Create the psqlpy driver parameter profile."""
|
|
846
|
+
|
|
847
|
+
return DriverParameterProfile(
|
|
848
|
+
name="Psqlpy",
|
|
849
|
+
default_style=ParameterStyle.NUMERIC,
|
|
850
|
+
supported_styles={ParameterStyle.NUMERIC, ParameterStyle.NAMED_DOLLAR, ParameterStyle.QMARK},
|
|
851
|
+
default_execution_style=ParameterStyle.NUMERIC,
|
|
852
|
+
supported_execution_styles={ParameterStyle.NUMERIC},
|
|
853
|
+
has_native_list_expansion=False,
|
|
854
|
+
preserve_parameter_format=True,
|
|
855
|
+
needs_static_script_compilation=False,
|
|
856
|
+
allow_mixed_parameter_styles=False,
|
|
857
|
+
preserve_original_params_for_many=False,
|
|
858
|
+
json_serializer_strategy="helper",
|
|
859
|
+
custom_type_coercions={decimal.Decimal: float},
|
|
860
|
+
default_dialect="postgres",
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
_PSQLPY_PROFILE = _build_psqlpy_profile()
|
|
865
|
+
|
|
866
|
+
register_driver_profile("psqlpy", _PSQLPY_PROFILE)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def _create_psqlpy_parameter_config(serializer: "Callable[[Any], str]") -> ParameterStyleConfig:
|
|
870
|
+
base_config = build_statement_config_from_profile(_PSQLPY_PROFILE, json_serializer=serializer).parameter_config
|
|
871
|
+
|
|
872
|
+
updated_type_map = dict(base_config.type_coercion_map)
|
|
873
|
+
updated_type_map[dict] = _prepare_dict_parameter
|
|
874
|
+
updated_type_map[list] = _prepare_list_parameter
|
|
875
|
+
updated_type_map[tuple] = _prepare_tuple_parameter
|
|
876
|
+
|
|
877
|
+
return base_config.replace(type_coercion_map=updated_type_map)
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
def build_psqlpy_statement_config(*, json_serializer: "Callable[[Any], str]" = to_json) -> StatementConfig:
|
|
881
|
+
parameter_config = _create_psqlpy_parameter_config(json_serializer)
|
|
882
|
+
return StatementConfig(
|
|
883
|
+
dialect="postgres",
|
|
884
|
+
parameter_config=parameter_config,
|
|
885
|
+
enable_parsing=True,
|
|
886
|
+
enable_validation=True,
|
|
887
|
+
enable_caching=True,
|
|
888
|
+
enable_parameter_type_wrapping=True,
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
psqlpy_statement_config = build_psqlpy_statement_config()
|