DIRAC 9.0.9__py3-none-any.whl → 9.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. DIRAC/Core/Utilities/ElasticSearchDB.py +1 -2
  2. DIRAC/DataManagementSystem/Client/DataManager.py +6 -7
  3. DIRAC/FrameworkSystem/DB/InstalledComponentsDB.py +1 -1
  4. DIRAC/FrameworkSystem/Utilities/MonitoringUtilities.py +1 -0
  5. DIRAC/Interfaces/Utilities/DConfigCache.py +2 -0
  6. DIRAC/Resources/Computing/BatchSystems/Condor.py +0 -3
  7. DIRAC/Resources/Computing/BatchSystems/executeBatch.py +15 -7
  8. DIRAC/Resources/Computing/LocalComputingElement.py +0 -2
  9. DIRAC/Resources/Computing/SSHComputingElement.py +61 -38
  10. DIRAC/TransformationSystem/Agent/InputDataAgent.py +4 -1
  11. DIRAC/TransformationSystem/Agent/MCExtensionAgent.py +5 -2
  12. DIRAC/TransformationSystem/Agent/TaskManagerAgentBase.py +3 -1
  13. DIRAC/TransformationSystem/Agent/TransformationCleaningAgent.py +44 -9
  14. DIRAC/TransformationSystem/Agent/ValidateOutputDataAgent.py +4 -2
  15. DIRAC/TransformationSystem/Client/TransformationClient.py +9 -1
  16. DIRAC/TransformationSystem/DB/TransformationDB.py +105 -43
  17. DIRAC/WorkloadManagementSystem/Agent/StalledJobAgent.py +39 -7
  18. DIRAC/WorkloadManagementSystem/Agent/test/Test_Agent_StalledJobAgent.py +24 -4
  19. DIRAC/WorkloadManagementSystem/DB/StatusUtils.py +48 -21
  20. DIRAC/WorkloadManagementSystem/DB/tests/Test_StatusUtils.py +19 -4
  21. DIRAC/WorkloadManagementSystem/Service/JobManagerHandler.py +25 -2
  22. {dirac-9.0.9.dist-info → dirac-9.0.11.dist-info}/METADATA +2 -2
  23. {dirac-9.0.9.dist-info → dirac-9.0.11.dist-info}/RECORD +27 -27
  24. {dirac-9.0.9.dist-info → dirac-9.0.11.dist-info}/WHEEL +0 -0
  25. {dirac-9.0.9.dist-info → dirac-9.0.11.dist-info}/entry_points.txt +0 -0
  26. {dirac-9.0.9.dist-info → dirac-9.0.11.dist-info}/licenses/LICENSE +0 -0
  27. {dirac-9.0.9.dist-info → dirac-9.0.11.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,9 @@ This class is typically used as a base class for more specific data processing
6
6
  databases
7
7
  """
8
8
 
9
+ # Disable it because pylint does not understand decorator (convertToReturnValue)
10
+
11
+ # pylint: disable=invalid-sequence-index
9
12
  import re
10
13
  import time
11
14
  import threading
@@ -15,6 +18,7 @@ from errno import ENOENT
15
18
  from DIRAC import gLogger, S_OK, S_ERROR
16
19
  from DIRAC.Core.Base.DB import DB
17
20
  from DIRAC.Core.Utilities.DErrno import cmpError
21
+ from DIRAC.Core.Utilities.ReturnValues import convertToReturnValue, returnValueOrRaise
18
22
  from DIRAC.Resources.Catalog.FileCatalog import FileCatalog
19
23
  from DIRAC.Core.Security.ProxyInfo import getProxyInfo
20
24
  from DIRAC.Core.Utilities.List import stringListToString, intListToString, breakListIntoChunks
@@ -25,6 +29,7 @@ from DIRAC.DataManagementSystem.Client.MetaQuery import MetaQuery
25
29
 
26
30
  MAX_ERROR_COUNT = 10
27
31
 
32
+ TMP_TABLE_JOIN_LIMIT = 100
28
33
  #############################################################################
29
34
 
30
35
 
@@ -270,6 +275,7 @@ class TransformationDB(DB):
270
275
  self.__updateTransformationLogging(transID, message, author, connection=connection)
271
276
  return S_OK(transID)
272
277
 
278
+ @convertToReturnValue
273
279
  def getTransformations(
274
280
  self,
275
281
  condDict=None,
@@ -289,32 +295,54 @@ class TransformationDB(DB):
289
295
  columns = self.TRANSPARAMS
290
296
  else:
291
297
  columns = [c for c in columns if c in self.TRANSPARAMS]
292
- req = "SELECT {} FROM Transformations {}".format(
293
- intListToString(columns),
294
- self.buildCondition(condDict, older, newer, timeStamp, orderAttribute, limit, offset=offset),
295
- )
296
- res = self._query(req, conn=connection)
297
- if not res["OK"]:
298
- return res
299
- if condDict is None:
300
- condDict = {}
301
- webList = []
298
+
299
+ join_query = ""
300
+
301
+ try:
302
+ # If we request multiple TransformationIDs, and they are more than TMP_TABLE_JOIN_LIMIT,
303
+ # we create a temporary table to speed up the query
304
+ if (
305
+ "TransformationID" in condDict
306
+ and isinstance(condDict["TransformationID"], list)
307
+ and len(condDict["TransformationID"]) > TMP_TABLE_JOIN_LIMIT
308
+ ):
309
+ # Create temporary table for TransformationIDs
310
+ transIDs = condDict.pop("TransformationID")
311
+ sqlCmd = "CREATE TEMPORARY TABLE to_query_TransformationIDs (TransID INTEGER NOT NULL, PRIMARY KEY (TransID)) ENGINE=MEMORY;"
312
+ returnValueOrRaise(self._update(sqlCmd, conn=connection))
313
+ join_query = " JOIN to_query_TransformationIDs t ON TransformationID = t.TransID"
314
+
315
+ # Insert TransformationIDs into temporary table
316
+ sqlCmd = "INSERT INTO to_query_TransformationIDs (TransID) VALUES ( %s )"
317
+ returnValueOrRaise(self._updatemany(sqlCmd, [(transID,) for transID in transIDs], conn=connection))
318
+
319
+ req = "SELECT {} FROM Transformations {} {}".format(
320
+ intListToString(columns),
321
+ join_query,
322
+ self.buildCondition(condDict, older, newer, timeStamp, orderAttribute, limit, offset=offset),
323
+ )
324
+ matching_transformations = returnValueOrRaise(self._query(req, conn=connection))
325
+
326
+ finally:
327
+ # Clean up temporary table
328
+ if join_query:
329
+ sqlCmd = "DROP TEMPORARY TABLE to_query_TransformationIDs"
330
+ self._update(sqlCmd, conn=connection)
331
+
332
+ # TODO: optimize by getting all the extra params at once
302
333
  resultList = []
303
- for row in res["Value"]:
334
+ for row in matching_transformations:
304
335
  # Prepare the structure for the web
305
- rList = [str(item) if not isinstance(item, int) else item for item in row]
306
336
  transDict = dict(zip(columns, row))
307
- webList.append(rList)
308
337
  if extraParams:
309
- res = self.__getAdditionalParameters(transDict["TransformationID"], connection=connection)
310
- if not res["OK"]:
311
- return res
312
- transDict.update(res["Value"])
338
+ trans_extra_param = returnValueOrRaise(
339
+ self.__getAdditionalParameters(transDict["TransformationID"], connection=connection)
340
+ )
341
+
342
+ transDict.update(trans_extra_param)
313
343
  resultList.append(transDict)
314
- result = S_OK(resultList)
315
- result["Records"] = webList
316
- result["ParameterNames"] = columns
317
- return result
344
+
345
+ return resultList
318
346
 
319
347
  def getTransformation(self, transName, extraParams=False, connection=False):
320
348
  """Get Transformation definition and parameters of Transformation identified by TransformationID"""
@@ -710,21 +738,38 @@ class TransformationDB(DB):
710
738
  countDict["Total"] = sum(countDict.values())
711
739
  return S_OK(countDict)
712
740
 
741
+ @convertToReturnValue
713
742
  def __addFilesToTransformation(self, transID, fileIDs, connection=False):
714
- req = "SELECT FileID from TransformationFiles"
715
- req = req + " WHERE TransformationID = %d AND FileID IN (%s);" % (transID, intListToString(fileIDs))
716
- res = self._query(req, conn=connection)
717
- if not res["OK"]:
718
- return res
719
- for tupleIn in res["Value"]:
720
- fileIDs.remove(tupleIn[0])
721
- if not fileIDs:
722
- return S_OK([])
723
- values = [(transID, fileID) for fileID in fileIDs]
724
- req = "INSERT INTO TransformationFiles (TransformationID,FileID,LastUpdate,InsertedTime) VALUES (%s, %s, UTC_TIMESTAMP(), UTC_TIMESTAMP())"
725
- if not (res := self._updatemany(req, values, conn=connection))["OK"]:
726
- return res
727
- return S_OK(fileIDs)
743
+ # Create temporary table for FileIDs
744
+ sqlCmd = "CREATE TEMPORARY TABLE to_query_FileIDs (FileID INT(11) UNSIGNED NOT NULL, PRIMARY KEY (FileID)) ENGINE=MEMORY;"
745
+ returnValueOrRaise(self._update(sqlCmd, conn=connection))
746
+
747
+ try:
748
+ # Insert FileIDs into temporary table
749
+ sqlCmd = "INSERT INTO to_query_FileIDs (FileID) VALUES ( %s )"
750
+ returnValueOrRaise(self._updatemany(sqlCmd, [(fileID,) for fileID in fileIDs], conn=connection))
751
+
752
+ # Query existing files using JOIN
753
+ req = (
754
+ "SELECT tf.FileID FROM TransformationFiles tf JOIN to_query_FileIDs t ON tf.FileID = t.FileID WHERE tf.TransformationID = %d;"
755
+ % transID
756
+ )
757
+ res = returnValueOrRaise(self._query(req, conn=connection))
758
+
759
+ # Remove already existing fileIDs using set difference for efficiency
760
+ existingFileIDs = {tupleIn[0] for tupleIn in res}
761
+ fileIDs = list(set(fileIDs) - existingFileIDs)
762
+ if not fileIDs:
763
+ return []
764
+
765
+ values = [(transID, fileID) for fileID in fileIDs]
766
+ req = "INSERT INTO TransformationFiles (TransformationID,FileID,LastUpdate,InsertedTime) VALUES (%s, %s, UTC_TIMESTAMP(), UTC_TIMESTAMP())"
767
+ returnValueOrRaise(self._updatemany(req, values, conn=connection))
768
+ return fileIDs
769
+ finally:
770
+ # Clean up temporary table
771
+ sqlCmd = "DROP TEMPORARY TABLE to_query_FileIDs"
772
+ returnValueOrRaise(self._update(sqlCmd, conn=connection))
728
773
 
729
774
  def __insertExistingTransformationFiles(self, transID, fileTuplesList, connection=False):
730
775
  """Inserting already transformation files in TransformationFiles table (e.g. for deriving transformations)"""
@@ -1271,18 +1316,35 @@ class TransformationDB(DB):
1271
1316
  # These methods manipulate the DataFiles table
1272
1317
  #
1273
1318
 
1319
+ @convertToReturnValue
1274
1320
  def __getFileIDsForLfns(self, lfns, connection=False):
1275
1321
  """Get file IDs for the given list of lfns
1276
1322
  warning: if the file is not present, we'll see no errors
1277
1323
  """
1278
- req = f"SELECT LFN,FileID FROM DataFiles WHERE LFN in ({stringListToString(lfns)});"
1279
- res = self._query(req, conn=connection)
1280
- if not res["OK"]:
1281
- return res
1282
- lfns = dict(res["Value"])
1283
- # Reverse dictionary
1284
- fids = {fileID: lfn for lfn, fileID in lfns.items()}
1285
- return S_OK((fids, lfns))
1324
+
1325
+ if not lfns:
1326
+ return ({}, {})
1327
+ # Create temporary table for LFNs
1328
+ sqlCmd = "CREATE TEMPORARY TABLE to_query_LFNs (LFN VARCHAR(255) NOT NULL, PRIMARY KEY (LFN)) ENGINE=MEMORY;"
1329
+ returnValueOrRaise(self._update(sqlCmd, conn=connection))
1330
+
1331
+ try:
1332
+ # Insert LFNs into temporary table
1333
+ sqlCmd = "INSERT INTO to_query_LFNs (LFN) VALUES ( %s )"
1334
+ returnValueOrRaise(self._updatemany(sqlCmd, [(lfn,) for lfn in lfns], conn=connection))
1335
+
1336
+ # Query using JOIN with temporary table
1337
+ req = "SELECT df.LFN, df.FileID FROM DataFiles df JOIN to_query_LFNs t ON df.LFN = t.LFN;"
1338
+ res = returnValueOrRaise(self._query(req, conn=connection))
1339
+
1340
+ lfns = dict(res)
1341
+ # Reverse dictionary
1342
+ fids = {fileID: lfn for lfn, fileID in lfns.items()}
1343
+ return (fids, lfns)
1344
+ finally:
1345
+ # Clean up temporary table
1346
+ sqlCmd = "DROP TEMPORARY TABLE to_query_LFNs"
1347
+ self._update(sqlCmd, conn=connection)
1286
1348
 
1287
1349
  def __getLfnsForFileIDs(self, fileIDs, connection=False):
1288
1350
  """Get lfns for the given list of fileIDs"""
@@ -17,11 +17,9 @@ from DIRAC.ConfigurationSystem.Client.Helpers import cfgPath
17
17
  from DIRAC.Core.Base.AgentModule import AgentModule
18
18
  from DIRAC.Core.Utilities import DErrno
19
19
  from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd
20
+ from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader
20
21
  from DIRAC.Core.Utilities.TimeUtilities import fromString, second, toEpoch
21
22
  from DIRAC.WorkloadManagementSystem.Client import JobMinorStatus, JobStatus
22
- from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB
23
- from DIRAC.WorkloadManagementSystem.DB.JobLoggingDB import JobLoggingDB
24
- from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB
25
23
  from DIRAC.WorkloadManagementSystem.Service.JobPolicy import RIGHT_KILL
26
24
  from DIRAC.WorkloadManagementSystem.DB.StatusUtils import kill_delete_jobs
27
25
  from DIRAC.WorkloadManagementSystem.Utilities.JobParameters import getJobParameters
@@ -40,6 +38,9 @@ class StalledJobAgent(AgentModule):
40
38
 
41
39
  self.jobDB = None
42
40
  self.logDB = None
41
+ self.taskQueueDB = None
42
+ self.pilotAgentsDB = None
43
+ self.storageManagementDB = None
43
44
  self.matchedTime = 7200
44
45
  self.rescheduledTime = 600
45
46
  self.submittingTime = 300
@@ -51,8 +52,30 @@ class StalledJobAgent(AgentModule):
51
52
  #############################################################################
52
53
  def initialize(self):
53
54
  """Sets default parameters."""
54
- self.jobDB = JobDB()
55
- self.logDB = JobLoggingDB()
55
+ result = ObjectLoader().loadObject("WorkloadManagementSystem.DB.JobDB", "JobDB")
56
+ if not result["OK"]:
57
+ return result
58
+ self.jobDB = result["Value"]()
59
+
60
+ result = ObjectLoader().loadObject("WorkloadManagementSystem.DB.JobLoggingDB", "JobLoggingDB")
61
+ if not result["OK"]:
62
+ return result
63
+ self.logDB = result["Value"]()
64
+
65
+ result = ObjectLoader().loadObject("WorkloadManagementSystem.DB.TaskQueueDB", "TaskQueueDB")
66
+ if not result["OK"]:
67
+ return result
68
+ self.taskQueueDB = result["Value"]()
69
+
70
+ result = ObjectLoader().loadObject("WorkloadManagementSystem.DB.PilotAgentsDB", "PilotAgentsDB")
71
+ if not result["OK"]:
72
+ return result
73
+ self.pilotAgentsDB = result["Value"]()
74
+
75
+ result = ObjectLoader().loadObject("StorageManagementSystem.DB.StorageManagementDB", "StorageManagementDB")
76
+ if not result["OK"]:
77
+ return result
78
+ self.storageManagementDB = result["Value"]()
56
79
 
57
80
  # getting parameters
58
81
 
@@ -235,7 +258,16 @@ class StalledJobAgent(AgentModule):
235
258
  # Set the jobs Failed, send them a kill signal in case they are not really dead
236
259
  # and send accounting info
237
260
  if setFailed:
238
- res = kill_delete_jobs(RIGHT_KILL, [jobID], nonauthJobList=[], force=True)
261
+ res = kill_delete_jobs(
262
+ RIGHT_KILL,
263
+ [jobID],
264
+ nonauthJobList=[],
265
+ force=True,
266
+ jobdb=self.jobDB,
267
+ taskqueuedb=self.taskQueueDB,
268
+ pilotagentsdb=self.pilotAgentsDB,
269
+ storagemanagementdb=self.storageManagementDB,
270
+ )
239
271
  if not res["OK"]:
240
272
  self.log.error("Failed to kill job", jobID)
241
273
 
@@ -262,7 +294,7 @@ class StalledJobAgent(AgentModule):
262
294
  # There is no pilot reference, hence its status is unknown
263
295
  return S_OK("NoPilot")
264
296
 
265
- result = PilotAgentsDB().getPilotInfo(pilotReference)
297
+ result = self.pilotAgentsDB.getPilotInfo(pilotReference)
266
298
  if not result["OK"]:
267
299
  if DErrno.cmpError(result, DErrno.EWMSNOPILOT):
268
300
  self.log.warn("No pilot found", f"for job {jobID}: {result['Message']}")
@@ -23,10 +23,31 @@ def sja(mocker):
23
23
  side_effect=lambda x, y=None: y,
24
24
  create=True,
25
25
  )
26
- mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.JobDB")
27
- mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.JobLoggingDB")
26
+
27
+ # Mock ObjectLoader to return mock DB instances
28
+ mockJobDB = MagicMock()
29
+ mockJobDB.log = gLogger
30
+ mockJobLoggingDB = MagicMock()
31
+ mockTaskQueueDB = MagicMock()
32
+ mockPilotAgentsDB = MagicMock()
33
+ mockStorageManagementDB = MagicMock()
34
+
35
+ def mock_load_object(module_path, class_name):
36
+ mocks = {
37
+ "JobDB": mockJobDB,
38
+ "JobLoggingDB": mockJobLoggingDB,
39
+ "TaskQueueDB": mockTaskQueueDB,
40
+ "PilotAgentsDB": mockPilotAgentsDB,
41
+ "StorageManagementDB": mockStorageManagementDB,
42
+ }
43
+ return {"OK": True, "Value": lambda: mocks[class_name]}
44
+
45
+ mocker.patch(
46
+ "DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.ObjectLoader.loadObject",
47
+ side_effect=mock_load_object,
48
+ )
49
+
28
50
  mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.rescheduleJobs", return_value=MagicMock())
29
- mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.PilotAgentsDB", return_value=MagicMock())
30
51
  mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.getJobParameters", return_value=MagicMock())
31
52
  mocker.patch("DIRAC.WorkloadManagementSystem.Agent.StalledJobAgent.kill_delete_jobs", return_value=MagicMock())
32
53
 
@@ -34,7 +55,6 @@ def sja(mocker):
34
55
  stalledJobAgent._AgentModule__configDefaults = mockAM
35
56
  stalledJobAgent.log = gLogger
36
57
  stalledJobAgent.initialize()
37
- stalledJobAgent.jobDB.log = gLogger
38
58
  stalledJobAgent.log.setLevel("DEBUG")
39
59
  stalledJobAgent.stalledTime = 120
40
60
 
@@ -1,43 +1,40 @@
1
1
  from DIRAC import S_ERROR, S_OK, gLogger
2
- from DIRAC.StorageManagementSystem.DB.StorageManagementDB import StorageManagementDB
3
2
  from DIRAC.WorkloadManagementSystem.Client import JobStatus
4
- from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB
5
- from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB
6
- from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import TaskQueueDB
3
+ from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader
7
4
  from DIRAC.WorkloadManagementSystem.Service.JobPolicy import RIGHT_DELETE, RIGHT_KILL
8
5
  from DIRAC.WorkloadManagementSystem.Utilities.jobAdministration import _filterJobStateTransition
9
6
 
10
7
 
11
- def _deleteJob(jobID, force=False):
8
+ def _deleteJob(jobID, force=False, *, jobdb, taskqueuedb, pilotagentsdb):
12
9
  """Set the job status to "Deleted"
13
10
  and remove the pilot that ran and its logging info if the pilot is finished.
14
11
 
15
12
  :param int jobID: job ID
16
13
  :return: S_OK()/S_ERROR()
17
14
  """
18
- if not (result := JobDB().setJobStatus(jobID, JobStatus.DELETED, "Checking accounting", force=force))["OK"]:
15
+ if not (result := jobdb.setJobStatus(jobID, JobStatus.DELETED, "Checking accounting", force=force))["OK"]:
19
16
  gLogger.warn("Failed to set job Deleted status", result["Message"])
20
17
  return result
21
18
 
22
- if not (result := TaskQueueDB().deleteJob(jobID))["OK"]:
19
+ if not (result := taskqueuedb.deleteJob(jobID))["OK"]:
23
20
  gLogger.warn("Failed to delete job from the TaskQueue")
24
21
 
25
22
  # if it was the last job for the pilot
26
- result = PilotAgentsDB().getPilotsForJobID(jobID)
23
+ result = pilotagentsdb.getPilotsForJobID(jobID)
27
24
  if not result["OK"]:
28
25
  gLogger.error("Failed to get Pilots for JobID", result["Message"])
29
26
  return result
30
27
  for pilot in result["Value"]:
31
- res = PilotAgentsDB().getJobsForPilot(pilot)
28
+ res = pilotagentsdb.getJobsForPilot(pilot)
32
29
  if not res["OK"]:
33
30
  gLogger.error("Failed to get jobs for pilot", res["Message"])
34
31
  return res
35
32
  if not res["Value"]: # if list of jobs for pilot is empty, delete pilot
36
- result = PilotAgentsDB().getPilotInfo(pilotID=pilot)
33
+ result = pilotagentsdb.getPilotInfo(pilotID=pilot)
37
34
  if not result["OK"]:
38
35
  gLogger.error("Failed to get pilot info", result["Message"])
39
36
  return result
40
- ret = PilotAgentsDB().deletePilot(result["Value"]["PilotJobReference"])
37
+ ret = pilotagentsdb.deletePilot(result["Value"]["PilotJobReference"])
41
38
  if not ret["OK"]:
42
39
  gLogger.error("Failed to delete pilot from PilotAgentsDB", ret["Message"])
43
40
  return ret
@@ -45,7 +42,7 @@ def _deleteJob(jobID, force=False):
45
42
  return S_OK()
46
43
 
47
44
 
48
- def _killJob(jobID, sendKillCommand=True, force=False):
45
+ def _killJob(jobID, sendKillCommand=True, force=False, *, jobdb, taskqueuedb):
49
46
  """Kill one job
50
47
 
51
48
  :param int jobID: job ID
@@ -54,32 +51,63 @@ def _killJob(jobID, sendKillCommand=True, force=False):
54
51
  :return: S_OK()/S_ERROR()
55
52
  """
56
53
  if sendKillCommand:
57
- if not (result := JobDB().setJobCommand(jobID, "Kill"))["OK"]:
54
+ if not (result := jobdb.setJobCommand(jobID, "Kill"))["OK"]:
58
55
  gLogger.warn("Failed to set job Kill command", result["Message"])
59
56
  return result
60
57
 
61
58
  gLogger.info("Job marked for termination", jobID)
62
- if not (result := JobDB().setJobStatus(jobID, JobStatus.KILLED, "Marked for termination", force=force))["OK"]:
59
+ if not (result := jobdb.setJobStatus(jobID, JobStatus.KILLED, "Marked for termination", force=force))["OK"]:
63
60
  gLogger.warn("Failed to set job Killed status", result["Message"])
64
- if not (result := TaskQueueDB().deleteJob(jobID))["OK"]:
61
+ if not (result := taskqueuedb.deleteJob(jobID))["OK"]:
65
62
  gLogger.warn("Failed to delete job from the TaskQueue", result["Message"])
66
63
 
67
64
  return S_OK()
68
65
 
69
66
 
70
- def kill_delete_jobs(right, validJobList, nonauthJobList=[], force=False):
67
+ def kill_delete_jobs(
68
+ right,
69
+ validJobList,
70
+ nonauthJobList=[],
71
+ force=False,
72
+ *,
73
+ jobdb=None,
74
+ taskqueuedb=None,
75
+ pilotagentsdb=None,
76
+ storagemanagementdb=None,
77
+ ):
71
78
  """Kill (== set the status to "KILLED") or delete (== set the status to "DELETED") jobs as necessary
72
79
 
73
80
  :param str right: RIGHT_KILL or RIGHT_DELETE
74
81
 
75
82
  :return: S_OK()/S_ERROR()
76
83
  """
84
+ if jobdb is None:
85
+ result = ObjectLoader().loadObject("WorkloadManagementSystem.DB.JobDB", "JobDB")
86
+ if not result["OK"]:
87
+ return result
88
+ jobdb = result["Value"]()
89
+ if taskqueuedb is None:
90
+ result = ObjectLoader().loadObject("WorkloadManagementSystem.DB.TaskQueueDB", "TaskQueueDB")
91
+ if not result["OK"]:
92
+ return result
93
+ taskqueuedb = result["Value"]()
94
+ if pilotagentsdb is None:
95
+ result = ObjectLoader().loadObject("WorkloadManagementSystem.DB.PilotAgentsDB", "PilotAgentsDB")
96
+ if not result["OK"]:
97
+ return result
98
+ pilotagentsdb = result["Value"]()
99
+ if storagemanagementdb is None:
100
+ result = ObjectLoader().loadObject("StorageManagementSystem.DB.StorageManagementDB", "StorageManagementDB")
101
+ if not result["OK"]:
102
+ return result
103
+ storagemanagementdb = result["Value"]()
104
+
77
105
  badIDs = []
78
106
 
79
107
  killJobList = []
80
108
  deleteJobList = []
81
109
  if validJobList:
82
- result = JobDB().getJobsAttributes(validJobList, ["Status"])
110
+ result = jobdb.getJobsAttributes(validJobList, ["Status"])
83
111
  if not result["OK"]:
84
112
  return result
85
113
  jobStates = result["Value"]
@@ -92,12 +120,12 @@ def kill_delete_jobs(right, validJobList, nonauthJobList=[], force=False):
92
120
  deleteJobList.extend(_filterJobStateTransition(jobStates, JobStatus.DELETED))
93
121
 
94
122
  for jobID in killJobList:
95
- result = _killJob(jobID, force=force)
123
+ result = _killJob(jobID, force=force, jobdb=jobdb, taskqueuedb=taskqueuedb)
96
124
  if not result["OK"]:
97
125
  badIDs.append(jobID)
98
126
 
99
127
  for jobID in deleteJobList:
100
- result = _deleteJob(jobID, force=force)
128
+ result = _deleteJob(jobID, force=force, jobdb=jobdb, taskqueuedb=taskqueuedb, pilotagentsdb=pilotagentsdb)
101
129
  if not result["OK"]:
102
130
  badIDs.append(jobID)
103
131
 
@@ -105,9 +133,8 @@ def kill_delete_jobs(right, validJobList, nonauthJobList=[], force=False):
105
133
  stagingJobList = [jobID for jobID, sDict in jobStates.items() if sDict["Status"] == JobStatus.STAGING]
106
134
 
107
135
  if stagingJobList:
108
- stagerDB = StorageManagementDB()
109
136
  gLogger.info("Going to send killing signal to stager as well!")
110
- result = stagerDB.killTasksBySourceTaskID(stagingJobList)
137
+ result = storagemanagementdb.killTasksBySourceTaskID(stagingJobList)
111
138
  if not result["OK"]:
112
139
  gLogger.warn("Failed to kill some Stager tasks", result["Message"])
113
140
 
@@ -19,10 +19,25 @@ from DIRAC.WorkloadManagementSystem.DB.StatusUtils import kill_delete_jobs
19
19
  ],
20
20
  )
21
21
  def test___kill_delete_jobs(mocker, jobIDs_list, right):
22
- mocker.patch("DIRAC.WorkloadManagementSystem.DB.StatusUtils.JobDB", MagicMock())
23
- mocker.patch("DIRAC.WorkloadManagementSystem.DB.StatusUtils.TaskQueueDB", MagicMock())
24
- mocker.patch("DIRAC.WorkloadManagementSystem.DB.StatusUtils.PilotAgentsDB", MagicMock())
25
- mocker.patch("DIRAC.WorkloadManagementSystem.DB.StatusUtils.StorageManagementDB", MagicMock())
22
+ # Mock ObjectLoader to return mock DB instances
23
+ mockJobDB = MagicMock()
24
+ mockTaskQueueDB = MagicMock()
25
+ mockPilotAgentsDB = MagicMock()
26
+ mockStorageManagementDB = MagicMock()
27
+
28
+ def mock_load_object(module_path, class_name):
29
+ mocks = {
30
+ "JobDB": mockJobDB,
31
+ "TaskQueueDB": mockTaskQueueDB,
32
+ "PilotAgentsDB": mockPilotAgentsDB,
33
+ "StorageManagementDB": mockStorageManagementDB,
34
+ }
35
+ return {"OK": True, "Value": lambda: mocks[class_name]}
36
+
37
+ mocker.patch(
38
+ "DIRAC.WorkloadManagementSystem.DB.StatusUtils.ObjectLoader.loadObject",
39
+ side_effect=mock_load_object,
40
+ )
26
41
 
27
42
  res = kill_delete_jobs(right, jobIDs_list)
28
43
  assert res["OK"]
@@ -65,6 +65,11 @@ class JobManagerHandlerMixin:
65
65
  return result
66
66
  cls.pilotAgentsDB = result["Value"](parentLogger=cls.log)
67
67
 
68
+ result = ObjectLoader().loadObject("StorageManagementSystem.DB.StorageManagementDB", "StorageManagementDB")
69
+ if not result["OK"]:
70
+ return result
71
+ cls.storageManagementDB = result["Value"](parentLogger=cls.log)
72
+
68
73
  except RuntimeError as excp:
69
74
  return S_ERROR(f"Can't connect to DB: {excp!r}")
70
75
 
@@ -449,7 +454,16 @@ class JobManagerHandlerMixin:
449
454
  jobList, RIGHT_DELETE
450
455
  )
451
456
 
452
- result = kill_delete_jobs(RIGHT_DELETE, validJobList, nonauthJobList, force=force)
457
+ result = kill_delete_jobs(
458
+ RIGHT_DELETE,
459
+ validJobList,
460
+ nonauthJobList,
461
+ force=force,
462
+ jobdb=self.jobDB,
463
+ taskqueuedb=self.taskQueueDB,
464
+ pilotagentsdb=self.pilotAgentsDB,
465
+ storagemanagementdb=self.storageManagementDB,
466
+ )
453
467
 
454
468
  result["requireProxyUpload"] = len(ownerJobList) > 0 and self.__checkIfProxyUploadIsRequired()
455
469
 
@@ -478,7 +492,16 @@ class JobManagerHandlerMixin:
478
492
  jobList, RIGHT_KILL
479
493
  )
480
494
 
481
- result = kill_delete_jobs(RIGHT_KILL, validJobList, nonauthJobList, force=force)
495
+ result = kill_delete_jobs(
496
+ RIGHT_KILL,
497
+ validJobList,
498
+ nonauthJobList,
499
+ force=force,
500
+ jobdb=self.jobDB,
501
+ taskqueuedb=self.taskQueueDB,
502
+ pilotagentsdb=self.pilotAgentsDB,
503
+ storagemanagementdb=self.storageManagementDB,
504
+ )
482
505
 
483
506
  result["requireProxyUpload"] = len(ownerJobList) > 0 and self.__checkIfProxyUploadIsRequired()
484
507
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: DIRAC
3
- Version: 9.0.9
3
+ Version: 9.0.11
4
4
  Summary: DIRAC is an interware, meaning a software framework for distributed computing.
5
5
  Home-page: https://github.com/DIRACGrid/DIRAC/
6
6
  License: GPL-3.0-only
@@ -19,7 +19,7 @@ Requires-Dist: cachetools
19
19
  Requires-Dist: certifi
20
20
  Requires-Dist: cwltool
21
21
  Requires-Dist: diraccfg
22
- Requires-Dist: DIRACCommon==v9.0.9
22
+ Requires-Dist: DIRACCommon==v9.0.11
23
23
  Requires-Dist: diracx-client>=v0.0.1
24
24
  Requires-Dist: diracx-core>=v0.0.1
25
25
  Requires-Dist: diracx-cli>=v0.0.1