diracx-db 0.0.1a22__py3-none-any.whl → 0.0.1a24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diracx/db/exceptions.py +4 -1
- diracx/db/os/utils.py +3 -3
- diracx/db/sql/auth/db.py +9 -9
- diracx/db/sql/auth/schema.py +25 -23
- diracx/db/sql/dummy/db.py +2 -2
- diracx/db/sql/dummy/schema.py +8 -6
- diracx/db/sql/job/db.py +57 -54
- diracx/db/sql/job/schema.py +56 -54
- diracx/db/sql/job_logging/db.py +50 -46
- diracx/db/sql/job_logging/schema.py +12 -8
- diracx/db/sql/pilot_agents/schema.py +24 -22
- diracx/db/sql/sandbox_metadata/db.py +42 -40
- diracx/db/sql/sandbox_metadata/schema.py +5 -3
- diracx/db/sql/task_queue/db.py +3 -3
- diracx/db/sql/task_queue/schema.py +2 -0
- diracx/db/sql/utils/__init__.py +22 -451
- diracx/db/sql/utils/base.py +328 -0
- diracx/db/sql/utils/functions.py +105 -0
- diracx/db/sql/utils/job.py +59 -55
- diracx/db/sql/utils/types.py +43 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/METADATA +2 -2
- diracx_db-0.0.1a24.dist-info/RECORD +39 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/WHEEL +1 -1
- diracx_db-0.0.1a22.dist-info/RECORD +0 -36
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/entry_points.txt +0 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/top_level.txt +0 -0
diracx/db/sql/job_logging/db.py
CHANGED
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
12
12
|
|
13
13
|
from collections import defaultdict
|
14
14
|
|
15
|
-
from diracx.core.exceptions import
|
15
|
+
from diracx.core.exceptions import JobNotFoundError
|
16
16
|
from diracx.core.models import (
|
17
17
|
JobStatus,
|
18
18
|
JobStatusReturn,
|
@@ -57,9 +57,9 @@ class JobLoggingDB(BaseSQLDB):
|
|
57
57
|
as datetime.datetime object. If the time stamp is not provided the current
|
58
58
|
UTC time is used.
|
59
59
|
"""
|
60
|
-
# First, fetch the maximum
|
61
|
-
seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.
|
62
|
-
LoggingInfo.
|
60
|
+
# First, fetch the maximum seq_num for the given job_id
|
61
|
+
seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)).where(
|
62
|
+
LoggingInfo.job_id == job_id
|
63
63
|
)
|
64
64
|
seqnum = await self.conn.scalar(seqnum_stmt)
|
65
65
|
|
@@ -70,14 +70,14 @@ class JobLoggingDB(BaseSQLDB):
|
|
70
70
|
)
|
71
71
|
|
72
72
|
stmt = insert(LoggingInfo).values(
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
73
|
+
job_id=int(job_id),
|
74
|
+
seq_num=seqnum,
|
75
|
+
status=status,
|
76
|
+
minor_status=minor_status,
|
77
|
+
application_status=application_status[:255],
|
78
|
+
status_time=date,
|
79
|
+
status_time_order=epoc,
|
80
|
+
source=source[:32],
|
81
81
|
)
|
82
82
|
await self.conn.execute(stmt)
|
83
83
|
|
@@ -97,18 +97,17 @@ class JobLoggingDB(BaseSQLDB):
|
|
97
97
|
# First, fetch the maximum SeqNums for the given job_ids
|
98
98
|
seqnum_stmt = (
|
99
99
|
select(
|
100
|
-
LoggingInfo.
|
100
|
+
LoggingInfo.job_id, func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)
|
101
101
|
)
|
102
|
-
.where(LoggingInfo.
|
103
|
-
.group_by(LoggingInfo.
|
102
|
+
.where(LoggingInfo.job_id.in_([record.job_id for record in records]))
|
103
|
+
.group_by(LoggingInfo.job_id)
|
104
104
|
)
|
105
105
|
|
106
106
|
seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))}
|
107
107
|
# IF a seqnum is not found, then assume it does not exist and the first sequence number is 1.
|
108
|
-
|
109
108
|
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements
|
110
109
|
await self.conn.execute(
|
111
|
-
insert(
|
110
|
+
LoggingInfo.__table__.insert(),
|
112
111
|
[
|
113
112
|
{
|
114
113
|
"JobID": record.job_id,
|
@@ -118,38 +117,43 @@ class JobLoggingDB(BaseSQLDB):
|
|
118
117
|
"ApplicationStatus": record.application_status[:255],
|
119
118
|
"StatusTime": record.date,
|
120
119
|
"StatusTimeOrder": get_epoc(record.date),
|
121
|
-
"
|
120
|
+
"StatusSource": record.source[:32],
|
122
121
|
}
|
123
122
|
for record in records
|
124
123
|
],
|
125
124
|
)
|
126
125
|
|
127
|
-
async def get_records(self,
|
126
|
+
async def get_records(self, job_ids: list[int]) -> dict[int, JobStatusReturn]:
|
128
127
|
"""Returns a Status,MinorStatus,ApplicationStatus,StatusTime,Source tuple
|
129
128
|
for each record found for job specified by its jobID in historical order.
|
130
129
|
"""
|
130
|
+
# We could potentially use a group_by here, but we need to post-process the
|
131
|
+
# results later.
|
131
132
|
stmt = (
|
132
133
|
select(
|
133
|
-
LoggingInfo.
|
134
|
-
LoggingInfo.
|
135
|
-
LoggingInfo.
|
136
|
-
LoggingInfo.
|
137
|
-
LoggingInfo.
|
134
|
+
LoggingInfo.job_id,
|
135
|
+
LoggingInfo.status,
|
136
|
+
LoggingInfo.minor_status,
|
137
|
+
LoggingInfo.application_status,
|
138
|
+
LoggingInfo.status_time,
|
139
|
+
LoggingInfo.source,
|
138
140
|
)
|
139
|
-
.where(LoggingInfo.
|
140
|
-
.order_by(LoggingInfo.
|
141
|
+
.where(LoggingInfo.job_id.in_(job_ids))
|
142
|
+
.order_by(LoggingInfo.status_time_order, LoggingInfo.status_time)
|
141
143
|
)
|
142
144
|
rows = await self.conn.execute(stmt)
|
143
145
|
|
144
|
-
values =
|
146
|
+
values = defaultdict(list)
|
145
147
|
for (
|
148
|
+
job_id,
|
146
149
|
status,
|
147
150
|
minor_status,
|
148
151
|
application_status,
|
149
152
|
status_time,
|
150
153
|
status_source,
|
151
154
|
) in rows:
|
152
|
-
|
155
|
+
|
156
|
+
values[job_id].append(
|
153
157
|
[
|
154
158
|
status,
|
155
159
|
minor_status,
|
@@ -161,16 +165,16 @@ class JobLoggingDB(BaseSQLDB):
|
|
161
165
|
|
162
166
|
# If no value has been set for the application status in the first place,
|
163
167
|
# We put this status to unknown
|
164
|
-
res =
|
165
|
-
|
166
|
-
if
|
167
|
-
|
168
|
+
res: dict = defaultdict(list)
|
169
|
+
for job_id, history in values.items():
|
170
|
+
if history[0][2] == "idem":
|
171
|
+
history[0][2] = "Unknown"
|
168
172
|
|
169
173
|
# We replace "idem" values by the value previously stated
|
170
|
-
for i in range(1, len(
|
174
|
+
for i in range(1, len(history)):
|
171
175
|
for j in range(3):
|
172
|
-
if
|
173
|
-
|
176
|
+
if history[i][j] == "idem":
|
177
|
+
history[i][j] = history[i - 1][j]
|
174
178
|
|
175
179
|
# And we replace arrays with tuples
|
176
180
|
for (
|
@@ -179,8 +183,8 @@ class JobLoggingDB(BaseSQLDB):
|
|
179
183
|
application_status,
|
180
184
|
status_time,
|
181
185
|
status_source,
|
182
|
-
) in
|
183
|
-
res.append(
|
186
|
+
) in history:
|
187
|
+
res[job_id].append(
|
184
188
|
JobStatusReturn(
|
185
189
|
Status=status,
|
186
190
|
MinorStatus=minor_status,
|
@@ -194,7 +198,7 @@ class JobLoggingDB(BaseSQLDB):
|
|
194
198
|
|
195
199
|
async def delete_records(self, job_ids: list[int]):
|
196
200
|
"""Delete logging records for given jobs."""
|
197
|
-
stmt = delete(LoggingInfo).where(LoggingInfo.
|
201
|
+
stmt = delete(LoggingInfo).where(LoggingInfo.job_id.in_(job_ids))
|
198
202
|
await self.conn.execute(stmt)
|
199
203
|
|
200
204
|
async def get_wms_time_stamps(self, job_id):
|
@@ -203,12 +207,12 @@ class JobLoggingDB(BaseSQLDB):
|
|
203
207
|
"""
|
204
208
|
result = {}
|
205
209
|
stmt = select(
|
206
|
-
LoggingInfo.
|
207
|
-
LoggingInfo.
|
208
|
-
).where(LoggingInfo.
|
210
|
+
LoggingInfo.status,
|
211
|
+
LoggingInfo.status_time_order,
|
212
|
+
).where(LoggingInfo.job_id == job_id)
|
209
213
|
rows = await self.conn.execute(stmt)
|
210
214
|
if not rows.rowcount:
|
211
|
-
raise
|
215
|
+
raise JobNotFoundError(job_id) from None
|
212
216
|
|
213
217
|
for event, etime in rows:
|
214
218
|
result[event] = str(etime + MAGIC_EPOC_NUMBER)
|
@@ -221,10 +225,10 @@ class JobLoggingDB(BaseSQLDB):
|
|
221
225
|
"""
|
222
226
|
result = defaultdict(dict)
|
223
227
|
stmt = select(
|
224
|
-
LoggingInfo.
|
225
|
-
LoggingInfo.
|
226
|
-
LoggingInfo.
|
227
|
-
).where(LoggingInfo.
|
228
|
+
LoggingInfo.job_id,
|
229
|
+
LoggingInfo.status,
|
230
|
+
LoggingInfo.status_time_order,
|
231
|
+
).where(LoggingInfo.job_id.in_(job_ids))
|
228
232
|
rows = await self.conn.execute(stmt)
|
229
233
|
if not rows.rowcount:
|
230
234
|
return {}
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from sqlalchemy import (
|
2
4
|
Integer,
|
3
5
|
Numeric,
|
@@ -13,13 +15,15 @@ JobLoggingDBBase = declarative_base()
|
|
13
15
|
|
14
16
|
class LoggingInfo(JobLoggingDBBase):
|
15
17
|
__tablename__ = "LoggingInfo"
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
18
|
+
job_id = Column("JobID", Integer)
|
19
|
+
seq_num = Column("SeqNum", Integer)
|
20
|
+
status = Column("Status", String(32), default="")
|
21
|
+
minor_status = Column("MinorStatus", String(128), default="")
|
22
|
+
application_status = Column("ApplicationStatus", String(255), default="")
|
23
|
+
status_time = DateNowColumn("StatusTime")
|
22
24
|
# TODO: Check that this corresponds to the DOUBLE(12,3) type in MySQL
|
23
|
-
|
24
|
-
|
25
|
+
status_time_order = Column(
|
26
|
+
"StatusTimeOrder", Numeric(precision=12, scale=3), default=0
|
27
|
+
)
|
28
|
+
source = Column("StatusSource", String(32), default="Unknown")
|
25
29
|
__table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),)
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from sqlalchemy import (
|
2
4
|
DateTime,
|
3
5
|
Double,
|
@@ -16,22 +18,22 @@ PilotAgentsDBBase = declarative_base()
|
|
16
18
|
class PilotAgents(PilotAgentsDBBase):
|
17
19
|
__tablename__ = "PilotAgents"
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
21
|
+
pilot_id = Column("PilotID", Integer, autoincrement=True, primary_key=True)
|
22
|
+
initial_job_id = Column("InitialJobID", Integer, default=0)
|
23
|
+
current_job_id = Column("CurrentJobID", Integer, default=0)
|
24
|
+
pilot_job_reference = Column("PilotJobReference", String(255), default="Unknown")
|
25
|
+
pilot_stamp = Column("PilotStamp", String(32), default="")
|
26
|
+
destination_site = Column("DestinationSite", String(128), default="NotAssigned")
|
27
|
+
queue = Column("Queue", String(128), default="Unknown")
|
28
|
+
grid_site = Column("GridSite", String(128), default="Unknown")
|
29
|
+
vo = Column("VO", String(128))
|
30
|
+
grid_type = Column("GridType", String(32), default="LCG")
|
31
|
+
benchmark = Column("BenchMark", Double, default=0.0)
|
32
|
+
submission_time = NullColumn("SubmissionTime", DateTime)
|
33
|
+
last_update_time = NullColumn("LastUpdateTime", DateTime)
|
34
|
+
status = Column("Status", String(32), default="Unknown")
|
35
|
+
status_reason = Column("StatusReason", String(255), default="Unknown")
|
36
|
+
accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False)
|
35
37
|
|
36
38
|
__table_args__ = (
|
37
39
|
Index("PilotJobReference", "PilotJobReference"),
|
@@ -43,9 +45,9 @@ class PilotAgents(PilotAgentsDBBase):
|
|
43
45
|
class JobToPilotMapping(PilotAgentsDBBase):
|
44
46
|
__tablename__ = "JobToPilotMapping"
|
45
47
|
|
46
|
-
|
47
|
-
|
48
|
-
|
48
|
+
pilot_id = Column("PilotID", Integer, primary_key=True)
|
49
|
+
job_id = Column("JobID", Integer, primary_key=True)
|
50
|
+
start_time = Column("StartTime", DateTime)
|
49
51
|
|
50
52
|
__table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID"))
|
51
53
|
|
@@ -53,6 +55,6 @@ class JobToPilotMapping(PilotAgentsDBBase):
|
|
53
55
|
class PilotOutput(PilotAgentsDBBase):
|
54
56
|
__tablename__ = "PilotOutput"
|
55
57
|
|
56
|
-
|
57
|
-
|
58
|
-
|
58
|
+
pilot_id = Column("PilotID", Integer, primary_key=True)
|
59
|
+
std_output = Column("StdOutput", Text)
|
60
|
+
std_error = Column("StdError", Text)
|
@@ -2,13 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
-
import
|
5
|
+
from sqlalchemy import Executable, delete, insert, literal, select, update
|
6
|
+
from sqlalchemy.exc import IntegrityError, NoResultFound
|
6
7
|
|
8
|
+
from diracx.core.exceptions import SandboxNotFoundError
|
7
9
|
from diracx.core.models import SandboxInfo, SandboxType, UserInfo
|
8
10
|
from diracx.db.sql.utils import BaseSQLDB, utcnow
|
9
11
|
|
10
12
|
from .schema import Base as SandboxMetadataDBBase
|
11
|
-
from .schema import
|
13
|
+
from .schema import SandBoxes, SBEntityMapping, SBOwners
|
12
14
|
|
13
15
|
|
14
16
|
class SandboxMetadataDB(BaseSQLDB):
|
@@ -17,16 +19,16 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
17
19
|
async def upsert_owner(self, user: UserInfo) -> int:
|
18
20
|
"""Get the id of the owner from the database."""
|
19
21
|
# TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
|
20
|
-
stmt =
|
21
|
-
|
22
|
-
|
23
|
-
|
22
|
+
stmt = select(SBOwners.OwnerID).where(
|
23
|
+
SBOwners.Owner == user.preferred_username,
|
24
|
+
SBOwners.OwnerGroup == user.dirac_group,
|
25
|
+
SBOwners.VO == user.vo,
|
24
26
|
)
|
25
27
|
result = await self.conn.execute(stmt)
|
26
28
|
if owner_id := result.scalar_one_or_none():
|
27
29
|
return owner_id
|
28
30
|
|
29
|
-
stmt =
|
31
|
+
stmt = insert(SBOwners).values(
|
30
32
|
Owner=user.preferred_username,
|
31
33
|
OwnerGroup=user.dirac_group,
|
32
34
|
VO=user.vo,
|
@@ -53,7 +55,7 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
53
55
|
"""Add a new sandbox in SandboxMetadataDB."""
|
54
56
|
# TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
|
55
57
|
owner_id = await self.upsert_owner(user)
|
56
|
-
stmt =
|
58
|
+
stmt = insert(SandBoxes).values(
|
57
59
|
OwnerId=owner_id,
|
58
60
|
SEName=se_name,
|
59
61
|
SEPFN=pfn,
|
@@ -63,27 +65,31 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
63
65
|
)
|
64
66
|
try:
|
65
67
|
result = await self.conn.execute(stmt)
|
66
|
-
except
|
68
|
+
except IntegrityError:
|
67
69
|
await self.update_sandbox_last_access_time(se_name, pfn)
|
68
70
|
else:
|
69
71
|
assert result.rowcount == 1
|
70
72
|
|
71
73
|
async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None:
|
72
74
|
stmt = (
|
73
|
-
|
74
|
-
.where(
|
75
|
+
update(SandBoxes)
|
76
|
+
.where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn)
|
75
77
|
.values(LastAccessTime=utcnow())
|
76
78
|
)
|
77
79
|
result = await self.conn.execute(stmt)
|
78
80
|
assert result.rowcount == 1
|
79
81
|
|
80
|
-
async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool:
|
82
|
+
async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool | None:
|
81
83
|
"""Checks if a sandbox exists and has been assigned."""
|
82
|
-
stmt:
|
83
|
-
|
84
|
+
stmt: Executable = select(SandBoxes.Assigned).where(
|
85
|
+
SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn
|
84
86
|
)
|
85
87
|
result = await self.conn.execute(stmt)
|
86
|
-
|
88
|
+
try:
|
89
|
+
is_assigned = result.scalar_one()
|
90
|
+
except NoResultFound as e:
|
91
|
+
raise SandboxNotFoundError(pfn, se_name) from e
|
92
|
+
|
87
93
|
return is_assigned
|
88
94
|
|
89
95
|
@staticmethod
|
@@ -97,11 +103,11 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
97
103
|
"""Get the sandbox assign to job."""
|
98
104
|
entity_id = self.jobid_to_entity_id(job_id)
|
99
105
|
stmt = (
|
100
|
-
|
101
|
-
.where(
|
106
|
+
select(SandBoxes.SEPFN)
|
107
|
+
.where(SandBoxes.SBId == SBEntityMapping.SBId)
|
102
108
|
.where(
|
103
|
-
|
104
|
-
|
109
|
+
SBEntityMapping.EntityId == entity_id,
|
110
|
+
SBEntityMapping.Type == sb_type,
|
105
111
|
)
|
106
112
|
)
|
107
113
|
result = await self.conn.execute(stmt)
|
@@ -118,24 +124,20 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
118
124
|
for job_id in jobs_ids:
|
119
125
|
# Define the entity id as 'Entity:entity_id' due to the DB definition:
|
120
126
|
entity_id = self.jobid_to_entity_id(job_id)
|
121
|
-
select_sb_id =
|
122
|
-
|
123
|
-
|
124
|
-
|
127
|
+
select_sb_id = select(
|
128
|
+
SandBoxes.SBId,
|
129
|
+
literal(entity_id).label("EntityId"),
|
130
|
+
literal(sb_type).label("Type"),
|
125
131
|
).where(
|
126
|
-
|
127
|
-
|
132
|
+
SandBoxes.SEName == se_name,
|
133
|
+
SandBoxes.SEPFN == pfn,
|
128
134
|
)
|
129
|
-
stmt =
|
135
|
+
stmt = insert(SBEntityMapping).from_select(
|
130
136
|
["SBId", "EntityId", "Type"], select_sb_id
|
131
137
|
)
|
132
138
|
await self.conn.execute(stmt)
|
133
139
|
|
134
|
-
stmt = (
|
135
|
-
sqlalchemy.update(sb_SandBoxes)
|
136
|
-
.where(sb_SandBoxes.SEPFN == pfn)
|
137
|
-
.values(Assigned=True)
|
138
|
-
)
|
140
|
+
stmt = update(SandBoxes).where(SandBoxes.SEPFN == pfn).values(Assigned=True)
|
139
141
|
result = await self.conn.execute(stmt)
|
140
142
|
assert result.rowcount == 1
|
141
143
|
|
@@ -143,29 +145,29 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
143
145
|
"""Delete mapping between jobs and sandboxes."""
|
144
146
|
for job_id in jobs_ids:
|
145
147
|
entity_id = self.jobid_to_entity_id(job_id)
|
146
|
-
sb_sel_stmt =
|
148
|
+
sb_sel_stmt = select(SandBoxes.SBId)
|
147
149
|
sb_sel_stmt = sb_sel_stmt.join(
|
148
|
-
|
150
|
+
SBEntityMapping, SBEntityMapping.SBId == SandBoxes.SBId
|
149
151
|
)
|
150
|
-
sb_sel_stmt = sb_sel_stmt.where(
|
152
|
+
sb_sel_stmt = sb_sel_stmt.where(SBEntityMapping.EntityId == entity_id)
|
151
153
|
|
152
154
|
result = await self.conn.execute(sb_sel_stmt)
|
153
155
|
sb_ids = [row.SBId for row in result]
|
154
156
|
|
155
|
-
del_stmt =
|
156
|
-
|
157
|
+
del_stmt = delete(SBEntityMapping).where(
|
158
|
+
SBEntityMapping.EntityId == entity_id
|
157
159
|
)
|
158
160
|
await self.conn.execute(del_stmt)
|
159
161
|
|
160
|
-
sb_entity_sel_stmt =
|
161
|
-
|
162
|
+
sb_entity_sel_stmt = select(SBEntityMapping.SBId).where(
|
163
|
+
SBEntityMapping.SBId.in_(sb_ids)
|
162
164
|
)
|
163
165
|
result = await self.conn.execute(sb_entity_sel_stmt)
|
164
166
|
remaining_sb_ids = [row.SBId for row in result]
|
165
167
|
if not remaining_sb_ids:
|
166
168
|
unassign_stmt = (
|
167
|
-
|
168
|
-
.where(
|
169
|
+
update(SandBoxes)
|
170
|
+
.where(SandBoxes.SBId.in_(sb_ids))
|
169
171
|
.values(Assigned=False)
|
170
172
|
)
|
171
173
|
await self.conn.execute(unassign_stmt)
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from sqlalchemy import (
|
2
4
|
BigInteger,
|
3
5
|
Boolean,
|
@@ -14,7 +16,7 @@ from diracx.db.sql.utils import Column, DateNowColumn
|
|
14
16
|
Base = declarative_base()
|
15
17
|
|
16
18
|
|
17
|
-
class
|
19
|
+
class SBOwners(Base):
|
18
20
|
__tablename__ = "sb_Owners"
|
19
21
|
OwnerID = Column(Integer, autoincrement=True)
|
20
22
|
Owner = Column(String(32))
|
@@ -23,7 +25,7 @@ class sb_Owners(Base):
|
|
23
25
|
__table_args__ = (PrimaryKeyConstraint("OwnerID"),)
|
24
26
|
|
25
27
|
|
26
|
-
class
|
28
|
+
class SandBoxes(Base):
|
27
29
|
__tablename__ = "sb_SandBoxes"
|
28
30
|
SBId = Column(Integer, autoincrement=True)
|
29
31
|
OwnerId = Column(Integer)
|
@@ -40,7 +42,7 @@ class sb_SandBoxes(Base):
|
|
40
42
|
)
|
41
43
|
|
42
44
|
|
43
|
-
class
|
45
|
+
class SBEntityMapping(Base):
|
44
46
|
__tablename__ = "sb_EntityMapping"
|
45
47
|
SBId = Column(Integer)
|
46
48
|
EntityId = Column(String(128))
|
diracx/db/sql/task_queue/db.py
CHANGED
@@ -121,12 +121,12 @@ class TaskQueueDB(BaseSQLDB):
|
|
121
121
|
# TODO: I guess the rows are already a list of tupes
|
122
122
|
# maybe refactor
|
123
123
|
data = [(r[0], r[1]) for r in rows if r]
|
124
|
-
|
124
|
+
num_owners = len(data)
|
125
125
|
# If there are no owners do now
|
126
|
-
if
|
126
|
+
if num_owners == 0:
|
127
127
|
return
|
128
128
|
# Split the share amongst the number of owners
|
129
|
-
entities_shares = {row[0]: job_share /
|
129
|
+
entities_shares = {row[0]: job_share / num_owners for row in data}
|
130
130
|
|
131
131
|
# TODO: implement the following
|
132
132
|
# If corrector is enabled let it work it's magic
|