cornflow 1.1.5a1__py3-none-any.whl → 1.2.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornflow/app.py +8 -3
- cornflow/cli/service.py +14 -6
- cornflow/config.py +9 -10
- cornflow/endpoints/login.py +54 -57
- cornflow/schemas/user.py +3 -17
- cornflow/shared/authentication/auth.py +211 -254
- cornflow/shared/const.py +3 -14
- cornflow/tests/custom_test_case.py +42 -21
- cornflow/tests/unit/test_actions.py +2 -2
- cornflow/tests/unit/test_apiview.py +2 -2
- cornflow/tests/unit/test_cases.py +20 -29
- cornflow/tests/unit/test_dags.py +5 -5
- cornflow/tests/unit/test_instances.py +2 -2
- cornflow/tests/unit/test_instances_file.py +1 -1
- cornflow/tests/unit/test_licenses.py +1 -1
- cornflow/tests/unit/test_log_in.py +227 -207
- cornflow/tests/unit/test_permissions.py +8 -8
- cornflow/tests/unit/test_roles.py +10 -10
- cornflow/tests/unit/test_tables.py +7 -7
- cornflow/tests/unit/test_token.py +12 -6
- {cornflow-1.1.5a1.dist-info → cornflow-1.2.0a1.dist-info}/METADATA +3 -2
- {cornflow-1.1.5a1.dist-info → cornflow-1.2.0a1.dist-info}/RECORD +25 -25
- {cornflow-1.1.5a1.dist-info → cornflow-1.2.0a1.dist-info}/WHEEL +1 -1
- {cornflow-1.1.5a1.dist-info → cornflow-1.2.0a1.dist-info}/entry_points.txt +0 -0
- {cornflow-1.1.5a1.dist-info → cornflow-1.2.0a1.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,14 @@ This file contains the auth class that can be used for authentication on the req
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
# Imports from external libraries
|
6
|
-
import base64
|
7
6
|
import jwt
|
8
7
|
import requests
|
9
|
-
|
10
|
-
from cryptography.hazmat.backends import default_backend
|
11
|
-
from cryptography.hazmat.primitives import serialization
|
12
|
-
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
|
8
|
+
from jwt.algorithms import RSAAlgorithm
|
13
9
|
from datetime import datetime, timedelta
|
14
10
|
from flask import request, g, current_app, Request
|
15
11
|
from functools import wraps
|
16
|
-
from typing import
|
17
|
-
|
18
|
-
from jwt import DecodeError
|
12
|
+
from typing import Tuple
|
13
|
+
from cachetools import TTLCache
|
19
14
|
from werkzeug.datastructures import Headers
|
20
15
|
|
21
16
|
# Imports from internal modules
|
@@ -26,15 +21,12 @@ from cornflow.models import (
|
|
26
21
|
ViewModel,
|
27
22
|
)
|
28
23
|
from cornflow.shared.const import (
|
29
|
-
|
30
|
-
OID_AZURE_DISCOVERY_TENANT_URL,
|
31
|
-
OID_AZURE_DISCOVERY_COMMON_URL,
|
32
|
-
OID_GOOGLE,
|
24
|
+
AUTH_OID,
|
33
25
|
PERMISSION_METHOD_MAP,
|
26
|
+
INTERNAL_TOKEN_ISSUER,
|
34
27
|
)
|
35
28
|
from cornflow.shared.exceptions import (
|
36
29
|
CommunicationError,
|
37
|
-
EndpointNotImplemented,
|
38
30
|
InvalidCredentials,
|
39
31
|
InvalidData,
|
40
32
|
InvalidUsage,
|
@@ -42,6 +34,9 @@ from cornflow.shared.exceptions import (
|
|
42
34
|
ObjectDoesNotExist,
|
43
35
|
)
|
44
36
|
|
37
|
+
# Cache for storing public keys with 1 hour TTL
|
38
|
+
public_keys_cache = TTLCache(maxsize=10, ttl=3600)
|
39
|
+
|
45
40
|
|
46
41
|
class Auth:
|
47
42
|
def __init__(self, user_model=UserModel):
|
@@ -92,9 +87,9 @@ class Auth:
|
|
92
87
|
@staticmethod
|
93
88
|
def generate_token(user_id: int = None) -> str:
|
94
89
|
"""
|
95
|
-
Generates a token given a user_id
|
90
|
+
Generates a token given a user_id. The token will contain the username in the sub claim.
|
96
91
|
|
97
|
-
:param int user_id: user
|
92
|
+
:param int user_id: user id to generate the token for
|
98
93
|
:return: the generated token
|
99
94
|
:rtype: str
|
100
95
|
"""
|
@@ -104,11 +99,19 @@ class Auth:
|
|
104
99
|
err, log_txt="Error while trying to generate token. " + err
|
105
100
|
)
|
106
101
|
|
102
|
+
user = UserModel.get_one_user(user_id)
|
103
|
+
if user is None:
|
104
|
+
err = "User does not exist"
|
105
|
+
raise InvalidUsage(
|
106
|
+
err, log_txt="Error while trying to generate token. " + err
|
107
|
+
)
|
108
|
+
|
107
109
|
payload = {
|
108
110
|
"exp": datetime.utcnow()
|
109
111
|
+ timedelta(hours=float(current_app.config["TOKEN_DURATION"])),
|
110
112
|
"iat": datetime.utcnow(),
|
111
|
-
"sub":
|
113
|
+
"sub": user.username,
|
114
|
+
"iss": INTERNAL_TOKEN_ISSUER,
|
112
115
|
}
|
113
116
|
|
114
117
|
return jwt.encode(
|
@@ -118,73 +121,71 @@ class Auth:
|
|
118
121
|
@staticmethod
|
119
122
|
def decode_token(token: str = None) -> dict:
|
120
123
|
"""
|
121
|
-
Decodes a given JSON Web token and extracts the
|
124
|
+
Decodes a given JSON Web token and extracts the username from the sub claim.
|
125
|
+
Works with both internal tokens and OpenID tokens.
|
122
126
|
|
123
127
|
:param str token: the given JSON Web Token
|
124
|
-
:return: the
|
128
|
+
:return: dictionary containing the username from the token's sub claim
|
125
129
|
:rtype: dict
|
126
130
|
"""
|
131
|
+
print("[decode_token] Starting token decode")
|
127
132
|
if token is None:
|
128
|
-
|
129
|
-
raise InvalidUsage(
|
130
|
-
err, log_txt="Error while trying to decode token. " + err
|
131
|
-
)
|
132
|
-
try:
|
133
|
-
payload = jwt.decode(
|
134
|
-
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
135
|
-
)
|
136
|
-
return {"user_id": payload["sub"]}
|
137
|
-
except jwt.ExpiredSignatureError:
|
133
|
+
print("[decode_token] Token is None")
|
138
134
|
raise InvalidCredentials(
|
139
|
-
"
|
140
|
-
log_txt="Error while trying to decode token.
|
141
|
-
|
142
|
-
except jwt.InvalidTokenError:
|
143
|
-
raise InvalidCredentials(
|
144
|
-
"Invalid token, please try again with a new token",
|
145
|
-
log_txt="Error while trying to decode token. The token is invalid.",
|
135
|
+
"Must provide a token in Authorization header",
|
136
|
+
log_txt="Error while trying to decode token. Token is missing.",
|
137
|
+
status_code=400,
|
146
138
|
)
|
147
139
|
|
148
|
-
def validate_oid_token(
|
149
|
-
self, token: str, client_id: str, tenant_id: str, issuer: str, provider: int
|
150
|
-
) -> dict:
|
151
|
-
"""
|
152
|
-
This method takes a token issued by an OID provider, the relevant information about the OID provider
|
153
|
-
and validates that the token was generated by such source, is valid and extracts the information
|
154
|
-
in the token for its use during the login process
|
155
|
-
|
156
|
-
:param str token: the received token
|
157
|
-
:param str client_id: the identifier from the client
|
158
|
-
:param str tenant_id: the identifier for the tenant
|
159
|
-
:param str issuer: the identifier for the issuer of the token
|
160
|
-
:param int provider: the identifier for the provider of the token
|
161
|
-
:return: the decoded token as a dictionary
|
162
|
-
:rtype: dict
|
163
|
-
"""
|
164
|
-
public_key = self._get_public_key(token, tenant_id, provider)
|
165
140
|
try:
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
141
|
+
# First try to decode header to validate basic token structure
|
142
|
+
print("[decode_token] Decoding unverified header")
|
143
|
+
unverified_payload = jwt.decode(token, options={"verify_signature": False})
|
144
|
+
issuer = unverified_payload.get("iss")
|
145
|
+
print(f"[decode_token] Token issuer: {issuer}")
|
146
|
+
|
147
|
+
# For internal tokens
|
148
|
+
if issuer == INTERNAL_TOKEN_ISSUER:
|
149
|
+
print("[decode_token] Processing internal token")
|
150
|
+
payload = jwt.decode(
|
151
|
+
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
152
|
+
)
|
153
|
+
return {"username": payload["sub"]}
|
154
|
+
|
155
|
+
# For OpenID tokens
|
156
|
+
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
157
|
+
print("[decode_token] Processing OpenID token")
|
158
|
+
decoded = Auth().verify_token(
|
159
|
+
token,
|
160
|
+
current_app.config["OID_PROVIDER"],
|
161
|
+
current_app.config["OID_EXPECTED_AUDIENCE"],
|
162
|
+
)
|
163
|
+
return {"username": decoded["sub"]}
|
164
|
+
|
165
|
+
# If we get here, the issuer is not valid
|
166
|
+
print(f"[decode_token] Invalid issuer: {issuer}")
|
167
|
+
raise InvalidCredentials(
|
168
|
+
"Invalid token issuer. Token must be issued by a valid provider",
|
169
|
+
log_txt="Error while trying to decode token. Invalid issuer.",
|
170
|
+
status_code=400,
|
173
171
|
)
|
174
|
-
|
172
|
+
|
175
173
|
except jwt.ExpiredSignatureError:
|
174
|
+
print("[decode_token] Caught ExpiredSignatureError")
|
176
175
|
raise InvalidCredentials(
|
177
176
|
"The token has expired, please login again",
|
178
|
-
log_txt="Error while trying to
|
177
|
+
log_txt="Error while trying to decode token. The token has expired.",
|
178
|
+
status_code=400,
|
179
179
|
)
|
180
|
-
except jwt.InvalidTokenError:
|
180
|
+
except jwt.InvalidTokenError as e:
|
181
|
+
print("[decode_token] Caught InvalidTokenError")
|
181
182
|
raise InvalidCredentials(
|
182
|
-
"Invalid token
|
183
|
-
log_txt="Error while trying to
|
183
|
+
"Invalid token format or signature",
|
184
|
+
log_txt=f"Error while trying to decode token. The token format is invalid: {str(e)}",
|
185
|
+
status_code=400,
|
184
186
|
)
|
185
187
|
|
186
|
-
|
187
|
-
def get_token_from_header(headers: Headers = None) -> str:
|
188
|
+
def get_token_from_header(self, headers: Headers = None) -> str:
|
188
189
|
"""
|
189
190
|
Extracts the token given on the request from the Authorization headers.
|
190
191
|
|
@@ -195,65 +196,167 @@ class Auth:
|
|
195
196
|
"""
|
196
197
|
if headers is None:
|
197
198
|
raise InvalidUsage(
|
198
|
-
|
199
|
+
"Request headers are missing",
|
200
|
+
log_txt="Error while trying to get a token from header. The header is invalid.",
|
201
|
+
status_code=400,
|
199
202
|
)
|
200
203
|
|
201
204
|
if "Authorization" not in headers:
|
202
205
|
raise InvalidCredentials(
|
203
|
-
"
|
206
|
+
"Authorization header is missing",
|
204
207
|
log_txt="Error while trying to get a token from header. The auth token is not available.",
|
208
|
+
status_code=400,
|
205
209
|
)
|
210
|
+
|
206
211
|
auth_header = headers.get("Authorization")
|
212
|
+
|
207
213
|
if not auth_header:
|
208
214
|
return ""
|
215
|
+
|
216
|
+
if not auth_header.startswith("Bearer "):
|
217
|
+
err = "Invalid Authorization header format. Must be 'Bearer <token>'"
|
218
|
+
raise InvalidCredentials(
|
219
|
+
err,
|
220
|
+
log_txt=f"Error while trying to get a token from header. " + err,
|
221
|
+
status_code=400,
|
222
|
+
)
|
223
|
+
|
209
224
|
try:
|
210
|
-
|
225
|
+
token = auth_header.split(" ")[1]
|
226
|
+
return token
|
211
227
|
except Exception as e:
|
212
|
-
err =
|
228
|
+
err = "Invalid Authorization header format. Must be 'Bearer <token>'"
|
213
229
|
raise InvalidCredentials(
|
214
|
-
err,
|
230
|
+
err,
|
231
|
+
log_txt=f"Error while trying to get a token from header. " + err,
|
232
|
+
status_code=400,
|
215
233
|
)
|
216
234
|
|
217
235
|
def get_user_from_header(self, headers: Headers = None) -> UserModel:
|
218
236
|
"""
|
219
|
-
|
237
|
+
Extracts the user from the Authorization headers.
|
220
238
|
|
221
239
|
:param headers: the request headers
|
222
240
|
:type headers: `Headers`
|
223
241
|
:return: the user object
|
224
|
-
:rtype: `
|
242
|
+
:rtype: :class:`UserModel`
|
225
243
|
"""
|
226
244
|
if headers is None:
|
227
|
-
err = "
|
245
|
+
err = "Request headers are missing"
|
228
246
|
raise InvalidUsage(
|
229
|
-
err,
|
247
|
+
err,
|
248
|
+
log_txt="Error while trying to get user from header. " + err,
|
249
|
+
status_code=400,
|
230
250
|
)
|
231
251
|
token = self.get_token_from_header(headers)
|
232
252
|
data = self.decode_token(token)
|
233
|
-
|
234
|
-
user = self.user_model.
|
253
|
+
|
254
|
+
user = self.user_model.get_one_object(username=data["username"])
|
255
|
+
|
235
256
|
if user is None:
|
236
|
-
err = "User
|
237
|
-
raise
|
238
|
-
err,
|
257
|
+
err = "User not found. Please ensure you are using valid credentials"
|
258
|
+
raise InvalidCredentials(
|
259
|
+
err,
|
260
|
+
log_txt="Error while trying to get user from header. User does not exist.",
|
261
|
+
status_code=400,
|
239
262
|
)
|
240
263
|
return user
|
241
264
|
|
242
265
|
@staticmethod
|
243
|
-
def
|
266
|
+
def get_public_keys(provider_url: str) -> dict:
|
244
267
|
"""
|
245
|
-
|
268
|
+
Gets the public keys from the OIDC provider and caches them
|
246
269
|
|
247
|
-
:param str
|
248
|
-
:return:
|
249
|
-
:rtype:
|
270
|
+
:param str provider_url: The base URL of the OIDC provider
|
271
|
+
:return: Dictionary of kid to public key mappings
|
272
|
+
:rtype: dict
|
250
273
|
"""
|
251
|
-
|
252
|
-
|
274
|
+
# Fetch keys from provider
|
275
|
+
jwks_url = f"{provider_url.rstrip('/')}/.well-known/jwks.json"
|
276
|
+
try:
|
277
|
+
response = requests.get(jwks_url)
|
278
|
+
response.raise_for_status()
|
279
|
+
|
280
|
+
# Convert JWK to RSA public keys using PyJWT's built-in method
|
281
|
+
public_keys = {
|
282
|
+
key["kid"]: RSAAlgorithm.from_jwk(key)
|
283
|
+
for key in response.json()["keys"]
|
284
|
+
}
|
253
285
|
|
254
|
-
|
255
|
-
|
256
|
-
|
286
|
+
# Store in cache
|
287
|
+
public_keys_cache[provider_url] = public_keys
|
288
|
+
return public_keys
|
289
|
+
|
290
|
+
except requests.exceptions.RequestException as e:
|
291
|
+
raise CommunicationError(
|
292
|
+
"Failed to fetch public keys from authentication provider",
|
293
|
+
log_txt=f"Error while fetching public keys from {jwks_url}: {str(e)}",
|
294
|
+
status_code=400,
|
295
|
+
)
|
296
|
+
|
297
|
+
def verify_token(
|
298
|
+
self, token: str, provider_url: str, expected_audience: str
|
299
|
+
) -> dict:
|
300
|
+
"""
|
301
|
+
Verifies an OpenID Connect token
|
302
|
+
|
303
|
+
:param str token: The token to verify
|
304
|
+
:param str provider_url: The base URL of the OIDC provider
|
305
|
+
:param str expected_audience: The expected audience claim
|
306
|
+
:return: The decoded token claims
|
307
|
+
:rtype: dict
|
308
|
+
"""
|
309
|
+
print("[verify_token] Starting token verification")
|
310
|
+
# Get unverified header - this will raise jwt.InvalidTokenError if token format is invalid
|
311
|
+
print("[verify_token] Getting unverified header")
|
312
|
+
unverified_header = jwt.get_unverified_header(token)
|
313
|
+
print(f"[verify_token] Unverified header: {unverified_header}")
|
314
|
+
|
315
|
+
# Check for kid in header
|
316
|
+
if "kid" not in unverified_header:
|
317
|
+
print("[verify_token] Missing kid in header")
|
318
|
+
raise InvalidCredentials(
|
319
|
+
"Invalid token: Missing key identifier (kid) in token header",
|
320
|
+
log_txt="Error while verifying token. Token header is missing 'kid'.",
|
321
|
+
status_code=400,
|
322
|
+
)
|
323
|
+
|
324
|
+
kid = unverified_header["kid"]
|
325
|
+
print(f"[verify_token] Found kid: {kid}")
|
326
|
+
|
327
|
+
# Check if we have the keys in cache and if the kid exists
|
328
|
+
print("[verify_token] Checking cache for public keys")
|
329
|
+
public_key = None
|
330
|
+
if provider_url in public_keys_cache:
|
331
|
+
print(f"[verify_token] Found keys in cache for {provider_url}")
|
332
|
+
cached_keys = public_keys_cache[provider_url]
|
333
|
+
if kid in cached_keys:
|
334
|
+
print("[verify_token] Found matching kid in cached keys")
|
335
|
+
public_key = cached_keys[kid]
|
336
|
+
|
337
|
+
# If kid not in cache, fetch fresh keys
|
338
|
+
if public_key is None:
|
339
|
+
print("[verify_token] Public key not found in cache, fetching fresh keys")
|
340
|
+
public_keys = self.get_public_keys(provider_url)
|
341
|
+
if kid not in public_keys:
|
342
|
+
print(f"[verify_token] Kid {kid} not found in fresh public keys")
|
343
|
+
print("[verify_token] About to raise InvalidCredentials")
|
344
|
+
raise InvalidCredentials(
|
345
|
+
"Invalid token: Unknown key identifier (kid)",
|
346
|
+
log_txt="Error while verifying token. Key ID not found in public keys.",
|
347
|
+
status_code=400,
|
348
|
+
)
|
349
|
+
public_key = public_keys[kid]
|
350
|
+
print("[verify_token] Found public key in fresh keys")
|
351
|
+
|
352
|
+
# Verify token - this will raise appropriate jwt exceptions that will be caught in decode_token
|
353
|
+
return jwt.decode(
|
354
|
+
token,
|
355
|
+
public_key,
|
356
|
+
algorithms=["RS256"],
|
357
|
+
audience=[expected_audience],
|
358
|
+
issuer=provider_url,
|
359
|
+
)
|
257
360
|
|
258
361
|
@staticmethod
|
259
362
|
def _get_permission_for_request(req, user_id):
|
@@ -308,164 +411,6 @@ class Auth:
|
|
308
411
|
"""
|
309
412
|
return getattr(req, "environ")["REQUEST_METHOD"], getattr(req, "url_rule").rule
|
310
413
|
|
311
|
-
@staticmethod
|
312
|
-
def _get_key_id(token: str) -> str:
|
313
|
-
"""
|
314
|
-
Function to get the Key ID from the token
|
315
|
-
|
316
|
-
:param str token: the given token
|
317
|
-
:return: the key identifier
|
318
|
-
:rtype: str
|
319
|
-
"""
|
320
|
-
try:
|
321
|
-
headers = jwt.get_unverified_header(token)
|
322
|
-
except DecodeError as err:
|
323
|
-
raise InvalidCredentials("Token is not valid")
|
324
|
-
if not headers:
|
325
|
-
raise InvalidCredentials("Token is missing the headers")
|
326
|
-
try:
|
327
|
-
return headers["kid"]
|
328
|
-
except KeyError:
|
329
|
-
raise InvalidCredentials("Token is missing the key identifier")
|
330
|
-
|
331
|
-
@staticmethod
|
332
|
-
def _fetch_discovery_meta(tenant_id: str, provider: int) -> dict:
|
333
|
-
"""
|
334
|
-
Function to return a dictionary with the discovery URL of the provider
|
335
|
-
|
336
|
-
:param str tenant_id: the tenant id
|
337
|
-
:param int provider: the provider information
|
338
|
-
:return: the different urls to be discovered on the provider
|
339
|
-
:rtype: dict
|
340
|
-
"""
|
341
|
-
if provider == OID_AZURE:
|
342
|
-
oid_tenant_url = OID_AZURE_DISCOVERY_TENANT_URL
|
343
|
-
oid_common_url = OID_AZURE_DISCOVERY_COMMON_URL
|
344
|
-
elif provider == OID_GOOGLE:
|
345
|
-
raise EndpointNotImplemented("The OID provider configuration is not valid")
|
346
|
-
else:
|
347
|
-
raise EndpointNotImplemented("The OID provider configuration is not valid")
|
348
|
-
|
349
|
-
discovery_url = (
|
350
|
-
oid_tenant_url.format(tenant_id=tenant_id) if tenant_id else oid_common_url
|
351
|
-
)
|
352
|
-
try:
|
353
|
-
response = requests.get(discovery_url)
|
354
|
-
response.raise_for_status()
|
355
|
-
except requests.exceptions.HTTPError:
|
356
|
-
raise CommunicationError(
|
357
|
-
f"Error getting issuer discovery meta from {discovery_url}"
|
358
|
-
)
|
359
|
-
return response.json()
|
360
|
-
|
361
|
-
def _get_json_web_keys_uri(self, tenant_id: str, provider: int) -> str:
|
362
|
-
"""
|
363
|
-
Returns the JSON Web Keys URI
|
364
|
-
|
365
|
-
:param str tenant_id: the tenant id
|
366
|
-
:param int provider: the provider information
|
367
|
-
:return: the URI from where to get the JSON Web Keys
|
368
|
-
:rtype: str
|
369
|
-
"""
|
370
|
-
meta = self._fetch_discovery_meta(tenant_id, provider)
|
371
|
-
if "jwks_uri" in meta:
|
372
|
-
return meta["jwks_uri"]
|
373
|
-
else:
|
374
|
-
raise CommunicationError("jwks_uri not found in the issuer meta")
|
375
|
-
|
376
|
-
def _get_json_web_keys(self, tenant_id: str, provider: int) -> dict:
|
377
|
-
"""
|
378
|
-
Function to get the json web keys from the tenant id and the provider
|
379
|
-
|
380
|
-
:param str tenant_id: the tenant id
|
381
|
-
:param int provider: the provider information
|
382
|
-
:return: the JSON Web Keys dict
|
383
|
-
:rtype: dict
|
384
|
-
"""
|
385
|
-
json_web_keys_uri = self._get_json_web_keys_uri(tenant_id, provider)
|
386
|
-
try:
|
387
|
-
response = requests.get(json_web_keys_uri)
|
388
|
-
response.raise_for_status()
|
389
|
-
except requests.exceptions.HTTPError as error:
|
390
|
-
raise CommunicationError(
|
391
|
-
f"Error getting issuer jwks from {json_web_keys_uri}", error
|
392
|
-
)
|
393
|
-
return response.json()
|
394
|
-
|
395
|
-
def _get_jwk(self, kid: str, tenant_id: str, provider: int) -> dict:
|
396
|
-
"""
|
397
|
-
Function to get the JSON Web Key from the key identifier, the tenant id and the provider information
|
398
|
-
|
399
|
-
:param str kid: the key identifier
|
400
|
-
:param str tenant_id: the tenant information
|
401
|
-
:param int provider: the provider information
|
402
|
-
:return: the JSON Web Key
|
403
|
-
:rtype: dict
|
404
|
-
"""
|
405
|
-
for jwk in self._get_json_web_keys(tenant_id, provider).get("keys"):
|
406
|
-
if jwk.get("kid") == kid:
|
407
|
-
return jwk
|
408
|
-
raise InvalidCredentials("Token has an unknown key identifier")
|
409
|
-
|
410
|
-
@staticmethod
|
411
|
-
def _ensure_bytes(key: Union[str, bytes]) -> bytes:
|
412
|
-
"""
|
413
|
-
Function that ensures that the key is in bytes format
|
414
|
-
|
415
|
-
:param str | bytes key:
|
416
|
-
:return: the key on bytes format
|
417
|
-
:rtype: bytes
|
418
|
-
"""
|
419
|
-
if isinstance(key, str):
|
420
|
-
key = key.encode("utf-8")
|
421
|
-
return key
|
422
|
-
|
423
|
-
def _decode_value(self, val: Union[str, bytes]) -> int:
|
424
|
-
"""
|
425
|
-
Function that ensures that the value is decoded as a big int
|
426
|
-
|
427
|
-
:param str | bytes val: the value that has to be decoded
|
428
|
-
:return: the decoded value as a big int
|
429
|
-
:rtype: int
|
430
|
-
"""
|
431
|
-
decoded = base64.urlsafe_b64decode(self._ensure_bytes(val) + b"==")
|
432
|
-
return int.from_bytes(decoded, "big")
|
433
|
-
|
434
|
-
def _rsa_pem_from_jwk(self, jwk: dict) -> bytes:
|
435
|
-
"""
|
436
|
-
Returns the private key from the JSON Web Key encoded as PEM
|
437
|
-
|
438
|
-
:param dict jwk: the JSON Web Key
|
439
|
-
:return: the RSA PEM key serialized as bytes
|
440
|
-
:rtype: bytes
|
441
|
-
"""
|
442
|
-
return (
|
443
|
-
RSAPublicNumbers(
|
444
|
-
n=self._decode_value(jwk["n"]),
|
445
|
-
e=self._decode_value(jwk["e"]),
|
446
|
-
)
|
447
|
-
.public_key(default_backend())
|
448
|
-
.public_bytes(
|
449
|
-
encoding=serialization.Encoding.PEM,
|
450
|
-
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
451
|
-
)
|
452
|
-
)
|
453
|
-
|
454
|
-
def _get_public_key(self, token: str, tenant_id: str, provider: int):
|
455
|
-
"""
|
456
|
-
This method returns the public key from the given token, ensuring that
|
457
|
-
the tenant information and provider are correct
|
458
|
-
|
459
|
-
:param str token: the given token
|
460
|
-
:param str tenant_id: the tenant information
|
461
|
-
:param int provider: the token provider information
|
462
|
-
:return: the public key in the token or it raises an error
|
463
|
-
:rtype: str
|
464
|
-
"""
|
465
|
-
kid = self._get_key_id(token)
|
466
|
-
jwk = self._get_jwk(kid, tenant_id, provider)
|
467
|
-
return self._rsa_pem_from_jwk(jwk)
|
468
|
-
|
469
414
|
|
470
415
|
class BIAuth(Auth):
|
471
416
|
def __init__(self, user_model=UserModel):
|
@@ -474,10 +419,10 @@ class BIAuth(Auth):
|
|
474
419
|
@staticmethod
|
475
420
|
def decode_token(token: str = None) -> dict:
|
476
421
|
"""
|
477
|
-
Decodes a given JSON Web token and extracts the
|
422
|
+
Decodes a given JSON Web token and extracts the username from the sub claim.
|
478
423
|
|
479
424
|
:param str token: the given JSON Web Token
|
480
|
-
:return: the
|
425
|
+
:return: dictionary containing the username from the token's sub claim
|
481
426
|
:rtype: dict
|
482
427
|
"""
|
483
428
|
if token is None:
|
@@ -489,19 +434,21 @@ class BIAuth(Auth):
|
|
489
434
|
payload = jwt.decode(
|
490
435
|
token, current_app.config["SECRET_BI_KEY"], algorithms="HS256"
|
491
436
|
)
|
492
|
-
return {"
|
437
|
+
return {"username": payload["sub"]}
|
493
438
|
except jwt.InvalidTokenError:
|
494
439
|
raise InvalidCredentials(
|
495
440
|
"Invalid token, please try again with a new token",
|
496
441
|
log_txt="Error while trying to decode token. The token is invalid.",
|
442
|
+
status_code=400,
|
497
443
|
)
|
498
444
|
|
499
445
|
@staticmethod
|
500
446
|
def generate_token(user_id: int = None) -> str:
|
501
447
|
"""
|
502
|
-
Generates a token given a user_id
|
448
|
+
Generates a token given a user_id. The token will contain the username in the sub claim.
|
449
|
+
BI tokens do not include expiration time.
|
503
450
|
|
504
|
-
:param int user_id: user
|
451
|
+
:param int user_id: user id to generate the token for
|
505
452
|
:return: the generated token
|
506
453
|
:rtype: str
|
507
454
|
"""
|
@@ -511,9 +458,19 @@ class BIAuth(Auth):
|
|
511
458
|
err, log_txt="Error while trying to generate token. " + err
|
512
459
|
)
|
513
460
|
|
461
|
+
user = UserModel.get_one_user(user_id)
|
462
|
+
if user is None:
|
463
|
+
err = "User does not exist"
|
464
|
+
raise InvalidUsage(
|
465
|
+
err, log_txt="Error while trying to generate token. " + err
|
466
|
+
)
|
467
|
+
|
514
468
|
payload = {
|
469
|
+
"exp": datetime.utcnow()
|
470
|
+
+ timedelta(hours=float(current_app.config["TOKEN_DURATION"])),
|
515
471
|
"iat": datetime.utcnow(),
|
516
|
-
"sub":
|
472
|
+
"sub": user.username,
|
473
|
+
"iss": INTERNAL_TOKEN_ISSUER,
|
517
474
|
}
|
518
475
|
|
519
476
|
return jwt.encode(
|
cornflow/shared/const.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
"""
|
2
|
-
In this
|
2
|
+
In this file we import the values for different constants on cornflow server
|
3
3
|
"""
|
4
4
|
|
5
|
+
INTERNAL_TOKEN_ISSUER = "cornflow"
|
6
|
+
|
5
7
|
# endpoints responses for health check
|
6
8
|
STATUS_HEALTHY = "healthy"
|
7
9
|
STATUS_UNHEALTHY = "unhealthy"
|
@@ -50,19 +52,6 @@ AUTH_LDAP = 2
|
|
50
52
|
AUTH_OAUTH = 4
|
51
53
|
AUTH_OID = 0
|
52
54
|
|
53
|
-
# Providers of open ID:
|
54
|
-
OID_NONE = 0
|
55
|
-
OID_AZURE = 1
|
56
|
-
OID_GOOGLE = 2
|
57
|
-
|
58
|
-
# AZURE OPEN ID URLS
|
59
|
-
OID_AZURE_DISCOVERY_COMMON_URL = (
|
60
|
-
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
|
61
|
-
)
|
62
|
-
OID_AZURE_DISCOVERY_TENANT_URL = (
|
63
|
-
"https://login.microsoftonline.com/{tenant_id}/.well-known/openid-configuration"
|
64
|
-
)
|
65
|
-
|
66
55
|
GET_ACTION = 1
|
67
56
|
PATCH_ACTION = 2
|
68
57
|
POST_ACTION = 3
|