diracx-db 0.0.1a23__tar.gz → 0.0.1a25__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (60) hide show
  1. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/PKG-INFO +2 -2
  2. diracx_db-0.0.1a25/src/diracx/db/exceptions.py +5 -0
  3. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/os/utils.py +3 -3
  4. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/auth/db.py +9 -9
  5. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/auth/schema.py +25 -23
  6. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/dummy/db.py +2 -2
  7. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/dummy/schema.py +8 -6
  8. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/job/db.py +57 -54
  9. diracx_db-0.0.1a25/src/diracx/db/sql/job/schema.py +131 -0
  10. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/job_logging/db.py +32 -32
  11. diracx_db-0.0.1a25/src/diracx/db/sql/job_logging/schema.py +29 -0
  12. diracx_db-0.0.1a25/src/diracx/db/sql/pilot_agents/schema.py +60 -0
  13. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/sandbox_metadata/db.py +42 -40
  14. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/sandbox_metadata/schema.py +5 -3
  15. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/task_queue/db.py +3 -3
  16. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/task_queue/schema.py +2 -0
  17. diracx_db-0.0.1a25/src/diracx/db/sql/utils/__init__.py +24 -0
  18. diracx_db-0.0.1a23/src/diracx/db/sql/utils/__init__.py → diracx_db-0.0.1a25/src/diracx/db/sql/utils/base.py +22 -147
  19. diracx_db-0.0.1a25/src/diracx/db/sql/utils/functions.py +105 -0
  20. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/utils/job.py +58 -55
  21. diracx_db-0.0.1a25/src/diracx/db/sql/utils/types.py +43 -0
  22. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx_db.egg-info/PKG-INFO +2 -2
  23. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx_db.egg-info/SOURCES.txt +7 -4
  24. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/auth/test_authorization_flow.py +1 -1
  25. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/auth/test_device_flow.py +3 -3
  26. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/auth/test_refresh_token.py +15 -10
  27. diracx_db-0.0.1a23/tests/jobs/test_jobDB.py → diracx_db-0.0.1a25/tests/jobs/test_job_db.py +2 -2
  28. diracx_db-0.0.1a23/tests/jobs/test_jobLoggingDB.py → diracx_db-0.0.1a25/tests/jobs/test_job_logging_db.py +2 -0
  29. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/jobs/test_sandbox_metadata.py +11 -10
  30. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/opensearch/test_connection.py +2 -2
  31. diracx_db-0.0.1a23/tests/test_dummyDB.py → diracx_db-0.0.1a25/tests/test_dummy_db.py +17 -17
  32. diracx_db-0.0.1a23/src/diracx/db/exceptions.py +0 -2
  33. diracx_db-0.0.1a23/src/diracx/db/sql/job/schema.py +0 -129
  34. diracx_db-0.0.1a23/src/diracx/db/sql/job_logging/schema.py +0 -25
  35. diracx_db-0.0.1a23/src/diracx/db/sql/pilot_agents/schema.py +0 -58
  36. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/README.md +0 -0
  37. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/pyproject.toml +0 -0
  38. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/setup.cfg +0 -0
  39. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/__init__.py +0 -0
  40. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/__main__.py +0 -0
  41. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/os/__init__.py +0 -0
  42. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/os/job_parameters.py +0 -0
  43. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/py.typed +0 -0
  44. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/__init__.py +0 -0
  45. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/auth/__init__.py +0 -0
  46. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/dummy/__init__.py +0 -0
  47. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/job/__init__.py +0 -0
  48. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/job_logging/__init__.py +0 -0
  49. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/pilot_agents/__init__.py +0 -0
  50. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/pilot_agents/db.py +0 -0
  51. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/sandbox_metadata/__init__.py +0 -0
  52. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx/db/sql/task_queue/__init__.py +0 -0
  53. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx_db.egg-info/dependency_links.txt +0 -0
  54. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx_db.egg-info/entry_points.txt +0 -0
  55. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx_db.egg-info/requires.txt +0 -0
  56. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/src/diracx_db.egg-info/top_level.txt +0 -0
  57. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/opensearch/test_index_template.py +0 -0
  58. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/opensearch/test_search.py +0 -0
  59. {diracx_db-0.0.1a23 → diracx_db-0.0.1a25}/tests/pilot_agents/__init__.py +0 -0
  60. /diracx_db-0.0.1a23/tests/pilot_agents/test_pilotAgentsDB.py → /diracx_db-0.0.1a25/tests/pilot_agents/test_pilot_agents_db.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: diracx-db
3
- Version: 0.0.1a23
3
+ Version: 0.0.1a25
4
4
  Summary: TODO
5
5
  License: GPL-3.0-only
6
6
  Classifier: Intended Audience :: Science/Research
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ class DBUnavailableError(Exception):
5
+ pass
@@ -16,7 +16,7 @@ from opensearchpy import AsyncOpenSearch
16
16
 
17
17
  from diracx.core.exceptions import InvalidQueryError
18
18
  from diracx.core.extensions import select_from_extension
19
- from diracx.db.exceptions import DBUnavailable
19
+ from diracx.db.exceptions import DBUnavailableError
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -25,7 +25,7 @@ class OpenSearchDBError(Exception):
25
25
  pass
26
26
 
27
27
 
28
- class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
28
+ class OpenSearchDBUnavailableError(DBUnavailableError, OpenSearchDBError):
29
29
  pass
30
30
 
31
31
 
@@ -152,7 +152,7 @@ class BaseOSDB(metaclass=ABCMeta):
152
152
  be ran at every query.
153
153
  """
154
154
  if not await self.client.ping():
155
- raise OpenSearchDBUnavailable(
155
+ raise OpenSearchDBUnavailableError(
156
156
  f"Failed to connect to {self.__class__.__qualname__}"
157
157
  )
158
158
 
@@ -58,7 +58,7 @@ class AuthDB(BaseSQLDB):
58
58
  stmt = select(
59
59
  DeviceFlows,
60
60
  (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label(
61
- "is_expired"
61
+ "IsExpired"
62
62
  ),
63
63
  ).with_for_update()
64
64
  stmt = stmt.where(
@@ -66,10 +66,10 @@ class AuthDB(BaseSQLDB):
66
66
  )
67
67
  res = dict((await self.conn.execute(stmt)).one()._mapping)
68
68
 
69
- if res["is_expired"]:
69
+ if res["IsExpired"]:
70
70
  raise ExpiredFlowError()
71
71
 
72
- if res["status"] == FlowStatus.READY:
72
+ if res["Status"] == FlowStatus.READY:
73
73
  # Update the status to Done before returning
74
74
  await self.conn.execute(
75
75
  update(DeviceFlows)
@@ -81,10 +81,10 @@ class AuthDB(BaseSQLDB):
81
81
  )
82
82
  return res
83
83
 
84
- if res["status"] == FlowStatus.DONE:
84
+ if res["Status"] == FlowStatus.DONE:
85
85
  raise AuthorizationError("Code was already used")
86
86
 
87
- if res["status"] == FlowStatus.PENDING:
87
+ if res["Status"] == FlowStatus.PENDING:
88
88
  raise PendingAuthorizationError()
89
89
 
90
90
  raise AuthorizationError("Bad state in device flow")
@@ -190,7 +190,7 @@ class AuthDB(BaseSQLDB):
190
190
  stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
191
191
  stmt = stmt.where(AuthorizationFlows.uuid == uuid)
192
192
  row = (await self.conn.execute(stmt)).one()
193
- return code, row.redirect_uri
193
+ return code, row.RedirectURI
194
194
 
195
195
  async def get_authorization_flow(self, code: str, max_validity: int):
196
196
  hashed_code = hashlib.sha256(code.encode()).hexdigest()
@@ -205,7 +205,7 @@ class AuthDB(BaseSQLDB):
205
205
 
206
206
  res = dict((await self.conn.execute(stmt)).one()._mapping)
207
207
 
208
- if res["status"] == FlowStatus.READY:
208
+ if res["Status"] == FlowStatus.READY:
209
209
  # Update the status to Done before returning
210
210
  await self.conn.execute(
211
211
  update(AuthorizationFlows)
@@ -215,7 +215,7 @@ class AuthDB(BaseSQLDB):
215
215
 
216
216
  return res
217
217
 
218
- if res["status"] == FlowStatus.DONE:
218
+ if res["Status"] == FlowStatus.DONE:
219
219
  raise AuthorizationError("Code was already used")
220
220
 
221
221
  raise AuthorizationError("Bad state in authorization flow")
@@ -247,7 +247,7 @@ class AuthDB(BaseSQLDB):
247
247
  row = (await self.conn.execute(stmt)).one()
248
248
 
249
249
  # Return the JWT ID and the creation time
250
- return jti, row.creation_time
250
+ return jti, row.CreationTime
251
251
 
252
252
  async def get_refresh_token(self, jti: str) -> dict:
253
253
  """Get refresh token details bound to a given JWT ID."""
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from enum import Enum, auto
2
4
 
3
5
  from sqlalchemy import (
@@ -39,27 +41,27 @@ class FlowStatus(Enum):
39
41
 
40
42
  class DeviceFlows(Base):
41
43
  __tablename__ = "DeviceFlows"
42
- user_code = Column(String(USER_CODE_LENGTH), primary_key=True)
43
- status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
44
- creation_time = DateNowColumn()
45
- client_id = Column(String(255))
46
- scope = Column(String(1024))
47
- device_code = Column(String(128), unique=True) # Should be a hash
48
- id_token = NullColumn(JSON())
44
+ user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True)
45
+ status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
46
+ creation_time = DateNowColumn("CreationTime")
47
+ client_id = Column("ClientID", String(255))
48
+ scope = Column("Scope", String(1024))
49
+ device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash
50
+ id_token = NullColumn("IDToken", JSON())
49
51
 
50
52
 
51
53
  class AuthorizationFlows(Base):
52
54
  __tablename__ = "AuthorizationFlows"
53
- uuid = Column(Uuid(as_uuid=False), primary_key=True)
54
- status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
55
- client_id = Column(String(255))
56
- creation_time = DateNowColumn()
57
- scope = Column(String(1024))
58
- code_challenge = Column(String(255))
59
- code_challenge_method = Column(String(8))
60
- redirect_uri = Column(String(255))
61
- code = NullColumn(String(255)) # Should be a hash
62
- id_token = NullColumn(JSON())
55
+ uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True)
56
+ status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
57
+ client_id = Column("ClientID", String(255))
58
+ creation_time = DateNowColumn("CreationTime")
59
+ scope = Column("Scope", String(1024))
60
+ code_challenge = Column("CodeChallenge", String(255))
61
+ code_challenge_method = Column("CodeChallengeMethod", String(8))
62
+ redirect_uri = Column("RedirectURI", String(255))
63
+ code = NullColumn("Code", String(255)) # Should be a hash
64
+ id_token = NullColumn("IDToken", JSON())
63
65
 
64
66
 
65
67
  class RefreshTokenStatus(Enum):
@@ -85,13 +87,13 @@ class RefreshTokens(Base):
85
87
 
86
88
  __tablename__ = "RefreshTokens"
87
89
  # Refresh token attributes
88
- jti = Column(Uuid(as_uuid=False), primary_key=True)
90
+ jti = Column("JTI", Uuid(as_uuid=False), primary_key=True)
89
91
  status = EnumColumn(
90
- RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
92
+ "Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
91
93
  )
92
- creation_time = DateNowColumn()
93
- scope = Column(String(1024))
94
+ creation_time = DateNowColumn("CreationTime")
95
+ scope = Column("Scope", String(1024))
94
96
 
95
97
  # User attributes bound to the refresh token
96
- sub = Column(String(1024))
97
- preferred_username = Column(String(255))
98
+ sub = Column("Sub", String(1024))
99
+ preferred_username = Column("PreferredUsername", String(255))
@@ -25,7 +25,7 @@ class DummyDB(BaseSQLDB):
25
25
  async def summary(self, group_by, search) -> list[dict[str, str | int]]:
26
26
  columns = [Cars.__table__.columns[x] for x in group_by]
27
27
 
28
- stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
28
+ stmt = select(*columns, func.count(Cars.license_plate).label("count"))
29
29
  stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
30
30
  stmt = stmt.group_by(*columns)
31
31
 
@@ -44,7 +44,7 @@ class DummyDB(BaseSQLDB):
44
44
 
45
45
  async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int:
46
46
  stmt = insert(Cars).values(
47
- licensePlate=license_plate, model=model, ownerID=owner_id
47
+ license_plate=license_plate, model=model, owner_id=owner_id
48
48
  )
49
49
 
50
50
  result = await self.conn.execute(stmt)
@@ -1,5 +1,7 @@
1
1
  # The utils class define some boilerplate types that should be used
2
2
  # in place of the SQLAlchemy one. Have a look at them
3
+ from __future__ import annotations
4
+
3
5
  from sqlalchemy import ForeignKey, Integer, String, Uuid
4
6
  from sqlalchemy.orm import declarative_base
5
7
 
@@ -10,13 +12,13 @@ Base = declarative_base()
10
12
 
11
13
  class Owners(Base):
12
14
  __tablename__ = "Owners"
13
- ownerID = Column(Integer, primary_key=True, autoincrement=True)
14
- creation_time = DateNowColumn()
15
- name = Column(String(255))
15
+ owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True)
16
+ creation_time = DateNowColumn("CreationTime")
17
+ name = Column("Name", String(255))
16
18
 
17
19
 
18
20
  class Cars(Base):
19
21
  __tablename__ = "Cars"
20
- licensePlate = Column(Uuid(), primary_key=True)
21
- model = Column(String(255))
22
- ownerID = Column(Integer, ForeignKey(Owners.ownerID))
22
+ license_plate = Column("LicensePlate", Uuid(), primary_key=True)
23
+ model = Column("Model", String(255))
24
+ owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id))
@@ -3,12 +3,13 @@ from __future__ import annotations
3
3
  from datetime import datetime, timezone
4
4
  from typing import TYPE_CHECKING, Any
5
5
 
6
- from sqlalchemy import bindparam, delete, func, insert, select, update
6
+ from sqlalchemy import bindparam, case, delete, func, insert, select, update
7
7
  from sqlalchemy.exc import IntegrityError, NoResultFound
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from sqlalchemy.sql.elements import BindParameter
11
- from diracx.core.exceptions import InvalidQueryError, JobNotFound
11
+
12
+ from diracx.core.exceptions import InvalidQueryError, JobNotFoundError
12
13
  from diracx.core.models import (
13
14
  LimitedJobStatusReturn,
14
15
  SearchSpec,
@@ -42,12 +43,12 @@ class JobDB(BaseSQLDB):
42
43
  # TODO: this is copied from the DIRAC JobDB
43
44
  # but is overwriten in LHCbDIRAC, so we need
44
45
  # to find a way to make it dynamic
45
- jdl2DBParameters = ["JobName", "JobType", "JobGroup"]
46
+ jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"]
46
47
 
47
48
  async def summary(self, group_by, search) -> list[dict[str, str | int]]:
48
49
  columns = _get_columns(Jobs.__table__, group_by)
49
50
 
50
- stmt = select(*columns, func.count(Jobs.JobID).label("count"))
51
+ stmt = select(*columns, func.count(Jobs.job_id).label("count"))
51
52
  stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
52
53
  stmt = stmt.group_by(*columns)
53
54
 
@@ -110,11 +111,11 @@ class JobDB(BaseSQLDB):
110
111
  ],
111
112
  )
112
113
 
113
- async def setJobAttributes(self, job_id, jobData):
114
+ async def set_job_attributes(self, job_id, job_data):
114
115
  """TODO: add myDate and force parameters."""
115
- if "Status" in jobData:
116
- jobData = jobData | {"LastUpdateTime": datetime.now(tz=timezone.utc)}
117
- stmt = update(Jobs).where(Jobs.JobID == job_id).values(jobData)
116
+ if "Status" in job_data:
117
+ job_data = job_data | {"LastUpdateTime": datetime.now(tz=timezone.utc)}
118
+ stmt = update(Jobs).where(Jobs.job_id == job_id).values(job_data)
118
119
  await self.conn.execute(stmt)
119
120
 
120
121
  async def create_job(self, original_jdl):
@@ -159,9 +160,9 @@ class JobDB(BaseSQLDB):
159
160
  ],
160
161
  )
161
162
 
162
- async def checkAndPrepareJob(
163
+ async def check_and_prepare_job(
163
164
  self,
164
- jobID,
165
+ job_id,
165
166
  class_ad_job,
166
167
  class_ad_req,
167
168
  owner,
@@ -178,8 +179,8 @@ class JobDB(BaseSQLDB):
178
179
  checkAndPrepareJob,
179
180
  )
180
181
 
181
- retVal = checkAndPrepareJob(
182
- jobID,
182
+ ret_val = checkAndPrepareJob(
183
+ job_id,
183
184
  class_ad_job,
184
185
  class_ad_req,
185
186
  owner,
@@ -188,21 +189,21 @@ class JobDB(BaseSQLDB):
188
189
  vo,
189
190
  )
190
191
 
191
- if not retVal["OK"]:
192
- if cmpError(retVal, EWMSSUBM):
193
- await self.setJobAttributes(jobID, job_attrs)
192
+ if not ret_val["OK"]:
193
+ if cmpError(ret_val, EWMSSUBM):
194
+ await self.set_job_attributes(job_id, job_attrs)
194
195
 
195
- returnValueOrRaise(retVal)
196
+ returnValueOrRaise(ret_val)
196
197
 
197
- async def setJobJDL(self, job_id, jdl):
198
+ async def set_job_jdl(self, job_id, jdl):
198
199
  from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
199
200
 
200
201
  stmt = (
201
- update(JobJDLs).where(JobJDLs.JobID == job_id).values(JDL=compressJDL(jdl))
202
+ update(JobJDLs).where(JobJDLs.job_id == job_id).values(JDL=compressJDL(jdl))
202
203
  )
203
204
  await self.conn.execute(stmt)
204
205
 
205
- async def setJobJDLsBulk(self, jdls):
206
+ async def set_job_jdl_bulk(self, jdls):
206
207
  from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import compressJDL
207
208
 
208
209
  await self.conn.execute(
@@ -212,44 +213,46 @@ class JobDB(BaseSQLDB):
212
213
  [{"b_JobID": jid, "JDL": compressJDL(jdl)} for jid, jdl in jdls.items()],
213
214
  )
214
215
 
215
- async def setJobAttributesBulk(self, jobData):
216
+ async def set_job_attributes_bulk(self, job_data):
216
217
  """TODO: add myDate and force parameters."""
217
- for job_id in jobData.keys():
218
- if "Status" in jobData[job_id]:
219
- jobData[job_id].update(
218
+ for job_id in job_data.keys():
219
+ if "Status" in job_data[job_id]:
220
+ job_data[job_id].update(
220
221
  {"LastUpdateTime": datetime.now(tz=timezone.utc)}
221
222
  )
223
+ columns = set(key for attrs in job_data.values() for key in attrs.keys())
224
+ case_expressions = {
225
+ column: case(
226
+ *[
227
+ (Jobs.__table__.c.JobID == job_id, attrs[column])
228
+ for job_id, attrs in job_data.items()
229
+ if column in attrs
230
+ ],
231
+ else_=getattr(Jobs.__table__.c, column), # Retain original value
232
+ )
233
+ for column in columns
234
+ }
222
235
 
223
- await self.conn.execute(
224
- Jobs.__table__.update().where(
225
- Jobs.__table__.c.JobID == bindparam("b_JobID")
226
- ),
227
- [{"b_JobID": job_id, **attrs} for job_id, attrs in jobData.items()],
236
+ stmt = (
237
+ Jobs.__table__.update()
238
+ .values(**case_expressions)
239
+ .where(Jobs.__table__.c.JobID.in_(job_data.keys()))
228
240
  )
241
+ await self.conn.execute(stmt)
229
242
 
230
- async def getJobJDL(self, job_id: int, original: bool = False) -> str:
231
- from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
232
-
233
- if original:
234
- stmt = select(JobJDLs.OriginalJDL).where(JobJDLs.JobID == job_id)
235
- else:
236
- stmt = select(JobJDLs.JDL).where(JobJDLs.JobID == job_id)
237
-
238
- jdl = (await self.conn.execute(stmt)).scalar_one()
239
- if jdl:
240
- jdl = extractJDL(jdl)
241
-
242
- return jdl
243
-
244
- async def getJobJDLs(self, job_ids, original: bool = False) -> dict[int | str, str]:
243
+ async def get_job_jdls(
244
+ self, job_ids, original: bool = False
245
+ ) -> dict[int | str, str]:
245
246
  from DIRAC.WorkloadManagementSystem.DB.JobDBUtils import extractJDL
246
247
 
247
248
  if original:
248
- stmt = select(JobJDLs.JobID, JobJDLs.OriginalJDL).where(
249
- JobJDLs.JobID.in_(job_ids)
249
+ stmt = select(JobJDLs.job_id, JobJDLs.original_jdl).where(
250
+ JobJDLs.job_id.in_(job_ids)
250
251
  )
251
252
  else:
252
- stmt = select(JobJDLs.JobID, JobJDLs.JDL).where(JobJDLs.JobID.in_(job_ids))
253
+ stmt = select(JobJDLs.job_id, JobJDLs.jdl).where(
254
+ JobJDLs.job_id.in_(job_ids)
255
+ )
253
256
 
254
257
  return {
255
258
  jobid: extractJDL(jdl)
@@ -259,14 +262,14 @@ class JobDB(BaseSQLDB):
259
262
 
260
263
  async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn:
261
264
  try:
262
- stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where(
263
- Jobs.JobID == job_id
264
- )
265
+ stmt = select(
266
+ Jobs.status, Jobs.minor_status, Jobs.application_status
267
+ ).where(Jobs.job_id == job_id)
265
268
  return LimitedJobStatusReturn(
266
269
  **dict((await self.conn.execute(stmt)).one()._mapping)
267
270
  )
268
271
  except NoResultFound as e:
269
- raise JobNotFound(job_id) from e
272
+ raise JobNotFoundError(job_id) from e
270
273
 
271
274
  async def set_job_command(self, job_id: int, command: str, arguments: str = ""):
272
275
  """Store a command to be passed to the job together with the next heart beat."""
@@ -279,11 +282,11 @@ class JobDB(BaseSQLDB):
279
282
  )
280
283
  await self.conn.execute(stmt)
281
284
  except IntegrityError as e:
282
- raise JobNotFound(job_id) from e
285
+ raise JobNotFoundError(job_id) from e
283
286
 
284
287
  async def set_job_command_bulk(self, commands):
285
288
  """Store a command to be passed to the job together with the next heart beat."""
286
- self.conn.execute(
289
+ await self.conn.execute(
287
290
  insert(JobCommands),
288
291
  [
289
292
  {
@@ -299,7 +302,7 @@ class JobDB(BaseSQLDB):
299
302
 
300
303
  async def delete_jobs(self, job_ids: list[int]):
301
304
  """Delete jobs from the database."""
302
- stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids))
305
+ stmt = delete(JobJDLs).where(JobJDLs.job_id.in_(job_ids))
303
306
  await self.conn.execute(stmt)
304
307
 
305
308
  async def set_properties(
@@ -332,7 +335,7 @@ class JobDB(BaseSQLDB):
332
335
  if update_timestamp:
333
336
  values["LastUpdateTime"] = datetime.now(tz=timezone.utc)
334
337
 
335
- stmt = update(Jobs).where(Jobs.JobID == bindparam("job_id")).values(**values)
338
+ stmt = update(Jobs).where(Jobs.job_id == bindparam("job_id")).values(**values)
336
339
  rows = await self.conn.execute(stmt, update_parameters)
337
340
 
338
341
  return rows.rowcount
@@ -0,0 +1,131 @@
1
+ from __future__ import annotations
2
+
3
+ from sqlalchemy import (
4
+ DateTime,
5
+ Enum,
6
+ ForeignKey,
7
+ Index,
8
+ Integer,
9
+ String,
10
+ Text,
11
+ )
12
+ from sqlalchemy.orm import declarative_base
13
+
14
+ from ..utils import Column, EnumBackedBool, NullColumn
15
+
16
+ JobDBBase = declarative_base()
17
+
18
+
19
+ class Jobs(JobDBBase):
20
+ __tablename__ = "Jobs"
21
+
22
+ job_id = Column(
23
+ "JobID",
24
+ Integer,
25
+ ForeignKey("JobJDLs.JobID", ondelete="CASCADE"),
26
+ primary_key=True,
27
+ default=0,
28
+ )
29
+ job_type = Column("JobType", String(32), default="user")
30
+ job_group = Column("JobGroup", String(32), default="00000000")
31
+ site = Column("Site", String(100), default="ANY")
32
+ job_name = Column("JobName", String(128), default="Unknown")
33
+ owner = Column("Owner", String(64), default="Unknown")
34
+ owner_group = Column("OwnerGroup", String(128), default="Unknown")
35
+ vo = Column("VO", String(32))
36
+ submission_time = NullColumn("SubmissionTime", DateTime)
37
+ reschedule_time = NullColumn("RescheduleTime", DateTime)
38
+ last_update_time = NullColumn("LastUpdateTime", DateTime)
39
+ start_exec_time = NullColumn("StartExecTime", DateTime)
40
+ heart_beat_time = NullColumn("HeartBeatTime", DateTime)
41
+ end_exec_time = NullColumn("EndExecTime", DateTime)
42
+ status = Column("Status", String(32), default="Received")
43
+ minor_status = Column("MinorStatus", String(128), default="Unknown")
44
+ application_status = Column("ApplicationStatus", String(255), default="Unknown")
45
+ user_priority = Column("UserPriority", Integer, default=0)
46
+ reschedule_counter = Column("RescheduleCounter", Integer, default=0)
47
+ verified_flag = Column("VerifiedFlag", EnumBackedBool(), default=False)
48
+ # TODO: Should this be True/False/"Failed"? Or True/False/Null?
49
+ accounted_flag = Column(
50
+ "AccountedFlag", Enum("True", "False", "Failed"), default="False"
51
+ )
52
+
53
+ __table_args__ = (
54
+ Index("JobType", "JobType"),
55
+ Index("JobGroup", "JobGroup"),
56
+ Index("Site", "Site"),
57
+ Index("Owner", "Owner"),
58
+ Index("OwnerGroup", "OwnerGroup"),
59
+ Index("Status", "Status"),
60
+ Index("MinorStatus", "MinorStatus"),
61
+ Index("ApplicationStatus", "ApplicationStatus"),
62
+ Index("StatusSite", "Status", "Site"),
63
+ Index("LastUpdateTime", "LastUpdateTime"),
64
+ )
65
+
66
+
67
+ class JobJDLs(JobDBBase):
68
+ __tablename__ = "JobJDLs"
69
+ job_id = Column("JobID", Integer, autoincrement=True, primary_key=True)
70
+ jdl = Column("JDL", Text)
71
+ job_requirements = Column("JobRequirements", Text)
72
+ original_jdl = Column("OriginalJDL", Text)
73
+
74
+
75
+ class InputData(JobDBBase):
76
+ __tablename__ = "InputData"
77
+ job_id = Column(
78
+ "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
79
+ )
80
+ lfn = Column("LFN", String(255), default="", primary_key=True)
81
+ status = Column("Status", String(32), default="AprioriGood")
82
+
83
+
84
+ class JobParameters(JobDBBase):
85
+ __tablename__ = "JobParameters"
86
+ job_id = Column(
87
+ "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
88
+ )
89
+ name = Column("Name", String(100), primary_key=True)
90
+ value = Column("Value", Text)
91
+
92
+
93
+ class OptimizerParameters(JobDBBase):
94
+ __tablename__ = "OptimizerParameters"
95
+ job_id = Column(
96
+ "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
97
+ )
98
+ name = Column("Name", String(100), primary_key=True)
99
+ value = Column("Value", Text)
100
+
101
+
102
+ class AtticJobParameters(JobDBBase):
103
+ __tablename__ = "AtticJobParameters"
104
+ job_id = Column(
105
+ "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
106
+ )
107
+ name = Column("Name", String(100), primary_key=True)
108
+ value = Column("Value", Text)
109
+ reschedule_cycle = Column("RescheduleCycle", Integer)
110
+
111
+
112
+ class HeartBeatLoggingInfo(JobDBBase):
113
+ __tablename__ = "HeartBeatLoggingInfo"
114
+ job_id = Column(
115
+ "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
116
+ )
117
+ name = Column("Name", String(100), primary_key=True)
118
+ value = Column("Value", Text)
119
+ heart_beat_time = Column("HeartBeatTime", DateTime, primary_key=True)
120
+
121
+
122
+ class JobCommands(JobDBBase):
123
+ __tablename__ = "JobCommands"
124
+ job_id = Column(
125
+ "JobID", Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True
126
+ )
127
+ command = Column("Command", String(100))
128
+ arguments = Column("Arguments", String(100))
129
+ status = Column("Status", String(64), default="Received")
130
+ reception_time = Column("ReceptionTime", DateTime, primary_key=True)
131
+ execution_time = NullColumn("ExecutionTime", DateTime)