sqlspec 0.12.2__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (113) hide show
  1. sqlspec/_sql.py +21 -180
  2. sqlspec/adapters/adbc/config.py +10 -12
  3. sqlspec/adapters/adbc/driver.py +120 -118
  4. sqlspec/adapters/aiosqlite/config.py +3 -3
  5. sqlspec/adapters/aiosqlite/driver.py +100 -130
  6. sqlspec/adapters/asyncmy/config.py +3 -4
  7. sqlspec/adapters/asyncmy/driver.py +123 -135
  8. sqlspec/adapters/asyncpg/config.py +3 -7
  9. sqlspec/adapters/asyncpg/driver.py +98 -140
  10. sqlspec/adapters/bigquery/config.py +4 -5
  11. sqlspec/adapters/bigquery/driver.py +125 -167
  12. sqlspec/adapters/duckdb/config.py +3 -6
  13. sqlspec/adapters/duckdb/driver.py +114 -111
  14. sqlspec/adapters/oracledb/config.py +6 -5
  15. sqlspec/adapters/oracledb/driver.py +242 -259
  16. sqlspec/adapters/psqlpy/config.py +3 -7
  17. sqlspec/adapters/psqlpy/driver.py +118 -93
  18. sqlspec/adapters/psycopg/config.py +18 -31
  19. sqlspec/adapters/psycopg/driver.py +283 -236
  20. sqlspec/adapters/sqlite/config.py +3 -3
  21. sqlspec/adapters/sqlite/driver.py +103 -97
  22. sqlspec/config.py +0 -4
  23. sqlspec/driver/_async.py +89 -98
  24. sqlspec/driver/_common.py +52 -17
  25. sqlspec/driver/_sync.py +81 -105
  26. sqlspec/driver/connection.py +207 -0
  27. sqlspec/driver/mixins/_csv_writer.py +91 -0
  28. sqlspec/driver/mixins/_pipeline.py +38 -49
  29. sqlspec/driver/mixins/_result_utils.py +27 -9
  30. sqlspec/driver/mixins/_storage.py +67 -181
  31. sqlspec/driver/mixins/_type_coercion.py +3 -4
  32. sqlspec/driver/parameters.py +138 -0
  33. sqlspec/exceptions.py +10 -2
  34. sqlspec/extensions/aiosql/adapter.py +0 -10
  35. sqlspec/extensions/litestar/handlers.py +0 -1
  36. sqlspec/extensions/litestar/plugin.py +0 -3
  37. sqlspec/extensions/litestar/providers.py +0 -14
  38. sqlspec/loader.py +25 -90
  39. sqlspec/protocols.py +542 -0
  40. sqlspec/service/__init__.py +3 -2
  41. sqlspec/service/_util.py +147 -0
  42. sqlspec/service/base.py +1116 -9
  43. sqlspec/statement/builder/__init__.py +42 -32
  44. sqlspec/statement/builder/_ddl_utils.py +0 -10
  45. sqlspec/statement/builder/_parsing_utils.py +10 -4
  46. sqlspec/statement/builder/base.py +67 -22
  47. sqlspec/statement/builder/column.py +283 -0
  48. sqlspec/statement/builder/ddl.py +91 -67
  49. sqlspec/statement/builder/delete.py +23 -7
  50. sqlspec/statement/builder/insert.py +29 -15
  51. sqlspec/statement/builder/merge.py +4 -4
  52. sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
  53. sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
  54. sqlspec/statement/builder/mixins/_delete_from.py +1 -1
  55. sqlspec/statement/builder/mixins/_from.py +10 -8
  56. sqlspec/statement/builder/mixins/_group_by.py +0 -1
  57. sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
  58. sqlspec/statement/builder/mixins/_insert_values.py +0 -2
  59. sqlspec/statement/builder/mixins/_join.py +20 -13
  60. sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
  61. sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
  62. sqlspec/statement/builder/mixins/_order_by.py +2 -2
  63. sqlspec/statement/builder/mixins/_pivot.py +4 -7
  64. sqlspec/statement/builder/mixins/_select_columns.py +6 -5
  65. sqlspec/statement/builder/mixins/_unpivot.py +6 -9
  66. sqlspec/statement/builder/mixins/_update_from.py +2 -1
  67. sqlspec/statement/builder/mixins/_update_set.py +11 -8
  68. sqlspec/statement/builder/mixins/_where.py +61 -34
  69. sqlspec/statement/builder/select.py +32 -17
  70. sqlspec/statement/builder/update.py +25 -11
  71. sqlspec/statement/filters.py +39 -14
  72. sqlspec/statement/parameter_manager.py +220 -0
  73. sqlspec/statement/parameters.py +210 -79
  74. sqlspec/statement/pipelines/__init__.py +166 -23
  75. sqlspec/statement/pipelines/analyzers/_analyzer.py +21 -20
  76. sqlspec/statement/pipelines/context.py +35 -39
  77. sqlspec/statement/pipelines/transformers/__init__.py +2 -3
  78. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
  79. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +628 -58
  80. sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
  81. sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
  82. sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
  83. sqlspec/statement/pipelines/validators/_performance.py +38 -23
  84. sqlspec/statement/pipelines/validators/_security.py +39 -62
  85. sqlspec/statement/result.py +37 -129
  86. sqlspec/statement/splitter.py +0 -12
  87. sqlspec/statement/sql.py +863 -391
  88. sqlspec/statement/sql_compiler.py +140 -0
  89. sqlspec/storage/__init__.py +10 -2
  90. sqlspec/storage/backends/fsspec.py +53 -8
  91. sqlspec/storage/backends/obstore.py +15 -19
  92. sqlspec/storage/capabilities.py +101 -0
  93. sqlspec/storage/registry.py +56 -83
  94. sqlspec/typing.py +6 -434
  95. sqlspec/utils/cached_property.py +25 -0
  96. sqlspec/utils/correlation.py +0 -2
  97. sqlspec/utils/logging.py +0 -6
  98. sqlspec/utils/sync_tools.py +0 -4
  99. sqlspec/utils/text.py +0 -5
  100. sqlspec/utils/type_guards.py +892 -0
  101. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
  102. sqlspec-0.13.0.dist-info/RECORD +150 -0
  103. sqlspec/statement/builder/protocols.py +0 -20
  104. sqlspec/statement/pipelines/base.py +0 -315
  105. sqlspec/statement/pipelines/result_types.py +0 -41
  106. sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
  107. sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
  108. sqlspec/statement/pipelines/validators/base.py +0 -67
  109. sqlspec/storage/protocol.py +0 -173
  110. sqlspec-0.12.2.dist-info/RECORD +0 -145
  111. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
  112. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
  113. {sqlspec-0.12.2.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
@@ -3,7 +3,6 @@
3
3
  import logging
4
4
  from collections.abc import AsyncGenerator
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional
8
7
 
9
8
  from psqlpy import ConnectionPool
@@ -302,7 +301,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
302
301
  if getattr(self, field, None) is not None and getattr(self, field) is not Empty
303
302
  }
304
303
 
305
- # Add connection-specific extras (not pool-specific ones)
306
304
  config.update(self.extras)
307
305
 
308
306
  return config
@@ -359,11 +357,9 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
359
357
  Returns:
360
358
  A psqlpy Connection instance.
361
359
  """
362
- # Ensure pool exists
363
360
  if not self.pool_instance:
364
361
  self.pool_instance = await self._create_pool()
365
362
 
366
- # Get connection from pool
367
363
  return await self.pool_instance.connection()
368
364
 
369
365
  @asynccontextmanager
@@ -377,7 +373,6 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
377
373
  Yields:
378
374
  A psqlpy Connection instance.
379
375
  """
380
- # Ensure pool exists
381
376
  if not self.pool_instance:
382
377
  self.pool_instance = await self._create_pool()
383
378
 
@@ -396,15 +391,16 @@ class PsqlpyConfig(AsyncDatabaseConfig[PsqlpyConnection, ConnectionPool, PsqlpyD
396
391
  A PsqlpyDriver instance.
397
392
  """
398
393
  async with self.provide_connection(*args, **kwargs) as conn:
399
- # Create statement config with parameter style info if not already set
400
394
  statement_config = self.statement_config
395
+ # Inject parameter style info if not already set
401
396
  if statement_config.allowed_parameter_styles is None:
397
+ from dataclasses import replace
398
+
402
399
  statement_config = replace(
403
400
  statement_config,
404
401
  allowed_parameter_styles=self.supported_parameter_styles,
405
402
  target_parameter_style=self.preferred_parameter_style,
406
403
  )
407
-
408
404
  driver = self.driver_type(connection=conn, config=statement_config)
409
405
  yield driver
410
406
 
@@ -2,11 +2,12 @@
2
2
 
3
3
  import io
4
4
  import logging
5
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
5
+ from typing import TYPE_CHECKING, Any, Optional, cast
6
6
 
7
7
  from psqlpy import Connection
8
8
 
9
9
  from sqlspec.driver import AsyncDriverAdapterProtocol
10
+ from sqlspec.driver.connection import managed_transaction_async
10
11
  from sqlspec.driver.mixins import (
11
12
  AsyncPipelinedExecutionMixin,
12
13
  AsyncStorageMixin,
@@ -14,10 +15,10 @@ from sqlspec.driver.mixins import (
14
15
  ToSchemaMixin,
15
16
  TypeCoercionMixin,
16
17
  )
17
- from sqlspec.statement.parameters import ParameterStyle
18
- from sqlspec.statement.result import DMLResultDict, ScriptResultDict, SelectResultDict, SQLResult
18
+ from sqlspec.statement.parameters import ParameterStyle, ParameterValidator
19
+ from sqlspec.statement.result import SQLResult
19
20
  from sqlspec.statement.sql import SQL, SQLConfig
20
- from sqlspec.typing import DictRow, ModelDTOT, RowT
21
+ from sqlspec.typing import DictRow, RowT
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from sqlglot.dialects.dialect import DialectType
@@ -76,13 +77,36 @@ class PsqlpyDriver(
76
77
 
77
78
  async def _execute_statement(
78
79
  self, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
79
- ) -> Union[SelectResultDict, DMLResultDict, ScriptResultDict]:
80
+ ) -> SQLResult[RowT]:
80
81
  if statement.is_script:
81
82
  sql, _ = statement.compile(placeholder_style=ParameterStyle.STATIC)
82
83
  return await self._execute_script(sql, connection=connection, **kwargs)
83
84
 
84
- # Let the SQL object handle parameter style conversion based on dialect support
85
- sql, params = statement.compile(placeholder_style=self.default_parameter_style)
85
+ # Detect parameter styles in the SQL
86
+ detected_styles = set()
87
+ sql_str = statement.to_sql(placeholder_style=None) # Get raw SQL
88
+ validator = self.config.parameter_validator if self.config else ParameterValidator()
89
+ param_infos = validator.extract_parameters(sql_str)
90
+ if param_infos:
91
+ detected_styles = {p.style for p in param_infos}
92
+
93
+ # Determine target style based on what's in the SQL
94
+ target_style = self.default_parameter_style
95
+
96
+ # Check if there are unsupported styles
97
+ unsupported_styles = detected_styles - set(self.supported_parameter_styles)
98
+ if unsupported_styles:
99
+ # Force conversion to default style
100
+ target_style = self.default_parameter_style
101
+ elif detected_styles:
102
+ # Prefer the first supported style found
103
+ for style in detected_styles:
104
+ if style in self.supported_parameter_styles:
105
+ target_style = style
106
+ break
107
+
108
+ # Compile with the determined style
109
+ sql, params = statement.compile(placeholder_style=target_style)
86
110
  params = self._process_parameters(params)
87
111
 
88
112
  if statement.is_many:
@@ -92,43 +116,99 @@ class PsqlpyDriver(
92
116
 
93
117
  async def _execute(
94
118
  self, sql: str, parameters: Any, statement: SQL, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
95
- ) -> Union[SelectResultDict, DMLResultDict]:
96
- conn = self._connection(connection)
97
- if self.returns_rows(statement.expression):
98
- query_result = await conn.fetch(sql, parameters=parameters)
99
- # Convert query_result to list of dicts
100
- dict_rows: list[dict[str, Any]] = []
101
- if query_result:
102
- # psqlpy QueryResult has a result() method that returns list of dicts
103
- dict_rows = query_result.result()
104
- column_names = list(dict_rows[0].keys()) if dict_rows else []
105
- return {"data": dict_rows, "column_names": column_names, "rows_affected": len(dict_rows)}
106
- query_result = await conn.execute(sql, parameters=parameters)
107
- # Note: psqlpy doesn't provide rows_affected for DML operations
108
- # The QueryResult object only has result(), as_class(), and row_factory() methods
109
- # For accurate row counts, use RETURNING clause
110
- affected_count = -1 # Unknown, as psqlpy doesn't provide this info
111
- return {"rows_affected": affected_count, "status_message": "OK"}
119
+ ) -> SQLResult[RowT]:
120
+ # Use provided connection or driver's default connection
121
+ conn = connection if connection is not None else self._connection(None)
122
+
123
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
124
+ # PSQLPy expects parameters as a list (for $1, $2, etc.) or dict
125
+ # Ensure we always pass a sequence or mapping, never a scalar
126
+ final_params: Any
127
+ if isinstance(parameters, (list, tuple)):
128
+ final_params = list(parameters)
129
+ elif isinstance(parameters, dict):
130
+ final_params = parameters
131
+ elif parameters is None:
132
+ final_params = []
133
+ else:
134
+ # Single parameter - wrap in list for NUMERIC style ($1)
135
+ final_params = [parameters]
136
+
137
+ if self.returns_rows(statement.expression):
138
+ query_result = await txn_conn.fetch(sql, parameters=final_params)
139
+ dict_rows: list[dict[str, Any]] = []
140
+ if query_result:
141
+ # psqlpy QueryResult has a result() method that returns list of dicts
142
+ dict_rows = query_result.result()
143
+ column_names = list(dict_rows[0].keys()) if dict_rows else []
144
+ return SQLResult(
145
+ statement=statement,
146
+ data=cast("list[RowT]", dict_rows),
147
+ column_names=column_names,
148
+ rows_affected=len(dict_rows),
149
+ operation_type="SELECT",
150
+ )
151
+
152
+ query_result = await txn_conn.execute(sql, parameters=final_params)
153
+ # Note: psqlpy doesn't provide rows_affected for DML operations
154
+ # The QueryResult object only has result(), as_class(), and row_factory() methods
155
+ affected_count = -1 # Unknown, as psqlpy doesn't provide this info
156
+ return SQLResult(
157
+ statement=statement,
158
+ data=[],
159
+ rows_affected=affected_count,
160
+ operation_type=self._determine_operation_type(statement),
161
+ metadata={"status_message": "OK"},
162
+ )
112
163
 
113
164
  async def _execute_many(
114
165
  self, sql: str, param_list: Any, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
115
- ) -> DMLResultDict:
116
- conn = self._connection(connection)
117
- await conn.execute_many(sql, param_list or [])
118
- # execute_many doesn't return a value with rows_affected
119
- affected_count = -1
120
- return {"rows_affected": affected_count, "status_message": "OK"}
166
+ ) -> SQLResult[RowT]:
167
+ # Use provided connection or driver's default connection
168
+ conn = connection if connection is not None else self._connection(None)
169
+
170
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
171
+ # PSQLPy expects a list of parameter lists/tuples for execute_many
172
+ if param_list is None:
173
+ final_param_list = []
174
+ elif isinstance(param_list, (list, tuple)):
175
+ # Ensure each parameter set is a list/tuple
176
+ final_param_list = [
177
+ list(params) if isinstance(params, (list, tuple)) else [params] for params in param_list
178
+ ]
179
+ else:
180
+ # Single parameter set - wrap it
181
+ final_param_list = [list(param_list) if isinstance(param_list, (list, tuple)) else [param_list]]
182
+
183
+ await txn_conn.execute_many(sql, final_param_list)
184
+ # execute_many doesn't return a value with rows_affected
185
+ affected_count = -1
186
+ return SQLResult(
187
+ statement=SQL(sql, _dialect=self.dialect),
188
+ data=[],
189
+ rows_affected=affected_count,
190
+ operation_type="EXECUTE",
191
+ metadata={"status_message": "OK"},
192
+ )
121
193
 
122
194
  async def _execute_script(
123
195
  self, script: str, connection: Optional[PsqlpyConnection] = None, **kwargs: Any
124
- ) -> ScriptResultDict:
125
- conn = self._connection(connection)
126
- # psqlpy can execute multi-statement scripts directly
127
- await conn.execute(script)
128
- return {
129
- "statements_executed": -1, # Not directly supported, but script is executed
130
- "status_message": "SCRIPT EXECUTED",
131
- }
196
+ ) -> SQLResult[RowT]:
197
+ # Use provided connection or driver's default connection
198
+ conn = connection if connection is not None else self._connection(None)
199
+
200
+ async with managed_transaction_async(conn, auto_commit=True) as txn_conn:
201
+ # psqlpy can execute multi-statement scripts directly
202
+ await txn_conn.execute(script)
203
+ return SQLResult(
204
+ statement=SQL(script, _dialect=self.dialect).as_script(),
205
+ data=[],
206
+ rows_affected=0,
207
+ operation_type="SCRIPT",
208
+ metadata={"status_message": "SCRIPT EXECUTED"},
209
+ total_statements=-1, # Not directly supported, but script is executed
210
+ successful_statements=-1,
211
+ )
132
212
 
133
213
  async def _ingest_arrow_table(self, table: "Any", table_name: str, mode: str = "append", **options: Any) -> int:
134
214
  self._ensure_pyarrow_installed()
@@ -154,61 +234,6 @@ class PsqlpyDriver(
154
234
  msg = "Connection does not support COPY operations"
155
235
  raise NotImplementedError(msg)
156
236
 
157
- async def _wrap_select_result(
158
- self, statement: SQL, result: SelectResultDict, schema_type: Optional[type[ModelDTOT]] = None, **kwargs: Any
159
- ) -> Union[SQLResult[ModelDTOT], SQLResult[RowT]]:
160
- dict_rows = result["data"]
161
- column_names = result["column_names"]
162
- rows_affected = result["rows_affected"]
163
-
164
- if schema_type:
165
- converted_data = self.to_schema(data=dict_rows, schema_type=schema_type)
166
- return SQLResult[ModelDTOT](
167
- statement=statement,
168
- data=list(converted_data),
169
- column_names=column_names,
170
- rows_affected=rows_affected,
171
- operation_type="SELECT",
172
- )
173
- return SQLResult[RowT](
174
- statement=statement,
175
- data=cast("list[RowT]", dict_rows),
176
- column_names=column_names,
177
- rows_affected=rows_affected,
178
- operation_type="SELECT",
179
- )
180
-
181
- async def _wrap_execute_result(
182
- self, statement: SQL, result: Union[DMLResultDict, ScriptResultDict], **kwargs: Any
183
- ) -> SQLResult[RowT]:
184
- operation_type = "UNKNOWN"
185
- if statement.expression:
186
- operation_type = str(statement.expression.key).upper()
187
-
188
- if "statements_executed" in result:
189
- script_result = cast("ScriptResultDict", result)
190
- return SQLResult[RowT](
191
- statement=statement,
192
- data=[],
193
- rows_affected=0,
194
- operation_type="SCRIPT",
195
- metadata={
196
- "status_message": script_result.get("status_message", ""),
197
- "statements_executed": script_result.get("statements_executed", -1),
198
- },
199
- )
200
-
201
- dml_result = cast("DMLResultDict", result)
202
- rows_affected = dml_result.get("rows_affected", -1)
203
- status_message = dml_result.get("status_message", "")
204
- return SQLResult[RowT](
205
- statement=statement,
206
- data=[],
207
- rows_affected=rows_affected,
208
- operation_type=operation_type,
209
- metadata={"status_message": status_message},
210
- )
211
-
212
237
  def _connection(self, connection: Optional[PsqlpyConnection] = None) -> PsqlpyConnection:
213
238
  """Get the connection to use for the operation."""
214
239
  return connection or self.connection
@@ -3,7 +3,6 @@
3
3
  import contextlib
4
4
  import logging
5
5
  from contextlib import asynccontextmanager
6
- from dataclasses import replace
7
6
  from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast
8
7
 
9
8
  from psycopg.rows import dict_row
@@ -211,7 +210,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
211
210
  self.configure = configure
212
211
  self.kwargs = kwargs or {}
213
212
 
214
- # Handle extras and additional kwargs
215
213
  self.extras = extras or {}
216
214
  self.extras.update(additional_kwargs)
217
215
 
@@ -240,7 +238,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
240
238
  if self.kwargs:
241
239
  config.update(self.kwargs)
242
240
 
243
- # Set DictRow as the row factory
244
241
  config["row_factory"] = dict_row
245
242
 
246
243
  return config
@@ -263,7 +260,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
263
260
  if self.kwargs:
264
261
  config.update(self.kwargs)
265
262
 
266
- # Set DictRow as the row factory
267
263
  config["row_factory"] = dict_row
268
264
 
269
265
  return config
@@ -273,7 +269,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
273
269
  logger.info("Creating Psycopg connection pool", extra={"adapter": "psycopg"})
274
270
 
275
271
  try:
276
- # Get all config (creates a new dict)
277
272
  all_config = self.pool_config_dict.copy()
278
273
 
279
274
  # Separate pool-specific parameters that ConnectionPool accepts directly
@@ -289,28 +284,27 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
289
284
  "num_workers": all_config.pop("num_workers", 3),
290
285
  }
291
286
 
292
- # Create a configure callback to set row_factory
287
+ # Capture autocommit setting before configuring the pool
288
+ autocommit_setting = all_config.get("autocommit")
289
+
293
290
  def configure_connection(conn: "PsycopgSyncConnection") -> None:
294
- # Set DictRow as the row factory
295
291
  conn.row_factory = dict_row
292
+ # Apply autocommit setting if specified
293
+ if autocommit_setting is not None:
294
+ conn.autocommit = autocommit_setting
296
295
 
297
296
  pool_params["configure"] = all_config.pop("configure", configure_connection)
298
297
 
299
- # Remove None values from pool_params
300
298
  pool_params = {k: v for k, v in pool_params.items() if v is not None}
301
299
 
302
- # Handle conninfo vs individual connection parameters
303
300
  conninfo = all_config.pop("conninfo", None)
304
301
  if conninfo:
305
302
  # If conninfo is provided, use it directly
306
303
  # Don't pass kwargs when using conninfo string
307
304
  pool = ConnectionPool(conninfo, open=True, **pool_params)
308
305
  else:
309
- # Otherwise, pass connection parameters via kwargs
310
- # Remove any non-connection parameters
311
306
  # row_factory is already popped out earlier
312
307
  all_config.pop("row_factory", None)
313
- # Remove pool-specific settings that may have been left
314
308
  all_config.pop("kwargs", None)
315
309
  pool = ConnectionPool("", kwargs=all_config, open=True, **pool_params)
316
310
 
@@ -328,7 +322,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
328
322
  logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})
329
323
 
330
324
  try:
331
- # Set a flag to prevent __del__ from running cleanup
332
325
  # This avoids the "cannot join current thread" error during garbage collection
333
326
  if hasattr(self.pool_instance, "_closed"):
334
327
  self.pool_instance._closed = True
@@ -339,7 +332,6 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
339
332
  logger.exception("Failed to close Psycopg connection pool", extra={"adapter": "psycopg", "error": str(e)})
340
333
  raise
341
334
  finally:
342
- # Clear the reference to help garbage collection
343
335
  self.pool_instance = None
344
336
 
345
337
  def create_connection(self) -> "PsycopgSyncConnection":
@@ -385,15 +377,16 @@ class PsycopgSyncConfig(SyncDatabaseConfig[PsycopgSyncConnection, ConnectionPool
385
377
  A PsycopgSyncDriver instance.
386
378
  """
387
379
  with self.provide_connection(*args, **kwargs) as conn:
388
- # Create statement config with parameter style info if not already set
389
380
  statement_config = self.statement_config
381
+ # Inject parameter style info if not already set
390
382
  if statement_config.allowed_parameter_styles is None:
383
+ from dataclasses import replace
384
+
391
385
  statement_config = replace(
392
386
  statement_config,
393
387
  allowed_parameter_styles=self.supported_parameter_styles,
394
388
  target_parameter_style=self.preferred_parameter_style,
395
389
  )
396
-
397
390
  driver = self.driver_type(connection=conn, config=statement_config)
398
391
  yield driver
399
392
 
@@ -555,7 +548,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
555
548
  self.configure = configure
556
549
  self.kwargs = kwargs or {}
557
550
 
558
- # Handle extras and additional kwargs
559
551
  self.extras = extras or {}
560
552
  self.extras.update(additional_kwargs)
561
553
 
@@ -584,7 +576,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
584
576
  if self.kwargs:
585
577
  config.update(self.kwargs)
586
578
 
587
- # Set DictRow as the row factory
588
579
  config["row_factory"] = dict_row
589
580
 
590
581
  return config
@@ -607,7 +598,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
607
598
  if self.kwargs:
608
599
  config.update(self.kwargs)
609
600
 
610
- # Set DictRow as the row factory
611
601
  config["row_factory"] = dict_row
612
602
 
613
603
  return config
@@ -615,7 +605,6 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
615
605
  async def _create_pool(self) -> "AsyncConnectionPool":
616
606
  """Create the actual async connection pool."""
617
607
 
618
- # Get all config (creates a new dict)
619
608
  all_config = self.pool_config_dict.copy()
620
609
 
621
610
  # Separate pool-specific parameters that AsyncConnectionPool accepts directly
@@ -631,28 +620,27 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
631
620
  "num_workers": all_config.pop("num_workers", 3),
632
621
  }
633
622
 
634
- # Create a configure callback to set row_factory
623
+ # Capture autocommit setting before configuring the pool
624
+ autocommit_setting = all_config.get("autocommit")
625
+
635
626
  async def configure_connection(conn: "PsycopgAsyncConnection") -> None:
636
- # Set DictRow as the row factory
637
627
  conn.row_factory = dict_row
628
+ # Apply autocommit setting if specified (async version requires await)
629
+ if autocommit_setting is not None:
630
+ await conn.set_autocommit(autocommit_setting)
638
631
 
639
632
  pool_params["configure"] = all_config.pop("configure", configure_connection)
640
633
 
641
- # Remove None values from pool_params
642
634
  pool_params = {k: v for k, v in pool_params.items() if v is not None}
643
635
 
644
- # Handle conninfo vs individual connection parameters
645
636
  conninfo = all_config.pop("conninfo", None)
646
637
  if conninfo:
647
638
  # If conninfo is provided, use it directly
648
639
  # Don't pass kwargs when using conninfo string
649
640
  pool = AsyncConnectionPool(conninfo, open=False, **pool_params)
650
641
  else:
651
- # Otherwise, pass connection parameters via kwargs
652
- # Remove any non-connection parameters
653
642
  # row_factory is already popped out earlier
654
643
  all_config.pop("row_factory", None)
655
- # Remove pool-specific settings that may have been left
656
644
  all_config.pop("kwargs", None)
657
645
  pool = AsyncConnectionPool("", kwargs=all_config, open=False, **pool_params)
658
646
 
@@ -666,14 +654,12 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
666
654
  return
667
655
 
668
656
  try:
669
- # Set a flag to prevent __del__ from running cleanup
670
657
  # This avoids the "cannot join current thread" error during garbage collection
671
658
  if hasattr(self.pool_instance, "_closed"):
672
659
  self.pool_instance._closed = True
673
660
 
674
661
  await self.pool_instance.close()
675
662
  finally:
676
- # Clear the reference to help garbage collection
677
663
  self.pool_instance = None
678
664
 
679
665
  async def create_connection(self) -> "PsycopgAsyncConnection": # pyright: ignore
@@ -719,15 +705,16 @@ class PsycopgAsyncConfig(AsyncDatabaseConfig[PsycopgAsyncConnection, AsyncConnec
719
705
  A PsycopgAsyncDriver instance.
720
706
  """
721
707
  async with self.provide_connection(*args, **kwargs) as conn:
722
- # Create statement config with parameter style info if not already set
723
708
  statement_config = self.statement_config
709
+ # Inject parameter style info if not already set
724
710
  if statement_config.allowed_parameter_styles is None:
711
+ from dataclasses import replace
712
+
725
713
  statement_config = replace(
726
714
  statement_config,
727
715
  allowed_parameter_styles=self.supported_parameter_styles,
728
716
  target_parameter_style=self.preferred_parameter_style,
729
717
  )
730
-
731
718
  driver = self.driver_type(connection=conn, config=statement_config)
732
719
  yield driver
733
720