diracx-db 0.0.1a17__py3-none-any.whl → 0.0.1a18__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
diracx/db/os/utils.py CHANGED
@@ -7,9 +7,10 @@ import json
7
7
  import logging
8
8
  import os
9
9
  from abc import ABCMeta, abstractmethod
10
+ from collections.abc import AsyncIterator
10
11
  from contextvars import ContextVar
11
12
  from datetime import datetime
12
- from typing import Any, AsyncIterator, Self
13
+ from typing import Any, Self
13
14
 
14
15
  from opensearchpy import AsyncOpenSearch
15
16
 
@@ -29,6 +30,48 @@ class OpenSearchDBUnavailable(DBUnavailable, OpenSearchDBError):
29
30
 
30
31
 
31
32
  class BaseOSDB(metaclass=ABCMeta):
33
+ """This should be the base class of all the OpenSearch DiracX DBs.
34
+
35
+ The details covered here should be handled automatically by the service and
36
+ task machinery of DiracX and this documentation exists for informational
37
+ purposes.
38
+
39
+ The available OpenSearch databases are discovered by calling `BaseOSDB.available_urls`.
40
+ This method returns a dictionary of database names to connection parameters.
41
+ The available databases are determined by the `diracx.db.os` entrypoint in
42
+ the `pyproject.toml` file and the connection parameters are taken from the
43
+ environment variables prefixed with `DIRACX_OS_DB_{DB_NAME}`.
44
+
45
+ If extensions to DiracX are being used, there can be multiple implementations
46
+ of the same database. To list the available implementations use
47
+ `BaseOSDB.available_implementations(db_name)`. The first entry in this list
48
+ will be the preferred implementation and it can be initialized by calling
49
+ its `__init__` function with the connection parameters previously obtained
50
+ from `BaseOSDB.available_urls`.
51
+
52
+ To control the lifetime of the OpenSearch client, the `BaseOSDB.client_context`
53
+ asynchronous context manager should be entered. When inside this context
54
+ manager, the client can be accessed with `BaseOSDB.client`.
55
+
56
+ Upon entering, the DB class can then be used as an asynchronous context
57
+ manager to perform operations. Currently this context manager has no effect
58
+ however it must be used as it may be used in future. When inside this
59
+ context manager, the DB connection can be accessed with `BaseOSDB.client`.
60
+
61
+ For example:
62
+
63
+ ```python
64
+ db_name = ...
65
+ conn_params = BaseOSDB.available_urls()[db_name]
66
+ MyDBClass = BaseOSDB.available_implementations(db_name)[0]
67
+
68
+ db = MyDBClass(conn_params)
69
+ async with db.client_context:
70
+ async with db:
71
+ # Do something with the OpenSearch client
72
+ ```
73
+ """
74
+
32
75
  # TODO: Make metadata an abstract property
33
76
  fields: dict
34
77
  index_prefix: str
@@ -77,13 +120,15 @@ class BaseOSDB(metaclass=ABCMeta):
77
120
  @classmethod
78
121
  def session(cls) -> Self:
79
122
  """This is just a fake method such that the Dependency overwrite has
80
- a hash to use"""
123
+ a hash to use.
124
+ """
81
125
  raise NotImplementedError("This should never be called")
82
126
 
83
127
  @property
84
128
  def client(self) -> AsyncOpenSearch:
85
129
  """Just a getter for _client, making sure we entered
86
- the context manager"""
130
+ the context manager.
131
+ """
87
132
  if self._client is None:
88
133
  raise RuntimeError(f"{self.__class__} was used before entering")
89
134
  return self._client
@@ -91,17 +136,18 @@ class BaseOSDB(metaclass=ABCMeta):
91
136
  @contextlib.asynccontextmanager
92
137
  async def client_context(self) -> AsyncIterator[None]:
93
138
  """Context manage to manage the client lifecycle.
94
- This is called when starting fastapi
139
+ This is called when starting fastapi.
95
140
 
96
141
  """
97
142
  assert self._client is None, "client_context cannot be nested"
98
143
  async with AsyncOpenSearch(**self._connection_kwargs) as self._client:
99
- yield
100
- self._client = None
144
+ try:
145
+ yield
146
+ finally:
147
+ self._client = None
101
148
 
102
149
  async def ping(self):
103
- """
104
- Check whether the connection to the DB is still working.
150
+ """Check whether the connection to the DB is still working.
105
151
  We could enable the ``pre_ping`` in the engine, but this would
106
152
  be ran at every query.
107
153
  """
@@ -113,7 +159,7 @@ class BaseOSDB(metaclass=ABCMeta):
113
159
  async def __aenter__(self):
114
160
  """This is entered on every request.
115
161
  At the moment it does nothing, however, we keep it here
116
- in case we ever want to use OpenSearch equivalent of a transaction
162
+ in case we ever want to use OpenSearch equivalent of a transaction.
117
163
  """
118
164
  assert not self._conn.get(), "BaseOSDB context cannot be nested"
119
165
  assert self._client is not None, "client_context hasn't been entered"
@@ -122,9 +168,7 @@ class BaseOSDB(metaclass=ABCMeta):
122
168
 
123
169
  async def __aexit__(self, exc_type, exc, tb):
124
170
  assert self._conn.get()
125
- self._client = None
126
171
  self._conn.set(False)
127
- return
128
172
 
129
173
  async def create_index_template(self) -> None:
130
174
  template_body = {
@@ -237,6 +281,11 @@ def apply_search_filters(db_fields, search):
237
281
  operator, field_name, field_type, {"keyword", "long", "date"}
238
282
  )
239
283
  result["must"].append({"terms": {field_name: query["values"]}})
284
+ case "not in":
285
+ require_type(
286
+ operator, field_name, field_type, {"keyword", "long", "date"}
287
+ )
288
+ result["must_not"].append({"terms": {field_name: query["values"]}})
240
289
  # TODO: Implement like and ilike
241
290
  # If the pattern is a simple "col like 'abc%'", we can use a prefix query
242
291
  # Else we need to use a wildcard query where we replace % with * and _ with ?
diracx/db/sql/__init__.py CHANGED
@@ -3,5 +3,7 @@ from __future__ import annotations
3
3
  __all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB")
4
4
 
5
5
  from .auth.db import AuthDB
6
- from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB
6
+ from .job.db import JobDB
7
+ from .job_logging.db import JobLoggingDB
7
8
  from .sandbox_metadata.db import SandboxMetadataDB
9
+ from .task_queue.db import TaskQueueDB
diracx/db/sql/auth/db.py CHANGED
@@ -35,7 +35,7 @@ class AuthDB(BaseSQLDB):
35
35
  async def device_flow_validate_user_code(
36
36
  self, user_code: str, max_validity: int
37
37
  ) -> str:
38
- """Validate that the user_code can be used (Pending status, not expired)
38
+ """Validate that the user_code can be used (Pending status, not expired).
39
39
 
40
40
  Returns the scope field for the given user_code
41
41
 
@@ -51,9 +51,7 @@ class AuthDB(BaseSQLDB):
51
51
  return (await self.conn.execute(stmt)).scalar_one()
52
52
 
53
53
  async def get_device_flow(self, device_code: str, max_validity: int):
54
- """
55
- :raises: NoResultFound
56
- """
54
+ """:raises: NoResultFound"""
57
55
  # The with_for_update
58
56
  # prevents that the token is retrieved
59
57
  # multiple time concurrently
@@ -94,9 +92,7 @@ class AuthDB(BaseSQLDB):
94
92
  async def device_flow_insert_id_token(
95
93
  self, user_code: str, id_token: dict[str, str], max_validity: int
96
94
  ) -> None:
97
- """
98
- :raises: AuthorizationError if no such code or status not pending
99
- """
95
+ """:raises: AuthorizationError if no such code or status not pending"""
100
96
  stmt = update(DeviceFlows)
101
97
  stmt = stmt.where(
102
98
  DeviceFlows.user_code == user_code,
@@ -170,11 +166,9 @@ class AuthDB(BaseSQLDB):
170
166
  async def authorization_flow_insert_id_token(
171
167
  self, uuid: str, id_token: dict[str, str], max_validity: int
172
168
  ) -> tuple[str, str]:
169
+ """Returns code, redirect_uri
170
+ :raises: AuthorizationError if no such uuid or status not pending.
173
171
  """
174
- returns code, redirect_uri
175
- :raises: AuthorizationError if no such uuid or status not pending
176
- """
177
-
178
172
  # Hash the code to avoid leaking information
179
173
  code = secrets.token_urlsafe()
180
174
  hashed_code = hashlib.sha256(code.encode()).hexdigest()
@@ -232,8 +226,7 @@ class AuthDB(BaseSQLDB):
232
226
  preferred_username: str,
233
227
  scope: str,
234
228
  ) -> tuple[str, datetime]:
235
- """
236
- Insert a refresh token in the DB as well as user attributes
229
+ """Insert a refresh token in the DB as well as user attributes
237
230
  required to generate access tokens.
238
231
  """
239
232
  # Generate a JWT ID
@@ -257,9 +250,7 @@ class AuthDB(BaseSQLDB):
257
250
  return jti, row.creation_time
258
251
 
259
252
  async def get_refresh_token(self, jti: str) -> dict:
260
- """
261
- Get refresh token details bound to a given JWT ID
262
- """
253
+ """Get refresh token details bound to a given JWT ID."""
263
254
  # The with_for_update
264
255
  # prevents that the token is retrieved
265
256
  # multiple time concurrently
@@ -275,7 +266,7 @@ class AuthDB(BaseSQLDB):
275
266
  return res
276
267
 
277
268
  async def get_user_refresh_tokens(self, subject: str | None = None) -> list[dict]:
278
- """Get a list of refresh token details based on a subject ID (not revoked)"""
269
+ """Get a list of refresh token details based on a subject ID (not revoked)."""
279
270
  # Get a list of refresh tokens
280
271
  stmt = select(RefreshTokens).with_for_update()
281
272
 
@@ -295,7 +286,7 @@ class AuthDB(BaseSQLDB):
295
286
  return refresh_tokens
296
287
 
297
288
  async def revoke_refresh_token(self, jti: str):
298
- """Revoke a token given by its JWT ID"""
289
+ """Revoke a token given by its JWT ID."""
299
290
  await self.conn.execute(
300
291
  update(RefreshTokens)
301
292
  .where(RefreshTokens.jti == jti)
@@ -303,7 +294,7 @@ class AuthDB(BaseSQLDB):
303
294
  )
304
295
 
305
296
  async def revoke_user_refresh_tokens(self, subject):
306
- """Revoke all the refresh tokens belonging to a user (subject ID)"""
297
+ """Revoke all the refresh tokens belonging to a user (subject ID)."""
307
298
  await self.conn.execute(
308
299
  update(RefreshTokens)
309
300
  .where(RefreshTokens.sub == subject)
@@ -15,12 +15,11 @@ Base = declarative_base()
15
15
 
16
16
 
17
17
  class FlowStatus(Enum):
18
- """
19
- The normal flow is
18
+ """The normal flow is
20
19
  PENDING -> READY -> DONE
21
20
  Pending is upon insertion
22
21
  Ready/Error is set in response to IdP
23
- Done means the user has been issued the dirac token
22
+ Done means the user has been issued the dirac token.
24
23
  """
25
24
 
26
25
  # The flow is ongoing
@@ -64,9 +63,8 @@ class AuthorizationFlows(Base):
64
63
 
65
64
 
66
65
  class RefreshTokenStatus(Enum):
67
- """
68
- The normal flow is
69
- CREATED -> REVOKED
66
+ """The normal flow is
67
+ CREATED -> REVOKED.
70
68
 
71
69
  Note1: There is no EXPIRED status as it can be calculated from a creation time
72
70
  Note2: As part of the refresh token rotation mechanism, the revoked token should be retained
@@ -82,7 +80,7 @@ class RefreshTokenStatus(Enum):
82
80
 
83
81
  class RefreshTokens(Base):
84
82
  """Store attributes bound to a refresh token, as well as specific user attributes
85
- that might be then used to generate access tokens
83
+ that might be then used to generate access tokens.
86
84
  """
87
85
 
88
86
  __tablename__ = "RefreshTokens"
diracx/db/sql/dummy/db.py CHANGED
@@ -11,8 +11,7 @@ from .schema import Cars, Owners
11
11
 
12
12
 
13
13
  class DummyDB(BaseSQLDB):
14
- """
15
- This DummyDB is just to illustrate some important aspect of writing
14
+ """This DummyDB is just to illustrate some important aspect of writing
16
15
  DB classes in DiracX.
17
16
 
18
17
  It is mostly pure SQLAlchemy, with a few convention
@@ -27,7 +26,7 @@ class DummyDB(BaseSQLDB):
27
26
  columns = [Cars.__table__.columns[x] for x in group_by]
28
27
 
29
28
  stmt = select(*columns, func.count(Cars.licensePlate).label("count"))
30
- stmt = apply_search_filters(Cars.__table__, stmt, search)
29
+ stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
31
30
  stmt = stmt.group_by(*columns)
32
31
 
33
32
  # Execute the query