SQLPyHelper 0.1.7__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/PKG-INFO +18 -1
  2. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/PKG-INFO +18 -1
  3. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/SOURCES.txt +2 -0
  4. sqlpyhelper-0.1.8/SQLPyHelper.egg-info/requires.txt +42 -0
  5. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/pyproject.toml +3 -0
  6. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/setup.py +8 -5
  7. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/__init__.py +6 -1
  8. sqlpyhelper-0.1.8/sqlpyhelper/async_helper.py +599 -0
  9. sqlpyhelper-0.1.8/test/test_async_helper.py +478 -0
  10. sqlpyhelper-0.1.7/SQLPyHelper.egg-info/requires.txt +0 -20
  11. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/LICENSE +0 -0
  12. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/README.md +0 -0
  13. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/dependency_links.txt +0 -0
  14. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/entry_points.txt +0 -0
  15. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/top_level.txt +0 -0
  16. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/setup.cfg +0 -0
  17. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/automation_utils.py +0 -0
  18. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/cli.py +0 -0
  19. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/db_helper.py +0 -0
  20. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/migration.py +0 -0
  21. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/py.typed +0 -0
  22. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/test/test_migration.py +0 -0
  23. {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/test/test_sqlpyhelper.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SQLPyHelper
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: A simple SQL database helper package for Python.
5
5
  Home-page: https://github.com/adebayopeter/sqlpyhelper
6
6
  Author: Adebayo Olaonipekun
@@ -34,11 +34,28 @@ Provides-Extra: sqlserver
34
34
  Requires-Dist: pyodbc; extra == "sqlserver"
35
35
  Provides-Extra: oracle
36
36
  Requires-Dist: oracledb; extra == "oracle"
37
+ Provides-Extra: async-postgres
38
+ Requires-Dist: asyncpg; extra == "async-postgres"
39
+ Provides-Extra: async-mysql
40
+ Requires-Dist: aiomysql; extra == "async-mysql"
41
+ Provides-Extra: async-sqlite
42
+ Requires-Dist: aiosqlite; extra == "async-sqlite"
43
+ Provides-Extra: async-sqlserver
44
+ Requires-Dist: aioodbc; extra == "async-sqlserver"
45
+ Provides-Extra: async-all
46
+ Requires-Dist: asyncpg; extra == "async-all"
47
+ Requires-Dist: aiomysql; extra == "async-all"
48
+ Requires-Dist: aiosqlite; extra == "async-all"
49
+ Requires-Dist: aioodbc; extra == "async-all"
37
50
  Provides-Extra: all
38
51
  Requires-Dist: psycopg2; extra == "all"
39
52
  Requires-Dist: mysql-connector-python; extra == "all"
40
53
  Requires-Dist: pyodbc; extra == "all"
41
54
  Requires-Dist: oracledb; extra == "all"
55
+ Requires-Dist: asyncpg; extra == "all"
56
+ Requires-Dist: aiomysql; extra == "all"
57
+ Requires-Dist: aiosqlite; extra == "all"
58
+ Requires-Dist: aioodbc; extra == "all"
42
59
  Dynamic: author
43
60
  Dynamic: author-email
44
61
  Dynamic: classifier
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SQLPyHelper
3
- Version: 0.1.7
3
+ Version: 0.1.8
4
4
  Summary: A simple SQL database helper package for Python.
5
5
  Home-page: https://github.com/adebayopeter/sqlpyhelper
6
6
  Author: Adebayo Olaonipekun
@@ -34,11 +34,28 @@ Provides-Extra: sqlserver
34
34
  Requires-Dist: pyodbc; extra == "sqlserver"
35
35
  Provides-Extra: oracle
36
36
  Requires-Dist: oracledb; extra == "oracle"
37
+ Provides-Extra: async-postgres
38
+ Requires-Dist: asyncpg; extra == "async-postgres"
39
+ Provides-Extra: async-mysql
40
+ Requires-Dist: aiomysql; extra == "async-mysql"
41
+ Provides-Extra: async-sqlite
42
+ Requires-Dist: aiosqlite; extra == "async-sqlite"
43
+ Provides-Extra: async-sqlserver
44
+ Requires-Dist: aioodbc; extra == "async-sqlserver"
45
+ Provides-Extra: async-all
46
+ Requires-Dist: asyncpg; extra == "async-all"
47
+ Requires-Dist: aiomysql; extra == "async-all"
48
+ Requires-Dist: aiosqlite; extra == "async-all"
49
+ Requires-Dist: aioodbc; extra == "async-all"
37
50
  Provides-Extra: all
38
51
  Requires-Dist: psycopg2; extra == "all"
39
52
  Requires-Dist: mysql-connector-python; extra == "all"
40
53
  Requires-Dist: pyodbc; extra == "all"
41
54
  Requires-Dist: oracledb; extra == "all"
55
+ Requires-Dist: asyncpg; extra == "all"
56
+ Requires-Dist: aiomysql; extra == "all"
57
+ Requires-Dist: aiosqlite; extra == "all"
58
+ Requires-Dist: aioodbc; extra == "all"
42
59
  Dynamic: author
43
60
  Dynamic: author-email
44
61
  Dynamic: classifier
@@ -10,10 +10,12 @@ SQLPyHelper.egg-info/entry_points.txt
10
10
  SQLPyHelper.egg-info/requires.txt
11
11
  SQLPyHelper.egg-info/top_level.txt
12
12
  sqlpyhelper/__init__.py
13
+ sqlpyhelper/async_helper.py
13
14
  sqlpyhelper/automation_utils.py
14
15
  sqlpyhelper/cli.py
15
16
  sqlpyhelper/db_helper.py
16
17
  sqlpyhelper/migration.py
17
18
  sqlpyhelper/py.typed
19
+ test/test_async_helper.py
18
20
  test/test_migration.py
19
21
  test/test_sqlpyhelper.py
@@ -0,0 +1,42 @@
1
+ python-dotenv
2
+ click
3
+
4
+ [all]
5
+ psycopg2
6
+ mysql-connector-python
7
+ pyodbc
8
+ oracledb
9
+ asyncpg
10
+ aiomysql
11
+ aiosqlite
12
+ aioodbc
13
+
14
+ [async-all]
15
+ asyncpg
16
+ aiomysql
17
+ aiosqlite
18
+ aioodbc
19
+
20
+ [async-mysql]
21
+ aiomysql
22
+
23
+ [async-postgres]
24
+ asyncpg
25
+
26
+ [async-sqlite]
27
+ aiosqlite
28
+
29
+ [async-sqlserver]
30
+ aioodbc
31
+
32
+ [mysql]
33
+ mysql-connector-python
34
+
35
+ [oracle]
36
+ oracledb
37
+
38
+ [postgres]
39
+ psycopg2
40
+
41
+ [sqlserver]
42
+ pyodbc
@@ -4,3 +4,6 @@ line-length = 88
4
4
  [tool.isort]
5
5
  profile = "black"
6
6
  line_length = 88
7
+
8
+ [tool.pytest.ini_options]
9
+ asyncio_mode = "auto"
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as f:
5
5
 
6
6
  setup(
7
7
  name='SQLPyHelper',
8
- version='0.1.7',
8
+ version='0.1.8',
9
9
  description='A simple SQL database helper package for Python.',
10
10
  long_description=long_description,
11
11
  long_description_content_type="text/markdown",
@@ -26,11 +26,14 @@ setup(
26
26
  "mysql": ["mysql-connector-python"],
27
27
  "sqlserver": ["pyodbc"],
28
28
  "oracle": ["oracledb"],
29
+ "async-postgres": ["asyncpg"],
30
+ "async-mysql": ["aiomysql"],
31
+ "async-sqlite": ["aiosqlite"],
32
+ "async-sqlserver": ["aioodbc"],
33
+ "async-all": ["asyncpg", "aiomysql", "aiosqlite", "aioodbc"],
29
34
  "all": [
30
- "psycopg2",
31
- "mysql-connector-python",
32
- "pyodbc",
33
- "oracledb",
35
+ "psycopg2", "mysql-connector-python", "pyodbc", "oracledb",
36
+ "asyncpg", "aiomysql", "aiosqlite", "aioodbc",
34
37
  ],
35
38
  },
36
39
  keywords=[
@@ -1,6 +1,11 @@
1
1
  # Match the version in setup.py
2
- __version__ = "0.1.7"
2
+ __version__ = "0.1.8"
3
3
 
4
+ from sqlpyhelper.async_helper import ( # noqa: F401
5
+ AsyncConnectionError,
6
+ AsyncQueryError,
7
+ AsyncSQLPyHelper,
8
+ )
4
9
  from sqlpyhelper.db_helper import ( # noqa: F401
5
10
  BackupError,
6
11
  ConnectionError,
@@ -0,0 +1,599 @@
1
+ """
2
+ sqlpyhelper.async_helper
3
+ ~~~~~~~~~~~~~~~~~~~~~~~~
4
+ Async-native database helper supporting SQLite, PostgreSQL, MySQL,
5
+ SQL Server, and Oracle.
6
+
7
+ Uses async-native drivers:
8
+ - SQLite: aiosqlite
9
+ - PostgreSQL: asyncpg
10
+ - MySQL: aiomysql
11
+ - SQL Server: aioodbc
12
+ - Oracle: python-oracledb (async mode)
13
+
14
+ Example usage::
15
+
16
+ import asyncio
17
+ from sqlpyhelper.async_helper import AsyncSQLPyHelper
18
+
19
+ async def main():
20
+ async with AsyncSQLPyHelper(db_type="sqlite", database="my.db") as db:
21
+ await db.execute("CREATE TABLE IF NOT EXISTS users (id INTEGER, name TEXT)")
22
+ await db.execute("INSERT INTO users VALUES ($1, $2)", 1, "Alice")
23
+ rows = await db.fetch_all("SELECT * FROM users")
24
+ print(rows)
25
+
26
+ asyncio.run(main())
27
+ """
28
+
29
+ import logging
30
+ import os
31
+ from typing import Any, Optional
32
+
33
+ from dotenv import load_dotenv
34
+
35
+ load_dotenv()
36
+
37
+ logger = logging.getLogger("sqlpyhelper.async")
38
+
39
+
40
+ class AsyncConnectionError(Exception):
41
+ """Raised when an async database connection fails."""
42
+
43
+
44
+ class AsyncQueryError(Exception):
45
+ """Raised when an async query fails."""
46
+
47
+
48
+ class AsyncSQLPyHelper:
49
+ """
50
+ Async-native database helper with a unified API across
51
+ SQLite, PostgreSQL, MySQL, SQL Server, and Oracle.
52
+
53
+ Use as an async context manager::
54
+
55
+ async with AsyncSQLPyHelper(db_type="postgres", ...) as db:
56
+ rows = await db.fetch_all("SELECT * FROM users")
57
+
58
+ Or manage the connection lifecycle manually::
59
+
60
+ db = AsyncSQLPyHelper(db_type="sqlite", database="my.db")
61
+ await db.connect()
62
+ try:
63
+ rows = await db.fetch_all("SELECT * FROM users")
64
+ finally:
65
+ await db.close()
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ db_type: Optional[str] = None,
71
+ host: Optional[str] = None,
72
+ user: Optional[str] = None,
73
+ password: Optional[str] = None,
74
+ database: Optional[str] = None,
75
+ driver: Optional[str] = None,
76
+ port: Optional[str] = None,
77
+ oracle_sid: Optional[str] = None,
78
+ ) -> None:
79
+ self.db_type: str = (db_type or os.getenv("DB_TYPE") or "").lower()
80
+ self.host: Optional[str] = host or os.getenv("DB_HOST")
81
+ self.user: Optional[str] = user or os.getenv("DB_USER")
82
+ self.password: Optional[str] = password or os.getenv("DB_PASSWORD")
83
+ self.database: Optional[str] = database or os.getenv("DB_NAME")
84
+ self.driver: Optional[str] = driver or os.getenv("DB_DRIVER")
85
+ self.port: Optional[str] = port or os.getenv("DB_PORT")
86
+ self.oracle_sid: Optional[str] = oracle_sid or os.getenv("ORACLE_SID")
87
+
88
+ self._connection: Any = None
89
+ self._pool: Any = None
90
+
91
+ if not self.db_type or not self.database:
92
+ raise ValueError("Missing required database configuration.")
93
+
94
+ if self.db_type not in ("sqlite", "postgres", "mysql", "sqlserver", "oracle"):
95
+ raise ValueError(f"Unsupported database type: {self.db_type!r}")
96
+
97
+ # -----------------------------------------------------------------------
98
+ # Connection lifecycle
99
+ # -----------------------------------------------------------------------
100
+
101
+ async def connect(self) -> None:
102
+ """Open the database connection."""
103
+ try:
104
+ if self.db_type == "sqlite":
105
+ import aiosqlite
106
+
107
+ self._connection = await aiosqlite.connect(self.database or "") # type: ignore[arg-type]
108
+ self._connection.row_factory = aiosqlite.Row
109
+ logger.info("Connected to SQLite database: %s", self.database)
110
+
111
+ elif self.db_type == "postgres":
112
+ import asyncpg
113
+
114
+ self._connection = await asyncpg.connect(
115
+ host=self.host,
116
+ port=int(self.port or 5432),
117
+ user=self.user,
118
+ password=self.password,
119
+ database=self.database,
120
+ )
121
+ logger.info("Connected to PostgreSQL database: %s", self.database)
122
+
123
+ elif self.db_type == "mysql":
124
+ import aiomysql
125
+
126
+ self._connection = await aiomysql.connect(
127
+ host=self.host or "localhost",
128
+ port=int(self.port or 3306),
129
+ user=self.user,
130
+ password=self.password or "",
131
+ db=self.database,
132
+ autocommit=False,
133
+ )
134
+ logger.info("Connected to MySQL database: %s", self.database)
135
+
136
+ elif self.db_type == "sqlserver":
137
+ import aioodbc
138
+
139
+ dsn = (
140
+ f"DRIVER={self.driver};"
141
+ f"SERVER={self.host};"
142
+ f"DATABASE={self.database};"
143
+ f"UID={self.user};"
144
+ f"PWD={self.password}"
145
+ )
146
+ self._connection = await aioodbc.connect(dsn=dsn)
147
+ logger.info("Connected to SQL Server database: %s", self.database)
148
+
149
+ elif self.db_type == "oracle":
150
+ import oracledb
151
+
152
+ oracle_port = int(os.getenv("ORACLE_DB_PORT", "1521"))
153
+ dsn = oracledb.makedsn(
154
+ self.host, oracle_port, sid=self.oracle_sid # type: ignore[arg-type]
155
+ )
156
+ self._connection = await oracledb.connect_async(
157
+ user=self.user, password=self.password, dsn=dsn
158
+ )
159
+ logger.info("Connected to Oracle database: %s", self.oracle_sid)
160
+
161
+ except Exception as e:
162
+ raise AsyncConnectionError(
163
+ f"Failed to connect to {self.db_type}: {e}"
164
+ ) from e
165
+
166
+ async def close(self) -> None:
167
+ """Close the database connection."""
168
+ try:
169
+ if self._connection is not None:
170
+ await self._connection.close()
171
+ self._connection = None
172
+ logger.info("Closed %s connection", self.db_type)
173
+ except Exception as e:
174
+ raise AsyncConnectionError(f"Failed to close connection: {e}") from e
175
+
176
+ async def __aenter__(self) -> "AsyncSQLPyHelper":
177
+ await self.connect()
178
+ return self
179
+
180
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
181
+ await self.close()
182
+ return False
183
+
184
+ # -----------------------------------------------------------------------
185
+ # Internal helpers
186
+ # -----------------------------------------------------------------------
187
+
188
+ def _check_connection(self) -> None:
189
+ if self._connection is None:
190
+ raise AsyncConnectionError(
191
+ "No active connection. Call connect() or use async with."
192
+ )
193
+
194
+ def _adapt_query(self, query: str, args: tuple) -> tuple[str, tuple]:
195
+ """
196
+ Adapt a query and its arguments for the active database driver.
197
+
198
+ asyncpg uses $1, $2, ... positional placeholders.
199
+ aiosqlite and aiomysql use ? and %s respectively.
200
+ Callers should write queries using $1, $2, ... style and this
201
+ method will translate as needed.
202
+ """
203
+ if not args:
204
+ return query, args
205
+
206
+ if self.db_type == "postgres":
207
+ # asyncpg natively uses $1, $2 — pass through unchanged
208
+ return query, args
209
+
210
+ elif self.db_type == "sqlite":
211
+ # Replace $1, $2 with ?
212
+ import re
213
+
214
+ adapted = re.sub(r"\$\d+", "?", query)
215
+ return adapted, args
216
+
217
+ elif self.db_type in ("mysql", "sqlserver"):
218
+ # Replace $1, $2 with %s
219
+ import re
220
+
221
+ adapted = re.sub(r"\$\d+", "%s", query)
222
+ return adapted, args
223
+
224
+ elif self.db_type == "oracle":
225
+ # Replace $1, $2 with :1, :2
226
+ import re
227
+
228
+ def replace_placeholder(m: Any) -> str:
229
+ return f":{m.group(0)[1:]}"
230
+
231
+ adapted = re.sub(r"\$(\d+)", replace_placeholder, query)
232
+ return adapted, args
233
+
234
+ return query, args
235
+
236
+ # -----------------------------------------------------------------------
237
+ # Query execution
238
+ # -----------------------------------------------------------------------
239
+
240
+ async def execute(self, query: str, *args: Any) -> None:
241
+ """
242
+ Execute a SQL statement (INSERT, UPDATE, DELETE, DDL).
243
+
244
+ Use $1, $2, ... for parameterised values::
245
+
246
+ await db.execute(
247
+ "INSERT INTO users (id, name) VALUES ($1, $2)",
248
+ 1, "Alice"
249
+ )
250
+
251
+ Args:
252
+ query: SQL query string using $1, $2 placeholders.
253
+ *args: Query parameters.
254
+
255
+ Raises:
256
+ AsyncQueryError: If the query fails.
257
+ """
258
+ self._check_connection()
259
+ adapted_query, adapted_args = self._adapt_query(query, args)
260
+ try:
261
+ if self.db_type == "postgres":
262
+ await self._connection.execute(adapted_query, *adapted_args)
263
+
264
+ elif self.db_type == "sqlite":
265
+ await self._connection.execute(adapted_query, adapted_args)
266
+ await self._connection.commit()
267
+
268
+ elif self.db_type in ("mysql", "sqlserver"):
269
+ async with self._connection.cursor() as cursor:
270
+ await cursor.execute(adapted_query, adapted_args)
271
+ await self._connection.commit()
272
+
273
+ elif self.db_type == "oracle":
274
+ cursor = self._connection.cursor()
275
+ await cursor.execute(adapted_query, adapted_args)
276
+ await self._connection.commit()
277
+
278
+ logger.debug("Executed: %s", query)
279
+
280
+ except Exception as e:
281
+ raise AsyncQueryError(f"Query failed: {e}") from e
282
+
283
+ async def fetch_one(self, query: str, *args: Any) -> Optional[Any]:
284
+ """
285
+ Execute a SELECT query and return a single row, or None.
286
+
287
+ Args:
288
+ query: SQL query string using $1, $2 placeholders.
289
+ *args: Query parameters.
290
+
291
+ Returns:
292
+ A single row, or None if no rows matched.
293
+
294
+ Raises:
295
+ AsyncQueryError: If the query fails.
296
+ """
297
+ self._check_connection()
298
+ adapted_query, adapted_args = self._adapt_query(query, args)
299
+ try:
300
+ if self.db_type == "postgres":
301
+ return await self._connection.fetchrow(adapted_query, *adapted_args)
302
+
303
+ elif self.db_type == "sqlite":
304
+ async with self._connection.execute(
305
+ adapted_query, adapted_args
306
+ ) as cursor:
307
+ return await cursor.fetchone()
308
+
309
+ elif self.db_type in ("mysql", "sqlserver"):
310
+ async with self._connection.cursor() as cursor:
311
+ await cursor.execute(adapted_query, adapted_args)
312
+ return await cursor.fetchone()
313
+
314
+ elif self.db_type == "oracle":
315
+ cursor = self._connection.cursor()
316
+ await cursor.execute(adapted_query, adapted_args)
317
+ return await cursor.fetchone()
318
+
319
+ return None
320
+
321
+ except Exception as e:
322
+ raise AsyncQueryError(f"fetch_one failed: {e}") from e
323
+
324
+ async def fetch_all(self, query: str, *args: Any) -> list[Any]:
325
+ """
326
+ Execute a SELECT query and return all rows.
327
+
328
+ Args:
329
+ query: SQL query string using $1, $2 placeholders.
330
+ *args: Query parameters.
331
+
332
+ Returns:
333
+ A list of rows (empty list if no rows matched).
334
+
335
+ Raises:
336
+ AsyncQueryError: If the query fails.
337
+ """
338
+ self._check_connection()
339
+ adapted_query, adapted_args = self._adapt_query(query, args)
340
+ try:
341
+ if self.db_type == "postgres":
342
+ return await self._connection.fetch(adapted_query, *adapted_args)
343
+
344
+ elif self.db_type == "sqlite":
345
+ async with self._connection.execute(
346
+ adapted_query, adapted_args
347
+ ) as cursor:
348
+ return await cursor.fetchall()
349
+
350
+ elif self.db_type in ("mysql", "sqlserver"):
351
+ async with self._connection.cursor() as cursor:
352
+ await cursor.execute(adapted_query, adapted_args)
353
+ return await cursor.fetchall()
354
+
355
+ elif self.db_type == "oracle":
356
+ cursor = self._connection.cursor()
357
+ await cursor.execute(adapted_query, adapted_args)
358
+ return await cursor.fetchall()
359
+
360
+ return []
361
+
362
+ except Exception as e:
363
+ raise AsyncQueryError(f"fetch_all failed: {e}") from e
364
+
365
+ async def fetch_val(self, query: str, *args: Any) -> Optional[Any]:
366
+ """
367
+ Execute a SELECT query and return a single scalar value.
368
+
369
+ Useful for COUNT, SUM, or any query returning one value::
370
+
371
+ count = await db.fetch_val("SELECT COUNT(*) FROM users")
372
+
373
+ Args:
374
+ query: SQL query string using $1, $2 placeholders.
375
+ *args: Query parameters.
376
+
377
+ Returns:
378
+ A single scalar value, or None.
379
+
380
+ Raises:
381
+ AsyncQueryError: If the query fails.
382
+ """
383
+ self._check_connection()
384
+ adapted_query, adapted_args = self._adapt_query(query, args)
385
+ try:
386
+ if self.db_type == "postgres":
387
+ return await self._connection.fetchval(adapted_query, *adapted_args)
388
+
389
+ elif self.db_type == "sqlite":
390
+ async with self._connection.execute(
391
+ adapted_query, adapted_args
392
+ ) as cursor:
393
+ row = await cursor.fetchone()
394
+ return row[0] if row else None
395
+
396
+ elif self.db_type in ("mysql", "sqlserver"):
397
+ async with self._connection.cursor() as cursor:
398
+ await cursor.execute(adapted_query, adapted_args)
399
+ row = await cursor.fetchone()
400
+ return row[0] if row else None
401
+
402
+ elif self.db_type == "oracle":
403
+ cursor = self._connection.cursor()
404
+ await cursor.execute(adapted_query, adapted_args)
405
+ row = await cursor.fetchone()
406
+ return row[0] if row else None
407
+
408
+ return None
409
+
410
+ except Exception as e:
411
+ raise AsyncQueryError(f"fetch_val failed: {e}") from e
412
+
413
+ async def execute_many(self, query: str, args_list: list[tuple]) -> None:
414
+ """
415
+ Execute a SQL statement multiple times with different parameters.
416
+ Efficient for bulk inserts::
417
+
418
+ await db.execute_many(
419
+ "INSERT INTO users (id, name) VALUES ($1, $2)",
420
+ [(1, "Alice"), (2, "Bob"), (3, "Charlie")]
421
+ )
422
+
423
+ Args:
424
+ query: SQL query string using $1, $2 placeholders.
425
+ args_list: List of parameter tuples.
426
+
427
+ Raises:
428
+ AsyncQueryError: If the operation fails.
429
+ """
430
+ self._check_connection()
431
+ if not args_list:
432
+ return
433
+ try:
434
+ if self.db_type == "postgres":
435
+ await self._connection.executemany(query, args_list)
436
+
437
+ elif self.db_type == "sqlite":
438
+ import re
439
+
440
+ adapted = re.sub(r"\$\d+", "?", query)
441
+ await self._connection.executemany(adapted, args_list)
442
+ await self._connection.commit()
443
+
444
+ elif self.db_type in ("mysql", "sqlserver"):
445
+ import re
446
+
447
+ adapted = re.sub(r"\$\d+", "%s", query)
448
+ async with self._connection.cursor() as cursor:
449
+ await cursor.executemany(adapted, args_list)
450
+ await self._connection.commit()
451
+
452
+ elif self.db_type == "oracle":
453
+ import re
454
+
455
+ def replace_placeholder(m: Any) -> str:
456
+ return f":{m.group(1)}"
457
+
458
+ adapted = re.sub(r"\$(\d+)", replace_placeholder, query)
459
+ cursor = self._connection.cursor()
460
+ await cursor.executemany(adapted, args_list)
461
+ await self._connection.commit()
462
+
463
+ logger.debug("execute_many: %d rows", len(args_list))
464
+
465
+ except Exception as e:
466
+ raise AsyncQueryError(f"execute_many failed: {e}") from e
467
+
468
+ # -----------------------------------------------------------------------
469
+ # Transaction management
470
+ # -----------------------------------------------------------------------
471
+
472
+ async def begin_transaction(self) -> None:
473
+ """
474
+ Begin an explicit transaction.
475
+
476
+ For PostgreSQL, use the transaction() context manager instead,
477
+ which is the idiomatic asyncpg approach.
478
+
479
+ Raises:
480
+ AsyncQueryError: If the transaction cannot be started.
481
+ """
482
+ self._check_connection()
483
+ try:
484
+ if self.db_type == "sqlite":
485
+ await self._connection.execute("BEGIN")
486
+ elif self.db_type == "mysql":
487
+ await self._connection.begin()
488
+ elif self.db_type == "sqlserver":
489
+ async with self._connection.cursor() as cursor:
490
+ await cursor.execute("BEGIN TRANSACTION")
491
+ elif self.db_type == "oracle":
492
+ pass # Oracle starts transactions implicitly
493
+ elif self.db_type == "postgres":
494
+ # asyncpg transactions are managed via connection.transaction()
495
+ # Calling begin() manually is supported but the context manager
496
+ # is preferred — see transaction() below
497
+ self._pg_transaction = self._connection.transaction()
498
+ await self._pg_transaction.start()
499
+ logger.info("Transaction started on %s", self.db_type)
500
+ except Exception as e:
501
+ raise AsyncQueryError(f"Failed to begin transaction: {e}") from e
502
+
503
+ async def commit_transaction(self) -> None:
504
+ """Commit the current transaction."""
505
+ self._check_connection()
506
+ try:
507
+ if self.db_type == "postgres":
508
+ await self._pg_transaction.commit()
509
+ else:
510
+ await self._connection.commit()
511
+ logger.info("Transaction committed on %s", self.db_type)
512
+ except Exception as e:
513
+ raise AsyncQueryError(f"Failed to commit transaction: {e}") from e
514
+
515
+ async def rollback_transaction(self) -> None:
516
+ """Roll back the current transaction."""
517
+ self._check_connection()
518
+ try:
519
+ if self.db_type == "postgres":
520
+ await self._pg_transaction.rollback()
521
+ else:
522
+ await self._connection.rollback()
523
+ logger.info("Transaction rolled back on %s", self.db_type)
524
+ except Exception as e:
525
+ raise AsyncQueryError(f"Failed to rollback transaction: {e}") from e
526
+
527
+ # -----------------------------------------------------------------------
528
+ # Connection pooling
529
+ # -----------------------------------------------------------------------
530
+
531
+ async def setup_pool(self, min_size: int = 1, max_size: int = 10) -> None:
532
+ """
533
+ Set up an async connection pool.
534
+
535
+ Supported for PostgreSQL and MySQL only.
536
+ After calling this, use get_connection_from_pool() to acquire
537
+ connections.
538
+
539
+ Args:
540
+ min_size: Minimum number of connections in the pool.
541
+ max_size: Maximum number of connections in the pool.
542
+
543
+ Raises:
544
+ AsyncConnectionError: If pool setup fails or db_type
545
+ does not support pooling.
546
+ """
547
+ try:
548
+ if self.db_type == "postgres":
549
+ import asyncpg
550
+
551
+ self._pool = await asyncpg.create_pool(
552
+ host=self.host,
553
+ port=int(self.port or 5432),
554
+ user=self.user,
555
+ password=self.password,
556
+ database=self.database,
557
+ min_size=min_size,
558
+ max_size=max_size,
559
+ )
560
+ logger.info(
561
+ "PostgreSQL async pool created (min=%d, max=%d)", min_size, max_size
562
+ )
563
+
564
+ elif self.db_type == "mysql":
565
+ import aiomysql
566
+
567
+ self._pool = await aiomysql.create_pool(
568
+ host=self.host or "localhost",
569
+ port=int(self.port or 3306),
570
+ user=self.user,
571
+ password=self.password or "",
572
+ db=self.database,
573
+ minsize=min_size,
574
+ maxsize=max_size,
575
+ )
576
+ logger.info(
577
+ "MySQL async pool created (min=%d, max=%d)", min_size, max_size
578
+ )
579
+
580
+ else:
581
+ raise AsyncConnectionError(
582
+ f"Async connection pooling not supported for {self.db_type!r}. "
583
+ "Supported: postgres, mysql."
584
+ )
585
+ except AsyncConnectionError:
586
+ raise
587
+ except Exception as e:
588
+ raise AsyncConnectionError(f"Failed to create async pool: {e}") from e
589
+
590
+ async def close_pool(self) -> None:
591
+ """Close the async connection pool."""
592
+ if self._pool is not None:
593
+ if self.db_type == "mysql":
594
+ self._pool.close()
595
+ await self._pool.wait_closed()
596
+ else:
597
+ await self._pool.close()
598
+ self._pool = None
599
+ logger.info("Async pool closed for %s", self.db_type)
@@ -0,0 +1,478 @@
1
+ """
2
+ Tests for sqlpyhelper.async_helper
3
+ All tests use mocking — no live database required.
4
+ Uses pytest-asyncio for async test support.
5
+ """
6
+
7
+ from unittest.mock import AsyncMock, MagicMock, patch
8
+
9
+ import pytest
10
+
11
+ from sqlpyhelper.async_helper import (
12
+ AsyncConnectionError,
13
+ AsyncQueryError,
14
+ AsyncSQLPyHelper,
15
+ )
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Helpers
19
+ # ---------------------------------------------------------------------------
20
+
21
+
22
+ def make_async_db(db_type: str = "sqlite") -> AsyncSQLPyHelper:
23
+ """Create an AsyncSQLPyHelper instance without connecting."""
24
+ kwargs = {"db_type": db_type, "database": "test.db"}
25
+ if db_type != "sqlite":
26
+ kwargs.update(
27
+ {
28
+ "host": "localhost",
29
+ "user": "user",
30
+ "password": "pass",
31
+ "database": "testdb",
32
+ }
33
+ )
34
+ if db_type == "oracle":
35
+ kwargs["oracle_sid"] = "XE"
36
+ if db_type == "sqlserver":
37
+ kwargs["driver"] = "ODBC Driver 17 for SQL Server"
38
+ return AsyncSQLPyHelper(**kwargs)
39
+
40
+
41
+ def attach_mock_connection(db: AsyncSQLPyHelper) -> MagicMock:
42
+ """Attach a mock connection to an AsyncSQLPyHelper instance."""
43
+ mock_conn = MagicMock()
44
+ mock_conn.close = AsyncMock()
45
+ mock_conn.commit = AsyncMock()
46
+ mock_conn.rollback = AsyncMock()
47
+ mock_conn.execute = AsyncMock()
48
+ mock_conn.executemany = AsyncMock()
49
+ mock_conn.fetch = AsyncMock(return_value=[])
50
+ mock_conn.fetchrow = AsyncMock(return_value=None)
51
+ mock_conn.fetchval = AsyncMock(return_value=None)
52
+ db._connection = mock_conn
53
+ return mock_conn
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # __init__ validation
58
+ # ---------------------------------------------------------------------------
59
+
60
+
61
+ class TestInit:
62
+ def test_missing_database_raises(self):
63
+ with patch.dict("os.environ", {}, clear=True):
64
+ with pytest.raises(
65
+ ValueError, match="Missing required database configuration"
66
+ ):
67
+ AsyncSQLPyHelper(db_type="sqlite", database="")
68
+
69
+ def test_missing_db_type_raises(self):
70
+ with patch.dict("os.environ", {}, clear=True):
71
+ with pytest.raises(
72
+ ValueError, match="Missing required database configuration"
73
+ ):
74
+ AsyncSQLPyHelper(db_type="", database="test.db")
75
+
76
+ def test_unsupported_db_type_raises(self):
77
+ with pytest.raises(ValueError, match="Unsupported database type"):
78
+ AsyncSQLPyHelper(db_type="mongodb", database="test")
79
+
80
+ def test_valid_sqlite_init(self):
81
+ db = AsyncSQLPyHelper(db_type="sqlite", database="test.db")
82
+ assert db.db_type == "sqlite"
83
+ assert db.database == "test.db"
84
+ assert db._connection is None
85
+
86
+ def test_valid_postgres_init(self):
87
+ db = AsyncSQLPyHelper(
88
+ db_type="postgres",
89
+ host="localhost",
90
+ user="user",
91
+ password="pass",
92
+ database="testdb",
93
+ )
94
+ assert db.db_type == "postgres"
95
+
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # connect / close
99
+ # ---------------------------------------------------------------------------
100
+
101
+
102
+ class TestConnect:
103
+ @pytest.mark.asyncio
104
+ async def test_sqlite_connect(self):
105
+ db = make_async_db("sqlite")
106
+ mock_conn = MagicMock()
107
+ mock_conn.close = AsyncMock()
108
+ with patch("aiosqlite.connect", new_callable=AsyncMock, return_value=mock_conn):
109
+ await db.connect()
110
+ assert db._connection is mock_conn
111
+
112
+ @pytest.mark.asyncio
113
+ async def test_postgres_connect(self):
114
+ db = make_async_db("postgres")
115
+ mock_conn = MagicMock()
116
+ mock_conn.close = AsyncMock()
117
+ with patch("asyncpg.connect", new_callable=AsyncMock, return_value=mock_conn):
118
+ await db.connect()
119
+ assert db._connection is mock_conn
120
+
121
+ @pytest.mark.asyncio
122
+ async def test_mysql_connect(self):
123
+ db = make_async_db("mysql")
124
+ mock_conn = MagicMock()
125
+ mock_conn.close = AsyncMock()
126
+ with patch("aiomysql.connect", new_callable=AsyncMock, return_value=mock_conn):
127
+ await db.connect()
128
+ assert db._connection is mock_conn
129
+
130
+ @pytest.mark.asyncio
131
+ async def test_connect_failure_raises_async_connection_error(self):
132
+ db = make_async_db("sqlite")
133
+ with patch("aiosqlite.connect", side_effect=Exception("disk full")):
134
+ with pytest.raises(AsyncConnectionError, match="Failed to connect"):
135
+ await db.connect()
136
+
137
+ @pytest.mark.asyncio
138
+ async def test_close(self):
139
+ db = make_async_db("sqlite")
140
+ mock_conn = attach_mock_connection(db)
141
+ await db.close()
142
+ mock_conn.close.assert_called_once()
143
+ assert db._connection is None
144
+
145
+ @pytest.mark.asyncio
146
+ async def test_close_raises_on_failure(self):
147
+ db = make_async_db("sqlite")
148
+ mock_conn = attach_mock_connection(db)
149
+ mock_conn.close.side_effect = Exception("already closed")
150
+ with pytest.raises(AsyncConnectionError, match="Failed to close"):
151
+ await db.close()
152
+
153
+
154
+ # ---------------------------------------------------------------------------
155
+ # Context manager
156
+ # ---------------------------------------------------------------------------
157
+
158
+
159
+ class TestContextManager:
160
+ @pytest.mark.asyncio
161
+ async def test_aenter_returns_self(self):
162
+ db = make_async_db("sqlite")
163
+ mock_conn = MagicMock()
164
+ mock_conn.close = AsyncMock()
165
+ with patch("aiosqlite.connect", new_callable=AsyncMock, return_value=mock_conn):
166
+ result = await db.__aenter__()
167
+ assert result is db
168
+
169
+ @pytest.mark.asyncio
170
+ async def test_aexit_closes_connection(self):
171
+ db = make_async_db("sqlite")
172
+ mock_conn = attach_mock_connection(db)
173
+ await db.__aexit__(None, None, None)
174
+ mock_conn.close.assert_called_once()
175
+
176
+ @pytest.mark.asyncio
177
+ async def test_aexit_returns_false(self):
178
+ db = make_async_db("sqlite")
179
+ attach_mock_connection(db)
180
+ result = await db.__aexit__(None, None, None)
181
+ assert result is False
182
+
183
+ @pytest.mark.asyncio
184
+ async def test_async_with_statement(self):
185
+ mock_conn = MagicMock()
186
+ mock_conn.close = AsyncMock()
187
+ with patch("aiosqlite.connect", new_callable=AsyncMock, return_value=mock_conn):
188
+ async with AsyncSQLPyHelper(db_type="sqlite", database="test.db") as db:
189
+ assert db._connection is mock_conn
190
+ mock_conn.close.assert_called_once()
191
+
192
+
193
+ # ---------------------------------------------------------------------------
194
+ # _check_connection
195
+ # ---------------------------------------------------------------------------
196
+
197
+
198
+ class TestCheckConnection:
199
+ def test_raises_when_no_connection(self):
200
+ db = make_async_db("sqlite")
201
+ with pytest.raises(AsyncConnectionError, match="No active connection"):
202
+ db._check_connection()
203
+
204
+ def test_passes_when_connected(self):
205
+ db = make_async_db("sqlite")
206
+ attach_mock_connection(db)
207
+ db._check_connection() # should not raise
208
+
209
+
210
+ # ---------------------------------------------------------------------------
211
+ # _adapt_query
212
+ # ---------------------------------------------------------------------------
213
+
214
+
215
+ class TestAdaptQuery:
216
+ def test_postgres_passes_through(self):
217
+ db = make_async_db("postgres")
218
+ q, args = db._adapt_query("SELECT $1", (1,))
219
+ assert q == "SELECT $1"
220
+
221
+ def test_sqlite_replaces_with_question_mark(self):
222
+ db = make_async_db("sqlite")
223
+ q, args = db._adapt_query("SELECT $1, $2", (1, 2))
224
+ assert q == "SELECT ?, ?"
225
+
226
+ def test_mysql_replaces_with_percent_s(self):
227
+ db = make_async_db("mysql")
228
+ q, args = db._adapt_query("SELECT $1, $2", (1, 2))
229
+ assert q == "SELECT %s, %s"
230
+
231
+ def test_oracle_replaces_with_colon(self):
232
+ db = AsyncSQLPyHelper(
233
+ db_type="oracle",
234
+ host="localhost",
235
+ user="u",
236
+ password="p",
237
+ database="d",
238
+ oracle_sid="XE",
239
+ )
240
+ q, args = db._adapt_query("SELECT $1, $2", (1, 2))
241
+ assert q == "SELECT :1, :2"
242
+
243
+ def test_empty_args_returns_unchanged(self):
244
+ db = make_async_db("sqlite")
245
+ q, args = db._adapt_query("SELECT 1", ())
246
+ assert q == "SELECT 1"
247
+ assert args == ()
248
+
249
+
250
+ # ---------------------------------------------------------------------------
251
+ # execute
252
+ # ---------------------------------------------------------------------------
253
+
254
+
255
+ class TestExecute:
256
+ @pytest.mark.asyncio
257
+ async def test_sqlite_execute(self):
258
+ db = make_async_db("sqlite")
259
+ mock_conn = attach_mock_connection(db)
260
+ await db.execute("CREATE TABLE users (id INTEGER)")
261
+ mock_conn.execute.assert_called_once()
262
+
263
+ @pytest.mark.asyncio
264
+ async def test_postgres_execute(self):
265
+ db = make_async_db("postgres")
266
+ mock_conn = attach_mock_connection(db)
267
+ await db.execute("INSERT INTO users VALUES ($1)", 1)
268
+ mock_conn.execute.assert_called_once()
269
+
270
+ @pytest.mark.asyncio
271
+ async def test_raises_async_query_error_on_failure(self):
272
+ db = make_async_db("sqlite")
273
+ mock_conn = attach_mock_connection(db)
274
+ mock_conn.execute.side_effect = Exception("disk full")
275
+ with pytest.raises(AsyncQueryError, match="Query failed"):
276
+ await db.execute("INSERT INTO users VALUES ($1)", 1)
277
+
278
+ @pytest.mark.asyncio
279
+ async def test_raises_when_not_connected(self):
280
+ db = make_async_db("sqlite")
281
+ with pytest.raises(AsyncConnectionError, match="No active connection"):
282
+ await db.execute("SELECT 1")
283
+
284
+
285
+ # ---------------------------------------------------------------------------
286
+ # fetch_one
287
+ # ---------------------------------------------------------------------------
288
+
289
+
290
+ class TestFetchOne:
291
+ @pytest.mark.asyncio
292
+ async def test_postgres_fetch_one(self):
293
+ db = make_async_db("postgres")
294
+ mock_conn = attach_mock_connection(db)
295
+ mock_conn.fetchrow.return_value = {"id": 1, "name": "Alice"}
296
+ result = await db.fetch_one("SELECT * FROM users WHERE id = $1", 1)
297
+ assert result == {"id": 1, "name": "Alice"}
298
+
299
+ @pytest.mark.asyncio
300
+ async def test_raises_on_failure(self):
301
+ db = make_async_db("postgres")
302
+ mock_conn = attach_mock_connection(db)
303
+ mock_conn.fetchrow.side_effect = Exception("connection lost")
304
+ with pytest.raises(AsyncQueryError, match="fetch_one failed"):
305
+ await db.fetch_one("SELECT * FROM users")
306
+
307
+ @pytest.mark.asyncio
308
+ async def test_raises_when_not_connected(self):
309
+ db = make_async_db("sqlite")
310
+ with pytest.raises(AsyncConnectionError):
311
+ await db.fetch_one("SELECT 1")
312
+
313
+
314
+ # ---------------------------------------------------------------------------
315
+ # fetch_all
316
+ # ---------------------------------------------------------------------------
317
+
318
+
319
+ class TestFetchAll:
320
+ @pytest.mark.asyncio
321
+ async def test_postgres_fetch_all(self):
322
+ db = make_async_db("postgres")
323
+ mock_conn = attach_mock_connection(db)
324
+ mock_conn.fetch.return_value = [{"id": 1}, {"id": 2}]
325
+ result = await db.fetch_all("SELECT * FROM users")
326
+ assert len(result) == 2
327
+
328
+ @pytest.mark.asyncio
329
+ async def test_raises_on_failure(self):
330
+ db = make_async_db("postgres")
331
+ mock_conn = attach_mock_connection(db)
332
+ mock_conn.fetch.side_effect = Exception("timeout")
333
+ with pytest.raises(AsyncQueryError, match="fetch_all failed"):
334
+ await db.fetch_all("SELECT * FROM users")
335
+
336
+ @pytest.mark.asyncio
337
+ async def test_raises_when_not_connected(self):
338
+ db = make_async_db("sqlite")
339
+ with pytest.raises(AsyncConnectionError):
340
+ await db.fetch_all("SELECT 1")
341
+
342
+
343
+ # ---------------------------------------------------------------------------
344
+ # fetch_val
345
+ # ---------------------------------------------------------------------------
346
+
347
+
348
+ class TestFetchVal:
349
+ @pytest.mark.asyncio
350
+ async def test_postgres_fetch_val(self):
351
+ db = make_async_db("postgres")
352
+ mock_conn = attach_mock_connection(db)
353
+ mock_conn.fetchval.return_value = 42
354
+ result = await db.fetch_val("SELECT COUNT(*) FROM users")
355
+ assert result == 42
356
+
357
+ @pytest.mark.asyncio
358
+ async def test_raises_on_failure(self):
359
+ db = make_async_db("postgres")
360
+ mock_conn = attach_mock_connection(db)
361
+ mock_conn.fetchval.side_effect = Exception("error")
362
+ with pytest.raises(AsyncQueryError, match="fetch_val failed"):
363
+ await db.fetch_val("SELECT COUNT(*) FROM users")
364
+
365
+
366
+ # ---------------------------------------------------------------------------
367
+ # execute_many
368
+ # ---------------------------------------------------------------------------
369
+
370
+
371
+ class TestExecuteMany:
372
+ @pytest.mark.asyncio
373
+ async def test_postgres_execute_many(self):
374
+ db = make_async_db("postgres")
375
+ mock_conn = attach_mock_connection(db)
376
+ mock_conn.executemany = AsyncMock()
377
+ await db.execute_many(
378
+ "INSERT INTO users VALUES ($1, $2)", [(1, "Alice"), (2, "Bob")]
379
+ )
380
+ mock_conn.executemany.assert_called_once()
381
+
382
+ @pytest.mark.asyncio
383
+ async def test_empty_list_does_nothing(self):
384
+ db = make_async_db("postgres")
385
+ mock_conn = attach_mock_connection(db)
386
+ await db.execute_many("INSERT INTO users VALUES ($1)", [])
387
+ mock_conn.executemany.assert_not_called()
388
+
389
+ @pytest.mark.asyncio
390
+ async def test_raises_on_failure(self):
391
+ db = make_async_db("postgres")
392
+ mock_conn = attach_mock_connection(db)
393
+ mock_conn.executemany.side_effect = Exception("error")
394
+ with pytest.raises(AsyncQueryError, match="execute_many failed"):
395
+ await db.execute_many("INSERT INTO users VALUES ($1)", [(1,)])
396
+
397
+
398
+ # ---------------------------------------------------------------------------
399
+ # Transactions
400
+ # ---------------------------------------------------------------------------
401
+
402
+
403
+ class TestTransactions:
404
+ @pytest.mark.asyncio
405
+ async def test_mysql_begin_calls_begin(self):
406
+ db = make_async_db("mysql")
407
+ mock_conn = attach_mock_connection(db)
408
+ mock_conn.begin = AsyncMock()
409
+ await db.begin_transaction()
410
+ mock_conn.begin.assert_called_once()
411
+
412
+ @pytest.mark.asyncio
413
+ async def test_commit(self):
414
+ db = make_async_db("mysql")
415
+ mock_conn = attach_mock_connection(db)
416
+ await db.commit_transaction()
417
+ mock_conn.commit.assert_called_once()
418
+
419
+ @pytest.mark.asyncio
420
+ async def test_rollback(self):
421
+ db = make_async_db("mysql")
422
+ mock_conn = attach_mock_connection(db)
423
+ await db.rollback_transaction()
424
+ mock_conn.rollback.assert_called_once()
425
+
426
+ @pytest.mark.asyncio
427
+ async def test_rollback_raises_on_failure(self):
428
+ db = make_async_db("mysql")
429
+ mock_conn = attach_mock_connection(db)
430
+ mock_conn.rollback.side_effect = Exception("rollback error")
431
+ with pytest.raises(AsyncQueryError, match="Failed to rollback"):
432
+ await db.rollback_transaction()
433
+
434
+
435
+ # ---------------------------------------------------------------------------
436
+ # Connection pool
437
+ # ---------------------------------------------------------------------------
438
+
439
+
440
+ class TestPool:
441
+ @pytest.mark.asyncio
442
+ async def test_postgres_pool_setup(self):
443
+ db = make_async_db("postgres")
444
+ mock_pool = MagicMock()
445
+ mock_pool.close = AsyncMock()
446
+ with patch(
447
+ "asyncpg.create_pool", new_callable=AsyncMock, return_value=mock_pool
448
+ ):
449
+ await db.setup_pool(min_size=1, max_size=5)
450
+ assert db._pool is mock_pool
451
+
452
+ @pytest.mark.asyncio
453
+ async def test_mysql_pool_setup(self):
454
+ db = make_async_db("mysql")
455
+ mock_pool = MagicMock()
456
+ mock_pool.close = MagicMock()
457
+ mock_pool.wait_closed = AsyncMock()
458
+ with patch(
459
+ "aiomysql.create_pool", new_callable=AsyncMock, return_value=mock_pool
460
+ ):
461
+ await db.setup_pool(min_size=1, max_size=5)
462
+ assert db._pool is mock_pool
463
+
464
+ @pytest.mark.asyncio
465
+ async def test_unsupported_db_raises(self):
466
+ db = make_async_db("sqlite")
467
+ with pytest.raises(AsyncConnectionError, match="not supported"):
468
+ await db.setup_pool()
469
+
470
+ @pytest.mark.asyncio
471
+ async def test_close_pool_postgres(self):
472
+ db = make_async_db("postgres")
473
+ mock_pool = MagicMock()
474
+ mock_pool.close = AsyncMock()
475
+ db._pool = mock_pool
476
+ await db.close_pool()
477
+ mock_pool.close.assert_called_once()
478
+ assert db._pool is None
@@ -1,20 +0,0 @@
1
- python-dotenv
2
- click
3
-
4
- [all]
5
- psycopg2
6
- mysql-connector-python
7
- pyodbc
8
- oracledb
9
-
10
- [mysql]
11
- mysql-connector-python
12
-
13
- [oracle]
14
- oracledb
15
-
16
- [postgres]
17
- psycopg2
18
-
19
- [sqlserver]
20
- pyodbc
File without changes
File without changes
File without changes