sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +100 -130
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +125 -167
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +114 -111
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +18 -31
- sqlspec/adapters/psycopg/driver.py +283 -236
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +103 -97
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +67 -181
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +25 -90
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +67 -22
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +91 -67
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +863 -391
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +53 -8
- sqlspec/storage/backends/obstore.py +15 -19
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -173
- sqlspec-0.12.2.dist-info/RECORD +0 -145
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
|
8
7
|
|
|
9
8
|
from psqlpy import ConnectionPool
|
|
@@ -302,7 +301,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
302
301
|
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
303
302
|
}
|
|
304
303
|
|
|
305
|
-
# Add connection-specific extras (not pool-specific ones)
|
|
306
304
|
config.update(self.extras)
|
|
307
305
|
|
|
308
306
|
return config
|
|
@@ -359,11 +357,9 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
359
357
|
Returns:
|
|
360
358
|
A psqlpy Connection instance.
|
|
361
359
|
"""
|
|
362
|
-
# Ensure pool exists
|
|
363
360
|
if not self.pool_instance:
|
|
364
361
|
self.pool_instance = await self._create_pool()
|
|
365
362
|
|
|
366
|
-
# Get connection from pool
|
|
367
363
|
return await self.pool_instance.connection()
|
|
368
364
|
|
|
369
365
|
@asynccontextmanager
|
|
@@ -377,7 +373,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
377
373
|
Yields:
|
|
378
374
|
A psqlpy Connection instance.
|
|
379
375
|
"""
|
|
380
|
-
# Ensure pool exists
|
|
381
376
|
if not self.pool_instance:
|
|
382
377
|
self.pool_instance = await self._create_pool()
|
|
383
378
|
|
|
@@ -396,15 +391,16 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
396
391
|
A PsqlpyDriver instance.
|
|
397
392
|
"""
|
|
398
393
|
async with self.provide_connection(*args, **kwargs) as conn:
|
|
399
|
-
# Create statement config with parameter style info if not already set
|
|
400
394
|
statement_config = self.statement_config
|
|
395
|
+
# Inject parameter style info if not already set
|
|
401
396
|
if statement_config.allowed_parameter_styles is None:
|
|
397
|
+
from dataclasses import replace
|
|
398
|
+
|
|
402
399
|
statement_config = replace(
|
|
403
400
|
statement_config,
|
|
404
401
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
405
402
|
target_parameter_style=self.preferred_parameter_style,
|
|
406
403
|
)
|
|
407
|
-
|
|
408
404
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
409
405
|
yield driver
|
|
410
406
|
|
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
4
|
import logging
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Optional,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
6
6
|
|
|
7
7
|
from psqlpy import Connection
|
|
8
8
|
|
|
9
9
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
10
|
+
from sqlspec.driver.connection import managed_transaction_async
|
|
10
11
|
from sqlspec.driver.mixins import (
|
|
11
12
|
AsyncPipelinedExecutionMixin,
|
|
12
13
|
AsyncStorageMixin,
|
|
@@ -14,10 +15,10 @@ from sqlspec.driver.mixins import (
|
|
|
14
15
|
ToSchemaMixin,
|
|
15
16
|
TypeCoercionMixin,
|
|
16
17
|
)
|
|
17
|
-
from sqlspec.statement.parameters import ParameterStyle
|
|
18
|
-
from sqlspec.statement.result import
|
|
18
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
19
|
+
from sqlspec.statement.result import SQLResult
|
|
19
20
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
20
|
-
from sqlspec.typing import DictRow,
|
|
21
|
+
from sqlspec.typing import DictRow, RowT
|
|
21
22
|
|
|
22
23
|
if TYPE_CHECKING:
|
|
23
24
|
from sqlglot.dialects.dialect import DialectType
|
|
@@ -76,13 +77,36 @@ class PsqlpyDriver(
|
|
|
76
77
|
|
|
77
78
|
async def _execute_statement(
|
|
78
79
|
self, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
79
|
-
) ->
|
|
80
|
+
) -> SQLResult[RowT]:
|
|
80
81
|
if statement.is_script:
|
|
81
82
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
82
83
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
83
84
|
|
|
84
|
-
#
|
|
85
|
-
|
|
85
|
+
# Detect parameter styles in the SQL
|
|
86
|
+
detected_styles = set()
|
|
87
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
88
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
89
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
90
|
+
if param_infos:
|
|
91
|
+
detected_styles = {p.style for p in param_infos}
|
|
92
|
+
|
|
93
|
+
# Determine target style based on what's in the SQL
|
|
94
|
+
target_style = self.default_parameter_style
|
|
95
|
+
|
|
96
|
+
# Check if there are unsupported styles
|
|
97
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
98
|
+
if unsupported_styles:
|
|
99
|
+
# Force conversion to default style
|
|
100
|
+
target_style = self.default_parameter_style
|
|
101
|
+
elif detected_styles:
|
|
102
|
+
# Prefer the first supported style found
|
|
103
|
+
for style in detected_styles:
|
|
104
|
+
if style in self.supported_parameter_styles:
|
|
105
|
+
target_style = style
|
|
106
|
+
break
|
|
107
|
+
|
|
108
|
+
# Compile with the determined style
|
|
109
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
86
110
|
params = self._process_parameters(params)
|
|
87
111
|
|
|
88
112
|
if statement.is_many:
|
|
@@ -92,43 +116,99 @@ class PsqlpyDriver(
|
|
|
92
116
|
|
|
93
117
|
async def _execute(
|
|
94
118
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
95
|
-
) ->
|
|
96
|
-
|
|
97
|
-
if self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
119
|
+
) -> SQLResult[RowT]:
|
|
120
|
+
# Use provided connection or driver's default connection
|
|
121
|
+
conn = connection if connection is not None else self._connection(None)
|
|
122
|
+
|
|
123
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
124
|
+
# PSQLPy expects parameters as a list (for $1, $2, etc.) or dict
|
|
125
|
+
# Ensure we always pass a sequence or mapping, never a scalar
|
|
126
|
+
final_params: Any
|
|
127
|
+
if isinstance(parameters, (list, tuple)):
|
|
128
|
+
final_params = list(parameters)
|
|
129
|
+
elif isinstance(parameters, dict):
|
|
130
|
+
final_params = parameters
|
|
131
|
+
elif parameters is None:
|
|
132
|
+
final_params = []
|
|
133
|
+
else:
|
|
134
|
+
# Single parameter - wrap in list for NUMERIC style ($1)
|
|
135
|
+
final_params = [parameters]
|
|
136
|
+
|
|
137
|
+
if self.returns_rows(statement.expression):
|
|
138
|
+
query_result = await txn_conn.fetch(sql, parameters=final_params)
|
|
139
|
+
dict_rows: list[dict[str, Any]] = []
|
|
140
|
+
if query_result:
|
|
141
|
+
# psqlpy QueryResult has a result() method that returns list of dicts
|
|
142
|
+
dict_rows = query_result.result()
|
|
143
|
+
column_names = list(dict_rows[0].keys()) if dict_rows else []
|
|
144
|
+
return SQLResult(
|
|
145
|
+
statement=statement,
|
|
146
|
+
data=cast("list[RowT]", dict_rows),
|
|
147
|
+
column_names=column_names,
|
|
148
|
+
rows_affected=len(dict_rows),
|
|
149
|
+
operation_type="SELECT",
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
query_result = await txn_conn.execute(sql, parameters=final_params)
|
|
153
|
+
# Note: psqlpy doesn't provide rows_affected for DML operations
|
|
154
|
+
# The QueryResult object only has result(), as_class(), and row_factory() methods
|
|
155
|
+
affected_count = -1 # Unknown, as psqlpy doesn't provide this info
|
|
156
|
+
return SQLResult(
|
|
157
|
+
statement=statement,
|
|
158
|
+
data=[],
|
|
159
|
+
rows_affected=affected_count,
|
|
160
|
+
operation_type=self._determine_operation_type(statement),
|
|
161
|
+
metadata={"status_message": "OK"},
|
|
162
|
+
)
|
|
112
163
|
|
|
113
164
|
async def _execute_many(
|
|
114
165
|
self, sql: str, param_list: Any, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
115
|
-
) ->
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
166
|
+
) -> SQLResult[RowT]:
|
|
167
|
+
# Use provided connection or driver's default connection
|
|
168
|
+
conn = connection if connection is not None else self._connection(None)
|
|
169
|
+
|
|
170
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
171
|
+
# PSQLPy expects a list of parameter lists/tuples for execute_many
|
|
172
|
+
if param_list is None:
|
|
173
|
+
final_param_list = []
|
|
174
|
+
elif isinstance(param_list, (list, tuple)):
|
|
175
|
+
# Ensure each parameter set is a list/tuple
|
|
176
|
+
final_param_list = [
|
|
177
|
+
list(params) if isinstance(params, (list, tuple)) else [params] for params in param_list
|
|
178
|
+
]
|
|
179
|
+
else:
|
|
180
|
+
# Single parameter set - wrap it
|
|
181
|
+
final_param_list = [list(param_list) if isinstance(param_list, (list, tuple)) else [param_list]]
|
|
182
|
+
|
|
183
|
+
await txn_conn.execute_many(sql, final_param_list)
|
|
184
|
+
# execute_many doesn't return a value with rows_affected
|
|
185
|
+
affected_count = -1
|
|
186
|
+
return SQLResult(
|
|
187
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
188
|
+
data=[],
|
|
189
|
+
rows_affected=affected_count,
|
|
190
|
+
operation_type="EXECUTE",
|
|
191
|
+
metadata={"status_message": "OK"},
|
|
192
|
+
)
|
|
121
193
|
|
|
122
194
|
async def _execute_script(
|
|
123
195
|
self, script: str, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
124
|
-
) ->
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
196
|
+
) -> SQLResult[RowT]:
|
|
197
|
+
# Use provided connection or driver's default connection
|
|
198
|
+
conn = connection if connection is not None else self._connection(None)
|
|
199
|
+
|
|
200
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
201
|
+
# psqlpy can execute multi-statement scripts directly
|
|
202
|
+
await txn_conn.execute(script)
|
|
203
|
+
return SQLResult(
|
|
204
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
205
|
+
data=[],
|
|
206
|
+
rows_affected=0,
|
|
207
|
+
operation_type="SCRIPT",
|
|
208
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
209
|
+
total_statements=-1, # Not directly supported, but script is executed
|
|
210
|
+
successful_statements=-1,
|
|
211
|
+
)
|
|
132
212
|
|
|
133
213
|
async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
|
|
134
214
|
self._ensure_pyarrow_installed()
|
|
@@ -154,61 +234,6 @@ class PsqlpyDriver(
|
|
|
154
234
|
msg = "Connection does not support COPY operations"
|
|
155
235
|
raise NotImplementedError(msg)
|
|
156
236
|
|
|
157
|
-
async def _wrap_select_result(
|
|
158
|
-
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
159
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
160
|
-
dict_rows = result["data"]
|
|
161
|
-
column_names = result["column_names"]
|
|
162
|
-
rows_affected = result["rows_affected"]
|
|
163
|
-
|
|
164
|
-
if schema_type:
|
|
165
|
-
converted_data = self.to_schema(data=dict_rows, schema_type=schema_type)
|
|
166
|
-
return SQLResult[ModelDTOT](
|
|
167
|
-
statement=statement,
|
|
168
|
-
data=list(converted_data),
|
|
169
|
-
column_names=column_names,
|
|
170
|
-
rows_affected=rows_affected,
|
|
171
|
-
operation_type="SELECT",
|
|
172
|
-
)
|
|
173
|
-
return SQLResult[RowT](
|
|
174
|
-
statement=statement,
|
|
175
|
-
data=cast("list[RowT]", dict_rows),
|
|
176
|
-
column_names=column_names,
|
|
177
|
-
rows_affected=rows_affected,
|
|
178
|
-
operation_type="SELECT",
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
async def _wrap_execute_result(
|
|
182
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
183
|
-
) -> SQLResult[RowT]:
|
|
184
|
-
operation_type = "UNKNOWN"
|
|
185
|
-
if statement.expression:
|
|
186
|
-
operation_type = str(statement.expression.key).upper()
|
|
187
|
-
|
|
188
|
-
if "statements_executed" in result:
|
|
189
|
-
script_result = cast("ScriptResultDict", result)
|
|
190
|
-
return SQLResult[RowT](
|
|
191
|
-
statement=statement,
|
|
192
|
-
data=[],
|
|
193
|
-
rows_affected=0,
|
|
194
|
-
operation_type="SCRIPT",
|
|
195
|
-
metadata={
|
|
196
|
-
"status_message": script_result.get("status_message", ""),
|
|
197
|
-
"statements_executed": script_result.get("statements_executed", -1),
|
|
198
|
-
},
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
dml_result = cast("DMLResultDict", result)
|
|
202
|
-
rows_affected = dml_result.get("rows_affected", -1)
|
|
203
|
-
status_message = dml_result.get("status_message", "")
|
|
204
|
-
return SQLResult[RowT](
|
|
205
|
-
statement=statement,
|
|
206
|
-
data=[],
|
|
207
|
-
rows_affected=rows_affected,
|
|
208
|
-
operation_type=operation_type,
|
|
209
|
-
metadata={"status_message": status_message},
|
|
210
|
-
)
|
|
211
|
-
|
|
212
237
|
def _connection(self, connection: Optional[PsqlpyConnection] = None) -> PsqlpyConnection:
|
|
213
238
|
"""Get the connection to use for the operation."""
|
|
214
239
|
return connection or self.connection
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import contextlib
|
|
4
4
|
import logging
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
|
|
8
7
|
|
|
9
8
|
from psycopg.rows import dict_row
|
|
@@ -211,7 +210,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
211
210
|
self.configure = configure
|
|
212
211
|
self.kwargs = kwargs or {}
|
|
213
212
|
|
|
214
|
-
# Handle extras and additional kwargs
|
|
215
213
|
self.extras = extras or {}
|
|
216
214
|
self.extras.update(additional_kwargs)
|
|
217
215
|
|
|
@@ -240,7 +238,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
240
238
|
if self.kwargs:
|
|
241
239
|
config.update(self.kwargs)
|
|
242
240
|
|
|
243
|
-
# Set DictRow as the row factory
|
|
244
241
|
config["row_factory"] = dict_row
|
|
245
242
|
|
|
246
243
|
return config
|
|
@@ -263,7 +260,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
263
260
|
if self.kwargs:
|
|
264
261
|
config.update(self.kwargs)
|
|
265
262
|
|
|
266
|
-
# Set DictRow as the row factory
|
|
267
263
|
config["row_factory"] = dict_row
|
|
268
264
|
|
|
269
265
|
return config
|
|
@@ -273,7 +269,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
273
269
|
logger.info("Creating Psycopg connection pool", extra={"adapter": "psycopg"})
|
|
274
270
|
|
|
275
271
|
try:
|
|
276
|
-
# Get all config (creates a new dict)
|
|
277
272
|
all_config = self.pool_config_dict.copy()
|
|
278
273
|
|
|
279
274
|
# Separate pool-specific parameters that ConnectionPool accepts directly
|
|
@@ -289,28 +284,27 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
289
284
|
"num_workers": all_config.pop("num_workers", 3),
|
|
290
285
|
}
|
|
291
286
|
|
|
292
|
-
#
|
|
287
|
+
# Capture autocommit setting before configuring the pool
|
|
288
|
+
autocommit_setting = all_config.get("autocommit")
|
|
289
|
+
|
|
293
290
|
def configure_connection(conn: "PsycopgSyncConnection") -> None:
|
|
294
|
-
# Set DictRow as the row factory
|
|
295
291
|
conn.row_factory = dict_row
|
|
292
|
+
# Apply autocommit setting if specified
|
|
293
|
+
if autocommit_setting is not None:
|
|
294
|
+
conn.autocommit = autocommit_setting
|
|
296
295
|
|
|
297
296
|
pool_params["configure"] = all_config.pop("configure", configure_connection)
|
|
298
297
|
|
|
299
|
-
# Remove None values from pool_params
|
|
300
298
|
pool_params = {k: v for k, v in pool_params.items() if v is not None}
|
|
301
299
|
|
|
302
|
-
# Handle conninfo vs individual connection parameters
|
|
303
300
|
conninfo = all_config.pop("conninfo", None)
|
|
304
301
|
if conninfo:
|
|
305
302
|
# If conninfo is provided, use it directly
|
|
306
303
|
# Don't pass kwargs when using conninfo string
|
|
307
304
|
pool = ConnectionPool(conninfo, open=True, **pool_params)
|
|
308
305
|
else:
|
|
309
|
-
# Otherwise, pass connection parameters via kwargs
|
|
310
|
-
# Remove any non-connection parameters
|
|
311
306
|
# row_factory is already popped out earlier
|
|
312
307
|
all_config.pop("row_factory", None)
|
|
313
|
-
# Remove pool-specific settings that may have been left
|
|
314
308
|
all_config.pop("kwargs", None)
|
|
315
309
|
pool = ConnectionPool("", kwargs=all_config, open=True, **pool_params)
|
|
316
310
|
|
|
@@ -328,7 +322,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
328
322
|
logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})
|
|
329
323
|
|
|
330
324
|
try:
|
|
331
|
-
# Set a flag to prevent __del__ from running cleanup
|
|
332
325
|
# This avoids the "cannot join current thread" error during garbage collection
|
|
333
326
|
if hasattr(self.pool_instance, "_closed"):
|
|
334
327
|
self.pool_instance._closed = True
|
|
@@ -339,7 +332,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
339
332
|
logger.exception("Failed to close Psycopg connection pool", extra={"adapter": "psycopg", "error": str(e)})
|
|
340
333
|
raise
|
|
341
334
|
finally:
|
|
342
|
-
# Clear the reference to help garbage collection
|
|
343
335
|
self.pool_instance = None
|
|
344
336
|
|
|
345
337
|
def create_connection(self) -> "PsycopgSyncConnection":
|
|
@@ -385,15 +377,16 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
385
377
|
A PsycopgSyncDriver instance.
|
|
386
378
|
"""
|
|
387
379
|
with self.provide_connection(*args, **kwargs) as conn:
|
|
388
|
-
# Create statement config with parameter style info if not already set
|
|
389
380
|
statement_config = self.statement_config
|
|
381
|
+
# Inject parameter style info if not already set
|
|
390
382
|
if statement_config.allowed_parameter_styles is None:
|
|
383
|
+
from dataclasses import replace
|
|
384
|
+
|
|
391
385
|
statement_config = replace(
|
|
392
386
|
statement_config,
|
|
393
387
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
394
388
|
target_parameter_style=self.preferred_parameter_style,
|
|
395
389
|
)
|
|
396
|
-
|
|
397
390
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
398
391
|
yield driver
|
|
399
392
|
|
|
@@ -555,7 +548,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
555
548
|
self.configure = configure
|
|
556
549
|
self.kwargs = kwargs or {}
|
|
557
550
|
|
|
558
|
-
# Handle extras and additional kwargs
|
|
559
551
|
self.extras = extras or {}
|
|
560
552
|
self.extras.update(additional_kwargs)
|
|
561
553
|
|
|
@@ -584,7 +576,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
584
576
|
if self.kwargs:
|
|
585
577
|
config.update(self.kwargs)
|
|
586
578
|
|
|
587
|
-
# Set DictRow as the row factory
|
|
588
579
|
config["row_factory"] = dict_row
|
|
589
580
|
|
|
590
581
|
return config
|
|
@@ -607,7 +598,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
607
598
|
if self.kwargs:
|
|
608
599
|
config.update(self.kwargs)
|
|
609
600
|
|
|
610
|
-
# Set DictRow as the row factory
|
|
611
601
|
config["row_factory"] = dict_row
|
|
612
602
|
|
|
613
603
|
return config
|
|
@@ -615,7 +605,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
615
605
|
async def _create_pool(self) -> "AsyncConnectionPool":
|
|
616
606
|
"""Create the actual async connection pool."""
|
|
617
607
|
|
|
618
|
-
# Get all config (creates a new dict)
|
|
619
608
|
all_config = self.pool_config_dict.copy()
|
|
620
609
|
|
|
621
610
|
# Separate pool-specific parameters that AsyncConnectionPool accepts directly
|
|
@@ -631,28 +620,27 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
631
620
|
"num_workers": all_config.pop("num_workers", 3),
|
|
632
621
|
}
|
|
633
622
|
|
|
634
|
-
#
|
|
623
|
+
# Capture autocommit setting before configuring the pool
|
|
624
|
+
autocommit_setting = all_config.get("autocommit")
|
|
625
|
+
|
|
635
626
|
async def configure_connection(conn: "PsycopgAsyncConnection") -> None:
|
|
636
|
-
# Set DictRow as the row factory
|
|
637
627
|
conn.row_factory = dict_row
|
|
628
|
+
# Apply autocommit setting if specified (async version requires await)
|
|
629
|
+
if autocommit_setting is not None:
|
|
630
|
+
await conn.set_autocommit(autocommit_setting)
|
|
638
631
|
|
|
639
632
|
pool_params["configure"] = all_config.pop("configure", configure_connection)
|
|
640
633
|
|
|
641
|
-
# Remove None values from pool_params
|
|
642
634
|
pool_params = {k: v for k, v in pool_params.items() if v is not None}
|
|
643
635
|
|
|
644
|
-
# Handle conninfo vs individual connection parameters
|
|
645
636
|
conninfo = all_config.pop("conninfo", None)
|
|
646
637
|
if conninfo:
|
|
647
638
|
# If conninfo is provided, use it directly
|
|
648
639
|
# Don't pass kwargs when using conninfo string
|
|
649
640
|
pool = AsyncConnectionPool(conninfo, open=False, **pool_params)
|
|
650
641
|
else:
|
|
651
|
-
# Otherwise, pass connection parameters via kwargs
|
|
652
|
-
# Remove any non-connection parameters
|
|
653
642
|
# row_factory is already popped out earlier
|
|
654
643
|
all_config.pop("row_factory", None)
|
|
655
|
-
# Remove pool-specific settings that may have been left
|
|
656
644
|
all_config.pop("kwargs", None)
|
|
657
645
|
pool = AsyncConnectionPool("", kwargs=all_config, open=False, **pool_params)
|
|
658
646
|
|
|
@@ -666,14 +654,12 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
666
654
|
return
|
|
667
655
|
|
|
668
656
|
try:
|
|
669
|
-
# Set a flag to prevent __del__ from running cleanup
|
|
670
657
|
# This avoids the "cannot join current thread" error during garbage collection
|
|
671
658
|
if hasattr(self.pool_instance, "_closed"):
|
|
672
659
|
self.pool_instance._closed = True
|
|
673
660
|
|
|
674
661
|
await self.pool_instance.close()
|
|
675
662
|
finally:
|
|
676
|
-
# Clear the reference to help garbage collection
|
|
677
663
|
self.pool_instance = None
|
|
678
664
|
|
|
679
665
|
async def create_connection(self) -> "PsycopgAsyncConnection": # pyright: ignore
|
|
@@ -719,15 +705,16 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
719
705
|
A PsycopgAsyncDriver instance.
|
|
720
706
|
"""
|
|
721
707
|
async with self.provide_connection(*args, **kwargs) as conn:
|
|
722
|
-
# Create statement config with parameter style info if not already set
|
|
723
708
|
statement_config = self.statement_config
|
|
709
|
+
# Inject parameter style info if not already set
|
|
724
710
|
if statement_config.allowed_parameter_styles is None:
|
|
711
|
+
from dataclasses import replace
|
|
712
|
+
|
|
725
713
|
statement_config = replace(
|
|
726
714
|
statement_config,
|
|
727
715
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
728
716
|
target_parameter_style=self.preferred_parameter_style,
|
|
729
717
|
)
|
|
730
|
-
|
|
731
718
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
732
719
|
yield driver
|
|
733
720
|
|