SQLPyHelper 0.1.7__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/PKG-INFO +18 -1
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/PKG-INFO +18 -1
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/SOURCES.txt +2 -0
- sqlpyhelper-0.1.8/SQLPyHelper.egg-info/requires.txt +42 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/pyproject.toml +3 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/setup.py +8 -5
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/__init__.py +6 -1
- sqlpyhelper-0.1.8/sqlpyhelper/async_helper.py +599 -0
- sqlpyhelper-0.1.8/test/test_async_helper.py +478 -0
- sqlpyhelper-0.1.7/SQLPyHelper.egg-info/requires.txt +0 -20
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/LICENSE +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/README.md +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/dependency_links.txt +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/entry_points.txt +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/SQLPyHelper.egg-info/top_level.txt +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/setup.cfg +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/automation_utils.py +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/cli.py +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/db_helper.py +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/migration.py +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/sqlpyhelper/py.typed +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/test/test_migration.py +0 -0
- {sqlpyhelper-0.1.7 → sqlpyhelper-0.1.8}/test/test_sqlpyhelper.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: SQLPyHelper
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: A simple SQL database helper package for Python.
|
|
5
5
|
Home-page: https://github.com/adebayopeter/sqlpyhelper
|
|
6
6
|
Author: Adebayo Olaonipekun
|
|
@@ -34,11 +34,28 @@ Provides-Extra: sqlserver
|
|
|
34
34
|
Requires-Dist: pyodbc; extra == "sqlserver"
|
|
35
35
|
Provides-Extra: oracle
|
|
36
36
|
Requires-Dist: oracledb; extra == "oracle"
|
|
37
|
+
Provides-Extra: async-postgres
|
|
38
|
+
Requires-Dist: asyncpg; extra == "async-postgres"
|
|
39
|
+
Provides-Extra: async-mysql
|
|
40
|
+
Requires-Dist: aiomysql; extra == "async-mysql"
|
|
41
|
+
Provides-Extra: async-sqlite
|
|
42
|
+
Requires-Dist: aiosqlite; extra == "async-sqlite"
|
|
43
|
+
Provides-Extra: async-sqlserver
|
|
44
|
+
Requires-Dist: aioodbc; extra == "async-sqlserver"
|
|
45
|
+
Provides-Extra: async-all
|
|
46
|
+
Requires-Dist: asyncpg; extra == "async-all"
|
|
47
|
+
Requires-Dist: aiomysql; extra == "async-all"
|
|
48
|
+
Requires-Dist: aiosqlite; extra == "async-all"
|
|
49
|
+
Requires-Dist: aioodbc; extra == "async-all"
|
|
37
50
|
Provides-Extra: all
|
|
38
51
|
Requires-Dist: psycopg2; extra == "all"
|
|
39
52
|
Requires-Dist: mysql-connector-python; extra == "all"
|
|
40
53
|
Requires-Dist: pyodbc; extra == "all"
|
|
41
54
|
Requires-Dist: oracledb; extra == "all"
|
|
55
|
+
Requires-Dist: asyncpg; extra == "all"
|
|
56
|
+
Requires-Dist: aiomysql; extra == "all"
|
|
57
|
+
Requires-Dist: aiosqlite; extra == "all"
|
|
58
|
+
Requires-Dist: aioodbc; extra == "all"
|
|
42
59
|
Dynamic: author
|
|
43
60
|
Dynamic: author-email
|
|
44
61
|
Dynamic: classifier
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: SQLPyHelper
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: A simple SQL database helper package for Python.
|
|
5
5
|
Home-page: https://github.com/adebayopeter/sqlpyhelper
|
|
6
6
|
Author: Adebayo Olaonipekun
|
|
@@ -34,11 +34,28 @@ Provides-Extra: sqlserver
|
|
|
34
34
|
Requires-Dist: pyodbc; extra == "sqlserver"
|
|
35
35
|
Provides-Extra: oracle
|
|
36
36
|
Requires-Dist: oracledb; extra == "oracle"
|
|
37
|
+
Provides-Extra: async-postgres
|
|
38
|
+
Requires-Dist: asyncpg; extra == "async-postgres"
|
|
39
|
+
Provides-Extra: async-mysql
|
|
40
|
+
Requires-Dist: aiomysql; extra == "async-mysql"
|
|
41
|
+
Provides-Extra: async-sqlite
|
|
42
|
+
Requires-Dist: aiosqlite; extra == "async-sqlite"
|
|
43
|
+
Provides-Extra: async-sqlserver
|
|
44
|
+
Requires-Dist: aioodbc; extra == "async-sqlserver"
|
|
45
|
+
Provides-Extra: async-all
|
|
46
|
+
Requires-Dist: asyncpg; extra == "async-all"
|
|
47
|
+
Requires-Dist: aiomysql; extra == "async-all"
|
|
48
|
+
Requires-Dist: aiosqlite; extra == "async-all"
|
|
49
|
+
Requires-Dist: aioodbc; extra == "async-all"
|
|
37
50
|
Provides-Extra: all
|
|
38
51
|
Requires-Dist: psycopg2; extra == "all"
|
|
39
52
|
Requires-Dist: mysql-connector-python; extra == "all"
|
|
40
53
|
Requires-Dist: pyodbc; extra == "all"
|
|
41
54
|
Requires-Dist: oracledb; extra == "all"
|
|
55
|
+
Requires-Dist: asyncpg; extra == "all"
|
|
56
|
+
Requires-Dist: aiomysql; extra == "all"
|
|
57
|
+
Requires-Dist: aiosqlite; extra == "all"
|
|
58
|
+
Requires-Dist: aioodbc; extra == "all"
|
|
42
59
|
Dynamic: author
|
|
43
60
|
Dynamic: author-email
|
|
44
61
|
Dynamic: classifier
|
|
@@ -10,10 +10,12 @@ SQLPyHelper.egg-info/entry_points.txt
|
|
|
10
10
|
SQLPyHelper.egg-info/requires.txt
|
|
11
11
|
SQLPyHelper.egg-info/top_level.txt
|
|
12
12
|
sqlpyhelper/__init__.py
|
|
13
|
+
sqlpyhelper/async_helper.py
|
|
13
14
|
sqlpyhelper/automation_utils.py
|
|
14
15
|
sqlpyhelper/cli.py
|
|
15
16
|
sqlpyhelper/db_helper.py
|
|
16
17
|
sqlpyhelper/migration.py
|
|
17
18
|
sqlpyhelper/py.typed
|
|
19
|
+
test/test_async_helper.py
|
|
18
20
|
test/test_migration.py
|
|
19
21
|
test/test_sqlpyhelper.py
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
python-dotenv
|
|
2
|
+
click
|
|
3
|
+
|
|
4
|
+
[all]
|
|
5
|
+
psycopg2
|
|
6
|
+
mysql-connector-python
|
|
7
|
+
pyodbc
|
|
8
|
+
oracledb
|
|
9
|
+
asyncpg
|
|
10
|
+
aiomysql
|
|
11
|
+
aiosqlite
|
|
12
|
+
aioodbc
|
|
13
|
+
|
|
14
|
+
[async-all]
|
|
15
|
+
asyncpg
|
|
16
|
+
aiomysql
|
|
17
|
+
aiosqlite
|
|
18
|
+
aioodbc
|
|
19
|
+
|
|
20
|
+
[async-mysql]
|
|
21
|
+
aiomysql
|
|
22
|
+
|
|
23
|
+
[async-postgres]
|
|
24
|
+
asyncpg
|
|
25
|
+
|
|
26
|
+
[async-sqlite]
|
|
27
|
+
aiosqlite
|
|
28
|
+
|
|
29
|
+
[async-sqlserver]
|
|
30
|
+
aioodbc
|
|
31
|
+
|
|
32
|
+
[mysql]
|
|
33
|
+
mysql-connector-python
|
|
34
|
+
|
|
35
|
+
[oracle]
|
|
36
|
+
oracledb
|
|
37
|
+
|
|
38
|
+
[postgres]
|
|
39
|
+
psycopg2
|
|
40
|
+
|
|
41
|
+
[sqlserver]
|
|
42
|
+
pyodbc
|
|
@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as f:
|
|
|
5
5
|
|
|
6
6
|
setup(
|
|
7
7
|
name='SQLPyHelper',
|
|
8
|
-
version='0.1.
|
|
8
|
+
version='0.1.8',
|
|
9
9
|
description='A simple SQL database helper package for Python.',
|
|
10
10
|
long_description=long_description,
|
|
11
11
|
long_description_content_type="text/markdown",
|
|
@@ -26,11 +26,14 @@ setup(
|
|
|
26
26
|
"mysql": ["mysql-connector-python"],
|
|
27
27
|
"sqlserver": ["pyodbc"],
|
|
28
28
|
"oracle": ["oracledb"],
|
|
29
|
+
"async-postgres": ["asyncpg"],
|
|
30
|
+
"async-mysql": ["aiomysql"],
|
|
31
|
+
"async-sqlite": ["aiosqlite"],
|
|
32
|
+
"async-sqlserver": ["aioodbc"],
|
|
33
|
+
"async-all": ["asyncpg", "aiomysql", "aiosqlite", "aioodbc"],
|
|
29
34
|
"all": [
|
|
30
|
-
"psycopg2",
|
|
31
|
-
"
|
|
32
|
-
"pyodbc",
|
|
33
|
-
"oracledb",
|
|
35
|
+
"psycopg2", "mysql-connector-python", "pyodbc", "oracledb",
|
|
36
|
+
"asyncpg", "aiomysql", "aiosqlite", "aioodbc",
|
|
34
37
|
],
|
|
35
38
|
},
|
|
36
39
|
keywords=[
|
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
# Match the version in setup.py
|
|
2
|
-
__version__ = "0.1.
|
|
2
|
+
__version__ = "0.1.8"
|
|
3
3
|
|
|
4
|
+
from sqlpyhelper.async_helper import ( # noqa: F401
|
|
5
|
+
AsyncConnectionError,
|
|
6
|
+
AsyncQueryError,
|
|
7
|
+
AsyncSQLPyHelper,
|
|
8
|
+
)
|
|
4
9
|
from sqlpyhelper.db_helper import ( # noqa: F401
|
|
5
10
|
BackupError,
|
|
6
11
|
ConnectionError,
|
|
@@ -0,0 +1,599 @@
|
|
|
1
|
+
"""
|
|
2
|
+
sqlpyhelper.async_helper
|
|
3
|
+
~~~~~~~~~~~~~~~~~~~~~~~~
|
|
4
|
+
Async-native database helper supporting SQLite, PostgreSQL, MySQL,
|
|
5
|
+
SQL Server, and Oracle.
|
|
6
|
+
|
|
7
|
+
Uses async-native drivers:
|
|
8
|
+
- SQLite: aiosqlite
|
|
9
|
+
- PostgreSQL: asyncpg
|
|
10
|
+
- MySQL: aiomysql
|
|
11
|
+
- SQL Server: aioodbc
|
|
12
|
+
- Oracle: python-oracledb (async mode)
|
|
13
|
+
|
|
14
|
+
Example usage::
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
from sqlpyhelper.async_helper import AsyncSQLPyHelper
|
|
18
|
+
|
|
19
|
+
async def main():
|
|
20
|
+
async with AsyncSQLPyHelper(db_type="sqlite", database="my.db") as db:
|
|
21
|
+
await db.execute("CREATE TABLE IF NOT EXISTS users (id INTEGER, name TEXT)")
|
|
22
|
+
await db.execute("INSERT INTO users VALUES ($1, $2)", 1, "Alice")
|
|
23
|
+
rows = await db.fetch_all("SELECT * FROM users")
|
|
24
|
+
print(rows)
|
|
25
|
+
|
|
26
|
+
asyncio.run(main())
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import logging
|
|
30
|
+
import os
|
|
31
|
+
from typing import Any, Optional
|
|
32
|
+
|
|
33
|
+
from dotenv import load_dotenv
|
|
34
|
+
|
|
35
|
+
load_dotenv()
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger("sqlpyhelper.async")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AsyncConnectionError(Exception):
|
|
41
|
+
"""Raised when an async database connection fails."""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AsyncQueryError(Exception):
|
|
45
|
+
"""Raised when an async query fails."""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AsyncSQLPyHelper:
|
|
49
|
+
"""
|
|
50
|
+
Async-native database helper with a unified API across
|
|
51
|
+
SQLite, PostgreSQL, MySQL, SQL Server, and Oracle.
|
|
52
|
+
|
|
53
|
+
Use as an async context manager::
|
|
54
|
+
|
|
55
|
+
async with AsyncSQLPyHelper(db_type="postgres", ...) as db:
|
|
56
|
+
rows = await db.fetch_all("SELECT * FROM users")
|
|
57
|
+
|
|
58
|
+
Or manage the connection lifecycle manually::
|
|
59
|
+
|
|
60
|
+
db = AsyncSQLPyHelper(db_type="sqlite", database="my.db")
|
|
61
|
+
await db.connect()
|
|
62
|
+
try:
|
|
63
|
+
rows = await db.fetch_all("SELECT * FROM users")
|
|
64
|
+
finally:
|
|
65
|
+
await db.close()
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
db_type: Optional[str] = None,
|
|
71
|
+
host: Optional[str] = None,
|
|
72
|
+
user: Optional[str] = None,
|
|
73
|
+
password: Optional[str] = None,
|
|
74
|
+
database: Optional[str] = None,
|
|
75
|
+
driver: Optional[str] = None,
|
|
76
|
+
port: Optional[str] = None,
|
|
77
|
+
oracle_sid: Optional[str] = None,
|
|
78
|
+
) -> None:
|
|
79
|
+
self.db_type: str = (db_type or os.getenv("DB_TYPE") or "").lower()
|
|
80
|
+
self.host: Optional[str] = host or os.getenv("DB_HOST")
|
|
81
|
+
self.user: Optional[str] = user or os.getenv("DB_USER")
|
|
82
|
+
self.password: Optional[str] = password or os.getenv("DB_PASSWORD")
|
|
83
|
+
self.database: Optional[str] = database or os.getenv("DB_NAME")
|
|
84
|
+
self.driver: Optional[str] = driver or os.getenv("DB_DRIVER")
|
|
85
|
+
self.port: Optional[str] = port or os.getenv("DB_PORT")
|
|
86
|
+
self.oracle_sid: Optional[str] = oracle_sid or os.getenv("ORACLE_SID")
|
|
87
|
+
|
|
88
|
+
self._connection: Any = None
|
|
89
|
+
self._pool: Any = None
|
|
90
|
+
|
|
91
|
+
if not self.db_type or not self.database:
|
|
92
|
+
raise ValueError("Missing required database configuration.")
|
|
93
|
+
|
|
94
|
+
if self.db_type not in ("sqlite", "postgres", "mysql", "sqlserver", "oracle"):
|
|
95
|
+
raise ValueError(f"Unsupported database type: {self.db_type!r}")
|
|
96
|
+
|
|
97
|
+
# -----------------------------------------------------------------------
|
|
98
|
+
# Connection lifecycle
|
|
99
|
+
# -----------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
async def connect(self) -> None:
|
|
102
|
+
"""Open the database connection."""
|
|
103
|
+
try:
|
|
104
|
+
if self.db_type == "sqlite":
|
|
105
|
+
import aiosqlite
|
|
106
|
+
|
|
107
|
+
self._connection = await aiosqlite.connect(self.database or "") # type: ignore[arg-type]
|
|
108
|
+
self._connection.row_factory = aiosqlite.Row
|
|
109
|
+
logger.info("Connected to SQLite database: %s", self.database)
|
|
110
|
+
|
|
111
|
+
elif self.db_type == "postgres":
|
|
112
|
+
import asyncpg
|
|
113
|
+
|
|
114
|
+
self._connection = await asyncpg.connect(
|
|
115
|
+
host=self.host,
|
|
116
|
+
port=int(self.port or 5432),
|
|
117
|
+
user=self.user,
|
|
118
|
+
password=self.password,
|
|
119
|
+
database=self.database,
|
|
120
|
+
)
|
|
121
|
+
logger.info("Connected to PostgreSQL database: %s", self.database)
|
|
122
|
+
|
|
123
|
+
elif self.db_type == "mysql":
|
|
124
|
+
import aiomysql
|
|
125
|
+
|
|
126
|
+
self._connection = await aiomysql.connect(
|
|
127
|
+
host=self.host or "localhost",
|
|
128
|
+
port=int(self.port or 3306),
|
|
129
|
+
user=self.user,
|
|
130
|
+
password=self.password or "",
|
|
131
|
+
db=self.database,
|
|
132
|
+
autocommit=False,
|
|
133
|
+
)
|
|
134
|
+
logger.info("Connected to MySQL database: %s", self.database)
|
|
135
|
+
|
|
136
|
+
elif self.db_type == "sqlserver":
|
|
137
|
+
import aioodbc
|
|
138
|
+
|
|
139
|
+
dsn = (
|
|
140
|
+
f"DRIVER={self.driver};"
|
|
141
|
+
f"SERVER={self.host};"
|
|
142
|
+
f"DATABASE={self.database};"
|
|
143
|
+
f"UID={self.user};"
|
|
144
|
+
f"PWD={self.password}"
|
|
145
|
+
)
|
|
146
|
+
self._connection = await aioodbc.connect(dsn=dsn)
|
|
147
|
+
logger.info("Connected to SQL Server database: %s", self.database)
|
|
148
|
+
|
|
149
|
+
elif self.db_type == "oracle":
|
|
150
|
+
import oracledb
|
|
151
|
+
|
|
152
|
+
oracle_port = int(os.getenv("ORACLE_DB_PORT", "1521"))
|
|
153
|
+
dsn = oracledb.makedsn(
|
|
154
|
+
self.host, oracle_port, sid=self.oracle_sid # type: ignore[arg-type]
|
|
155
|
+
)
|
|
156
|
+
self._connection = await oracledb.connect_async(
|
|
157
|
+
user=self.user, password=self.password, dsn=dsn
|
|
158
|
+
)
|
|
159
|
+
logger.info("Connected to Oracle database: %s", self.oracle_sid)
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
raise AsyncConnectionError(
|
|
163
|
+
f"Failed to connect to {self.db_type}: {e}"
|
|
164
|
+
) from e
|
|
165
|
+
|
|
166
|
+
async def close(self) -> None:
|
|
167
|
+
"""Close the database connection."""
|
|
168
|
+
try:
|
|
169
|
+
if self._connection is not None:
|
|
170
|
+
await self._connection.close()
|
|
171
|
+
self._connection = None
|
|
172
|
+
logger.info("Closed %s connection", self.db_type)
|
|
173
|
+
except Exception as e:
|
|
174
|
+
raise AsyncConnectionError(f"Failed to close connection: {e}") from e
|
|
175
|
+
|
|
176
|
+
async def __aenter__(self) -> "AsyncSQLPyHelper":
|
|
177
|
+
await self.connect()
|
|
178
|
+
return self
|
|
179
|
+
|
|
180
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool:
|
|
181
|
+
await self.close()
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
# -----------------------------------------------------------------------
|
|
185
|
+
# Internal helpers
|
|
186
|
+
# -----------------------------------------------------------------------
|
|
187
|
+
|
|
188
|
+
def _check_connection(self) -> None:
|
|
189
|
+
if self._connection is None:
|
|
190
|
+
raise AsyncConnectionError(
|
|
191
|
+
"No active connection. Call connect() or use async with."
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _adapt_query(self, query: str, args: tuple) -> tuple[str, tuple]:
|
|
195
|
+
"""
|
|
196
|
+
Adapt a query and its arguments for the active database driver.
|
|
197
|
+
|
|
198
|
+
asyncpg uses $1, $2, ... positional placeholders.
|
|
199
|
+
aiosqlite and aiomysql use ? and %s respectively.
|
|
200
|
+
Callers should write queries using $1, $2, ... style and this
|
|
201
|
+
method will translate as needed.
|
|
202
|
+
"""
|
|
203
|
+
if not args:
|
|
204
|
+
return query, args
|
|
205
|
+
|
|
206
|
+
if self.db_type == "postgres":
|
|
207
|
+
# asyncpg natively uses $1, $2 — pass through unchanged
|
|
208
|
+
return query, args
|
|
209
|
+
|
|
210
|
+
elif self.db_type == "sqlite":
|
|
211
|
+
# Replace $1, $2 with ?
|
|
212
|
+
import re
|
|
213
|
+
|
|
214
|
+
adapted = re.sub(r"\$\d+", "?", query)
|
|
215
|
+
return adapted, args
|
|
216
|
+
|
|
217
|
+
elif self.db_type in ("mysql", "sqlserver"):
|
|
218
|
+
# Replace $1, $2 with %s
|
|
219
|
+
import re
|
|
220
|
+
|
|
221
|
+
adapted = re.sub(r"\$\d+", "%s", query)
|
|
222
|
+
return adapted, args
|
|
223
|
+
|
|
224
|
+
elif self.db_type == "oracle":
|
|
225
|
+
# Replace $1, $2 with :1, :2
|
|
226
|
+
import re
|
|
227
|
+
|
|
228
|
+
def replace_placeholder(m: Any) -> str:
|
|
229
|
+
return f":{m.group(0)[1:]}"
|
|
230
|
+
|
|
231
|
+
adapted = re.sub(r"\$(\d+)", replace_placeholder, query)
|
|
232
|
+
return adapted, args
|
|
233
|
+
|
|
234
|
+
return query, args
|
|
235
|
+
|
|
236
|
+
# -----------------------------------------------------------------------
|
|
237
|
+
# Query execution
|
|
238
|
+
# -----------------------------------------------------------------------
|
|
239
|
+
|
|
240
|
+
async def execute(self, query: str, *args: Any) -> None:
|
|
241
|
+
"""
|
|
242
|
+
Execute a SQL statement (INSERT, UPDATE, DELETE, DDL).
|
|
243
|
+
|
|
244
|
+
Use $1, $2, ... for parameterised values::
|
|
245
|
+
|
|
246
|
+
await db.execute(
|
|
247
|
+
"INSERT INTO users (id, name) VALUES ($1, $2)",
|
|
248
|
+
1, "Alice"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
query: SQL query string using $1, $2 placeholders.
|
|
253
|
+
*args: Query parameters.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
AsyncQueryError: If the query fails.
|
|
257
|
+
"""
|
|
258
|
+
self._check_connection()
|
|
259
|
+
adapted_query, adapted_args = self._adapt_query(query, args)
|
|
260
|
+
try:
|
|
261
|
+
if self.db_type == "postgres":
|
|
262
|
+
await self._connection.execute(adapted_query, *adapted_args)
|
|
263
|
+
|
|
264
|
+
elif self.db_type == "sqlite":
|
|
265
|
+
await self._connection.execute(adapted_query, adapted_args)
|
|
266
|
+
await self._connection.commit()
|
|
267
|
+
|
|
268
|
+
elif self.db_type in ("mysql", "sqlserver"):
|
|
269
|
+
async with self._connection.cursor() as cursor:
|
|
270
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
271
|
+
await self._connection.commit()
|
|
272
|
+
|
|
273
|
+
elif self.db_type == "oracle":
|
|
274
|
+
cursor = self._connection.cursor()
|
|
275
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
276
|
+
await self._connection.commit()
|
|
277
|
+
|
|
278
|
+
logger.debug("Executed: %s", query)
|
|
279
|
+
|
|
280
|
+
except Exception as e:
|
|
281
|
+
raise AsyncQueryError(f"Query failed: {e}") from e
|
|
282
|
+
|
|
283
|
+
async def fetch_one(self, query: str, *args: Any) -> Optional[Any]:
|
|
284
|
+
"""
|
|
285
|
+
Execute a SELECT query and return a single row, or None.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
query: SQL query string using $1, $2 placeholders.
|
|
289
|
+
*args: Query parameters.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
A single row, or None if no rows matched.
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
AsyncQueryError: If the query fails.
|
|
296
|
+
"""
|
|
297
|
+
self._check_connection()
|
|
298
|
+
adapted_query, adapted_args = self._adapt_query(query, args)
|
|
299
|
+
try:
|
|
300
|
+
if self.db_type == "postgres":
|
|
301
|
+
return await self._connection.fetchrow(adapted_query, *adapted_args)
|
|
302
|
+
|
|
303
|
+
elif self.db_type == "sqlite":
|
|
304
|
+
async with self._connection.execute(
|
|
305
|
+
adapted_query, adapted_args
|
|
306
|
+
) as cursor:
|
|
307
|
+
return await cursor.fetchone()
|
|
308
|
+
|
|
309
|
+
elif self.db_type in ("mysql", "sqlserver"):
|
|
310
|
+
async with self._connection.cursor() as cursor:
|
|
311
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
312
|
+
return await cursor.fetchone()
|
|
313
|
+
|
|
314
|
+
elif self.db_type == "oracle":
|
|
315
|
+
cursor = self._connection.cursor()
|
|
316
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
317
|
+
return await cursor.fetchone()
|
|
318
|
+
|
|
319
|
+
return None
|
|
320
|
+
|
|
321
|
+
except Exception as e:
|
|
322
|
+
raise AsyncQueryError(f"fetch_one failed: {e}") from e
|
|
323
|
+
|
|
324
|
+
async def fetch_all(self, query: str, *args: Any) -> list[Any]:
|
|
325
|
+
"""
|
|
326
|
+
Execute a SELECT query and return all rows.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
query: SQL query string using $1, $2 placeholders.
|
|
330
|
+
*args: Query parameters.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
A list of rows (empty list if no rows matched).
|
|
334
|
+
|
|
335
|
+
Raises:
|
|
336
|
+
AsyncQueryError: If the query fails.
|
|
337
|
+
"""
|
|
338
|
+
self._check_connection()
|
|
339
|
+
adapted_query, adapted_args = self._adapt_query(query, args)
|
|
340
|
+
try:
|
|
341
|
+
if self.db_type == "postgres":
|
|
342
|
+
return await self._connection.fetch(adapted_query, *adapted_args)
|
|
343
|
+
|
|
344
|
+
elif self.db_type == "sqlite":
|
|
345
|
+
async with self._connection.execute(
|
|
346
|
+
adapted_query, adapted_args
|
|
347
|
+
) as cursor:
|
|
348
|
+
return await cursor.fetchall()
|
|
349
|
+
|
|
350
|
+
elif self.db_type in ("mysql", "sqlserver"):
|
|
351
|
+
async with self._connection.cursor() as cursor:
|
|
352
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
353
|
+
return await cursor.fetchall()
|
|
354
|
+
|
|
355
|
+
elif self.db_type == "oracle":
|
|
356
|
+
cursor = self._connection.cursor()
|
|
357
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
358
|
+
return await cursor.fetchall()
|
|
359
|
+
|
|
360
|
+
return []
|
|
361
|
+
|
|
362
|
+
except Exception as e:
|
|
363
|
+
raise AsyncQueryError(f"fetch_all failed: {e}") from e
|
|
364
|
+
|
|
365
|
+
async def fetch_val(self, query: str, *args: Any) -> Optional[Any]:
|
|
366
|
+
"""
|
|
367
|
+
Execute a SELECT query and return a single scalar value.
|
|
368
|
+
|
|
369
|
+
Useful for COUNT, SUM, or any query returning one value::
|
|
370
|
+
|
|
371
|
+
count = await db.fetch_val("SELECT COUNT(*) FROM users")
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
query: SQL query string using $1, $2 placeholders.
|
|
375
|
+
*args: Query parameters.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
A single scalar value, or None.
|
|
379
|
+
|
|
380
|
+
Raises:
|
|
381
|
+
AsyncQueryError: If the query fails.
|
|
382
|
+
"""
|
|
383
|
+
self._check_connection()
|
|
384
|
+
adapted_query, adapted_args = self._adapt_query(query, args)
|
|
385
|
+
try:
|
|
386
|
+
if self.db_type == "postgres":
|
|
387
|
+
return await self._connection.fetchval(adapted_query, *adapted_args)
|
|
388
|
+
|
|
389
|
+
elif self.db_type == "sqlite":
|
|
390
|
+
async with self._connection.execute(
|
|
391
|
+
adapted_query, adapted_args
|
|
392
|
+
) as cursor:
|
|
393
|
+
row = await cursor.fetchone()
|
|
394
|
+
return row[0] if row else None
|
|
395
|
+
|
|
396
|
+
elif self.db_type in ("mysql", "sqlserver"):
|
|
397
|
+
async with self._connection.cursor() as cursor:
|
|
398
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
399
|
+
row = await cursor.fetchone()
|
|
400
|
+
return row[0] if row else None
|
|
401
|
+
|
|
402
|
+
elif self.db_type == "oracle":
|
|
403
|
+
cursor = self._connection.cursor()
|
|
404
|
+
await cursor.execute(adapted_query, adapted_args)
|
|
405
|
+
row = await cursor.fetchone()
|
|
406
|
+
return row[0] if row else None
|
|
407
|
+
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
except Exception as e:
|
|
411
|
+
raise AsyncQueryError(f"fetch_val failed: {e}") from e
|
|
412
|
+
|
|
413
|
+
async def execute_many(self, query: str, args_list: list[tuple]) -> None:
|
|
414
|
+
"""
|
|
415
|
+
Execute a SQL statement multiple times with different parameters.
|
|
416
|
+
Efficient for bulk inserts::
|
|
417
|
+
|
|
418
|
+
await db.execute_many(
|
|
419
|
+
"INSERT INTO users (id, name) VALUES ($1, $2)",
|
|
420
|
+
[(1, "Alice"), (2, "Bob"), (3, "Charlie")]
|
|
421
|
+
)
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
query: SQL query string using $1, $2 placeholders.
|
|
425
|
+
args_list: List of parameter tuples.
|
|
426
|
+
|
|
427
|
+
Raises:
|
|
428
|
+
AsyncQueryError: If the operation fails.
|
|
429
|
+
"""
|
|
430
|
+
self._check_connection()
|
|
431
|
+
if not args_list:
|
|
432
|
+
return
|
|
433
|
+
try:
|
|
434
|
+
if self.db_type == "postgres":
|
|
435
|
+
await self._connection.executemany(query, args_list)
|
|
436
|
+
|
|
437
|
+
elif self.db_type == "sqlite":
|
|
438
|
+
import re
|
|
439
|
+
|
|
440
|
+
adapted = re.sub(r"\$\d+", "?", query)
|
|
441
|
+
await self._connection.executemany(adapted, args_list)
|
|
442
|
+
await self._connection.commit()
|
|
443
|
+
|
|
444
|
+
elif self.db_type in ("mysql", "sqlserver"):
|
|
445
|
+
import re
|
|
446
|
+
|
|
447
|
+
adapted = re.sub(r"\$\d+", "%s", query)
|
|
448
|
+
async with self._connection.cursor() as cursor:
|
|
449
|
+
await cursor.executemany(adapted, args_list)
|
|
450
|
+
await self._connection.commit()
|
|
451
|
+
|
|
452
|
+
elif self.db_type == "oracle":
|
|
453
|
+
import re
|
|
454
|
+
|
|
455
|
+
def replace_placeholder(m: Any) -> str:
|
|
456
|
+
return f":{m.group(1)}"
|
|
457
|
+
|
|
458
|
+
adapted = re.sub(r"\$(\d+)", replace_placeholder, query)
|
|
459
|
+
cursor = self._connection.cursor()
|
|
460
|
+
await cursor.executemany(adapted, args_list)
|
|
461
|
+
await self._connection.commit()
|
|
462
|
+
|
|
463
|
+
logger.debug("execute_many: %d rows", len(args_list))
|
|
464
|
+
|
|
465
|
+
except Exception as e:
|
|
466
|
+
raise AsyncQueryError(f"execute_many failed: {e}") from e
|
|
467
|
+
|
|
468
|
+
# -----------------------------------------------------------------------
|
|
469
|
+
# Transaction management
|
|
470
|
+
# -----------------------------------------------------------------------
|
|
471
|
+
|
|
472
|
+
async def begin_transaction(self) -> None:
|
|
473
|
+
"""
|
|
474
|
+
Begin an explicit transaction.
|
|
475
|
+
|
|
476
|
+
For PostgreSQL, use the transaction() context manager instead,
|
|
477
|
+
which is the idiomatic asyncpg approach.
|
|
478
|
+
|
|
479
|
+
Raises:
|
|
480
|
+
AsyncQueryError: If the transaction cannot be started.
|
|
481
|
+
"""
|
|
482
|
+
self._check_connection()
|
|
483
|
+
try:
|
|
484
|
+
if self.db_type == "sqlite":
|
|
485
|
+
await self._connection.execute("BEGIN")
|
|
486
|
+
elif self.db_type == "mysql":
|
|
487
|
+
await self._connection.begin()
|
|
488
|
+
elif self.db_type == "sqlserver":
|
|
489
|
+
async with self._connection.cursor() as cursor:
|
|
490
|
+
await cursor.execute("BEGIN TRANSACTION")
|
|
491
|
+
elif self.db_type == "oracle":
|
|
492
|
+
pass # Oracle starts transactions implicitly
|
|
493
|
+
elif self.db_type == "postgres":
|
|
494
|
+
# asyncpg transactions are managed via connection.transaction()
|
|
495
|
+
# Calling begin() manually is supported but the context manager
|
|
496
|
+
# is preferred — see transaction() below
|
|
497
|
+
self._pg_transaction = self._connection.transaction()
|
|
498
|
+
await self._pg_transaction.start()
|
|
499
|
+
logger.info("Transaction started on %s", self.db_type)
|
|
500
|
+
except Exception as e:
|
|
501
|
+
raise AsyncQueryError(f"Failed to begin transaction: {e}") from e
|
|
502
|
+
|
|
503
|
+
async def commit_transaction(self) -> None:
|
|
504
|
+
"""Commit the current transaction."""
|
|
505
|
+
self._check_connection()
|
|
506
|
+
try:
|
|
507
|
+
if self.db_type == "postgres":
|
|
508
|
+
await self._pg_transaction.commit()
|
|
509
|
+
else:
|
|
510
|
+
await self._connection.commit()
|
|
511
|
+
logger.info("Transaction committed on %s", self.db_type)
|
|
512
|
+
except Exception as e:
|
|
513
|
+
raise AsyncQueryError(f"Failed to commit transaction: {e}") from e
|
|
514
|
+
|
|
515
|
+
async def rollback_transaction(self) -> None:
|
|
516
|
+
"""Roll back the current transaction."""
|
|
517
|
+
self._check_connection()
|
|
518
|
+
try:
|
|
519
|
+
if self.db_type == "postgres":
|
|
520
|
+
await self._pg_transaction.rollback()
|
|
521
|
+
else:
|
|
522
|
+
await self._connection.rollback()
|
|
523
|
+
logger.info("Transaction rolled back on %s", self.db_type)
|
|
524
|
+
except Exception as e:
|
|
525
|
+
raise AsyncQueryError(f"Failed to rollback transaction: {e}") from e
|
|
526
|
+
|
|
527
|
+
# -----------------------------------------------------------------------
|
|
528
|
+
# Connection pooling
|
|
529
|
+
# -----------------------------------------------------------------------
|
|
530
|
+
|
|
531
|
+
async def setup_pool(self, min_size: int = 1, max_size: int = 10) -> None:
|
|
532
|
+
"""
|
|
533
|
+
Set up an async connection pool.
|
|
534
|
+
|
|
535
|
+
Supported for PostgreSQL and MySQL only.
|
|
536
|
+
After calling this, use get_connection_from_pool() to acquire
|
|
537
|
+
connections.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
min_size: Minimum number of connections in the pool.
|
|
541
|
+
max_size: Maximum number of connections in the pool.
|
|
542
|
+
|
|
543
|
+
Raises:
|
|
544
|
+
AsyncConnectionError: If pool setup fails or db_type
|
|
545
|
+
does not support pooling.
|
|
546
|
+
"""
|
|
547
|
+
try:
|
|
548
|
+
if self.db_type == "postgres":
|
|
549
|
+
import asyncpg
|
|
550
|
+
|
|
551
|
+
self._pool = await asyncpg.create_pool(
|
|
552
|
+
host=self.host,
|
|
553
|
+
port=int(self.port or 5432),
|
|
554
|
+
user=self.user,
|
|
555
|
+
password=self.password,
|
|
556
|
+
database=self.database,
|
|
557
|
+
min_size=min_size,
|
|
558
|
+
max_size=max_size,
|
|
559
|
+
)
|
|
560
|
+
logger.info(
|
|
561
|
+
"PostgreSQL async pool created (min=%d, max=%d)", min_size, max_size
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
elif self.db_type == "mysql":
|
|
565
|
+
import aiomysql
|
|
566
|
+
|
|
567
|
+
self._pool = await aiomysql.create_pool(
|
|
568
|
+
host=self.host or "localhost",
|
|
569
|
+
port=int(self.port or 3306),
|
|
570
|
+
user=self.user,
|
|
571
|
+
password=self.password or "",
|
|
572
|
+
db=self.database,
|
|
573
|
+
minsize=min_size,
|
|
574
|
+
maxsize=max_size,
|
|
575
|
+
)
|
|
576
|
+
logger.info(
|
|
577
|
+
"MySQL async pool created (min=%d, max=%d)", min_size, max_size
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
else:
|
|
581
|
+
raise AsyncConnectionError(
|
|
582
|
+
f"Async connection pooling not supported for {self.db_type!r}. "
|
|
583
|
+
"Supported: postgres, mysql."
|
|
584
|
+
)
|
|
585
|
+
except AsyncConnectionError:
|
|
586
|
+
raise
|
|
587
|
+
except Exception as e:
|
|
588
|
+
raise AsyncConnectionError(f"Failed to create async pool: {e}") from e
|
|
589
|
+
|
|
590
|
+
async def close_pool(self) -> None:
|
|
591
|
+
"""Close the async connection pool."""
|
|
592
|
+
if self._pool is not None:
|
|
593
|
+
if self.db_type == "mysql":
|
|
594
|
+
self._pool.close()
|
|
595
|
+
await self._pool.wait_closed()
|
|
596
|
+
else:
|
|
597
|
+
await self._pool.close()
|
|
598
|
+
self._pool = None
|
|
599
|
+
logger.info("Async pool closed for %s", self.db_type)
|
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for sqlpyhelper.async_helper
|
|
3
|
+
All tests use mocking — no live database required.
|
|
4
|
+
Uses pytest-asyncio for async test support.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
|
|
11
|
+
from sqlpyhelper.async_helper import (
|
|
12
|
+
AsyncConnectionError,
|
|
13
|
+
AsyncQueryError,
|
|
14
|
+
AsyncSQLPyHelper,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
# Helpers
|
|
19
|
+
# ---------------------------------------------------------------------------
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_async_db(db_type: str = "sqlite") -> AsyncSQLPyHelper:
|
|
23
|
+
"""Create an AsyncSQLPyHelper instance without connecting."""
|
|
24
|
+
kwargs = {"db_type": db_type, "database": "test.db"}
|
|
25
|
+
if db_type != "sqlite":
|
|
26
|
+
kwargs.update(
|
|
27
|
+
{
|
|
28
|
+
"host": "localhost",
|
|
29
|
+
"user": "user",
|
|
30
|
+
"password": "pass",
|
|
31
|
+
"database": "testdb",
|
|
32
|
+
}
|
|
33
|
+
)
|
|
34
|
+
if db_type == "oracle":
|
|
35
|
+
kwargs["oracle_sid"] = "XE"
|
|
36
|
+
if db_type == "sqlserver":
|
|
37
|
+
kwargs["driver"] = "ODBC Driver 17 for SQL Server"
|
|
38
|
+
return AsyncSQLPyHelper(**kwargs)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def attach_mock_connection(db: AsyncSQLPyHelper) -> MagicMock:
|
|
42
|
+
"""Attach a mock connection to an AsyncSQLPyHelper instance."""
|
|
43
|
+
mock_conn = MagicMock()
|
|
44
|
+
mock_conn.close = AsyncMock()
|
|
45
|
+
mock_conn.commit = AsyncMock()
|
|
46
|
+
mock_conn.rollback = AsyncMock()
|
|
47
|
+
mock_conn.execute = AsyncMock()
|
|
48
|
+
mock_conn.executemany = AsyncMock()
|
|
49
|
+
mock_conn.fetch = AsyncMock(return_value=[])
|
|
50
|
+
mock_conn.fetchrow = AsyncMock(return_value=None)
|
|
51
|
+
mock_conn.fetchval = AsyncMock(return_value=None)
|
|
52
|
+
db._connection = mock_conn
|
|
53
|
+
return mock_conn
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ---------------------------------------------------------------------------
|
|
57
|
+
# __init__ validation
|
|
58
|
+
# ---------------------------------------------------------------------------
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TestInit:
|
|
62
|
+
def test_missing_database_raises(self):
|
|
63
|
+
with patch.dict("os.environ", {}, clear=True):
|
|
64
|
+
with pytest.raises(
|
|
65
|
+
ValueError, match="Missing required database configuration"
|
|
66
|
+
):
|
|
67
|
+
AsyncSQLPyHelper(db_type="sqlite", database="")
|
|
68
|
+
|
|
69
|
+
def test_missing_db_type_raises(self):
|
|
70
|
+
with patch.dict("os.environ", {}, clear=True):
|
|
71
|
+
with pytest.raises(
|
|
72
|
+
ValueError, match="Missing required database configuration"
|
|
73
|
+
):
|
|
74
|
+
AsyncSQLPyHelper(db_type="", database="test.db")
|
|
75
|
+
|
|
76
|
+
def test_unsupported_db_type_raises(self):
|
|
77
|
+
with pytest.raises(ValueError, match="Unsupported database type"):
|
|
78
|
+
AsyncSQLPyHelper(db_type="mongodb", database="test")
|
|
79
|
+
|
|
80
|
+
def test_valid_sqlite_init(self):
|
|
81
|
+
db = AsyncSQLPyHelper(db_type="sqlite", database="test.db")
|
|
82
|
+
assert db.db_type == "sqlite"
|
|
83
|
+
assert db.database == "test.db"
|
|
84
|
+
assert db._connection is None
|
|
85
|
+
|
|
86
|
+
def test_valid_postgres_init(self):
|
|
87
|
+
db = AsyncSQLPyHelper(
|
|
88
|
+
db_type="postgres",
|
|
89
|
+
host="localhost",
|
|
90
|
+
user="user",
|
|
91
|
+
password="pass",
|
|
92
|
+
database="testdb",
|
|
93
|
+
)
|
|
94
|
+
assert db.db_type == "postgres"
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ---------------------------------------------------------------------------
|
|
98
|
+
# connect / close
|
|
99
|
+
# ---------------------------------------------------------------------------
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class TestConnect:
|
|
103
|
+
@pytest.mark.asyncio
|
|
104
|
+
async def test_sqlite_connect(self):
|
|
105
|
+
db = make_async_db("sqlite")
|
|
106
|
+
mock_conn = MagicMock()
|
|
107
|
+
mock_conn.close = AsyncMock()
|
|
108
|
+
with patch("aiosqlite.connect", new_callable=AsyncMock, return_value=mock_conn):
|
|
109
|
+
await db.connect()
|
|
110
|
+
assert db._connection is mock_conn
|
|
111
|
+
|
|
112
|
+
@pytest.mark.asyncio
|
|
113
|
+
async def test_postgres_connect(self):
|
|
114
|
+
db = make_async_db("postgres")
|
|
115
|
+
mock_conn = MagicMock()
|
|
116
|
+
mock_conn.close = AsyncMock()
|
|
117
|
+
with patch("asyncpg.connect", new_callable=AsyncMock, return_value=mock_conn):
|
|
118
|
+
await db.connect()
|
|
119
|
+
assert db._connection is mock_conn
|
|
120
|
+
|
|
121
|
+
@pytest.mark.asyncio
|
|
122
|
+
async def test_mysql_connect(self):
|
|
123
|
+
db = make_async_db("mysql")
|
|
124
|
+
mock_conn = MagicMock()
|
|
125
|
+
mock_conn.close = AsyncMock()
|
|
126
|
+
with patch("aiomysql.connect", new_callable=AsyncMock, return_value=mock_conn):
|
|
127
|
+
await db.connect()
|
|
128
|
+
assert db._connection is mock_conn
|
|
129
|
+
|
|
130
|
+
@pytest.mark.asyncio
|
|
131
|
+
async def test_connect_failure_raises_async_connection_error(self):
|
|
132
|
+
db = make_async_db("sqlite")
|
|
133
|
+
with patch("aiosqlite.connect", side_effect=Exception("disk full")):
|
|
134
|
+
with pytest.raises(AsyncConnectionError, match="Failed to connect"):
|
|
135
|
+
await db.connect()
|
|
136
|
+
|
|
137
|
+
@pytest.mark.asyncio
|
|
138
|
+
async def test_close(self):
|
|
139
|
+
db = make_async_db("sqlite")
|
|
140
|
+
mock_conn = attach_mock_connection(db)
|
|
141
|
+
await db.close()
|
|
142
|
+
mock_conn.close.assert_called_once()
|
|
143
|
+
assert db._connection is None
|
|
144
|
+
|
|
145
|
+
@pytest.mark.asyncio
|
|
146
|
+
async def test_close_raises_on_failure(self):
|
|
147
|
+
db = make_async_db("sqlite")
|
|
148
|
+
mock_conn = attach_mock_connection(db)
|
|
149
|
+
mock_conn.close.side_effect = Exception("already closed")
|
|
150
|
+
with pytest.raises(AsyncConnectionError, match="Failed to close"):
|
|
151
|
+
await db.close()
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# ---------------------------------------------------------------------------
|
|
155
|
+
# Context manager
|
|
156
|
+
# ---------------------------------------------------------------------------
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class TestContextManager:
|
|
160
|
+
@pytest.mark.asyncio
|
|
161
|
+
async def test_aenter_returns_self(self):
|
|
162
|
+
db = make_async_db("sqlite")
|
|
163
|
+
mock_conn = MagicMock()
|
|
164
|
+
mock_conn.close = AsyncMock()
|
|
165
|
+
with patch("aiosqlite.connect", new_callable=AsyncMock, return_value=mock_conn):
|
|
166
|
+
result = await db.__aenter__()
|
|
167
|
+
assert result is db
|
|
168
|
+
|
|
169
|
+
@pytest.mark.asyncio
|
|
170
|
+
async def test_aexit_closes_connection(self):
|
|
171
|
+
db = make_async_db("sqlite")
|
|
172
|
+
mock_conn = attach_mock_connection(db)
|
|
173
|
+
await db.__aexit__(None, None, None)
|
|
174
|
+
mock_conn.close.assert_called_once()
|
|
175
|
+
|
|
176
|
+
@pytest.mark.asyncio
|
|
177
|
+
async def test_aexit_returns_false(self):
|
|
178
|
+
db = make_async_db("sqlite")
|
|
179
|
+
attach_mock_connection(db)
|
|
180
|
+
result = await db.__aexit__(None, None, None)
|
|
181
|
+
assert result is False
|
|
182
|
+
|
|
183
|
+
@pytest.mark.asyncio
|
|
184
|
+
async def test_async_with_statement(self):
|
|
185
|
+
mock_conn = MagicMock()
|
|
186
|
+
mock_conn.close = AsyncMock()
|
|
187
|
+
with patch("aiosqlite.connect", new_callable=AsyncMock, return_value=mock_conn):
|
|
188
|
+
async with AsyncSQLPyHelper(db_type="sqlite", database="test.db") as db:
|
|
189
|
+
assert db._connection is mock_conn
|
|
190
|
+
mock_conn.close.assert_called_once()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
# ---------------------------------------------------------------------------
|
|
194
|
+
# _check_connection
|
|
195
|
+
# ---------------------------------------------------------------------------
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class TestCheckConnection:
|
|
199
|
+
def test_raises_when_no_connection(self):
|
|
200
|
+
db = make_async_db("sqlite")
|
|
201
|
+
with pytest.raises(AsyncConnectionError, match="No active connection"):
|
|
202
|
+
db._check_connection()
|
|
203
|
+
|
|
204
|
+
def test_passes_when_connected(self):
|
|
205
|
+
db = make_async_db("sqlite")
|
|
206
|
+
attach_mock_connection(db)
|
|
207
|
+
db._check_connection() # should not raise
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# ---------------------------------------------------------------------------
|
|
211
|
+
# _adapt_query
|
|
212
|
+
# ---------------------------------------------------------------------------
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class TestAdaptQuery:
|
|
216
|
+
def test_postgres_passes_through(self):
|
|
217
|
+
db = make_async_db("postgres")
|
|
218
|
+
q, args = db._adapt_query("SELECT $1", (1,))
|
|
219
|
+
assert q == "SELECT $1"
|
|
220
|
+
|
|
221
|
+
def test_sqlite_replaces_with_question_mark(self):
|
|
222
|
+
db = make_async_db("sqlite")
|
|
223
|
+
q, args = db._adapt_query("SELECT $1, $2", (1, 2))
|
|
224
|
+
assert q == "SELECT ?, ?"
|
|
225
|
+
|
|
226
|
+
def test_mysql_replaces_with_percent_s(self):
|
|
227
|
+
db = make_async_db("mysql")
|
|
228
|
+
q, args = db._adapt_query("SELECT $1, $2", (1, 2))
|
|
229
|
+
assert q == "SELECT %s, %s"
|
|
230
|
+
|
|
231
|
+
def test_oracle_replaces_with_colon(self):
|
|
232
|
+
db = AsyncSQLPyHelper(
|
|
233
|
+
db_type="oracle",
|
|
234
|
+
host="localhost",
|
|
235
|
+
user="u",
|
|
236
|
+
password="p",
|
|
237
|
+
database="d",
|
|
238
|
+
oracle_sid="XE",
|
|
239
|
+
)
|
|
240
|
+
q, args = db._adapt_query("SELECT $1, $2", (1, 2))
|
|
241
|
+
assert q == "SELECT :1, :2"
|
|
242
|
+
|
|
243
|
+
def test_empty_args_returns_unchanged(self):
|
|
244
|
+
db = make_async_db("sqlite")
|
|
245
|
+
q, args = db._adapt_query("SELECT 1", ())
|
|
246
|
+
assert q == "SELECT 1"
|
|
247
|
+
assert args == ()
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# ---------------------------------------------------------------------------
|
|
251
|
+
# execute
|
|
252
|
+
# ---------------------------------------------------------------------------
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class TestExecute:
|
|
256
|
+
@pytest.mark.asyncio
|
|
257
|
+
async def test_sqlite_execute(self):
|
|
258
|
+
db = make_async_db("sqlite")
|
|
259
|
+
mock_conn = attach_mock_connection(db)
|
|
260
|
+
await db.execute("CREATE TABLE users (id INTEGER)")
|
|
261
|
+
mock_conn.execute.assert_called_once()
|
|
262
|
+
|
|
263
|
+
@pytest.mark.asyncio
|
|
264
|
+
async def test_postgres_execute(self):
|
|
265
|
+
db = make_async_db("postgres")
|
|
266
|
+
mock_conn = attach_mock_connection(db)
|
|
267
|
+
await db.execute("INSERT INTO users VALUES ($1)", 1)
|
|
268
|
+
mock_conn.execute.assert_called_once()
|
|
269
|
+
|
|
270
|
+
@pytest.mark.asyncio
|
|
271
|
+
async def test_raises_async_query_error_on_failure(self):
|
|
272
|
+
db = make_async_db("sqlite")
|
|
273
|
+
mock_conn = attach_mock_connection(db)
|
|
274
|
+
mock_conn.execute.side_effect = Exception("disk full")
|
|
275
|
+
with pytest.raises(AsyncQueryError, match="Query failed"):
|
|
276
|
+
await db.execute("INSERT INTO users VALUES ($1)", 1)
|
|
277
|
+
|
|
278
|
+
@pytest.mark.asyncio
|
|
279
|
+
async def test_raises_when_not_connected(self):
|
|
280
|
+
db = make_async_db("sqlite")
|
|
281
|
+
with pytest.raises(AsyncConnectionError, match="No active connection"):
|
|
282
|
+
await db.execute("SELECT 1")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# ---------------------------------------------------------------------------
|
|
286
|
+
# fetch_one
|
|
287
|
+
# ---------------------------------------------------------------------------
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class TestFetchOne:
|
|
291
|
+
@pytest.mark.asyncio
|
|
292
|
+
async def test_postgres_fetch_one(self):
|
|
293
|
+
db = make_async_db("postgres")
|
|
294
|
+
mock_conn = attach_mock_connection(db)
|
|
295
|
+
mock_conn.fetchrow.return_value = {"id": 1, "name": "Alice"}
|
|
296
|
+
result = await db.fetch_one("SELECT * FROM users WHERE id = $1", 1)
|
|
297
|
+
assert result == {"id": 1, "name": "Alice"}
|
|
298
|
+
|
|
299
|
+
@pytest.mark.asyncio
|
|
300
|
+
async def test_raises_on_failure(self):
|
|
301
|
+
db = make_async_db("postgres")
|
|
302
|
+
mock_conn = attach_mock_connection(db)
|
|
303
|
+
mock_conn.fetchrow.side_effect = Exception("connection lost")
|
|
304
|
+
with pytest.raises(AsyncQueryError, match="fetch_one failed"):
|
|
305
|
+
await db.fetch_one("SELECT * FROM users")
|
|
306
|
+
|
|
307
|
+
@pytest.mark.asyncio
|
|
308
|
+
async def test_raises_when_not_connected(self):
|
|
309
|
+
db = make_async_db("sqlite")
|
|
310
|
+
with pytest.raises(AsyncConnectionError):
|
|
311
|
+
await db.fetch_one("SELECT 1")
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
# ---------------------------------------------------------------------------
|
|
315
|
+
# fetch_all
|
|
316
|
+
# ---------------------------------------------------------------------------
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class TestFetchAll:
|
|
320
|
+
@pytest.mark.asyncio
|
|
321
|
+
async def test_postgres_fetch_all(self):
|
|
322
|
+
db = make_async_db("postgres")
|
|
323
|
+
mock_conn = attach_mock_connection(db)
|
|
324
|
+
mock_conn.fetch.return_value = [{"id": 1}, {"id": 2}]
|
|
325
|
+
result = await db.fetch_all("SELECT * FROM users")
|
|
326
|
+
assert len(result) == 2
|
|
327
|
+
|
|
328
|
+
@pytest.mark.asyncio
|
|
329
|
+
async def test_raises_on_failure(self):
|
|
330
|
+
db = make_async_db("postgres")
|
|
331
|
+
mock_conn = attach_mock_connection(db)
|
|
332
|
+
mock_conn.fetch.side_effect = Exception("timeout")
|
|
333
|
+
with pytest.raises(AsyncQueryError, match="fetch_all failed"):
|
|
334
|
+
await db.fetch_all("SELECT * FROM users")
|
|
335
|
+
|
|
336
|
+
@pytest.mark.asyncio
|
|
337
|
+
async def test_raises_when_not_connected(self):
|
|
338
|
+
db = make_async_db("sqlite")
|
|
339
|
+
with pytest.raises(AsyncConnectionError):
|
|
340
|
+
await db.fetch_all("SELECT 1")
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# ---------------------------------------------------------------------------
|
|
344
|
+
# fetch_val
|
|
345
|
+
# ---------------------------------------------------------------------------
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
class TestFetchVal:
|
|
349
|
+
@pytest.mark.asyncio
|
|
350
|
+
async def test_postgres_fetch_val(self):
|
|
351
|
+
db = make_async_db("postgres")
|
|
352
|
+
mock_conn = attach_mock_connection(db)
|
|
353
|
+
mock_conn.fetchval.return_value = 42
|
|
354
|
+
result = await db.fetch_val("SELECT COUNT(*) FROM users")
|
|
355
|
+
assert result == 42
|
|
356
|
+
|
|
357
|
+
@pytest.mark.asyncio
|
|
358
|
+
async def test_raises_on_failure(self):
|
|
359
|
+
db = make_async_db("postgres")
|
|
360
|
+
mock_conn = attach_mock_connection(db)
|
|
361
|
+
mock_conn.fetchval.side_effect = Exception("error")
|
|
362
|
+
with pytest.raises(AsyncQueryError, match="fetch_val failed"):
|
|
363
|
+
await db.fetch_val("SELECT COUNT(*) FROM users")
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
# ---------------------------------------------------------------------------
|
|
367
|
+
# execute_many
|
|
368
|
+
# ---------------------------------------------------------------------------
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
class TestExecuteMany:
|
|
372
|
+
@pytest.mark.asyncio
|
|
373
|
+
async def test_postgres_execute_many(self):
|
|
374
|
+
db = make_async_db("postgres")
|
|
375
|
+
mock_conn = attach_mock_connection(db)
|
|
376
|
+
mock_conn.executemany = AsyncMock()
|
|
377
|
+
await db.execute_many(
|
|
378
|
+
"INSERT INTO users VALUES ($1, $2)", [(1, "Alice"), (2, "Bob")]
|
|
379
|
+
)
|
|
380
|
+
mock_conn.executemany.assert_called_once()
|
|
381
|
+
|
|
382
|
+
@pytest.mark.asyncio
|
|
383
|
+
async def test_empty_list_does_nothing(self):
|
|
384
|
+
db = make_async_db("postgres")
|
|
385
|
+
mock_conn = attach_mock_connection(db)
|
|
386
|
+
await db.execute_many("INSERT INTO users VALUES ($1)", [])
|
|
387
|
+
mock_conn.executemany.assert_not_called()
|
|
388
|
+
|
|
389
|
+
@pytest.mark.asyncio
|
|
390
|
+
async def test_raises_on_failure(self):
|
|
391
|
+
db = make_async_db("postgres")
|
|
392
|
+
mock_conn = attach_mock_connection(db)
|
|
393
|
+
mock_conn.executemany.side_effect = Exception("error")
|
|
394
|
+
with pytest.raises(AsyncQueryError, match="execute_many failed"):
|
|
395
|
+
await db.execute_many("INSERT INTO users VALUES ($1)", [(1,)])
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
# ---------------------------------------------------------------------------
|
|
399
|
+
# Transactions
|
|
400
|
+
# ---------------------------------------------------------------------------
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class TestTransactions:
|
|
404
|
+
@pytest.mark.asyncio
|
|
405
|
+
async def test_mysql_begin_calls_begin(self):
|
|
406
|
+
db = make_async_db("mysql")
|
|
407
|
+
mock_conn = attach_mock_connection(db)
|
|
408
|
+
mock_conn.begin = AsyncMock()
|
|
409
|
+
await db.begin_transaction()
|
|
410
|
+
mock_conn.begin.assert_called_once()
|
|
411
|
+
|
|
412
|
+
@pytest.mark.asyncio
|
|
413
|
+
async def test_commit(self):
|
|
414
|
+
db = make_async_db("mysql")
|
|
415
|
+
mock_conn = attach_mock_connection(db)
|
|
416
|
+
await db.commit_transaction()
|
|
417
|
+
mock_conn.commit.assert_called_once()
|
|
418
|
+
|
|
419
|
+
@pytest.mark.asyncio
|
|
420
|
+
async def test_rollback(self):
|
|
421
|
+
db = make_async_db("mysql")
|
|
422
|
+
mock_conn = attach_mock_connection(db)
|
|
423
|
+
await db.rollback_transaction()
|
|
424
|
+
mock_conn.rollback.assert_called_once()
|
|
425
|
+
|
|
426
|
+
@pytest.mark.asyncio
|
|
427
|
+
async def test_rollback_raises_on_failure(self):
|
|
428
|
+
db = make_async_db("mysql")
|
|
429
|
+
mock_conn = attach_mock_connection(db)
|
|
430
|
+
mock_conn.rollback.side_effect = Exception("rollback error")
|
|
431
|
+
with pytest.raises(AsyncQueryError, match="Failed to rollback"):
|
|
432
|
+
await db.rollback_transaction()
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
# ---------------------------------------------------------------------------
|
|
436
|
+
# Connection pool
|
|
437
|
+
# ---------------------------------------------------------------------------
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class TestPool:
|
|
441
|
+
@pytest.mark.asyncio
|
|
442
|
+
async def test_postgres_pool_setup(self):
|
|
443
|
+
db = make_async_db("postgres")
|
|
444
|
+
mock_pool = MagicMock()
|
|
445
|
+
mock_pool.close = AsyncMock()
|
|
446
|
+
with patch(
|
|
447
|
+
"asyncpg.create_pool", new_callable=AsyncMock, return_value=mock_pool
|
|
448
|
+
):
|
|
449
|
+
await db.setup_pool(min_size=1, max_size=5)
|
|
450
|
+
assert db._pool is mock_pool
|
|
451
|
+
|
|
452
|
+
@pytest.mark.asyncio
|
|
453
|
+
async def test_mysql_pool_setup(self):
|
|
454
|
+
db = make_async_db("mysql")
|
|
455
|
+
mock_pool = MagicMock()
|
|
456
|
+
mock_pool.close = MagicMock()
|
|
457
|
+
mock_pool.wait_closed = AsyncMock()
|
|
458
|
+
with patch(
|
|
459
|
+
"aiomysql.create_pool", new_callable=AsyncMock, return_value=mock_pool
|
|
460
|
+
):
|
|
461
|
+
await db.setup_pool(min_size=1, max_size=5)
|
|
462
|
+
assert db._pool is mock_pool
|
|
463
|
+
|
|
464
|
+
@pytest.mark.asyncio
|
|
465
|
+
async def test_unsupported_db_raises(self):
|
|
466
|
+
db = make_async_db("sqlite")
|
|
467
|
+
with pytest.raises(AsyncConnectionError, match="not supported"):
|
|
468
|
+
await db.setup_pool()
|
|
469
|
+
|
|
470
|
+
@pytest.mark.asyncio
|
|
471
|
+
async def test_close_pool_postgres(self):
|
|
472
|
+
db = make_async_db("postgres")
|
|
473
|
+
mock_pool = MagicMock()
|
|
474
|
+
mock_pool.close = AsyncMock()
|
|
475
|
+
db._pool = mock_pool
|
|
476
|
+
await db.close_pool()
|
|
477
|
+
mock_pool.close.assert_called_once()
|
|
478
|
+
assert db._pool is None
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|