diracx-db 0.0.1a21__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of diracx-db might be problematic. Click here for more details.

diracx/db/__main__.py CHANGED
@@ -31,7 +31,6 @@ async def init_sql():
31
31
  from diracx.db.sql.utils import BaseSQLDB
32
32
 
33
33
  for db_name, db_url in BaseSQLDB.available_urls().items():
34
-
35
34
  logger.info("Initialising %s", db_name)
36
35
  db = BaseSQLDB.available_implementations(db_name)[0](db_url)
37
36
  async with db.engine_context():
@@ -40,6 +39,7 @@ async def init_sql():
40
39
  if db._db_url.startswith("sqlite"):
41
40
  await conn.exec_driver_sql("PRAGMA foreign_keys=ON")
42
41
  await conn.run_sync(db.metadata.create_all)
42
+ await db.post_create(conn)
43
43
 
44
44
 
45
45
  async def init_os():
diracx/db/exceptions.py CHANGED
@@ -1,2 +1,5 @@
1
- class DBUnavailable(Exception):
1
+ from __future__ import annotations
2
+
3
+
4
+ class DBUnavailableError(Exception):
2
5
  pass
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from datetime import UTC, datetime
4
+
3
5
  from diracx.db.os.utils import BaseOSDB
4
6
 
5
7
 
@@ -7,19 +9,35 @@ class JobParametersDB(BaseOSDB):
7
9
  fields = {
8
10
  "JobID": {"type": "long"},
9
11
  "timestamp": {"type": "date"},
12
+ "PilotAgent": {"type": "keyword"},
13
+ "Pilot_Reference": {"type": "keyword"},
14
+ "JobGroup": {"type": "keyword"},
10
15
  "CPUNormalizationFactor": {"type": "long"},
11
16
  "NormCPUTime(s)": {"type": "long"},
12
- "Memory(kB)": {"type": "long"},
17
+ "Memory(MB)": {"type": "long"},
18
+ "LocalAccount": {"type": "keyword"},
13
19
  "TotalCPUTime(s)": {"type": "long"},
14
- "MemoryUsed(kb)": {"type": "long"},
15
- "HostName": {"type": "keyword"},
20
+ "PayloadPID": {"type": "long"},
21
+ "HostName": {"type": "text"},
16
22
  "GridCE": {"type": "keyword"},
23
+ "CEQueue": {"type": "keyword"},
24
+ "BatchSystem": {"type": "keyword"},
17
25
  "ModelName": {"type": "keyword"},
18
26
  "Status": {"type": "keyword"},
19
27
  "JobType": {"type": "keyword"},
20
28
  }
21
- index_prefix = "mysetup_elasticjobparameters_index_"
29
+ # TODO: Does this need to be configurable?
30
+ index_prefix = "job_parameters"
31
+
32
+ def index_name(self, vo, doc_id: int) -> str:
33
+ split = int(int(doc_id) // 1e6)
34
+ # The index name must be lowercase or opensearchpy will throw.
35
+ return f"{self.index_prefix}_{vo.lower()}_{split}m"
22
36
 
23
- def index_name(self, doc_id: int) -> str:
24
- # TODO: Remove setup and replace "123.0m" with "120m"?
25
- return f"{self.index_prefix}_{doc_id // 1e6:.1f}m"
37
+ def upsert(self, vo, doc_id, document):
38
+ document = {
39
+ "JobID": doc_id,
40
+ "timestamp": int(datetime.now(tz=UTC).timestamp() * 1000),
41
+ **document,
42
+ }
43
+ return super().upsert(vo, doc_id, document)
diracx/db/os/utils.py CHANGED
@@ -16,7 +16,7 @@ from opensearchpy import AsyncOpenSearch
16
16
 
17
17
  from diracx.core.exceptions import InvalidQueryError
18
18
  from diracx.core.extensions import select_from_extension
19
- from diracx.db.exceptions import DBUnavailable
19
+ from diracx.db.exceptions import DBUnavailableError
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -25,7 +25,7 @@ class OpenSearchDBError(Exception):
25
25
  pass
26
26
 
27
27
 
28
- class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
28
+ class OpenSearchDBUnavailableError(DBUnavailableError, OpenSearchDBError):
29
29
  pass
30
30
 
31
31
 
@@ -38,7 +38,7 @@ class BaseOSDB(metaclass=ABCMeta):
38
38
 
39
39
  The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`.
40
40
  This method returns a dictionary of database names to connection parameters.
41
- The available databases are determined by the `diracx.db.os` entrypoint in
41
+ The available databases are determined by the `diracx.dbs.os` entrypoint in
42
42
  the `pyproject.toml` file and the connection parameters are taken from the
43
43
  environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
44
44
 
@@ -77,7 +77,7 @@ class BaseOSDB(metaclass=ABCMeta):
77
77
  index_prefix: str
78
78
 
79
79
  @abstractmethod
80
- def index_name(self, doc_id: int) -> str: ...
80
+ def index_name(self, vo: str, doc_id: int) -> str: ...
81
81
 
82
82
  def __init__(self, connection_kwargs: dict[str, Any]) -> None:
83
83
  self._client: AsyncOpenSearch | None = None
@@ -92,7 +92,9 @@ class BaseOSDB(metaclass=ABCMeta):
92
92
  """Return the available implementations of the DB in reverse priority order."""
93
93
  db_classes: list[type[BaseOSDB]] = [
94
94
  entry_point.load()
95
- for entry_point in select_from_extension(group="diracx.db.os", name=db_name)
95
+ for entry_point in select_from_extension(
96
+ group="diracx.dbs.os", name=db_name
97
+ )
96
98
  ]
97
99
  if not db_classes:
98
100
  raise NotImplementedError(f"Could not find any matches for {db_name=}")
@@ -106,7 +108,7 @@ class BaseOSDB(metaclass=ABCMeta):
106
108
  prefixed with ``DIRACX_OS_DB_{DB_NAME}``.
107
109
  """
108
110
  conn_kwargs: dict[str, dict[str, Any]] = {}
109
- for entry_point in select_from_extension(group="diracx.db.os"):
111
+ for entry_point in select_from_extension(group="diracx.dbs.os"):
110
112
  db_name = entry_point.name
111
113
  var_name = f"DIRACX_OS_DB_{entry_point.name.upper()}"
112
114
  if var_name in os.environ:
@@ -152,7 +154,7 @@ class BaseOSDB(metaclass=ABCMeta):
152
154
  be ran at every query.
153
155
  """
154
156
  if not await self.client.ping():
155
- raise OpenSearchDBUnavailable(
157
+ raise OpenSearchDBUnavailableError(
156
158
  f"Failed to connect to {self.__class__.__qualname__}"
157
159
  )
158
160
 
@@ -180,15 +182,20 @@ class BaseOSDB(metaclass=ABCMeta):
180
182
  )
181
183
  assert result["acknowledged"]
182
184
 
183
- async def upsert(self, doc_id, document) -> None:
184
- # TODO: Implement properly
185
+ async def upsert(self, vo: str, doc_id: int, document: Any) -> None:
186
+ index_name = self.index_name(vo, doc_id)
185
187
  response = await self.client.update(
186
- index=self.index_name(doc_id),
188
+ index=index_name,
187
189
  id=doc_id,
188
190
  body={"doc": document, "doc_as_upsert": True},
189
191
  params=dict(retry_on_conflict=10),
190
192
  )
191
- print(f"{response=}")
193
+ logger.debug(
194
+ "Upserted document %s in index %s with response: %s",
195
+ doc_id,
196
+ index_name,
197
+ response,
198
+ )
192
199
 
193
200
  async def search(
194
201
  self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None
diracx/db/sql/auth/db.py CHANGED
@@ -1,19 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- import hashlib
3
+ import logging
4
4
  import secrets
5
- from datetime import datetime
6
- from uuid import uuid4
5
+ from datetime import UTC, datetime
6
+ from itertools import pairwise
7
7
 
8
- from sqlalchemy import insert, select, update
8
+ from dateutil.rrule import MONTHLY, rrule
9
+ from sqlalchemy import insert, select, text, update
9
10
  from sqlalchemy.exc import IntegrityError, NoResultFound
11
+ from sqlalchemy.ext.asyncio import AsyncConnection
12
+ from uuid_utils import UUID, uuid7
10
13
 
11
14
  from diracx.core.exceptions import (
12
15
  AuthorizationError,
13
- ExpiredFlowError,
14
- PendingAuthorizationError,
16
+ TokenNotFoundError,
15
17
  )
16
- from diracx.db.sql.utils import BaseSQLDB, substract_date
18
+ from diracx.db.sql.utils import BaseSQLDB, hash, substract_date, uuid7_from_datetime
17
19
 
18
20
  from .schema import (
19
21
  AuthorizationFlows,
@@ -28,10 +30,72 @@ from .schema import Base as AuthDBBase
28
30
  USER_CODE_ALPHABET = "BCDFGHJKLMNPQRSTVWXZ"
29
31
  MAX_RETRY = 5
30
32
 
33
+ logger = logging.getLogger(__name__)
34
+
31
35
 
32
36
  class AuthDB(BaseSQLDB):
33
37
  metadata = AuthDBBase.metadata
34
38
 
39
+ @classmethod
40
+ async def post_create(cls, conn: AsyncConnection) -> None:
41
+ """Create partitions if it is a MySQL DB and it does not have
42
+ it yet and the table does not have any data yet.
43
+ We do this as a post_create step as sqlalchemy does not support
44
+ partition so well.
45
+ """
46
+ if conn.dialect.name == "mysql":
47
+ check_partition_query = text(
48
+ "SELECT PARTITION_NAME FROM information_schema.partitions "
49
+ "WHERE TABLE_NAME = 'RefreshTokens' AND PARTITION_NAME is not NULL"
50
+ )
51
+ partition_names = (await conn.execute(check_partition_query)).all()
52
+
53
+ if not partition_names:
54
+ # Create a monthly partition from today until 2 years
55
+ # The partition are named p_<year>_<month>
56
+ start_date = datetime.now(tz=UTC).replace(
57
+ day=1, hour=0, minute=0, second=0, microsecond=0
58
+ )
59
+ end_date = start_date.replace(year=start_date.year + 2)
60
+
61
+ dates = [
62
+ dt for dt in rrule(MONTHLY, dtstart=start_date, until=end_date)
63
+ ]
64
+
65
+ partition_list = []
66
+ for name, limit in pairwise(dates):
67
+ partition_list.append(
68
+ f"PARTITION p_{name.year}_{name.month} "
69
+ f"VALUES LESS THAN ('{str(uuid7_from_datetime(limit, randomize=False)).replace('-', '')}')"
70
+ )
71
+ partition_list.append("PARTITION p_future VALUES LESS THAN (MAXVALUE)")
72
+
73
+ alter_query = text(
74
+ f"ALTER TABLE RefreshTokens PARTITION BY RANGE COLUMNS (JTI) ({','.join(partition_list)})"
75
+ )
76
+
77
+ check_table_empty_query = text("SELECT * FROM RefreshTokens LIMIT 1")
78
+ refresh_table_content = (
79
+ await conn.execute(check_table_empty_query)
80
+ ).all()
81
+ if refresh_table_content:
82
+ logger.warning(
83
+ "RefreshTokens table not empty. Run the following query yourself"
84
+ )
85
+ logger.warning(alter_query)
86
+ return
87
+
88
+ await conn.execute(alter_query)
89
+
90
+ partition_names = (
91
+ await conn.execute(
92
+ check_partition_query, {"table_name": "RefreshTokens"}
93
+ )
94
+ ).all()
95
+ assert partition_names, (
96
+ f"There should be partitions now {partition_names}"
97
+ )
98
+
35
99
  async def device_flow_validate_user_code(
36
100
  self, user_code: str, max_validity: int
37
101
  ) -> str:
@@ -50,44 +114,25 @@ class AuthDB(BaseSQLDB):
50
114
 
51
115
  return (await self.conn.execute(stmt)).scalar_one()
52
116
 
53
- async def get_device_flow(self, device_code: str, max_validity: int):
117
+ async def get_device_flow(self, device_code: str):
54
118
  """:raises: NoResultFound"""
55
119
  # The with_for_update
56
120
  # prevents that the token is retrieved
57
121
  # multiple time concurrently
58
- stmt = select(
59
- DeviceFlows,
60
- (DeviceFlows.creation_time < substract_date(seconds=max_validity)).label(
61
- "is_expired"
62
- ),
63
- ).with_for_update()
122
+ stmt = select(DeviceFlows).with_for_update()
64
123
  stmt = stmt.where(
65
- DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(),
124
+ DeviceFlows.device_code == hash(device_code),
66
125
  )
67
- res = dict((await self.conn.execute(stmt)).one()._mapping)
68
-
69
- if res["is_expired"]:
70
- raise ExpiredFlowError()
71
-
72
- if res["status"] == FlowStatus.READY:
73
- # Update the status to Done before returning
74
- await self.conn.execute(
75
- update(DeviceFlows)
76
- .where(
77
- DeviceFlows.device_code
78
- == hashlib.sha256(device_code.encode()).hexdigest()
79
- )
80
- .values(status=FlowStatus.DONE)
81
- )
82
- return res
83
-
84
- if res["status"] == FlowStatus.DONE:
85
- raise AuthorizationError("Code was already used")
126
+ return dict((await self.conn.execute(stmt)).one()._mapping)
86
127
 
87
- if res["status"] == FlowStatus.PENDING:
88
- raise PendingAuthorizationError()
89
-
90
- raise AuthorizationError("Bad state in device flow")
128
+ async def update_device_flow_status(
129
+ self, device_code: str, status: FlowStatus
130
+ ) -> None:
131
+ stmt = update(DeviceFlows).where(
132
+ DeviceFlows.device_code == hash(device_code),
133
+ )
134
+ stmt = stmt.values(status=status)
135
+ await self.conn.execute(stmt)
91
136
 
92
137
  async def device_flow_insert_id_token(
93
138
  self, user_code: str, id_token: dict[str, str], max_validity: int
@@ -121,7 +166,7 @@ class AuthDB(BaseSQLDB):
121
166
  device_code = secrets.token_urlsafe()
122
167
 
123
168
  # Hash the the device_code to avoid leaking information
124
- hashed_device_code = hashlib.sha256(device_code.encode()).hexdigest()
169
+ hashed_device_code = hash(device_code)
125
170
 
126
171
  stmt = insert(DeviceFlows).values(
127
172
  client_id=client_id,
@@ -133,6 +178,10 @@ class AuthDB(BaseSQLDB):
133
178
  await self.conn.execute(stmt)
134
179
 
135
180
  except IntegrityError:
181
+ logger.warning(
182
+ "Device flow code collision detected, retrying (user_code=%s)",
183
+ user_code,
184
+ )
136
185
  continue
137
186
 
138
187
  return user_code, device_code
@@ -148,7 +197,7 @@ class AuthDB(BaseSQLDB):
148
197
  code_challenge_method: str,
149
198
  redirect_uri: str,
150
199
  ) -> str:
151
- uuid = str(uuid4())
200
+ uuid = str(uuid7())
152
201
 
153
202
  stmt = insert(AuthorizationFlows).values(
154
203
  uuid=uuid,
@@ -171,7 +220,7 @@ class AuthDB(BaseSQLDB):
171
220
  """
172
221
  # Hash the code to avoid leaking information
173
222
  code = secrets.token_urlsafe()
174
- hashed_code = hashlib.sha256(code.encode()).hexdigest()
223
+ hashed_code = hash(code)
175
224
 
176
225
  stmt = update(AuthorizationFlows)
177
226
 
@@ -190,10 +239,11 @@ class AuthDB(BaseSQLDB):
190
239
  stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
191
240
  stmt = stmt.where(AuthorizationFlows.uuid == uuid)
192
241
  row = (await self.conn.execute(stmt)).one()
193
- return code, row.redirect_uri
242
+ return code, row.RedirectURI
194
243
 
195
244
  async def get_authorization_flow(self, code: str, max_validity: int):
196
- hashed_code = hashlib.sha256(code.encode()).hexdigest()
245
+ """Get the authorization flow details based on the code."""
246
+ hashed_code = hash(code)
197
247
  # The with_for_update
198
248
  # prevents that the token is retrieved
199
249
  # multiple time concurrently
@@ -203,54 +253,39 @@ class AuthDB(BaseSQLDB):
203
253
  AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
204
254
  )
205
255
 
206
- res = dict((await self.conn.execute(stmt)).one()._mapping)
207
-
208
- if res["status"] == FlowStatus.READY:
209
- # Update the status to Done before returning
210
- await self.conn.execute(
211
- update(AuthorizationFlows)
212
- .where(AuthorizationFlows.code == hashed_code)
213
- .values(status=FlowStatus.DONE)
214
- )
256
+ return dict((await self.conn.execute(stmt)).one()._mapping)
215
257
 
216
- return res
217
-
218
- if res["status"] == FlowStatus.DONE:
219
- raise AuthorizationError("Code was already used")
220
-
221
- raise AuthorizationError("Bad state in authorization flow")
258
+ async def update_authorization_flow_status(
259
+ self, code: str, status: FlowStatus
260
+ ) -> None:
261
+ """Update the status of an authorization flow based on the code."""
262
+ hashed_code = hash(code)
263
+ await self.conn.execute(
264
+ update(AuthorizationFlows)
265
+ .where(AuthorizationFlows.code == hashed_code)
266
+ .values(status=status)
267
+ )
222
268
 
223
269
  async def insert_refresh_token(
224
270
  self,
271
+ jti: UUID,
225
272
  subject: str,
226
- preferred_username: str,
227
273
  scope: str,
228
- ) -> tuple[str, datetime]:
274
+ ) -> None:
229
275
  """Insert a refresh token in the DB as well as user attributes
230
276
  required to generate access tokens.
231
277
  """
232
- # Generate a JWT ID
233
- jti = str(uuid4())
234
-
235
278
  # Insert values into the DB
236
279
  stmt = insert(RefreshTokens).values(
237
- jti=jti,
280
+ jti=str(jti),
238
281
  sub=subject,
239
- preferred_username=preferred_username,
240
282
  scope=scope,
241
283
  )
242
284
  await self.conn.execute(stmt)
243
285
 
244
- # Get the creation time of the new tuple: generated by the insert operation
245
- stmt = select(RefreshTokens.creation_time)
246
- stmt = stmt.where(RefreshTokens.jti == jti)
247
- row = (await self.conn.execute(stmt)).one()
248
-
249
- # Return the JWT ID and the creation time
250
- return jti, row.creation_time
251
-
252
- async def get_refresh_token(self, jti: str) -> dict:
286
+ async def get_refresh_token(self, jti: UUID) -> dict:
253
287
  """Get refresh token details bound to a given JWT ID."""
288
+ jti = str(jti)
254
289
  # The with_for_update
255
290
  # prevents that the token is retrieved
256
291
  # multiple time concurrently
@@ -260,8 +295,8 @@ class AuthDB(BaseSQLDB):
260
295
  )
261
296
  try:
262
297
  res = dict((await self.conn.execute(stmt)).one()._mapping)
263
- except NoResultFound:
264
- return {}
298
+ except NoResultFound as e:
299
+ raise TokenNotFoundError(jti) from e
265
300
 
266
301
  return res
267
302
 
@@ -285,11 +320,11 @@ class AuthDB(BaseSQLDB):
285
320
 
286
321
  return refresh_tokens
287
322
 
288
- async def revoke_refresh_token(self, jti: str):
323
+ async def revoke_refresh_token(self, jti: UUID):
289
324
  """Revoke a token given by its JWT ID."""
290
325
  await self.conn.execute(
291
326
  update(RefreshTokens)
292
- .where(RefreshTokens.jti == jti)
327
+ .where(RefreshTokens.jti == str(jti))
293
328
  .values(status=RefreshTokenStatus.REVOKED)
294
329
  )
295
330
 
@@ -1,13 +1,21 @@
1
+ from __future__ import annotations
2
+
1
3
  from enum import Enum, auto
2
4
 
3
5
  from sqlalchemy import (
4
6
  JSON,
7
+ Index,
5
8
  String,
6
9
  Uuid,
7
10
  )
8
11
  from sqlalchemy.orm import declarative_base
9
12
 
10
- from diracx.db.sql.utils import Column, DateNowColumn, EnumColumn, NullColumn
13
+ from diracx.db.sql.utils import (
14
+ Column,
15
+ DateNowColumn,
16
+ EnumColumn,
17
+ NullColumn,
18
+ )
11
19
 
12
20
  USER_CODE_LENGTH = 8
13
21
 
@@ -39,27 +47,27 @@ class FlowStatus(Enum):
39
47
 
40
48
  class DeviceFlows(Base):
41
49
  __tablename__ = "DeviceFlows"
42
- user_code = Column(String(USER_CODE_LENGTH), primary_key=True)
43
- status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
44
- creation_time = DateNowColumn()
45
- client_id = Column(String(255))
46
- scope = Column(String(1024))
47
- device_code = Column(String(128), unique=True) # Should be a hash
48
- id_token = NullColumn(JSON())
50
+ user_code = Column("UserCode", String(USER_CODE_LENGTH), primary_key=True)
51
+ status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
52
+ creation_time = DateNowColumn("CreationTime")
53
+ client_id = Column("ClientID", String(255))
54
+ scope = Column("Scope", String(1024))
55
+ device_code = Column("DeviceCode", String(128), unique=True) # Should be a hash
56
+ id_token = NullColumn("IDToken", JSON())
49
57
 
50
58
 
51
59
  class AuthorizationFlows(Base):
52
60
  __tablename__ = "AuthorizationFlows"
53
- uuid = Column(Uuid(as_uuid=False), primary_key=True)
54
- status = EnumColumn(FlowStatus, server_default=FlowStatus.PENDING.name)
55
- client_id = Column(String(255))
56
- creation_time = DateNowColumn()
57
- scope = Column(String(1024))
58
- code_challenge = Column(String(255))
59
- code_challenge_method = Column(String(8))
60
- redirect_uri = Column(String(255))
61
- code = NullColumn(String(255)) # Should be a hash
62
- id_token = NullColumn(JSON())
61
+ uuid = Column("UUID", Uuid(as_uuid=False), primary_key=True)
62
+ status = EnumColumn("Status", FlowStatus, server_default=FlowStatus.PENDING.name)
63
+ client_id = Column("ClientID", String(255))
64
+ creation_time = DateNowColumn("CreationTime")
65
+ scope = Column("Scope", String(1024))
66
+ code_challenge = Column("CodeChallenge", String(255))
67
+ code_challenge_method = Column("CodeChallengeMethod", String(8))
68
+ redirect_uri = Column("RedirectURI", String(255))
69
+ code = NullColumn("Code", String(255)) # Should be a hash
70
+ id_token = NullColumn("IDToken", JSON())
63
71
 
64
72
 
65
73
  class RefreshTokenStatus(Enum):
@@ -85,13 +93,13 @@ class RefreshTokens(Base):
85
93
 
86
94
  __tablename__ = "RefreshTokens"
87
95
  # Refresh token attributes
88
- jti = Column(Uuid(as_uuid=False), primary_key=True)
96
+ jti = Column("JTI", Uuid(as_uuid=False), primary_key=True)
89
97
  status = EnumColumn(
90
- RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
98
+ "Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
91
99
  )
92
- creation_time = DateNowColumn()
93
- scope = Column(String(1024))
100
+ scope = Column("Scope", String(1024))
94
101
 
95
102
  # User attributes bound to the refresh token
96
- sub = Column(String(1024))
97
- preferred_username = Column(String(255))
103
+ sub = Column("Sub", String(256), index=True)
104
+
105
+ __table_args__ = (Index("index_status_sub", status, sub),)
diracx/db/sql/dummy/db.py CHANGED
@@ -1,10 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from uuid import UUID
3
+ from sqlalchemy import insert
4
+ from uuid_utils import UUID
4
5
 
5
- from sqlalchemy import func, insert, select
6
-
7
- from diracx.db.sql.utils import BaseSQLDB, apply_search_filters
6
+ from diracx.db.sql.utils import BaseSQLDB
8
7
 
9
8
  from .schema import Base as DummyDBBase
10
9
  from .schema import Cars, Owners
@@ -23,18 +22,7 @@ class DummyDB(BaseSQLDB):
23
22
  metadata = DummyDBBase.metadata
24
23
 
25
24
  async def summary(self, group_by, search) -> list[dict[str, str | int]]:
26
- columns = [Cars.__table__.columns[x] for x in group_by]
27
-
28
- stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
29
- stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
30
- stmt = stmt.group_by(*columns)
31
-
32
- # Execute the query
33
- return [
34
- dict(row._mapping)
35
- async for row in (await self.conn.stream(stmt))
36
- if row.count > 0 # type: ignore
37
- ]
25
+ return await self._summary(Cars, group_by, search)
38
26
 
39
27
  async def insert_owner(self, name: str) -> int:
40
28
  stmt = insert(Owners).values(name=name)
@@ -44,7 +32,7 @@ class DummyDB(BaseSQLDB):
44
32
 
45
33
  async def insert_car(self, license_plate: UUID, model: str, owner_id: int) -> int:
46
34
  stmt = insert(Cars).values(
47
- licensePlate=license_plate, model=model, ownerID=owner_id
35
+ license_plate=license_plate, model=model, owner_id=owner_id
48
36
  )
49
37
 
50
38
  result = await self.conn.execute(stmt)
@@ -1,5 +1,7 @@
1
1
  # The utils class define some boilerplate types that should be used
2
2
  # in place of the SQLAlchemy one. Have a look at them
3
+ from __future__ import annotations
4
+
3
5
  from sqlalchemy import ForeignKey, Integer, String, Uuid
4
6
  from sqlalchemy.orm import declarative_base
5
7
 
@@ -10,13 +12,13 @@ Base = declarative_base()
10
12
 
11
13
  class Owners(Base):
12
14
  __tablename__ = "Owners"
13
- ownerID = Column(Integer, primary_key=True, autoincrement=True)
14
- creation_time = DateNowColumn()
15
- name = Column(String(255))
15
+ owner_id = Column("OwnerID", Integer, primary_key=True, autoincrement=True)
16
+ creation_time = DateNowColumn("CreationTime")
17
+ name = Column("Name", String(255))
16
18
 
17
19
 
18
20
  class Cars(Base):
19
21
  __tablename__ = "Cars"
20
- licensePlate = Column(Uuid(), primary_key=True)
21
- model = Column(String(255))
22
- ownerID = Column(Integer, ForeignKey(Owners.ownerID))
22
+ license_plate = Column("LicensePlate", Uuid(), primary_key=True)
23
+ model = Column("Model", String(255))
24
+ owner_id = Column("OwnerID", Integer, ForeignKey(Owners.owner_id))