diracx-db 0.0.1a23__py3-none-any.whl → 0.0.1a25__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diracx/db/exceptions.py +4 -1
- diracx/db/os/utils.py +3 -3
- diracx/db/sql/auth/db.py +9 -9
- diracx/db/sql/auth/schema.py +25 -23
- diracx/db/sql/dummy/db.py +2 -2
- diracx/db/sql/dummy/schema.py +8 -6
- diracx/db/sql/job/db.py +57 -54
- diracx/db/sql/job/schema.py +56 -54
- diracx/db/sql/job_logging/db.py +32 -32
- diracx/db/sql/job_logging/schema.py +12 -8
- diracx/db/sql/pilot_agents/schema.py +24 -22
- diracx/db/sql/sandbox_metadata/db.py +42 -40
- diracx/db/sql/sandbox_metadata/schema.py +5 -3
- diracx/db/sql/task_queue/db.py +3 -3
- diracx/db/sql/task_queue/schema.py +2 -0
- diracx/db/sql/utils/__init__.py +22 -451
- diracx/db/sql/utils/base.py +328 -0
- diracx/db/sql/utils/functions.py +105 -0
- diracx/db/sql/utils/job.py +58 -55
- diracx/db/sql/utils/types.py +43 -0
- {diracx_db-0.0.1a23.dist-info → diracx_db-0.0.1a25.dist-info}/METADATA +2 -2
- diracx_db-0.0.1a25.dist-info/RECORD +39 -0
- {diracx_db-0.0.1a23.dist-info → diracx_db-0.0.1a25.dist-info}/WHEEL +1 -1
- diracx_db-0.0.1a23.dist-info/RECORD +0 -36
- {diracx_db-0.0.1a23.dist-info → diracx_db-0.0.1a25.dist-info}/entry_points.txt +0 -0
- {diracx_db-0.0.1a23.dist-info → diracx_db-0.0.1a25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,328 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import re
|
7
|
+
from abc import ABCMeta
|
8
|
+
from collections.abc import AsyncIterator
|
9
|
+
from contextvars import ContextVar
|
10
|
+
from datetime import datetime
|
11
|
+
from typing import Self, cast
|
12
|
+
|
13
|
+
from pydantic import TypeAdapter
|
14
|
+
from sqlalchemy import DateTime, MetaData, select
|
15
|
+
from sqlalchemy.exc import OperationalError
|
16
|
+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
17
|
+
|
18
|
+
from diracx.core.exceptions import InvalidQueryError
|
19
|
+
from diracx.core.extensions import select_from_extension
|
20
|
+
from diracx.core.models import SortDirection
|
21
|
+
from diracx.core.settings import SqlalchemyDsn
|
22
|
+
from diracx.db.exceptions import DBUnavailableError
|
23
|
+
|
24
|
+
from .functions import date_trunc
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
class SQLDBError(Exception):
|
30
|
+
pass
|
31
|
+
|
32
|
+
|
33
|
+
class SQLDBUnavailableError(DBUnavailableError, SQLDBError):
|
34
|
+
"""Used whenever we encounter a problem with the B connection."""
|
35
|
+
|
36
|
+
|
37
|
+
class BaseSQLDB(metaclass=ABCMeta):
|
38
|
+
"""This should be the base class of all the SQL DiracX DBs.
|
39
|
+
|
40
|
+
The details covered here should be handled automatically by the service and
|
41
|
+
task machinery of DiracX and this documentation exists for informational
|
42
|
+
purposes.
|
43
|
+
|
44
|
+
The available databases are discovered by calling `BaseSQLDB.available_urls`.
|
45
|
+
This method returns a mapping of database names to connection URLs. The
|
46
|
+
available databases are determined by the `diracx.dbs.sql` entrypoint in the
|
47
|
+
`pyproject.toml` file and the connection URLs are taken from the environment
|
48
|
+
variables of the form `DIRACX_DB_URL_<db-name>`.
|
49
|
+
|
50
|
+
If extensions to DiracX are being used, there can be multiple implementations
|
51
|
+
of the same database. To list the available implementations use
|
52
|
+
`BaseSQLDB.available_implementations(db_name)`. The first entry in this list
|
53
|
+
will be the preferred implementation and it can be initialized by calling
|
54
|
+
it's `__init__` function with a URL perviously obtained from
|
55
|
+
`BaseSQLDB.available_urls`.
|
56
|
+
|
57
|
+
To control the lifetime of the SQLAlchemy engine used for connecting to the
|
58
|
+
database, which includes the connection pool, the `BaseSQLDB.engine_context`
|
59
|
+
asynchronous context manager should be entered. When inside this context
|
60
|
+
manager, the engine can be accessed with `BaseSQLDB.engine`.
|
61
|
+
|
62
|
+
Upon entering, the DB class can then be used as an asynchronous context
|
63
|
+
manager to enter transactions. If an exception is raised the transaction is
|
64
|
+
rolled back automatically. If the inner context exits peacefully, the
|
65
|
+
transaction is committed automatically. When inside this context manager,
|
66
|
+
the DB connection can be accessed with `BaseSQLDB.conn`.
|
67
|
+
|
68
|
+
For example:
|
69
|
+
|
70
|
+
```python
|
71
|
+
db_name = ...
|
72
|
+
url = BaseSQLDB.available_urls()[db_name]
|
73
|
+
MyDBClass = BaseSQLDB.available_implementations(db_name)[0]
|
74
|
+
|
75
|
+
db = MyDBClass(url)
|
76
|
+
async with db.engine_context():
|
77
|
+
async with db:
|
78
|
+
# Do something in the first transaction
|
79
|
+
# Commit will be called automatically
|
80
|
+
|
81
|
+
async with db:
|
82
|
+
# This transaction will be rolled back due to the exception
|
83
|
+
raise Exception(...)
|
84
|
+
```
|
85
|
+
"""
|
86
|
+
|
87
|
+
# engine: AsyncEngine
|
88
|
+
# TODO: Make metadata an abstract property
|
89
|
+
metadata: MetaData
|
90
|
+
|
91
|
+
def __init__(self, db_url: str) -> None:
|
92
|
+
# We use a ContextVar to make sure that self._conn
|
93
|
+
# is specific to each context, and avoid parallel
|
94
|
+
# route executions to overlap
|
95
|
+
self._conn: ContextVar[AsyncConnection | None] = ContextVar(
|
96
|
+
"_conn", default=None
|
97
|
+
)
|
98
|
+
self._db_url = db_url
|
99
|
+
self._engine: AsyncEngine | None = None
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
|
103
|
+
"""Return the available implementations of the DB in reverse priority order."""
|
104
|
+
db_classes: list[type[BaseSQLDB]] = [
|
105
|
+
entry_point.load()
|
106
|
+
for entry_point in select_from_extension(
|
107
|
+
group="diracx.db.sql", name=db_name
|
108
|
+
)
|
109
|
+
]
|
110
|
+
if not db_classes:
|
111
|
+
raise NotImplementedError(f"Could not find any matches for {db_name=}")
|
112
|
+
return db_classes
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def available_urls(cls) -> dict[str, str]:
|
116
|
+
"""Return a dict of available database urls.
|
117
|
+
|
118
|
+
The list of available URLs is determined by environment variables
|
119
|
+
prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
|
120
|
+
"""
|
121
|
+
db_urls: dict[str, str] = {}
|
122
|
+
for entry_point in select_from_extension(group="diracx.db.sql"):
|
123
|
+
db_name = entry_point.name
|
124
|
+
var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
|
125
|
+
if var_name in os.environ:
|
126
|
+
try:
|
127
|
+
db_url = os.environ[var_name]
|
128
|
+
if db_url == "sqlite+aiosqlite:///:memory:":
|
129
|
+
db_urls[db_name] = db_url
|
130
|
+
# pydantic does not allow for underscore in scheme
|
131
|
+
# so we do a special case
|
132
|
+
elif "_" in db_url.split(":")[0]:
|
133
|
+
# Validate the URL with a fake schema, and then store
|
134
|
+
# the original one
|
135
|
+
scheme_id = db_url.find(":")
|
136
|
+
fake_url = (
|
137
|
+
db_url[:scheme_id].replace("_", "-") + db_url[scheme_id:]
|
138
|
+
)
|
139
|
+
TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
|
140
|
+
db_urls[db_name] = db_url
|
141
|
+
|
142
|
+
else:
|
143
|
+
db_urls[db_name] = str(
|
144
|
+
TypeAdapter(SqlalchemyDsn).validate_python(db_url)
|
145
|
+
)
|
146
|
+
except Exception:
|
147
|
+
logger.error("Error loading URL for %s", db_name)
|
148
|
+
raise
|
149
|
+
return db_urls
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def transaction(cls) -> Self:
|
153
|
+
raise NotImplementedError("This should never be called")
|
154
|
+
|
155
|
+
@property
|
156
|
+
def engine(self) -> AsyncEngine:
|
157
|
+
"""The engine to use for database operations.
|
158
|
+
|
159
|
+
It is normally not necessary to use the engine directly, unless you are
|
160
|
+
doing something special, like writing a test fixture that gives you a db.
|
161
|
+
|
162
|
+
Requires that the engine_context has been entered.
|
163
|
+
"""
|
164
|
+
assert self._engine is not None, "engine_context must be entered"
|
165
|
+
return self._engine
|
166
|
+
|
167
|
+
@contextlib.asynccontextmanager
|
168
|
+
async def engine_context(self) -> AsyncIterator[None]:
|
169
|
+
"""Context manage to manage the engine lifecycle.
|
170
|
+
|
171
|
+
This is called once at the application startup (see ``lifetime_functions``).
|
172
|
+
"""
|
173
|
+
assert self._engine is None, "engine_context cannot be nested"
|
174
|
+
|
175
|
+
# Set the pool_recycle to 30mn
|
176
|
+
# That should prevent the problem of MySQL expiring connection
|
177
|
+
# after 60mn by default
|
178
|
+
engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
|
179
|
+
self._engine = engine
|
180
|
+
try:
|
181
|
+
yield
|
182
|
+
finally:
|
183
|
+
self._engine = None
|
184
|
+
await engine.dispose()
|
185
|
+
|
186
|
+
@property
|
187
|
+
def conn(self) -> AsyncConnection:
|
188
|
+
if self._conn.get() is None:
|
189
|
+
raise RuntimeError(f"{self.__class__} was used before entering")
|
190
|
+
return cast(AsyncConnection, self._conn.get())
|
191
|
+
|
192
|
+
async def __aenter__(self) -> Self:
|
193
|
+
"""Create a connection.
|
194
|
+
|
195
|
+
This is called by the Dependency mechanism (see ``db_transaction``),
|
196
|
+
It will create a new connection/transaction for each route call.
|
197
|
+
"""
|
198
|
+
assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
|
199
|
+
try:
|
200
|
+
self._conn.set(await self.engine.connect().__aenter__())
|
201
|
+
except Exception as e:
|
202
|
+
raise SQLDBUnavailableError(
|
203
|
+
f"Cannot connect to {self.__class__.__name__}"
|
204
|
+
) from e
|
205
|
+
|
206
|
+
return self
|
207
|
+
|
208
|
+
async def __aexit__(self, exc_type, exc, tb):
|
209
|
+
"""This is called when exiting a route.
|
210
|
+
|
211
|
+
If there was no exception, the changes in the DB are committed.
|
212
|
+
Otherwise, they are rolled back.
|
213
|
+
"""
|
214
|
+
if exc_type is None:
|
215
|
+
await self._conn.get().commit()
|
216
|
+
await self._conn.get().__aexit__(exc_type, exc, tb)
|
217
|
+
self._conn.set(None)
|
218
|
+
|
219
|
+
async def ping(self):
|
220
|
+
"""Check whether the connection to the DB is still working.
|
221
|
+
|
222
|
+
We could enable the ``pre_ping`` in the engine, but this would be ran at
|
223
|
+
every query.
|
224
|
+
"""
|
225
|
+
try:
|
226
|
+
await self.conn.scalar(select(1))
|
227
|
+
except OperationalError as e:
|
228
|
+
raise SQLDBUnavailableError("Cannot ping the DB") from e
|
229
|
+
|
230
|
+
|
231
|
+
def find_time_resolution(value):
|
232
|
+
if isinstance(value, datetime):
|
233
|
+
return None, value
|
234
|
+
if match := re.fullmatch(
|
235
|
+
r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{6}Z?)?)?)?)?)?)?", value
|
236
|
+
):
|
237
|
+
if match.group(6):
|
238
|
+
precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
|
239
|
+
elif match.group(5):
|
240
|
+
precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
|
241
|
+
elif match.group(3):
|
242
|
+
precision, pattern = "HOUR", r"\1-\2-\3 \4"
|
243
|
+
elif match.group(2):
|
244
|
+
precision, pattern = "DAY", r"\1-\2-\3"
|
245
|
+
elif match.group(1):
|
246
|
+
precision, pattern = "MONTH", r"\1-\2"
|
247
|
+
else:
|
248
|
+
precision, pattern = "YEAR", r"\1"
|
249
|
+
return (
|
250
|
+
precision,
|
251
|
+
re.sub(
|
252
|
+
r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{6})?Z?$",
|
253
|
+
pattern,
|
254
|
+
value,
|
255
|
+
),
|
256
|
+
)
|
257
|
+
|
258
|
+
raise InvalidQueryError(f"Cannot parse {value=}")
|
259
|
+
|
260
|
+
|
261
|
+
def apply_search_filters(column_mapping, stmt, search):
|
262
|
+
for query in search:
|
263
|
+
try:
|
264
|
+
column = column_mapping(query["parameter"])
|
265
|
+
except KeyError as e:
|
266
|
+
raise InvalidQueryError(f"Unknown column {query['parameter']}") from e
|
267
|
+
|
268
|
+
if isinstance(column.type, DateTime):
|
269
|
+
if "value" in query and isinstance(query["value"], str):
|
270
|
+
resolution, value = find_time_resolution(query["value"])
|
271
|
+
if resolution:
|
272
|
+
column = date_trunc(column, time_resolution=resolution)
|
273
|
+
query["value"] = value
|
274
|
+
|
275
|
+
if query.get("values"):
|
276
|
+
resolutions, values = zip(
|
277
|
+
*map(find_time_resolution, query.get("values"))
|
278
|
+
)
|
279
|
+
if len(set(resolutions)) != 1:
|
280
|
+
raise InvalidQueryError(
|
281
|
+
f"Cannot mix different time resolutions in {query=}"
|
282
|
+
)
|
283
|
+
if resolution := resolutions[0]:
|
284
|
+
column = date_trunc(column, time_resolution=resolution)
|
285
|
+
query["values"] = values
|
286
|
+
|
287
|
+
if query["operator"] == "eq":
|
288
|
+
expr = column == query["value"]
|
289
|
+
elif query["operator"] == "neq":
|
290
|
+
expr = column != query["value"]
|
291
|
+
elif query["operator"] == "gt":
|
292
|
+
expr = column > query["value"]
|
293
|
+
elif query["operator"] == "lt":
|
294
|
+
expr = column < query["value"]
|
295
|
+
elif query["operator"] == "in":
|
296
|
+
expr = column.in_(query["values"])
|
297
|
+
elif query["operator"] == "not in":
|
298
|
+
expr = column.notin_(query["values"])
|
299
|
+
elif query["operator"] in "like":
|
300
|
+
expr = column.like(query["value"])
|
301
|
+
elif query["operator"] in "ilike":
|
302
|
+
expr = column.ilike(query["value"])
|
303
|
+
else:
|
304
|
+
raise InvalidQueryError(f"Unknown filter {query=}")
|
305
|
+
stmt = stmt.where(expr)
|
306
|
+
return stmt
|
307
|
+
|
308
|
+
|
309
|
+
def apply_sort_constraints(column_mapping, stmt, sorts):
|
310
|
+
sort_columns = []
|
311
|
+
for sort in sorts or []:
|
312
|
+
try:
|
313
|
+
column = column_mapping(sort["parameter"])
|
314
|
+
except KeyError as e:
|
315
|
+
raise InvalidQueryError(
|
316
|
+
f"Cannot sort by {sort['parameter']}: unknown column"
|
317
|
+
) from e
|
318
|
+
sorted_column = None
|
319
|
+
if sort["direction"] == SortDirection.ASC:
|
320
|
+
sorted_column = column.asc()
|
321
|
+
elif sort["direction"] == SortDirection.DESC:
|
322
|
+
sorted_column = column.desc()
|
323
|
+
else:
|
324
|
+
raise InvalidQueryError(f"Unknown sort {sort['direction']=}")
|
325
|
+
sort_columns.append(sorted_column)
|
326
|
+
if sort_columns:
|
327
|
+
stmt = stmt.order_by(*sort_columns)
|
328
|
+
return stmt
|
@@ -0,0 +1,105 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from datetime import datetime, timedelta, timezone
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
from sqlalchemy import DateTime, func
|
7
|
+
from sqlalchemy.ext.compiler import compiles
|
8
|
+
from sqlalchemy.sql import expression
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from sqlalchemy.types import TypeEngine
|
12
|
+
|
13
|
+
|
14
|
+
class utcnow(expression.FunctionElement): # noqa: N801
|
15
|
+
type: TypeEngine = DateTime()
|
16
|
+
inherit_cache: bool = True
|
17
|
+
|
18
|
+
|
19
|
+
@compiles(utcnow, "postgresql")
|
20
|
+
def pg_utcnow(element, compiler, **kw) -> str:
|
21
|
+
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
|
22
|
+
|
23
|
+
|
24
|
+
@compiles(utcnow, "mssql")
|
25
|
+
def ms_utcnow(element, compiler, **kw) -> str:
|
26
|
+
return "GETUTCDATE()"
|
27
|
+
|
28
|
+
|
29
|
+
@compiles(utcnow, "mysql")
|
30
|
+
def mysql_utcnow(element, compiler, **kw) -> str:
|
31
|
+
return "(UTC_TIMESTAMP)"
|
32
|
+
|
33
|
+
|
34
|
+
@compiles(utcnow, "sqlite")
|
35
|
+
def sqlite_utcnow(element, compiler, **kw) -> str:
|
36
|
+
return "DATETIME('now')"
|
37
|
+
|
38
|
+
|
39
|
+
class date_trunc(expression.FunctionElement): # noqa: N801
|
40
|
+
"""Sqlalchemy function to truncate a date to a given resolution.
|
41
|
+
|
42
|
+
Primarily used to be able to query for a specific resolution of a date e.g.
|
43
|
+
|
44
|
+
select * from table where date_trunc('day', date_column) = '2021-01-01'
|
45
|
+
select * from table where date_trunc('year', date_column) = '2021'
|
46
|
+
select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
|
47
|
+
"""
|
48
|
+
|
49
|
+
type = DateTime()
|
50
|
+
inherit_cache = True
|
51
|
+
|
52
|
+
def __init__(self, *args, time_resolution, **kwargs) -> None:
|
53
|
+
super().__init__(*args, **kwargs)
|
54
|
+
self._time_resolution = time_resolution
|
55
|
+
|
56
|
+
|
57
|
+
@compiles(date_trunc, "postgresql")
|
58
|
+
def pg_date_trunc(element, compiler, **kw):
|
59
|
+
res = {
|
60
|
+
"SECOND": "second",
|
61
|
+
"MINUTE": "minute",
|
62
|
+
"HOUR": "hour",
|
63
|
+
"DAY": "day",
|
64
|
+
"MONTH": "month",
|
65
|
+
"YEAR": "year",
|
66
|
+
}[element._time_resolution]
|
67
|
+
return f"date_trunc('{res}', {compiler.process(element.clauses)})"
|
68
|
+
|
69
|
+
|
70
|
+
@compiles(date_trunc, "mysql")
|
71
|
+
def mysql_date_trunc(element, compiler, **kw):
|
72
|
+
pattern = {
|
73
|
+
"SECOND": "%Y-%m-%d %H:%i:%S",
|
74
|
+
"MINUTE": "%Y-%m-%d %H:%i",
|
75
|
+
"HOUR": "%Y-%m-%d %H",
|
76
|
+
"DAY": "%Y-%m-%d",
|
77
|
+
"MONTH": "%Y-%m",
|
78
|
+
"YEAR": "%Y",
|
79
|
+
}[element._time_resolution]
|
80
|
+
|
81
|
+
(dt_col,) = list(element.clauses)
|
82
|
+
return compiler.process(func.date_format(dt_col, pattern))
|
83
|
+
|
84
|
+
|
85
|
+
@compiles(date_trunc, "sqlite")
|
86
|
+
def sqlite_date_trunc(element, compiler, **kw):
|
87
|
+
pattern = {
|
88
|
+
"SECOND": "%Y-%m-%d %H:%M:%S",
|
89
|
+
"MINUTE": "%Y-%m-%d %H:%M",
|
90
|
+
"HOUR": "%Y-%m-%d %H",
|
91
|
+
"DAY": "%Y-%m-%d",
|
92
|
+
"MONTH": "%Y-%m",
|
93
|
+
"YEAR": "%Y",
|
94
|
+
}[element._time_resolution]
|
95
|
+
(dt_col,) = list(element.clauses)
|
96
|
+
return compiler.process(
|
97
|
+
func.strftime(
|
98
|
+
pattern,
|
99
|
+
dt_col,
|
100
|
+
)
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
def substract_date(**kwargs: float) -> datetime:
|
105
|
+
return datetime.now(tz=timezone.utc) - timedelta(**kwargs)
|
diracx/db/sql/utils/job.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
from collections import defaultdict
|
3
5
|
from copy import deepcopy
|
@@ -49,7 +51,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
49
51
|
async with asyncio.TaskGroup() as tg:
|
50
52
|
for job in jobs:
|
51
53
|
original_jdl = deepcopy(job.jdl)
|
52
|
-
|
54
|
+
job_manifest = returnValueOrRaise(
|
53
55
|
checkAndAddOwner(original_jdl, job.owner, job.owner_group)
|
54
56
|
)
|
55
57
|
|
@@ -60,13 +62,13 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
60
62
|
original_jdls.append(
|
61
63
|
(
|
62
64
|
original_jdl,
|
63
|
-
|
65
|
+
job_manifest,
|
64
66
|
tg.create_task(job_db.create_job(original_jdl)),
|
65
67
|
)
|
66
68
|
)
|
67
69
|
|
68
70
|
async with asyncio.TaskGroup() as tg:
|
69
|
-
for job, (original_jdl,
|
71
|
+
for job, (original_jdl, job_manifest_, job_id_task) in zip(jobs, original_jdls):
|
70
72
|
job_id = job_id_task.result()
|
71
73
|
job_attrs = {
|
72
74
|
"JobID": job_id,
|
@@ -77,16 +79,16 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
77
79
|
"VO": job.vo,
|
78
80
|
}
|
79
81
|
|
80
|
-
|
82
|
+
job_manifest_.setOption("JobID", job_id)
|
81
83
|
|
82
84
|
# 2.- Check JDL and Prepare DIRAC JDL
|
83
|
-
|
85
|
+
job_jdl = job_manifest_.dumpAsJDL()
|
84
86
|
|
85
87
|
# Replace the JobID placeholder if any
|
86
|
-
if
|
87
|
-
|
88
|
+
if job_jdl.find("%j") != -1:
|
89
|
+
job_jdl = job_jdl.replace("%j", str(job_id))
|
88
90
|
|
89
|
-
class_ad_job = ClassAd(
|
91
|
+
class_ad_job = ClassAd(job_jdl)
|
90
92
|
|
91
93
|
class_ad_req = ClassAd("[]")
|
92
94
|
if not class_ad_job.isOK():
|
@@ -99,7 +101,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
99
101
|
# TODO is this even needed?
|
100
102
|
class_ad_job.insertAttributeInt("JobID", job_id)
|
101
103
|
|
102
|
-
await job_db.
|
104
|
+
await job_db.check_and_prepare_job(
|
103
105
|
job_id,
|
104
106
|
class_ad_job,
|
105
107
|
class_ad_req,
|
@@ -108,10 +110,10 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
108
110
|
job_attrs,
|
109
111
|
job.vo,
|
110
112
|
)
|
111
|
-
|
113
|
+
job_jdl = createJDLWithInitialStatus(
|
112
114
|
class_ad_job,
|
113
115
|
class_ad_req,
|
114
|
-
job_db.
|
116
|
+
job_db.jdl_2_db_parameters,
|
115
117
|
job_attrs,
|
116
118
|
job.initial_status,
|
117
119
|
job.initial_minor_status,
|
@@ -119,11 +121,11 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
119
121
|
)
|
120
122
|
|
121
123
|
jobs_to_insert[job_id] = job_attrs
|
122
|
-
jdls_to_update[job_id] =
|
124
|
+
jdls_to_update[job_id] = job_jdl
|
123
125
|
|
124
126
|
if class_ad_job.lookupAttribute("InputData"):
|
125
|
-
|
126
|
-
inputdata_to_insert[job_id] = [lfn for lfn in
|
127
|
+
input_data = class_ad_job.getListFromExpression("InputData")
|
128
|
+
inputdata_to_insert[job_id] = [lfn for lfn in input_data if lfn]
|
127
129
|
|
128
130
|
tg.create_task(job_db.update_job_jdls(jdls_to_update))
|
129
131
|
tg.create_task(job_db.insert_job_attributes(jobs_to_insert))
|
@@ -243,7 +245,7 @@ async def reschedule_jobs_bulk(
|
|
243
245
|
job_jdls = {
|
244
246
|
jobid: parse_jdl(jobid, jdl)
|
245
247
|
for jobid, jdl in (
|
246
|
-
(await job_db.
|
248
|
+
(await job_db.get_job_jdls(surviving_job_ids, original=True)).items()
|
247
249
|
)
|
248
250
|
}
|
249
251
|
|
@@ -251,7 +253,7 @@ async def reschedule_jobs_bulk(
|
|
251
253
|
class_ad_job = job_jdls[job_id]
|
252
254
|
class_ad_req = ClassAd("[]")
|
253
255
|
try:
|
254
|
-
await job_db.
|
256
|
+
await job_db.check_and_prepare_job(
|
255
257
|
job_id,
|
256
258
|
class_ad_job,
|
257
259
|
class_ad_req,
|
@@ -277,11 +279,11 @@ async def reschedule_jobs_bulk(
|
|
277
279
|
else:
|
278
280
|
site = site_list[0]
|
279
281
|
|
280
|
-
|
281
|
-
class_ad_job.insertAttributeInt("JobRequirements",
|
282
|
-
|
282
|
+
req_jdl = class_ad_req.asJDL()
|
283
|
+
class_ad_job.insertAttributeInt("JobRequirements", req_jdl)
|
284
|
+
job_jdl = class_ad_job.asJDL()
|
283
285
|
# Replace the JobID placeholder if any
|
284
|
-
|
286
|
+
job_jdl = job_jdl.replace("%j", str(job_id))
|
285
287
|
|
286
288
|
additional_attrs = {
|
287
289
|
"Site": site,
|
@@ -291,7 +293,7 @@ async def reschedule_jobs_bulk(
|
|
291
293
|
}
|
292
294
|
|
293
295
|
# set new JDL
|
294
|
-
jdl_changes[job_id] =
|
296
|
+
jdl_changes[job_id] = job_jdl
|
295
297
|
|
296
298
|
# set new status
|
297
299
|
status_changes[job_id] = {
|
@@ -319,17 +321,18 @@ async def reschedule_jobs_bulk(
|
|
319
321
|
|
320
322
|
# BULK JDL UPDATE
|
321
323
|
# DATABASE OPERATION
|
322
|
-
await job_db.
|
324
|
+
await job_db.set_job_jdl_bulk(jdl_changes)
|
323
325
|
|
324
326
|
return {
|
325
327
|
"failed": failed,
|
326
328
|
"success": {
|
327
329
|
job_id: {
|
328
|
-
"InputData": job_jdls
|
330
|
+
"InputData": job_jdls.get(job_id, None),
|
329
331
|
**attribute_changes[job_id],
|
330
332
|
**set_status_result.model_dump(),
|
331
333
|
}
|
332
334
|
for job_id, set_status_result in set_job_status_result.success.items()
|
335
|
+
if job_id not in failed
|
333
336
|
},
|
334
337
|
}
|
335
338
|
|
@@ -411,40 +414,40 @@ async def set_job_status_bulk(
|
|
411
414
|
|
412
415
|
for res in results:
|
413
416
|
job_id = int(res["JobID"])
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
+
current_status = res["Status"]
|
418
|
+
start_time = res["StartExecTime"]
|
419
|
+
end_time = res["EndExecTime"]
|
417
420
|
|
418
421
|
# If the current status is Stalled and we get an update, it should probably be "Running"
|
419
|
-
if
|
420
|
-
|
422
|
+
if current_status == JobStatus.STALLED:
|
423
|
+
current_status = JobStatus.RUNNING
|
421
424
|
|
422
425
|
#####################################################################################################
|
423
|
-
|
424
|
-
# This is more precise than "LastTime".
|
425
|
-
|
426
|
-
|
426
|
+
status_dict = status_dicts[job_id]
|
427
|
+
# This is more precise than "LastTime". time_stamps is a sorted list of tuples...
|
428
|
+
time_stamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items())
|
429
|
+
last_time = TimeUtilities.fromEpoch(time_stamps[-1][0]).replace(
|
427
430
|
tzinfo=timezone.utc
|
428
431
|
)
|
429
432
|
|
430
433
|
# Get chronological order of new updates
|
431
|
-
|
434
|
+
update_times = sorted(status_dict)
|
432
435
|
|
433
|
-
|
434
|
-
|
436
|
+
new_start_time, new_end_time = getStartAndEndTime(
|
437
|
+
start_time, end_time, update_times, time_stamps, status_dict
|
435
438
|
)
|
436
439
|
|
437
440
|
job_data: dict[str, str] = {}
|
438
441
|
new_status: str | None = None
|
439
|
-
if
|
442
|
+
if update_times[-1] >= last_time:
|
440
443
|
new_status, new_minor, new_application = (
|
441
444
|
returnValueOrRaise( # TODO: Catch this
|
442
445
|
getNewStatus(
|
443
446
|
job_id,
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
447
|
+
update_times,
|
448
|
+
last_time,
|
449
|
+
status_dict,
|
450
|
+
current_status,
|
448
451
|
force,
|
449
452
|
MagicMock(), # FIXME
|
450
453
|
)
|
@@ -466,15 +469,15 @@ async def set_job_status_bulk(
|
|
466
469
|
# if not result["OK"]:
|
467
470
|
# return result
|
468
471
|
|
469
|
-
for
|
470
|
-
if
|
471
|
-
job_data["HeartBeatTime"] = str(
|
472
|
+
for upd_time in update_times:
|
473
|
+
if status_dict[upd_time]["Source"].startswith("Job"):
|
474
|
+
job_data["HeartBeatTime"] = str(upd_time)
|
472
475
|
|
473
|
-
if not
|
474
|
-
job_data["StartExecTime"] =
|
476
|
+
if not start_time and new_start_time:
|
477
|
+
job_data["StartExecTime"] = new_start_time
|
475
478
|
|
476
|
-
if not
|
477
|
-
job_data["EndExecTime"] =
|
479
|
+
if not end_time and new_end_time:
|
480
|
+
job_data["EndExecTime"] = new_end_time
|
478
481
|
|
479
482
|
#####################################################################################################
|
480
483
|
# delete or kill job, if we transition to DELETED or KILLED state
|
@@ -485,20 +488,20 @@ async def set_job_status_bulk(
|
|
485
488
|
if job_data:
|
486
489
|
job_attribute_updates[job_id] = job_data
|
487
490
|
|
488
|
-
for
|
489
|
-
|
491
|
+
for upd_time in update_times:
|
492
|
+
s_dict = status_dict[upd_time]
|
490
493
|
job_logging_updates.append(
|
491
494
|
JobLoggingRecord(
|
492
495
|
job_id=job_id,
|
493
|
-
status=
|
494
|
-
minor_status=
|
495
|
-
application_status=
|
496
|
-
date=
|
497
|
-
source=
|
496
|
+
status=s_dict.get("Status", "idem"),
|
497
|
+
minor_status=s_dict.get("MinorStatus", "idem"),
|
498
|
+
application_status=s_dict.get("ApplicationStatus", "idem"),
|
499
|
+
date=upd_time,
|
500
|
+
source=s_dict.get("Source", "Unknown"),
|
498
501
|
)
|
499
502
|
)
|
500
503
|
|
501
|
-
await job_db.
|
504
|
+
await job_db.set_job_attributes_bulk(job_attribute_updates)
|
502
505
|
|
503
506
|
await remove_jobs_from_task_queue(
|
504
507
|
list(deletable_killable_jobs), config, task_queue_db, background_task
|