diracx-db 0.0.1a22__py3-none-any.whl → 0.0.1a24__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,453 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ("utcnow", "Column", "NullColumn", "DateNowColumn", "BaseSQLDB")
4
-
5
- import contextlib
6
- import logging
7
- import os
8
- import re
9
- from abc import ABCMeta
10
- from collections.abc import AsyncIterator
11
- from contextvars import ContextVar
12
- from datetime import datetime, timedelta, timezone
13
- from functools import partial
14
- from typing import TYPE_CHECKING, Self, cast
15
-
16
- import sqlalchemy.types as types
17
- from pydantic import TypeAdapter
18
- from sqlalchemy import Column as RawColumn
19
- from sqlalchemy import DateTime, Enum, MetaData, func, select
20
- from sqlalchemy.exc import OperationalError
21
- from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
22
- from sqlalchemy.ext.compiler import compiles
23
- from sqlalchemy.sql import expression
24
-
25
- from diracx.core.exceptions import InvalidQueryError
26
- from diracx.core.extensions import select_from_extension
27
- from diracx.core.models import SortDirection
28
- from diracx.core.settings import SqlalchemyDsn
29
- from diracx.db.exceptions import DBUnavailable
30
-
31
- if TYPE_CHECKING:
32
- from sqlalchemy.types import TypeEngine
33
-
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- class utcnow(expression.FunctionElement):
38
- type: TypeEngine = DateTime()
39
- inherit_cache: bool = True
40
-
41
-
42
- @compiles(utcnow, "postgresql")
43
- def pg_utcnow(element, compiler, **kw) -> str:
44
- return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
45
-
46
-
47
- @compiles(utcnow, "mssql")
48
- def ms_utcnow(element, compiler, **kw) -> str:
49
- return "GETUTCDATE()"
50
-
51
-
52
- @compiles(utcnow, "mysql")
53
- def mysql_utcnow(element, compiler, **kw) -> str:
54
- return "(UTC_TIMESTAMP)"
55
-
56
-
57
- @compiles(utcnow, "sqlite")
58
- def sqlite_utcnow(element, compiler, **kw) -> str:
59
- return "DATETIME('now')"
60
-
61
-
62
- class date_trunc(expression.FunctionElement):
63
- """Sqlalchemy function to truncate a date to a given resolution.
64
-
65
- Primarily used to be able to query for a specific resolution of a date e.g.
66
-
67
- select * from table where date_trunc('day', date_column) = '2021-01-01'
68
- select * from table where date_trunc('year', date_column) = '2021'
69
- select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
70
- """
71
-
72
- type = DateTime()
73
- inherit_cache = True
74
-
75
- def __init__(self, *args, time_resolution, **kwargs) -> None:
76
- super().__init__(*args, **kwargs)
77
- self._time_resolution = time_resolution
78
-
79
-
80
- @compiles(date_trunc, "postgresql")
81
- def pg_date_trunc(element, compiler, **kw):
82
- res = {
83
- "SECOND": "second",
84
- "MINUTE": "minute",
85
- "HOUR": "hour",
86
- "DAY": "day",
87
- "MONTH": "month",
88
- "YEAR": "year",
89
- }[element._time_resolution]
90
- return f"date_trunc('{res}', {compiler.process(element.clauses)})"
91
-
92
-
93
- @compiles(date_trunc, "mysql")
94
- def mysql_date_trunc(element, compiler, **kw):
95
- pattern = {
96
- "SECOND": "%Y-%m-%d %H:%i:%S",
97
- "MINUTE": "%Y-%m-%d %H:%i",
98
- "HOUR": "%Y-%m-%d %H",
99
- "DAY": "%Y-%m-%d",
100
- "MONTH": "%Y-%m",
101
- "YEAR": "%Y",
102
- }[element._time_resolution]
103
-
104
- (dt_col,) = list(element.clauses)
105
- return compiler.process(func.date_format(dt_col, pattern))
106
-
107
-
108
- @compiles(date_trunc, "sqlite")
109
- def sqlite_date_trunc(element, compiler, **kw):
110
- pattern = {
111
- "SECOND": "%Y-%m-%d %H:%M:%S",
112
- "MINUTE": "%Y-%m-%d %H:%M",
113
- "HOUR": "%Y-%m-%d %H",
114
- "DAY": "%Y-%m-%d",
115
- "MONTH": "%Y-%m",
116
- "YEAR": "%Y",
117
- }[element._time_resolution]
118
- (dt_col,) = list(element.clauses)
119
- return compiler.process(
120
- func.strftime(
121
- pattern,
122
- dt_col,
123
- )
124
- )
125
-
126
-
127
- def substract_date(**kwargs: float) -> datetime:
128
- return datetime.now(tz=timezone.utc) - timedelta(**kwargs)
129
-
130
-
131
- Column: partial[RawColumn] = partial(RawColumn, nullable=False)
132
- NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True)
133
- DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow())
134
-
135
-
136
- def EnumColumn(enum_type, **kwargs):
137
- return Column(Enum(enum_type, native_enum=False, length=16), **kwargs)
138
-
139
-
140
- class EnumBackedBool(types.TypeDecorator):
141
- """Maps a ``EnumBackedBool()`` column to True/False in Python."""
142
-
143
- impl = types.Enum
144
- cache_ok: bool = True
145
-
146
- def __init__(self) -> None:
147
- super().__init__("True", "False")
148
-
149
- def process_bind_param(self, value, dialect) -> str:
150
- if value is True:
151
- return "True"
152
- elif value is False:
153
- return "False"
154
- else:
155
- raise NotImplementedError(value, dialect)
156
-
157
- def process_result_value(self, value, dialect) -> bool:
158
- if value == "True":
159
- return True
160
- elif value == "False":
161
- return False
162
- else:
163
- raise NotImplementedError(f"Unknown {value=}")
164
-
165
-
166
- class SQLDBError(Exception):
167
- pass
168
-
169
-
170
- class SQLDBUnavailable(DBUnavailable, SQLDBError):
171
- """Used whenever we encounter a problem with the B connection."""
172
-
173
-
174
- class BaseSQLDB(metaclass=ABCMeta):
175
- """This should be the base class of all the SQL DiracX DBs.
176
-
177
- The details covered here should be handled automatically by the service and
178
- task machinery of DiracX and this documentation exists for informational
179
- purposes.
180
-
181
- The available databases are discovered by calling `BaseSQLDB.available_urls`.
182
- This method returns a mapping of database names to connection URLs. The
183
- available databases are determined by the `diracx.dbs.sql` entrypoint in the
184
- `pyproject.toml` file and the connection URLs are taken from the environment
185
- variables of the form `DIRACX_DB_URL_<db-name>`.
186
-
187
- If extensions to DiracX are being used, there can be multiple implementations
188
- of the same database. To list the available implementations use
189
- `BaseSQLDB.available_implementations(db_name)`. The first entry in this list
190
- will be the preferred implementation and it can be initialized by calling
191
- it's `__init__` function with a URL perviously obtained from
192
- `BaseSQLDB.available_urls`.
193
-
194
- To control the lifetime of the SQLAlchemy engine used for connecting to the
195
- database, which includes the connection pool, the `BaseSQLDB.engine_context`
196
- asynchronous context manager should be entered. When inside this context
197
- manager, the engine can be accessed with `BaseSQLDB.engine`.
198
-
199
- Upon entering, the DB class can then be used as an asynchronous context
200
- manager to enter transactions. If an exception is raised the transaction is
201
- rolled back automatically. If the inner context exits peacefully, the
202
- transaction is committed automatically. When inside this context manager,
203
- the DB connection can be accessed with `BaseSQLDB.conn`.
204
-
205
- For example:
206
-
207
- ```python
208
- db_name = ...
209
- url = BaseSQLDB.available_urls()[db_name]
210
- MyDBClass = BaseSQLDB.available_implementations(db_name)[0]
211
-
212
- db = MyDBClass(url)
213
- async with db.engine_context:
214
- async with db:
215
- # Do something in the first transaction
216
- # Commit will be called automatically
217
-
218
- async with db:
219
- # This transaction will be rolled back due to the exception
220
- raise Exception(...)
221
- ```
222
- """
223
-
224
- # engine: AsyncEngine
225
- # TODO: Make metadata an abstract property
226
- metadata: MetaData
227
-
228
- def __init__(self, db_url: str) -> None:
229
- # We use a ContextVar to make sure that self._conn
230
- # is specific to each context, and avoid parallel
231
- # route executions to overlap
232
- self._conn: ContextVar[AsyncConnection | None] = ContextVar(
233
- "_conn", default=None
234
- )
235
- self._db_url = db_url
236
- self._engine: AsyncEngine | None = None
237
-
238
- @classmethod
239
- def available_implementations(cls, db_name: str) -> list[type[BaseSQLDB]]:
240
- """Return the available implementations of the DB in reverse priority order."""
241
- db_classes: list[type[BaseSQLDB]] = [
242
- entry_point.load()
243
- for entry_point in select_from_extension(
244
- group="diracx.db.sql", name=db_name
245
- )
246
- ]
247
- if not db_classes:
248
- raise NotImplementedError(f"Could not find any matches for {db_name=}")
249
- return db_classes
250
-
251
- @classmethod
252
- def available_urls(cls) -> dict[str, str]:
253
- """Return a dict of available database urls.
254
-
255
- The list of available URLs is determined by environment variables
256
- prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
257
- """
258
- db_urls: dict[str, str] = {}
259
- for entry_point in select_from_extension(group="diracx.db.sql"):
260
- db_name = entry_point.name
261
- var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
262
- if var_name in os.environ:
263
- try:
264
- db_url = os.environ[var_name]
265
- if db_url == "sqlite+aiosqlite:///:memory:":
266
- db_urls[db_name] = db_url
267
- else:
268
- db_urls[db_name] = str(
269
- TypeAdapter(SqlalchemyDsn).validate_python(db_url)
270
- )
271
- except Exception:
272
- logger.error("Error loading URL for %s", db_name)
273
- raise
274
- return db_urls
275
-
276
- @classmethod
277
- def transaction(cls) -> Self:
278
- raise NotImplementedError("This should never be called")
279
-
280
- @property
281
- def engine(self) -> AsyncEngine:
282
- """The engine to use for database operations.
283
-
284
- It is normally not necessary to use the engine directly, unless you are
285
- doing something special, like writing a test fixture that gives you a db.
286
-
287
- Requires that the engine_context has been entered.
288
- """
289
- assert self._engine is not None, "engine_context must be entered"
290
- return self._engine
291
-
292
- @contextlib.asynccontextmanager
293
- async def engine_context(self) -> AsyncIterator[None]:
294
- """Context manage to manage the engine lifecycle.
295
-
296
- This is called once at the application startup (see ``lifetime_functions``).
297
- """
298
- assert self._engine is None, "engine_context cannot be nested"
299
-
300
- # Set the pool_recycle to 30mn
301
- # That should prevent the problem of MySQL expiring connection
302
- # after 60mn by default
303
- engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
304
- self._engine = engine
305
- try:
306
- yield
307
- finally:
308
- self._engine = None
309
- await engine.dispose()
310
-
311
- @property
312
- def conn(self) -> AsyncConnection:
313
- if self._conn.get() is None:
314
- raise RuntimeError(f"{self.__class__} was used before entering")
315
- return cast(AsyncConnection, self._conn.get())
316
-
317
- async def __aenter__(self) -> Self:
318
- """Create a connection.
319
-
320
- This is called by the Dependency mechanism (see ``db_transaction``),
321
- It will create a new connection/transaction for each route call.
322
- """
323
- assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
324
- try:
325
- self._conn.set(await self.engine.connect().__aenter__())
326
- except Exception as e:
327
- raise SQLDBUnavailable(
328
- f"Cannot connect to {self.__class__.__name__}"
329
- ) from e
330
-
331
- return self
332
-
333
- async def __aexit__(self, exc_type, exc, tb):
334
- """This is called when exiting a route.
335
-
336
- If there was no exception, the changes in the DB are committed.
337
- Otherwise, they are rolled back.
338
- """
339
- if exc_type is None:
340
- await self._conn.get().commit()
341
- await self._conn.get().__aexit__(exc_type, exc, tb)
342
- self._conn.set(None)
343
-
344
- async def ping(self):
345
- """Check whether the connection to the DB is still working.
346
-
347
- We could enable the ``pre_ping`` in the engine, but this would be ran at
348
- every query.
349
- """
350
- try:
351
- await self.conn.scalar(select(1))
352
- except OperationalError as e:
353
- raise SQLDBUnavailable("Cannot ping the DB") from e
354
-
355
-
356
- def find_time_resolution(value):
357
- if isinstance(value, datetime):
358
- return None, value
359
- if match := re.fullmatch(
360
- r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{6}Z?)?)?)?)?)?)?", value
361
- ):
362
- if match.group(6):
363
- precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
364
- elif match.group(5):
365
- precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
366
- elif match.group(3):
367
- precision, pattern = "HOUR", r"\1-\2-\3 \4"
368
- elif match.group(2):
369
- precision, pattern = "DAY", r"\1-\2-\3"
370
- elif match.group(1):
371
- precision, pattern = "MONTH", r"\1-\2"
372
- else:
373
- precision, pattern = "YEAR", r"\1"
374
- return (
375
- precision,
376
- re.sub(
377
- r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{6})?Z?$",
378
- pattern,
379
- value,
380
- ),
381
- )
382
-
383
- raise InvalidQueryError(f"Cannot parse {value=}")
384
-
385
-
386
- def apply_search_filters(column_mapping, stmt, search):
387
- for query in search:
388
- try:
389
- column = column_mapping(query["parameter"])
390
- except KeyError as e:
391
- raise InvalidQueryError(f"Unknown column {query['parameter']}") from e
392
-
393
- if isinstance(column.type, DateTime):
394
- if "value" in query and isinstance(query["value"], str):
395
- resolution, value = find_time_resolution(query["value"])
396
- if resolution:
397
- column = date_trunc(column, time_resolution=resolution)
398
- query["value"] = value
399
-
400
- if query.get("values"):
401
- resolutions, values = zip(
402
- *map(find_time_resolution, query.get("values"))
403
- )
404
- if len(set(resolutions)) != 1:
405
- raise InvalidQueryError(
406
- f"Cannot mix different time resolutions in {query=}"
407
- )
408
- if resolution := resolutions[0]:
409
- column = date_trunc(column, time_resolution=resolution)
410
- query["values"] = values
411
-
412
- if query["operator"] == "eq":
413
- expr = column == query["value"]
414
- elif query["operator"] == "neq":
415
- expr = column != query["value"]
416
- elif query["operator"] == "gt":
417
- expr = column > query["value"]
418
- elif query["operator"] == "lt":
419
- expr = column < query["value"]
420
- elif query["operator"] == "in":
421
- expr = column.in_(query["values"])
422
- elif query["operator"] == "not in":
423
- expr = column.notin_(query["values"])
424
- elif query["operator"] in "like":
425
- expr = column.like(query["value"])
426
- elif query["operator"] in "ilike":
427
- expr = column.ilike(query["value"])
428
- else:
429
- raise InvalidQueryError(f"Unknown filter {query=}")
430
- stmt = stmt.where(expr)
431
- return stmt
432
-
433
-
434
- def apply_sort_constraints(column_mapping, stmt, sorts):
435
- sort_columns = []
436
- for sort in sorts or []:
437
- try:
438
- column = column_mapping(sort["parameter"])
439
- except KeyError as e:
440
- raise InvalidQueryError(
441
- f"Cannot sort by {sort['parameter']}: unknown column"
442
- ) from e
443
- sorted_column = None
444
- if sort["direction"] == SortDirection.ASC:
445
- sorted_column = column.asc()
446
- elif sort["direction"] == SortDirection.DESC:
447
- sorted_column = column.desc()
448
- else:
449
- raise InvalidQueryError(f"Unknown sort {sort['direction']=}")
450
- sort_columns.append(sorted_column)
451
- if sort_columns:
452
- stmt = stmt.order_by(*sort_columns)
453
- return stmt
3
+ from .base import (
4
+ BaseSQLDB,
5
+ SQLDBUnavailableError,
6
+ apply_search_filters,
7
+ apply_sort_constraints,
8
+ )
9
+ from .functions import substract_date, utcnow
10
+ from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn
11
+
12
+ __all__ = (
13
+ "utcnow",
14
+ "Column",
15
+ "NullColumn",
16
+ "DateNowColumn",
17
+ "BaseSQLDB",
18
+ "EnumBackedBool",
19
+ "EnumColumn",
20
+ "apply_search_filters",
21
+ "apply_sort_constraints",
22
+ "substract_date",
23
+ "SQLDBUnavailableError",
24
+ )