diracx-db 0.0.1a11__tar.gz → 0.0.1a13__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/PKG-INFO +1 -1
  2. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/auth/db.py +19 -12
  3. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/auth/schema.py +2 -4
  4. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/jobs/status_utility.py +1 -1
  5. diracx-db-0.0.1a13/src/diracx/db/sql/sandbox_metadata/db.py +169 -0
  6. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx_db.egg-info/PKG-INFO +1 -1
  7. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/auth/test_authorization_flow.py +2 -3
  8. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/auth/test_device_flow.py +11 -6
  9. diracx-db-0.0.1a13/tests/test_sandbox_metadata.py +173 -0
  10. diracx-db-0.0.1a11/src/diracx/db/sql/sandbox_metadata/db.py +0 -96
  11. diracx-db-0.0.1a11/tests/test_sandbox_metadata.py +0 -92
  12. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/README.md +0 -0
  13. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/pyproject.toml +0 -0
  14. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/setup.cfg +0 -0
  15. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/__init__.py +0 -0
  16. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/__main__.py +0 -0
  17. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/exceptions.py +0 -0
  18. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/os/__init__.py +0 -0
  19. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/os/job_parameters.py +0 -0
  20. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/os/utils.py +0 -0
  21. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/py.typed +0 -0
  22. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/__init__.py +0 -0
  23. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/auth/__init__.py +0 -0
  24. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/dummy/__init__.py +0 -0
  25. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/dummy/db.py +0 -0
  26. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/dummy/schema.py +0 -0
  27. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/jobs/__init__.py +0 -0
  28. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/jobs/db.py +0 -0
  29. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/jobs/schema.py +0 -0
  30. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/sandbox_metadata/__init__.py +0 -0
  31. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/sandbox_metadata/schema.py +0 -0
  32. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx/db/sql/utils.py +0 -0
  33. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx_db.egg-info/SOURCES.txt +0 -0
  34. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx_db.egg-info/dependency_links.txt +0 -0
  35. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx_db.egg-info/entry_points.txt +0 -0
  36. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx_db.egg-info/requires.txt +0 -0
  37. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/src/diracx_db.egg-info/top_level.txt +0 -0
  38. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/auth/test_refresh_token.py +0 -0
  39. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/jobs/test_jobDB.py +0 -0
  40. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/jobs/test_jobLoggingDB.py +0 -0
  41. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/opensearch/test_connection.py +0 -0
  42. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/opensearch/test_index_template.py +0 -0
  43. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/opensearch/test_search.py +0 -0
  44. {diracx-db-0.0.1a11 → diracx-db-0.0.1a13}/tests/test_dummyDB.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: diracx-db
3
- Version: 0.0.1a11
3
+ Version: 0.0.1a13
4
4
  Summary: TODO
5
5
  License: GPL-3.0-only
6
6
  Classifier: Intended Audience :: Science/Research
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import hashlib
3
4
  import secrets
4
5
  from datetime import datetime
5
6
  from uuid import uuid4
@@ -63,7 +64,7 @@ class AuthDB(BaseSQLDB):
63
64
  ),
64
65
  ).with_for_update()
65
66
  stmt = stmt.where(
66
- DeviceFlows.device_code == device_code,
67
+ DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(),
67
68
  )
68
69
  res = dict((await self.conn.execute(stmt)).one()._mapping)
69
70
 
@@ -74,7 +75,10 @@ class AuthDB(BaseSQLDB):
74
75
  # Update the status to Done before returning
75
76
  await self.conn.execute(
76
77
  update(DeviceFlows)
77
- .where(DeviceFlows.device_code == device_code)
78
+ .where(
79
+ DeviceFlows.device_code
80
+ == hashlib.sha256(device_code.encode()).hexdigest()
81
+ )
78
82
  .values(status=FlowStatus.DONE)
79
83
  )
80
84
  return res
@@ -110,7 +114,6 @@ class AuthDB(BaseSQLDB):
110
114
  self,
111
115
  client_id: str,
112
116
  scope: str,
113
- audience: str,
114
117
  ) -> tuple[str, str]:
115
118
  # Because the user_code might be short, there is a risk of conflicts
116
119
  # This is why we retry multiple times
@@ -119,14 +122,16 @@ class AuthDB(BaseSQLDB):
119
122
  secrets.choice(USER_CODE_ALPHABET)
120
123
  for _ in range(DeviceFlows.user_code.type.length) # type: ignore
121
124
  )
122
- # user_code = "2QRKPY"
123
125
  device_code = secrets.token_urlsafe()
126
+
127
+ # Hash the the device_code to avoid leaking information
128
+ hashed_device_code = hashlib.sha256(device_code.encode()).hexdigest()
129
+
124
130
  stmt = insert(DeviceFlows).values(
125
131
  client_id=client_id,
126
132
  scope=scope,
127
- audience=audience,
128
133
  user_code=user_code,
129
- device_code=device_code,
134
+ device_code=hashed_device_code,
130
135
  )
131
136
  try:
132
137
  await self.conn.execute(stmt)
@@ -143,7 +148,6 @@ class AuthDB(BaseSQLDB):
143
148
  self,
144
149
  client_id: str,
145
150
  scope: str,
146
- audience: str,
147
151
  code_challenge: str,
148
152
  code_challenge_method: str,
149
153
  redirect_uri: str,
@@ -154,7 +158,6 @@ class AuthDB(BaseSQLDB):
154
158
  uuid=uuid,
155
159
  client_id=client_id,
156
160
  scope=scope,
157
- audience=audience,
158
161
  code_challenge=code_challenge,
159
162
  code_challenge_method=code_challenge_method,
160
163
  redirect_uri=redirect_uri,
@@ -172,7 +175,10 @@ class AuthDB(BaseSQLDB):
172
175
  :raises: AuthorizationError if no such uuid or status not pending
173
176
  """
174
177
 
178
+ # Hash the code to avoid leaking information
175
179
  code = secrets.token_urlsafe()
180
+ hashed_code = hashlib.sha256(code.encode()).hexdigest()
181
+
176
182
  stmt = update(AuthorizationFlows)
177
183
 
178
184
  stmt = stmt.where(
@@ -181,7 +187,7 @@ class AuthDB(BaseSQLDB):
181
187
  AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
182
188
  )
183
189
 
184
- stmt = stmt.values(id_token=id_token, code=code, status=FlowStatus.READY)
190
+ stmt = stmt.values(id_token=id_token, code=hashed_code, status=FlowStatus.READY)
185
191
  res = await self.conn.execute(stmt)
186
192
 
187
193
  if res.rowcount != 1:
@@ -190,15 +196,16 @@ class AuthDB(BaseSQLDB):
190
196
  stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
191
197
  stmt = stmt.where(AuthorizationFlows.uuid == uuid)
192
198
  row = (await self.conn.execute(stmt)).one()
193
- return row.code, row.redirect_uri
199
+ return code, row.redirect_uri
194
200
 
195
201
  async def get_authorization_flow(self, code: str, max_validity: int):
202
+ hashed_code = hashlib.sha256(code.encode()).hexdigest()
196
203
  # The with_for_update
197
204
  # prevents that the token is retrieved
198
205
  # multiple time concurrently
199
206
  stmt = select(AuthorizationFlows).with_for_update()
200
207
  stmt = stmt.where(
201
- AuthorizationFlows.code == code,
208
+ AuthorizationFlows.code == hashed_code,
202
209
  AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
203
210
  )
204
211
 
@@ -208,7 +215,7 @@ class AuthDB(BaseSQLDB):
208
215
  # Update the status to Done before returning
209
216
  await self.conn.execute(
210
217
  update(AuthorizationFlows)
211
- .where(AuthorizationFlows.code == code)
218
+ .where(AuthorizationFlows.code == hashed_code)
212
219
  .values(status=FlowStatus.DONE)
213
220
  )
214
221
 
@@ -45,8 +45,7 @@ class DeviceFlows(Base):
45
45
  creation_time = DateNowColumn()
46
46
  client_id = Column(String(255))
47
47
  scope = Column(String(1024))
48
- audience = Column(String(255))
49
- device_code = Column(String(128), unique=True) # hash it ?
48
+ device_code = Column(String(128), unique=True) # Should be a hash
50
49
  id_token = NullColumn(JSON())
51
50
 
52
51
 
@@ -57,11 +56,10 @@ class AuthorizationFlows(Base):
57
56
  client_id = Column(String(255))
58
57
  creation_time = DateNowColumn()
59
58
  scope = Column(String(1024))
60
- audience = Column(String(255))
61
59
  code_challenge = Column(String(255))
62
60
  code_challenge_method = Column(String(8))
63
61
  redirect_uri = Column(String(255))
64
- code = NullColumn(String(255)) # hash it ?
62
+ code = NullColumn(String(255)) # Should be a hash
65
63
  id_token = NullColumn(JSON())
66
64
 
67
65
 
@@ -272,7 +272,7 @@ async def remove_jobs(
272
272
 
273
273
  # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent
274
274
  # I think it should be done here as well
275
- await sandbox_metadata_db.unassign_sandbox_from_jobs(job_ids)
275
+ await sandbox_metadata_db.unassign_sandboxes_to_jobs(job_ids)
276
276
 
277
277
  # Remove the job from TaskQueueDB
278
278
  await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task)
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import sqlalchemy
6
+
7
+ from diracx.core.models import SandboxInfo, SandboxType, UserInfo
8
+ from diracx.db.sql.utils import BaseSQLDB, utcnow
9
+
10
+ from .schema import Base as SandboxMetadataDBBase
11
+ from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes
12
+
13
+
14
+ class SandboxMetadataDB(BaseSQLDB):
15
+ metadata = SandboxMetadataDBBase.metadata
16
+
17
+ async def upsert_owner(self, user: UserInfo) -> int:
18
+ """Get the id of the owner from the database"""
19
+ # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
20
+ stmt = sqlalchemy.select(sb_Owners.OwnerID).where(
21
+ sb_Owners.Owner == user.preferred_username,
22
+ sb_Owners.OwnerGroup == user.dirac_group,
23
+ sb_Owners.VO == user.vo,
24
+ )
25
+ result = await self.conn.execute(stmt)
26
+ if owner_id := result.scalar_one_or_none():
27
+ return owner_id
28
+
29
+ stmt = sqlalchemy.insert(sb_Owners).values(
30
+ Owner=user.preferred_username,
31
+ OwnerGroup=user.dirac_group,
32
+ VO=user.vo,
33
+ )
34
+ result = await self.conn.execute(stmt)
35
+ return result.lastrowid
36
+
37
+ @staticmethod
38
+ def get_pfn(bucket_name: str, user: UserInfo, sandbox_info: SandboxInfo) -> str:
39
+ """Get the sandbox's user namespaced and content addressed PFN"""
40
+ parts = [
41
+ "S3",
42
+ bucket_name,
43
+ user.vo,
44
+ user.dirac_group,
45
+ user.preferred_username,
46
+ f"{sandbox_info.checksum_algorithm}:{sandbox_info.checksum}.{sandbox_info.format}",
47
+ ]
48
+ return "/" + "/".join(parts)
49
+
50
+ async def insert_sandbox(
51
+ self, se_name: str, user: UserInfo, pfn: str, size: int
52
+ ) -> None:
53
+ """Add a new sandbox in SandboxMetadataDB"""
54
+ # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
55
+ owner_id = await self.upsert_owner(user)
56
+ stmt = sqlalchemy.insert(sb_SandBoxes).values(
57
+ OwnerId=owner_id,
58
+ SEName=se_name,
59
+ SEPFN=pfn,
60
+ Bytes=size,
61
+ RegistrationTime=utcnow(),
62
+ LastAccessTime=utcnow(),
63
+ )
64
+ try:
65
+ result = await self.conn.execute(stmt)
66
+ except sqlalchemy.exc.IntegrityError:
67
+ await self.update_sandbox_last_access_time(se_name, pfn)
68
+ else:
69
+ assert result.rowcount == 1
70
+
71
+ async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None:
72
+ stmt = (
73
+ sqlalchemy.update(sb_SandBoxes)
74
+ .where(sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn)
75
+ .values(LastAccessTime=utcnow())
76
+ )
77
+ result = await self.conn.execute(stmt)
78
+ assert result.rowcount == 1
79
+
80
+ async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool:
81
+ """Checks if a sandbox exists and has been assigned."""
82
+ stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where(
83
+ sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn
84
+ )
85
+ result = await self.conn.execute(stmt)
86
+ is_assigned = result.scalar_one()
87
+ return is_assigned
88
+
89
+ @staticmethod
90
+ def jobid_to_entity_id(job_id: int) -> str:
91
+ """Define the entity id as 'Entity:entity_id' due to the DB definition"""
92
+ return f"Job:{job_id}"
93
+
94
+ async def get_sandbox_assigned_to_job(
95
+ self, job_id: int, sb_type: SandboxType
96
+ ) -> list[Any]:
97
+ """Get the sandbox assign to job"""
98
+ entity_id = self.jobid_to_entity_id(job_id)
99
+ stmt = (
100
+ sqlalchemy.select(sb_SandBoxes.SEPFN)
101
+ .where(sb_SandBoxes.SBId == sb_EntityMapping.SBId)
102
+ .where(
103
+ sb_EntityMapping.EntityId == entity_id,
104
+ sb_EntityMapping.Type == sb_type,
105
+ )
106
+ )
107
+ result = await self.conn.execute(stmt)
108
+ return [result.scalar()]
109
+
110
+ async def assign_sandbox_to_jobs(
111
+ self,
112
+ jobs_ids: list[int],
113
+ pfn: str,
114
+ sb_type: SandboxType,
115
+ se_name: str,
116
+ ) -> None:
117
+ """Mapp sandbox and jobs"""
118
+ for job_id in jobs_ids:
119
+ # Define the entity id as 'Entity:entity_id' due to the DB definition:
120
+ entity_id = self.jobid_to_entity_id(job_id)
121
+ select_sb_id = sqlalchemy.select(
122
+ sb_SandBoxes.SBId,
123
+ sqlalchemy.literal(entity_id).label("EntityId"),
124
+ sqlalchemy.literal(sb_type).label("Type"),
125
+ ).where(
126
+ sb_SandBoxes.SEName == se_name,
127
+ sb_SandBoxes.SEPFN == pfn,
128
+ )
129
+ stmt = sqlalchemy.insert(sb_EntityMapping).from_select(
130
+ ["SBId", "EntityId", "Type"], select_sb_id
131
+ )
132
+ await self.conn.execute(stmt)
133
+
134
+ stmt = (
135
+ sqlalchemy.update(sb_SandBoxes)
136
+ .where(sb_SandBoxes.SEPFN == pfn)
137
+ .values(Assigned=True)
138
+ )
139
+ result = await self.conn.execute(stmt)
140
+ assert result.rowcount == 1
141
+
142
+ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None:
143
+ """Delete mapping between jobs and sandboxes"""
144
+ for job_id in jobs_ids:
145
+ entity_id = self.jobid_to_entity_id(job_id)
146
+ sb_sel_stmt = sqlalchemy.select(
147
+ sb_SandBoxes.SBId,
148
+ ).where(sb_EntityMapping.EntityId == entity_id)
149
+
150
+ result = await self.conn.execute(sb_sel_stmt)
151
+ sb_ids = [row.SBId for row in result]
152
+
153
+ del_stmt = sqlalchemy.delete(sb_EntityMapping).where(
154
+ sb_EntityMapping.EntityId == entity_id
155
+ )
156
+ await self.conn.execute(del_stmt)
157
+
158
+ sb_entity_sel_stmt = sqlalchemy.select(sb_EntityMapping.SBId).where(
159
+ sb_EntityMapping.SBId.in_(sb_ids)
160
+ )
161
+ result = await self.conn.execute(sb_entity_sel_stmt)
162
+ remaining_sb_ids = [row.SBId for row in result]
163
+ if not remaining_sb_ids:
164
+ unassign_stmt = (
165
+ sqlalchemy.update(sb_SandBoxes)
166
+ .where(sb_SandBoxes.SBId.in_(sb_ids))
167
+ .values(Assigned=False)
168
+ )
169
+ await self.conn.execute(unassign_stmt)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: diracx-db
3
- Version: 0.0.1a11
3
+ Version: 0.0.1a13
4
4
  Summary: TODO
5
5
  License: GPL-3.0-only
6
6
  Classifier: Intended Audience :: Science/Research
@@ -23,7 +23,7 @@ async def test_insert_id_token(auth_db: AuthDB):
23
23
  # First insert
24
24
  async with auth_db as auth_db:
25
25
  uuid = await auth_db.insert_authorization_flow(
26
- "client_id", "scope", "audience", "code_challenge", "S256", "redirect_uri"
26
+ "client_id", "scope", "code_challenge", "S256", "redirect_uri"
27
27
  )
28
28
 
29
29
  id_token = {"sub": "myIdToken"}
@@ -68,12 +68,11 @@ async def test_insert(auth_db: AuthDB):
68
68
  # First insert
69
69
  async with auth_db as auth_db:
70
70
  uuid1 = await auth_db.insert_authorization_flow(
71
- "client_id", "scope", "audience", "code_challenge", "S256", "redirect_uri"
71
+ "client_id", "scope", "code_challenge", "S256", "redirect_uri"
72
72
  )
73
73
  uuid2 = await auth_db.insert_authorization_flow(
74
74
  "client_id2",
75
75
  "scope2",
76
- "audience2",
77
76
  "code_challenge2",
78
77
  "S256",
79
78
  "redirect_uri2",
@@ -28,20 +28,22 @@ async def test_device_user_code_collision(auth_db: AuthDB, monkeypatch):
28
28
  # First insert should work
29
29
  async with auth_db as auth_db:
30
30
  code, device = await auth_db.insert_device_flow(
31
- "client_id", "scope", "audience"
31
+ "client_id",
32
+ "scope",
32
33
  )
33
34
  assert code == "A" * USER_CODE_LENGTH
34
35
  assert device
35
36
 
36
37
  async with auth_db as auth_db:
37
38
  with pytest.raises(NotImplementedError, match="insert new device flow"):
38
- await auth_db.insert_device_flow("client_id", "scope", "audience")
39
+ await auth_db.insert_device_flow("client_id", "scope")
39
40
 
40
41
  monkeypatch.setattr(secrets, "choice", lambda _: "B")
41
42
 
42
43
  async with auth_db as auth_db:
43
44
  code, device = await auth_db.insert_device_flow(
44
- "client_id", "scope", "audience"
45
+ "client_id",
46
+ "scope",
45
47
  )
46
48
  assert code == "B" * USER_CODE_LENGTH
47
49
  assert device
@@ -59,10 +61,12 @@ async def test_device_flow_lookup(auth_db: AuthDB, monkeypatch):
59
61
  # First insert
60
62
  async with auth_db as auth_db:
61
63
  user_code1, device_code1 = await auth_db.insert_device_flow(
62
- "client_id1", "scope1", "audience1"
64
+ "client_id1",
65
+ "scope1",
63
66
  )
64
67
  user_code2, device_code2 = await auth_db.insert_device_flow(
65
- "client_id2", "scope2", "audience2"
68
+ "client_id2",
69
+ "scope2",
66
70
  )
67
71
 
68
72
  assert user_code1 != user_code2
@@ -123,7 +127,8 @@ async def test_device_flow_insert_id_token(auth_db: AuthDB):
123
127
  # First insert
124
128
  async with auth_db as auth_db:
125
129
  user_code, device_code = await auth_db.insert_device_flow(
126
- "client_id", "scope", "audience"
130
+ "client_id",
131
+ "scope",
127
132
  )
128
133
 
129
134
  # Make sure it exists, and is Pending
@@ -0,0 +1,173 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import secrets
5
+ from datetime import datetime
6
+
7
+ import pytest
8
+ import sqlalchemy
9
+
10
+ from diracx.core.models import SandboxInfo, UserInfo
11
+ from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB
12
+ from diracx.db.sql.sandbox_metadata.schema import sb_EntityMapping, sb_SandBoxes
13
+
14
+
15
+ @pytest.fixture
16
+ async def sandbox_metadata_db(tmp_path):
17
+ sandbox_metadata_db = SandboxMetadataDB("sqlite+aiosqlite:///:memory:")
18
+ async with sandbox_metadata_db.engine_context():
19
+ async with sandbox_metadata_db.engine.begin() as conn:
20
+ await conn.run_sync(sandbox_metadata_db.metadata.create_all)
21
+ yield sandbox_metadata_db
22
+
23
+
24
+ def test_get_pfn(sandbox_metadata_db: SandboxMetadataDB):
25
+ user_info = UserInfo(
26
+ sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo"
27
+ )
28
+ sandbox_info = SandboxInfo(
29
+ checksum="checksum",
30
+ checksum_algorithm="sha256",
31
+ format="tar.bz2",
32
+ size=100,
33
+ )
34
+ pfn = sandbox_metadata_db.get_pfn("bucket1", user_info, sandbox_info)
35
+ assert pfn == "/S3/bucket1/vo/group1/user1/sha256:checksum.tar.bz2"
36
+
37
+
38
+ async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB):
39
+ user_info = UserInfo(
40
+ sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo"
41
+ )
42
+ pfn1 = secrets.token_hex()
43
+
44
+ # Make sure the sandbox doesn't already exist
45
+ db_contents = await _dump_db(sandbox_metadata_db)
46
+ assert pfn1 not in db_contents
47
+ async with sandbox_metadata_db:
48
+ with pytest.raises(sqlalchemy.exc.NoResultFound):
49
+ await sandbox_metadata_db.sandbox_is_assigned(pfn1, "SandboxSE")
50
+
51
+ # Insert the sandbox
52
+ async with sandbox_metadata_db:
53
+ await sandbox_metadata_db.insert_sandbox("SandboxSE", user_info, pfn1, 100)
54
+ db_contents = await _dump_db(sandbox_metadata_db)
55
+ owner_id1, last_access_time1 = db_contents[pfn1]
56
+
57
+ # Inserting again should update the last access time
58
+ await asyncio.sleep(1) # The timestamp only has second precision
59
+ async with sandbox_metadata_db:
60
+ await sandbox_metadata_db.insert_sandbox("SandboxSE", user_info, pfn1, 100)
61
+ db_contents = await _dump_db(sandbox_metadata_db)
62
+ owner_id2, last_access_time2 = db_contents[pfn1]
63
+ assert owner_id1 == owner_id2
64
+ assert last_access_time2 > last_access_time1
65
+
66
+ # The sandbox still hasn't been assigned
67
+ async with sandbox_metadata_db:
68
+ assert not await sandbox_metadata_db.sandbox_is_assigned(pfn1, "SandboxSE")
69
+
70
+ # Inserting again should update the last access time
71
+ await asyncio.sleep(1) # The timestamp only has second precision
72
+ last_access_time3 = (await _dump_db(sandbox_metadata_db))[pfn1][1]
73
+ assert last_access_time2 == last_access_time3
74
+ async with sandbox_metadata_db:
75
+ await sandbox_metadata_db.update_sandbox_last_access_time("SandboxSE", pfn1)
76
+ last_access_time4 = (await _dump_db(sandbox_metadata_db))[pfn1][1]
77
+ assert last_access_time2 < last_access_time4
78
+
79
+
80
+ async def _dump_db(
81
+ sandbox_metadata_db: SandboxMetadataDB,
82
+ ) -> dict[str, tuple[int, datetime]]:
83
+ """Dump the contents of the sandbox metadata database
84
+
85
+ Returns a dict[pfn: str, (owner_id: int, last_access_time: datetime)]
86
+ """
87
+ async with sandbox_metadata_db:
88
+ stmt = sqlalchemy.select(
89
+ sb_SandBoxes.SEPFN, sb_SandBoxes.OwnerId, sb_SandBoxes.LastAccessTime
90
+ )
91
+ res = await sandbox_metadata_db.conn.execute(stmt)
92
+ return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res}
93
+
94
+
95
+ async def test_assign_and_unsassign_sandbox_to_jobs(
96
+ sandbox_metadata_db: SandboxMetadataDB,
97
+ ):
98
+ pfn = secrets.token_hex()
99
+ user_info = UserInfo(
100
+ sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo"
101
+ )
102
+ dummy_jobid = 666
103
+ sandbox_se = "SandboxSE"
104
+ # Insert the sandbox
105
+ async with sandbox_metadata_db:
106
+ await sandbox_metadata_db.insert_sandbox(sandbox_se, user_info, pfn, 100)
107
+
108
+ async with sandbox_metadata_db:
109
+ stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN)
110
+ res = await sandbox_metadata_db.conn.execute(stmt)
111
+ db_contents = {row.SEPFN: row.SBId for row in res}
112
+ sb_id_1 = db_contents[pfn]
113
+ # The sandbox still hasn't been assigned
114
+ async with sandbox_metadata_db:
115
+ assert not await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se)
116
+
117
+ # Check there is no mapping
118
+ async with sandbox_metadata_db:
119
+ stmt = sqlalchemy.select(
120
+ sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type
121
+ )
122
+ res = await sandbox_metadata_db.conn.execute(stmt)
123
+ db_contents = {row.SBId: (row.EntityId, row.Type) for row in res}
124
+ assert db_contents == {}
125
+
126
+ # Assign sandbox with dummy jobid
127
+ async with sandbox_metadata_db:
128
+ await sandbox_metadata_db.assign_sandbox_to_jobs(
129
+ jobs_ids=[dummy_jobid], pfn=pfn, sb_type="Output", se_name=sandbox_se
130
+ )
131
+ # Check if sandbox and job are mapped
132
+ async with sandbox_metadata_db:
133
+ stmt = sqlalchemy.select(
134
+ sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type
135
+ )
136
+ res = await sandbox_metadata_db.conn.execute(stmt)
137
+ db_contents = {row.SBId: (row.EntityId, row.Type) for row in res}
138
+
139
+ entity_id_1, sb_type = db_contents[sb_id_1]
140
+ assert entity_id_1 == f"Job:{dummy_jobid}"
141
+ assert sb_type == "Output"
142
+
143
+ async with sandbox_metadata_db:
144
+ stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN)
145
+ res = await sandbox_metadata_db.conn.execute(stmt)
146
+ db_contents = {row.SEPFN: row.SBId for row in res}
147
+ sb_id_1 = db_contents[pfn]
148
+ # The sandbox should be assigned
149
+ async with sandbox_metadata_db:
150
+ assert await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se)
151
+
152
+ # Unassign the sandbox to job
153
+ async with sandbox_metadata_db:
154
+ await sandbox_metadata_db.unassign_sandboxes_to_jobs([dummy_jobid])
155
+
156
+ # Entity should not exists anymore
157
+ async with sandbox_metadata_db:
158
+ stmt = sqlalchemy.select(sb_EntityMapping.SBId).where(
159
+ sb_EntityMapping.EntityId == entity_id_1
160
+ )
161
+ res = await sandbox_metadata_db.conn.execute(stmt)
162
+ entity_sb_id = [row.SBId for row in res]
163
+ assert entity_sb_id == []
164
+
165
+ # Should not be assigned anymore
166
+ async with sandbox_metadata_db:
167
+ assert await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se) is False
168
+ # Check the mapping has been deleted
169
+ async with sandbox_metadata_db:
170
+ stmt = sqlalchemy.select(sb_EntityMapping.SBId)
171
+ res = await sandbox_metadata_db.conn.execute(stmt)
172
+ res_sb_id = [row.SBId for row in res]
173
+ assert sb_id_1 not in res_sb_id
@@ -1,96 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import sqlalchemy
4
- from sqlalchemy import delete
5
-
6
- from diracx.core.models import SandboxInfo, UserInfo
7
- from diracx.db.sql.utils import BaseSQLDB, utcnow
8
-
9
- from .schema import Base as SandboxMetadataDBBase
10
- from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes
11
-
12
-
13
- class SandboxMetadataDB(BaseSQLDB):
14
- metadata = SandboxMetadataDBBase.metadata
15
-
16
- async def upsert_owner(self, user: UserInfo) -> int:
17
- """Get the id of the owner from the database"""
18
- # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
19
- stmt = sqlalchemy.select(sb_Owners.OwnerID).where(
20
- sb_Owners.Owner == user.preferred_username,
21
- sb_Owners.OwnerGroup == user.dirac_group,
22
- sb_Owners.VO == user.vo,
23
- )
24
- result = await self.conn.execute(stmt)
25
- if owner_id := result.scalar_one_or_none():
26
- return owner_id
27
-
28
- stmt = sqlalchemy.insert(sb_Owners).values(
29
- Owner=user.preferred_username,
30
- OwnerGroup=user.dirac_group,
31
- VO=user.vo,
32
- )
33
- result = await self.conn.execute(stmt)
34
- return result.lastrowid
35
-
36
- @staticmethod
37
- def get_pfn(bucket_name: str, user: UserInfo, sandbox_info: SandboxInfo) -> str:
38
- """Get the sandbox's user namespaced and content addressed PFN"""
39
- parts = [
40
- "S3",
41
- bucket_name,
42
- user.vo,
43
- user.dirac_group,
44
- user.preferred_username,
45
- f"{sandbox_info.checksum_algorithm}:{sandbox_info.checksum}.{sandbox_info.format}",
46
- ]
47
- return "/" + "/".join(parts)
48
-
49
- async def insert_sandbox(
50
- self, se_name: str, user: UserInfo, pfn: str, size: int
51
- ) -> None:
52
- """Add a new sandbox in SandboxMetadataDB"""
53
- # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
54
- owner_id = await self.upsert_owner(user)
55
- stmt = sqlalchemy.insert(sb_SandBoxes).values(
56
- OwnerId=owner_id,
57
- SEName=se_name,
58
- SEPFN=pfn,
59
- Bytes=size,
60
- RegistrationTime=utcnow(),
61
- LastAccessTime=utcnow(),
62
- )
63
- try:
64
- result = await self.conn.execute(stmt)
65
- except sqlalchemy.exc.IntegrityError:
66
- await self.update_sandbox_last_access_time(se_name, pfn)
67
- else:
68
- assert result.rowcount == 1
69
-
70
- async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None:
71
- stmt = (
72
- sqlalchemy.update(sb_SandBoxes)
73
- .where(sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn)
74
- .values(LastAccessTime=utcnow())
75
- )
76
- result = await self.conn.execute(stmt)
77
- assert result.rowcount == 1
78
-
79
- async def sandbox_is_assigned(self, se_name: str, pfn: str) -> bool:
80
- """Checks if a sandbox exists and has been assigned."""
81
- stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where(
82
- sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn
83
- )
84
- result = await self.conn.execute(stmt)
85
- is_assigned = result.scalar_one()
86
- return is_assigned
87
- return True
88
-
89
- async def unassign_sandbox_from_jobs(self, job_ids: list[int]):
90
- """
91
- Unassign sandbox from jobs
92
- """
93
- stmt = delete(sb_EntityMapping).where(
94
- sb_EntityMapping.EntityId.in_(f"Job:{job_id}" for job_id in job_ids)
95
- )
96
- await self.conn.execute(stmt)
@@ -1,92 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import secrets
5
- from datetime import datetime
6
-
7
- import pytest
8
- import sqlalchemy
9
-
10
- from diracx.core.models import SandboxInfo, UserInfo
11
- from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB
12
- from diracx.db.sql.sandbox_metadata.schema import sb_SandBoxes
13
-
14
-
15
- @pytest.fixture
16
- async def sandbox_metadata_db(tmp_path):
17
- sandbox_metadata_db = SandboxMetadataDB("sqlite+aiosqlite:///:memory:")
18
- async with sandbox_metadata_db.engine_context():
19
- async with sandbox_metadata_db.engine.begin() as conn:
20
- await conn.run_sync(sandbox_metadata_db.metadata.create_all)
21
- yield sandbox_metadata_db
22
-
23
-
24
- def test_get_pfn(sandbox_metadata_db: SandboxMetadataDB):
25
- user_info = UserInfo(
26
- sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo"
27
- )
28
- sandbox_info = SandboxInfo(
29
- checksum="checksum",
30
- checksum_algorithm="sha256",
31
- format="tar.bz2",
32
- size=100,
33
- )
34
- pfn = sandbox_metadata_db.get_pfn("bucket1", user_info, sandbox_info)
35
- assert pfn == "/S3/bucket1/vo/group1/user1/sha256:checksum.tar.bz2"
36
-
37
-
38
- async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB):
39
- user_info = UserInfo(
40
- sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo"
41
- )
42
- pfn1 = secrets.token_hex()
43
-
44
- # Make sure the sandbox doesn't already exist
45
- db_contents = await _dump_db(sandbox_metadata_db)
46
- assert pfn1 not in db_contents
47
- async with sandbox_metadata_db:
48
- with pytest.raises(sqlalchemy.exc.NoResultFound):
49
- await sandbox_metadata_db.sandbox_is_assigned("SandboxSE", pfn1)
50
-
51
- # Insert the sandbox
52
- async with sandbox_metadata_db:
53
- await sandbox_metadata_db.insert_sandbox("SandboxSE", user_info, pfn1, 100)
54
- db_contents = await _dump_db(sandbox_metadata_db)
55
- owner_id1, last_access_time1 = db_contents[pfn1]
56
-
57
- # Inserting again should update the last access time
58
- await asyncio.sleep(1) # The timestamp only has second precision
59
- async with sandbox_metadata_db:
60
- await sandbox_metadata_db.insert_sandbox("SandboxSE", user_info, pfn1, 100)
61
- db_contents = await _dump_db(sandbox_metadata_db)
62
- owner_id2, last_access_time2 = db_contents[pfn1]
63
- assert owner_id1 == owner_id2
64
- assert last_access_time2 > last_access_time1
65
-
66
- # The sandbox still hasn't been assigned
67
- async with sandbox_metadata_db:
68
- assert not await sandbox_metadata_db.sandbox_is_assigned("SandboxSE", pfn1)
69
-
70
- # Inserting again should update the last access time
71
- await asyncio.sleep(1) # The timestamp only has second precision
72
- last_access_time3 = (await _dump_db(sandbox_metadata_db))[pfn1][1]
73
- assert last_access_time2 == last_access_time3
74
- async with sandbox_metadata_db:
75
- await sandbox_metadata_db.update_sandbox_last_access_time("SandboxSE", pfn1)
76
- last_access_time4 = (await _dump_db(sandbox_metadata_db))[pfn1][1]
77
- assert last_access_time2 < last_access_time4
78
-
79
-
80
- async def _dump_db(
81
- sandbox_metadata_db: SandboxMetadataDB,
82
- ) -> dict[str, tuple[int, datetime]]:
83
- """Dump the contents of the sandbox metadata database
84
-
85
- Returns a dict[pfn: str, (owner_id: int, last_access_time: datetime)]
86
- """
87
- async with sandbox_metadata_db:
88
- stmt = sqlalchemy.select(
89
- sb_SandBoxes.SEPFN, sb_SandBoxes.OwnerId, sb_SandBoxes.LastAccessTime
90
- )
91
- res = await sandbox_metadata_db.conn.execute(stmt)
92
- return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res}
File without changes
File without changes