diracx-db 0.0.1a45__tar.gz → 0.0.1a47__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/PKG-INFO +1 -1
  2. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/__main__.py +0 -1
  3. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/os/job_parameters.py +9 -3
  4. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/dummy/db.py +3 -14
  5. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/job/db.py +15 -54
  6. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/job_logging/db.py +0 -1
  7. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/pilot_agents/db.py +0 -1
  8. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/sandbox_metadata/db.py +5 -1
  9. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/utils/__init__.py +2 -0
  10. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/utils/base.py +92 -3
  11. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/jobs/test_job_db.py +49 -0
  12. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/opensearch/test_index_template.py +4 -2
  13. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/pilot_agents/test_pilot_agents_db.py +0 -1
  14. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/test_dummy_db.py +1 -1
  15. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/.gitignore +0 -0
  16. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/README.md +0 -0
  17. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/pyproject.toml +0 -0
  18. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/__init__.py +0 -0
  19. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/exceptions.py +0 -0
  20. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/os/__init__.py +0 -0
  21. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/os/utils.py +0 -0
  22. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/py.typed +0 -0
  23. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/__init__.py +0 -0
  24. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/auth/__init__.py +0 -0
  25. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/auth/db.py +0 -0
  26. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/auth/schema.py +0 -0
  27. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/dummy/__init__.py +0 -0
  28. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/dummy/schema.py +0 -0
  29. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/job/__init__.py +0 -0
  30. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/job/schema.py +0 -0
  31. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/job_logging/__init__.py +0 -0
  32. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/job_logging/schema.py +0 -0
  33. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/pilot_agents/__init__.py +0 -0
  34. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/pilot_agents/schema.py +0 -0
  35. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/sandbox_metadata/__init__.py +0 -0
  36. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/sandbox_metadata/schema.py +0 -0
  37. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/task_queue/__init__.py +0 -0
  38. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/task_queue/db.py +0 -0
  39. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/task_queue/schema.py +0 -0
  40. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/utils/functions.py +0 -0
  41. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/src/diracx/db/sql/utils/types.py +0 -0
  42. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/auth/test_authorization_flow.py +0 -0
  43. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/auth/test_device_flow.py +0 -0
  44. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/auth/test_refresh_token.py +0 -0
  45. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/jobs/test_job_logging_db.py +0 -0
  46. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/jobs/test_sandbox_metadata.py +0 -0
  47. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/opensearch/test_connection.py +0 -0
  48. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/opensearch/test_search.py +0 -0
  49. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/pilot_agents/__init__.py +0 -0
  50. {diracx_db-0.0.1a45 → diracx_db-0.0.1a47}/tests/test_freeze_time.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diracx-db
3
- Version: 0.0.1a45
3
+ Version: 0.0.1a47
4
4
  Summary: TODO
5
5
  License: GPL-3.0-only
6
6
  Classifier: Intended Audience :: Science/Research
@@ -31,7 +31,6 @@ async def init_sql():
31
31
  from diracx.db.sql.utils import BaseSQLDB
32
32
 
33
33
  for db_name, db_url in BaseSQLDB.available_urls().items():
34
-
35
34
  logger.info("Initialising %s", db_name)
36
35
  db = BaseSQLDB.available_implementations(db_name)[0](db_url)
37
36
  async with db.engine_context():
@@ -9,13 +9,19 @@ class JobParametersDB(BaseOSDB):
9
9
  fields = {
10
10
  "JobID": {"type": "long"},
11
11
  "timestamp": {"type": "date"},
12
+ "PilotAgent": {"type": "keyword"},
13
+ "Pilot_Reference": {"type": "keyword"},
14
+ "JobGroup": {"type": "keyword"},
12
15
  "CPUNormalizationFactor": {"type": "long"},
13
16
  "NormCPUTime(s)": {"type": "long"},
14
- "Memory(kB)": {"type": "long"},
17
+ "Memory(MB)": {"type": "long"},
18
+ "LocalAccount": {"type": "keyword"},
15
19
  "TotalCPUTime(s)": {"type": "long"},
16
- "MemoryUsed(kb)": {"type": "long"},
17
- "HostName": {"type": "keyword"},
20
+ "PayloadPID": {"type": "long"},
21
+ "HostName": {"type": "text"},
18
22
  "GridCE": {"type": "keyword"},
23
+ "CEQueue": {"type": "keyword"},
24
+ "BatchSystem": {"type": "keyword"},
19
25
  "ModelName": {"type": "keyword"},
20
26
  "Status": {"type": "keyword"},
21
27
  "JobType": {"type": "keyword"},
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from sqlalchemy import func, insert, select
3
+ from sqlalchemy import insert
4
4
  from uuid_utils import UUID
5
5
 
6
- from diracx.db.sql.utils import BaseSQLDB, apply_search_filters
6
+ from diracx.db.sql.utils import BaseSQLDB
7
7
 
8
8
  from .schema import Base as DummyDBBase
9
9
  from .schema import Cars, Owners
@@ -22,18 +22,7 @@ class DummyDB(BaseSQLDB):
22
22
  metadata = DummyDBBase.metadata
23
23
 
24
24
  async def summary(self, group_by, search) -> list[dict[str, str | int]]:
25
- columns = [Cars.__table__.columns[x] for x in group_by]
26
-
27
- stmt = select(*columns, func.count(Cars.license_plate).label("count"))
28
- stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search)
29
- stmt = stmt.group_by(*columns)
30
-
31
- # Execute the query
32
- return [
33
- dict(row._mapping)
34
- async for row in (await self.conn.stream(stmt))
35
- if row.count > 0 # type: ignore
36
- ]
25
+ return await self._summary(Cars, group_by, search)
37
26
 
38
27
  async def insert_owner(self, name: str) -> int:
39
28
  stmt = insert(Owners).values(name=name)
@@ -5,7 +5,7 @@ __all__ = ["JobDB"]
5
5
  from datetime import datetime, timezone
6
6
  from typing import TYPE_CHECKING, Any, Iterable
7
7
 
8
- from sqlalchemy import bindparam, case, delete, func, insert, select, update
8
+ from sqlalchemy import bindparam, case, delete, insert, select, update
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from sqlalchemy.sql.elements import BindParameter
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
13
13
  from diracx.core.exceptions import InvalidQueryError
14
14
  from diracx.core.models import JobCommand, SearchSpec, SortSpec
15
15
 
16
- from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints
16
+ from ..utils import BaseSQLDB, _get_columns
17
17
  from ..utils.functions import utcnow
18
18
  from .schema import (
19
19
  HeartBeatLoggingInfo,
@@ -25,17 +25,6 @@ from .schema import (
25
25
  )
26
26
 
27
27
 
28
- def _get_columns(table, parameters):
29
- columns = [x for x in table.columns]
30
- if parameters:
31
- if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
32
- raise InvalidQueryError(
33
- f"Unrecognised parameters requested {unrecognised_parameters}"
34
- )
35
- columns = [c for c in columns if c.name in parameters]
36
- return columns
37
-
38
-
39
28
  class JobDB(BaseSQLDB):
40
29
  metadata = JobDBBase.metadata
41
30
 
@@ -54,20 +43,11 @@ class JobDB(BaseSQLDB):
54
43
  # to find a way to make it dynamic
55
44
  jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"]
56
45
 
57
- async def summary(self, group_by, search) -> list[dict[str, str | int]]:
46
+ async def summary(
47
+ self, group_by: list[str], search: list[SearchSpec]
48
+ ) -> list[dict[str, str | int]]:
58
49
  """Get a summary of the jobs."""
59
- columns = _get_columns(Jobs.__table__, group_by)
60
-
61
- stmt = select(*columns, func.count(Jobs.job_id).label("count"))
62
- stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
63
- stmt = stmt.group_by(*columns)
64
-
65
- # Execute the query
66
- return [
67
- dict(row._mapping)
68
- async for row in (await self.conn.stream(stmt))
69
- if row.count > 0 # type: ignore
70
- ]
50
+ return await self._summary(table=Jobs, group_by=group_by, search=search)
71
51
 
72
52
  async def search(
73
53
  self,
@@ -80,34 +60,15 @@ class JobDB(BaseSQLDB):
80
60
  page: int | None = None,
81
61
  ) -> tuple[int, list[dict[Any, Any]]]:
82
62
  """Search for jobs in the database."""
83
- # Find which columns to select
84
- columns = _get_columns(Jobs.__table__, parameters)
85
-
86
- stmt = select(*columns)
87
-
88
- stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
89
- stmt = apply_sort_constraints(Jobs.__table__.columns.__getitem__, stmt, sorts)
90
-
91
- if distinct:
92
- stmt = stmt.distinct()
93
-
94
- # Calculate total count before applying pagination
95
- total_count_subquery = stmt.alias()
96
- total_count_stmt = select(func.count()).select_from(total_count_subquery)
97
- total = (await self.conn.execute(total_count_stmt)).scalar_one()
98
-
99
- # Apply pagination
100
- if page is not None:
101
- if page < 1:
102
- raise InvalidQueryError("Page must be a positive integer")
103
- if per_page < 1:
104
- raise InvalidQueryError("Per page must be a positive integer")
105
- stmt = stmt.offset((page - 1) * per_page).limit(per_page)
106
-
107
- # Execute the query
108
- return total, [
109
- dict(row._mapping) async for row in (await self.conn.stream(stmt))
110
- ]
63
+ return await self._search(
64
+ table=Jobs,
65
+ parameters=parameters,
66
+ search=search,
67
+ sorts=sorts,
68
+ distinct=distinct,
69
+ per_page=per_page,
70
+ page=page,
71
+ )
111
72
 
112
73
  async def create_job(self, compressed_original_jdl: str):
113
74
  """Used to insert a new job with original JDL. Returns inserted job id."""
@@ -89,7 +89,6 @@ class JobLoggingDB(BaseSQLDB):
89
89
  status_time,
90
90
  status_source,
91
91
  ) in rows:
92
-
93
92
  values[job_id].append(
94
93
  [
95
94
  status,
@@ -20,7 +20,6 @@ class PilotAgentsDB(BaseSQLDB):
20
20
  grid_type: str = "DIRAC",
21
21
  pilot_stamps: dict | None = None,
22
22
  ) -> None:
23
-
24
23
  if pilot_stamps is None:
25
24
  pilot_stamps = {}
26
25
 
@@ -13,6 +13,7 @@ from sqlalchemy import (
13
13
  Table,
14
14
  and_,
15
15
  delete,
16
+ exists,
16
17
  insert,
17
18
  literal,
18
19
  or_,
@@ -236,7 +237,10 @@ class SandboxMetadataDB(BaseSQLDB):
236
237
  """
237
238
  conditions = [
238
239
  # If it has assigned to a job but is no longer mapped it can be removed
239
- # and_(SandBoxes.Assigned, ~exists(SandBoxes.SBId == SBEntityMapping.SBId)),
240
+ and_(
241
+ SandBoxes.Assigned,
242
+ ~exists().where(SBEntityMapping.SBId == SandBoxes.SBId),
243
+ ),
240
244
  # If the sandbox is still unassigned after 15 days, remove it
241
245
  and_(~SandBoxes.Assigned, days_since(SandBoxes.LastAccessTime) >= 15),
242
246
  ]
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  from .base import (
4
4
  BaseSQLDB,
5
5
  SQLDBUnavailableError,
6
+ _get_columns,
6
7
  apply_search_filters,
7
8
  apply_sort_constraints,
8
9
  )
@@ -10,6 +11,7 @@ from .functions import hash, substract_date, utcnow
10
11
  from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn
11
12
 
12
13
  __all__ = (
14
+ "_get_columns",
13
15
  "utcnow",
14
16
  "Column",
15
17
  "NullColumn",
@@ -8,16 +8,20 @@ from abc import ABCMeta
8
8
  from collections.abc import AsyncIterator
9
9
  from contextvars import ContextVar
10
10
  from datetime import datetime
11
- from typing import Self, cast
11
+ from typing import Any, Self, cast
12
12
 
13
13
  from pydantic import TypeAdapter
14
- from sqlalchemy import DateTime, MetaData, select
14
+ from sqlalchemy import DateTime, MetaData, func, select
15
15
  from sqlalchemy.exc import OperationalError
16
16
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
17
17
 
18
18
  from diracx.core.exceptions import InvalidQueryError
19
19
  from diracx.core.extensions import select_from_extension
20
- from diracx.core.models import SortDirection
20
+ from diracx.core.models import (
21
+ SearchSpec,
22
+ SortDirection,
23
+ SortSpec,
24
+ )
21
25
  from diracx.core.settings import SqlalchemyDsn
22
26
  from diracx.db.exceptions import DBUnavailableError
23
27
 
@@ -227,6 +231,71 @@ class BaseSQLDB(metaclass=ABCMeta):
227
231
  except OperationalError as e:
228
232
  raise SQLDBUnavailableError("Cannot ping the DB") from e
229
233
 
234
+ async def _search(
235
+ self,
236
+ table: Any,
237
+ parameters: list[str] | None,
238
+ search: list[SearchSpec],
239
+ sorts: list[SortSpec],
240
+ *,
241
+ distinct: bool = False,
242
+ per_page: int = 100,
243
+ page: int | None = None,
244
+ ) -> tuple[int, list[dict[str, Any]]]:
245
+ """Search for elements in a table."""
246
+ # Find which columns to select
247
+ columns = _get_columns(table.__table__, parameters)
248
+
249
+ stmt = select(*columns)
250
+
251
+ stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
252
+ stmt = apply_sort_constraints(table.__table__.columns.__getitem__, stmt, sorts)
253
+
254
+ if distinct:
255
+ stmt = stmt.distinct()
256
+
257
+ # Calculate total count before applying pagination
258
+ total_count_subquery = stmt.alias()
259
+ total_count_stmt = select(func.count()).select_from(total_count_subquery)
260
+ total = (await self.conn.execute(total_count_stmt)).scalar_one()
261
+
262
+ # Apply pagination
263
+ if page is not None:
264
+ if page < 1:
265
+ raise InvalidQueryError("Page must be a positive integer")
266
+ if per_page < 1:
267
+ raise InvalidQueryError("Per page must be a positive integer")
268
+ stmt = stmt.offset((page - 1) * per_page).limit(per_page)
269
+
270
+ # Execute the query
271
+ return total, [
272
+ dict(row._mapping) async for row in (await self.conn.stream(stmt))
273
+ ]
274
+
275
+ async def _summary(
276
+ self, table: Any, group_by: list[str], search: list[SearchSpec]
277
+ ) -> list[dict[str, str | int]]:
278
+ """Get a summary of the elements of a table."""
279
+ columns = _get_columns(table.__table__, group_by)
280
+
281
+ pk_columns = list(table.__table__.primary_key.columns)
282
+ if not pk_columns:
283
+ raise ValueError(
284
+ "Model has no primary key and no count_column was provided."
285
+ )
286
+ count_col = pk_columns[0]
287
+
288
+ stmt = select(*columns, func.count(count_col).label("count"))
289
+ stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
290
+ stmt = stmt.group_by(*columns)
291
+
292
+ # Execute the query
293
+ return [
294
+ dict(row._mapping)
295
+ async for row in (await self.conn.stream(stmt))
296
+ if row.count > 0 # type: ignore
297
+ ]
298
+
230
299
 
231
300
  def find_time_resolution(value):
232
301
  if isinstance(value, datetime):
@@ -258,6 +327,17 @@ def find_time_resolution(value):
258
327
  raise InvalidQueryError(f"Cannot parse {value=}")
259
328
 
260
329
 
330
+ def _get_columns(table, parameters):
331
+ columns = [x for x in table.columns]
332
+ if parameters:
333
+ if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
334
+ raise InvalidQueryError(
335
+ f"Unrecognised parameters requested {unrecognised_parameters}"
336
+ )
337
+ columns = [c for c in columns if c.name in parameters]
338
+ return columns
339
+
340
+
261
341
  def apply_search_filters(column_mapping, stmt, search):
262
342
  for query in search:
263
343
  try:
@@ -300,6 +380,15 @@ def apply_search_filters(column_mapping, stmt, search):
300
380
  expr = column.like(query["value"])
301
381
  elif query["operator"] in "ilike":
302
382
  expr = column.ilike(query["value"])
383
+ elif query["operator"] == "not like":
384
+ expr = column.not_like(query["value"])
385
+ elif query["operator"] == "regex":
386
+ # We check the regex validity here
387
+ try:
388
+ re.compile(query["value"])
389
+ except re.error as e:
390
+ raise InvalidQueryError(f"Invalid regex {query['value']}") from e
391
+ expr = column.regexp_match(query["value"])
303
392
  else:
304
393
  raise InvalidQueryError(f"Unknown filter {query=}")
305
394
  stmt = stmt.where(expr)
@@ -187,6 +187,55 @@ async def test_search_conditions(populated_job_db):
187
187
  assert total == 0
188
188
  assert not result
189
189
 
190
+ # Search for a specific scalar condition: Owner not like 'owner1%'
191
+ condition = ScalarSearchSpec(
192
+ parameter="Owner", operator=ScalarSearchOperator.NOT_LIKE, value="owner1%"
193
+ )
194
+ total, result = await job_db.search([], [condition], [])
195
+ assert total == 100 - 11
196
+ assert result
197
+ assert len(result) == 100 - 11
198
+ assert all(not r["Owner"].startswith("owner1") for r in result)
199
+
200
+ # Search for a specific scalar condition: OwnerGroup not like 'owner_group2'
201
+ condition = ScalarSearchSpec(
202
+ parameter="OwnerGroup",
203
+ operator=ScalarSearchOperator.NOT_LIKE,
204
+ value="owner_group2",
205
+ )
206
+ total, result = await job_db.search([], [condition], [])
207
+ assert total == 100 - 50
208
+ assert result
209
+ assert len(result) == 100 - 50
210
+ assert all(not r["OwnerGroup"] == "owner_group2" for r in result)
211
+
212
+ # Search for a specific scalar condition: Owner regex '^owner\d+$'
213
+ condition = ScalarSearchSpec(
214
+ parameter="Owner", operator=ScalarSearchOperator.REGEX, value="^owner\\d+$"
215
+ )
216
+ total, result = await job_db.search([], [condition], [])
217
+ assert total == 100
218
+ assert result
219
+ assert len(result) == 100
220
+
221
+ # Search for a specific scalar condition: JobID regex 'owner[0-3]+'
222
+ # owner0, owner1, owner2, owner3 (4 jobs)
223
+ # owner11 -> owner39 (30 jobs)
224
+ condition = ScalarSearchSpec(
225
+ parameter="Owner", operator=ScalarSearchOperator.REGEX, value="owner[0-3]+"
226
+ )
227
+ total, result = await job_db.search([], [condition], [])
228
+ assert total == 34
229
+ assert result
230
+ assert len(result) == 34
231
+
232
+ # Search for a specific scalar condition: JobID regex 'owner[1-'
233
+ condition = ScalarSearchSpec(
234
+ parameter="Owner", operator=ScalarSearchOperator.REGEX, value="owner[1-"
235
+ )
236
+ with pytest.raises(InvalidQueryError):
237
+ await job_db.search([], [condition], [])
238
+
190
239
 
191
240
  async def test_search_sorts(populated_job_db):
192
241
  """Test that we can search for jobs in the database and sort the results."""
@@ -37,13 +37,15 @@ async def _get_test_index_mappings(dummy_opensearch_db: DummyOSDB):
37
37
 
38
38
  # At this point the index should not exist yet
39
39
  with pytest.raises(opensearchpy.exceptions.NotFoundError):
40
- await dummy_opensearch_db.client.indices.get_mapping(index_name)
40
+ await dummy_opensearch_db.client.indices.get_mapping(index=index_name)
41
41
 
42
42
  # Insert document which will automatically create the index based on the template
43
43
  await dummy_opensearch_db.upsert(vo, document_id, DUMMY_DOCUMENT)
44
44
 
45
45
  # Ensure the result looks as expected and return the mappings
46
- index_mapping = await dummy_opensearch_db.client.indices.get_mapping(index_name)
46
+ index_mapping = await dummy_opensearch_db.client.indices.get_mapping(
47
+ index=index_name
48
+ )
47
49
  assert list(index_mapping) == [index_name]
48
50
  assert list(index_mapping[index_name]) == ["mappings"]
49
51
  return index_mapping[index_name]["mappings"]
@@ -15,7 +15,6 @@ async def pilot_agents_db(tmp_path) -> PilotAgentsDB:
15
15
 
16
16
 
17
17
  async def test_insert_and_select(pilot_agents_db: PilotAgentsDB):
18
-
19
18
  async with pilot_agents_db as pilot_agents_db:
20
19
  # Add a pilot reference
21
20
  refs = [f"ref_{i}" for i in range(10)]
@@ -129,7 +129,7 @@ async def test_failed_transaction(dummy_db):
129
129
 
130
130
  # The connection is created when the context manager is entered
131
131
  # This is our transaction
132
- with pytest.raises(KeyError):
132
+ with pytest.raises(InvalidQueryError):
133
133
  async with dummy_db as dummy_db:
134
134
  assert dummy_db.conn
135
135
 
File without changes
File without changes