sqlite-anyio 0.2.7__tar.gz → 0.2.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sqlite-anyio
3
- Version: 0.2.7
3
+ Version: 0.2.9
4
4
  Summary: Asynchronous client for SQLite using AnyIO
5
5
  Author: Alex Grönholm, David Brochart
6
6
  Author-email: Alex Grönholm <alex.gronholm@nextday.fi>, David Brochart <david.brochart@gmail.com>
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.11
35
35
  Classifier: Programming Language :: Python :: 3.12
36
36
  Classifier: Programming Language :: Python :: 3.13
37
37
  Classifier: Programming Language :: Python :: 3.14
38
- Requires-Dist: anyio>=4.0,<5.0
38
+ Requires-Dist: anyio>=4.14,<5.0
39
39
  Requires-Dist: typing-extensions>=4.15.0
40
40
  Requires-Python: >=3.10
41
41
  Project-URL: Source, https://github.com/davidbrochart/sqlite-anyio
@@ -6,7 +6,7 @@ build-backend = "uv_build"
6
6
  name = "sqlite-anyio"
7
7
  description = "Asynchronous client for SQLite using AnyIO"
8
8
  readme = "README.md"
9
- version = "0.2.7"
9
+ version = "0.2.9"
10
10
  authors = [
11
11
  {name = "Alex Grönholm", email = "alex.gronholm@nextday.fi"},
12
12
  {name = "David Brochart", email = "david.brochart@gmail.com"},
@@ -28,7 +28,7 @@ classifiers = [
28
28
  ]
29
29
  requires-python = ">= 3.10"
30
30
  dependencies = [
31
- "anyio >=4.0,<5.0",
31
+ "anyio >=4.14,<5.0",
32
32
  "typing-extensions >=4.15.0",
33
33
  ]
34
34
 
@@ -4,18 +4,86 @@ __all__ = ["connect", "Connection", "Cursor"]
4
4
 
5
5
  import sqlite3
6
6
  import sys
7
+ import threading
7
8
  from collections.abc import Callable, Sequence
8
- from functools import partial, update_wrapper
9
+ from functools import partial
9
10
  from logging import Logger, getLogger
10
11
  from types import TracebackType
11
- from typing import Any
12
+ from typing import Any, TypeVar
12
13
 
13
- from anyio import CapacityLimiter, to_thread
14
+ import anyio
15
+ from anyio import to_thread, from_thread
14
16
 
15
17
  if sys.version_info >= (3, 11):
16
- from typing import Self
17
- else:
18
- from typing_extensions import Self # pragma: nocover
18
+ from typing import Self, TypeVarTuple, Unpack
19
+ else: # pragma: nocover
20
+ from exceptiongroup import BaseExceptionGroup
21
+ from typing_extensions import Self, TypeVarTuple, Unpack
22
+
23
+ T_Retval = TypeVar("T_Retval")
24
+ PosArgsT = TypeVarTuple("PosArgsT")
25
+
26
+
27
+ async def _interruptible_dispatch(
28
+ self: Connection | Cursor,
29
+ func: Callable[[Unpack[PosArgsT]], T_Retval],
30
+ *args: Unpack[PosArgsT]
31
+ ) -> T_Retval:
32
+ if isinstance(self, Connection):
33
+ real_connection = self._real_connection
34
+ elif isinstance(self, Cursor):
35
+ real_connection = self._real_cursor.connection
36
+ else: # pragma: nocover
37
+ raise AssertionError("Unknown type:", self)
38
+
39
+ ev = anyio.Event()
40
+ lock = threading.Lock()
41
+ need_interrupt = False
42
+
43
+ async def cancel_detector() -> None:
44
+ try:
45
+ await ev.wait()
46
+ except anyio.get_cancelled_exc_class():
47
+ # Block progress in the thread while checking this flag.
48
+ # Our guard_interrupt thread only ever holds the lock briefly,
49
+ # so there's no risk of blocking the event loop.
50
+ with lock:
51
+ # Due to race conditions, the first calls to interrupt may be
52
+ # ignored. This race is quick so this loop should not cycle much.
53
+ while need_interrupt:
54
+ real_connection.interrupt()
55
+ await anyio.lowlevel.cancel_shielded_checkpoint()
56
+ # we do NOT re-raise the cancellation so that the task group
57
+ # does not swallow our retval. If a Cancelled is to propagate,
58
+ # it should come out of to_thread.run_sync
59
+
60
+ def guard_interrupt() -> T_Retval:
61
+ nonlocal need_interrupt
62
+
63
+ with lock:
64
+ from_thread.check_cancelled()
65
+ need_interrupt = True
66
+ try:
67
+ return func(*args)
68
+ except sqlite3.OperationalError as e:
69
+ if str(e) == "interrupted":
70
+ from_thread.check_cancelled()
71
+ raise
72
+ finally:
73
+ need_interrupt = False
74
+
75
+ try:
76
+ async with anyio.create_task_group() as g:
77
+ g.start_soon(cancel_detector)
78
+ retval = await to_thread.run_sync(guard_interrupt, limiter=self._limiter)
79
+ ev.set()
80
+ except BaseExceptionGroup as eg:
81
+ if len(eg.exceptions) == 1:
82
+ if isinstance(eg.exceptions[0], Exception):
83
+ raise eg.exceptions[0]
84
+ raise # pragma: nocover (would be an internal error that should fail other tests)
85
+
86
+ return retval
19
87
 
20
88
 
21
89
  class Connection:
@@ -28,7 +96,7 @@ class Connection:
28
96
  self._real_connection = _real_connection
29
97
  self._exception_handler = _exception_handler
30
98
  self._log = _log or getLogger(__name__)
31
- self._limiter = CapacityLimiter(1)
99
+ self._limiter = anyio.CapacityLimiter(1)
32
100
 
33
101
  async def __aenter__(self) -> Self:
34
102
  return self
@@ -40,38 +108,32 @@ class Connection:
40
108
  exc_tb: TracebackType | None,
41
109
  ) -> bool | None:
42
110
  if exc_val is None:
43
- await self.commit() # type: ignore[call-arg]
111
+ await self.commit()
44
112
  return None
45
113
 
46
114
  assert exc_type is not None
47
115
  assert exc_val is not None
48
116
  assert exc_tb is not None
49
- await self.rollback() # type: ignore[call-arg]
117
+ await self.rollback()
50
118
  exception_handled = False
51
119
  if self._exception_handler is not None:
52
120
  exception_handled = self._exception_handler(exc_type, exc_val, exc_tb, self._log)
53
121
  return exception_handled
54
122
 
55
123
  async def execute(self, sql: str, parameters: Sequence[Any] = (), /) -> Cursor:
56
- real_cursor = await to_thread.run_sync(self._real_connection.execute, sql, parameters, limiter=self._limiter)
124
+ real_cursor = await _interruptible_dispatch(self, self._real_connection.execute, sql, parameters)
57
125
  return Cursor(real_cursor, self._limiter, self._exception_handler, self._log)
58
126
 
59
- update_wrapper(execute, sqlite3.Connection.execute)
60
-
61
- async def close(self):
62
- return await to_thread.run_sync(self._real_connection.close, limiter=self._limiter)
63
-
64
- update_wrapper(close, sqlite3.Connection.close)
65
-
66
- async def commit(self):
67
- return await to_thread.run_sync(self._real_connection.commit, limiter=self._limiter)
68
-
69
- update_wrapper(commit, sqlite3.Connection.commit)
127
+ async def close(self) -> None:
128
+ with anyio.CancelScope(shield=True):
129
+ await to_thread.run_sync(self._real_connection.close, limiter=self._limiter)
70
130
 
71
- async def rollback(self):
72
- return await to_thread.run_sync(self._real_connection.rollback, limiter=self._limiter)
131
+ async def commit(self) -> None:
132
+ await _interruptible_dispatch(self, self._real_connection.commit)
73
133
 
74
- update_wrapper(rollback, sqlite3.Connection.rollback)
134
+ async def rollback(self) -> None:
135
+ with anyio.CancelScope(shield=True):
136
+ await to_thread.run_sync(self._real_connection.rollback, limiter=self._limiter)
75
137
 
76
138
  async def cursor(self, factory: Callable[[sqlite3.Connection], sqlite3.Cursor] = sqlite3.Cursor) -> Cursor:
77
139
  real_cursor = await to_thread.run_sync(self._real_connection.cursor, factory, limiter=self._limiter)
@@ -82,7 +144,7 @@ class Cursor:
82
144
  def __init__(
83
145
  self,
84
146
  real_cursor: sqlite3.Cursor,
85
- limiter: CapacityLimiter,
147
+ limiter: anyio.CapacityLimiter,
86
148
  _exception_handler: Callable[[type[BaseException], BaseException, TracebackType, Logger], bool] | None,
87
149
  _log: Logger,
88
150
  ) -> None:
@@ -125,42 +187,29 @@ class Cursor:
125
187
  return exception_handled
126
188
 
127
189
  async def close(self) -> None:
128
- await to_thread.run_sync(self._real_cursor.close, limiter=self._limiter)
129
-
130
- update_wrapper(close, sqlite3.Cursor.close)
190
+ with anyio.CancelScope(shield=True):
191
+ await to_thread.run_sync(self._real_cursor.close, limiter=self._limiter)
131
192
 
132
193
  async def execute(self, sql: str, parameters: Sequence[Any] = (), /) -> Cursor:
133
- real_cursor = await to_thread.run_sync(self._real_cursor.execute, sql, parameters, limiter=self._limiter)
134
- return Cursor(real_cursor, self._limiter, self._exception_handler, self._log)
135
-
136
- update_wrapper(execute, sqlite3.Cursor.execute)
194
+ await _interruptible_dispatch(self, self._real_cursor.execute, sql, parameters)
195
+ return self
137
196
 
138
197
  async def executemany(self, sql: str, parameters: Sequence[Any], /) -> Cursor:
139
- real_cursor = await to_thread.run_sync(self._real_cursor.executemany, sql, parameters, limiter=self._limiter)
140
- return Cursor(real_cursor, self._limiter, self._exception_handler, self._log)
141
-
142
- update_wrapper(executemany, sqlite3.Cursor.executemany)
198
+ await _interruptible_dispatch(self, self._real_cursor.executemany, sql, parameters)
199
+ return self
143
200
 
144
201
  async def executescript(self, sql_script: str, /) -> Cursor:
145
- real_cursor = await to_thread.run_sync(self._real_cursor.executescript, sql_script, limiter=self._limiter)
146
- return Cursor(real_cursor, self._limiter, self._exception_handler, self._log)
147
-
148
- update_wrapper(executescript, sqlite3.Cursor.executescript)
202
+ await _interruptible_dispatch(self, self._real_cursor.executescript, sql_script)
203
+ return self
149
204
 
150
205
  async def fetchone(self) -> tuple[Any, ...] | None:
151
- return await to_thread.run_sync(self._real_cursor.fetchone, limiter=self._limiter)
152
-
153
- update_wrapper(fetchone, sqlite3.Cursor.fetchone)
206
+ return await _interruptible_dispatch(self, self._real_cursor.fetchone)
154
207
 
155
208
  async def fetchmany(self, size: int) -> list[tuple[Any, ...]]:
156
- return await to_thread.run_sync(self._real_cursor.fetchmany, size, limiter=self._limiter)
157
-
158
- update_wrapper(fetchmany, sqlite3.Cursor.fetchmany)
209
+ return await _interruptible_dispatch(self, self._real_cursor.fetchmany, size)
159
210
 
160
211
  async def fetchall(self) -> list[tuple[Any, ...]]:
161
- return await to_thread.run_sync(self._real_cursor.fetchall, limiter=self._limiter)
162
-
163
- update_wrapper(fetchall, sqlite3.Cursor.fetchall)
212
+ return await _interruptible_dispatch(self, self._real_cursor.fetchall)
164
213
 
165
214
 
166
215
  async def connect(
File without changes
File without changes