diracx-db 0.0.1a22__py3-none-any.whl → 0.0.1a24__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diracx/db/exceptions.py +4 -1
- diracx/db/os/utils.py +3 -3
- diracx/db/sql/auth/db.py +9 -9
- diracx/db/sql/auth/schema.py +25 -23
- diracx/db/sql/dummy/db.py +2 -2
- diracx/db/sql/dummy/schema.py +8 -6
- diracx/db/sql/job/db.py +57 -54
- diracx/db/sql/job/schema.py +56 -54
- diracx/db/sql/job_logging/db.py +50 -46
- diracx/db/sql/job_logging/schema.py +12 -8
- diracx/db/sql/pilot_agents/schema.py +24 -22
- diracx/db/sql/sandbox_metadata/db.py +42 -40
- diracx/db/sql/sandbox_metadata/schema.py +5 -3
- diracx/db/sql/task_queue/db.py +3 -3
- diracx/db/sql/task_queue/schema.py +2 -0
- diracx/db/sql/utils/__init__.py +22 -451
- diracx/db/sql/utils/base.py +328 -0
- diracx/db/sql/utils/functions.py +105 -0
- diracx/db/sql/utils/job.py +59 -55
- diracx/db/sql/utils/types.py +43 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/METADATA +2 -2
- diracx_db-0.0.1a24.dist-info/RECORD +39 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/WHEEL +1 -1
- diracx_db-0.0.1a22.dist-info/RECORD +0 -36
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/entry_points.txt +0 -0
- {diracx_db-0.0.1a22.dist-info → diracx_db-0.0.1a24.dist-info}/top_level.txt +0 -0
diracx/db/sql/job_logging/db.py
CHANGED
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
12
12
|
|
13
13
|
from collections import defaultdict
|
14
14
|
|
15
|
-
from diracx.core.exceptions import
|
15
|
+
from diracx.core.exceptions import JobNotFoundError
|
16
16
|
from diracx.core.models import (
|
17
17
|
JobStatus,
|
18
18
|
JobStatusReturn,
|
@@ -57,9 +57,9 @@ class JobLoggingDB(BaseSQLDB):
|
|
57
57
|
as datetime.datetime object. If the time stamp is not provided the current
|
58
58
|
UTC time is used.
|
59
59
|
"""
|
60
|
-
# First, fetch the maximum
|
61
|
-
seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.
|
62
|
-
LoggingInfo.
|
60
|
+
# First, fetch the maximum seq_num for the given job_id
|
61
|
+
seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)).where(
|
62
|
+
LoggingInfo.job_id == job_id
|
63
63
|
)
|
64
64
|
seqnum = await self.conn.scalar(seqnum_stmt)
|
65
65
|
|
@@ -70,14 +70,14 @@ class JobLoggingDB(BaseSQLDB):
|
|
70
70
|
)
|
71
71
|
|
72
72
|
stmt = insert(LoggingInfo).values(
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
73
|
+
job_id=int(job_id),
|
74
|
+
seq_num=seqnum,
|
75
|
+
status=status,
|
76
|
+
minor_status=minor_status,
|
77
|
+
application_status=application_status[:255],
|
78
|
+
status_time=date,
|
79
|
+
status_time_order=epoc,
|
80
|
+
source=source[:32],
|
81
81
|
)
|
82
82
|
await self.conn.execute(stmt)
|
83
83
|
|
@@ -97,18 +97,17 @@ class JobLoggingDB(BaseSQLDB):
|
|
97
97
|
# First, fetch the maximum SeqNums for the given job_ids
|
98
98
|
seqnum_stmt = (
|
99
99
|
select(
|
100
|
-
LoggingInfo.
|
100
|
+
LoggingInfo.job_id, func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)
|
101
101
|
)
|
102
|
-
.where(LoggingInfo.
|
103
|
-
.group_by(LoggingInfo.
|
102
|
+
.where(LoggingInfo.job_id.in_([record.job_id for record in records]))
|
103
|
+
.group_by(LoggingInfo.job_id)
|
104
104
|
)
|
105
105
|
|
106
106
|
seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))}
|
107
107
|
# IF a seqnum is not found, then assume it does not exist and the first sequence number is 1.
|
108
|
-
|
109
108
|
# https://docs.sqlalchemy.org/en/20/orm/queryguide/dml.html#orm-bulk-insert-statements
|
110
109
|
await self.conn.execute(
|
111
|
-
insert(
|
110
|
+
LoggingInfo.__table__.insert(),
|
112
111
|
[
|
113
112
|
{
|
114
113
|
"JobID": record.job_id,
|
@@ -118,38 +117,43 @@ class JobLoggingDB(BaseSQLDB):
|
|
118
117
|
"ApplicationStatus": record.application_status[:255],
|
119
118
|
"StatusTime": record.date,
|
120
119
|
"StatusTimeOrder": get_epoc(record.date),
|
121
|
-
"
|
120
|
+
"StatusSource": record.source[:32],
|
122
121
|
}
|
123
122
|
for record in records
|
124
123
|
],
|
125
124
|
)
|
126
125
|
|
127
|
-
async def get_records(self,
|
126
|
+
async def get_records(self, job_ids: list[int]) -> dict[int, JobStatusReturn]:
|
128
127
|
"""Returns a Status,MinorStatus,ApplicationStatus,StatusTime,Source tuple
|
129
128
|
for each record found for job specified by its jobID in historical order.
|
130
129
|
"""
|
130
|
+
# We could potentially use a group_by here, but we need to post-process the
|
131
|
+
# results later.
|
131
132
|
stmt = (
|
132
133
|
select(
|
133
|
-
LoggingInfo.
|
134
|
-
LoggingInfo.
|
135
|
-
LoggingInfo.
|
136
|
-
LoggingInfo.
|
137
|
-
LoggingInfo.
|
134
|
+
LoggingInfo.job_id,
|
135
|
+
LoggingInfo.status,
|
136
|
+
LoggingInfo.minor_status,
|
137
|
+
LoggingInfo.application_status,
|
138
|
+
LoggingInfo.status_time,
|
139
|
+
LoggingInfo.source,
|
138
140
|
)
|
139
|
-
.where(LoggingInfo.
|
140
|
-
.order_by(LoggingInfo.
|
141
|
+
.where(LoggingInfo.job_id.in_(job_ids))
|
142
|
+
.order_by(LoggingInfo.status_time_order, LoggingInfo.status_time)
|
141
143
|
)
|
142
144
|
rows = await self.conn.execute(stmt)
|
143
145
|
|
144
|
-
values =
|
146
|
+
values = defaultdict(list)
|
145
147
|
for (
|
148
|
+
job_id,
|
146
149
|
status,
|
147
150
|
minor_status,
|
148
151
|
application_status,
|
149
152
|
status_time,
|
150
153
|
status_source,
|
151
154
|
) in rows:
|
152
|
-
|
155
|
+
|
156
|
+
values[job_id].append(
|
153
157
|
[
|
154
158
|
status,
|
155
159
|
minor_status,
|
@@ -161,16 +165,16 @@ class JobLoggingDB(BaseSQLDB):
|
|
161
165
|
|
162
166
|
# If no value has been set for the application status in the first place,
|
163
167
|
# We put this status to unknown
|
164
|
-
res =
|
165
|
-
|
166
|
-
if
|
167
|
-
|
168
|
+
res: dict = defaultdict(list)
|
169
|
+
for job_id, history in values.items():
|
170
|
+
if history[0][2] == "idem":
|
171
|
+
history[0][2] = "Unknown"
|
168
172
|
|
169
173
|
# We replace "idem" values by the value previously stated
|
170
|
-
for i in range(1, len(
|
174
|
+
for i in range(1, len(history)):
|
171
175
|
for j in range(3):
|
172
|
-
if
|
173
|
-
|
176
|
+
if history[i][j] == "idem":
|
177
|
+
history[i][j] = history[i - 1][j]
|
174
178
|
|
175
179
|
# And we replace arrays with tuples
|
176
180
|
for (
|
@@ -179,8 +183,8 @@ class JobLoggingDB(BaseSQLDB):
|
|
179
183
|
application_status,
|
180
184
|
status_time,
|
181
185
|
status_source,
|
182
|
-
) in
|
183
|
-
res.append(
|
186
|
+
) in history:
|
187
|
+
res[job_id].append(
|
184
188
|
JobStatusReturn(
|
185
189
|
Status=status,
|
186
190
|
MinorStatus=minor_status,
|
@@ -194,7 +198,7 @@ class JobLoggingDB(BaseSQLDB):
|
|
194
198
|
|
195
199
|
async def delete_records(self, job_ids: list[int]):
|
196
200
|
"""Delete logging records for given jobs."""
|
197
|
-
stmt = delete(LoggingInfo).where(LoggingInfo.
|
201
|
+
stmt = delete(LoggingInfo).where(LoggingInfo.job_id.in_(job_ids))
|
198
202
|
await self.conn.execute(stmt)
|
199
203
|
|
200
204
|
async def get_wms_time_stamps(self, job_id):
|
@@ -203,12 +207,12 @@ class JobLoggingDB(BaseSQLDB):
|
|
203
207
|
"""
|
204
208
|
result = {}
|
205
209
|
stmt = select(
|
206
|
-
LoggingInfo.
|
207
|
-
LoggingInfo.
|
208
|
-
).where(LoggingInfo.
|
210
|
+
LoggingInfo.status,
|
211
|
+
LoggingInfo.status_time_order,
|
212
|
+
).where(LoggingInfo.job_id == job_id)
|
209
213
|
rows = await self.conn.execute(stmt)
|
210
214
|
if not rows.rowcount:
|
211
|
-
raise
|
215
|
+
raise JobNotFoundError(job_id) from None
|
212
216
|
|
213
217
|
for event, etime in rows:
|
214
218
|
result[event] = str(etime + MAGIC_EPOC_NUMBER)
|
@@ -221,10 +225,10 @@ class JobLoggingDB(BaseSQLDB):
|
|
221
225
|
"""
|
222
226
|
result = defaultdict(dict)
|
223
227
|
stmt = select(
|
224
|
-
LoggingInfo.
|
225
|
-
LoggingInfo.
|
226
|
-
LoggingInfo.
|
227
|
-
).where(LoggingInfo.
|
228
|
+
LoggingInfo.job_id,
|
229
|
+
LoggingInfo.status,
|
230
|
+
LoggingInfo.status_time_order,
|
231
|
+
).where(LoggingInfo.job_id.in_(job_ids))
|
228
232
|
rows = await self.conn.execute(stmt)
|
229
233
|
if not rows.rowcount:
|
230
234
|
return {}
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from sqlalchemy import (
|
2
4
|
Integer,
|
3
5
|
Numeric,
|
@@ -13,13 +15,15 @@ JobLoggingDBBase = declarative_base()
|
|
13
15
|
|
14
16
|
class LoggingInfo(JobLoggingDBBase):
|
15
17
|
__tablename__ = "LoggingInfo"
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
18
|
+
job_id = Column("JobID", Integer)
|
19
|
+
seq_num = Column("SeqNum", Integer)
|
20
|
+
status = Column("Status", String(32), default="")
|
21
|
+
minor_status = Column("MinorStatus", String(128), default="")
|
22
|
+
application_status = Column("ApplicationStatus", String(255), default="")
|
23
|
+
status_time = DateNowColumn("StatusTime")
|
22
24
|
# TODO: Check that this corresponds to the DOUBLE(12,3) type in MySQL
|
23
|
-
|
24
|
-
|
25
|
+
status_time_order = Column(
|
26
|
+
"StatusTimeOrder", Numeric(precision=12, scale=3), default=0
|
27
|
+
)
|
28
|
+
source = Column("StatusSource", String(32), default="Unknown")
|
25
29
|
__table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),)
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from sqlalchemy import (
|
2
4
|
DateTime,
|
3
5
|
Double,
|
@@ -16,22 +18,22 @@ PilotAgentsDBBase = declarative_base()
|
|
16
18
|
class PilotAgents(PilotAgentsDBBase):
|
17
19
|
__tablename__ = "PilotAgents"
|
18
20
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
21
|
+
pilot_id = Column("PilotID", Integer, autoincrement=True, primary_key=True)
|
22
|
+
initial_job_id = Column("InitialJobID", Integer, default=0)
|
23
|
+
current_job_id = Column("CurrentJobID", Integer, default=0)
|
24
|
+
pilot_job_reference = Column("PilotJobReference", String(255), default="Unknown")
|
25
|
+
pilot_stamp = Column("PilotStamp", String(32), default="")
|
26
|
+
destination_site = Column("DestinationSite", String(128), default="NotAssigned")
|
27
|
+
queue = Column("Queue", String(128), default="Unknown")
|
28
|
+
grid_site = Column("GridSite", String(128), default="Unknown")
|
29
|
+
vo = Column("VO", String(128))
|
30
|
+
grid_type = Column("GridType", String(32), default="LCG")
|
31
|
+
benchmark = Column("BenchMark", Double, default=0.0)
|
32
|
+
submission_time = NullColumn("SubmissionTime", DateTime)
|
33
|
+
last_update_time = NullColumn("LastUpdateTime", DateTime)
|
34
|
+
status = Column("Status", String(32), default="Unknown")
|
35
|
+
status_reason = Column("StatusReason", String(255), default="Unknown")
|
36
|
+
accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False)
|
35
37
|
|
36
38
|
__table_args__ = (
|
37
39
|
Index("PilotJobReference", "PilotJobReference"),
|
@@ -43,9 +45,9 @@ class PilotAgents(PilotAgentsDBBase):
|
|
43
45
|
class JobToPilotMapping(PilotAgentsDBBase):
|
44
46
|
__tablename__ = "JobToPilotMapping"
|
45
47
|
|
46
|
-
|
47
|
-
|
48
|
-
|
48
|
+
pilot_id = Column("PilotID", Integer, primary_key=True)
|
49
|
+
job_id = Column("JobID", Integer, primary_key=True)
|
50
|
+
start_time = Column("StartTime", DateTime)
|
49
51
|
|
50
52
|
__table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID"))
|
51
53
|
|
@@ -53,6 +55,6 @@ class JobToPilotMapping(PilotAgentsDBBase):
|
|
53
55
|
class PilotOutput(PilotAgentsDBBase):
|
54
56
|
__tablename__ = "PilotOutput"
|
55
57
|
|
56
|
-
|
57
|
-
|
58
|
-
|
58
|
+
pilot_id = Column("PilotID", Integer, primary_key=True)
|
59
|
+
std_output = Column("StdOutput", Text)
|
60
|
+
std_error = Column("StdError", Text)
|
@@ -2,13 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from typing import Any
|
4
4
|
|
5
|
-
import
|
5
|
+
from sqlalchemy import Executable, delete, insert, literal, select, update
|
6
|
+
from sqlalchemy.exc import IntegrityError, NoResultFound
|
6
7
|
|
8
|
+
from diracx.core.exceptions import SandboxNotFoundError
|
7
9
|
from diracx.core.models import SandboxInfo, SandboxType, UserInfo
|
8
10
|
from diracx.db.sql.utils import BaseSQLDB, utcnow
|
9
11
|
|
10
12
|
from .schema import Base as SandboxMetadataDBBase
|
11
|
-
from .schema import
|
13
|
+
from .schema import SandBoxes, SBEntityMapping, SBOwners
|
12
14
|
|
13
15
|
|
14
16
|
class SandboxMetadataDB(BaseSQLDB):
|
@@ -17,16 +19,16 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
17
19
|
async def upsert_owner(self, user: UserInfo) -> int:
|
18
20
|
"""Get the id of the owner from the database."""
|
19
21
|
# TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
|
20
|
-
stmt =
|
21
|
-
|
22
|
-
|
23
|
-
|
22
|
+
stmt = select(SBOwners.OwnerID).where(
|
23
|
+
SBOwners.Owner == user.preferred_username,
|
24
|
+
SBOwners.OwnerGroup == user.dirac_group,
|
25
|
+
SBOwners.VO == user.vo,
|
24
26
|
)
|
25
27
|
result = await self.conn.execute(stmt)
|
26
28
|
if owner_id := result.scalar_one_or_none():
|
27
29
|
return owner_id
|
28
30
|
|
29
|
-
stmt =
|
31
|
+
stmt = insert(SBOwners).values(
|
30
32
|
Owner=user.preferred_username,
|
31
33
|
OwnerGroup=user.dirac_group,
|
32
34
|
VO=user.vo,
|
@@ -53,7 +55,7 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
53
55
|
"""Add a new sandbox in SandboxMetadataDB."""
|
54
56
|
# TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
|
55
57
|
owner_id = await self.upsert_owner(user)
|
56
|
-
stmt =
|
58
|
+
stmt = insert(SandBoxes).values(
|
57
59
|
OwnerId=owner_id,
|
58
60
|
SEName=se_name,
|
59
61
|
SEPFN=pfn,
|
@@ -63,27 +65,31 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
63
65
|
)
|
64
66
|
try:
|
65
67
|
result = await self.conn.execute(stmt)
|
66
|
-
except
|
68
|
+
except IntegrityError:
|
67
69
|
await self.update_sandbox_last_access_time(se_name, pfn)
|
68
70
|
else:
|
69
71
|
assert result.rowcount == 1
|
70
72
|
|
71
73
|
async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None:
|
72
74
|
stmt = (
|
73
|
-
|
74
|
-
.where(
|
75
|
+
update(SandBoxes)
|
76
|
+
.where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn)
|
75
77
|
.values(LastAccessTime=utcnow())
|
76
78
|
)
|
77
79
|
result = await self.conn.execute(stmt)
|
78
80
|
assert result.rowcount == 1
|
79
81
|
|
80
|
-
async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool:
|
82
|
+
async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool | None:
|
81
83
|
"""Checks if a sandbox exists and has been assigned."""
|
82
|
-
stmt:
|
83
|
-
|
84
|
+
stmt: Executable = select(SandBoxes.Assigned).where(
|
85
|
+
SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn
|
84
86
|
)
|
85
87
|
result = await self.conn.execute(stmt)
|
86
|
-
|
88
|
+
try:
|
89
|
+
is_assigned = result.scalar_one()
|
90
|
+
except NoResultFound as e:
|
91
|
+
raise SandboxNotFoundError(pfn, se_name) from e
|
92
|
+
|
87
93
|
return is_assigned
|
88
94
|
|
89
95
|
@staticmethod
|
@@ -97,11 +103,11 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
97
103
|
"""Get the sandbox assign to job."""
|
98
104
|
entity_id = self.jobid_to_entity_id(job_id)
|
99
105
|
stmt = (
|
100
|
-
|
101
|
-
.where(
|
106
|
+
select(SandBoxes.SEPFN)
|
107
|
+
.where(SandBoxes.SBId == SBEntityMapping.SBId)
|
102
108
|
.where(
|
103
|
-
|
104
|
-
|
109
|
+
SBEntityMapping.EntityId == entity_id,
|
110
|
+
SBEntityMapping.Type == sb_type,
|
105
111
|
)
|
106
112
|
)
|
107
113
|
result = await self.conn.execute(stmt)
|
@@ -118,24 +124,20 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
118
124
|
for job_id in jobs_ids:
|
119
125
|
# Define the entity id as 'Entity:entity_id' due to the DB definition:
|
120
126
|
entity_id = self.jobid_to_entity_id(job_id)
|
121
|
-
select_sb_id =
|
122
|
-
|
123
|
-
|
124
|
-
|
127
|
+
select_sb_id = select(
|
128
|
+
SandBoxes.SBId,
|
129
|
+
literal(entity_id).label("EntityId"),
|
130
|
+
literal(sb_type).label("Type"),
|
125
131
|
).where(
|
126
|
-
|
127
|
-
|
132
|
+
SandBoxes.SEName == se_name,
|
133
|
+
SandBoxes.SEPFN == pfn,
|
128
134
|
)
|
129
|
-
stmt =
|
135
|
+
stmt = insert(SBEntityMapping).from_select(
|
130
136
|
["SBId", "EntityId", "Type"], select_sb_id
|
131
137
|
)
|
132
138
|
await self.conn.execute(stmt)
|
133
139
|
|
134
|
-
stmt = (
|
135
|
-
sqlalchemy.update(sb_SandBoxes)
|
136
|
-
.where(sb_SandBoxes.SEPFN == pfn)
|
137
|
-
.values(Assigned=True)
|
138
|
-
)
|
140
|
+
stmt = update(SandBoxes).where(SandBoxes.SEPFN == pfn).values(Assigned=True)
|
139
141
|
result = await self.conn.execute(stmt)
|
140
142
|
assert result.rowcount == 1
|
141
143
|
|
@@ -143,29 +145,29 @@ class SandboxMetadataDB(BaseSQLDB):
|
|
143
145
|
"""Delete mapping between jobs and sandboxes."""
|
144
146
|
for job_id in jobs_ids:
|
145
147
|
entity_id = self.jobid_to_entity_id(job_id)
|
146
|
-
sb_sel_stmt =
|
148
|
+
sb_sel_stmt = select(SandBoxes.SBId)
|
147
149
|
sb_sel_stmt = sb_sel_stmt.join(
|
148
|
-
|
150
|
+
SBEntityMapping, SBEntityMapping.SBId == SandBoxes.SBId
|
149
151
|
)
|
150
|
-
sb_sel_stmt = sb_sel_stmt.where(
|
152
|
+
sb_sel_stmt = sb_sel_stmt.where(SBEntityMapping.EntityId == entity_id)
|
151
153
|
|
152
154
|
result = await self.conn.execute(sb_sel_stmt)
|
153
155
|
sb_ids = [row.SBId for row in result]
|
154
156
|
|
155
|
-
del_stmt =
|
156
|
-
|
157
|
+
del_stmt = delete(SBEntityMapping).where(
|
158
|
+
SBEntityMapping.EntityId == entity_id
|
157
159
|
)
|
158
160
|
await self.conn.execute(del_stmt)
|
159
161
|
|
160
|
-
sb_entity_sel_stmt =
|
161
|
-
|
162
|
+
sb_entity_sel_stmt = select(SBEntityMapping.SBId).where(
|
163
|
+
SBEntityMapping.SBId.in_(sb_ids)
|
162
164
|
)
|
163
165
|
result = await self.conn.execute(sb_entity_sel_stmt)
|
164
166
|
remaining_sb_ids = [row.SBId for row in result]
|
165
167
|
if not remaining_sb_ids:
|
166
168
|
unassign_stmt = (
|
167
|
-
|
168
|
-
.where(
|
169
|
+
update(SandBoxes)
|
170
|
+
.where(SandBoxes.SBId.in_(sb_ids))
|
169
171
|
.values(Assigned=False)
|
170
172
|
)
|
171
173
|
await self.conn.execute(unassign_stmt)
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from sqlalchemy import (
|
2
4
|
BigInteger,
|
3
5
|
Boolean,
|
@@ -14,7 +16,7 @@ from diracx.db.sql.utils import Column, DateNowColumn
|
|
14
16
|
Base = declarative_base()
|
15
17
|
|
16
18
|
|
17
|
-
class
|
19
|
+
class SBOwners(Base):
|
18
20
|
__tablename__ = "sb_Owners"
|
19
21
|
OwnerID = Column(Integer, autoincrement=True)
|
20
22
|
Owner = Column(String(32))
|
@@ -23,7 +25,7 @@ class sb_Owners(Base):
|
|
23
25
|
__table_args__ = (PrimaryKeyConstraint("OwnerID"),)
|
24
26
|
|
25
27
|
|
26
|
-
class
|
28
|
+
class SandBoxes(Base):
|
27
29
|
__tablename__ = "sb_SandBoxes"
|
28
30
|
SBId = Column(Integer, autoincrement=True)
|
29
31
|
OwnerId = Column(Integer)
|
@@ -40,7 +42,7 @@ class sb_SandBoxes(Base):
|
|
40
42
|
)
|
41
43
|
|
42
44
|
|
43
|
-
class
|
45
|
+
class SBEntityMapping(Base):
|
44
46
|
__tablename__ = "sb_EntityMapping"
|
45
47
|
SBId = Column(Integer)
|
46
48
|
EntityId = Column(String(128))
|
diracx/db/sql/task_queue/db.py
CHANGED
@@ -121,12 +121,12 @@ class TaskQueueDB(BaseSQLDB):
|
|
121
121
|
# TODO: I guess the rows are already a list of tupes
|
122
122
|
# maybe refactor
|
123
123
|
data = [(r[0], r[1]) for r in rows if r]
|
124
|
-
|
124
|
+
num_owners = len(data)
|
125
125
|
# If there are no owners do now
|
126
|
-
if
|
126
|
+
if num_owners == 0:
|
127
127
|
return
|
128
128
|
# Split the share amongst the number of owners
|
129
|
-
entities_shares = {row[0]: job_share /
|
129
|
+
entities_shares = {row[0]: job_share / num_owners for row in data}
|
130
130
|
|
131
131
|
# TODO: implement the following
|
132
132
|
# If corrector is enabled let it work it's magic
|