cornflow 1.0.11a1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. cornflow/cli/service.py +4 -0
  2. cornflow/commands/__init__.py +1 -1
  3. cornflow/commands/schemas.py +31 -0
  4. cornflow/config.py +6 -0
  5. cornflow/endpoints/__init__.py +15 -20
  6. cornflow/endpoints/example_data.py +64 -13
  7. cornflow/endpoints/execution.py +2 -1
  8. cornflow/endpoints/login.py +16 -13
  9. cornflow/endpoints/user.py +2 -2
  10. cornflow/migrations/versions/991b98e24225_.py +33 -0
  11. cornflow/models/user.py +4 -0
  12. cornflow/schemas/example_data.py +7 -2
  13. cornflow/schemas/execution.py +8 -1
  14. cornflow/schemas/solution_log.py +11 -5
  15. cornflow/schemas/user.py +3 -0
  16. cornflow/shared/authentication/auth.py +1 -1
  17. cornflow/shared/licenses.py +17 -54
  18. cornflow/tests/custom_test_case.py +17 -3
  19. cornflow/tests/integration/test_cornflowclient.py +20 -14
  20. cornflow/tests/unit/test_cases.py +95 -6
  21. cornflow/tests/unit/test_cli.py +5 -5
  22. cornflow/tests/unit/test_dags.py +48 -1
  23. cornflow/tests/unit/test_example_data.py +85 -12
  24. cornflow/tests/unit/test_executions.py +98 -8
  25. cornflow/tests/unit/test_instances.py +43 -5
  26. cornflow/tests/unit/test_main_alarms.py +8 -8
  27. cornflow/tests/unit/test_schemas.py +12 -1
  28. cornflow/tests/unit/test_token.py +17 -0
  29. cornflow/tests/unit/test_users.py +16 -0
  30. {cornflow-1.0.11a1.dist-info → cornflow-1.1.0.dist-info}/METADATA +2 -2
  31. {cornflow-1.0.11a1.dist-info → cornflow-1.1.0.dist-info}/RECORD +34 -33
  32. {cornflow-1.0.11a1.dist-info → cornflow-1.1.0.dist-info}/WHEEL +0 -0
  33. {cornflow-1.0.11a1.dist-info → cornflow-1.1.0.dist-info}/entry_points.txt +0 -0
  34. {cornflow-1.0.11a1.dist-info → cornflow-1.1.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,9 @@ This file contains the different custom test classes used to generalize the unit
5
5
  # Import from libraries
6
6
  import logging as log
7
7
  from datetime import datetime, timedelta
8
+
9
+ from typing import List
10
+
8
11
  from flask import current_app
9
12
  from flask_testing import TestCase
10
13
  import json
@@ -27,7 +30,6 @@ from cornflow.tests.const import (
27
30
  TOKEN_URL,
28
31
  )
29
32
 
30
-
31
33
  try:
32
34
  date_from_str = datetime.fromisoformat
33
35
  except:
@@ -172,7 +174,9 @@ class CustomTestCase(TestCase):
172
174
  self.assertEqual(getattr(row, key), payload[key])
173
175
  return row.id
174
176
 
175
- def get_rows(self, url, data, token=None, check_data=True):
177
+ def get_rows(
178
+ self, url, data, token=None, check_data=True, keys_to_check: List[str] = None
179
+ ):
176
180
  token = token or self.token
177
181
 
178
182
  codes = [
@@ -187,6 +191,8 @@ class CustomTestCase(TestCase):
187
191
  if check_data:
188
192
  for i in range(len(data)):
189
193
  self.assertEqual(rows_data[i]["id"], codes[i])
194
+ if keys_to_check:
195
+ self.assertCountEqual(list(rows_data[i].keys()), keys_to_check)
190
196
  for key in self.get_keys_to_check(data[i]):
191
197
  self.assertIn(key, rows_data[i])
192
198
  if key in data[i]:
@@ -199,7 +205,13 @@ class CustomTestCase(TestCase):
199
205
  return payload.keys()
200
206
 
201
207
  def get_one_row(
202
- self, url, payload, expected_status=200, check_payload=True, token=None
208
+ self,
209
+ url,
210
+ payload,
211
+ expected_status=200,
212
+ check_payload=True,
213
+ token=None,
214
+ keys_to_check: List[str] = None,
203
215
  ):
204
216
  token = token or self.token
205
217
 
@@ -210,6 +222,8 @@ class CustomTestCase(TestCase):
210
222
  self.assertEqual(expected_status, row.status_code)
211
223
  if not check_payload:
212
224
  return row.json
225
+ if keys_to_check:
226
+ self.assertCountEqual(list(row.json.keys()), keys_to_check)
213
227
  self.assertEqual(row.json["id"], payload["id"])
214
228
  for key in self.get_keys_to_check(payload):
215
229
  self.assertIn(key, row.json)
@@ -35,6 +35,24 @@ class TestCornflowClientBasic(CustomTestCaseLive):
35
35
  super().setUp()
36
36
  self.items_to_check = ["name", "description"]
37
37
 
38
+ def check_status_evolution(self, execution, end_state=EXEC_STATE_CORRECT):
39
+ statuses = [execution["state"]]
40
+ while end_state not in statuses and len(statuses) < 100:
41
+ time.sleep(1)
42
+ status = self.client.get_status(execution["id"])
43
+ statuses.append(status["state"])
44
+
45
+ self.assertIn(EXEC_STATE_QUEUED, statuses)
46
+ self.assertIn(EXEC_STATE_RUNNING, statuses)
47
+ self.assertIn(end_state, statuses)
48
+
49
+ queued_idx = statuses.index(EXEC_STATE_QUEUED)
50
+ running_idx = statuses.index(EXEC_STATE_RUNNING)
51
+ end_state_idx = statuses.index(end_state)
52
+
53
+ self.assertLess(queued_idx, running_idx)
54
+ self.assertLess(running_idx, end_state_idx)
55
+
38
56
  def create_new_instance_file(self, mps_file):
39
57
  name = "test_instance1"
40
58
  description = "description123"
@@ -141,7 +159,6 @@ class TestCornflowClientBasic(CustomTestCaseLive):
141
159
 
142
160
 
143
161
  class TestCornflowClientOpen(TestCornflowClientBasic):
144
-
145
162
  # TODO: user management
146
163
  # TODO: infeasible execution
147
164
 
@@ -242,7 +259,6 @@ class TestCornflowClientOpen(TestCornflowClientBasic):
242
259
  self.client.create_instance(**payload)
243
260
 
244
261
  def test_new_instance_with_schema_good(self):
245
-
246
262
  payload = load_file(INSTANCE_PATH)
247
263
  payload["schema"] = "solve_model_dag"
248
264
  self.create_new_instance_payload(payload)
@@ -335,23 +351,13 @@ class TestCornflowClientAdmin(TestCornflowClientBasic):
335
351
 
336
352
  def test_status_solving(self):
337
353
  execution = self.create_instance_and_execution()
338
- time.sleep(10)
339
- status = self.client.get_status(execution["id"])
340
- self.assertEqual(status["state"], EXEC_STATE_CORRECT)
354
+ self.check_status_evolution(execution, EXEC_STATE_CORRECT)
341
355
 
342
356
  def test_status_solving_timer(self):
343
357
  execution = self.create_timer_instance_and_execution(10)
344
- status = self.client.get_status(execution["id"])
345
- self.assertEqual(status["state"], EXEC_STATE_QUEUED)
346
- time.sleep(5)
347
- status = self.client.get_status(execution["id"])
348
- self.assertEqual(status["state"], EXEC_STATE_RUNNING)
349
- time.sleep(12)
350
- status = self.client.get_status(execution["id"])
351
- self.assertEqual(status["state"], EXEC_STATE_CORRECT)
358
+ self.check_status_evolution(execution, EXEC_STATE_CORRECT)
352
359
 
353
360
  def test_manual_execution(self):
354
-
355
361
  instance_payload = load_file(INSTANCE_PATH)
356
362
  one_instance = self.create_new_instance_payload(instance_payload)
357
363
  name = "test_execution_name_123"
@@ -201,8 +201,29 @@ class TestCasesRawDataEndpoint(CustomTestCase):
201
201
  self.payload.pop("solution")
202
202
  self.items_to_check = ["name", "description", "schema", "data"]
203
203
  _id = self.create_new_row(self.url, self.model, self.payload)
204
+ keys_to_check = [
205
+ "data",
206
+ "solution_checks",
207
+ "updated_at",
208
+ "id",
209
+ "schema",
210
+ "data_hash",
211
+ "path",
212
+ "solution_hash",
213
+ "user_id",
214
+ "indicators",
215
+ "solution",
216
+ "is_dir",
217
+ "description",
218
+ "name",
219
+ "checks",
220
+ "created_at",
221
+ ]
204
222
  data = self.get_one_row(
205
- self.url + "/" + str(_id) + "/data/", payload={}, check_payload=False
223
+ self.url + "/" + str(_id) + "/data/",
224
+ payload={},
225
+ check_payload=False,
226
+ keys_to_check=keys_to_check,
206
227
  )
207
228
  self.assertIsNone(data["solution"])
208
229
 
@@ -293,7 +314,21 @@ class TestCaseListEndpoint(BaseTestCases.ListFilters):
293
314
  self.url = CASE_URL
294
315
 
295
316
  def test_get_rows(self):
296
- self.get_rows(self.url, self.payloads)
317
+ keys_to_check = [
318
+ "data_hash",
319
+ "created_at",
320
+ "is_dir",
321
+ "path",
322
+ "schema",
323
+ "description",
324
+ "solution_hash",
325
+ "id",
326
+ "user_id",
327
+ "updated_at",
328
+ "name",
329
+ "indicators",
330
+ ]
331
+ self.get_rows(self.url, self.payloads, keys_to_check=keys_to_check)
297
332
 
298
333
 
299
334
  class TestCaseDetailEndpoint(BaseTestCases.DetailEndpoint):
@@ -365,7 +400,19 @@ class TestCaseToInstanceEndpoint(CustomTestCase):
365
400
  )
366
401
 
367
402
  payload = response.json
368
- result = self.get_one_row(INSTANCE_URL + payload["id"] + "/", payload)
403
+ keys_to_check = [
404
+ "id",
405
+ "schema",
406
+ "data_hash",
407
+ "executions",
408
+ "user_id",
409
+ "description",
410
+ "name",
411
+ "created_at",
412
+ ]
413
+ result = self.get_one_row(
414
+ INSTANCE_URL + payload["id"] + "/", payload, keys_to_check=keys_to_check
415
+ )
369
416
  dif = self.response_items.symmetric_difference(result.keys())
370
417
  self.assertEqual(len(dif), 0)
371
418
 
@@ -382,7 +429,23 @@ class TestCaseToInstanceEndpoint(CustomTestCase):
382
429
  ]
383
430
  self.response_items = set(self.items_to_check)
384
431
 
385
- result = self.get_one_row(INSTANCE_URL + payload["id"] + "/data/", payload)
432
+ keys_to_check = [
433
+ "data",
434
+ "id",
435
+ "schema",
436
+ "data_hash",
437
+ "user_id",
438
+ "description",
439
+ "name",
440
+ "checks",
441
+ "created_at",
442
+ ]
443
+
444
+ result = self.get_one_row(
445
+ INSTANCE_URL + payload["id"] + "/data/",
446
+ payload,
447
+ keys_to_check=keys_to_check,
448
+ )
386
449
  dif = self.response_items.symmetric_difference(result.keys())
387
450
  self.assertEqual(len(dif), 0)
388
451
 
@@ -539,11 +602,37 @@ class TestCaseDataEndpoint(CustomTestCase):
539
602
  ]
540
603
 
541
604
  def test_get_data(self):
542
- self.get_one_row(self.url + str(self.payload["id"]) + "/data/", self.payload)
605
+ keys_to_check = [
606
+ "data",
607
+ "solution_checks",
608
+ "updated_at",
609
+ "id",
610
+ "schema",
611
+ "data_hash",
612
+ "path",
613
+ "solution_hash",
614
+ "user_id",
615
+ "indicators",
616
+ "solution",
617
+ "is_dir",
618
+ "description",
619
+ "name",
620
+ "checks",
621
+ "created_at",
622
+ ]
623
+ self.get_one_row(
624
+ self.url + str(self.payload["id"]) + "/data/",
625
+ self.payload,
626
+ keys_to_check=keys_to_check,
627
+ )
543
628
 
544
629
  def test_get_no_data(self):
545
630
  self.get_one_row(
546
- self.url + str(500) + "/data/", {}, expected_status=404, check_payload=False
631
+ self.url + str(500) + "/data/",
632
+ {},
633
+ expected_status=404,
634
+ check_payload=False,
635
+ keys_to_check=["error"],
547
636
  )
548
637
 
549
638
  def test_get_compressed_data(self):
@@ -131,7 +131,7 @@ class CLITests(TestCase):
131
131
  result = runner.invoke(cli, ["views", "init", "-v"])
132
132
  self.assertEqual(result.exit_code, 0)
133
133
  views = ViewModel.get_all_objects().all()
134
- self.assertEqual(len(views), 48)
134
+ self.assertEqual(len(views), 49)
135
135
 
136
136
  def test_permissions_entrypoint(self):
137
137
  runner = CliRunner()
@@ -155,8 +155,8 @@ class CLITests(TestCase):
155
155
  permissions = PermissionViewRoleModel.get_all_objects().all()
156
156
  self.assertEqual(len(actions), 5)
157
157
  self.assertEqual(len(roles), 4)
158
- self.assertEqual(len(views), 48)
159
- self.assertEqual(len(permissions), 530)
158
+ self.assertEqual(len(views), 49)
159
+ self.assertEqual(len(permissions), 546)
160
160
 
161
161
  def test_permissions_base_command(self):
162
162
  runner = CliRunner()
@@ -171,8 +171,8 @@ class CLITests(TestCase):
171
171
  permissions = PermissionViewRoleModel.get_all_objects().all()
172
172
  self.assertEqual(len(actions), 5)
173
173
  self.assertEqual(len(roles), 4)
174
- self.assertEqual(len(views), 48)
175
- self.assertEqual(len(permissions), 530)
174
+ self.assertEqual(len(views), 49)
175
+ self.assertEqual(len(permissions), 546)
176
176
 
177
177
  def test_service_entrypoint(self):
178
178
  runner = CliRunner()
@@ -24,6 +24,7 @@ from cornflow.tests.const import (
24
24
  LOGIN_URL,
25
25
  SIGNUP_URL,
26
26
  USER_URL,
27
+ EXECUTION_URL,
27
28
  )
28
29
  from cornflow.tests.unit.test_executions import TestExecutionsDetailEndpointMock
29
30
  from cornflow_client import get_pulp_jsonschema, get_empty_schema
@@ -90,13 +91,31 @@ class TestDagDetailEndpoint(TestExecutionsDetailEndpointMock):
90
91
  idx = self.create_new_row(EXECUTION_URL_NORUN, self.model, self.payload)
91
92
  with open(CASE_PATH) as f:
92
93
  payload = json.load(f)
94
+
95
+ log_json = {
96
+ "time": 10.3,
97
+ "solver": "dummy",
98
+ "status": "feasible",
99
+ "status_code": 2,
100
+ "sol_code": 1,
101
+ "some_other_key": "this should be excluded",
102
+ }
103
+
93
104
  data = dict(
94
105
  data=payload["data"],
95
106
  state=EXEC_STATE_CORRECT,
107
+ log_json={
108
+ "time": 10.3,
109
+ "solver": "dummy",
110
+ "status": "feasible",
111
+ "status_code": 2,
112
+ "sol_code": 1,
113
+ "some_other_key": "this should be excluded",
114
+ },
96
115
  )
97
116
  payload_to_check = {**self.payload, **data}
98
117
  token = self.create_service_user()
99
- data = self.update_row(
118
+ self.update_row(
100
119
  url=DAG_URL + idx + "/",
101
120
  payload_to_check=payload_to_check,
102
121
  change=data,
@@ -104,19 +123,46 @@ class TestDagDetailEndpoint(TestExecutionsDetailEndpointMock):
104
123
  check_payload=False,
105
124
  )
106
125
 
126
+ data = self.get_one_row(
127
+ url=EXECUTION_URL + idx + "/log/",
128
+ token=token,
129
+ check_payload=False,
130
+ payload=self.payload,
131
+ expected_status=200,
132
+ )
133
+
134
+ for key in data["log"]:
135
+ self.assertEqual(data["log"][key], log_json[key])
136
+
137
+ self.assertNotIn("some_other_key", data["log"].keys())
138
+
107
139
  def test_get_dag(self):
108
140
  idx = self.create_new_row(EXECUTION_URL_NORUN, self.model, self.payload)
109
141
  token = self.create_service_user()
142
+ keys_to_check = ["id", "data", "solution_data", "config"]
110
143
  data = self.get_one_row(
111
144
  url=DAG_URL + idx + "/",
112
145
  token=token,
113
146
  check_payload=False,
114
147
  payload=self.payload,
148
+ keys_to_check=keys_to_check,
115
149
  )
150
+ keys_to_check = [
151
+ "data",
152
+ "id",
153
+ "schema",
154
+ "data_hash",
155
+ "user_id",
156
+ "description",
157
+ "name",
158
+ "checks",
159
+ "created_at",
160
+ ]
116
161
  instance_data = self.get_one_row(
117
162
  url=INSTANCE_URL + self.payload["instance_id"] + "/data/",
118
163
  payload=dict(),
119
164
  check_payload=False,
165
+ keys_to_check=keys_to_check,
120
166
  )
121
167
  self.assertEqual(data["data"], instance_data["data"])
122
168
  self.assertEqual(data["config"], self.payload["config"])
@@ -130,6 +176,7 @@ class TestDagDetailEndpoint(TestExecutionsDetailEndpointMock):
130
176
  check_payload=False,
131
177
  payload=self.payload,
132
178
  expected_status=403,
179
+ keys_to_check=["error"],
133
180
  )
134
181
 
135
182
 
@@ -1,14 +1,9 @@
1
- """
2
-
3
- """
4
- # General imports
5
1
  import json
6
2
 
7
- # Partial imports
8
- from unittest.mock import patch
9
3
 
4
+ from unittest.mock import patch
10
5
 
11
- # Imports from internal modules
6
+ from cornflow.models import PermissionsDAG
12
7
  from cornflow.tests.const import EXAMPLE_URL, INSTANCE_PATH
13
8
  from cornflow.tests.custom_test_case import CustomTestCase
14
9
 
@@ -22,7 +17,18 @@ class TestExampleDataEndpoint(CustomTestCase):
22
17
  temp = json.load(f)
23
18
  return temp
24
19
 
25
- self.example = load_file(INSTANCE_PATH)
20
+ self.example = [
21
+ {
22
+ "name": "test_example_1",
23
+ "description": "some_description",
24
+ "instance": load_file(INSTANCE_PATH),
25
+ },
26
+ {
27
+ "name": "test_example_2",
28
+ "description": "some_description",
29
+ "instance": load_file(INSTANCE_PATH),
30
+ },
31
+ ]
26
32
  self.url = EXAMPLE_URL
27
33
  self.schema_name = "solve_model_dag"
28
34
 
@@ -37,15 +43,82 @@ class TestExampleDataEndpoint(CustomTestCase):
37
43
  af_client.get_all_schemas.return_value = [{"name": self.schema_name}]
38
44
  return af_client
39
45
 
46
+ def patch_af_client_not_alive(self, Airflow_mock):
47
+ af_client = Airflow_mock.return_value
48
+ af_client.is_alive.return_value = False
49
+ af_client.is_alive.return_value = False
50
+ return af_client
51
+
40
52
  @patch("cornflow.endpoints.example_data.Airflow.from_config")
41
- def test_get_example(self, airflow_init):
53
+ def test_get_list_of_examples(self, airflow_init):
42
54
  af_client = self.patch_af_client(airflow_init)
55
+ examples = self.get_one_row(
56
+ f"{self.url}/{self.schema_name}/",
57
+ {},
58
+ expected_status=200,
59
+ check_payload=False,
60
+ )
61
+
62
+ for pos, item in enumerate(examples):
63
+ self.assertIn("name", item)
64
+ self.assertEqual(self.example[pos]["name"], item["name"])
65
+ self.assertIn("description", item)
66
+ self.assertEqual(self.example[pos]["description"], item["description"])
67
+
68
+ @patch("cornflow.endpoints.example_data.Airflow.from_config")
69
+ def test_get_one_example(self, airflow_init):
70
+ def load_file(_file):
71
+ with open(_file) as f:
72
+ temp = json.load(f)
73
+ return temp
74
+
75
+ af_client = self.patch_af_client(airflow_init)
76
+ keys_to_check = ["name", "examples"]
43
77
  example = self.get_one_row(
44
- self.url + "{}/".format(self.schema_name),
78
+ f"{self.url}/{self.schema_name}/test_example_1/",
45
79
  {},
46
80
  expected_status=200,
47
81
  check_payload=False,
82
+ keys_to_check=keys_to_check,
48
83
  )
49
- self.assertIn("examples", example)
84
+
50
85
  self.assertIn("name", example)
51
- self.assertEqual(example["examples"], self.example)
86
+ self.assertEqual("test_example_1", example["name"])
87
+ self.assertIn("description", example)
88
+ self.assertIn("instance", example)
89
+ self.assertEqual(load_file(INSTANCE_PATH), example["instance"])
90
+
91
+ @patch("cornflow.endpoints.example_data.Airflow.from_config")
92
+ def test_airflow_not_available(self, airflow_init):
93
+ af_client = self.patch_af_client_not_alive(airflow_init)
94
+ self.get_one_row(
95
+ f"{self.url}/{self.schema_name}/test_example_1/",
96
+ {},
97
+ expected_status=400,
98
+ check_payload=False,
99
+ )
100
+
101
+ self.get_one_row(
102
+ f"{self.url}/{self.schema_name}/",
103
+ {},
104
+ expected_status=400,
105
+ check_payload=False,
106
+ )
107
+
108
+ def test_if_no_permission(self):
109
+ with patch.object(
110
+ PermissionsDAG, "check_if_has_permissions", return_value=False
111
+ ) as mock_permission:
112
+ self.get_one_row(
113
+ f"{self.url}/{self.schema_name}/",
114
+ {},
115
+ expected_status=403,
116
+ check_payload=False,
117
+ )
118
+
119
+ self.get_one_row(
120
+ f"{self.url}/{self.schema_name}/test_example_1/",
121
+ {},
122
+ expected_status=403,
123
+ check_payload=False,
124
+ )