cornflow 2.0.0a13__py3-none-any.whl → 2.0.0a14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cornflow/app.py +3 -1
  2. cornflow/cli/__init__.py +4 -0
  3. cornflow/cli/actions.py +4 -0
  4. cornflow/cli/config.py +4 -0
  5. cornflow/cli/migrations.py +13 -8
  6. cornflow/cli/permissions.py +4 -0
  7. cornflow/cli/roles.py +4 -0
  8. cornflow/cli/schemas.py +5 -0
  9. cornflow/cli/service.py +260 -147
  10. cornflow/cli/tools/api_generator.py +13 -10
  11. cornflow/cli/tools/endpoint_tools.py +191 -196
  12. cornflow/cli/tools/models_tools.py +87 -60
  13. cornflow/cli/tools/schema_generator.py +161 -67
  14. cornflow/cli/tools/schemas_tools.py +4 -5
  15. cornflow/cli/users.py +8 -0
  16. cornflow/cli/views.py +4 -0
  17. cornflow/commands/dag.py +3 -2
  18. cornflow/commands/schemas.py +6 -4
  19. cornflow/commands/users.py +12 -17
  20. cornflow/config.py +3 -2
  21. cornflow/endpoints/dag.py +27 -25
  22. cornflow/endpoints/data_check.py +102 -164
  23. cornflow/endpoints/example_data.py +9 -3
  24. cornflow/endpoints/execution.py +27 -23
  25. cornflow/endpoints/health.py +4 -5
  26. cornflow/endpoints/instance.py +39 -12
  27. cornflow/endpoints/meta_resource.py +4 -5
  28. cornflow/shared/airflow.py +157 -0
  29. cornflow/shared/authentication/auth.py +73 -42
  30. cornflow/shared/const.py +9 -0
  31. cornflow/shared/databricks.py +10 -10
  32. cornflow/shared/exceptions.py +3 -1
  33. cornflow/shared/utils_tables.py +36 -8
  34. cornflow/shared/validators.py +1 -1
  35. cornflow/tests/custom_test_case.py +4 -4
  36. cornflow/tests/unit/test_alarms.py +1 -2
  37. cornflow/tests/unit/test_cases.py +4 -7
  38. cornflow/tests/unit/test_executions.py +29 -20
  39. cornflow/tests/unit/test_log_in.py +46 -9
  40. cornflow/tests/unit/test_tables.py +3 -3
  41. cornflow/tests/unit/tools.py +31 -13
  42. {cornflow-2.0.0a13.dist-info → cornflow-2.0.0a14.dist-info}/METADATA +2 -2
  43. {cornflow-2.0.0a13.dist-info → cornflow-2.0.0a14.dist-info}/RECORD +46 -45
  44. {cornflow-2.0.0a13.dist-info → cornflow-2.0.0a14.dist-info}/WHEEL +1 -1
  45. {cornflow-2.0.0a13.dist-info → cornflow-2.0.0a14.dist-info}/entry_points.txt +0 -0
  46. {cornflow-2.0.0a13.dist-info → cornflow-2.0.0a14.dist-info}/top_level.txt +0 -0
@@ -6,10 +6,9 @@ These endpoints hve different access url, but manage the same data entities
6
6
 
7
7
  # Import from libraries
8
8
  from cornflow_client.airflow.api import Airflow
9
+
9
10
  from cornflow_client.databricks.api import Databricks
10
11
  from cornflow_client.constants import INSTANCE_SCHEMA, CONFIG_SCHEMA, SOLUTION_SCHEMA
11
-
12
- # TODO AGA: Porqué el import no funcina correctamente
13
12
  from flask import request, current_app
14
13
  from flask_apispec import marshal_with, use_kwargs, doc
15
14
 
@@ -37,6 +36,9 @@ from cornflow.shared.const import (
37
36
  DATABRICKS_BACKEND,
38
37
  )
39
38
  from cornflow.shared.const import (
39
+ AIRFLOW_ERROR_MSG,
40
+ AIRFLOW_NOT_REACHABLE_MSG,
41
+ DAG_PAUSED_MSG,
40
42
  EXEC_STATE_RUNNING,
41
43
  EXEC_STATE_ERROR,
42
44
  EXEC_STATE_ERROR_START,
@@ -64,7 +66,7 @@ from cornflow.shared.validators import (
64
66
 
65
67
  class ExecutionEndpoint(BaseMetaResource):
66
68
  """
67
- Endpoint used to create and get executions
69
+ Endpoint used to create a new execution or get all the executions and their information back
68
70
  """
69
71
 
70
72
  def __init__(self):
@@ -157,10 +159,10 @@ class ExecutionEndpoint(BaseMetaResource):
157
159
  )
158
160
  continue
159
161
 
160
- if not self.orch_client.is_alive():
162
+ if not self.orch_client.is_alive(config=current_app.config):
161
163
  current_app.logger.warning(
162
- "Error while the app tried to update the status of all running executions."
163
- "Airflow is not accessible."
164
+ f"Error while the app tried to update the status of all running executions."
165
+ f"{AIRFLOW_NOT_REACHABLE_MSG}"
164
166
  )
165
167
  continue
166
168
  try:
@@ -170,7 +172,7 @@ class ExecutionEndpoint(BaseMetaResource):
170
172
  except self.orch_error as err:
171
173
  current_app.logger.warning(
172
174
  "Error while the app tried to update the status of all running executions."
173
- f"Orchestrator responded with an error: {err}"
175
+ f"{AIRFLOW_ERROR_MSG} {err}"
174
176
  )
175
177
  continue
176
178
 
@@ -207,6 +209,10 @@ class ExecutionEndpoint(BaseMetaResource):
207
209
  user=self.get_user(), idx=execution.instance_id
208
210
  )
209
211
 
212
+ if execution.schema != instance.schema:
213
+ execution.delete()
214
+ raise InvalidData(error="Instance and execution schema mismatch")
215
+
210
216
  current_app.logger.debug(f"The request is: {request.args.get('run')}")
211
217
  # this allows testing without orchestrator interaction:
212
218
  if request.args.get("run", "1") == "0":
@@ -286,22 +292,21 @@ class ExecutionEndpoint(BaseMetaResource):
286
292
  if self.orch_type == AIRFLOW_BACKEND:
287
293
  info = schema_info.json()
288
294
  if info["is_paused"]:
289
- err = "The dag exists but it is paused in airflow"
290
- current_app.logger.error(err)
295
+ current_app.logger.error(DAG_PAUSED_MSG)
291
296
  execution.update_state(EXEC_STATE_ERROR_START)
292
297
  raise self.orch_error(
293
- error=err,
298
+ error=DAG_PAUSED_MSG,
294
299
  payload=dict(
295
300
  message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
296
301
  state=EXEC_STATE_ERROR_START,
297
302
  ),
298
303
  log_txt=f"Error while user {self.get_user()} tries to create an execution. "
299
- + err,
304
+ + DAG_PAUSED_MSG,
300
305
  )
301
306
  # TODO AGA: revisar si hay que hacer alguna verificación a los JOBS
302
307
 
303
308
  try:
304
- # TODO AGA: Hay que genestionar la posible eliminación de execution.id como
309
+ # TODO AGA: Hay que gestionar la posible eliminación de execution.id como
305
310
  # parámetro, ya que no se puede seleccionar el id en databricks
306
311
  # revisar las consecuencias que puede tener
307
312
  response = self.orch_client.run_workflow(execution.id, orch_name=schema)
@@ -321,7 +326,8 @@ class ExecutionEndpoint(BaseMetaResource):
321
326
 
322
327
  # if we succeed, we register the run_id in the execution table:
323
328
  orch_data = response.json()
324
- print("orch data is ", orch_data)
329
+ info = "orch data is " + str(orch_data)
330
+ current_app.logger.info(info)
325
331
  execution.run_id = orch_data[self.orch_const["run_id"]]
326
332
  execution.update_state(EXEC_STATE_QUEUED)
327
333
  current_app.logger.info(
@@ -446,17 +452,16 @@ class ExecutionRelaunchEndpoint(BaseMetaResource):
446
452
  info = schema_info.json()
447
453
  if self.orch_type == AIRFLOW_BACKEND:
448
454
  if info["is_paused"]:
449
- err = "The dag exists but it is paused in airflow"
450
- current_app.logger.error(err)
455
+ current_app.logger.error(AIRFLOW_NOT_REACHABLE_MSG)
451
456
  execution.update_state(EXEC_STATE_ERROR_START)
452
457
  raise self.orch_error(
453
- error=err,
458
+ error=AIRFLOW_NOT_REACHABLE_MSG,
454
459
  payload=dict(
455
460
  message=EXECUTION_STATE_MESSAGE_DICT[EXEC_STATE_ERROR_START],
456
461
  state=EXEC_STATE_ERROR_START,
457
462
  ),
458
463
  log_txt=f"Error while user {self.get_user()} tries to relaunch execution {idx}. "
459
- + err,
464
+ + AIRFLOW_NOT_REACHABLE_MSG,
460
465
  )
461
466
  # TODO GG: revisar si hay que hacer alguna comprobación del estilo a databricks
462
467
  try:
@@ -626,14 +631,13 @@ class ExecutionDetailsEndpoint(ExecutionDetailsEndpointBase):
626
631
  f"The execution does not exist."
627
632
  )
628
633
 
629
- if not self.orch_client.is_alive():
630
- err = self.orch_const["name"] + " is not accessible"
634
+ if not self.orch_client.is_alive(config=current_app.config):
631
635
  raise self.orch_error(
632
- error=err,
633
- log_txt=f"Error while user {self.get_user()} tries to stop execution {idx}. {err}",
636
+ error=AIRFLOW_NOT_REACHABLE_MSG,
637
+ log_txt=f"Error while user {self.get_user()} tries to stop execution {idx}. {AIRFLOW_NOT_REACHABLE_MSG}",
634
638
  )
635
639
 
636
- response = self.orch_client.set_dag_run_to_fail(
640
+ self.orch_client.set_dag_run_to_fail(
637
641
  dag_name=execution.schema, run_id=execution.run_id
638
642
  )
639
643
  # We should check if the execution has been stopped
@@ -897,7 +901,7 @@ def get_databricks(schema, execution, message="tries to create an execution"):
897
901
  """
898
902
  db_client = Databricks.from_config(current_app.config)
899
903
  schema_info = db_client.get_orch_info(schema)
900
- if not db_client.is_alive():
904
+ if not db_client.is_alive(config=current_app.config):
901
905
  err = "Databricks is not accessible"
902
906
  current_app.logger.error(err)
903
907
  execution.update_state(EXEC_STATE_ERROR_START)
@@ -20,7 +20,7 @@ from cornflow.shared.const import (
20
20
  STATUS_HEALTHY,
21
21
  STATUS_UNHEALTHY,
22
22
  )
23
- from cornflow.shared.databricks import Databricks
23
+ from cornflow_client.databricks.api import Databricks
24
24
  from cornflow.shared.exceptions import EndpointNotImplemented
25
25
 
26
26
 
@@ -35,11 +35,10 @@ class HealthEndpoint(BaseMetaResource):
35
35
  :rtype: dict
36
36
  :doc-author: baobab soluciones
37
37
  """
38
-
38
+
39
39
  backend_status = self.check_backend_status()
40
40
 
41
41
  cornflow_status = STATUS_UNHEALTHY
42
-
43
42
 
44
43
  if (
45
44
  UserModel.get_one_user_by_username(os.getenv("CORNFLOW_SERVICE_USER"))
@@ -75,5 +74,5 @@ class HealthEndpoint(BaseMetaResource):
75
74
  databricks_status = STATUS_UNHEALTHY
76
75
  if db_client.is_alive():
77
76
  databricks_status = STATUS_HEALTHY
78
-
79
- return databricks_status
77
+
78
+ return databricks_status
@@ -34,7 +34,6 @@ from cornflow.shared.exceptions import InvalidUsage, InvalidData
34
34
  from cornflow.shared.validators import json_schema_validate_as_string
35
35
 
36
36
 
37
-
38
37
  # Initialize the schema that all endpoints are going to use
39
38
  ALLOWED_EXTENSIONS = {"mps", "lp"}
40
39
 
@@ -92,14 +91,18 @@ class InstanceEndpoint(BaseMetaResource):
92
91
  # We validate the instance data
93
92
  config = current_app.config
94
93
 
95
- instance_schema = DeployedOrch.get_one_schema(config, data_schema, INSTANCE_SCHEMA)
96
- instance_errors = json_schema_validate_as_string(instance_schema, kwargs["data"])
94
+ instance_schema = DeployedOrch.get_one_schema(
95
+ config, data_schema, INSTANCE_SCHEMA
96
+ )
97
+ instance_errors = json_schema_validate_as_string(
98
+ instance_schema, kwargs["data"]
99
+ )
97
100
 
98
101
  if instance_errors:
99
102
  raise InvalidData(
100
103
  payload=dict(jsonschema_errors=instance_errors),
101
104
  log_txt=f"Error while user {self.get_user()} tries to create an instance. "
102
- f"Instance data do not match the jsonschema.",
105
+ f"Instance data do not match the jsonschema.",
103
106
  )
104
107
 
105
108
  # if we're here, we validated and the data seems to fit the schema
@@ -163,14 +166,18 @@ class InstanceDetailsEndpoint(InstanceDetailsEndpointBase):
163
166
 
164
167
  config = current_app.config
165
168
 
166
- instance_schema = DeployedOrch.get_one_schema(config, schema, INSTANCE_SCHEMA)
167
- instance_errors = json_schema_validate_as_string(instance_schema, kwargs["data"])
169
+ instance_schema = DeployedOrch.get_one_schema(
170
+ config, schema, INSTANCE_SCHEMA
171
+ )
172
+ instance_errors = json_schema_validate_as_string(
173
+ instance_schema, kwargs["data"]
174
+ )
168
175
 
169
176
  if instance_errors:
170
177
  raise InvalidData(
171
178
  payload=dict(jsonschema_errors=instance_errors),
172
179
  log_txt=f"Error while user {self.get_user()} tries to create an instance. "
173
- f"Instance data do not match the jsonschema.",
180
+ f"Instance data do not match the jsonschema.",
174
181
  )
175
182
 
176
183
  response = self.put_detail(data=kwargs, user=self.get_user(), idx=idx)
@@ -268,15 +275,35 @@ class InstanceFileEndpoint(BaseMetaResource):
268
275
  sense = 1 if minimize else -1
269
276
  try:
270
277
  _vars, problem = pulp.LpProblem.fromMPS(filename, sense=sense)
271
- except:
278
+ except FileNotFoundError as e:
279
+ # Handle file not found specifically
272
280
  raise InvalidUsage(
273
- error="There was an error reading the file",
274
- log_txt=f"Error while user {self.get_user()} tries to create instance from mps file. "
275
- f"There was an error reading the file.",
281
+ error=f"MPS file not found: {filename}",
282
+ log_txt=f"Error for user {self.get_user()}: MPS file '{filename}' not found. Details: {e}",
283
+ status_code=404,
284
+ ) from e
285
+ except PermissionError as e:
286
+ # Handle permission issues
287
+ raise InvalidUsage(
288
+ error=f"Permission denied reading MPS file: {filename}",
289
+ log_txt=f"Error for user {self.get_user()}: Permission denied for MPS file '{filename}'. Details: {e}",
290
+ status_code=403,
291
+ ) from e
292
+ except (ValueError, pulp.PulpError, OSError, IndexError) as e:
293
+ # Catch parsing errors, PuLP errors, and other IO errors
294
+ # Handle parsing, PuLP, or other OS errors
295
+ current_app.logger.error(
296
+ f"Error parsing MPS file {filename} for user {self.get_user()}: {e}",
297
+ exc_info=True,
276
298
  )
299
+ raise InvalidUsage(
300
+ error="Error reading or parsing the MPS file.",
301
+ log_txt=f"Error while user {self.get_user()} tries to create instance from MPS file {filename}. Details: {e}",
302
+ ) from e
303
+
277
304
  try:
278
305
  os.remove(filename)
279
- except:
306
+ except FileNotFoundError:
280
307
  pass
281
308
 
282
309
  pb_data = dict(
@@ -1,6 +1,7 @@
1
1
  """
2
2
  This file has all the logic shared for all the resources
3
3
  """
4
+
4
5
  # Import from external libraries
5
6
  from flask_restful import Resource
6
7
  from flask import g, request
@@ -13,7 +14,6 @@ from cornflow.shared.const import ALL_DEFAULT_ROLES
13
14
  from cornflow.shared.exceptions import InvalidUsage, ObjectDoesNotExist, NoPermission
14
15
 
15
16
 
16
-
17
17
  class BaseMetaResource(Resource, MethodResource):
18
18
  """
19
19
  The base resource from all methods inherit from.
@@ -30,7 +30,6 @@ class BaseMetaResource(Resource, MethodResource):
30
30
  self.auth_class = None
31
31
  self.dependents = None
32
32
  self.unique = None
33
- pass
34
33
 
35
34
  """
36
35
  METHODS USED FOR THE BASIC CRUD OPERATIONS: GET, POST, PUT, PATCH, DELETE
@@ -121,7 +120,7 @@ class BaseMetaResource(Resource, MethodResource):
121
120
  """
122
121
  item = self.data_model.get_one_object(**kwargs)
123
122
  if item is None:
124
- raise ObjectDoesNotExist("The data entity does not exist on the database")
123
+ raise ObjectDoesNotExist()
125
124
 
126
125
  data = dict(data)
127
126
 
@@ -144,7 +143,7 @@ class BaseMetaResource(Resource, MethodResource):
144
143
  item = self.data_model.get_one_object(**kwargs)
145
144
 
146
145
  if item is None:
147
- raise ObjectDoesNotExist("The data entity does not exist on the database")
146
+ raise ObjectDoesNotExist()
148
147
 
149
148
  data = dict(data)
150
149
 
@@ -164,7 +163,7 @@ class BaseMetaResource(Resource, MethodResource):
164
163
  """
165
164
  item = self.data_model.get_one_object(**kwargs)
166
165
  if item is None:
167
- raise ObjectDoesNotExist("The data entity does not exist on the database")
166
+ raise ObjectDoesNotExist()
168
167
  if self.dependents is not None:
169
168
  for element in getattr(item, self.dependents):
170
169
  element.delete()
@@ -0,0 +1,157 @@
1
+ """ """
2
+
3
+ # Full imports
4
+ import json
5
+ import requests
6
+
7
+ # Partial imports
8
+ from requests.auth import HTTPBasicAuth
9
+ from requests.exceptions import ConnectionError, HTTPError
10
+
11
+ # Imports from modules
12
+ from cornflow_client.constants import AirflowError
13
+ from cornflow_client.orchestrator_constants import config_orchestrator
14
+
15
+
16
+ class Airflow(object):
17
+ def __init__(self, url, user, pwd):
18
+ self.url = f"{url}/api/v1"
19
+ self.auth = HTTPBasicAuth(user, pwd)
20
+ self.constants = config_orchestrator["airflow"]
21
+
22
+ @classmethod
23
+ def from_config(cls, config):
24
+ data = dict(
25
+ url=config["AIRFLOW_URL"],
26
+ user=config["AIRFLOW_USER"],
27
+ pwd=config["AIRFLOW_PWD"],
28
+ )
29
+ return cls(**data)
30
+
31
+ def is_alive(self, config=None):
32
+ try:
33
+ response = requests.get(f"{self.url}/health")
34
+ except (ConnectionError, HTTPError):
35
+ return False
36
+ try:
37
+ data = response.json()
38
+ database = data["metadatabase"]["status"] == "healthy"
39
+ scheduler = data["scheduler"]["status"] == "healthy"
40
+ except json.JSONDecodeError:
41
+ return False
42
+ except KeyError:
43
+ return False
44
+
45
+ return database and scheduler
46
+
47
+ def get_dag_info(self, dag_name, method="GET"):
48
+ url = f"{self.url}/dags/{dag_name}"
49
+ return self.request_headers_auth(method=method, url=url)
50
+
51
+ def request_headers_auth(self, status=200, **kwargs):
52
+ def_headers = {"Content-type": "application/json", "Accept": "application/json"}
53
+ headers = kwargs.get("headers", def_headers)
54
+ response = requests.request(headers=headers, auth=self.auth, **kwargs)
55
+ if response.status_code != status:
56
+ raise AirflowError(error=response.text, status_code=response.status_code)
57
+ return response
58
+
59
+ def consume_dag_run(self, dag_name, payload, dag_run_id=None, method="POST"):
60
+ url = f"{self.url}/dags/{dag_name}/dagRuns"
61
+ if dag_run_id is not None:
62
+ url = url + f"/{dag_run_id}"
63
+ response = self.request_headers_auth(method=method, url=url, json=payload)
64
+ return response
65
+
66
+ def set_dag_run_state(self, dag_name, payload):
67
+ url = f"{self.url}/dags/{dag_name}/updateTaskInstancesState"
68
+ return self.request_headers_auth(method="POST", url=url, json=payload)
69
+
70
+ def run_workflow(
71
+ self,
72
+ execution_id,
73
+ orch_name=config_orchestrator["airflow"]["def_schema"],
74
+ checks_only=False,
75
+ case_id=None,
76
+ ):
77
+ conf = dict(exec_id=execution_id, checks_only=checks_only)
78
+ if case_id is not None:
79
+ conf["case_id"] = case_id
80
+ payload = dict(conf=conf)
81
+ return self.consume_dag_run(orch_name, payload=payload, method="POST")
82
+
83
+ def update_schemas(self, dag_name="update_all_schemas"):
84
+ return self.consume_dag_run(dag_name, payload={}, method="POST")
85
+
86
+ def update_dag_registry(self, dag_name="update_dag_registry"):
87
+ return self.consume_dag_run(dag_name, payload={}, method="POST")
88
+
89
+ def get_run_status(self, schema, run_id):
90
+ return self.consume_dag_run(
91
+ schema, payload=None, dag_run_id=run_id, method="GET"
92
+ )
93
+
94
+ def set_dag_run_to_fail(self, dag_name, run_id, new_status="failed"):
95
+ # here, two calls have to be done:
96
+ # first we get information on the dag_run
97
+ dag_run = self.consume_dag_run(
98
+ dag_name, payload=None, dag_run_id=run_id, method="GET"
99
+ )
100
+ dag_run_data = dag_run.json()
101
+ # then, we use the "executed_date" to build a call to the change state api
102
+ # TODO: We assume the solving task is named as is parent dag!
103
+ payload = dict(
104
+ dry_run=False,
105
+ include_downstream=True,
106
+ include_future=False,
107
+ include_past=False,
108
+ include_upstream=True,
109
+ new_state=new_status,
110
+ task_id=dag_name,
111
+ execution_date=dag_run_data["execution_date"],
112
+ )
113
+ return self.set_dag_run_state(dag_name, payload=payload)
114
+
115
+ def get_all_dag_runs(self, dag_name):
116
+ return self.consume_dag_run(dag_name=dag_name, payload=None, method="GET")
117
+
118
+ def get_orch_info(self, orch_name, method="GET"):
119
+ url = f"{self.url}/dags/{orch_name}"
120
+ schema_info = self.request_headers_auth(method=method, url=url)
121
+ if schema_info.status_code != 200:
122
+ raise AirflowError("DAG not available")
123
+ return schema_info
124
+
125
+ def get_one_variable(self, variable):
126
+ url = f"{self.url}/variables/{variable}"
127
+ return self.request_headers_auth(method="GET", url=url).json()
128
+
129
+ def get_all_variables(self):
130
+ return self.request_headers_auth(
131
+ method="GET", url=f"{self.url}/variables"
132
+ ).json()
133
+
134
+ def get_one_schema(self, dag_name, schema):
135
+ return self.get_schemas_for_dag_name(dag_name)[schema]
136
+
137
+ def get_schemas_for_dag_name(self, dag_name):
138
+ response = self.get_one_variable(dag_name)
139
+ result = json.loads(response["value"])
140
+ result["name"] = response["key"]
141
+ return result
142
+
143
+ def get_all_schemas(self):
144
+ response = self.get_all_variables()
145
+ return [dict(name=variable["key"]) for variable in response["variables"]]
146
+
147
+ def get_all_dags(self, method="GET"):
148
+ url = f"{self.url}/dags"
149
+ return self.request_headers_auth(method=method, url=url)
150
+
151
+ def get_internal_dags(self, method="GET"):
152
+ url = f"{self.url}/dags?tags=internal"
153
+ return self.request_headers_auth(method=method, url=url)
154
+
155
+ def get_model_dags(self, method="GET"):
156
+ url = f"{self.url}/dags?tags=model"
157
+ return self.request_headers_auth(method=method, url=url)
@@ -2,15 +2,16 @@
2
2
  This file contains the auth class that can be used for authentication on the request to the REST API
3
3
  """
4
4
 
5
- # Imports from external libraries
6
- import jwt
7
- import requests
8
- from jwt.algorithms import RSAAlgorithm
9
5
  from datetime import datetime, timedelta, timezone
10
- from flask import request, g, current_app, Request
11
6
  from functools import wraps
12
7
  from typing import Tuple
8
+
9
+ # Imports from external libraries
10
+ import jwt
11
+ import requests
13
12
  from cachetools import TTLCache
13
+ from flask import request, g, current_app, Request
14
+ from jwt.algorithms import RSAAlgorithm
14
15
  from werkzeug.datastructures import Headers
15
16
 
16
17
  # Imports from internal modules
@@ -31,7 +32,6 @@ from cornflow.shared.exceptions import (
31
32
  InvalidData,
32
33
  InvalidUsage,
33
34
  NoPermission,
34
- ObjectDoesNotExist,
35
35
  )
36
36
 
37
37
  # Cache for storing public keys with 1 hour TTL
@@ -122,15 +122,13 @@ class Auth:
122
122
  def decode_token(token: str = None) -> dict:
123
123
  """
124
124
  Decodes a given JSON Web token and extracts the username from the sub claim.
125
- Works with both internal tokens and OpenID tokens.
125
+ Works with both internal tokens and OpenID tokens by attempting verification methods sequentially.
126
126
 
127
127
  :param str token: the given JSON Web Token
128
128
  :return: dictionary containing the username from the token's sub claim
129
129
  :rtype: dict
130
130
  """
131
-
132
131
  if token is None:
133
-
134
132
  raise InvalidCredentials(
135
133
  "Must provide a token in Authorization header",
136
134
  log_txt="Error while trying to decode token. Token is missing.",
@@ -138,48 +136,81 @@ class Auth:
138
136
  )
139
137
 
140
138
  try:
141
- # First try to decode header to validate basic token structure
142
-
143
- unverified_payload = jwt.decode(token, options={"verify_signature": False})
144
- issuer = unverified_payload.get("iss")
145
-
146
- # For internal tokens
147
- if issuer == INTERNAL_TOKEN_ISSUER:
148
-
149
- return jwt.decode(
150
- token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
151
- )
152
-
153
- # For OpenID tokens
154
- if current_app.config["AUTH_TYPE"] == AUTH_OID:
155
-
156
- return Auth().verify_token(
157
- token,
158
- current_app.config["OID_PROVIDER"],
159
- current_app.config["OID_EXPECTED_AUDIENCE"],
160
- )
161
-
162
- # If we get here, the issuer is not valid
163
-
164
- raise InvalidCredentials(
165
- "Invalid token issuer. Token must be issued by a valid provider",
166
- log_txt="Error while trying to decode token. Invalid issuer.",
167
- status_code=400,
139
+ # Attempt 1: Verify as an internal token (HS256)
140
+ payload = jwt.decode(
141
+ token, current_app.config["SECRET_TOKEN_KEY"], algorithms=["HS256"]
168
142
  )
143
+ if payload.get("iss") != INTERNAL_TOKEN_ISSUER:
144
+ raise jwt.InvalidIssuerError(
145
+ "Internal token issuer mismatch after verification"
146
+ )
147
+ return payload
169
148
 
170
149
  except jwt.ExpiredSignatureError:
171
-
150
+ # Handle expiration specifically, could apply to either token type if caught here first
172
151
  raise InvalidCredentials(
173
152
  "The token has expired, please login again",
174
153
  log_txt="Error while trying to decode token. The token has expired.",
175
154
  status_code=400,
176
155
  )
177
- except jwt.InvalidTokenError as e:
178
-
156
+ except (
157
+ jwt.InvalidSignatureError,
158
+ jwt.DecodeError,
159
+ jwt.InvalidTokenError,
160
+ ) as e_internal:
161
+ # Internal verification failed (signature, format, etc.). Try OIDC if configured.
162
+ if current_app.config["AUTH_TYPE"] == AUTH_OID:
163
+ try:
164
+ # Attempt 2: Verify as an OIDC token (RS256) using the dedicated method
165
+ return Auth().verify_token(
166
+ token,
167
+ current_app.config["OID_PROVIDER"],
168
+ current_app.config["OID_EXPECTED_AUDIENCE"],
169
+ )
170
+ except jwt.ExpiredSignatureError:
171
+ # OIDC token expired
172
+ raise InvalidCredentials(
173
+ "The token has expired, please login again",
174
+ log_txt="Error while trying to decode OIDC token. The token has expired.",
175
+ status_code=400,
176
+ )
177
+ except (
178
+ jwt.InvalidTokenError,
179
+ InvalidCredentials,
180
+ CommunicationError,
181
+ ) as e_oidc:
182
+ # OIDC verification failed (JWT format, signature, kid, audience, issuer, comms error)
183
+ # Log details for debugging but return a generic error to the client.
184
+ log_message = (
185
+ f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
186
+ f"OIDC verification failed ({type(e_oidc).__name__}: {str(e_oidc)})."
187
+ )
188
+ current_app.logger.warning(log_message)
189
+ raise InvalidCredentials(
190
+ "Invalid token format, signature, or configuration",
191
+ log_txt=log_message,
192
+ status_code=400,
193
+ )
194
+ else:
195
+ # Internal verification failed, and OIDC is not configured
196
+ log_message = (
197
+ f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
198
+ f"OIDC is not configured."
199
+ )
200
+ current_app.logger.warning(log_message)
201
+ raise InvalidCredentials(
202
+ "Invalid token format or signature",
203
+ log_txt=log_message,
204
+ status_code=400,
205
+ )
206
+ except Exception as e:
207
+ # Catch any other unexpected errors during the process
208
+ log_message = f"Unexpected error during token decoding: {str(e)}"
209
+ current_app.logger.error(log_message)
179
210
  raise InvalidCredentials(
180
- "Invalid token format or signature",
181
- log_txt=f"Error while trying to decode token. The token format is invalid: {str(e)}",
182
- status_code=400,
211
+ "Could not decode or verify token due to an unexpected server error",
212
+ log_txt=log_message,
213
+ status_code=500,
183
214
  )
184
215
 
185
216
  def get_token_from_header(self, headers: Headers = None) -> str:
cornflow/shared/const.py CHANGED
@@ -135,3 +135,12 @@ BASE_PERMISSION_ASSIGNATION = [
135
135
  EXTRA_PERMISSION_ASSIGNATION = [
136
136
  (VIEWER_ROLE, PUT_ACTION, "user-detail"),
137
137
  ]
138
+
139
+ # migrations constants
140
+ MIGRATIONS_DEFAULT_PATH = "./cornflow/migrations"
141
+
142
+ # Costants for messages that are given back on exceptions
143
+ AIRFLOW_NOT_REACHABLE_MSG = "Airflow is not reachable"
144
+ DAG_PAUSED_MSG = "The dag exists but it is paused in airflow"
145
+ AIRFLOW_ERROR_MSG = "Airflow responded with an error:"
146
+ DATA_DOES_NOT_EXIST_MSG = "The data entity does not exist on the database"