sqlobjects 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlobjects/session.py ADDED
@@ -0,0 +1,389 @@
1
+ import contextvars
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import AsyncGenerator
4
+ from contextlib import asynccontextmanager
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from sqlalchemy.exc import SQLAlchemyError
8
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
9
+
10
+ from .exceptions import convert_sqlalchemy_error
11
+
12
+
13
+ if TYPE_CHECKING:
14
+ from .database import Database # noqa
15
+
16
+
17
+ __all__ = [
18
+ "AsyncSession",
19
+ "SessionContextManager",
20
+ "ctx_session",
21
+ "ctx_sessions",
22
+ ]
23
+
24
+
25
+ # Explicit session management (highest priority)
26
+ _explicit_sessions: contextvars.ContextVar[dict[str, "AsyncSession"]] = contextvars.ContextVar("explicit_sessions")
27
+
28
+
29
+ class AsyncSession(ABC):
30
+ """Abstract base class for database sessions.
31
+
32
+ This class should not be instantiated directly. Use ctx_session() or
33
+ SessionContextManager.get_session() to obtain session instances.
34
+ """
35
+
36
+ @property
37
+ @abstractmethod
38
+ def db_name(self) -> str:
39
+ """Database name for this session."""
40
+ pass
41
+
42
+ @abstractmethod
43
+ async def execute(self, statement: Any, parameters: Any = None) -> Any:
44
+ """Execute statement with automatic transaction management."""
45
+ pass
46
+
47
+ @abstractmethod
48
+ async def stream(self, statement: Any, parameters: Any = None) -> Any:
49
+ """Execute statement and return streaming result."""
50
+ pass
51
+
52
+ @abstractmethod
53
+ async def commit(self) -> None:
54
+ """Commit transaction if exists and not readonly."""
55
+ pass
56
+
57
+ @abstractmethod
58
+ async def rollback(self) -> None:
59
+ """Rollback transaction if exists and not readonly."""
60
+ pass
61
+
62
+ @abstractmethod
63
+ async def close(self) -> None:
64
+ """Close session and cleanup resources."""
65
+ pass
66
+
67
+ @abstractmethod
68
+ async def __aenter__(self) -> "AsyncSession":
69
+ """Async context manager entry."""
70
+ pass
71
+
72
+ @abstractmethod
73
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
74
+ """Async context manager exit with automatic cleanup."""
75
+ pass
76
+
77
+
78
+ class _AsyncSession(AsyncSession):
79
+ """Internal AsyncSession implementation with smart connection and transaction management.
80
+
81
+ Two main usage patterns:
82
+ 1. Explicit sessions (ctx_session): auto_commit=False, manual transaction control
83
+ 2. Implicit sessions (get_session): auto_commit=True, automatic cleanup after each operation
84
+ """
85
+
86
+ def __init__(self, db_name: str, readonly: bool = True, auto_commit: bool = True):
87
+ """Initialize AsyncSession with lazy connection.
88
+
89
+ Args:
90
+ db_name: Database name
91
+ readonly: True for readonly (no transaction), False for transactional
92
+ auto_commit: True to auto-commit and close after each operation (ignored if readonly=True)
93
+ """
94
+ self._db_name = db_name
95
+ self.readonly = readonly
96
+ self.auto_commit = auto_commit and not readonly # readonly sessions never auto-commit
97
+ self._conn: AsyncConnection | None = None
98
+ self._trans = None
99
+
100
+ @property
101
+ def db_name(self) -> str:
102
+ """Database name for this session."""
103
+ return self._db_name
104
+
105
+ async def execute(self, statement: Any, parameters: Any = None) -> Any:
106
+ """Execute statement with automatic transaction management."""
107
+ await self._ensure_connection()
108
+
109
+ # Auto-begin transaction for non-readonly sessions
110
+ if not self.readonly and self._trans is None:
111
+ self._trans = await self._conn.begin() # type: ignore
112
+
113
+ try:
114
+ result = await self._conn.execute(statement, parameters) # type: ignore
115
+
116
+ # Auto-commit for implicit sessions
117
+ if self.auto_commit:
118
+ await self.commit()
119
+
120
+ return result
121
+ except SQLAlchemyError as e:
122
+ # Rollback transaction on SQLAlchemy error and convert
123
+ await self.rollback()
124
+ raise convert_sqlalchemy_error(e) from e
125
+ except Exception:
126
+ # Rollback transaction on other errors
127
+ await self.rollback()
128
+ raise
129
+ finally:
130
+ # Auto-close connection for implicit sessions to prevent resource leaks
131
+ if self.auto_commit:
132
+ await self.close()
133
+
134
+ async def stream(self, statement: Any, parameters: Any = None) -> Any:
135
+ """Execute statement and return streaming result."""
136
+ await self._ensure_connection()
137
+
138
+ # Auto-begin transaction for non-readonly sessions
139
+ if not self.readonly and self._trans is None:
140
+ self._trans = await self._conn.begin() # type: ignore
141
+
142
+ try:
143
+ result = await self._conn.stream(statement, parameters) # type: ignore
144
+
145
+ # Note: No auto-commit for streaming results as they need to remain open
146
+ # Connection cleanup will happen when the stream is consumed
147
+
148
+ return result
149
+ except SQLAlchemyError as e:
150
+ # Rollback transaction on SQLAlchemy error and convert
151
+ await self.rollback()
152
+ # Auto-close connection on error for implicit sessions
153
+ if self.auto_commit:
154
+ await self.close()
155
+ raise convert_sqlalchemy_error(e) from e
156
+ except Exception:
157
+ # Rollback transaction on other errors
158
+ await self.rollback()
159
+ # Auto-close connection on error for implicit sessions
160
+ if self.auto_commit:
161
+ await self.close()
162
+ raise
163
+
164
+ async def commit(self):
165
+ """Commit transaction if exists and not readonly."""
166
+ if self._trans and not self.readonly:
167
+ await self._trans.commit()
168
+ self._trans = None
169
+
170
+ async def rollback(self):
171
+ """Rollback transaction if exists and not readonly."""
172
+ if self._trans and not self.readonly:
173
+ await self._trans.rollback()
174
+ self._trans = None
175
+
176
+ async def close(self):
177
+ """Close session and cleanup resources."""
178
+ if self._trans:
179
+ await self._trans.rollback()
180
+ self._trans = None
181
+
182
+ if self._conn:
183
+ await self._conn.close()
184
+ self._conn = None
185
+
186
+ async def _ensure_connection(self):
187
+ """Ensure connection is available (lazy initialization)."""
188
+ if self._conn is None:
189
+ engine = SessionContextManager.engines[self._db_name]
190
+ self._conn = await engine.connect()
191
+
192
+ async def __aenter__(self) -> "AsyncSession":
193
+ """Async context manager entry."""
194
+ return self
195
+
196
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
197
+ """Async context manager exit with automatic cleanup."""
198
+ await self.close()
199
+
200
+ def __getattr__(self, name: str) -> Any:
201
+ """Proxy AsyncConnection methods."""
202
+ if self._conn and hasattr(self._conn, name):
203
+ return getattr(self._conn, name)
204
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
205
+
206
+
207
+ class SessionContextManager:
208
+ """Multi-database session context manager with readonly optimization.
209
+
210
+ Provides automatic session management based on SQLAlchemy Core,
211
+ supporting both readonly and transactional modes with intelligent session reuse.
212
+
213
+ Examples:
214
+ >>> # Set database engine
215
+ >>> SessionContextManager.set_engine(engine, "main", is_default=True)
216
+ >>> # Get readonly session (optimized for SELECT)
217
+ >>> async with SessionContextManager.get_session(readonly=True) as session:
218
+ ... result = await session.execute(text("SELECT 1"))
219
+ >>> # Get transactional session with auto-commit
220
+ >>> async with SessionContextManager.get_session(readonly=False) as session:
221
+ ... await session.execute(text("UPDATE users SET status='active'"))
222
+ """
223
+
224
+ engines: dict[str, AsyncEngine] = {}
225
+ default_db: str | None = None
226
+
227
+ @classmethod
228
+ def set_engine(cls, engine: AsyncEngine, db_name: str = "default", is_default: bool = False) -> None:
229
+ """Set database engine."""
230
+ cls.engines[db_name] = engine
231
+ if is_default or cls.default_db is None:
232
+ cls.default_db = db_name
233
+
234
+ @classmethod
235
+ def get_session(cls, db_name: str | None = None, readonly: bool = True, auto_commit: bool = True) -> AsyncSession:
236
+ """Get database session with readonly optimization.
237
+
238
+ Args:
239
+ db_name: Database name (uses default database if None)
240
+ readonly: True for readonly (no transaction), False for transactional
241
+ auto_commit: True to auto-commit after each operation (ignored if readonly=True)
242
+
243
+ Returns:
244
+ AsyncSession instance
245
+
246
+ Priority:
247
+ 1. Explicitly set sessions (ctx_session, ctx_sessions)
248
+ 2. Create new AsyncSession with specified parameters
249
+ """
250
+ name = db_name or cls.default_db or "default"
251
+
252
+ # Priority 1: Explicitly set sessions
253
+ try:
254
+ explicit_sessions = _explicit_sessions.get({})
255
+ if name in explicit_sessions:
256
+ return explicit_sessions[name]
257
+ except LookupError:
258
+ pass
259
+
260
+ # Priority 2: Create new _AsyncSession
261
+ return _AsyncSession(name, readonly, auto_commit)
262
+
263
+ @classmethod
264
+ def set_session(cls, session: AsyncSession, db_name: str | None = None) -> None:
265
+ """Set active session in current context."""
266
+ name = db_name or cls.default_db or "default"
267
+ try:
268
+ current_sessions = _explicit_sessions.get({})
269
+ except LookupError:
270
+ current_sessions = {}
271
+ new_sessions = current_sessions.copy()
272
+ new_sessions[name] = session
273
+ _explicit_sessions.set(new_sessions)
274
+
275
+ @classmethod
276
+ def set_default(cls, db_name: str) -> None:
277
+ """Set default database by name."""
278
+ if db_name not in cls.engines:
279
+ raise RuntimeError(f"Database '{db_name}' is not initialized")
280
+ cls.default_db = db_name
281
+
282
+ @classmethod
283
+ def clear_session(cls, db_name: str | None = None) -> None:
284
+ """Clear active session from current context."""
285
+ try:
286
+ current_sessions = _explicit_sessions.get({})
287
+ if db_name:
288
+ if db_name in current_sessions:
289
+ new_sessions = current_sessions.copy()
290
+ del new_sessions[db_name]
291
+ _explicit_sessions.set(new_sessions)
292
+ else:
293
+ _explicit_sessions.set({})
294
+ except LookupError:
295
+ pass
296
+
297
+ # DatabaseObserver protocol implementation
298
+ @classmethod
299
+ def on_database_added(cls, name: str, database: "Database", is_default: bool) -> None:
300
+ """Register engine when database is added"""
301
+ cls.set_engine(database.engine, name, is_default)
302
+
303
+ @classmethod
304
+ def on_database_closed(cls, name: str) -> None:
305
+ """Clean up engine when database is closed"""
306
+ if name in cls.engines:
307
+ del cls.engines[name]
308
+
309
+ @classmethod
310
+ def on_default_changed(cls, old_default: str | None, new_default: str | None) -> None:
311
+ """Update default setting when default database changes"""
312
+ cls.default_db = new_default
313
+
314
+
315
+ @asynccontextmanager
316
+ async def ctx_session(db_name: str | None = None) -> AsyncGenerator[AsyncSession, None]:
317
+ """Get async context manager for single database transactional session.
318
+
319
+ Creates a transactional session with manual commit control (auto_commit=False).
320
+ Transaction is automatically committed on successful exit or rolled back on exception.
321
+
322
+ Args:
323
+ db_name: Database name (uses default database if None)
324
+
325
+ Yields:
326
+ AsyncSession: Transactional session with manual commit control
327
+ """
328
+ name = db_name or SessionContextManager.default_db or "default"
329
+ session = _AsyncSession(name, readonly=False, auto_commit=False)
330
+
331
+ # Set as explicit session in context
332
+ SessionContextManager.set_session(session, name)
333
+
334
+ try:
335
+ yield session
336
+ # Auto-commit on successful exit
337
+ await session.commit()
338
+ except Exception:
339
+ # Auto-rollback on exception
340
+ await session.rollback()
341
+ raise
342
+ finally:
343
+ # Cleanup
344
+ await session.close()
345
+ SessionContextManager.clear_session(name)
346
+
347
+
348
+ @asynccontextmanager
349
+ async def ctx_sessions(*db_names: str) -> AsyncGenerator[dict[str, AsyncSession], None]:
350
+ """Get async context manager for multiple database transactional sessions.
351
+
352
+ Creates transactional sessions for multiple databases with manual commit control.
353
+ All transactions are automatically committed on successful exit or rolled back on exception.
354
+
355
+ Args:
356
+ *db_names: Database names
357
+
358
+ Yields:
359
+ dict[str, AsyncSession]: Dictionary mapping database names to sessions
360
+ """
361
+ if not db_names:
362
+ raise ValueError("At least one database name must be provided")
363
+
364
+ sessions: dict[str, AsyncSession] = {}
365
+
366
+ try:
367
+ # Create sessions for all databases
368
+ for db_name in db_names:
369
+ session = _AsyncSession(db_name, readonly=False, auto_commit=False)
370
+ sessions[db_name] = session
371
+ SessionContextManager.set_session(session, db_name)
372
+
373
+ yield sessions
374
+
375
+ # Auto-commit all sessions on successful exit
376
+ for session in sessions.values():
377
+ await session.commit()
378
+
379
+ except Exception:
380
+ # Auto-rollback all sessions on exception
381
+ for session in sessions.values():
382
+ await session.rollback()
383
+ raise
384
+
385
+ finally:
386
+ # Cleanup all sessions
387
+ for db_name, session in sessions.items():
388
+ await session.close()
389
+ SessionContextManager.clear_session(db_name)