diracx-db 0.0.1a11__py3-none-any.whl → 0.0.1a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diracx/db/sql/auth/db.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import hashlib
3
4
  import secrets
4
5
  from datetime import datetime
5
6
  from uuid import uuid4
@@ -63,7 +64,7 @@ class AuthDB(BaseSQLDB):
63
64
  ),
64
65
  ).with_for_update()
65
66
  stmt = stmt.where(
66
- DeviceFlows.device_code == device_code,
67
+ DeviceFlows.device_code == hashlib.sha256(device_code.encode()).hexdigest(),
67
68
  )
68
69
  res = dict((await self.conn.execute(stmt)).one()._mapping)
69
70
 
@@ -74,7 +75,10 @@ class AuthDB(BaseSQLDB):
74
75
  # Update the status to Done before returning
75
76
  await self.conn.execute(
76
77
  update(DeviceFlows)
77
- .where(DeviceFlows.device_code == device_code)
78
+ .where(
79
+ DeviceFlows.device_code
80
+ == hashlib.sha256(device_code.encode()).hexdigest()
81
+ )
78
82
  .values(status=FlowStatus.DONE)
79
83
  )
80
84
  return res
@@ -110,7 +114,6 @@ class AuthDB(BaseSQLDB):
110
114
  self,
111
115
  client_id: str,
112
116
  scope: str,
113
- audience: str,
114
117
  ) -> tuple[str, str]:
115
118
  # Because the user_code might be short, there is a risk of conflicts
116
119
  # This is why we retry multiple times
@@ -119,14 +122,16 @@ class AuthDB(BaseSQLDB):
119
122
  secrets.choice(USER_CODE_ALPHABET)
120
123
  for _ in range(DeviceFlows.user_code.type.length) # type: ignore
121
124
  )
122
- # user_code = "2QRKPY"
123
125
  device_code = secrets.token_urlsafe()
126
+
127
+ # Hash the the device_code to avoid leaking information
128
+ hashed_device_code = hashlib.sha256(device_code.encode()).hexdigest()
129
+
124
130
  stmt = insert(DeviceFlows).values(
125
131
  client_id=client_id,
126
132
  scope=scope,
127
- audience=audience,
128
133
  user_code=user_code,
129
- device_code=device_code,
134
+ device_code=hashed_device_code,
130
135
  )
131
136
  try:
132
137
  await self.conn.execute(stmt)
@@ -143,7 +148,6 @@ class AuthDB(BaseSQLDB):
143
148
  self,
144
149
  client_id: str,
145
150
  scope: str,
146
- audience: str,
147
151
  code_challenge: str,
148
152
  code_challenge_method: str,
149
153
  redirect_uri: str,
@@ -154,7 +158,6 @@ class AuthDB(BaseSQLDB):
154
158
  uuid=uuid,
155
159
  client_id=client_id,
156
160
  scope=scope,
157
- audience=audience,
158
161
  code_challenge=code_challenge,
159
162
  code_challenge_method=code_challenge_method,
160
163
  redirect_uri=redirect_uri,
@@ -172,7 +175,10 @@ class AuthDB(BaseSQLDB):
172
175
  :raises: AuthorizationError if no such uuid or status not pending
173
176
  """
174
177
 
178
+ # Hash the code to avoid leaking information
175
179
  code = secrets.token_urlsafe()
180
+ hashed_code = hashlib.sha256(code.encode()).hexdigest()
181
+
176
182
  stmt = update(AuthorizationFlows)
177
183
 
178
184
  stmt = stmt.where(
@@ -181,7 +187,7 @@ class AuthDB(BaseSQLDB):
181
187
  AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
182
188
  )
183
189
 
184
- stmt = stmt.values(id_token=id_token, code=code, status=FlowStatus.READY)
190
+ stmt = stmt.values(id_token=id_token, code=hashed_code, status=FlowStatus.READY)
185
191
  res = await self.conn.execute(stmt)
186
192
 
187
193
  if res.rowcount != 1:
@@ -190,15 +196,16 @@ class AuthDB(BaseSQLDB):
190
196
  stmt = select(AuthorizationFlows.code, AuthorizationFlows.redirect_uri)
191
197
  stmt = stmt.where(AuthorizationFlows.uuid == uuid)
192
198
  row = (await self.conn.execute(stmt)).one()
193
- return row.code, row.redirect_uri
199
+ return code, row.redirect_uri
194
200
 
195
201
  async def get_authorization_flow(self, code: str, max_validity: int):
202
+ hashed_code = hashlib.sha256(code.encode()).hexdigest()
196
203
  # The with_for_update
197
204
  # prevents that the token is retrieved
198
205
  # multiple time concurrently
199
206
  stmt = select(AuthorizationFlows).with_for_update()
200
207
  stmt = stmt.where(
201
- AuthorizationFlows.code == code,
208
+ AuthorizationFlows.code == hashed_code,
202
209
  AuthorizationFlows.creation_time > substract_date(seconds=max_validity),
203
210
  )
204
211
 
@@ -208,7 +215,7 @@ class AuthDB(BaseSQLDB):
208
215
  # Update the status to Done before returning
209
216
  await self.conn.execute(
210
217
  update(AuthorizationFlows)
211
- .where(AuthorizationFlows.code == code)
218
+ .where(AuthorizationFlows.code == hashed_code)
212
219
  .values(status=FlowStatus.DONE)
213
220
  )
214
221
 
@@ -45,8 +45,7 @@ class DeviceFlows(Base):
45
45
  creation_time = DateNowColumn()
46
46
  client_id = Column(String(255))
47
47
  scope = Column(String(1024))
48
- audience = Column(String(255))
49
- device_code = Column(String(128), unique=True) # hash it ?
48
+ device_code = Column(String(128), unique=True) # Should be a hash
50
49
  id_token = NullColumn(JSON())
51
50
 
52
51
 
@@ -57,11 +56,10 @@ class AuthorizationFlows(Base):
57
56
  client_id = Column(String(255))
58
57
  creation_time = DateNowColumn()
59
58
  scope = Column(String(1024))
60
- audience = Column(String(255))
61
59
  code_challenge = Column(String(255))
62
60
  code_challenge_method = Column(String(8))
63
61
  redirect_uri = Column(String(255))
64
- code = NullColumn(String(255)) # hash it ?
62
+ code = NullColumn(String(255)) # Should be a hash
65
63
  id_token = NullColumn(JSON())
66
64
 
67
65
 
@@ -272,7 +272,7 @@ async def remove_jobs(
272
272
 
273
273
  # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent
274
274
  # I think it should be done here as well
275
- await sandbox_metadata_db.unassign_sandbox_from_jobs(job_ids)
275
+ await sandbox_metadata_db.unassign_sandboxes_to_jobs(job_ids)
276
276
 
277
277
  # Remove the job from TaskQueueDB
278
278
  await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task)
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from typing import Any
4
+
3
5
  import sqlalchemy
4
- from sqlalchemy import delete
5
6
 
6
- from diracx.core.models import SandboxInfo, UserInfo
7
+ from diracx.core.models import SandboxInfo, SandboxType, UserInfo
7
8
  from diracx.db.sql.utils import BaseSQLDB, utcnow
8
9
 
9
10
  from .schema import Base as SandboxMetadataDBBase
@@ -76,7 +77,7 @@ class SandboxMetadataDB(BaseSQLDB):
76
77
  result = await self.conn.execute(stmt)
77
78
  assert result.rowcount == 1
78
79
 
79
- async def sandbox_is_assigned(self, se_name: str, pfn: str) -> bool:
80
+ async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool:
80
81
  """Checks if a sandbox exists and has been assigned."""
81
82
  stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where(
82
83
  sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn
@@ -84,13 +85,85 @@ class SandboxMetadataDB(BaseSQLDB):
84
85
  result = await self.conn.execute(stmt)
85
86
  is_assigned = result.scalar_one()
86
87
  return is_assigned
87
- return True
88
-
89
- async def unassign_sandbox_from_jobs(self, job_ids: list[int]):
90
- """
91
- Unassign sandbox from jobs
92
- """
93
- stmt = delete(sb_EntityMapping).where(
94
- sb_EntityMapping.EntityId.in_(f"Job:{job_id}" for job_id in job_ids)
88
+
89
+ @staticmethod
90
+ def jobid_to_entity_id(job_id: int) -> str:
91
+ """Define the entity id as 'Entity:entity_id' due to the DB definition"""
92
+ return f"Job:{job_id}"
93
+
94
+ async def get_sandbox_assigned_to_job(
95
+ self, job_id: int, sb_type: SandboxType
96
+ ) -> list[Any]:
97
+ """Get the sandbox assign to job"""
98
+ entity_id = self.jobid_to_entity_id(job_id)
99
+ stmt = (
100
+ sqlalchemy.select(sb_SandBoxes.SEPFN)
101
+ .where(sb_SandBoxes.SBId == sb_EntityMapping.SBId)
102
+ .where(
103
+ sb_EntityMapping.EntityId == entity_id,
104
+ sb_EntityMapping.Type == sb_type,
105
+ )
95
106
  )
96
- await self.conn.execute(stmt)
107
+ result = await self.conn.execute(stmt)
108
+ return [result.scalar()]
109
+
110
+ async def assign_sandbox_to_jobs(
111
+ self,
112
+ jobs_ids: list[int],
113
+ pfn: str,
114
+ sb_type: SandboxType,
115
+ se_name: str,
116
+ ) -> None:
117
+ """Mapp sandbox and jobs"""
118
+ for job_id in jobs_ids:
119
+ # Define the entity id as 'Entity:entity_id' due to the DB definition:
120
+ entity_id = self.jobid_to_entity_id(job_id)
121
+ select_sb_id = sqlalchemy.select(
122
+ sb_SandBoxes.SBId,
123
+ sqlalchemy.literal(entity_id).label("EntityId"),
124
+ sqlalchemy.literal(sb_type).label("Type"),
125
+ ).where(
126
+ sb_SandBoxes.SEName == se_name,
127
+ sb_SandBoxes.SEPFN == pfn,
128
+ )
129
+ stmt = sqlalchemy.insert(sb_EntityMapping).from_select(
130
+ ["SBId", "EntityId", "Type"], select_sb_id
131
+ )
132
+ await self.conn.execute(stmt)
133
+
134
+ stmt = (
135
+ sqlalchemy.update(sb_SandBoxes)
136
+ .where(sb_SandBoxes.SEPFN == pfn)
137
+ .values(Assigned=True)
138
+ )
139
+ result = await self.conn.execute(stmt)
140
+ assert result.rowcount == 1
141
+
142
+ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None:
143
+ """Delete mapping between jobs and sandboxes"""
144
+ for job_id in jobs_ids:
145
+ entity_id = self.jobid_to_entity_id(job_id)
146
+ sb_sel_stmt = sqlalchemy.select(
147
+ sb_SandBoxes.SBId,
148
+ ).where(sb_EntityMapping.EntityId == entity_id)
149
+
150
+ result = await self.conn.execute(sb_sel_stmt)
151
+ sb_ids = [row.SBId for row in result]
152
+
153
+ del_stmt = sqlalchemy.delete(sb_EntityMapping).where(
154
+ sb_EntityMapping.EntityId == entity_id
155
+ )
156
+ await self.conn.execute(del_stmt)
157
+
158
+ sb_entity_sel_stmt = sqlalchemy.select(sb_EntityMapping.SBId).where(
159
+ sb_EntityMapping.SBId.in_(sb_ids)
160
+ )
161
+ result = await self.conn.execute(sb_entity_sel_stmt)
162
+ remaining_sb_ids = [row.SBId for row in result]
163
+ if not remaining_sb_ids:
164
+ unassign_stmt = (
165
+ sqlalchemy.update(sb_SandBoxes)
166
+ .where(sb_SandBoxes.SBId.in_(sb_ids))
167
+ .values(Assigned=False)
168
+ )
169
+ await self.conn.execute(unassign_stmt)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: diracx-db
3
- Version: 0.0.1a11
3
+ Version: 0.0.1a13
4
4
  Summary: TODO
5
5
  License: GPL-3.0-only
6
6
  Classifier: Intended Audience :: Science/Research
@@ -8,20 +8,20 @@ diracx/db/os/utils.py,sha256=mau0_2uRi-I3geefmKQRWFKo4JcIkIUADvnwBiQX700,9129
8
8
  diracx/db/sql/__init__.py,sha256=R6tk5lo1EHbt8joGDesesYHcc1swIq9T4AaSixhh7lA,252
9
9
  diracx/db/sql/utils.py,sha256=BuXjIuXN-_v8YkCoMoMhw2tHVUqG6lTBx-e4VEYWE8o,7857
10
10
  diracx/db/sql/auth/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- diracx/db/sql/auth/db.py,sha256=7GAdwD7g2YR86yQkP9ecuIbFd1h_wxVKnF_GfXfZqLA,9915
12
- diracx/db/sql/auth/schema.py,sha256=wutCjZ_uz21J0HHZjwoOXq3cLdlNY2lCR390yIJ_T60,2891
11
+ diracx/db/sql/auth/db.py,sha256=mKjy5B8orw0yu6nOwxyzbBqyeE-J9iYq6fKjuELmr9g,10273
12
+ diracx/db/sql/auth/schema.py,sha256=JCkSa2IRzqMHTpaSc9aB9h33XsFyEM_Ohsenex6xagY,2835
13
13
  diracx/db/sql/dummy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  diracx/db/sql/dummy/db.py,sha256=5PIPv6aKY7CGIwmvnGKowjVr9ZQWpbjFSd2PIX7YOUw,1627
15
15
  diracx/db/sql/dummy/schema.py,sha256=uEkGDNVZbmJecytkHY1CO-M1MiKxe5w1_h0joJMPC9E,680
16
16
  diracx/db/sql/jobs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  diracx/db/sql/jobs/db.py,sha256=Y_2mx5kPTeuz6nxXVwGLzTssKsIH6nfnoTvWvilSgxA,29876
18
18
  diracx/db/sql/jobs/schema.py,sha256=YkxIdjTkvLlEZ9IQt86nj80eMvOPbcrfk9aisjmNpqY,9275
19
- diracx/db/sql/jobs/status_utility.py,sha256=0kAt623nh1O5wgsgktctdCmHEynO1nU0vn-7zakNeOA,10525
19
+ diracx/db/sql/jobs/status_utility.py,sha256=_3Wdd11ShA4Z6HKr0_D_o8-zPZhdzgFpZSYAyYkH4Q0,10525
20
20
  diracx/db/sql/sandbox_metadata/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
- diracx/db/sql/sandbox_metadata/db.py,sha256=HjlbnsT4cRMuFAcTL_sK3IqCehA7zISzR_d7xIGZoNk,3498
21
+ diracx/db/sql/sandbox_metadata/db.py,sha256=0EDFMfOW_O3pEPTShqBCME9z4j-JKpyYM6-BBccr27E,6303
22
22
  diracx/db/sql/sandbox_metadata/schema.py,sha256=rngYYkJxBhjETBHGLD1CTipDGe44mRYR0wdaFoAJwp0,1400
23
- diracx_db-0.0.1a11.dist-info/METADATA,sha256=xRS26odR-83dZ1cj6lfGBEuzufrq59vKglTFvVFvVd4,681
24
- diracx_db-0.0.1a11.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
25
- diracx_db-0.0.1a11.dist-info/entry_points.txt,sha256=xEFGu_zgmPgQPlUeFtdahQfQIboJ1ugFOK8eMio9gtw,271
26
- diracx_db-0.0.1a11.dist-info/top_level.txt,sha256=vJx10tdRlBX3rF2Psgk5jlwVGZNcL3m_7iQWwgPXt-U,7
27
- diracx_db-0.0.1a11.dist-info/RECORD,,
23
+ diracx_db-0.0.1a13.dist-info/METADATA,sha256=jmbXQvJykcvn3vGnxvO8GUGP3D3yjL-cXZwqXXJkzP4,681
24
+ diracx_db-0.0.1a13.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
25
+ diracx_db-0.0.1a13.dist-info/entry_points.txt,sha256=xEFGu_zgmPgQPlUeFtdahQfQIboJ1ugFOK8eMio9gtw,271
26
+ diracx_db-0.0.1a13.dist-info/top_level.txt,sha256=vJx10tdRlBX3rF2Psgk5jlwVGZNcL3m_7iQWwgPXt-U,7
27
+ diracx_db-0.0.1a13.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.42.0)
2
+ Generator: bdist_wheel (0.43.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5