cornflow 2.0.0a12__py3-none-any.whl → 2.0.0a14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. cornflow/app.py +3 -1
  2. cornflow/cli/__init__.py +4 -0
  3. cornflow/cli/actions.py +4 -0
  4. cornflow/cli/config.py +4 -0
  5. cornflow/cli/migrations.py +13 -8
  6. cornflow/cli/permissions.py +4 -0
  7. cornflow/cli/roles.py +4 -0
  8. cornflow/cli/schemas.py +5 -0
  9. cornflow/cli/service.py +260 -147
  10. cornflow/cli/tools/api_generator.py +13 -10
  11. cornflow/cli/tools/endpoint_tools.py +191 -196
  12. cornflow/cli/tools/models_tools.py +87 -60
  13. cornflow/cli/tools/schema_generator.py +161 -67
  14. cornflow/cli/tools/schemas_tools.py +4 -5
  15. cornflow/cli/users.py +8 -0
  16. cornflow/cli/views.py +4 -0
  17. cornflow/commands/dag.py +3 -2
  18. cornflow/commands/schemas.py +6 -4
  19. cornflow/commands/users.py +12 -17
  20. cornflow/config.py +3 -2
  21. cornflow/endpoints/dag.py +27 -25
  22. cornflow/endpoints/data_check.py +102 -164
  23. cornflow/endpoints/example_data.py +9 -3
  24. cornflow/endpoints/execution.py +27 -23
  25. cornflow/endpoints/health.py +4 -5
  26. cornflow/endpoints/instance.py +39 -12
  27. cornflow/endpoints/meta_resource.py +4 -5
  28. cornflow/schemas/execution.py +1 -0
  29. cornflow/shared/airflow.py +157 -0
  30. cornflow/shared/authentication/auth.py +73 -42
  31. cornflow/shared/const.py +9 -0
  32. cornflow/shared/databricks.py +10 -10
  33. cornflow/shared/exceptions.py +3 -1
  34. cornflow/shared/utils_tables.py +36 -8
  35. cornflow/shared/validators.py +1 -1
  36. cornflow/tests/const.py +1 -0
  37. cornflow/tests/custom_test_case.py +4 -4
  38. cornflow/tests/unit/test_alarms.py +1 -2
  39. cornflow/tests/unit/test_cases.py +4 -7
  40. cornflow/tests/unit/test_executions.py +105 -43
  41. cornflow/tests/unit/test_log_in.py +46 -9
  42. cornflow/tests/unit/test_tables.py +3 -3
  43. cornflow/tests/unit/tools.py +31 -13
  44. {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/METADATA +2 -2
  45. {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/RECORD +48 -47
  46. {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/WHEEL +1 -1
  47. {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/entry_points.txt +0 -0
  48. {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/top_level.txt +0 -0
@@ -48,12 +48,16 @@ class Databricks:
48
48
  oauth_token = oauth_response.json()["access_token"]
49
49
  return oauth_token
50
50
 
51
- def is_alive(self):
51
+ def is_alive(self, config=None):
52
52
  try:
53
- # TODO: this url is project specific. Either it has to be a config option or some other way has to be found
54
- path = (
55
- "/Workspace/Repos/nippon/nippon_production_scheduling/requirements.txt"
56
- )
53
+ if config is None or config["DATABRICKS_HEALTH_PATH"] == "default path":
54
+ # We raise an error because the default path is not valid
55
+ raise DatabricksError(
56
+ "Invalid default path. Please set DATABRICKS_HEALTH_PATH as an environment variable"
57
+ )
58
+ else:
59
+ path = config["DATABRICKS_HEALTH_PATH"]
60
+
57
61
  url = f"{self.url}/api/2.0/workspace/get-status?path={path}"
58
62
  response = self.request_headers_auth(method="GET", url=url)
59
63
  if "error_code" in response.json().keys():
@@ -75,8 +79,6 @@ class Databricks:
75
79
  raise DatabricksError("JOB not available")
76
80
  return schema_info
77
81
 
78
- # TODO AGA: incluir un id de job por defecto o hacer obligatorio el uso el parámetro.
79
- # Revisar los efectos secundarios de eliminar execution_id y usar el predeterminado
80
82
  def run_workflow(
81
83
  self,
82
84
  execution_id,
@@ -87,13 +89,11 @@ class Databricks:
87
89
  """
88
90
  Run a job in Databricks
89
91
  """
90
- # TODO AGA: revisar si la url esta bien/si acepta asi los parámetros
91
92
  url = f"{self.url}/api/2.1/jobs/run-now/"
92
- # TODO AGA: revisar si deben ser notebook parameters o job parameters.
93
93
  # Entender cómo se usa checks_only
94
94
  payload = dict(
95
95
  job_id=orch_name,
96
- notebook_parameters=dict(
96
+ job_parameters=dict(
97
97
  checks_only=checks_only,
98
98
  execution_id=execution_id,
99
99
  ),
@@ -9,6 +9,8 @@ from cornflow_client.constants import AirflowError, DatabricksError
9
9
  from werkzeug.exceptions import HTTPException
10
10
  import traceback
11
11
 
12
+ from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
13
+
12
14
 
13
15
  class InvalidUsage(Exception):
14
16
  """
@@ -48,7 +50,7 @@ class ObjectDoesNotExist(InvalidUsage):
48
50
  """
49
51
 
50
52
  status_code = 404
51
- error = "The object does not exist"
53
+ error = DATA_DOES_NOT_EXIST_MSG
52
54
 
53
55
 
54
56
  class ObjectAlreadyExists(InvalidUsage):
@@ -2,6 +2,7 @@
2
2
  import inspect
3
3
  import os
4
4
  import sys
5
+ import collections
5
6
 
6
7
  from importlib import import_module
7
8
  from sqlalchemy.dialects.postgresql import TEXT
@@ -31,14 +32,42 @@ def import_models():
31
32
 
32
33
 
33
34
  def all_subclasses(cls, models=None):
34
- subclasses = set()
35
+ """Finds all direct and indirect subclasses of a given class.
36
+
37
+ Optionally filters a provided list of models first.
38
+
39
+ Args:
40
+ cls: The base class to find subclasses for.
41
+ models: An optional iterable of classes to pre-filter.
42
+
43
+ Returns:
44
+ A set containing all subclasses found.
45
+ """
46
+ filtered_subclasses = set()
35
47
  if models is not None:
36
48
  for val in models:
37
- if issubclass(val, cls):
38
- subclasses.add(val)
49
+ # Ensure val is a class before checking issubclass
50
+ if isinstance(val, type) and issubclass(val, cls):
51
+ filtered_subclasses.add(val)
52
+
53
+ all_descendants = set()
54
+ # Use a deque for efficient pop(0)
55
+ queue = collections.deque(cls.__subclasses__())
56
+ # Keep track of visited classes during the traversal to handle potential complex hierarchies
57
+ # (though direct subclass relationships shouldn't form cycles)
58
+ # Initialize with direct subclasses as they are the starting point.
59
+ visited_for_queue = set(cls.__subclasses__())
60
+
61
+ while queue:
62
+ current_sub = queue.popleft()
63
+ all_descendants.add(current_sub)
64
+
65
+ for grandchild in current_sub.__subclasses__():
66
+ if grandchild not in visited_for_queue:
67
+ visited_for_queue.add(grandchild)
68
+ queue.append(grandchild)
39
69
 
40
- return subclasses.union(set(cls.__subclasses__()).union(
41
- [s for c in cls.__subclasses__() for s in all_subclasses(c)]))
70
+ return filtered_subclasses.union(all_descendants)
42
71
 
43
72
 
44
73
  type_converter = {
@@ -76,6 +105,5 @@ def item_as_dict(item):
76
105
 
77
106
  def items_as_dict_list(ls):
78
107
  return [
79
- {c.name: getattr(item, c.name) for c in item.__table__.columns}
80
- for item in ls
81
- ]
108
+ {c.name: getattr(item, c.name) for c in item.__table__.columns} for item in ls
109
+ ]
@@ -1,12 +1,12 @@
1
1
  """
2
2
  This file has several validators
3
3
  """
4
+
4
5
  import re
5
6
  from typing import Tuple, Union
6
7
 
7
8
  from jsonschema import Draft7Validator, validators
8
9
  from disposable_email_domains import blocklist
9
- from jsonschema.protocols import Validator
10
10
 
11
11
 
12
12
  def is_special_character(character):
cornflow/tests/const.py CHANGED
@@ -15,6 +15,7 @@ INSTANCE_URL = PREFIX + "/instance/"
15
15
  INSTANCE_MPS = _get_file("./data/test_mps.mps")
16
16
  INSTANCE_GC_20 = _get_file("./data/gc_20_7.json")
17
17
  INSTANCE_FILE_FAIL = _get_file("./unit/test_instances.py")
18
+ EDIT_EXECUTION_SOLUTION = _get_file("./data/edit_execution_solution.json")
18
19
 
19
20
  EXECUTION_PATH = _get_file("./data/new_execution.json")
20
21
  BAD_EXECUTION_PATH = _get_file("./data/bad_execution.json")
@@ -635,9 +635,9 @@ class BaseTestCases:
635
635
  # (we patch the request to airflow to check if the schema is valid)
636
636
  # we create 4 instances
637
637
  data_many = [self.payload for _ in range(4)]
638
- data_many[-1] = {**data_many[-1], **dict(schema="timer")}
639
- allrows = self.get_rows(self.url, data_many)
640
- self.apply_filter(self.url, dict(schema="timer"), allrows.json[:1])
638
+
639
+ self.get_rows(self.url, data_many)
640
+ self.apply_filter(self.url, dict(schema="timer"), [])
641
641
 
642
642
  def test_opt_filters_date_lte(self):
643
643
  """
@@ -1117,7 +1117,7 @@ class LoginTestCases:
1117
1117
 
1118
1118
  self.assertEqual(400, response.status_code)
1119
1119
  self.assertEqual(
1120
- "Invalid token issuer. Token must be issued by a valid provider",
1120
+ "Invalid token format or signature",
1121
1121
  response.json["error"],
1122
1122
  )
1123
1123
 
@@ -91,7 +91,6 @@ class TestAlarmsEndpoint(CustomTestCase):
91
91
  class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint):
92
92
  def setUp(self):
93
93
  super().setUp()
94
- self.url = self.url
95
94
  self.idx = 0
96
95
  self.payload = {
97
96
  "name": "Alarm 1",
@@ -138,4 +137,4 @@ class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint)
138
137
  # We check deleted at has a value
139
138
  self.assertIsNotNone(row.deleted_at)
140
139
  else:
141
- self.assertIsNone(row.deleted_at)
140
+ self.assertIsNone(row.deleted_at)
@@ -42,6 +42,7 @@ import zlib
42
42
 
43
43
  # Import from internal modules
44
44
  from cornflow.models import CaseModel, ExecutionModel, InstanceModel, UserModel
45
+ from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
45
46
  from cornflow.shared.utils import hash_json_256
46
47
  from cornflow.tests.const import (
47
48
  INSTANCE_URL,
@@ -227,12 +228,8 @@ class TestCasesFromInstanceExecutionEndpoint(CustomTestCase):
227
228
  "execution_id": execution_id,
228
229
  "schema": "solve_model_dag",
229
230
  }
230
- self.instance = InstanceModel.get_one_object(
231
- user=self.user, idx=instance_id
232
- )
233
- self.execution = ExecutionModel.get_one_object(
234
- user=self.user, idx=execution_id
235
- )
231
+ self.instance = InstanceModel.get_one_object(user=self.user, idx=instance_id)
232
+ self.execution = ExecutionModel.get_one_object(user=self.user, idx=execution_id)
236
233
 
237
234
  def test_new_case_execution(self):
238
235
  """
@@ -729,7 +726,7 @@ class TestCaseToInstanceEndpoint(CustomTestCase):
729
726
  headers=self.get_header_with_auth(self.token),
730
727
  )
731
728
  self.assertEqual(response.status_code, 404)
732
- self.assertEqual(response.json["error"], "The object does not exist")
729
+ self.assertEqual(response.json["error"], DATA_DOES_NOT_EXIST_MSG)
733
730
 
734
731
 
735
732
  class TestCaseJsonPatch(CustomTestCase):
@@ -20,6 +20,7 @@ from cornflow.tests.const import (
20
20
  DAG_URL,
21
21
  BAD_EXECUTION_PATH,
22
22
  EXECUTION_SOLUTION_PATH,
23
+ EDIT_EXECUTION_SOLUTION,
23
24
  )
24
25
  from cornflow.tests.custom_test_case import CustomTestCase, BaseTestCases
25
26
  from cornflow.tests.unit.tools import patch_af_client, patch_db_client
@@ -186,7 +187,9 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
186
187
  app = create_app("testing-databricks")
187
188
  return app
188
189
 
189
- def test_new_execution(self):
190
+ @patch("cornflow.endpoints.execution.Databricks")
191
+ def test_new_execution(self, db_client_class):
192
+ patch_db_client(db_client_class)
190
193
  self.create_new_row(self.url, self.model, payload=self.payload)
191
194
 
192
195
  @patch("cornflow.endpoints.execution.Databricks")
@@ -231,7 +234,6 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
231
234
 
232
235
  @patch("cornflow.endpoints.execution.Databricks")
233
236
  def test_new_execution_with_solution_bad(self, db_client_class):
234
- patch_db_client(db_client_class)
235
237
  patch_db_client(db_client_class)
236
238
  self.payload["data"] = {"message": "THIS IS NOT A VALID SOLUTION"}
237
239
  response = self.create_new_row(
@@ -244,7 +246,9 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
244
246
  self.assertIn("error", response)
245
247
  self.assertIn("jsonschema_errors", response)
246
248
 
247
- def test_new_execution_no_instance(self):
249
+ @patch("cornflow.endpoints.execution.Databricks")
250
+ def test_new_execution_no_instance(self, db_client_class):
251
+ patch_db_client(db_client_class)
248
252
  payload = dict(self.payload)
249
253
  payload["instance_id"] = "bad_id"
250
254
  response = self.client.post(
@@ -256,13 +260,19 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
256
260
  self.assertEqual(404, response.status_code)
257
261
  self.assertTrue("error" in response.json)
258
262
 
259
- def test_get_executions(self):
263
+ @patch("cornflow.endpoints.execution.Databricks")
264
+ def test_get_executions(self, db_client_class):
265
+ patch_db_client(db_client_class)
260
266
  self.get_rows(self.url, self.payloads, keys_to_check=self.keys_to_check)
261
267
 
262
- def test_get_no_executions(self):
268
+ @patch("cornflow.endpoints.execution.Databricks")
269
+ def test_get_no_executions(self, db_client_class):
270
+ patch_db_client(db_client_class)
263
271
  self.get_no_rows(self.url)
264
272
 
265
- def test_get_executions_superadmin(self):
273
+ @patch("cornflow.endpoints.execution.Databricks")
274
+ def test_get_executions_superadmin(self, db_client_class):
275
+ patch_db_client(db_client_class)
266
276
  self.get_rows(self.url, self.payloads, keys_to_check=self.keys_to_check)
267
277
  token = self.create_service_user()
268
278
  rows = self.client.get(
@@ -473,6 +483,7 @@ class TestExecutionsDetailEndpointMock(CustomTestCase):
473
483
  with open(INSTANCE_PATH) as f:
474
484
  payload = json.load(f)
475
485
  fk_id = self.create_new_row(INSTANCE_URL, InstanceModel, payload)
486
+ self.instance_payload = payload
476
487
  self.model = ExecutionModel
477
488
  self.response_items = {
478
489
  "id",
@@ -501,7 +512,6 @@ class TestExecutionsDetailEndpointAirflow(
501
512
  ):
502
513
  def setUp(self):
503
514
  super().setUp()
504
- self.url = self.url
505
515
  self.query_arguments = {"run": 0}
506
516
 
507
517
  def create_app(self):
@@ -523,51 +533,41 @@ class TestExecutionsDetailEndpointAirflow(
523
533
  self.assertEqual(200, response.status_code)
524
534
  self.assertEqual(response.json["message"], "The execution has been stopped")
525
535
 
536
+ def test_edit_execution(self):
526
537
 
527
- class TestExecutionsDetailEndpointDatabricks(
528
- TestExecutionsDetailEndpointMock, BaseTestCases.DetailEndpoint
529
- ):
530
- def setUp(self):
531
- super().setUp()
532
- self.url = self.url
533
- self.query_arguments = {"run": 0}
534
-
535
- def create_app(self):
536
- app = create_app("testing-databricks")
537
- return app
538
-
539
- @patch("cornflow.endpoints.execution.Databricks")
540
- def test_stop_execution(self, db_client_class):
541
- patch_db_client(db_client_class)
542
-
543
- idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
544
-
545
- response = self.client.post(
546
- self.url + str(idx) + "/",
547
- follow_redirects=True,
548
- headers=self.get_header_with_auth(self.token),
538
+ id_new_instance = self.create_new_row(
539
+ INSTANCE_URL, InstanceModel, self.instance_payload
549
540
  )
550
-
551
- self.assertEqual(200, response.status_code)
552
- self.assertEqual(
553
- response.json["message"], "This feature is not available for Databricks"
541
+ idx = self.create_new_row(
542
+ self.url_with_query_arguments(), self.model, self.payload
554
543
  )
555
544
 
545
+ # Extract the data from data/edit_execution_solution.json
546
+ with open(EDIT_EXECUTION_SOLUTION) as f:
547
+ data = json.load(f)
556
548
 
557
- class TestExecutionsStatusEndpointAirflow(TestExecutionsDetailEndpointMock):
558
- def setUp(self):
559
- super().setUp()
560
- self.response_items = {"id", "name", "status"}
561
- self.items_to_check = []
562
-
563
- def create_app(self):
564
- app = create_app("testing")
565
- return app
549
+ data = {
550
+ "name": "new_name",
551
+ "description": "Updated description",
552
+ "data": data,
553
+ "instance_id": id_new_instance,
554
+ }
555
+ payload_to_check = {
556
+ "id": idx,
557
+ "name": "new_name",
558
+ "description": "Updated description",
559
+ "data_hash": "74234e98afe7498fb5daf1f36ac2d78acc339464f950703b8c019892f982b90b",
560
+ "instance_id": "805bad3280c95e45384dc6bd91a41317f9a7858c",
561
+ }
562
+ self.update_row(
563
+ self.url + str(idx) + "/",
564
+ data,
565
+ payload_to_check,
566
+ )
566
567
 
567
568
  @patch("cornflow.endpoints.execution.Airflow")
568
569
  def test_get_one_status(self, af_client_class):
569
570
  patch_af_client(af_client_class)
570
-
571
571
  idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
572
572
  payload = dict(self.payload)
573
573
  payload["id"] = idx
@@ -598,6 +598,68 @@ class TestExecutionsStatusEndpointAirflow(TestExecutionsDetailEndpointMock):
598
598
  self.assertEqual(f"execution {idx} updated correctly", response.json["message"])
599
599
 
600
600
 
601
+ class TestExecutionsDetailEndpointDatabricks(
602
+ TestExecutionsDetailEndpointMock, BaseTestCases.DetailEndpoint
603
+ ):
604
+ def setUp(self):
605
+ super().setUp()
606
+ self.url = self.url
607
+ self.query_arguments = {"run": 0}
608
+
609
+ def create_app(self):
610
+ app = create_app("testing-databricks")
611
+ return app
612
+
613
+ @patch("cornflow.endpoints.execution.Databricks")
614
+ def test_stop_execution(self, db_client_class):
615
+ patch_db_client(db_client_class)
616
+
617
+ idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
618
+
619
+ response = self.client.post(
620
+ self.url + str(idx) + "/",
621
+ follow_redirects=True,
622
+ headers=self.get_header_with_auth(self.token),
623
+ )
624
+
625
+ self.assertEqual(200, response.status_code)
626
+ self.assertEqual(
627
+ response.json["message"], "This feature is not available for Databricks"
628
+ )
629
+
630
+ def test_edit_execution(self):
631
+
632
+ id_new_instance = self.create_new_row(
633
+ INSTANCE_URL, InstanceModel, self.instance_payload
634
+ )
635
+ idx = self.create_new_row(
636
+ self.url_with_query_arguments(), self.model, self.payload
637
+ )
638
+
639
+ # Extract the data from data/edit_execution_solution.json
640
+ with open(EDIT_EXECUTION_SOLUTION) as f:
641
+ data = json.load(f)
642
+
643
+ data = {
644
+ "name": "new_name",
645
+ "description": "Updated description",
646
+ "data": data,
647
+ "instance_id": id_new_instance,
648
+ }
649
+ payload_to_check = {
650
+ "id": idx,
651
+ "name": "new_name",
652
+ "description": "Updated description",
653
+ "data_hash": "74234e98afe7498fb5daf1f36ac2d78acc339464f950703b8c019892f982b90b",
654
+ "instance_id": "805bad3280c95e45384dc6bd91a41317f9a7858c",
655
+ }
656
+ self.update_row(
657
+ self.url + str(idx) + "/",
658
+ data,
659
+ payload_to_check,
660
+ )
661
+
662
+
601
663
  class TestExecutionsStatusEndpointDatabricks(TestExecutionsDetailEndpointMock):
602
664
  def setUp(self):
603
665
  super().setUp()
@@ -135,11 +135,20 @@ class TestLogInOpenAuth(CustomTestCase):
135
135
  Tests token validation failure when the kid is not found in public keys
136
136
  """
137
137
  # Import the real exceptions to ensure they are preserved
138
- from jwt.exceptions import InvalidTokenError, ExpiredSignatureError
138
+ from jwt.exceptions import (
139
+ InvalidTokenError,
140
+ ExpiredSignatureError,
141
+ InvalidIssuerError,
142
+ InvalidSignatureError,
143
+ DecodeError,
144
+ )
139
145
 
140
146
  # Keep the real exception classes in the mock
141
147
  mock_jwt.ExpiredSignatureError = ExpiredSignatureError
142
148
  mock_jwt.InvalidTokenError = InvalidTokenError
149
+ mock_jwt.InvalidIssuerError = InvalidIssuerError
150
+ mock_jwt.InvalidSignatureError = InvalidSignatureError
151
+ mock_jwt.DecodeError = DecodeError
143
152
 
144
153
  mock_jwt.get_unverified_header.return_value = {"kid": "test_kid"}
145
154
 
@@ -165,7 +174,7 @@ class TestLogInOpenAuth(CustomTestCase):
165
174
 
166
175
  self.assertEqual(400, response.status_code)
167
176
  self.assertEqual(
168
- response.json["error"], "Invalid token: Unknown key identifier (kid)"
177
+ response.json["error"], "Invalid token format, signature, or configuration"
169
178
  )
170
179
 
171
180
  @mock.patch("cornflow.shared.authentication.auth.jwt")
@@ -174,11 +183,20 @@ class TestLogInOpenAuth(CustomTestCase):
174
183
  Tests token validation failure when the token header is missing the kid
175
184
  """
176
185
  # Import the real exceptions to ensure they are preserved
177
- from jwt.exceptions import InvalidTokenError, ExpiredSignatureError
186
+ from jwt.exceptions import (
187
+ InvalidTokenError,
188
+ ExpiredSignatureError,
189
+ InvalidIssuerError,
190
+ InvalidSignatureError,
191
+ DecodeError,
192
+ )
178
193
 
179
194
  # Keep the real exception classes in the mock
180
195
  mock_jwt.ExpiredSignatureError = ExpiredSignatureError
181
196
  mock_jwt.InvalidTokenError = InvalidTokenError
197
+ mock_jwt.InvalidIssuerError = InvalidIssuerError
198
+ mock_jwt.InvalidSignatureError = InvalidSignatureError
199
+ mock_jwt.DecodeError = DecodeError
182
200
 
183
201
  # Mock jwt.get_unverified_header to return a header without kid
184
202
  mock_jwt.get_unverified_header.return_value = {"alg": "RS256"}
@@ -194,8 +212,7 @@ class TestLogInOpenAuth(CustomTestCase):
194
212
 
195
213
  self.assertEqual(400, response.status_code)
196
214
  self.assertEqual(
197
- response.json["error"],
198
- "Invalid token: Missing key identifier (kid) in token header",
215
+ "Invalid token format, signature, or configuration", response.json["error"]
199
216
  )
200
217
 
201
218
  @mock.patch("cornflow.shared.authentication.auth.requests.get")
@@ -205,11 +222,20 @@ class TestLogInOpenAuth(CustomTestCase):
205
222
  Tests failure when trying to fetch public keys from the OIDC provider
206
223
  """
207
224
  # Import the real exceptions to ensure they are preserved
208
- from jwt.exceptions import InvalidTokenError, ExpiredSignatureError
225
+ from jwt.exceptions import (
226
+ InvalidTokenError,
227
+ ExpiredSignatureError,
228
+ InvalidIssuerError,
229
+ InvalidSignatureError,
230
+ DecodeError,
231
+ )
209
232
 
210
233
  # Keep the real exception classes in the mock
211
234
  mock_jwt.ExpiredSignatureError = ExpiredSignatureError
212
235
  mock_jwt.InvalidTokenError = InvalidTokenError
236
+ mock_jwt.InvalidIssuerError = InvalidIssuerError
237
+ mock_jwt.InvalidSignatureError = InvalidSignatureError
238
+ mock_jwt.DecodeError = DecodeError
213
239
 
214
240
  # Clear the cache
215
241
  from cornflow.shared.authentication.auth import public_keys_cache
@@ -240,8 +266,8 @@ class TestLogInOpenAuth(CustomTestCase):
240
266
 
241
267
  self.assertEqual(400, response.status_code)
242
268
  self.assertEqual(
269
+ "Invalid token format, signature, or configuration",
243
270
  response.json["error"],
244
- "Failed to fetch public keys from authentication provider",
245
271
  )
246
272
 
247
273
  @mock.patch("cornflow.shared.authentication.Auth.decode_token")
@@ -289,11 +315,20 @@ class TestLogInOpenAuth(CustomTestCase):
289
315
  the system fetches fresh keys from the provider.
290
316
  """
291
317
  # Import the real exceptions to ensure they are preserved
292
- from jwt.exceptions import InvalidTokenError, ExpiredSignatureError
318
+ from jwt.exceptions import (
319
+ InvalidTokenError,
320
+ ExpiredSignatureError,
321
+ InvalidIssuerError,
322
+ InvalidSignatureError,
323
+ DecodeError,
324
+ )
293
325
 
294
326
  # Keep the real exception classes in the mock
295
327
  mock_jwt.ExpiredSignatureError = ExpiredSignatureError
296
328
  mock_jwt.InvalidTokenError = InvalidTokenError
329
+ mock_jwt.InvalidIssuerError = InvalidIssuerError
330
+ mock_jwt.InvalidSignatureError = InvalidSignatureError
331
+ mock_jwt.DecodeError = DecodeError
297
332
 
298
333
  # Mock jwt to return valid unverified header and payload
299
334
  mock_jwt.get_unverified_header.return_value = {"kid": "test_kid"}
@@ -531,4 +566,6 @@ class TestLogInOpenAuthService(CustomTestCase):
531
566
  )
532
567
 
533
568
  self.assertEqual(400, response.status_code)
534
- self.assertEqual(response.json["error"], "Invalid token format or signature")
569
+ self.assertEqual(
570
+ response.json["error"], "Invalid token format, signature, or configuration"
571
+ )
@@ -11,7 +11,7 @@ from cornflow.app import create_app
11
11
  from cornflow.commands.access import access_init_command
12
12
  from cornflow.models import UserRoleModel
13
13
  from cornflow.shared import db
14
- from cornflow.shared.const import ADMIN_ROLE, SERVICE_ROLE
14
+ from cornflow.shared.const import ADMIN_ROLE, SERVICE_ROLE, DATA_DOES_NOT_EXIST_MSG
15
15
  from cornflow.tests.const import LOGIN_URL, SIGNUP_URL, TABLES_URL
16
16
 
17
17
 
@@ -220,7 +220,7 @@ class TestTablesDetailEndpoint(TestCase):
220
220
  },
221
221
  )
222
222
  self.assertEqual(response.status_code, 404)
223
- self.assertEqual(response.json["error"], "The object does not exist")
223
+ self.assertEqual(response.json["error"], DATA_DOES_NOT_EXIST_MSG)
224
224
 
225
225
 
226
226
  class TestTablesEndpointAdmin(TestCase):
@@ -291,4 +291,4 @@ class TestTablesEndpointAdmin(TestCase):
291
291
  self.assertEqual(response.status_code, 403)
292
292
  self.assertEqual(
293
293
  response.json["error"], "You do not have permission to access this endpoint"
294
- )
294
+ )
@@ -3,7 +3,10 @@ from flask import Flask
3
3
  import json
4
4
  import os
5
5
 
6
- from cornflow.shared.const import DATABRICKS_TERMINATE_STATE, DATABRICKS_FINISH_TO_STATE_MAP
6
+ from cornflow.shared.const import (
7
+ DATABRICKS_TERMINATE_STATE,
8
+ DATABRICKS_FINISH_TO_STATE_MAP,
9
+ )
7
10
 
8
11
 
9
12
  def create_test_app():
@@ -13,9 +16,7 @@ def create_test_app():
13
16
 
14
17
 
15
18
  def patch_af_client(af_client_class):
16
- with patch(
17
- "cornflow.endpoints.execution.current_app.config"
18
- ) as mock_config:
19
+ with patch("cornflow.endpoints.execution.current_app.config") as mock_config:
19
20
  mock_config.__getitem__.side_effect = lambda key: (
20
21
  1 if key == "CORNFLOW_BACKEND" else {}
21
22
  )
@@ -27,9 +28,24 @@ def patch_af_client(af_client_class):
27
28
  "state": "success",
28
29
  }
29
30
  af_client_mock.is_alive.return_value = True
30
- af_client_mock.get_orch_info.return_value = responses_mock
31
- af_client_mock.run_workflow.return_value = responses_mock
32
- af_client_mock.get_run_status.return_value = responses_mock
31
+ af_client_mock.get_dag_info.return_value = responses_mock
32
+
33
+ # Configurar get_orch_info
34
+ schema_info_mock = Mock()
35
+ schema_info_mock.json.return_value = {"is_paused": False}
36
+ af_client_mock.get_orch_info.return_value = schema_info_mock
37
+
38
+ # Configurar run_workflow
39
+ run_response_mock = Mock()
40
+ run_response_mock.json.return_value = {"dag_run_id": "12345"}
41
+ af_client_mock.run_workflow.return_value = run_response_mock
42
+
43
+ # Configurar get_run_status
44
+ status_mock = Mock()
45
+ status_mock.json.return_value = {"state": "success"}
46
+ af_client_mock.get_run_status.return_value = status_mock
47
+
48
+ af_client_mock.run_dag.return_value = responses_mock
33
49
  af_client_mock.set_dag_run_to_fail.return_value = None
34
50
  af_client_class.from_config.return_value = af_client_mock
35
51
 
@@ -37,9 +53,7 @@ def patch_af_client(af_client_class):
37
53
  def patch_db_client(db_client_class):
38
54
  mock_config = {"CORNFLOW_BACKEND": 2}
39
55
 
40
- with patch(
41
- "cornflow.endpoints.execution.current_app.config", mock_config
42
- ):
56
+ with patch("cornflow.endpoints.execution.current_app.config", mock_config):
43
57
  db_client_mock = Mock()
44
58
  responses_mock = Mock()
45
59
  responses_mock.json.return_value = {
@@ -65,10 +79,14 @@ def patch_db_client(db_client_class):
65
79
  state = response_get_run_status.json.return_value["status"]["state"]
66
80
  if state == DATABRICKS_TERMINATE_STATE:
67
81
  if (
68
- response_get_run_status.json.return_value["status"]["termination_details"]["code"]
69
- in DATABRICKS_FINISH_TO_STATE_MAP.keys()
82
+ response_get_run_status.json.return_value["status"][
83
+ "termination_details"
84
+ ]["code"]
85
+ in DATABRICKS_FINISH_TO_STATE_MAP.keys()
70
86
  ):
71
- response_get_run_status = response_get_run_status.json.return_value["status"]["termination_details"]["code"]
87
+ response_get_run_status = response_get_run_status.json.return_value[
88
+ "status"
89
+ ]["termination_details"]["code"]
72
90
  else:
73
91
  response_get_run_status = "OTHER_FINISH_ERROR"
74
92
  db_client_mock.is_alive.return_value = True