cornflow 2.0.0a12__py3-none-any.whl → 2.0.0a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornflow/app.py +3 -1
- cornflow/cli/__init__.py +4 -0
- cornflow/cli/actions.py +4 -0
- cornflow/cli/config.py +4 -0
- cornflow/cli/migrations.py +13 -8
- cornflow/cli/permissions.py +4 -0
- cornflow/cli/roles.py +4 -0
- cornflow/cli/schemas.py +5 -0
- cornflow/cli/service.py +260 -147
- cornflow/cli/tools/api_generator.py +13 -10
- cornflow/cli/tools/endpoint_tools.py +191 -196
- cornflow/cli/tools/models_tools.py +87 -60
- cornflow/cli/tools/schema_generator.py +161 -67
- cornflow/cli/tools/schemas_tools.py +4 -5
- cornflow/cli/users.py +8 -0
- cornflow/cli/views.py +4 -0
- cornflow/commands/dag.py +3 -2
- cornflow/commands/schemas.py +6 -4
- cornflow/commands/users.py +12 -17
- cornflow/config.py +3 -2
- cornflow/endpoints/dag.py +27 -25
- cornflow/endpoints/data_check.py +102 -164
- cornflow/endpoints/example_data.py +9 -3
- cornflow/endpoints/execution.py +27 -23
- cornflow/endpoints/health.py +4 -5
- cornflow/endpoints/instance.py +39 -12
- cornflow/endpoints/meta_resource.py +4 -5
- cornflow/schemas/execution.py +1 -0
- cornflow/shared/airflow.py +157 -0
- cornflow/shared/authentication/auth.py +73 -42
- cornflow/shared/const.py +9 -0
- cornflow/shared/databricks.py +10 -10
- cornflow/shared/exceptions.py +3 -1
- cornflow/shared/utils_tables.py +36 -8
- cornflow/shared/validators.py +1 -1
- cornflow/tests/const.py +1 -0
- cornflow/tests/custom_test_case.py +4 -4
- cornflow/tests/unit/test_alarms.py +1 -2
- cornflow/tests/unit/test_cases.py +4 -7
- cornflow/tests/unit/test_executions.py +105 -43
- cornflow/tests/unit/test_log_in.py +46 -9
- cornflow/tests/unit/test_tables.py +3 -3
- cornflow/tests/unit/tools.py +31 -13
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/METADATA +2 -2
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/RECORD +48 -47
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/WHEEL +1 -1
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/entry_points.txt +0 -0
- {cornflow-2.0.0a12.dist-info → cornflow-2.0.0a14.dist-info}/top_level.txt +0 -0
cornflow/shared/databricks.py
CHANGED
@@ -48,12 +48,16 @@ class Databricks:
|
|
48
48
|
oauth_token = oauth_response.json()["access_token"]
|
49
49
|
return oauth_token
|
50
50
|
|
51
|
-
def is_alive(self):
|
51
|
+
def is_alive(self, config=None):
|
52
52
|
try:
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
53
|
+
if config is None or config["DATABRICKS_HEALTH_PATH"] == "default path":
|
54
|
+
# We raise an error because the default path is not valid
|
55
|
+
raise DatabricksError(
|
56
|
+
"Invalid default path. Please set DATABRICKS_HEALTH_PATH as an environment variable"
|
57
|
+
)
|
58
|
+
else:
|
59
|
+
path = config["DATABRICKS_HEALTH_PATH"]
|
60
|
+
|
57
61
|
url = f"{self.url}/api/2.0/workspace/get-status?path={path}"
|
58
62
|
response = self.request_headers_auth(method="GET", url=url)
|
59
63
|
if "error_code" in response.json().keys():
|
@@ -75,8 +79,6 @@ class Databricks:
|
|
75
79
|
raise DatabricksError("JOB not available")
|
76
80
|
return schema_info
|
77
81
|
|
78
|
-
# TODO AGA: incluir un id de job por defecto o hacer obligatorio el uso el parámetro.
|
79
|
-
# Revisar los efectos secundarios de eliminar execution_id y usar el predeterminado
|
80
82
|
def run_workflow(
|
81
83
|
self,
|
82
84
|
execution_id,
|
@@ -87,13 +89,11 @@ class Databricks:
|
|
87
89
|
"""
|
88
90
|
Run a job in Databricks
|
89
91
|
"""
|
90
|
-
# TODO AGA: revisar si la url esta bien/si acepta asi los parámetros
|
91
92
|
url = f"{self.url}/api/2.1/jobs/run-now/"
|
92
|
-
# TODO AGA: revisar si deben ser notebook parameters o job parameters.
|
93
93
|
# Entender cómo se usa checks_only
|
94
94
|
payload = dict(
|
95
95
|
job_id=orch_name,
|
96
|
-
|
96
|
+
job_parameters=dict(
|
97
97
|
checks_only=checks_only,
|
98
98
|
execution_id=execution_id,
|
99
99
|
),
|
cornflow/shared/exceptions.py
CHANGED
@@ -9,6 +9,8 @@ from cornflow_client.constants import AirflowError, DatabricksError
|
|
9
9
|
from werkzeug.exceptions import HTTPException
|
10
10
|
import traceback
|
11
11
|
|
12
|
+
from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
|
13
|
+
|
12
14
|
|
13
15
|
class InvalidUsage(Exception):
|
14
16
|
"""
|
@@ -48,7 +50,7 @@ class ObjectDoesNotExist(InvalidUsage):
|
|
48
50
|
"""
|
49
51
|
|
50
52
|
status_code = 404
|
51
|
-
error =
|
53
|
+
error = DATA_DOES_NOT_EXIST_MSG
|
52
54
|
|
53
55
|
|
54
56
|
class ObjectAlreadyExists(InvalidUsage):
|
cornflow/shared/utils_tables.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
import inspect
|
3
3
|
import os
|
4
4
|
import sys
|
5
|
+
import collections
|
5
6
|
|
6
7
|
from importlib import import_module
|
7
8
|
from sqlalchemy.dialects.postgresql import TEXT
|
@@ -31,14 +32,42 @@ def import_models():
|
|
31
32
|
|
32
33
|
|
33
34
|
def all_subclasses(cls, models=None):
|
34
|
-
subclasses
|
35
|
+
"""Finds all direct and indirect subclasses of a given class.
|
36
|
+
|
37
|
+
Optionally filters a provided list of models first.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
cls: The base class to find subclasses for.
|
41
|
+
models: An optional iterable of classes to pre-filter.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
A set containing all subclasses found.
|
45
|
+
"""
|
46
|
+
filtered_subclasses = set()
|
35
47
|
if models is not None:
|
36
48
|
for val in models:
|
37
|
-
|
38
|
-
|
49
|
+
# Ensure val is a class before checking issubclass
|
50
|
+
if isinstance(val, type) and issubclass(val, cls):
|
51
|
+
filtered_subclasses.add(val)
|
52
|
+
|
53
|
+
all_descendants = set()
|
54
|
+
# Use a deque for efficient pop(0)
|
55
|
+
queue = collections.deque(cls.__subclasses__())
|
56
|
+
# Keep track of visited classes during the traversal to handle potential complex hierarchies
|
57
|
+
# (though direct subclass relationships shouldn't form cycles)
|
58
|
+
# Initialize with direct subclasses as they are the starting point.
|
59
|
+
visited_for_queue = set(cls.__subclasses__())
|
60
|
+
|
61
|
+
while queue:
|
62
|
+
current_sub = queue.popleft()
|
63
|
+
all_descendants.add(current_sub)
|
64
|
+
|
65
|
+
for grandchild in current_sub.__subclasses__():
|
66
|
+
if grandchild not in visited_for_queue:
|
67
|
+
visited_for_queue.add(grandchild)
|
68
|
+
queue.append(grandchild)
|
39
69
|
|
40
|
-
return
|
41
|
-
[s for c in cls.__subclasses__() for s in all_subclasses(c)]))
|
70
|
+
return filtered_subclasses.union(all_descendants)
|
42
71
|
|
43
72
|
|
44
73
|
type_converter = {
|
@@ -76,6 +105,5 @@ def item_as_dict(item):
|
|
76
105
|
|
77
106
|
def items_as_dict_list(ls):
|
78
107
|
return [
|
79
|
-
{c.name: getattr(item, c.name) for c in item.__table__.columns}
|
80
|
-
|
81
|
-
]
|
108
|
+
{c.name: getattr(item, c.name) for c in item.__table__.columns} for item in ls
|
109
|
+
]
|
cornflow/shared/validators.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
"""
|
2
2
|
This file has several validators
|
3
3
|
"""
|
4
|
+
|
4
5
|
import re
|
5
6
|
from typing import Tuple, Union
|
6
7
|
|
7
8
|
from jsonschema import Draft7Validator, validators
|
8
9
|
from disposable_email_domains import blocklist
|
9
|
-
from jsonschema.protocols import Validator
|
10
10
|
|
11
11
|
|
12
12
|
def is_special_character(character):
|
cornflow/tests/const.py
CHANGED
@@ -15,6 +15,7 @@ INSTANCE_URL = PREFIX + "/instance/"
|
|
15
15
|
INSTANCE_MPS = _get_file("./data/test_mps.mps")
|
16
16
|
INSTANCE_GC_20 = _get_file("./data/gc_20_7.json")
|
17
17
|
INSTANCE_FILE_FAIL = _get_file("./unit/test_instances.py")
|
18
|
+
EDIT_EXECUTION_SOLUTION = _get_file("./data/edit_execution_solution.json")
|
18
19
|
|
19
20
|
EXECUTION_PATH = _get_file("./data/new_execution.json")
|
20
21
|
BAD_EXECUTION_PATH = _get_file("./data/bad_execution.json")
|
@@ -635,9 +635,9 @@ class BaseTestCases:
|
|
635
635
|
# (we patch the request to airflow to check if the schema is valid)
|
636
636
|
# we create 4 instances
|
637
637
|
data_many = [self.payload for _ in range(4)]
|
638
|
-
|
639
|
-
|
640
|
-
self.apply_filter(self.url, dict(schema="timer"),
|
638
|
+
|
639
|
+
self.get_rows(self.url, data_many)
|
640
|
+
self.apply_filter(self.url, dict(schema="timer"), [])
|
641
641
|
|
642
642
|
def test_opt_filters_date_lte(self):
|
643
643
|
"""
|
@@ -1117,7 +1117,7 @@ class LoginTestCases:
|
|
1117
1117
|
|
1118
1118
|
self.assertEqual(400, response.status_code)
|
1119
1119
|
self.assertEqual(
|
1120
|
-
"Invalid token
|
1120
|
+
"Invalid token format or signature",
|
1121
1121
|
response.json["error"],
|
1122
1122
|
)
|
1123
1123
|
|
@@ -91,7 +91,6 @@ class TestAlarmsEndpoint(CustomTestCase):
|
|
91
91
|
class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint):
|
92
92
|
def setUp(self):
|
93
93
|
super().setUp()
|
94
|
-
self.url = self.url
|
95
94
|
self.idx = 0
|
96
95
|
self.payload = {
|
97
96
|
"name": "Alarm 1",
|
@@ -138,4 +137,4 @@ class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint)
|
|
138
137
|
# We check deleted at has a value
|
139
138
|
self.assertIsNotNone(row.deleted_at)
|
140
139
|
else:
|
141
|
-
self.assertIsNone(row.deleted_at)
|
140
|
+
self.assertIsNone(row.deleted_at)
|
@@ -42,6 +42,7 @@ import zlib
|
|
42
42
|
|
43
43
|
# Import from internal modules
|
44
44
|
from cornflow.models import CaseModel, ExecutionModel, InstanceModel, UserModel
|
45
|
+
from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
|
45
46
|
from cornflow.shared.utils import hash_json_256
|
46
47
|
from cornflow.tests.const import (
|
47
48
|
INSTANCE_URL,
|
@@ -227,12 +228,8 @@ class TestCasesFromInstanceExecutionEndpoint(CustomTestCase):
|
|
227
228
|
"execution_id": execution_id,
|
228
229
|
"schema": "solve_model_dag",
|
229
230
|
}
|
230
|
-
self.instance = InstanceModel.get_one_object(
|
231
|
-
|
232
|
-
)
|
233
|
-
self.execution = ExecutionModel.get_one_object(
|
234
|
-
user=self.user, idx=execution_id
|
235
|
-
)
|
231
|
+
self.instance = InstanceModel.get_one_object(user=self.user, idx=instance_id)
|
232
|
+
self.execution = ExecutionModel.get_one_object(user=self.user, idx=execution_id)
|
236
233
|
|
237
234
|
def test_new_case_execution(self):
|
238
235
|
"""
|
@@ -729,7 +726,7 @@ class TestCaseToInstanceEndpoint(CustomTestCase):
|
|
729
726
|
headers=self.get_header_with_auth(self.token),
|
730
727
|
)
|
731
728
|
self.assertEqual(response.status_code, 404)
|
732
|
-
self.assertEqual(response.json["error"],
|
729
|
+
self.assertEqual(response.json["error"], DATA_DOES_NOT_EXIST_MSG)
|
733
730
|
|
734
731
|
|
735
732
|
class TestCaseJsonPatch(CustomTestCase):
|
@@ -20,6 +20,7 @@ from cornflow.tests.const import (
|
|
20
20
|
DAG_URL,
|
21
21
|
BAD_EXECUTION_PATH,
|
22
22
|
EXECUTION_SOLUTION_PATH,
|
23
|
+
EDIT_EXECUTION_SOLUTION,
|
23
24
|
)
|
24
25
|
from cornflow.tests.custom_test_case import CustomTestCase, BaseTestCases
|
25
26
|
from cornflow.tests.unit.tools import patch_af_client, patch_db_client
|
@@ -186,7 +187,9 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
|
|
186
187
|
app = create_app("testing-databricks")
|
187
188
|
return app
|
188
189
|
|
189
|
-
|
190
|
+
@patch("cornflow.endpoints.execution.Databricks")
|
191
|
+
def test_new_execution(self, db_client_class):
|
192
|
+
patch_db_client(db_client_class)
|
190
193
|
self.create_new_row(self.url, self.model, payload=self.payload)
|
191
194
|
|
192
195
|
@patch("cornflow.endpoints.execution.Databricks")
|
@@ -231,7 +234,6 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
|
|
231
234
|
|
232
235
|
@patch("cornflow.endpoints.execution.Databricks")
|
233
236
|
def test_new_execution_with_solution_bad(self, db_client_class):
|
234
|
-
patch_db_client(db_client_class)
|
235
237
|
patch_db_client(db_client_class)
|
236
238
|
self.payload["data"] = {"message": "THIS IS NOT A VALID SOLUTION"}
|
237
239
|
response = self.create_new_row(
|
@@ -244,7 +246,9 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
|
|
244
246
|
self.assertIn("error", response)
|
245
247
|
self.assertIn("jsonschema_errors", response)
|
246
248
|
|
247
|
-
|
249
|
+
@patch("cornflow.endpoints.execution.Databricks")
|
250
|
+
def test_new_execution_no_instance(self, db_client_class):
|
251
|
+
patch_db_client(db_client_class)
|
248
252
|
payload = dict(self.payload)
|
249
253
|
payload["instance_id"] = "bad_id"
|
250
254
|
response = self.client.post(
|
@@ -256,13 +260,19 @@ class TestExecutionsListEndpointDatabricks(BaseTestCases.ListFilters):
|
|
256
260
|
self.assertEqual(404, response.status_code)
|
257
261
|
self.assertTrue("error" in response.json)
|
258
262
|
|
259
|
-
|
263
|
+
@patch("cornflow.endpoints.execution.Databricks")
|
264
|
+
def test_get_executions(self, db_client_class):
|
265
|
+
patch_db_client(db_client_class)
|
260
266
|
self.get_rows(self.url, self.payloads, keys_to_check=self.keys_to_check)
|
261
267
|
|
262
|
-
|
268
|
+
@patch("cornflow.endpoints.execution.Databricks")
|
269
|
+
def test_get_no_executions(self, db_client_class):
|
270
|
+
patch_db_client(db_client_class)
|
263
271
|
self.get_no_rows(self.url)
|
264
272
|
|
265
|
-
|
273
|
+
@patch("cornflow.endpoints.execution.Databricks")
|
274
|
+
def test_get_executions_superadmin(self, db_client_class):
|
275
|
+
patch_db_client(db_client_class)
|
266
276
|
self.get_rows(self.url, self.payloads, keys_to_check=self.keys_to_check)
|
267
277
|
token = self.create_service_user()
|
268
278
|
rows = self.client.get(
|
@@ -473,6 +483,7 @@ class TestExecutionsDetailEndpointMock(CustomTestCase):
|
|
473
483
|
with open(INSTANCE_PATH) as f:
|
474
484
|
payload = json.load(f)
|
475
485
|
fk_id = self.create_new_row(INSTANCE_URL, InstanceModel, payload)
|
486
|
+
self.instance_payload = payload
|
476
487
|
self.model = ExecutionModel
|
477
488
|
self.response_items = {
|
478
489
|
"id",
|
@@ -501,7 +512,6 @@ class TestExecutionsDetailEndpointAirflow(
|
|
501
512
|
):
|
502
513
|
def setUp(self):
|
503
514
|
super().setUp()
|
504
|
-
self.url = self.url
|
505
515
|
self.query_arguments = {"run": 0}
|
506
516
|
|
507
517
|
def create_app(self):
|
@@ -523,51 +533,41 @@ class TestExecutionsDetailEndpointAirflow(
|
|
523
533
|
self.assertEqual(200, response.status_code)
|
524
534
|
self.assertEqual(response.json["message"], "The execution has been stopped")
|
525
535
|
|
536
|
+
def test_edit_execution(self):
|
526
537
|
|
527
|
-
|
528
|
-
|
529
|
-
):
|
530
|
-
def setUp(self):
|
531
|
-
super().setUp()
|
532
|
-
self.url = self.url
|
533
|
-
self.query_arguments = {"run": 0}
|
534
|
-
|
535
|
-
def create_app(self):
|
536
|
-
app = create_app("testing-databricks")
|
537
|
-
return app
|
538
|
-
|
539
|
-
@patch("cornflow.endpoints.execution.Databricks")
|
540
|
-
def test_stop_execution(self, db_client_class):
|
541
|
-
patch_db_client(db_client_class)
|
542
|
-
|
543
|
-
idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
|
544
|
-
|
545
|
-
response = self.client.post(
|
546
|
-
self.url + str(idx) + "/",
|
547
|
-
follow_redirects=True,
|
548
|
-
headers=self.get_header_with_auth(self.token),
|
538
|
+
id_new_instance = self.create_new_row(
|
539
|
+
INSTANCE_URL, InstanceModel, self.instance_payload
|
549
540
|
)
|
550
|
-
|
551
|
-
|
552
|
-
self.assertEqual(
|
553
|
-
response.json["message"], "This feature is not available for Databricks"
|
541
|
+
idx = self.create_new_row(
|
542
|
+
self.url_with_query_arguments(), self.model, self.payload
|
554
543
|
)
|
555
544
|
|
545
|
+
# Extract the data from data/edit_execution_solution.json
|
546
|
+
with open(EDIT_EXECUTION_SOLUTION) as f:
|
547
|
+
data = json.load(f)
|
556
548
|
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
549
|
+
data = {
|
550
|
+
"name": "new_name",
|
551
|
+
"description": "Updated description",
|
552
|
+
"data": data,
|
553
|
+
"instance_id": id_new_instance,
|
554
|
+
}
|
555
|
+
payload_to_check = {
|
556
|
+
"id": idx,
|
557
|
+
"name": "new_name",
|
558
|
+
"description": "Updated description",
|
559
|
+
"data_hash": "74234e98afe7498fb5daf1f36ac2d78acc339464f950703b8c019892f982b90b",
|
560
|
+
"instance_id": "805bad3280c95e45384dc6bd91a41317f9a7858c",
|
561
|
+
}
|
562
|
+
self.update_row(
|
563
|
+
self.url + str(idx) + "/",
|
564
|
+
data,
|
565
|
+
payload_to_check,
|
566
|
+
)
|
566
567
|
|
567
568
|
@patch("cornflow.endpoints.execution.Airflow")
|
568
569
|
def test_get_one_status(self, af_client_class):
|
569
570
|
patch_af_client(af_client_class)
|
570
|
-
|
571
571
|
idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
|
572
572
|
payload = dict(self.payload)
|
573
573
|
payload["id"] = idx
|
@@ -598,6 +598,68 @@ class TestExecutionsStatusEndpointAirflow(TestExecutionsDetailEndpointMock):
|
|
598
598
|
self.assertEqual(f"execution {idx} updated correctly", response.json["message"])
|
599
599
|
|
600
600
|
|
601
|
+
class TestExecutionsDetailEndpointDatabricks(
|
602
|
+
TestExecutionsDetailEndpointMock, BaseTestCases.DetailEndpoint
|
603
|
+
):
|
604
|
+
def setUp(self):
|
605
|
+
super().setUp()
|
606
|
+
self.url = self.url
|
607
|
+
self.query_arguments = {"run": 0}
|
608
|
+
|
609
|
+
def create_app(self):
|
610
|
+
app = create_app("testing-databricks")
|
611
|
+
return app
|
612
|
+
|
613
|
+
@patch("cornflow.endpoints.execution.Databricks")
|
614
|
+
def test_stop_execution(self, db_client_class):
|
615
|
+
patch_db_client(db_client_class)
|
616
|
+
|
617
|
+
idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
|
618
|
+
|
619
|
+
response = self.client.post(
|
620
|
+
self.url + str(idx) + "/",
|
621
|
+
follow_redirects=True,
|
622
|
+
headers=self.get_header_with_auth(self.token),
|
623
|
+
)
|
624
|
+
|
625
|
+
self.assertEqual(200, response.status_code)
|
626
|
+
self.assertEqual(
|
627
|
+
response.json["message"], "This feature is not available for Databricks"
|
628
|
+
)
|
629
|
+
|
630
|
+
def test_edit_execution(self):
|
631
|
+
|
632
|
+
id_new_instance = self.create_new_row(
|
633
|
+
INSTANCE_URL, InstanceModel, self.instance_payload
|
634
|
+
)
|
635
|
+
idx = self.create_new_row(
|
636
|
+
self.url_with_query_arguments(), self.model, self.payload
|
637
|
+
)
|
638
|
+
|
639
|
+
# Extract the data from data/edit_execution_solution.json
|
640
|
+
with open(EDIT_EXECUTION_SOLUTION) as f:
|
641
|
+
data = json.load(f)
|
642
|
+
|
643
|
+
data = {
|
644
|
+
"name": "new_name",
|
645
|
+
"description": "Updated description",
|
646
|
+
"data": data,
|
647
|
+
"instance_id": id_new_instance,
|
648
|
+
}
|
649
|
+
payload_to_check = {
|
650
|
+
"id": idx,
|
651
|
+
"name": "new_name",
|
652
|
+
"description": "Updated description",
|
653
|
+
"data_hash": "74234e98afe7498fb5daf1f36ac2d78acc339464f950703b8c019892f982b90b",
|
654
|
+
"instance_id": "805bad3280c95e45384dc6bd91a41317f9a7858c",
|
655
|
+
}
|
656
|
+
self.update_row(
|
657
|
+
self.url + str(idx) + "/",
|
658
|
+
data,
|
659
|
+
payload_to_check,
|
660
|
+
)
|
661
|
+
|
662
|
+
|
601
663
|
class TestExecutionsStatusEndpointDatabricks(TestExecutionsDetailEndpointMock):
|
602
664
|
def setUp(self):
|
603
665
|
super().setUp()
|
@@ -135,11 +135,20 @@ class TestLogInOpenAuth(CustomTestCase):
|
|
135
135
|
Tests token validation failure when the kid is not found in public keys
|
136
136
|
"""
|
137
137
|
# Import the real exceptions to ensure they are preserved
|
138
|
-
from jwt.exceptions import
|
138
|
+
from jwt.exceptions import (
|
139
|
+
InvalidTokenError,
|
140
|
+
ExpiredSignatureError,
|
141
|
+
InvalidIssuerError,
|
142
|
+
InvalidSignatureError,
|
143
|
+
DecodeError,
|
144
|
+
)
|
139
145
|
|
140
146
|
# Keep the real exception classes in the mock
|
141
147
|
mock_jwt.ExpiredSignatureError = ExpiredSignatureError
|
142
148
|
mock_jwt.InvalidTokenError = InvalidTokenError
|
149
|
+
mock_jwt.InvalidIssuerError = InvalidIssuerError
|
150
|
+
mock_jwt.InvalidSignatureError = InvalidSignatureError
|
151
|
+
mock_jwt.DecodeError = DecodeError
|
143
152
|
|
144
153
|
mock_jwt.get_unverified_header.return_value = {"kid": "test_kid"}
|
145
154
|
|
@@ -165,7 +174,7 @@ class TestLogInOpenAuth(CustomTestCase):
|
|
165
174
|
|
166
175
|
self.assertEqual(400, response.status_code)
|
167
176
|
self.assertEqual(
|
168
|
-
response.json["error"], "Invalid token
|
177
|
+
response.json["error"], "Invalid token format, signature, or configuration"
|
169
178
|
)
|
170
179
|
|
171
180
|
@mock.patch("cornflow.shared.authentication.auth.jwt")
|
@@ -174,11 +183,20 @@ class TestLogInOpenAuth(CustomTestCase):
|
|
174
183
|
Tests token validation failure when the token header is missing the kid
|
175
184
|
"""
|
176
185
|
# Import the real exceptions to ensure they are preserved
|
177
|
-
from jwt.exceptions import
|
186
|
+
from jwt.exceptions import (
|
187
|
+
InvalidTokenError,
|
188
|
+
ExpiredSignatureError,
|
189
|
+
InvalidIssuerError,
|
190
|
+
InvalidSignatureError,
|
191
|
+
DecodeError,
|
192
|
+
)
|
178
193
|
|
179
194
|
# Keep the real exception classes in the mock
|
180
195
|
mock_jwt.ExpiredSignatureError = ExpiredSignatureError
|
181
196
|
mock_jwt.InvalidTokenError = InvalidTokenError
|
197
|
+
mock_jwt.InvalidIssuerError = InvalidIssuerError
|
198
|
+
mock_jwt.InvalidSignatureError = InvalidSignatureError
|
199
|
+
mock_jwt.DecodeError = DecodeError
|
182
200
|
|
183
201
|
# Mock jwt.get_unverified_header to return a header without kid
|
184
202
|
mock_jwt.get_unverified_header.return_value = {"alg": "RS256"}
|
@@ -194,8 +212,7 @@ class TestLogInOpenAuth(CustomTestCase):
|
|
194
212
|
|
195
213
|
self.assertEqual(400, response.status_code)
|
196
214
|
self.assertEqual(
|
197
|
-
response.json["error"]
|
198
|
-
"Invalid token: Missing key identifier (kid) in token header",
|
215
|
+
"Invalid token format, signature, or configuration", response.json["error"]
|
199
216
|
)
|
200
217
|
|
201
218
|
@mock.patch("cornflow.shared.authentication.auth.requests.get")
|
@@ -205,11 +222,20 @@ class TestLogInOpenAuth(CustomTestCase):
|
|
205
222
|
Tests failure when trying to fetch public keys from the OIDC provider
|
206
223
|
"""
|
207
224
|
# Import the real exceptions to ensure they are preserved
|
208
|
-
from jwt.exceptions import
|
225
|
+
from jwt.exceptions import (
|
226
|
+
InvalidTokenError,
|
227
|
+
ExpiredSignatureError,
|
228
|
+
InvalidIssuerError,
|
229
|
+
InvalidSignatureError,
|
230
|
+
DecodeError,
|
231
|
+
)
|
209
232
|
|
210
233
|
# Keep the real exception classes in the mock
|
211
234
|
mock_jwt.ExpiredSignatureError = ExpiredSignatureError
|
212
235
|
mock_jwt.InvalidTokenError = InvalidTokenError
|
236
|
+
mock_jwt.InvalidIssuerError = InvalidIssuerError
|
237
|
+
mock_jwt.InvalidSignatureError = InvalidSignatureError
|
238
|
+
mock_jwt.DecodeError = DecodeError
|
213
239
|
|
214
240
|
# Clear the cache
|
215
241
|
from cornflow.shared.authentication.auth import public_keys_cache
|
@@ -240,8 +266,8 @@ class TestLogInOpenAuth(CustomTestCase):
|
|
240
266
|
|
241
267
|
self.assertEqual(400, response.status_code)
|
242
268
|
self.assertEqual(
|
269
|
+
"Invalid token format, signature, or configuration",
|
243
270
|
response.json["error"],
|
244
|
-
"Failed to fetch public keys from authentication provider",
|
245
271
|
)
|
246
272
|
|
247
273
|
@mock.patch("cornflow.shared.authentication.Auth.decode_token")
|
@@ -289,11 +315,20 @@ class TestLogInOpenAuth(CustomTestCase):
|
|
289
315
|
the system fetches fresh keys from the provider.
|
290
316
|
"""
|
291
317
|
# Import the real exceptions to ensure they are preserved
|
292
|
-
from jwt.exceptions import
|
318
|
+
from jwt.exceptions import (
|
319
|
+
InvalidTokenError,
|
320
|
+
ExpiredSignatureError,
|
321
|
+
InvalidIssuerError,
|
322
|
+
InvalidSignatureError,
|
323
|
+
DecodeError,
|
324
|
+
)
|
293
325
|
|
294
326
|
# Keep the real exception classes in the mock
|
295
327
|
mock_jwt.ExpiredSignatureError = ExpiredSignatureError
|
296
328
|
mock_jwt.InvalidTokenError = InvalidTokenError
|
329
|
+
mock_jwt.InvalidIssuerError = InvalidIssuerError
|
330
|
+
mock_jwt.InvalidSignatureError = InvalidSignatureError
|
331
|
+
mock_jwt.DecodeError = DecodeError
|
297
332
|
|
298
333
|
# Mock jwt to return valid unverified header and payload
|
299
334
|
mock_jwt.get_unverified_header.return_value = {"kid": "test_kid"}
|
@@ -531,4 +566,6 @@ class TestLogInOpenAuthService(CustomTestCase):
|
|
531
566
|
)
|
532
567
|
|
533
568
|
self.assertEqual(400, response.status_code)
|
534
|
-
self.assertEqual(
|
569
|
+
self.assertEqual(
|
570
|
+
response.json["error"], "Invalid token format, signature, or configuration"
|
571
|
+
)
|
@@ -11,7 +11,7 @@ from cornflow.app import create_app
|
|
11
11
|
from cornflow.commands.access import access_init_command
|
12
12
|
from cornflow.models import UserRoleModel
|
13
13
|
from cornflow.shared import db
|
14
|
-
from cornflow.shared.const import ADMIN_ROLE, SERVICE_ROLE
|
14
|
+
from cornflow.shared.const import ADMIN_ROLE, SERVICE_ROLE, DATA_DOES_NOT_EXIST_MSG
|
15
15
|
from cornflow.tests.const import LOGIN_URL, SIGNUP_URL, TABLES_URL
|
16
16
|
|
17
17
|
|
@@ -220,7 +220,7 @@ class TestTablesDetailEndpoint(TestCase):
|
|
220
220
|
},
|
221
221
|
)
|
222
222
|
self.assertEqual(response.status_code, 404)
|
223
|
-
self.assertEqual(response.json["error"],
|
223
|
+
self.assertEqual(response.json["error"], DATA_DOES_NOT_EXIST_MSG)
|
224
224
|
|
225
225
|
|
226
226
|
class TestTablesEndpointAdmin(TestCase):
|
@@ -291,4 +291,4 @@ class TestTablesEndpointAdmin(TestCase):
|
|
291
291
|
self.assertEqual(response.status_code, 403)
|
292
292
|
self.assertEqual(
|
293
293
|
response.json["error"], "You do not have permission to access this endpoint"
|
294
|
-
)
|
294
|
+
)
|
cornflow/tests/unit/tools.py
CHANGED
@@ -3,7 +3,10 @@ from flask import Flask
|
|
3
3
|
import json
|
4
4
|
import os
|
5
5
|
|
6
|
-
from cornflow.shared.const import
|
6
|
+
from cornflow.shared.const import (
|
7
|
+
DATABRICKS_TERMINATE_STATE,
|
8
|
+
DATABRICKS_FINISH_TO_STATE_MAP,
|
9
|
+
)
|
7
10
|
|
8
11
|
|
9
12
|
def create_test_app():
|
@@ -13,9 +16,7 @@ def create_test_app():
|
|
13
16
|
|
14
17
|
|
15
18
|
def patch_af_client(af_client_class):
|
16
|
-
with patch(
|
17
|
-
"cornflow.endpoints.execution.current_app.config"
|
18
|
-
) as mock_config:
|
19
|
+
with patch("cornflow.endpoints.execution.current_app.config") as mock_config:
|
19
20
|
mock_config.__getitem__.side_effect = lambda key: (
|
20
21
|
1 if key == "CORNFLOW_BACKEND" else {}
|
21
22
|
)
|
@@ -27,9 +28,24 @@ def patch_af_client(af_client_class):
|
|
27
28
|
"state": "success",
|
28
29
|
}
|
29
30
|
af_client_mock.is_alive.return_value = True
|
30
|
-
af_client_mock.
|
31
|
-
|
32
|
-
|
31
|
+
af_client_mock.get_dag_info.return_value = responses_mock
|
32
|
+
|
33
|
+
# Configurar get_orch_info
|
34
|
+
schema_info_mock = Mock()
|
35
|
+
schema_info_mock.json.return_value = {"is_paused": False}
|
36
|
+
af_client_mock.get_orch_info.return_value = schema_info_mock
|
37
|
+
|
38
|
+
# Configurar run_workflow
|
39
|
+
run_response_mock = Mock()
|
40
|
+
run_response_mock.json.return_value = {"dag_run_id": "12345"}
|
41
|
+
af_client_mock.run_workflow.return_value = run_response_mock
|
42
|
+
|
43
|
+
# Configurar get_run_status
|
44
|
+
status_mock = Mock()
|
45
|
+
status_mock.json.return_value = {"state": "success"}
|
46
|
+
af_client_mock.get_run_status.return_value = status_mock
|
47
|
+
|
48
|
+
af_client_mock.run_dag.return_value = responses_mock
|
33
49
|
af_client_mock.set_dag_run_to_fail.return_value = None
|
34
50
|
af_client_class.from_config.return_value = af_client_mock
|
35
51
|
|
@@ -37,9 +53,7 @@ def patch_af_client(af_client_class):
|
|
37
53
|
def patch_db_client(db_client_class):
|
38
54
|
mock_config = {"CORNFLOW_BACKEND": 2}
|
39
55
|
|
40
|
-
with patch(
|
41
|
-
"cornflow.endpoints.execution.current_app.config", mock_config
|
42
|
-
):
|
56
|
+
with patch("cornflow.endpoints.execution.current_app.config", mock_config):
|
43
57
|
db_client_mock = Mock()
|
44
58
|
responses_mock = Mock()
|
45
59
|
responses_mock.json.return_value = {
|
@@ -65,10 +79,14 @@ def patch_db_client(db_client_class):
|
|
65
79
|
state = response_get_run_status.json.return_value["status"]["state"]
|
66
80
|
if state == DATABRICKS_TERMINATE_STATE:
|
67
81
|
if (
|
68
|
-
|
69
|
-
|
82
|
+
response_get_run_status.json.return_value["status"][
|
83
|
+
"termination_details"
|
84
|
+
]["code"]
|
85
|
+
in DATABRICKS_FINISH_TO_STATE_MAP.keys()
|
70
86
|
):
|
71
|
-
response_get_run_status = response_get_run_status.json.return_value[
|
87
|
+
response_get_run_status = response_get_run_status.json.return_value[
|
88
|
+
"status"
|
89
|
+
]["termination_details"]["code"]
|
72
90
|
else:
|
73
91
|
response_get_run_status = "OTHER_FINISH_ERROR"
|
74
92
|
db_client_mock.is_alive.return_value = True
|