diracx-db 0.0.1a23__py3-none-any.whl → 0.0.1a24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,328 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import logging
5
+ import os
6
+ import re
7
+ from abc import ABCMeta
8
+ from collections.abc import AsyncIterator
9
+ from contextvars import ContextVar
10
+ from datetime import datetime
11
+ from typing import Self, cast
12
+
13
+ from pydantic import TypeAdapter
14
+ from sqlalchemy import DateTime, MetaData, select
15
+ from sqlalchemy.exc import OperationalError
16
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
17
+
18
+ from diracx.core.exceptions import InvalidQueryError
19
+ from diracx.core.extensions import select_from_extension
20
+ from diracx.core.models import SortDirection
21
+ from diracx.core.settings import SqlalchemyDsn
22
+ from diracx.db.exceptions import DBUnavailableError
23
+
24
+ from .functions import date_trunc
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class SQLDBError(Exception):
30
+ pass
31
+
32
+
33
+ class SQLDBUnavailableError(DBUnavailableError, SQLDBError):
34
+ """Used whenever we encounter a problem with the B connection."""
35
+
36
+
37
+ class BaseSQLDB(metaclass=ABCMeta):
38
+ """This should be the base class of all the SQL DiracX DBs.
39
+
40
+ The details covered here should be handled automatically by the service and
41
+ task machinery of DiracX and this documentation exists for informational
42
+ purposes.
43
+
44
+ The available databases are discovered by calling `BaseSQLDB.available_urls`.
45
+ This method returns a mapping of database names to connection URLs. The
46
+ available databases are determined by the `diracx.dbs.sql` entrypoint in the
47
+ `pyproject.toml` file and the connection URLs are taken from the environment
48
+ variables of the form `DIRACX_DB_URL_<db-name>`.
49
+
50
+ If extensions to DiracX are being used, there can be multiple implementations
51
+ of the same database. To list the available implementations use
52
+ `BaseSQLDB.available_implementations(db_name)`. The first entry in this list
53
+ will be the preferred implementation and it can be initialized by calling
54
+ it's `__init__` function with a URL perviously obtained from
55
+ `BaseSQLDB.available_urls`.
56
+
57
+ To control the lifetime of the SQLAlchemy engine used for connecting to the
58
+ database, which includes the connection pool, the `BaseSQLDB.engine_context`
59
+ asynchronous context manager should be entered. When inside this context
60
+ manager, the engine can be accessed with `BaseSQLDB.engine`.
61
+
62
+ Upon entering, the DB class can then be used as an asynchronous context
63
+ manager to enter transactions. If an exception is raised the transaction is
64
+ rolled back automatically. If the inner context exits peacefully, the
65
+ transaction is committed automatically. When inside this context manager,
66
+ the DB connection can be accessed with `BaseSQLDB.conn`.
67
+
68
+ For example:
69
+
70
+ ```python
71
+ db_name = ...
72
+ url = BaseSQLDB.available_urls()[db_name]
73
+ MyDBClass = BaseSQLDB.available_implementations(db_name)[0]
74
+
75
+ db = MyDBClass(url)
76
+ async with db.engine_context:
77
+ async with db:
78
+ # Do something in the first transaction
79
+ # Commit will be called automatically
80
+
81
+ async with db:
82
+ # This transaction will be rolled back due to the exception
83
+ raise Exception(...)
84
+ ```
85
+ """
86
+
87
+ # engine: AsyncEngine
88
+ # TODO: Make metadata an abstract property
89
+ metadata: MetaData
90
+
91
+ def __init__(self, db_url: str) -> None:
92
+ # We use a ContextVar to make sure that self._conn
93
+ # is specific to each context, and avoid parallel
94
+ # route executions to overlap
95
+ self._conn: ContextVar[AsyncConnection | None] = ContextVar(
96
+ "_conn", default=None
97
+ )
98
+ self._db_url = db_url
99
+ self._engine: AsyncEngine | None = None
100
+
101
+ @classmethod
102
+ def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
103
+ """Return the available implementations of the DB in reverse priority order."""
104
+ db_classes: list[type[BaseSQLDB]] = [
105
+ entry_point.load()
106
+ for entry_point in select_from_extension(
107
+ group="diracx.db.sql", name=db_name
108
+ )
109
+ ]
110
+ if not db_classes:
111
+ raise NotImplementedError(f"Could not find any matches for {db_name=}")
112
+ return db_classes
113
+
114
+ @classmethod
115
+ def available_urls(cls) -> dict[str, str]:
116
+ """Return a dict of available database urls.
117
+
118
+ The list of available URLs is determined by environment variables
119
+ prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
120
+ """
121
+ db_urls: dict[str, str] = {}
122
+ for entry_point in select_from_extension(group="diracx.db.sql"):
123
+ db_name = entry_point.name
124
+ var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
125
+ if var_name in os.environ:
126
+ try:
127
+ db_url = os.environ[var_name]
128
+ if db_url == "sqlite+aiosqlite:///:memory:":
129
+ db_urls[db_name] = db_url
130
+ # pydantic does not allow for underscore in scheme
131
+ # so we do a special case
132
+ elif "_" in db_url.split(":")[0]:
133
+ # Validate the URL with a fake schema, and then store
134
+ # the original one
135
+ scheme_id = db_url.find(":")
136
+ fake_url = (
137
+ db_url[:scheme_id].replace("_", "-") + db_url[scheme_id:]
138
+ )
139
+ TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
140
+ db_urls[db_name] = db_url
141
+
142
+ else:
143
+ db_urls[db_name] = str(
144
+ TypeAdapter(SqlalchemyDsn).validate_python(db_url)
145
+ )
146
+ except Exception:
147
+ logger.error("Error loading URL for %s", db_name)
148
+ raise
149
+ return db_urls
150
+
151
+ @classmethod
152
+ def transaction(cls) -> Self:
153
+ raise NotImplementedError("This should never be called")
154
+
155
+ @property
156
+ def engine(self) -> AsyncEngine:
157
+ """The engine to use for database operations.
158
+
159
+ It is normally not necessary to use the engine directly, unless you are
160
+ doing something special, like writing a test fixture that gives you a db.
161
+
162
+ Requires that the engine_context has been entered.
163
+ """
164
+ assert self._engine is not None, "engine_context must be entered"
165
+ return self._engine
166
+
167
+ @contextlib.asynccontextmanager
168
+ async def engine_context(self) -> AsyncIterator[None]:
169
+ """Context manage to manage the engine lifecycle.
170
+
171
+ This is called once at the application startup (see ``lifetime_functions``).
172
+ """
173
+ assert self._engine is None, "engine_context cannot be nested"
174
+
175
+ # Set the pool_recycle to 30mn
176
+ # That should prevent the problem of MySQL expiring connection
177
+ # after 60mn by default
178
+ engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
179
+ self._engine = engine
180
+ try:
181
+ yield
182
+ finally:
183
+ self._engine = None
184
+ await engine.dispose()
185
+
186
+ @property
187
+ def conn(self) -> AsyncConnection:
188
+ if self._conn.get() is None:
189
+ raise RuntimeError(f"{self.__class__} was used before entering")
190
+ return cast(AsyncConnection, self._conn.get())
191
+
192
+ async def __aenter__(self) -> Self:
193
+ """Create a connection.
194
+
195
+ This is called by the Dependency mechanism (see ``db_transaction``),
196
+ It will create a new connection/transaction for each route call.
197
+ """
198
+ assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
199
+ try:
200
+ self._conn.set(await self.engine.connect().__aenter__())
201
+ except Exception as e:
202
+ raise SQLDBUnavailableError(
203
+ f"Cannot connect to {self.__class__.__name__}"
204
+ ) from e
205
+
206
+ return self
207
+
208
+ async def __aexit__(self, exc_type, exc, tb):
209
+ """This is called when exiting a route.
210
+
211
+ If there was no exception, the changes in the DB are committed.
212
+ Otherwise, they are rolled back.
213
+ """
214
+ if exc_type is None:
215
+ await self._conn.get().commit()
216
+ await self._conn.get().__aexit__(exc_type, exc, tb)
217
+ self._conn.set(None)
218
+
219
+ async def ping(self):
220
+ """Check whether the connection to the DB is still working.
221
+
222
+ We could enable the ``pre_ping`` in the engine, but this would be ran at
223
+ every query.
224
+ """
225
+ try:
226
+ await self.conn.scalar(select(1))
227
+ except OperationalError as e:
228
+ raise SQLDBUnavailableError("Cannot ping the DB") from e
229
+
230
+
231
+ def find_time_resolution(value):
232
+ if isinstance(value, datetime):
233
+ return None, value
234
+ if match := re.fullmatch(
235
+ r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{6}Z?)?)?)?)?)?)?", value
236
+ ):
237
+ if match.group(6):
238
+ precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
239
+ elif match.group(5):
240
+ precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
241
+ elif match.group(3):
242
+ precision, pattern = "HOUR", r"\1-\2-\3 \4"
243
+ elif match.group(2):
244
+ precision, pattern = "DAY", r"\1-\2-\3"
245
+ elif match.group(1):
246
+ precision, pattern = "MONTH", r"\1-\2"
247
+ else:
248
+ precision, pattern = "YEAR", r"\1"
249
+ return (
250
+ precision,
251
+ re.sub(
252
+ r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{6})?Z?$",
253
+ pattern,
254
+ value,
255
+ ),
256
+ )
257
+
258
+ raise InvalidQueryError(f"Cannot parse {value=}")
259
+
260
+
261
+ def apply_search_filters(column_mapping, stmt, search):
262
+ for query in search:
263
+ try:
264
+ column = column_mapping(query["parameter"])
265
+ except KeyError as e:
266
+ raise InvalidQueryError(f"Unknown column {query['parameter']}") from e
267
+
268
+ if isinstance(column.type, DateTime):
269
+ if "value" in query and isinstance(query["value"], str):
270
+ resolution, value = find_time_resolution(query["value"])
271
+ if resolution:
272
+ column = date_trunc(column, time_resolution=resolution)
273
+ query["value"] = value
274
+
275
+ if query.get("values"):
276
+ resolutions, values = zip(
277
+ *map(find_time_resolution, query.get("values"))
278
+ )
279
+ if len(set(resolutions)) != 1:
280
+ raise InvalidQueryError(
281
+ f"Cannot mix different time resolutions in {query=}"
282
+ )
283
+ if resolution := resolutions[0]:
284
+ column = date_trunc(column, time_resolution=resolution)
285
+ query["values"] = values
286
+
287
+ if query["operator"] == "eq":
288
+ expr = column == query["value"]
289
+ elif query["operator"] == "neq":
290
+ expr = column != query["value"]
291
+ elif query["operator"] == "gt":
292
+ expr = column > query["value"]
293
+ elif query["operator"] == "lt":
294
+ expr = column < query["value"]
295
+ elif query["operator"] == "in":
296
+ expr = column.in_(query["values"])
297
+ elif query["operator"] == "not in":
298
+ expr = column.notin_(query["values"])
299
+ elif query["operator"] in "like":
300
+ expr = column.like(query["value"])
301
+ elif query["operator"] in "ilike":
302
+ expr = column.ilike(query["value"])
303
+ else:
304
+ raise InvalidQueryError(f"Unknown filter {query=}")
305
+ stmt = stmt.where(expr)
306
+ return stmt
307
+
308
+
309
+ def apply_sort_constraints(column_mapping, stmt, sorts):
310
+ sort_columns = []
311
+ for sort in sorts or []:
312
+ try:
313
+ column = column_mapping(sort["parameter"])
314
+ except KeyError as e:
315
+ raise InvalidQueryError(
316
+ f"Cannot sort by {sort['parameter']}: unknown column"
317
+ ) from e
318
+ sorted_column = None
319
+ if sort["direction"] == SortDirection.ASC:
320
+ sorted_column = column.asc()
321
+ elif sort["direction"] == SortDirection.DESC:
322
+ sorted_column = column.desc()
323
+ else:
324
+ raise InvalidQueryError(f"Unknown sort {sort['direction']=}")
325
+ sort_columns.append(sorted_column)
326
+ if sort_columns:
327
+ stmt = stmt.order_by(*sort_columns)
328
+ return stmt
@@ -0,0 +1,105 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime, timedelta, timezone
4
+ from typing import TYPE_CHECKING
5
+
6
+ from sqlalchemy import DateTime, func
7
+ from sqlalchemy.ext.compiler import compiles
8
+ from sqlalchemy.sql import expression
9
+
10
+ if TYPE_CHECKING:
11
+ from sqlalchemy.types import TypeEngine
12
+
13
+
14
+ class utcnow(expression.FunctionElement): # noqa: N801
15
+ type: TypeEngine = DateTime()
16
+ inherit_cache: bool = True
17
+
18
+
19
+ @compiles(utcnow, "postgresql")
20
+ def pg_utcnow(element, compiler, **kw) -> str:
21
+ return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
22
+
23
+
24
+ @compiles(utcnow, "mssql")
25
+ def ms_utcnow(element, compiler, **kw) -> str:
26
+ return "GETUTCDATE()"
27
+
28
+
29
+ @compiles(utcnow, "mysql")
30
+ def mysql_utcnow(element, compiler, **kw) -> str:
31
+ return "(UTC_TIMESTAMP)"
32
+
33
+
34
+ @compiles(utcnow, "sqlite")
35
+ def sqlite_utcnow(element, compiler, **kw) -> str:
36
+ return "DATETIME('now')"
37
+
38
+
39
+ class date_trunc(expression.FunctionElement): # noqa: N801
40
+ """Sqlalchemy function to truncate a date to a given resolution.
41
+
42
+ Primarily used to be able to query for a specific resolution of a date e.g.
43
+
44
+ select * from table where date_trunc('day', date_column) = '2021-01-01'
45
+ select * from table where date_trunc('year', date_column) = '2021'
46
+ select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
47
+ """
48
+
49
+ type = DateTime()
50
+ inherit_cache = True
51
+
52
+ def __init__(self, *args, time_resolution, **kwargs) -> None:
53
+ super().__init__(*args, **kwargs)
54
+ self._time_resolution = time_resolution
55
+
56
+
57
+ @compiles(date_trunc, "postgresql")
58
+ def pg_date_trunc(element, compiler, **kw):
59
+ res = {
60
+ "SECOND": "second",
61
+ "MINUTE": "minute",
62
+ "HOUR": "hour",
63
+ "DAY": "day",
64
+ "MONTH": "month",
65
+ "YEAR": "year",
66
+ }[element._time_resolution]
67
+ return f"date_trunc('{res}', {compiler.process(element.clauses)})"
68
+
69
+
70
+ @compiles(date_trunc, "mysql")
71
+ def mysql_date_trunc(element, compiler, **kw):
72
+ pattern = {
73
+ "SECOND": "%Y-%m-%d %H:%i:%S",
74
+ "MINUTE": "%Y-%m-%d %H:%i",
75
+ "HOUR": "%Y-%m-%d %H",
76
+ "DAY": "%Y-%m-%d",
77
+ "MONTH": "%Y-%m",
78
+ "YEAR": "%Y",
79
+ }[element._time_resolution]
80
+
81
+ (dt_col,) = list(element.clauses)
82
+ return compiler.process(func.date_format(dt_col, pattern))
83
+
84
+
85
+ @compiles(date_trunc, "sqlite")
86
+ def sqlite_date_trunc(element, compiler, **kw):
87
+ pattern = {
88
+ "SECOND": "%Y-%m-%d %H:%M:%S",
89
+ "MINUTE": "%Y-%m-%d %H:%M",
90
+ "HOUR": "%Y-%m-%d %H",
91
+ "DAY": "%Y-%m-%d",
92
+ "MONTH": "%Y-%m",
93
+ "YEAR": "%Y",
94
+ }[element._time_resolution]
95
+ (dt_col,) = list(element.clauses)
96
+ return compiler.process(
97
+ func.strftime(
98
+ pattern,
99
+ dt_col,
100
+ )
101
+ )
102
+
103
+
104
+ def substract_date(**kwargs: float) -> datetime:
105
+ return datetime.now(tz=timezone.utc) - timedelta(**kwargs)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  from collections import defaultdict
3
5
  from copy import deepcopy
@@ -49,7 +51,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
49
51
  async with asyncio.TaskGroup() as tg:
50
52
  for job in jobs:
51
53
  original_jdl = deepcopy(job.jdl)
52
- jobManifest = returnValueOrRaise(
54
+ job_manifest = returnValueOrRaise(
53
55
  checkAndAddOwner(original_jdl, job.owner, job.owner_group)
54
56
  )
55
57
 
@@ -60,13 +62,13 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
60
62
  original_jdls.append(
61
63
  (
62
64
  original_jdl,
63
- jobManifest,
65
+ job_manifest,
64
66
  tg.create_task(job_db.create_job(original_jdl)),
65
67
  )
66
68
  )
67
69
 
68
70
  async with asyncio.TaskGroup() as tg:
69
- for job, (original_jdl, jobManifest_, job_id_task) in zip(jobs, original_jdls):
71
+ for job, (original_jdl, job_manifest_, job_id_task) in zip(jobs, original_jdls):
70
72
  job_id = job_id_task.result()
71
73
  job_attrs = {
72
74
  "JobID": job_id,
@@ -77,16 +79,16 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
77
79
  "VO": job.vo,
78
80
  }
79
81
 
80
- jobManifest_.setOption("JobID", job_id)
82
+ job_manifest_.setOption("JobID", job_id)
81
83
 
82
84
  # 2.- Check JDL and Prepare DIRAC JDL
83
- jobJDL = jobManifest_.dumpAsJDL()
85
+ job_jdl = job_manifest_.dumpAsJDL()
84
86
 
85
87
  # Replace the JobID placeholder if any
86
- if jobJDL.find("%j") != -1:
87
- jobJDL = jobJDL.replace("%j", str(job_id))
88
+ if job_jdl.find("%j") != -1:
89
+ job_jdl = job_jdl.replace("%j", str(job_id))
88
90
 
89
- class_ad_job = ClassAd(jobJDL)
91
+ class_ad_job = ClassAd(job_jdl)
90
92
 
91
93
  class_ad_req = ClassAd("[]")
92
94
  if not class_ad_job.isOK():
@@ -99,7 +101,7 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
99
101
  # TODO is this even needed?
100
102
  class_ad_job.insertAttributeInt("JobID", job_id)
101
103
 
102
- await job_db.checkAndPrepareJob(
104
+ await job_db.check_and_prepare_job(
103
105
  job_id,
104
106
  class_ad_job,
105
107
  class_ad_req,
@@ -108,10 +110,10 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
108
110
  job_attrs,
109
111
  job.vo,
110
112
  )
111
- jobJDL = createJDLWithInitialStatus(
113
+ job_jdl = createJDLWithInitialStatus(
112
114
  class_ad_job,
113
115
  class_ad_req,
114
- job_db.jdl2DBParameters,
116
+ job_db.jdl_2_db_parameters,
115
117
  job_attrs,
116
118
  job.initial_status,
117
119
  job.initial_minor_status,
@@ -119,11 +121,11 @@ async def submit_jobs_jdl(jobs: list[JobSubmissionSpec], job_db: JobDB):
119
121
  )
120
122
 
121
123
  jobs_to_insert[job_id] = job_attrs
122
- jdls_to_update[job_id] = jobJDL
124
+ jdls_to_update[job_id] = job_jdl
123
125
 
124
126
  if class_ad_job.lookupAttribute("InputData"):
125
- inputData = class_ad_job.getListFromExpression("InputData")
126
- inputdata_to_insert[job_id] = [lfn for lfn in inputData if lfn]
127
+ input_data = class_ad_job.getListFromExpression("InputData")
128
+ inputdata_to_insert[job_id] = [lfn for lfn in input_data if lfn]
127
129
 
128
130
  tg.create_task(job_db.update_job_jdls(jdls_to_update))
129
131
  tg.create_task(job_db.insert_job_attributes(jobs_to_insert))
@@ -243,7 +245,7 @@ async def reschedule_jobs_bulk(
243
245
  job_jdls = {
244
246
  jobid: parse_jdl(jobid, jdl)
245
247
  for jobid, jdl in (
246
- (await job_db.getJobJDLs(surviving_job_ids, original=True)).items()
248
+ (await job_db.get_job_jdls(surviving_job_ids, original=True)).items()
247
249
  )
248
250
  }
249
251
 
@@ -251,7 +253,7 @@ async def reschedule_jobs_bulk(
251
253
  class_ad_job = job_jdls[job_id]
252
254
  class_ad_req = ClassAd("[]")
253
255
  try:
254
- await job_db.checkAndPrepareJob(
256
+ await job_db.check_and_prepare_job(
255
257
  job_id,
256
258
  class_ad_job,
257
259
  class_ad_req,
@@ -277,11 +279,11 @@ async def reschedule_jobs_bulk(
277
279
  else:
278
280
  site = site_list[0]
279
281
 
280
- reqJDL = class_ad_req.asJDL()
281
- class_ad_job.insertAttributeInt("JobRequirements", reqJDL)
282
- jobJDL = class_ad_job.asJDL()
282
+ req_jdl = class_ad_req.asJDL()
283
+ class_ad_job.insertAttributeInt("JobRequirements", req_jdl)
284
+ job_jdl = class_ad_job.asJDL()
283
285
  # Replace the JobID placeholder if any
284
- jobJDL = jobJDL.replace("%j", str(job_id))
286
+ job_jdl = job_jdl.replace("%j", str(job_id))
285
287
 
286
288
  additional_attrs = {
287
289
  "Site": site,
@@ -291,7 +293,7 @@ async def reschedule_jobs_bulk(
291
293
  }
292
294
 
293
295
  # set new JDL
294
- jdl_changes[job_id] = jobJDL
296
+ jdl_changes[job_id] = job_jdl
295
297
 
296
298
  # set new status
297
299
  status_changes[job_id] = {
@@ -319,17 +321,18 @@ async def reschedule_jobs_bulk(
319
321
 
320
322
  # BULK JDL UPDATE
321
323
  # DATABASE OPERATION
322
- await job_db.setJobJDLsBulk(jdl_changes)
324
+ await job_db.set_job_jdl_bulk(jdl_changes)
323
325
 
324
326
  return {
325
327
  "failed": failed,
326
328
  "success": {
327
329
  job_id: {
328
- "InputData": job_jdls[job_id],
330
+ "InputData": job_jdls.get(job_id, None),
329
331
  **attribute_changes[job_id],
330
332
  **set_status_result.model_dump(),
331
333
  }
332
334
  for job_id, set_status_result in set_job_status_result.success.items()
335
+ if job_id not in failed
333
336
  },
334
337
  }
335
338
 
@@ -411,40 +414,40 @@ async def set_job_status_bulk(
411
414
 
412
415
  for res in results:
413
416
  job_id = int(res["JobID"])
414
- currentStatus = res["Status"]
415
- startTime = res["StartExecTime"]
416
- endTime = res["EndExecTime"]
417
+ current_status = res["Status"]
418
+ start_time = res["StartExecTime"]
419
+ end_time = res["EndExecTime"]
417
420
 
418
421
  # If the current status is Stalled and we get an update, it should probably be "Running"
419
- if currentStatus == JobStatus.STALLED:
420
- currentStatus = JobStatus.RUNNING
422
+ if current_status == JobStatus.STALLED:
423
+ current_status = JobStatus.RUNNING
421
424
 
422
425
  #####################################################################################################
423
- statusDict = status_dicts[job_id]
424
- # This is more precise than "LastTime". timeStamps is a sorted list of tuples...
425
- timeStamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items())
426
- lastTime = TimeUtilities.fromEpoch(timeStamps[-1][0]).replace(
426
+ status_dict = status_dicts[job_id]
427
+ # This is more precise than "LastTime". time_stamps is a sorted list of tuples...
428
+ time_stamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items())
429
+ last_time = TimeUtilities.fromEpoch(time_stamps[-1][0]).replace(
427
430
  tzinfo=timezone.utc
428
431
  )
429
432
 
430
433
  # Get chronological order of new updates
431
- updateTimes = sorted(statusDict)
434
+ update_times = sorted(status_dict)
432
435
 
433
- newStartTime, newEndTime = getStartAndEndTime(
434
- startTime, endTime, updateTimes, timeStamps, statusDict
436
+ new_start_time, new_end_time = getStartAndEndTime(
437
+ start_time, end_time, update_times, time_stamps, status_dict
435
438
  )
436
439
 
437
440
  job_data: dict[str, str] = {}
438
441
  new_status: str | None = None
439
- if updateTimes[-1] >= lastTime:
442
+ if update_times[-1] >= last_time:
440
443
  new_status, new_minor, new_application = (
441
444
  returnValueOrRaise( # TODO: Catch this
442
445
  getNewStatus(
443
446
  job_id,
444
- updateTimes,
445
- lastTime,
446
- statusDict,
447
- currentStatus,
447
+ update_times,
448
+ last_time,
449
+ status_dict,
450
+ current_status,
448
451
  force,
449
452
  MagicMock(), # FIXME
450
453
  )
@@ -466,15 +469,15 @@ async def set_job_status_bulk(
466
469
  # if not result["OK"]:
467
470
  # return result
468
471
 
469
- for updTime in updateTimes:
470
- if statusDict[updTime]["Source"].startswith("Job"):
471
- job_data["HeartBeatTime"] = str(updTime)
472
+ for upd_time in update_times:
473
+ if status_dict[upd_time]["Source"].startswith("Job"):
474
+ job_data["HeartBeatTime"] = str(upd_time)
472
475
 
473
- if not startTime and newStartTime:
474
- job_data["StartExecTime"] = newStartTime
476
+ if not start_time and new_start_time:
477
+ job_data["StartExecTime"] = new_start_time
475
478
 
476
- if not endTime and newEndTime:
477
- job_data["EndExecTime"] = newEndTime
479
+ if not end_time and new_end_time:
480
+ job_data["EndExecTime"] = new_end_time
478
481
 
479
482
  #####################################################################################################
480
483
  # delete or kill job, if we transition to DELETED or KILLED state
@@ -485,20 +488,20 @@ async def set_job_status_bulk(
485
488
  if job_data:
486
489
  job_attribute_updates[job_id] = job_data
487
490
 
488
- for updTime in updateTimes:
489
- sDict = statusDict[updTime]
491
+ for upd_time in update_times:
492
+ s_dict = status_dict[upd_time]
490
493
  job_logging_updates.append(
491
494
  JobLoggingRecord(
492
495
  job_id=job_id,
493
- status=sDict.get("Status", "idem"),
494
- minor_status=sDict.get("MinorStatus", "idem"),
495
- application_status=sDict.get("ApplicationStatus", "idem"),
496
- date=updTime,
497
- source=sDict.get("Source", "Unknown"),
496
+ status=s_dict.get("Status", "idem"),
497
+ minor_status=s_dict.get("MinorStatus", "idem"),
498
+ application_status=s_dict.get("ApplicationStatus", "idem"),
499
+ date=upd_time,
500
+ source=s_dict.get("Source", "Unknown"),
498
501
  )
499
502
  )
500
503
 
501
- await job_db.setJobAttributesBulk(job_attribute_updates)
504
+ await job_db.set_job_attributes_bulk(job_attribute_updates)
502
505
 
503
506
  await remove_jobs_from_task_queue(
504
507
  list(deletable_killable_jobs), config, task_queue_db, background_task