diracx-db 0.0.1a19__py3-none-any.whl → 0.0.1a21__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
diracx/db/sql/job/db.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  from datetime import datetime, timezone
5
4
  from typing import TYPE_CHECKING, Any
6
5
 
@@ -9,14 +8,9 @@ from sqlalchemy.exc import IntegrityError, NoResultFound
9
8
 
10
9
  if TYPE_CHECKING:
11
10
  from sqlalchemy.sql.elements import BindParameter
12
-
13
11
  from diracx.core.exceptions import InvalidQueryError, JobNotFound
14
12
  from diracx.core.models import (
15
- JobMinorStatus,
16
- JobStatus,
17
13
  LimitedJobStatusReturn,
18
- ScalarSearchOperator,
19
- ScalarSearchSpec,
20
14
  SearchSpec,
21
15
  SortSpec,
22
16
  )
@@ -50,11 +44,6 @@ class JobDB(BaseSQLDB):
50
44
  # to find a way to make it dynamic
51
45
  jdl2DBParameters = ["JobName", "JobType", "JobGroup"]
52
46
 
53
- # TODO: set maxRescheduling value from CS
54
- # maxRescheduling = self.getCSOption("MaxRescheduling", 3)
55
- # For now:
56
- maxRescheduling = 3
57
-
58
47
  async def summary(self, group_by, search) -> list[dict[str, str | int]]:
59
48
  columns = _get_columns(Jobs.__table__, group_by)
60
49
 
@@ -81,6 +70,7 @@ class JobDB(BaseSQLDB):
81
70
  ) -> tuple[int, list[dict[Any, Any]]]:
82
71
  # Find which columns to select
83
72
  columns = _get_columns(Jobs.__table__, parameters)
73
+
84
74
  stmt = select(*columns)
85
75
 
86
76
  stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
@@ -107,23 +97,18 @@ class JobDB(BaseSQLDB):
107
97
  dict(row._mapping) async for row in (await self.conn.stream(stmt))
108
98
  ]
109
99
 
110
- async def _insertNewJDL(self, jdl) -> int:
111
- from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
112
-
113
- stmt = insert(JobJDLs).values(
114
- JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl)
100
+ async def insert_input_data(self, lfns: dict[int, list[str]]):
101
+ await self.conn.execute(
102
+ InputData.__table__.insert(),
103
+ [
104
+ {
105
+ "JobID": job_id,
106
+ "LFN": lfn,
107
+ }
108
+ for job_id, lfns_ in lfns.items()
109
+ for lfn in lfns_
110
+ ],
115
111
  )
116
- result = await self.conn.execute(stmt)
117
- # await self.engine.commit()
118
- return result.lastrowid
119
-
120
- async def _insertJob(self, jobData: dict[str, Any]):
121
- stmt = insert(Jobs).values(jobData)
122
- await self.conn.execute(stmt)
123
-
124
- async def _insertInputData(self, job_id: int, lfns: list[str]):
125
- stmt = insert(InputData).values([{"JobID": job_id, "LFN": lfn} for lfn in lfns])
126
- await self.conn.execute(stmt)
127
112
 
128
113
  async def setJobAttributes(self, job_id, jobData):
129
114
  """TODO: add myDate and force parameters."""
@@ -132,7 +117,49 @@ class JobDB(BaseSQLDB):
132
117
  stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData)
133
118
  await self.conn.execute(stmt)
134
119
 
135
- async def _checkAndPrepareJob(
120
+ async def create_job(self, original_jdl):
121
+ """Used to insert a new job with original JDL. Returns inserted job id."""
122
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
123
+
124
+ result = await self.conn.execute(
125
+ JobJDLs.__table__.insert().values(
126
+ JDL="",
127
+ JobRequirements="",
128
+ OriginalJDL=compressJDL(original_jdl),
129
+ )
130
+ )
131
+ return result.lastrowid
132
+
133
+ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
134
+ await self.conn.execute(
135
+ Jobs.__table__.insert(),
136
+ [
137
+ {
138
+ "JobID": job_id,
139
+ **attrs,
140
+ }
141
+ for job_id, attrs in jobs_to_update.items()
142
+ ],
143
+ )
144
+
145
+ async def update_job_jdls(self, jdls_to_update: dict[int, str]):
146
+ """Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
147
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
148
+
149
+ await self.conn.execute(
150
+ JobJDLs.__table__.update().where(
151
+ JobJDLs.__table__.c.JobID == bindparam("b_JobID")
152
+ ),
153
+ [
154
+ {
155
+ "b_JobID": job_id,
156
+ "JDL": compressJDL(jdl),
157
+ }
158
+ for job_id, jdl in jdls_to_update.items()
159
+ ],
160
+ )
161
+
162
+ async def checkAndPrepareJob(
136
163
  self,
137
164
  jobID,
138
165
  class_ad_job,
@@ -175,6 +202,31 @@ class JobDB(BaseSQLDB):
175
202
  )
176
203
  await self.conn.execute(stmt)
177
204
 
205
+ async def setJobJDLsBulk(self, jdls):
206
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
207
+
208
+ await self.conn.execute(
209
+ JobJDLs.__table__.update().where(
210
+ JobJDLs.__table__.c.JobID == bindparam("b_JobID")
211
+ ),
212
+ [{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()],
213
+ )
214
+
215
+ async def setJobAttributesBulk(self, jobData):
216
+ """TODO: add myDate and force parameters."""
217
+ for job_id in jobData.keys():
218
+ if "Status" in jobData[job_id]:
219
+ jobData[job_id].update(
220
+ {"LastUpdateTime": datetime.now(tz=timezone.utc)}
221
+ )
222
+
223
+ await self.conn.execute(
224
+ Jobs.__table__.update().where(
225
+ Jobs.__table__.c.JobID == bindparam("b_JobID")
226
+ ),
227
+ [{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()],
228
+ )
229
+
178
230
  async def getJobJDL(self, job_id: int, original: bool = False) -> str:
179
231
  from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
180
232
 
@@ -189,243 +241,21 @@ class JobDB(BaseSQLDB):
189
241
 
190
242
  return jdl
191
243
 
192
- async def insert(
193
- self,
194
- jdl,
195
- owner,
196
- owner_group,
197
- initial_status,
198
- initial_minor_status,
199
- vo,
200
- ):
201
- from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
202
- from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
203
- from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import (
204
- checkAndAddOwner,
205
- createJDLWithInitialStatus,
206
- fixJDL,
207
- )
208
-
209
- job_attrs = {
210
- "LastUpdateTime": datetime.now(tz=timezone.utc),
211
- "SubmissionTime": datetime.now(tz=timezone.utc),
212
- "Owner": owner,
213
- "OwnerGroup": owner_group,
214
- "VO": vo,
215
- }
216
-
217
- jobManifest = returnValueOrRaise(checkAndAddOwner(jdl, owner, owner_group))
218
-
219
- jdl = fixJDL(jdl)
220
-
221
- job_id = await self._insertNewJDL(jdl)
222
-
223
- jobManifest.setOption("JobID", job_id)
224
-
225
- job_attrs["JobID"] = job_id
226
-
227
- # 2.- Check JDL and Prepare DIRAC JDL
228
- jobJDL = jobManifest.dumpAsJDL()
229
-
230
- # Replace the JobID placeholder if any
231
- if jobJDL.find("%j") != -1:
232
- jobJDL = jobJDL.replace("%j", str(job_id))
233
-
234
- class_ad_job = ClassAd(jobJDL)
235
- class_ad_req = ClassAd("[]")
236
- if not class_ad_job.isOK():
237
- job_attrs["Status"] = JobStatus.FAILED
238
-
239
- job_attrs["MinorStatus"] = "Error in JDL syntax"
240
-
241
- await self._insertJob(job_attrs)
242
-
243
- return {
244
- "JobID": job_id,
245
- "Status": JobStatus.FAILED,
246
- "MinorStatus": "Error in JDL syntax",
247
- }
248
-
249
- class_ad_job.insertAttributeInt("JobID", job_id)
250
-
251
- await self._checkAndPrepareJob(
252
- job_id,
253
- class_ad_job,
254
- class_ad_req,
255
- owner,
256
- owner_group,
257
- job_attrs,
258
- vo,
259
- )
260
-
261
- jobJDL = createJDLWithInitialStatus(
262
- class_ad_job,
263
- class_ad_req,
264
- self.jdl2DBParameters,
265
- job_attrs,
266
- initial_status,
267
- initial_minor_status,
268
- modern=True,
269
- )
270
-
271
- await self.setJobJDL(job_id, jobJDL)
272
-
273
- # Adding the job in the Jobs table
274
- await self._insertJob(job_attrs)
275
-
276
- # TODO: check if that is actually true
277
- if class_ad_job.lookupAttribute("Parameters"):
278
- raise NotImplementedError("Parameters in the JDL are not supported")
279
-
280
- # Looking for the Input Data
281
- inputData = []
282
- if class_ad_job.lookupAttribute("InputData"):
283
- inputData = class_ad_job.getListFromExpression("InputData")
284
- lfns = [lfn for lfn in inputData if lfn]
285
- if lfns:
286
- await self._insertInputData(job_id, lfns)
287
-
288
- return {
289
- "JobID": job_id,
290
- "Status": initial_status,
291
- "MinorStatus": initial_minor_status,
292
- "TimeStamp": datetime.now(tz=timezone.utc),
293
- }
294
-
295
- async def rescheduleJob(self, job_id) -> dict[str, Any]:
296
- """Reschedule given job."""
297
- from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
298
- from DIRAC.Core.Utilities.ReturnValues import SErrorException
299
-
300
- _, result = await self.search(
301
- parameters=[
302
- "Status",
303
- "MinorStatus",
304
- "VerifiedFlag",
305
- "RescheduleCounter",
306
- "Owner",
307
- "OwnerGroup",
308
- ],
309
- search=[
310
- ScalarSearchSpec(
311
- parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=job_id
312
- )
313
- ],
314
- sorts=[],
315
- )
316
- if not result:
317
- raise ValueError(f"Job {job_id} not found.")
318
-
319
- jobAttrs = result[0]
320
-
321
- if "VerifiedFlag" not in jobAttrs:
322
- raise ValueError(f"Job {job_id} not found in the system")
323
-
324
- if not jobAttrs["VerifiedFlag"]:
325
- raise ValueError(
326
- f"Job {job_id} not Verified: Status {jobAttrs['Status']}, Minor Status: {jobAttrs['MinorStatus']}"
327
- )
328
-
329
- reschedule_counter = int(jobAttrs["RescheduleCounter"]) + 1
330
-
331
- # TODO: update maxRescheduling:
332
- # self.maxRescheduling = self.getCSOption("MaxRescheduling", self.maxRescheduling)
333
-
334
- if reschedule_counter > self.maxRescheduling:
335
- logging.warn(f"Job {job_id}: Maximum number of reschedulings is reached.")
336
- self.setJobAttributes(
337
- job_id,
338
- {
339
- "Status": JobStatus.FAILED,
340
- "MinorStatus": JobMinorStatus.MAX_RESCHEDULING,
341
- },
342
- )
343
- raise ValueError(
344
- f"Maximum number of reschedulings is reached: {self.maxRescheduling}"
345
- )
346
-
347
- new_job_attributes = {"RescheduleCounter": reschedule_counter}
348
-
349
- # TODO: get the job parameters from JobMonitoringClient
350
- # result = JobMonitoringClient().getJobParameters(jobID)
351
- # if result["OK"]:
352
- # parDict = result["Value"]
353
- # for key, value in parDict.get(jobID, {}).items():
354
- # result = self.setAtticJobParameter(jobID, key, value, rescheduleCounter - 1)
355
- # if not result["OK"]:
356
- # break
357
-
358
- # TODO: IF we keep JobParameters and OptimizerParameters: Delete job in those tables.
359
- # await self.delete_job_parameters(job_id)
360
- # await self.delete_job_optimizer_parameters(job_id)
361
-
362
- job_jdl = await self.getJobJDL(job_id, original=True)
363
- if not job_jdl.strip().startswith("["):
364
- job_jdl = f"[{job_jdl}]"
365
-
366
- classAdJob = ClassAd(job_jdl)
367
- classAdReq = ClassAd("[]")
368
- retVal = {}
369
- retVal["JobID"] = job_id
370
-
371
- classAdJob.insertAttributeInt("JobID", job_id)
244
+ async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, str]:
245
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
372
246
 
373
- try:
374
- result = await self._checkAndPrepareJob(
375
- job_id,
376
- classAdJob,
377
- classAdReq,
378
- jobAttrs["Owner"],
379
- jobAttrs["OwnerGroup"],
380
- new_job_attributes,
381
- classAdJob.getAttributeString("VirtualOrganization"),
247
+ if original:
248
+ stmt = select(JobJDLs.JobID, JobJDLs.OriginalJDL).where(
249
+ JobJDLs.JobID.in_(job_ids)
382
250
  )
383
- except SErrorException as e:
384
- raise ValueError(e) from e
385
-
386
- priority = classAdJob.getAttributeInt("Priority")
387
- if priority is None:
388
- priority = 0
389
- jobAttrs["UserPriority"] = priority
390
-
391
- siteList = classAdJob.getListFromExpression("Site")
392
- if not siteList:
393
- site = "ANY"
394
- elif len(siteList) > 1:
395
- site = "Multiple"
396
251
  else:
397
- site = siteList[0]
398
-
399
- jobAttrs["Site"] = site
400
-
401
- jobAttrs["Status"] = JobStatus.RECEIVED
402
-
403
- jobAttrs["MinorStatus"] = JobMinorStatus.RESCHEDULED
404
-
405
- jobAttrs["ApplicationStatus"] = "Unknown"
252
+ stmt = select(JobJDLs.JobID, JobJDLs.JDL).where(JobJDLs.JobID.in_(job_ids))
406
253
 
407
- jobAttrs["LastUpdateTime"] = datetime.now(tz=timezone.utc)
408
-
409
- jobAttrs["RescheduleTime"] = datetime.now(tz=timezone.utc)
410
-
411
- reqJDL = classAdReq.asJDL()
412
- classAdJob.insertAttributeInt("JobRequirements", reqJDL)
413
-
414
- jobJDL = classAdJob.asJDL()
415
-
416
- # Replace the JobID placeholder if any
417
- jobJDL = jobJDL.replace("%j", str(job_id))
418
-
419
- result = await self.setJobJDL(job_id, jobJDL)
420
-
421
- result = await self.setJobAttributes(job_id, jobAttrs)
422
-
423
- retVal["InputData"] = classAdJob.lookupAttribute("InputData")
424
- retVal["RescheduleCounter"] = reschedule_counter
425
- retVal["Status"] = JobStatus.RECEIVED
426
- retVal["MinorStatus"] = JobMinorStatus.RESCHEDULED
427
-
428
- return retVal
254
+ return {
255
+ jobid: extractJDL(jdl)
256
+ for jobid, jdl in (await self.conn.execute(stmt))
257
+ if jdl
258
+ }
429
259
 
430
260
  async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
431
261
  try:
@@ -451,6 +281,22 @@ class JobDB(BaseSQLDB):
451
281
  except IntegrityError as e:
452
282
  raise JobNotFound(job_id) from e
453
283
 
284
+ async def set_job_command_bulk(self, commands):
285
+ """Store a command to be passed to the job together with the next heart beat."""
286
+ self.conn.execute(
287
+ insert(JobCommands),
288
+ [
289
+ {
290
+ "JobID": job_id,
291
+ "Command": command,
292
+ "Arguments": arguments,
293
+ "ReceptionTime": datetime.now(tz=timezone.utc),
294
+ }
295
+ for job_id, command, arguments in commands
296
+ ],
297
+ )
298
+ # FIXME handle IntegrityError
299
+
454
300
  async def delete_jobs(self, job_ids: list[int]):
455
301
  """Delete jobs from the database."""
456
302
  stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids))
@@ -4,11 +4,14 @@ import time
4
4
  from datetime import datetime, timezone
5
5
  from typing import TYPE_CHECKING
6
6
 
7
+ from pydantic import BaseModel
7
8
  from sqlalchemy import delete, func, insert, select
8
9
 
9
10
  if TYPE_CHECKING:
10
11
  pass
11
12
 
13
+ from collections import defaultdict
14
+
12
15
  from diracx.core.exceptions import JobNotFound
13
16
  from diracx.core.models import (
14
17
  JobStatus,
@@ -24,6 +27,15 @@ from .schema import (
24
27
  MAGIC_EPOC_NUMBER = 1270000000
25
28
 
26
29
 
30
+ class JobLoggingRecord(BaseModel):
31
+ job_id: int
32
+ status: JobStatus
33
+ minor_status: str
34
+ application_status: str
35
+ date: datetime
36
+ source: str
37
+
38
+
27
39
  class JobLoggingDB(BaseSQLDB):
28
40
  """Frontend for the JobLoggingDB. Provides the ability to store changes with timestamps."""
29
41
 
@@ -69,6 +81,49 @@ class JobLoggingDB(BaseSQLDB):
69
81
  )
70
82
  await self.conn.execute(stmt)
71
83
 
84
+ async def bulk_insert_record(
85
+ self,
86
+ records: list[JobLoggingRecord],
87
+ ):
88
+ """Bulk insert entries to the JobLoggingDB table."""
89
+
90
+ def get_epoc(date):
91
+ return (
92
+ time.mktime(date.timetuple())
93
+ + date.microsecond / 1000000.0
94
+ - MAGIC_EPOC_NUMBER
95
+ )
96
+
97
+ # First, fetch the maximum SeqNums for the given job_ids
98
+ seqnum_stmt = (
99
+ select(
100
+ LoggingInfo.JobID, func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)
101
+ )
102
+ .where(LoggingInfo.JobID.in_([record.job_id for record in records]))
103
+ .group_by(LoggingInfo.JobID)
104
+ )
105
+
106
+ seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))}
107
+ # IF a seqnum is not found, then assume it does not exist and the first sequence number is 1.
108
+
109
+ # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements
110
+ await self.conn.execute(
111
+ insert(LoggingInfo),
112
+ [
113
+ {
114
+ "JobID": record.job_id,
115
+ "SeqNum": seqnum.get(record.job_id, 1),
116
+ "Status": record.status,
117
+ "MinorStatus": record.minor_status,
118
+ "ApplicationStatus": record.application_status[:255],
119
+ "StatusTime": record.date,
120
+ "StatusTimeOrder": get_epoc(record.date),
121
+ "Source": record.source[:32],
122
+ }
123
+ for record in records
124
+ ],
125
+ )
126
+
72
127
  async def get_records(self, job_id: int) -> list[JobStatusReturn]:
73
128
  """Returns a Status,MinorStatus,ApplicationStatus,StatusTime,Source tuple
74
129
  for each record found for job specified by its jobID in historical order.
@@ -159,3 +214,22 @@ class JobLoggingDB(BaseSQLDB):
159
214
  result[event] = str(etime + MAGIC_EPOC_NUMBER)
160
215
 
161
216
  return result
217
+
218
+ async def get_wms_time_stamps_bulk(self, job_ids):
219
+ """Get TimeStamps for job MajorState transitions for multiple jobs at once
220
+ return a {JobID: {State:timestamp}} dictionary.
221
+ """
222
+ result = defaultdict(dict)
223
+ stmt = select(
224
+ LoggingInfo.JobID,
225
+ LoggingInfo.Status,
226
+ LoggingInfo.StatusTimeOrder,
227
+ ).where(LoggingInfo.JobID.in_(job_ids))
228
+ rows = await self.conn.execute(stmt)
229
+ if not rows.rowcount:
230
+ return {}
231
+
232
+ for job_id, event, etime in rows:
233
+ result[job_id][event] = str(etime + MAGIC_EPOC_NUMBER)
234
+
235
+ return result
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Self, cast
16
16
  import sqlalchemy.types as types
17
17
  from pydantic import TypeAdapter
18
18
  from sqlalchemy import Column as RawColumn
19
- from sqlalchemy import DateTime, Enum, MetaData, select
19
+ from sqlalchemy import DateTime, Enum, MetaData, func, select
20
20
  from sqlalchemy.exc import OperationalError
21
21
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
22
22
  from sqlalchemy.ext.compiler import compiles
@@ -100,7 +100,9 @@ def mysql_date_trunc(element, compiler, **kw):
100
100
  "MONTH": "%Y-%m",
101
101
  "YEAR": "%Y",
102
102
  }[element._time_resolution]
103
- return f"DATE_FORMAT({compiler.process(element.clauses)}, '{pattern}')"
103
+
104
+ (dt_col,) = list(element.clauses)
105
+ return compiler.process(func.date_format(dt_col, pattern))
104
106
 
105
107
 
106
108
  @compiles(date_trunc, "sqlite")
@@ -113,7 +115,13 @@ def sqlite_date_trunc(element, compiler, **kw):
113
115
  "MONTH": "%Y-%m",
114
116
  "YEAR": "%Y",
115
117
  }[element._time_resolution]
116
- return f"strftime('{pattern}', {compiler.process(element.clauses)})"
118
+ (dt_col,) = list(element.clauses)
119
+ return compiler.process(
120
+ func.strftime(
121
+ pattern,
122
+ dt_col,
123
+ )
124
+ )
117
125
 
118
126
 
119
127
  def substract_date(**kwargs: float) -> datetime: