sqlspec 0.16.2__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

@@ -0,0 +1,492 @@
1
+ """Efficient multi-connection pool for aiosqlite with proper shutdown handling."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import time
6
+ import uuid
7
+ from contextlib import asynccontextmanager, suppress
8
+ from typing import TYPE_CHECKING, Any, Optional, Union
9
+
10
+ import aiosqlite
11
+
12
+ from sqlspec.exceptions import SQLSpecError
13
+
14
+ if TYPE_CHECKING:
15
+ import threading
16
+ from collections.abc import AsyncGenerator
17
+
18
+ from sqlspec.adapters.aiosqlite._types import AiosqliteConnection
19
+
20
+ __all__ = (
21
+ "AiosqliteConnectTimeoutError",
22
+ "AiosqliteConnectionPool",
23
+ "AiosqlitePoolClosedError",
24
+ "AiosqlitePoolConnection",
25
+ )
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class AiosqlitePoolClosedError(SQLSpecError):
31
+ """Pool has been closed and cannot accept new operations."""
32
+
33
+
34
+ class AiosqliteConnectTimeoutError(SQLSpecError):
35
+ """Connection could not be established within the specified timeout period."""
36
+
37
+
38
+ class AiosqlitePoolConnection:
39
+ """Wrapper for database connections with pool lifecycle management."""
40
+
41
+ __slots__ = ("_closed", "connection", "id", "idle_since")
42
+
43
+ def __init__(self, connection: "AiosqliteConnection") -> None:
44
+ """Initialize pool connection wrapper.
45
+
46
+ Args:
47
+ connection: The raw aiosqlite connection
48
+ """
49
+ self.id = uuid.uuid4().hex
50
+ self.connection = connection
51
+ self.idle_since: Optional[float] = None
52
+ self._closed = False
53
+
54
+ @property
55
+ def idle_time(self) -> float:
56
+ """Get idle time in seconds.
57
+
58
+ Returns:
59
+ Idle time in seconds, 0.0 if connection is in use
60
+ """
61
+ if self.idle_since is None:
62
+ return 0.0
63
+ return time.time() - self.idle_since
64
+
65
+ @property
66
+ def is_closed(self) -> bool:
67
+ """Check if connection is closed.
68
+
69
+ Returns:
70
+ True if connection is closed
71
+ """
72
+ return self._closed
73
+
74
+ def mark_as_in_use(self) -> None:
75
+ """Mark connection as in use."""
76
+ self.idle_since = None
77
+
78
+ def mark_as_idle(self) -> None:
79
+ """Mark connection as idle."""
80
+ self.idle_since = time.time()
81
+
82
+ async def is_alive(self) -> bool:
83
+ """Check if connection is alive and functional.
84
+
85
+ Returns:
86
+ True if connection is healthy
87
+ """
88
+ if self._closed:
89
+ return False
90
+ try:
91
+ await self.connection.execute("SELECT 1")
92
+ except Exception:
93
+ return False
94
+ else:
95
+ return True
96
+
97
+ async def reset(self) -> None:
98
+ """Reset connection to clean state."""
99
+ if self._closed:
100
+ return
101
+ with suppress(Exception):
102
+ await self.connection.rollback()
103
+
104
+ async def close(self) -> None:
105
+ """Close the connection.
106
+
107
+ Since we use daemon threads, the connection will be terminated
108
+ when the process exits even if close fails.
109
+ """
110
+ if self._closed:
111
+ return
112
+ try:
113
+ with suppress(Exception):
114
+ await self.connection.rollback()
115
+ await self.connection.close()
116
+ except Exception:
117
+ logger.debug("Error closing connection %s", self.id)
118
+ finally:
119
+ self._closed = True
120
+
121
+
122
+ class AiosqliteConnectionPool:
123
+ """Multi-connection pool for aiosqlite with proper shutdown handling."""
124
+
125
+ __slots__ = (
126
+ "_closed_event",
127
+ "_connect_timeout",
128
+ "_connection_parameters",
129
+ "_connection_registry",
130
+ "_idle_timeout",
131
+ "_lock",
132
+ "_operation_timeout",
133
+ "_pool_size",
134
+ "_queue",
135
+ "_tracked_threads",
136
+ "_wal_initialized",
137
+ )
138
+
139
+ def __init__(
140
+ self,
141
+ connection_parameters: "dict[str, Any]",
142
+ pool_size: int = 5,
143
+ connect_timeout: float = 30.0,
144
+ idle_timeout: float = 24 * 60 * 60, # 24 hours
145
+ operation_timeout: float = 10.0,
146
+ ) -> None:
147
+ """Initialize connection pool.
148
+
149
+ Args:
150
+ connection_parameters: SQLite connection parameters
151
+ pool_size: Maximum number of connections in the pool
152
+ connect_timeout: Maximum time to wait for connection acquisition
153
+ idle_timeout: Maximum time a connection can remain idle
154
+ operation_timeout: Maximum time for connection operations
155
+ """
156
+ self._connection_parameters = connection_parameters
157
+ self._pool_size = pool_size
158
+ self._connect_timeout = connect_timeout
159
+ self._idle_timeout = idle_timeout
160
+ self._operation_timeout = operation_timeout
161
+
162
+ self._queue: asyncio.Queue[AiosqlitePoolConnection] = asyncio.Queue(maxsize=pool_size)
163
+ self._connection_registry: dict[str, AiosqlitePoolConnection] = {}
164
+ self._lock = asyncio.Lock()
165
+ self._closed_event = asyncio.Event()
166
+ self._tracked_threads: set[Union[threading.Thread, AiosqliteConnection]] = set()
167
+ self._wal_initialized = False
168
+
169
+ @property
170
+ def is_closed(self) -> bool:
171
+ """Check if pool is closed.
172
+
173
+ Returns:
174
+ True if pool is closed
175
+ """
176
+ return self._closed_event.is_set()
177
+
178
+ def size(self) -> int:
179
+ """Get total number of connections in pool.
180
+
181
+ Returns:
182
+ Total connection count
183
+ """
184
+ return len(self._connection_registry)
185
+
186
+ def checked_out(self) -> int:
187
+ """Get number of checked out connections.
188
+
189
+ Returns:
190
+ Number of connections currently in use
191
+ """
192
+ return len(self._connection_registry) - self._queue.qsize()
193
+
194
+ def _track_aiosqlite_thread(self, connection: "AiosqliteConnection") -> None:
195
+ """Track the background thread associated with an aiosqlite connection.
196
+
197
+ Args:
198
+ connection: The aiosqlite connection whose thread to track
199
+ """
200
+ self._tracked_threads.add(connection)
201
+
202
+ async def _create_connection(self) -> AiosqlitePoolConnection:
203
+ """Create a new connection with SQLite optimizations.
204
+
205
+ Returns:
206
+ New pool connection instance
207
+ """
208
+ connection = aiosqlite.connect(**self._connection_parameters)
209
+ connection.daemon = True
210
+ connection = await connection
211
+
212
+ # Detect database type for appropriate optimization
213
+ database_path = str(self._connection_parameters.get("database", ""))
214
+ is_shared_cache = "cache=shared" in database_path
215
+ is_memory_db = ":memory:" in database_path or "mode=memory" in database_path
216
+
217
+ try:
218
+ if is_memory_db:
219
+ await connection.execute("PRAGMA journal_mode = MEMORY")
220
+ await connection.execute("PRAGMA synchronous = OFF")
221
+ await connection.execute("PRAGMA temp_store = MEMORY")
222
+ await connection.execute("PRAGMA cache_size = -16000")
223
+ else:
224
+ await connection.execute("PRAGMA journal_mode = WAL")
225
+ await connection.execute("PRAGMA synchronous = NORMAL")
226
+
227
+ await connection.execute("PRAGMA foreign_keys = ON")
228
+ await connection.execute("PRAGMA busy_timeout = 30000")
229
+
230
+ if is_shared_cache and is_memory_db:
231
+ await connection.execute("PRAGMA read_uncommitted = ON")
232
+
233
+ await connection.commit()
234
+
235
+ if is_shared_cache:
236
+ self._wal_initialized = True
237
+ logger.debug("Database optimized for shared cache (memory: %s)", is_memory_db)
238
+
239
+ except Exception as e:
240
+ logger.warning("Failed to optimize connection: %s", e)
241
+ await connection.execute("PRAGMA foreign_keys = ON")
242
+ await connection.execute("PRAGMA busy_timeout = 30000")
243
+ await connection.commit()
244
+
245
+ pool_connection = AiosqlitePoolConnection(connection)
246
+ pool_connection.mark_as_idle()
247
+ self._track_aiosqlite_thread(connection)
248
+
249
+ async with self._lock:
250
+ self._connection_registry[pool_connection.id] = pool_connection
251
+
252
+ logger.debug("Created new aiosqlite connection: %s", pool_connection.id)
253
+ return pool_connection
254
+
255
+ async def _claim_if_healthy(self, connection: AiosqlitePoolConnection) -> bool:
256
+ """Check if connection is healthy and claim it.
257
+
258
+ Args:
259
+ connection: Connection to check and claim
260
+
261
+ Returns:
262
+ True if connection was successfully claimed
263
+ """
264
+ if connection.idle_time > self._idle_timeout:
265
+ logger.debug("Connection %s exceeded idle timeout, retiring", connection.id)
266
+ await self._retire_connection(connection)
267
+ return False
268
+
269
+ try:
270
+ await asyncio.wait_for(connection.is_alive(), timeout=self._operation_timeout)
271
+ except asyncio.TimeoutError:
272
+ logger.debug("Connection %s health check timed out, retiring", connection.id)
273
+ await self._retire_connection(connection)
274
+ return False
275
+ else:
276
+ connection.mark_as_in_use()
277
+ return True
278
+
279
+ async def _retire_connection(self, connection: AiosqlitePoolConnection) -> None:
280
+ """Retire a connection from the pool.
281
+
282
+ Args:
283
+ connection: Connection to retire
284
+ """
285
+ async with self._lock:
286
+ self._connection_registry.pop(connection.id, None)
287
+
288
+ try:
289
+ await asyncio.wait_for(connection.close(), timeout=self._operation_timeout)
290
+ except asyncio.TimeoutError:
291
+ logger.warning("Connection %s close timed out during retirement", connection.id)
292
+
293
+ async def _try_provision_new_connection(self) -> "Optional[AiosqlitePoolConnection]":
294
+ """Try to create a new connection if under capacity.
295
+
296
+ Returns:
297
+ New connection if successful, None if at capacity
298
+ """
299
+ async with self._lock:
300
+ if len(self._connection_registry) >= self._pool_size:
301
+ return None
302
+
303
+ try:
304
+ connection = await self._create_connection()
305
+ except Exception:
306
+ logger.exception("Failed to create new connection")
307
+ return None
308
+ else:
309
+ connection.mark_as_in_use()
310
+ return connection
311
+
312
+ async def _wait_for_healthy_connection(self) -> AiosqlitePoolConnection:
313
+ """Wait for a healthy connection to become available.
314
+
315
+ Returns:
316
+ Available healthy connection
317
+
318
+ Raises:
319
+ AiosqlitePoolClosedError: If pool is closed while waiting
320
+ """
321
+ while True:
322
+ get_connection_task = asyncio.create_task(self._queue.get())
323
+ pool_closed_task = asyncio.create_task(self._closed_event.wait())
324
+
325
+ done, pending = await asyncio.wait(
326
+ {get_connection_task, pool_closed_task}, return_when=asyncio.FIRST_COMPLETED
327
+ )
328
+
329
+ try:
330
+ if pool_closed_task in done:
331
+ msg = "Pool closed during connection acquisition"
332
+ raise AiosqlitePoolClosedError(msg)
333
+
334
+ connection = get_connection_task.result()
335
+ if await self._claim_if_healthy(connection):
336
+ return connection
337
+
338
+ finally:
339
+ for task in pending:
340
+ task.cancel()
341
+ with suppress(asyncio.CancelledError):
342
+ await task
343
+
344
+ async def _get_connection(self) -> AiosqlitePoolConnection:
345
+ """Run the three-phase connection acquisition cycle.
346
+
347
+ Returns:
348
+ Available connection
349
+
350
+ Raises:
351
+ AiosqlitePoolClosedError: If pool is closed
352
+ """
353
+ if self.is_closed:
354
+ msg = "Cannot acquire connection from closed pool"
355
+ raise AiosqlitePoolClosedError(msg)
356
+
357
+ while not self._queue.empty():
358
+ connection = self._queue.get_nowait()
359
+ if await self._claim_if_healthy(connection):
360
+ return connection
361
+
362
+ new_connection = await self._try_provision_new_connection()
363
+ if new_connection is not None:
364
+ return new_connection
365
+
366
+ return await self._wait_for_healthy_connection()
367
+
368
+ async def _wait_for_threads_to_terminate(self, timeout: float = 1.0) -> None:
369
+ """Wait for all tracked aiosqlite connection threads to terminate.
370
+
371
+ Since we use daemon threads, this is just a best-effort cleanup.
372
+ The threads will terminate when the process exits regardless.
373
+
374
+ Args:
375
+ timeout: Maximum time to wait for thread termination in seconds
376
+ """
377
+ if not self._tracked_threads:
378
+ return
379
+
380
+ logger.debug("Waiting for %d aiosqlite connection threads to terminate...", len(self._tracked_threads))
381
+ start_time = time.time()
382
+
383
+ dead_threads = {t for t in self._tracked_threads if not t.is_alive()}
384
+ self._tracked_threads -= dead_threads
385
+
386
+ if not self._tracked_threads:
387
+ logger.debug("All aiosqlite connection threads already terminated")
388
+ return
389
+
390
+ while self._tracked_threads and (time.time() - start_time) < timeout:
391
+ await asyncio.sleep(0.05)
392
+ dead_threads = {t for t in self._tracked_threads if not t.is_alive()}
393
+ self._tracked_threads -= dead_threads
394
+
395
+ remaining_threads = len(self._tracked_threads)
396
+ elapsed = time.time() - start_time
397
+
398
+ if remaining_threads > 0:
399
+ logger.debug(
400
+ "%d aiosqlite threads still running after %.2fs (daemon threads will terminate on exit)",
401
+ remaining_threads,
402
+ elapsed,
403
+ )
404
+ else:
405
+ logger.debug("All aiosqlite connection threads terminated successfully in %.2fs", elapsed)
406
+
407
+ async def acquire(self) -> AiosqlitePoolConnection:
408
+ """Acquire a connection from the pool.
409
+
410
+ Returns:
411
+ Available connection
412
+
413
+ Raises:
414
+ AiosqliteConnectTimeoutError: If acquisition times out
415
+ """
416
+ try:
417
+ connection = await asyncio.wait_for(self._get_connection(), timeout=self._connect_timeout)
418
+ if not self._wal_initialized and "cache=shared" in str(self._connection_parameters.get("database", "")):
419
+ await asyncio.sleep(0.01)
420
+ except asyncio.TimeoutError as e:
421
+ msg = f"Connection acquisition timed out after {self._connect_timeout}s"
422
+ raise AiosqliteConnectTimeoutError(msg) from e
423
+ else:
424
+ return connection
425
+
426
+ async def release(self, connection: AiosqlitePoolConnection) -> None:
427
+ """Release a connection back to the pool.
428
+
429
+ Args:
430
+ connection: Connection to release
431
+ """
432
+ if self.is_closed:
433
+ await self._retire_connection(connection)
434
+ return
435
+
436
+ if connection.id not in self._connection_registry:
437
+ logger.warning("Attempted to release unknown connection: %s", connection.id)
438
+ return
439
+
440
+ try:
441
+ await asyncio.wait_for(connection.reset(), timeout=self._operation_timeout)
442
+ connection.mark_as_idle()
443
+ self._queue.put_nowait(connection)
444
+ logger.debug("Released connection back to pool: %s", connection.id)
445
+ except Exception as e:
446
+ logger.warning("Failed to reset connection %s during release: %s", connection.id, e)
447
+ await self._retire_connection(connection)
448
+
449
+ @asynccontextmanager
450
+ async def get_connection(self) -> "AsyncGenerator[AiosqliteConnection, None]":
451
+ """Get a connection with automatic release.
452
+
453
+ Yields:
454
+ Raw aiosqlite connection
455
+
456
+ """
457
+ connection = await self.acquire()
458
+ try:
459
+ yield connection.connection
460
+ finally:
461
+ await self.release(connection)
462
+
463
+ async def close(self) -> None:
464
+ """Close the connection pool gracefully.
465
+
466
+ Ensures all connections are properly closed and background threads are terminated.
467
+ """
468
+ if self.is_closed:
469
+ return
470
+ self._closed_event.set()
471
+
472
+ # Clear the queue
473
+ while not self._queue.empty():
474
+ self._queue.get_nowait()
475
+
476
+ # Get all connections and clear registry
477
+ async with self._lock:
478
+ connections = list(self._connection_registry.values())
479
+ self._connection_registry.clear()
480
+
481
+ # Close all connections
482
+ if connections:
483
+ close_tasks = [asyncio.wait_for(conn.close(), timeout=self._operation_timeout) for conn in connections]
484
+ results = await asyncio.gather(*close_tasks, return_exceptions=True)
485
+
486
+ # Log any close errors
487
+ for i, result in enumerate(results):
488
+ if isinstance(result, Exception):
489
+ logger.warning("Error closing connection %s: %s", connections[i].id, result)
490
+
491
+ await self._wait_for_threads_to_terminate(timeout=1.0)
492
+ logger.debug("Aiosqlite connection pool closed successfully")
@@ -8,11 +8,13 @@ from sqlspec.adapters.duckdb.config import (
8
8
  DuckDBSecretConfig,
9
9
  )
10
10
  from sqlspec.adapters.duckdb.driver import DuckDBCursor, DuckDBDriver, DuckDBExceptionHandler, duckdb_statement_config
11
+ from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool
11
12
 
12
13
  __all__ = (
13
14
  "DuckDBConfig",
14
15
  "DuckDBConnection",
15
16
  "DuckDBConnectionParams",
17
+ "DuckDBConnectionPool",
16
18
  "DuckDBCursor",
17
19
  "DuckDBDriver",
18
20
  "DuckDBExceptionHandler",