diracx-db 0.0.1a17__py3-none-any.whl → 0.0.1a19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diracx/db/__main__.py CHANGED
@@ -31,6 +31,7 @@ async def init_sql():
31
31
  from diracx.db.sql.utils import BaseSQLDB
32
32
 
33
33
  for db_name, db_url in BaseSQLDB.available_urls().items():
34
+
34
35
  logger.info("Initialising %s", db_name)
35
36
  db = BaseSQLDB.available_implementations(db_name)[0](db_url)
36
37
  async with db.engine_context():
diracx/db/os/utils.py CHANGED
@@ -7,9 +7,10 @@ import json
7
7
  import logging
8
8
  import os
9
9
  from abc import ABCMeta, abstractmethod
10
+ from collections.abc import AsyncIterator
10
11
  from contextvars import ContextVar
11
12
  from datetime import datetime
12
- from typing import Any, AsyncIterator, Self
13
+ from typing import Any, Self
13
14
 
14
15
  from opensearchpy import AsyncOpenSearch
15
16
 
@@ -29,6 +30,48 @@ class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
29
30
 
30
31
 
31
32
  class BaseOSDB(metaclass=ABCMeta):
33
+ """This should be the base class of all the OpenSearch DiracX DBs.
34
+
35
+ The details covered here should be handled automatically by the service and
36
+ task machinery of DiracX and this documentation exists for informational
37
+ purposes.
38
+
39
+ The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`.
40
+ This method returns a dictionary of database names to connection parameters.
41
+ The available databases are determined by the `diracx.db.os` entrypoint in
42
+ the `pyproject.toml` file and the connection parameters are taken from the
43
+ environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
44
+
45
+ If extensions to DiracX are being used, there can be multiple implementations
46
+ of the same database. To list the available implementations use
47
+ `BaseOSDB.available_implementations(db_name)`. The first entry in this list
48
+ will be the preferred implementation and it can be initialized by calling
49
+ its `__init__` function with the connection parameters previously obtained
50
+ from `BaseOSDB.available_urls`.
51
+
52
+ To control the lifetime of the OpenSearch client, the `BaseOSDB.client_context`
53
+ asynchronous context manager should be entered. When inside this context
54
+ manager, the client can be accessed with `BaseOSDB.client`.
55
+
56
+ Upon entering, the DB class can then be used as an asynchronous context
57
+ manager to perform operations. Currently this context manager has no effect
58
+ however it must be used as it may be used in future. When inside this
59
+ context manager, the DB connection can be accessed with `BaseOSDB.client`.
60
+
61
+ For example:
62
+
63
+ ```python
64
+ db_name = ...
65
+ conn_params = BaseOSDB.available_urls()[db_name]
66
+ MyDBClass = BaseOSDB.available_implementations(db_name)[0]
67
+
68
+ db = MyDBClass(conn_params)
69
+ async with db.client_context:
70
+ async with db:
71
+ # Do something with the OpenSearch client
72
+ ```
73
+ """
74
+
32
75
  # TODO: Make metadata an abstract property
33
76
  fields: dict
34
77
  index_prefix: str
@@ -77,13 +120,15 @@ class BaseOSDB(metaclass=ABCMeta):
77
120
  @classmethod
78
121
  def session(cls) -> Self:
79
122
  """This is just a fake method such that the Dependency overwrite has
80
- a hash to use"""
123
+ a hash to use.
124
+ """
81
125
  raise NotImplementedError("This should never be called")
82
126
 
83
127
  @property
84
128
  def client(self) -> AsyncOpenSearch:
85
129
  """Just a getter for _client, making sure we entered
86
- the context manager"""
130
+ the context manager.
131
+ """
87
132
  if self._client is None:
88
133
  raise RuntimeError(f"{self.__class__} was used before entering")
89
134
  return self._client
@@ -91,17 +136,18 @@ class BaseOSDB(metaclass=ABCMeta):
91
136
  @contextlib.asynccontextmanager
92
137
  async def client_context(self) -> AsyncIterator[None]:
93
138
  """Context manage to manage the client lifecycle.
94
- This is called when starting fastapi
139
+ This is called when starting fastapi.
95
140
 
96
141
  """
97
142
  assert self._client is None, "client_context cannot be nested"
98
143
  async with AsyncOpenSearch(**self._connection_kwargs) as self._client:
99
- yield
100
- self._client = None
144
+ try:
145
+ yield
146
+ finally:
147
+ self._client = None
101
148
 
102
149
  async def ping(self):
103
- """
104
- Check whether the connection to the DB is still working.
150
+ """Check whether the connection to the DB is still working.
105
151
  We could enable the ``pre_ping`` in the engine, but this would
106
152
  be ran at every query.
107
153
  """
@@ -113,7 +159,7 @@ class BaseOSDB(metaclass=ABCMeta):
113
159
  async def __aenter__(self):
114
160
  """This is entered on every request.
115
161
  At the moment it does nothing, however, we keep it here
116
- in case we ever want to use OpenSearch equivalent of a transaction
162
+ in case we ever want to use OpenSearch equivalent of a transaction.
117
163
  """
118
164
  assert not self._conn.get(), "BaseOSDB context cannot be nested"
119
165
  assert self._client is not None, "client_context hasn't been entered"
@@ -122,9 +168,7 @@ class BaseOSDB(metaclass=ABCMeta):
122
168
 
123
169
  async def __aexit__(self, exc_type, exc, tb):
124
170
  assert self._conn.get()
125
- self._client = None
126
171
  self._conn.set(False)
127
- return
128
172
 
129
173
  async def create_index_template(self) -> None:
130
174
  template_body = {
@@ -237,6 +281,11 @@ def apply_search_filters(db_fields, search):
237
281
  operator, field_name, field_type, {"keyword", "long", "date"}
238
282
  )
239
283
  result["must"].append({"terms": {field_name: query["values"]}})
284
+ case "not in":
285
+ require_type(
286
+ operator, field_name, field_type, {"keyword", "long", "date"}
287
+ )
288
+ result["must_not"].append({"terms": {field_name: query["values"]}})
240
289
  # TODO: Implement like and ilike
241
290
  # If the pattern is a simple "col like 'abc%'", we can use a prefix query
242
291
  # Else we need to use a wildcard query where we replace % with * and _ with ?
diracx/db/sql/__init__.py CHANGED
@@ -1,7 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- __all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB")
3
+ __all__ = (
4
+ "AuthDB",
5
+ "JobDB",
6
+ "JobLoggingDB",
7
+ "PilotAgentsDB",
8
+ "SandboxMetadataDB",
9
+ "TaskQueueDB",
10
+ )
4
11
 
5
12
  from .auth.db import AuthDB
6
- from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB
13
+ from .job.db import JobDB
14
+ from .job_logging.db import JobLoggingDB
15
+ from .pilot_agents.db import PilotAgentsDB
7
16
  from .sandbox_metadata.db import SandboxMetadataDB
17
+ from .task_queue.db import TaskQueueDB
diracx/db/sql/auth/db.py CHANGED
@@ -35,7 +35,7 @@ class AuthDB(BaseSQLDB):
35
35
  async def device_flow_validate_user_code(
36
36
  self, user_code: str, max_validity: int
37
37
  ) -> str:
38
- """Validate that the user_code can be used (Pending status, not expired)
38
+ """Validate that the user_code can be used (Pending status, not expired).
39
39
 
40
40
  Returns the scope field for the given user_code
41
41
 
@@ -51,9 +51,7 @@ class AuthDB(BaseSQLDB):
51
51
  return (await self.conn.execute(stmt)).scalar_one()
52
52
 
53
53
  async def get_device_flow(self, device_code: str, max_validity: int):
54
- """
55
- :raises: NoResultFound
56
- """
54
+ """:raises: NoResultFound"""
57
55
  # The with_for_update
58
56
  # prevents that the token is retrieved
59
57
  # multiple time concurrently
@@ -94,9 +92,7 @@ class AuthDB(BaseSQLDB):
94
92
  async def device_flow_insert_id_token(
95
93
  self, user_code: str, id_token: dict[str, str], max_validity: int
96
94
  ) -> None:
97
- """
98
- :raises: AuthorizationError if no such code or status not pending
99
- """
95
+ """:raises: AuthorizationError if no such code or status not pending"""
100
96
  stmt = update(DeviceFlows)
101
97
  stmt = stmt.where(
102
98
  DeviceFlows.user_code == user_code,
@@ -170,11 +166,9 @@ class AuthDB(BaseSQLDB):
170
166
  async def authorization_flow_insert_id_token(
171
167
  self, uuid: str, id_token: dict[str, str], max_validity: int
172
168
  ) -> tuple[str, str]:
169
+ """Returns code, redirect_uri
170
+ :raises: AuthorizationError if no such uuid or status not pending.
173
171
  """
174
- returns code, redirect_uri
175
- :raises: AuthorizationError if no such uuid or status not pending
176
- """
177
-
178
172
  # Hash the code to avoid leaking information
179
173
  code = secrets.token_urlsafe()
180
174
  hashed_code = hashlib.sha256(code.encode()).hexdigest()
@@ -232,8 +226,7 @@ class AuthDB(BaseSQLDB):
232
226
  preferred_username: str,
233
227
  scope: str,
234
228
  ) -> tuple[str, datetime]:
235
- """
236
- Insert a refresh token in the DB as well as user attributes
229
+ """Insert a refresh token in the DB as well as user attributes
237
230
  required to generate access tokens.
238
231
  """
239
232
  # Generate a JWT ID
@@ -257,9 +250,7 @@ class AuthDB(BaseSQLDB):
257
250
  return jti, row.creation_time
258
251
 
259
252
  async def get_refresh_token(self, jti: str) -> dict:
260
- """
261
- Get refresh token details bound to a given JWT ID
262
- """
253
+ """Get refresh token details bound to a given JWT ID."""
263
254
  # The with_for_update
264
255
  # prevents that the token is retrieved
265
256
  # multiple time concurrently
@@ -275,7 +266,7 @@ class AuthDB(BaseSQLDB):
275
266
  return res
276
267
 
277
268
  async def get_user_refresh_tokens(self, subject: str | None = None) -> list[dict]:
278
- """Get a list of refresh token details based on a subject ID (not revoked)"""
269
+ """Get a list of refresh token details based on a subject ID (not revoked)."""
279
270
  # Get a list of refresh tokens
280
271
  stmt = select(RefreshTokens).with_for_update()
281
272
 
@@ -295,7 +286,7 @@ class AuthDB(BaseSQLDB):
295
286
  return refresh_tokens
296
287
 
297
288
  async def revoke_refresh_token(self, jti: str):
298
- """Revoke a token given by its JWT ID"""
289
+ """Revoke a token given by its JWT ID."""
299
290
  await self.conn.execute(
300
291
  update(RefreshTokens)
301
292
  .where(RefreshTokens.jti == jti)
@@ -303,7 +294,7 @@ class AuthDB(BaseSQLDB):
303
294
  )
304
295
 
305
296
  async def revoke_user_refresh_tokens(self, subject):
306
- """Revoke all the refresh tokens belonging to a user (subject ID)"""
297
+ """Revoke all the refresh tokens belonging to a user (subject ID)."""
307
298
  await self.conn.execute(
308
299
  update(RefreshTokens)
309
300
  .where(RefreshTokens.sub == subject)
@@ -15,12 +15,11 @@ Base = declarative_base()
15
15
 
16
16
 
17
17
  class FlowStatus(Enum):
18
- """
19
- The normal flow is
18
+ """The normal flow is
20
19
  PENDING -> READY -> DONE
21
20
  Pending is upon insertion
22
21
  Ready/Error is set in response to IdP
23
- Done means the user has been issued the dirac token
22
+ Done means the user has been issued the dirac token.
24
23
  """
25
24
 
26
25
  # The flow is ongoing
@@ -64,9 +63,8 @@ class AuthorizationFlows(Base):
64
63
 
65
64
 
66
65
  class RefreshTokenStatus(Enum):
67
- """
68
- The normal flow is
69
- CREATED -> REVOKED
66
+ """The normal flow is
67
+ CREATED -> REVOKED.
70
68
 
71
69
  Note1: There is no EXPIRED status as it can be calculated from a creation time
72
70
  Note2: As part of the refresh token rotation mechanism, the revoked token should be retained
@@ -82,7 +80,7 @@ class RefreshTokenStatus(Enum):
82
80
 
83
81
  class RefreshTokens(Base):
84
82
  """Store attributes bound to a refresh token, as well as specific user attributes
85
- that might be then used to generate access tokens
83
+ that might be then used to generate access tokens.
86
84
  """
87
85
 
88
86
  __tablename__ = "RefreshTokens"
diracx/db/sql/dummy/db.py CHANGED
@@ -11,8 +11,7 @@ from .schema import Cars, Owners
11
11
 
12
12
 
13
13
  class DummyDB(BaseSQLDB):
14
- """
15
- This DummyDB is just to illustrate some important aspect of writing
14
+ """This DummyDB is just to illustrate some important aspect of writing
16
15
  DB classes in DiracX.
17
16
 
18
17
  It is mostly pure SQLAlchemy, with a few convention
@@ -27,7 +26,7 @@ class DummyDB(BaseSQLDB):
27
26
  columns = [Cars.__table__.columns[x] for x in group_by]
28
27
 
29
28
  stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
30
- stmt = apply_search_filters(Cars.__table__, stmt, search)
29
+ stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
31
30
  stmt = stmt.group_by(*columns)
32
31
 
33
32
  # Execute the query