diracx-db 0.0.1a22__py3-none-any.whl → 0.0.1a24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diracx/db/exceptions.py +4 -1
- diracx/db/os/utils.py +3 -3
- diracx/db/sql/auth/db.py +9 -9
- diracx/db/sql/auth/schema.py +25 -23
- diracx/db/sql/dummy/db.py +2 -2
- diracx/db/sql/dummy/schema.py +8 -6
- diracx/db/sql/job/db.py +57 -54
- diracx/db/sql/job/schema.py +56 -54
- diracx/db/sql/job_logging/db.py +50 -46
- diracx/db/sql/job_logging/schema.py +12 -8
- diracx/db/sql/pilot_agents/schema.py +24 -22
- diracx/db/sql/sandbox_metadata/db.py +42 -40
- diracx/db/sql/sandbox_metadata/schema.py +5 -3
- diracx/db/sql/task_queue/db.py +3 -3
- diracx/db/sql/task_queue/schema.py +2 -0
- diracx/db/sql/utils/__init__.py +22 -451
- diracx/db/sql/utils/base.py +328 -0
- diracx/db/sql/utils/functions.py +105 -0
- diracx/db/sql/utils/job.py +59 -55
- diracx/db/sql/utils/types.py +43 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/METADATA +2 -2
- diracx_db-0.0.1a24.dist-info/RECORD +39 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/WHEEL +1 -1
- diracx_db-0.0.1a22.dist-info/RECORD +0 -36
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/entry_points.txt +0 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,328 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import contextlib
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import re
|
7
|
+
from abc import ABCMeta
|
8
|
+
from collections.abc import AsyncIterator
|
9
|
+
from contextvars import ContextVar
|
10
|
+
from datetime import datetime
|
11
|
+
from typing import Self, cast
|
12
|
+
|
13
|
+
from pydantic import TypeAdapter
|
14
|
+
from sqlalchemy import DateTime, MetaData, select
|
15
|
+
from sqlalchemy.exc import OperationalError
|
16
|
+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
17
|
+
|
18
|
+
from diracx.core.exceptions import InvalidQueryError
|
19
|
+
from diracx.core.extensions import select_from_extension
|
20
|
+
from diracx.core.models import SortDirection
|
21
|
+
from diracx.core.settings import SqlalchemyDsn
|
22
|
+
from diracx.db.exceptions import DBUnavailableError
|
23
|
+
|
24
|
+
from .functions import date_trunc
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
class SQLDBError(Exception):
|
30
|
+
pass
|
31
|
+
|
32
|
+
|
33
|
+
class SQLDBUnavailableError(DBUnavailableError, SQLDBError):
|
34
|
+
"""Used whenever we encounter a problem with the B connection."""
|
35
|
+
|
36
|
+
|
37
|
+
class BaseSQLDB(metaclass=ABCMeta):
|
38
|
+
"""This should be the base class of all the SQL DiracX DBs.
|
39
|
+
|
40
|
+
The details covered here should be handled automatically by the service and
|
41
|
+
task machinery of DiracX and this documentation exists for informational
|
42
|
+
purposes.
|
43
|
+
|
44
|
+
The available databases are discovered by calling `BaseSQLDB.available_urls`.
|
45
|
+
This method returns a mapping of database names to connection URLs. The
|
46
|
+
available databases are determined by the `diracx.dbs.sql` entrypoint in the
|
47
|
+
`pyproject.toml` file and the connection URLs are taken from the environment
|
48
|
+
variables of the form `DIRACX_DB_URL_<db-name>`.
|
49
|
+
|
50
|
+
If extensions to DiracX are being used, there can be multiple implementations
|
51
|
+
of the same database. To list the available implementations use
|
52
|
+
`BaseSQLDB.available_implementations(db_name)`. The first entry in this list
|
53
|
+
will be the preferred implementation and it can be initialized by calling
|
54
|
+
it's `__init__` function with a URL perviously obtained from
|
55
|
+
`BaseSQLDB.available_urls`.
|
56
|
+
|
57
|
+
To control the lifetime of the SQLAlchemy engine used for connecting to the
|
58
|
+
database, which includes the connection pool, the `BaseSQLDB.engine_context`
|
59
|
+
asynchronous context manager should be entered. When inside this context
|
60
|
+
manager, the engine can be accessed with `BaseSQLDB.engine`.
|
61
|
+
|
62
|
+
Upon entering, the DB class can then be used as an asynchronous context
|
63
|
+
manager to enter transactions. If an exception is raised the transaction is
|
64
|
+
rolled back automatically. If the inner context exits peacefully, the
|
65
|
+
transaction is committed automatically. When inside this context manager,
|
66
|
+
the DB connection can be accessed with `BaseSQLDB.conn`.
|
67
|
+
|
68
|
+
For example:
|
69
|
+
|
70
|
+
```python
|
71
|
+
db_name = ...
|
72
|
+
url = BaseSQLDB.available_urls()[db_name]
|
73
|
+
MyDBClass = BaseSQLDB.available_implementations(db_name)[0]
|
74
|
+
|
75
|
+
db = MyDBClass(url)
|
76
|
+
async with db.engine_context:
|
77
|
+
async with db:
|
78
|
+
# Do something in the first transaction
|
79
|
+
# Commit will be called automatically
|
80
|
+
|
81
|
+
async with db:
|
82
|
+
# This transaction will be rolled back due to the exception
|
83
|
+
raise Exception(...)
|
84
|
+
```
|
85
|
+
"""
|
86
|
+
|
87
|
+
# engine: AsyncEngine
|
88
|
+
# TODO: Make metadata an abstract property
|
89
|
+
metadata: MetaData
|
90
|
+
|
91
|
+
def __init__(self, db_url: str) -> None:
|
92
|
+
# We use a ContextVar to make sure that self._conn
|
93
|
+
# is specific to each context, and avoid parallel
|
94
|
+
# route executions to overlap
|
95
|
+
self._conn: ContextVar[AsyncConnection | None] = ContextVar(
|
96
|
+
"_conn", default=None
|
97
|
+
)
|
98
|
+
self._db_url = db_url
|
99
|
+
self._engine: AsyncEngine | None = None
|
100
|
+
|
101
|
+
@classmethod
|
102
|
+
def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
|
103
|
+
"""Return the available implementations of the DB in reverse priority order."""
|
104
|
+
db_classes: list[type[BaseSQLDB]] = [
|
105
|
+
entry_point.load()
|
106
|
+
for entry_point in select_from_extension(
|
107
|
+
group="diracx.db.sql", name=db_name
|
108
|
+
)
|
109
|
+
]
|
110
|
+
if not db_classes:
|
111
|
+
raise NotImplementedError(f"Could not find any matches for {db_name=}")
|
112
|
+
return db_classes
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def available_urls(cls) -> dict[str, str]:
|
116
|
+
"""Return a dict of available database urls.
|
117
|
+
|
118
|
+
The list of available URLs is determined by environment variables
|
119
|
+
prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
|
120
|
+
"""
|
121
|
+
db_urls: dict[str, str] = {}
|
122
|
+
for entry_point in select_from_extension(group="diracx.db.sql"):
|
123
|
+
db_name = entry_point.name
|
124
|
+
var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
|
125
|
+
if var_name in os.environ:
|
126
|
+
try:
|
127
|
+
db_url = os.environ[var_name]
|
128
|
+
if db_url == "sqlite+aiosqlite:///:memory:":
|
129
|
+
db_urls[db_name] = db_url
|
130
|
+
# pydantic does not allow for underscore in scheme
|
131
|
+
# so we do a special case
|
132
|
+
elif "_" in db_url.split(":")[0]:
|
133
|
+
# Validate the URL with a fake schema, and then store
|
134
|
+
# the original one
|
135
|
+
scheme_id = db_url.find(":")
|
136
|
+
fake_url = (
|
137
|
+
db_url[:scheme_id].replace("_", "-") + db_url[scheme_id:]
|
138
|
+
)
|
139
|
+
TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
|
140
|
+
db_urls[db_name] = db_url
|
141
|
+
|
142
|
+
else:
|
143
|
+
db_urls[db_name] = str(
|
144
|
+
TypeAdapter(SqlalchemyDsn).validate_python(db_url)
|
145
|
+
)
|
146
|
+
except Exception:
|
147
|
+
logger.error("Error loading URL for %s", db_name)
|
148
|
+
raise
|
149
|
+
return db_urls
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def transaction(cls) -> Self:
|
153
|
+
raise NotImplementedError("This should never be called")
|
154
|
+
|
155
|
+
@property
|
156
|
+
def engine(self) -> AsyncEngine:
|
157
|
+
"""The engine to use for database operations.
|
158
|
+
|
159
|
+
It is normally not necessary to use the engine directly, unless you are
|
160
|
+
doing something special, like writing a test fixture that gives you a db.
|
161
|
+
|
162
|
+
Requires that the engine_context has been entered.
|
163
|
+
"""
|
164
|
+
assert self._engine is not None, "engine_context must be entered"
|
165
|
+
return self._engine
|
166
|
+
|
167
|
+
@contextlib.asynccontextmanager
|
168
|
+
async def engine_context(self) -> AsyncIterator[None]:
|
169
|
+
"""Context manage to manage the engine lifecycle.
|
170
|
+
|
171
|
+
This is called once at the application startup (see ``lifetime_functions``).
|
172
|
+
"""
|
173
|
+
assert self._engine is None, "engine_context cannot be nested"
|
174
|
+
|
175
|
+
# Set the pool_recycle to 30mn
|
176
|
+
# That should prevent the problem of MySQL expiring connection
|
177
|
+
# after 60mn by default
|
178
|
+
engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
|
179
|
+
self._engine = engine
|
180
|
+
try:
|
181
|
+
yield
|
182
|
+
finally:
|
183
|
+
self._engine = None
|
184
|
+
await engine.dispose()
|
185
|
+
|
186
|
+
@property
|
187
|
+
def conn(self) -> AsyncConnection:
|
188
|
+
if self._conn.get() is None:
|
189
|
+
raise RuntimeError(f"{self.__class__} was used before entering")
|
190
|
+
return cast(AsyncConnection, self._conn.get())
|
191
|
+
|
192
|
+
async def __aenter__(self) -> Self:
|
193
|
+
"""Create a connection.
|
194
|
+
|
195
|
+
This is called by the Dependency mechanism (see ``db_transaction``),
|
196
|
+
It will create a new connection/transaction for each route call.
|
197
|
+
"""
|
198
|
+
assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
|
199
|
+
try:
|
200
|
+
self._conn.set(await self.engine.connect().__aenter__())
|
201
|
+
except Exception as e:
|
202
|
+
raise SQLDBUnavailableError(
|
203
|
+
f"Cannot connect to {self.__class__.__name__}"
|
204
|
+
) from e
|
205
|
+
|
206
|
+
return self
|
207
|
+
|
208
|
+
async def __aexit__(self, exc_type, exc, tb):
|
209
|
+
"""This is called when exiting a route.
|
210
|
+
|
211
|
+
If there was no exception, the changes in the DB are committed.
|
212
|
+
Otherwise, they are rolled back.
|
213
|
+
"""
|
214
|
+
if exc_type is None:
|
215
|
+
await self._conn.get().commit()
|
216
|
+
await self._conn.get().__aexit__(exc_type, exc, tb)
|
217
|
+
self._conn.set(None)
|
218
|
+
|
219
|
+
async def ping(self):
|
220
|
+
"""Check whether the connection to the DB is still working.
|
221
|
+
|
222
|
+
We could enable the ``pre_ping`` in the engine, but this would be ran at
|
223
|
+
every query.
|
224
|
+
"""
|
225
|
+
try:
|
226
|
+
await self.conn.scalar(select(1))
|
227
|
+
except OperationalError as e:
|
228
|
+
raise SQLDBUnavailableError("Cannot ping the DB") from e
|
229
|
+
|
230
|
+
|
231
|
+
def find_time_resolution(value):
|
232
|
+
if isinstance(value, datetime):
|
233
|
+
return None, value
|
234
|
+
if match := re.fullmatch(
|
235
|
+
r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{6}Z?)?)?)?)?)?)?", value
|
236
|
+
):
|
237
|
+
if match.group(6):
|
238
|
+
precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
|
239
|
+
elif match.group(5):
|
240
|
+
precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
|
241
|
+
elif match.group(3):
|
242
|
+
precision, pattern = "HOUR", r"\1-\2-\3 \4"
|
243
|
+
elif match.group(2):
|
244
|
+
precision, pattern = "DAY", r"\1-\2-\3"
|
245
|
+
elif match.group(1):
|
246
|
+
precision, pattern = "MONTH", r"\1-\2"
|
247
|
+
else:
|
248
|
+
precision, pattern = "YEAR", r"\1"
|
249
|
+
return (
|
250
|
+
precision,
|
251
|
+
re.sub(
|
252
|
+
r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{6})?Z?$",
|
253
|
+
pattern,
|
254
|
+
value,
|
255
|
+
),
|
256
|
+
)
|
257
|
+
|
258
|
+
raise InvalidQueryError(f"Cannot parse {value=}")
|
259
|
+
|
260
|
+
|
261
|
+
def apply_search_filters(column_mapping, stmt, search):
|
262
|
+
for query in search:
|
263
|
+
try:
|
264
|
+
column = column_mapping(query["parameter"])
|
265
|
+
except KeyError as e:
|
266
|
+
raise InvalidQueryError(f"Unknown column {query['parameter']}") from e
|
267
|
+
|
268
|
+
if isinstance(column.type, DateTime):
|
269
|
+
if "value" in query and isinstance(query["value"], str):
|
270
|
+
resolution, value = find_time_resolution(query["value"])
|
271
|
+
if resolution:
|
272
|
+
column = date_trunc(column, time_resolution=resolution)
|
273
|
+
query["value"] = value
|
274
|
+
|
275
|
+
if query.get("values"):
|
276
|
+
resolutions, values = zip(
|
277
|
+
*map(find_time_resolution, query.get("values"))
|
278
|
+
)
|
279
|
+
if len(set(resolutions)) != 1:
|
280
|
+
raise InvalidQueryError(
|
281
|
+
f"Cannot mix different time resolutions in {query=}"
|
282
|
+
)
|
283
|
+
if resolution := resolutions[0]:
|
284
|
+
column = date_trunc(column, time_resolution=resolution)
|
285
|
+
query["values"] = values
|
286
|
+
|
287
|
+
if query["operator"] == "eq":
|
288
|
+
expr = column == query["value"]
|
289
|
+
elif query["operator"] == "neq":
|
290
|
+
expr = column != query["value"]
|
291
|
+
elif query["operator"] == "gt":
|
292
|
+
expr = column > query["value"]
|
293
|
+
elif query["operator"] == "lt":
|
294
|
+
expr = column < query["value"]
|
295
|
+
elif query["operator"] == "in":
|
296
|
+
expr = column.in_(query["values"])
|
297
|
+
elif query["operator"] == "not in":
|
298
|
+
expr = column.notin_(query["values"])
|
299
|
+
elif query["operator"] in "like":
|
300
|
+
expr = column.like(query["value"])
|
301
|
+
elif query["operator"] in "ilike":
|
302
|
+
expr = column.ilike(query["value"])
|
303
|
+
else:
|
304
|
+
raise InvalidQueryError(f"Unknown filter {query=}")
|
305
|
+
stmt = stmt.where(expr)
|
306
|
+
return stmt
|
307
|
+
|
308
|
+
|
309
|
+
def apply_sort_constraints(column_mapping, stmt, sorts):
|
310
|
+
sort_columns = []
|
311
|
+
for sort in sorts or []:
|
312
|
+
try:
|
313
|
+
column = column_mapping(sort["parameter"])
|
314
|
+
except KeyError as e:
|
315
|
+
raise InvalidQueryError(
|
316
|
+
f"Cannot sort by {sort['parameter']}: unknown column"
|
317
|
+
) from e
|
318
|
+
sorted_column = None
|
319
|
+
if sort["direction"] == SortDirection.ASC:
|
320
|
+
sorted_column = column.asc()
|
321
|
+
elif sort["direction"] == SortDirection.DESC:
|
322
|
+
sorted_column = column.desc()
|
323
|
+
else:
|
324
|
+
raise InvalidQueryError(f"Unknown sort {sort['direction']=}")
|
325
|
+
sort_columns.append(sorted_column)
|
326
|
+
if sort_columns:
|
327
|
+
stmt = stmt.order_by(*sort_columns)
|
328
|
+
return stmt
|
@@ -0,0 +1,105 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from datetime import datetime, timedelta, timezone
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
from sqlalchemy import DateTime, func
|
7
|
+
from sqlalchemy.ext.compiler import compiles
|
8
|
+
from sqlalchemy.sql import expression
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from sqlalchemy.types import TypeEngine
|
12
|
+
|
13
|
+
|
14
|
+
class utcnow(expression.FunctionElement): # noqa: N801
|
15
|
+
type: TypeEngine = DateTime()
|
16
|
+
inherit_cache: bool = True
|
17
|
+
|
18
|
+
|
19
|
+
@compiles(utcnow, "postgresql")
|
20
|
+
def pg_utcnow(element, compiler, **kw) -> str:
|
21
|
+
return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
|
22
|
+
|
23
|
+
|
24
|
+
@compiles(utcnow, "mssql")
|
25
|
+
def ms_utcnow(element, compiler, **kw) -> str:
|
26
|
+
return "GETUTCDATE()"
|
27
|
+
|
28
|
+
|
29
|
+
@compiles(utcnow, "mysql")
|
30
|
+
def mysql_utcnow(element, compiler, **kw) -> str:
|
31
|
+
return "(UTC_TIMESTAMP)"
|
32
|
+
|
33
|
+
|
34
|
+
@compiles(utcnow, "sqlite")
|
35
|
+
def sqlite_utcnow(element, compiler, **kw) -> str:
|
36
|
+
return "DATETIME('now')"
|
37
|
+
|
38
|
+
|
39
|
+
class date_trunc(expression.FunctionElement): # noqa: N801
|
40
|
+
"""Sqlalchemy function to truncate a date to a given resolution.
|
41
|
+
|
42
|
+
Primarily used to be able to query for a specific resolution of a date e.g.
|
43
|
+
|
44
|
+
select * from table where date_trunc('day', date_column) = '2021-01-01'
|
45
|
+
select * from table where date_trunc('year', date_column) = '2021'
|
46
|
+
select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
|
47
|
+
"""
|
48
|
+
|
49
|
+
type = DateTime()
|
50
|
+
inherit_cache = True
|
51
|
+
|
52
|
+
def __init__(self, *args, time_resolution, **kwargs) -> None:
|
53
|
+
super().__init__(*args, **kwargs)
|
54
|
+
self._time_resolution = time_resolution
|
55
|
+
|
56
|
+
|
57
|
+
@compiles(date_trunc, "postgresql")
|
58
|
+
def pg_date_trunc(element, compiler, **kw):
|
59
|
+
res = {
|
60
|
+
"SECOND": "second",
|
61
|
+
"MINUTE": "minute",
|
62
|
+
"HOUR": "hour",
|
63
|
+
"DAY": "day",
|
64
|
+
"MONTH": "month",
|
65
|
+
"YEAR": "year",
|
66
|
+
}[element._time_resolution]
|
67
|
+
return f"date_trunc('{res}', {compiler.process(element.clauses)})"
|
68
|
+
|
69
|
+
|
70
|
+
@compiles(date_trunc, "mysql")
|
71
|
+
def mysql_date_trunc(element, compiler, **kw):
|
72
|
+
pattern = {
|
73
|
+
"SECOND": "%Y-%m-%d %H:%i:%S",
|
74
|
+
"MINUTE": "%Y-%m-%d %H:%i",
|
75
|
+
"HOUR": "%Y-%m-%d %H",
|
76
|
+
"DAY": "%Y-%m-%d",
|
77
|
+
"MONTH": "%Y-%m",
|
78
|
+
"YEAR": "%Y",
|
79
|
+
}[element._time_resolution]
|
80
|
+
|
81
|
+
(dt_col,) = list(element.clauses)
|
82
|
+
return compiler.process(func.date_format(dt_col, pattern))
|
83
|
+
|
84
|
+
|
85
|
+
@compiles(date_trunc, "sqlite")
|
86
|
+
def sqlite_date_trunc(element, compiler, **kw):
|
87
|
+
pattern = {
|
88
|
+
"SECOND": "%Y-%m-%d %H:%M:%S",
|
89
|
+
"MINUTE": "%Y-%m-%d %H:%M",
|
90
|
+
"HOUR": "%Y-%m-%d %H",
|
91
|
+
"DAY": "%Y-%m-%d",
|
92
|
+
"MONTH": "%Y-%m",
|
93
|
+
"YEAR": "%Y",
|
94
|
+
}[element._time_resolution]
|
95
|
+
(dt_col,) = list(element.clauses)
|
96
|
+
return compiler.process(
|
97
|
+
func.strftime(
|
98
|
+
pattern,
|
99
|
+
dt_col,
|
100
|
+
)
|
101
|
+
)
|
102
|
+
|
103
|
+
|
104
|
+
def substract_date(**kwargs: float) -> datetime:
|
105
|
+
return datetime.now(tz=timezone.utc) - timedelta(**kwargs)
|
diracx/db/sql/utils/job.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import asyncio
|
2
4
|
from collections import defaultdict
|
3
5
|
from copy import deepcopy
|
@@ -49,7 +51,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
49
51
|
async with asyncio.TaskGroup() as tg:
|
50
52
|
for job in jobs:
|
51
53
|
original_jdl = deepcopy(job.jdl)
|
52
|
-
|
54
|
+
job_manifest = returnValueOrRaise(
|
53
55
|
checkAndAddOwner(original_jdl, job.owner, job.owner_group)
|
54
56
|
)
|
55
57
|
|
@@ -60,13 +62,13 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
60
62
|
original_jdls.append(
|
61
63
|
(
|
62
64
|
original_jdl,
|
63
|
-
|
65
|
+
job_manifest,
|
64
66
|
tg.create_task(job_db.create_job(original_jdl)),
|
65
67
|
)
|
66
68
|
)
|
67
69
|
|
68
70
|
async with asyncio.TaskGroup() as tg:
|
69
|
-
for job, (original_jdl,
|
71
|
+
for job, (original_jdl, job_manifest_, job_id_task) in zip(jobs, original_jdls):
|
70
72
|
job_id = job_id_task.result()
|
71
73
|
job_attrs = {
|
72
74
|
"JobID": job_id,
|
@@ -77,16 +79,16 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
77
79
|
"VO": job.vo,
|
78
80
|
}
|
79
81
|
|
80
|
-
|
82
|
+
job_manifest_.setOption("JobID", job_id)
|
81
83
|
|
82
84
|
# 2.- Check JDL and Prepare DIRAC JDL
|
83
|
-
|
85
|
+
job_jdl = job_manifest_.dumpAsJDL()
|
84
86
|
|
85
87
|
# Replace the JobID placeholder if any
|
86
|
-
if
|
87
|
-
|
88
|
+
if job_jdl.find("%j") != -1:
|
89
|
+
job_jdl = job_jdl.replace("%j", str(job_id))
|
88
90
|
|
89
|
-
class_ad_job = ClassAd(
|
91
|
+
class_ad_job = ClassAd(job_jdl)
|
90
92
|
|
91
93
|
class_ad_req = ClassAd("[]")
|
92
94
|
if not class_ad_job.isOK():
|
@@ -99,7 +101,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
99
101
|
# TODO is this even needed?
|
100
102
|
class_ad_job.insertAttributeInt("JobID", job_id)
|
101
103
|
|
102
|
-
await job_db.
|
104
|
+
await job_db.check_and_prepare_job(
|
103
105
|
job_id,
|
104
106
|
class_ad_job,
|
105
107
|
class_ad_req,
|
@@ -108,10 +110,10 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
108
110
|
job_attrs,
|
109
111
|
job.vo,
|
110
112
|
)
|
111
|
-
|
113
|
+
job_jdl = createJDLWithInitialStatus(
|
112
114
|
class_ad_job,
|
113
115
|
class_ad_req,
|
114
|
-
job_db.
|
116
|
+
job_db.jdl_2_db_parameters,
|
115
117
|
job_attrs,
|
116
118
|
job.initial_status,
|
117
119
|
job.initial_minor_status,
|
@@ -119,11 +121,11 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
|
|
119
121
|
)
|
120
122
|
|
121
123
|
jobs_to_insert[job_id] = job_attrs
|
122
|
-
jdls_to_update[job_id] =
|
124
|
+
jdls_to_update[job_id] = job_jdl
|
123
125
|
|
124
126
|
if class_ad_job.lookupAttribute("InputData"):
|
125
|
-
|
126
|
-
inputdata_to_insert[job_id] = [lfn for lfn in
|
127
|
+
input_data = class_ad_job.getListFromExpression("InputData")
|
128
|
+
inputdata_to_insert[job_id] = [lfn for lfn in input_data if lfn]
|
127
129
|
|
128
130
|
tg.create_task(job_db.update_job_jdls(jdls_to_update))
|
129
131
|
tg.create_task(job_db.insert_job_attributes(jobs_to_insert))
|
@@ -243,7 +245,7 @@ async def reschedule_jobs_bulk(
|
|
243
245
|
job_jdls = {
|
244
246
|
jobid: parse_jdl(jobid, jdl)
|
245
247
|
for jobid, jdl in (
|
246
|
-
(await job_db.
|
248
|
+
(await job_db.get_job_jdls(surviving_job_ids, original=True)).items()
|
247
249
|
)
|
248
250
|
}
|
249
251
|
|
@@ -251,7 +253,7 @@ async def reschedule_jobs_bulk(
|
|
251
253
|
class_ad_job = job_jdls[job_id]
|
252
254
|
class_ad_req = ClassAd("[]")
|
253
255
|
try:
|
254
|
-
await job_db.
|
256
|
+
await job_db.check_and_prepare_job(
|
255
257
|
job_id,
|
256
258
|
class_ad_job,
|
257
259
|
class_ad_req,
|
@@ -277,11 +279,11 @@ async def reschedule_jobs_bulk(
|
|
277
279
|
else:
|
278
280
|
site = site_list[0]
|
279
281
|
|
280
|
-
|
281
|
-
class_ad_job.insertAttributeInt("JobRequirements",
|
282
|
-
|
282
|
+
req_jdl = class_ad_req.asJDL()
|
283
|
+
class_ad_job.insertAttributeInt("JobRequirements", req_jdl)
|
284
|
+
job_jdl = class_ad_job.asJDL()
|
283
285
|
# Replace the JobID placeholder if any
|
284
|
-
|
286
|
+
job_jdl = job_jdl.replace("%j", str(job_id))
|
285
287
|
|
286
288
|
additional_attrs = {
|
287
289
|
"Site": site,
|
@@ -291,7 +293,7 @@ async def reschedule_jobs_bulk(
|
|
291
293
|
}
|
292
294
|
|
293
295
|
# set new JDL
|
294
|
-
jdl_changes[job_id] =
|
296
|
+
jdl_changes[job_id] = job_jdl
|
295
297
|
|
296
298
|
# set new status
|
297
299
|
status_changes[job_id] = {
|
@@ -319,17 +321,18 @@ async def reschedule_jobs_bulk(
|
|
319
321
|
|
320
322
|
# BULK JDL UPDATE
|
321
323
|
# DATABASE OPERATION
|
322
|
-
await job_db.
|
324
|
+
await job_db.set_job_jdl_bulk(jdl_changes)
|
323
325
|
|
324
326
|
return {
|
325
327
|
"failed": failed,
|
326
328
|
"success": {
|
327
329
|
job_id: {
|
328
|
-
"InputData": job_jdls
|
330
|
+
"InputData": job_jdls.get(job_id, None),
|
329
331
|
**attribute_changes[job_id],
|
330
332
|
**set_status_result.model_dump(),
|
331
333
|
}
|
332
334
|
for job_id, set_status_result in set_job_status_result.success.items()
|
335
|
+
if job_id not in failed
|
333
336
|
},
|
334
337
|
}
|
335
338
|
|
@@ -411,39 +414,40 @@ async def set_job_status_bulk(
|
|
411
414
|
|
412
415
|
for res in results:
|
413
416
|
job_id = int(res["JobID"])
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
+
current_status = res["Status"]
|
418
|
+
start_time = res["StartExecTime"]
|
419
|
+
end_time = res["EndExecTime"]
|
417
420
|
|
418
421
|
# If the current status is Stalled and we get an update, it should probably be "Running"
|
419
|
-
if
|
420
|
-
|
422
|
+
if current_status == JobStatus.STALLED:
|
423
|
+
current_status = JobStatus.RUNNING
|
421
424
|
|
422
425
|
#####################################################################################################
|
423
|
-
|
424
|
-
# This is more precise than "LastTime".
|
425
|
-
|
426
|
-
|
426
|
+
status_dict = status_dicts[job_id]
|
427
|
+
# This is more precise than "LastTime". time_stamps is a sorted list of tuples...
|
428
|
+
time_stamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items())
|
429
|
+
last_time = TimeUtilities.fromEpoch(time_stamps[-1][0]).replace(
|
427
430
|
tzinfo=timezone.utc
|
428
431
|
)
|
429
432
|
|
430
433
|
# Get chronological order of new updates
|
431
|
-
|
434
|
+
update_times = sorted(status_dict)
|
432
435
|
|
433
|
-
|
434
|
-
|
436
|
+
new_start_time, new_end_time = getStartAndEndTime(
|
437
|
+
start_time, end_time, update_times, time_stamps, status_dict
|
435
438
|
)
|
436
439
|
|
437
440
|
job_data: dict[str, str] = {}
|
438
|
-
|
441
|
+
new_status: str | None = None
|
442
|
+
if update_times[-1] >= last_time:
|
439
443
|
new_status, new_minor, new_application = (
|
440
444
|
returnValueOrRaise( # TODO: Catch this
|
441
445
|
getNewStatus(
|
442
446
|
job_id,
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
+
update_times,
|
448
|
+
last_time,
|
449
|
+
status_dict,
|
450
|
+
current_status,
|
447
451
|
force,
|
448
452
|
MagicMock(), # FIXME
|
449
453
|
)
|
@@ -465,15 +469,15 @@ async def set_job_status_bulk(
|
|
465
469
|
# if not result["OK"]:
|
466
470
|
# return result
|
467
471
|
|
468
|
-
for
|
469
|
-
if
|
470
|
-
job_data["HeartBeatTime"] = str(
|
472
|
+
for upd_time in update_times:
|
473
|
+
if status_dict[upd_time]["Source"].startswith("Job"):
|
474
|
+
job_data["HeartBeatTime"] = str(upd_time)
|
471
475
|
|
472
|
-
if not
|
473
|
-
job_data["StartExecTime"] =
|
476
|
+
if not start_time and new_start_time:
|
477
|
+
job_data["StartExecTime"] = new_start_time
|
474
478
|
|
475
|
-
if not
|
476
|
-
job_data["EndExecTime"] =
|
479
|
+
if not end_time and new_end_time:
|
480
|
+
job_data["EndExecTime"] = new_end_time
|
477
481
|
|
478
482
|
#####################################################################################################
|
479
483
|
# delete or kill job, if we transition to DELETED or KILLED state
|
@@ -484,20 +488,20 @@ async def set_job_status_bulk(
|
|
484
488
|
if job_data:
|
485
489
|
job_attribute_updates[job_id] = job_data
|
486
490
|
|
487
|
-
for
|
488
|
-
|
491
|
+
for upd_time in update_times:
|
492
|
+
s_dict = status_dict[upd_time]
|
489
493
|
job_logging_updates.append(
|
490
494
|
JobLoggingRecord(
|
491
495
|
job_id=job_id,
|
492
|
-
status=
|
493
|
-
minor_status=
|
494
|
-
application_status=
|
495
|
-
date=
|
496
|
-
source=
|
496
|
+
status=s_dict.get("Status", "idem"),
|
497
|
+
minor_status=s_dict.get("MinorStatus", "idem"),
|
498
|
+
application_status=s_dict.get("ApplicationStatus", "idem"),
|
499
|
+
date=upd_time,
|
500
|
+
source=s_dict.get("Source", "Unknown"),
|
497
501
|
)
|
498
502
|
)
|
499
503
|
|
500
|
-
await job_db.
|
504
|
+
await job_db.set_job_attributes_bulk(job_attribute_updates)
|
501
505
|
|
502
506
|
await remove_jobs_from_task_queue(
|
503
507
|
list(deletable_killable_jobs), config, task_queue_db, background_task
|