sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +116 -141
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +231 -181
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +132 -124
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +34 -30
  19. sqlspec/adapters/psycopg/driver.py +342 -214
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +150 -104
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +149 -216
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +31 -118
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +70 -23
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +102 -65
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +885 -379
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +82 -35
  91. sqlspec/storage/backends/obstore.py +66 -49
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -170
  110. sqlspec-0.12.1.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -3,7 +3,6 @@
3
3
  import logging
4
4
  from collections.abc import AsyncGenerator
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional
8
7
 
9
8
  from psqlpy import ConnectionPool
@@ -302,7 +301,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
302
301
  if getattr(self, field, None) is not None and getattr(self, field) is not Empty
303
302
  }
304
303
 
305
- # Add connection-specific extras (not pool-specific ones)
306
304
  config.update(self.extras)
307
305
 
308
306
  return config
@@ -359,11 +357,9 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
359
357
  Returns:
360
358
  A psqlpy Connection instance.
361
359
  """
362
- # Ensure pool exists
363
360
  if not self.pool_instance:
364
361
  self.pool_instance = await self._create_pool()
365
362
 
366
- # Get connection from pool
367
363
  return await self.pool_instance.connection()
368
364
 
369
365
  @asynccontextmanager
@@ -377,7 +373,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
377
373
  Yields:
378
374
  A psqlpy Connection instance.
379
375
  """
380
- # Ensure pool exists
381
376
  if not self.pool_instance:
382
377
  self.pool_instance = await self._create_pool()
383
378
 
@@ -396,15 +391,16 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
396
391
  A PsqlpyDriver instance.
397
392
  """
398
393
  async with self.provide_connection(*args, **kwargs) as conn:
399
- # Create statement config with parameter style info if not already set
400
394
  statement_config = self.statement_config
395
+ # Inject parameter style info if not already set
401
396
  if statement_config.allowed_parameter_styles is None:
397
+ from dataclasses import replace
398
+
402
399
  statement_config = replace(
403
400
  statement_config,
404
401
  allowed_parameter_styles=self.supported_parameter_styles,
405
402
  target_parameter_style=self.preferred_parameter_style,
406
403
  )
407
-
408
404
  driver = self.driver_type(connection=conn, config=statement_config)
409
405
  yield driver
410
406
 
@@ -2,11 +2,12 @@
2
2
 
3
3
  import io
4
4
  import logging
5
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
5
+ from typing import TYPE_CHECKING, Any, Optional, cast
6
6
 
7
7
  from psqlpy import Connection
8
8
 
9
9
  from sqlspec.driver import AsyncDriverAdapterProtocol
10
+ from sqlspec.driver.connection import managed_transaction_async
10
11
  from sqlspec.driver.mixins import (
11
12
  AsyncPipelinedExecutionMixin,
12
13
  AsyncStorageMixin,
@@ -14,10 +15,10 @@ from sqlspec.driver.mixins import (
14
15
  ToSchemaMixin,
15
16
  TypeCoercionMixin,
16
17
  )
17
- from sqlspec.statement.parameters import ParameterStyle
18
- from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
18
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
19
+ from sqlspec.statement.result import SQLResult
19
20
  from sqlspec.statement.sql import SQL, SQLConfig
20
- from sqlspec.typing import DictRow, ModelDTOT, RowT
21
+ from sqlspec.typing import DictRow, RowT
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from sqlglot.dialects.dialect import DialectType
@@ -76,13 +77,36 @@ class PsqlpyDriver(
76
77
 
77
78
  async def _execute_statement(
78
79
  self, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
79
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
80
+ ) -> SQLResult[RowT]:
80
81
  if statement.is_script:
81
82
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
82
83
  return await self._execute_script(sql, connection=connection, **kwargs)
83
84
 
84
- # Let the SQL object handle parameter style conversion based on dialect support
85
- sql, params = statement.compile(placeholder_style=self.default_parameter_style)
85
+ # Detect parameter styles in the SQL
86
+ detected_styles = set()
87
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
88
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
89
+ param_infos = validator.extract_parameters(sql_str)
90
+ if param_infos:
91
+ detected_styles = {p.style for p in param_infos}
92
+
93
+ # Determine target style based on what's in the SQL
94
+ target_style = self.default_parameter_style
95
+
96
+ # Check if there are unsupported styles
97
+ unsupported_styles = detected_styles - set(self.supported_parameter_styles)
98
+ if unsupported_styles:
99
+ # Force conversion to default style
100
+ target_style = self.default_parameter_style
101
+ elif detected_styles:
102
+ # Prefer the first supported style found
103
+ for style in detected_styles:
104
+ if style in self.supported_parameter_styles:
105
+ target_style = style
106
+ break
107
+
108
+ # Compile with the determined style
109
+ sql, params = statement.compile(placeholder_style=target_style)
86
110
  params = self._process_parameters(params)
87
111
 
88
112
  if statement.is_many:
@@ -92,43 +116,99 @@ class PsqlpyDriver(
92
116
 
93
117
  async def _execute(
94
118
  self, sql: str, parameters: Any, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
95
- ) -> Union[SelectResultDict, DMLResultDict]:
96
- conn = self._connection(connection)
97
- if self.returns_rows(statement.expression):
98
- query_result = await conn.fetch(sql, parameters=parameters)
99
- # Convert query_result to list of dicts
100
- dict_rows: list[dict[str, Any]] = []
101
- if query_result:
102
- # psqlpy QueryResult has a result() method that returns list of dicts
103
- dict_rows = query_result.result()
104
- column_names = list(dict_rows[0].keys()) if dict_rows else []
105
- return {"data": dict_rows, "column_names": column_names, "rows_affected": len(dict_rows)}
106
- query_result = await conn.execute(sql, parameters=parameters)
107
- # Note: psqlpy doesn't provide rows_affected for DML operations
108
- # The QueryResult object only has result(), as_class(), and row_factory() methods
109
- # For accurate row counts, use RETURNING clause
110
- affected_count = -1 # Unknown, as psqlpy doesn't provide this info
111
- return {"rows_affected": affected_count, "status_message": "OK"}
119
+ ) -> SQLResult[RowT]:
120
+ # Use provided connection or driver's default connection
121
+ conn = connection if connection is not None else self._connection(None)
122
+
123
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
124
+ # PSQLPy expects parameters as a list (for $1, $2, etc.) or dict
125
+ # Ensure we always pass a sequence or mapping, never a scalar
126
+ final_params: Any
127
+ if isinstance(parameters, (list, tuple)):
128
+ final_params = list(parameters)
129
+ elif isinstance(parameters, dict):
130
+ final_params = parameters
131
+ elif parameters is None:
132
+ final_params = []
133
+ else:
134
+ # Single parameter - wrap in list for NUMERIC style ($1)
135
+ final_params = [parameters]
136
+
137
+ if self.returns_rows(statement.expression):
138
+ query_result = await txn_conn.fetch(sql, parameters=final_params)
139
+ dict_rows: list[dict[str, Any]] = []
140
+ if query_result:
141
+ # psqlpy QueryResult has a result() method that returns list of dicts
142
+ dict_rows = query_result.result()
143
+ column_names = list(dict_rows[0].keys()) if dict_rows else []
144
+ return SQLResult(
145
+ statement=statement,
146
+ data=cast("list[RowT]", dict_rows),
147
+ column_names=column_names,
148
+ rows_affected=len(dict_rows),
149
+ operation_type="SELECT",
150
+ )
151
+
152
+ query_result = await txn_conn.execute(sql, parameters=final_params)
153
+ # Note: psqlpy doesn't provide rows_affected for DML operations
154
+ # The QueryResult object only has result(), as_class(), and row_factory() methods
155
+ affected_count = -1 # Unknown, as psqlpy doesn't provide this info
156
+ return SQLResult(
157
+ statement=statement,
158
+ data=[],
159
+ rows_affected=affected_count,
160
+ operation_type=self._determine_operation_type(statement),
161
+ metadata={"status_message": "OK"},
162
+ )
112
163
 
113
164
  async def _execute_many(
114
165
  self, sql: str, param_list: Any, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
115
- ) -> DMLResultDict:
116
- conn = self._connection(connection)
117
- await conn.execute_many(sql, param_list or [])
118
- # execute_many doesn't return a value with rows_affected
119
- affected_count = -1
120
- return {"rows_affected": affected_count, "status_message": "OK"}
166
+ ) -> SQLResult[RowT]:
167
+ # Use provided connection or driver's default connection
168
+ conn = connection if connection is not None else self._connection(None)
169
+
170
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
171
+ # PSQLPy expects a list of parameter lists/tuples for execute_many
172
+ if param_list is None:
173
+ final_param_list = []
174
+ elif isinstance(param_list, (list, tuple)):
175
+ # Ensure each parameter set is a list/tuple
176
+ final_param_list = [
177
+ list(params) if isinstance(params, (list, tuple)) else [params] for params in param_list
178
+ ]
179
+ else:
180
+ # Single parameter set - wrap it
181
+ final_param_list = [list(param_list) if isinstance(param_list, (list, tuple)) else [param_list]]
182
+
183
+ await txn_conn.execute_many(sql, final_param_list)
184
+ # execute_many doesn't return a value with rows_affected
185
+ affected_count = -1
186
+ return SQLResult(
187
+ statement=SQL(sql, _dialect=self.dialect),
188
+ data=[],
189
+ rows_affected=affected_count,
190
+ operation_type="EXECUTE",
191
+ metadata={"status_message": "OK"},
192
+ )
121
193
 
122
194
  async def _execute_script(
123
195
  self, script: str, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
124
- ) -> ScriptResultDict:
125
- conn = self._connection(connection)
126
- # psqlpy can execute multi-statement scripts directly
127
- await conn.execute(script)
128
- return {
129
- "statements_executed": -1, # Not directly supported, but script is executed
130
- "status_message": "SCRIPT EXECUTED",
131
- }
196
+ ) -> SQLResult[RowT]:
197
+ # Use provided connection or driver's default connection
198
+ conn = connection if connection is not None else self._connection(None)
199
+
200
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
201
+ # psqlpy can execute multi-statement scripts directly
202
+ await txn_conn.execute(script)
203
+ return SQLResult(
204
+ statement=SQL(script, _dialect=self.dialect).as_script(),
205
+ data=[],
206
+ rows_affected=0,
207
+ operation_type="SCRIPT",
208
+ metadata={"status_message": "SCRIPT EXECUTED"},
209
+ total_statements=-1, # Not directly supported, but script is executed
210
+ successful_statements=-1,
211
+ )
132
212
 
133
213
  async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
134
214
  self._ensure_pyarrow_installed()
@@ -154,61 +234,6 @@ class PsqlpyDriver(
154
234
  msg = "Connection does not support COPY operations"
155
235
  raise NotImplementedError(msg)
156
236
 
157
- async def _wrap_select_result(
158
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
159
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
160
- dict_rows = result["data"]
161
- column_names = result["column_names"]
162
- rows_affected = result["rows_affected"]
163
-
164
- if schema_type:
165
- converted_data = self.to_schema(data=dict_rows, schema_type=schema_type)
166
- return SQLResult[ModelDTOT](
167
- statement=statement,
168
- data=list(converted_data),
169
- column_names=column_names,
170
- rows_affected=rows_affected,
171
- operation_type="SELECT",
172
- )
173
- return SQLResult[RowT](
174
- statement=statement,
175
- data=cast("list[RowT]", dict_rows),
176
- column_names=column_names,
177
- rows_affected=rows_affected,
178
- operation_type="SELECT",
179
- )
180
-
181
- async def _wrap_execute_result(
182
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
183
- ) -> SQLResult[RowT]:
184
- operation_type = "UNKNOWN"
185
- if statement.expression:
186
- operation_type = str(statement.expression.key).upper()
187
-
188
- if "statements_executed" in result:
189
- script_result = cast("ScriptResultDict", result)
190
- return SQLResult[RowT](
191
- statement=statement,
192
- data=[],
193
- rows_affected=0,
194
- operation_type="SCRIPT",
195
- metadata={
196
- "status_message": script_result.get("status_message", ""),
197
- "statements_executed": script_result.get("statements_executed", -1),
198
- },
199
- )
200
-
201
- dml_result = cast("DMLResultDict", result)
202
- rows_affected = dml_result.get("rows_affected", -1)
203
- status_message = dml_result.get("status_message", "")
204
- return SQLResult[RowT](
205
- statement=statement,
206
- data=[],
207
- rows_affected=rows_affected,
208
- operation_type=operation_type,
209
- metadata={"status_message": status_message},
210
- )
211
-
212
237
  def _connection(self, connection: Optional[PsqlpyConnection] = None) -> PsqlpyConnection:
213
238
  """Get the connection to use for the operation."""
214
239
  return connection or self.connection
@@ -3,7 +3,6 @@
3
3
  import contextlib
4
4
  import logging
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
8
7
 
9
8
  from psycopg.rows import dict_row
@@ -211,7 +210,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
211
210
  self.configure = configure
212
211
  self.kwargs = kwargs or {}
213
212
 
214
- # Handle extras and additional kwargs
215
213
  self.extras = extras or {}
216
214
  self.extras.update(additional_kwargs)
217
215
 
@@ -240,7 +238,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
240
238
  if self.kwargs:
241
239
  config.update(self.kwargs)
242
240
 
243
- # Set DictRow as the row factory
244
241
  config["row_factory"] = dict_row
245
242
 
246
243
  return config
@@ -263,7 +260,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
263
260
  if self.kwargs:
264
261
  config.update(self.kwargs)
265
262
 
266
- # Set DictRow as the row factory
267
263
  config["row_factory"] = dict_row
268
264
 
269
265
  return config
@@ -273,7 +269,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
273
269
  logger.info("Creating Psycopg connection pool", extra={"adapter": "psycopg"})
274
270
 
275
271
  try:
276
- # Get all config (creates a new dict)
277
272
  all_config = self.pool_config_dict.copy()
278
273
 
279
274
  # Separate pool-specific parameters that ConnectionPool accepts directly
@@ -289,30 +284,29 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
289
284
  "num_workers": all_config.pop("num_workers", 3),
290
285
  }
291
286
 
292
- # Create a configure callback to set row_factory
287
+ # Capture autocommit setting before configuring the pool
288
+ autocommit_setting = all_config.get("autocommit")
289
+
293
290
  def configure_connection(conn: "PsycopgSyncConnection") -> None:
294
- # Set DictRow as the row factory
295
291
  conn.row_factory = dict_row
292
+ # Apply autocommit setting if specified
293
+ if autocommit_setting is not None:
294
+ conn.autocommit = autocommit_setting
296
295
 
297
296
  pool_params["configure"] = all_config.pop("configure", configure_connection)
298
297
 
299
- # Remove None values from pool_params
300
298
  pool_params = {k: v for k, v in pool_params.items() if v is not None}
301
299
 
302
- # Handle conninfo vs individual connection parameters
303
300
  conninfo = all_config.pop("conninfo", None)
304
301
  if conninfo:
305
302
  # If conninfo is provided, use it directly
306
303
  # Don't pass kwargs when using conninfo string
307
- pool = ConnectionPool(conninfo, **pool_params)
304
+ pool = ConnectionPool(conninfo, open=True, **pool_params)
308
305
  else:
309
- # Otherwise, pass connection parameters via kwargs
310
- # Remove any non-connection parameters
311
306
  # row_factory is already popped out earlier
312
307
  all_config.pop("row_factory", None)
313
- # Remove pool-specific settings that may have been left
314
308
  all_config.pop("kwargs", None)
315
- pool = ConnectionPool("", kwargs=all_config, **pool_params)
309
+ pool = ConnectionPool("", kwargs=all_config, open=True, **pool_params)
316
310
 
317
311
  logger.info("Psycopg connection pool created successfully", extra={"adapter": "psycopg"})
318
312
  except Exception as e:
@@ -328,11 +322,17 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
328
322
  logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})
329
323
 
330
324
  try:
325
+ # This avoids the "cannot join current thread" error during garbage collection
326
+ if hasattr(self.pool_instance, "_closed"):
327
+ self.pool_instance._closed = True
328
+
331
329
  self.pool_instance.close()
332
330
  logger.info("Psycopg connection pool closed successfully", extra={"adapter": "psycopg"})
333
331
  except Exception as e:
334
332
  logger.exception("Failed to close Psycopg connection pool", extra={"adapter": "psycopg", "error": str(e)})
335
333
  raise
334
+ finally:
335
+ self.pool_instance = None
336
336
 
337
337
  def create_connection(self) -> "PsycopgSyncConnection":
338
338
  """Create a single connection (not from pool).
@@ -377,15 +377,16 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
377
377
  A PsycopgSyncDriver instance.
378
378
  """
379
379
  with self.provide_connection(*args, **kwargs) as conn:
380
- # Create statement config with parameter style info if not already set
381
380
  statement_config = self.statement_config
381
+ # Inject parameter style info if not already set
382
382
  if statement_config.allowed_parameter_styles is None:
383
+ from dataclasses import replace
384
+
383
385
  statement_config = replace(
384
386
  statement_config,
385
387
  allowed_parameter_styles=self.supported_parameter_styles,
386
388
  target_parameter_style=self.preferred_parameter_style,
387
389
  )
388
-
389
390
  driver = self.driver_type(connection=conn, config=statement_config)
390
391
  yield driver
391
392
 
@@ -547,7 +548,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
547
548
  self.configure = configure
548
549
  self.kwargs = kwargs or {}
549
550
 
550
- # Handle extras and additional kwargs
551
551
  self.extras = extras or {}
552
552
  self.extras.update(additional_kwargs)
553
553
 
@@ -576,7 +576,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
576
576
  if self.kwargs:
577
577
  config.update(self.kwargs)
578
578
 
579
- # Set DictRow as the row factory
580
579
  config["row_factory"] = dict_row
581
580
 
582
581
  return config
@@ -599,7 +598,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
599
598
  if self.kwargs:
600
599
  config.update(self.kwargs)
601
600
 
602
- # Set DictRow as the row factory
603
601
  config["row_factory"] = dict_row
604
602
 
605
603
  return config
@@ -607,7 +605,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
607
605
  async def _create_pool(self) -> "AsyncConnectionPool":
608
606
  """Create the actual async connection pool."""
609
607
 
610
- # Get all config (creates a new dict)
611
608
  all_config = self.pool_config_dict.copy()
612
609
 
613
610
  # Separate pool-specific parameters that AsyncConnectionPool accepts directly
@@ -623,28 +620,27 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
623
620
  "num_workers": all_config.pop("num_workers", 3),
624
621
  }
625
622
 
626
- # Create a configure callback to set row_factory
623
+ # Capture autocommit setting before configuring the pool
624
+ autocommit_setting = all_config.get("autocommit")
625
+
627
626
  async def configure_connection(conn: "PsycopgAsyncConnection") -> None:
628
- # Set DictRow as the row factory
629
627
  conn.row_factory = dict_row
628
+ # Apply autocommit setting if specified (async version requires await)
629
+ if autocommit_setting is not None:
630
+ await conn.set_autocommit(autocommit_setting)
630
631
 
631
632
  pool_params["configure"] = all_config.pop("configure", configure_connection)
632
633
 
633
- # Remove None values from pool_params
634
634
  pool_params = {k: v for k, v in pool_params.items() if v is not None}
635
635
 
636
- # Handle conninfo vs individual connection parameters
637
636
  conninfo = all_config.pop("conninfo", None)
638
637
  if conninfo:
639
638
  # If conninfo is provided, use it directly
640
639
  # Don't pass kwargs when using conninfo string
641
640
  pool = AsyncConnectionPool(conninfo, open=False, **pool_params)
642
641
  else:
643
- # Otherwise, pass connection parameters via kwargs
644
- # Remove any non-connection parameters
645
642
  # row_factory is already popped out earlier
646
643
  all_config.pop("row_factory", None)
647
- # Remove pool-specific settings that may have been left
648
644
  all_config.pop("kwargs", None)
649
645
  pool = AsyncConnectionPool("", kwargs=all_config, open=False, **pool_params)
650
646
 
@@ -657,7 +653,14 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
657
653
  if not self.pool_instance:
658
654
  return
659
655
 
660
- await self.pool_instance.close()
656
+ try:
657
+ # This avoids the "cannot join current thread" error during garbage collection
658
+ if hasattr(self.pool_instance, "_closed"):
659
+ self.pool_instance._closed = True
660
+
661
+ await self.pool_instance.close()
662
+ finally:
663
+ self.pool_instance = None
661
664
 
662
665
  async def create_connection(self) -> "PsycopgAsyncConnection": # pyright: ignore
663
666
  """Create a single async connection (not from pool).
@@ -702,15 +705,16 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
702
705
  A PsycopgAsyncDriver instance.
703
706
  """
704
707
  async with self.provide_connection(*args, **kwargs) as conn:
705
- # Create statement config with parameter style info if not already set
706
708
  statement_config = self.statement_config
709
+ # Inject parameter style info if not already set
707
710
  if statement_config.allowed_parameter_styles is None:
711
+ from dataclasses import replace
712
+
708
713
  statement_config = replace(
709
714
  statement_config,
710
715
  allowed_parameter_styles=self.supported_parameter_styles,
711
716
  target_parameter_style=self.preferred_parameter_style,
712
717
  )
713
-
714
718
  driver = self.driver_type(connection=conn, config=statement_config)
715
719
  yield driver
716
720