diracx-db 0.0.1a17__py3-none-any.whl → 0.0.1a18__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diracx/db/os/utils.py +60 -11
- diracx/db/sql/__init__.py +3 -1
- diracx/db/sql/auth/db.py +10 -19
- diracx/db/sql/auth/schema.py +5 -7
- diracx/db/sql/dummy/db.py +2 -3
- diracx/db/sql/{jobs → job}/db.py +12 -452
- diracx/db/sql/{jobs → job}/schema.py +2 -118
- diracx/db/sql/job_logging/__init__.py +0 -0
- diracx/db/sql/job_logging/db.py +161 -0
- diracx/db/sql/job_logging/schema.py +25 -0
- diracx/db/sql/sandbox_metadata/db.py +12 -10
- diracx/db/sql/task_queue/__init__.py +0 -0
- diracx/db/sql/task_queue/db.py +261 -0
- diracx/db/sql/task_queue/schema.py +109 -0
- diracx/db/sql/utils/__init__.py +418 -0
- diracx/db/sql/{jobs/status_utility.py → utils/job_status.py} +11 -18
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a18.dist-info}/METADATA +5 -5
- diracx_db-0.0.1a18.dist-info/RECORD +33 -0
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a18.dist-info}/WHEEL +1 -1
- diracx/db/sql/utils.py +0 -236
- diracx_db-0.0.1a17.dist-info/RECORD +0 -27
- /diracx/db/sql/{jobs → job}/__init__.py +0 -0
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a18.dist-info}/entry_points.txt +0 -0
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a18.dist-info}/top_level.txt +0 -0
diracx/db/os/utils.py
CHANGED
@@ -7,9 +7,10 @@ import json
|
|
7
7
|
import logging
|
8
8
|
import os
|
9
9
|
from abc import ABCMeta, abstractmethod
|
10
|
+
from collections.abc import AsyncIterator
|
10
11
|
from contextvars import ContextVar
|
11
12
|
from datetime import datetime
|
12
|
-
from typing import Any,
|
13
|
+
from typing import Any, Self
|
13
14
|
|
14
15
|
from opensearchpy import AsyncOpenSearch
|
15
16
|
|
@@ -29,6 +30,48 @@ class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
|
|
29
30
|
|
30
31
|
|
31
32
|
class BaseOSDB(metaclass=ABCMeta):
|
33
|
+
"""This should be the base class of all the OpenSearch DiracX DBs.
|
34
|
+
|
35
|
+
The details covered here should be handled automatically by the service and
|
36
|
+
task machinery of DiracX and this documentation exists for informational
|
37
|
+
purposes.
|
38
|
+
|
39
|
+
The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`.
|
40
|
+
This method returns a dictionary of database names to connection parameters.
|
41
|
+
The available databases are determined by the `diracx.db.os` entrypoint in
|
42
|
+
the `pyproject.toml` file and the connection parameters are taken from the
|
43
|
+
environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
|
44
|
+
|
45
|
+
If extensions to DiracX are being used, there can be multiple implementations
|
46
|
+
of the same database. To list the available implementations use
|
47
|
+
`BaseOSDB.available_implementations(db_name)`. The first entry in this list
|
48
|
+
will be the preferred implementation and it can be initialized by calling
|
49
|
+
its `__init__` function with the connection parameters previously obtained
|
50
|
+
from `BaseOSDB.available_urls`.
|
51
|
+
|
52
|
+
To control the lifetime of the OpenSearch client, the `BaseOSDB.client_context`
|
53
|
+
asynchronous context manager should be entered. When inside this context
|
54
|
+
manager, the client can be accessed with `BaseOSDB.client`.
|
55
|
+
|
56
|
+
Upon entering, the DB class can then be used as an asynchronous context
|
57
|
+
manager to perform operations. Currently this context manager has no effect
|
58
|
+
however it must be used as it may be used in future. When inside this
|
59
|
+
context manager, the DB connection can be accessed with `BaseOSDB.client`.
|
60
|
+
|
61
|
+
For example:
|
62
|
+
|
63
|
+
```python
|
64
|
+
db_name = ...
|
65
|
+
conn_params = BaseOSDB.available_urls()[db_name]
|
66
|
+
MyDBClass = BaseOSDB.available_implementations(db_name)[0]
|
67
|
+
|
68
|
+
db = MyDBClass(conn_params)
|
69
|
+
async with db.client_context:
|
70
|
+
async with db:
|
71
|
+
# Do something with the OpenSearch client
|
72
|
+
```
|
73
|
+
"""
|
74
|
+
|
32
75
|
# TODO: Make metadata an abstract property
|
33
76
|
fields: dict
|
34
77
|
index_prefix: str
|
@@ -77,13 +120,15 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
77
120
|
@classmethod
|
78
121
|
def session(cls) -> Self:
|
79
122
|
"""This is just a fake method such that the Dependency overwrite has
|
80
|
-
a hash to use
|
123
|
+
a hash to use.
|
124
|
+
"""
|
81
125
|
raise NotImplementedError("This should never be called")
|
82
126
|
|
83
127
|
@property
|
84
128
|
def client(self) -> AsyncOpenSearch:
|
85
129
|
"""Just a getter for _client, making sure we entered
|
86
|
-
the context manager
|
130
|
+
the context manager.
|
131
|
+
"""
|
87
132
|
if self._client is None:
|
88
133
|
raise RuntimeError(f"{self.__class__} was used before entering")
|
89
134
|
return self._client
|
@@ -91,17 +136,18 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
91
136
|
@contextlib.asynccontextmanager
|
92
137
|
async def client_context(self) -> AsyncIterator[None]:
|
93
138
|
"""Context manage to manage the client lifecycle.
|
94
|
-
This is called when starting fastapi
|
139
|
+
This is called when starting fastapi.
|
95
140
|
|
96
141
|
"""
|
97
142
|
assert self._client is None, "client_context cannot be nested"
|
98
143
|
async with AsyncOpenSearch(**self._connection_kwargs) as self._client:
|
99
|
-
|
100
|
-
|
144
|
+
try:
|
145
|
+
yield
|
146
|
+
finally:
|
147
|
+
self._client = None
|
101
148
|
|
102
149
|
async def ping(self):
|
103
|
-
"""
|
104
|
-
Check whether the connection to the DB is still working.
|
150
|
+
"""Check whether the connection to the DB is still working.
|
105
151
|
We could enable the ``pre_ping`` in the engine, but this would
|
106
152
|
be ran at every query.
|
107
153
|
"""
|
@@ -113,7 +159,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
113
159
|
async def __aenter__(self):
|
114
160
|
"""This is entered on every request.
|
115
161
|
At the moment it does nothing, however, we keep it here
|
116
|
-
in case we ever want to use OpenSearch equivalent of a transaction
|
162
|
+
in case we ever want to use OpenSearch equivalent of a transaction.
|
117
163
|
"""
|
118
164
|
assert not self._conn.get(), "BaseOSDB context cannot be nested"
|
119
165
|
assert self._client is not None, "client_context hasn't been entered"
|
@@ -122,9 +168,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
122
168
|
|
123
169
|
async def __aexit__(self, exc_type, exc, tb):
|
124
170
|
assert self._conn.get()
|
125
|
-
self._client = None
|
126
171
|
self._conn.set(False)
|
127
|
-
return
|
128
172
|
|
129
173
|
async def create_index_template(self) -> None:
|
130
174
|
template_body = {
|
@@ -237,6 +281,11 @@ def apply_search_filters(db_fields, search):
|
|
237
281
|
operator, field_name, field_type, {"keyword", "long", "date"}
|
238
282
|
)
|
239
283
|
result["must"].append({"terms": {field_name: query["values"]}})
|
284
|
+
case "not in":
|
285
|
+
require_type(
|
286
|
+
operator, field_name, field_type, {"keyword", "long", "date"}
|
287
|
+
)
|
288
|
+
result["must_not"].append({"terms": {field_name: query["values"]}})
|
240
289
|
# TODO: Implement like and ilike
|
241
290
|
# If the pattern is a simple "col like 'abc%'", we can use a prefix query
|
242
291
|
# Else we need to use a wildcard query where we replace % with * and _ with ?
|
diracx/db/sql/__init__.py
CHANGED
@@ -3,5 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB")
|
4
4
|
|
5
5
|
from .auth.db import AuthDB
|
6
|
-
from .
|
6
|
+
from .job.db import JobDB
|
7
|
+
from .job_logging.db import JobLoggingDB
|
7
8
|
from .sandbox_metadata.db import SandboxMetadataDB
|
9
|
+
from .task_queue.db import TaskQueueDB
|
diracx/db/sql/auth/db.py
CHANGED
@@ -35,7 +35,7 @@ class AuthDB(BaseSQLDB):
|
|
35
35
|
async def device_flow_validate_user_code(
|
36
36
|
self, user_code: str, max_validity: int
|
37
37
|
) -> str:
|
38
|
-
"""Validate that the user_code can be used (Pending status, not expired)
|
38
|
+
"""Validate that the user_code can be used (Pending status, not expired).
|
39
39
|
|
40
40
|
Returns the scope field for the given user_code
|
41
41
|
|
@@ -51,9 +51,7 @@ class AuthDB(BaseSQLDB):
|
|
51
51
|
return (await self.conn.execute(stmt)).scalar_one()
|
52
52
|
|
53
53
|
async def get_device_flow(self, device_code: str, max_validity: int):
|
54
|
-
"""
|
55
|
-
:raises: NoResultFound
|
56
|
-
"""
|
54
|
+
""":raises: NoResultFound"""
|
57
55
|
# The with_for_update
|
58
56
|
# prevents that the token is retrieved
|
59
57
|
# multiple time concurrently
|
@@ -94,9 +92,7 @@ class AuthDB(BaseSQLDB):
|
|
94
92
|
async def device_flow_insert_id_token(
|
95
93
|
self, user_code: str, id_token: dict[str, str], max_validity: int
|
96
94
|
) -> None:
|
97
|
-
"""
|
98
|
-
:raises: AuthorizationError if no such code or status not pending
|
99
|
-
"""
|
95
|
+
""":raises: AuthorizationError if no such code or status not pending"""
|
100
96
|
stmt = update(DeviceFlows)
|
101
97
|
stmt = stmt.where(
|
102
98
|
DeviceFlows.user_code == user_code,
|
@@ -170,11 +166,9 @@ class AuthDB(BaseSQLDB):
|
|
170
166
|
async def authorization_flow_insert_id_token(
|
171
167
|
self, uuid: str, id_token: dict[str, str], max_validity: int
|
172
168
|
) -> tuple[str, str]:
|
169
|
+
"""Returns code, redirect_uri
|
170
|
+
:raises: AuthorizationError if no such uuid or status not pending.
|
173
171
|
"""
|
174
|
-
returns code, redirect_uri
|
175
|
-
:raises: AuthorizationError if no such uuid or status not pending
|
176
|
-
"""
|
177
|
-
|
178
172
|
# Hash the code to avoid leaking information
|
179
173
|
code = secrets.token_urlsafe()
|
180
174
|
hashed_code = hashlib.sha256(code.encode()).hexdigest()
|
@@ -232,8 +226,7 @@ class AuthDB(BaseSQLDB):
|
|
232
226
|
preferred_username: str,
|
233
227
|
scope: str,
|
234
228
|
) -> tuple[str, datetime]:
|
235
|
-
"""
|
236
|
-
Insert a refresh token in the DB as well as user attributes
|
229
|
+
"""Insert a refresh token in the DB as well as user attributes
|
237
230
|
required to generate access tokens.
|
238
231
|
"""
|
239
232
|
# Generate a JWT ID
|
@@ -257,9 +250,7 @@ class AuthDB(BaseSQLDB):
|
|
257
250
|
return jti, row.creation_time
|
258
251
|
|
259
252
|
async def get_refresh_token(self, jti: str) -> dict:
|
260
|
-
"""
|
261
|
-
Get refresh token details bound to a given JWT ID
|
262
|
-
"""
|
253
|
+
"""Get refresh token details bound to a given JWT ID."""
|
263
254
|
# The with_for_update
|
264
255
|
# prevents that the token is retrieved
|
265
256
|
# multiple time concurrently
|
@@ -275,7 +266,7 @@ class AuthDB(BaseSQLDB):
|
|
275
266
|
return res
|
276
267
|
|
277
268
|
async def get_user_refresh_tokens(self, subject: str | None = None) -> list[dict]:
|
278
|
-
"""Get a list of refresh token details based on a subject ID (not revoked)"""
|
269
|
+
"""Get a list of refresh token details based on a subject ID (not revoked)."""
|
279
270
|
# Get a list of refresh tokens
|
280
271
|
stmt = select(RefreshTokens).with_for_update()
|
281
272
|
|
@@ -295,7 +286,7 @@ class AuthDB(BaseSQLDB):
|
|
295
286
|
return refresh_tokens
|
296
287
|
|
297
288
|
async def revoke_refresh_token(self, jti: str):
|
298
|
-
"""Revoke a token given by its JWT ID"""
|
289
|
+
"""Revoke a token given by its JWT ID."""
|
299
290
|
await self.conn.execute(
|
300
291
|
update(RefreshTokens)
|
301
292
|
.where(RefreshTokens.jti == jti)
|
@@ -303,7 +294,7 @@ class AuthDB(BaseSQLDB):
|
|
303
294
|
)
|
304
295
|
|
305
296
|
async def revoke_user_refresh_tokens(self, subject):
|
306
|
-
"""Revoke all the refresh tokens belonging to a user (subject ID)"""
|
297
|
+
"""Revoke all the refresh tokens belonging to a user (subject ID)."""
|
307
298
|
await self.conn.execute(
|
308
299
|
update(RefreshTokens)
|
309
300
|
.where(RefreshTokens.sub == subject)
|
diracx/db/sql/auth/schema.py
CHANGED
@@ -15,12 +15,11 @@ Base = declarative_base()
|
|
15
15
|
|
16
16
|
|
17
17
|
class FlowStatus(Enum):
|
18
|
-
"""
|
19
|
-
The normal flow is
|
18
|
+
"""The normal flow is
|
20
19
|
PENDING -> READY -> DONE
|
21
20
|
Pending is upon insertion
|
22
21
|
Ready/Error is set in response to IdP
|
23
|
-
Done means the user has been issued the dirac token
|
22
|
+
Done means the user has been issued the dirac token.
|
24
23
|
"""
|
25
24
|
|
26
25
|
# The flow is ongoing
|
@@ -64,9 +63,8 @@ class AuthorizationFlows(Base):
|
|
64
63
|
|
65
64
|
|
66
65
|
class RefreshTokenStatus(Enum):
|
67
|
-
"""
|
68
|
-
|
69
|
-
CREATED -> REVOKED
|
66
|
+
"""The normal flow is
|
67
|
+
CREATED -> REVOKED.
|
70
68
|
|
71
69
|
Note1: There is no EXPIRED status as it can be calculated from a creation time
|
72
70
|
Note2: As part of the refresh token rotation mechanism, the revoked token should be retained
|
@@ -82,7 +80,7 @@ class RefreshTokenStatus(Enum):
|
|
82
80
|
|
83
81
|
class RefreshTokens(Base):
|
84
82
|
"""Store attributes bound to a refresh token, as well as specific user attributes
|
85
|
-
that might be then used to generate access tokens
|
83
|
+
that might be then used to generate access tokens.
|
86
84
|
"""
|
87
85
|
|
88
86
|
__tablename__ = "RefreshTokens"
|
diracx/db/sql/dummy/db.py
CHANGED
@@ -11,8 +11,7 @@ from .schema import Cars, Owners
|
|
11
11
|
|
12
12
|
|
13
13
|
class DummyDB(BaseSQLDB):
|
14
|
-
"""
|
15
|
-
This DummyDB is just to illustrate some important aspect of writing
|
14
|
+
"""This DummyDB is just to illustrate some important aspect of writing
|
16
15
|
DB classes in DiracX.
|
17
16
|
|
18
17
|
It is mostly pure SQLAlchemy, with a few convention
|
@@ -27,7 +26,7 @@ class DummyDB(BaseSQLDB):
|
|
27
26
|
columns = [Cars.__table__.columns[x] for x in group_by]
|
28
27
|
|
29
28
|
stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
|
30
|
-
stmt = apply_search_filters(Cars.__table__, stmt, search)
|
29
|
+
stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
|
31
30
|
stmt = stmt.group_by(*columns)
|
32
31
|
|
33
32
|
# Execute the query
|