diracx-db 0.0.1a21__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of diracx-db might be problematic. Click here for more details.

@@ -0,0 +1,461 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import logging
5
+ import os
6
+ import re
7
+ from abc import ABCMeta
8
+ from collections.abc import AsyncIterator
9
+ from contextvars import ContextVar
10
+ from datetime import datetime, timezone
11
+ from typing import Any, Self, cast
12
+ from uuid import UUID as StdUUID # noqa: N811
13
+
14
+ from pydantic import TypeAdapter
15
+ from sqlalchemy import DateTime, MetaData, func, select
16
+ from sqlalchemy.exc import OperationalError
17
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
18
+ from uuid_utils import UUID, uuid7
19
+
20
+ from diracx.core.exceptions import InvalidQueryError
21
+ from diracx.core.extensions import select_from_extension
22
+ from diracx.core.models import (
23
+ SearchSpec,
24
+ SortDirection,
25
+ SortSpec,
26
+ )
27
+ from diracx.core.settings import SqlalchemyDsn
28
+ from diracx.db.exceptions import DBUnavailableError
29
+ from diracx.db.sql.utils.types import SmarterDateTime
30
+
31
+ from .functions import date_trunc
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class SQLDBError(Exception):
37
+ pass
38
+
39
+
40
+ class SQLDBUnavailableError(DBUnavailableError, SQLDBError):
41
+ """Used whenever we encounter a problem with the B connection."""
42
+
43
+
44
+ class BaseSQLDB(metaclass=ABCMeta):
45
+ """This should be the base class of all the SQL DiracX DBs.
46
+
47
+ The details covered here should be handled automatically by the service and
48
+ task machinery of DiracX and this documentation exists for informational
49
+ purposes.
50
+
51
+ The available databases are discovered by calling `BaseSQLDB.available_urls`.
52
+ This method returns a mapping of database names to connection URLs. The
53
+ available databases are determined by the `diracx.dbs.sql` entrypoint in the
54
+ `pyproject.toml` file and the connection URLs are taken from the environment
55
+ variables of the form `DIRACX_DB_URL_<db-name>`.
56
+
57
+ If extensions to DiracX are being used, there can be multiple implementations
58
+ of the same database. To list the available implementations use
59
+ `BaseSQLDB.available_implementations(db_name)`. The first entry in this list
60
+ will be the preferred implementation and it can be initialized by calling
61
+ it's `__init__` function with a URL previously obtained from
62
+ `BaseSQLDB.available_urls`.
63
+
64
+ To control the lifetime of the SQLAlchemy engine used for connecting to the
65
+ database, which includes the connection pool, the `BaseSQLDB.engine_context`
66
+ asynchronous context manager should be entered. When inside this context
67
+ manager, the engine can be accessed with `BaseSQLDB.engine`.
68
+
69
+ Upon entering, the DB class can then be used as an asynchronous context
70
+ manager to enter transactions. If an exception is raised the transaction is
71
+ rolled back automatically. If the inner context exits peacefully, the
72
+ transaction is committed automatically. When inside this context manager,
73
+ the DB connection can be accessed with `BaseSQLDB.conn`.
74
+
75
+ For example:
76
+
77
+ ```python
78
+ db_name = ...
79
+ url = BaseSQLDB.available_urls()[db_name]
80
+ MyDBClass = BaseSQLDB.available_implementations(db_name)[0]
81
+
82
+ db = MyDBClass(url)
83
+ async with db.engine_context():
84
+ async with db:
85
+ # Do something in the first transaction
86
+ # Commit will be called automatically
87
+
88
+ async with db:
89
+ # This transaction will be rolled back due to the exception
90
+ raise Exception(...)
91
+ ```
92
+ """
93
+
94
+ # engine: AsyncEngine
95
+ # TODO: Make metadata an abstract property
96
+ metadata: MetaData
97
+
98
+ def __init__(self, db_url: str) -> None:
99
+ # We use a ContextVar to make sure that self._conn
100
+ # is specific to each context, and avoid parallel
101
+ # route executions to overlap
102
+ self._conn: ContextVar[AsyncConnection | None] = ContextVar(
103
+ "_conn", default=None
104
+ )
105
+ self._db_url = db_url
106
+ self._engine: AsyncEngine | None = None
107
+
108
+ @classmethod
109
+ def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
110
+ """Return the available implementations of the DB in reverse priority order."""
111
+ db_classes: list[type[BaseSQLDB]] = [
112
+ entry_point.load()
113
+ for entry_point in select_from_extension(
114
+ group="diracx.dbs.sql", name=db_name
115
+ )
116
+ ]
117
+ if not db_classes:
118
+ raise NotImplementedError(f"Could not find any matches for {db_name=}")
119
+ return db_classes
120
+
121
+ @classmethod
122
+ def available_urls(cls) -> dict[str, str]:
123
+ """Return a dict of available database urls.
124
+
125
+ The list of available URLs is determined by environment variables
126
+ prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
127
+ """
128
+ db_urls: dict[str, str] = {}
129
+ for entry_point in select_from_extension(group="diracx.dbs.sql"):
130
+ db_name = entry_point.name
131
+ var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
132
+ if var_name in os.environ:
133
+ try:
134
+ db_url = os.environ[var_name]
135
+ if db_url == "sqlite+aiosqlite:///:memory:":
136
+ db_urls[db_name] = db_url
137
+ # pydantic does not allow for underscore in scheme
138
+ # so we do a special case
139
+ elif "_" in db_url.split(":")[0]:
140
+ # Validate the URL with a fake schema, and then store
141
+ # the original one
142
+ scheme_id = db_url.find(":")
143
+ fake_url = (
144
+ db_url[:scheme_id].replace("_", "-") + db_url[scheme_id:]
145
+ )
146
+ TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
147
+ db_urls[db_name] = db_url
148
+
149
+ else:
150
+ db_urls[db_name] = str(
151
+ TypeAdapter(SqlalchemyDsn).validate_python(db_url)
152
+ )
153
+ except Exception:
154
+ logger.error("Error loading URL for %s", db_name)
155
+ raise
156
+ return db_urls
157
+
158
+ @classmethod
159
+ async def post_create(cls, conn: AsyncConnection) -> None:
160
+ """Execute actions after the schema has been created."""
161
+ return
162
+
163
+ @classmethod
164
+ def transaction(cls) -> Self:
165
+ raise NotImplementedError("This should never be called")
166
+
167
+ @property
168
+ def engine(self) -> AsyncEngine:
169
+ """The engine to use for database operations.
170
+
171
+ It is normally not necessary to use the engine directly, unless you are
172
+ doing something special, like writing a test fixture that gives you a db.
173
+
174
+ Requires that the engine_context has been entered.
175
+ """
176
+ assert self._engine is not None, "engine_context must be entered"
177
+ return self._engine
178
+
179
+ @contextlib.asynccontextmanager
180
+ async def engine_context(self) -> AsyncIterator[None]:
181
+ """Context manage to manage the engine lifecycle.
182
+
183
+ This is called once at the application startup (see ``lifetime_functions``).
184
+ """
185
+ assert self._engine is None, "engine_context cannot be nested"
186
+
187
+ # Set the pool_recycle to 30mn
188
+ # That should prevent the problem of MySQL expiring connection
189
+ # after 60mn by default
190
+ engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
191
+ self._engine = engine
192
+ try:
193
+ yield
194
+ finally:
195
+ self._engine = None
196
+ await engine.dispose()
197
+
198
+ @property
199
+ def conn(self) -> AsyncConnection:
200
+ if self._conn.get() is None:
201
+ raise RuntimeError(f"{self.__class__} was used before entering")
202
+ return cast(AsyncConnection, self._conn.get())
203
+
204
+ async def __aenter__(self) -> Self:
205
+ """Create a connection.
206
+
207
+ This is called by the Dependency mechanism (see ``db_transaction``),
208
+ It will create a new connection/transaction for each route call.
209
+ """
210
+ assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
211
+ try:
212
+ self._conn.set(await self.engine.connect().__aenter__())
213
+ except Exception as e:
214
+ logger.warning(
215
+ "Database connection failed for %s: %s",
216
+ self.__class__.__name__,
217
+ e,
218
+ exc_info=True,
219
+ )
220
+ raise SQLDBUnavailableError(
221
+ f"Cannot connect to {self.__class__.__name__}"
222
+ ) from e
223
+
224
+ return self
225
+
226
+ async def __aexit__(self, exc_type, exc, tb):
227
+ """This is called when exiting a route.
228
+
229
+ If there was no exception, the changes in the DB are committed.
230
+ Otherwise, they are rolled back.
231
+ """
232
+ if exc_type is None:
233
+ await self._conn.get().commit()
234
+ await self._conn.get().__aexit__(exc_type, exc, tb)
235
+ self._conn.set(None)
236
+
237
+ async def ping(self):
238
+ """Check whether the connection to the DB is still working.
239
+
240
+ We could enable the ``pre_ping`` in the engine, but this would be ran at
241
+ every query.
242
+ """
243
+ try:
244
+ await self.conn.scalar(select(1))
245
+ except OperationalError as e:
246
+ raise SQLDBUnavailableError("Cannot ping the DB") from e
247
+
248
+ async def _search(
249
+ self,
250
+ table: Any,
251
+ parameters: list[str] | None,
252
+ search: list[SearchSpec],
253
+ sorts: list[SortSpec],
254
+ *,
255
+ distinct: bool = False,
256
+ per_page: int = 100,
257
+ page: int | None = None,
258
+ ) -> tuple[int, list[dict[str, Any]]]:
259
+ """Search for elements in a table."""
260
+ # Find which columns to select
261
+ columns = _get_columns(table.__table__, parameters)
262
+
263
+ stmt = select(*columns)
264
+
265
+ stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
266
+ stmt = apply_sort_constraints(table.__table__.columns.__getitem__, stmt, sorts)
267
+
268
+ if distinct:
269
+ stmt = stmt.distinct()
270
+
271
+ # Calculate total count before applying pagination
272
+ total_count_subquery = stmt.alias()
273
+ total_count_stmt = select(func.count()).select_from(total_count_subquery)
274
+ total = (await self.conn.execute(total_count_stmt)).scalar_one()
275
+
276
+ # Apply pagination
277
+ if page is not None:
278
+ if page < 1:
279
+ raise InvalidQueryError("Page must be a positive integer")
280
+ if per_page < 1:
281
+ raise InvalidQueryError("Per page must be a positive integer")
282
+ stmt = stmt.offset((page - 1) * per_page).limit(per_page)
283
+
284
+ # Execute the query
285
+ return total, [
286
+ dict(row._mapping) async for row in (await self.conn.stream(stmt))
287
+ ]
288
+
289
+ async def _summary(
290
+ self, table: Any, group_by: list[str], search: list[SearchSpec]
291
+ ) -> list[dict[str, str | int]]:
292
+ """Get a summary of the elements of a table."""
293
+ columns = _get_columns(table.__table__, group_by)
294
+
295
+ pk_columns = list(table.__table__.primary_key.columns)
296
+ if not pk_columns:
297
+ raise ValueError(
298
+ "Model has no primary key and no count_column was provided."
299
+ )
300
+ count_col = pk_columns[0]
301
+
302
+ stmt = select(*columns, func.count(count_col).label("count"))
303
+ stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
304
+ stmt = stmt.group_by(*columns)
305
+
306
+ # Execute the query
307
+ return [
308
+ dict(row._mapping)
309
+ async for row in (await self.conn.stream(stmt))
310
+ if row.count > 0 # type: ignore
311
+ ]
312
+
313
+
314
+ def find_time_resolution(value):
315
+ if isinstance(value, datetime):
316
+ return None, value
317
+ if match := re.fullmatch(
318
+ r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{1,6}Z?)?)?)?)?)?)?", value
319
+ ):
320
+ if match.group(6):
321
+ precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
322
+ elif match.group(5):
323
+ precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
324
+ elif match.group(3):
325
+ precision, pattern = "HOUR", r"\1-\2-\3 \4"
326
+ elif match.group(2):
327
+ precision, pattern = "DAY", r"\1-\2-\3"
328
+ elif match.group(1):
329
+ precision, pattern = "MONTH", r"\1-\2"
330
+ else:
331
+ precision, pattern = "YEAR", r"\1"
332
+ return (
333
+ precision,
334
+ re.sub(
335
+ r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{1,6})?Z?$",
336
+ pattern,
337
+ value,
338
+ ),
339
+ )
340
+
341
+ raise InvalidQueryError(f"Cannot parse {value=}")
342
+
343
+
344
+ def _get_columns(table, parameters):
345
+ columns = [x for x in table.columns]
346
+ if parameters:
347
+ if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
348
+ raise InvalidQueryError(
349
+ f"Unrecognised parameters requested {unrecognised_parameters}"
350
+ )
351
+ columns = [c for c in columns if c.name in parameters]
352
+ return columns
353
+
354
+
355
+ def apply_search_filters(column_mapping, stmt, search):
356
+ for query in search:
357
+ try:
358
+ column = column_mapping(query["parameter"])
359
+ except KeyError as e:
360
+ raise InvalidQueryError(f"Unknown column {query['parameter']}") from e
361
+
362
+ if isinstance(column.type, (DateTime, SmarterDateTime)):
363
+ if "value" in query and isinstance(query["value"], str):
364
+ resolution, value = find_time_resolution(query["value"])
365
+ if resolution:
366
+ column = date_trunc(column, time_resolution=resolution)
367
+ query["value"] = value
368
+
369
+ if query.get("values"):
370
+ resolutions, values = zip(
371
+ *map(find_time_resolution, query.get("values"))
372
+ )
373
+ if len(set(resolutions)) != 1:
374
+ raise InvalidQueryError(
375
+ f"Cannot mix different time resolutions in {query=}"
376
+ )
377
+ if resolution := resolutions[0]:
378
+ column = date_trunc(column, time_resolution=resolution)
379
+ query["values"] = values
380
+
381
+ if query["operator"] == "eq":
382
+ expr = column == query["value"]
383
+ elif query["operator"] == "neq":
384
+ expr = column != query["value"]
385
+ elif query["operator"] == "gt":
386
+ expr = column > query["value"]
387
+ elif query["operator"] == "lt":
388
+ expr = column < query["value"]
389
+ elif query["operator"] == "in":
390
+ expr = column.in_(query["values"])
391
+ elif query["operator"] == "not in":
392
+ expr = column.notin_(query["values"])
393
+ elif query["operator"] in "like":
394
+ expr = column.like(query["value"])
395
+ elif query["operator"] in "ilike":
396
+ expr = column.ilike(query["value"])
397
+ elif query["operator"] == "not like":
398
+ expr = column.not_like(query["value"])
399
+ elif query["operator"] == "regex":
400
+ # We check the regex validity here
401
+ try:
402
+ re.compile(query["value"])
403
+ except re.error as e:
404
+ raise InvalidQueryError(f"Invalid regex {query['value']}") from e
405
+ expr = column.regexp_match(query["value"])
406
+ else:
407
+ raise InvalidQueryError(f"Unknown filter {query=}")
408
+ stmt = stmt.where(expr)
409
+ return stmt
410
+
411
+
412
+ def apply_sort_constraints(column_mapping, stmt, sorts):
413
+ sort_columns = []
414
+ for sort in sorts or []:
415
+ try:
416
+ column = column_mapping(sort["parameter"])
417
+ except KeyError as e:
418
+ raise InvalidQueryError(
419
+ f"Cannot sort by {sort['parameter']}: unknown column"
420
+ ) from e
421
+ sorted_column = None
422
+ if sort["direction"] == SortDirection.ASC:
423
+ sorted_column = column.asc()
424
+ elif sort["direction"] == SortDirection.DESC:
425
+ sorted_column = column.desc()
426
+ else:
427
+ raise InvalidQueryError(f"Unknown sort {sort['direction']=}")
428
+ sort_columns.append(sorted_column)
429
+ if sort_columns:
430
+ stmt = stmt.order_by(*sort_columns)
431
+ return stmt
432
+
433
+
434
+ def uuid7_to_datetime(uuid: UUID | StdUUID | str) -> datetime:
435
+ """Convert a UUIDv7 to a datetime."""
436
+ if isinstance(uuid, StdUUID):
437
+ # Convert stdlib UUID to uuid_utils.UUID
438
+ uuid = UUID(str(uuid))
439
+ elif not isinstance(uuid, UUID):
440
+ # Convert string or other types to uuid_utils.UUID
441
+ uuid = UUID(uuid)
442
+ if uuid.version != 7:
443
+ raise ValueError(f"UUID {uuid} is not a UUIDv7")
444
+ return datetime.fromtimestamp(uuid.timestamp / 1000.0, tz=timezone.utc)
445
+
446
+
447
+ def uuid7_from_datetime(dt: datetime, *, randomize: bool = True) -> UUID:
448
+ """Generate a UUIDv7 corresponding to the given datetime.
449
+
450
+ If randomize is True, the standard uuid7 function is used resulting in the
451
+ lowest 62-bits being random. If randomize is False, the UUIDv7 will be the
452
+ lowest possible UUIDv7 for the given datetime.
453
+ """
454
+ timestamp = dt.timestamp()
455
+ if randomize:
456
+ uuid = uuid7(int(timestamp), int((timestamp % 1) * 1e9))
457
+ else:
458
+ time_high = int(timestamp * 1000) >> 16
459
+ time_low = int(timestamp * 1000) & 0xFFFF
460
+ uuid = UUID.from_fields((time_high, time_low, 0x7000, 0x80, 0, 0))
461
+ return uuid
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import TYPE_CHECKING
6
+
7
+ from sqlalchemy import DateTime, func
8
+ from sqlalchemy.ext.compiler import compiles
9
+ from sqlalchemy.sql import expression
10
+
11
+ if TYPE_CHECKING:
12
+ from sqlalchemy.types import TypeEngine
13
+
14
+
15
+ class utcnow(expression.FunctionElement): # noqa: N801
16
+ type: TypeEngine = DateTime()
17
+ inherit_cache: bool = True
18
+
19
+
20
+ @compiles(utcnow, "postgresql")
21
+ def pg_utcnow(element, compiler, **kw) -> str:
22
+ return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
23
+
24
+
25
+ @compiles(utcnow, "mssql")
26
+ def ms_utcnow(element, compiler, **kw) -> str:
27
+ return "GETUTCDATE()"
28
+
29
+
30
+ @compiles(utcnow, "mysql")
31
+ def mysql_utcnow(element, compiler, **kw) -> str:
32
+ return "(UTC_TIMESTAMP)"
33
+
34
+
35
+ @compiles(utcnow, "sqlite")
36
+ def sqlite_utcnow(element, compiler, **kw) -> str:
37
+ return "DATETIME('now')"
38
+
39
+
40
+ class date_trunc(expression.FunctionElement): # noqa: N801
41
+ """Sqlalchemy function to truncate a date to a given resolution.
42
+
43
+ Primarily used to be able to query for a specific resolution of a date e.g.
44
+
45
+ select * from table where date_trunc('day', date_column) = '2021-01-01'
46
+ select * from table where date_trunc('year', date_column) = '2021'
47
+ select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
48
+ """
49
+
50
+ type = DateTime()
51
+ # Cache does not work as intended with time resolution values, so we disable it
52
+ inherit_cache = False
53
+
54
+ def __init__(self, *args, time_resolution, **kwargs) -> None:
55
+ super().__init__(*args, **kwargs)
56
+ self._time_resolution = time_resolution
57
+
58
+
59
+ @compiles(date_trunc, "postgresql")
60
+ def pg_date_trunc(element, compiler, **kw):
61
+ res = {
62
+ "SECOND": "second",
63
+ "MINUTE": "minute",
64
+ "HOUR": "hour",
65
+ "DAY": "day",
66
+ "MONTH": "month",
67
+ "YEAR": "year",
68
+ }[element._time_resolution]
69
+ return f"date_trunc('{res}', {compiler.process(element.clauses)})"
70
+
71
+
72
+ @compiles(date_trunc, "mysql")
73
+ def mysql_date_trunc(element, compiler, **kw):
74
+ pattern = {
75
+ "SECOND": "%Y-%m-%d %H:%i:%S",
76
+ "MINUTE": "%Y-%m-%d %H:%i",
77
+ "HOUR": "%Y-%m-%d %H",
78
+ "DAY": "%Y-%m-%d",
79
+ "MONTH": "%Y-%m",
80
+ "YEAR": "%Y",
81
+ }[element._time_resolution]
82
+
83
+ (dt_col,) = list(element.clauses)
84
+ return compiler.process(func.date_format(dt_col, pattern))
85
+
86
+
87
+ @compiles(date_trunc, "sqlite")
88
+ def sqlite_date_trunc(element, compiler, **kw):
89
+ pattern = {
90
+ "SECOND": "%Y-%m-%d %H:%M:%S",
91
+ "MINUTE": "%Y-%m-%d %H:%M",
92
+ "HOUR": "%Y-%m-%d %H",
93
+ "DAY": "%Y-%m-%d",
94
+ "MONTH": "%Y-%m",
95
+ "YEAR": "%Y",
96
+ }[element._time_resolution]
97
+ (dt_col,) = list(element.clauses)
98
+ return compiler.process(
99
+ func.strftime(
100
+ pattern,
101
+ dt_col,
102
+ )
103
+ )
104
+
105
+
106
+ class days_since(expression.FunctionElement): # noqa: N801
107
+ """Sqlalchemy function to get the number of days since a given date.
108
+
109
+ Primarily used to be able to query for a specific resolution of a date e.g.
110
+
111
+ select * from table where days_since(date_column) = 0
112
+ select * from table where days_since(date_column) = 1
113
+ """
114
+
115
+ type = DateTime()
116
+ inherit_cache = False
117
+
118
+ def __init__(self, *args, **kwargs) -> None:
119
+ super().__init__(*args, **kwargs)
120
+
121
+
122
+ @compiles(days_since, "postgresql")
123
+ def pg_days_since(element, compiler, **kw):
124
+ return f"EXTRACT(DAY FROM (now() - {compiler.process(element.clauses)}))"
125
+
126
+
127
+ @compiles(days_since, "mysql")
128
+ def mysql_days_since(element, compiler, **kw):
129
+ return f"DATEDIFF(NOW(), {compiler.process(element.clauses)})"
130
+
131
+
132
+ @compiles(days_since, "sqlite")
133
+ def sqlite_days_since(element, compiler, **kw):
134
+ return f"julianday('now') - julianday({compiler.process(element.clauses)})"
135
+
136
+
137
+ def substract_date(**kwargs: float) -> datetime:
138
+ return datetime.now(tz=timezone.utc) - timedelta(**kwargs)
139
+
140
+
141
+ def hash(code: str):
142
+ return hashlib.sha256(code.encode()).hexdigest()