cornflow 1.0.11a1__py3-none-any.whl → 1.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cornflow/cli/service.py +4 -0
- cornflow/commands/__init__.py +1 -1
- cornflow/commands/schemas.py +31 -0
- cornflow/config.py +6 -0
- cornflow/endpoints/execution.py +2 -1
- cornflow/endpoints/login.py +16 -13
- cornflow/endpoints/user.py +2 -2
- cornflow/migrations/versions/991b98e24225_.py +33 -0
- cornflow/models/user.py +4 -0
- cornflow/schemas/execution.py +8 -1
- cornflow/schemas/solution_log.py +11 -5
- cornflow/schemas/user.py +3 -0
- cornflow/shared/authentication/auth.py +1 -1
- cornflow/shared/licenses.py +17 -54
- cornflow/tests/custom_test_case.py +17 -3
- cornflow/tests/integration/test_cornflowclient.py +20 -14
- cornflow/tests/unit/test_cases.py +95 -6
- cornflow/tests/unit/test_dags.py +48 -1
- cornflow/tests/unit/test_example_data.py +3 -0
- cornflow/tests/unit/test_executions.py +98 -8
- cornflow/tests/unit/test_instances.py +43 -5
- cornflow/tests/unit/test_main_alarms.py +8 -8
- cornflow/tests/unit/test_schemas.py +12 -1
- cornflow/tests/unit/test_token.py +17 -0
- cornflow/tests/unit/test_users.py +16 -0
- {cornflow-1.0.11a1.dist-info → cornflow-1.1.0a1.dist-info}/METADATA +2 -2
- {cornflow-1.0.11a1.dist-info → cornflow-1.1.0a1.dist-info}/RECORD +30 -29
- {cornflow-1.0.11a1.dist-info → cornflow-1.1.0a1.dist-info}/WHEEL +0 -0
- {cornflow-1.0.11a1.dist-info → cornflow-1.1.0a1.dist-info}/entry_points.txt +0 -0
- {cornflow-1.0.11a1.dist-info → cornflow-1.1.0a1.dist-info}/top_level.txt +0 -0
cornflow/cli/service.py
CHANGED
@@ -14,6 +14,7 @@ from cornflow.commands import (
|
|
14
14
|
register_deployed_dags_command,
|
15
15
|
register_dag_permissions_command,
|
16
16
|
update_schemas_command,
|
17
|
+
update_dag_registry_command,
|
17
18
|
)
|
18
19
|
from cornflow.shared.const import AUTH_DB, ADMIN_ROLE, SERVICE_ROLE
|
19
20
|
from cornflow.shared import db
|
@@ -211,6 +212,9 @@ def init_cornflow_service():
|
|
211
212
|
)
|
212
213
|
register_dag_permissions_command(open_deployment, verbose=True)
|
213
214
|
update_schemas_command(airflow_url, airflow_user, airflow_pwd, verbose=True)
|
215
|
+
update_dag_registry_command(
|
216
|
+
airflow_url, airflow_user, airflow_pwd, verbose=True
|
217
|
+
)
|
214
218
|
|
215
219
|
os.system(
|
216
220
|
f"/usr/local/bin/gunicorn -c python:cornflow.gunicorn "
|
cornflow/commands/__init__.py
CHANGED
@@ -6,7 +6,7 @@ from .permissions import (
|
|
6
6
|
register_dag_permissions_command,
|
7
7
|
)
|
8
8
|
from .roles import register_roles_command
|
9
|
-
from .schemas import update_schemas_command
|
9
|
+
from .schemas import update_schemas_command, update_dag_registry_command
|
10
10
|
from .users import (
|
11
11
|
create_user_with_role,
|
12
12
|
create_service_user_command,
|
cornflow/commands/schemas.py
CHANGED
@@ -27,3 +27,34 @@ def update_schemas_command(url, user, pwd, verbose: bool = False):
|
|
27
27
|
current_app.logger.info("The DAGs schemas were not updated properly")
|
28
28
|
|
29
29
|
return True
|
30
|
+
|
31
|
+
|
32
|
+
def update_dag_registry_command(url, user, pwd, verbose: bool = False):
|
33
|
+
import time
|
34
|
+
from flask import current_app
|
35
|
+
|
36
|
+
from cornflow_client.airflow.api import Airflow
|
37
|
+
|
38
|
+
af_client = Airflow(url, user, pwd)
|
39
|
+
max_attempts = 20
|
40
|
+
attempts = 0
|
41
|
+
while not af_client.is_alive() and attempts < max_attempts:
|
42
|
+
attempts += 1
|
43
|
+
if verbose == 1:
|
44
|
+
current_app.logger.info(f"Airflow is not reachable (attempt {attempts})")
|
45
|
+
time.sleep(15)
|
46
|
+
|
47
|
+
if not af_client.is_alive():
|
48
|
+
if verbose == 1:
|
49
|
+
current_app.logger.info("Airflow is not reachable")
|
50
|
+
return False
|
51
|
+
|
52
|
+
response = af_client.update_dag_registry()
|
53
|
+
if response.status_code == 200:
|
54
|
+
if verbose:
|
55
|
+
current_app.logger.info("DAGs schemas updated on cornflow")
|
56
|
+
else:
|
57
|
+
if verbose:
|
58
|
+
current_app.logger.info("The DAGs schemas were not updated properly")
|
59
|
+
|
60
|
+
return True
|
cornflow/config.py
CHANGED
@@ -76,6 +76,12 @@ class DefaultConfig(object):
|
|
76
76
|
# Alarms endpoints
|
77
77
|
ALARMS_ENDPOINTS = os.getenv("CF_ALARMS_ENDPOINT", 0)
|
78
78
|
|
79
|
+
# Token duration in hours
|
80
|
+
TOKEN_DURATION = os.getenv("TOKEN_DURATION", 24)
|
81
|
+
|
82
|
+
# Password rotation time in days
|
83
|
+
PWD_ROTATION_TIME = os.getenv("PWD_ROTATION_TIME", 120)
|
84
|
+
|
79
85
|
|
80
86
|
class Development(DefaultConfig):
|
81
87
|
|
cornflow/endpoints/execution.py
CHANGED
@@ -24,6 +24,7 @@ from cornflow.schemas.execution import (
|
|
24
24
|
ExecutionEditRequest,
|
25
25
|
QueryFiltersExecution,
|
26
26
|
ReLaunchExecutionRequest,
|
27
|
+
ExecutionDetailsWithIndicatorsAndLogResponse
|
27
28
|
)
|
28
29
|
from cornflow.shared.authentication import Auth, authenticate
|
29
30
|
from cornflow.shared.compress import compressed
|
@@ -58,7 +59,7 @@ class ExecutionEndpoint(BaseMetaResource):
|
|
58
59
|
|
59
60
|
@doc(description="Get all executions", tags=["Executions"])
|
60
61
|
@authenticate(auth_class=Auth())
|
61
|
-
@marshal_with(
|
62
|
+
@marshal_with(ExecutionDetailsWithIndicatorsAndLogResponse(many=True))
|
62
63
|
@use_kwargs(QueryFiltersExecution, location="query")
|
63
64
|
def get(self, **kwargs):
|
64
65
|
"""
|
cornflow/endpoints/login.py
CHANGED
@@ -6,10 +6,11 @@ External endpoint for the user to login to the cornflow webserver
|
|
6
6
|
from flask import current_app
|
7
7
|
from flask_apispec import use_kwargs, doc
|
8
8
|
from sqlalchemy.exc import IntegrityError, DBAPIError
|
9
|
+
from datetime import datetime, timedelta
|
9
10
|
|
10
11
|
# Import from internal modules
|
11
12
|
from cornflow.endpoints.meta_resource import BaseMetaResource
|
12
|
-
from cornflow.models import
|
13
|
+
from cornflow.models import UserModel, UserRoleModel
|
13
14
|
from cornflow.schemas.user import LoginEndpointRequest, LoginOpenAuthRequest
|
14
15
|
from cornflow.shared import db
|
15
16
|
from cornflow.shared.authentication import Auth, LDAPBase
|
@@ -47,9 +48,11 @@ class LoginBaseEndpoint(BaseMetaResource):
|
|
47
48
|
:rtype: dict
|
48
49
|
"""
|
49
50
|
auth_type = current_app.config["AUTH_TYPE"]
|
51
|
+
response = {}
|
50
52
|
|
51
53
|
if auth_type == AUTH_DB:
|
52
54
|
user = self.auth_db_authenticate(**kwargs)
|
55
|
+
response.update({"change_password": check_last_password_change(user)})
|
53
56
|
elif auth_type == AUTH_LDAP:
|
54
57
|
user = self.auth_ldap_authenticate(**kwargs)
|
55
58
|
elif auth_type == AUTH_OID:
|
@@ -62,7 +65,9 @@ class LoginBaseEndpoint(BaseMetaResource):
|
|
62
65
|
except Exception as e:
|
63
66
|
raise InvalidUsage(f"Error in generating user token: {str(e)}", 400)
|
64
67
|
|
65
|
-
|
68
|
+
response.update({"token": token, "id": user.id})
|
69
|
+
|
70
|
+
return response, 200
|
66
71
|
|
67
72
|
def auth_db_authenticate(self, username, password):
|
68
73
|
"""
|
@@ -176,6 +181,13 @@ class LoginBaseEndpoint(BaseMetaResource):
|
|
176
181
|
return user
|
177
182
|
|
178
183
|
|
184
|
+
def check_last_password_change(user):
|
185
|
+
if user.pwd_last_change:
|
186
|
+
if user.pwd_last_change + timedelta(days=int(current_app.config["PWD_ROTATION_TIME"])) < datetime.utcnow():
|
187
|
+
return True
|
188
|
+
return False
|
189
|
+
|
190
|
+
|
179
191
|
class LoginEndpoint(LoginBaseEndpoint):
|
180
192
|
"""
|
181
193
|
Endpoint used to do the login to the cornflow webserver
|
@@ -198,11 +210,7 @@ class LoginEndpoint(LoginBaseEndpoint):
|
|
198
210
|
:rtype: Tuple(dict, integer)
|
199
211
|
"""
|
200
212
|
|
201
|
-
|
202
|
-
if int(current_app.config["OPEN_DEPLOYMENT"]) == 1:
|
203
|
-
PermissionsDAG.delete_all_permissions_from_user(content["id"])
|
204
|
-
PermissionsDAG.add_all_permissions_to_user(content["id"])
|
205
|
-
return content, status
|
213
|
+
return self.log_in(**kwargs)
|
206
214
|
|
207
215
|
|
208
216
|
class LoginOpenAuthEndpoint(LoginBaseEndpoint):
|
@@ -218,9 +226,4 @@ class LoginOpenAuthEndpoint(LoginBaseEndpoint):
|
|
218
226
|
@use_kwargs(LoginOpenAuthRequest, location="json")
|
219
227
|
def post(self, **kwargs):
|
220
228
|
""" """
|
221
|
-
|
222
|
-
content, status = self.log_in(**kwargs)
|
223
|
-
if int(current_app.config["OPEN_DEPLOYMENT"]) == 1:
|
224
|
-
PermissionsDAG.delete_all_permissions_from_user(content["id"])
|
225
|
-
PermissionsDAG.add_all_permissions_to_user(content["id"])
|
226
|
-
return content, status
|
229
|
+
return self.log_in(**kwargs)
|
cornflow/endpoints/user.py
CHANGED
@@ -170,7 +170,7 @@ class UserDetailsEndpoint(BaseMetaResource):
|
|
170
170
|
f"To edit a user, go to the OID provider.",
|
171
171
|
)
|
172
172
|
|
173
|
-
if data.get("password"):
|
173
|
+
if data.get("password") is not None:
|
174
174
|
check, msg = check_password_pattern(data.get("password"))
|
175
175
|
if not check:
|
176
176
|
raise InvalidCredentials(
|
@@ -179,7 +179,7 @@ class UserDetailsEndpoint(BaseMetaResource):
|
|
179
179
|
f"The new password is not valid.",
|
180
180
|
)
|
181
181
|
|
182
|
-
if data.get("email"):
|
182
|
+
if data.get("email") is not None:
|
183
183
|
check, msg = check_email_pattern(data.get("email"))
|
184
184
|
if not check:
|
185
185
|
raise InvalidCredentials(
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""
|
2
|
+
Added pwd_last_change column to users table
|
3
|
+
|
4
|
+
Revision ID: 991b98e24225
|
5
|
+
Revises: ebdd955fcc5e
|
6
|
+
Create Date: 2024-01-31 19:17:18.009264
|
7
|
+
|
8
|
+
"""
|
9
|
+
from alembic import op
|
10
|
+
import sqlalchemy as sa
|
11
|
+
|
12
|
+
|
13
|
+
# revision identifiers, used by Alembic.
|
14
|
+
revision = "991b98e24225"
|
15
|
+
down_revision = "ebdd955fcc5e"
|
16
|
+
branch_labels = None
|
17
|
+
depends_on = None
|
18
|
+
|
19
|
+
|
20
|
+
def upgrade():
|
21
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
22
|
+
with op.batch_alter_table("users", schema=None) as batch_op:
|
23
|
+
batch_op.add_column(sa.Column("pwd_last_change", sa.DateTime(), nullable=True))
|
24
|
+
|
25
|
+
# ### end Alembic commands ###
|
26
|
+
|
27
|
+
|
28
|
+
def downgrade():
|
29
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
30
|
+
with op.batch_alter_table("users", schema=None) as batch_op:
|
31
|
+
batch_op.drop_column("pwd_last_change")
|
32
|
+
|
33
|
+
# ### end Alembic commands ###
|
cornflow/models/user.py
CHANGED
@@ -4,6 +4,7 @@ This file contains the UserModel
|
|
4
4
|
# Imports from external libraries
|
5
5
|
import random
|
6
6
|
import string
|
7
|
+
from datetime import datetime
|
7
8
|
|
8
9
|
# Imports from internal modules
|
9
10
|
from cornflow.models.meta_models import TraceAttributesModel
|
@@ -51,6 +52,7 @@ class UserModel(TraceAttributesModel):
|
|
51
52
|
last_name = db.Column(db.String(128), nullable=True)
|
52
53
|
username = db.Column(db.String(128), nullable=False, unique=True)
|
53
54
|
password = db.Column(db.String(128), nullable=True)
|
55
|
+
pwd_last_change = db.Column(db.DateTime, nullable=True)
|
54
56
|
email = db.Column(db.String(128), nullable=False, unique=True)
|
55
57
|
|
56
58
|
user_roles = db.relationship(
|
@@ -93,6 +95,7 @@ class UserModel(TraceAttributesModel):
|
|
93
95
|
self.first_name = data.get("first_name")
|
94
96
|
self.last_name = data.get("last_name")
|
95
97
|
self.username = data.get("username")
|
98
|
+
self.pwd_last_change = datetime.utcnow()
|
96
99
|
# TODO: handle better None passwords that can be found when using ldap
|
97
100
|
check_pass, msg = check_password_pattern(data.get("password"))
|
98
101
|
if check_pass:
|
@@ -123,6 +126,7 @@ class UserModel(TraceAttributesModel):
|
|
123
126
|
if new_password:
|
124
127
|
new_password = self.__generate_hash(new_password)
|
125
128
|
data["password"] = new_password
|
129
|
+
data["pwd_last_change"] = datetime.utcnow()
|
126
130
|
super().update(data)
|
127
131
|
|
128
132
|
def comes_from_external_provider(self):
|
cornflow/schemas/execution.py
CHANGED
@@ -4,7 +4,7 @@ from marshmallow import fields, Schema, validate
|
|
4
4
|
# Imports from internal modules
|
5
5
|
from cornflow.shared.const import MIN_EXECUTION_STATUS_CODE, MAX_EXECUTION_STATUS_CODE
|
6
6
|
from .common import QueryFilters, BaseDataEndpointResponse
|
7
|
-
from .solution_log import LogSchema
|
7
|
+
from .solution_log import LogSchema, BasicLogSchema
|
8
8
|
|
9
9
|
|
10
10
|
class QueryFiltersExecution(QueryFilters):
|
@@ -114,6 +114,12 @@ class ExecutionDetailsEndpointWithIndicatorsResponse(ExecutionDetailsEndpointRes
|
|
114
114
|
indicators = fields.Method("get_indicators")
|
115
115
|
|
116
116
|
|
117
|
+
class ExecutionDetailsWithIndicatorsAndLogResponse(
|
118
|
+
ExecutionDetailsEndpointWithIndicatorsResponse
|
119
|
+
):
|
120
|
+
log = fields.Nested(BasicLogSchema, attribute="log_json")
|
121
|
+
|
122
|
+
|
117
123
|
class ExecutionStatusEndpointResponse(Schema):
|
118
124
|
id = fields.Str()
|
119
125
|
state = fields.Int()
|
@@ -129,6 +135,7 @@ class ExecutionStatusEndpointUpdate(Schema):
|
|
129
135
|
class ExecutionDataEndpointResponse(ExecutionDetailsEndpointResponse):
|
130
136
|
data = fields.Raw()
|
131
137
|
checks = fields.Raw()
|
138
|
+
log = fields.Nested(BasicLogSchema, attribute="log_json")
|
132
139
|
|
133
140
|
|
134
141
|
class ExecutionLogEndpointResponse(ExecutionDetailsEndpointWithIndicatorsResponse):
|
cornflow/schemas/solution_log.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from marshmallow import fields, Schema
|
1
|
+
from marshmallow import fields, Schema, EXCLUDE
|
2
2
|
|
3
3
|
options = dict(required=True, allow_none=True)
|
4
4
|
log_options = dict(required=False, allow_none=True)
|
@@ -49,10 +49,18 @@ class FirstSolution(Schema):
|
|
49
49
|
CutsBestBound = fields.Float(**options)
|
50
50
|
|
51
51
|
|
52
|
-
class
|
52
|
+
class BasicLogSchema(Schema):
|
53
|
+
status = fields.Str(**log_options)
|
54
|
+
status_code = fields.Int(**log_options)
|
55
|
+
sol_code = fields.Int(**log_options)
|
56
|
+
|
57
|
+
|
58
|
+
class LogSchema(BasicLogSchema):
|
59
|
+
class Meta:
|
60
|
+
unknown = EXCLUDE
|
61
|
+
|
53
62
|
version = fields.Str(**log_options)
|
54
63
|
solver = fields.Str(**log_options)
|
55
|
-
status = fields.Str(**log_options)
|
56
64
|
best_bound = fields.Float(**log_options)
|
57
65
|
best_solution = fields.Float(**log_options)
|
58
66
|
gap = fields.Float(**log_options)
|
@@ -63,8 +71,6 @@ class LogSchema(Schema):
|
|
63
71
|
presolve = fields.Nested(PresolveSchema, **log_options)
|
64
72
|
first_relaxed = fields.Float(**log_options)
|
65
73
|
first_solution = fields.Nested(FirstSolution, **log_options)
|
66
|
-
status_code = fields.Int(**log_options)
|
67
|
-
sol_code = fields.Int(**log_options)
|
68
74
|
nodes = fields.Int(**log_options)
|
69
75
|
progress = fields.Nested(ProgressSchema, required=False)
|
70
76
|
cut_info = fields.Raw(**log_options)
|
cornflow/schemas/user.py
CHANGED
@@ -25,6 +25,7 @@ class UserEndpointResponse(Schema):
|
|
25
25
|
last_name = fields.Str()
|
26
26
|
email = fields.Str()
|
27
27
|
created_at = fields.Str()
|
28
|
+
pwd_last_change = fields.Str()
|
28
29
|
|
29
30
|
|
30
31
|
class UserDetailsEndpointResponse(Schema):
|
@@ -33,6 +34,7 @@ class UserDetailsEndpointResponse(Schema):
|
|
33
34
|
last_name = fields.Str()
|
34
35
|
username = fields.Str()
|
35
36
|
email = fields.Str()
|
37
|
+
pwd_last_change = fields.Str()
|
36
38
|
|
37
39
|
|
38
40
|
class TokenEndpointResponse(Schema):
|
@@ -49,6 +51,7 @@ class UserEditRequest(Schema):
|
|
49
51
|
last_name = fields.Str(required=False)
|
50
52
|
email = fields.Str(required=False)
|
51
53
|
password = fields.Str(required=False)
|
54
|
+
pwd_last_change = fields.DateTime(required=False)
|
52
55
|
|
53
56
|
|
54
57
|
class LoginEndpointRequest(Schema):
|
cornflow/shared/licenses.py
CHANGED
@@ -1,63 +1,21 @@
|
|
1
|
-
import
|
1
|
+
import importlib.metadata as metadata
|
2
2
|
|
3
3
|
|
4
|
-
def
|
5
|
-
if pkg.has_metadata("LICENSE"):
|
6
|
-
lic = pkg.get_metadata("LICENSE")
|
7
|
-
else:
|
8
|
-
lic = "(license detail not found)"
|
9
|
-
return lic
|
10
|
-
|
11
|
-
|
12
|
-
def get_info(name, lines):
|
4
|
+
def get_info(name, pkg):
|
13
5
|
"""
|
14
|
-
Search information in a
|
6
|
+
Search information in a package metadata.
|
15
7
|
The expected format of the line is "name: info"
|
16
|
-
This function
|
8
|
+
This function searches for the name and returns the info.
|
17
9
|
|
18
|
-
:param name: name to be
|
19
|
-
:param
|
10
|
+
:param name: name to be searched.
|
11
|
+
:param pkg: a dictionary representing the package metadata.
|
20
12
|
:return: the info part of the line for the given name.
|
21
13
|
"""
|
22
|
-
|
23
|
-
|
24
|
-
if line.startswith(sep):
|
25
|
-
return line.split(sep, maxsplit=1)[1]
|
14
|
+
if name in pkg:
|
15
|
+
return pkg[name]
|
26
16
|
return f"({name} not found)"
|
27
17
|
|
28
18
|
|
29
|
-
def get_main_info(pkg):
|
30
|
-
"""
|
31
|
-
Get information from libraries.
|
32
|
-
|
33
|
-
:param pkg: a package object from pkg_resources.working_set
|
34
|
-
:return: a dict with library, license, version, author, description and home page.
|
35
|
-
"""
|
36
|
-
lines1 = []
|
37
|
-
lines2 = []
|
38
|
-
# Find info in metadata
|
39
|
-
if pkg.has_metadata("METADATA"):
|
40
|
-
lines1 = pkg.get_metadata_lines("METADATA")
|
41
|
-
# find info in PKG-INFO
|
42
|
-
if pkg.has_metadata("PKG-INFO"):
|
43
|
-
lines2 = pkg.get_metadata_lines("PKG-INFO")
|
44
|
-
# Transform lines into list
|
45
|
-
lines = [l for l in lines1] + [l for l in lines2]
|
46
|
-
|
47
|
-
# Manage case where license is UNKNOWN
|
48
|
-
lic = get_info("License", lines)
|
49
|
-
if lic == "UNKNOWN":
|
50
|
-
lic = get_info("Classifier: License :", lines)
|
51
|
-
return {
|
52
|
-
"library": get_info("Name", lines),
|
53
|
-
"license": lic,
|
54
|
-
"version": get_info("Version", lines),
|
55
|
-
"author": get_info("Author", lines),
|
56
|
-
"description": get_info("Summary", lines),
|
57
|
-
"home page": get_info("Home-page", lines),
|
58
|
-
}
|
59
|
-
|
60
|
-
|
61
19
|
def get_licenses_summary():
|
62
20
|
"""
|
63
21
|
Get a list of dicts with licenses and library information.
|
@@ -65,12 +23,17 @@ def get_licenses_summary():
|
|
65
23
|
:return: a list of dicts with library, license, version, author, description, home page and license text.
|
66
24
|
"""
|
67
25
|
license_list = []
|
68
|
-
|
69
|
-
|
26
|
+
for pkg in sorted(metadata.distributions(), key=lambda x: x.metadata['Name'].lower()):
|
27
|
+
pkg_metadata = dict(pkg.metadata.items())
|
70
28
|
license_list += [
|
71
29
|
{
|
72
|
-
|
73
|
-
"
|
30
|
+
"library": get_info("Name", pkg_metadata),
|
31
|
+
"license": get_info("License", pkg_metadata),
|
32
|
+
"version": get_info("Version", pkg_metadata),
|
33
|
+
"author": get_info("Author", pkg_metadata),
|
34
|
+
"description": get_info("Summary", pkg_metadata),
|
35
|
+
"home page": get_info("Home-page", pkg_metadata),
|
74
36
|
}
|
75
37
|
]
|
38
|
+
|
76
39
|
return license_list
|
@@ -5,6 +5,9 @@ This file contains the different custom test classes used to generalize the unit
|
|
5
5
|
# Import from libraries
|
6
6
|
import logging as log
|
7
7
|
from datetime import datetime, timedelta
|
8
|
+
|
9
|
+
from typing import List
|
10
|
+
|
8
11
|
from flask import current_app
|
9
12
|
from flask_testing import TestCase
|
10
13
|
import json
|
@@ -27,7 +30,6 @@ from cornflow.tests.const import (
|
|
27
30
|
TOKEN_URL,
|
28
31
|
)
|
29
32
|
|
30
|
-
|
31
33
|
try:
|
32
34
|
date_from_str = datetime.fromisoformat
|
33
35
|
except:
|
@@ -172,7 +174,9 @@ class CustomTestCase(TestCase):
|
|
172
174
|
self.assertEqual(getattr(row, key), payload[key])
|
173
175
|
return row.id
|
174
176
|
|
175
|
-
def get_rows(
|
177
|
+
def get_rows(
|
178
|
+
self, url, data, token=None, check_data=True, keys_to_check: List[str] = None
|
179
|
+
):
|
176
180
|
token = token or self.token
|
177
181
|
|
178
182
|
codes = [
|
@@ -187,6 +191,8 @@ class CustomTestCase(TestCase):
|
|
187
191
|
if check_data:
|
188
192
|
for i in range(len(data)):
|
189
193
|
self.assertEqual(rows_data[i]["id"], codes[i])
|
194
|
+
if keys_to_check:
|
195
|
+
self.assertCountEqual(list(rows_data[i].keys()), keys_to_check)
|
190
196
|
for key in self.get_keys_to_check(data[i]):
|
191
197
|
self.assertIn(key, rows_data[i])
|
192
198
|
if key in data[i]:
|
@@ -199,7 +205,13 @@ class CustomTestCase(TestCase):
|
|
199
205
|
return payload.keys()
|
200
206
|
|
201
207
|
def get_one_row(
|
202
|
-
self,
|
208
|
+
self,
|
209
|
+
url,
|
210
|
+
payload,
|
211
|
+
expected_status=200,
|
212
|
+
check_payload=True,
|
213
|
+
token=None,
|
214
|
+
keys_to_check: List[str] = None,
|
203
215
|
):
|
204
216
|
token = token or self.token
|
205
217
|
|
@@ -210,6 +222,8 @@ class CustomTestCase(TestCase):
|
|
210
222
|
self.assertEqual(expected_status, row.status_code)
|
211
223
|
if not check_payload:
|
212
224
|
return row.json
|
225
|
+
if keys_to_check:
|
226
|
+
self.assertCountEqual(list(row.json.keys()), keys_to_check)
|
213
227
|
self.assertEqual(row.json["id"], payload["id"])
|
214
228
|
for key in self.get_keys_to_check(payload):
|
215
229
|
self.assertIn(key, row.json)
|
@@ -35,6 +35,24 @@ class TestCornflowClientBasic(CustomTestCaseLive):
|
|
35
35
|
super().setUp()
|
36
36
|
self.items_to_check = ["name", "description"]
|
37
37
|
|
38
|
+
def check_status_evolution(self, execution, end_state=EXEC_STATE_CORRECT):
|
39
|
+
statuses = [execution["state"]]
|
40
|
+
while end_state not in statuses and len(statuses) < 100:
|
41
|
+
time.sleep(1)
|
42
|
+
status = self.client.get_status(execution["id"])
|
43
|
+
statuses.append(status["state"])
|
44
|
+
|
45
|
+
self.assertIn(EXEC_STATE_QUEUED, statuses)
|
46
|
+
self.assertIn(EXEC_STATE_RUNNING, statuses)
|
47
|
+
self.assertIn(end_state, statuses)
|
48
|
+
|
49
|
+
queued_idx = statuses.index(EXEC_STATE_QUEUED)
|
50
|
+
running_idx = statuses.index(EXEC_STATE_RUNNING)
|
51
|
+
end_state_idx = statuses.index(end_state)
|
52
|
+
|
53
|
+
self.assertLess(queued_idx, running_idx)
|
54
|
+
self.assertLess(running_idx, end_state_idx)
|
55
|
+
|
38
56
|
def create_new_instance_file(self, mps_file):
|
39
57
|
name = "test_instance1"
|
40
58
|
description = "description123"
|
@@ -141,7 +159,6 @@ class TestCornflowClientBasic(CustomTestCaseLive):
|
|
141
159
|
|
142
160
|
|
143
161
|
class TestCornflowClientOpen(TestCornflowClientBasic):
|
144
|
-
|
145
162
|
# TODO: user management
|
146
163
|
# TODO: infeasible execution
|
147
164
|
|
@@ -242,7 +259,6 @@ class TestCornflowClientOpen(TestCornflowClientBasic):
|
|
242
259
|
self.client.create_instance(**payload)
|
243
260
|
|
244
261
|
def test_new_instance_with_schema_good(self):
|
245
|
-
|
246
262
|
payload = load_file(INSTANCE_PATH)
|
247
263
|
payload["schema"] = "solve_model_dag"
|
248
264
|
self.create_new_instance_payload(payload)
|
@@ -335,23 +351,13 @@ class TestCornflowClientAdmin(TestCornflowClientBasic):
|
|
335
351
|
|
336
352
|
def test_status_solving(self):
|
337
353
|
execution = self.create_instance_and_execution()
|
338
|
-
|
339
|
-
status = self.client.get_status(execution["id"])
|
340
|
-
self.assertEqual(status["state"], EXEC_STATE_CORRECT)
|
354
|
+
self.check_status_evolution(execution, EXEC_STATE_CORRECT)
|
341
355
|
|
342
356
|
def test_status_solving_timer(self):
|
343
357
|
execution = self.create_timer_instance_and_execution(10)
|
344
|
-
|
345
|
-
self.assertEqual(status["state"], EXEC_STATE_QUEUED)
|
346
|
-
time.sleep(5)
|
347
|
-
status = self.client.get_status(execution["id"])
|
348
|
-
self.assertEqual(status["state"], EXEC_STATE_RUNNING)
|
349
|
-
time.sleep(12)
|
350
|
-
status = self.client.get_status(execution["id"])
|
351
|
-
self.assertEqual(status["state"], EXEC_STATE_CORRECT)
|
358
|
+
self.check_status_evolution(execution, EXEC_STATE_CORRECT)
|
352
359
|
|
353
360
|
def test_manual_execution(self):
|
354
|
-
|
355
361
|
instance_payload = load_file(INSTANCE_PATH)
|
356
362
|
one_instance = self.create_new_instance_payload(instance_payload)
|
357
363
|
name = "test_execution_name_123"
|