sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import logging
|
|
4
4
|
from collections.abc import AsyncGenerator
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
|
8
7
|
|
|
9
8
|
from psqlpy import ConnectionPool
|
|
@@ -302,7 +301,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
302
301
|
if getattr(self, field, None) is not None and getattr(self, field) is not Empty
|
|
303
302
|
}
|
|
304
303
|
|
|
305
|
-
# Add connection-specific extras (not pool-specific ones)
|
|
306
304
|
config.update(self.extras)
|
|
307
305
|
|
|
308
306
|
return config
|
|
@@ -359,11 +357,9 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
359
357
|
Returns:
|
|
360
358
|
A psqlpy Connection instance.
|
|
361
359
|
"""
|
|
362
|
-
# Ensure pool exists
|
|
363
360
|
if not self.pool_instance:
|
|
364
361
|
self.pool_instance = await self._create_pool()
|
|
365
362
|
|
|
366
|
-
# Get connection from pool
|
|
367
363
|
return await self.pool_instance.connection()
|
|
368
364
|
|
|
369
365
|
@asynccontextmanager
|
|
@@ -377,7 +373,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
377
373
|
Yields:
|
|
378
374
|
A psqlpy Connection instance.
|
|
379
375
|
"""
|
|
380
|
-
# Ensure pool exists
|
|
381
376
|
if not self.pool_instance:
|
|
382
377
|
self.pool_instance = await self._create_pool()
|
|
383
378
|
|
|
@@ -396,15 +391,16 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
|
|
|
396
391
|
A PsqlpyDriver instance.
|
|
397
392
|
"""
|
|
398
393
|
async with self.provide_connection(*args, **kwargs) as conn:
|
|
399
|
-
# Create statement config with parameter style info if not already set
|
|
400
394
|
statement_config = self.statement_config
|
|
395
|
+
# Inject parameter style info if not already set
|
|
401
396
|
if statement_config.allowed_parameter_styles is None:
|
|
397
|
+
from dataclasses import replace
|
|
398
|
+
|
|
402
399
|
statement_config = replace(
|
|
403
400
|
statement_config,
|
|
404
401
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
405
402
|
target_parameter_style=self.preferred_parameter_style,
|
|
406
403
|
)
|
|
407
|
-
|
|
408
404
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
409
405
|
yield driver
|
|
410
406
|
|
|
@@ -2,11 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
4
|
import logging
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Optional,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
6
6
|
|
|
7
7
|
from psqlpy import Connection
|
|
8
8
|
|
|
9
9
|
from sqlspec.driver import AsyncDriverAdapterProtocol
|
|
10
|
+
from sqlspec.driver.connection import managed_transaction_async
|
|
10
11
|
from sqlspec.driver.mixins import (
|
|
11
12
|
AsyncPipelinedExecutionMixin,
|
|
12
13
|
AsyncStorageMixin,
|
|
@@ -14,10 +15,10 @@ from sqlspec.driver.mixins import (
|
|
|
14
15
|
ToSchemaMixin,
|
|
15
16
|
TypeCoercionMixin,
|
|
16
17
|
)
|
|
17
|
-
from sqlspec.statement.parameters import ParameterStyle
|
|
18
|
-
from sqlspec.statement.result import
|
|
18
|
+
from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
|
|
19
|
+
from sqlspec.statement.result import SQLResult
|
|
19
20
|
from sqlspec.statement.sql import SQL, SQLConfig
|
|
20
|
-
from sqlspec.typing import DictRow,
|
|
21
|
+
from sqlspec.typing import DictRow, RowT
|
|
21
22
|
|
|
22
23
|
if TYPE_CHECKING:
|
|
23
24
|
from sqlglot.dialects.dialect import DialectType
|
|
@@ -76,13 +77,36 @@ class PsqlpyDriver(
|
|
|
76
77
|
|
|
77
78
|
async def _execute_statement(
|
|
78
79
|
self, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
79
|
-
) ->
|
|
80
|
+
) -> SQLResult[RowT]:
|
|
80
81
|
if statement.is_script:
|
|
81
82
|
sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
|
|
82
83
|
return await self._execute_script(sql, connection=connection, **kwargs)
|
|
83
84
|
|
|
84
|
-
#
|
|
85
|
-
|
|
85
|
+
# Detect parameter styles in the SQL
|
|
86
|
+
detected_styles = set()
|
|
87
|
+
sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
|
|
88
|
+
validator = self.config.parameter_validator if self.config else ParameterValidator()
|
|
89
|
+
param_infos = validator.extract_parameters(sql_str)
|
|
90
|
+
if param_infos:
|
|
91
|
+
detected_styles = {p.style for p in param_infos}
|
|
92
|
+
|
|
93
|
+
# Determine target style based on what's in the SQL
|
|
94
|
+
target_style = self.default_parameter_style
|
|
95
|
+
|
|
96
|
+
# Check if there are unsupported styles
|
|
97
|
+
unsupported_styles = detected_styles - set(self.supported_parameter_styles)
|
|
98
|
+
if unsupported_styles:
|
|
99
|
+
# Force conversion to default style
|
|
100
|
+
target_style = self.default_parameter_style
|
|
101
|
+
elif detected_styles:
|
|
102
|
+
# Prefer the first supported style found
|
|
103
|
+
for style in detected_styles:
|
|
104
|
+
if style in self.supported_parameter_styles:
|
|
105
|
+
target_style = style
|
|
106
|
+
break
|
|
107
|
+
|
|
108
|
+
# Compile with the determined style
|
|
109
|
+
sql, params = statement.compile(placeholder_style=target_style)
|
|
86
110
|
params = self._process_parameters(params)
|
|
87
111
|
|
|
88
112
|
if statement.is_many:
|
|
@@ -92,43 +116,99 @@ class PsqlpyDriver(
|
|
|
92
116
|
|
|
93
117
|
async def _execute(
|
|
94
118
|
self, sql: str, parameters: Any, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
95
|
-
) ->
|
|
96
|
-
|
|
97
|
-
if self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
119
|
+
) -> SQLResult[RowT]:
|
|
120
|
+
# Use provided connection or driver's default connection
|
|
121
|
+
conn = connection if connection is not None else self._connection(None)
|
|
122
|
+
|
|
123
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
124
|
+
# PSQLPy expects parameters as a list (for $1, $2, etc.) or dict
|
|
125
|
+
# Ensure we always pass a sequence or mapping, never a scalar
|
|
126
|
+
final_params: Any
|
|
127
|
+
if isinstance(parameters, (list, tuple)):
|
|
128
|
+
final_params = list(parameters)
|
|
129
|
+
elif isinstance(parameters, dict):
|
|
130
|
+
final_params = parameters
|
|
131
|
+
elif parameters is None:
|
|
132
|
+
final_params = []
|
|
133
|
+
else:
|
|
134
|
+
# Single parameter - wrap in list for NUMERIC style ($1)
|
|
135
|
+
final_params = [parameters]
|
|
136
|
+
|
|
137
|
+
if self.returns_rows(statement.expression):
|
|
138
|
+
query_result = await txn_conn.fetch(sql, parameters=final_params)
|
|
139
|
+
dict_rows: list[dict[str, Any]] = []
|
|
140
|
+
if query_result:
|
|
141
|
+
# psqlpy QueryResult has a result() method that returns list of dicts
|
|
142
|
+
dict_rows = query_result.result()
|
|
143
|
+
column_names = list(dict_rows[0].keys()) if dict_rows else []
|
|
144
|
+
return SQLResult(
|
|
145
|
+
statement=statement,
|
|
146
|
+
data=cast("list[RowT]", dict_rows),
|
|
147
|
+
column_names=column_names,
|
|
148
|
+
rows_affected=len(dict_rows),
|
|
149
|
+
operation_type="SELECT",
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
query_result = await txn_conn.execute(sql, parameters=final_params)
|
|
153
|
+
# Note: psqlpy doesn't provide rows_affected for DML operations
|
|
154
|
+
# The QueryResult object only has result(), as_class(), and row_factory() methods
|
|
155
|
+
affected_count = -1 # Unknown, as psqlpy doesn't provide this info
|
|
156
|
+
return SQLResult(
|
|
157
|
+
statement=statement,
|
|
158
|
+
data=[],
|
|
159
|
+
rows_affected=affected_count,
|
|
160
|
+
operation_type=self._determine_operation_type(statement),
|
|
161
|
+
metadata={"status_message": "OK"},
|
|
162
|
+
)
|
|
112
163
|
|
|
113
164
|
async def _execute_many(
|
|
114
165
|
self, sql: str, param_list: Any, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
115
|
-
) ->
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
166
|
+
) -> SQLResult[RowT]:
|
|
167
|
+
# Use provided connection or driver's default connection
|
|
168
|
+
conn = connection if connection is not None else self._connection(None)
|
|
169
|
+
|
|
170
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
171
|
+
# PSQLPy expects a list of parameter lists/tuples for execute_many
|
|
172
|
+
if param_list is None:
|
|
173
|
+
final_param_list = []
|
|
174
|
+
elif isinstance(param_list, (list, tuple)):
|
|
175
|
+
# Ensure each parameter set is a list/tuple
|
|
176
|
+
final_param_list = [
|
|
177
|
+
list(params) if isinstance(params, (list, tuple)) else [params] for params in param_list
|
|
178
|
+
]
|
|
179
|
+
else:
|
|
180
|
+
# Single parameter set - wrap it
|
|
181
|
+
final_param_list = [list(param_list) if isinstance(param_list, (list, tuple)) else [param_list]]
|
|
182
|
+
|
|
183
|
+
await txn_conn.execute_many(sql, final_param_list)
|
|
184
|
+
# execute_many doesn't return a value with rows_affected
|
|
185
|
+
affected_count = -1
|
|
186
|
+
return SQLResult(
|
|
187
|
+
statement=SQL(sql, _dialect=self.dialect),
|
|
188
|
+
data=[],
|
|
189
|
+
rows_affected=affected_count,
|
|
190
|
+
operation_type="EXECUTE",
|
|
191
|
+
metadata={"status_message": "OK"},
|
|
192
|
+
)
|
|
121
193
|
|
|
122
194
|
async def _execute_script(
|
|
123
195
|
self, script: str, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
|
|
124
|
-
) ->
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
196
|
+
) -> SQLResult[RowT]:
|
|
197
|
+
# Use provided connection or driver's default connection
|
|
198
|
+
conn = connection if connection is not None else self._connection(None)
|
|
199
|
+
|
|
200
|
+
async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
|
|
201
|
+
# psqlpy can execute multi-statement scripts directly
|
|
202
|
+
await txn_conn.execute(script)
|
|
203
|
+
return SQLResult(
|
|
204
|
+
statement=SQL(script, _dialect=self.dialect).as_script(),
|
|
205
|
+
data=[],
|
|
206
|
+
rows_affected=0,
|
|
207
|
+
operation_type="SCRIPT",
|
|
208
|
+
metadata={"status_message": "SCRIPT EXECUTED"},
|
|
209
|
+
total_statements=-1, # Not directly supported, but script is executed
|
|
210
|
+
successful_statements=-1,
|
|
211
|
+
)
|
|
132
212
|
|
|
133
213
|
async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
|
|
134
214
|
self._ensure_pyarrow_installed()
|
|
@@ -154,61 +234,6 @@ class PsqlpyDriver(
|
|
|
154
234
|
msg = "Connection does not support COPY operations"
|
|
155
235
|
raise NotImplementedError(msg)
|
|
156
236
|
|
|
157
|
-
async def _wrap_select_result(
|
|
158
|
-
self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
|
|
159
|
-
) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
|
|
160
|
-
dict_rows = result["data"]
|
|
161
|
-
column_names = result["column_names"]
|
|
162
|
-
rows_affected = result["rows_affected"]
|
|
163
|
-
|
|
164
|
-
if schema_type:
|
|
165
|
-
converted_data = self.to_schema(data=dict_rows, schema_type=schema_type)
|
|
166
|
-
return SQLResult[ModelDTOT](
|
|
167
|
-
statement=statement,
|
|
168
|
-
data=list(converted_data),
|
|
169
|
-
column_names=column_names,
|
|
170
|
-
rows_affected=rows_affected,
|
|
171
|
-
operation_type="SELECT",
|
|
172
|
-
)
|
|
173
|
-
return SQLResult[RowT](
|
|
174
|
-
statement=statement,
|
|
175
|
-
data=cast("list[RowT]", dict_rows),
|
|
176
|
-
column_names=column_names,
|
|
177
|
-
rows_affected=rows_affected,
|
|
178
|
-
operation_type="SELECT",
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
async def _wrap_execute_result(
|
|
182
|
-
self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
|
|
183
|
-
) -> SQLResult[RowT]:
|
|
184
|
-
operation_type = "UNKNOWN"
|
|
185
|
-
if statement.expression:
|
|
186
|
-
operation_type = str(statement.expression.key).upper()
|
|
187
|
-
|
|
188
|
-
if "statements_executed" in result:
|
|
189
|
-
script_result = cast("ScriptResultDict", result)
|
|
190
|
-
return SQLResult[RowT](
|
|
191
|
-
statement=statement,
|
|
192
|
-
data=[],
|
|
193
|
-
rows_affected=0,
|
|
194
|
-
operation_type="SCRIPT",
|
|
195
|
-
metadata={
|
|
196
|
-
"status_message": script_result.get("status_message", ""),
|
|
197
|
-
"statements_executed": script_result.get("statements_executed", -1),
|
|
198
|
-
},
|
|
199
|
-
)
|
|
200
|
-
|
|
201
|
-
dml_result = cast("DMLResultDict", result)
|
|
202
|
-
rows_affected = dml_result.get("rows_affected", -1)
|
|
203
|
-
status_message = dml_result.get("status_message", "")
|
|
204
|
-
return SQLResult[RowT](
|
|
205
|
-
statement=statement,
|
|
206
|
-
data=[],
|
|
207
|
-
rows_affected=rows_affected,
|
|
208
|
-
operation_type=operation_type,
|
|
209
|
-
metadata={"status_message": status_message},
|
|
210
|
-
)
|
|
211
|
-
|
|
212
237
|
def _connection(self, connection: Optional[PsqlpyConnection] = None) -> PsqlpyConnection:
|
|
213
238
|
"""Get the connection to use for the operation."""
|
|
214
239
|
return connection or self.connection
|
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
import contextlib
|
|
4
4
|
import logging
|
|
5
5
|
from contextlib import asynccontextmanager
|
|
6
|
-
from dataclasses import replace
|
|
7
6
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
|
|
8
7
|
|
|
9
8
|
from psycopg.rows import dict_row
|
|
@@ -211,7 +210,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
211
210
|
self.configure = configure
|
|
212
211
|
self.kwargs = kwargs or {}
|
|
213
212
|
|
|
214
|
-
# Handle extras and additional kwargs
|
|
215
213
|
self.extras = extras or {}
|
|
216
214
|
self.extras.update(additional_kwargs)
|
|
217
215
|
|
|
@@ -240,7 +238,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
240
238
|
if self.kwargs:
|
|
241
239
|
config.update(self.kwargs)
|
|
242
240
|
|
|
243
|
-
# Set DictRow as the row factory
|
|
244
241
|
config["row_factory"] = dict_row
|
|
245
242
|
|
|
246
243
|
return config
|
|
@@ -263,7 +260,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
263
260
|
if self.kwargs:
|
|
264
261
|
config.update(self.kwargs)
|
|
265
262
|
|
|
266
|
-
# Set DictRow as the row factory
|
|
267
263
|
config["row_factory"] = dict_row
|
|
268
264
|
|
|
269
265
|
return config
|
|
@@ -273,7 +269,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
273
269
|
logger.info("Creating Psycopg connection pool", extra={"adapter": "psycopg"})
|
|
274
270
|
|
|
275
271
|
try:
|
|
276
|
-
# Get all config (creates a new dict)
|
|
277
272
|
all_config = self.pool_config_dict.copy()
|
|
278
273
|
|
|
279
274
|
# Separate pool-specific parameters that ConnectionPool accepts directly
|
|
@@ -289,30 +284,29 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
289
284
|
"num_workers": all_config.pop("num_workers", 3),
|
|
290
285
|
}
|
|
291
286
|
|
|
292
|
-
#
|
|
287
|
+
# Capture autocommit setting before configuring the pool
|
|
288
|
+
autocommit_setting = all_config.get("autocommit")
|
|
289
|
+
|
|
293
290
|
def configure_connection(conn: "PsycopgSyncConnection") -> None:
|
|
294
|
-
# Set DictRow as the row factory
|
|
295
291
|
conn.row_factory = dict_row
|
|
292
|
+
# Apply autocommit setting if specified
|
|
293
|
+
if autocommit_setting is not None:
|
|
294
|
+
conn.autocommit = autocommit_setting
|
|
296
295
|
|
|
297
296
|
pool_params["configure"] = all_config.pop("configure", configure_connection)
|
|
298
297
|
|
|
299
|
-
# Remove None values from pool_params
|
|
300
298
|
pool_params = {k: v for k, v in pool_params.items() if v is not None}
|
|
301
299
|
|
|
302
|
-
# Handle conninfo vs individual connection parameters
|
|
303
300
|
conninfo = all_config.pop("conninfo", None)
|
|
304
301
|
if conninfo:
|
|
305
302
|
# If conninfo is provided, use it directly
|
|
306
303
|
# Don't pass kwargs when using conninfo string
|
|
307
|
-
pool = ConnectionPool(conninfo, **pool_params)
|
|
304
|
+
pool = ConnectionPool(conninfo, open=True, **pool_params)
|
|
308
305
|
else:
|
|
309
|
-
# Otherwise, pass connection parameters via kwargs
|
|
310
|
-
# Remove any non-connection parameters
|
|
311
306
|
# row_factory is already popped out earlier
|
|
312
307
|
all_config.pop("row_factory", None)
|
|
313
|
-
# Remove pool-specific settings that may have been left
|
|
314
308
|
all_config.pop("kwargs", None)
|
|
315
|
-
pool = ConnectionPool("", kwargs=all_config, **pool_params)
|
|
309
|
+
pool = ConnectionPool("", kwargs=all_config, open=True, **pool_params)
|
|
316
310
|
|
|
317
311
|
logger.info("Psycopg connection pool created successfully", extra={"adapter": "psycopg"})
|
|
318
312
|
except Exception as e:
|
|
@@ -328,11 +322,17 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
328
322
|
logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})
|
|
329
323
|
|
|
330
324
|
try:
|
|
325
|
+
# This avoids the "cannot join current thread" error during garbage collection
|
|
326
|
+
if hasattr(self.pool_instance, "_closed"):
|
|
327
|
+
self.pool_instance._closed = True
|
|
328
|
+
|
|
331
329
|
self.pool_instance.close()
|
|
332
330
|
logger.info("Psycopg connection pool closed successfully", extra={"adapter": "psycopg"})
|
|
333
331
|
except Exception as e:
|
|
334
332
|
logger.exception("Failed to close Psycopg connection pool", extra={"adapter": "psycopg", "error": str(e)})
|
|
335
333
|
raise
|
|
334
|
+
finally:
|
|
335
|
+
self.pool_instance = None
|
|
336
336
|
|
|
337
337
|
def create_connection(self) -> "PsycopgSyncConnection":
|
|
338
338
|
"""Create a single connection (not from pool).
|
|
@@ -377,15 +377,16 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
|
|
|
377
377
|
A PsycopgSyncDriver instance.
|
|
378
378
|
"""
|
|
379
379
|
with self.provide_connection(*args, **kwargs) as conn:
|
|
380
|
-
# Create statement config with parameter style info if not already set
|
|
381
380
|
statement_config = self.statement_config
|
|
381
|
+
# Inject parameter style info if not already set
|
|
382
382
|
if statement_config.allowed_parameter_styles is None:
|
|
383
|
+
from dataclasses import replace
|
|
384
|
+
|
|
383
385
|
statement_config = replace(
|
|
384
386
|
statement_config,
|
|
385
387
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
386
388
|
target_parameter_style=self.preferred_parameter_style,
|
|
387
389
|
)
|
|
388
|
-
|
|
389
390
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
390
391
|
yield driver
|
|
391
392
|
|
|
@@ -547,7 +548,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
547
548
|
self.configure = configure
|
|
548
549
|
self.kwargs = kwargs or {}
|
|
549
550
|
|
|
550
|
-
# Handle extras and additional kwargs
|
|
551
551
|
self.extras = extras or {}
|
|
552
552
|
self.extras.update(additional_kwargs)
|
|
553
553
|
|
|
@@ -576,7 +576,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
576
576
|
if self.kwargs:
|
|
577
577
|
config.update(self.kwargs)
|
|
578
578
|
|
|
579
|
-
# Set DictRow as the row factory
|
|
580
579
|
config["row_factory"] = dict_row
|
|
581
580
|
|
|
582
581
|
return config
|
|
@@ -599,7 +598,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
599
598
|
if self.kwargs:
|
|
600
599
|
config.update(self.kwargs)
|
|
601
600
|
|
|
602
|
-
# Set DictRow as the row factory
|
|
603
601
|
config["row_factory"] = dict_row
|
|
604
602
|
|
|
605
603
|
return config
|
|
@@ -607,7 +605,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
607
605
|
async def _create_pool(self) -> "AsyncConnectionPool":
|
|
608
606
|
"""Create the actual async connection pool."""
|
|
609
607
|
|
|
610
|
-
# Get all config (creates a new dict)
|
|
611
608
|
all_config = self.pool_config_dict.copy()
|
|
612
609
|
|
|
613
610
|
# Separate pool-specific parameters that AsyncConnectionPool accepts directly
|
|
@@ -623,28 +620,27 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
623
620
|
"num_workers": all_config.pop("num_workers", 3),
|
|
624
621
|
}
|
|
625
622
|
|
|
626
|
-
#
|
|
623
|
+
# Capture autocommit setting before configuring the pool
|
|
624
|
+
autocommit_setting = all_config.get("autocommit")
|
|
625
|
+
|
|
627
626
|
async def configure_connection(conn: "PsycopgAsyncConnection") -> None:
|
|
628
|
-
# Set DictRow as the row factory
|
|
629
627
|
conn.row_factory = dict_row
|
|
628
|
+
# Apply autocommit setting if specified (async version requires await)
|
|
629
|
+
if autocommit_setting is not None:
|
|
630
|
+
await conn.set_autocommit(autocommit_setting)
|
|
630
631
|
|
|
631
632
|
pool_params["configure"] = all_config.pop("configure", configure_connection)
|
|
632
633
|
|
|
633
|
-
# Remove None values from pool_params
|
|
634
634
|
pool_params = {k: v for k, v in pool_params.items() if v is not None}
|
|
635
635
|
|
|
636
|
-
# Handle conninfo vs individual connection parameters
|
|
637
636
|
conninfo = all_config.pop("conninfo", None)
|
|
638
637
|
if conninfo:
|
|
639
638
|
# If conninfo is provided, use it directly
|
|
640
639
|
# Don't pass kwargs when using conninfo string
|
|
641
640
|
pool = AsyncConnectionPool(conninfo, open=False, **pool_params)
|
|
642
641
|
else:
|
|
643
|
-
# Otherwise, pass connection parameters via kwargs
|
|
644
|
-
# Remove any non-connection parameters
|
|
645
642
|
# row_factory is already popped out earlier
|
|
646
643
|
all_config.pop("row_factory", None)
|
|
647
|
-
# Remove pool-specific settings that may have been left
|
|
648
644
|
all_config.pop("kwargs", None)
|
|
649
645
|
pool = AsyncConnectionPool("", kwargs=all_config, open=False, **pool_params)
|
|
650
646
|
|
|
@@ -657,7 +653,14 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
657
653
|
if not self.pool_instance:
|
|
658
654
|
return
|
|
659
655
|
|
|
660
|
-
|
|
656
|
+
try:
|
|
657
|
+
# This avoids the "cannot join current thread" error during garbage collection
|
|
658
|
+
if hasattr(self.pool_instance, "_closed"):
|
|
659
|
+
self.pool_instance._closed = True
|
|
660
|
+
|
|
661
|
+
await self.pool_instance.close()
|
|
662
|
+
finally:
|
|
663
|
+
self.pool_instance = None
|
|
661
664
|
|
|
662
665
|
async def create_connection(self) -> "PsycopgAsyncConnection": # pyright: ignore
|
|
663
666
|
"""Create a single async connection (not from pool).
|
|
@@ -702,15 +705,16 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
|
|
|
702
705
|
A PsycopgAsyncDriver instance.
|
|
703
706
|
"""
|
|
704
707
|
async with self.provide_connection(*args, **kwargs) as conn:
|
|
705
|
-
# Create statement config with parameter style info if not already set
|
|
706
708
|
statement_config = self.statement_config
|
|
709
|
+
# Inject parameter style info if not already set
|
|
707
710
|
if statement_config.allowed_parameter_styles is None:
|
|
711
|
+
from dataclasses import replace
|
|
712
|
+
|
|
708
713
|
statement_config = replace(
|
|
709
714
|
statement_config,
|
|
710
715
|
allowed_parameter_styles=self.supported_parameter_styles,
|
|
711
716
|
target_parameter_style=self.preferred_parameter_style,
|
|
712
717
|
)
|
|
713
|
-
|
|
714
718
|
driver = self.driver_type(connection=conn, config=statement_config)
|
|
715
719
|
yield driver
|
|
716
720
|
|