cornflow 1.2.1__py3-none-any.whl → 1.2.3a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornflow/app.py +3 -1
- cornflow/cli/__init__.py +4 -0
- cornflow/cli/actions.py +4 -0
- cornflow/cli/config.py +4 -0
- cornflow/cli/migrations.py +13 -8
- cornflow/cli/permissions.py +4 -0
- cornflow/cli/roles.py +4 -0
- cornflow/cli/schemas.py +5 -0
- cornflow/cli/service.py +235 -131
- cornflow/cli/tools/api_generator.py +13 -10
- cornflow/cli/tools/endpoint_tools.py +191 -196
- cornflow/cli/tools/models_tools.py +87 -60
- cornflow/cli/tools/schema_generator.py +161 -67
- cornflow/cli/tools/schemas_tools.py +4 -5
- cornflow/cli/users.py +8 -0
- cornflow/cli/views.py +4 -0
- cornflow/commands/dag.py +3 -2
- cornflow/commands/permissions.py +3 -3
- cornflow/commands/schemas.py +6 -4
- cornflow/commands/users.py +12 -17
- cornflow/endpoints/dag.py +27 -25
- cornflow/endpoints/data_check.py +128 -165
- cornflow/endpoints/example_data.py +9 -3
- cornflow/endpoints/execution.py +40 -34
- cornflow/endpoints/health.py +7 -7
- cornflow/endpoints/instance.py +39 -12
- cornflow/endpoints/meta_resource.py +4 -5
- cornflow/schemas/execution.py +9 -1
- cornflow/schemas/health.py +1 -0
- cornflow/shared/authentication/auth.py +73 -42
- cornflow/shared/const.py +10 -1
- cornflow/shared/exceptions.py +3 -1
- cornflow/shared/utils_tables.py +36 -8
- cornflow/shared/validators.py +1 -1
- cornflow/tests/const.py +1 -0
- cornflow/tests/custom_test_case.py +4 -4
- cornflow/tests/unit/test_alarms.py +1 -2
- cornflow/tests/unit/test_cases.py +4 -7
- cornflow/tests/unit/test_commands.py +90 -0
- cornflow/tests/unit/test_executions.py +22 -1
- cornflow/tests/unit/test_health.py +4 -1
- cornflow/tests/unit/test_log_in.py +46 -9
- cornflow/tests/unit/test_tables.py +3 -3
- {cornflow-1.2.1.dist-info → cornflow-1.2.3a1.dist-info}/METADATA +1 -1
- {cornflow-1.2.1.dist-info → cornflow-1.2.3a1.dist-info}/RECORD +48 -48
- {cornflow-1.2.1.dist-info → cornflow-1.2.3a1.dist-info}/WHEEL +1 -1
- {cornflow-1.2.1.dist-info → cornflow-1.2.3a1.dist-info}/entry_points.txt +0 -0
- {cornflow-1.2.1.dist-info → cornflow-1.2.3a1.dist-info}/top_level.txt +0 -0
cornflow/endpoints/execution.py
CHANGED
@@ -24,11 +24,14 @@ from cornflow.schemas.execution import (
|
|
24
24
|
ExecutionEditRequest,
|
25
25
|
QueryFiltersExecution,
|
26
26
|
ReLaunchExecutionRequest,
|
27
|
-
ExecutionDetailsWithIndicatorsAndLogResponse
|
27
|
+
ExecutionDetailsWithIndicatorsAndLogResponse,
|
28
28
|
)
|
29
29
|
from cornflow.shared.authentication import Auth, authenticate
|
30
30
|
from cornflow.shared.compress import compressed
|
31
31
|
from cornflow.shared.const import (
|
32
|
+
AIRFLOW_ERROR_MSG,
|
33
|
+
AIRFLOW_NOT_REACHABLE_MSG,
|
34
|
+
DAG_PAUSED_MSG,
|
32
35
|
EXEC_STATE_RUNNING,
|
33
36
|
EXEC_STATE_ERROR,
|
34
37
|
EXEC_STATE_ERROR_START,
|
@@ -100,8 +103,8 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
100
103
|
af_client = Airflow.from_config(current_app.config)
|
101
104
|
if not af_client.is_alive():
|
102
105
|
current_app.logger.warning(
|
103
|
-
"Error while the app tried to update the status of all running executions."
|
104
|
-
"
|
106
|
+
f"Error while the app tried to update the status of all running executions."
|
107
|
+
f"{AIRFLOW_NOT_REACHABLE_MSG}"
|
105
108
|
)
|
106
109
|
continue
|
107
110
|
|
@@ -112,7 +115,7 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
112
115
|
except AirflowError as err:
|
113
116
|
current_app.logger.warning(
|
114
117
|
"Error while the app tried to update the status of all running executions."
|
115
|
-
f"
|
118
|
+
f"{AIRFLOW_ERROR_MSG} {err}"
|
116
119
|
)
|
117
120
|
continue
|
118
121
|
|
@@ -137,18 +140,21 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
137
140
|
the reference_id for the newly created execution if successful) and a integer wit the HTTP status code
|
138
141
|
:rtype: Tuple(dict, integer)
|
139
142
|
"""
|
140
|
-
|
141
|
-
# TODO: should the schema field be cross validated with the instance schema field?
|
143
|
+
|
142
144
|
config = current_app.config
|
143
145
|
|
144
146
|
if "schema" not in kwargs:
|
145
147
|
kwargs["schema"] = "solve_model_dag"
|
146
148
|
|
147
|
-
execution,
|
149
|
+
execution, _ = self.post_list(data=kwargs)
|
148
150
|
instance = InstanceModel.get_one_object(
|
149
151
|
user=self.get_user(), idx=execution.instance_id
|
150
152
|
)
|
151
153
|
|
154
|
+
if execution.schema != instance.schema:
|
155
|
+
execution.delete()
|
156
|
+
raise InvalidData(error="Instance and execution schema mismatch")
|
157
|
+
|
152
158
|
current_app.logger.debug(f"The request is: {request.args.get('run')}")
|
153
159
|
# this allows testing without airflow interaction:
|
154
160
|
if request.args.get("run", "1") == "0":
|
@@ -161,17 +167,17 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
161
167
|
# We now try to launch the task in airflow
|
162
168
|
af_client = Airflow.from_config(config)
|
163
169
|
if not af_client.is_alive():
|
164
|
-
|
165
|
-
current_app.logger.error(
|
170
|
+
|
171
|
+
current_app.logger.error(AIRFLOW_NOT_REACHABLE_MSG)
|
166
172
|
execution.update_state(EXEC_STATE_ERROR_START)
|
167
173
|
raise AirflowError(
|
168
|
-
error=
|
174
|
+
error=AIRFLOW_NOT_REACHABLE_MSG,
|
169
175
|
payload=dict(
|
170
176
|
message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
|
171
177
|
state=EXEC_STATE_ERROR_START,
|
172
178
|
),
|
173
179
|
log_txt=f"Error while user {self.get_user()} tries to create an execution "
|
174
|
-
+
|
180
|
+
+ AIRFLOW_NOT_REACHABLE_MSG,
|
175
181
|
)
|
176
182
|
# ask airflow if dag_name exists
|
177
183
|
schema = execution.schema
|
@@ -231,23 +237,23 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
231
237
|
|
232
238
|
info = schema_info.json()
|
233
239
|
if info["is_paused"]:
|
234
|
-
|
235
|
-
current_app.logger.error(
|
240
|
+
|
241
|
+
current_app.logger.error(DAG_PAUSED_MSG)
|
236
242
|
execution.update_state(EXEC_STATE_ERROR_START)
|
237
243
|
raise AirflowError(
|
238
|
-
error=
|
244
|
+
error=DAG_PAUSED_MSG,
|
239
245
|
payload=dict(
|
240
246
|
message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
|
241
247
|
state=EXEC_STATE_ERROR_START,
|
242
248
|
),
|
243
249
|
log_txt=f"Error while user {self.get_user()} tries to create an execution. "
|
244
|
-
+
|
250
|
+
+ DAG_PAUSED_MSG,
|
245
251
|
)
|
246
252
|
|
247
253
|
try:
|
248
254
|
response = af_client.run_dag(execution.id, dag_name=schema)
|
249
255
|
except AirflowError as err:
|
250
|
-
error = "
|
256
|
+
error = f"{AIRFLOW_ERROR_MSG} {err}"
|
251
257
|
current_app.logger.error(error)
|
252
258
|
execution.update_state(EXEC_STATE_ERROR)
|
253
259
|
raise AirflowError(
|
@@ -339,17 +345,17 @@ class ExecutionRelaunchEndpoint(BaseMetaResource):
|
|
339
345
|
# We now try to launch the task in airflow
|
340
346
|
af_client = Airflow.from_config(config)
|
341
347
|
if not af_client.is_alive():
|
342
|
-
|
343
|
-
current_app.logger.error(
|
348
|
+
|
349
|
+
current_app.logger.error(AIRFLOW_NOT_REACHABLE_MSG)
|
344
350
|
execution.update_state(EXEC_STATE_ERROR_START)
|
345
351
|
raise AirflowError(
|
346
|
-
error=
|
352
|
+
error=AIRFLOW_NOT_REACHABLE_MSG,
|
347
353
|
payload=dict(
|
348
354
|
message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
|
349
355
|
state=EXEC_STATE_ERROR_START,
|
350
356
|
),
|
351
357
|
log_txt=f"Error while user {self.get_user()} tries to relaunch execution {idx}. "
|
352
|
-
+
|
358
|
+
+ AIRFLOW_NOT_REACHABLE_MSG,
|
353
359
|
)
|
354
360
|
# ask airflow if dag_name exists
|
355
361
|
schema = execution.schema
|
@@ -357,23 +363,23 @@ class ExecutionRelaunchEndpoint(BaseMetaResource):
|
|
357
363
|
|
358
364
|
info = schema_info.json()
|
359
365
|
if info["is_paused"]:
|
360
|
-
|
361
|
-
current_app.logger.error(
|
366
|
+
|
367
|
+
current_app.logger.error(DAG_PAUSED_MSG)
|
362
368
|
execution.update_state(EXEC_STATE_ERROR_START)
|
363
369
|
raise AirflowError(
|
364
|
-
error=
|
370
|
+
error=DAG_PAUSED_MSG,
|
365
371
|
payload=dict(
|
366
372
|
message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
|
367
373
|
state=EXEC_STATE_ERROR_START,
|
368
374
|
),
|
369
375
|
log_txt=f"Error while user {self.get_user()} tries to relaunch execution {idx}. "
|
370
|
-
+
|
376
|
+
+ DAG_PAUSED_MSG,
|
371
377
|
)
|
372
378
|
|
373
379
|
try:
|
374
380
|
response = af_client.run_dag(execution.id, dag_name=schema)
|
375
381
|
except AirflowError as err:
|
376
|
-
error = "
|
382
|
+
error = f"{AIRFLOW_ERROR_MSG} {err}"
|
377
383
|
current_app.logger.error(error)
|
378
384
|
execution.update_state(EXEC_STATE_ERROR)
|
379
385
|
raise AirflowError(
|
@@ -490,13 +496,13 @@ class ExecutionDetailsEndpoint(ExecutionDetailsEndpointBase):
|
|
490
496
|
)
|
491
497
|
af_client = Airflow.from_config(current_app.config)
|
492
498
|
if not af_client.is_alive():
|
493
|
-
|
499
|
+
|
494
500
|
raise AirflowError(
|
495
|
-
error=
|
501
|
+
error=AIRFLOW_NOT_REACHABLE_MSG,
|
496
502
|
log_txt=f"Error while user {self.get_user()} tries to stop execution {idx}. "
|
497
|
-
+
|
503
|
+
+ AIRFLOW_NOT_REACHABLE_MSG,
|
498
504
|
)
|
499
|
-
|
505
|
+
af_client.set_dag_run_to_fail(
|
500
506
|
dag_name=execution.schema, dag_run_id=execution.dag_run_id
|
501
507
|
)
|
502
508
|
execution.update_state(EXEC_STATE_STOPPED)
|
@@ -563,21 +569,21 @@ class ExecutionStatusEndpoint(BaseMetaResource):
|
|
563
569
|
|
564
570
|
af_client = Airflow.from_config(current_app.config)
|
565
571
|
if not af_client.is_alive():
|
566
|
-
|
572
|
+
|
567
573
|
_raise_af_error(
|
568
574
|
execution,
|
569
|
-
|
575
|
+
AIRFLOW_NOT_REACHABLE_MSG,
|
570
576
|
log_txt=f"Error while user {self.get_user()} tries to get the status of execution {idx}. "
|
571
|
-
+
|
577
|
+
+ AIRFLOW_NOT_REACHABLE_MSG,
|
572
578
|
)
|
573
579
|
|
574
580
|
try:
|
575
|
-
|
581
|
+
|
576
582
|
response = af_client.get_dag_run_status(
|
577
583
|
dag_name=execution.schema, dag_run_id=dag_run_id
|
578
584
|
)
|
579
585
|
except AirflowError as err:
|
580
|
-
error = f"
|
586
|
+
error = f"{AIRFLOW_ERROR_MSG} {err}"
|
581
587
|
_raise_af_error(
|
582
588
|
execution,
|
583
589
|
error,
|
cornflow/endpoints/health.py
CHANGED
@@ -4,16 +4,15 @@ It performs a health check to airflow and a health check to cornflow database
|
|
4
4
|
"""
|
5
5
|
import os
|
6
6
|
|
7
|
-
# Import from libraries
|
8
|
-
from cornflow_client.airflow.api import Airflow
|
9
|
-
from flask import current_app
|
10
|
-
from flask_apispec import marshal_with, doc
|
11
|
-
|
12
7
|
# Import from internal modules
|
13
8
|
from cornflow.endpoints.meta_resource import BaseMetaResource
|
14
9
|
from cornflow.models import UserModel
|
15
10
|
from cornflow.schemas.health import HealthResponse
|
16
|
-
from cornflow.shared.const import STATUS_HEALTHY, STATUS_UNHEALTHY
|
11
|
+
from cornflow.shared.const import STATUS_HEALTHY, STATUS_UNHEALTHY, CORNFLOW_VERSION
|
12
|
+
# Import from libraries
|
13
|
+
from cornflow_client.airflow.api import Airflow
|
14
|
+
from flask import current_app
|
15
|
+
from flask_apispec import marshal_with, doc
|
17
16
|
|
18
17
|
|
19
18
|
class HealthEndpoint(BaseMetaResource):
|
@@ -30,6 +29,7 @@ class HealthEndpoint(BaseMetaResource):
|
|
30
29
|
af_client = Airflow.from_config(current_app.config)
|
31
30
|
airflow_status = STATUS_UNHEALTHY
|
32
31
|
cornflow_status = STATUS_UNHEALTHY
|
32
|
+
cornflow_version = CORNFLOW_VERSION
|
33
33
|
if af_client.is_alive():
|
34
34
|
airflow_status = STATUS_HEALTHY
|
35
35
|
|
@@ -42,4 +42,4 @@ class HealthEndpoint(BaseMetaResource):
|
|
42
42
|
current_app.logger.info(
|
43
43
|
f"Health check: cornflow {cornflow_status}, airflow {airflow_status}"
|
44
44
|
)
|
45
|
-
return {"cornflow_status": cornflow_status, "airflow_status": airflow_status}
|
45
|
+
return {"cornflow_status": cornflow_status, "airflow_status": airflow_status, "cornflow_version":cornflow_version}
|
cornflow/endpoints/instance.py
CHANGED
@@ -34,7 +34,6 @@ from cornflow.shared.exceptions import InvalidUsage, InvalidData
|
|
34
34
|
from cornflow.shared.validators import json_schema_validate_as_string
|
35
35
|
|
36
36
|
|
37
|
-
|
38
37
|
# Initialize the schema that all endpoints are going to use
|
39
38
|
ALLOWED_EXTENSIONS = {"mps", "lp"}
|
40
39
|
|
@@ -92,14 +91,18 @@ class InstanceEndpoint(BaseMetaResource):
|
|
92
91
|
# We validate the instance data
|
93
92
|
config = current_app.config
|
94
93
|
|
95
|
-
instance_schema = DeployedDAG.get_one_schema(
|
96
|
-
|
94
|
+
instance_schema = DeployedDAG.get_one_schema(
|
95
|
+
config, data_schema, INSTANCE_SCHEMA
|
96
|
+
)
|
97
|
+
instance_errors = json_schema_validate_as_string(
|
98
|
+
instance_schema, kwargs["data"]
|
99
|
+
)
|
97
100
|
|
98
101
|
if instance_errors:
|
99
102
|
raise InvalidData(
|
100
103
|
payload=dict(jsonschema_errors=instance_errors),
|
101
104
|
log_txt=f"Error while user {self.get_user()} tries to create an instance. "
|
102
|
-
|
105
|
+
f"Instance data do not match the jsonschema.",
|
103
106
|
)
|
104
107
|
|
105
108
|
# if we're here, we validated and the data seems to fit the schema
|
@@ -163,14 +166,18 @@ class InstanceDetailsEndpoint(InstanceDetailsEndpointBase):
|
|
163
166
|
|
164
167
|
config = current_app.config
|
165
168
|
|
166
|
-
instance_schema = DeployedDAG.get_one_schema(
|
167
|
-
|
169
|
+
instance_schema = DeployedDAG.get_one_schema(
|
170
|
+
config, schema, INSTANCE_SCHEMA
|
171
|
+
)
|
172
|
+
instance_errors = json_schema_validate_as_string(
|
173
|
+
instance_schema, kwargs["data"]
|
174
|
+
)
|
168
175
|
|
169
176
|
if instance_errors:
|
170
177
|
raise InvalidData(
|
171
178
|
payload=dict(jsonschema_errors=instance_errors),
|
172
179
|
log_txt=f"Error while user {self.get_user()} tries to create an instance. "
|
173
|
-
|
180
|
+
f"Instance data do not match the jsonschema.",
|
174
181
|
)
|
175
182
|
|
176
183
|
response = self.put_detail(data=kwargs, user=self.get_user(), idx=idx)
|
@@ -268,15 +275,35 @@ class InstanceFileEndpoint(BaseMetaResource):
|
|
268
275
|
sense = 1 if minimize else -1
|
269
276
|
try:
|
270
277
|
_vars, problem = pulp.LpProblem.fromMPS(filename, sense=sense)
|
271
|
-
except:
|
278
|
+
except FileNotFoundError as e:
|
279
|
+
# Handle file not found specifically
|
272
280
|
raise InvalidUsage(
|
273
|
-
error="
|
274
|
-
log_txt=f"Error
|
275
|
-
|
281
|
+
error=f"MPS file not found: {filename}",
|
282
|
+
log_txt=f"Error for user {self.get_user()}: MPS file '{filename}' not found. Details: {e}",
|
283
|
+
status_code=404,
|
284
|
+
) from e
|
285
|
+
except PermissionError as e:
|
286
|
+
# Handle permission issues
|
287
|
+
raise InvalidUsage(
|
288
|
+
error=f"Permission denied reading MPS file: {filename}",
|
289
|
+
log_txt=f"Error for user {self.get_user()}: Permission denied for MPS file '{filename}'. Details: {e}",
|
290
|
+
status_code=403,
|
291
|
+
) from e
|
292
|
+
except (ValueError, pulp.PulpError, OSError, IndexError) as e:
|
293
|
+
# Catch parsing errors, PuLP errors, and other IO errors
|
294
|
+
# Handle parsing, PuLP, or other OS errors
|
295
|
+
current_app.logger.error(
|
296
|
+
f"Error parsing MPS file {filename} for user {self.get_user()}: {e}",
|
297
|
+
exc_info=True,
|
276
298
|
)
|
299
|
+
raise InvalidUsage(
|
300
|
+
error="Error reading or parsing the MPS file.",
|
301
|
+
log_txt=f"Error while user {self.get_user()} tries to create instance from MPS file {filename}. Details: {e}",
|
302
|
+
) from e
|
303
|
+
|
277
304
|
try:
|
278
305
|
os.remove(filename)
|
279
|
-
except:
|
306
|
+
except FileNotFoundError:
|
280
307
|
pass
|
281
308
|
|
282
309
|
pb_data = dict(
|
@@ -1,6 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
This file has all the logic shared for all the resources
|
3
3
|
"""
|
4
|
+
|
4
5
|
# Import from external libraries
|
5
6
|
from flask_restful import Resource
|
6
7
|
from flask import g, request
|
@@ -13,7 +14,6 @@ from cornflow.shared.const import ALL_DEFAULT_ROLES
|
|
13
14
|
from cornflow.shared.exceptions import InvalidUsage, ObjectDoesNotExist, NoPermission
|
14
15
|
|
15
16
|
|
16
|
-
|
17
17
|
class BaseMetaResource(Resource, MethodResource):
|
18
18
|
"""
|
19
19
|
The base resource from all methods inherit from.
|
@@ -30,7 +30,6 @@ class BaseMetaResource(Resource, MethodResource):
|
|
30
30
|
self.auth_class = None
|
31
31
|
self.dependents = None
|
32
32
|
self.unique = None
|
33
|
-
pass
|
34
33
|
|
35
34
|
"""
|
36
35
|
METHODS USED FOR THE BASIC CRUD OPERATIONS: GET, POST, PUT, PATCH, DELETE
|
@@ -121,7 +120,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
121
120
|
"""
|
122
121
|
item = self.data_model.get_one_object(**kwargs)
|
123
122
|
if item is None:
|
124
|
-
raise ObjectDoesNotExist(
|
123
|
+
raise ObjectDoesNotExist()
|
125
124
|
|
126
125
|
data = dict(data)
|
127
126
|
|
@@ -144,7 +143,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
144
143
|
item = self.data_model.get_one_object(**kwargs)
|
145
144
|
|
146
145
|
if item is None:
|
147
|
-
raise ObjectDoesNotExist(
|
146
|
+
raise ObjectDoesNotExist()
|
148
147
|
|
149
148
|
data = dict(data)
|
150
149
|
|
@@ -164,7 +163,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
164
163
|
"""
|
165
164
|
item = self.data_model.get_one_object(**kwargs)
|
166
165
|
if item is None:
|
167
|
-
raise ObjectDoesNotExist(
|
166
|
+
raise ObjectDoesNotExist()
|
168
167
|
if self.dependents is not None:
|
169
168
|
for element in getattr(item, self.dependents):
|
170
169
|
element.delete()
|
cornflow/schemas/execution.py
CHANGED
@@ -95,7 +95,7 @@ class ExecutionDagPostRequest(ExecutionRequest, ExecutionDagRequest):
|
|
95
95
|
|
96
96
|
|
97
97
|
class ExecutionDetailsEndpointResponse(BaseDataEndpointResponse):
|
98
|
-
config = fields.
|
98
|
+
config = fields.Raw()
|
99
99
|
instance_id = fields.Str()
|
100
100
|
state = fields.Int()
|
101
101
|
message = fields.Str(attribute="state_message")
|
@@ -112,6 +112,14 @@ class ExecutionDetailsEndpointWithIndicatorsResponse(ExecutionDetailsEndpointRes
|
|
112
112
|
return indicators_string[1:-1]
|
113
113
|
|
114
114
|
indicators = fields.Method("get_indicators")
|
115
|
+
updated_at = fields.DateTime(dump_only=True)
|
116
|
+
|
117
|
+
def get_username(self, obj):
|
118
|
+
if hasattr(obj, "user") and obj.user is not None:
|
119
|
+
return obj.user.username
|
120
|
+
return None
|
121
|
+
|
122
|
+
username = fields.Method("get_username")
|
115
123
|
|
116
124
|
|
117
125
|
class ExecutionDetailsWithIndicatorsAndLogResponse(
|
cornflow/schemas/health.py
CHANGED
@@ -2,15 +2,16 @@
|
|
2
2
|
This file contains the auth class that can be used for authentication on the request to the REST API
|
3
3
|
"""
|
4
4
|
|
5
|
-
# Imports from external libraries
|
6
|
-
import jwt
|
7
|
-
import requests
|
8
|
-
from jwt.algorithms import RSAAlgorithm
|
9
5
|
from datetime import datetime, timedelta, timezone
|
10
|
-
from flask import request, g, current_app, Request
|
11
6
|
from functools import wraps
|
12
7
|
from typing import Tuple
|
8
|
+
|
9
|
+
# Imports from external libraries
|
10
|
+
import jwt
|
11
|
+
import requests
|
13
12
|
from cachetools import TTLCache
|
13
|
+
from flask import request, g, current_app, Request
|
14
|
+
from jwt.algorithms import RSAAlgorithm
|
14
15
|
from werkzeug.datastructures import Headers
|
15
16
|
|
16
17
|
# Imports from internal modules
|
@@ -31,7 +32,6 @@ from cornflow.shared.exceptions import (
|
|
31
32
|
InvalidData,
|
32
33
|
InvalidUsage,
|
33
34
|
NoPermission,
|
34
|
-
ObjectDoesNotExist,
|
35
35
|
)
|
36
36
|
|
37
37
|
# Cache for storing public keys with 1 hour TTL
|
@@ -122,15 +122,13 @@ class Auth:
|
|
122
122
|
def decode_token(token: str = None) -> dict:
|
123
123
|
"""
|
124
124
|
Decodes a given JSON Web token and extracts the username from the sub claim.
|
125
|
-
Works with both internal tokens and OpenID tokens.
|
125
|
+
Works with both internal tokens and OpenID tokens by attempting verification methods sequentially.
|
126
126
|
|
127
127
|
:param str token: the given JSON Web Token
|
128
128
|
:return: dictionary containing the username from the token's sub claim
|
129
129
|
:rtype: dict
|
130
130
|
"""
|
131
|
-
|
132
131
|
if token is None:
|
133
|
-
|
134
132
|
raise InvalidCredentials(
|
135
133
|
"Must provide a token in Authorization header",
|
136
134
|
log_txt="Error while trying to decode token. Token is missing.",
|
@@ -138,48 +136,81 @@ class Auth:
|
|
138
136
|
)
|
139
137
|
|
140
138
|
try:
|
141
|
-
#
|
142
|
-
|
143
|
-
|
144
|
-
issuer = unverified_payload.get("iss")
|
145
|
-
|
146
|
-
# For internal tokens
|
147
|
-
if issuer == INTERNAL_TOKEN_ISSUER:
|
148
|
-
|
149
|
-
return jwt.decode(
|
150
|
-
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
151
|
-
)
|
152
|
-
|
153
|
-
# For OpenID tokens
|
154
|
-
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
155
|
-
|
156
|
-
return Auth().verify_token(
|
157
|
-
token,
|
158
|
-
current_app.config["OID_PROVIDER"],
|
159
|
-
current_app.config["OID_EXPECTED_AUDIENCE"],
|
160
|
-
)
|
161
|
-
|
162
|
-
# If we get here, the issuer is not valid
|
163
|
-
|
164
|
-
raise InvalidCredentials(
|
165
|
-
"Invalid token issuer. Token must be issued by a valid provider",
|
166
|
-
log_txt="Error while trying to decode token. Invalid issuer.",
|
167
|
-
status_code=400,
|
139
|
+
# Attempt 1: Verify as an internal token (HS256)
|
140
|
+
payload = jwt.decode(
|
141
|
+
token, current_app.config["SECRET_TOKEN_KEY"], algorithms=["HS256"]
|
168
142
|
)
|
143
|
+
if payload.get("iss") != INTERNAL_TOKEN_ISSUER:
|
144
|
+
raise jwt.InvalidIssuerError(
|
145
|
+
"Internal token issuer mismatch after verification"
|
146
|
+
)
|
147
|
+
return payload
|
169
148
|
|
170
149
|
except jwt.ExpiredSignatureError:
|
171
|
-
|
150
|
+
# Handle expiration specifically, could apply to either token type if caught here first
|
172
151
|
raise InvalidCredentials(
|
173
152
|
"The token has expired, please login again",
|
174
153
|
log_txt="Error while trying to decode token. The token has expired.",
|
175
154
|
status_code=400,
|
176
155
|
)
|
177
|
-
except
|
178
|
-
|
156
|
+
except (
|
157
|
+
jwt.InvalidSignatureError,
|
158
|
+
jwt.DecodeError,
|
159
|
+
jwt.InvalidTokenError,
|
160
|
+
) as e_internal:
|
161
|
+
# Internal verification failed (signature, format, etc.). Try OIDC if configured.
|
162
|
+
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
163
|
+
try:
|
164
|
+
# Attempt 2: Verify as an OIDC token (RS256) using the dedicated method
|
165
|
+
return Auth().verify_token(
|
166
|
+
token,
|
167
|
+
current_app.config["OID_PROVIDER"],
|
168
|
+
current_app.config["OID_EXPECTED_AUDIENCE"],
|
169
|
+
)
|
170
|
+
except jwt.ExpiredSignatureError:
|
171
|
+
# OIDC token expired
|
172
|
+
raise InvalidCredentials(
|
173
|
+
"The token has expired, please login again",
|
174
|
+
log_txt="Error while trying to decode OIDC token. The token has expired.",
|
175
|
+
status_code=400,
|
176
|
+
)
|
177
|
+
except (
|
178
|
+
jwt.InvalidTokenError,
|
179
|
+
InvalidCredentials,
|
180
|
+
CommunicationError,
|
181
|
+
) as e_oidc:
|
182
|
+
# OIDC verification failed (JWT format, signature, kid, audience, issuer, comms error)
|
183
|
+
# Log details for debugging but return a generic error to the client.
|
184
|
+
log_message = (
|
185
|
+
f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
|
186
|
+
f"OIDC verification failed ({type(e_oidc).__name__}: {str(e_oidc)})."
|
187
|
+
)
|
188
|
+
current_app.logger.warning(log_message)
|
189
|
+
raise InvalidCredentials(
|
190
|
+
"Invalid token format, signature, or configuration",
|
191
|
+
log_txt=log_message,
|
192
|
+
status_code=400,
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
# Internal verification failed, and OIDC is not configured
|
196
|
+
log_message = (
|
197
|
+
f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
|
198
|
+
f"OIDC is not configured."
|
199
|
+
)
|
200
|
+
current_app.logger.warning(log_message)
|
201
|
+
raise InvalidCredentials(
|
202
|
+
"Invalid token format or signature",
|
203
|
+
log_txt=log_message,
|
204
|
+
status_code=400,
|
205
|
+
)
|
206
|
+
except Exception as e:
|
207
|
+
# Catch any other unexpected errors during the process
|
208
|
+
log_message = f"Unexpected error during token decoding: {str(e)}"
|
209
|
+
current_app.logger.error(log_message)
|
179
210
|
raise InvalidCredentials(
|
180
|
-
"
|
181
|
-
log_txt=
|
182
|
-
status_code=
|
211
|
+
"Could not decode or verify token due to an unexpected server error",
|
212
|
+
log_txt=log_message,
|
213
|
+
status_code=500,
|
183
214
|
)
|
184
215
|
|
185
216
|
def get_token_from_header(self, headers: Headers = None) -> str:
|
cornflow/shared/const.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
In this file we import the values for different constants on cornflow server
|
3
3
|
"""
|
4
|
-
|
4
|
+
CORNFLOW_VERSION = "1.2.3a1"
|
5
5
|
INTERNAL_TOKEN_ISSUER = "cornflow"
|
6
6
|
|
7
7
|
# endpoints responses for health check
|
@@ -112,3 +112,12 @@ BASE_PERMISSION_ASSIGNATION = [
|
|
112
112
|
EXTRA_PERMISSION_ASSIGNATION = [
|
113
113
|
(VIEWER_ROLE, PUT_ACTION, "user-detail"),
|
114
114
|
]
|
115
|
+
|
116
|
+
# migrations constants
|
117
|
+
MIGRATIONS_DEFAULT_PATH = "./cornflow/migrations"
|
118
|
+
|
119
|
+
# Costants for messages that are given back on exceptions
|
120
|
+
AIRFLOW_NOT_REACHABLE_MSG = "Airflow is not reachable"
|
121
|
+
DAG_PAUSED_MSG = "The dag exists but it is paused in airflow"
|
122
|
+
AIRFLOW_ERROR_MSG = "Airflow responded with an error:"
|
123
|
+
DATA_DOES_NOT_EXIST_MSG = "The data entity does not exist on the database"
|
cornflow/shared/exceptions.py
CHANGED
@@ -9,6 +9,8 @@ from cornflow_client.constants import AirflowError
|
|
9
9
|
from werkzeug.exceptions import HTTPException
|
10
10
|
import traceback
|
11
11
|
|
12
|
+
from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
|
13
|
+
|
12
14
|
|
13
15
|
class InvalidUsage(Exception):
|
14
16
|
"""
|
@@ -48,7 +50,7 @@ class ObjectDoesNotExist(InvalidUsage):
|
|
48
50
|
"""
|
49
51
|
|
50
52
|
status_code = 404
|
51
|
-
error =
|
53
|
+
error = DATA_DOES_NOT_EXIST_MSG
|
52
54
|
|
53
55
|
|
54
56
|
class ObjectAlreadyExists(InvalidUsage):
|