sqlspec 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlspec/__init__.py +104 -0
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +14 -0
- sqlspec/_serialization.py +312 -0
- sqlspec/_typing.py +784 -0
- sqlspec/adapters/__init__.py +0 -0
- sqlspec/adapters/adbc/__init__.py +5 -0
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/adk/__init__.py +5 -0
- sqlspec/adapters/adbc/adk/store.py +880 -0
- sqlspec/adapters/adbc/config.py +436 -0
- sqlspec/adapters/adbc/data_dictionary.py +537 -0
- sqlspec/adapters/adbc/driver.py +841 -0
- sqlspec/adapters/adbc/litestar/__init__.py +5 -0
- sqlspec/adapters/adbc/litestar/store.py +504 -0
- sqlspec/adapters/adbc/type_converter.py +153 -0
- sqlspec/adapters/aiosqlite/__init__.py +29 -0
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/adk/store.py +536 -0
- sqlspec/adapters/aiosqlite/config.py +310 -0
- sqlspec/adapters/aiosqlite/data_dictionary.py +260 -0
- sqlspec/adapters/aiosqlite/driver.py +463 -0
- sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
- sqlspec/adapters/aiosqlite/pool.py +500 -0
- sqlspec/adapters/asyncmy/__init__.py +25 -0
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
- sqlspec/adapters/asyncmy/adk/store.py +503 -0
- sqlspec/adapters/asyncmy/config.py +246 -0
- sqlspec/adapters/asyncmy/data_dictionary.py +241 -0
- sqlspec/adapters/asyncmy/driver.py +632 -0
- sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncmy/litestar/store.py +296 -0
- sqlspec/adapters/asyncpg/__init__.py +23 -0
- sqlspec/adapters/asyncpg/_type_handlers.py +76 -0
- sqlspec/adapters/asyncpg/_types.py +23 -0
- sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
- sqlspec/adapters/asyncpg/adk/store.py +460 -0
- sqlspec/adapters/asyncpg/config.py +464 -0
- sqlspec/adapters/asyncpg/data_dictionary.py +321 -0
- sqlspec/adapters/asyncpg/driver.py +720 -0
- sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
- sqlspec/adapters/asyncpg/litestar/store.py +253 -0
- sqlspec/adapters/bigquery/__init__.py +18 -0
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/adk/__init__.py +5 -0
- sqlspec/adapters/bigquery/adk/store.py +585 -0
- sqlspec/adapters/bigquery/config.py +298 -0
- sqlspec/adapters/bigquery/data_dictionary.py +256 -0
- sqlspec/adapters/bigquery/driver.py +1073 -0
- sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
- sqlspec/adapters/bigquery/litestar/store.py +327 -0
- sqlspec/adapters/bigquery/type_converter.py +125 -0
- sqlspec/adapters/duckdb/__init__.py +24 -0
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/adk/__init__.py +14 -0
- sqlspec/adapters/duckdb/adk/store.py +563 -0
- sqlspec/adapters/duckdb/config.py +396 -0
- sqlspec/adapters/duckdb/data_dictionary.py +264 -0
- sqlspec/adapters/duckdb/driver.py +604 -0
- sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
- sqlspec/adapters/duckdb/litestar/store.py +332 -0
- sqlspec/adapters/duckdb/pool.py +273 -0
- sqlspec/adapters/duckdb/type_converter.py +133 -0
- sqlspec/adapters/oracledb/__init__.py +32 -0
- sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
- sqlspec/adapters/oracledb/_types.py +39 -0
- sqlspec/adapters/oracledb/_uuid_handlers.py +130 -0
- sqlspec/adapters/oracledb/adk/__init__.py +5 -0
- sqlspec/adapters/oracledb/adk/store.py +1632 -0
- sqlspec/adapters/oracledb/config.py +469 -0
- sqlspec/adapters/oracledb/data_dictionary.py +717 -0
- sqlspec/adapters/oracledb/driver.py +1493 -0
- sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
- sqlspec/adapters/oracledb/litestar/store.py +765 -0
- sqlspec/adapters/oracledb/migrations.py +532 -0
- sqlspec/adapters/oracledb/type_converter.py +207 -0
- sqlspec/adapters/psqlpy/__init__.py +16 -0
- sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
- sqlspec/adapters/psqlpy/_types.py +12 -0
- sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
- sqlspec/adapters/psqlpy/adk/store.py +483 -0
- sqlspec/adapters/psqlpy/config.py +271 -0
- sqlspec/adapters/psqlpy/data_dictionary.py +179 -0
- sqlspec/adapters/psqlpy/driver.py +892 -0
- sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
- sqlspec/adapters/psqlpy/litestar/store.py +272 -0
- sqlspec/adapters/psqlpy/type_converter.py +102 -0
- sqlspec/adapters/psycopg/__init__.py +32 -0
- sqlspec/adapters/psycopg/_type_handlers.py +90 -0
- sqlspec/adapters/psycopg/_types.py +18 -0
- sqlspec/adapters/psycopg/adk/__init__.py +5 -0
- sqlspec/adapters/psycopg/adk/store.py +962 -0
- sqlspec/adapters/psycopg/config.py +487 -0
- sqlspec/adapters/psycopg/data_dictionary.py +630 -0
- sqlspec/adapters/psycopg/driver.py +1336 -0
- sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
- sqlspec/adapters/psycopg/litestar/store.py +554 -0
- sqlspec/adapters/spanner/__init__.py +38 -0
- sqlspec/adapters/spanner/_type_handlers.py +186 -0
- sqlspec/adapters/spanner/_types.py +12 -0
- sqlspec/adapters/spanner/adk/__init__.py +5 -0
- sqlspec/adapters/spanner/adk/store.py +435 -0
- sqlspec/adapters/spanner/config.py +241 -0
- sqlspec/adapters/spanner/data_dictionary.py +95 -0
- sqlspec/adapters/spanner/dialect/__init__.py +6 -0
- sqlspec/adapters/spanner/dialect/_spangres.py +52 -0
- sqlspec/adapters/spanner/dialect/_spanner.py +123 -0
- sqlspec/adapters/spanner/driver.py +366 -0
- sqlspec/adapters/spanner/litestar/__init__.py +5 -0
- sqlspec/adapters/spanner/litestar/store.py +266 -0
- sqlspec/adapters/spanner/type_converter.py +46 -0
- sqlspec/adapters/sqlite/__init__.py +18 -0
- sqlspec/adapters/sqlite/_type_handlers.py +86 -0
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/adk/__init__.py +5 -0
- sqlspec/adapters/sqlite/adk/store.py +582 -0
- sqlspec/adapters/sqlite/config.py +221 -0
- sqlspec/adapters/sqlite/data_dictionary.py +256 -0
- sqlspec/adapters/sqlite/driver.py +527 -0
- sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
- sqlspec/adapters/sqlite/litestar/store.py +318 -0
- sqlspec/adapters/sqlite/pool.py +140 -0
- sqlspec/base.py +811 -0
- sqlspec/builder/__init__.py +146 -0
- sqlspec/builder/_base.py +900 -0
- sqlspec/builder/_column.py +517 -0
- sqlspec/builder/_ddl.py +1642 -0
- sqlspec/builder/_delete.py +84 -0
- sqlspec/builder/_dml.py +381 -0
- sqlspec/builder/_expression_wrappers.py +46 -0
- sqlspec/builder/_factory.py +1537 -0
- sqlspec/builder/_insert.py +315 -0
- sqlspec/builder/_join.py +375 -0
- sqlspec/builder/_merge.py +848 -0
- sqlspec/builder/_parsing_utils.py +297 -0
- sqlspec/builder/_select.py +1615 -0
- sqlspec/builder/_update.py +161 -0
- sqlspec/builder/_vector_expressions.py +259 -0
- sqlspec/cli.py +764 -0
- sqlspec/config.py +1540 -0
- sqlspec/core/__init__.py +305 -0
- sqlspec/core/cache.py +785 -0
- sqlspec/core/compiler.py +603 -0
- sqlspec/core/filters.py +872 -0
- sqlspec/core/hashing.py +274 -0
- sqlspec/core/metrics.py +83 -0
- sqlspec/core/parameters/__init__.py +64 -0
- sqlspec/core/parameters/_alignment.py +266 -0
- sqlspec/core/parameters/_converter.py +413 -0
- sqlspec/core/parameters/_processor.py +341 -0
- sqlspec/core/parameters/_registry.py +201 -0
- sqlspec/core/parameters/_transformers.py +226 -0
- sqlspec/core/parameters/_types.py +430 -0
- sqlspec/core/parameters/_validator.py +123 -0
- sqlspec/core/pipeline.py +187 -0
- sqlspec/core/result.py +1124 -0
- sqlspec/core/splitter.py +940 -0
- sqlspec/core/stack.py +163 -0
- sqlspec/core/statement.py +835 -0
- sqlspec/core/type_conversion.py +235 -0
- sqlspec/driver/__init__.py +36 -0
- sqlspec/driver/_async.py +1027 -0
- sqlspec/driver/_common.py +1236 -0
- sqlspec/driver/_sync.py +1025 -0
- sqlspec/driver/mixins/__init__.py +7 -0
- sqlspec/driver/mixins/_result_tools.py +61 -0
- sqlspec/driver/mixins/_sql_translator.py +122 -0
- sqlspec/driver/mixins/_storage.py +311 -0
- sqlspec/exceptions.py +321 -0
- sqlspec/extensions/__init__.py +0 -0
- sqlspec/extensions/adk/__init__.py +53 -0
- sqlspec/extensions/adk/_types.py +51 -0
- sqlspec/extensions/adk/converters.py +172 -0
- sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
- sqlspec/extensions/adk/migrations/__init__.py +0 -0
- sqlspec/extensions/adk/service.py +181 -0
- sqlspec/extensions/adk/store.py +536 -0
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +471 -0
- sqlspec/extensions/fastapi/__init__.py +19 -0
- sqlspec/extensions/fastapi/extension.py +341 -0
- sqlspec/extensions/fastapi/providers.py +543 -0
- sqlspec/extensions/flask/__init__.py +36 -0
- sqlspec/extensions/flask/_state.py +72 -0
- sqlspec/extensions/flask/_utils.py +40 -0
- sqlspec/extensions/flask/extension.py +402 -0
- sqlspec/extensions/litestar/__init__.py +23 -0
- sqlspec/extensions/litestar/_utils.py +52 -0
- sqlspec/extensions/litestar/cli.py +92 -0
- sqlspec/extensions/litestar/config.py +90 -0
- sqlspec/extensions/litestar/handlers.py +316 -0
- sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
- sqlspec/extensions/litestar/migrations/__init__.py +3 -0
- sqlspec/extensions/litestar/plugin.py +638 -0
- sqlspec/extensions/litestar/providers.py +454 -0
- sqlspec/extensions/litestar/store.py +265 -0
- sqlspec/extensions/otel/__init__.py +58 -0
- sqlspec/extensions/prometheus/__init__.py +107 -0
- sqlspec/extensions/starlette/__init__.py +10 -0
- sqlspec/extensions/starlette/_state.py +26 -0
- sqlspec/extensions/starlette/_utils.py +52 -0
- sqlspec/extensions/starlette/extension.py +257 -0
- sqlspec/extensions/starlette/middleware.py +154 -0
- sqlspec/loader.py +716 -0
- sqlspec/migrations/__init__.py +36 -0
- sqlspec/migrations/base.py +728 -0
- sqlspec/migrations/commands.py +1140 -0
- sqlspec/migrations/context.py +142 -0
- sqlspec/migrations/fix.py +203 -0
- sqlspec/migrations/loaders.py +450 -0
- sqlspec/migrations/runner.py +1024 -0
- sqlspec/migrations/templates.py +234 -0
- sqlspec/migrations/tracker.py +403 -0
- sqlspec/migrations/utils.py +256 -0
- sqlspec/migrations/validation.py +203 -0
- sqlspec/observability/__init__.py +22 -0
- sqlspec/observability/_config.py +228 -0
- sqlspec/observability/_diagnostics.py +67 -0
- sqlspec/observability/_dispatcher.py +151 -0
- sqlspec/observability/_observer.py +180 -0
- sqlspec/observability/_runtime.py +381 -0
- sqlspec/observability/_spans.py +158 -0
- sqlspec/protocols.py +530 -0
- sqlspec/py.typed +0 -0
- sqlspec/storage/__init__.py +46 -0
- sqlspec/storage/_utils.py +104 -0
- sqlspec/storage/backends/__init__.py +1 -0
- sqlspec/storage/backends/base.py +163 -0
- sqlspec/storage/backends/fsspec.py +398 -0
- sqlspec/storage/backends/local.py +377 -0
- sqlspec/storage/backends/obstore.py +580 -0
- sqlspec/storage/errors.py +104 -0
- sqlspec/storage/pipeline.py +604 -0
- sqlspec/storage/registry.py +289 -0
- sqlspec/typing.py +219 -0
- sqlspec/utils/__init__.py +31 -0
- sqlspec/utils/arrow_helpers.py +95 -0
- sqlspec/utils/config_resolver.py +153 -0
- sqlspec/utils/correlation.py +132 -0
- sqlspec/utils/data_transformation.py +114 -0
- sqlspec/utils/dependencies.py +79 -0
- sqlspec/utils/deprecation.py +113 -0
- sqlspec/utils/fixtures.py +250 -0
- sqlspec/utils/logging.py +172 -0
- sqlspec/utils/module_loader.py +273 -0
- sqlspec/utils/portal.py +325 -0
- sqlspec/utils/schema.py +288 -0
- sqlspec/utils/serializers.py +396 -0
- sqlspec/utils/singleton.py +41 -0
- sqlspec/utils/sync_tools.py +277 -0
- sqlspec/utils/text.py +108 -0
- sqlspec/utils/type_converters.py +99 -0
- sqlspec/utils/type_guards.py +1324 -0
- sqlspec/utils/version.py +444 -0
- sqlspec-0.32.0.dist-info/METADATA +202 -0
- sqlspec-0.32.0.dist-info/RECORD +262 -0
- sqlspec-0.32.0.dist-info/WHEEL +4 -0
- sqlspec-0.32.0.dist-info/entry_points.txt +2 -0
- sqlspec-0.32.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1336 @@
|
|
|
1
|
+
"""PostgreSQL psycopg driver implementation."""
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import io
|
|
5
|
+
from contextlib import AsyncExitStack, ExitStack
|
|
6
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, cast
|
|
7
|
+
|
|
8
|
+
import psycopg
|
|
9
|
+
from psycopg import sql as psycopg_sql
|
|
10
|
+
|
|
11
|
+
from sqlspec.adapters.psycopg._types import PsycopgAsyncConnection, PsycopgSyncConnection
|
|
12
|
+
from sqlspec.core import (
|
|
13
|
+
SQL,
|
|
14
|
+
DriverParameterProfile,
|
|
15
|
+
ParameterStyle,
|
|
16
|
+
ParameterStyleConfig,
|
|
17
|
+
SQLResult,
|
|
18
|
+
StackOperation,
|
|
19
|
+
StackResult,
|
|
20
|
+
Statement,
|
|
21
|
+
StatementConfig,
|
|
22
|
+
StatementStack,
|
|
23
|
+
build_statement_config_from_profile,
|
|
24
|
+
get_cache_config,
|
|
25
|
+
is_copy_from_operation,
|
|
26
|
+
is_copy_operation,
|
|
27
|
+
is_copy_to_operation,
|
|
28
|
+
register_driver_profile,
|
|
29
|
+
)
|
|
30
|
+
from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase
|
|
31
|
+
from sqlspec.driver._common import StackExecutionObserver, describe_stack_statement
|
|
32
|
+
from sqlspec.exceptions import (
|
|
33
|
+
CheckViolationError,
|
|
34
|
+
DatabaseConnectionError,
|
|
35
|
+
DataError,
|
|
36
|
+
ForeignKeyViolationError,
|
|
37
|
+
IntegrityError,
|
|
38
|
+
NotNullViolationError,
|
|
39
|
+
OperationalError,
|
|
40
|
+
SQLParsingError,
|
|
41
|
+
SQLSpecError,
|
|
42
|
+
StackExecutionError,
|
|
43
|
+
TransactionError,
|
|
44
|
+
UniqueViolationError,
|
|
45
|
+
)
|
|
46
|
+
from sqlspec.utils.logging import get_logger
|
|
47
|
+
from sqlspec.utils.serializers import to_json
|
|
48
|
+
from sqlspec.utils.type_converters import build_json_list_converter, build_json_tuple_converter
|
|
49
|
+
|
|
50
|
+
if TYPE_CHECKING:
|
|
51
|
+
from collections.abc import Callable
|
|
52
|
+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
53
|
+
|
|
54
|
+
from sqlspec.builder import QueryBuilder
|
|
55
|
+
from sqlspec.core import ArrowResult
|
|
56
|
+
from sqlspec.driver._async import AsyncDataDictionaryBase
|
|
57
|
+
from sqlspec.driver._common import ExecutionResult
|
|
58
|
+
from sqlspec.driver._sync import SyncDataDictionaryBase
|
|
59
|
+
from sqlspec.storage import (
|
|
60
|
+
AsyncStoragePipeline,
|
|
61
|
+
StorageBridgeJob,
|
|
62
|
+
StorageDestination,
|
|
63
|
+
StorageFormat,
|
|
64
|
+
StorageTelemetry,
|
|
65
|
+
SyncStoragePipeline,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
class _PipelineDriver(Protocol):
|
|
69
|
+
statement_config: StatementConfig
|
|
70
|
+
|
|
71
|
+
def prepare_statement(
|
|
72
|
+
self,
|
|
73
|
+
statement: "SQL | Statement | QueryBuilder",
|
|
74
|
+
parameters: Any,
|
|
75
|
+
*,
|
|
76
|
+
statement_config: StatementConfig,
|
|
77
|
+
kwargs: dict[str, Any],
|
|
78
|
+
) -> SQL: ...
|
|
79
|
+
|
|
80
|
+
def _get_compiled_sql(self, statement: SQL, statement_config: StatementConfig) -> tuple[str, Any]: ...
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
__all__ = (
|
|
84
|
+
"PsycopgAsyncCursor",
|
|
85
|
+
"PsycopgAsyncDriver",
|
|
86
|
+
"PsycopgAsyncExceptionHandler",
|
|
87
|
+
"PsycopgSyncCursor",
|
|
88
|
+
"PsycopgSyncDriver",
|
|
89
|
+
"PsycopgSyncExceptionHandler",
|
|
90
|
+
"build_psycopg_statement_config",
|
|
91
|
+
"psycopg_statement_config",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
logger = get_logger("adapters.psycopg")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _psycopg_pipeline_supported() -> bool:
|
|
98
|
+
"""Return True when libpq pipeline support is available."""
|
|
99
|
+
|
|
100
|
+
capabilities = getattr(psycopg, "capabilities", None)
|
|
101
|
+
if capabilities is None:
|
|
102
|
+
return False
|
|
103
|
+
try:
|
|
104
|
+
return bool(capabilities.has_pipeline())
|
|
105
|
+
except Exception: # pragma: no cover - defensive guard for unexpected capability implementations
|
|
106
|
+
return False
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class _PreparedStackOperation(NamedTuple):
|
|
110
|
+
"""Precompiled stack operation metadata for psycopg pipeline execution."""
|
|
111
|
+
|
|
112
|
+
operation_index: int
|
|
113
|
+
operation: "StackOperation"
|
|
114
|
+
statement: "SQL"
|
|
115
|
+
sql: str
|
|
116
|
+
parameters: "tuple[Any, ...] | dict[str, Any] | None"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class _PipelineCursorEntry(NamedTuple):
|
|
120
|
+
"""Cursor pending result data for psycopg pipeline execution."""
|
|
121
|
+
|
|
122
|
+
prepared: "_PreparedStackOperation"
|
|
123
|
+
cursor: Any
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class PsycopgPipelineMixin:
|
|
127
|
+
"""Shared helpers for psycopg sync/async pipeline execution."""
|
|
128
|
+
|
|
129
|
+
__slots__ = ()
|
|
130
|
+
|
|
131
|
+
def _prepare_pipeline_operations(self, stack: "StatementStack") -> "list[_PreparedStackOperation] | None":
|
|
132
|
+
prepared: list[_PreparedStackOperation] = []
|
|
133
|
+
for index, operation in enumerate(stack.operations):
|
|
134
|
+
normalized = self._normalize_stack_operation_for_pipeline(index, operation)
|
|
135
|
+
if normalized is None:
|
|
136
|
+
return None
|
|
137
|
+
prepared.append(normalized)
|
|
138
|
+
return prepared
|
|
139
|
+
|
|
140
|
+
def _normalize_stack_operation_for_pipeline(
|
|
141
|
+
self, index: int, operation: "StackOperation"
|
|
142
|
+
) -> "_PreparedStackOperation | None":
|
|
143
|
+
if operation.method != "execute":
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
kwargs = dict(operation.keyword_arguments) if operation.keyword_arguments else {}
|
|
147
|
+
statement_config = kwargs.pop("statement_config", None)
|
|
148
|
+
driver = cast("_PipelineDriver", self)
|
|
149
|
+
config = statement_config or driver.statement_config
|
|
150
|
+
|
|
151
|
+
sql_statement = driver.prepare_statement(
|
|
152
|
+
operation.statement, operation.arguments, statement_config=config, kwargs=kwargs
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if sql_statement.is_script or sql_statement.is_many:
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
sql_text, prepared_parameters = driver._get_compiled_sql( # pyright: ignore[reportPrivateUsage]
|
|
159
|
+
sql_statement, config
|
|
160
|
+
)
|
|
161
|
+
return _PreparedStackOperation(
|
|
162
|
+
operation_index=index,
|
|
163
|
+
operation=operation,
|
|
164
|
+
statement=sql_statement,
|
|
165
|
+
sql=sql_text,
|
|
166
|
+
parameters=prepared_parameters,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
TRANSACTION_STATUS_IDLE = 0
|
|
171
|
+
TRANSACTION_STATUS_ACTIVE = 1
|
|
172
|
+
TRANSACTION_STATUS_INTRANS = 2
|
|
173
|
+
TRANSACTION_STATUS_INERROR = 3
|
|
174
|
+
TRANSACTION_STATUS_UNKNOWN = 4
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _compose_table_identifier(table: str) -> "psycopg_sql.Composed":
|
|
178
|
+
parts = [part for part in table.split(".") if part]
|
|
179
|
+
if not parts:
|
|
180
|
+
msg = "Table name must not be empty"
|
|
181
|
+
raise SQLSpecError(msg)
|
|
182
|
+
identifiers = [psycopg_sql.Identifier(part) for part in parts]
|
|
183
|
+
return psycopg_sql.SQL(".").join(identifiers)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _build_copy_from_command(table: str, columns: "list[str]") -> "psycopg_sql.Composed":
|
|
187
|
+
table_identifier = _compose_table_identifier(table)
|
|
188
|
+
column_sql = psycopg_sql.SQL(", ").join(psycopg_sql.Identifier(column) for column in columns)
|
|
189
|
+
return psycopg_sql.SQL("COPY {} ({}) FROM STDIN").format(table_identifier, column_sql)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _build_truncate_command(table: str) -> "psycopg_sql.Composed":
|
|
193
|
+
return psycopg_sql.SQL("TRUNCATE TABLE {}").format(_compose_table_identifier(table))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class PsycopgSyncCursor:
|
|
197
|
+
"""Context manager for PostgreSQL psycopg cursor management."""
|
|
198
|
+
|
|
199
|
+
__slots__ = ("connection", "cursor")
|
|
200
|
+
|
|
201
|
+
def __init__(self, connection: PsycopgSyncConnection) -> None:
|
|
202
|
+
self.connection = connection
|
|
203
|
+
self.cursor: Any | None = None
|
|
204
|
+
|
|
205
|
+
def __enter__(self) -> Any:
|
|
206
|
+
self.cursor = self.connection.cursor()
|
|
207
|
+
return self.cursor
|
|
208
|
+
|
|
209
|
+
def __exit__(self, *_: Any) -> None:
|
|
210
|
+
if self.cursor is not None:
|
|
211
|
+
self.cursor.close()
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class PsycopgSyncExceptionHandler:
|
|
215
|
+
"""Context manager for handling PostgreSQL psycopg database exceptions.
|
|
216
|
+
|
|
217
|
+
Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
|
|
218
|
+
for better error handling in application code.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
__slots__ = ()
|
|
222
|
+
|
|
223
|
+
def __enter__(self) -> None:
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
227
|
+
if exc_type is None:
|
|
228
|
+
return
|
|
229
|
+
if issubclass(exc_type, psycopg.Error):
|
|
230
|
+
self._map_postgres_exception(exc_val)
|
|
231
|
+
|
|
232
|
+
def _map_postgres_exception(self, e: Any) -> None:
|
|
233
|
+
"""Map PostgreSQL exception to SQLSpec exception.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
e: psycopg.Error instance
|
|
237
|
+
|
|
238
|
+
Raises:
|
|
239
|
+
Specific SQLSpec exception based on SQLSTATE code
|
|
240
|
+
"""
|
|
241
|
+
error_code = getattr(e, "sqlstate", None)
|
|
242
|
+
|
|
243
|
+
if not error_code:
|
|
244
|
+
self._raise_generic_error(e, None)
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
if error_code == "23505":
|
|
248
|
+
self._raise_unique_violation(e, error_code)
|
|
249
|
+
elif error_code == "23503":
|
|
250
|
+
self._raise_foreign_key_violation(e, error_code)
|
|
251
|
+
elif error_code == "23502":
|
|
252
|
+
self._raise_not_null_violation(e, error_code)
|
|
253
|
+
elif error_code == "23514":
|
|
254
|
+
self._raise_check_violation(e, error_code)
|
|
255
|
+
elif error_code.startswith("23"):
|
|
256
|
+
self._raise_integrity_error(e, error_code)
|
|
257
|
+
elif error_code.startswith("42"):
|
|
258
|
+
self._raise_parsing_error(e, error_code)
|
|
259
|
+
elif error_code.startswith("08"):
|
|
260
|
+
self._raise_connection_error(e, error_code)
|
|
261
|
+
elif error_code.startswith("40"):
|
|
262
|
+
self._raise_transaction_error(e, error_code)
|
|
263
|
+
elif error_code.startswith("22"):
|
|
264
|
+
self._raise_data_error(e, error_code)
|
|
265
|
+
elif error_code.startswith(("53", "54", "55", "57", "58")):
|
|
266
|
+
self._raise_operational_error(e, error_code)
|
|
267
|
+
else:
|
|
268
|
+
self._raise_generic_error(e, error_code)
|
|
269
|
+
|
|
270
|
+
def _raise_unique_violation(self, e: Any, code: str) -> None:
|
|
271
|
+
msg = f"PostgreSQL unique constraint violation [{code}]: {e}"
|
|
272
|
+
raise UniqueViolationError(msg) from e
|
|
273
|
+
|
|
274
|
+
def _raise_foreign_key_violation(self, e: Any, code: str) -> None:
|
|
275
|
+
msg = f"PostgreSQL foreign key constraint violation [{code}]: {e}"
|
|
276
|
+
raise ForeignKeyViolationError(msg) from e
|
|
277
|
+
|
|
278
|
+
def _raise_not_null_violation(self, e: Any, code: str) -> None:
|
|
279
|
+
msg = f"PostgreSQL not-null constraint violation [{code}]: {e}"
|
|
280
|
+
raise NotNullViolationError(msg) from e
|
|
281
|
+
|
|
282
|
+
def _raise_check_violation(self, e: Any, code: str) -> None:
|
|
283
|
+
msg = f"PostgreSQL check constraint violation [{code}]: {e}"
|
|
284
|
+
raise CheckViolationError(msg) from e
|
|
285
|
+
|
|
286
|
+
def _raise_integrity_error(self, e: Any, code: str) -> None:
|
|
287
|
+
msg = f"PostgreSQL integrity constraint violation [{code}]: {e}"
|
|
288
|
+
raise IntegrityError(msg) from e
|
|
289
|
+
|
|
290
|
+
def _raise_parsing_error(self, e: Any, code: str) -> None:
|
|
291
|
+
msg = f"PostgreSQL SQL syntax error [{code}]: {e}"
|
|
292
|
+
raise SQLParsingError(msg) from e
|
|
293
|
+
|
|
294
|
+
def _raise_connection_error(self, e: Any, code: str) -> None:
|
|
295
|
+
msg = f"PostgreSQL connection error [{code}]: {e}"
|
|
296
|
+
raise DatabaseConnectionError(msg) from e
|
|
297
|
+
|
|
298
|
+
def _raise_transaction_error(self, e: Any, code: str) -> None:
|
|
299
|
+
msg = f"PostgreSQL transaction error [{code}]: {e}"
|
|
300
|
+
raise TransactionError(msg) from e
|
|
301
|
+
|
|
302
|
+
def _raise_data_error(self, e: Any, code: str) -> None:
|
|
303
|
+
msg = f"PostgreSQL data error [{code}]: {e}"
|
|
304
|
+
raise DataError(msg) from e
|
|
305
|
+
|
|
306
|
+
def _raise_operational_error(self, e: Any, code: str) -> None:
|
|
307
|
+
msg = f"PostgreSQL operational error [{code}]: {e}"
|
|
308
|
+
raise OperationalError(msg) from e
|
|
309
|
+
|
|
310
|
+
def _raise_generic_error(self, e: Any, code: "str | None") -> None:
|
|
311
|
+
msg = f"PostgreSQL database error [{code}]: {e}" if code else f"PostgreSQL database error: {e}"
|
|
312
|
+
raise SQLSpecError(msg) from e
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class PsycopgSyncDriver(PsycopgPipelineMixin, SyncDriverAdapterBase):
|
|
316
|
+
"""PostgreSQL psycopg synchronous driver.
|
|
317
|
+
|
|
318
|
+
Provides synchronous database operations for PostgreSQL using psycopg3.
|
|
319
|
+
Supports SQL statement execution with parameter binding, transaction
|
|
320
|
+
management, result processing with column metadata, parameter style
|
|
321
|
+
conversion, PostgreSQL arrays and JSON handling, COPY operations for
|
|
322
|
+
bulk data transfer, and PostgreSQL-specific error handling.
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
__slots__ = ("_data_dictionary",)
|
|
326
|
+
dialect = "postgres"
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
connection: PsycopgSyncConnection,
|
|
331
|
+
statement_config: "StatementConfig | None" = None,
|
|
332
|
+
driver_features: "dict[str, Any] | None" = None,
|
|
333
|
+
) -> None:
|
|
334
|
+
if statement_config is None:
|
|
335
|
+
cache_config = get_cache_config()
|
|
336
|
+
default_config = psycopg_statement_config.replace(
|
|
337
|
+
enable_caching=cache_config.compiled_cache_enabled,
|
|
338
|
+
enable_parsing=True,
|
|
339
|
+
enable_validation=True,
|
|
340
|
+
dialect="postgres",
|
|
341
|
+
)
|
|
342
|
+
statement_config = default_config
|
|
343
|
+
|
|
344
|
+
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
|
|
345
|
+
self._data_dictionary: SyncDataDictionaryBase | None = None
|
|
346
|
+
|
|
347
|
+
def with_cursor(self, connection: PsycopgSyncConnection) -> PsycopgSyncCursor:
|
|
348
|
+
"""Create context manager for PostgreSQL cursor."""
|
|
349
|
+
return PsycopgSyncCursor(connection)
|
|
350
|
+
|
|
351
|
+
def begin(self) -> None:
|
|
352
|
+
"""Begin a database transaction on the current connection."""
|
|
353
|
+
try:
|
|
354
|
+
if hasattr(self.connection, "autocommit") and not self.connection.autocommit:
|
|
355
|
+
pass
|
|
356
|
+
else:
|
|
357
|
+
self.connection.autocommit = False
|
|
358
|
+
except Exception as e:
|
|
359
|
+
msg = f"Failed to begin transaction: {e}"
|
|
360
|
+
raise SQLSpecError(msg) from e
|
|
361
|
+
|
|
362
|
+
def rollback(self) -> None:
|
|
363
|
+
"""Rollback the current transaction on the current connection."""
|
|
364
|
+
try:
|
|
365
|
+
self.connection.rollback()
|
|
366
|
+
except Exception as e:
|
|
367
|
+
msg = f"Failed to rollback transaction: {e}"
|
|
368
|
+
raise SQLSpecError(msg) from e
|
|
369
|
+
|
|
370
|
+
def commit(self) -> None:
|
|
371
|
+
"""Commit the current transaction on the current connection."""
|
|
372
|
+
try:
|
|
373
|
+
self.connection.commit()
|
|
374
|
+
except Exception as e:
|
|
375
|
+
msg = f"Failed to commit transaction: {e}"
|
|
376
|
+
raise SQLSpecError(msg) from e
|
|
377
|
+
|
|
378
|
+
def handle_database_exceptions(self) -> "AbstractContextManager[None]":
|
|
379
|
+
"""Handle database-specific exceptions and wrap them appropriately."""
|
|
380
|
+
return PsycopgSyncExceptionHandler()
|
|
381
|
+
|
|
382
|
+
def _handle_transaction_error_cleanup(self) -> None:
|
|
383
|
+
"""Handle transaction cleanup after database errors."""
|
|
384
|
+
try:
|
|
385
|
+
if hasattr(self.connection, "info") and hasattr(self.connection.info, "transaction_status"):
|
|
386
|
+
status = self.connection.info.transaction_status
|
|
387
|
+
|
|
388
|
+
if status == TRANSACTION_STATUS_INERROR:
|
|
389
|
+
logger.debug("Connection in aborted transaction state, performing rollback")
|
|
390
|
+
self.connection.rollback()
|
|
391
|
+
except Exception as cleanup_error:
|
|
392
|
+
logger.warning("Failed to cleanup transaction state: %s", cleanup_error)
|
|
393
|
+
|
|
394
|
+
def _try_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None":
|
|
395
|
+
"""Hook for PostgreSQL-specific special operations.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
cursor: Psycopg cursor object
|
|
399
|
+
statement: SQL statement to analyze
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
SQLResult if special handling was applied, None otherwise
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
statement.compile()
|
|
406
|
+
|
|
407
|
+
if is_copy_operation(statement.operation_type):
|
|
408
|
+
return self._handle_copy_operation(cursor, statement)
|
|
409
|
+
|
|
410
|
+
return None
|
|
411
|
+
|
|
412
|
+
def _handle_copy_operation(self, cursor: Any, statement: "SQL") -> "SQLResult":
|
|
413
|
+
"""Handle PostgreSQL COPY operations using copy_expert.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
cursor: Psycopg cursor object
|
|
417
|
+
statement: SQL statement with COPY operation
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
SQLResult with COPY operation results
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
sql = statement.sql
|
|
424
|
+
operation_type = statement.operation_type
|
|
425
|
+
copy_data = statement.parameters
|
|
426
|
+
if isinstance(copy_data, list) and len(copy_data) == 1:
|
|
427
|
+
copy_data = copy_data[0]
|
|
428
|
+
|
|
429
|
+
if is_copy_from_operation(operation_type):
|
|
430
|
+
if isinstance(copy_data, (str, bytes)):
|
|
431
|
+
data_file = io.StringIO(copy_data) if isinstance(copy_data, str) else io.BytesIO(copy_data)
|
|
432
|
+
elif hasattr(copy_data, "read"):
|
|
433
|
+
data_file = copy_data
|
|
434
|
+
else:
|
|
435
|
+
data_file = io.StringIO(str(copy_data))
|
|
436
|
+
|
|
437
|
+
with cursor.copy(sql) as copy_ctx:
|
|
438
|
+
data_to_write = data_file.read() if hasattr(data_file, "read") else str(copy_data) # pyright: ignore
|
|
439
|
+
if isinstance(data_to_write, str):
|
|
440
|
+
data_to_write = data_to_write.encode()
|
|
441
|
+
copy_ctx.write(data_to_write)
|
|
442
|
+
|
|
443
|
+
rows_affected = max(cursor.rowcount, 0)
|
|
444
|
+
|
|
445
|
+
return SQLResult(
|
|
446
|
+
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FROM_STDIN"}
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if is_copy_to_operation(operation_type):
|
|
450
|
+
output_data: list[str] = []
|
|
451
|
+
with cursor.copy(sql) as copy_ctx:
|
|
452
|
+
output_data.extend(row.decode() if isinstance(row, bytes) else str(row) for row in copy_ctx)
|
|
453
|
+
|
|
454
|
+
exported_data = "".join(output_data)
|
|
455
|
+
|
|
456
|
+
return SQLResult(
|
|
457
|
+
data=[{"copy_output": exported_data}],
|
|
458
|
+
rows_affected=0,
|
|
459
|
+
statement=statement,
|
|
460
|
+
metadata={"copy_operation": "TO_STDOUT"},
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
cursor.execute(sql)
|
|
464
|
+
rows_affected = max(cursor.rowcount, 0)
|
|
465
|
+
|
|
466
|
+
return SQLResult(
|
|
467
|
+
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FILE"}
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def _execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
471
|
+
"""Execute SQL script with multiple statements.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
cursor: Database cursor
|
|
475
|
+
statement: SQL statement containing multiple commands
|
|
476
|
+
|
|
477
|
+
Returns:
|
|
478
|
+
ExecutionResult with script execution details
|
|
479
|
+
"""
|
|
480
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
481
|
+
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
|
|
482
|
+
|
|
483
|
+
successful_count = 0
|
|
484
|
+
last_cursor = cursor
|
|
485
|
+
|
|
486
|
+
for stmt in statements:
|
|
487
|
+
if prepared_parameters:
|
|
488
|
+
cursor.execute(stmt, prepared_parameters)
|
|
489
|
+
else:
|
|
490
|
+
cursor.execute(stmt)
|
|
491
|
+
successful_count += 1
|
|
492
|
+
|
|
493
|
+
return self.create_execution_result(
|
|
494
|
+
last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
def execute_stack(self, stack: "StatementStack", *, continue_on_error: bool = False) -> "tuple[StackResult, ...]":
|
|
498
|
+
"""Execute a StatementStack using psycopg pipeline mode when supported."""
|
|
499
|
+
|
|
500
|
+
if (
|
|
501
|
+
not isinstance(stack, StatementStack)
|
|
502
|
+
or not stack
|
|
503
|
+
or self.stack_native_disabled
|
|
504
|
+
or not _psycopg_pipeline_supported()
|
|
505
|
+
or continue_on_error
|
|
506
|
+
):
|
|
507
|
+
return super().execute_stack(stack, continue_on_error=continue_on_error)
|
|
508
|
+
|
|
509
|
+
prepared_ops = self._prepare_pipeline_operations(stack)
|
|
510
|
+
if prepared_ops is None:
|
|
511
|
+
return super().execute_stack(stack, continue_on_error=continue_on_error)
|
|
512
|
+
|
|
513
|
+
return self._execute_stack_pipeline(stack, prepared_ops)
|
|
514
|
+
|
|
515
|
+
def _execute_stack_pipeline(
|
|
516
|
+
self, stack: "StatementStack", prepared_ops: "list[_PreparedStackOperation]"
|
|
517
|
+
) -> "tuple[StackResult, ...]":
|
|
518
|
+
results: list[StackResult] = []
|
|
519
|
+
started_transaction = False
|
|
520
|
+
|
|
521
|
+
with StackExecutionObserver(self, stack, continue_on_error=False, native_pipeline=True):
|
|
522
|
+
try:
|
|
523
|
+
if not self._connection_in_transaction():
|
|
524
|
+
self.begin()
|
|
525
|
+
started_transaction = True
|
|
526
|
+
|
|
527
|
+
with ExitStack() as resource_stack:
|
|
528
|
+
pipeline = resource_stack.enter_context(self.connection.pipeline())
|
|
529
|
+
pending: list[_PipelineCursorEntry] = []
|
|
530
|
+
|
|
531
|
+
for prepared in prepared_ops:
|
|
532
|
+
exception_ctx = self.handle_database_exceptions()
|
|
533
|
+
resource_stack.enter_context(exception_ctx)
|
|
534
|
+
cursor = resource_stack.enter_context(self.with_cursor(self.connection))
|
|
535
|
+
|
|
536
|
+
try:
|
|
537
|
+
if prepared.parameters:
|
|
538
|
+
cursor.execute(prepared.sql, prepared.parameters)
|
|
539
|
+
else:
|
|
540
|
+
cursor.execute(prepared.sql)
|
|
541
|
+
except Exception as exc:
|
|
542
|
+
stack_error = StackExecutionError(
|
|
543
|
+
prepared.operation_index,
|
|
544
|
+
describe_stack_statement(prepared.operation.statement),
|
|
545
|
+
exc,
|
|
546
|
+
adapter=type(self).__name__,
|
|
547
|
+
mode="fail-fast",
|
|
548
|
+
)
|
|
549
|
+
raise stack_error from exc
|
|
550
|
+
|
|
551
|
+
pending.append(_PipelineCursorEntry(prepared=prepared, cursor=cursor))
|
|
552
|
+
|
|
553
|
+
pipeline.sync()
|
|
554
|
+
|
|
555
|
+
results.extend(self._build_pipeline_stack_result(entry) for entry in pending)
|
|
556
|
+
|
|
557
|
+
if started_transaction:
|
|
558
|
+
self.commit()
|
|
559
|
+
except Exception:
|
|
560
|
+
if started_transaction:
|
|
561
|
+
try:
|
|
562
|
+
self.rollback()
|
|
563
|
+
except Exception as rollback_error: # pragma: no cover - diagnostics only
|
|
564
|
+
logger.debug("Rollback after psycopg pipeline failure failed: %s", rollback_error)
|
|
565
|
+
raise
|
|
566
|
+
|
|
567
|
+
return tuple(results)
|
|
568
|
+
|
|
569
|
+
def _build_pipeline_stack_result(self, entry: "_PipelineCursorEntry") -> StackResult:
|
|
570
|
+
statement = entry.prepared.statement
|
|
571
|
+
cursor = entry.cursor
|
|
572
|
+
|
|
573
|
+
if statement.returns_rows():
|
|
574
|
+
fetched_data = cursor.fetchall()
|
|
575
|
+
column_names = [col.name for col in cursor.description or []]
|
|
576
|
+
execution_result = self.create_execution_result(
|
|
577
|
+
cursor,
|
|
578
|
+
selected_data=fetched_data,
|
|
579
|
+
column_names=column_names,
|
|
580
|
+
data_row_count=len(fetched_data),
|
|
581
|
+
is_select_result=True,
|
|
582
|
+
)
|
|
583
|
+
else:
|
|
584
|
+
affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
|
|
585
|
+
execution_result = self.create_execution_result(cursor, rowcount_override=affected_rows)
|
|
586
|
+
|
|
587
|
+
sql_result = self.build_statement_result(statement, execution_result)
|
|
588
|
+
return StackResult.from_sql_result(sql_result)
|
|
589
|
+
|
|
590
|
+
def _execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
591
|
+
"""Execute SQL with multiple parameter sets.
|
|
592
|
+
|
|
593
|
+
Args:
|
|
594
|
+
cursor: Database cursor
|
|
595
|
+
statement: SQL statement with parameter list
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
ExecutionResult with batch execution details
|
|
599
|
+
"""
|
|
600
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
601
|
+
|
|
602
|
+
if not prepared_parameters:
|
|
603
|
+
return self.create_execution_result(cursor, rowcount_override=0, is_many_result=True)
|
|
604
|
+
|
|
605
|
+
cursor.executemany(sql, prepared_parameters)
|
|
606
|
+
|
|
607
|
+
affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
|
|
608
|
+
|
|
609
|
+
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
|
|
610
|
+
|
|
611
|
+
def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
612
|
+
"""Execute single SQL statement.
|
|
613
|
+
|
|
614
|
+
Args:
|
|
615
|
+
cursor: Database cursor
|
|
616
|
+
statement: SQL statement to execute
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
ExecutionResult with statement execution details
|
|
620
|
+
"""
|
|
621
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
622
|
+
|
|
623
|
+
if prepared_parameters:
|
|
624
|
+
cursor.execute(sql, prepared_parameters)
|
|
625
|
+
else:
|
|
626
|
+
cursor.execute(sql)
|
|
627
|
+
|
|
628
|
+
if statement.returns_rows():
|
|
629
|
+
fetched_data = cursor.fetchall()
|
|
630
|
+
column_names = [col.name for col in cursor.description or []]
|
|
631
|
+
|
|
632
|
+
return self.create_execution_result(
|
|
633
|
+
cursor,
|
|
634
|
+
selected_data=fetched_data,
|
|
635
|
+
column_names=column_names,
|
|
636
|
+
data_row_count=len(fetched_data),
|
|
637
|
+
is_select_result=True,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
|
|
641
|
+
return self.create_execution_result(cursor, rowcount_override=affected_rows)
|
|
642
|
+
|
|
643
|
+
def select_to_storage(
|
|
644
|
+
self,
|
|
645
|
+
statement: "SQL | str",
|
|
646
|
+
destination: "StorageDestination",
|
|
647
|
+
/,
|
|
648
|
+
*parameters: Any,
|
|
649
|
+
statement_config: "StatementConfig | None" = None,
|
|
650
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
651
|
+
format_hint: "StorageFormat | None" = None,
|
|
652
|
+
telemetry: "StorageTelemetry | None" = None,
|
|
653
|
+
**kwargs: Any,
|
|
654
|
+
) -> "StorageBridgeJob":
|
|
655
|
+
"""Execute a query and stream Arrow results to storage (sync)."""
|
|
656
|
+
|
|
657
|
+
self._require_capability("arrow_export_enabled")
|
|
658
|
+
arrow_result = self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
659
|
+
sync_pipeline: SyncStoragePipeline = cast("SyncStoragePipeline", self._storage_pipeline())
|
|
660
|
+
telemetry_payload = self._write_result_to_storage_sync(
|
|
661
|
+
arrow_result, destination, format_hint=format_hint, pipeline=sync_pipeline
|
|
662
|
+
)
|
|
663
|
+
self._attach_partition_telemetry(telemetry_payload, partitioner)
|
|
664
|
+
return self._create_storage_job(telemetry_payload, telemetry)
|
|
665
|
+
|
|
666
|
+
def load_from_arrow(
|
|
667
|
+
self,
|
|
668
|
+
table: str,
|
|
669
|
+
source: "ArrowResult | Any",
|
|
670
|
+
*,
|
|
671
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
672
|
+
overwrite: bool = False,
|
|
673
|
+
telemetry: "StorageTelemetry | None" = None,
|
|
674
|
+
) -> "StorageBridgeJob":
|
|
675
|
+
"""Load Arrow data into PostgreSQL using COPY."""
|
|
676
|
+
|
|
677
|
+
self._require_capability("arrow_import_enabled")
|
|
678
|
+
arrow_table = self._coerce_arrow_table(source)
|
|
679
|
+
if overwrite:
|
|
680
|
+
self._truncate_table_sync(table)
|
|
681
|
+
columns, records = self._arrow_table_to_rows(arrow_table)
|
|
682
|
+
if records:
|
|
683
|
+
copy_sql = _build_copy_from_command(table, columns)
|
|
684
|
+
with ExitStack() as stack:
|
|
685
|
+
stack.enter_context(self.handle_database_exceptions())
|
|
686
|
+
cursor = stack.enter_context(self.with_cursor(self.connection))
|
|
687
|
+
copy_ctx = stack.enter_context(cursor.copy(copy_sql))
|
|
688
|
+
for record in records:
|
|
689
|
+
copy_ctx.write_row(record)
|
|
690
|
+
telemetry_payload = self._build_ingest_telemetry(arrow_table)
|
|
691
|
+
telemetry_payload["destination"] = table
|
|
692
|
+
self._attach_partition_telemetry(telemetry_payload, partitioner)
|
|
693
|
+
return self._create_storage_job(telemetry_payload, telemetry)
|
|
694
|
+
|
|
695
|
+
def load_from_storage(
|
|
696
|
+
self,
|
|
697
|
+
table: str,
|
|
698
|
+
source: "StorageDestination",
|
|
699
|
+
*,
|
|
700
|
+
file_format: "StorageFormat",
|
|
701
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
702
|
+
overwrite: bool = False,
|
|
703
|
+
) -> "StorageBridgeJob":
|
|
704
|
+
"""Load staged artifacts into PostgreSQL via COPY."""
|
|
705
|
+
|
|
706
|
+
arrow_table, inbound = self._read_arrow_from_storage_sync(source, file_format=file_format)
|
|
707
|
+
return self.load_from_arrow(table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound)
|
|
708
|
+
|
|
709
|
+
@property
|
|
710
|
+
def data_dictionary(self) -> "SyncDataDictionaryBase":
|
|
711
|
+
"""Get the data dictionary for this driver.
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
Data dictionary instance for metadata queries
|
|
715
|
+
"""
|
|
716
|
+
if self._data_dictionary is None:
|
|
717
|
+
from sqlspec.adapters.psycopg.data_dictionary import PostgresSyncDataDictionary
|
|
718
|
+
|
|
719
|
+
self._data_dictionary = PostgresSyncDataDictionary()
|
|
720
|
+
return self._data_dictionary
|
|
721
|
+
|
|
722
|
+
def _truncate_table_sync(self, table: str) -> None:
|
|
723
|
+
truncate_sql = _build_truncate_command(table)
|
|
724
|
+
with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions():
|
|
725
|
+
cursor.execute(truncate_sql)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
class PsycopgAsyncCursor:
|
|
729
|
+
"""Async context manager for PostgreSQL psycopg cursor management."""
|
|
730
|
+
|
|
731
|
+
__slots__ = ("connection", "cursor")
|
|
732
|
+
|
|
733
|
+
def __init__(self, connection: "PsycopgAsyncConnection") -> None:
|
|
734
|
+
self.connection = connection
|
|
735
|
+
self.cursor: Any | None = None
|
|
736
|
+
|
|
737
|
+
async def __aenter__(self) -> Any:
|
|
738
|
+
self.cursor = self.connection.cursor()
|
|
739
|
+
return self.cursor
|
|
740
|
+
|
|
741
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
742
|
+
_ = (exc_type, exc_val, exc_tb)
|
|
743
|
+
if self.cursor is not None:
|
|
744
|
+
await self.cursor.close()
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
class PsycopgAsyncExceptionHandler:
|
|
748
|
+
"""Async context manager for handling PostgreSQL psycopg database exceptions.
|
|
749
|
+
|
|
750
|
+
Maps PostgreSQL SQLSTATE error codes to specific SQLSpec exceptions
|
|
751
|
+
for better error handling in application code.
|
|
752
|
+
"""
|
|
753
|
+
|
|
754
|
+
__slots__ = ()
|
|
755
|
+
|
|
756
|
+
async def __aenter__(self) -> None:
|
|
757
|
+
return None
|
|
758
|
+
|
|
759
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
760
|
+
if exc_type is None:
|
|
761
|
+
return
|
|
762
|
+
if issubclass(exc_type, psycopg.Error):
|
|
763
|
+
self._map_postgres_exception(exc_val)
|
|
764
|
+
|
|
765
|
+
def _map_postgres_exception(self, e: Any) -> None:
|
|
766
|
+
"""Map PostgreSQL exception to SQLSpec exception.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
e: psycopg.Error instance
|
|
770
|
+
|
|
771
|
+
Raises:
|
|
772
|
+
Specific SQLSpec exception based on SQLSTATE code
|
|
773
|
+
"""
|
|
774
|
+
error_code = getattr(e, "sqlstate", None)
|
|
775
|
+
|
|
776
|
+
if not error_code:
|
|
777
|
+
self._raise_generic_error(e, None)
|
|
778
|
+
return
|
|
779
|
+
|
|
780
|
+
if error_code == "23505":
|
|
781
|
+
self._raise_unique_violation(e, error_code)
|
|
782
|
+
elif error_code == "23503":
|
|
783
|
+
self._raise_foreign_key_violation(e, error_code)
|
|
784
|
+
elif error_code == "23502":
|
|
785
|
+
self._raise_not_null_violation(e, error_code)
|
|
786
|
+
elif error_code == "23514":
|
|
787
|
+
self._raise_check_violation(e, error_code)
|
|
788
|
+
elif error_code.startswith("23"):
|
|
789
|
+
self._raise_integrity_error(e, error_code)
|
|
790
|
+
elif error_code.startswith("42"):
|
|
791
|
+
self._raise_parsing_error(e, error_code)
|
|
792
|
+
elif error_code.startswith("08"):
|
|
793
|
+
self._raise_connection_error(e, error_code)
|
|
794
|
+
elif error_code.startswith("40"):
|
|
795
|
+
self._raise_transaction_error(e, error_code)
|
|
796
|
+
elif error_code.startswith("22"):
|
|
797
|
+
self._raise_data_error(e, error_code)
|
|
798
|
+
elif error_code.startswith(("53", "54", "55", "57", "58")):
|
|
799
|
+
self._raise_operational_error(e, error_code)
|
|
800
|
+
else:
|
|
801
|
+
self._raise_generic_error(e, error_code)
|
|
802
|
+
|
|
803
|
+
def _raise_unique_violation(self, e: Any, code: str) -> None:
|
|
804
|
+
msg = f"PostgreSQL unique constraint violation [{code}]: {e}"
|
|
805
|
+
raise UniqueViolationError(msg) from e
|
|
806
|
+
|
|
807
|
+
def _raise_foreign_key_violation(self, e: Any, code: str) -> None:
|
|
808
|
+
msg = f"PostgreSQL foreign key constraint violation [{code}]: {e}"
|
|
809
|
+
raise ForeignKeyViolationError(msg) from e
|
|
810
|
+
|
|
811
|
+
def _raise_not_null_violation(self, e: Any, code: str) -> None:
|
|
812
|
+
msg = f"PostgreSQL not-null constraint violation [{code}]: {e}"
|
|
813
|
+
raise NotNullViolationError(msg) from e
|
|
814
|
+
|
|
815
|
+
def _raise_check_violation(self, e: Any, code: str) -> None:
|
|
816
|
+
msg = f"PostgreSQL check constraint violation [{code}]: {e}"
|
|
817
|
+
raise CheckViolationError(msg) from e
|
|
818
|
+
|
|
819
|
+
def _raise_integrity_error(self, e: Any, code: str) -> None:
|
|
820
|
+
msg = f"PostgreSQL integrity constraint violation [{code}]: {e}"
|
|
821
|
+
raise IntegrityError(msg) from e
|
|
822
|
+
|
|
823
|
+
def _raise_parsing_error(self, e: Any, code: str) -> None:
|
|
824
|
+
msg = f"PostgreSQL SQL syntax error [{code}]: {e}"
|
|
825
|
+
raise SQLParsingError(msg) from e
|
|
826
|
+
|
|
827
|
+
def _raise_connection_error(self, e: Any, code: str) -> None:
|
|
828
|
+
msg = f"PostgreSQL connection error [{code}]: {e}"
|
|
829
|
+
raise DatabaseConnectionError(msg) from e
|
|
830
|
+
|
|
831
|
+
def _raise_transaction_error(self, e: Any, code: str) -> None:
|
|
832
|
+
msg = f"PostgreSQL transaction error [{code}]: {e}"
|
|
833
|
+
raise TransactionError(msg) from e
|
|
834
|
+
|
|
835
|
+
def _raise_data_error(self, e: Any, code: str) -> None:
|
|
836
|
+
msg = f"PostgreSQL data error [{code}]: {e}"
|
|
837
|
+
raise DataError(msg) from e
|
|
838
|
+
|
|
839
|
+
def _raise_operational_error(self, e: Any, code: str) -> None:
|
|
840
|
+
msg = f"PostgreSQL operational error [{code}]: {e}"
|
|
841
|
+
raise OperationalError(msg) from e
|
|
842
|
+
|
|
843
|
+
def _raise_generic_error(self, e: Any, code: "str | None") -> None:
|
|
844
|
+
msg = f"PostgreSQL database error [{code}]: {e}" if code else f"PostgreSQL database error: {e}"
|
|
845
|
+
raise SQLSpecError(msg) from e
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
class PsycopgAsyncDriver(PsycopgPipelineMixin, AsyncDriverAdapterBase):
|
|
849
|
+
"""PostgreSQL psycopg asynchronous driver.
|
|
850
|
+
|
|
851
|
+
Provides asynchronous database operations for PostgreSQL using psycopg3.
|
|
852
|
+
Supports async SQL statement execution with parameter binding, async
|
|
853
|
+
transaction management, async result processing with column metadata,
|
|
854
|
+
parameter style conversion, PostgreSQL arrays and JSON handling, COPY
|
|
855
|
+
operations for bulk data transfer, PostgreSQL-specific error handling,
|
|
856
|
+
and async pub/sub support.
|
|
857
|
+
"""
|
|
858
|
+
|
|
859
|
+
__slots__ = ("_data_dictionary",)
|
|
860
|
+
dialect = "postgres"
|
|
861
|
+
|
|
862
|
+
def __init__(
|
|
863
|
+
self,
|
|
864
|
+
connection: "PsycopgAsyncConnection",
|
|
865
|
+
statement_config: "StatementConfig | None" = None,
|
|
866
|
+
driver_features: "dict[str, Any] | None" = None,
|
|
867
|
+
) -> None:
|
|
868
|
+
if statement_config is None:
|
|
869
|
+
cache_config = get_cache_config()
|
|
870
|
+
default_config = psycopg_statement_config.replace(
|
|
871
|
+
enable_caching=cache_config.compiled_cache_enabled,
|
|
872
|
+
enable_parsing=True,
|
|
873
|
+
enable_validation=True,
|
|
874
|
+
dialect="postgres",
|
|
875
|
+
)
|
|
876
|
+
statement_config = default_config
|
|
877
|
+
|
|
878
|
+
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
|
|
879
|
+
self._data_dictionary: AsyncDataDictionaryBase | None = None
|
|
880
|
+
|
|
881
|
+
def with_cursor(self, connection: "PsycopgAsyncConnection") -> "PsycopgAsyncCursor":
|
|
882
|
+
"""Create async context manager for PostgreSQL cursor."""
|
|
883
|
+
return PsycopgAsyncCursor(connection)
|
|
884
|
+
|
|
885
|
+
async def begin(self) -> None:
|
|
886
|
+
"""Begin a database transaction on the current connection."""
|
|
887
|
+
try:
|
|
888
|
+
autocommit_flag = getattr(self.connection, "autocommit", None)
|
|
889
|
+
if isinstance(autocommit_flag, bool) and not autocommit_flag:
|
|
890
|
+
return
|
|
891
|
+
await self.connection.set_autocommit(False)
|
|
892
|
+
except Exception as e:
|
|
893
|
+
msg = f"Failed to begin transaction: {e}"
|
|
894
|
+
raise SQLSpecError(msg) from e
|
|
895
|
+
|
|
896
|
+
async def rollback(self) -> None:
|
|
897
|
+
"""Rollback the current transaction on the current connection."""
|
|
898
|
+
try:
|
|
899
|
+
await self.connection.rollback()
|
|
900
|
+
except Exception as e:
|
|
901
|
+
msg = f"Failed to rollback transaction: {e}"
|
|
902
|
+
raise SQLSpecError(msg) from e
|
|
903
|
+
|
|
904
|
+
async def commit(self) -> None:
|
|
905
|
+
"""Commit the current transaction on the current connection."""
|
|
906
|
+
try:
|
|
907
|
+
await self.connection.commit()
|
|
908
|
+
except Exception as e:
|
|
909
|
+
msg = f"Failed to commit transaction: {e}"
|
|
910
|
+
raise SQLSpecError(msg) from e
|
|
911
|
+
|
|
912
|
+
def handle_database_exceptions(self) -> "AbstractAsyncContextManager[None]":
|
|
913
|
+
"""Handle database-specific exceptions and wrap them appropriately."""
|
|
914
|
+
return PsycopgAsyncExceptionHandler()
|
|
915
|
+
|
|
916
|
+
async def _handle_transaction_error_cleanup_async(self) -> None:
|
|
917
|
+
"""Handle async transaction cleanup after database errors."""
|
|
918
|
+
try:
|
|
919
|
+
if hasattr(self.connection, "info") and hasattr(self.connection.info, "transaction_status"):
|
|
920
|
+
status = self.connection.info.transaction_status
|
|
921
|
+
|
|
922
|
+
if status == TRANSACTION_STATUS_INERROR:
|
|
923
|
+
logger.debug("Connection in aborted transaction state, performing async rollback")
|
|
924
|
+
await self.connection.rollback()
|
|
925
|
+
except Exception as cleanup_error:
|
|
926
|
+
logger.warning("Failed to cleanup transaction state: %s", cleanup_error)
|
|
927
|
+
|
|
928
|
+
async def _try_special_handling(self, cursor: Any, statement: "SQL") -> "SQLResult | None":
|
|
929
|
+
"""Hook for PostgreSQL-specific special operations.
|
|
930
|
+
|
|
931
|
+
Args:
|
|
932
|
+
cursor: Psycopg async cursor object
|
|
933
|
+
statement: SQL statement to analyze
|
|
934
|
+
|
|
935
|
+
Returns:
|
|
936
|
+
SQLResult if special handling was applied, None otherwise
|
|
937
|
+
"""
|
|
938
|
+
|
|
939
|
+
statement.compile()
|
|
940
|
+
|
|
941
|
+
if is_copy_operation(statement.operation_type):
|
|
942
|
+
return await self._handle_copy_operation_async(cursor, statement)
|
|
943
|
+
|
|
944
|
+
return None
|
|
945
|
+
|
|
946
|
+
async def _handle_copy_operation_async(self, cursor: Any, statement: "SQL") -> "SQLResult":
|
|
947
|
+
"""Handle PostgreSQL COPY operations (async).
|
|
948
|
+
|
|
949
|
+
Args:
|
|
950
|
+
cursor: Psycopg async cursor object
|
|
951
|
+
statement: SQL statement with COPY operation
|
|
952
|
+
|
|
953
|
+
Returns:
|
|
954
|
+
SQLResult with COPY operation results
|
|
955
|
+
"""
|
|
956
|
+
|
|
957
|
+
sql = statement.sql
|
|
958
|
+
sql_upper = sql.upper()
|
|
959
|
+
operation_type = statement.operation_type
|
|
960
|
+
copy_data = statement.parameters
|
|
961
|
+
if isinstance(copy_data, list) and len(copy_data) == 1:
|
|
962
|
+
copy_data = copy_data[0]
|
|
963
|
+
|
|
964
|
+
if is_copy_from_operation(operation_type) and "FROM STDIN" in sql_upper:
|
|
965
|
+
if isinstance(copy_data, (str, bytes)):
|
|
966
|
+
data_file = io.StringIO(copy_data) if isinstance(copy_data, str) else io.BytesIO(copy_data)
|
|
967
|
+
elif hasattr(copy_data, "read"):
|
|
968
|
+
data_file = copy_data
|
|
969
|
+
else:
|
|
970
|
+
data_file = io.StringIO(str(copy_data))
|
|
971
|
+
|
|
972
|
+
async with cursor.copy(sql) as copy_ctx:
|
|
973
|
+
data_to_write = data_file.read() if hasattr(data_file, "read") else str(copy_data) # pyright: ignore
|
|
974
|
+
if isinstance(data_to_write, str):
|
|
975
|
+
data_to_write = data_to_write.encode()
|
|
976
|
+
await copy_ctx.write(data_to_write)
|
|
977
|
+
|
|
978
|
+
rows_affected = max(cursor.rowcount, 0)
|
|
979
|
+
|
|
980
|
+
return SQLResult(
|
|
981
|
+
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FROM_STDIN"}
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
if is_copy_to_operation(operation_type) and "TO STDOUT" in sql_upper:
|
|
985
|
+
output_data: list[str] = []
|
|
986
|
+
async with cursor.copy(sql) as copy_ctx:
|
|
987
|
+
output_data.extend([row.decode() if isinstance(row, bytes) else str(row) async for row in copy_ctx])
|
|
988
|
+
|
|
989
|
+
exported_data = "".join(output_data)
|
|
990
|
+
|
|
991
|
+
return SQLResult(
|
|
992
|
+
data=[{"copy_output": exported_data}],
|
|
993
|
+
rows_affected=0,
|
|
994
|
+
statement=statement,
|
|
995
|
+
metadata={"copy_operation": "TO_STDOUT"},
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
await cursor.execute(sql)
|
|
999
|
+
rows_affected = max(cursor.rowcount, 0)
|
|
1000
|
+
|
|
1001
|
+
return SQLResult(
|
|
1002
|
+
data=None, rows_affected=rows_affected, statement=statement, metadata={"copy_operation": "FILE"}
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
async def _execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
1006
|
+
"""Execute SQL script with multiple statements (async).
|
|
1007
|
+
|
|
1008
|
+
Args:
|
|
1009
|
+
cursor: Database cursor
|
|
1010
|
+
statement: SQL statement containing multiple commands
|
|
1011
|
+
|
|
1012
|
+
Returns:
|
|
1013
|
+
ExecutionResult with script execution details
|
|
1014
|
+
"""
|
|
1015
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
1016
|
+
statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True)
|
|
1017
|
+
|
|
1018
|
+
successful_count = 0
|
|
1019
|
+
last_cursor = cursor
|
|
1020
|
+
|
|
1021
|
+
for stmt in statements:
|
|
1022
|
+
if prepared_parameters:
|
|
1023
|
+
await cursor.execute(stmt, prepared_parameters)
|
|
1024
|
+
else:
|
|
1025
|
+
await cursor.execute(stmt)
|
|
1026
|
+
successful_count += 1
|
|
1027
|
+
|
|
1028
|
+
return self.create_execution_result(
|
|
1029
|
+
last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
async def execute_stack(
|
|
1033
|
+
self, stack: "StatementStack", *, continue_on_error: bool = False
|
|
1034
|
+
) -> "tuple[StackResult, ...]":
|
|
1035
|
+
"""Execute a StatementStack using psycopg async pipeline when supported."""
|
|
1036
|
+
|
|
1037
|
+
if (
|
|
1038
|
+
not isinstance(stack, StatementStack)
|
|
1039
|
+
or not stack
|
|
1040
|
+
or self.stack_native_disabled
|
|
1041
|
+
or not _psycopg_pipeline_supported()
|
|
1042
|
+
or continue_on_error
|
|
1043
|
+
):
|
|
1044
|
+
return await super().execute_stack(stack, continue_on_error=continue_on_error)
|
|
1045
|
+
|
|
1046
|
+
prepared_ops = self._prepare_pipeline_operations(stack)
|
|
1047
|
+
if prepared_ops is None:
|
|
1048
|
+
return await super().execute_stack(stack, continue_on_error=continue_on_error)
|
|
1049
|
+
|
|
1050
|
+
return await self._execute_stack_pipeline(stack, prepared_ops)
|
|
1051
|
+
|
|
1052
|
+
async def _execute_stack_pipeline(
|
|
1053
|
+
self, stack: "StatementStack", prepared_ops: "list[_PreparedStackOperation]"
|
|
1054
|
+
) -> "tuple[StackResult, ...]":
|
|
1055
|
+
results: list[StackResult] = []
|
|
1056
|
+
started_transaction = False
|
|
1057
|
+
|
|
1058
|
+
with StackExecutionObserver(self, stack, continue_on_error=False, native_pipeline=True):
|
|
1059
|
+
try:
|
|
1060
|
+
if not self._connection_in_transaction():
|
|
1061
|
+
await self.begin()
|
|
1062
|
+
started_transaction = True
|
|
1063
|
+
|
|
1064
|
+
async with AsyncExitStack() as resource_stack:
|
|
1065
|
+
pipeline = await resource_stack.enter_async_context(self.connection.pipeline())
|
|
1066
|
+
pending: list[_PipelineCursorEntry] = []
|
|
1067
|
+
|
|
1068
|
+
for prepared in prepared_ops:
|
|
1069
|
+
exception_ctx = self.handle_database_exceptions()
|
|
1070
|
+
await resource_stack.enter_async_context(exception_ctx)
|
|
1071
|
+
cursor = await resource_stack.enter_async_context(self.with_cursor(self.connection))
|
|
1072
|
+
|
|
1073
|
+
try:
|
|
1074
|
+
if prepared.parameters:
|
|
1075
|
+
await cursor.execute(prepared.sql, prepared.parameters)
|
|
1076
|
+
else:
|
|
1077
|
+
await cursor.execute(prepared.sql)
|
|
1078
|
+
except Exception as exc:
|
|
1079
|
+
stack_error = StackExecutionError(
|
|
1080
|
+
prepared.operation_index,
|
|
1081
|
+
describe_stack_statement(prepared.operation.statement),
|
|
1082
|
+
exc,
|
|
1083
|
+
adapter=type(self).__name__,
|
|
1084
|
+
mode="fail-fast",
|
|
1085
|
+
)
|
|
1086
|
+
raise stack_error from exc
|
|
1087
|
+
|
|
1088
|
+
pending.append(_PipelineCursorEntry(prepared=prepared, cursor=cursor))
|
|
1089
|
+
|
|
1090
|
+
await pipeline.sync()
|
|
1091
|
+
|
|
1092
|
+
results.extend([await self._build_pipeline_stack_result_async(entry) for entry in pending])
|
|
1093
|
+
|
|
1094
|
+
if started_transaction:
|
|
1095
|
+
await self.commit()
|
|
1096
|
+
except Exception:
|
|
1097
|
+
if started_transaction:
|
|
1098
|
+
try:
|
|
1099
|
+
await self.rollback()
|
|
1100
|
+
except Exception as rollback_error: # pragma: no cover - diagnostics only
|
|
1101
|
+
logger.debug("Rollback after psycopg pipeline failure failed: %s", rollback_error)
|
|
1102
|
+
raise
|
|
1103
|
+
|
|
1104
|
+
return tuple(results)
|
|
1105
|
+
|
|
1106
|
+
async def _build_pipeline_stack_result_async(self, entry: "_PipelineCursorEntry") -> StackResult:
|
|
1107
|
+
statement = entry.prepared.statement
|
|
1108
|
+
cursor = entry.cursor
|
|
1109
|
+
|
|
1110
|
+
if statement.returns_rows():
|
|
1111
|
+
fetched_data = await cursor.fetchall()
|
|
1112
|
+
column_names = [col.name for col in cursor.description or []]
|
|
1113
|
+
execution_result = self.create_execution_result(
|
|
1114
|
+
cursor,
|
|
1115
|
+
selected_data=fetched_data,
|
|
1116
|
+
column_names=column_names,
|
|
1117
|
+
data_row_count=len(fetched_data),
|
|
1118
|
+
is_select_result=True,
|
|
1119
|
+
)
|
|
1120
|
+
else:
|
|
1121
|
+
affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
|
|
1122
|
+
execution_result = self.create_execution_result(cursor, rowcount_override=affected_rows)
|
|
1123
|
+
|
|
1124
|
+
sql_result = self.build_statement_result(statement, execution_result)
|
|
1125
|
+
return StackResult.from_sql_result(sql_result)
|
|
1126
|
+
|
|
1127
|
+
async def _execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
1128
|
+
"""Execute SQL with multiple parameter sets (async).
|
|
1129
|
+
|
|
1130
|
+
Args:
|
|
1131
|
+
cursor: Database cursor
|
|
1132
|
+
statement: SQL statement with parameter list
|
|
1133
|
+
|
|
1134
|
+
Returns:
|
|
1135
|
+
ExecutionResult with batch execution details
|
|
1136
|
+
"""
|
|
1137
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
1138
|
+
|
|
1139
|
+
if not prepared_parameters:
|
|
1140
|
+
return self.create_execution_result(cursor, rowcount_override=0, is_many_result=True)
|
|
1141
|
+
|
|
1142
|
+
await cursor.executemany(sql, prepared_parameters)
|
|
1143
|
+
|
|
1144
|
+
affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
|
|
1145
|
+
|
|
1146
|
+
return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True)
|
|
1147
|
+
|
|
1148
|
+
async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
|
|
1149
|
+
"""Execute single SQL statement (async).
|
|
1150
|
+
|
|
1151
|
+
Args:
|
|
1152
|
+
cursor: Database cursor
|
|
1153
|
+
statement: SQL statement to execute
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
ExecutionResult with statement execution details
|
|
1157
|
+
"""
|
|
1158
|
+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
|
|
1159
|
+
|
|
1160
|
+
if prepared_parameters:
|
|
1161
|
+
await cursor.execute(sql, prepared_parameters)
|
|
1162
|
+
else:
|
|
1163
|
+
await cursor.execute(sql)
|
|
1164
|
+
|
|
1165
|
+
if statement.returns_rows():
|
|
1166
|
+
fetched_data = await cursor.fetchall()
|
|
1167
|
+
column_names = [col.name for col in cursor.description or []]
|
|
1168
|
+
|
|
1169
|
+
return self.create_execution_result(
|
|
1170
|
+
cursor,
|
|
1171
|
+
selected_data=fetched_data,
|
|
1172
|
+
column_names=column_names,
|
|
1173
|
+
data_row_count=len(fetched_data),
|
|
1174
|
+
is_select_result=True,
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
affected_rows = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0
|
|
1178
|
+
return self.create_execution_result(cursor, rowcount_override=affected_rows)
|
|
1179
|
+
|
|
1180
|
+
async def select_to_storage(
|
|
1181
|
+
self,
|
|
1182
|
+
statement: "SQL | str",
|
|
1183
|
+
destination: "StorageDestination",
|
|
1184
|
+
/,
|
|
1185
|
+
*parameters: Any,
|
|
1186
|
+
statement_config: "StatementConfig | None" = None,
|
|
1187
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
1188
|
+
format_hint: "StorageFormat | None" = None,
|
|
1189
|
+
telemetry: "StorageTelemetry | None" = None,
|
|
1190
|
+
**kwargs: Any,
|
|
1191
|
+
) -> "StorageBridgeJob":
|
|
1192
|
+
"""Execute a query and stream Arrow data to storage asynchronously."""
|
|
1193
|
+
|
|
1194
|
+
self._require_capability("arrow_export_enabled")
|
|
1195
|
+
arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs)
|
|
1196
|
+
async_pipeline: AsyncStoragePipeline = cast("AsyncStoragePipeline", self._storage_pipeline())
|
|
1197
|
+
telemetry_payload = await self._write_result_to_storage_async(
|
|
1198
|
+
arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline
|
|
1199
|
+
)
|
|
1200
|
+
self._attach_partition_telemetry(telemetry_payload, partitioner)
|
|
1201
|
+
return self._create_storage_job(telemetry_payload, telemetry)
|
|
1202
|
+
|
|
1203
|
+
async def load_from_arrow(
|
|
1204
|
+
self,
|
|
1205
|
+
table: str,
|
|
1206
|
+
source: "ArrowResult | Any",
|
|
1207
|
+
*,
|
|
1208
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
1209
|
+
overwrite: bool = False,
|
|
1210
|
+
telemetry: "StorageTelemetry | None" = None,
|
|
1211
|
+
) -> "StorageBridgeJob":
|
|
1212
|
+
"""Load Arrow data into PostgreSQL asynchronously via COPY."""
|
|
1213
|
+
|
|
1214
|
+
self._require_capability("arrow_import_enabled")
|
|
1215
|
+
arrow_table = self._coerce_arrow_table(source)
|
|
1216
|
+
if overwrite:
|
|
1217
|
+
await self._truncate_table_async(table)
|
|
1218
|
+
columns, records = self._arrow_table_to_rows(arrow_table)
|
|
1219
|
+
if records:
|
|
1220
|
+
copy_sql = _build_copy_from_command(table, columns)
|
|
1221
|
+
async with AsyncExitStack() as stack:
|
|
1222
|
+
await stack.enter_async_context(self.handle_database_exceptions())
|
|
1223
|
+
cursor = await stack.enter_async_context(self.with_cursor(self.connection))
|
|
1224
|
+
copy_ctx = await stack.enter_async_context(cursor.copy(copy_sql))
|
|
1225
|
+
for record in records:
|
|
1226
|
+
await copy_ctx.write_row(record)
|
|
1227
|
+
telemetry_payload = self._build_ingest_telemetry(arrow_table)
|
|
1228
|
+
telemetry_payload["destination"] = table
|
|
1229
|
+
self._attach_partition_telemetry(telemetry_payload, partitioner)
|
|
1230
|
+
return self._create_storage_job(telemetry_payload, telemetry)
|
|
1231
|
+
|
|
1232
|
+
async def load_from_storage(
|
|
1233
|
+
self,
|
|
1234
|
+
table: str,
|
|
1235
|
+
source: "StorageDestination",
|
|
1236
|
+
*,
|
|
1237
|
+
file_format: "StorageFormat",
|
|
1238
|
+
partitioner: "dict[str, Any] | None" = None,
|
|
1239
|
+
overwrite: bool = False,
|
|
1240
|
+
) -> "StorageBridgeJob":
|
|
1241
|
+
"""Load staged artifacts asynchronously."""
|
|
1242
|
+
|
|
1243
|
+
arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format)
|
|
1244
|
+
return await self.load_from_arrow(
|
|
1245
|
+
table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
@property
|
|
1249
|
+
def data_dictionary(self) -> "AsyncDataDictionaryBase":
|
|
1250
|
+
"""Get the data dictionary for this driver.
|
|
1251
|
+
|
|
1252
|
+
Returns:
|
|
1253
|
+
Data dictionary instance for metadata queries
|
|
1254
|
+
"""
|
|
1255
|
+
if self._data_dictionary is None:
|
|
1256
|
+
from sqlspec.adapters.psycopg.data_dictionary import PostgresAsyncDataDictionary
|
|
1257
|
+
|
|
1258
|
+
self._data_dictionary = PostgresAsyncDataDictionary()
|
|
1259
|
+
return self._data_dictionary
|
|
1260
|
+
|
|
1261
|
+
async def _truncate_table_async(self, table: str) -> None:
|
|
1262
|
+
truncate_sql = _build_truncate_command(table)
|
|
1263
|
+
async with self.with_cursor(self.connection) as cursor, self.handle_database_exceptions():
|
|
1264
|
+
await cursor.execute(truncate_sql)
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def _identity(value: Any) -> Any:
|
|
1268
|
+
return value
|
|
1269
|
+
|
|
1270
|
+
|
|
1271
|
+
def _build_psycopg_custom_type_coercions() -> dict[type, "Callable[[Any], Any]"]:
|
|
1272
|
+
"""Return custom type coercions for psycopg."""
|
|
1273
|
+
|
|
1274
|
+
return {datetime.datetime: _identity, datetime.date: _identity, datetime.time: _identity}
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
def _build_psycopg_profile() -> DriverParameterProfile:
|
|
1278
|
+
"""Create the psycopg driver parameter profile."""
|
|
1279
|
+
|
|
1280
|
+
return DriverParameterProfile(
|
|
1281
|
+
name="Psycopg",
|
|
1282
|
+
default_style=ParameterStyle.POSITIONAL_PYFORMAT,
|
|
1283
|
+
supported_styles={
|
|
1284
|
+
ParameterStyle.POSITIONAL_PYFORMAT,
|
|
1285
|
+
ParameterStyle.NAMED_PYFORMAT,
|
|
1286
|
+
ParameterStyle.NUMERIC,
|
|
1287
|
+
ParameterStyle.QMARK,
|
|
1288
|
+
},
|
|
1289
|
+
default_execution_style=ParameterStyle.POSITIONAL_PYFORMAT,
|
|
1290
|
+
supported_execution_styles={ParameterStyle.POSITIONAL_PYFORMAT, ParameterStyle.NAMED_PYFORMAT},
|
|
1291
|
+
has_native_list_expansion=True,
|
|
1292
|
+
preserve_parameter_format=True,
|
|
1293
|
+
needs_static_script_compilation=False,
|
|
1294
|
+
allow_mixed_parameter_styles=False,
|
|
1295
|
+
preserve_original_params_for_many=False,
|
|
1296
|
+
json_serializer_strategy="helper",
|
|
1297
|
+
custom_type_coercions=_build_psycopg_custom_type_coercions(),
|
|
1298
|
+
default_dialect="postgres",
|
|
1299
|
+
)
|
|
1300
|
+
|
|
1301
|
+
|
|
1302
|
+
_PSYCOPG_PROFILE = _build_psycopg_profile()
|
|
1303
|
+
|
|
1304
|
+
register_driver_profile("psycopg", _PSYCOPG_PROFILE)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
def _create_psycopg_parameter_config(serializer: "Callable[[Any], str]") -> ParameterStyleConfig:
|
|
1308
|
+
"""Construct parameter configuration with shared JSON serializer support."""
|
|
1309
|
+
|
|
1310
|
+
base_config = build_statement_config_from_profile(_PSYCOPG_PROFILE, json_serializer=serializer).parameter_config
|
|
1311
|
+
|
|
1312
|
+
updated_type_map = dict(base_config.type_coercion_map)
|
|
1313
|
+
updated_type_map[list] = build_json_list_converter(serializer)
|
|
1314
|
+
updated_type_map[tuple] = build_json_tuple_converter(serializer)
|
|
1315
|
+
|
|
1316
|
+
return base_config.replace(type_coercion_map=updated_type_map)
|
|
1317
|
+
|
|
1318
|
+
|
|
1319
|
+
def build_psycopg_statement_config(*, json_serializer: "Callable[[Any], str]" = to_json) -> StatementConfig:
|
|
1320
|
+
"""Construct the psycopg statement configuration with optional JSON codecs."""
|
|
1321
|
+
|
|
1322
|
+
parameter_config = _create_psycopg_parameter_config(json_serializer)
|
|
1323
|
+
return StatementConfig(
|
|
1324
|
+
dialect="postgres",
|
|
1325
|
+
pre_process_steps=None,
|
|
1326
|
+
post_process_steps=None,
|
|
1327
|
+
enable_parsing=True,
|
|
1328
|
+
enable_transformations=True,
|
|
1329
|
+
enable_validation=True,
|
|
1330
|
+
enable_caching=True,
|
|
1331
|
+
enable_parameter_type_wrapping=True,
|
|
1332
|
+
parameter_config=parameter_config,
|
|
1333
|
+
)
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
psycopg_statement_config = build_psycopg_statement_config()
|