diracx-db 0.0.1a17__py3-none-any.whl → 0.0.1a19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diracx/db/__main__.py +1 -0
- diracx/db/os/utils.py +60 -11
- diracx/db/sql/__init__.py +12 -2
- diracx/db/sql/auth/db.py +10 -19
- diracx/db/sql/auth/schema.py +5 -7
- diracx/db/sql/dummy/db.py +2 -3
- diracx/db/sql/{jobs → job}/db.py +12 -452
- diracx/db/sql/job/schema.py +129 -0
- diracx/db/sql/job_logging/__init__.py +0 -0
- diracx/db/sql/job_logging/db.py +161 -0
- diracx/db/sql/job_logging/schema.py +25 -0
- diracx/db/sql/pilot_agents/__init__.py +0 -0
- diracx/db/sql/pilot_agents/db.py +46 -0
- diracx/db/sql/pilot_agents/schema.py +58 -0
- diracx/db/sql/sandbox_metadata/db.py +12 -10
- diracx/db/sql/task_queue/__init__.py +0 -0
- diracx/db/sql/task_queue/db.py +261 -0
- diracx/db/sql/task_queue/schema.py +109 -0
- diracx/db/sql/utils/__init__.py +445 -0
- diracx/db/sql/{jobs/status_utility.py → utils/job_status.py} +11 -18
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/METADATA +5 -5
- diracx_db-0.0.1a19.dist-info/RECORD +36 -0
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/WHEEL +1 -1
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/entry_points.txt +1 -0
- diracx/db/sql/jobs/schema.py +0 -290
- diracx/db/sql/utils.py +0 -236
- diracx_db-0.0.1a17.dist-info/RECORD +0 -27
- /diracx/db/sql/{jobs → job}/__init__.py +0 -0
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/top_level.txt +0 -0
diracx/db/__main__.py
CHANGED
@@ -31,6 +31,7 @@ async def init_sql():
|
|
31
31
|
from diracx.db.sql.utils import BaseSQLDB
|
32
32
|
|
33
33
|
for db_name, db_url in BaseSQLDB.available_urls().items():
|
34
|
+
|
34
35
|
logger.info("Initialising %s", db_name)
|
35
36
|
db = BaseSQLDB.available_implementations(db_name)[0](db_url)
|
36
37
|
async with db.engine_context():
|
diracx/db/os/utils.py
CHANGED
@@ -7,9 +7,10 @@ import json
|
|
7
7
|
import logging
|
8
8
|
import os
|
9
9
|
from abc import ABCMeta, abstractmethod
|
10
|
+
from collections.abc import AsyncIterator
|
10
11
|
from contextvars import ContextVar
|
11
12
|
from datetime import datetime
|
12
|
-
from typing import Any,
|
13
|
+
from typing import Any, Self
|
13
14
|
|
14
15
|
from opensearchpy import AsyncOpenSearch
|
15
16
|
|
@@ -29,6 +30,48 @@ class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
|
|
29
30
|
|
30
31
|
|
31
32
|
class BaseOSDB(metaclass=ABCMeta):
|
33
|
+
"""This should be the base class of all the OpenSearch DiracX DBs.
|
34
|
+
|
35
|
+
The details covered here should be handled automatically by the service and
|
36
|
+
task machinery of DiracX and this documentation exists for informational
|
37
|
+
purposes.
|
38
|
+
|
39
|
+
The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`.
|
40
|
+
This method returns a dictionary of database names to connection parameters.
|
41
|
+
The available databases are determined by the `diracx.db.os` entrypoint in
|
42
|
+
the `pyproject.toml` file and the connection parameters are taken from the
|
43
|
+
environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
|
44
|
+
|
45
|
+
If extensions to DiracX are being used, there can be multiple implementations
|
46
|
+
of the same database. To list the available implementations use
|
47
|
+
`BaseOSDB.available_implementations(db_name)`. The first entry in this list
|
48
|
+
will be the preferred implementation and it can be initialized by calling
|
49
|
+
its `__init__` function with the connection parameters previously obtained
|
50
|
+
from `BaseOSDB.available_urls`.
|
51
|
+
|
52
|
+
To control the lifetime of the OpenSearch client, the `BaseOSDB.client_context`
|
53
|
+
asynchronous context manager should be entered. When inside this context
|
54
|
+
manager, the client can be accessed with `BaseOSDB.client`.
|
55
|
+
|
56
|
+
Upon entering, the DB class can then be used as an asynchronous context
|
57
|
+
manager to perform operations. Currently this context manager has no effect
|
58
|
+
however it must be used as it may be used in future. When inside this
|
59
|
+
context manager, the DB connection can be accessed with `BaseOSDB.client`.
|
60
|
+
|
61
|
+
For example:
|
62
|
+
|
63
|
+
```python
|
64
|
+
db_name = ...
|
65
|
+
conn_params = BaseOSDB.available_urls()[db_name]
|
66
|
+
MyDBClass = BaseOSDB.available_implementations(db_name)[0]
|
67
|
+
|
68
|
+
db = MyDBClass(conn_params)
|
69
|
+
async with db.client_context:
|
70
|
+
async with db:
|
71
|
+
# Do something with the OpenSearch client
|
72
|
+
```
|
73
|
+
"""
|
74
|
+
|
32
75
|
# TODO: Make metadata an abstract property
|
33
76
|
fields: dict
|
34
77
|
index_prefix: str
|
@@ -77,13 +120,15 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
77
120
|
@classmethod
|
78
121
|
def session(cls) -> Self:
|
79
122
|
"""This is just a fake method such that the Dependency overwrite has
|
80
|
-
a hash to use
|
123
|
+
a hash to use.
|
124
|
+
"""
|
81
125
|
raise NotImplementedError("This should never be called")
|
82
126
|
|
83
127
|
@property
|
84
128
|
def client(self) -> AsyncOpenSearch:
|
85
129
|
"""Just a getter for _client, making sure we entered
|
86
|
-
the context manager
|
130
|
+
the context manager.
|
131
|
+
"""
|
87
132
|
if self._client is None:
|
88
133
|
raise RuntimeError(f"{self.__class__} was used before entering")
|
89
134
|
return self._client
|
@@ -91,17 +136,18 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
91
136
|
@contextlib.asynccontextmanager
|
92
137
|
async def client_context(self) -> AsyncIterator[None]:
|
93
138
|
"""Context manage to manage the client lifecycle.
|
94
|
-
This is called when starting fastapi
|
139
|
+
This is called when starting fastapi.
|
95
140
|
|
96
141
|
"""
|
97
142
|
assert self._client is None, "client_context cannot be nested"
|
98
143
|
async with AsyncOpenSearch(**self._connection_kwargs) as self._client:
|
99
|
-
|
100
|
-
|
144
|
+
try:
|
145
|
+
yield
|
146
|
+
finally:
|
147
|
+
self._client = None
|
101
148
|
|
102
149
|
async def ping(self):
|
103
|
-
"""
|
104
|
-
Check whether the connection to the DB is still working.
|
150
|
+
"""Check whether the connection to the DB is still working.
|
105
151
|
We could enable the ``pre_ping`` in the engine, but this would
|
106
152
|
be ran at every query.
|
107
153
|
"""
|
@@ -113,7 +159,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
113
159
|
async def __aenter__(self):
|
114
160
|
"""This is entered on every request.
|
115
161
|
At the moment it does nothing, however, we keep it here
|
116
|
-
in case we ever want to use OpenSearch equivalent of a transaction
|
162
|
+
in case we ever want to use OpenSearch equivalent of a transaction.
|
117
163
|
"""
|
118
164
|
assert not self._conn.get(), "BaseOSDB context cannot be nested"
|
119
165
|
assert self._client is not None, "client_context hasn't been entered"
|
@@ -122,9 +168,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
122
168
|
|
123
169
|
async def __aexit__(self, exc_type, exc, tb):
|
124
170
|
assert self._conn.get()
|
125
|
-
self._client = None
|
126
171
|
self._conn.set(False)
|
127
|
-
return
|
128
172
|
|
129
173
|
async def create_index_template(self) -> None:
|
130
174
|
template_body = {
|
@@ -237,6 +281,11 @@ def apply_search_filters(db_fields, search):
|
|
237
281
|
operator, field_name, field_type, {"keyword", "long", "date"}
|
238
282
|
)
|
239
283
|
result["must"].append({"terms": {field_name: query["values"]}})
|
284
|
+
case "not in":
|
285
|
+
require_type(
|
286
|
+
operator, field_name, field_type, {"keyword", "long", "date"}
|
287
|
+
)
|
288
|
+
result["must_not"].append({"terms": {field_name: query["values"]}})
|
240
289
|
# TODO: Implement like and ilike
|
241
290
|
# If the pattern is a simple "col like 'abc%'", we can use a prefix query
|
242
291
|
# Else we need to use a wildcard query where we replace % with * and _ with ?
|
diracx/db/sql/__init__.py
CHANGED
@@ -1,7 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
__all__ = (
|
3
|
+
__all__ = (
|
4
|
+
"AuthDB",
|
5
|
+
"JobDB",
|
6
|
+
"JobLoggingDB",
|
7
|
+
"PilotAgentsDB",
|
8
|
+
"SandboxMetadataDB",
|
9
|
+
"TaskQueueDB",
|
10
|
+
)
|
4
11
|
|
5
12
|
from .auth.db import AuthDB
|
6
|
-
from .
|
13
|
+
from .job.db import JobDB
|
14
|
+
from .job_logging.db import JobLoggingDB
|
15
|
+
from .pilot_agents.db import PilotAgentsDB
|
7
16
|
from .sandbox_metadata.db import SandboxMetadataDB
|
17
|
+
from .task_queue.db import TaskQueueDB
|
diracx/db/sql/auth/db.py
CHANGED
@@ -35,7 +35,7 @@ class AuthDB(BaseSQLDB):
|
|
35
35
|
async def device_flow_validate_user_code(
|
36
36
|
self, user_code: str, max_validity: int
|
37
37
|
) -> str:
|
38
|
-
"""Validate that the user_code can be used (Pending status, not expired)
|
38
|
+
"""Validate that the user_code can be used (Pending status, not expired).
|
39
39
|
|
40
40
|
Returns the scope field for the given user_code
|
41
41
|
|
@@ -51,9 +51,7 @@ class AuthDB(BaseSQLDB):
|
|
51
51
|
return (await self.conn.execute(stmt)).scalar_one()
|
52
52
|
|
53
53
|
async def get_device_flow(self, device_code: str, max_validity: int):
|
54
|
-
"""
|
55
|
-
:raises: NoResultFound
|
56
|
-
"""
|
54
|
+
""":raises: NoResultFound"""
|
57
55
|
# The with_for_update
|
58
56
|
# prevents that the token is retrieved
|
59
57
|
# multiple time concurrently
|
@@ -94,9 +92,7 @@ class AuthDB(BaseSQLDB):
|
|
94
92
|
async def device_flow_insert_id_token(
|
95
93
|
self, user_code: str, id_token: dict[str, str], max_validity: int
|
96
94
|
) -> None:
|
97
|
-
"""
|
98
|
-
:raises: AuthorizationError if no such code or status not pending
|
99
|
-
"""
|
95
|
+
""":raises: AuthorizationError if no such code or status not pending"""
|
100
96
|
stmt = update(DeviceFlows)
|
101
97
|
stmt = stmt.where(
|
102
98
|
DeviceFlows.user_code == user_code,
|
@@ -170,11 +166,9 @@ class AuthDB(BaseSQLDB):
|
|
170
166
|
async def authorization_flow_insert_id_token(
|
171
167
|
self, uuid: str, id_token: dict[str, str], max_validity: int
|
172
168
|
) -> tuple[str, str]:
|
169
|
+
"""Returns code, redirect_uri
|
170
|
+
:raises: AuthorizationError if no such uuid or status not pending.
|
173
171
|
"""
|
174
|
-
returns code, redirect_uri
|
175
|
-
:raises: AuthorizationError if no such uuid or status not pending
|
176
|
-
"""
|
177
|
-
|
178
172
|
# Hash the code to avoid leaking information
|
179
173
|
code = secrets.token_urlsafe()
|
180
174
|
hashed_code = hashlib.sha256(code.encode()).hexdigest()
|
@@ -232,8 +226,7 @@ class AuthDB(BaseSQLDB):
|
|
232
226
|
preferred_username: str,
|
233
227
|
scope: str,
|
234
228
|
) -> tuple[str, datetime]:
|
235
|
-
"""
|
236
|
-
Insert a refresh token in the DB as well as user attributes
|
229
|
+
"""Insert a refresh token in the DB as well as user attributes
|
237
230
|
required to generate access tokens.
|
238
231
|
"""
|
239
232
|
# Generate a JWT ID
|
@@ -257,9 +250,7 @@ class AuthDB(BaseSQLDB):
|
|
257
250
|
return jti, row.creation_time
|
258
251
|
|
259
252
|
async def get_refresh_token(self, jti: str) -> dict:
|
260
|
-
"""
|
261
|
-
Get refresh token details bound to a given JWT ID
|
262
|
-
"""
|
253
|
+
"""Get refresh token details bound to a given JWT ID."""
|
263
254
|
# The with_for_update
|
264
255
|
# prevents that the token is retrieved
|
265
256
|
# multiple time concurrently
|
@@ -275,7 +266,7 @@ class AuthDB(BaseSQLDB):
|
|
275
266
|
return res
|
276
267
|
|
277
268
|
async def get_user_refresh_tokens(self, subject: str | None = None) -> list[dict]:
|
278
|
-
"""Get a list of refresh token details based on a subject ID (not revoked)"""
|
269
|
+
"""Get a list of refresh token details based on a subject ID (not revoked)."""
|
279
270
|
# Get a list of refresh tokens
|
280
271
|
stmt = select(RefreshTokens).with_for_update()
|
281
272
|
|
@@ -295,7 +286,7 @@ class AuthDB(BaseSQLDB):
|
|
295
286
|
return refresh_tokens
|
296
287
|
|
297
288
|
async def revoke_refresh_token(self, jti: str):
|
298
|
-
"""Revoke a token given by its JWT ID"""
|
289
|
+
"""Revoke a token given by its JWT ID."""
|
299
290
|
await self.conn.execute(
|
300
291
|
update(RefreshTokens)
|
301
292
|
.where(RefreshTokens.jti == jti)
|
@@ -303,7 +294,7 @@ class AuthDB(BaseSQLDB):
|
|
303
294
|
)
|
304
295
|
|
305
296
|
async def revoke_user_refresh_tokens(self, subject):
|
306
|
-
"""Revoke all the refresh tokens belonging to a user (subject ID)"""
|
297
|
+
"""Revoke all the refresh tokens belonging to a user (subject ID)."""
|
307
298
|
await self.conn.execute(
|
308
299
|
update(RefreshTokens)
|
309
300
|
.where(RefreshTokens.sub == subject)
|
diracx/db/sql/auth/schema.py
CHANGED
@@ -15,12 +15,11 @@ Base = declarative_base()
|
|
15
15
|
|
16
16
|
|
17
17
|
class FlowStatus(Enum):
|
18
|
-
"""
|
19
|
-
The normal flow is
|
18
|
+
"""The normal flow is
|
20
19
|
PENDING -> READY -> DONE
|
21
20
|
Pending is upon insertion
|
22
21
|
Ready/Error is set in response to IdP
|
23
|
-
Done means the user has been issued the dirac token
|
22
|
+
Done means the user has been issued the dirac token.
|
24
23
|
"""
|
25
24
|
|
26
25
|
# The flow is ongoing
|
@@ -64,9 +63,8 @@ class AuthorizationFlows(Base):
|
|
64
63
|
|
65
64
|
|
66
65
|
class RefreshTokenStatus(Enum):
|
67
|
-
"""
|
68
|
-
|
69
|
-
CREATED -> REVOKED
|
66
|
+
"""The normal flow is
|
67
|
+
CREATED -> REVOKED.
|
70
68
|
|
71
69
|
Note1: There is no EXPIRED status as it can be calculated from a creation time
|
72
70
|
Note2: As part of the refresh token rotation mechanism, the revoked token should be retained
|
@@ -82,7 +80,7 @@ class RefreshTokenStatus(Enum):
|
|
82
80
|
|
83
81
|
class RefreshTokens(Base):
|
84
82
|
"""Store attributes bound to a refresh token, as well as specific user attributes
|
85
|
-
that might be then used to generate access tokens
|
83
|
+
that might be then used to generate access tokens.
|
86
84
|
"""
|
87
85
|
|
88
86
|
__tablename__ = "RefreshTokens"
|
diracx/db/sql/dummy/db.py
CHANGED
@@ -11,8 +11,7 @@ from .schema import Cars, Owners
|
|
11
11
|
|
12
12
|
|
13
13
|
class DummyDB(BaseSQLDB):
|
14
|
-
"""
|
15
|
-
This DummyDB is just to illustrate some important aspect of writing
|
14
|
+
"""This DummyDB is just to illustrate some important aspect of writing
|
16
15
|
DB classes in DiracX.
|
17
16
|
|
18
17
|
It is mostly pure SQLAlchemy, with a few convention
|
@@ -27,7 +26,7 @@ class DummyDB(BaseSQLDB):
|
|
27
26
|
columns = [Cars.__table__.columns[x] for x in group_by]
|
28
27
|
|
29
28
|
stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
|
30
|
-
stmt = apply_search_filters(Cars.__table__, stmt, search)
|
29
|
+
stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
|
31
30
|
stmt = stmt.group_by(*columns)
|
32
31
|
|
33
32
|
# Execute the query
|