cornflow 2.0.0a10__py3-none-any.whl → 2.0.0a12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. airflow_config/airflow_local_settings.py +1 -1
  2. cornflow/app.py +8 -3
  3. cornflow/cli/migrations.py +23 -3
  4. cornflow/cli/service.py +18 -18
  5. cornflow/cli/utils.py +16 -1
  6. cornflow/commands/dag.py +1 -1
  7. cornflow/config.py +13 -8
  8. cornflow/endpoints/__init__.py +8 -2
  9. cornflow/endpoints/alarms.py +66 -2
  10. cornflow/endpoints/data_check.py +53 -26
  11. cornflow/endpoints/execution.py +387 -132
  12. cornflow/endpoints/login.py +81 -63
  13. cornflow/endpoints/meta_resource.py +11 -3
  14. cornflow/migrations/versions/999b98e24225.py +34 -0
  15. cornflow/models/base_data_model.py +4 -32
  16. cornflow/models/execution.py +2 -3
  17. cornflow/models/meta_models.py +28 -22
  18. cornflow/models/user.py +7 -10
  19. cornflow/schemas/alarms.py +8 -0
  20. cornflow/schemas/execution.py +1 -1
  21. cornflow/schemas/query.py +2 -1
  22. cornflow/schemas/user.py +5 -20
  23. cornflow/shared/authentication/auth.py +201 -264
  24. cornflow/shared/const.py +3 -14
  25. cornflow/shared/databricks.py +5 -1
  26. cornflow/tests/const.py +1 -0
  27. cornflow/tests/custom_test_case.py +77 -26
  28. cornflow/tests/unit/test_actions.py +2 -2
  29. cornflow/tests/unit/test_alarms.py +55 -1
  30. cornflow/tests/unit/test_apiview.py +108 -3
  31. cornflow/tests/unit/test_cases.py +20 -29
  32. cornflow/tests/unit/test_cli.py +6 -5
  33. cornflow/tests/unit/test_commands.py +3 -3
  34. cornflow/tests/unit/test_dags.py +5 -6
  35. cornflow/tests/unit/test_executions.py +443 -123
  36. cornflow/tests/unit/test_instances.py +14 -2
  37. cornflow/tests/unit/test_instances_file.py +1 -1
  38. cornflow/tests/unit/test_licenses.py +1 -1
  39. cornflow/tests/unit/test_log_in.py +230 -207
  40. cornflow/tests/unit/test_permissions.py +8 -8
  41. cornflow/tests/unit/test_roles.py +48 -10
  42. cornflow/tests/unit/test_schemas.py +1 -1
  43. cornflow/tests/unit/test_tables.py +7 -7
  44. cornflow/tests/unit/test_token.py +19 -5
  45. cornflow/tests/unit/test_users.py +22 -6
  46. cornflow/tests/unit/tools.py +75 -10
  47. {cornflow-2.0.0a10.dist-info → cornflow-2.0.0a12.dist-info}/METADATA +16 -15
  48. {cornflow-2.0.0a10.dist-info → cornflow-2.0.0a12.dist-info}/RECORD +51 -51
  49. {cornflow-2.0.0a10.dist-info → cornflow-2.0.0a12.dist-info}/WHEEL +1 -1
  50. cornflow/endpoints/execution_databricks.py +0 -808
  51. {cornflow-2.0.0a10.dist-info → cornflow-2.0.0a12.dist-info}/entry_points.txt +0 -0
  52. {cornflow-2.0.0a10.dist-info → cornflow-2.0.0a12.dist-info}/top_level.txt +0 -0
cornflow/schemas/user.py CHANGED
@@ -27,7 +27,7 @@ class UserEndpointResponse(Schema):
27
27
  last_name = fields.Str()
28
28
  email = fields.Str()
29
29
  created_at = fields.Str()
30
- pwd_last_change = fields.Str()
30
+ pwd_last_change = fields.DateTime()
31
31
 
32
32
 
33
33
  class UserDetailsEndpointResponse(Schema):
@@ -36,7 +36,7 @@ class UserDetailsEndpointResponse(Schema):
36
36
  last_name = fields.Str()
37
37
  username = fields.Str()
38
38
  email = fields.Str()
39
- pwd_last_change = fields.Str()
39
+ pwd_last_change = fields.DateTime()
40
40
 
41
41
 
42
42
  class TokenEndpointResponse(Schema):
@@ -53,7 +53,6 @@ class UserEditRequest(Schema):
53
53
  last_name = fields.Str(required=False)
54
54
  email = fields.Str(required=False)
55
55
  password = fields.Str(required=False)
56
- pwd_last_change = fields.DateTime(required=False)
57
56
 
58
57
 
59
58
  class LoginEndpointRequest(Schema):
@@ -67,24 +66,10 @@ class LoginEndpointRequest(Schema):
67
66
 
68
67
  class LoginOpenAuthRequest(Schema):
69
68
  """
70
- This is the schema used by the login endpoint with Open ID protocol
71
- Validates that either a token is provided, or both username and password are present
69
+ Schema for the login request with OpenID authentication
72
70
  """
73
-
74
- token = fields.Str(required=False)
75
- username = fields.Str(required=False)
76
- password = fields.Str(required=False)
77
-
78
- @validates_schema
79
- def validate_fields(self, data, **kwargs):
80
- if data.get("token") is None:
81
- if not data.get("username") or not data.get("password"):
82
- raise ValidationError(
83
- "A token needs to be provided when using Open ID authentication"
84
- )
85
- else:
86
- if data.get("username") or data.get("password"):
87
- raise ValidationError("The login needs to be done with a token only")
71
+ username = fields.String(required=False)
72
+ password = fields.String(required=False)
88
73
 
89
74
 
90
75
  class SignupRequest(Schema):
@@ -3,19 +3,14 @@ This file contains the auth class that can be used for authentication on the req
3
3
  """
4
4
 
5
5
  # Imports from external libraries
6
- import base64
7
6
  import jwt
8
7
  import requests
9
-
10
- from cryptography.hazmat.backends import default_backend
11
- from cryptography.hazmat.primitives import serialization
12
- from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
13
- from datetime import datetime, timedelta
8
+ from jwt.algorithms import RSAAlgorithm
9
+ from datetime import datetime, timedelta, timezone
14
10
  from flask import request, g, current_app, Request
15
11
  from functools import wraps
16
- from typing import Union, Tuple
17
-
18
- from jwt import DecodeError
12
+ from typing import Tuple
13
+ from cachetools import TTLCache
19
14
  from werkzeug.datastructures import Headers
20
15
 
21
16
  # Imports from internal modules
@@ -26,15 +21,12 @@ from cornflow.models import (
26
21
  ViewModel,
27
22
  )
28
23
  from cornflow.shared.const import (
29
- OID_AZURE,
30
- OID_AZURE_DISCOVERY_TENANT_URL,
31
- OID_AZURE_DISCOVERY_COMMON_URL,
32
- OID_GOOGLE,
24
+ AUTH_OID,
33
25
  PERMISSION_METHOD_MAP,
26
+ INTERNAL_TOKEN_ISSUER,
34
27
  )
35
28
  from cornflow.shared.exceptions import (
36
29
  CommunicationError,
37
- EndpointNotImplemented,
38
30
  InvalidCredentials,
39
31
  InvalidData,
40
32
  InvalidUsage,
@@ -42,6 +34,9 @@ from cornflow.shared.exceptions import (
42
34
  ObjectDoesNotExist,
43
35
  )
44
36
 
37
+ # Cache for storing public keys with 1 hour TTL
38
+ public_keys_cache = TTLCache(maxsize=10, ttl=3600)
39
+
45
40
 
46
41
  class Auth:
47
42
  def __init__(self, user_model=UserModel):
@@ -92,9 +87,9 @@ class Auth:
92
87
  @staticmethod
93
88
  def generate_token(user_id: int = None) -> str:
94
89
  """
95
- Generates a token given a user_id with a duration of one day
90
+ Generates a token given a user_id. The token will contain the username in the sub claim.
96
91
 
97
- :param int user_id: user code to be encoded in the token to identify the user afterwards
92
+ :param int user_id: user id to generate the token for
98
93
  :return: the generated token
99
94
  :rtype: str
100
95
  """
@@ -104,11 +99,19 @@ class Auth:
104
99
  err, log_txt="Error while trying to generate token. " + err
105
100
  )
106
101
 
102
+ user = UserModel.get_one_user(user_id)
103
+ if user is None:
104
+ err = "User does not exist"
105
+ raise InvalidUsage(
106
+ err, log_txt="Error while trying to generate token. " + err
107
+ )
108
+
107
109
  payload = {
108
- "exp": datetime.utcnow()
110
+ "exp": datetime.now(timezone.utc)
109
111
  + timedelta(hours=float(current_app.config["TOKEN_DURATION"])),
110
- "iat": datetime.utcnow(),
111
- "sub": user_id,
112
+ "iat": datetime.now(timezone.utc),
113
+ "sub": user.username,
114
+ "iss": INTERNAL_TOKEN_ISSUER,
112
115
  }
113
116
 
114
117
  return jwt.encode(
@@ -118,73 +121,68 @@ class Auth:
118
121
  @staticmethod
119
122
  def decode_token(token: str = None) -> dict:
120
123
  """
121
- Decodes a given JSON Web token and extracts the sub from it to give it back.
124
+ Decodes a given JSON Web token and extracts the username from the sub claim.
125
+ Works with both internal tokens and OpenID tokens.
122
126
 
123
127
  :param str token: the given JSON Web Token
124
- :return: the sub field of the token as the user_id
128
+ :return: dictionary containing the username from the token's sub claim
125
129
  :rtype: dict
126
130
  """
131
+
127
132
  if token is None:
128
- err = "The provided token is not valid."
129
- raise InvalidUsage(
130
- err, log_txt="Error while trying to decode token. " + err
131
- )
132
- try:
133
- payload = jwt.decode(
134
- token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
135
- )
136
- return {"user_id": payload["sub"]}
137
- except jwt.ExpiredSignatureError:
138
- raise InvalidCredentials(
139
- "The token has expired, please login again",
140
- log_txt="Error while trying to decode token. The token has expired.",
141
- )
142
- except jwt.InvalidTokenError:
133
+
143
134
  raise InvalidCredentials(
144
- "Invalid token, please try again with a new token",
145
- log_txt="Error while trying to decode token. The token is invalid.",
135
+ "Must provide a token in Authorization header",
136
+ log_txt="Error while trying to decode token. Token is missing.",
137
+ status_code=400,
146
138
  )
147
139
 
148
- def validate_oid_token(
149
- self, token: str, client_id: str, tenant_id: str, issuer: str, provider: int
150
- ) -> dict:
151
- """
152
- This method takes a token issued by an OID provider, the relevant information about the OID provider
153
- and validates that the token was generated by such source, is valid and extracts the information
154
- in the token for its use during the login process
155
-
156
- :param str token: the received token
157
- :param str client_id: the identifier from the client
158
- :param str tenant_id: the identifier for the tenant
159
- :param str issuer: the identifier for the issuer of the token
160
- :param int provider: the identifier for the provider of the token
161
- :return: the decoded token as a dictionary
162
- :rtype: dict
163
- """
164
- public_key = self._get_public_key(token, tenant_id, provider)
165
140
  try:
166
- decoded = jwt.decode(
167
- token,
168
- public_key,
169
- verify=True,
170
- algorithms=["RS256"],
171
- audience=[client_id],
172
- issuer=issuer,
141
+ # First try to decode header to validate basic token structure
142
+
143
+ unverified_payload = jwt.decode(token, options={"verify_signature": False})
144
+ issuer = unverified_payload.get("iss")
145
+
146
+ # For internal tokens
147
+ if issuer == INTERNAL_TOKEN_ISSUER:
148
+
149
+ return jwt.decode(
150
+ token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
151
+ )
152
+
153
+ # For OpenID tokens
154
+ if current_app.config["AUTH_TYPE"] == AUTH_OID:
155
+
156
+ return Auth().verify_token(
157
+ token,
158
+ current_app.config["OID_PROVIDER"],
159
+ current_app.config["OID_EXPECTED_AUDIENCE"],
160
+ )
161
+
162
+ # If we get here, the issuer is not valid
163
+
164
+ raise InvalidCredentials(
165
+ "Invalid token issuer. Token must be issued by a valid provider",
166
+ log_txt="Error while trying to decode token. Invalid issuer.",
167
+ status_code=400,
173
168
  )
174
- return decoded
169
+
175
170
  except jwt.ExpiredSignatureError:
171
+
176
172
  raise InvalidCredentials(
177
173
  "The token has expired, please login again",
178
- log_txt="Error while trying to validate a token. The token has expired.",
174
+ log_txt="Error while trying to decode token. The token has expired.",
175
+ status_code=400,
179
176
  )
180
- except jwt.InvalidTokenError:
177
+ except jwt.InvalidTokenError as e:
178
+
181
179
  raise InvalidCredentials(
182
- "Invalid token, please try again with a new token",
183
- log_txt="Error while trying to validate a token. The token is not valid.",
180
+ "Invalid token format or signature",
181
+ log_txt=f"Error while trying to decode token. The token format is invalid: {str(e)}",
182
+ status_code=400,
184
183
  )
185
184
 
186
- @staticmethod
187
- def get_token_from_header(headers: Headers = None) -> str:
185
+ def get_token_from_header(self, headers: Headers = None) -> str:
188
186
  """
189
187
  Extracts the token given on the request from the Authorization headers.
190
188
 
@@ -195,65 +193,157 @@ class Auth:
195
193
  """
196
194
  if headers is None:
197
195
  raise InvalidUsage(
198
- log_txt="Error while trying to get a token from header. The header is invalid."
196
+ "Request headers are missing",
197
+ log_txt="Error while trying to get a token from header. The header is invalid.",
198
+ status_code=400,
199
199
  )
200
200
 
201
201
  if "Authorization" not in headers:
202
202
  raise InvalidCredentials(
203
- "Auth token is not available",
203
+ "Authorization header is missing",
204
204
  log_txt="Error while trying to get a token from header. The auth token is not available.",
205
+ status_code=400,
205
206
  )
207
+
206
208
  auth_header = headers.get("Authorization")
209
+
207
210
  if not auth_header:
208
211
  return ""
212
+
213
+ if not auth_header.startswith("Bearer "):
214
+ err = "Invalid Authorization header format. Must be 'Bearer <token>'"
215
+ raise InvalidCredentials(
216
+ err,
217
+ log_txt=f"Error while trying to get a token from header. " + err,
218
+ status_code=400,
219
+ )
220
+
209
221
  try:
210
- return auth_header.split(" ")[1]
222
+ token = auth_header.split(" ")[1]
223
+ return token
211
224
  except Exception as e:
212
- err = f"The authorization header has a bad syntax: {e}"
225
+ err = "Invalid Authorization header format. Must be 'Bearer <token>'"
213
226
  raise InvalidCredentials(
214
- err, log_txt=f"Error while trying to get a token from header. " + err
227
+ err,
228
+ log_txt=f"Error while trying to get a token from header. " + err,
229
+ status_code=400,
215
230
  )
216
231
 
217
232
  def get_user_from_header(self, headers: Headers = None) -> UserModel:
218
233
  """
219
- Gets the user represented by the token that has to be in the request headers.
234
+ Extracts the user from the Authorization headers.
220
235
 
221
236
  :param headers: the request headers
222
237
  :type headers: `Headers`
223
238
  :return: the user object
224
- :rtype: `UserBaseModel`
239
+ :rtype: :class:`UserModel`
225
240
  """
226
241
  if headers is None:
227
- err = "Headers are missing from the request. Authentication was not possible to perform."
242
+ err = "Request headers are missing"
228
243
  raise InvalidUsage(
229
- err, log_txt="Error while trying to get user from header. " + err
244
+ err,
245
+ log_txt="Error while trying to get user from header. " + err,
246
+ status_code=400,
230
247
  )
231
248
  token = self.get_token_from_header(headers)
232
249
  data = self.decode_token(token)
233
- user_id = data["user_id"]
234
- user = self.user_model.get_one_user(user_id)
250
+
251
+ user = self.user_model.get_one_object(username=data["sub"])
252
+
235
253
  if user is None:
236
- err = "User does not exist, invalid token."
237
- raise ObjectDoesNotExist(
238
- err, log_txt="Error while trying to get user from header. " + err
254
+ err = "User not found. Please ensure you are using valid credentials"
255
+ raise InvalidCredentials(
256
+ err,
257
+ log_txt="Error while trying to get user from header. User does not exist.",
258
+ status_code=400,
239
259
  )
240
260
  return user
241
261
 
242
262
  @staticmethod
243
- def return_user_from_token(token):
263
+ def get_public_keys(provider_url: str) -> dict:
264
+ """
265
+ Gets the public keys from the OIDC provider and caches them
266
+
267
+ :param str provider_url: The base URL of the OIDC provider
268
+ :return: Dictionary of kid to public key mappings
269
+ :rtype: dict
244
270
  """
245
- Function used for internal testing. Given a token gives back the user_id encoded in it.
271
+ # Fetch keys from provider
272
+ jwks_url = f"{provider_url.rstrip('/')}/.well-known/jwks.json"
273
+ try:
274
+ response = requests.get(jwks_url)
275
+ response.raise_for_status()
276
+
277
+ # Convert JWK to RSA public keys using PyJWT's built-in method
278
+ public_keys = {
279
+ key["kid"]: RSAAlgorithm.from_jwk(key)
280
+ for key in response.json()["keys"]
281
+ }
282
+
283
+ # Store in cache
284
+ public_keys_cache[provider_url] = public_keys
285
+ return public_keys
246
286
 
247
- :param str token: the given token
248
- :return: the user id code.
249
- :rtype: int
287
+ except requests.exceptions.RequestException as e:
288
+ raise CommunicationError(
289
+ "Failed to fetch public keys from authentication provider",
290
+ log_txt=f"Error while fetching public keys from {jwks_url}: {str(e)}",
291
+ status_code=400,
292
+ )
293
+
294
+ def verify_token(
295
+ self, token: str, provider_url: str, expected_audience: str
296
+ ) -> dict:
250
297
  """
251
- user_id = Auth.decode_token(token)["user_id"]
252
- return user_id
298
+ Verifies an OpenID Connect token
299
+
300
+ :param str token: The token to verify
301
+ :param str provider_url: The base URL of the OIDC provider
302
+ :param str expected_audience: The expected audience claim
303
+ :return: The decoded token claims
304
+ :rtype: dict
305
+ """
306
+
307
+ # Get unverified header - this will raise jwt.InvalidTokenError if token format is invalid
308
+ unverified_header = jwt.get_unverified_header(token)
309
+
310
+ # Check for kid in header
311
+ if "kid" not in unverified_header:
312
+
313
+ raise InvalidCredentials(
314
+ "Invalid token: Missing key identifier (kid) in token header",
315
+ log_txt="Error while verifying token. Token header is missing 'kid'.",
316
+ status_code=400,
317
+ )
253
318
 
254
- """
255
- START OF INTERNAL PROTECTED METHODS
256
- """
319
+ kid = unverified_header["kid"]
320
+
321
+ # Check if we have the keys in cache and if the kid exists
322
+ public_key = None
323
+ if provider_url in public_keys_cache:
324
+ cached_keys = public_keys_cache[provider_url]
325
+ if kid in cached_keys:
326
+ public_key = cached_keys[kid]
327
+
328
+ # If kid not in cache, fetch fresh keys
329
+ if public_key is None:
330
+ public_keys = self.get_public_keys(provider_url)
331
+ if kid not in public_keys:
332
+ raise InvalidCredentials(
333
+ "Invalid token: Unknown key identifier (kid)",
334
+ log_txt="Error while verifying token. Key ID not found in public keys.",
335
+ status_code=400,
336
+ )
337
+ public_key = public_keys[kid]
338
+
339
+ # Verify token - this will raise appropriate jwt exceptions that will be caught in decode_token
340
+ return jwt.decode(
341
+ token,
342
+ public_key,
343
+ algorithms=["RS256"],
344
+ audience=[expected_audience],
345
+ issuer=provider_url,
346
+ )
257
347
 
258
348
  @staticmethod
259
349
  def _get_permission_for_request(req, user_id):
@@ -308,164 +398,6 @@ class Auth:
308
398
  """
309
399
  return getattr(req, "environ")["REQUEST_METHOD"], getattr(req, "url_rule").rule
310
400
 
311
- @staticmethod
312
- def _get_key_id(token: str) -> str:
313
- """
314
- Function to get the Key ID from the token
315
-
316
- :param str token: the given token
317
- :return: the key identifier
318
- :rtype: str
319
- """
320
- try:
321
- headers = jwt.get_unverified_header(token)
322
- except DecodeError as err:
323
- raise InvalidCredentials("Token is not valid")
324
- if not headers:
325
- raise InvalidCredentials("Token is missing the headers")
326
- try:
327
- return headers["kid"]
328
- except KeyError:
329
- raise InvalidCredentials("Token is missing the key identifier")
330
-
331
- @staticmethod
332
- def _fetch_discovery_meta(tenant_id: str, provider: int) -> dict:
333
- """
334
- Function to return a dictionary with the discovery URL of the provider
335
-
336
- :param str tenant_id: the tenant id
337
- :param int provider: the provider information
338
- :return: the different urls to be discovered on the provider
339
- :rtype: dict
340
- """
341
- if provider == OID_AZURE:
342
- oid_tenant_url = OID_AZURE_DISCOVERY_TENANT_URL
343
- oid_common_url = OID_AZURE_DISCOVERY_COMMON_URL
344
- elif provider == OID_GOOGLE:
345
- raise EndpointNotImplemented("The OID provider configuration is not valid")
346
- else:
347
- raise EndpointNotImplemented("The OID provider configuration is not valid")
348
-
349
- discovery_url = (
350
- oid_tenant_url.format(tenant_id=tenant_id) if tenant_id else oid_common_url
351
- )
352
- try:
353
- response = requests.get(discovery_url)
354
- response.raise_for_status()
355
- except requests.exceptions.HTTPError:
356
- raise CommunicationError(
357
- f"Error getting issuer discovery meta from {discovery_url}"
358
- )
359
- return response.json()
360
-
361
- def _get_json_web_keys_uri(self, tenant_id: str, provider: int) -> str:
362
- """
363
- Returns the JSON Web Keys URI
364
-
365
- :param str tenant_id: the tenant id
366
- :param int provider: the provider information
367
- :return: the URI from where to get the JSON Web Keys
368
- :rtype: str
369
- """
370
- meta = self._fetch_discovery_meta(tenant_id, provider)
371
- if "jwks_uri" in meta:
372
- return meta["jwks_uri"]
373
- else:
374
- raise CommunicationError("jwks_uri not found in the issuer meta")
375
-
376
- def _get_json_web_keys(self, tenant_id: str, provider: int) -> dict:
377
- """
378
- Function to get the json web keys from the tenant id and the provider
379
-
380
- :param str tenant_id: the tenant id
381
- :param int provider: the provider information
382
- :return: the JSON Web Keys dict
383
- :rtype: dict
384
- """
385
- json_web_keys_uri = self._get_json_web_keys_uri(tenant_id, provider)
386
- try:
387
- response = requests.get(json_web_keys_uri)
388
- response.raise_for_status()
389
- except requests.exceptions.HTTPError as error:
390
- raise CommunicationError(
391
- f"Error getting issuer jwks from {json_web_keys_uri}", error
392
- )
393
- return response.json()
394
-
395
- def _get_jwk(self, kid: str, tenant_id: str, provider: int) -> dict:
396
- """
397
- Function to get the JSON Web Key from the key identifier, the tenant id and the provider information
398
-
399
- :param str kid: the key identifier
400
- :param str tenant_id: the tenant information
401
- :param int provider: the provider information
402
- :return: the JSON Web Key
403
- :rtype: dict
404
- """
405
- for jwk in self._get_json_web_keys(tenant_id, provider).get("keys"):
406
- if jwk.get("kid") == kid:
407
- return jwk
408
- raise InvalidCredentials("Token has an unknown key identifier")
409
-
410
- @staticmethod
411
- def _ensure_bytes(key: Union[str, bytes]) -> bytes:
412
- """
413
- Function that ensures that the key is in bytes format
414
-
415
- :param str | bytes key:
416
- :return: the key on bytes format
417
- :rtype: bytes
418
- """
419
- if isinstance(key, str):
420
- key = key.encode("utf-8")
421
- return key
422
-
423
- def _decode_value(self, val: Union[str, bytes]) -> int:
424
- """
425
- Function that ensures that the value is decoded as a big int
426
-
427
- :param str | bytes val: the value that has to be decoded
428
- :return: the decoded value as a big int
429
- :rtype: int
430
- """
431
- decoded = base64.urlsafe_b64decode(self._ensure_bytes(val) + b"==")
432
- return int.from_bytes(decoded, "big")
433
-
434
- def _rsa_pem_from_jwk(self, jwk: dict) -> bytes:
435
- """
436
- Returns the private key from the JSON Web Key encoded as PEM
437
-
438
- :param dict jwk: the JSON Web Key
439
- :return: the RSA PEM key serialized as bytes
440
- :rtype: bytes
441
- """
442
- return (
443
- RSAPublicNumbers(
444
- n=self._decode_value(jwk["n"]),
445
- e=self._decode_value(jwk["e"]),
446
- )
447
- .public_key(default_backend())
448
- .public_bytes(
449
- encoding=serialization.Encoding.PEM,
450
- format=serialization.PublicFormat.SubjectPublicKeyInfo,
451
- )
452
- )
453
-
454
- def _get_public_key(self, token: str, tenant_id: str, provider: int):
455
- """
456
- This method returns the public key from the given token, ensuring that
457
- the tenant information and provider are correct
458
-
459
- :param str token: the given token
460
- :param str tenant_id: the tenant information
461
- :param int provider: the token provider information
462
- :return: the public key in the token or it raises an error
463
- :rtype: str
464
- """
465
- kid = self._get_key_id(token)
466
- jwk = self._get_jwk(kid, tenant_id, provider)
467
- return self._rsa_pem_from_jwk(jwk)
468
-
469
401
 
470
402
  class BIAuth(Auth):
471
403
  def __init__(self, user_model=UserModel):
@@ -474,34 +406,31 @@ class BIAuth(Auth):
474
406
  @staticmethod
475
407
  def decode_token(token: str = None) -> dict:
476
408
  """
477
- Decodes a given JSON Web token and extracts the sub from it to give it back.
409
+ Decodes a given JSON Web token and extracts the username from the sub claim.
478
410
 
479
411
  :param str token: the given JSON Web Token
480
- :return: the sub field of the token as the user_id
412
+ :return: dictionary containing the username from the token's sub claim
481
413
  :rtype: dict
482
414
  """
483
- if token is None:
484
- err = "The provided token is not valid."
485
- raise InvalidUsage(
486
- err, log_txt="Error while trying to decode token. " + err
487
- )
488
415
  try:
489
- payload = jwt.decode(
416
+ return jwt.decode(
490
417
  token, current_app.config["SECRET_BI_KEY"], algorithms="HS256"
491
418
  )
492
- return {"user_id": payload["sub"]}
419
+
493
420
  except jwt.InvalidTokenError:
494
421
  raise InvalidCredentials(
495
422
  "Invalid token, please try again with a new token",
496
423
  log_txt="Error while trying to decode token. The token is invalid.",
424
+ status_code=400,
497
425
  )
498
426
 
499
427
  @staticmethod
500
428
  def generate_token(user_id: int = None) -> str:
501
429
  """
502
- Generates a token given a user_id with a duration of one day
430
+ Generates a token given a user_id. The token will contain the username in the sub claim.
431
+ BI tokens do not include expiration time.
503
432
 
504
- :param int user_id: user code to be encoded in the token to identify the user afterward.
433
+ :param int user_id: user id to generate the token for
505
434
  :return: the generated token
506
435
  :rtype: str
507
436
  """
@@ -511,9 +440,17 @@ class BIAuth(Auth):
511
440
  err, log_txt="Error while trying to generate token. " + err
512
441
  )
513
442
 
443
+ user = UserModel.get_one_user(user_id)
444
+ if user is None:
445
+ err = "User does not exist"
446
+ raise InvalidUsage(
447
+ err, log_txt="Error while trying to generate token. " + err
448
+ )
449
+
514
450
  payload = {
515
- "iat": datetime.utcnow(),
516
- "sub": user_id,
451
+ "iat": datetime.now(timezone.utc),
452
+ "sub": user.username,
453
+ "iss": INTERNAL_TOKEN_ISSUER,
517
454
  }
518
455
 
519
456
  return jwt.encode(