diracx-db 0.0.1a19__py3-none-any.whl → 0.0.1a21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diracx/db/sql/job/db.py CHANGED
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
3
  from datetime import datetime, timezone
5
4
  from typing import TYPE_CHECKING, Any
6
5
 
@@ -9,14 +8,9 @@ from sqlalchemy.exc import IntegrityError, NoResultFound
9
8
 
10
9
  if TYPE_CHECKING:
11
10
  from sqlalchemy.sql.elements import BindParameter
12
-
13
11
  from diracx.core.exceptions import InvalidQueryError, JobNotFound
14
12
  from diracx.core.models import (
15
- JobMinorStatus,
16
- JobStatus,
17
13
  LimitedJobStatusReturn,
18
- ScalarSearchOperator,
19
- ScalarSearchSpec,
20
14
  SearchSpec,
21
15
  SortSpec,
22
16
  )
@@ -50,11 +44,6 @@ class JobDB(BaseSQLDB):
50
44
  # to find a way to make it dynamic
51
45
  jdl2DBParameters = ["JobName", "JobType", "JobGroup"]
52
46
 
53
- # TODO: set maxRescheduling value from CS
54
- # maxRescheduling = self.getCSOption("MaxRescheduling", 3)
55
- # For now:
56
- maxRescheduling = 3
57
-
58
47
  async def summary(self, group_by, search) -> list[dict[str, str | int]]:
59
48
  columns = _get_columns(Jobs.__table__, group_by)
60
49
 
@@ -81,6 +70,7 @@ class JobDB(BaseSQLDB):
81
70
  ) -> tuple[int, list[dict[Any, Any]]]:
82
71
  # Find which columns to select
83
72
  columns = _get_columns(Jobs.__table__, parameters)
73
+
84
74
  stmt = select(*columns)
85
75
 
86
76
  stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
@@ -107,23 +97,18 @@ class JobDB(BaseSQLDB):
107
97
  dict(row._mapping) async for row in (await self.conn.stream(stmt))
108
98
  ]
109
99
 
110
- async def _insertNewJDL(self, jdl) -> int:
111
- from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
112
-
113
- stmt = insert(JobJDLs).values(
114
- JDL="", JobRequirements="", OriginalJDL=compressJDL(jdl)
100
+ async def insert_input_data(self, lfns: dict[int, list[str]]):
101
+ await self.conn.execute(
102
+ InputData.__table__.insert(),
103
+ [
104
+ {
105
+ "JobID": job_id,
106
+ "LFN": lfn,
107
+ }
108
+ for job_id, lfns_ in lfns.items()
109
+ for lfn in lfns_
110
+ ],
115
111
  )
116
- result = await self.conn.execute(stmt)
117
- # await self.engine.commit()
118
- return result.lastrowid
119
-
120
- async def _insertJob(self, jobData: dict[str, Any]):
121
- stmt = insert(Jobs).values(jobData)
122
- await self.conn.execute(stmt)
123
-
124
- async def _insertInputData(self, job_id: int, lfns: list[str]):
125
- stmt = insert(InputData).values([{"JobID": job_id, "LFN": lfn} for lfn in lfns])
126
- await self.conn.execute(stmt)
127
112
 
128
113
  async def setJobAttributes(self, job_id, jobData):
129
114
  """TODO: add myDate and force parameters."""
@@ -132,7 +117,49 @@ class JobDB(BaseSQLDB):
132
117
  stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData)
133
118
  await self.conn.execute(stmt)
134
119
 
135
- async def _checkAndPrepareJob(
120
+ async def create_job(self, original_jdl):
121
+ """Used to insert a new job with original JDL. Returns inserted job id."""
122
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
123
+
124
+ result = await self.conn.execute(
125
+ JobJDLs.__table__.insert().values(
126
+ JDL="",
127
+ JobRequirements="",
128
+ OriginalJDL=compressJDL(original_jdl),
129
+ )
130
+ )
131
+ return result.lastrowid
132
+
133
+ async def insert_job_attributes(self, jobs_to_update: dict[int, dict]):
134
+ await self.conn.execute(
135
+ Jobs.__table__.insert(),
136
+ [
137
+ {
138
+ "JobID": job_id,
139
+ **attrs,
140
+ }
141
+ for job_id, attrs in jobs_to_update.items()
142
+ ],
143
+ )
144
+
145
+ async def update_job_jdls(self, jdls_to_update: dict[int, str]):
146
+ """Used to update the JDL, typically just after inserting the original JDL, or rescheduling, for example."""
147
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
148
+
149
+ await self.conn.execute(
150
+ JobJDLs.__table__.update().where(
151
+ JobJDLs.__table__.c.JobID == bindparam("b_JobID")
152
+ ),
153
+ [
154
+ {
155
+ "b_JobID": job_id,
156
+ "JDL": compressJDL(jdl),
157
+ }
158
+ for job_id, jdl in jdls_to_update.items()
159
+ ],
160
+ )
161
+
162
+ async def checkAndPrepareJob(
136
163
  self,
137
164
  jobID,
138
165
  class_ad_job,
@@ -175,6 +202,31 @@ class JobDB(BaseSQLDB):
175
202
  )
176
203
  await self.conn.execute(stmt)
177
204
 
205
+ async def setJobJDLsBulk(self, jdls):
206
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
207
+
208
+ await self.conn.execute(
209
+ JobJDLs.__table__.update().where(
210
+ JobJDLs.__table__.c.JobID == bindparam("b_JobID")
211
+ ),
212
+ [{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()],
213
+ )
214
+
215
+ async def setJobAttributesBulk(self, jobData):
216
+ """TODO: add myDate and force parameters."""
217
+ for job_id in jobData.keys():
218
+ if "Status" in jobData[job_id]:
219
+ jobData[job_id].update(
220
+ {"LastUpdateTime": datetime.now(tz=timezone.utc)}
221
+ )
222
+
223
+ await self.conn.execute(
224
+ Jobs.__table__.update().where(
225
+ Jobs.__table__.c.JobID == bindparam("b_JobID")
226
+ ),
227
+ [{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()],
228
+ )
229
+
178
230
  async def getJobJDL(self, job_id: int, original: bool = False) -> str:
179
231
  from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
180
232
 
@@ -189,243 +241,21 @@ class JobDB(BaseSQLDB):
189
241
 
190
242
  return jdl
191
243
 
192
- async def insert(
193
- self,
194
- jdl,
195
- owner,
196
- owner_group,
197
- initial_status,
198
- initial_minor_status,
199
- vo,
200
- ):
201
- from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
202
- from DIRAC.Core.Utilities.ReturnValues import returnValueOrRaise
203
- from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import (
204
- checkAndAddOwner,
205
- createJDLWithInitialStatus,
206
- fixJDL,
207
- )
208
-
209
- job_attrs = {
210
- "LastUpdateTime": datetime.now(tz=timezone.utc),
211
- "SubmissionTime": datetime.now(tz=timezone.utc),
212
- "Owner": owner,
213
- "OwnerGroup": owner_group,
214
- "VO": vo,
215
- }
216
-
217
- jobManifest = returnValueOrRaise(checkAndAddOwner(jdl, owner, owner_group))
218
-
219
- jdl = fixJDL(jdl)
220
-
221
- job_id = await self._insertNewJDL(jdl)
222
-
223
- jobManifest.setOption("JobID", job_id)
224
-
225
- job_attrs["JobID"] = job_id
226
-
227
- # 2.- Check JDL and Prepare DIRAC JDL
228
- jobJDL = jobManifest.dumpAsJDL()
229
-
230
- # Replace the JobID placeholder if any
231
- if jobJDL.find("%j") != -1:
232
- jobJDL = jobJDL.replace("%j", str(job_id))
233
-
234
- class_ad_job = ClassAd(jobJDL)
235
- class_ad_req = ClassAd("[]")
236
- if not class_ad_job.isOK():
237
- job_attrs["Status"] = JobStatus.FAILED
238
-
239
- job_attrs["MinorStatus"] = "Error in JDL syntax"
240
-
241
- await self._insertJob(job_attrs)
242
-
243
- return {
244
- "JobID": job_id,
245
- "Status": JobStatus.FAILED,
246
- "MinorStatus": "Error in JDL syntax",
247
- }
248
-
249
- class_ad_job.insertAttributeInt("JobID", job_id)
250
-
251
- await self._checkAndPrepareJob(
252
- job_id,
253
- class_ad_job,
254
- class_ad_req,
255
- owner,
256
- owner_group,
257
- job_attrs,
258
- vo,
259
- )
260
-
261
- jobJDL = createJDLWithInitialStatus(
262
- class_ad_job,
263
- class_ad_req,
264
- self.jdl2DBParameters,
265
- job_attrs,
266
- initial_status,
267
- initial_minor_status,
268
- modern=True,
269
- )
270
-
271
- await self.setJobJDL(job_id, jobJDL)
272
-
273
- # Adding the job in the Jobs table
274
- await self._insertJob(job_attrs)
275
-
276
- # TODO: check if that is actually true
277
- if class_ad_job.lookupAttribute("Parameters"):
278
- raise NotImplementedError("Parameters in the JDL are not supported")
279
-
280
- # Looking for the Input Data
281
- inputData = []
282
- if class_ad_job.lookupAttribute("InputData"):
283
- inputData = class_ad_job.getListFromExpression("InputData")
284
- lfns = [lfn for lfn in inputData if lfn]
285
- if lfns:
286
- await self._insertInputData(job_id, lfns)
287
-
288
- return {
289
- "JobID": job_id,
290
- "Status": initial_status,
291
- "MinorStatus": initial_minor_status,
292
- "TimeStamp": datetime.now(tz=timezone.utc),
293
- }
294
-
295
- async def rescheduleJob(self, job_id) -> dict[str, Any]:
296
- """Reschedule given job."""
297
- from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
298
- from DIRAC.Core.Utilities.ReturnValues import SErrorException
299
-
300
- _, result = await self.search(
301
- parameters=[
302
- "Status",
303
- "MinorStatus",
304
- "VerifiedFlag",
305
- "RescheduleCounter",
306
- "Owner",
307
- "OwnerGroup",
308
- ],
309
- search=[
310
- ScalarSearchSpec(
311
- parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=job_id
312
- )
313
- ],
314
- sorts=[],
315
- )
316
- if not result:
317
- raise ValueError(f"Job {job_id} not found.")
318
-
319
- jobAttrs = result[0]
320
-
321
- if "VerifiedFlag" not in jobAttrs:
322
- raise ValueError(f"Job {job_id} not found in the system")
323
-
324
- if not jobAttrs["VerifiedFlag"]:
325
- raise ValueError(
326
- f"Job {job_id} not Verified: Status {jobAttrs['Status']}, Minor Status: {jobAttrs['MinorStatus']}"
327
- )
328
-
329
- reschedule_counter = int(jobAttrs["RescheduleCounter"]) + 1
330
-
331
- # TODO: update maxRescheduling:
332
- # self.maxRescheduling = self.getCSOption("MaxRescheduling", self.maxRescheduling)
333
-
334
- if reschedule_counter > self.maxRescheduling:
335
- logging.warn(f"Job {job_id}: Maximum number of reschedulings is reached.")
336
- self.setJobAttributes(
337
- job_id,
338
- {
339
- "Status": JobStatus.FAILED,
340
- "MinorStatus": JobMinorStatus.MAX_RESCHEDULING,
341
- },
342
- )
343
- raise ValueError(
344
- f"Maximum number of reschedulings is reached: {self.maxRescheduling}"
345
- )
346
-
347
- new_job_attributes = {"RescheduleCounter": reschedule_counter}
348
-
349
- # TODO: get the job parameters from JobMonitoringClient
350
- # result = JobMonitoringClient().getJobParameters(jobID)
351
- # if result["OK"]:
352
- # parDict = result["Value"]
353
- # for key, value in parDict.get(jobID, {}).items():
354
- # result = self.setAtticJobParameter(jobID, key, value, rescheduleCounter - 1)
355
- # if not result["OK"]:
356
- # break
357
-
358
- # TODO: IF we keep JobParameters and OptimizerParameters: Delete job in those tables.
359
- # await self.delete_job_parameters(job_id)
360
- # await self.delete_job_optimizer_parameters(job_id)
361
-
362
- job_jdl = await self.getJobJDL(job_id, original=True)
363
- if not job_jdl.strip().startswith("["):
364
- job_jdl = f"[{job_jdl}]"
365
-
366
- classAdJob = ClassAd(job_jdl)
367
- classAdReq = ClassAd("[]")
368
- retVal = {}
369
- retVal["JobID"] = job_id
370
-
371
- classAdJob.insertAttributeInt("JobID", job_id)
244
+ async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, str]:
245
+ from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
372
246
 
373
- try:
374
- result = await self._checkAndPrepareJob(
375
- job_id,
376
- classAdJob,
377
- classAdReq,
378
- jobAttrs["Owner"],
379
- jobAttrs["OwnerGroup"],
380
- new_job_attributes,
381
- classAdJob.getAttributeString("VirtualOrganization"),
247
+ if original:
248
+ stmt = select(JobJDLs.JobID, JobJDLs.OriginalJDL).where(
249
+ JobJDLs.JobID.in_(job_ids)
382
250
  )
383
- except SErrorException as e:
384
- raise ValueError(e) from e
385
-
386
- priority = classAdJob.getAttributeInt("Priority")
387
- if priority is None:
388
- priority = 0
389
- jobAttrs["UserPriority"] = priority
390
-
391
- siteList = classAdJob.getListFromExpression("Site")
392
- if not siteList:
393
- site = "ANY"
394
- elif len(siteList) > 1:
395
- site = "Multiple"
396
251
  else:
397
- site = siteList[0]
398
-
399
- jobAttrs["Site"] = site
400
-
401
- jobAttrs["Status"] = JobStatus.RECEIVED
402
-
403
- jobAttrs["MinorStatus"] = JobMinorStatus.RESCHEDULED
404
-
405
- jobAttrs["ApplicationStatus"] = "Unknown"
252
+ stmt = select(JobJDLs.JobID, JobJDLs.JDL).where(JobJDLs.JobID.in_(job_ids))
406
253
 
407
- jobAttrs["LastUpdateTime"] = datetime.now(tz=timezone.utc)
408
-
409
- jobAttrs["RescheduleTime"] = datetime.now(tz=timezone.utc)
410
-
411
- reqJDL = classAdReq.asJDL()
412
- classAdJob.insertAttributeInt("JobRequirements", reqJDL)
413
-
414
- jobJDL = classAdJob.asJDL()
415
-
416
- # Replace the JobID placeholder if any
417
- jobJDL = jobJDL.replace("%j", str(job_id))
418
-
419
- result = await self.setJobJDL(job_id, jobJDL)
420
-
421
- result = await self.setJobAttributes(job_id, jobAttrs)
422
-
423
- retVal["InputData"] = classAdJob.lookupAttribute("InputData")
424
- retVal["RescheduleCounter"] = reschedule_counter
425
- retVal["Status"] = JobStatus.RECEIVED
426
- retVal["MinorStatus"] = JobMinorStatus.RESCHEDULED
427
-
428
- return retVal
254
+ return {
255
+ jobid: extractJDL(jdl)
256
+ for jobid, jdl in (await self.conn.execute(stmt))
257
+ if jdl
258
+ }
429
259
 
430
260
  async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
431
261
  try:
@@ -451,6 +281,22 @@ class JobDB(BaseSQLDB):
451
281
  except IntegrityError as e:
452
282
  raise JobNotFound(job_id) from e
453
283
 
284
+ async def set_job_command_bulk(self, commands):
285
+ """Store a command to be passed to the job together with the next heart beat."""
286
+ self.conn.execute(
287
+ insert(JobCommands),
288
+ [
289
+ {
290
+ "JobID": job_id,
291
+ "Command": command,
292
+ "Arguments": arguments,
293
+ "ReceptionTime": datetime.now(tz=timezone.utc),
294
+ }
295
+ for job_id, command, arguments in commands
296
+ ],
297
+ )
298
+ # FIXME handle IntegrityError
299
+
454
300
  async def delete_jobs(self, job_ids: list[int]):
455
301
  """Delete jobs from the database."""
456
302
  stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids))
@@ -4,11 +4,14 @@ import time
4
4
  from datetime import datetime, timezone
5
5
  from typing import TYPE_CHECKING
6
6
 
7
+ from pydantic import BaseModel
7
8
  from sqlalchemy import delete, func, insert, select
8
9
 
9
10
  if TYPE_CHECKING:
10
11
  pass
11
12
 
13
+ from collections import defaultdict
14
+
12
15
  from diracx.core.exceptions import JobNotFound
13
16
  from diracx.core.models import (
14
17
  JobStatus,
@@ -24,6 +27,15 @@ from .schema import (
24
27
  MAGIC_EPOC_NUMBER = 1270000000
25
28
 
26
29
 
30
+ class JobLoggingRecord(BaseModel):
31
+ job_id: int
32
+ status: JobStatus
33
+ minor_status: str
34
+ application_status: str
35
+ date: datetime
36
+ source: str
37
+
38
+
27
39
  class JobLoggingDB(BaseSQLDB):
28
40
  """Frontend for the JobLoggingDB. Provides the ability to store changes with timestamps."""
29
41
 
@@ -69,6 +81,49 @@ class JobLoggingDB(BaseSQLDB):
69
81
  )
70
82
  await self.conn.execute(stmt)
71
83
 
84
+ async def bulk_insert_record(
85
+ self,
86
+ records: list[JobLoggingRecord],
87
+ ):
88
+ """Bulk insert entries to the JobLoggingDB table."""
89
+
90
+ def get_epoc(date):
91
+ return (
92
+ time.mktime(date.timetuple())
93
+ + date.microsecond / 1000000.0
94
+ - MAGIC_EPOC_NUMBER
95
+ )
96
+
97
+ # First, fetch the maximum SeqNums for the given job_ids
98
+ seqnum_stmt = (
99
+ select(
100
+ LoggingInfo.JobID, func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)
101
+ )
102
+ .where(LoggingInfo.JobID.in_([record.job_id for record in records]))
103
+ .group_by(LoggingInfo.JobID)
104
+ )
105
+
106
+ seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))}
107
+ # IF a seqnum is not found, then assume it does not exist and the first sequence number is 1.
108
+
109
+ # https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements
110
+ await self.conn.execute(
111
+ insert(LoggingInfo),
112
+ [
113
+ {
114
+ "JobID": record.job_id,
115
+ "SeqNum": seqnum.get(record.job_id, 1),
116
+ "Status": record.status,
117
+ "MinorStatus": record.minor_status,
118
+ "ApplicationStatus": record.application_status[:255],
119
+ "StatusTime": record.date,
120
+ "StatusTimeOrder": get_epoc(record.date),
121
+ "Source": record.source[:32],
122
+ }
123
+ for record in records
124
+ ],
125
+ )
126
+
72
127
  async def get_records(self, job_id: int) -> list[JobStatusReturn]:
73
128
  """Returns a Status,MinorStatus,ApplicationStatus,StatusTime,Source tuple
74
129
  for each record found for job specified by its jobID in historical order.
@@ -159,3 +214,22 @@ class JobLoggingDB(BaseSQLDB):
159
214
  result[event] = str(etime + MAGIC_EPOC_NUMBER)
160
215
 
161
216
  return result
217
+
218
+ async def get_wms_time_stamps_bulk(self, job_ids):
219
+ """Get TimeStamps for job MajorState transitions for multiple jobs at once
220
+ return a {JobID: {State:timestamp}} dictionary.
221
+ """
222
+ result = defaultdict(dict)
223
+ stmt = select(
224
+ LoggingInfo.JobID,
225
+ LoggingInfo.Status,
226
+ LoggingInfo.StatusTimeOrder,
227
+ ).where(LoggingInfo.JobID.in_(job_ids))
228
+ rows = await self.conn.execute(stmt)
229
+ if not rows.rowcount:
230
+ return {}
231
+
232
+ for job_id, event, etime in rows:
233
+ result[job_id][event] = str(etime + MAGIC_EPOC_NUMBER)
234
+
235
+ return result
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Self, cast
16
16
  import sqlalchemy.types as types
17
17
  from pydantic import TypeAdapter
18
18
  from sqlalchemy import Column as RawColumn
19
- from sqlalchemy import DateTime, Enum, MetaData, select
19
+ from sqlalchemy import DateTime, Enum, MetaData, func, select
20
20
  from sqlalchemy.exc import OperationalError
21
21
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
22
22
  from sqlalchemy.ext.compiler import compiles
@@ -100,7 +100,9 @@ def mysql_date_trunc(element, compiler, **kw):
100
100
  "MONTH": "%Y-%m",
101
101
  "YEAR": "%Y",
102
102
  }[element._time_resolution]
103
- return f"DATE_FORMAT({compiler.process(element.clauses)}, '{pattern}')"
103
+
104
+ (dt_col,) = list(element.clauses)
105
+ return compiler.process(func.date_format(dt_col, pattern))
104
106
 
105
107
 
106
108
  @compiles(date_trunc, "sqlite")
@@ -113,7 +115,13 @@ def sqlite_date_trunc(element, compiler, **kw):
113
115
  "MONTH": "%Y-%m",
114
116
  "YEAR": "%Y",
115
117
  }[element._time_resolution]
116
- return f"strftime('{pattern}', {compiler.process(element.clauses)})"
118
+ (dt_col,) = list(element.clauses)
119
+ return compiler.process(
120
+ func.strftime(
121
+ pattern,
122
+ dt_col,
123
+ )
124
+ )
117
125
 
118
126
 
119
127
  def substract_date(**kwargs: float) -> datetime: