sqla-fancy-core 1.0.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqla-fancy-core might be problematic. Click here for more details.

@@ -0,0 +1,439 @@
1
+ """Some wrappers for fun times with SQLAlchemy core."""
2
+
3
+ from contextlib import asynccontextmanager, contextmanager
4
+ from contextvars import ContextVar
5
+ from typing import Any, Optional, TypeVar, overload
6
+
7
+ from sqlalchemy import Connection, CursorResult, Engine, Executable
8
+ from sqlalchemy.engine.interfaces import (
9
+ CoreExecuteOptionsParameter,
10
+ _CoreAnyExecuteParams,
11
+ )
12
+ from sqlalchemy.ext.asyncio import (
13
+ AsyncConnection,
14
+ AsyncEngine,
15
+ )
16
+ from sqlalchemy.sql.selectable import TypedReturnsRows
17
+
18
+ _T = TypeVar("_T", bound=Any)
19
+
20
+
21
+ class FancyError(Exception):
22
+ """Custom error for FancyEngineWrapper."""
23
+
24
+ pass
25
+
26
+
27
+ class AtomicContextError(FancyError):
28
+ """Error raised when ax() is called outside of an atomic context."""
29
+
30
+ def __init__(self) -> None:
31
+ super().__init__("ax() must be called within the atomic() context manager")
32
+
33
+
34
+ class FancyEngineWrapper:
35
+ """A wrapper around SQLAlchemy Engine with additional features."""
36
+
37
+ _ATOMIC_TX_CONN: ContextVar[Optional[Connection]] = ContextVar( # type: ignore
38
+ "fancy_global_transaction", default=None
39
+ )
40
+
41
+ def __init__(self, engine: Engine) -> None:
42
+ self.engine = engine
43
+
44
+ @contextmanager
45
+ def atomic(self):
46
+ """A context manager that provides a transactional connection."""
47
+ global_txn_conn = self._ATOMIC_TX_CONN.get()
48
+ if global_txn_conn is not None:
49
+ # Reuse existing transaction connection
50
+ yield global_txn_conn
51
+ else:
52
+ with self.engine.begin() as connection:
53
+ token = self._ATOMIC_TX_CONN.set(connection)
54
+ try:
55
+ yield connection
56
+ finally:
57
+ # Restore previous ContextVar state
58
+ self._ATOMIC_TX_CONN.reset(token)
59
+
60
+ @overload
61
+ def ax(
62
+ self,
63
+ statement: TypedReturnsRows[_T],
64
+ parameters: Optional[_CoreAnyExecuteParams] = None,
65
+ *,
66
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
67
+ ) -> CursorResult[_T]: ...
68
+ @overload
69
+ def ax(
70
+ self,
71
+ statement: Executable,
72
+ parameters: Optional[_CoreAnyExecuteParams] = None,
73
+ *,
74
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
75
+ ) -> CursorResult[Any]: ...
76
+ def ax(
77
+ self,
78
+ statement: Executable,
79
+ parameters: Optional[_CoreAnyExecuteParams] = None,
80
+ *,
81
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
82
+ ) -> CursorResult[Any]:
83
+ """Execute the query within the atomic context and return the result.
84
+
85
+ It must be called within the `atomic` context manager. Else an error is raised.
86
+ """
87
+ connection = self._ATOMIC_TX_CONN.get()
88
+ if connection:
89
+ return connection.execute(
90
+ statement, parameters, execution_options=execution_options
91
+ )
92
+ else:
93
+ raise AtomicContextError()
94
+
95
+ @overload
96
+ def x(
97
+ self,
98
+ connection: Optional[Connection],
99
+ statement: TypedReturnsRows[_T],
100
+ parameters: Optional[_CoreAnyExecuteParams] = None,
101
+ *,
102
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
103
+ ) -> CursorResult[_T]: ...
104
+ @overload
105
+ def x(
106
+ self,
107
+ connection: Optional[Connection],
108
+ statement: Executable,
109
+ parameters: Optional[_CoreAnyExecuteParams] = None,
110
+ *,
111
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
112
+ ) -> CursorResult[Any]: ...
113
+ def x(
114
+ self,
115
+ connection: Optional[Connection],
116
+ statement: Executable,
117
+ parameters: Optional[_CoreAnyExecuteParams] = None,
118
+ *,
119
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
120
+ ) -> CursorResult[Any]:
121
+ """Connect to the database, execute the query, and return the result.
122
+
123
+ If a connection is provided, use it; otherwise, create a new one.
124
+ """
125
+ connection = connection
126
+ if connection:
127
+ return connection.execute(
128
+ statement, parameters, execution_options=execution_options
129
+ )
130
+ else:
131
+ with self.engine.connect() as connection:
132
+ return connection.execute(
133
+ statement, parameters, execution_options=execution_options
134
+ )
135
+
136
+ @overload
137
+ def tx(
138
+ self,
139
+ connection: Optional[Connection],
140
+ statement: TypedReturnsRows[_T],
141
+ parameters: Optional[_CoreAnyExecuteParams] = None,
142
+ *,
143
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
144
+ ) -> CursorResult[_T]: ...
145
+ @overload
146
+ def tx(
147
+ self,
148
+ connection: Optional[Connection],
149
+ statement: Executable,
150
+ parameters: Optional[_CoreAnyExecuteParams] = None,
151
+ *,
152
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
153
+ ) -> CursorResult[Any]: ...
154
+ def tx(
155
+ self,
156
+ connection: Optional[Connection],
157
+ statement: Executable,
158
+ parameters: Optional[_CoreAnyExecuteParams] = None,
159
+ *,
160
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
161
+ ) -> CursorResult[Any]:
162
+ """Begin a transaction, execute the query, and return the result.
163
+
164
+ If a connection is provided, use it; otherwise, use the global atomic
165
+ context or create a new one.
166
+ """
167
+ connection = connection or self._ATOMIC_TX_CONN.get()
168
+ if connection:
169
+ if connection.in_transaction():
170
+ # Transaction is already active
171
+ return connection.execute(
172
+ statement, parameters, execution_options=execution_options
173
+ )
174
+ else:
175
+ with connection.begin():
176
+ return connection.execute(
177
+ statement, parameters, execution_options=execution_options
178
+ )
179
+ else:
180
+ with self.engine.begin() as connection:
181
+ return connection.execute(
182
+ statement, parameters, execution_options=execution_options
183
+ )
184
+
185
+ @overload
186
+ def atx(
187
+ self,
188
+ statement: TypedReturnsRows[_T],
189
+ parameters: Optional[_CoreAnyExecuteParams] = None,
190
+ *,
191
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
192
+ ) -> CursorResult[_T]: ...
193
+ @overload
194
+ def atx(
195
+ self,
196
+ statement: Executable,
197
+ parameters: Optional[_CoreAnyExecuteParams] = None,
198
+ *,
199
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
200
+ ) -> CursorResult[Any]: ...
201
+ def atx(
202
+ self,
203
+ statement: Executable,
204
+ parameters: Optional[_CoreAnyExecuteParams] = None,
205
+ *,
206
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
207
+ ) -> CursorResult[Any]:
208
+ """If within an atomic context, execute the query there; else, create a new transaction."""
209
+
210
+ conn = self._ATOMIC_TX_CONN.get()
211
+ if conn:
212
+ return conn.execute(
213
+ statement, parameters, execution_options=execution_options
214
+ )
215
+ else:
216
+ with self.engine.begin() as conn:
217
+ return conn.execute(
218
+ statement, parameters, execution_options=execution_options
219
+ )
220
+
221
+
222
+ class AsyncFancyEngineWrapper:
223
+ """A wrapper around SQLAlchemy AsyncEngine with additional features."""
224
+
225
+ _ATOMIC_TX_CONN: ContextVar[Optional[AsyncConnection]] = ContextVar( # type: ignore
226
+ "fancy_global_transaction", default=None
227
+ )
228
+
229
+ def __init__(self, engine: AsyncEngine) -> None:
230
+ self.engine = engine
231
+
232
+ @asynccontextmanager
233
+ async def atomic(self):
234
+ """An async context manager that provides a transactional connection."""
235
+ global_txn_conn = self._ATOMIC_TX_CONN.get()
236
+ if global_txn_conn is not None:
237
+ yield global_txn_conn
238
+ else:
239
+ async with self.engine.begin() as connection:
240
+ token = self._ATOMIC_TX_CONN.set(connection)
241
+ try:
242
+ yield connection
243
+ finally:
244
+ self._ATOMIC_TX_CONN.reset(token)
245
+
246
+ @overload
247
+ async def ax(
248
+ self,
249
+ statement: TypedReturnsRows[_T],
250
+ parameters: Optional[_CoreAnyExecuteParams] = None,
251
+ *,
252
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
253
+ ) -> CursorResult[_T]: ...
254
+ @overload
255
+ async def ax(
256
+ self,
257
+ statement: Executable,
258
+ parameters: Optional[_CoreAnyExecuteParams] = None,
259
+ *,
260
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
261
+ ) -> CursorResult[Any]: ...
262
+ async def ax(
263
+ self,
264
+ statement: Executable,
265
+ parameters: Optional[_CoreAnyExecuteParams] = None,
266
+ *,
267
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
268
+ ) -> CursorResult[Any]:
269
+ """Execute the query within the atomic context and return the result.
270
+
271
+ It must be called within the `atomic` context manager. Else an error is raised.
272
+ """
273
+ connection = self._ATOMIC_TX_CONN.get()
274
+ if connection:
275
+ return await connection.execute(
276
+ statement, parameters, execution_options=execution_options
277
+ )
278
+ else:
279
+ raise AtomicContextError()
280
+
281
+ @overload
282
+ async def x(
283
+ self,
284
+ connection: Optional[AsyncConnection],
285
+ statement: TypedReturnsRows[_T],
286
+ parameters: Optional[_CoreAnyExecuteParams] = None,
287
+ *,
288
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
289
+ ) -> CursorResult[_T]: ...
290
+ @overload
291
+ async def x(
292
+ self,
293
+ connection: Optional[AsyncConnection],
294
+ statement: Executable,
295
+ parameters: Optional[_CoreAnyExecuteParams] = None,
296
+ *,
297
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
298
+ ) -> CursorResult[Any]: ...
299
+ async def x(
300
+ self,
301
+ connection: Optional[AsyncConnection],
302
+ statement: Executable,
303
+ parameters: Optional[_CoreAnyExecuteParams] = None,
304
+ *,
305
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
306
+ ) -> CursorResult[Any]:
307
+ """Connect to the database, execute the query, and return the result.
308
+
309
+ If a connection is provided, use it; otherwise, create a new one.
310
+ """
311
+ if connection:
312
+ return await connection.execute(
313
+ statement, parameters, execution_options=execution_options
314
+ )
315
+ else:
316
+ async with self.engine.connect() as connection:
317
+ return await connection.execute(
318
+ statement, parameters, execution_options=execution_options
319
+ )
320
+
321
+ @overload
322
+ async def tx(
323
+ self,
324
+ connection: Optional[AsyncConnection],
325
+ statement: TypedReturnsRows[_T],
326
+ parameters: Optional[_CoreAnyExecuteParams] = None,
327
+ *,
328
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
329
+ ) -> CursorResult[_T]: ...
330
+ @overload
331
+ async def tx(
332
+ self,
333
+ connection: Optional[AsyncConnection],
334
+ statement: Executable,
335
+ parameters: Optional[_CoreAnyExecuteParams] = None,
336
+ *,
337
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
338
+ ) -> CursorResult[Any]: ...
339
+ async def tx(
340
+ self,
341
+ connection: Optional[AsyncConnection],
342
+ statement: Executable,
343
+ parameters: Optional[_CoreAnyExecuteParams] = None,
344
+ *,
345
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
346
+ ) -> CursorResult[Any]:
347
+ """Execute the query within a transaction and return the result.
348
+
349
+ If a connection is provided, use it; otherwise, use the global atomic
350
+ context or create a new one.
351
+ """
352
+ connection = connection or self._ATOMIC_TX_CONN.get()
353
+ if connection:
354
+ if connection.in_transaction():
355
+ return await connection.execute(
356
+ statement, parameters, execution_options=execution_options
357
+ )
358
+ else:
359
+ async with connection.begin():
360
+ return await connection.execute(
361
+ statement, parameters, execution_options=execution_options
362
+ )
363
+ else:
364
+ async with self.engine.begin() as connection:
365
+ return await connection.execute(
366
+ statement, parameters, execution_options=execution_options
367
+ )
368
+
369
+ @overload
370
+ async def atx(
371
+ self,
372
+ statement: TypedReturnsRows[_T],
373
+ parameters: Optional[_CoreAnyExecuteParams] = None,
374
+ *,
375
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
376
+ ) -> CursorResult[_T]: ...
377
+ @overload
378
+ async def atx(
379
+ self,
380
+ statement: Executable,
381
+ parameters: Optional[_CoreAnyExecuteParams] = None,
382
+ *,
383
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
384
+ ) -> CursorResult[Any]: ...
385
+ async def atx(
386
+ self,
387
+ statement: Executable,
388
+ parameters: Optional[_CoreAnyExecuteParams] = None,
389
+ *,
390
+ execution_options: Optional[CoreExecuteOptionsParameter] = None,
391
+ ) -> CursorResult[Any]:
392
+ """If within an atomic context, execute the query there; else, create a new transaction."""
393
+
394
+ connection = self._ATOMIC_TX_CONN.get()
395
+ if connection:
396
+ return await connection.execute(
397
+ statement, parameters, execution_options=execution_options
398
+ )
399
+ else:
400
+ async with self.engine.begin() as connection:
401
+ return await connection.execute(
402
+ statement, parameters, execution_options=execution_options
403
+ )
404
+
405
+
406
+ @overload
407
+ def fancy(obj: Engine, /) -> FancyEngineWrapper: ...
408
+ @overload
409
+ def fancy(obj: AsyncEngine, /) -> AsyncFancyEngineWrapper: ...
410
+ def fancy(obj, /):
411
+ """Fancy engine wrapper makes the following syntax possible: ::
412
+
413
+ import sqlalchemy as sa
414
+
415
+ fancy_engine = fancy(sa.create_engine("sqlite:///:memory:"))
416
+
417
+ def handler(conn: sa.Connection | None = None):
418
+ # Execute a query outside of a transaction
419
+ result = fancy_engine.x(conn, sa.select(...))
420
+
421
+ # Execute a query within a transaction
422
+ result = fancy_engine.tx(conn, sa.insert(...))
423
+
424
+ # Using an explicit connection:
425
+ with fancy_engine.engine.connect() as conn:
426
+ handler(conn=conn)
427
+
428
+ # Using a dependency injection system:
429
+ handler(conn=dependency(transaction)) # Uses the provided transaction connection
430
+
431
+ # Or without a given connection (e.g. in IPython shell):
432
+ handler()
433
+ """
434
+ if isinstance(obj, Engine):
435
+ return FancyEngineWrapper(obj)
436
+ elif isinstance(obj, AsyncEngine):
437
+ return AsyncFancyEngineWrapper(obj)
438
+ else:
439
+ raise TypeError("Unsupported input type for fancy()")