cornflow 1.2.1__py3-none-any.whl → 1.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. cornflow/app.py +4 -2
  2. cornflow/cli/__init__.py +4 -0
  3. cornflow/cli/actions.py +4 -0
  4. cornflow/cli/config.py +4 -0
  5. cornflow/cli/migrations.py +13 -8
  6. cornflow/cli/permissions.py +4 -0
  7. cornflow/cli/roles.py +5 -1
  8. cornflow/cli/schemas.py +5 -0
  9. cornflow/cli/service.py +263 -131
  10. cornflow/cli/tools/api_generator.py +13 -10
  11. cornflow/cli/tools/endpoint_tools.py +191 -196
  12. cornflow/cli/tools/models_tools.py +87 -60
  13. cornflow/cli/tools/schema_generator.py +161 -67
  14. cornflow/cli/tools/schemas_tools.py +4 -5
  15. cornflow/cli/users.py +8 -0
  16. cornflow/cli/views.py +4 -0
  17. cornflow/commands/access.py +14 -3
  18. cornflow/commands/auxiliar.py +106 -0
  19. cornflow/commands/dag.py +3 -2
  20. cornflow/commands/permissions.py +186 -81
  21. cornflow/commands/roles.py +15 -14
  22. cornflow/commands/schemas.py +6 -4
  23. cornflow/commands/users.py +12 -17
  24. cornflow/commands/views.py +171 -41
  25. cornflow/endpoints/dag.py +27 -25
  26. cornflow/endpoints/data_check.py +128 -165
  27. cornflow/endpoints/example_data.py +9 -3
  28. cornflow/endpoints/execution.py +40 -34
  29. cornflow/endpoints/health.py +7 -7
  30. cornflow/endpoints/instance.py +39 -12
  31. cornflow/endpoints/meta_resource.py +4 -5
  32. cornflow/schemas/execution.py +9 -1
  33. cornflow/schemas/health.py +1 -0
  34. cornflow/shared/authentication/auth.py +76 -45
  35. cornflow/shared/const.py +10 -1
  36. cornflow/shared/exceptions.py +3 -1
  37. cornflow/shared/utils_tables.py +36 -8
  38. cornflow/shared/validators.py +1 -1
  39. cornflow/tests/const.py +1 -0
  40. cornflow/tests/custom_test_case.py +4 -4
  41. cornflow/tests/unit/test_alarms.py +1 -2
  42. cornflow/tests/unit/test_cases.py +4 -7
  43. cornflow/tests/unit/test_executions.py +22 -1
  44. cornflow/tests/unit/test_external_role_creation.py +785 -0
  45. cornflow/tests/unit/test_health.py +4 -1
  46. cornflow/tests/unit/test_log_in.py +46 -9
  47. cornflow/tests/unit/test_tables.py +3 -3
  48. {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/METADATA +2 -2
  49. {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/RECORD +52 -50
  50. {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/WHEEL +1 -1
  51. {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/entry_points.txt +0 -0
  52. {cornflow-1.2.1.dist-info → cornflow-1.2.3.dist-info}/top_level.txt +0 -0
@@ -4,16 +4,15 @@ It performs a health check to airflow and a health check to cornflow database
4
4
  """
5
5
  import os
6
6
 
7
- # Import from libraries
8
- from cornflow_client.airflow.api import Airflow
9
- from flask import current_app
10
- from flask_apispec import marshal_with, doc
11
-
12
7
  # Import from internal modules
13
8
  from cornflow.endpoints.meta_resource import BaseMetaResource
14
9
  from cornflow.models import UserModel
15
10
  from cornflow.schemas.health import HealthResponse
16
- from cornflow.shared.const import STATUS_HEALTHY, STATUS_UNHEALTHY
11
+ from cornflow.shared.const import STATUS_HEALTHY, STATUS_UNHEALTHY, CORNFLOW_VERSION
12
+ # Import from libraries
13
+ from cornflow_client.airflow.api import Airflow
14
+ from flask import current_app
15
+ from flask_apispec import marshal_with, doc
17
16
 
18
17
 
19
18
  class HealthEndpoint(BaseMetaResource):
@@ -30,6 +29,7 @@ class HealthEndpoint(BaseMetaResource):
30
29
  af_client = Airflow.from_config(current_app.config)
31
30
  airflow_status = STATUS_UNHEALTHY
32
31
  cornflow_status = STATUS_UNHEALTHY
32
+ cornflow_version = CORNFLOW_VERSION
33
33
  if af_client.is_alive():
34
34
  airflow_status = STATUS_HEALTHY
35
35
 
@@ -42,4 +42,4 @@ class HealthEndpoint(BaseMetaResource):
42
42
  current_app.logger.info(
43
43
  f"Health check: cornflow {cornflow_status}, airflow {airflow_status}"
44
44
  )
45
- return {"cornflow_status": cornflow_status, "airflow_status": airflow_status}
45
+ return {"cornflow_status": cornflow_status, "airflow_status": airflow_status, "cornflow_version":cornflow_version}
@@ -34,7 +34,6 @@ from cornflow.shared.exceptions import InvalidUsage, InvalidData
34
34
  from cornflow.shared.validators import json_schema_validate_as_string
35
35
 
36
36
 
37
-
38
37
  # Initialize the schema that all endpoints are going to use
39
38
  ALLOWED_EXTENSIONS = {"mps", "lp"}
40
39
 
@@ -92,14 +91,18 @@ class InstanceEndpoint(BaseMetaResource):
92
91
  # We validate the instance data
93
92
  config = current_app.config
94
93
 
95
- instance_schema = DeployedDAG.get_one_schema(config, data_schema, INSTANCE_SCHEMA)
96
- instance_errors = json_schema_validate_as_string(instance_schema, kwargs["data"])
94
+ instance_schema = DeployedDAG.get_one_schema(
95
+ config, data_schema, INSTANCE_SCHEMA
96
+ )
97
+ instance_errors = json_schema_validate_as_string(
98
+ instance_schema, kwargs["data"]
99
+ )
97
100
 
98
101
  if instance_errors:
99
102
  raise InvalidData(
100
103
  payload=dict(jsonschema_errors=instance_errors),
101
104
  log_txt=f"Error while user {self.get_user()} tries to create an instance. "
102
- f"Instance data do not match the jsonschema.",
105
+ f"Instance data do not match the jsonschema.",
103
106
  )
104
107
 
105
108
  # if we're here, we validated and the data seems to fit the schema
@@ -163,14 +166,18 @@ class InstanceDetailsEndpoint(InstanceDetailsEndpointBase):
163
166
 
164
167
  config = current_app.config
165
168
 
166
- instance_schema = DeployedDAG.get_one_schema(config, schema, INSTANCE_SCHEMA)
167
- instance_errors = json_schema_validate_as_string(instance_schema, kwargs["data"])
169
+ instance_schema = DeployedDAG.get_one_schema(
170
+ config, schema, INSTANCE_SCHEMA
171
+ )
172
+ instance_errors = json_schema_validate_as_string(
173
+ instance_schema, kwargs["data"]
174
+ )
168
175
 
169
176
  if instance_errors:
170
177
  raise InvalidData(
171
178
  payload=dict(jsonschema_errors=instance_errors),
172
179
  log_txt=f"Error while user {self.get_user()} tries to create an instance. "
173
- f"Instance data do not match the jsonschema.",
180
+ f"Instance data do not match the jsonschema.",
174
181
  )
175
182
 
176
183
  response = self.put_detail(data=kwargs, user=self.get_user(), idx=idx)
@@ -268,15 +275,35 @@ class InstanceFileEndpoint(BaseMetaResource):
268
275
  sense = 1 if minimize else -1
269
276
  try:
270
277
  _vars, problem = pulp.LpProblem.fromMPS(filename, sense=sense)
271
- except:
278
+ except FileNotFoundError as e:
279
+ # Handle file not found specifically
272
280
  raise InvalidUsage(
273
- error="There was an error reading the file",
274
- log_txt=f"Error while user {self.get_user()} tries to create instance from mps file. "
275
- f"There was an error reading the file.",
281
+ error=f"MPS file not found: {filename}",
282
+ log_txt=f"Error for user {self.get_user()}: MPS file '{filename}' not found. Details: {e}",
283
+ status_code=404,
284
+ ) from e
285
+ except PermissionError as e:
286
+ # Handle permission issues
287
+ raise InvalidUsage(
288
+ error=f"Permission denied reading MPS file: {filename}",
289
+ log_txt=f"Error for user {self.get_user()}: Permission denied for MPS file '{filename}'. Details: {e}",
290
+ status_code=403,
291
+ ) from e
292
+ except (ValueError, pulp.PulpError, OSError, IndexError) as e:
293
+ # Catch parsing errors, PuLP errors, and other IO errors
294
+ # Handle parsing, PuLP, or other OS errors
295
+ current_app.logger.error(
296
+ f"Error parsing MPS file {filename} for user {self.get_user()}: {e}",
297
+ exc_info=True,
276
298
  )
299
+ raise InvalidUsage(
300
+ error="Error reading or parsing the MPS file.",
301
+ log_txt=f"Error while user {self.get_user()} tries to create instance from MPS file {filename}. Details: {e}",
302
+ ) from e
303
+
277
304
  try:
278
305
  os.remove(filename)
279
- except:
306
+ except FileNotFoundError:
280
307
  pass
281
308
 
282
309
  pb_data = dict(
@@ -1,6 +1,7 @@
1
1
  """
2
2
  This file has all the logic shared for all the resources
3
3
  """
4
+
4
5
  # Import from external libraries
5
6
  from flask_restful import Resource
6
7
  from flask import g, request
@@ -13,7 +14,6 @@ from cornflow.shared.const import ALL_DEFAULT_ROLES
13
14
  from cornflow.shared.exceptions import InvalidUsage, ObjectDoesNotExist, NoPermission
14
15
 
15
16
 
16
-
17
17
  class BaseMetaResource(Resource, MethodResource):
18
18
  """
19
19
  The base resource from all methods inherit from.
@@ -30,7 +30,6 @@ class BaseMetaResource(Resource, MethodResource):
30
30
  self.auth_class = None
31
31
  self.dependents = None
32
32
  self.unique = None
33
- pass
34
33
 
35
34
  """
36
35
  METHODS USED FOR THE BASIC CRUD OPERATIONS: GET, POST, PUT, PATCH, DELETE
@@ -121,7 +120,7 @@ class BaseMetaResource(Resource, MethodResource):
121
120
  """
122
121
  item = self.data_model.get_one_object(**kwargs)
123
122
  if item is None:
124
- raise ObjectDoesNotExist("The data entity does not exist on the database")
123
+ raise ObjectDoesNotExist()
125
124
 
126
125
  data = dict(data)
127
126
 
@@ -144,7 +143,7 @@ class BaseMetaResource(Resource, MethodResource):
144
143
  item = self.data_model.get_one_object(**kwargs)
145
144
 
146
145
  if item is None:
147
- raise ObjectDoesNotExist("The data entity does not exist on the database")
146
+ raise ObjectDoesNotExist()
148
147
 
149
148
  data = dict(data)
150
149
 
@@ -164,7 +163,7 @@ class BaseMetaResource(Resource, MethodResource):
164
163
  """
165
164
  item = self.data_model.get_one_object(**kwargs)
166
165
  if item is None:
167
- raise ObjectDoesNotExist("The data entity does not exist on the database")
166
+ raise ObjectDoesNotExist()
168
167
  if self.dependents is not None:
169
168
  for element in getattr(item, self.dependents):
170
169
  element.delete()
@@ -95,7 +95,7 @@ class ExecutionDagPostRequest(ExecutionRequest, ExecutionDagRequest):
95
95
 
96
96
 
97
97
  class ExecutionDetailsEndpointResponse(BaseDataEndpointResponse):
98
- config = fields.Nested(ConfigSchemaResponse)
98
+ config = fields.Raw()
99
99
  instance_id = fields.Str()
100
100
  state = fields.Int()
101
101
  message = fields.Str(attribute="state_message")
@@ -112,6 +112,14 @@ class ExecutionDetailsEndpointWithIndicatorsResponse(ExecutionDetailsEndpointRes
112
112
  return indicators_string[1:-1]
113
113
 
114
114
  indicators = fields.Method("get_indicators")
115
+ updated_at = fields.DateTime(dump_only=True)
116
+
117
+ def get_username(self, obj):
118
+ if hasattr(obj, "user") and obj.user is not None:
119
+ return obj.user.username
120
+ return None
121
+
122
+ username = fields.Method("get_username")
115
123
 
116
124
 
117
125
  class ExecutionDetailsWithIndicatorsAndLogResponse(
@@ -5,3 +5,4 @@ class HealthResponse(Schema):
5
5
 
6
6
  cornflow_status = fields.Str()
7
7
  airflow_status = fields.Str()
8
+ cornflow_version = fields.Str()
@@ -2,15 +2,16 @@
2
2
  This file contains the auth class that can be used for authentication on the request to the REST API
3
3
  """
4
4
 
5
- # Imports from external libraries
6
- import jwt
7
- import requests
8
- from jwt.algorithms import RSAAlgorithm
9
5
  from datetime import datetime, timedelta, timezone
10
- from flask import request, g, current_app, Request
11
6
  from functools import wraps
12
7
  from typing import Tuple
8
+
9
+ # Imports from external libraries
10
+ import jwt
11
+ import requests
13
12
  from cachetools import TTLCache
13
+ from flask import request, g, current_app, Request
14
+ from jwt.algorithms import RSAAlgorithm
14
15
  from werkzeug.datastructures import Headers
15
16
 
16
17
  # Imports from internal modules
@@ -31,7 +32,6 @@ from cornflow.shared.exceptions import (
31
32
  InvalidData,
32
33
  InvalidUsage,
33
34
  NoPermission,
34
- ObjectDoesNotExist,
35
35
  )
36
36
 
37
37
  # Cache for storing public keys with 1 hour TTL
@@ -122,15 +122,13 @@ class Auth:
122
122
  def decode_token(token: str = None) -> dict:
123
123
  """
124
124
  Decodes a given JSON Web token and extracts the username from the sub claim.
125
- Works with both internal tokens and OpenID tokens.
125
+ Works with both internal tokens and OpenID tokens by attempting verification methods sequentially.
126
126
 
127
127
  :param str token: the given JSON Web Token
128
128
  :return: dictionary containing the username from the token's sub claim
129
129
  :rtype: dict
130
130
  """
131
-
132
131
  if token is None:
133
-
134
132
  raise InvalidCredentials(
135
133
  "Must provide a token in Authorization header",
136
134
  log_txt="Error while trying to decode token. Token is missing.",
@@ -138,48 +136,81 @@ class Auth:
138
136
  )
139
137
 
140
138
  try:
141
- # First try to decode header to validate basic token structure
142
-
143
- unverified_payload = jwt.decode(token, options={"verify_signature": False})
144
- issuer = unverified_payload.get("iss")
145
-
146
- # For internal tokens
147
- if issuer == INTERNAL_TOKEN_ISSUER:
148
-
149
- return jwt.decode(
150
- token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
151
- )
152
-
153
- # For OpenID tokens
154
- if current_app.config["AUTH_TYPE"] == AUTH_OID:
155
-
156
- return Auth().verify_token(
157
- token,
158
- current_app.config["OID_PROVIDER"],
159
- current_app.config["OID_EXPECTED_AUDIENCE"],
160
- )
161
-
162
- # If we get here, the issuer is not valid
163
-
164
- raise InvalidCredentials(
165
- "Invalid token issuer. Token must be issued by a valid provider",
166
- log_txt="Error while trying to decode token. Invalid issuer.",
167
- status_code=400,
139
+ # Attempt 1: Verify as an internal token (HS256)
140
+ payload = jwt.decode(
141
+ token, current_app.config["SECRET_TOKEN_KEY"], algorithms=["HS256"]
168
142
  )
143
+ if payload.get("iss") != INTERNAL_TOKEN_ISSUER:
144
+ raise jwt.InvalidIssuerError(
145
+ "Internal token issuer mismatch after verification"
146
+ )
147
+ return payload
169
148
 
170
149
  except jwt.ExpiredSignatureError:
171
-
150
+ # Handle expiration specifically, could apply to either token type if caught here first
172
151
  raise InvalidCredentials(
173
152
  "The token has expired, please login again",
174
153
  log_txt="Error while trying to decode token. The token has expired.",
175
154
  status_code=400,
176
155
  )
177
- except jwt.InvalidTokenError as e:
178
-
156
+ except (
157
+ jwt.InvalidSignatureError,
158
+ jwt.DecodeError,
159
+ jwt.InvalidTokenError,
160
+ ) as e_internal:
161
+ # Internal verification failed (signature, format, etc.). Try OIDC if configured.
162
+ if current_app.config["AUTH_TYPE"] == AUTH_OID:
163
+ try:
164
+ # Attempt 2: Verify as an OIDC token (RS256) using the dedicated method
165
+ return Auth().verify_token(
166
+ token,
167
+ current_app.config["OID_PROVIDER"],
168
+ current_app.config["OID_EXPECTED_AUDIENCE"],
169
+ )
170
+ except jwt.ExpiredSignatureError:
171
+ # OIDC token expired
172
+ raise InvalidCredentials(
173
+ "The token has expired, please login again",
174
+ log_txt="Error while trying to decode OIDC token. The token has expired.",
175
+ status_code=400,
176
+ )
177
+ except (
178
+ jwt.InvalidTokenError,
179
+ InvalidCredentials,
180
+ CommunicationError,
181
+ ) as e_oidc:
182
+ # OIDC verification failed (JWT format, signature, kid, audience, issuer, comms error)
183
+ # Log details for debugging but return a generic error to the client.
184
+ log_message = (
185
+ f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
186
+ f"OIDC verification failed ({type(e_oidc).__name__}: {str(e_oidc)})."
187
+ )
188
+ current_app.logger.warning(log_message)
189
+ raise InvalidCredentials(
190
+ "Invalid token format, signature, or configuration",
191
+ log_txt=log_message,
192
+ status_code=400,
193
+ )
194
+ else:
195
+ # Internal verification failed, and OIDC is not configured
196
+ log_message = (
197
+ f"Error decoding token. Internal verification failed ({type(e_internal).__name__}). "
198
+ f"OIDC is not configured."
199
+ )
200
+ current_app.logger.warning(log_message)
201
+ raise InvalidCredentials(
202
+ "Invalid token format or signature",
203
+ log_txt=log_message,
204
+ status_code=400,
205
+ )
206
+ except Exception as e:
207
+ # Catch any other unexpected errors during the process
208
+ log_message = f"Unexpected error during token decoding: {str(e)}"
209
+ current_app.logger.error(log_message)
179
210
  raise InvalidCredentials(
180
- "Invalid token format or signature",
181
- log_txt=f"Error while trying to decode token. The token format is invalid: {str(e)}",
182
- status_code=400,
211
+ "Could not decode or verify token due to an unexpected server error",
212
+ log_txt=log_message,
213
+ status_code=500,
183
214
  )
184
215
 
185
216
  def get_token_from_header(self, headers: Headers = None) -> str:
@@ -365,10 +396,10 @@ class Auth:
365
396
  "The permission for this endpoint is not in the database."
366
397
  )
367
398
  raise NoPermission(
368
- error="You do not have permission to access this endpoint",
399
+ error="The permission for this endpoint is not in the database.",
369
400
  status_code=403,
370
401
  log_txt=f"Error while user {user_id} tries to access endpoint. "
371
- f"The user does not permission to access. ",
402
+ f"The permission for this endpoint is not in the database.",
372
403
  )
373
404
 
374
405
  for role in user_roles:
@@ -383,7 +414,7 @@ class Auth:
383
414
  error="You do not have permission to access this endpoint",
384
415
  status_code=403,
385
416
  log_txt=f"Error while user {user_id} tries to access endpoint {view_id} with action {action_id}. "
386
- f"The user does not permission to access. ",
417
+ f"The user does not have permission to access. ",
387
418
  )
388
419
 
389
420
  @staticmethod
cornflow/shared/const.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """
2
2
  In this file we import the values for different constants on cornflow server
3
3
  """
4
-
4
+ CORNFLOW_VERSION = "1.2.3"
5
5
  INTERNAL_TOKEN_ISSUER = "cornflow"
6
6
 
7
7
  # endpoints responses for health check
@@ -112,3 +112,12 @@ BASE_PERMISSION_ASSIGNATION = [
112
112
  EXTRA_PERMISSION_ASSIGNATION = [
113
113
  (VIEWER_ROLE, PUT_ACTION, "user-detail"),
114
114
  ]
115
+
116
+ # migrations constants
117
+ MIGRATIONS_DEFAULT_PATH = "./cornflow/migrations"
118
+
119
+ # Costants for messages that are given back on exceptions
120
+ AIRFLOW_NOT_REACHABLE_MSG = "Airflow is not reachable"
121
+ DAG_PAUSED_MSG = "The dag exists but it is paused in airflow"
122
+ AIRFLOW_ERROR_MSG = "Airflow responded with an error:"
123
+ DATA_DOES_NOT_EXIST_MSG = "The data entity does not exist on the database"
@@ -9,6 +9,8 @@ from cornflow_client.constants import AirflowError
9
9
  from werkzeug.exceptions import HTTPException
10
10
  import traceback
11
11
 
12
+ from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
13
+
12
14
 
13
15
  class InvalidUsage(Exception):
14
16
  """
@@ -48,7 +50,7 @@ class ObjectDoesNotExist(InvalidUsage):
48
50
  """
49
51
 
50
52
  status_code = 404
51
- error = "The object does not exist"
53
+ error = DATA_DOES_NOT_EXIST_MSG
52
54
 
53
55
 
54
56
  class ObjectAlreadyExists(InvalidUsage):
@@ -2,6 +2,7 @@
2
2
  import inspect
3
3
  import os
4
4
  import sys
5
+ import collections
5
6
 
6
7
  from importlib import import_module
7
8
  from sqlalchemy.dialects.postgresql import TEXT
@@ -31,14 +32,42 @@ def import_models():
31
32
 
32
33
 
33
34
  def all_subclasses(cls, models=None):
34
- subclasses = set()
35
+ """Finds all direct and indirect subclasses of a given class.
36
+
37
+ Optionally filters a provided list of models first.
38
+
39
+ Args:
40
+ cls: The base class to find subclasses for.
41
+ models: An optional iterable of classes to pre-filter.
42
+
43
+ Returns:
44
+ A set containing all subclasses found.
45
+ """
46
+ filtered_subclasses = set()
35
47
  if models is not None:
36
48
  for val in models:
37
- if issubclass(val, cls):
38
- subclasses.add(val)
49
+ # Ensure val is a class before checking issubclass
50
+ if isinstance(val, type) and issubclass(val, cls):
51
+ filtered_subclasses.add(val)
52
+
53
+ all_descendants = set()
54
+ # Use a deque for efficient pop(0)
55
+ queue = collections.deque(cls.__subclasses__())
56
+ # Keep track of visited classes during the traversal to handle potential complex hierarchies
57
+ # (though direct subclass relationships shouldn't form cycles)
58
+ # Initialize with direct subclasses as they are the starting point.
59
+ visited_for_queue = set(cls.__subclasses__())
60
+
61
+ while queue:
62
+ current_sub = queue.popleft()
63
+ all_descendants.add(current_sub)
64
+
65
+ for grandchild in current_sub.__subclasses__():
66
+ if grandchild not in visited_for_queue:
67
+ visited_for_queue.add(grandchild)
68
+ queue.append(grandchild)
39
69
 
40
- return subclasses.union(set(cls.__subclasses__()).union(
41
- [s for c in cls.__subclasses__() for s in all_subclasses(c)]))
70
+ return filtered_subclasses.union(all_descendants)
42
71
 
43
72
 
44
73
  type_converter = {
@@ -76,6 +105,5 @@ def item_as_dict(item):
76
105
 
77
106
  def items_as_dict_list(ls):
78
107
  return [
79
- {c.name: getattr(item, c.name) for c in item.__table__.columns}
80
- for item in ls
81
- ]
108
+ {c.name: getattr(item, c.name) for c in item.__table__.columns} for item in ls
109
+ ]
@@ -1,12 +1,12 @@
1
1
  """
2
2
  This file has several validators
3
3
  """
4
+
4
5
  import re
5
6
  from typing import Tuple, Union
6
7
 
7
8
  from jsonschema import Draft7Validator, validators
8
9
  from disposable_email_domains import blocklist
9
- from jsonschema.protocols import Validator
10
10
 
11
11
 
12
12
  def is_special_character(character):
cornflow/tests/const.py CHANGED
@@ -17,6 +17,7 @@ INSTANCE_GC_20 = _get_file("./data/gc_20_7.json")
17
17
  INSTANCE_FILE_FAIL = _get_file("./unit/test_instances.py")
18
18
 
19
19
  EXECUTION_PATH = _get_file("./data/new_execution.json")
20
+ CUSTOM_CONFIG_PATH = _get_file("./data/new_execution_custom_config.json")
20
21
  BAD_EXECUTION_PATH = _get_file("./data/bad_execution.json")
21
22
  EXECUTION_SOLUTION_PATH = _get_file("./data/new_execution_solution.json")
22
23
  EXECUTIONS_LIST = [EXECUTION_PATH, _get_file("./data/new_execution_2.json")]
@@ -635,9 +635,9 @@ class BaseTestCases:
635
635
  # (we patch the request to airflow to check if the schema is valid)
636
636
  # we create 4 instances
637
637
  data_many = [self.payload for _ in range(4)]
638
- data_many[-1] = {**data_many[-1], **dict(schema="timer")}
639
- allrows = self.get_rows(self.url, data_many)
640
- self.apply_filter(self.url, dict(schema="timer"), allrows.json[:1])
638
+
639
+ self.get_rows(self.url, data_many)
640
+ self.apply_filter(self.url, dict(schema="timer"), [])
641
641
 
642
642
  def test_opt_filters_date_lte(self):
643
643
  """
@@ -1117,7 +1117,7 @@ class LoginTestCases:
1117
1117
 
1118
1118
  self.assertEqual(400, response.status_code)
1119
1119
  self.assertEqual(
1120
- "Invalid token issuer. Token must be issued by a valid provider",
1120
+ "Invalid token format or signature",
1121
1121
  response.json["error"],
1122
1122
  )
1123
1123
 
@@ -91,7 +91,6 @@ class TestAlarmsEndpoint(CustomTestCase):
91
91
  class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint):
92
92
  def setUp(self):
93
93
  super().setUp()
94
- self.url = self.url
95
94
  self.idx = 0
96
95
  self.payload = {
97
96
  "name": "Alarm 1",
@@ -138,4 +137,4 @@ class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint)
138
137
  # We check deleted at has a value
139
138
  self.assertIsNotNone(row.deleted_at)
140
139
  else:
141
- self.assertIsNone(row.deleted_at)
140
+ self.assertIsNone(row.deleted_at)
@@ -42,6 +42,7 @@ import zlib
42
42
 
43
43
  # Import from internal modules
44
44
  from cornflow.models import CaseModel, ExecutionModel, InstanceModel, UserModel
45
+ from cornflow.shared.const import DATA_DOES_NOT_EXIST_MSG
45
46
  from cornflow.shared.utils import hash_json_256
46
47
  from cornflow.tests.const import (
47
48
  INSTANCE_URL,
@@ -227,12 +228,8 @@ class TestCasesFromInstanceExecutionEndpoint(CustomTestCase):
227
228
  "execution_id": execution_id,
228
229
  "schema": "solve_model_dag",
229
230
  }
230
- self.instance = InstanceModel.get_one_object(
231
- user=self.user, idx=instance_id
232
- )
233
- self.execution = ExecutionModel.get_one_object(
234
- user=self.user, idx=execution_id
235
- )
231
+ self.instance = InstanceModel.get_one_object(user=self.user, idx=instance_id)
232
+ self.execution = ExecutionModel.get_one_object(user=self.user, idx=execution_id)
236
233
 
237
234
  def test_new_case_execution(self):
238
235
  """
@@ -729,7 +726,7 @@ class TestCaseToInstanceEndpoint(CustomTestCase):
729
726
  headers=self.get_header_with_auth(self.token),
730
727
  )
731
728
  self.assertEqual(response.status_code, 404)
732
- self.assertEqual(response.json["error"], "The object does not exist")
729
+ self.assertEqual(response.json["error"], DATA_DOES_NOT_EXIST_MSG)
733
730
 
734
731
 
735
732
  class TestCaseJsonPatch(CustomTestCase):
@@ -18,6 +18,7 @@ from cornflow.tests.const import (
18
18
  DAG_URL,
19
19
  BAD_EXECUTION_PATH,
20
20
  EXECUTION_SOLUTION_PATH,
21
+ CUSTOM_CONFIG_PATH,
21
22
  )
22
23
  from cornflow.tests.custom_test_case import CustomTestCase, BaseTestCases
23
24
  from cornflow.tests.unit.tools import patch_af_client
@@ -43,6 +44,7 @@ class TestExecutionsListEndpoint(BaseTestCases.ListFilters):
43
44
  self.bad_payload = load_file_fk(BAD_EXECUTION_PATH)
44
45
  self.payloads = [load_file_fk(f) for f in EXECUTIONS_LIST]
45
46
  self.solution = load_file_fk(EXECUTION_SOLUTION_PATH)
47
+ self.custom_config_payload = load_file_fk(CUSTOM_CONFIG_PATH)
46
48
  self.keys_to_check = [
47
49
  "data_hash",
48
50
  "created_at",
@@ -57,11 +59,25 @@ class TestExecutionsListEndpoint(BaseTestCases.ListFilters):
57
59
  "instance_id",
58
60
  "name",
59
61
  "indicators",
62
+ "username",
63
+ "updated_at"
60
64
  ]
61
65
 
62
66
  def test_new_execution(self):
63
67
  self.create_new_row(self.url, self.model, payload=self.payload)
64
68
 
69
+ def test_get_custom_config(self):
70
+ id = self.create_new_row(
71
+ self.url, self.model, payload=self.custom_config_payload
72
+ )
73
+ url = EXECUTION_URL + "/" + str(id) + "/" + "?run=0"
74
+
75
+ response = self.get_one_row(
76
+ url,
77
+ payload={**self.custom_config_payload, **dict(id=id)},
78
+ )
79
+ self.assertEqual(response["config"]["block_model"]["solver"], "mip.gurobi")
80
+
65
81
  @patch("cornflow.endpoints.execution.Airflow")
66
82
  def test_new_execution_run(self, af_client_class):
67
83
  patch_af_client(af_client_class)
@@ -260,6 +276,8 @@ class TestExecutionsDetailEndpointMock(CustomTestCase):
260
276
  "schema",
261
277
  "user_id",
262
278
  "indicators",
279
+ "username",
280
+ "updated_at"
263
281
  }
264
282
  # we only check the following because this endpoint does not return data
265
283
  self.items_to_check = ["name", "description"]
@@ -274,7 +292,6 @@ class TestExecutionsDetailEndpoint(
274
292
  ):
275
293
  def setUp(self):
276
294
  super().setUp()
277
- self.url = self.url
278
295
  self.query_arguments = {"run": 0}
279
296
 
280
297
  # TODO: this test should be moved as it is not using the detail endpoint
@@ -303,6 +320,8 @@ class TestExecutionsDetailEndpoint(
303
320
  "name",
304
321
  "created_at",
305
322
  "state",
323
+ "username",
324
+ "updated_at"
306
325
  ]
307
326
  execution = self.get_one_row(
308
327
  self.url + idx,
@@ -450,6 +469,8 @@ class TestExecutionsLogEndpoint(TestExecutionsDetailEndpointMock):
450
469
  "user_id",
451
470
  "config",
452
471
  "indicators",
472
+ "username",
473
+ "updated_at"
453
474
  ]
454
475
 
455
476
  def test_get_one_execution(self):