diracx-db 0.0.1a21__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of diracx-db might be problematic. Click here for more details.
- diracx/db/__main__.py +1 -1
- diracx/db/exceptions.py +4 -1
- diracx/db/os/job_parameters.py +25 -7
- diracx/db/os/utils.py +18 -11
- diracx/db/sql/auth/db.py +113 -78
- diracx/db/sql/auth/schema.py +32 -24
- diracx/db/sql/dummy/db.py +5 -17
- diracx/db/sql/dummy/schema.py +8 -6
- diracx/db/sql/job/db.py +155 -205
- diracx/db/sql/job/schema.py +115 -59
- diracx/db/sql/job_logging/db.py +60 -143
- diracx/db/sql/job_logging/schema.py +54 -15
- diracx/db/sql/pilot_agents/db.py +0 -1
- diracx/db/sql/pilot_agents/schema.py +26 -23
- diracx/db/sql/sandbox_metadata/db.py +164 -57
- diracx/db/sql/sandbox_metadata/schema.py +9 -4
- diracx/db/sql/task_queue/db.py +44 -125
- diracx/db/sql/task_queue/schema.py +2 -0
- diracx/db/sql/utils/__init__.py +29 -451
- diracx/db/sql/utils/base.py +461 -0
- diracx/db/sql/utils/functions.py +142 -0
- diracx/db/sql/utils/types.py +137 -0
- {diracx_db-0.0.1a21.dist-info → diracx_db-0.0.6.dist-info}/METADATA +8 -6
- diracx_db-0.0.6.dist-info/RECORD +37 -0
- {diracx_db-0.0.1a21.dist-info → diracx_db-0.0.6.dist-info}/WHEEL +1 -2
- {diracx_db-0.0.1a21.dist-info → diracx_db-0.0.6.dist-info}/entry_points.txt +2 -2
- diracx/db/sql/utils/job.py +0 -574
- diracx_db-0.0.1a21.dist-info/RECORD +0 -36
- diracx_db-0.0.1a21.dist-info/top_level.txt +0 -1
diracx/db/__main__.py
CHANGED
|
@@ -31,7 +31,6 @@ async def init_sql():
|
|
|
31
31
|
from diracx.db.sql.utils import BaseSQLDB
|
|
32
32
|
|
|
33
33
|
for db_name, db_url in BaseSQLDB.available_urls().items():
|
|
34
|
-
|
|
35
34
|
logger.info("Initialising %s", db_name)
|
|
36
35
|
db = BaseSQLDB.available_implementations(db_name)[0](db_url)
|
|
37
36
|
async with db.engine_context():
|
|
@@ -40,6 +39,7 @@ async def init_sql():
|
|
|
40
39
|
if db._db_url.startswith("sqlite"):
|
|
41
40
|
await conn.exec_driver_sql("PRAGMA foreign_keys=ON")
|
|
42
41
|
await conn.run_sync(db.metadata.create_all)
|
|
42
|
+
await db.post_create(conn)
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
async def init_os():
|
diracx/db/exceptions.py
CHANGED
diracx/db/os/job_parameters.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
|
|
3
5
|
from diracx.db.os.utils import BaseOSDB
|
|
4
6
|
|
|
5
7
|
|
|
@@ -7,19 +9,35 @@ class JobParametersDB(BaseOSDB):
|
|
|
7
9
|
fields = {
|
|
8
10
|
"JobID": {"type": "long"},
|
|
9
11
|
"timestamp": {"type": "date"},
|
|
12
|
+
"PilotAgent": {"type": "keyword"},
|
|
13
|
+
"Pilot_Reference": {"type": "keyword"},
|
|
14
|
+
"JobGroup": {"type": "keyword"},
|
|
10
15
|
"CPUNormalizationFactor": {"type": "long"},
|
|
11
16
|
"NormCPUTime(s)": {"type": "long"},
|
|
12
|
-
"Memory(
|
|
17
|
+
"Memory(MB)": {"type": "long"},
|
|
18
|
+
"LocalAccount": {"type": "keyword"},
|
|
13
19
|
"TotalCPUTime(s)": {"type": "long"},
|
|
14
|
-
"
|
|
15
|
-
"HostName": {"type": "
|
|
20
|
+
"PayloadPID": {"type": "long"},
|
|
21
|
+
"HostName": {"type": "text"},
|
|
16
22
|
"GridCE": {"type": "keyword"},
|
|
23
|
+
"CEQueue": {"type": "keyword"},
|
|
24
|
+
"BatchSystem": {"type": "keyword"},
|
|
17
25
|
"ModelName": {"type": "keyword"},
|
|
18
26
|
"Status": {"type": "keyword"},
|
|
19
27
|
"JobType": {"type": "keyword"},
|
|
20
28
|
}
|
|
21
|
-
|
|
29
|
+
# TODO: Does this need to be configurable?
|
|
30
|
+
index_prefix = "job_parameters"
|
|
31
|
+
|
|
32
|
+
def index_name(self, vo, doc_id: int) -> str:
|
|
33
|
+
split = int(int(doc_id) // 1e6)
|
|
34
|
+
# The index name must be lowercase or opensearchpy will throw.
|
|
35
|
+
return f"{self.index_prefix}_{vo.lower()}_{split}m"
|
|
22
36
|
|
|
23
|
-
def
|
|
24
|
-
|
|
25
|
-
|
|
37
|
+
def upsert(self, vo, doc_id, document):
|
|
38
|
+
document = {
|
|
39
|
+
"JobID": doc_id,
|
|
40
|
+
"timestamp": int(datetime.now(tz=UTC).timestamp() * 1000),
|
|
41
|
+
**document,
|
|
42
|
+
}
|
|
43
|
+
return super().upsert(vo, doc_id, document)
|
diracx/db/os/utils.py
CHANGED
|
@@ -16,7 +16,7 @@ from opensearchpy import AsyncOpenSearch
|
|
|
16
16
|
|
|
17
17
|
from diracx.core.exceptions import InvalidQueryError
|
|
18
18
|
from diracx.core.extensions import select_from_extension
|
|
19
|
-
from diracx.db.exceptions import
|
|
19
|
+
from diracx.db.exceptions import DBUnavailableError
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
@@ -25,7 +25,7 @@ class OpenSearchDBError(Exception):
|
|
|
25
25
|
pass
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
class
|
|
28
|
+
class OpenSearchDBUnavailableError(DBUnavailableError, OpenSearchDBError):
|
|
29
29
|
pass
|
|
30
30
|
|
|
31
31
|
|
|
@@ -38,7 +38,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
|
38
38
|
|
|
39
39
|
The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`.
|
|
40
40
|
This method returns a dictionary of database names to connection parameters.
|
|
41
|
-
The available databases are determined by the `diracx.
|
|
41
|
+
The available databases are determined by the `diracx.dbs.os` entrypoint in
|
|
42
42
|
the `pyproject.toml` file and the connection parameters are taken from the
|
|
43
43
|
environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
|
|
44
44
|
|
|
@@ -77,7 +77,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
|
77
77
|
index_prefix: str
|
|
78
78
|
|
|
79
79
|
@abstractmethod
|
|
80
|
-
def index_name(self, doc_id: int) -> str: ...
|
|
80
|
+
def index_name(self, vo: str, doc_id: int) -> str: ...
|
|
81
81
|
|
|
82
82
|
def __init__(self, connection_kwargs: dict[str, Any]) -> None:
|
|
83
83
|
self._client: AsyncOpenSearch | None = None
|
|
@@ -92,7 +92,9 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
|
92
92
|
"""Return the available implementations of the DB in reverse priority order."""
|
|
93
93
|
db_classes: list[type[BaseOSDB]] = [
|
|
94
94
|
entry_point.load()
|
|
95
|
-
for entry_point in select_from_extension(
|
|
95
|
+
for entry_point in select_from_extension(
|
|
96
|
+
group="diracx.dbs.os", name=db_name
|
|
97
|
+
)
|
|
96
98
|
]
|
|
97
99
|
if not db_classes:
|
|
98
100
|
raise NotImplementedError(f"Could not find any matches for {db_name=}")
|
|
@@ -106,7 +108,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
|
106
108
|
prefixed with ``DIRACX_OS_DB_{DB_NAME}``.
|
|
107
109
|
"""
|
|
108
110
|
conn_kwargs: dict[str, dict[str, Any]] = {}
|
|
109
|
-
for entry_point in select_from_extension(group="diracx.
|
|
111
|
+
for entry_point in select_from_extension(group="diracx.dbs.os"):
|
|
110
112
|
db_name = entry_point.name
|
|
111
113
|
var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}"
|
|
112
114
|
if var_name in os.environ:
|
|
@@ -152,7 +154,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
|
152
154
|
be ran at every query.
|
|
153
155
|
"""
|
|
154
156
|
if not await self.client.ping():
|
|
155
|
-
raise
|
|
157
|
+
raise OpenSearchDBUnavailableError(
|
|
156
158
|
f"Failed to connect to {self.__class__.__qualname__}"
|
|
157
159
|
)
|
|
158
160
|
|
|
@@ -180,15 +182,20 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
|
180
182
|
)
|
|
181
183
|
assert result["acknowledged"]
|
|
182
184
|
|
|
183
|
-
async def upsert(self, doc_id, document) -> None:
|
|
184
|
-
|
|
185
|
+
async def upsert(self, vo: str, doc_id: int, document: Any) -> None:
|
|
186
|
+
index_name = self.index_name(vo, doc_id)
|
|
185
187
|
response = await self.client.update(
|
|
186
|
-
index=
|
|
188
|
+
index=index_name,
|
|
187
189
|
id=doc_id,
|
|
188
190
|
body={"doc": document, "doc_as_upsert": True},
|
|
189
191
|
params=dict(retry_on_conflict=10),
|
|
190
192
|
)
|
|
191
|
-
|
|
193
|
+
logger.debug(
|
|
194
|
+
"Upserted document %s in index %s with response: %s",
|
|
195
|
+
doc_id,
|
|
196
|
+
index_name,
|
|
197
|
+
response,
|
|
198
|
+
)
|
|
192
199
|
|
|
193
200
|
async def search(
|
|
194
201
|
self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None
|
diracx/db/sql/auth/db.py
CHANGED
|
@@ -1,19 +1,21 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
import logging
|
|
4
4
|
import secrets
|
|
5
|
-
from datetime import datetime
|
|
6
|
-
from
|
|
5
|
+
from datetime import UTC, datetime
|
|
6
|
+
from itertools import pairwise
|
|
7
7
|
|
|
8
|
-
from
|
|
8
|
+
from dateutil.rrule import MONTHLY, rrule
|
|
9
|
+
from sqlalchemy import insert, select, text, update
|
|
9
10
|
from sqlalchemy.exc import IntegrityError, NoResultFound
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncConnection
|
|
12
|
+
from uuid_utils import UUID, uuid7
|
|
10
13
|
|
|
11
14
|
from diracx.core.exceptions import (
|
|
12
15
|
AuthorizationError,
|
|
13
|
-
|
|
14
|
-
PendingAuthorizationError,
|
|
16
|
+
TokenNotFoundError,
|
|
15
17
|
)
|
|
16
|
-
from diracx.db.sql.utils import BaseSQLDB, substract_date
|
|
18
|
+
from diracx.db.sql.utils import BaseSQLDB, hash, substract_date, uuid7_from_datetime
|
|
17
19
|
|
|
18
20
|
from .schema import (
|
|
19
21
|
AuthorizationFlows,
|
|
@@ -28,10 +30,72 @@ from .schema import Base as AuthDBBase
|
|
|
28
30
|
USER_CODE_ALPHABET = "BCDFGHJKLMNPQRSTVWXZ"
|
|
29
31
|
MAX_RETRY = 5
|
|
30
32
|
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
31
35
|
|
|
32
36
|
class AuthDB(BaseSQLDB):
|
|
33
37
|
metadata = AuthDBBase.metadata
|
|
34
38
|
|
|
39
|
+
@classmethod
|
|
40
|
+
async def post_create(cls, conn: AsyncConnection) -> None:
|
|
41
|
+
"""Create partitions if it is a MySQL DB and it does not have
|
|
42
|
+
it yet and the table does not have any data yet.
|
|
43
|
+
We do this as a post_create step as sqlalchemy does not support
|
|
44
|
+
partition so well.
|
|
45
|
+
"""
|
|
46
|
+
if conn.dialect.name == "mysql":
|
|
47
|
+
check_partition_query = text(
|
|
48
|
+
"SELECT PARTITION_NAME FROM information_schema.partitions "
|
|
49
|
+
"WHERE TABLE_NAME = 'RefreshTokens' AND PARTITION_NAME is not NULL"
|
|
50
|
+
)
|
|
51
|
+
partition_names = (await conn.execute(check_partition_query)).all()
|
|
52
|
+
|
|
53
|
+
if not partition_names:
|
|
54
|
+
# Create a monthly partition from today until 2 years
|
|
55
|
+
# The partition are named p_<year>_<month>
|
|
56
|
+
start_date = datetime.now(tz=UTC).replace(
|
|
57
|
+
day=1, hour=0, minute=0, second=0, microsecond=0
|
|
58
|
+
)
|
|
59
|
+
end_date = start_date.replace(year=start_date.year + 2)
|
|
60
|
+
|
|
61
|
+
dates = [
|
|
62
|
+
dt for dt in rrule(MONTHLY, dtstart=start_date, until=end_date)
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
partition_list = []
|
|
66
|
+
for name, limit in pairwise(dates):
|
|
67
|
+
partition_list.append(
|
|
68
|
+
f"PARTITION p_{name.year}_{name.month} "
|
|
69
|
+
f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')"
|
|
70
|
+
)
|
|
71
|
+
partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")
|
|
72
|
+
|
|
73
|
+
alter_query = text(
|
|
74
|
+
f"ALTER TABLE RefreshTokens PARTITION BY RANGE COLUMNS (JTI) ({','.join(partition_list)})"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
check_table_empty_query = text("SELECT * FROM RefreshTokens LIMIT 1")
|
|
78
|
+
refresh_table_content = (
|
|
79
|
+
await conn.execute(check_table_empty_query)
|
|
80
|
+
).all()
|
|
81
|
+
if refresh_table_content:
|
|
82
|
+
logger.warning(
|
|
83
|
+
"RefreshTokens table not empty. Run the following query yourself"
|
|
84
|
+
)
|
|
85
|
+
logger.warning(alter_query)
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
await conn.execute(alter_query)
|
|
89
|
+
|
|
90
|
+
partition_names = (
|
|
91
|
+
await conn.execute(
|
|
92
|
+
check_partition_query, {"table_name": "RefreshTokens"}
|
|
93
|
+
)
|
|
94
|
+
).all()
|
|
95
|
+
assert partition_names, (
|
|
96
|
+
f"There should be partitions now {partition_names}"
|
|
97
|
+
)
|
|
98
|
+
|
|
35
99
|
async def device_flow_validate_user_code(
|
|
36
100
|
self, user_code: str, max_validity: int
|
|
37
101
|
) -> str:
|
|
@@ -50,44 +114,25 @@ class AuthDB(BaseSQLDB):
|
|
|
50
114
|
|
|
51
115
|
return (await self.conn.execute(stmt)).scalar_one()
|
|
52
116
|
|
|
53
|
-
async def get_device_flow(self, device_code: str
|
|
117
|
+
async def get_device_flow(self, device_code: str):
|
|
54
118
|
""":raises: NoResultFound"""
|
|
55
119
|
# The with_for_update
|
|
56
120
|
# prevents that the token is retrieved
|
|
57
121
|
# multiple time concurrently
|
|
58
|
-
stmt = select(
|
|
59
|
-
DeviceFlows,
|
|
60
|
-
(DeviceFlows.creation_time < substract_date(seconds=max_validity)).label(
|
|
61
|
-
"is_expired"
|
|
62
|
-
),
|
|
63
|
-
).with_for_update()
|
|
122
|
+
stmt = select(DeviceFlows).with_for_update()
|
|
64
123
|
stmt = stmt.where(
|
|
65
|
-
DeviceFlows.device_code ==
|
|
124
|
+
DeviceFlows.device_code == hash(device_code),
|
|
66
125
|
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
if res["is_expired"]:
|
|
70
|
-
raise ExpiredFlowError()
|
|
71
|
-
|
|
72
|
-
if res["status"] == FlowStatus.READY:
|
|
73
|
-
# Update the status to Done before returning
|
|
74
|
-
await self.conn.execute(
|
|
75
|
-
update(DeviceFlows)
|
|
76
|
-
.where(
|
|
77
|
-
DeviceFlows.device_code
|
|
78
|
-
== hashlib.sha256(device_code.encode()).hexdigest()
|
|
79
|
-
)
|
|
80
|
-
.values(status=FlowStatus.DONE)
|
|
81
|
-
)
|
|
82
|
-
return res
|
|
83
|
-
|
|
84
|
-
if res["status"] == FlowStatus.DONE:
|
|
85
|
-
raise AuthorizationError("Code was already used")
|
|
126
|
+
return dict((await self.conn.execute(stmt)).one()._mapping)
|
|
86
127
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
128
|
+
async def update_device_flow_status(
|
|
129
|
+
self, device_code: str, status: FlowStatus
|
|
130
|
+
) -> None:
|
|
131
|
+
stmt = update(DeviceFlows).where(
|
|
132
|
+
DeviceFlows.device_code == hash(device_code),
|
|
133
|
+
)
|
|
134
|
+
stmt = stmt.values(status=status)
|
|
135
|
+
await self.conn.execute(stmt)
|
|
91
136
|
|
|
92
137
|
async def device_flow_insert_id_token(
|
|
93
138
|
self, user_code: str, id_token: dict[str, str], max_validity: int
|
|
@@ -121,7 +166,7 @@ class AuthDB(BaseSQLDB):
|
|
|
121
166
|
device_code = secrets.token_urlsafe()
|
|
122
167
|
|
|
123
168
|
# Hash the the device_code to avoid leaking information
|
|
124
|
-
hashed_device_code =
|
|
169
|
+
hashed_device_code = hash(device_code)
|
|
125
170
|
|
|
126
171
|
stmt = insert(DeviceFlows).values(
|
|
127
172
|
client_id=client_id,
|
|
@@ -133,6 +178,10 @@ class AuthDB(BaseSQLDB):
|
|
|
133
178
|
await self.conn.execute(stmt)
|
|
134
179
|
|
|
135
180
|
except IntegrityError:
|
|
181
|
+
logger.warning(
|
|
182
|
+
"Device flow code collision detected, retrying (user_code=%s)",
|
|
183
|
+
user_code,
|
|
184
|
+
)
|
|
136
185
|
continue
|
|
137
186
|
|
|
138
187
|
return user_code, device_code
|
|
@@ -148,7 +197,7 @@ class AuthDB(BaseSQLDB):
|
|
|
148
197
|
code_challenge_method: str,
|
|
149
198
|
redirect_uri: str,
|
|
150
199
|
) -> str:
|
|
151
|
-
uuid = str(
|
|
200
|
+
uuid = str(uuid7())
|
|
152
201
|
|
|
153
202
|
stmt = insert(AuthorizationFlows).values(
|
|
154
203
|
uuid=uuid,
|
|
@@ -171,7 +220,7 @@ class AuthDB(BaseSQLDB):
|
|
|
171
220
|
"""
|
|
172
221
|
# Hash the code to avoid leaking information
|
|
173
222
|
code = secrets.token_urlsafe()
|
|
174
|
-
hashed_code =
|
|
223
|
+
hashed_code = hash(code)
|
|
175
224
|
|
|
176
225
|
stmt = update(AuthorizationFlows)
|
|
177
226
|
|
|
@@ -190,10 +239,11 @@ class AuthDB(BaseSQLDB):
|
|
|
190
239
|
stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
|
|
191
240
|
stmt = stmt.where(AuthorizationFlows.uuid == uuid)
|
|
192
241
|
row = (await self.conn.execute(stmt)).one()
|
|
193
|
-
return code, row.
|
|
242
|
+
return code, row.RedirectURI
|
|
194
243
|
|
|
195
244
|
async def get_authorization_flow(self, code: str, max_validity: int):
|
|
196
|
-
|
|
245
|
+
"""Get the authorization flow details based on the code."""
|
|
246
|
+
hashed_code = hash(code)
|
|
197
247
|
# The with_for_update
|
|
198
248
|
# prevents that the token is retrieved
|
|
199
249
|
# multiple time concurrently
|
|
@@ -203,54 +253,39 @@ class AuthDB(BaseSQLDB):
|
|
|
203
253
|
AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
|
|
204
254
|
)
|
|
205
255
|
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
if res["status"] == FlowStatus.READY:
|
|
209
|
-
# Update the status to Done before returning
|
|
210
|
-
await self.conn.execute(
|
|
211
|
-
update(AuthorizationFlows)
|
|
212
|
-
.where(AuthorizationFlows.code == hashed_code)
|
|
213
|
-
.values(status=FlowStatus.DONE)
|
|
214
|
-
)
|
|
256
|
+
return dict((await self.conn.execute(stmt)).one()._mapping)
|
|
215
257
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
258
|
+
async def update_authorization_flow_status(
|
|
259
|
+
self, code: str, status: FlowStatus
|
|
260
|
+
) -> None:
|
|
261
|
+
"""Update the status of an authorization flow based on the code."""
|
|
262
|
+
hashed_code = hash(code)
|
|
263
|
+
await self.conn.execute(
|
|
264
|
+
update(AuthorizationFlows)
|
|
265
|
+
.where(AuthorizationFlows.code == hashed_code)
|
|
266
|
+
.values(status=status)
|
|
267
|
+
)
|
|
222
268
|
|
|
223
269
|
async def insert_refresh_token(
|
|
224
270
|
self,
|
|
271
|
+
jti: UUID,
|
|
225
272
|
subject: str,
|
|
226
|
-
preferred_username: str,
|
|
227
273
|
scope: str,
|
|
228
|
-
) ->
|
|
274
|
+
) -> None:
|
|
229
275
|
"""Insert a refresh token in the DB as well as user attributes
|
|
230
276
|
required to generate access tokens.
|
|
231
277
|
"""
|
|
232
|
-
# Generate a JWT ID
|
|
233
|
-
jti = str(uuid4())
|
|
234
|
-
|
|
235
278
|
# Insert values into the DB
|
|
236
279
|
stmt = insert(RefreshTokens).values(
|
|
237
|
-
jti=jti,
|
|
280
|
+
jti=str(jti),
|
|
238
281
|
sub=subject,
|
|
239
|
-
preferred_username=preferred_username,
|
|
240
282
|
scope=scope,
|
|
241
283
|
)
|
|
242
284
|
await self.conn.execute(stmt)
|
|
243
285
|
|
|
244
|
-
|
|
245
|
-
stmt = select(RefreshTokens.creation_time)
|
|
246
|
-
stmt = stmt.where(RefreshTokens.jti == jti)
|
|
247
|
-
row = (await self.conn.execute(stmt)).one()
|
|
248
|
-
|
|
249
|
-
# Return the JWT ID and the creation time
|
|
250
|
-
return jti, row.creation_time
|
|
251
|
-
|
|
252
|
-
async def get_refresh_token(self, jti: str) -> dict:
|
|
286
|
+
async def get_refresh_token(self, jti: UUID) -> dict:
|
|
253
287
|
"""Get refresh token details bound to a given JWT ID."""
|
|
288
|
+
jti = str(jti)
|
|
254
289
|
# The with_for_update
|
|
255
290
|
# prevents that the token is retrieved
|
|
256
291
|
# multiple time concurrently
|
|
@@ -260,8 +295,8 @@ class AuthDB(BaseSQLDB):
|
|
|
260
295
|
)
|
|
261
296
|
try:
|
|
262
297
|
res = dict((await self.conn.execute(stmt)).one()._mapping)
|
|
263
|
-
except NoResultFound:
|
|
264
|
-
|
|
298
|
+
except NoResultFound as e:
|
|
299
|
+
raise TokenNotFoundError(jti) from e
|
|
265
300
|
|
|
266
301
|
return res
|
|
267
302
|
|
|
@@ -285,11 +320,11 @@ class AuthDB(BaseSQLDB):
|
|
|
285
320
|
|
|
286
321
|
return refresh_tokens
|
|
287
322
|
|
|
288
|
-
async def revoke_refresh_token(self, jti:
|
|
323
|
+
async def revoke_refresh_token(self, jti: UUID):
|
|
289
324
|
"""Revoke a token given by its JWT ID."""
|
|
290
325
|
await self.conn.execute(
|
|
291
326
|
update(RefreshTokens)
|
|
292
|
-
.where(RefreshTokens.jti == jti)
|
|
327
|
+
.where(RefreshTokens.jti == str(jti))
|
|
293
328
|
.values(status=RefreshTokenStatus.REVOKED)
|
|
294
329
|
)
|
|
295
330
|
|
diracx/db/sql/auth/schema.py
CHANGED
|
@@ -1,13 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from enum import Enum, auto
|
|
2
4
|
|
|
3
5
|
from sqlalchemy import (
|
|
4
6
|
JSON,
|
|
7
|
+
Index,
|
|
5
8
|
String,
|
|
6
9
|
Uuid,
|
|
7
10
|
)
|
|
8
11
|
from sqlalchemy.orm import declarative_base
|
|
9
12
|
|
|
10
|
-
from diracx.db.sql.utils import
|
|
13
|
+
from diracx.db.sql.utils import (
|
|
14
|
+
Column,
|
|
15
|
+
DateNowColumn,
|
|
16
|
+
EnumColumn,
|
|
17
|
+
NullColumn,
|
|
18
|
+
)
|
|
11
19
|
|
|
12
20
|
USER_CODE_LENGTH = 8
|
|
13
21
|
|
|
@@ -39,27 +47,27 @@ class FlowStatus(Enum):
|
|
|
39
47
|
|
|
40
48
|
class DeviceFlows(Base):
|
|
41
49
|
__tablename__ = "DeviceFlows"
|
|
42
|
-
user_code = Column(String(USER_CODE_LENGTH), primary_key=True)
|
|
43
|
-
status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
|
|
44
|
-
creation_time = DateNowColumn()
|
|
45
|
-
client_id = Column(String(255))
|
|
46
|
-
scope = Column(String(1024))
|
|
47
|
-
device_code = Column(String(128), unique=True) # Should be a hash
|
|
48
|
-
id_token = NullColumn(JSON())
|
|
50
|
+
user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True)
|
|
51
|
+
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
|
|
52
|
+
creation_time = DateNowColumn("CreationTime")
|
|
53
|
+
client_id = Column("ClientID", String(255))
|
|
54
|
+
scope = Column("Scope", String(1024))
|
|
55
|
+
device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash
|
|
56
|
+
id_token = NullColumn("IDToken", JSON())
|
|
49
57
|
|
|
50
58
|
|
|
51
59
|
class AuthorizationFlows(Base):
|
|
52
60
|
__tablename__ = "AuthorizationFlows"
|
|
53
|
-
uuid = Column(Uuid(as_uuid=False), primary_key=True)
|
|
54
|
-
status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
|
|
55
|
-
client_id = Column(String(255))
|
|
56
|
-
creation_time = DateNowColumn()
|
|
57
|
-
scope = Column(String(1024))
|
|
58
|
-
code_challenge = Column(String(255))
|
|
59
|
-
code_challenge_method = Column(String(8))
|
|
60
|
-
redirect_uri = Column(String(255))
|
|
61
|
-
code = NullColumn(String(255)) # Should be a hash
|
|
62
|
-
id_token = NullColumn(JSON())
|
|
61
|
+
uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True)
|
|
62
|
+
status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
|
|
63
|
+
client_id = Column("ClientID", String(255))
|
|
64
|
+
creation_time = DateNowColumn("CreationTime")
|
|
65
|
+
scope = Column("Scope", String(1024))
|
|
66
|
+
code_challenge = Column("CodeChallenge", String(255))
|
|
67
|
+
code_challenge_method = Column("CodeChallengeMethod", String(8))
|
|
68
|
+
redirect_uri = Column("RedirectURI", String(255))
|
|
69
|
+
code = NullColumn("Code", String(255)) # Should be a hash
|
|
70
|
+
id_token = NullColumn("IDToken", JSON())
|
|
63
71
|
|
|
64
72
|
|
|
65
73
|
class RefreshTokenStatus(Enum):
|
|
@@ -85,13 +93,13 @@ class RefreshTokens(Base):
|
|
|
85
93
|
|
|
86
94
|
__tablename__ = "RefreshTokens"
|
|
87
95
|
# Refresh token attributes
|
|
88
|
-
jti = Column(Uuid(as_uuid=False), primary_key=True)
|
|
96
|
+
jti = Column("JTI", Uuid(as_uuid=False), primary_key=True)
|
|
89
97
|
status = EnumColumn(
|
|
90
|
-
RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
|
|
98
|
+
"Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
|
|
91
99
|
)
|
|
92
|
-
|
|
93
|
-
scope = Column(String(1024))
|
|
100
|
+
scope = Column("Scope", String(1024))
|
|
94
101
|
|
|
95
102
|
# User attributes bound to the refresh token
|
|
96
|
-
sub = Column(String(
|
|
97
|
-
|
|
103
|
+
sub = Column("Sub", String(256), index=True)
|
|
104
|
+
|
|
105
|
+
__table_args__ = (Index("index_status_sub", status, sub),)
|
diracx/db/sql/dummy/db.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from sqlalchemy import insert
|
|
4
|
+
from uuid_utils import UUID
|
|
4
5
|
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from diracx.db.sql.utils import BaseSQLDB, apply_search_filters
|
|
6
|
+
from diracx.db.sql.utils import BaseSQLDB
|
|
8
7
|
|
|
9
8
|
from .schema import Base as DummyDBBase
|
|
10
9
|
from .schema import Cars, Owners
|
|
@@ -23,18 +22,7 @@ class DummyDB(BaseSQLDB):
|
|
|
23
22
|
metadata = DummyDBBase.metadata
|
|
24
23
|
|
|
25
24
|
async def summary(self, group_by, search) -> list[dict[str, str | int]]:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
|
|
29
|
-
stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
|
|
30
|
-
stmt = stmt.group_by(*columns)
|
|
31
|
-
|
|
32
|
-
# Execute the query
|
|
33
|
-
return [
|
|
34
|
-
dict(row._mapping)
|
|
35
|
-
async for row in (await self.conn.stream(stmt))
|
|
36
|
-
if row.count > 0 # type: ignore
|
|
37
|
-
]
|
|
25
|
+
return await self._summary(Cars, group_by, search)
|
|
38
26
|
|
|
39
27
|
async def insert_owner(self, name: str) -> int:
|
|
40
28
|
stmt = insert(Owners).values(name=name)
|
|
@@ -44,7 +32,7 @@ class DummyDB(BaseSQLDB):
|
|
|
44
32
|
|
|
45
33
|
async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int:
|
|
46
34
|
stmt = insert(Cars).values(
|
|
47
|
-
|
|
35
|
+
license_plate=license_plate, model=model, owner_id=owner_id
|
|
48
36
|
)
|
|
49
37
|
|
|
50
38
|
result = await self.conn.execute(stmt)
|
diracx/db/sql/dummy/schema.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# The utils class define some boilerplate types that should be used
|
|
2
2
|
# in place of the SQLAlchemy one. Have a look at them
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from sqlalchemy import ForeignKey, Integer, String, Uuid
|
|
4
6
|
from sqlalchemy.orm import declarative_base
|
|
5
7
|
|
|
@@ -10,13 +12,13 @@ Base = declarative_base()
|
|
|
10
12
|
|
|
11
13
|
class Owners(Base):
|
|
12
14
|
__tablename__ = "Owners"
|
|
13
|
-
|
|
14
|
-
creation_time = DateNowColumn()
|
|
15
|
-
name = Column(String(255))
|
|
15
|
+
owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True)
|
|
16
|
+
creation_time = DateNowColumn("CreationTime")
|
|
17
|
+
name = Column("Name", String(255))
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
class Cars(Base):
|
|
19
21
|
__tablename__ = "Cars"
|
|
20
|
-
|
|
21
|
-
model = Column(String(255))
|
|
22
|
-
|
|
22
|
+
license_plate = Column("LicensePlate", Uuid(), primary_key=True)
|
|
23
|
+
model = Column("Model", String(255))
|
|
24
|
+
owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id))
|