cornflow 1.2.0a1__py3-none-any.whl → 1.2.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow_config/airflow_local_settings.py +1 -1
- cornflow/cli/migrations.py +23 -3
- cornflow/cli/service.py +3 -9
- cornflow/cli/utils.py +16 -1
- cornflow/config.py +1 -1
- cornflow/endpoints/__init__.py +7 -1
- cornflow/endpoints/alarms.py +66 -2
- cornflow/endpoints/login.py +59 -38
- cornflow/endpoints/meta_resource.py +11 -3
- cornflow/models/base_data_model.py +4 -32
- cornflow/models/meta_models.py +28 -22
- cornflow/models/user.py +7 -10
- cornflow/schemas/alarms.py +8 -0
- cornflow/schemas/query.py +2 -1
- cornflow/schemas/user.py +2 -3
- cornflow/shared/authentication/auth.py +19 -39
- cornflow/tests/const.py +1 -0
- cornflow/tests/custom_test_case.py +42 -12
- cornflow/tests/unit/test_alarms.py +55 -1
- cornflow/tests/unit/test_apiview.py +106 -1
- cornflow/tests/unit/test_cli.py +6 -5
- cornflow/tests/unit/test_dags.py +0 -1
- cornflow/tests/unit/test_instances.py +12 -0
- cornflow/tests/unit/test_log_in.py +8 -5
- cornflow/tests/unit/test_roles.py +38 -0
- cornflow/tests/unit/test_token.py +11 -3
- cornflow/tests/unit/test_users.py +22 -6
- {cornflow-1.2.0a1.dist-info → cornflow-1.2.0a3.dist-info}/METADATA +13 -13
- {cornflow-1.2.0a1.dist-info → cornflow-1.2.0a3.dist-info}/RECORD +32 -32
- {cornflow-1.2.0a1.dist-info → cornflow-1.2.0a3.dist-info}/WHEEL +1 -1
- {cornflow-1.2.0a1.dist-info → cornflow-1.2.0a3.dist-info}/entry_points.txt +0 -0
- {cornflow-1.2.0a1.dist-info → cornflow-1.2.0a3.dist-info}/top_level.txt +0 -0
cornflow/models/user.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
"""
|
2
2
|
This file contains the UserModel
|
3
3
|
"""
|
4
|
+
|
4
5
|
# Imports from external libraries
|
5
6
|
import random
|
6
7
|
import string
|
7
|
-
from datetime import datetime
|
8
|
+
from datetime import datetime, timezone, timedelta
|
8
9
|
|
9
10
|
# Imports from internal modules
|
10
11
|
from cornflow.models.meta_models import TraceAttributesModel
|
@@ -55,9 +56,7 @@ class UserModel(TraceAttributesModel):
|
|
55
56
|
pwd_last_change = db.Column(db.DateTime, nullable=True)
|
56
57
|
email = db.Column(db.String(128), nullable=False, unique=True)
|
57
58
|
|
58
|
-
user_roles = db.relationship(
|
59
|
-
"UserRoleModel", cascade="all,delete", backref="users"
|
60
|
-
)
|
59
|
+
user_roles = db.relationship("UserRoleModel", cascade="all,delete", backref="users")
|
61
60
|
|
62
61
|
instances = db.relationship(
|
63
62
|
"InstanceModel",
|
@@ -95,15 +94,14 @@ class UserModel(TraceAttributesModel):
|
|
95
94
|
self.first_name = data.get("first_name")
|
96
95
|
self.last_name = data.get("last_name")
|
97
96
|
self.username = data.get("username")
|
98
|
-
self.pwd_last_change = datetime.
|
97
|
+
self.pwd_last_change = datetime.now(timezone.utc)
|
99
98
|
# TODO: handle better None passwords that can be found when using ldap
|
100
99
|
check_pass, msg = check_password_pattern(data.get("password"))
|
101
100
|
if check_pass:
|
102
101
|
self.password = self.__generate_hash(data.get("password"))
|
103
102
|
else:
|
104
103
|
raise InvalidCredentials(
|
105
|
-
msg,
|
106
|
-
log_txt="Error while trying to create a new user. " + msg
|
104
|
+
msg, log_txt="Error while trying to create a new user. " + msg
|
107
105
|
)
|
108
106
|
|
109
107
|
check_email, msg = check_email_pattern(data.get("email"))
|
@@ -111,8 +109,7 @@ class UserModel(TraceAttributesModel):
|
|
111
109
|
self.email = data.get("email")
|
112
110
|
else:
|
113
111
|
raise InvalidCredentials(
|
114
|
-
msg,
|
115
|
-
log_txt="Error while trying to create a new user. " + msg
|
112
|
+
msg, log_txt="Error while trying to create a new user. " + msg
|
116
113
|
)
|
117
114
|
|
118
115
|
def update(self, data):
|
@@ -126,7 +123,7 @@ class UserModel(TraceAttributesModel):
|
|
126
123
|
if new_password:
|
127
124
|
new_password = self.__generate_hash(new_password)
|
128
125
|
data["password"] = new_password
|
129
|
-
data["pwd_last_change"] = datetime.
|
126
|
+
data["pwd_last_change"] = datetime.now(timezone.utc)
|
130
127
|
super().update(data)
|
131
128
|
|
132
129
|
def comes_from_external_provider(self):
|
cornflow/schemas/alarms.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""
|
2
2
|
This file contains the schemas used for the table alarms defined in the application None
|
3
3
|
"""
|
4
|
+
|
4
5
|
from marshmallow import fields, Schema
|
5
6
|
|
6
7
|
|
@@ -18,3 +19,10 @@ class AlarmsResponse(AlarmsPostRequest):
|
|
18
19
|
class QueryFiltersAlarms(Schema):
|
19
20
|
schema = fields.Str(required=False)
|
20
21
|
criticality = fields.Number(required=False)
|
22
|
+
|
23
|
+
|
24
|
+
class AlarmEditRequest(Schema):
|
25
|
+
name = fields.Str(required=False)
|
26
|
+
criticality = fields.Number(required=False)
|
27
|
+
description = fields.Str(required=False)
|
28
|
+
schema = fields.Str(required=False)
|
cornflow/schemas/query.py
CHANGED
@@ -1,13 +1,14 @@
|
|
1
1
|
"""
|
2
2
|
This file contains the schemas used to query the results of a GET request
|
3
3
|
"""
|
4
|
+
|
4
5
|
from marshmallow import fields, Schema
|
5
6
|
|
6
7
|
|
7
8
|
class BaseQueryFilters(Schema):
|
8
9
|
"""This is the schema of the base query arguments"""
|
9
10
|
|
10
|
-
limit = fields.Int(required=False, dump_default=
|
11
|
+
limit = fields.Int(required=False, dump_default=10)
|
11
12
|
offset = fields.Int(required=False, dump_default=0)
|
12
13
|
creation_date_gte = fields.DateTime(required=False)
|
13
14
|
creation_date_lte = fields.DateTime(required=False)
|
cornflow/schemas/user.py
CHANGED
@@ -27,7 +27,7 @@ class UserEndpointResponse(Schema):
|
|
27
27
|
last_name = fields.Str()
|
28
28
|
email = fields.Str()
|
29
29
|
created_at = fields.Str()
|
30
|
-
pwd_last_change = fields.
|
30
|
+
pwd_last_change = fields.DateTime()
|
31
31
|
|
32
32
|
|
33
33
|
class UserDetailsEndpointResponse(Schema):
|
@@ -36,7 +36,7 @@ class UserDetailsEndpointResponse(Schema):
|
|
36
36
|
last_name = fields.Str()
|
37
37
|
username = fields.Str()
|
38
38
|
email = fields.Str()
|
39
|
-
pwd_last_change = fields.
|
39
|
+
pwd_last_change = fields.DateTime()
|
40
40
|
|
41
41
|
|
42
42
|
class TokenEndpointResponse(Schema):
|
@@ -53,7 +53,6 @@ class UserEditRequest(Schema):
|
|
53
53
|
last_name = fields.Str(required=False)
|
54
54
|
email = fields.Str(required=False)
|
55
55
|
password = fields.Str(required=False)
|
56
|
-
pwd_last_change = fields.DateTime(required=False)
|
57
56
|
|
58
57
|
|
59
58
|
class LoginEndpointRequest(Schema):
|
@@ -6,7 +6,7 @@ This file contains the auth class that can be used for authentication on the req
|
|
6
6
|
import jwt
|
7
7
|
import requests
|
8
8
|
from jwt.algorithms import RSAAlgorithm
|
9
|
-
from datetime import datetime, timedelta
|
9
|
+
from datetime import datetime, timedelta, timezone
|
10
10
|
from flask import request, g, current_app, Request
|
11
11
|
from functools import wraps
|
12
12
|
from typing import Tuple
|
@@ -107,9 +107,9 @@ class Auth:
|
|
107
107
|
)
|
108
108
|
|
109
109
|
payload = {
|
110
|
-
"exp": datetime.
|
110
|
+
"exp": datetime.now(timezone.utc)
|
111
111
|
+ timedelta(hours=float(current_app.config["TOKEN_DURATION"])),
|
112
|
-
"iat": datetime.
|
112
|
+
"iat": datetime.now(timezone.utc),
|
113
113
|
"sub": user.username,
|
114
114
|
"iss": INTERNAL_TOKEN_ISSUER,
|
115
115
|
}
|
@@ -128,9 +128,9 @@ class Auth:
|
|
128
128
|
:return: dictionary containing the username from the token's sub claim
|
129
129
|
:rtype: dict
|
130
130
|
"""
|
131
|
-
|
131
|
+
|
132
132
|
if token is None:
|
133
|
-
|
133
|
+
|
134
134
|
raise InvalidCredentials(
|
135
135
|
"Must provide a token in Authorization header",
|
136
136
|
log_txt="Error while trying to decode token. Token is missing.",
|
@@ -139,31 +139,28 @@ class Auth:
|
|
139
139
|
|
140
140
|
try:
|
141
141
|
# First try to decode header to validate basic token structure
|
142
|
-
|
142
|
+
|
143
143
|
unverified_payload = jwt.decode(token, options={"verify_signature": False})
|
144
144
|
issuer = unverified_payload.get("iss")
|
145
|
-
print(f"[decode_token] Token issuer: {issuer}")
|
146
145
|
|
147
146
|
# For internal tokens
|
148
147
|
if issuer == INTERNAL_TOKEN_ISSUER:
|
149
|
-
|
150
|
-
|
148
|
+
|
149
|
+
return jwt.decode(
|
151
150
|
token, current_app.config["SECRET_TOKEN_KEY"], algorithms="HS256"
|
152
151
|
)
|
153
|
-
return {"username": payload["sub"]}
|
154
152
|
|
155
153
|
# For OpenID tokens
|
156
154
|
if current_app.config["AUTH_TYPE"] == AUTH_OID:
|
157
|
-
|
158
|
-
|
155
|
+
|
156
|
+
return Auth().verify_token(
|
159
157
|
token,
|
160
158
|
current_app.config["OID_PROVIDER"],
|
161
159
|
current_app.config["OID_EXPECTED_AUDIENCE"],
|
162
160
|
)
|
163
|
-
return {"username": decoded["sub"]}
|
164
161
|
|
165
162
|
# If we get here, the issuer is not valid
|
166
|
-
|
163
|
+
|
167
164
|
raise InvalidCredentials(
|
168
165
|
"Invalid token issuer. Token must be issued by a valid provider",
|
169
166
|
log_txt="Error while trying to decode token. Invalid issuer.",
|
@@ -171,14 +168,14 @@ class Auth:
|
|
171
168
|
)
|
172
169
|
|
173
170
|
except jwt.ExpiredSignatureError:
|
174
|
-
|
171
|
+
|
175
172
|
raise InvalidCredentials(
|
176
173
|
"The token has expired, please login again",
|
177
174
|
log_txt="Error while trying to decode token. The token has expired.",
|
178
175
|
status_code=400,
|
179
176
|
)
|
180
177
|
except jwt.InvalidTokenError as e:
|
181
|
-
|
178
|
+
|
182
179
|
raise InvalidCredentials(
|
183
180
|
"Invalid token format or signature",
|
184
181
|
log_txt=f"Error while trying to decode token. The token format is invalid: {str(e)}",
|
@@ -251,7 +248,7 @@ class Auth:
|
|
251
248
|
token = self.get_token_from_header(headers)
|
252
249
|
data = self.decode_token(token)
|
253
250
|
|
254
|
-
user = self.user_model.get_one_object(username=data["
|
251
|
+
user = self.user_model.get_one_object(username=data["sub"])
|
255
252
|
|
256
253
|
if user is None:
|
257
254
|
err = "User not found. Please ensure you are using valid credentials"
|
@@ -306,15 +303,13 @@ class Auth:
|
|
306
303
|
:return: The decoded token claims
|
307
304
|
:rtype: dict
|
308
305
|
"""
|
309
|
-
|
306
|
+
|
310
307
|
# Get unverified header - this will raise jwt.InvalidTokenError if token format is invalid
|
311
|
-
print("[verify_token] Getting unverified header")
|
312
308
|
unverified_header = jwt.get_unverified_header(token)
|
313
|
-
print(f"[verify_token] Unverified header: {unverified_header}")
|
314
309
|
|
315
310
|
# Check for kid in header
|
316
311
|
if "kid" not in unverified_header:
|
317
|
-
|
312
|
+
|
318
313
|
raise InvalidCredentials(
|
319
314
|
"Invalid token: Missing key identifier (kid) in token header",
|
320
315
|
log_txt="Error while verifying token. Token header is missing 'kid'.",
|
@@ -322,32 +317,24 @@ class Auth:
|
|
322
317
|
)
|
323
318
|
|
324
319
|
kid = unverified_header["kid"]
|
325
|
-
print(f"[verify_token] Found kid: {kid}")
|
326
320
|
|
327
321
|
# Check if we have the keys in cache and if the kid exists
|
328
|
-
print("[verify_token] Checking cache for public keys")
|
329
322
|
public_key = None
|
330
323
|
if provider_url in public_keys_cache:
|
331
|
-
print(f"[verify_token] Found keys in cache for {provider_url}")
|
332
324
|
cached_keys = public_keys_cache[provider_url]
|
333
325
|
if kid in cached_keys:
|
334
|
-
print("[verify_token] Found matching kid in cached keys")
|
335
326
|
public_key = cached_keys[kid]
|
336
327
|
|
337
328
|
# If kid not in cache, fetch fresh keys
|
338
329
|
if public_key is None:
|
339
|
-
print("[verify_token] Public key not found in cache, fetching fresh keys")
|
340
330
|
public_keys = self.get_public_keys(provider_url)
|
341
331
|
if kid not in public_keys:
|
342
|
-
print(f"[verify_token] Kid {kid} not found in fresh public keys")
|
343
|
-
print("[verify_token] About to raise InvalidCredentials")
|
344
332
|
raise InvalidCredentials(
|
345
333
|
"Invalid token: Unknown key identifier (kid)",
|
346
334
|
log_txt="Error while verifying token. Key ID not found in public keys.",
|
347
335
|
status_code=400,
|
348
336
|
)
|
349
337
|
public_key = public_keys[kid]
|
350
|
-
print("[verify_token] Found public key in fresh keys")
|
351
338
|
|
352
339
|
# Verify token - this will raise appropriate jwt exceptions that will be caught in decode_token
|
353
340
|
return jwt.decode(
|
@@ -425,16 +412,11 @@ class BIAuth(Auth):
|
|
425
412
|
:return: dictionary containing the username from the token's sub claim
|
426
413
|
:rtype: dict
|
427
414
|
"""
|
428
|
-
if token is None:
|
429
|
-
err = "The provided token is not valid."
|
430
|
-
raise InvalidUsage(
|
431
|
-
err, log_txt="Error while trying to decode token. " + err
|
432
|
-
)
|
433
415
|
try:
|
434
|
-
|
416
|
+
return jwt.decode(
|
435
417
|
token, current_app.config["SECRET_BI_KEY"], algorithms="HS256"
|
436
418
|
)
|
437
|
-
|
419
|
+
|
438
420
|
except jwt.InvalidTokenError:
|
439
421
|
raise InvalidCredentials(
|
440
422
|
"Invalid token, please try again with a new token",
|
@@ -466,9 +448,7 @@ class BIAuth(Auth):
|
|
466
448
|
)
|
467
449
|
|
468
450
|
payload = {
|
469
|
-
"
|
470
|
-
+ timedelta(hours=float(current_app.config["TOKEN_DURATION"])),
|
471
|
-
"iat": datetime.utcnow(),
|
451
|
+
"iat": datetime.now(timezone.utc),
|
472
452
|
"sub": user.username,
|
473
453
|
"iss": INTERNAL_TOKEN_ISSUER,
|
474
454
|
}
|
cornflow/tests/const.py
CHANGED
@@ -9,6 +9,7 @@ def _get_file(relative_path):
|
|
9
9
|
|
10
10
|
PREFIX = ""
|
11
11
|
INSTANCE_PATH = _get_file("./data/new_instance.json")
|
12
|
+
EMPTY_INSTANCE_PATH = _get_file("./data/empty_instance.json")
|
12
13
|
INSTANCES_LIST = [INSTANCE_PATH, _get_file("./data/new_instance_2.json")]
|
13
14
|
INSTANCE_URL = PREFIX + "/instance/"
|
14
15
|
INSTANCE_MPS = _get_file("./data/test_mps.mps")
|
@@ -14,16 +14,16 @@ LoginTestCases
|
|
14
14
|
Test cases for login functionality
|
15
15
|
"""
|
16
16
|
|
17
|
+
import json
|
18
|
+
|
17
19
|
# Import from libraries
|
18
20
|
import logging as log
|
19
|
-
from datetime import datetime, timedelta
|
20
|
-
|
21
|
+
from datetime import datetime, timedelta, timezone
|
21
22
|
from typing import List
|
22
23
|
|
24
|
+
import jwt
|
23
25
|
from flask import current_app
|
24
26
|
from flask_testing import TestCase
|
25
|
-
import json
|
26
|
-
import jwt
|
27
27
|
|
28
28
|
# Import from internal modules
|
29
29
|
from cornflow.app import create_app
|
@@ -31,8 +31,15 @@ from cornflow.models import UserRoleModel, UserModel
|
|
31
31
|
from cornflow.commands.access import access_init_command
|
32
32
|
from cornflow.commands.dag import register_deployed_dags_command_test
|
33
33
|
from cornflow.commands.permissions import register_dag_permissions_command
|
34
|
+
from cornflow.models import UserRoleModel
|
35
|
+
from cornflow.shared import db
|
34
36
|
from cornflow.shared.authentication import Auth
|
35
|
-
from cornflow.shared.const import
|
37
|
+
from cornflow.shared.const import (
|
38
|
+
ADMIN_ROLE,
|
39
|
+
PLANNER_ROLE,
|
40
|
+
SERVICE_ROLE,
|
41
|
+
INTERNAL_TOKEN_ISSUER,
|
42
|
+
)
|
36
43
|
from cornflow.shared import db
|
37
44
|
from cornflow.tests.const import (
|
38
45
|
LOGIN_URL,
|
@@ -122,7 +129,7 @@ class CustomTestCase(TestCase):
|
|
122
129
|
).json["token"]
|
123
130
|
|
124
131
|
data = Auth().decode_token(self.token)
|
125
|
-
self.user = UserModel.get_one_object(username=data["
|
132
|
+
self.user = UserModel.get_one_object(username=data["sub"])
|
126
133
|
self.url = None
|
127
134
|
self.model = None
|
128
135
|
self.copied_items = set()
|
@@ -588,6 +595,14 @@ class BaseTestCases:
|
|
588
595
|
allrows = self.get_rows(self.url, data_many)
|
589
596
|
self.apply_filter(self.url, dict(limit=1), [allrows.json[0]])
|
590
597
|
|
598
|
+
def test_opt_filters_limit_none(self):
|
599
|
+
"""
|
600
|
+
Tests the limit filter option
|
601
|
+
"""
|
602
|
+
data_many = [self.payload for _ in range(4)]
|
603
|
+
allrows = self.get_rows(self.url, data_many)
|
604
|
+
self.apply_filter(self.url, dict(limit=None), allrows.json)
|
605
|
+
|
591
606
|
def test_opt_filters_offset(self):
|
592
607
|
"""
|
593
608
|
Tests the offset filter option.
|
@@ -597,6 +612,22 @@ class BaseTestCases:
|
|
597
612
|
allrows = self.get_rows(self.url, data_many)
|
598
613
|
self.apply_filter(self.url, dict(offset=1, limit=2), allrows.json[1:3])
|
599
614
|
|
615
|
+
def test_opt_filters_offset_zero(self):
|
616
|
+
"""
|
617
|
+
Tests the offset filter option with a zero value.
|
618
|
+
"""
|
619
|
+
data_many = [self.payload for _ in range(4)]
|
620
|
+
allrows = self.get_rows(self.url, data_many)
|
621
|
+
self.apply_filter(self.url, dict(offset=0), allrows.json)
|
622
|
+
|
623
|
+
def test_opt_filters_offset_none(self):
|
624
|
+
"""
|
625
|
+
Tests the offset filter option with a None value.
|
626
|
+
"""
|
627
|
+
data_many = [self.payload for _ in range(4)]
|
628
|
+
allrows = self.get_rows(self.url, data_many)
|
629
|
+
self.apply_filter(self.url, dict(offset=None), allrows.json)
|
630
|
+
|
600
631
|
def test_opt_filters_schema(self):
|
601
632
|
"""
|
602
633
|
Tests the schema filter option.
|
@@ -1005,7 +1036,7 @@ class LoginTestCases:
|
|
1005
1036
|
expired_token = jwt.encode(
|
1006
1037
|
expired_payload,
|
1007
1038
|
current_app.config["SECRET_TOKEN_KEY"],
|
1008
|
-
algorithm="HS256"
|
1039
|
+
algorithm="HS256",
|
1009
1040
|
)
|
1010
1041
|
|
1011
1042
|
# Try to use the expired token
|
@@ -1020,8 +1051,7 @@ class LoginTestCases:
|
|
1020
1051
|
|
1021
1052
|
self.assertEqual(400, response.status_code)
|
1022
1053
|
self.assertEqual(
|
1023
|
-
"The token has expired, please login again",
|
1024
|
-
response.json["error"]
|
1054
|
+
"The token has expired, please login again", response.json["error"]
|
1025
1055
|
)
|
1026
1056
|
|
1027
1057
|
def test_bad_format_token(self):
|
@@ -1072,7 +1102,7 @@ class LoginTestCases:
|
|
1072
1102
|
invalid_token = jwt.encode(
|
1073
1103
|
invalid_payload,
|
1074
1104
|
current_app.config["SECRET_TOKEN_KEY"],
|
1075
|
-
algorithm="HS256"
|
1105
|
+
algorithm="HS256",
|
1076
1106
|
)
|
1077
1107
|
|
1078
1108
|
# Try to use the invalid token
|
@@ -1113,7 +1143,7 @@ class LoginTestCases:
|
|
1113
1143
|
)
|
1114
1144
|
|
1115
1145
|
self.assertAlmostEqual(
|
1116
|
-
datetime.
|
1117
|
-
datetime.
|
1146
|
+
datetime.now(timezone.utc),
|
1147
|
+
datetime.fromtimestamp(decoded_token["iat"], timezone.utc),
|
1118
1148
|
delta=timedelta(seconds=2),
|
1119
1149
|
)
|
@@ -13,9 +13,10 @@ TestAlarmsEndpoint
|
|
13
13
|
"""
|
14
14
|
|
15
15
|
# Imports from internal modules
|
16
|
+
import json
|
16
17
|
from cornflow.models import AlarmsModel
|
17
18
|
from cornflow.tests.const import ALARMS_URL
|
18
|
-
from cornflow.tests.custom_test_case import CustomTestCase
|
19
|
+
from cornflow.tests.custom_test_case import BaseTestCases, CustomTestCase
|
19
20
|
|
20
21
|
|
21
22
|
class TestAlarmsEndpoint(CustomTestCase):
|
@@ -85,3 +86,56 @@ class TestAlarmsEndpoint(CustomTestCase):
|
|
85
86
|
self.assertIn(key, rows_data[i])
|
86
87
|
if key in data[i]:
|
87
88
|
self.assertEqual(rows_data[i][key], data[i][key])
|
89
|
+
|
90
|
+
|
91
|
+
class TestAlarmsDetailEndpoint(TestAlarmsEndpoint, BaseTestCases.DetailEndpoint):
|
92
|
+
def setUp(self):
|
93
|
+
super().setUp()
|
94
|
+
self.url = self.url
|
95
|
+
self.idx = 0
|
96
|
+
self.payload = {
|
97
|
+
"name": "Alarm 1",
|
98
|
+
"description": "Description Alarm 1",
|
99
|
+
"criticality": 1,
|
100
|
+
}
|
101
|
+
|
102
|
+
def test_disable_alarm_detail(self):
|
103
|
+
"""
|
104
|
+
The idea would be to read the alarm_id from the query and be able to disable the entire row for that alarm in the database.
|
105
|
+
To check this, I would use an example data set of different alarms and, after giving a specific id, be able to return the same data set,
|
106
|
+
excluding those related to the given alarm_id.
|
107
|
+
|
108
|
+
Verifies:
|
109
|
+
- Retrieval of a single alarm using its ID
|
110
|
+
- Correct validation of alarm data fields
|
111
|
+
|
112
|
+
"""
|
113
|
+
|
114
|
+
data = {
|
115
|
+
"name": "Alarm 1",
|
116
|
+
"description": "Description Alarm 1",
|
117
|
+
"criticality": 1,
|
118
|
+
}
|
119
|
+
id = self.create_new_row(self.url, self.model, data)
|
120
|
+
data = {
|
121
|
+
"name": "Alarm 2",
|
122
|
+
"description": "Description Alarm 2",
|
123
|
+
"criticality": 1,
|
124
|
+
}
|
125
|
+
self.idx = id
|
126
|
+
url = self.url + str(id) + "/"
|
127
|
+
response = self.client.delete(
|
128
|
+
url,
|
129
|
+
data=json.dumps(data),
|
130
|
+
follow_redirects=True,
|
131
|
+
headers=self.get_header_with_auth(self.token),
|
132
|
+
)
|
133
|
+
self.assertEqual(response.status_code, 200)
|
134
|
+
self.assertEqual(response.json, {"message": "Object marked as disabled"})
|
135
|
+
all_rows = AlarmsModel.query.all()
|
136
|
+
for row in all_rows:
|
137
|
+
if row.id == id:
|
138
|
+
# We check deleted at has a value
|
139
|
+
self.assertIsNotNone(row.deleted_at)
|
140
|
+
else:
|
141
|
+
self.assertIsNone(row.deleted_at)
|
@@ -13,7 +13,8 @@ TestApiViewListEndpoint
|
|
13
13
|
"""
|
14
14
|
|
15
15
|
# Import from internal modules
|
16
|
-
from cornflow.endpoints import ApiViewListEndpoint, resources
|
16
|
+
from cornflow.endpoints import ApiViewListEndpoint, resources, alarms_resources
|
17
|
+
from cornflow.models import ViewModel
|
17
18
|
from cornflow.shared.const import ROLES_MAP
|
18
19
|
from cornflow.tests.const import APIVIEW_URL
|
19
20
|
from cornflow.tests.custom_test_case import CustomTestCase
|
@@ -102,3 +103,107 @@ class TestApiViewListEndpoint(CustomTestCase):
|
|
102
103
|
)
|
103
104
|
|
104
105
|
self.assertEqual(403, response.status_code)
|
106
|
+
|
107
|
+
|
108
|
+
class TestApiViewModel(CustomTestCase):
|
109
|
+
"""
|
110
|
+
Test cases for the API views list endpoint.
|
111
|
+
|
112
|
+
This class tests the functionality of listing available API views, including:
|
113
|
+
- Authorization checks for different user roles
|
114
|
+
- Validation of returned API view data
|
115
|
+
- Access control for authorized and unauthorized roles
|
116
|
+
"""
|
117
|
+
|
118
|
+
def setUp(self):
|
119
|
+
"""
|
120
|
+
Set up test environment before each test.
|
121
|
+
|
122
|
+
Initializes test data including:
|
123
|
+
- Base test case setup
|
124
|
+
- Roles with access permissions
|
125
|
+
- Test payload with API view data
|
126
|
+
- Items to check in responses
|
127
|
+
"""
|
128
|
+
super().setUp()
|
129
|
+
self.roles_with_access = ApiViewListEndpoint.ROLES_WITH_ACCESS
|
130
|
+
self.payload = [
|
131
|
+
{
|
132
|
+
"name": view["endpoint"],
|
133
|
+
"url_rule": view["urls"],
|
134
|
+
"description": view["resource"].DESCRIPTION,
|
135
|
+
}
|
136
|
+
for view in resources
|
137
|
+
]
|
138
|
+
self.items_to_check = ["name", "description", "url_rule"]
|
139
|
+
|
140
|
+
def test_get_all_objects(self):
|
141
|
+
"""
|
142
|
+
Test that the get_all_objects method works properly
|
143
|
+
"""
|
144
|
+
expected_count = len(resources) + len(alarms_resources)
|
145
|
+
# Test getting all objects
|
146
|
+
all_instances = ViewModel.get_all_objects().all()
|
147
|
+
self.assertEqual(len(all_instances), expected_count)
|
148
|
+
|
149
|
+
# Check that all resources are included in the results
|
150
|
+
api_view_names = [view.name for view in all_instances]
|
151
|
+
expected_names = [view["endpoint"] for view in resources]
|
152
|
+
|
153
|
+
# Verify all expected resource names are in the results
|
154
|
+
for name in expected_names:
|
155
|
+
self.assertIn(name, api_view_names)
|
156
|
+
|
157
|
+
# Test with offset and limit
|
158
|
+
limited_instances = ViewModel.get_all_objects(offset=10, limit=10).all()
|
159
|
+
self.assertEqual(len(limited_instances), 10)
|
160
|
+
|
161
|
+
# Verify these are different from the first 10 results
|
162
|
+
first_ten = [view.id for view in all_instances[:10]]
|
163
|
+
offset_ten = [view.id for view in limited_instances]
|
164
|
+
self.assertNotEqual(first_ten, offset_ten)
|
165
|
+
|
166
|
+
# Test with only offset
|
167
|
+
offset_instances = ViewModel.get_all_objects(offset=10).all()
|
168
|
+
self.assertEqual(len(offset_instances), expected_count - 10)
|
169
|
+
|
170
|
+
# Verify these match with the correct slice of all results
|
171
|
+
offset_ids = [view.id for view in offset_instances]
|
172
|
+
expected_offset_ids = [view.id for view in all_instances[10:]]
|
173
|
+
self.assertEqual(offset_ids, expected_offset_ids)
|
174
|
+
|
175
|
+
# Test with only limit
|
176
|
+
limit_instances = ViewModel.get_all_objects(limit=10).all()
|
177
|
+
self.assertEqual(len(limit_instances), 10)
|
178
|
+
|
179
|
+
# Verify these match with the first 10 of all results
|
180
|
+
limit_ids = [view.id for view in limit_instances]
|
181
|
+
expected_limit_ids = [view.id for view in all_instances[:10]]
|
182
|
+
self.assertEqual(limit_ids, expected_limit_ids)
|
183
|
+
|
184
|
+
def test_get_one_object_by_name(self):
|
185
|
+
"""
|
186
|
+
Test that the get_one_by_name method works properly
|
187
|
+
"""
|
188
|
+
instance = ViewModel.get_one_by_name(name="instance")
|
189
|
+
self.assertEqual(instance.name, "instance")
|
190
|
+
|
191
|
+
def test_get_one_object(self):
|
192
|
+
"""
|
193
|
+
Test that the get_one_object method works properly
|
194
|
+
"""
|
195
|
+
instance = ViewModel.get_one_object(idx=1)
|
196
|
+
self.assertEqual(instance.name, "instance")
|
197
|
+
|
198
|
+
def test_get_one_object_without_idx(self):
|
199
|
+
"""
|
200
|
+
Test that the get_one_object method works properly when called without any filters
|
201
|
+
"""
|
202
|
+
# Call get_one_object without any filters
|
203
|
+
instance = ViewModel.get_one_object()
|
204
|
+
|
205
|
+
# It should return the first instance by default
|
206
|
+
first_instance = ViewModel.get_all_objects().first()
|
207
|
+
self.assertIsNotNone(instance)
|
208
|
+
self.assertEqual(instance.id, first_instance.id)
|
209
|
+
self.assertEqual(instance.name, first_instance.name)
|
cornflow/tests/unit/test_cli.py
CHANGED
@@ -33,6 +33,7 @@ from cornflow.models import (
|
|
33
33
|
from cornflow.models import UserModel
|
34
34
|
from cornflow.shared import db
|
35
35
|
from cornflow.shared.exceptions import NoPermission, ObjectDoesNotExist
|
36
|
+
from cornflow.endpoints import resources, alarms_resources
|
36
37
|
|
37
38
|
|
38
39
|
class CLITests(TestCase):
|
@@ -281,7 +282,7 @@ class CLITests(TestCase):
|
|
281
282
|
result = runner.invoke(cli, ["views", "init", "-v"])
|
282
283
|
self.assertEqual(result.exit_code, 0)
|
283
284
|
views = ViewModel.get_all_objects().all()
|
284
|
-
self.assertEqual(len(views),
|
285
|
+
self.assertEqual(len(views), (len(resources) + len(alarms_resources)))
|
285
286
|
|
286
287
|
def test_permissions_entrypoint(self):
|
287
288
|
"""
|
@@ -323,8 +324,8 @@ class CLITests(TestCase):
|
|
323
324
|
permissions = PermissionViewRoleModel.get_all_objects().all()
|
324
325
|
self.assertEqual(len(actions), 5)
|
325
326
|
self.assertEqual(len(roles), 4)
|
326
|
-
self.assertEqual(len(views),
|
327
|
-
self.assertEqual(len(permissions),
|
327
|
+
self.assertEqual(len(views), (len(resources) + len(alarms_resources)))
|
328
|
+
self.assertEqual(len(permissions), 562)
|
328
329
|
|
329
330
|
def test_permissions_base_command(self):
|
330
331
|
"""
|
@@ -348,8 +349,8 @@ class CLITests(TestCase):
|
|
348
349
|
permissions = PermissionViewRoleModel.get_all_objects().all()
|
349
350
|
self.assertEqual(len(actions), 5)
|
350
351
|
self.assertEqual(len(roles), 4)
|
351
|
-
self.assertEqual(len(views),
|
352
|
-
self.assertEqual(len(permissions),
|
352
|
+
self.assertEqual(len(views), (len(resources) + len(alarms_resources)))
|
353
|
+
self.assertEqual(len(permissions), 562)
|
353
354
|
|
354
355
|
def test_service_entrypoint(self):
|
355
356
|
"""
|
cornflow/tests/unit/test_dags.py
CHANGED