diracx-db 0.0.1a22__py3-none-any.whl → 0.0.1a24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,453 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ("utcnow", "Column", "NullColumn", "DateNowColumn", "BaseSQLDB")
4
-
5
- import contextlib
6
- import logging
7
- import os
8
- import re
9
- from abc import ABCMeta
10
- from collections.abc import AsyncIterator
11
- from contextvars import ContextVar
12
- from datetime import datetime, timedelta, timezone
13
- from functools import partial
14
- from typing import TYPE_CHECKING, Self, cast
15
-
16
- import sqlalchemy.types as types
17
- from pydantic import TypeAdapter
18
- from sqlalchemy import Column as RawColumn
19
- from sqlalchemy import DateTime, Enum, MetaData, func, select
20
- from sqlalchemy.exc import OperationalError
21
- from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
22
- from sqlalchemy.ext.compiler import compiles
23
- from sqlalchemy.sql import expression
24
-
25
- from diracx.core.exceptions import InvalidQueryError
26
- from diracx.core.extensions import select_from_extension
27
- from diracx.core.models import SortDirection
28
- from diracx.core.settings import SqlalchemyDsn
29
- from diracx.db.exceptions import DBUnavailable
30
-
31
- if TYPE_CHECKING:
32
- from sqlalchemy.types import TypeEngine
33
-
34
- logger = logging.getLogger(__name__)
35
-
36
-
37
- class utcnow(expression.FunctionElement):
38
- type: TypeEngine = DateTime()
39
- inherit_cache: bool = True
40
-
41
-
42
- @compiles(utcnow, "postgresql")
43
- def pg_utcnow(element, compiler, **kw) -> str:
44
- return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
45
-
46
-
47
- @compiles(utcnow, "mssql")
48
- def ms_utcnow(element, compiler, **kw) -> str:
49
- return "GETUTCDATE()"
50
-
51
-
52
- @compiles(utcnow, "mysql")
53
- def mysql_utcnow(element, compiler, **kw) -> str:
54
- return "(UTC_TIMESTAMP)"
55
-
56
-
57
- @compiles(utcnow, "sqlite")
58
- def sqlite_utcnow(element, compiler, **kw) -> str:
59
- return "DATETIME('now')"
60
-
61
-
62
- class date_trunc(expression.FunctionElement):
63
- """Sqlalchemy function to truncate a date to a given resolution.
64
-
65
- Primarily used to be able to query for a specific resolution of a date e.g.
66
-
67
- select * from table where date_trunc('day', date_column) = '2021-01-01'
68
- select * from table where date_trunc('year', date_column) = '2021'
69
- select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
70
- """
71
-
72
- type = DateTime()
73
- inherit_cache = True
74
-
75
- def __init__(self, *args, time_resolution, **kwargs) -> None:
76
- super().__init__(*args, **kwargs)
77
- self._time_resolution = time_resolution
78
-
79
-
80
- @compiles(date_trunc, "postgresql")
81
- def pg_date_trunc(element, compiler, **kw):
82
- res = {
83
- "SECOND": "second",
84
- "MINUTE": "minute",
85
- "HOUR": "hour",
86
- "DAY": "day",
87
- "MONTH": "month",
88
- "YEAR": "year",
89
- }[element._time_resolution]
90
- return f"date_trunc('{res}', {compiler.process(element.clauses)})"
91
-
92
-
93
- @compiles(date_trunc, "mysql")
94
- def mysql_date_trunc(element, compiler, **kw):
95
- pattern = {
96
- "SECOND": "%Y-%m-%d %H:%i:%S",
97
- "MINUTE": "%Y-%m-%d %H:%i",
98
- "HOUR": "%Y-%m-%d %H",
99
- "DAY": "%Y-%m-%d",
100
- "MONTH": "%Y-%m",
101
- "YEAR": "%Y",
102
- }[element._time_resolution]
103
-
104
- (dt_col,) = list(element.clauses)
105
- return compiler.process(func.date_format(dt_col, pattern))
106
-
107
-
108
- @compiles(date_trunc, "sqlite")
109
- def sqlite_date_trunc(element, compiler, **kw):
110
- pattern = {
111
- "SECOND": "%Y-%m-%d %H:%M:%S",
112
- "MINUTE": "%Y-%m-%d %H:%M",
113
- "HOUR": "%Y-%m-%d %H",
114
- "DAY": "%Y-%m-%d",
115
- "MONTH": "%Y-%m",
116
- "YEAR": "%Y",
117
- }[element._time_resolution]
118
- (dt_col,) = list(element.clauses)
119
- return compiler.process(
120
- func.strftime(
121
- pattern,
122
- dt_col,
123
- )
124
- )
125
-
126
-
127
- def substract_date(**kwargs: float) -> datetime:
128
- return datetime.now(tz=timezone.utc) - timedelta(**kwargs)
129
-
130
-
131
- Column: partial[RawColumn] = partial(RawColumn, nullable=False)
132
- NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True)
133
- DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow())
134
-
135
-
136
- def EnumColumn(enum_type, **kwargs):
137
- return Column(Enum(enum_type, native_enum=False, length=16), **kwargs)
138
-
139
-
140
- class EnumBackedBool(types.TypeDecorator):
141
- """Maps a ``EnumBackedBool()`` column to True/False in Python."""
142
-
143
- impl = types.Enum
144
- cache_ok: bool = True
145
-
146
- def __init__(self) -> None:
147
- super().__init__("True", "False")
148
-
149
- def process_bind_param(self, value, dialect) -> str:
150
- if value is True:
151
- return "True"
152
- elif value is False:
153
- return "False"
154
- else:
155
- raise NotImplementedError(value, dialect)
156
-
157
- def process_result_value(self, value, dialect) -> bool:
158
- if value == "True":
159
- return True
160
- elif value == "False":
161
- return False
162
- else:
163
- raise NotImplementedError(f"Unknown {value=}")
164
-
165
-
166
- class SQLDBError(Exception):
167
- pass
168
-
169
-
170
- class SQLDBUnavailable(DBUnavailable, SQLDBError):
171
- """Used whenever we encounter a problem with the B connection."""
172
-
173
-
174
- class BaseSQLDB(metaclass=ABCMeta):
175
- """This should be the base class of all the SQL DiracX DBs.
176
-
177
- The details covered here should be handled automatically by the service and
178
- task machinery of DiracX and this documentation exists for informational
179
- purposes.
180
-
181
- The available databases are discovered by calling `BaseSQLDB.available_urls`.
182
- This method returns a mapping of database names to connection URLs. The
183
- available databases are determined by the `diracx.dbs.sql` entrypoint in the
184
- `pyproject.toml` file and the connection URLs are taken from the environment
185
- variables of the form `DIRACX_DB_URL_<db-name>`.
186
-
187
- If extensions to DiracX are being used, there can be multiple implementations
188
- of the same database. To list the available implementations use
189
- `BaseSQLDB.available_implementations(db_name)`. The first entry in this list
190
- will be the preferred implementation and it can be initialized by calling
191
- it's `__init__` function with a URL perviously obtained from
192
- `BaseSQLDB.available_urls`.
193
-
194
- To control the lifetime of the SQLAlchemy engine used for connecting to the
195
- database, which includes the connection pool, the `BaseSQLDB.engine_context`
196
- asynchronous context manager should be entered. When inside this context
197
- manager, the engine can be accessed with `BaseSQLDB.engine`.
198
-
199
- Upon entering, the DB class can then be used as an asynchronous context
200
- manager to enter transactions. If an exception is raised the transaction is
201
- rolled back automatically. If the inner context exits peacefully, the
202
- transaction is committed automatically. When inside this context manager,
203
- the DB connection can be accessed with `BaseSQLDB.conn`.
204
-
205
- For example:
206
-
207
- ```python
208
- db_name = ...
209
- url = BaseSQLDB.available_urls()[db_name]
210
- MyDBClass = BaseSQLDB.available_implementations(db_name)[0]
211
-
212
- db = MyDBClass(url)
213
- async with db.engine_context:
214
- async with db:
215
- # Do something in the first transaction
216
- # Commit will be called automatically
217
-
218
- async with db:
219
- # This transaction will be rolled back due to the exception
220
- raise Exception(...)
221
- ```
222
- """
223
-
224
- # engine: AsyncEngine
225
- # TODO: Make metadata an abstract property
226
- metadata: MetaData
227
-
228
- def __init__(self, db_url: str) -> None:
229
- # We use a ContextVar to make sure that self._conn
230
- # is specific to each context, and avoid parallel
231
- # route executions to overlap
232
- self._conn: ContextVar[AsyncConnection | None] = ContextVar(
233
- "_conn", default=None
234
- )
235
- self._db_url = db_url
236
- self._engine: AsyncEngine | None = None
237
-
238
- @classmethod
239
- def available_implementations(cls, db_name: str) -> list[type[BaseSQLDB]]:
240
- """Return the available implementations of the DB in reverse priority order."""
241
- db_classes: list[type[BaseSQLDB]] = [
242
- entry_point.load()
243
- for entry_point in select_from_extension(
244
- group="diracx.db.sql", name=db_name
245
- )
246
- ]
247
- if not db_classes:
248
- raise NotImplementedError(f"Could not find any matches for {db_name=}")
249
- return db_classes
250
-
251
- @classmethod
252
- def available_urls(cls) -> dict[str, str]:
253
- """Return a dict of available database urls.
254
-
255
- The list of available URLs is determined by environment variables
256
- prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
257
- """
258
- db_urls: dict[str, str] = {}
259
- for entry_point in select_from_extension(group="diracx.db.sql"):
260
- db_name = entry_point.name
261
- var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
262
- if var_name in os.environ:
263
- try:
264
- db_url = os.environ[var_name]
265
- if db_url == "sqlite+aiosqlite:///:memory:":
266
- db_urls[db_name] = db_url
267
- else:
268
- db_urls[db_name] = str(
269
- TypeAdapter(SqlalchemyDsn).validate_python(db_url)
270
- )
271
- except Exception:
272
- logger.error("Error loading URL for %s", db_name)
273
- raise
274
- return db_urls
275
-
276
- @classmethod
277
- def transaction(cls) -> Self:
278
- raise NotImplementedError("This should never be called")
279
-
280
- @property
281
- def engine(self) -> AsyncEngine:
282
- """The engine to use for database operations.
283
-
284
- It is normally not necessary to use the engine directly, unless you are
285
- doing something special, like writing a test fixture that gives you a db.
286
-
287
- Requires that the engine_context has been entered.
288
- """
289
- assert self._engine is not None, "engine_context must be entered"
290
- return self._engine
291
-
292
- @contextlib.asynccontextmanager
293
- async def engine_context(self) -> AsyncIterator[None]:
294
- """Context manage to manage the engine lifecycle.
295
-
296
- This is called once at the application startup (see ``lifetime_functions``).
297
- """
298
- assert self._engine is None, "engine_context cannot be nested"
299
-
300
- # Set the pool_recycle to 30mn
301
- # That should prevent the problem of MySQL expiring connection
302
- # after 60mn by default
303
- engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
304
- self._engine = engine
305
- try:
306
- yield
307
- finally:
308
- self._engine = None
309
- await engine.dispose()
310
-
311
- @property
312
- def conn(self) -> AsyncConnection:
313
- if self._conn.get() is None:
314
- raise RuntimeError(f"{self.__class__} was used before entering")
315
- return cast(AsyncConnection, self._conn.get())
316
-
317
- async def __aenter__(self) -> Self:
318
- """Create a connection.
319
-
320
- This is called by the Dependency mechanism (see ``db_transaction``),
321
- It will create a new connection/transaction for each route call.
322
- """
323
- assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
324
- try:
325
- self._conn.set(await self.engine.connect().__aenter__())
326
- except Exception as e:
327
- raise SQLDBUnavailable(
328
- f"Cannot connect to {self.__class__.__name__}"
329
- ) from e
330
-
331
- return self
332
-
333
- async def __aexit__(self, exc_type, exc, tb):
334
- """This is called when exiting a route.
335
-
336
- If there was no exception, the changes in the DB are committed.
337
- Otherwise, they are rolled back.
338
- """
339
- if exc_type is None:
340
- await self._conn.get().commit()
341
- await self._conn.get().__aexit__(exc_type, exc, tb)
342
- self._conn.set(None)
343
-
344
- async def ping(self):
345
- """Check whether the connection to the DB is still working.
346
-
347
- We could enable the ``pre_ping`` in the engine, but this would be ran at
348
- every query.
349
- """
350
- try:
351
- await self.conn.scalar(select(1))
352
- except OperationalError as e:
353
- raise SQLDBUnavailable("Cannot ping the DB") from e
354
-
355
-
356
- def find_time_resolution(value):
357
- if isinstance(value, datetime):
358
- return None, value
359
- if match := re.fullmatch(
360
- r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{6}Z?)?)?)?)?)?)?", value
361
- ):
362
- if match.group(6):
363
- precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
364
- elif match.group(5):
365
- precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
366
- elif match.group(3):
367
- precision, pattern = "HOUR", r"\1-\2-\3 \4"
368
- elif match.group(2):
369
- precision, pattern = "DAY", r"\1-\2-\3"
370
- elif match.group(1):
371
- precision, pattern = "MONTH", r"\1-\2"
372
- else:
373
- precision, pattern = "YEAR", r"\1"
374
- return (
375
- precision,
376
- re.sub(
377
- r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{6})?Z?$",
378
- pattern,
379
- value,
380
- ),
381
- )
382
-
383
- raise InvalidQueryError(f"Cannot parse {value=}")
384
-
385
-
386
- def apply_search_filters(column_mapping, stmt, search):
387
- for query in search:
388
- try:
389
- column = column_mapping(query["parameter"])
390
- except KeyError as e:
391
- raise InvalidQueryError(f"Unknown column {query['parameter']}") from e
392
-
393
- if isinstance(column.type, DateTime):
394
- if "value" in query and isinstance(query["value"], str):
395
- resolution, value = find_time_resolution(query["value"])
396
- if resolution:
397
- column = date_trunc(column, time_resolution=resolution)
398
- query["value"] = value
399
-
400
- if query.get("values"):
401
- resolutions, values = zip(
402
- *map(find_time_resolution, query.get("values"))
403
- )
404
- if len(set(resolutions)) != 1:
405
- raise InvalidQueryError(
406
- f"Cannot mix different time resolutions in {query=}"
407
- )
408
- if resolution := resolutions[0]:
409
- column = date_trunc(column, time_resolution=resolution)
410
- query["values"] = values
411
-
412
- if query["operator"] == "eq":
413
- expr = column == query["value"]
414
- elif query["operator"] == "neq":
415
- expr = column != query["value"]
416
- elif query["operator"] == "gt":
417
- expr = column > query["value"]
418
- elif query["operator"] == "lt":
419
- expr = column < query["value"]
420
- elif query["operator"] == "in":
421
- expr = column.in_(query["values"])
422
- elif query["operator"] == "not in":
423
- expr = column.notin_(query["values"])
424
- elif query["operator"] in "like":
425
- expr = column.like(query["value"])
426
- elif query["operator"] in "ilike":
427
- expr = column.ilike(query["value"])
428
- else:
429
- raise InvalidQueryError(f"Unknown filter {query=}")
430
- stmt = stmt.where(expr)
431
- return stmt
432
-
433
-
434
- def apply_sort_constraints(column_mapping, stmt, sorts):
435
- sort_columns = []
436
- for sort in sorts or []:
437
- try:
438
- column = column_mapping(sort["parameter"])
439
- except KeyError as e:
440
- raise InvalidQueryError(
441
- f"Cannot sort by {sort['parameter']}: unknown column"
442
- ) from e
443
- sorted_column = None
444
- if sort["direction"] == SortDirection.ASC:
445
- sorted_column = column.asc()
446
- elif sort["direction"] == SortDirection.DESC:
447
- sorted_column = column.desc()
448
- else:
449
- raise InvalidQueryError(f"Unknown sort {sort['direction']=}")
450
- sort_columns.append(sorted_column)
451
- if sort_columns:
452
- stmt = stmt.order_by(*sort_columns)
453
- return stmt
3
+ from .base import (
4
+ BaseSQLDB,
5
+ SQLDBUnavailableError,
6
+ apply_search_filters,
7
+ apply_sort_constraints,
8
+ )
9
+ from .functions import substract_date, utcnow
10
+ from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn
11
+
12
+ __all__ = (
13
+ "utcnow",
14
+ "Column",
15
+ "NullColumn",
16
+ "DateNowColumn",
17
+ "BaseSQLDB",
18
+ "EnumBackedBool",
19
+ "EnumColumn",
20
+ "apply_search_filters",
21
+ "apply_sort_constraints",
22
+ "substract_date",
23
+ "SQLDBUnavailableError",
24
+ )