diracx-db 0.0.1a26__tar.gz → 0.0.1a29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/PKG-INFO +1 -3
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/pyproject.toml +2 -4
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/os/utils.py +5 -3
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/auth/db.py +39 -73
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/job/db.py +30 -123
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/job_logging/db.py +21 -81
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/sandbox_metadata/db.py +23 -16
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/task_queue/db.py +43 -124
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/utils/__init__.py +2 -1
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/utils/base.py +4 -4
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/utils/functions.py +7 -1
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx_db.egg-info/PKG-INFO +1 -3
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx_db.egg-info/SOURCES.txt +0 -1
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx_db.egg-info/entry_points.txt +2 -2
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx_db.egg-info/requires.txt +0 -2
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/auth/test_authorization_flow.py +0 -5
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/auth/test_device_flow.py +11 -20
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/auth/test_refresh_token.py +22 -5
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/jobs/test_job_db.py +33 -81
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/jobs/test_job_logging_db.py +30 -22
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/jobs/test_sandbox_metadata.py +16 -5
- diracx_db-0.0.1a26/src/diracx/db/sql/utils/job.py +0 -578
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/README.md +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/setup.cfg +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/__main__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/exceptions.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/os/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/os/job_parameters.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/py.typed +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/auth/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/auth/schema.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/dummy/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/dummy/db.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/dummy/schema.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/job/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/job/schema.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/job_logging/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/job_logging/schema.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/pilot_agents/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/pilot_agents/db.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/pilot_agents/schema.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/sandbox_metadata/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/sandbox_metadata/schema.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/task_queue/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/task_queue/schema.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx/db/sql/utils/types.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx_db.egg-info/dependency_links.txt +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/src/diracx_db.egg-info/top_level.txt +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/opensearch/test_connection.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/opensearch/test_index_template.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/opensearch/test_search.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/pilot_agents/__init__.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/pilot_agents/test_pilot_agents_db.py +0 -0
- {diracx_db-0.0.1a26 → diracx_db-0.0.1a29}/tests/test_dummy_db.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: diracx-db
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a29
|
4
4
|
Summary: TODO
|
5
5
|
License: GPL-3.0-only
|
6
6
|
Classifier: Intended Audience :: Science/Research
|
@@ -10,9 +10,7 @@ Classifier: Topic :: Scientific/Engineering
|
|
10
10
|
Classifier: Topic :: System :: Distributed Computing
|
11
11
|
Requires-Python: >=3.11
|
12
12
|
Description-Content-Type: text/markdown
|
13
|
-
Requires-Dist: dirac
|
14
13
|
Requires-Dist: diracx-core
|
15
|
-
Requires-Dist: fastapi
|
16
14
|
Requires-Dist: opensearch-py[async]
|
17
15
|
Requires-Dist: pydantic>=2.10
|
18
16
|
Requires-Dist: sqlalchemy[aiomysql,aiosqlite]>=2
|
@@ -13,9 +13,7 @@ classifiers = [
|
|
13
13
|
"Topic :: System :: Distributed Computing",
|
14
14
|
]
|
15
15
|
dependencies = [
|
16
|
-
"dirac",
|
17
16
|
"diracx-core",
|
18
|
-
"fastapi",
|
19
17
|
"opensearch-py[async]",
|
20
18
|
"pydantic >=2.10",
|
21
19
|
"sqlalchemy[aiomysql,aiosqlite] >= 2",
|
@@ -27,7 +25,7 @@ testing = [
|
|
27
25
|
"diracx-testing",
|
28
26
|
]
|
29
27
|
|
30
|
-
[project.entry-points."diracx.
|
28
|
+
[project.entry-points."diracx.dbs.sql"]
|
31
29
|
AuthDB = "diracx.db.sql:AuthDB"
|
32
30
|
JobDB = "diracx.db.sql:JobDB"
|
33
31
|
JobLoggingDB = "diracx.db.sql:JobLoggingDB"
|
@@ -35,7 +33,7 @@ PilotAgentsDB = "diracx.db.sql:PilotAgentsDB"
|
|
35
33
|
SandboxMetadataDB = "diracx.db.sql:SandboxMetadataDB"
|
36
34
|
TaskQueueDB = "diracx.db.sql:TaskQueueDB"
|
37
35
|
|
38
|
-
[project.entry-points."diracx.
|
36
|
+
[project.entry-points."diracx.dbs.os"]
|
39
37
|
JobParametersDB = "diracx.db.os:JobParametersDB"
|
40
38
|
|
41
39
|
[tool.setuptools.packages.find]
|
@@ -38,7 +38,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
38
38
|
|
39
39
|
The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`.
|
40
40
|
This method returns a dictionary of database names to connection parameters.
|
41
|
-
The available databases are determined by the `diracx.
|
41
|
+
The available databases are determined by the `diracx.dbs.os` entrypoint in
|
42
42
|
the `pyproject.toml` file and the connection parameters are taken from the
|
43
43
|
environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
|
44
44
|
|
@@ -92,7 +92,9 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
92
92
|
"""Return the available implementations of the DB in reverse priority order."""
|
93
93
|
db_classes: list[type[BaseOSDB]] = [
|
94
94
|
entry_point.load()
|
95
|
-
for entry_point in select_from_extension(
|
95
|
+
for entry_point in select_from_extension(
|
96
|
+
group="diracx.dbs.os", name=db_name
|
97
|
+
)
|
96
98
|
]
|
97
99
|
if not db_classes:
|
98
100
|
raise NotImplementedError(f"Could not find any matches for {db_name=}")
|
@@ -106,7 +108,7 @@ class BaseOSDB(metaclass=ABCMeta):
|
|
106
108
|
prefixed with ``DIRACX_OS_DB_{DB_NAME}``.
|
107
109
|
"""
|
108
110
|
conn_kwargs: dict[str, dict[str, Any]] = {}
|
109
|
-
for entry_point in select_from_extension(group="diracx.
|
111
|
+
for entry_point in select_from_extension(group="diracx.dbs.os"):
|
110
112
|
db_name = entry_point.name
|
111
113
|
var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}"
|
112
114
|
if var_name in os.environ:
|
@@ -1,19 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import hashlib
|
4
3
|
import secrets
|
5
|
-
from
|
6
|
-
from uuid import uuid4
|
4
|
+
from uuid import UUID, uuid4
|
7
5
|
|
8
6
|
from sqlalchemy import insert, select, update
|
9
7
|
from sqlalchemy.exc import IntegrityError, NoResultFound
|
10
8
|
|
11
9
|
from diracx.core.exceptions import (
|
12
10
|
AuthorizationError,
|
13
|
-
|
14
|
-
PendingAuthorizationError,
|
11
|
+
TokenNotFoundError,
|
15
12
|
)
|
16
|
-
from diracx.db.sql.utils import BaseSQLDB, substract_date
|
13
|
+
from diracx.db.sql.utils import BaseSQLDB, hash, substract_date
|
17
14
|
|
18
15
|
from .schema import (
|
19
16
|
AuthorizationFlows,
|
@@ -50,44 +47,25 @@ class AuthDB(BaseSQLDB):
|
|
50
47
|
|
51
48
|
return (await self.conn.execute(stmt)).scalar_one()
|
52
49
|
|
53
|
-
async def get_device_flow(self, device_code: str
|
50
|
+
async def get_device_flow(self, device_code: str):
|
54
51
|
""":raises: NoResultFound"""
|
55
52
|
# The with_for_update
|
56
53
|
# prevents that the token is retrieved
|
57
54
|
# multiple time concurrently
|
58
|
-
stmt = select(
|
59
|
-
DeviceFlows,
|
60
|
-
(DeviceFlows.creation_time < substract_date(seconds=max_validity)).label(
|
61
|
-
"IsExpired"
|
62
|
-
),
|
63
|
-
).with_for_update()
|
55
|
+
stmt = select(DeviceFlows).with_for_update()
|
64
56
|
stmt = stmt.where(
|
65
|
-
DeviceFlows.device_code ==
|
57
|
+
DeviceFlows.device_code == hash(device_code),
|
66
58
|
)
|
67
|
-
|
68
|
-
|
69
|
-
if res["IsExpired"]:
|
70
|
-
raise ExpiredFlowError()
|
71
|
-
|
72
|
-
if res["Status"] == FlowStatus.READY:
|
73
|
-
# Update the status to Done before returning
|
74
|
-
await self.conn.execute(
|
75
|
-
update(DeviceFlows)
|
76
|
-
.where(
|
77
|
-
DeviceFlows.device_code
|
78
|
-
== hashlib.sha256(device_code.encode()).hexdigest()
|
79
|
-
)
|
80
|
-
.values(status=FlowStatus.DONE)
|
81
|
-
)
|
82
|
-
return res
|
83
|
-
|
84
|
-
if res["Status"] == FlowStatus.DONE:
|
85
|
-
raise AuthorizationError("Code was already used")
|
86
|
-
|
87
|
-
if res["Status"] == FlowStatus.PENDING:
|
88
|
-
raise PendingAuthorizationError()
|
59
|
+
return dict((await self.conn.execute(stmt)).one()._mapping)
|
89
60
|
|
90
|
-
|
61
|
+
async def update_device_flow_status(
|
62
|
+
self, device_code: str, status: FlowStatus
|
63
|
+
) -> None:
|
64
|
+
stmt = update(DeviceFlows).where(
|
65
|
+
DeviceFlows.device_code == hash(device_code),
|
66
|
+
)
|
67
|
+
stmt = stmt.values(status=status)
|
68
|
+
await self.conn.execute(stmt)
|
91
69
|
|
92
70
|
async def device_flow_insert_id_token(
|
93
71
|
self, user_code: str, id_token: dict[str, str], max_validity: int
|
@@ -121,7 +99,7 @@ class AuthDB(BaseSQLDB):
|
|
121
99
|
device_code = secrets.token_urlsafe()
|
122
100
|
|
123
101
|
# Hash the the device_code to avoid leaking information
|
124
|
-
hashed_device_code =
|
102
|
+
hashed_device_code = hash(device_code)
|
125
103
|
|
126
104
|
stmt = insert(DeviceFlows).values(
|
127
105
|
client_id=client_id,
|
@@ -171,7 +149,7 @@ class AuthDB(BaseSQLDB):
|
|
171
149
|
"""
|
172
150
|
# Hash the code to avoid leaking information
|
173
151
|
code = secrets.token_urlsafe()
|
174
|
-
hashed_code =
|
152
|
+
hashed_code = hash(code)
|
175
153
|
|
176
154
|
stmt = update(AuthorizationFlows)
|
177
155
|
|
@@ -193,7 +171,8 @@ class AuthDB(BaseSQLDB):
|
|
193
171
|
return code, row.RedirectURI
|
194
172
|
|
195
173
|
async def get_authorization_flow(self, code: str, max_validity: int):
|
196
|
-
|
174
|
+
"""Get the authorization flow details based on the code."""
|
175
|
+
hashed_code = hash(code)
|
197
176
|
# The with_for_update
|
198
177
|
# prevents that the token is retrieved
|
199
178
|
# multiple time concurrently
|
@@ -203,54 +182,41 @@ class AuthDB(BaseSQLDB):
|
|
203
182
|
AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
|
204
183
|
)
|
205
184
|
|
206
|
-
|
207
|
-
|
208
|
-
if res["Status"] == FlowStatus.READY:
|
209
|
-
# Update the status to Done before returning
|
210
|
-
await self.conn.execute(
|
211
|
-
update(AuthorizationFlows)
|
212
|
-
.where(AuthorizationFlows.code == hashed_code)
|
213
|
-
.values(status=FlowStatus.DONE)
|
214
|
-
)
|
215
|
-
|
216
|
-
return res
|
185
|
+
return dict((await self.conn.execute(stmt)).one()._mapping)
|
217
186
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
187
|
+
async def update_authorization_flow_status(
|
188
|
+
self, code: str, status: FlowStatus
|
189
|
+
) -> None:
|
190
|
+
"""Update the status of an authorization flow based on the code."""
|
191
|
+
hashed_code = hash(code)
|
192
|
+
await self.conn.execute(
|
193
|
+
update(AuthorizationFlows)
|
194
|
+
.where(AuthorizationFlows.code == hashed_code)
|
195
|
+
.values(status=status)
|
196
|
+
)
|
222
197
|
|
223
198
|
async def insert_refresh_token(
|
224
199
|
self,
|
200
|
+
jti: UUID,
|
225
201
|
subject: str,
|
226
202
|
preferred_username: str,
|
227
203
|
scope: str,
|
228
|
-
) ->
|
204
|
+
) -> None:
|
229
205
|
"""Insert a refresh token in the DB as well as user attributes
|
230
206
|
required to generate access tokens.
|
231
207
|
"""
|
232
|
-
# Generate a JWT ID
|
233
|
-
jti = str(uuid4())
|
234
|
-
|
235
208
|
# Insert values into the DB
|
236
209
|
stmt = insert(RefreshTokens).values(
|
237
|
-
jti=jti,
|
210
|
+
jti=str(jti),
|
238
211
|
sub=subject,
|
239
212
|
preferred_username=preferred_username,
|
240
213
|
scope=scope,
|
241
214
|
)
|
242
215
|
await self.conn.execute(stmt)
|
243
216
|
|
244
|
-
|
245
|
-
stmt = select(RefreshTokens.creation_time)
|
246
|
-
stmt = stmt.where(RefreshTokens.jti == jti)
|
247
|
-
row = (await self.conn.execute(stmt)).one()
|
248
|
-
|
249
|
-
# Return the JWT ID and the creation time
|
250
|
-
return jti, row.CreationTime
|
251
|
-
|
252
|
-
async def get_refresh_token(self, jti: str) -> dict:
|
217
|
+
async def get_refresh_token(self, jti: UUID) -> dict:
|
253
218
|
"""Get refresh token details bound to a given JWT ID."""
|
219
|
+
jti = str(jti)
|
254
220
|
# The with_for_update
|
255
221
|
# prevents that the token is retrieved
|
256
222
|
# multiple time concurrently
|
@@ -260,8 +226,8 @@ class AuthDB(BaseSQLDB):
|
|
260
226
|
)
|
261
227
|
try:
|
262
228
|
res = dict((await self.conn.execute(stmt)).one()._mapping)
|
263
|
-
except NoResultFound:
|
264
|
-
|
229
|
+
except NoResultFound as e:
|
230
|
+
raise TokenNotFoundError(jti) from e
|
265
231
|
|
266
232
|
return res
|
267
233
|
|
@@ -285,11 +251,11 @@ class AuthDB(BaseSQLDB):
|
|
285
251
|
|
286
252
|
return refresh_tokens
|
287
253
|
|
288
|
-
async def revoke_refresh_token(self, jti:
|
254
|
+
async def revoke_refresh_token(self, jti: UUID):
|
289
255
|
"""Revoke a token given by its JWT ID."""
|
290
256
|
await self.conn.execute(
|
291
257
|
update(RefreshTokens)
|
292
|
-
.where(RefreshTokens.jti == jti)
|
258
|
+
.where(RefreshTokens.jti == str(jti))
|
293
259
|
.values(status=RefreshTokenStatus.REVOKED)
|
294
260
|
)
|
295
261
|
|
@@ -4,14 +4,12 @@ from datetime import datetime, timezone
|
|
4
4
|
from typing import TYPE_CHECKING, Any
|
5
5
|
|
6
6
|
from sqlalchemy import bindparam, case, delete, func, insert, select, update
|
7
|
-
from sqlalchemy.exc import IntegrityError, NoResultFound
|
8
7
|
|
9
8
|
if TYPE_CHECKING:
|
10
9
|
from sqlalchemy.sql.elements import BindParameter
|
11
10
|
|
12
|
-
from diracx.core.exceptions import InvalidQueryError
|
11
|
+
from diracx.core.exceptions import InvalidQueryError
|
13
12
|
from diracx.core.models import (
|
14
|
-
LimitedJobStatusReturn,
|
15
13
|
SearchSpec,
|
16
14
|
SortSpec,
|
17
15
|
)
|
@@ -46,6 +44,7 @@ class JobDB(BaseSQLDB):
|
|
46
44
|
jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"]
|
47
45
|
|
48
46
|
async def summary(self, group_by, search) -> list[dict[str, str | int]]:
|
47
|
+
"""Get a summary of the jobs."""
|
49
48
|
columns = _get_columns(Jobs.__table__, group_by)
|
50
49
|
|
51
50
|
stmt = select(*columns, func.count(Jobs.job_id).label("count"))
|
@@ -69,6 +68,7 @@ class JobDB(BaseSQLDB):
|
|
69
68
|
per_page: int = 100,
|
70
69
|
page: int | None = None,
|
71
70
|
) -> tuple[int, list[dict[Any, Any]]]:
|
71
|
+
"""Search for jobs in the database."""
|
72
72
|
# Find which columns to select
|
73
73
|
columns = _get_columns(Jobs.__table__, parameters)
|
74
74
|
|
@@ -98,7 +98,24 @@ class JobDB(BaseSQLDB):
|
|
98
98
|
dict(row._mapping) async for row in (await self.conn.stream(stmt))
|
99
99
|
]
|
100
100
|
|
101
|
+
async def create_job(self, compressed_original_jdl: str):
|
102
|
+
"""Used to insert a new job with original JDL. Returns inserted job id."""
|
103
|
+
result = await self.conn.execute(
|
104
|
+
JobJDLs.__table__.insert().values(
|
105
|
+
JDL="",
|
106
|
+
JobRequirements="",
|
107
|
+
OriginalJDL=compressed_original_jdl,
|
108
|
+
)
|
109
|
+
)
|
110
|
+
return result.lastrowid
|
111
|
+
|
112
|
+
async def delete_jobs(self, job_ids: list[int]):
|
113
|
+
"""Delete jobs from the database."""
|
114
|
+
stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids))
|
115
|
+
await self.conn.execute(stmt)
|
116
|
+
|
101
117
|
async def insert_input_data(self, lfns: dict[int, list[str]]):
|
118
|
+
"""Insert input data for jobs."""
|
102
119
|
await self.conn.execute(
|
103
120
|
InputData.__table__.insert(),
|
104
121
|
[
|
@@ -111,27 +128,8 @@ class JobDB(BaseSQLDB):
|
|
111
128
|
],
|
112
129
|
)
|
113
130
|
|
114
|
-
async def set_job_attributes(self, job_id, job_data):
|
115
|
-
"""TODO: add myDate and force parameters."""
|
116
|
-
if "Status" in job_data:
|
117
|
-
job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)}
|
118
|
-
stmt = update(Jobs).where(Jobs.job_id == job_id).values(job_data)
|
119
|
-
await self.conn.execute(stmt)
|
120
|
-
|
121
|
-
async def create_job(self, original_jdl):
|
122
|
-
"""Used to insert a new job with original JDL. Returns inserted job id."""
|
123
|
-
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
|
124
|
-
|
125
|
-
result = await self.conn.execute(
|
126
|
-
JobJDLs.__table__.insert().values(
|
127
|
-
JDL="",
|
128
|
-
JobRequirements="",
|
129
|
-
OriginalJDL=compressJDL(original_jdl),
|
130
|
-
)
|
131
|
-
)
|
132
|
-
return result.lastrowid
|
133
|
-
|
134
131
|
async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
|
132
|
+
"""Insert the job attributes."""
|
135
133
|
await self.conn.execute(
|
136
134
|
Jobs.__table__.insert(),
|
137
135
|
[
|
@@ -145,8 +143,6 @@ class JobDB(BaseSQLDB):
|
|
145
143
|
|
146
144
|
async def update_job_jdls(self, jdls_to_update: dict[int, str]):
|
147
145
|
"""Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
|
148
|
-
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
|
149
|
-
|
150
146
|
await self.conn.execute(
|
151
147
|
JobJDLs.__table__.update().where(
|
152
148
|
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
|
@@ -154,67 +150,15 @@ class JobDB(BaseSQLDB):
|
|
154
150
|
[
|
155
151
|
{
|
156
152
|
"b_JobID": job_id,
|
157
|
-
"JDL":
|
153
|
+
"JDL": compressed_jdl,
|
158
154
|
}
|
159
|
-
for job_id,
|
155
|
+
for job_id, compressed_jdl in jdls_to_update.items()
|
160
156
|
],
|
161
157
|
)
|
162
158
|
|
163
|
-
async def
|
164
|
-
|
165
|
-
|
166
|
-
class_ad_job,
|
167
|
-
class_ad_req,
|
168
|
-
owner,
|
169
|
-
owner_group,
|
170
|
-
job_attrs,
|
171
|
-
vo,
|
172
|
-
):
|
173
|
-
"""Check Consistency of Submitted JDL and set some defaults
|
174
|
-
Prepare subJDL with Job Requirements.
|
175
|
-
"""
|
176
|
-
from DIRAC.Core.Utilities.DErrno import EWMSSUBM, cmpError
|
177
|
-
from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
|
178
|
-
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import (
|
179
|
-
checkAndPrepareJob,
|
180
|
-
)
|
181
|
-
|
182
|
-
ret_val = checkAndPrepareJob(
|
183
|
-
job_id,
|
184
|
-
class_ad_job,
|
185
|
-
class_ad_req,
|
186
|
-
owner,
|
187
|
-
owner_group,
|
188
|
-
job_attrs,
|
189
|
-
vo,
|
190
|
-
)
|
191
|
-
|
192
|
-
if not ret_val["OK"]:
|
193
|
-
if cmpError(ret_val, EWMSSUBM):
|
194
|
-
await self.set_job_attributes(job_id, job_attrs)
|
195
|
-
|
196
|
-
returnValueOrRaise(ret_val)
|
197
|
-
|
198
|
-
async def set_job_jdl(self, job_id, jdl):
|
199
|
-
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
|
200
|
-
|
201
|
-
stmt = (
|
202
|
-
update(JobJDLs).where(JobJDLs.job_id == job_id).values(JDL=compressJDL(jdl))
|
203
|
-
)
|
204
|
-
await self.conn.execute(stmt)
|
205
|
-
|
206
|
-
async def set_job_jdl_bulk(self, jdls):
|
207
|
-
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
|
208
|
-
|
209
|
-
await self.conn.execute(
|
210
|
-
JobJDLs.__table__.update().where(
|
211
|
-
JobJDLs.__table__.c.JobID == bindparam("b_JobID")
|
212
|
-
),
|
213
|
-
[{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()],
|
214
|
-
)
|
215
|
-
|
216
|
-
async def set_job_attributes_bulk(self, job_data):
|
217
|
-
"""TODO: add myDate and force parameters."""
|
159
|
+
async def set_job_attributes(self, job_data):
|
160
|
+
"""Update the parameters of the given jobs."""
|
161
|
+
# TODO: add myDate and force parameters.
|
218
162
|
for job_id in job_data.keys():
|
219
163
|
if "Status" in job_data[job_id]:
|
220
164
|
job_data[job_id].update(
|
@@ -240,11 +184,8 @@ class JobDB(BaseSQLDB):
|
|
240
184
|
)
|
241
185
|
await self.conn.execute(stmt)
|
242
186
|
|
243
|
-
async def get_job_jdls(
|
244
|
-
|
245
|
-
) -> dict[int | str, str]:
|
246
|
-
from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
|
247
|
-
|
187
|
+
async def get_job_jdls(self, job_ids, original: bool = False) -> dict[int, str]:
|
188
|
+
"""Get the JDLs for the given jobs."""
|
248
189
|
if original:
|
249
190
|
stmt = select(JobJDLs.job_id, JobJDLs.original_jdl).where(
|
250
191
|
JobJDLs.job_id.in_(job_ids)
|
@@ -254,37 +195,9 @@ class JobDB(BaseSQLDB):
|
|
254
195
|
JobJDLs.job_id.in_(job_ids)
|
255
196
|
)
|
256
197
|
|
257
|
-
return {
|
258
|
-
jobid: extractJDL(jdl)
|
259
|
-
for jobid, jdl in (await self.conn.execute(stmt))
|
260
|
-
if jdl
|
261
|
-
}
|
262
|
-
|
263
|
-
async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
|
264
|
-
try:
|
265
|
-
stmt = select(
|
266
|
-
Jobs.status, Jobs.minor_status, Jobs.application_status
|
267
|
-
).where(Jobs.job_id == job_id)
|
268
|
-
return LimitedJobStatusReturn(
|
269
|
-
**dict((await self.conn.execute(stmt)).one()._mapping)
|
270
|
-
)
|
271
|
-
except NoResultFound as e:
|
272
|
-
raise JobNotFoundError(job_id) from e
|
273
|
-
|
274
|
-
async def set_job_command(self, job_id: int, command: str, arguments: str = ""):
|
275
|
-
"""Store a command to be passed to the job together with the next heart beat."""
|
276
|
-
try:
|
277
|
-
stmt = insert(JobCommands).values(
|
278
|
-
JobID=job_id,
|
279
|
-
Command=command,
|
280
|
-
Arguments=arguments,
|
281
|
-
ReceptionTime=datetime.now(tz=timezone.utc),
|
282
|
-
)
|
283
|
-
await self.conn.execute(stmt)
|
284
|
-
except IntegrityError as e:
|
285
|
-
raise JobNotFoundError(job_id) from e
|
198
|
+
return {jobid: jdl for jobid, jdl in (await self.conn.execute(stmt)) if jdl}
|
286
199
|
|
287
|
-
async def
|
200
|
+
async def set_job_commands(self, commands: list[tuple[int, str, str]]):
|
288
201
|
"""Store a command to be passed to the job together with the next heart beat."""
|
289
202
|
await self.conn.execute(
|
290
203
|
insert(JobCommands),
|
@@ -298,12 +211,6 @@ class JobDB(BaseSQLDB):
|
|
298
211
|
for job_id, command, arguments in commands
|
299
212
|
],
|
300
213
|
)
|
301
|
-
# FIXME handle IntegrityError
|
302
|
-
|
303
|
-
async def delete_jobs(self, job_ids: list[int]):
|
304
|
-
"""Delete jobs from the database."""
|
305
|
-
stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids))
|
306
|
-
await self.conn.execute(stmt)
|
307
214
|
|
308
215
|
async def set_properties(
|
309
216
|
self, properties: dict[int, dict[str, Any]], update_timestamp: bool = False
|
@@ -1,20 +1,18 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import time
|
4
|
-
from datetime import
|
4
|
+
from datetime import timezone
|
5
5
|
from typing import TYPE_CHECKING
|
6
6
|
|
7
|
-
from
|
8
|
-
from sqlalchemy import delete, func, insert, select
|
7
|
+
from sqlalchemy import delete, func, select
|
9
8
|
|
10
9
|
if TYPE_CHECKING:
|
11
10
|
pass
|
12
11
|
|
13
12
|
from collections import defaultdict
|
14
13
|
|
15
|
-
from diracx.core.exceptions import JobNotFoundError
|
16
14
|
from diracx.core.models import (
|
17
|
-
|
15
|
+
JobLoggingRecord,
|
18
16
|
JobStatusReturn,
|
19
17
|
)
|
20
18
|
|
@@ -27,61 +25,12 @@ from .schema import (
|
|
27
25
|
MAGIC_EPOC_NUMBER = 1270000000
|
28
26
|
|
29
27
|
|
30
|
-
class JobLoggingRecord(BaseModel):
|
31
|
-
job_id: int
|
32
|
-
status: JobStatus
|
33
|
-
minor_status: str
|
34
|
-
application_status: str
|
35
|
-
date: datetime
|
36
|
-
source: str
|
37
|
-
|
38
|
-
|
39
28
|
class JobLoggingDB(BaseSQLDB):
|
40
29
|
"""Frontend for the JobLoggingDB. Provides the ability to store changes with timestamps."""
|
41
30
|
|
42
31
|
metadata = JobLoggingDBBase.metadata
|
43
32
|
|
44
|
-
async def
|
45
|
-
self,
|
46
|
-
job_id: int,
|
47
|
-
status: JobStatus,
|
48
|
-
minor_status: str,
|
49
|
-
application_status: str,
|
50
|
-
date: datetime,
|
51
|
-
source: str,
|
52
|
-
):
|
53
|
-
"""Add a new entry to the JobLoggingDB table. One, two or all the three status
|
54
|
-
components (status, minorStatus, applicationStatus) can be specified.
|
55
|
-
Optionally the time stamp of the status can
|
56
|
-
be provided in a form of a string in a format '%Y-%m-%d %H:%M:%S' or
|
57
|
-
as datetime.datetime object. If the time stamp is not provided the current
|
58
|
-
UTC time is used.
|
59
|
-
"""
|
60
|
-
# First, fetch the maximum seq_num for the given job_id
|
61
|
-
seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)).where(
|
62
|
-
LoggingInfo.job_id == job_id
|
63
|
-
)
|
64
|
-
seqnum = await self.conn.scalar(seqnum_stmt)
|
65
|
-
|
66
|
-
epoc = (
|
67
|
-
time.mktime(date.timetuple())
|
68
|
-
+ date.microsecond / 1000000.0
|
69
|
-
- MAGIC_EPOC_NUMBER
|
70
|
-
)
|
71
|
-
|
72
|
-
stmt = insert(LoggingInfo).values(
|
73
|
-
job_id=int(job_id),
|
74
|
-
seq_num=seqnum,
|
75
|
-
status=status,
|
76
|
-
minor_status=minor_status,
|
77
|
-
application_status=application_status[:255],
|
78
|
-
status_time=date,
|
79
|
-
status_time_order=epoc,
|
80
|
-
source=source[:32],
|
81
|
-
)
|
82
|
-
await self.conn.execute(stmt)
|
83
|
-
|
84
|
-
async def bulk_insert_record(
|
33
|
+
async def insert_records(
|
85
34
|
self,
|
86
35
|
records: list[JobLoggingRecord],
|
87
36
|
):
|
@@ -103,15 +52,20 @@ class JobLoggingDB(BaseSQLDB):
|
|
103
52
|
.group_by(LoggingInfo.job_id)
|
104
53
|
)
|
105
54
|
|
106
|
-
|
55
|
+
seqnums = {
|
56
|
+
jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))
|
57
|
+
}
|
107
58
|
# IF a seqnum is not found, then assume it does not exist and the first sequence number is 1.
|
108
59
|
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements
|
109
|
-
|
110
|
-
|
111
|
-
|
60
|
+
values = []
|
61
|
+
for record in records:
|
62
|
+
if record.job_id not in seqnums:
|
63
|
+
seqnums[record.job_id] = 1
|
64
|
+
|
65
|
+
values.append(
|
112
66
|
{
|
113
67
|
"JobID": record.job_id,
|
114
|
-
"SeqNum":
|
68
|
+
"SeqNum": seqnums[record.job_id],
|
115
69
|
"Status": record.status,
|
116
70
|
"MinorStatus": record.minor_status,
|
117
71
|
"ApplicationStatus": record.application_status[:255],
|
@@ -119,8 +73,12 @@ class JobLoggingDB(BaseSQLDB):
|
|
119
73
|
"StatusTimeOrder": get_epoc(record.date),
|
120
74
|
"StatusSource": record.source[:32],
|
121
75
|
}
|
122
|
-
|
123
|
-
]
|
76
|
+
)
|
77
|
+
seqnums[record.job_id] = seqnums[record.job_id] + 1
|
78
|
+
|
79
|
+
await self.conn.execute(
|
80
|
+
LoggingInfo.__table__.insert(),
|
81
|
+
values,
|
124
82
|
)
|
125
83
|
|
126
84
|
async def get_records(self, job_ids: list[int]) -> dict[int, JobStatusReturn]:
|
@@ -201,25 +159,7 @@ class JobLoggingDB(BaseSQLDB):
|
|
201
159
|
stmt = delete(LoggingInfo).where(LoggingInfo.job_id.in_(job_ids))
|
202
160
|
await self.conn.execute(stmt)
|
203
161
|
|
204
|
-
async def get_wms_time_stamps(self,
|
205
|
-
"""Get TimeStamps for job MajorState transitions
|
206
|
-
return a {State:timestamp} dictionary.
|
207
|
-
"""
|
208
|
-
result = {}
|
209
|
-
stmt = select(
|
210
|
-
LoggingInfo.status,
|
211
|
-
LoggingInfo.status_time_order,
|
212
|
-
).where(LoggingInfo.job_id == job_id)
|
213
|
-
rows = await self.conn.execute(stmt)
|
214
|
-
if not rows.rowcount:
|
215
|
-
raise JobNotFoundError(job_id) from None
|
216
|
-
|
217
|
-
for event, etime in rows:
|
218
|
-
result[event] = str(etime + MAGIC_EPOC_NUMBER)
|
219
|
-
|
220
|
-
return result
|
221
|
-
|
222
|
-
async def get_wms_time_stamps_bulk(self, job_ids):
|
162
|
+
async def get_wms_time_stamps(self, job_ids):
|
223
163
|
"""Get TimeStamps for job MajorState transitions for multiple jobs at once
|
224
164
|
return a {JobID: {State:timestamp}} dictionary.
|
225
165
|
"""
|