cornflow 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornflow/app.py +4 -2
- cornflow/cli/__init__.py +4 -0
- cornflow/cli/actions.py +4 -0
- cornflow/cli/config.py +4 -0
- cornflow/cli/migrations.py +13 -8
- cornflow/cli/permissions.py +4 -0
- cornflow/cli/roles.py +5 -1
- cornflow/cli/schemas.py +5 -0
- cornflow/cli/service.py +263 -131
- cornflow/cli/tools/api_generator.py +13 -10
- cornflow/cli/tools/endpoint_tools.py +191 -196
- cornflow/cli/tools/models_tools.py +87 -60
- cornflow/cli/tools/schema_generator.py +161 -67
- cornflow/cli/tools/schemas_tools.py +4 -5
- cornflow/cli/users.py +8 -0
- cornflow/cli/views.py +4 -0
- cornflow/commands/access.py +14 -3
- cornflow/commands/auxiliar.py +106 -0
- cornflow/commands/dag.py +3 -2
- cornflow/commands/permissions.py +186 -81
- cornflow/commands/roles.py +15 -14
- cornflow/commands/schemas.py +6 -4
- cornflow/commands/users.py +12 -17
- cornflow/commands/views.py +171 -41
- cornflow/endpoints/dag.py +27 -25
- cornflow/endpoints/data_check.py +128 -165
- cornflow/endpoints/example_data.py +9 -3
- cornflow/endpoints/execution.py +40 -34
- cornflow/endpoints/health.py +7 -7
- cornflow/endpoints/instance.py +39 -12
- cornflow/endpoints/meta_resource.py +4 -5
- cornflow/schemas/execution.py +9 -1
- cornflow/schemas/health.py +1 -0
- cornflow/shared/authentication/auth.py +76 -45
- cornflow/shared/const.py +10 -1
- cornflow/shared/exceptions.py +3 -1
- cornflow/shared/utils_tables.py +36 -8
- cornflow/shared/validators.py +1 -1
- cornflow/tests/const.py +1 -0
- cornflow/tests/custom_test_case.py +4 -4
- cornflow/tests/unit/test_alarms.py +1 -2
- cornflow/tests/unit/test_cases.py +4 -7
- cornflow/tests/unit/test_executions.py +22 -1
- cornflow/tests/unit/test_external_role_creation.py +785 -0
- cornflow/tests/unit/test_health.py +4 -1
- cornflow/tests/unit/test_log_in.py +46 -9
- cornflow/tests/unit/test_tables.py +3 -3
- {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/METADATA +2 -2
- {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/RECORD +52 -50
- {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/WHEEL +1 -1
- {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/entry_points.txt +0 -0
- {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/top_level.txt +0 -0
cornflow/endpoints/health.py
CHANGED
@@ -4,16 +4,15 @@ It performs a health check to airflow and a health check to cornflow database
|
|
4
4
|
"""
|
5
5
|
import os
|
6
6
|
|
7
|
-
# Import from libraries
|
8
|
-
from cornflow_client.airflow.api import Airflow
|
9
|
-
from flask import current_app
|
10
|
-
from flask_apispec import marshal_with, doc
|
11
|
-
|
12
7
|
# Import from internal modules
|
13
8
|
from cornflow.endpoints.meta_resource import BaseMetaResource
|
14
9
|
from cornflow.models import UserModel
|
15
10
|
from cornflow.schemas.health import HealthResponse
|
16
|
-
from cornflow.shared.const import STATUS_HEALTHY, STATUS_UNHEALTHY
|
11
|
+
from cornflow.shared.const import STATUS_HEALTHY, STATUS_UNHEALTHY, CORNFLOW_VERSION
|
12
|
+
# Import from libraries
|
13
|
+
from cornflow_client.airflow.api import Airflow
|
14
|
+
from flask import current_app
|
15
|
+
from flask_apispec import marshal_with, doc
|
17
16
|
|
18
17
|
|
19
18
|
class HealthEndpoint(BaseMetaResource):
|
@@ -30,6 +29,7 @@ class HealthEndpoint(BaseMetaResource):
|
|
30
29
|
af_client = Airflow.from_config(current_app.config)
|
31
30
|
airflow_status = STATUS_UNHEALTHY
|
32
31
|
cornflow_status = STATUS_UNHEALTHY
|
32
|
+
cornflow_version = CORNFLOW_VERSION
|
33
33
|
if af_client.is_alive():
|
34
34
|
airflow_status = STATUS_HEALTHY
|
35
35
|
|
@@ -42,4 +42,4 @@ class HealthEndpoint(BaseMetaResource):
|
|
42
42
|
current_app.logger.info(
|
43
43
|
f"Health check: cornflow {cornflow_status}, airflow {airflow_status}"
|
44
44
|
)
|
45
|
-
return {"cornflow_status": cornflow_status, "airflow_status": airflow_status}
|
45
|
+
return {"cornflow_status": cornflow_status, "airflow_status": airflow_status, "cornflow_version":cornflow_version}
|
cornflow/endpoints/instance.py
CHANGED
@@ -34,7 +34,6 @@ from cornflow.shared.exceptions import InvalidUsage, InvalidData
|
|
34
34
|
from cornflow.shared.validators import json_schema_validate_as_string
|
35
35
|
|
36
36
|
|
37
|
-
|
38
37
|
# Initialize the schema that all endpoints are going to use
|
39
38
|
ALLOWED_EXTENSIONS = {"mps", "lp"}
|
40
39
|
|
@@ -92,14 +91,18 @@ class InstanceEndpoint(BaseMetaResource):
|
|
92
91
|
# We validate the instance data
|
93
92
|
config = current_app.config
|
94
93
|
|
95
|
-
instance_schema = DeployedDAG.get_one_schema(
|
96
|
-
|
94
|
+
instance_schema = DeployedDAG.get_one_schema(
|
95
|
+
config, data_schema, INSTANCE_SCHEMA
|
96
|
+
)
|
97
|
+
instance_errors = json_schema_validate_as_string(
|
98
|
+
instance_schema, kwargs["data"]
|
99
|
+
)
|
97
100
|
|
98
101
|
if instance_errors:
|
99
102
|
raise InvalidData(
|
100
103
|
payload=dict(jsonschema_errors=instance_errors),
|
101
104
|
log_txt=f"Error while user {self.get_user()} tries to create an instance. "
|
102
|
-
|
105
|
+
f"Instance data do not match the jsonschema.",
|
103
106
|
)
|
104
107
|
|
105
108
|
# if we're here, we validated and the data seems to fit the schema
|
@@ -163,14 +166,18 @@ class InstanceDetailsEndpoint(InstanceDetailsEndpointBase):
|
|
163
166
|
|
164
167
|
config = current_app.config
|
165
168
|
|
166
|
-
instance_schema = DeployedDAG.get_one_schema(
|
167
|
-
|
169
|
+
instance_schema = DeployedDAG.get_one_schema(
|
170
|
+
config, schema, INSTANCE_SCHEMA
|
171
|
+
)
|
172
|
+
instance_errors = json_schema_validate_as_string(
|
173
|
+
instance_schema, kwargs["data"]
|
174
|
+
)
|
168
175
|
|
169
176
|
if instance_errors:
|
170
177
|
raise InvalidData(
|
171
178
|
payload=dict(jsonschema_errors=instance_errors),
|
172
179
|
log_txt=f"Error while user {self.get_user()} tries to create an instance. "
|
173
|
-
|
180
|
+
f"Instance data do not match the jsonschema.",
|
174
181
|
)
|
175
182
|
|
176
183
|
response = self.put_detail(data=kwargs, user=self.get_user(), idx=idx)
|
@@ -268,15 +275,35 @@ class InstanceFileEndpoint(BaseMetaResource):
|
|
268
275
|
sense = 1 if minimize else -1
|
269
276
|
try:
|
270
277
|
_vars, problem = pulp.LpProblem.fromMPS(filename, sense=sense)
|
271
|
-
except:
|
278
|
+
except FileNotFoundError as e:
|
279
|
+
# Handle file not found specifically
|
272
280
|
raise InvalidUsage(
|
273
|
-
error="
|
274
|
-
log_txt=f"Error
|
275
|
-
|
281
|
+
error=f"MPS file not found: {filename}",
|
282
|
+
log_txt=f"Error for user {self.get_user()}: MPS file '{filename}' not found. Details: {e}",
|
283
|
+
status_code=404,
|
284
|
+
) from e
|
285
|
+
except PermissionError as e:
|
286
|
+
# Handle permission issues
|
287
|
+
raise InvalidUsage(
|
288
|
+
error=f"Permission denied reading MPS file: {filename}",
|
289
|
+
log_txt=f"Error for user {self.get_user()}: Permission denied for MPS file '{filename}'. Details: {e}",
|
290
|
+
status_code=403,
|
291
|
+
) from e
|
292
|
+
except (ValueError, pulp.PulpError, OSError, IndexError) as e:
|
293
|
+
# Catch parsing errors, PuLP errors, and other IO errors
|
294
|
+
# Handle parsing, PuLP, or other OS errors
|
295
|
+
current_app.logger.error(
|
296
|
+
f"Error parsing MPS file {filename} for user {self.get_user()}: {e}",
|
297
|
+
exc_info=True,
|
276
298
|
)
|
299
|
+
raise InvalidUsage(
|
300
|
+
error="Error reading or parsing the MPS file.",
|
301
|
+
log_txt=f"Error while user {self.get_user()} tries to create instance from MPS file {filename}. Details: {e}",
|
302
|
+
) from e
|
303
|
+
|
277
304
|
try:
|
278
305
|
os.remove(filename)
|
279
|
-
except:
|
306
|
+
except FileNotFoundError:
|
280
307
|
pass
|
281
308
|
|
282
309
|
pb_data = dict(
|
@@ -1,6 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
This file has all the logic shared for all the resources
|
3
3
|
"""
|
4
|
+
|
4
5
|
# Import from external libraries
|
5
6
|
from flask_restful import Resource
|
6
7
|
from flask import g, request
|
@@ -13,7 +14,6 @@ from cornflow.shared.const import ALL_DEFAULT_ROLES
|
|
13
14
|
from cornflow.shared.exceptions import InvalidUsage, ObjectDoesNotExist, NoPermission
|
14
15
|
|
15
16
|
|
16
|
-
|
17
17
|
class BaseMetaResource(Resource, MethodResource):
|
18
18
|
"""
|
19
19
|
The base resource from all methods inherit from.
|
@@ -30,7 +30,6 @@ class BaseMetaResource(Resource, MethodResource):
|
|
30
30
|
self.auth_class = None
|
31
31
|
self.dependents = None
|
32
32
|
self.unique = None
|
33
|
-
pass
|
34
33
|
|
35
34
|
"""
|
36
35
|
METHODS USED FOR THE BASIC CRUD OPERATIONS: GET, POST, PUT, PATCH, DELETE
|
@@ -121,7 +120,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
121
120
|
"""
|
122
121
|
item = self.data_model.get_one_object(**kwargs)
|
123
122
|
if item is None:
|
124
|
-
raise ObjectDoesNotExist(
|
123
|
+
raise ObjectDoesNotExist()
|
125
124
|
|
126
125
|
data = dict(data)
|
127
126
|
|
@@ -144,7 +143,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
144
143
|
item = self.data_model.get_one_object(**kwargs)
|
145
144
|
|
146
145
|
if item is None:
|
147
|
-
raise ObjectDoesNotExist(
|
146
|
+
raise ObjectDoesNotExist()
|
148
147
|
|
149
148
|
data = dict(data)
|
150
149
|
|
@@ -164,7 +163,7 @@ class BaseMetaResource(Resource, MethodResource):
|
|
164
163
|
"""
|
165
164
|
item = self.data_model.get_one_object(**kwargs)
|
166
165
|
if item is None:
|
167
|
-
raise ObjectDoesNotExist(
|
166
|
+
raise ObjectDoesNotExist()
|
168
167
|
if self.dependents is not None:
|
169
168
|
for element in getattr(item, self.dependents):
|
170
169
|
element.delete()
|
cornflow/schemas/execution.py
CHANGED
@@ -95,7 +95,7 @@ class ExecutionDagPostRequest(ExecutionRequest, ExecutionDagRequest):
|
|
95
95
|
|
96
96
|
|
97
97
|
class ExecutionDetailsEndpointResponse(BaseDataEndpointResponse):
|
98
|
-
config = fields.
|
98
|
+
config = fields.Raw()
|
99
99
|
instance_id = fields.Str()
|
100
100
|
state = fields.Int()
|
101
101
|
message = fields.Str(attribute="state_message")
|
@@ -112,6 +112,14 @@ class ExecutionDetailsEndpointWithIndicatorsResponse(ExecutionDetailsEndpointRes
|
|
112
112
|
return indicators_string[1:-1]
|
113
113
|
|
114
114
|
indicators = fields.Method("get_indicators")
|
115
|
+
updated_at = fields.DateTime(dump_only=True)
|
116
|
+
|
117
|
+
def get_username(self, obj):
|
118
|
+
if hasattr(obj, "user") and obj.user is not None:
|
119
|
+
return obj.user.username
|
120
|
+
return None
|
121
|
+
|
122
|
+
username = fields.Method("get_username")
|
115
123
|
|
116
124
|
|
117
125
|
class ExecutionDetailsWithIndicatorsAndLogResponse(
|
cornflow/schemas/health.py
CHANGED
@@ -2,15 +2,16 @@
|
|
2
2
|
This file contains the auth class that can be used for authentication on the request to the REST API
|
3
3
|
"""
|
4
4
|
|
5
|
-
# Imports from external libraries
|
6
|
-
import jwt
|
7
|
-
import requests
|
8
|
-
from jwt.algorithms import RSAAlgorithm
|
9
5
|
from datetime import datetime, timedelta, timezone
|
10
|
-
from flask import request, g, current_app, Request
|
11
6
|
from functools import wraps
|
12
7
|
from typing import Tuple
|
8
|
+
|
9
|
+
# Imports from external libraries
|
10
|
+
import jwt
|
11
|
+
import requests
|
13
12
|
from cachetools import TTLCache
|
13
|
+
from flask import request, g, current_app, Request
|
14
|
+
from jwt.algorithms import RSAAlgorithm
|
14
15
|
from werkzeug.datastructures import Headers
|
15
16
|
|
16
17
|
# Imports from internal modules
|
@@ -31,7 +32,6 @@ from cornflow.shared.exceptions import (
|
|
31
32
|
InvalidData,
|
32
33
|
InvalidUsage,
|
33
34
|
NoPermission,
|
34
|
-
ObjectDoesNotExist,
|
35
35
|
)
|
36
36
|
|
37
37
|
# Cache for storing public keys with 1 hour TTL
|
@@ -122,15 +122,13 @@ class Auth:
|
|
122
122
|
def decode_token(token: str = None) -> dict:
|
123
123
|
"""
|
124
124
|
Decodes a given JSON Web token and extracts the username from the sub claim.
|
125
|
-
Works with both internal tokens and OpenID tokens.
|
125
|
+
Works with both internal tokens and OpenID tokens by attempting verification methods sequentially.
|
126
126
|
|
127
127
|
:param str token: the given JSON Web Token
|
128
128
|
:return: dictionary containing the username from the token's sub claim
|
129
129
|
:rtype: dict
|
130
130
|
"""
|
131
|
-
|
132
131
|
if token is None:
|
133
|
-
|
134
132
|
raise InvalidCredentials(
|
135
133
|
"Must provide a token in Authorization header",
|
136
134
|
log_txt="Error while trying to decode token. Token is missing.",
|
@@ -138,48 +136,81 @@ class Auth:
|
|
138
136
|
)
|
139
137
|
|
140
138
|
try:
|
141
|
-
#
|
142
|
-
|
143
|
-
|
144
|
-
issuer = unverified_payload.get("iss")
|
145
|
-
|
146
|
-
# For internal tokens
|
147
|
-
if issuer == INTERNAL_TOKEN_ISSUER:
|
148
|
-
|
149
|
-
return jwt.decode(
|
150
|
-
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
151
|
-
)
|
152
|
-
|
153
|
-
# For OpenID tokens
|
154
|
-
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
155
|
-
|
156
|
-
return Auth().verify_token(
|
157
|
-
token,
|
158
|
-
current_app.config["OID_PROVIDER"],
|
159
|
-
current_app.config["OID_EXPECTED_AUDIENCE"],
|
160
|
-
)
|
161
|
-
|
162
|
-
# If we get here, the issuer is not valid
|
163
|
-
|
164
|
-
raise InvalidCredentials(
|
165
|
-
"Invalid token issuer. Token must be issued by a valid provider",
|
166
|
-
log_txt="Error while trying to decode token. Invalid issuer.",
|
167
|
-
status_code=400,
|
139
|
+
# Attempt 1: Verify as an internal token (HS256)
|
140
|
+
payload = jwt.decode(
|
141
|
+
token, current_app.config["SECRET_TOKEN_KEY"], algorithms=["HS256"]
|
168
142
|
)
|
143
|
+
if payload.get("iss") != INTERNAL_TOKEN_ISSUER:
|
144
|
+
raise jwt.InvalidIssuerError(
|
145
|
+
"Internal token issuer mismatch after verification"
|
146
|
+
)
|
147
|
+
return payload
|
169
148
|
|
170
149
|
except jwt.ExpiredSignatureError:
|
171
|
-
|
150
|
+
# Handle expiration specifically, could apply to either token type if caught here first
|
172
151
|
raise InvalidCredentials(
|
173
152
|
"The token has expired, please login again",
|
174
153
|
log_txt="Error while trying to decode token. The token has expired.",
|
175
154
|
status_code=400,
|
176
155
|
)
|
177
|
-
except
|
178
|
-
|
156
|
+
except (
|
157
|
+
jwt.InvalidSignatureError,
|
158
|
+
jwt.DecodeError,
|
159
|
+
jwt.InvalidTokenError,
|
160
|
+
) as e_internal:
|
161
|
+
# Internal verification failed (signature, format, etc.). Try OIDC if configured.
|
162
|
+
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
163
|
+
try:
|
164
|
+
# Attempt 2: Verify as an OIDC token (RS256) using the dedicated method
|
165
|
+
return Auth().verify_token(
|
166
|
+
token,
|
167
|
+
current_app.config["OID_PROVIDER"],
|
168
|
+
current_app.config["OID_EXPECTED_AUDIENCE"],
|
169
|
+
)
|
170
|
+
except jwt.ExpiredSignatureError:
|
171
|
+
# OIDC token expired
|
172
|
+
raise InvalidCredentials(
|
173
|
+
"The token has expired, please login again",
|
174
|
+
log_txt="Error while trying to decode OIDC token. The token has expired.",
|
175
|
+
status_code=400,
|
176
|
+
)
|
177
|
+
except (
|
178
|
+
jwt.InvalidTokenError,
|
179
|
+
InvalidCredentials,
|
180
|
+
CommunicationError,
|
181
|
+
) as e_oidc:
|
182
|
+
# OIDC verification failed (JWT format, signature, kid, audience, issuer, comms error)
|
183
|
+
# Log details for debugging but return a generic error to the client.
|
184
|
+
log_message = (
|
185
|
+
f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
|
186
|
+
f"OIDC verification failed ({type(e_oidc).__name__}: {str(e_oidc)})."
|
187
|
+
)
|
188
|
+
current_app.logger.warning(log_message)
|
189
|
+
raise InvalidCredentials(
|
190
|
+
"Invalid token format, signature, or configuration",
|
191
|
+
log_txt=log_message,
|
192
|
+
status_code=400,
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
# Internal verification failed, and OIDC is not configured
|
196
|
+
log_message = (
|
197
|
+
f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
|
198
|
+
f"OIDC is not configured."
|
199
|
+
)
|
200
|
+
current_app.logger.warning(log_message)
|
201
|
+
raise InvalidCredentials(
|
202
|
+
"Invalid token format or signature",
|
203
|
+
log_txt=log_message,
|
204
|
+
status_code=400,
|
205
|
+
)
|
206
|
+
except Exception as e:
|
207
|
+
# Catch any other unexpected errors during the process
|
208
|
+
log_message = f"Unexpected error during token decoding: {str(e)}"
|
209
|
+
current_app.logger.error(log_message)
|
179
210
|
raise InvalidCredentials(
|
180
|
-
"
|
181
|
-
log_txt=
|
182
|
-
status_code=
|
211
|
+
"Could not decode or verify token due to an unexpected server error",
|
212
|
+
log_txt=log_message,
|
213
|
+
status_code=500,
|
183
214
|
)
|
184
215
|
|
185
216
|
def get_token_from_header(self, headers: Headers = None) -> str:
|
@@ -365,10 +396,10 @@ class Auth:
|
|
365
396
|
"The permission for this endpoint is not in the database."
|
366
397
|
)
|
367
398
|
raise NoPermission(
|
368
|
-
error="
|
399
|
+
error="The permission for this endpoint is not in the database.",
|
369
400
|
status_code=403,
|
370
401
|
log_txt=f"Error while user {user_id} tries to access endpoint. "
|
371
|
-
f"The
|
402
|
+
f"The permission for this endpoint is not in the database.",
|
372
403
|
)
|
373
404
|
|
374
405
|
for role in user_roles:
|
@@ -383,7 +414,7 @@ class Auth:
|
|
383
414
|
error="You do not have permission to access this endpoint",
|
384
415
|
status_code=403,
|
385
416
|
log_txt=f"Error while user {user_id} tries to access endpoint {view_id} with action {action_id}. "
|
386
|
-
f"The user does not permission to access. ",
|
417
|
+
f"The user does not have permission to access. ",
|
387
418
|
)
|
388
419
|
|
389
420
|
@staticmethod
|
cornflow/shared/const.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
In this file we import the values for different constants on cornflow server
|
3
3
|
"""
|
4
|
-
|
4
|
+
CORNFLOW_VERSION = "1.2.3"
|
5
5
|
INTERNAL_TOKEN_ISSUER = "cornflow"
|
6
6
|
|
7
7
|
# endpoints responses for health check
|
@@ -112,3 +112,12 @@ BASE_PERMISSION_ASSIGNATION = [
|
|
112
112
|
EXTRA_PERMISSION_ASSIGNATION = [
|
113
113
|
(VIEWER_ROLE, PUT_ACTION, "user-detail"),
|
114
114
|
]
|
115
|
+
|
116
|
+
# migrations constants
|
117
|
+
MIGRATIONS_DEFAULT_PATH = "./cornflow/migrations"
|
118
|
+
|
119
|
+
# Costants for messages that are given back on exceptions
|
120
|
+
AIRFLOW_NOT_REACHABLE_MSG = "Airflow is not reachable"
|
121
|
+
DAG_PAUSED_MSG = "The dag exists but it is paused in airflow"
|
122
|
+
AIRFLOW_ERROR_MSG = "Airflow responded with an error:"
|
123
|
+
DATA_DOES_NOT_EXIST_MSG = "The data entity does not exist on the database"
|
cornflow/shared/exceptions.py
CHANGED
@@ -9,6 +9,8 @@ from cornflow_client.constants import AirflowError
|
|
9
9
|
from werkzeug.exceptions import HTTPException
|
10
10
|
import traceback
|
11
11
|
|
12
|
+
from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
|
13
|
+
|
12
14
|
|
13
15
|
class InvalidUsage(Exception):
|
14
16
|
"""
|
@@ -48,7 +50,7 @@ class ObjectDoesNotExist(InvalidUsage):
|
|
48
50
|
"""
|
49
51
|
|
50
52
|
status_code = 404
|
51
|
-
error =
|
53
|
+
error = DATA_DOES_NOT_EXIST_MSG
|
52
54
|
|
53
55
|
|
54
56
|
class ObjectAlreadyExists(InvalidUsage):
|
cornflow/shared/utils_tables.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
import inspect
|
3
3
|
import os
|
4
4
|
import sys
|
5
|
+
import collections
|
5
6
|
|
6
7
|
from importlib import import_module
|
7
8
|
from sqlalchemy.dialects.postgresql import TEXT
|
@@ -31,14 +32,42 @@ def import_models():
|
|
31
32
|
|
32
33
|
|
33
34
|
def all_subclasses(cls, models=None):
|
34
|
-
subclasses
|
35
|
+
"""Finds all direct and indirect subclasses of a given class.
|
36
|
+
|
37
|
+
Optionally filters a provided list of models first.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
cls: The base class to find subclasses for.
|
41
|
+
models: An optional iterable of classes to pre-filter.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
A set containing all subclasses found.
|
45
|
+
"""
|
46
|
+
filtered_subclasses = set()
|
35
47
|
if models is not None:
|
36
48
|
for val in models:
|
37
|
-
|
38
|
-
|
49
|
+
# Ensure val is a class before checking issubclass
|
50
|
+
if isinstance(val, type) and issubclass(val, cls):
|
51
|
+
filtered_subclasses.add(val)
|
52
|
+
|
53
|
+
all_descendants = set()
|
54
|
+
# Use a deque for efficient pop(0)
|
55
|
+
queue = collections.deque(cls.__subclasses__())
|
56
|
+
# Keep track of visited classes during the traversal to handle potential complex hierarchies
|
57
|
+
# (though direct subclass relationships shouldn't form cycles)
|
58
|
+
# Initialize with direct subclasses as they are the starting point.
|
59
|
+
visited_for_queue = set(cls.__subclasses__())
|
60
|
+
|
61
|
+
while queue:
|
62
|
+
current_sub = queue.popleft()
|
63
|
+
all_descendants.add(current_sub)
|
64
|
+
|
65
|
+
for grandchild in current_sub.__subclasses__():
|
66
|
+
if grandchild not in visited_for_queue:
|
67
|
+
visited_for_queue.add(grandchild)
|
68
|
+
queue.append(grandchild)
|
39
69
|
|
40
|
-
return
|
41
|
-
[s for c in cls.__subclasses__() for s in all_subclasses(c)]))
|
70
|
+
return filtered_subclasses.union(all_descendants)
|
42
71
|
|
43
72
|
|
44
73
|
type_converter = {
|
@@ -76,6 +105,5 @@ def item_as_dict(item):
|
|
76
105
|
|
77
106
|
def items_as_dict_list(ls):
|
78
107
|
return [
|
79
|
-
{c.name: getattr(item, c.name) for c in item.__table__.columns}
|
80
|
-
|
81
|
-
]
|
108
|
+
{c.name: getattr(item, c.name) for c in item.__table__.columns} for item in ls
|
109
|
+
]
|
cornflow/shared/validators.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
"""
|
2
2
|
This file has several validators
|
3
3
|
"""
|
4
|
+
|
4
5
|
import re
|
5
6
|
from typing import Tuple, Union
|
6
7
|
|
7
8
|
from jsonschema import Draft7Validator, validators
|
8
9
|
from disposable_email_domains import blocklist
|
9
|
-
from jsonschema.protocols import Validator
|
10
10
|
|
11
11
|
|
12
12
|
def is_special_character(character):
|
cornflow/tests/const.py
CHANGED
@@ -17,6 +17,7 @@ INSTANCE_GC_20 = _get_file("./data/gc_20_7.json")
|
|
17
17
|
INSTANCE_FILE_FAIL = _get_file("./unit/test_instances.py")
|
18
18
|
|
19
19
|
EXECUTION_PATH = _get_file("./data/new_execution.json")
|
20
|
+
CUSTOM_CONFIG_PATH = _get_file("./data/new_execution_custom_config.json")
|
20
21
|
BAD_EXECUTION_PATH = _get_file("./data/bad_execution.json")
|
21
22
|
EXECUTION_SOLUTION_PATH = _get_file("./data/new_execution_solution.json")
|
22
23
|
EXECUTIONS_LIST = [EXECUTION_PATH, _get_file("./data/new_execution_2.json")]
|
@@ -635,9 +635,9 @@ class BaseTestCases:
|
|
635
635
|
# (we patch the request to airflow to check if the schema is valid)
|
636
636
|
# we create 4 instances
|
637
637
|
data_many = [self.payload for _ in range(4)]
|
638
|
-
|
639
|
-
|
640
|
-
self.apply_filter(self.url, dict(schema="timer"),
|
638
|
+
|
639
|
+
self.get_rows(self.url, data_many)
|
640
|
+
self.apply_filter(self.url, dict(schema="timer"), [])
|
641
641
|
|
642
642
|
def test_opt_filters_date_lte(self):
|
643
643
|
"""
|
@@ -1117,7 +1117,7 @@ class LoginTestCases:
|
|
1117
1117
|
|
1118
1118
|
self.assertEqual(400, response.status_code)
|
1119
1119
|
self.assertEqual(
|
1120
|
-
"Invalid token
|
1120
|
+
"Invalid token format or signature",
|
1121
1121
|
response.json["error"],
|
1122
1122
|
)
|
1123
1123
|
|
@@ -91,7 +91,6 @@ class TestAlarmsEndpoint(CustomTestCase):
|
|
91
91
|
class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint):
|
92
92
|
def setUp(self):
|
93
93
|
super().setUp()
|
94
|
-
self.url = self.url
|
95
94
|
self.idx = 0
|
96
95
|
self.payload = {
|
97
96
|
"name": "Alarm 1",
|
@@ -138,4 +137,4 @@ class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint)
|
|
138
137
|
# We check deleted at has a value
|
139
138
|
self.assertIsNotNone(row.deleted_at)
|
140
139
|
else:
|
141
|
-
self.assertIsNone(row.deleted_at)
|
140
|
+
self.assertIsNone(row.deleted_at)
|
@@ -42,6 +42,7 @@ import zlib
|
|
42
42
|
|
43
43
|
# Import from internal modules
|
44
44
|
from cornflow.models import CaseModel, ExecutionModel, InstanceModel, UserModel
|
45
|
+
from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
|
45
46
|
from cornflow.shared.utils import hash_json_256
|
46
47
|
from cornflow.tests.const import (
|
47
48
|
INSTANCE_URL,
|
@@ -227,12 +228,8 @@ class TestCasesFromInstanceExecutionEndpoint(CustomTestCase):
|
|
227
228
|
"execution_id": execution_id,
|
228
229
|
"schema": "solve_model_dag",
|
229
230
|
}
|
230
|
-
self.instance = InstanceModel.get_one_object(
|
231
|
-
|
232
|
-
)
|
233
|
-
self.execution = ExecutionModel.get_one_object(
|
234
|
-
user=self.user, idx=execution_id
|
235
|
-
)
|
231
|
+
self.instance = InstanceModel.get_one_object(user=self.user, idx=instance_id)
|
232
|
+
self.execution = ExecutionModel.get_one_object(user=self.user, idx=execution_id)
|
236
233
|
|
237
234
|
def test_new_case_execution(self):
|
238
235
|
"""
|
@@ -729,7 +726,7 @@ class TestCaseToInstanceEndpoint(CustomTestCase):
|
|
729
726
|
headers=self.get_header_with_auth(self.token),
|
730
727
|
)
|
731
728
|
self.assertEqual(response.status_code, 404)
|
732
|
-
self.assertEqual(response.json["error"],
|
729
|
+
self.assertEqual(response.json["error"], DATA_DOES_NOT_EXIST_MSG)
|
733
730
|
|
734
731
|
|
735
732
|
class TestCaseJsonPatch(CustomTestCase):
|
@@ -18,6 +18,7 @@ from cornflow.tests.const import (
|
|
18
18
|
DAG_URL,
|
19
19
|
BAD_EXECUTION_PATH,
|
20
20
|
EXECUTION_SOLUTION_PATH,
|
21
|
+
CUSTOM_CONFIG_PATH,
|
21
22
|
)
|
22
23
|
from cornflow.tests.custom_test_case import CustomTestCase, BaseTestCases
|
23
24
|
from cornflow.tests.unit.tools import patch_af_client
|
@@ -43,6 +44,7 @@ class TestExecutionsListEndpoint(BaseTestCases.ListFilters):
|
|
43
44
|
self.bad_payload = load_file_fk(BAD_EXECUTION_PATH)
|
44
45
|
self.payloads = [load_file_fk(f) for f in EXECUTIONS_LIST]
|
45
46
|
self.solution = load_file_fk(EXECUTION_SOLUTION_PATH)
|
47
|
+
self.custom_config_payload = load_file_fk(CUSTOM_CONFIG_PATH)
|
46
48
|
self.keys_to_check = [
|
47
49
|
"data_hash",
|
48
50
|
"created_at",
|
@@ -57,11 +59,25 @@ class TestExecutionsListEndpoint(BaseTestCases.ListFilters):
|
|
57
59
|
"instance_id",
|
58
60
|
"name",
|
59
61
|
"indicators",
|
62
|
+
"username",
|
63
|
+
"updated_at"
|
60
64
|
]
|
61
65
|
|
62
66
|
def test_new_execution(self):
|
63
67
|
self.create_new_row(self.url, self.model, payload=self.payload)
|
64
68
|
|
69
|
+
def test_get_custom_config(self):
|
70
|
+
id = self.create_new_row(
|
71
|
+
self.url, self.model, payload=self.custom_config_payload
|
72
|
+
)
|
73
|
+
url = EXECUTION_URL + "/" + str(id) + "/" + "?run=0"
|
74
|
+
|
75
|
+
response = self.get_one_row(
|
76
|
+
url,
|
77
|
+
payload={**self.custom_config_payload, **dict(id=id)},
|
78
|
+
)
|
79
|
+
self.assertEqual(response["config"]["block_model"]["solver"], "mip.gurobi")
|
80
|
+
|
65
81
|
@patch("cornflow.endpoints.execution.Airflow")
|
66
82
|
def test_new_execution_run(self, af_client_class):
|
67
83
|
patch_af_client(af_client_class)
|
@@ -260,6 +276,8 @@ class TestExecutionsDetailEndpointMock(CustomTestCase):
|
|
260
276
|
"schema",
|
261
277
|
"user_id",
|
262
278
|
"indicators",
|
279
|
+
"username",
|
280
|
+
"updated_at"
|
263
281
|
}
|
264
282
|
# we only check the following because this endpoint does not return data
|
265
283
|
self.items_to_check = ["name", "description"]
|
@@ -274,7 +292,6 @@ class TestExecutionsDetailEndpoint(
|
|
274
292
|
):
|
275
293
|
def setUp(self):
|
276
294
|
super().setUp()
|
277
|
-
self.url = self.url
|
278
295
|
self.query_arguments = {"run": 0}
|
279
296
|
|
280
297
|
# TODO: this test should be moved as it is not using the detail endpoint
|
@@ -303,6 +320,8 @@ class TestExecutionsDetailEndpoint(
|
|
303
320
|
"name",
|
304
321
|
"created_at",
|
305
322
|
"state",
|
323
|
+
"username",
|
324
|
+
"updated_at"
|
306
325
|
]
|
307
326
|
execution = self.get_one_row(
|
308
327
|
self.url + idx,
|
@@ -450,6 +469,8 @@ class TestExecutionsLogEndpoint(TestExecutionsDetailEndpointMock):
|
|
450
469
|
"user_id",
|
451
470
|
"config",
|
452
471
|
"indicators",
|
472
|
+
"username",
|
473
|
+
"updated_at"
|
453
474
|
]
|
454
475
|
|
455
476
|
def test_get_one_execution(self):
|