cornflow 2.0.0a12__py3-none-any.whl → 2.0.0a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornflow/app.py +3 -1
- cornflow/cli/__init__.py +4 -0
- cornflow/cli/actions.py +4 -0
- cornflow/cli/config.py +4 -0
- cornflow/cli/migrations.py +13 -8
- cornflow/cli/permissions.py +4 -0
- cornflow/cli/roles.py +4 -0
- cornflow/cli/schemas.py +5 -0
- cornflow/cli/service.py +260 -147
- cornflow/cli/tools/api_generator.py +13 -10
- cornflow/cli/tools/endpoint_tools.py +191 -196
- cornflow/cli/tools/models_tools.py +87 -60
- cornflow/cli/tools/schema_generator.py +161 -67
- cornflow/cli/tools/schemas_tools.py +4 -5
- cornflow/cli/users.py +8 -0
- cornflow/cli/views.py +4 -0
- cornflow/commands/dag.py +3 -2
- cornflow/commands/schemas.py +6 -4
- cornflow/commands/users.py +12 -17
- cornflow/config.py +3 -2
- cornflow/endpoints/dag.py +27 -25
- cornflow/endpoints/data_check.py +102 -164
- cornflow/endpoints/example_data.py +9 -3
- cornflow/endpoints/execution.py +27 -23
- cornflow/endpoints/health.py +4 -5
- cornflow/endpoints/instance.py +39 -12
- cornflow/endpoints/meta_resource.py +4 -5
- cornflow/schemas/execution.py +1 -0
- cornflow/shared/airflow.py +157 -0
- cornflow/shared/authentication/auth.py +73 -42
- cornflow/shared/const.py +9 -0
- cornflow/shared/databricks.py +10 -10
- cornflow/shared/exceptions.py +3 -1
- cornflow/shared/utils_tables.py +36 -8
- cornflow/shared/validators.py +1 -1
- cornflow/tests/const.py +1 -0
- cornflow/tests/custom_test_case.py +4 -4
- cornflow/tests/unit/test_alarms.py +1 -2
- cornflow/tests/unit/test_cases.py +4 -7
- cornflow/tests/unit/test_executions.py +105 -43
- cornflow/tests/unit/test_log_in.py +46 -9
- cornflow/tests/unit/test_tables.py +3 -3
- cornflow/tests/unit/tools.py +31 -13
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/METADATA +2 -2
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/RECORD +48 -47
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/WHEEL +1 -1
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/entry_points.txt +0 -0
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/top_level.txt +0 -0
cornflow/endpoints/execution.py
CHANGED
@@ -6,10 +6,9 @@ These endpoints hve different access url, but manage the same data entities
|
|
6
6
|
|
7
7
|
# Import from libraries
|
8
8
|
from cornflow_client.airflow.api import Airflow
|
9
|
+
|
9
10
|
from cornflow_client.databricks.api import Databricks
|
10
11
|
from cornflow_client.constants import INSTANCE_SCHEMA, CONFIG_SCHEMA, SOLUTION_SCHEMA
|
11
|
-
|
12
|
-
# TODO AGA: Porqué el import no funcina correctamente
|
13
12
|
from flask import request, current_app
|
14
13
|
from flask_apispec import marshal_with, use_kwargs, doc
|
15
14
|
|
@@ -37,6 +36,9 @@ from cornflow.shared.const import (
|
|
37
36
|
DATABRICKS_BACKEND,
|
38
37
|
)
|
39
38
|
from cornflow.shared.const import (
|
39
|
+
AIRFLOW_ERROR_MSG,
|
40
|
+
AIRFLOW_NOT_REACHABLE_MSG,
|
41
|
+
DAG_PAUSED_MSG,
|
40
42
|
EXEC_STATE_RUNNING,
|
41
43
|
EXEC_STATE_ERROR,
|
42
44
|
EXEC_STATE_ERROR_START,
|
@@ -64,7 +66,7 @@ from cornflow.shared.validators import (
|
|
64
66
|
|
65
67
|
class ExecutionEndpoint(BaseMetaResource):
|
66
68
|
"""
|
67
|
-
Endpoint used to create
|
69
|
+
Endpoint used to create a new execution or get all the executions and their information back
|
68
70
|
"""
|
69
71
|
|
70
72
|
def __init__(self):
|
@@ -157,10 +159,10 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
157
159
|
)
|
158
160
|
continue
|
159
161
|
|
160
|
-
if not self.orch_client.is_alive():
|
162
|
+
if not self.orch_client.is_alive(config=current_app.config):
|
161
163
|
current_app.logger.warning(
|
162
|
-
"Error while the app tried to update the status of all running executions."
|
163
|
-
"
|
164
|
+
f"Error while the app tried to update the status of all running executions."
|
165
|
+
f"{AIRFLOW_NOT_REACHABLE_MSG}"
|
164
166
|
)
|
165
167
|
continue
|
166
168
|
try:
|
@@ -170,7 +172,7 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
170
172
|
except self.orch_error as err:
|
171
173
|
current_app.logger.warning(
|
172
174
|
"Error while the app tried to update the status of all running executions."
|
173
|
-
f"
|
175
|
+
f"{AIRFLOW_ERROR_MSG} {err}"
|
174
176
|
)
|
175
177
|
continue
|
176
178
|
|
@@ -207,6 +209,10 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
207
209
|
user=self.get_user(), idx=execution.instance_id
|
208
210
|
)
|
209
211
|
|
212
|
+
if execution.schema != instance.schema:
|
213
|
+
execution.delete()
|
214
|
+
raise InvalidData(error="Instance and execution schema mismatch")
|
215
|
+
|
210
216
|
current_app.logger.debug(f"The request is: {request.args.get('run')}")
|
211
217
|
# this allows testing without orchestrator interaction:
|
212
218
|
if request.args.get("run", "1") == "0":
|
@@ -286,22 +292,21 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
286
292
|
if self.orch_type == AIRFLOW_BACKEND:
|
287
293
|
info = schema_info.json()
|
288
294
|
if info["is_paused"]:
|
289
|
-
|
290
|
-
current_app.logger.error(err)
|
295
|
+
current_app.logger.error(DAG_PAUSED_MSG)
|
291
296
|
execution.update_state(EXEC_STATE_ERROR_START)
|
292
297
|
raise self.orch_error(
|
293
|
-
error=
|
298
|
+
error=DAG_PAUSED_MSG,
|
294
299
|
payload=dict(
|
295
300
|
message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
|
296
301
|
state=EXEC_STATE_ERROR_START,
|
297
302
|
),
|
298
303
|
log_txt=f"Error while user {self.get_user()} tries to create an execution. "
|
299
|
-
+
|
304
|
+
+ DAG_PAUSED_MSG,
|
300
305
|
)
|
301
306
|
# TODO AGA: revisar si hay que hacer alguna verificación a los JOBS
|
302
307
|
|
303
308
|
try:
|
304
|
-
# TODO AGA: Hay que
|
309
|
+
# TODO AGA: Hay que gestionar la posible eliminación de execution.id como
|
305
310
|
# parámetro, ya que no se puede seleccionar el id en databricks
|
306
311
|
# revisar las consecuencias que puede tener
|
307
312
|
response = self.orch_client.run_workflow(execution.id, orch_name=schema)
|
@@ -321,7 +326,8 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
321
326
|
|
322
327
|
# if we succeed, we register the run_id in the execution table:
|
323
328
|
orch_data = response.json()
|
324
|
-
|
329
|
+
info = "orch data is " + str(orch_data)
|
330
|
+
current_app.logger.info(info)
|
325
331
|
execution.run_id = orch_data[self.orch_const["run_id"]]
|
326
332
|
execution.update_state(EXEC_STATE_QUEUED)
|
327
333
|
current_app.logger.info(
|
@@ -446,17 +452,16 @@ class ExecutionRelaunchEndpoint(BaseMetaResource):
|
|
446
452
|
info = schema_info.json()
|
447
453
|
if self.orch_type == AIRFLOW_BACKEND:
|
448
454
|
if info["is_paused"]:
|
449
|
-
|
450
|
-
current_app.logger.error(err)
|
455
|
+
current_app.logger.error(AIRFLOW_NOT_REACHABLE_MSG)
|
451
456
|
execution.update_state(EXEC_STATE_ERROR_START)
|
452
457
|
raise self.orch_error(
|
453
|
-
error=
|
458
|
+
error=AIRFLOW_NOT_REACHABLE_MSG,
|
454
459
|
payload=dict(
|
455
460
|
message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
|
456
461
|
state=EXEC_STATE_ERROR_START,
|
457
462
|
),
|
458
463
|
log_txt=f"Error while user {self.get_user()} tries to relaunch execution {idx}. "
|
459
|
-
+
|
464
|
+
+ AIRFLOW_NOT_REACHABLE_MSG,
|
460
465
|
)
|
461
466
|
# TODO GG: revisar si hay que hacer alguna comprobación del estilo a databricks
|
462
467
|
try:
|
@@ -626,14 +631,13 @@ class ExecutionDetailsEndpoint(ExecutionDetailsEndpointBase):
|
|
626
631
|
f"The execution does not exist."
|
627
632
|
)
|
628
633
|
|
629
|
-
if not self.orch_client.is_alive():
|
630
|
-
err = self.orch_const["name"] + " is not accessible"
|
634
|
+
if not self.orch_client.is_alive(config=current_app.config):
|
631
635
|
raise self.orch_error(
|
632
|
-
error=
|
633
|
-
log_txt=f"Error while user {self.get_user()} tries to stop execution {idx}. {
|
636
|
+
error=AIRFLOW_NOT_REACHABLE_MSG,
|
637
|
+
log_txt=f"Error while user {self.get_user()} tries to stop execution {idx}. {AIRFLOW_NOT_REACHABLE_MSG}",
|
634
638
|
)
|
635
639
|
|
636
|
-
|
640
|
+
self.orch_client.set_dag_run_to_fail(
|
637
641
|
dag_name=execution.schema, run_id=execution.run_id
|
638
642
|
)
|
639
643
|
# We should check if the execution has been stopped
|
@@ -897,7 +901,7 @@ def get_databricks(schema, execution, message="tries to create an execution"):
|
|
897
901
|
"""
|
898
902
|
db_client = Databricks.from_config(current_app.config)
|
899
903
|
schema_info = db_client.get_orch_info(schema)
|
900
|
-
if not db_client.is_alive():
|
904
|
+
if not db_client.is_alive(config=current_app.config):
|
901
905
|
err = "Databricks is not accessible"
|
902
906
|
current_app.logger.error(err)
|
903
907
|
execution.update_state(EXEC_STATE_ERROR_START)
|
cornflow/endpoints/health.py
CHANGED
@@ -20,7 +20,7 @@ from cornflow.shared.const import (
|
|
20
20
|
STATUS_HEALTHY,
|
21
21
|
STATUS_UNHEALTHY,
|
22
22
|
)
|
23
|
-
from
|
23
|
+
from cornflow_client.databricks.api import Databricks
|
24
24
|
from cornflow.shared.exceptions import EndpointNotImplemented
|
25
25
|
|
26
26
|
|
@@ -35,11 +35,10 @@ class HealthEndpoint(BaseMetaResource):
|
|
35
35
|
:rtype: dict
|
36
36
|
:doc-author: baobab soluciones
|
37
37
|
"""
|
38
|
-
|
38
|
+
|
39
39
|
backend_status = self.check_backend_status()
|
40
40
|
|
41
41
|
cornflow_status = STATUS_UNHEALTHY
|
42
|
-
|
43
42
|
|
44
43
|
if (
|
45
44
|
UserModel.get_one_user_by_username(os.getenv("CORNFLOW_SERVICE_USER"))
|
@@ -75,5 +74,5 @@ class HealthEndpoint(BaseMetaResource):
|
|
75
74
|
databricks_status = STATUS_UNHEALTHY
|
76
75
|
if db_client.is_alive():
|
77
76
|
databricks_status = STATUS_HEALTHY
|
78
|
-
|
79
|
-
return databricks_status
|
77
|
+
|
78
|
+
return databricks_status
|
cornflow/endpoints/instance.py
CHANGED
@@ -34,7 +34,6 @@ from cornflow.shared.exceptions import InvalidUsage, InvalidData
|
|
34
34
|
from cornflow.shared.validators import json_schema_validate_as_string
|
35
35
|
|
36
36
|
|
37
|
-
|
38
37
|
# Initialize the schema that all endpoints are going to use
|
39
38
|
ALLOWED_EXTENSIONS = {"mps", "lp"}
|
40
39
|
|
@@ -92,14 +91,18 @@ class InstanceEndpoint(BaseMetaResource):
|
|
92
91
|
# We validate the instance data
|
93
92
|
config = current_app.config
|
94
93
|
|
95
|
-
instance_schema = DeployedOrch.get_one_schema(
|
96
|
-
|
94
|
+
instance_schema = DeployedOrch.get_one_schema(
|
95
|
+
config, data_schema, INSTANCE_SCHEMA
|
96
|
+
)
|
97
|
+
instance_errors = json_schema_validate_as_string(
|
98
|
+
instance_schema, kwargs["data"]
|
99
|
+
)
|
97
100
|
|
98
101
|
if instance_errors:
|
99
102
|
raise InvalidData(
|
100
103
|
payload=dict(jsonschema_errors=instance_errors),
|
101
104
|
log_txt=f"Error while user {self.get_user()} tries to create an instance. "
|
102
|
-
|
105
|
+
f"Instance data do not match the jsonschema.",
|
103
106
|
)
|
104
107
|
|
105
108
|
# if we're here, we validated and the data seems to fit the schema
|
@@ -163,14 +166,18 @@ class InstanceDetailsEndpoint(InstanceDetailsEndpointBase):
|
|
163
166
|
|
164
167
|
config = current_app.config
|
165
168
|
|
166
|
-
instance_schema = DeployedOrch.get_one_schema(
|
167
|
-
|
169
|
+
instance_schema = DeployedOrch.get_one_schema(
|
170
|
+
config, schema, INSTANCE_SCHEMA
|
171
|
+
)
|
172
|
+
instance_errors = json_schema_validate_as_string(
|
173
|
+
instance_schema, kwargs["data"]
|
174
|
+
)
|
168
175
|
|
169
176
|
if instance_errors:
|
170
177
|
raise InvalidData(
|
171
178
|
payload=dict(jsonschema_errors=instance_errors),
|
172
179
|
log_txt=f"Error while user {self.get_user()} tries to create an instance. "
|
173
|
-
|
180
|
+
f"Instance data do not match the jsonschema.",
|
174
181
|
)
|
175
182
|
|
176
183
|
response = self.put_detail(data=kwargs, user=self.get_user(), idx=idx)
|
@@ -268,15 +275,35 @@ class InstanceFileEndpoint(BaseMetaResource):
|
|
268
275
|
sense = 1 if minimize else -1
|
269
276
|
try:
|
270
277
|
_vars, problem = pulp.LpProblem.fromMPS(filename, sense=sense)
|
271
|
-
except:
|
278
|
+
except FileNotFoundError as e:
|
279
|
+
# Handle file not found specifically
|
272
280
|
raise InvalidUsage(
|
273
|
-
error="
|
274
|
-
log_txt=f"Error
|
275
|
-
|
281
|
+
error=f"MPS file not found: {filename}",
|
282
|
+
log_txt=f"Error for user {self.get_user()}: MPS file '{filename}' not found. Details: {e}",
|
283
|
+
status_code=404,
|
284
|
+
) from e
|
285
|
+
except PermissionError as e:
|
286
|
+
# Handle permission issues
|
287
|
+
raise InvalidUsage(
|
288
|
+
error=f"Permission denied reading MPS file: {filename}",
|
289
|
+
log_txt=f"Error for user {self.get_user()}: Permission denied for MPS file '{filename}'. Details: {e}",
|
290
|
+
status_code=403,
|
291
|
+
) from e
|
292
|
+
except (ValueError, pulp.PulpError, OSError, IndexError) as e:
|
293
|
+
# Catch parsing errors, PuLP errors, and other IO errors
|
294
|
+
# Handle parsing, PuLP, or other OS errors
|
295
|
+
current_app.logger.error(
|
296
|
+
f"Error parsing MPS file {filename} for user {self.get_user()}: {e}",
|
297
|
+
exc_info=True,
|
276
298
|
)
|
299
|
+
raise InvalidUsage(
|
300
|
+
error="Error reading or parsing the MPS file.",
|
301
|
+
log_txt=f"Error while user {self.get_user()} tries to create instance from MPS file {filename}. Details: {e}",
|
302
|
+
) from e
|
303
|
+
|
277
304
|
try:
|
278
305
|
os.remove(filename)
|
279
|
-
except:
|
306
|
+
except FileNotFoundError:
|
280
307
|
pass
|
281
308
|
|
282
309
|
pb_data = dict(
|
@@ -1,6 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
This file has all the logic shared for all the resources
|
3
3
|
"""
|
4
|
+
|
4
5
|
# Import from external libraries
|
5
6
|
from flask_restful import Resource
|
6
7
|
from flask import g, request
|
@@ -13,7 +14,6 @@ from cornflow.shared.const import ALL_DEFAULT_ROLES
|
|
13
14
|
from cornflow.shared.exceptions import InvalidUsage, ObjectDoesNotExist, NoPermission
|
14
15
|
|
15
16
|
|
16
|
-
|
17
17
|
class BaseMetaResource(Resource, MethodResource):
|
18
18
|
"""
|
19
19
|
The base resource from all methods inherit from.
|
@@ -30,7 +30,6 @@ class BaseMetaResource(Resource, MethodResource):
|
|
30
30
|
self.auth_class = None
|
31
31
|
self.dependents = None
|
32
32
|
self.unique = None
|
33
|
-
pass
|
34
33
|
|
35
34
|
"""
|
36
35
|
METHODS USED FOR THE BASIC CRUD OPERATIONS: GET, POST, PUT, PATCH, DELETE
|
@@ -121,7 +120,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
121
120
|
"""
|
122
121
|
item = self.data_model.get_one_object(**kwargs)
|
123
122
|
if item is None:
|
124
|
-
raise ObjectDoesNotExist(
|
123
|
+
raise ObjectDoesNotExist()
|
125
124
|
|
126
125
|
data = dict(data)
|
127
126
|
|
@@ -144,7 +143,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
144
143
|
item = self.data_model.get_one_object(**kwargs)
|
145
144
|
|
146
145
|
if item is None:
|
147
|
-
raise ObjectDoesNotExist(
|
146
|
+
raise ObjectDoesNotExist()
|
148
147
|
|
149
148
|
data = dict(data)
|
150
149
|
|
@@ -164,7 +163,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
164
163
|
"""
|
165
164
|
item = self.data_model.get_one_object(**kwargs)
|
166
165
|
if item is None:
|
167
|
-
raise ObjectDoesNotExist(
|
166
|
+
raise ObjectDoesNotExist()
|
168
167
|
if self.dependents is not None:
|
169
168
|
for element in getattr(item, self.dependents):
|
170
169
|
element.delete()
|
cornflow/schemas/execution.py
CHANGED
@@ -0,0 +1,157 @@
|
|
1
|
+
""" """
|
2
|
+
|
3
|
+
# Full imports
|
4
|
+
import json
|
5
|
+
import requests
|
6
|
+
|
7
|
+
# Partial imports
|
8
|
+
from requests.auth import HTTPBasicAuth
|
9
|
+
from requests.exceptions import ConnectionError, HTTPError
|
10
|
+
|
11
|
+
# Imports from modules
|
12
|
+
from cornflow_client.constants import AirflowError
|
13
|
+
from cornflow_client.orchestrator_constants import config_orchestrator
|
14
|
+
|
15
|
+
|
16
|
+
class Airflow(object):
|
17
|
+
def __init__(self, url, user, pwd):
|
18
|
+
self.url = f"{url}/api/v1"
|
19
|
+
self.auth = HTTPBasicAuth(user, pwd)
|
20
|
+
self.constants = config_orchestrator["airflow"]
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def from_config(cls, config):
|
24
|
+
data = dict(
|
25
|
+
url=config["AIRFLOW_URL"],
|
26
|
+
user=config["AIRFLOW_USER"],
|
27
|
+
pwd=config["AIRFLOW_PWD"],
|
28
|
+
)
|
29
|
+
return cls(**data)
|
30
|
+
|
31
|
+
def is_alive(self, config=None):
|
32
|
+
try:
|
33
|
+
response = requests.get(f"{self.url}/health")
|
34
|
+
except (ConnectionError, HTTPError):
|
35
|
+
return False
|
36
|
+
try:
|
37
|
+
data = response.json()
|
38
|
+
database = data["metadatabase"]["status"] == "healthy"
|
39
|
+
scheduler = data["scheduler"]["status"] == "healthy"
|
40
|
+
except json.JSONDecodeError:
|
41
|
+
return False
|
42
|
+
except KeyError:
|
43
|
+
return False
|
44
|
+
|
45
|
+
return database and scheduler
|
46
|
+
|
47
|
+
def get_dag_info(self, dag_name, method="GET"):
|
48
|
+
url = f"{self.url}/dags/{dag_name}"
|
49
|
+
return self.request_headers_auth(method=method, url=url)
|
50
|
+
|
51
|
+
def request_headers_auth(self, status=200, **kwargs):
|
52
|
+
def_headers = {"Content-type": "application/json", "Accept": "application/json"}
|
53
|
+
headers = kwargs.get("headers", def_headers)
|
54
|
+
response = requests.request(headers=headers, auth=self.auth, **kwargs)
|
55
|
+
if response.status_code != status:
|
56
|
+
raise AirflowError(error=response.text, status_code=response.status_code)
|
57
|
+
return response
|
58
|
+
|
59
|
+
def consume_dag_run(self, dag_name, payload, dag_run_id=None, method="POST"):
|
60
|
+
url = f"{self.url}/dags/{dag_name}/dagRuns"
|
61
|
+
if dag_run_id is not None:
|
62
|
+
url = url + f"/{dag_run_id}"
|
63
|
+
response = self.request_headers_auth(method=method, url=url, json=payload)
|
64
|
+
return response
|
65
|
+
|
66
|
+
def set_dag_run_state(self, dag_name, payload):
|
67
|
+
url = f"{self.url}/dags/{dag_name}/updateTaskInstancesState"
|
68
|
+
return self.request_headers_auth(method="POST", url=url, json=payload)
|
69
|
+
|
70
|
+
def run_workflow(
|
71
|
+
self,
|
72
|
+
execution_id,
|
73
|
+
orch_name=config_orchestrator["airflow"]["def_schema"],
|
74
|
+
checks_only=False,
|
75
|
+
case_id=None,
|
76
|
+
):
|
77
|
+
conf = dict(exec_id=execution_id, checks_only=checks_only)
|
78
|
+
if case_id is not None:
|
79
|
+
conf["case_id"] = case_id
|
80
|
+
payload = dict(conf=conf)
|
81
|
+
return self.consume_dag_run(orch_name, payload=payload, method="POST")
|
82
|
+
|
83
|
+
def update_schemas(self, dag_name="update_all_schemas"):
|
84
|
+
return self.consume_dag_run(dag_name, payload={}, method="POST")
|
85
|
+
|
86
|
+
def update_dag_registry(self, dag_name="update_dag_registry"):
|
87
|
+
return self.consume_dag_run(dag_name, payload={}, method="POST")
|
88
|
+
|
89
|
+
def get_run_status(self, schema, run_id):
|
90
|
+
return self.consume_dag_run(
|
91
|
+
schema, payload=None, dag_run_id=run_id, method="GET"
|
92
|
+
)
|
93
|
+
|
94
|
+
def set_dag_run_to_fail(self, dag_name, run_id, new_status="failed"):
|
95
|
+
# here, two calls have to be done:
|
96
|
+
# first we get information on the dag_run
|
97
|
+
dag_run = self.consume_dag_run(
|
98
|
+
dag_name, payload=None, dag_run_id=run_id, method="GET"
|
99
|
+
)
|
100
|
+
dag_run_data = dag_run.json()
|
101
|
+
# then, we use the "executed_date" to build a call to the change state api
|
102
|
+
# TODO: We assume the solving task is named as is parent dag!
|
103
|
+
payload = dict(
|
104
|
+
dry_run=False,
|
105
|
+
include_downstream=True,
|
106
|
+
include_future=False,
|
107
|
+
include_past=False,
|
108
|
+
include_upstream=True,
|
109
|
+
new_state=new_status,
|
110
|
+
task_id=dag_name,
|
111
|
+
execution_date=dag_run_data["execution_date"],
|
112
|
+
)
|
113
|
+
return self.set_dag_run_state(dag_name, payload=payload)
|
114
|
+
|
115
|
+
def get_all_dag_runs(self, dag_name):
|
116
|
+
return self.consume_dag_run(dag_name=dag_name, payload=None, method="GET")
|
117
|
+
|
118
|
+
def get_orch_info(self, orch_name, method="GET"):
|
119
|
+
url = f"{self.url}/dags/{orch_name}"
|
120
|
+
schema_info = self.request_headers_auth(method=method, url=url)
|
121
|
+
if schema_info.status_code != 200:
|
122
|
+
raise AirflowError("DAG not available")
|
123
|
+
return schema_info
|
124
|
+
|
125
|
+
def get_one_variable(self, variable):
|
126
|
+
url = f"{self.url}/variables/{variable}"
|
127
|
+
return self.request_headers_auth(method="GET", url=url).json()
|
128
|
+
|
129
|
+
def get_all_variables(self):
|
130
|
+
return self.request_headers_auth(
|
131
|
+
method="GET", url=f"{self.url}/variables"
|
132
|
+
).json()
|
133
|
+
|
134
|
+
def get_one_schema(self, dag_name, schema):
|
135
|
+
return self.get_schemas_for_dag_name(dag_name)[schema]
|
136
|
+
|
137
|
+
def get_schemas_for_dag_name(self, dag_name):
|
138
|
+
response = self.get_one_variable(dag_name)
|
139
|
+
result = json.loads(response["value"])
|
140
|
+
result["name"] = response["key"]
|
141
|
+
return result
|
142
|
+
|
143
|
+
def get_all_schemas(self):
|
144
|
+
response = self.get_all_variables()
|
145
|
+
return [dict(name=variable["key"]) for variable in response["variables"]]
|
146
|
+
|
147
|
+
def get_all_dags(self, method="GET"):
|
148
|
+
url = f"{self.url}/dags"
|
149
|
+
return self.request_headers_auth(method=method, url=url)
|
150
|
+
|
151
|
+
def get_internal_dags(self, method="GET"):
|
152
|
+
url = f"{self.url}/dags?tags=internal"
|
153
|
+
return self.request_headers_auth(method=method, url=url)
|
154
|
+
|
155
|
+
def get_model_dags(self, method="GET"):
|
156
|
+
url = f"{self.url}/dags?tags=model"
|
157
|
+
return self.request_headers_auth(method=method, url=url)
|
@@ -2,15 +2,16 @@
|
|
2
2
|
This file contains the auth class that can be used for authentication on the request to the REST API
|
3
3
|
"""
|
4
4
|
|
5
|
-
# Imports from external libraries
|
6
|
-
import jwt
|
7
|
-
import requests
|
8
|
-
from jwt.algorithms import RSAAlgorithm
|
9
5
|
from datetime import datetime, timedelta, timezone
|
10
|
-
from flask import request, g, current_app, Request
|
11
6
|
from functools import wraps
|
12
7
|
from typing import Tuple
|
8
|
+
|
9
|
+
# Imports from external libraries
|
10
|
+
import jwt
|
11
|
+
import requests
|
13
12
|
from cachetools import TTLCache
|
13
|
+
from flask import request, g, current_app, Request
|
14
|
+
from jwt.algorithms import RSAAlgorithm
|
14
15
|
from werkzeug.datastructures import Headers
|
15
16
|
|
16
17
|
# Imports from internal modules
|
@@ -31,7 +32,6 @@ from cornflow.shared.exceptions import (
|
|
31
32
|
InvalidData,
|
32
33
|
InvalidUsage,
|
33
34
|
NoPermission,
|
34
|
-
ObjectDoesNotExist,
|
35
35
|
)
|
36
36
|
|
37
37
|
# Cache for storing public keys with 1 hour TTL
|
@@ -122,15 +122,13 @@ class Auth:
|
|
122
122
|
def decode_token(token: str = None) -> dict:
|
123
123
|
"""
|
124
124
|
Decodes a given JSON Web token and extracts the username from the sub claim.
|
125
|
-
Works with both internal tokens and OpenID tokens.
|
125
|
+
Works with both internal tokens and OpenID tokens by attempting verification methods sequentially.
|
126
126
|
|
127
127
|
:param str token: the given JSON Web Token
|
128
128
|
:return: dictionary containing the username from the token's sub claim
|
129
129
|
:rtype: dict
|
130
130
|
"""
|
131
|
-
|
132
131
|
if token is None:
|
133
|
-
|
134
132
|
raise InvalidCredentials(
|
135
133
|
"Must provide a token in Authorization header",
|
136
134
|
log_txt="Error while trying to decode token. Token is missing.",
|
@@ -138,48 +136,81 @@ class Auth:
|
|
138
136
|
)
|
139
137
|
|
140
138
|
try:
|
141
|
-
#
|
142
|
-
|
143
|
-
|
144
|
-
issuer = unverified_payload.get("iss")
|
145
|
-
|
146
|
-
# For internal tokens
|
147
|
-
if issuer == INTERNAL_TOKEN_ISSUER:
|
148
|
-
|
149
|
-
return jwt.decode(
|
150
|
-
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
151
|
-
)
|
152
|
-
|
153
|
-
# For OpenID tokens
|
154
|
-
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
155
|
-
|
156
|
-
return Auth().verify_token(
|
157
|
-
token,
|
158
|
-
current_app.config["OID_PROVIDER"],
|
159
|
-
current_app.config["OID_EXPECTED_AUDIENCE"],
|
160
|
-
)
|
161
|
-
|
162
|
-
# If we get here, the issuer is not valid
|
163
|
-
|
164
|
-
raise InvalidCredentials(
|
165
|
-
"Invalid token issuer. Token must be issued by a valid provider",
|
166
|
-
log_txt="Error while trying to decode token. Invalid issuer.",
|
167
|
-
status_code=400,
|
139
|
+
# Attempt 1: Verify as an internal token (HS256)
|
140
|
+
payload = jwt.decode(
|
141
|
+
token, current_app.config["SECRET_TOKEN_KEY"], algorithms=["HS256"]
|
168
142
|
)
|
143
|
+
if payload.get("iss") != INTERNAL_TOKEN_ISSUER:
|
144
|
+
raise jwt.InvalidIssuerError(
|
145
|
+
"Internal token issuer mismatch after verification"
|
146
|
+
)
|
147
|
+
return payload
|
169
148
|
|
170
149
|
except jwt.ExpiredSignatureError:
|
171
|
-
|
150
|
+
# Handle expiration specifically, could apply to either token type if caught here first
|
172
151
|
raise InvalidCredentials(
|
173
152
|
"The token has expired, please login again",
|
174
153
|
log_txt="Error while trying to decode token. The token has expired.",
|
175
154
|
status_code=400,
|
176
155
|
)
|
177
|
-
except
|
178
|
-
|
156
|
+
except (
|
157
|
+
jwt.InvalidSignatureError,
|
158
|
+
jwt.DecodeError,
|
159
|
+
jwt.InvalidTokenError,
|
160
|
+
) as e_internal:
|
161
|
+
# Internal verification failed (signature, format, etc.). Try OIDC if configured.
|
162
|
+
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
163
|
+
try:
|
164
|
+
# Attempt 2: Verify as an OIDC token (RS256) using the dedicated method
|
165
|
+
return Auth().verify_token(
|
166
|
+
token,
|
167
|
+
current_app.config["OID_PROVIDER"],
|
168
|
+
current_app.config["OID_EXPECTED_AUDIENCE"],
|
169
|
+
)
|
170
|
+
except jwt.ExpiredSignatureError:
|
171
|
+
# OIDC token expired
|
172
|
+
raise InvalidCredentials(
|
173
|
+
"The token has expired, please login again",
|
174
|
+
log_txt="Error while trying to decode OIDC token. The token has expired.",
|
175
|
+
status_code=400,
|
176
|
+
)
|
177
|
+
except (
|
178
|
+
jwt.InvalidTokenError,
|
179
|
+
InvalidCredentials,
|
180
|
+
CommunicationError,
|
181
|
+
) as e_oidc:
|
182
|
+
# OIDC verification failed (JWT format, signature, kid, audience, issuer, comms error)
|
183
|
+
# Log details for debugging but return a generic error to the client.
|
184
|
+
log_message = (
|
185
|
+
f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
|
186
|
+
f"OIDC verification failed ({type(e_oidc).__name__}: {str(e_oidc)})."
|
187
|
+
)
|
188
|
+
current_app.logger.warning(log_message)
|
189
|
+
raise InvalidCredentials(
|
190
|
+
"Invalid token format, signature, or configuration",
|
191
|
+
log_txt=log_message,
|
192
|
+
status_code=400,
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
# Internal verification failed, and OIDC is not configured
|
196
|
+
log_message = (
|
197
|
+
f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
|
198
|
+
f"OIDC is not configured."
|
199
|
+
)
|
200
|
+
current_app.logger.warning(log_message)
|
201
|
+
raise InvalidCredentials(
|
202
|
+
"Invalid token format or signature",
|
203
|
+
log_txt=log_message,
|
204
|
+
status_code=400,
|
205
|
+
)
|
206
|
+
except Exception as e:
|
207
|
+
# Catch any other unexpected errors during the process
|
208
|
+
log_message = f"Unexpected error during token decoding: {str(e)}"
|
209
|
+
current_app.logger.error(log_message)
|
179
210
|
raise InvalidCredentials(
|
180
|
-
"
|
181
|
-
log_txt=
|
182
|
-
status_code=
|
211
|
+
"Could not decode or verify token due to an unexpected server error",
|
212
|
+
log_txt=log_message,
|
213
|
+
status_code=500,
|
183
214
|
)
|
184
215
|
|
185
216
|
def get_token_from_header(self, headers: Headers = None) -> str:
|
cornflow/shared/const.py
CHANGED
@@ -135,3 +135,12 @@ BASE_PERMISSION_ASSIGNATION = [
|
|
135
135
|
EXTRA_PERMISSION_ASSIGNATION = [
|
136
136
|
(VIEWER_ROLE, PUT_ACTION, "user-detail"),
|
137
137
|
]
|
138
|
+
|
139
|
+
# migrations constants
|
140
|
+
MIGRATIONS_DEFAULT_PATH = "./cornflow/migrations"
|
141
|
+
|
142
|
+
# Costants for messages that are given back on exceptions
|
143
|
+
AIRFLOW_NOT_REACHABLE_MSG = "Airflow is not reachable"
|
144
|
+
DAG_PAUSED_MSG = "The dag exists but it is paused in airflow"
|
145
|
+
AIRFLOW_ERROR_MSG = "Airflow responded with an error:"
|
146
|
+
DATA_DOES_NOT_EXIST_MSG = "The data entity does not exist on the database"
|