diracx-db 0.0.1a17__py3-none-any.whl → 0.0.1a19__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diracx/db/__main__.py +1 -0
- diracx/db/os/utils.py +60 -11
- diracx/db/sql/__init__.py +12 -2
- diracx/db/sql/auth/db.py +10 -19
- diracx/db/sql/auth/schema.py +5 -7
- diracx/db/sql/dummy/db.py +2 -3
- diracx/db/sql/{jobs → job}/db.py +12 -452
- diracx/db/sql/job/schema.py +129 -0
- diracx/db/sql/job_logging/__init__.py +0 -0
- diracx/db/sql/job_logging/db.py +161 -0
- diracx/db/sql/job_logging/schema.py +25 -0
- diracx/db/sql/pilot_agents/__init__.py +0 -0
- diracx/db/sql/pilot_agents/db.py +46 -0
- diracx/db/sql/pilot_agents/schema.py +58 -0
- diracx/db/sql/sandbox_metadata/db.py +12 -10
- diracx/db/sql/task_queue/__init__.py +0 -0
- diracx/db/sql/task_queue/db.py +261 -0
- diracx/db/sql/task_queue/schema.py +109 -0
- diracx/db/sql/utils/__init__.py +445 -0
- diracx/db/sql/{jobs/status_utility.py → utils/job_status.py} +11 -18
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/METADATA +5 -5
- diracx_db-0.0.1a19.dist-info/RECORD +36 -0
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/WHEEL +1 -1
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/entry_points.txt +1 -0
- diracx/db/sql/jobs/schema.py +0 -290
- diracx/db/sql/utils.py +0 -236
- diracx_db-0.0.1a17.dist-info/RECORD +0 -27
- /diracx/db/sql/{jobs → job}/__init__.py +0 -0
- {diracx_db-0.0.1a17.dist-info → diracx_db-0.0.1a19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import time
|
4
|
+
from datetime import datetime, timezone
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
from sqlalchemy import delete, func, insert, select
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
pass
|
11
|
+
|
12
|
+
from diracx.core.exceptions import JobNotFound
|
13
|
+
from diracx.core.models import (
|
14
|
+
JobStatus,
|
15
|
+
JobStatusReturn,
|
16
|
+
)
|
17
|
+
|
18
|
+
from ..utils import BaseSQLDB
|
19
|
+
from .schema import (
|
20
|
+
JobLoggingDBBase,
|
21
|
+
LoggingInfo,
|
22
|
+
)
|
23
|
+
|
24
|
+
MAGIC_EPOC_NUMBER = 1270000000
|
25
|
+
|
26
|
+
|
27
|
+
class JobLoggingDB(BaseSQLDB):
|
28
|
+
"""Frontend for the JobLoggingDB. Provides the ability to store changes with timestamps."""
|
29
|
+
|
30
|
+
metadata = JobLoggingDBBase.metadata
|
31
|
+
|
32
|
+
async def insert_record(
|
33
|
+
self,
|
34
|
+
job_id: int,
|
35
|
+
status: JobStatus,
|
36
|
+
minor_status: str,
|
37
|
+
application_status: str,
|
38
|
+
date: datetime,
|
39
|
+
source: str,
|
40
|
+
):
|
41
|
+
"""Add a new entry to the JobLoggingDB table. One, two or all the three status
|
42
|
+
components (status, minorStatus, applicationStatus) can be specified.
|
43
|
+
Optionally the time stamp of the status can
|
44
|
+
be provided in a form of a string in a format '%Y-%m-%d %H:%M:%S' or
|
45
|
+
as datetime.datetime object. If the time stamp is not provided the current
|
46
|
+
UTC time is used.
|
47
|
+
"""
|
48
|
+
# First, fetch the maximum SeqNum for the given job_id
|
49
|
+
seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)).where(
|
50
|
+
LoggingInfo.JobID == job_id
|
51
|
+
)
|
52
|
+
seqnum = await self.conn.scalar(seqnum_stmt)
|
53
|
+
|
54
|
+
epoc = (
|
55
|
+
time.mktime(date.timetuple())
|
56
|
+
+ date.microsecond / 1000000.0
|
57
|
+
- MAGIC_EPOC_NUMBER
|
58
|
+
)
|
59
|
+
|
60
|
+
stmt = insert(LoggingInfo).values(
|
61
|
+
JobID=int(job_id),
|
62
|
+
SeqNum=seqnum,
|
63
|
+
Status=status,
|
64
|
+
MinorStatus=minor_status,
|
65
|
+
ApplicationStatus=application_status[:255],
|
66
|
+
StatusTime=date,
|
67
|
+
StatusTimeOrder=epoc,
|
68
|
+
Source=source[:32],
|
69
|
+
)
|
70
|
+
await self.conn.execute(stmt)
|
71
|
+
|
72
|
+
async def get_records(self, job_id: int) -> list[JobStatusReturn]:
|
73
|
+
"""Returns a Status,MinorStatus,ApplicationStatus,StatusTime,Source tuple
|
74
|
+
for each record found for job specified by its jobID in historical order.
|
75
|
+
"""
|
76
|
+
stmt = (
|
77
|
+
select(
|
78
|
+
LoggingInfo.Status,
|
79
|
+
LoggingInfo.MinorStatus,
|
80
|
+
LoggingInfo.ApplicationStatus,
|
81
|
+
LoggingInfo.StatusTime,
|
82
|
+
LoggingInfo.Source,
|
83
|
+
)
|
84
|
+
.where(LoggingInfo.JobID == int(job_id))
|
85
|
+
.order_by(LoggingInfo.StatusTimeOrder, LoggingInfo.StatusTime)
|
86
|
+
)
|
87
|
+
rows = await self.conn.execute(stmt)
|
88
|
+
|
89
|
+
values = []
|
90
|
+
for (
|
91
|
+
status,
|
92
|
+
minor_status,
|
93
|
+
application_status,
|
94
|
+
status_time,
|
95
|
+
status_source,
|
96
|
+
) in rows:
|
97
|
+
values.append(
|
98
|
+
[
|
99
|
+
status,
|
100
|
+
minor_status,
|
101
|
+
application_status,
|
102
|
+
status_time.replace(tzinfo=timezone.utc),
|
103
|
+
status_source,
|
104
|
+
]
|
105
|
+
)
|
106
|
+
|
107
|
+
# If no value has been set for the application status in the first place,
|
108
|
+
# We put this status to unknown
|
109
|
+
res = []
|
110
|
+
if values:
|
111
|
+
if values[0][2] == "idem":
|
112
|
+
values[0][2] = "Unknown"
|
113
|
+
|
114
|
+
# We replace "idem" values by the value previously stated
|
115
|
+
for i in range(1, len(values)):
|
116
|
+
for j in range(3):
|
117
|
+
if values[i][j] == "idem":
|
118
|
+
values[i][j] = values[i - 1][j]
|
119
|
+
|
120
|
+
# And we replace arrays with tuples
|
121
|
+
for (
|
122
|
+
status,
|
123
|
+
minor_status,
|
124
|
+
application_status,
|
125
|
+
status_time,
|
126
|
+
status_source,
|
127
|
+
) in values:
|
128
|
+
res.append(
|
129
|
+
JobStatusReturn(
|
130
|
+
Status=status,
|
131
|
+
MinorStatus=minor_status,
|
132
|
+
ApplicationStatus=application_status,
|
133
|
+
StatusTime=status_time,
|
134
|
+
Source=status_source,
|
135
|
+
)
|
136
|
+
)
|
137
|
+
|
138
|
+
return res
|
139
|
+
|
140
|
+
async def delete_records(self, job_ids: list[int]):
|
141
|
+
"""Delete logging records for given jobs."""
|
142
|
+
stmt = delete(LoggingInfo).where(LoggingInfo.JobID.in_(job_ids))
|
143
|
+
await self.conn.execute(stmt)
|
144
|
+
|
145
|
+
async def get_wms_time_stamps(self, job_id):
|
146
|
+
"""Get TimeStamps for job MajorState transitions
|
147
|
+
return a {State:timestamp} dictionary.
|
148
|
+
"""
|
149
|
+
result = {}
|
150
|
+
stmt = select(
|
151
|
+
LoggingInfo.Status,
|
152
|
+
LoggingInfo.StatusTimeOrder,
|
153
|
+
).where(LoggingInfo.JobID == job_id)
|
154
|
+
rows = await self.conn.execute(stmt)
|
155
|
+
if not rows.rowcount:
|
156
|
+
raise JobNotFound(job_id) from None
|
157
|
+
|
158
|
+
for event, etime in rows:
|
159
|
+
result[event] = str(etime + MAGIC_EPOC_NUMBER)
|
160
|
+
|
161
|
+
return result
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from sqlalchemy import (
|
2
|
+
Integer,
|
3
|
+
Numeric,
|
4
|
+
PrimaryKeyConstraint,
|
5
|
+
String,
|
6
|
+
)
|
7
|
+
from sqlalchemy.orm import declarative_base
|
8
|
+
|
9
|
+
from ..utils import Column, DateNowColumn
|
10
|
+
|
11
|
+
JobLoggingDBBase = declarative_base()
|
12
|
+
|
13
|
+
|
14
|
+
class LoggingInfo(JobLoggingDBBase):
|
15
|
+
__tablename__ = "LoggingInfo"
|
16
|
+
JobID = Column(Integer)
|
17
|
+
SeqNum = Column(Integer)
|
18
|
+
Status = Column(String(32), default="")
|
19
|
+
MinorStatus = Column(String(128), default="")
|
20
|
+
ApplicationStatus = Column(String(255), default="")
|
21
|
+
StatusTime = DateNowColumn()
|
22
|
+
# TODO: Check that this corresponds to the DOUBLE(12,3) type in MySQL
|
23
|
+
StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0)
|
24
|
+
Source = Column(String(32), default="Unknown", name="StatusSource")
|
25
|
+
__table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),)
|
File without changes
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from datetime import datetime, timezone
|
4
|
+
|
5
|
+
from sqlalchemy import insert
|
6
|
+
|
7
|
+
from ..utils import BaseSQLDB
|
8
|
+
from .schema import PilotAgents, PilotAgentsDBBase
|
9
|
+
|
10
|
+
|
11
|
+
class PilotAgentsDB(BaseSQLDB):
|
12
|
+
"""PilotAgentsDB class is a front-end to the PilotAgents Database."""
|
13
|
+
|
14
|
+
metadata = PilotAgentsDBBase.metadata
|
15
|
+
|
16
|
+
async def add_pilot_references(
|
17
|
+
self,
|
18
|
+
pilot_ref: list[str],
|
19
|
+
vo: str,
|
20
|
+
grid_type: str = "DIRAC",
|
21
|
+
pilot_stamps: dict | None = None,
|
22
|
+
) -> None:
|
23
|
+
|
24
|
+
if pilot_stamps is None:
|
25
|
+
pilot_stamps = {}
|
26
|
+
|
27
|
+
now = datetime.now(tz=timezone.utc)
|
28
|
+
|
29
|
+
# Prepare the list of dictionaries for bulk insertion
|
30
|
+
values = [
|
31
|
+
{
|
32
|
+
"PilotJobReference": ref,
|
33
|
+
"VO": vo,
|
34
|
+
"GridType": grid_type,
|
35
|
+
"SubmissionTime": now,
|
36
|
+
"LastUpdateTime": now,
|
37
|
+
"Status": "Submitted",
|
38
|
+
"PilotStamp": pilot_stamps.get(ref, ""),
|
39
|
+
}
|
40
|
+
for ref in pilot_ref
|
41
|
+
]
|
42
|
+
|
43
|
+
# Insert multiple rows in a single execute call
|
44
|
+
stmt = insert(PilotAgents).values(values)
|
45
|
+
await self.conn.execute(stmt)
|
46
|
+
return
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from sqlalchemy import (
|
2
|
+
DateTime,
|
3
|
+
Double,
|
4
|
+
Index,
|
5
|
+
Integer,
|
6
|
+
String,
|
7
|
+
Text,
|
8
|
+
)
|
9
|
+
from sqlalchemy.orm import declarative_base
|
10
|
+
|
11
|
+
from ..utils import Column, EnumBackedBool, NullColumn
|
12
|
+
|
13
|
+
PilotAgentsDBBase = declarative_base()
|
14
|
+
|
15
|
+
|
16
|
+
class PilotAgents(PilotAgentsDBBase):
|
17
|
+
__tablename__ = "PilotAgents"
|
18
|
+
|
19
|
+
PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True)
|
20
|
+
InitialJobID = Column("InitialJobID", Integer, default=0)
|
21
|
+
CurrentJobID = Column("CurrentJobID", Integer, default=0)
|
22
|
+
PilotJobReference = Column("PilotJobReference", String(255), default="Unknown")
|
23
|
+
PilotStamp = Column("PilotStamp", String(32), default="")
|
24
|
+
DestinationSite = Column("DestinationSite", String(128), default="NotAssigned")
|
25
|
+
Queue = Column("Queue", String(128), default="Unknown")
|
26
|
+
GridSite = Column("GridSite", String(128), default="Unknown")
|
27
|
+
VO = Column("VO", String(128))
|
28
|
+
GridType = Column("GridType", String(32), default="LCG")
|
29
|
+
BenchMark = Column("BenchMark", Double, default=0.0)
|
30
|
+
SubmissionTime = NullColumn("SubmissionTime", DateTime)
|
31
|
+
LastUpdateTime = NullColumn("LastUpdateTime", DateTime)
|
32
|
+
Status = Column("Status", String(32), default="Unknown")
|
33
|
+
StatusReason = Column("StatusReason", String(255), default="Unknown")
|
34
|
+
AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False)
|
35
|
+
|
36
|
+
__table_args__ = (
|
37
|
+
Index("PilotJobReference", "PilotJobReference"),
|
38
|
+
Index("Status", "Status"),
|
39
|
+
Index("Statuskey", "GridSite", "DestinationSite", "Status"),
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class JobToPilotMapping(PilotAgentsDBBase):
|
44
|
+
__tablename__ = "JobToPilotMapping"
|
45
|
+
|
46
|
+
PilotID = Column("PilotID", Integer, primary_key=True)
|
47
|
+
JobID = Column("JobID", Integer, primary_key=True)
|
48
|
+
StartTime = Column("StartTime", DateTime)
|
49
|
+
|
50
|
+
__table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID"))
|
51
|
+
|
52
|
+
|
53
|
+
class PilotOutput(PilotAgentsDBBase):
|
54
|
+
__tablename__ = "PilotOutput"
|
55
|
+
|
56
|
+
PilotID = Column("PilotID", Integer, primary_key=True)
|
57
|
+
StdOutput = Column("StdOutput", Text)
|
58
|
+
StdError = Column("StdError", Text)
|
@@ -15,7 +15,7 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
15
15
|
metadata = SandboxMetadataDBBase.metadata
|
16
16
|
|
17
17
|
async def upsert_owner(self, user: UserInfo) -> int:
|
18
|
-
"""Get the id of the owner from the database"""
|
18
|
+
"""Get the id of the owner from the database."""
|
19
19
|
# TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
|
20
20
|
stmt = sqlalchemy.select(sb_Owners.OwnerID).where(
|
21
21
|
sb_Owners.Owner == user.preferred_username,
|
@@ -36,7 +36,7 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
36
36
|
|
37
37
|
@staticmethod
|
38
38
|
def get_pfn(bucket_name: str, user: UserInfo, sandbox_info: SandboxInfo) -> str:
|
39
|
-
"""Get the sandbox's user namespaced and content addressed PFN"""
|
39
|
+
"""Get the sandbox's user namespaced and content addressed PFN."""
|
40
40
|
parts = [
|
41
41
|
"S3",
|
42
42
|
bucket_name,
|
@@ -50,7 +50,7 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
50
50
|
async def insert_sandbox(
|
51
51
|
self, se_name: str, user: UserInfo, pfn: str, size: int
|
52
52
|
) -> None:
|
53
|
-
"""Add a new sandbox in SandboxMetadataDB"""
|
53
|
+
"""Add a new sandbox in SandboxMetadataDB."""
|
54
54
|
# TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
|
55
55
|
owner_id = await self.upsert_owner(user)
|
56
56
|
stmt = sqlalchemy.insert(sb_SandBoxes).values(
|
@@ -88,13 +88,13 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
88
88
|
|
89
89
|
@staticmethod
|
90
90
|
def jobid_to_entity_id(job_id: int) -> str:
|
91
|
-
"""Define the entity id as 'Entity:entity_id' due to the DB definition"""
|
91
|
+
"""Define the entity id as 'Entity:entity_id' due to the DB definition."""
|
92
92
|
return f"Job:{job_id}"
|
93
93
|
|
94
94
|
async def get_sandbox_assigned_to_job(
|
95
95
|
self, job_id: int, sb_type: SandboxType
|
96
96
|
) -> list[Any]:
|
97
|
-
"""Get the sandbox assign to job"""
|
97
|
+
"""Get the sandbox assign to job."""
|
98
98
|
entity_id = self.jobid_to_entity_id(job_id)
|
99
99
|
stmt = (
|
100
100
|
sqlalchemy.select(sb_SandBoxes.SEPFN)
|
@@ -114,7 +114,7 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
114
114
|
sb_type: SandboxType,
|
115
115
|
se_name: str,
|
116
116
|
) -> None:
|
117
|
-
"""Mapp sandbox and jobs"""
|
117
|
+
"""Mapp sandbox and jobs."""
|
118
118
|
for job_id in jobs_ids:
|
119
119
|
# Define the entity id as 'Entity:entity_id' due to the DB definition:
|
120
120
|
entity_id = self.jobid_to_entity_id(job_id)
|
@@ -140,12 +140,14 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
140
140
|
assert result.rowcount == 1
|
141
141
|
|
142
142
|
async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None:
|
143
|
-
"""Delete mapping between jobs and sandboxes"""
|
143
|
+
"""Delete mapping between jobs and sandboxes."""
|
144
144
|
for job_id in jobs_ids:
|
145
145
|
entity_id = self.jobid_to_entity_id(job_id)
|
146
|
-
sb_sel_stmt = sqlalchemy.select(
|
147
|
-
|
148
|
-
|
146
|
+
sb_sel_stmt = sqlalchemy.select(sb_SandBoxes.SBId)
|
147
|
+
sb_sel_stmt = sb_sel_stmt.join(
|
148
|
+
sb_EntityMapping, sb_EntityMapping.SBId == sb_SandBoxes.SBId
|
149
|
+
)
|
150
|
+
sb_sel_stmt = sb_sel_stmt.where(sb_EntityMapping.EntityId == entity_id)
|
149
151
|
|
150
152
|
result = await self.conn.execute(sb_sel_stmt)
|
151
153
|
sb_ids = [row.SBId for row in result]
|
File without changes
|
@@ -0,0 +1,261 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from sqlalchemy import delete, func, select, update
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
pass
|
9
|
+
|
10
|
+
from diracx.core.properties import JOB_SHARING, SecurityProperty
|
11
|
+
|
12
|
+
from ..utils import BaseSQLDB
|
13
|
+
from .schema import (
|
14
|
+
BannedSitesQueue,
|
15
|
+
GridCEsQueue,
|
16
|
+
JobsQueue,
|
17
|
+
JobTypesQueue,
|
18
|
+
PlatformsQueue,
|
19
|
+
SitesQueue,
|
20
|
+
TagsQueue,
|
21
|
+
TaskQueueDBBase,
|
22
|
+
TaskQueues,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class TaskQueueDB(BaseSQLDB):
|
27
|
+
metadata = TaskQueueDBBase.metadata
|
28
|
+
|
29
|
+
async def get_tq_infos_for_jobs(
|
30
|
+
self, job_ids: list[int]
|
31
|
+
) -> set[tuple[int, str, str, str]]:
|
32
|
+
"""Get the task queue info for given jobs."""
|
33
|
+
stmt = (
|
34
|
+
select(
|
35
|
+
TaskQueues.TQId, TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO
|
36
|
+
)
|
37
|
+
.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)
|
38
|
+
.where(JobsQueue.JobId.in_(job_ids))
|
39
|
+
)
|
40
|
+
return set(
|
41
|
+
(int(row[0]), str(row[1]), str(row[2]), str(row[3]))
|
42
|
+
for row in (await self.conn.execute(stmt)).all()
|
43
|
+
)
|
44
|
+
|
45
|
+
async def get_owner_for_task_queue(self, tq_id: int) -> dict[str, str]:
|
46
|
+
"""Get the owner and owner group for a task queue."""
|
47
|
+
stmt = select(TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO).where(
|
48
|
+
TaskQueues.TQId == tq_id
|
49
|
+
)
|
50
|
+
return dict((await self.conn.execute(stmt)).one()._mapping)
|
51
|
+
|
52
|
+
async def remove_job(self, job_id: int):
|
53
|
+
"""Remove a job from the task queues."""
|
54
|
+
stmt = delete(JobsQueue).where(JobsQueue.JobId == job_id)
|
55
|
+
await self.conn.execute(stmt)
|
56
|
+
|
57
|
+
async def remove_jobs(self, job_ids: list[int]):
|
58
|
+
"""Remove jobs from the task queues."""
|
59
|
+
stmt = delete(JobsQueue).where(JobsQueue.JobId.in_(job_ids))
|
60
|
+
await self.conn.execute(stmt)
|
61
|
+
|
62
|
+
async def delete_task_queue_if_empty(
|
63
|
+
self,
|
64
|
+
tq_id: int,
|
65
|
+
tq_owner: str,
|
66
|
+
tq_group: str,
|
67
|
+
job_share: int,
|
68
|
+
group_properties: set[SecurityProperty],
|
69
|
+
enable_shares_correction: bool,
|
70
|
+
allow_background_tqs: bool,
|
71
|
+
):
|
72
|
+
"""Try to delete a task queue if it's empty."""
|
73
|
+
# Check if the task queue is empty
|
74
|
+
stmt = (
|
75
|
+
select(TaskQueues.TQId)
|
76
|
+
.where(TaskQueues.Enabled >= 1)
|
77
|
+
.where(TaskQueues.TQId == tq_id)
|
78
|
+
.where(~TaskQueues.TQId.in_(select(JobsQueue.TQId)))
|
79
|
+
)
|
80
|
+
rows = await self.conn.execute(stmt)
|
81
|
+
if not rows.rowcount:
|
82
|
+
return
|
83
|
+
|
84
|
+
# Deleting the task queue (the other tables will be deleted in cascade)
|
85
|
+
stmt = delete(TaskQueues).where(TaskQueues.TQId == tq_id)
|
86
|
+
await self.conn.execute(stmt)
|
87
|
+
|
88
|
+
await self.recalculate_tq_shares_for_entity(
|
89
|
+
tq_owner,
|
90
|
+
tq_group,
|
91
|
+
job_share,
|
92
|
+
group_properties,
|
93
|
+
enable_shares_correction,
|
94
|
+
allow_background_tqs,
|
95
|
+
)
|
96
|
+
|
97
|
+
async def recalculate_tq_shares_for_entity(
|
98
|
+
self,
|
99
|
+
owner: str,
|
100
|
+
group: str,
|
101
|
+
job_share: int,
|
102
|
+
group_properties: set[SecurityProperty],
|
103
|
+
enable_shares_correction: bool,
|
104
|
+
allow_background_tqs: bool,
|
105
|
+
):
|
106
|
+
"""Recalculate the shares for a user/userGroup combo."""
|
107
|
+
if JOB_SHARING in group_properties:
|
108
|
+
# If group has JobSharing just set prio for that entry, user is irrelevant
|
109
|
+
return await self.__set_priorities_for_entity(
|
110
|
+
owner, group, job_share, group_properties, allow_background_tqs
|
111
|
+
)
|
112
|
+
|
113
|
+
stmt = (
|
114
|
+
select(TaskQueues.Owner, func.count(TaskQueues.Owner))
|
115
|
+
.where(TaskQueues.OwnerGroup == group)
|
116
|
+
.group_by(TaskQueues.Owner)
|
117
|
+
)
|
118
|
+
rows = await self.conn.execute(stmt)
|
119
|
+
# make the rows a list of tuples
|
120
|
+
# Get owners in this group and the amount of times they appear
|
121
|
+
# TODO: I guess the rows are already a list of tupes
|
122
|
+
# maybe refactor
|
123
|
+
data = [(r[0], r[1]) for r in rows if r]
|
124
|
+
numOwners = len(data)
|
125
|
+
# If there are no owners do now
|
126
|
+
if numOwners == 0:
|
127
|
+
return
|
128
|
+
# Split the share amongst the number of owners
|
129
|
+
entities_shares = {row[0]: job_share / numOwners for row in data}
|
130
|
+
|
131
|
+
# TODO: implement the following
|
132
|
+
# If corrector is enabled let it work it's magic
|
133
|
+
# if enable_shares_correction:
|
134
|
+
# entities_shares = await self.__shares_corrector.correct_shares(
|
135
|
+
# entitiesShares, group=group
|
136
|
+
# )
|
137
|
+
|
138
|
+
# Keep updating
|
139
|
+
owners = dict(data)
|
140
|
+
# IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified
|
141
|
+
# (The number of owners didn't change)
|
142
|
+
if owner in owners and owners[owner] > 1:
|
143
|
+
await self.__set_priorities_for_entity(
|
144
|
+
owner,
|
145
|
+
group,
|
146
|
+
entities_shares[owner],
|
147
|
+
group_properties,
|
148
|
+
allow_background_tqs,
|
149
|
+
)
|
150
|
+
return
|
151
|
+
# Oops the number of owners may have changed so we recalculate the prio for all owners in the group
|
152
|
+
for owner in owners:
|
153
|
+
await self.__set_priorities_for_entity(
|
154
|
+
owner,
|
155
|
+
group,
|
156
|
+
entities_shares[owner],
|
157
|
+
group_properties,
|
158
|
+
allow_background_tqs,
|
159
|
+
)
|
160
|
+
|
161
|
+
async def __set_priorities_for_entity(
|
162
|
+
self,
|
163
|
+
owner: str,
|
164
|
+
group: str,
|
165
|
+
share,
|
166
|
+
properties: set[SecurityProperty],
|
167
|
+
allow_background_tqs: bool,
|
168
|
+
):
|
169
|
+
"""Set the priority for a user/userGroup combo given a splitted share."""
|
170
|
+
from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import calculate_priority
|
171
|
+
|
172
|
+
stmt = (
|
173
|
+
select(
|
174
|
+
TaskQueues.TQId,
|
175
|
+
func.sum(JobsQueue.RealPriority) / func.count(JobsQueue.RealPriority),
|
176
|
+
)
|
177
|
+
.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)
|
178
|
+
.where(TaskQueues.OwnerGroup == group)
|
179
|
+
.group_by(TaskQueues.TQId)
|
180
|
+
)
|
181
|
+
if JOB_SHARING not in properties:
|
182
|
+
stmt = stmt.where(TaskQueues.Owner == owner)
|
183
|
+
rows = await self.conn.execute(stmt)
|
184
|
+
tq_dict: dict[int, float] = {tq_id: priority for tq_id, priority in rows}
|
185
|
+
|
186
|
+
if not tq_dict:
|
187
|
+
return
|
188
|
+
|
189
|
+
rows = await self.retrieve_task_queues(list(tq_dict))
|
190
|
+
|
191
|
+
prio_dict = calculate_priority(tq_dict, rows, share, allow_background_tqs)
|
192
|
+
|
193
|
+
# Execute updates
|
194
|
+
for prio, tqs in prio_dict.items():
|
195
|
+
update_stmt = (
|
196
|
+
update(TaskQueues).where(TaskQueues.TQId.in_(tqs)).values(Priority=prio)
|
197
|
+
)
|
198
|
+
await self.conn.execute(update_stmt)
|
199
|
+
|
200
|
+
async def retrieve_task_queues(self, tq_id_list=None):
|
201
|
+
"""Get all the task queues."""
|
202
|
+
if tq_id_list is not None and not tq_id_list:
|
203
|
+
# Empty list => Fast-track no matches
|
204
|
+
return {}
|
205
|
+
|
206
|
+
stmt = (
|
207
|
+
select(
|
208
|
+
TaskQueues.TQId,
|
209
|
+
TaskQueues.Priority,
|
210
|
+
func.count(JobsQueue.TQId).label("Jobs"),
|
211
|
+
TaskQueues.Owner,
|
212
|
+
TaskQueues.OwnerGroup,
|
213
|
+
TaskQueues.VO,
|
214
|
+
TaskQueues.CPUTime,
|
215
|
+
)
|
216
|
+
.join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId)
|
217
|
+
.join(SitesQueue, TaskQueues.TQId == SitesQueue.TQId)
|
218
|
+
.join(GridCEsQueue, TaskQueues.TQId == GridCEsQueue.TQId)
|
219
|
+
.group_by(
|
220
|
+
TaskQueues.TQId,
|
221
|
+
TaskQueues.Priority,
|
222
|
+
TaskQueues.Owner,
|
223
|
+
TaskQueues.OwnerGroup,
|
224
|
+
TaskQueues.VO,
|
225
|
+
TaskQueues.CPUTime,
|
226
|
+
)
|
227
|
+
)
|
228
|
+
if tq_id_list is not None:
|
229
|
+
stmt = stmt.where(TaskQueues.TQId.in_(tq_id_list))
|
230
|
+
|
231
|
+
tq_data: dict[int, dict[str, list[str]]] = dict(
|
232
|
+
dict(row._mapping) for row in await self.conn.execute(stmt)
|
233
|
+
)
|
234
|
+
# TODO: the line above should be equivalent to the following commented code, check this is the case
|
235
|
+
# for record in rows:
|
236
|
+
# tqId = record[0]
|
237
|
+
# tqData[tqId] = {
|
238
|
+
# "Priority": record[1],
|
239
|
+
# "Jobs": record[2],
|
240
|
+
# "Owner": record[3],
|
241
|
+
# "OwnerGroup": record[4],
|
242
|
+
# "VO": record[5],
|
243
|
+
# "CPUTime": record[6],
|
244
|
+
# }
|
245
|
+
|
246
|
+
for tq_id in tq_data:
|
247
|
+
# TODO: maybe factorize this handy tuple list
|
248
|
+
for table, field in {
|
249
|
+
(SitesQueue, "Sites"),
|
250
|
+
(GridCEsQueue, "GridCEs"),
|
251
|
+
(BannedSitesQueue, "BannedSites"),
|
252
|
+
(PlatformsQueue, "Platforms"),
|
253
|
+
(JobTypesQueue, "JobTypes"),
|
254
|
+
(TagsQueue, "Tags"),
|
255
|
+
}:
|
256
|
+
stmt = select(table.Value).where(table.TQId == tq_id)
|
257
|
+
tq_data[tq_id][field] = list(
|
258
|
+
row[0] for row in await self.conn.execute(stmt)
|
259
|
+
)
|
260
|
+
|
261
|
+
return tq_data
|