cornflow 2.0.0a11__py3-none-any.whl → 2.0.0a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow_config/airflow_local_settings.py +1 -1
- cornflow/app.py +8 -3
- cornflow/cli/migrations.py +23 -3
- cornflow/cli/service.py +18 -18
- cornflow/cli/utils.py +16 -1
- cornflow/commands/dag.py +1 -1
- cornflow/config.py +13 -8
- cornflow/endpoints/__init__.py +8 -2
- cornflow/endpoints/alarms.py +66 -2
- cornflow/endpoints/data_check.py +53 -26
- cornflow/endpoints/execution.py +387 -132
- cornflow/endpoints/login.py +81 -63
- cornflow/endpoints/meta_resource.py +11 -3
- cornflow/migrations/versions/999b98e24225.py +34 -0
- cornflow/models/base_data_model.py +4 -32
- cornflow/models/execution.py +2 -3
- cornflow/models/meta_models.py +28 -22
- cornflow/models/user.py +7 -10
- cornflow/schemas/alarms.py +8 -0
- cornflow/schemas/execution.py +2 -1
- cornflow/schemas/query.py +2 -1
- cornflow/schemas/user.py +5 -20
- cornflow/shared/authentication/auth.py +201 -264
- cornflow/shared/const.py +3 -14
- cornflow/shared/databricks.py +5 -1
- cornflow/tests/const.py +2 -0
- cornflow/tests/custom_test_case.py +77 -26
- cornflow/tests/unit/test_actions.py +2 -2
- cornflow/tests/unit/test_alarms.py +55 -1
- cornflow/tests/unit/test_apiview.py +108 -3
- cornflow/tests/unit/test_cases.py +20 -29
- cornflow/tests/unit/test_cli.py +6 -5
- cornflow/tests/unit/test_commands.py +3 -3
- cornflow/tests/unit/test_dags.py +5 -6
- cornflow/tests/unit/test_executions.py +491 -118
- cornflow/tests/unit/test_instances.py +14 -2
- cornflow/tests/unit/test_instances_file.py +1 -1
- cornflow/tests/unit/test_licenses.py +1 -1
- cornflow/tests/unit/test_log_in.py +230 -207
- cornflow/tests/unit/test_permissions.py +8 -8
- cornflow/tests/unit/test_roles.py +48 -10
- cornflow/tests/unit/test_schemas.py +1 -1
- cornflow/tests/unit/test_tables.py +7 -7
- cornflow/tests/unit/test_token.py +19 -5
- cornflow/tests/unit/test_users.py +22 -6
- cornflow/tests/unit/tools.py +75 -10
- {cornflow-2.0.0a11.dist-info → cornflow-2.0.0a13.dist-info}/METADATA +16 -15
- {cornflow-2.0.0a11.dist-info → cornflow-2.0.0a13.dist-info}/RECORD +51 -51
- {cornflow-2.0.0a11.dist-info → cornflow-2.0.0a13.dist-info}/WHEEL +1 -1
- cornflow/endpoints/execution_databricks.py +0 -808
- {cornflow-2.0.0a11.dist-info → cornflow-2.0.0a13.dist-info}/entry_points.txt +0 -0
- {cornflow-2.0.0a11.dist-info → cornflow-2.0.0a13.dist-info}/top_level.txt +0 -0
cornflow/schemas/user.py
CHANGED
@@ -27,7 +27,7 @@ class UserEndpointResponse(Schema):
|
|
27
27
|
last_name = fields.Str()
|
28
28
|
email = fields.Str()
|
29
29
|
created_at = fields.Str()
|
30
|
-
pwd_last_change = fields.
|
30
|
+
pwd_last_change = fields.DateTime()
|
31
31
|
|
32
32
|
|
33
33
|
class UserDetailsEndpointResponse(Schema):
|
@@ -36,7 +36,7 @@ class UserDetailsEndpointResponse(Schema):
|
|
36
36
|
last_name = fields.Str()
|
37
37
|
username = fields.Str()
|
38
38
|
email = fields.Str()
|
39
|
-
pwd_last_change = fields.
|
39
|
+
pwd_last_change = fields.DateTime()
|
40
40
|
|
41
41
|
|
42
42
|
class TokenEndpointResponse(Schema):
|
@@ -53,7 +53,6 @@ class UserEditRequest(Schema):
|
|
53
53
|
last_name = fields.Str(required=False)
|
54
54
|
email = fields.Str(required=False)
|
55
55
|
password = fields.Str(required=False)
|
56
|
-
pwd_last_change = fields.DateTime(required=False)
|
57
56
|
|
58
57
|
|
59
58
|
class LoginEndpointRequest(Schema):
|
@@ -67,24 +66,10 @@ class LoginEndpointRequest(Schema):
|
|
67
66
|
|
68
67
|
class LoginOpenAuthRequest(Schema):
|
69
68
|
"""
|
70
|
-
|
71
|
-
Validates that either a token is provided, or both username and password are present
|
69
|
+
Schema for the login request with OpenID authentication
|
72
70
|
"""
|
73
|
-
|
74
|
-
|
75
|
-
username = fields.Str(required=False)
|
76
|
-
password = fields.Str(required=False)
|
77
|
-
|
78
|
-
@validates_schema
|
79
|
-
def validate_fields(self, data, **kwargs):
|
80
|
-
if data.get("token") is None:
|
81
|
-
if not data.get("username") or not data.get("password"):
|
82
|
-
raise ValidationError(
|
83
|
-
"A token needs to be provided when using Open ID authentication"
|
84
|
-
)
|
85
|
-
else:
|
86
|
-
if data.get("username") or data.get("password"):
|
87
|
-
raise ValidationError("The login needs to be done with a token only")
|
71
|
+
username = fields.String(required=False)
|
72
|
+
password = fields.String(required=False)
|
88
73
|
|
89
74
|
|
90
75
|
class SignupRequest(Schema):
|
@@ -3,19 +3,14 @@ This file contains the auth class that can be used for authentication on the req
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
# Imports from external libraries
|
6
|
-
import base64
|
7
6
|
import jwt
|
8
7
|
import requests
|
9
|
-
|
10
|
-
from
|
11
|
-
from cryptography.hazmat.primitives import serialization
|
12
|
-
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
|
13
|
-
from datetime import datetime, timedelta
|
8
|
+
from jwt.algorithms import RSAAlgorithm
|
9
|
+
from datetime import datetime, timedelta, timezone
|
14
10
|
from flask import request, g, current_app, Request
|
15
11
|
from functools import wraps
|
16
|
-
from typing import
|
17
|
-
|
18
|
-
from jwt import DecodeError
|
12
|
+
from typing import Tuple
|
13
|
+
from cachetools import TTLCache
|
19
14
|
from werkzeug.datastructures import Headers
|
20
15
|
|
21
16
|
# Imports from internal modules
|
@@ -26,15 +21,12 @@ from cornflow.models import (
|
|
26
21
|
ViewModel,
|
27
22
|
)
|
28
23
|
from cornflow.shared.const import (
|
29
|
-
|
30
|
-
OID_AZURE_DISCOVERY_TENANT_URL,
|
31
|
-
OID_AZURE_DISCOVERY_COMMON_URL,
|
32
|
-
OID_GOOGLE,
|
24
|
+
AUTH_OID,
|
33
25
|
PERMISSION_METHOD_MAP,
|
26
|
+
INTERNAL_TOKEN_ISSUER,
|
34
27
|
)
|
35
28
|
from cornflow.shared.exceptions import (
|
36
29
|
CommunicationError,
|
37
|
-
EndpointNotImplemented,
|
38
30
|
InvalidCredentials,
|
39
31
|
InvalidData,
|
40
32
|
InvalidUsage,
|
@@ -42,6 +34,9 @@ from cornflow.shared.exceptions import (
|
|
42
34
|
ObjectDoesNotExist,
|
43
35
|
)
|
44
36
|
|
37
|
+
# Cache for storing public keys with 1 hour TTL
|
38
|
+
public_keys_cache = TTLCache(maxsize=10, ttl=3600)
|
39
|
+
|
45
40
|
|
46
41
|
class Auth:
|
47
42
|
def __init__(self, user_model=UserModel):
|
@@ -92,9 +87,9 @@ class Auth:
|
|
92
87
|
@staticmethod
|
93
88
|
def generate_token(user_id: int = None) -> str:
|
94
89
|
"""
|
95
|
-
Generates a token given a user_id
|
90
|
+
Generates a token given a user_id. The token will contain the username in the sub claim.
|
96
91
|
|
97
|
-
:param int user_id: user
|
92
|
+
:param int user_id: user id to generate the token for
|
98
93
|
:return: the generated token
|
99
94
|
:rtype: str
|
100
95
|
"""
|
@@ -104,11 +99,19 @@ class Auth:
|
|
104
99
|
err, log_txt="Error while trying to generate token. " + err
|
105
100
|
)
|
106
101
|
|
102
|
+
user = UserModel.get_one_user(user_id)
|
103
|
+
if user is None:
|
104
|
+
err = "User does not exist"
|
105
|
+
raise InvalidUsage(
|
106
|
+
err, log_txt="Error while trying to generate token. " + err
|
107
|
+
)
|
108
|
+
|
107
109
|
payload = {
|
108
|
-
"exp": datetime.
|
110
|
+
"exp": datetime.now(timezone.utc)
|
109
111
|
+ timedelta(hours=float(current_app.config["TOKEN_DURATION"])),
|
110
|
-
"iat": datetime.
|
111
|
-
"sub":
|
112
|
+
"iat": datetime.now(timezone.utc),
|
113
|
+
"sub": user.username,
|
114
|
+
"iss": INTERNAL_TOKEN_ISSUER,
|
112
115
|
}
|
113
116
|
|
114
117
|
return jwt.encode(
|
@@ -118,73 +121,68 @@ class Auth:
|
|
118
121
|
@staticmethod
|
119
122
|
def decode_token(token: str = None) -> dict:
|
120
123
|
"""
|
121
|
-
Decodes a given JSON Web token and extracts the
|
124
|
+
Decodes a given JSON Web token and extracts the username from the sub claim.
|
125
|
+
Works with both internal tokens and OpenID tokens.
|
122
126
|
|
123
127
|
:param str token: the given JSON Web Token
|
124
|
-
:return: the
|
128
|
+
:return: dictionary containing the username from the token's sub claim
|
125
129
|
:rtype: dict
|
126
130
|
"""
|
131
|
+
|
127
132
|
if token is None:
|
128
|
-
|
129
|
-
raise InvalidUsage(
|
130
|
-
err, log_txt="Error while trying to decode token. " + err
|
131
|
-
)
|
132
|
-
try:
|
133
|
-
payload = jwt.decode(
|
134
|
-
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
135
|
-
)
|
136
|
-
return {"user_id": payload["sub"]}
|
137
|
-
except jwt.ExpiredSignatureError:
|
138
|
-
raise InvalidCredentials(
|
139
|
-
"The token has expired, please login again",
|
140
|
-
log_txt="Error while trying to decode token. The token has expired.",
|
141
|
-
)
|
142
|
-
except jwt.InvalidTokenError:
|
133
|
+
|
143
134
|
raise InvalidCredentials(
|
144
|
-
"
|
145
|
-
log_txt="Error while trying to decode token.
|
135
|
+
"Must provide a token in Authorization header",
|
136
|
+
log_txt="Error while trying to decode token. Token is missing.",
|
137
|
+
status_code=400,
|
146
138
|
)
|
147
139
|
|
148
|
-
def validate_oid_token(
|
149
|
-
self, token: str, client_id: str, tenant_id: str, issuer: str, provider: int
|
150
|
-
) -> dict:
|
151
|
-
"""
|
152
|
-
This method takes a token issued by an OID provider, the relevant information about the OID provider
|
153
|
-
and validates that the token was generated by such source, is valid and extracts the information
|
154
|
-
in the token for its use during the login process
|
155
|
-
|
156
|
-
:param str token: the received token
|
157
|
-
:param str client_id: the identifier from the client
|
158
|
-
:param str tenant_id: the identifier for the tenant
|
159
|
-
:param str issuer: the identifier for the issuer of the token
|
160
|
-
:param int provider: the identifier for the provider of the token
|
161
|
-
:return: the decoded token as a dictionary
|
162
|
-
:rtype: dict
|
163
|
-
"""
|
164
|
-
public_key = self._get_public_key(token, tenant_id, provider)
|
165
140
|
try:
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
141
|
+
# First try to decode header to validate basic token structure
|
142
|
+
|
143
|
+
unverified_payload = jwt.decode(token, options={"verify_signature": False})
|
144
|
+
issuer = unverified_payload.get("iss")
|
145
|
+
|
146
|
+
# For internal tokens
|
147
|
+
if issuer == INTERNAL_TOKEN_ISSUER:
|
148
|
+
|
149
|
+
return jwt.decode(
|
150
|
+
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
151
|
+
)
|
152
|
+
|
153
|
+
# For OpenID tokens
|
154
|
+
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
155
|
+
|
156
|
+
return Auth().verify_token(
|
157
|
+
token,
|
158
|
+
current_app.config["OID_PROVIDER"],
|
159
|
+
current_app.config["OID_EXPECTED_AUDIENCE"],
|
160
|
+
)
|
161
|
+
|
162
|
+
# If we get here, the issuer is not valid
|
163
|
+
|
164
|
+
raise InvalidCredentials(
|
165
|
+
"Invalid token issuer. Token must be issued by a valid provider",
|
166
|
+
log_txt="Error while trying to decode token. Invalid issuer.",
|
167
|
+
status_code=400,
|
173
168
|
)
|
174
|
-
|
169
|
+
|
175
170
|
except jwt.ExpiredSignatureError:
|
171
|
+
|
176
172
|
raise InvalidCredentials(
|
177
173
|
"The token has expired, please login again",
|
178
|
-
log_txt="Error while trying to
|
174
|
+
log_txt="Error while trying to decode token. The token has expired.",
|
175
|
+
status_code=400,
|
179
176
|
)
|
180
|
-
except jwt.InvalidTokenError:
|
177
|
+
except jwt.InvalidTokenError as e:
|
178
|
+
|
181
179
|
raise InvalidCredentials(
|
182
|
-
"Invalid token
|
183
|
-
log_txt="Error while trying to
|
180
|
+
"Invalid token format or signature",
|
181
|
+
log_txt=f"Error while trying to decode token. The token format is invalid: {str(e)}",
|
182
|
+
status_code=400,
|
184
183
|
)
|
185
184
|
|
186
|
-
|
187
|
-
def get_token_from_header(headers: Headers = None) -> str:
|
185
|
+
def get_token_from_header(self, headers: Headers = None) -> str:
|
188
186
|
"""
|
189
187
|
Extracts the token given on the request from the Authorization headers.
|
190
188
|
|
@@ -195,65 +193,157 @@ class Auth:
|
|
195
193
|
"""
|
196
194
|
if headers is None:
|
197
195
|
raise InvalidUsage(
|
198
|
-
|
196
|
+
"Request headers are missing",
|
197
|
+
log_txt="Error while trying to get a token from header. The header is invalid.",
|
198
|
+
status_code=400,
|
199
199
|
)
|
200
200
|
|
201
201
|
if "Authorization" not in headers:
|
202
202
|
raise InvalidCredentials(
|
203
|
-
"
|
203
|
+
"Authorization header is missing",
|
204
204
|
log_txt="Error while trying to get a token from header. The auth token is not available.",
|
205
|
+
status_code=400,
|
205
206
|
)
|
207
|
+
|
206
208
|
auth_header = headers.get("Authorization")
|
209
|
+
|
207
210
|
if not auth_header:
|
208
211
|
return ""
|
212
|
+
|
213
|
+
if not auth_header.startswith("Bearer "):
|
214
|
+
err = "Invalid Authorization header format. Must be 'Bearer <token>'"
|
215
|
+
raise InvalidCredentials(
|
216
|
+
err,
|
217
|
+
log_txt=f"Error while trying to get a token from header. " + err,
|
218
|
+
status_code=400,
|
219
|
+
)
|
220
|
+
|
209
221
|
try:
|
210
|
-
|
222
|
+
token = auth_header.split(" ")[1]
|
223
|
+
return token
|
211
224
|
except Exception as e:
|
212
|
-
err =
|
225
|
+
err = "Invalid Authorization header format. Must be 'Bearer <token>'"
|
213
226
|
raise InvalidCredentials(
|
214
|
-
err,
|
227
|
+
err,
|
228
|
+
log_txt=f"Error while trying to get a token from header. " + err,
|
229
|
+
status_code=400,
|
215
230
|
)
|
216
231
|
|
217
232
|
def get_user_from_header(self, headers: Headers = None) -> UserModel:
|
218
233
|
"""
|
219
|
-
|
234
|
+
Extracts the user from the Authorization headers.
|
220
235
|
|
221
236
|
:param headers: the request headers
|
222
237
|
:type headers: `Headers`
|
223
238
|
:return: the user object
|
224
|
-
:rtype: `
|
239
|
+
:rtype: :class:`UserModel`
|
225
240
|
"""
|
226
241
|
if headers is None:
|
227
|
-
err = "
|
242
|
+
err = "Request headers are missing"
|
228
243
|
raise InvalidUsage(
|
229
|
-
err,
|
244
|
+
err,
|
245
|
+
log_txt="Error while trying to get user from header. " + err,
|
246
|
+
status_code=400,
|
230
247
|
)
|
231
248
|
token = self.get_token_from_header(headers)
|
232
249
|
data = self.decode_token(token)
|
233
|
-
|
234
|
-
user = self.user_model.
|
250
|
+
|
251
|
+
user = self.user_model.get_one_object(username=data["sub"])
|
252
|
+
|
235
253
|
if user is None:
|
236
|
-
err = "User
|
237
|
-
raise
|
238
|
-
err,
|
254
|
+
err = "User not found. Please ensure you are using valid credentials"
|
255
|
+
raise InvalidCredentials(
|
256
|
+
err,
|
257
|
+
log_txt="Error while trying to get user from header. User does not exist.",
|
258
|
+
status_code=400,
|
239
259
|
)
|
240
260
|
return user
|
241
261
|
|
242
262
|
@staticmethod
|
243
|
-
def
|
263
|
+
def get_public_keys(provider_url: str) -> dict:
|
264
|
+
"""
|
265
|
+
Gets the public keys from the OIDC provider and caches them
|
266
|
+
|
267
|
+
:param str provider_url: The base URL of the OIDC provider
|
268
|
+
:return: Dictionary of kid to public key mappings
|
269
|
+
:rtype: dict
|
244
270
|
"""
|
245
|
-
|
271
|
+
# Fetch keys from provider
|
272
|
+
jwks_url = f"{provider_url.rstrip('/')}/.well-known/jwks.json"
|
273
|
+
try:
|
274
|
+
response = requests.get(jwks_url)
|
275
|
+
response.raise_for_status()
|
276
|
+
|
277
|
+
# Convert JWK to RSA public keys using PyJWT's built-in method
|
278
|
+
public_keys = {
|
279
|
+
key["kid"]: RSAAlgorithm.from_jwk(key)
|
280
|
+
for key in response.json()["keys"]
|
281
|
+
}
|
282
|
+
|
283
|
+
# Store in cache
|
284
|
+
public_keys_cache[provider_url] = public_keys
|
285
|
+
return public_keys
|
246
286
|
|
247
|
-
|
248
|
-
|
249
|
-
|
287
|
+
except requests.exceptions.RequestException as e:
|
288
|
+
raise CommunicationError(
|
289
|
+
"Failed to fetch public keys from authentication provider",
|
290
|
+
log_txt=f"Error while fetching public keys from {jwks_url}: {str(e)}",
|
291
|
+
status_code=400,
|
292
|
+
)
|
293
|
+
|
294
|
+
def verify_token(
|
295
|
+
self, token: str, provider_url: str, expected_audience: str
|
296
|
+
) -> dict:
|
250
297
|
"""
|
251
|
-
|
252
|
-
|
298
|
+
Verifies an OpenID Connect token
|
299
|
+
|
300
|
+
:param str token: The token to verify
|
301
|
+
:param str provider_url: The base URL of the OIDC provider
|
302
|
+
:param str expected_audience: The expected audience claim
|
303
|
+
:return: The decoded token claims
|
304
|
+
:rtype: dict
|
305
|
+
"""
|
306
|
+
|
307
|
+
# Get unverified header - this will raise jwt.InvalidTokenError if token format is invalid
|
308
|
+
unverified_header = jwt.get_unverified_header(token)
|
309
|
+
|
310
|
+
# Check for kid in header
|
311
|
+
if "kid" not in unverified_header:
|
312
|
+
|
313
|
+
raise InvalidCredentials(
|
314
|
+
"Invalid token: Missing key identifier (kid) in token header",
|
315
|
+
log_txt="Error while verifying token. Token header is missing 'kid'.",
|
316
|
+
status_code=400,
|
317
|
+
)
|
253
318
|
|
254
|
-
|
255
|
-
|
256
|
-
|
319
|
+
kid = unverified_header["kid"]
|
320
|
+
|
321
|
+
# Check if we have the keys in cache and if the kid exists
|
322
|
+
public_key = None
|
323
|
+
if provider_url in public_keys_cache:
|
324
|
+
cached_keys = public_keys_cache[provider_url]
|
325
|
+
if kid in cached_keys:
|
326
|
+
public_key = cached_keys[kid]
|
327
|
+
|
328
|
+
# If kid not in cache, fetch fresh keys
|
329
|
+
if public_key is None:
|
330
|
+
public_keys = self.get_public_keys(provider_url)
|
331
|
+
if kid not in public_keys:
|
332
|
+
raise InvalidCredentials(
|
333
|
+
"Invalid token: Unknown key identifier (kid)",
|
334
|
+
log_txt="Error while verifying token. Key ID not found in public keys.",
|
335
|
+
status_code=400,
|
336
|
+
)
|
337
|
+
public_key = public_keys[kid]
|
338
|
+
|
339
|
+
# Verify token - this will raise appropriate jwt exceptions that will be caught in decode_token
|
340
|
+
return jwt.decode(
|
341
|
+
token,
|
342
|
+
public_key,
|
343
|
+
algorithms=["RS256"],
|
344
|
+
audience=[expected_audience],
|
345
|
+
issuer=provider_url,
|
346
|
+
)
|
257
347
|
|
258
348
|
@staticmethod
|
259
349
|
def _get_permission_for_request(req, user_id):
|
@@ -308,164 +398,6 @@ class Auth:
|
|
308
398
|
"""
|
309
399
|
return getattr(req, "environ")["REQUEST_METHOD"], getattr(req, "url_rule").rule
|
310
400
|
|
311
|
-
@staticmethod
|
312
|
-
def _get_key_id(token: str) -> str:
|
313
|
-
"""
|
314
|
-
Function to get the Key ID from the token
|
315
|
-
|
316
|
-
:param str token: the given token
|
317
|
-
:return: the key identifier
|
318
|
-
:rtype: str
|
319
|
-
"""
|
320
|
-
try:
|
321
|
-
headers = jwt.get_unverified_header(token)
|
322
|
-
except DecodeError as err:
|
323
|
-
raise InvalidCredentials("Token is not valid")
|
324
|
-
if not headers:
|
325
|
-
raise InvalidCredentials("Token is missing the headers")
|
326
|
-
try:
|
327
|
-
return headers["kid"]
|
328
|
-
except KeyError:
|
329
|
-
raise InvalidCredentials("Token is missing the key identifier")
|
330
|
-
|
331
|
-
@staticmethod
|
332
|
-
def _fetch_discovery_meta(tenant_id: str, provider: int) -> dict:
|
333
|
-
"""
|
334
|
-
Function to return a dictionary with the discovery URL of the provider
|
335
|
-
|
336
|
-
:param str tenant_id: the tenant id
|
337
|
-
:param int provider: the provider information
|
338
|
-
:return: the different urls to be discovered on the provider
|
339
|
-
:rtype: dict
|
340
|
-
"""
|
341
|
-
if provider == OID_AZURE:
|
342
|
-
oid_tenant_url = OID_AZURE_DISCOVERY_TENANT_URL
|
343
|
-
oid_common_url = OID_AZURE_DISCOVERY_COMMON_URL
|
344
|
-
elif provider == OID_GOOGLE:
|
345
|
-
raise EndpointNotImplemented("The OID provider configuration is not valid")
|
346
|
-
else:
|
347
|
-
raise EndpointNotImplemented("The OID provider configuration is not valid")
|
348
|
-
|
349
|
-
discovery_url = (
|
350
|
-
oid_tenant_url.format(tenant_id=tenant_id) if tenant_id else oid_common_url
|
351
|
-
)
|
352
|
-
try:
|
353
|
-
response = requests.get(discovery_url)
|
354
|
-
response.raise_for_status()
|
355
|
-
except requests.exceptions.HTTPError:
|
356
|
-
raise CommunicationError(
|
357
|
-
f"Error getting issuer discovery meta from {discovery_url}"
|
358
|
-
)
|
359
|
-
return response.json()
|
360
|
-
|
361
|
-
def _get_json_web_keys_uri(self, tenant_id: str, provider: int) -> str:
|
362
|
-
"""
|
363
|
-
Returns the JSON Web Keys URI
|
364
|
-
|
365
|
-
:param str tenant_id: the tenant id
|
366
|
-
:param int provider: the provider information
|
367
|
-
:return: the URI from where to get the JSON Web Keys
|
368
|
-
:rtype: str
|
369
|
-
"""
|
370
|
-
meta = self._fetch_discovery_meta(tenant_id, provider)
|
371
|
-
if "jwks_uri" in meta:
|
372
|
-
return meta["jwks_uri"]
|
373
|
-
else:
|
374
|
-
raise CommunicationError("jwks_uri not found in the issuer meta")
|
375
|
-
|
376
|
-
def _get_json_web_keys(self, tenant_id: str, provider: int) -> dict:
|
377
|
-
"""
|
378
|
-
Function to get the json web keys from the tenant id and the provider
|
379
|
-
|
380
|
-
:param str tenant_id: the tenant id
|
381
|
-
:param int provider: the provider information
|
382
|
-
:return: the JSON Web Keys dict
|
383
|
-
:rtype: dict
|
384
|
-
"""
|
385
|
-
json_web_keys_uri = self._get_json_web_keys_uri(tenant_id, provider)
|
386
|
-
try:
|
387
|
-
response = requests.get(json_web_keys_uri)
|
388
|
-
response.raise_for_status()
|
389
|
-
except requests.exceptions.HTTPError as error:
|
390
|
-
raise CommunicationError(
|
391
|
-
f"Error getting issuer jwks from {json_web_keys_uri}", error
|
392
|
-
)
|
393
|
-
return response.json()
|
394
|
-
|
395
|
-
def _get_jwk(self, kid: str, tenant_id: str, provider: int) -> dict:
|
396
|
-
"""
|
397
|
-
Function to get the JSON Web Key from the key identifier, the tenant id and the provider information
|
398
|
-
|
399
|
-
:param str kid: the key identifier
|
400
|
-
:param str tenant_id: the tenant information
|
401
|
-
:param int provider: the provider information
|
402
|
-
:return: the JSON Web Key
|
403
|
-
:rtype: dict
|
404
|
-
"""
|
405
|
-
for jwk in self._get_json_web_keys(tenant_id, provider).get("keys"):
|
406
|
-
if jwk.get("kid") == kid:
|
407
|
-
return jwk
|
408
|
-
raise InvalidCredentials("Token has an unknown key identifier")
|
409
|
-
|
410
|
-
@staticmethod
|
411
|
-
def _ensure_bytes(key: Union[str, bytes]) -> bytes:
|
412
|
-
"""
|
413
|
-
Function that ensures that the key is in bytes format
|
414
|
-
|
415
|
-
:param str | bytes key:
|
416
|
-
:return: the key on bytes format
|
417
|
-
:rtype: bytes
|
418
|
-
"""
|
419
|
-
if isinstance(key, str):
|
420
|
-
key = key.encode("utf-8")
|
421
|
-
return key
|
422
|
-
|
423
|
-
def _decode_value(self, val: Union[str, bytes]) -> int:
|
424
|
-
"""
|
425
|
-
Function that ensures that the value is decoded as a big int
|
426
|
-
|
427
|
-
:param str | bytes val: the value that has to be decoded
|
428
|
-
:return: the decoded value as a big int
|
429
|
-
:rtype: int
|
430
|
-
"""
|
431
|
-
decoded = base64.urlsafe_b64decode(self._ensure_bytes(val) + b"==")
|
432
|
-
return int.from_bytes(decoded, "big")
|
433
|
-
|
434
|
-
def _rsa_pem_from_jwk(self, jwk: dict) -> bytes:
|
435
|
-
"""
|
436
|
-
Returns the private key from the JSON Web Key encoded as PEM
|
437
|
-
|
438
|
-
:param dict jwk: the JSON Web Key
|
439
|
-
:return: the RSA PEM key serialized as bytes
|
440
|
-
:rtype: bytes
|
441
|
-
"""
|
442
|
-
return (
|
443
|
-
RSAPublicNumbers(
|
444
|
-
n=self._decode_value(jwk["n"]),
|
445
|
-
e=self._decode_value(jwk["e"]),
|
446
|
-
)
|
447
|
-
.public_key(default_backend())
|
448
|
-
.public_bytes(
|
449
|
-
encoding=serialization.Encoding.PEM,
|
450
|
-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
451
|
-
)
|
452
|
-
)
|
453
|
-
|
454
|
-
def _get_public_key(self, token: str, tenant_id: str, provider: int):
|
455
|
-
"""
|
456
|
-
This method returns the public key from the given token, ensuring that
|
457
|
-
the tenant information and provider are correct
|
458
|
-
|
459
|
-
:param str token: the given token
|
460
|
-
:param str tenant_id: the tenant information
|
461
|
-
:param int provider: the token provider information
|
462
|
-
:return: the public key in the token or it raises an error
|
463
|
-
:rtype: str
|
464
|
-
"""
|
465
|
-
kid = self._get_key_id(token)
|
466
|
-
jwk = self._get_jwk(kid, tenant_id, provider)
|
467
|
-
return self._rsa_pem_from_jwk(jwk)
|
468
|
-
|
469
401
|
|
470
402
|
class BIAuth(Auth):
|
471
403
|
def __init__(self, user_model=UserModel):
|
@@ -474,34 +406,31 @@ class BIAuth(Auth):
|
|
474
406
|
@staticmethod
|
475
407
|
def decode_token(token: str = None) -> dict:
|
476
408
|
"""
|
477
|
-
Decodes a given JSON Web token and extracts the
|
409
|
+
Decodes a given JSON Web token and extracts the username from the sub claim.
|
478
410
|
|
479
411
|
:param str token: the given JSON Web Token
|
480
|
-
:return: the
|
412
|
+
:return: dictionary containing the username from the token's sub claim
|
481
413
|
:rtype: dict
|
482
414
|
"""
|
483
|
-
if token is None:
|
484
|
-
err = "The provided token is not valid."
|
485
|
-
raise InvalidUsage(
|
486
|
-
err, log_txt="Error while trying to decode token. " + err
|
487
|
-
)
|
488
415
|
try:
|
489
|
-
|
416
|
+
return jwt.decode(
|
490
417
|
token, current_app.config["SECRET_BI_KEY"], algorithms="HS256"
|
491
418
|
)
|
492
|
-
|
419
|
+
|
493
420
|
except jwt.InvalidTokenError:
|
494
421
|
raise InvalidCredentials(
|
495
422
|
"Invalid token, please try again with a new token",
|
496
423
|
log_txt="Error while trying to decode token. The token is invalid.",
|
424
|
+
status_code=400,
|
497
425
|
)
|
498
426
|
|
499
427
|
@staticmethod
|
500
428
|
def generate_token(user_id: int = None) -> str:
|
501
429
|
"""
|
502
|
-
Generates a token given a user_id
|
430
|
+
Generates a token given a user_id. The token will contain the username in the sub claim.
|
431
|
+
BI tokens do not include expiration time.
|
503
432
|
|
504
|
-
:param int user_id: user
|
433
|
+
:param int user_id: user id to generate the token for
|
505
434
|
:return: the generated token
|
506
435
|
:rtype: str
|
507
436
|
"""
|
@@ -511,9 +440,17 @@ class BIAuth(Auth):
|
|
511
440
|
err, log_txt="Error while trying to generate token. " + err
|
512
441
|
)
|
513
442
|
|
443
|
+
user = UserModel.get_one_user(user_id)
|
444
|
+
if user is None:
|
445
|
+
err = "User does not exist"
|
446
|
+
raise InvalidUsage(
|
447
|
+
err, log_txt="Error while trying to generate token. " + err
|
448
|
+
)
|
449
|
+
|
514
450
|
payload = {
|
515
|
-
"iat": datetime.
|
516
|
-
"sub":
|
451
|
+
"iat": datetime.now(timezone.utc),
|
452
|
+
"sub": user.username,
|
453
|
+
"iss": INTERNAL_TOKEN_ISSUER,
|
517
454
|
}
|
518
455
|
|
519
456
|
return jwt.encode(
|