diracx-db 0.0.1a46__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of diracx-db might be problematic. Click here for more details.

diracx/db/__main__.py CHANGED
@@ -39,6 +39,7 @@ async def init_sql():
39
39
  if db._db_url.startswith("sqlite"):
40
40
  await conn.exec_driver_sql("PRAGMA foreign_keys=ON")
41
41
  await conn.run_sync(db.metadata.create_all)
42
+ await db.post_create(conn)
42
43
 
43
44
 
44
45
  async def init_os():
@@ -9,13 +9,19 @@ class JobParametersDB(BaseOSDB):
9
9
  fields = {
10
10
  "JobID": {"type": "long"},
11
11
  "timestamp": {"type": "date"},
12
+ "PilotAgent": {"type": "keyword"},
13
+ "Pilot_Reference": {"type": "keyword"},
14
+ "JobGroup": {"type": "keyword"},
12
15
  "CPUNormalizationFactor": {"type": "long"},
13
16
  "NormCPUTime(s)": {"type": "long"},
14
- "Memory(kB)": {"type": "long"},
17
+ "Memory(MB)": {"type": "long"},
18
+ "LocalAccount": {"type": "keyword"},
15
19
  "TotalCPUTime(s)": {"type": "long"},
16
- "MemoryUsed(kb)": {"type": "long"},
17
- "HostName": {"type": "keyword"},
20
+ "PayloadPID": {"type": "long"},
21
+ "HostName": {"type": "text"},
18
22
  "GridCE": {"type": "keyword"},
23
+ "CEQueue": {"type": "keyword"},
24
+ "BatchSystem": {"type": "keyword"},
19
25
  "ModelName": {"type": "keyword"},
20
26
  "Status": {"type": "keyword"},
21
27
  "JobType": {"type": "keyword"},
diracx/db/sql/auth/db.py CHANGED
@@ -1,16 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import secrets
5
+ from datetime import UTC, datetime
6
+ from itertools import pairwise
4
7
 
5
- from sqlalchemy import insert, select, update
8
+ from dateutil.rrule import MONTHLY, rrule
9
+ from sqlalchemy import insert, select, text, update
6
10
  from sqlalchemy.exc import IntegrityError, NoResultFound
11
+ from sqlalchemy.ext.asyncio import AsyncConnection
7
12
  from uuid_utils import UUID, uuid7
8
13
 
9
14
  from diracx.core.exceptions import (
10
15
  AuthorizationError,
11
16
  TokenNotFoundError,
12
17
  )
13
- from diracx.db.sql.utils import BaseSQLDB, hash, substract_date
18
+ from diracx.db.sql.utils import BaseSQLDB, hash, substract_date, uuid7_from_datetime
14
19
 
15
20
  from .schema import (
16
21
  AuthorizationFlows,
@@ -25,10 +30,72 @@ from .schema import Base as AuthDBBase
25
30
  USER_CODE_ALPHABET = "BCDFGHJKLMNPQRSTVWXZ"
26
31
  MAX_RETRY = 5
27
32
 
33
+ logger = logging.getLogger(__name__)
34
+
28
35
 
29
36
  class AuthDB(BaseSQLDB):
30
37
  metadata = AuthDBBase.metadata
31
38
 
39
+ @classmethod
40
+ async def post_create(cls, conn: AsyncConnection) -> None:
41
+ """Create partitions if it is a MySQL DB and it does not have
42
+ it yet and the table does not have any data yet.
43
+ We do this as a post_create step as sqlalchemy does not support
44
+ partition so well.
45
+ """
46
+ if conn.dialect.name == "mysql":
47
+ check_partition_query = text(
48
+ "SELECT PARTITION_NAME FROM information_schema.partitions "
49
+ "WHERE TABLE_NAME = 'RefreshTokens' AND PARTITION_NAME is not NULL"
50
+ )
51
+ partition_names = (await conn.execute(check_partition_query)).all()
52
+
53
+ if not partition_names:
54
+ # Create a monthly partition from today until 2 years
55
+ # The partition are named p_<year>_<month>
56
+ start_date = datetime.now(tz=UTC).replace(
57
+ day=1, hour=0, minute=0, second=0, microsecond=0
58
+ )
59
+ end_date = start_date.replace(year=start_date.year + 2)
60
+
61
+ dates = [
62
+ dt for dt in rrule(MONTHLY, dtstart=start_date, until=end_date)
63
+ ]
64
+
65
+ partition_list = []
66
+ for name, limit in pairwise(dates):
67
+ partition_list.append(
68
+ f"PARTITION p_{name.year}_{name.month} "
69
+ f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')"
70
+ )
71
+ partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")
72
+
73
+ alter_query = text(
74
+ f"ALTER TABLE RefreshTokens PARTITION BY RANGE COLUMNS (JTI) ({','.join(partition_list)})"
75
+ )
76
+
77
+ check_table_empty_query = text("SELECT * FROM RefreshTokens LIMIT 1")
78
+ refresh_table_content = (
79
+ await conn.execute(check_table_empty_query)
80
+ ).all()
81
+ if refresh_table_content:
82
+ logger.warning(
83
+ "RefreshTokens table not empty. Run the following query yourself"
84
+ )
85
+ logger.warning(alter_query)
86
+ return
87
+
88
+ await conn.execute(alter_query)
89
+
90
+ partition_names = (
91
+ await conn.execute(
92
+ check_partition_query, {"table_name": "RefreshTokens"}
93
+ )
94
+ ).all()
95
+ assert partition_names, (
96
+ f"There should be partitions now {partition_names}"
97
+ )
98
+
32
99
  async def device_flow_validate_user_code(
33
100
  self, user_code: str, max_validity: int
34
101
  ) -> str:
@@ -111,6 +178,10 @@ class AuthDB(BaseSQLDB):
111
178
  await self.conn.execute(stmt)
112
179
 
113
180
  except IntegrityError:
181
+ logger.warning(
182
+ "Device flow code collision detected, retrying (user_code=%s)",
183
+ user_code,
184
+ )
114
185
  continue
115
186
 
116
187
  return user_code, device_code
@@ -10,7 +10,12 @@ from sqlalchemy import (
10
10
  )
11
11
  from sqlalchemy.orm import declarative_base
12
12
 
13
- from diracx.db.sql.utils import Column, DateNowColumn, EnumColumn, NullColumn
13
+ from diracx.db.sql.utils import (
14
+ Column,
15
+ DateNowColumn,
16
+ EnumColumn,
17
+ NullColumn,
18
+ )
14
19
 
15
20
  USER_CODE_LENGTH = 8
16
21
 
@@ -92,7 +97,6 @@ class RefreshTokens(Base):
92
97
  status = EnumColumn(
93
98
  "Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
94
99
  )
95
- creation_time = DateNowColumn("CreationTime", index=True)
96
100
  scope = Column("Scope", String(1024))
97
101
 
98
102
  # User attributes bound to the refresh token
diracx/db/sql/dummy/db.py CHANGED
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from sqlalchemy import func, insert, select
3
+ from sqlalchemy import insert
4
4
  from uuid_utils import UUID
5
5
 
6
- from diracx.db.sql.utils import BaseSQLDB, apply_search_filters
6
+ from diracx.db.sql.utils import BaseSQLDB
7
7
 
8
8
  from .schema import Base as DummyDBBase
9
9
  from .schema import Cars, Owners
@@ -22,18 +22,7 @@ class DummyDB(BaseSQLDB):
22
22
  metadata = DummyDBBase.metadata
23
23
 
24
24
  async def summary(self, group_by, search) -> list[dict[str, str | int]]:
25
- columns = [Cars.__table__.columns[x] for x in group_by]
26
-
27
- stmt = select(*columns, func.count(Cars.license_plate).label("count"))
28
- stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
29
- stmt = stmt.group_by(*columns)
30
-
31
- # Execute the query
32
- return [
33
- dict(row._mapping)
34
- async for row in (await self.conn.stream(stmt))
35
- if row.count > 0 # type: ignore
36
- ]
25
+ return await self._summary(Cars, group_by, search)
37
26
 
38
27
  async def insert_owner(self, name: str) -> int:
39
28
  stmt = insert(Owners).values(name=name)
diracx/db/sql/job/db.py CHANGED
@@ -5,15 +5,16 @@ __all__ = ["JobDB"]
5
5
  from datetime import datetime, timezone
6
6
  from typing import TYPE_CHECKING, Any, Iterable
7
7
 
8
- from sqlalchemy import bindparam, case, delete, func, insert, select, update
8
+ from sqlalchemy import bindparam, case, delete, literal, select, update
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from sqlalchemy.sql.elements import BindParameter
12
+ from sqlalchemy.sql import expression
12
13
 
13
14
  from diracx.core.exceptions import InvalidQueryError
14
15
  from diracx.core.models import JobCommand, SearchSpec, SortSpec
15
16
 
16
- from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints
17
+ from ..utils import BaseSQLDB, _get_columns
17
18
  from ..utils.functions import utcnow
18
19
  from .schema import (
19
20
  HeartBeatLoggingInfo,
@@ -25,17 +26,6 @@ from .schema import (
25
26
  )
26
27
 
27
28
 
28
- def _get_columns(table, parameters):
29
- columns = [x for x in table.columns]
30
- if parameters:
31
- if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
32
- raise InvalidQueryError(
33
- f"Unrecognised parameters requested {unrecognised_parameters}"
34
- )
35
- columns = [c for c in columns if c.name in parameters]
36
- return columns
37
-
38
-
39
29
  class JobDB(BaseSQLDB):
40
30
  metadata = JobDBBase.metadata
41
31
 
@@ -54,20 +44,11 @@ class JobDB(BaseSQLDB):
54
44
  # to find a way to make it dynamic
55
45
  jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"]
56
46
 
57
- async def summary(self, group_by, search) -> list[dict[str, str | int]]:
47
+ async def summary(
48
+ self, group_by: list[str], search: list[SearchSpec]
49
+ ) -> list[dict[str, str | int]]:
58
50
  """Get a summary of the jobs."""
59
- columns = _get_columns(Jobs.__table__, group_by)
60
-
61
- stmt = select(*columns, func.count(Jobs.job_id).label("count"))
62
- stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
63
- stmt = stmt.group_by(*columns)
64
-
65
- # Execute the query
66
- return [
67
- dict(row._mapping)
68
- async for row in (await self.conn.stream(stmt))
69
- if row.count > 0 # type: ignore
70
- ]
51
+ return await self._summary(table=Jobs, group_by=group_by, search=search)
71
52
 
72
53
  async def search(
73
54
  self,
@@ -80,34 +61,15 @@ class JobDB(BaseSQLDB):
80
61
  page: int | None = None,
81
62
  ) -> tuple[int, list[dict[Any, Any]]]:
82
63
  """Search for jobs in the database."""
83
- # Find which columns to select
84
- columns = _get_columns(Jobs.__table__, parameters)
85
-
86
- stmt = select(*columns)
87
-
88
- stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
89
- stmt = apply_sort_constraints(Jobs.__table__.columns.__getitem__, stmt, sorts)
90
-
91
- if distinct:
92
- stmt = stmt.distinct()
93
-
94
- # Calculate total count before applying pagination
95
- total_count_subquery = stmt.alias()
96
- total_count_stmt = select(func.count()).select_from(total_count_subquery)
97
- total = (await self.conn.execute(total_count_stmt)).scalar_one()
98
-
99
- # Apply pagination
100
- if page is not None:
101
- if page < 1:
102
- raise InvalidQueryError("Page must be a positive integer")
103
- if per_page < 1:
104
- raise InvalidQueryError("Per page must be a positive integer")
105
- stmt = stmt.offset((page - 1) * per_page).limit(per_page)
106
-
107
- # Execute the query
108
- return total, [
109
- dict(row._mapping) async for row in (await self.conn.stream(stmt))
110
- ]
64
+ return await self._search(
65
+ table=Jobs,
66
+ parameters=parameters,
67
+ search=search,
68
+ sorts=sorts,
69
+ distinct=distinct,
70
+ per_page=per_page,
71
+ page=page,
72
+ )
111
73
 
112
74
  async def create_job(self, compressed_original_jdl: str):
113
75
  """Used to insert a new job with original JDL. Returns inserted job id."""
@@ -167,27 +129,14 @@ class JobDB(BaseSQLDB):
167
129
  ],
168
130
  )
169
131
 
170
- @staticmethod
171
- def _set_job_attributes_fix_value(column, value):
172
- """Apply corrections to the values before inserting them into the database.
173
-
174
- TODO: Move this logic into the sqlalchemy model.
175
- """
176
- if column == "VerifiedFlag":
177
- value_str = str(value)
178
- if value_str in ("True", "False"):
179
- return value_str
180
- if column == "AccountedFlag":
181
- value_str = str(value)
182
- if value_str in ("True", "False", "Failed"):
183
- return value_str
184
- else:
185
- return value
186
- raise NotImplementedError(f"Unrecognized value for column {column}: {value}")
187
-
188
132
  async def set_job_attributes(self, job_data):
189
133
  """Update the parameters of the given jobs."""
190
134
  # TODO: add myDate and force parameters.
135
+
136
+ if not job_data:
137
+ # nothing to do!
138
+ raise ValueError("job_data is empty")
139
+
191
140
  for job_id in job_data.keys():
192
141
  if "Status" in job_data[job_id]:
193
142
  job_data[job_id].update(
@@ -199,7 +148,11 @@ class JobDB(BaseSQLDB):
199
148
  *[
200
149
  (
201
150
  Jobs.__table__.c.JobID == job_id,
202
- self._set_job_attributes_fix_value(column, attrs[column]),
151
+ # Since the setting of the new column value is obscured by the CASE statement,
152
+ # ensure that SQLAlchemy renders the new column value with the correct type
153
+ literal(attrs[column], type_=Jobs.__table__.c[column].type)
154
+ if not isinstance(attrs[column], expression.FunctionElement)
155
+ else attrs[column],
203
156
  )
204
157
  for job_id, attrs in job_data.items()
205
158
  if column in attrs
@@ -232,7 +185,7 @@ class JobDB(BaseSQLDB):
232
185
  async def set_job_commands(self, commands: list[tuple[int, str, str]]) -> None:
233
186
  """Store a command to be passed to the job together with the next heart beat."""
234
187
  await self.conn.execute(
235
- insert(JobCommands),
188
+ JobCommands.__table__.insert(),
236
189
  [
237
190
  {
238
191
  "JobID": job_id,
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import sqlalchemy.types as types
4
4
  from sqlalchemy import (
5
- DateTime,
6
5
  ForeignKey,
7
6
  Index,
8
7
  Integer,
@@ -11,6 +10,8 @@ from sqlalchemy import (
11
10
  )
12
11
  from sqlalchemy.orm import declarative_base
13
12
 
13
+ from diracx.db.sql.utils.types import SmarterDateTime
14
+
14
15
  from ..utils import Column, EnumBackedBool, NullColumn
15
16
 
16
17
  JobDBBase = declarative_base()
@@ -19,11 +20,8 @@ JobDBBase = declarative_base()
19
20
  class AccountedFlagEnum(types.TypeDecorator):
20
21
  """Maps a ``AccountedFlagEnum()`` column to True/False in Python."""
21
22
 
22
- impl = types.Enum
23
- cache_ok: bool = True
24
-
25
- def __init__(self) -> None:
26
- super().__init__("True", "False", "Failed")
23
+ impl = types.Enum("True", "False", "Failed", name="accounted_flag_enum")
24
+ cache_ok = True
27
25
 
28
26
  def process_bind_param(self, value, dialect) -> str:
29
27
  if value is True:
@@ -63,12 +61,30 @@ class Jobs(JobDBBase):
63
61
  owner = Column("Owner", String(64), default="Unknown")
64
62
  owner_group = Column("OwnerGroup", String(128), default="Unknown")
65
63
  vo = Column("VO", String(32))
66
- submission_time = NullColumn("SubmissionTime", DateTime)
67
- reschedule_time = NullColumn("RescheduleTime", DateTime)
68
- last_update_time = NullColumn("LastUpdateTime", DateTime)
69
- start_exec_time = NullColumn("StartExecTime", DateTime)
70
- heart_beat_time = NullColumn("HeartBeatTime", DateTime)
71
- end_exec_time = NullColumn("EndExecTime", DateTime)
64
+ submission_time = NullColumn(
65
+ "SubmissionTime",
66
+ SmarterDateTime(),
67
+ )
68
+ reschedule_time = NullColumn(
69
+ "RescheduleTime",
70
+ SmarterDateTime(),
71
+ )
72
+ last_update_time = NullColumn(
73
+ "LastUpdateTime",
74
+ SmarterDateTime(),
75
+ )
76
+ start_exec_time = NullColumn(
77
+ "StartExecTime",
78
+ SmarterDateTime(),
79
+ )
80
+ heart_beat_time = NullColumn(
81
+ "HeartBeatTime",
82
+ SmarterDateTime(),
83
+ )
84
+ end_exec_time = NullColumn(
85
+ "EndExecTime",
86
+ SmarterDateTime(),
87
+ )
72
88
  status = Column("Status", String(32), default="Received")
73
89
  minor_status = Column("MinorStatus", String(128), default="Unknown")
74
90
  application_status = Column("ApplicationStatus", String(255), default="Unknown")
@@ -143,7 +159,11 @@ class HeartBeatLoggingInfo(JobDBBase):
143
159
  )
144
160
  name = Column("Name", String(100), primary_key=True)
145
161
  value = Column("Value", Text)
146
- heart_beat_time = Column("HeartBeatTime", DateTime, primary_key=True)
162
+ heart_beat_time = Column(
163
+ "HeartBeatTime",
164
+ SmarterDateTime(),
165
+ primary_key=True,
166
+ )
147
167
 
148
168
 
149
169
  class JobCommands(JobDBBase):
@@ -154,5 +174,12 @@ class JobCommands(JobDBBase):
154
174
  command = Column("Command", String(100))
155
175
  arguments = Column("Arguments", String(100))
156
176
  status = Column("Status", String(64), default="Received")
157
- reception_time = Column("ReceptionTime", DateTime, primary_key=True)
158
- execution_time = NullColumn("ExecutionTime", DateTime)
177
+ reception_time = Column(
178
+ "ReceptionTime",
179
+ SmarterDateTime(),
180
+ primary_key=True,
181
+ )
182
+ execution_time = NullColumn(
183
+ "ExecutionTime",
184
+ SmarterDateTime(),
185
+ )
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from sqlalchemy import (
4
- DateTime,
5
4
  Double,
6
5
  Index,
7
6
  Integer,
@@ -10,6 +9,8 @@ from sqlalchemy import (
10
9
  )
11
10
  from sqlalchemy.orm import declarative_base
12
11
 
12
+ from diracx.db.sql.utils.types import SmarterDateTime
13
+
13
14
  from ..utils import Column, EnumBackedBool, NullColumn
14
15
 
15
16
  PilotAgentsDBBase = declarative_base()
@@ -29,8 +30,8 @@ class PilotAgents(PilotAgentsDBBase):
29
30
  vo = Column("VO", String(128))
30
31
  grid_type = Column("GridType", String(32), default="LCG")
31
32
  benchmark = Column("BenchMark", Double, default=0.0)
32
- submission_time = NullColumn("SubmissionTime", DateTime)
33
- last_update_time = NullColumn("LastUpdateTime", DateTime)
33
+ submission_time = NullColumn("SubmissionTime", SmarterDateTime)
34
+ last_update_time = NullColumn("LastUpdateTime", SmarterDateTime)
34
35
  status = Column("Status", String(32), default="Unknown")
35
36
  status_reason = Column("StatusReason", String(255), default="Unknown")
36
37
  accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False)
@@ -47,7 +48,7 @@ class JobToPilotMapping(PilotAgentsDBBase):
47
48
 
48
49
  pilot_id = Column("PilotID", Integer, primary_key=True)
49
50
  job_id = Column("JobID", Integer, primary_key=True)
50
- start_time = Column("StartTime", DateTime)
51
+ start_time = Column("StartTime", SmarterDateTime)
51
52
 
52
53
  __table_args__ = (Index("JobID", "JobID"), Index("PilotID", "PilotID"))
53
54
 
@@ -1,15 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from .base import (
4
- BaseSQLDB,
5
- SQLDBUnavailableError,
6
- apply_search_filters,
7
- apply_sort_constraints,
8
- )
9
- from .functions import hash, substract_date, utcnow
10
- from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn
11
-
12
- __all__ = (
3
+ __all__ = [
4
+ "_get_columns",
13
5
  "utcnow",
14
6
  "Column",
15
7
  "NullColumn",
@@ -22,4 +14,18 @@ __all__ = (
22
14
  "substract_date",
23
15
  "hash",
24
16
  "SQLDBUnavailableError",
17
+ "uuid7_from_datetime",
18
+ "uuid7_to_datetime",
19
+ ]
20
+
21
+ from .base import (
22
+ BaseSQLDB,
23
+ SQLDBUnavailableError,
24
+ _get_columns,
25
+ apply_search_filters,
26
+ apply_sort_constraints,
27
+ uuid7_from_datetime,
28
+ uuid7_to_datetime,
25
29
  )
30
+ from .functions import hash, substract_date, utcnow
31
+ from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn
@@ -7,19 +7,26 @@ import re
7
7
  from abc import ABCMeta
8
8
  from collections.abc import AsyncIterator
9
9
  from contextvars import ContextVar
10
- from datetime import datetime
11
- from typing import Self, cast
10
+ from datetime import datetime, timezone
11
+ from typing import Any, Self, cast
12
+ from uuid import UUID as StdUUID # noqa: N811
12
13
 
13
14
  from pydantic import TypeAdapter
14
- from sqlalchemy import DateTime, MetaData, select
15
+ from sqlalchemy import DateTime, MetaData, func, select
15
16
  from sqlalchemy.exc import OperationalError
16
17
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
18
+ from uuid_utils import UUID, uuid7
17
19
 
18
20
  from diracx.core.exceptions import InvalidQueryError
19
21
  from diracx.core.extensions import select_from_extension
20
- from diracx.core.models import SortDirection
22
+ from diracx.core.models import (
23
+ SearchSpec,
24
+ SortDirection,
25
+ SortSpec,
26
+ )
21
27
  from diracx.core.settings import SqlalchemyDsn
22
28
  from diracx.db.exceptions import DBUnavailableError
29
+ from diracx.db.sql.utils.types import SmarterDateTime
23
30
 
24
31
  from .functions import date_trunc
25
32
 
@@ -148,6 +155,11 @@ class BaseSQLDB(metaclass=ABCMeta):
148
155
  raise
149
156
  return db_urls
150
157
 
158
+ @classmethod
159
+ async def post_create(cls, conn: AsyncConnection) -> None:
160
+ """Execute actions after the schema has been created."""
161
+ return
162
+
151
163
  @classmethod
152
164
  def transaction(cls) -> Self:
153
165
  raise NotImplementedError("This should never be called")
@@ -199,6 +211,12 @@ class BaseSQLDB(metaclass=ABCMeta):
199
211
  try:
200
212
  self._conn.set(await self.engine.connect().__aenter__())
201
213
  except Exception as e:
214
+ logger.warning(
215
+ "Database connection failed for %s: %s",
216
+ self.__class__.__name__,
217
+ e,
218
+ exc_info=True,
219
+ )
202
220
  raise SQLDBUnavailableError(
203
221
  f"Cannot connect to {self.__class__.__name__}"
204
222
  ) from e
@@ -227,6 +245,71 @@ class BaseSQLDB(metaclass=ABCMeta):
227
245
  except OperationalError as e:
228
246
  raise SQLDBUnavailableError("Cannot ping the DB") from e
229
247
 
248
+ async def _search(
249
+ self,
250
+ table: Any,
251
+ parameters: list[str] | None,
252
+ search: list[SearchSpec],
253
+ sorts: list[SortSpec],
254
+ *,
255
+ distinct: bool = False,
256
+ per_page: int = 100,
257
+ page: int | None = None,
258
+ ) -> tuple[int, list[dict[str, Any]]]:
259
+ """Search for elements in a table."""
260
+ # Find which columns to select
261
+ columns = _get_columns(table.__table__, parameters)
262
+
263
+ stmt = select(*columns)
264
+
265
+ stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
266
+ stmt = apply_sort_constraints(table.__table__.columns.__getitem__, stmt, sorts)
267
+
268
+ if distinct:
269
+ stmt = stmt.distinct()
270
+
271
+ # Calculate total count before applying pagination
272
+ total_count_subquery = stmt.alias()
273
+ total_count_stmt = select(func.count()).select_from(total_count_subquery)
274
+ total = (await self.conn.execute(total_count_stmt)).scalar_one()
275
+
276
+ # Apply pagination
277
+ if page is not None:
278
+ if page < 1:
279
+ raise InvalidQueryError("Page must be a positive integer")
280
+ if per_page < 1:
281
+ raise InvalidQueryError("Per page must be a positive integer")
282
+ stmt = stmt.offset((page - 1) * per_page).limit(per_page)
283
+
284
+ # Execute the query
285
+ return total, [
286
+ dict(row._mapping) async for row in (await self.conn.stream(stmt))
287
+ ]
288
+
289
+ async def _summary(
290
+ self, table: Any, group_by: list[str], search: list[SearchSpec]
291
+ ) -> list[dict[str, str | int]]:
292
+ """Get a summary of the elements of a table."""
293
+ columns = _get_columns(table.__table__, group_by)
294
+
295
+ pk_columns = list(table.__table__.primary_key.columns)
296
+ if not pk_columns:
297
+ raise ValueError(
298
+ "Model has no primary key and no count_column was provided."
299
+ )
300
+ count_col = pk_columns[0]
301
+
302
+ stmt = select(*columns, func.count(count_col).label("count"))
303
+ stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
304
+ stmt = stmt.group_by(*columns)
305
+
306
+ # Execute the query
307
+ return [
308
+ dict(row._mapping)
309
+ async for row in (await self.conn.stream(stmt))
310
+ if row.count > 0 # type: ignore
311
+ ]
312
+
230
313
 
231
314
  def find_time_resolution(value):
232
315
  if isinstance(value, datetime):
@@ -258,6 +341,17 @@ def find_time_resolution(value):
258
341
  raise InvalidQueryError(f"Cannot parse {value=}")
259
342
 
260
343
 
344
+ def _get_columns(table, parameters):
345
+ columns = [x for x in table.columns]
346
+ if parameters:
347
+ if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
348
+ raise InvalidQueryError(
349
+ f"Unrecognised parameters requested {unrecognised_parameters}"
350
+ )
351
+ columns = [c for c in columns if c.name in parameters]
352
+ return columns
353
+
354
+
261
355
  def apply_search_filters(column_mapping, stmt, search):
262
356
  for query in search:
263
357
  try:
@@ -265,7 +359,7 @@ def apply_search_filters(column_mapping, stmt, search):
265
359
  except KeyError as e:
266
360
  raise InvalidQueryError(f"Unknown column {query['parameter']}") from e
267
361
 
268
- if isinstance(column.type, DateTime):
362
+ if isinstance(column.type, (DateTime, SmarterDateTime)):
269
363
  if "value" in query and isinstance(query["value"], str):
270
364
  resolution, value = find_time_resolution(query["value"])
271
365
  if resolution:
@@ -300,6 +394,15 @@ def apply_search_filters(column_mapping, stmt, search):
300
394
  expr = column.like(query["value"])
301
395
  elif query["operator"] in "ilike":
302
396
  expr = column.ilike(query["value"])
397
+ elif query["operator"] == "not like":
398
+ expr = column.not_like(query["value"])
399
+ elif query["operator"] == "regex":
400
+ # We check the regex validity here
401
+ try:
402
+ re.compile(query["value"])
403
+ except re.error as e:
404
+ raise InvalidQueryError(f"Invalid regex {query['value']}") from e
405
+ expr = column.regexp_match(query["value"])
303
406
  else:
304
407
  raise InvalidQueryError(f"Unknown filter {query=}")
305
408
  stmt = stmt.where(expr)
@@ -326,3 +429,33 @@ def apply_sort_constraints(column_mapping, stmt, sorts):
326
429
  if sort_columns:
327
430
  stmt = stmt.order_by(*sort_columns)
328
431
  return stmt
432
+
433
+
434
+ def uuid7_to_datetime(uuid: UUID | StdUUID | str) -> datetime:
435
+ """Convert a UUIDv7 to a datetime."""
436
+ if isinstance(uuid, StdUUID):
437
+ # Convert stdlib UUID to uuid_utils.UUID
438
+ uuid = UUID(str(uuid))
439
+ elif not isinstance(uuid, UUID):
440
+ # Convert string or other types to uuid_utils.UUID
441
+ uuid = UUID(uuid)
442
+ if uuid.version != 7:
443
+ raise ValueError(f"UUID {uuid} is not a UUIDv7")
444
+ return datetime.fromtimestamp(uuid.timestamp / 1000.0, tz=timezone.utc)
445
+
446
+
447
+ def uuid7_from_datetime(dt: datetime, *, randomize: bool = True) -> UUID:
448
+ """Generate a UUIDv7 corresponding to the given datetime.
449
+
450
+ If randomize is True, the standard uuid7 function is used resulting in the
451
+ lowest 62-bits being random. If randomize is False, the UUIDv7 will be the
452
+ lowest possible UUIDv7 for the given datetime.
453
+ """
454
+ timestamp = dt.timestamp()
455
+ if randomize:
456
+ uuid = uuid7(int(timestamp), int((timestamp % 1) * 1e9))
457
+ else:
458
+ time_high = int(timestamp * 1000) >> 16
459
+ time_low = int(timestamp * 1000) & 0xFFFF
460
+ uuid = UUID.from_fields((time_high, time_low, 0x7000, 0x80, 0, 0))
461
+ return uuid
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from datetime import datetime
3
4
  from functools import partial
5
+ from zoneinfo import ZoneInfo
4
6
 
5
7
  import sqlalchemy.types as types
6
8
  from sqlalchemy import Column as RawColumn
@@ -20,11 +22,8 @@ def EnumColumn(name, enum_type, **kwargs): # noqa: N802
20
22
  class EnumBackedBool(types.TypeDecorator):
21
23
  """Maps a ``EnumBackedBool()`` column to True/False in Python."""
22
24
 
23
- impl = types.Enum
24
- cache_ok: bool = True
25
-
26
- def __init__(self) -> None:
27
- super().__init__("True", "False")
25
+ impl = types.Enum("True", "False", name="enum_backed_bool")
26
+ cache_ok = True
28
27
 
29
28
  def process_bind_param(self, value, dialect) -> str:
30
29
  if value is True:
@@ -41,3 +40,98 @@ class EnumBackedBool(types.TypeDecorator):
41
40
  return False
42
41
  else:
43
42
  raise NotImplementedError(f"Unknown {value=}")
43
+
44
+
45
+ class SmarterDateTime(types.TypeDecorator):
46
+ """A DateTime type that also accepts ISO8601 strings.
47
+
48
+ Takes into account converting timezone aware datetime objects into
49
+ naive form and back when needed.
50
+
51
+ """
52
+
53
+ impl = DateTime()
54
+ cache_ok = True
55
+
56
+ def __init__(
57
+ self,
58
+ stored_tz: ZoneInfo | None = ZoneInfo("UTC"),
59
+ returned_tz: ZoneInfo = ZoneInfo("UTC"),
60
+ stored_naive_sqlite=True,
61
+ stored_naive_mysql=True,
62
+ stored_naive_postgres=False, # Forces timezone-awareness
63
+ ):
64
+ self._stored_naive_dialect = {
65
+ "sqlite": stored_naive_sqlite,
66
+ "mysql": stored_naive_mysql,
67
+ "postgres": stored_naive_postgres,
68
+ }
69
+ self._stored_tz: ZoneInfo | None = stored_tz # None = Local timezone
70
+ self._returned_tz: ZoneInfo = returned_tz
71
+
72
+ def _stored_naive(self, dialect):
73
+ if dialect.name not in self._stored_naive_dialect:
74
+ raise NotImplementedError(dialect.name)
75
+ return self._stored_naive_dialect.get(dialect.name)
76
+
77
+ def process_bind_param(self, value, dialect):
78
+ if value is None:
79
+ return None
80
+
81
+ if isinstance(value, str):
82
+ try:
83
+ value: datetime = datetime.fromisoformat(value)
84
+ except ValueError as err:
85
+ raise ValueError(f"Unable to parse datetime string: {value}") from err
86
+
87
+ if not isinstance(value, datetime):
88
+ raise ValueError(f"Expected datetime or ISO8601 string, but got {value!r}")
89
+
90
+ if not value.tzinfo:
91
+ raise ValueError(
92
+ f"Provided timestamp {value=} has no tzinfo -"
93
+ " this is problematic and may cause inconsistencies in stored timestamps.\n"
94
+ " Please always work with tz-aware datetimes / attach tzinfo to your datetime objects:"
95
+ " e.g. datetime.now(tz=timezone.utc) or use datetime_obj.astimezone() with no arguments if you need to "
96
+ "attach the local timezone to a local naive timestamp."
97
+ )
98
+
99
+ # Check that we need to convert the timezone to match self._stored_tz timezone:
100
+ if self._stored_naive(dialect):
101
+ # if self._stored_tz is None, we use our local/system timezone.
102
+ stored_tz = self._stored_tz
103
+
104
+ # astimezone converts to the stored timezone (local timezone if None)
105
+ # replace strips the TZ info --> naive datetime object
106
+ value = value.astimezone(tz=stored_tz).replace(tzinfo=None)
107
+
108
+ return value
109
+
110
+ def process_result_value(self, value, dialect):
111
+ if value is None:
112
+ return None
113
+ if not isinstance(value, datetime):
114
+ raise NotImplementedError(f"{value=} not a datetime object")
115
+
116
+ if self._stored_naive(dialect):
117
+ # Here we add back the tzinfo to the naive timestamp
118
+ # from the DB to make it aware again.
119
+ if value.tzinfo is None:
120
+ # we are definitely given a naive timestamp, so handle it.
121
+ # add back the timezone info if stored_tz is set
122
+ if self._stored_tz:
123
+ value = value.replace(tzinfo=self._stored_tz)
124
+ else:
125
+ # if stored as a local time, add back the system timezone info...
126
+ value = value.astimezone()
127
+ else:
128
+ raise ValueError(
129
+ f"stored_naive is True for {dialect.name=}, but the database engine returned "
130
+ "a tz-aware datetime. You need to check the SQLAlchemy model is consistent with the DB schema."
131
+ )
132
+
133
+ # finally, convert the datetime according to the "returned_tz"
134
+ value = value.astimezone(self._returned_tz)
135
+
136
+ # phew...
137
+ return value
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diracx-db
3
- Version: 0.0.1a46
3
+ Version: 0.0.6
4
4
  Summary: TODO
5
5
  License: GPL-3.0-only
6
6
  Classifier: Intended Audience :: Science/Research
@@ -12,8 +12,11 @@ Requires-Python: >=3.11
12
12
  Requires-Dist: diracx-core
13
13
  Requires-Dist: opensearch-py[async]
14
14
  Requires-Dist: pydantic>=2.10
15
+ Requires-Dist: python-dateutil
15
16
  Requires-Dist: sqlalchemy[aiomysql,aiosqlite]>=2
16
17
  Requires-Dist: uuid-utils
17
18
  Provides-Extra: testing
18
19
  Requires-Dist: diracx-testing; extra == 'testing'
19
20
  Requires-Dist: freezegun; extra == 'testing'
21
+ Provides-Extra: types
22
+ Requires-Dist: types-python-dateutil; extra == 'types'
@@ -1,37 +1,37 @@
1
1
  diracx/db/__init__.py,sha256=2oeUeVwZq53bo_ZOflEYZsBn7tcR5Tzb2AIu0TAWELM,109
2
- diracx/db/__main__.py,sha256=3yaUP1ig-yaPSQM4wy6CtSXXHivQg-hIz2FeBt7joBc,1714
2
+ diracx/db/__main__.py,sha256=6YlmpiU1cLLHjKLy1DfdEOQUyvSla-MbJsJ7aQwAOVs,1757
3
3
  diracx/db/exceptions.py,sha256=1nn-SZLG-nQwkxbvHjZqXhE5ouzWj1f3qhSda2B4ZEg,83
4
4
  diracx/db/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  diracx/db/os/__init__.py,sha256=IZr6z6SefrRvuC8sTC4RmB3_wwOyEt1GzpDuwSMH8O4,112
6
- diracx/db/os/job_parameters.py,sha256=loAc-bo3u-RMAp_H1g8VRt8T-rCCsXp_d9aCvg5OS-A,1225
6
+ diracx/db/os/job_parameters.py,sha256=3w_CeA2z-cY5pWwXkGu-Fod27FobbUXuwVKK-jN037U,1479
7
7
  diracx/db/os/utils.py,sha256=V4T-taos64SFNcorfIr7mq5l5y88K6TzyCj1YqWk8VI,11562
8
8
  diracx/db/sql/__init__.py,sha256=JYu0b0IVhoXy3lX2m2r2dmAjsRS7IbECBUMEDvX0Te4,391
9
9
  diracx/db/sql/auth/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- diracx/db/sql/auth/db.py,sha256=QJtBqMrhOf97UvMG0WpyjsgIRiu19v04FoDzXAyXtT0,8952
11
- diracx/db/sql/auth/schema.py,sha256=x2PEbmM_bNPdZUN5BMGMrdSmX8zkDeJ3P9XfhLBGBTs,3173
10
+ diracx/db/sql/auth/db.py,sha256=F9s05K-9C6kL2nUZY7P8zD79fiuo2biREMhfI7oCjh4,11979
11
+ diracx/db/sql/auth/schema.py,sha256=9fUV7taDPnoAcoiwRAmQraOmF2Ytoizjs2TFvN7zsVs,3132
12
12
  diracx/db/sql/dummy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- diracx/db/sql/dummy/db.py,sha256=IW4FzG7ERKbhZvC32KL7Rodu2u-zKAf8BryO4VAdJew,1650
13
+ diracx/db/sql/dummy/db.py,sha256=MKSUSJI1BlRgK08tjCfkCkOz02asvJAeBw60pAdiGV8,1212
14
14
  diracx/db/sql/dummy/schema.py,sha256=9zI53pKlzc6qBezsyjkatOQrNZdGCjwgjQ8Iz_pyAXs,789
15
15
  diracx/db/sql/job/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- diracx/db/sql/job/db.py,sha256=TnEc0fckiuMJAZg2v1_Pbwfn7kDPDam6TXp9ySuiddk,11910
17
- diracx/db/sql/job/schema.py,sha256=eFgZshe6NEzOM2qI0HI9Y3abrqDMoQIwa9L0vZugHcU,5431
16
+ diracx/db/sql/job/db.py,sha256=bX-4OMyW4h9tqeTE3OvonxTXlL6j_Qvv9uEtK5SthN8,10120
17
+ diracx/db/sql/job/schema.py,sha256=fJdmiLp6psdAjo_CoBfSAGSYk2NJkSBwvik9tznESD0,5740
18
18
  diracx/db/sql/job_logging/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  diracx/db/sql/job_logging/db.py,sha256=hyklARuEj3R1sSJ8UaObRprmsRx7RjbKAcbfgT9BwRg,5496
20
20
  diracx/db/sql/job_logging/schema.py,sha256=k6uBw-RHAcJ5GEleNpiWoXEJBhCiNG-y4xAgBKHZjjM,2524
21
21
  diracx/db/sql/pilot_agents/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  diracx/db/sql/pilot_agents/db.py,sha256=6CQ0QGV4NhsGKVCygEtE4kmIjT89xJwrIMuYZTslWFE,1231
23
- diracx/db/sql/pilot_agents/schema.py,sha256=KeWnFSpYOTrT3-_rOCFjbjNnPNXKnUZiJVsu4vv5U2U,2149
23
+ diracx/db/sql/pilot_agents/schema.py,sha256=BTFLuiwcxAvAtTvTP9C7DbGtXoM-IHVDG9k7HMx62AA,2211
24
24
  diracx/db/sql/sandbox_metadata/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  diracx/db/sql/sandbox_metadata/db.py,sha256=FtyPx6GAGJAH-lmuw8PQj6_KGHG6t3AC3-E9uWf-JNs,10236
26
26
  diracx/db/sql/sandbox_metadata/schema.py,sha256=V5gV2PHwzTbBz_th9ribLfE7Lqk8YGemDmvqq4jWQJ4,1530
27
27
  diracx/db/sql/task_queue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  diracx/db/sql/task_queue/db.py,sha256=2qul1D2tX2uCI92N591WK5xWHakG0pNibzDwKQ7W-I8,6246
29
29
  diracx/db/sql/task_queue/schema.py,sha256=5efAgvNYRkLlaJ2NzRInRfmVa3tyIzQu2l0oRPy4Kzw,3258
30
- diracx/db/sql/utils/__init__.py,sha256=QkvpqBuIAgkAOywAssYzdxSzUQVZlSUumK7mPxotXfM,547
31
- diracx/db/sql/utils/base.py,sha256=HYQuX16mgg9LAMtAEmbTmJFIN0OSMe1Hcb57dtl7LCc,12367
30
+ diracx/db/sql/utils/__init__.py,sha256=k1DI4Idlqv36pXn2BhQysb947Peio9DnYaePslkTpUQ,685
31
+ diracx/db/sql/utils/base.py,sha256=DqW-JYgjqvqkwLFqou5uzg73lZ83C0jHCgkt9qR1NTg,17255
32
32
  diracx/db/sql/utils/functions.py,sha256=_E4tc9Gti6LuSh7QEyoqPJSvCuByVqvRenOXCzxsulE,4014
33
- diracx/db/sql/utils/types.py,sha256=yU-tXsu6hFGPsr9ba1n3ZjGPnHQI_06lbpkTeDCWJtg,1287
34
- diracx_db-0.0.1a46.dist-info/METADATA,sha256=hu6BtBuYz30tuMO8Z40gyZhSf53QzqTKO45ryw6tTL4,675
35
- diracx_db-0.0.1a46.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
36
- diracx_db-0.0.1a46.dist-info/entry_points.txt,sha256=UPqhLvb9gui0kOyWeI_edtefcrHToZmQt1p76vIwujo,317
37
- diracx_db-0.0.1a46.dist-info/RECORD,,
33
+ diracx/db/sql/utils/types.py,sha256=KNZWJfpvHTjfIPg6Nn7zY-rS0q3ybnirHcTcLAYSYbE,5118
34
+ diracx_db-0.0.6.dist-info/METADATA,sha256=Lu8x2pR3BfnKGgHYz4w5Z4CTCf7tPi9p9tlldVFJiLo,780
35
+ diracx_db-0.0.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
36
+ diracx_db-0.0.6.dist-info/entry_points.txt,sha256=UPqhLvb9gui0kOyWeI_edtefcrHToZmQt1p76vIwujo,317
37
+ diracx_db-0.0.6.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any