diracx-db 0.0.1a23__py3-none-any.whl → 0.0.1a24__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
 
13
13
  from collections import defaultdict
14
14
 
15
- from diracx.core.exceptions import JobNotFound
15
+ from diracx.core.exceptions import JobNotFoundError
16
16
  from diracx.core.models import (
17
17
  JobStatus,
18
18
  JobStatusReturn,
@@ -57,9 +57,9 @@ class JobLoggingDB(BaseSQLDB):
57
57
  as datetime.datetime object. If the time stamp is not provided the current
58
58
  UTC time is used.
59
59
  """
60
- # First, fetch the maximum SeqNum for the given job_id
61
- seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)).where(
62
- LoggingInfo.JobID == job_id
60
+ # First, fetch the maximum seq_num for the given job_id
61
+ seqnum_stmt = select(func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)).where(
62
+ LoggingInfo.job_id == job_id
63
63
  )
64
64
  seqnum = await self.conn.scalar(seqnum_stmt)
65
65
 
@@ -70,14 +70,14 @@ class JobLoggingDB(BaseSQLDB):
70
70
  )
71
71
 
72
72
  stmt = insert(LoggingInfo).values(
73
- JobID=int(job_id),
74
- SeqNum=seqnum,
75
- Status=status,
76
- MinorStatus=minor_status,
77
- ApplicationStatus=application_status[:255],
78
- StatusTime=date,
79
- StatusTimeOrder=epoc,
80
- Source=source[:32],
73
+ job_id=int(job_id),
74
+ seq_num=seqnum,
75
+ status=status,
76
+ minor_status=minor_status,
77
+ application_status=application_status[:255],
78
+ status_time=date,
79
+ status_time_order=epoc,
80
+ source=source[:32],
81
81
  )
82
82
  await self.conn.execute(stmt)
83
83
 
@@ -97,10 +97,10 @@ class JobLoggingDB(BaseSQLDB):
97
97
  # First, fetch the maximum SeqNums for the given job_ids
98
98
  seqnum_stmt = (
99
99
  select(
100
- LoggingInfo.JobID, func.coalesce(func.max(LoggingInfo.SeqNum) + 1, 1)
100
+ LoggingInfo.job_id, func.coalesce(func.max(LoggingInfo.seq_num) + 1, 1)
101
101
  )
102
- .where(LoggingInfo.JobID.in_([record.job_id for record in records]))
103
- .group_by(LoggingInfo.JobID)
102
+ .where(LoggingInfo.job_id.in_([record.job_id for record in records]))
103
+ .group_by(LoggingInfo.job_id)
104
104
  )
105
105
 
106
106
  seqnum = {jid: seqnum for jid, seqnum in (await self.conn.execute(seqnum_stmt))}
@@ -131,15 +131,15 @@ class JobLoggingDB(BaseSQLDB):
131
131
  # results later.
132
132
  stmt = (
133
133
  select(
134
- LoggingInfo.JobID,
135
- LoggingInfo.Status,
136
- LoggingInfo.MinorStatus,
137
- LoggingInfo.ApplicationStatus,
138
- LoggingInfo.StatusTime,
139
- LoggingInfo.Source,
134
+ LoggingInfo.job_id,
135
+ LoggingInfo.status,
136
+ LoggingInfo.minor_status,
137
+ LoggingInfo.application_status,
138
+ LoggingInfo.status_time,
139
+ LoggingInfo.source,
140
140
  )
141
- .where(LoggingInfo.JobID.in_(job_ids))
142
- .order_by(LoggingInfo.StatusTimeOrder, LoggingInfo.StatusTime)
141
+ .where(LoggingInfo.job_id.in_(job_ids))
142
+ .order_by(LoggingInfo.status_time_order, LoggingInfo.status_time)
143
143
  )
144
144
  rows = await self.conn.execute(stmt)
145
145
 
@@ -198,7 +198,7 @@ class JobLoggingDB(BaseSQLDB):
198
198
 
199
199
  async def delete_records(self, job_ids: list[int]):
200
200
  """Delete logging records for given jobs."""
201
- stmt = delete(LoggingInfo).where(LoggingInfo.JobID.in_(job_ids))
201
+ stmt = delete(LoggingInfo).where(LoggingInfo.job_id.in_(job_ids))
202
202
  await self.conn.execute(stmt)
203
203
 
204
204
  async def get_wms_time_stamps(self, job_id):
@@ -207,12 +207,12 @@ class JobLoggingDB(BaseSQLDB):
207
207
  """
208
208
  result = {}
209
209
  stmt = select(
210
- LoggingInfo.Status,
211
- LoggingInfo.StatusTimeOrder,
212
- ).where(LoggingInfo.JobID == job_id)
210
+ LoggingInfo.status,
211
+ LoggingInfo.status_time_order,
212
+ ).where(LoggingInfo.job_id == job_id)
213
213
  rows = await self.conn.execute(stmt)
214
214
  if not rows.rowcount:
215
- raise JobNotFound(job_id) from None
215
+ raise JobNotFoundError(job_id) from None
216
216
 
217
217
  for event, etime in rows:
218
218
  result[event] = str(etime + MAGIC_EPOC_NUMBER)
@@ -225,10 +225,10 @@ class JobLoggingDB(BaseSQLDB):
225
225
  """
226
226
  result = defaultdict(dict)
227
227
  stmt = select(
228
- LoggingInfo.JobID,
229
- LoggingInfo.Status,
230
- LoggingInfo.StatusTimeOrder,
231
- ).where(LoggingInfo.JobID.in_(job_ids))
228
+ LoggingInfo.job_id,
229
+ LoggingInfo.status,
230
+ LoggingInfo.status_time_order,
231
+ ).where(LoggingInfo.job_id.in_(job_ids))
232
232
  rows = await self.conn.execute(stmt)
233
233
  if not rows.rowcount:
234
234
  return {}
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from sqlalchemy import (
2
4
  Integer,
3
5
  Numeric,
@@ -13,13 +15,15 @@ JobLoggingDBBase = declarative_base()
13
15
 
14
16
  class LoggingInfo(JobLoggingDBBase):
15
17
  __tablename__ = "LoggingInfo"
16
- JobID = Column(Integer)
17
- SeqNum = Column(Integer)
18
- Status = Column(String(32), default="")
19
- MinorStatus = Column(String(128), default="")
20
- ApplicationStatus = Column(String(255), default="")
21
- StatusTime = DateNowColumn()
18
+ job_id = Column("JobID", Integer)
19
+ seq_num = Column("SeqNum", Integer)
20
+ status = Column("Status", String(32), default="")
21
+ minor_status = Column("MinorStatus", String(128), default="")
22
+ application_status = Column("ApplicationStatus", String(255), default="")
23
+ status_time = DateNowColumn("StatusTime")
22
24
  # TODO: Check that this corresponds to the DOUBLE(12,3) type in MySQL
23
- StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0)
24
- Source = Column(String(32), default="Unknown", name="StatusSource")
25
+ status_time_order = Column(
26
+ "StatusTimeOrder", Numeric(precision=12, scale=3), default=0
27
+ )
28
+ source = Column("StatusSource", String(32), default="Unknown")
25
29
  __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from sqlalchemy import (
2
4
  DateTime,
3
5
  Double,
@@ -16,22 +18,22 @@ PilotAgentsDBBase = declarative_base()
16
18
  class PilotAgents(PilotAgentsDBBase):
17
19
  __tablename__ = "PilotAgents"
18
20
 
19
- PilotID = Column("PilotID", Integer, autoincrement=True, primary_key=True)
20
- InitialJobID = Column("InitialJobID", Integer, default=0)
21
- CurrentJobID = Column("CurrentJobID", Integer, default=0)
22
- PilotJobReference = Column("PilotJobReference", String(255), default="Unknown")
23
- PilotStamp = Column("PilotStamp", String(32), default="")
24
- DestinationSite = Column("DestinationSite", String(128), default="NotAssigned")
25
- Queue = Column("Queue", String(128), default="Unknown")
26
- GridSite = Column("GridSite", String(128), default="Unknown")
27
- VO = Column("VO", String(128))
28
- GridType = Column("GridType", String(32), default="LCG")
29
- BenchMark = Column("BenchMark", Double, default=0.0)
30
- SubmissionTime = NullColumn("SubmissionTime", DateTime)
31
- LastUpdateTime = NullColumn("LastUpdateTime", DateTime)
32
- Status = Column("Status", String(32), default="Unknown")
33
- StatusReason = Column("StatusReason", String(255), default="Unknown")
34
- AccountingSent = Column("AccountingSent", EnumBackedBool(), default=False)
21
+ pilot_id = Column("PilotID", Integer, autoincrement=True, primary_key=True)
22
+ initial_job_id = Column("InitialJobID", Integer, default=0)
23
+ current_job_id = Column("CurrentJobID", Integer, default=0)
24
+ pilot_job_reference = Column("PilotJobReference", String(255), default="Unknown")
25
+ pilot_stamp = Column("PilotStamp", String(32), default="")
26
+ destination_site = Column("DestinationSite", String(128), default="NotAssigned")
27
+ queue = Column("Queue", String(128), default="Unknown")
28
+ grid_site = Column("GridSite", String(128), default="Unknown")
29
+ vo = Column("VO", String(128))
30
+ grid_type = Column("GridType", String(32), default="LCG")
31
+ benchmark = Column("BenchMark", Double, default=0.0)
32
+ submission_time = NullColumn("SubmissionTime", DateTime)
33
+ last_update_time = NullColumn("LastUpdateTime", DateTime)
34
+ status = Column("Status", String(32), default="Unknown")
35
+ status_reason = Column("StatusReason", String(255), default="Unknown")
36
+ accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False)
35
37
 
36
38
  __table_args__ = (
37
39
  Index("PilotJobReference", "PilotJobReference"),
@@ -43,9 +45,9 @@ class PilotAgents(PilotAgentsDBBase):
43
45
  class JobToPilotMapping(PilotAgentsDBBase):
44
46
  __tablename__ = "JobToPilotMapping"
45
47
 
46
- PilotID = Column("PilotID", Integer, primary_key=True)
47
- JobID = Column("JobID", Integer, primary_key=True)
48
- StartTime = Column("StartTime", DateTime)
48
+ pilot_id = Column("PilotID", Integer, primary_key=True)
49
+ job_id = Column("JobID", Integer, primary_key=True)
50
+ start_time = Column("StartTime", DateTime)
49
51
 
50
52
  __table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID"))
51
53
 
@@ -53,6 +55,6 @@ class JobToPilotMapping(PilotAgentsDBBase):
53
55
  class PilotOutput(PilotAgentsDBBase):
54
56
  __tablename__ = "PilotOutput"
55
57
 
56
- PilotID = Column("PilotID", Integer, primary_key=True)
57
- StdOutput = Column("StdOutput", Text)
58
- StdError = Column("StdError", Text)
58
+ pilot_id = Column("PilotID", Integer, primary_key=True)
59
+ std_output = Column("StdOutput", Text)
60
+ std_error = Column("StdError", Text)
@@ -2,13 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any
4
4
 
5
- import sqlalchemy
5
+ from sqlalchemy import Executable, delete, insert, literal, select, update
6
+ from sqlalchemy.exc import IntegrityError, NoResultFound
6
7
 
8
+ from diracx.core.exceptions import SandboxNotFoundError
7
9
  from diracx.core.models import SandboxInfo, SandboxType, UserInfo
8
10
  from diracx.db.sql.utils import BaseSQLDB, utcnow
9
11
 
10
12
  from .schema import Base as SandboxMetadataDBBase
11
- from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes
13
+ from .schema import SandBoxes, SBEntityMapping, SBOwners
12
14
 
13
15
 
14
16
  class SandboxMetadataDB(BaseSQLDB):
@@ -17,16 +19,16 @@ class SandboxMetadataDB(BaseSQLDB):
17
19
  async def upsert_owner(self, user: UserInfo) -> int:
18
20
  """Get the id of the owner from the database."""
19
21
  # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
20
- stmt = sqlalchemy.select(sb_Owners.OwnerID).where(
21
- sb_Owners.Owner == user.preferred_username,
22
- sb_Owners.OwnerGroup == user.dirac_group,
23
- sb_Owners.VO == user.vo,
22
+ stmt = select(SBOwners.OwnerID).where(
23
+ SBOwners.Owner == user.preferred_username,
24
+ SBOwners.OwnerGroup == user.dirac_group,
25
+ SBOwners.VO == user.vo,
24
26
  )
25
27
  result = await self.conn.execute(stmt)
26
28
  if owner_id := result.scalar_one_or_none():
27
29
  return owner_id
28
30
 
29
- stmt = sqlalchemy.insert(sb_Owners).values(
31
+ stmt = insert(SBOwners).values(
30
32
  Owner=user.preferred_username,
31
33
  OwnerGroup=user.dirac_group,
32
34
  VO=user.vo,
@@ -53,7 +55,7 @@ class SandboxMetadataDB(BaseSQLDB):
53
55
  """Add a new sandbox in SandboxMetadataDB."""
54
56
  # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
55
57
  owner_id = await self.upsert_owner(user)
56
- stmt = sqlalchemy.insert(sb_SandBoxes).values(
58
+ stmt = insert(SandBoxes).values(
57
59
  OwnerId=owner_id,
58
60
  SEName=se_name,
59
61
  SEPFN=pfn,
@@ -63,27 +65,31 @@ class SandboxMetadataDB(BaseSQLDB):
63
65
  )
64
66
  try:
65
67
  result = await self.conn.execute(stmt)
66
- except sqlalchemy.exc.IntegrityError:
68
+ except IntegrityError:
67
69
  await self.update_sandbox_last_access_time(se_name, pfn)
68
70
  else:
69
71
  assert result.rowcount == 1
70
72
 
71
73
  async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None:
72
74
  stmt = (
73
- sqlalchemy.update(sb_SandBoxes)
74
- .where(sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn)
75
+ update(SandBoxes)
76
+ .where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn)
75
77
  .values(LastAccessTime=utcnow())
76
78
  )
77
79
  result = await self.conn.execute(stmt)
78
80
  assert result.rowcount == 1
79
81
 
80
- async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool:
82
+ async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool | None:
81
83
  """Checks if a sandbox exists and has been assigned."""
82
- stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where(
83
- sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn
84
+ stmt: Executable = select(SandBoxes.Assigned).where(
85
+ SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn
84
86
  )
85
87
  result = await self.conn.execute(stmt)
86
- is_assigned = result.scalar_one()
88
+ try:
89
+ is_assigned = result.scalar_one()
90
+ except NoResultFound as e:
91
+ raise SandboxNotFoundError(pfn, se_name) from e
92
+
87
93
  return is_assigned
88
94
 
89
95
  @staticmethod
@@ -97,11 +103,11 @@ class SandboxMetadataDB(BaseSQLDB):
97
103
  """Get the sandbox assign to job."""
98
104
  entity_id = self.jobid_to_entity_id(job_id)
99
105
  stmt = (
100
- sqlalchemy.select(sb_SandBoxes.SEPFN)
101
- .where(sb_SandBoxes.SBId == sb_EntityMapping.SBId)
106
+ select(SandBoxes.SEPFN)
107
+ .where(SandBoxes.SBId == SBEntityMapping.SBId)
102
108
  .where(
103
- sb_EntityMapping.EntityId == entity_id,
104
- sb_EntityMapping.Type == sb_type,
109
+ SBEntityMapping.EntityId == entity_id,
110
+ SBEntityMapping.Type == sb_type,
105
111
  )
106
112
  )
107
113
  result = await self.conn.execute(stmt)
@@ -118,24 +124,20 @@ class SandboxMetadataDB(BaseSQLDB):
118
124
  for job_id in jobs_ids:
119
125
  # Define the entity id as 'Entity:entity_id' due to the DB definition:
120
126
  entity_id = self.jobid_to_entity_id(job_id)
121
- select_sb_id = sqlalchemy.select(
122
- sb_SandBoxes.SBId,
123
- sqlalchemy.literal(entity_id).label("EntityId"),
124
- sqlalchemy.literal(sb_type).label("Type"),
127
+ select_sb_id = select(
128
+ SandBoxes.SBId,
129
+ literal(entity_id).label("EntityId"),
130
+ literal(sb_type).label("Type"),
125
131
  ).where(
126
- sb_SandBoxes.SEName == se_name,
127
- sb_SandBoxes.SEPFN == pfn,
132
+ SandBoxes.SEName == se_name,
133
+ SandBoxes.SEPFN == pfn,
128
134
  )
129
- stmt = sqlalchemy.insert(sb_EntityMapping).from_select(
135
+ stmt = insert(SBEntityMapping).from_select(
130
136
  ["SBId", "EntityId", "Type"], select_sb_id
131
137
  )
132
138
  await self.conn.execute(stmt)
133
139
 
134
- stmt = (
135
- sqlalchemy.update(sb_SandBoxes)
136
- .where(sb_SandBoxes.SEPFN == pfn)
137
- .values(Assigned=True)
138
- )
140
+ stmt = update(SandBoxes).where(SandBoxes.SEPFN == pfn).values(Assigned=True)
139
141
  result = await self.conn.execute(stmt)
140
142
  assert result.rowcount == 1
141
143
 
@@ -143,29 +145,29 @@ class SandboxMetadataDB(BaseSQLDB):
143
145
  """Delete mapping between jobs and sandboxes."""
144
146
  for job_id in jobs_ids:
145
147
  entity_id = self.jobid_to_entity_id(job_id)
146
- sb_sel_stmt = sqlalchemy.select(sb_SandBoxes.SBId)
148
+ sb_sel_stmt = select(SandBoxes.SBId)
147
149
  sb_sel_stmt = sb_sel_stmt.join(
148
- sb_EntityMapping, sb_EntityMapping.SBId == sb_SandBoxes.SBId
150
+ SBEntityMapping, SBEntityMapping.SBId == SandBoxes.SBId
149
151
  )
150
- sb_sel_stmt = sb_sel_stmt.where(sb_EntityMapping.EntityId == entity_id)
152
+ sb_sel_stmt = sb_sel_stmt.where(SBEntityMapping.EntityId == entity_id)
151
153
 
152
154
  result = await self.conn.execute(sb_sel_stmt)
153
155
  sb_ids = [row.SBId for row in result]
154
156
 
155
- del_stmt = sqlalchemy.delete(sb_EntityMapping).where(
156
- sb_EntityMapping.EntityId == entity_id
157
+ del_stmt = delete(SBEntityMapping).where(
158
+ SBEntityMapping.EntityId == entity_id
157
159
  )
158
160
  await self.conn.execute(del_stmt)
159
161
 
160
- sb_entity_sel_stmt = sqlalchemy.select(sb_EntityMapping.SBId).where(
161
- sb_EntityMapping.SBId.in_(sb_ids)
162
+ sb_entity_sel_stmt = select(SBEntityMapping.SBId).where(
163
+ SBEntityMapping.SBId.in_(sb_ids)
162
164
  )
163
165
  result = await self.conn.execute(sb_entity_sel_stmt)
164
166
  remaining_sb_ids = [row.SBId for row in result]
165
167
  if not remaining_sb_ids:
166
168
  unassign_stmt = (
167
- sqlalchemy.update(sb_SandBoxes)
168
- .where(sb_SandBoxes.SBId.in_(sb_ids))
169
+ update(SandBoxes)
170
+ .where(SandBoxes.SBId.in_(sb_ids))
169
171
  .values(Assigned=False)
170
172
  )
171
173
  await self.conn.execute(unassign_stmt)
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from sqlalchemy import (
2
4
  BigInteger,
3
5
  Boolean,
@@ -14,7 +16,7 @@ from diracx.db.sql.utils import Column, DateNowColumn
14
16
  Base = declarative_base()
15
17
 
16
18
 
17
- class sb_Owners(Base):
19
+ class SBOwners(Base):
18
20
  __tablename__ = "sb_Owners"
19
21
  OwnerID = Column(Integer, autoincrement=True)
20
22
  Owner = Column(String(32))
@@ -23,7 +25,7 @@ class sb_Owners(Base):
23
25
  __table_args__ = (PrimaryKeyConstraint("OwnerID"),)
24
26
 
25
27
 
26
- class sb_SandBoxes(Base):
28
+ class SandBoxes(Base):
27
29
  __tablename__ = "sb_SandBoxes"
28
30
  SBId = Column(Integer, autoincrement=True)
29
31
  OwnerId = Column(Integer)
@@ -40,7 +42,7 @@ class sb_SandBoxes(Base):
40
42
  )
41
43
 
42
44
 
43
- class sb_EntityMapping(Base):
45
+ class SBEntityMapping(Base):
44
46
  __tablename__ = "sb_EntityMapping"
45
47
  SBId = Column(Integer)
46
48
  EntityId = Column(String(128))
@@ -121,12 +121,12 @@ class TaskQueueDB(BaseSQLDB):
121
121
  # TODO: I guess the rows are already a list of tupes
122
122
  # maybe refactor
123
123
  data = [(r[0], r[1]) for r in rows if r]
124
- numOwners = len(data)
124
+ num_owners = len(data)
125
125
  # If there are no owners do now
126
- if numOwners == 0:
126
+ if num_owners == 0:
127
127
  return
128
128
  # Split the share amongst the number of owners
129
- entities_shares = {row[0]: job_share / numOwners for row in data}
129
+ entities_shares = {row[0]: job_share / num_owners for row in data}
130
130
 
131
131
  # TODO: implement the following
132
132
  # If corrector is enabled let it work it's magic
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from sqlalchemy import (
2
4
  BigInteger,
3
5
  Boolean,