cornflow 1.2.4__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. cornflow/cli/service.py +91 -42
  2. cornflow/commands/dag.py +7 -7
  3. cornflow/commands/permissions.py +9 -5
  4. cornflow/config.py +23 -3
  5. cornflow/endpoints/case.py +37 -21
  6. cornflow/endpoints/dag.py +5 -5
  7. cornflow/endpoints/data_check.py +8 -7
  8. cornflow/endpoints/example_data.py +4 -2
  9. cornflow/endpoints/execution.py +215 -127
  10. cornflow/endpoints/health.py +30 -11
  11. cornflow/endpoints/instance.py +3 -3
  12. cornflow/endpoints/login.py +9 -2
  13. cornflow/endpoints/schemas.py +3 -3
  14. cornflow/migrations/versions/999b98e24225.py +34 -0
  15. cornflow/migrations/versions/cef1df240b27_.py +34 -0
  16. cornflow/models/__init__.py +2 -1
  17. cornflow/models/dag.py +8 -9
  18. cornflow/models/dag_permissions.py +3 -3
  19. cornflow/models/execution.py +2 -3
  20. cornflow/models/permissions.py +1 -0
  21. cornflow/models/user.py +1 -1
  22. cornflow/schemas/execution.py +14 -1
  23. cornflow/schemas/health.py +1 -1
  24. cornflow/shared/authentication/auth.py +14 -1
  25. cornflow/shared/authentication/decorators.py +0 -1
  26. cornflow/shared/const.py +44 -1
  27. cornflow/shared/exceptions.py +2 -1
  28. cornflow/tests/base_test_execution.py +798 -0
  29. cornflow/tests/const.py +1 -0
  30. cornflow/tests/integration/test_commands.py +2 -2
  31. cornflow/tests/integration/test_cornflowclient.py +2 -1
  32. cornflow/tests/unit/test_cases.py +1 -1
  33. cornflow/tests/unit/test_commands.py +5 -5
  34. cornflow/tests/unit/test_dags.py +3 -3
  35. cornflow/tests/unit/test_example_data.py +1 -1
  36. cornflow/tests/unit/test_executions.py +115 -535
  37. cornflow/tests/unit/test_health.py +84 -3
  38. cornflow/tests/unit/test_main_alarms.py +1 -1
  39. cornflow/tests/unit/test_roles.py +2 -1
  40. cornflow/tests/unit/test_schema_from_models.py +1 -1
  41. cornflow/tests/unit/test_schemas.py +1 -1
  42. cornflow/tests/unit/tools.py +93 -10
  43. {cornflow-1.2.4.dist-info → cornflow-1.3.0rc1.dist-info}/METADATA +2 -2
  44. {cornflow-1.2.4.dist-info → cornflow-1.3.0rc1.dist-info}/RECORD +47 -44
  45. {cornflow-1.2.4.dist-info → cornflow-1.3.0rc1.dist-info}/WHEEL +0 -0
  46. {cornflow-1.2.4.dist-info → cornflow-1.3.0rc1.dist-info}/entry_points.txt +0 -0
  47. {cornflow-1.2.4.dist-info → cornflow-1.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,798 @@
1
+ # Import from libraries
2
+ import json
3
+ from unittest.mock import patch
4
+
5
+ from cornflow.app import create_app
6
+
7
+ # Import from internal modules
8
+ from cornflow.models import ExecutionModel, InstanceModel
9
+ from cornflow.tests.const import (
10
+ INSTANCE_PATH,
11
+ EXECUTION_PATH,
12
+ EXECUTIONS_LIST,
13
+ EXECUTION_URL,
14
+ EXECUTION_URL_NORUN,
15
+ INSTANCE_URL,
16
+ DAG_URL,
17
+ BAD_EXECUTION_PATH,
18
+ EXECUTION_SOLUTION_PATH,
19
+ EDIT_EXECUTION_SOLUTION,
20
+ CUSTOM_CONFIG_PATH,
21
+ )
22
+ from cornflow.tests.custom_test_case import CustomTestCase, BaseTestCases
23
+ from cornflow.tests.unit.tools import patch_af_client, patch_db_client
24
+ from abc import ABC, abstractmethod
25
+
26
+
27
+ class TestExecutionsDetailEndpointMock(CustomTestCase):
28
+
29
+ def setUp(self):
30
+ super().setUp()
31
+ with open(INSTANCE_PATH) as f:
32
+ payload = json.load(f)
33
+ fk_id = self.create_new_row(INSTANCE_URL, InstanceModel, payload)
34
+ self.instance_payload = payload
35
+ self.model = ExecutionModel
36
+ self.response_items = {
37
+ "id",
38
+ "name",
39
+ "description",
40
+ "created_at",
41
+ "instance_id",
42
+ "data_hash",
43
+ "message",
44
+ "state",
45
+ "config",
46
+ "schema",
47
+ "user_id",
48
+ "indicators",
49
+ "username",
50
+ "updated_at"
51
+ }
52
+ # we only check the following because this endpoint does not return data
53
+ self.items_to_check = ["name", "description"]
54
+ self.url = EXECUTION_URL
55
+ with open(EXECUTION_PATH) as f:
56
+ self.payload = json.load(f)
57
+ self.payload["instance_id"] = fk_id
58
+
59
+ class BaseExecutionList(BaseTestCases.ListFilters, ABC):
60
+
61
+ @property
62
+ @abstractmethod
63
+ def orchestrator_patch_target(self):
64
+ """Must be implemented by concrete classes"""
65
+ pass
66
+
67
+ @property
68
+ @abstractmethod
69
+ def orchestrator_patch_fn(self):
70
+ """Must be implemented by concrete classes"""
71
+ pass
72
+
73
+ def setUp(self):
74
+ super().setUp()
75
+
76
+ with open(INSTANCE_PATH) as f:
77
+ payload = json.load(f)
78
+ fk_id = self.create_new_row(INSTANCE_URL, InstanceModel, payload)
79
+ self.url = EXECUTION_URL_NORUN
80
+ self.model = ExecutionModel
81
+
82
+ def load_file_fk(_file):
83
+ with open(_file) as f:
84
+ temp = json.load(f)
85
+ temp["instance_id"] = fk_id
86
+ return temp
87
+
88
+ self.payload = load_file_fk(EXECUTION_PATH)
89
+ self.bad_payload = load_file_fk(BAD_EXECUTION_PATH)
90
+ self.payloads = [load_file_fk(f) for f in EXECUTIONS_LIST]
91
+ self.solution = load_file_fk(EXECUTION_SOLUTION_PATH)
92
+ self.custom_config_payload = load_file_fk(CUSTOM_CONFIG_PATH)
93
+ self.keys_to_check = [
94
+ "data_hash",
95
+ "created_at",
96
+ "config",
97
+ "state",
98
+ "message",
99
+ "schema",
100
+ "description",
101
+ "id",
102
+ "user_id",
103
+ "log",
104
+ "instance_id",
105
+ "name",
106
+ "indicators",
107
+ "username",
108
+ "first_name",
109
+ "last_name",
110
+ "updated_at",
111
+ ]
112
+
113
+ def patch_orchestrator(self, client_class):
114
+ if self.orchestrator_patch_fn:
115
+ self.orchestrator_patch_fn(client_class)
116
+
117
+ def test_new_execution(self):
118
+ with patch(self.orchestrator_patch_target) as client:
119
+ self.patch_orchestrator(client)
120
+ self.create_new_row(self.url, self.model, payload=self.payload)
121
+
122
+ def test_get_custom_config(self):
123
+ with patch(self.orchestrator_patch_target) as client:
124
+ self.patch_orchestrator(client)
125
+ id = self.create_new_row(
126
+ self.url, self.model, payload=self.custom_config_payload
127
+ )
128
+ url = EXECUTION_URL + "/" + str(id) + "/" + "?run=0"
129
+
130
+ response = self.get_one_row(
131
+ url,
132
+ payload={**self.custom_config_payload, **dict(id=id)},
133
+ )
134
+ self.assertEqual(response["config"]["block_model"]["solver"], "mip.gurobi")
135
+
136
+ def test_new_execution_run(self):
137
+ with patch(self.orchestrator_patch_target) as client:
138
+ self.patch_orchestrator(client)
139
+ self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
140
+
141
+ def test_new_execution_bad_config(self):
142
+ with patch(self.orchestrator_patch_target) as client:
143
+ self.patch_orchestrator(client)
144
+ response = self.create_new_row(
145
+ EXECUTION_URL,
146
+ self.model,
147
+ payload=self.bad_payload,
148
+ expected_status=400,
149
+ check_payload=False,
150
+ )
151
+ self.assertIn("error", response)
152
+ self.assertIn("jsonschema_errors", response)
153
+
154
+ def test_new_execution_partial_config(self):
155
+ with patch(self.orchestrator_patch_target) as client:
156
+ self.patch_orchestrator(client)
157
+ self.payload["config"].pop("solver")
158
+ response = self.create_new_row(
159
+ EXECUTION_URL, self.model, payload=self.payload, check_payload=False
160
+ )
161
+ self.assertIn("solver", response["config"])
162
+ self.assertEqual(response["config"]["solver"], "cbc")
163
+
164
+ def test_new_execution_with_solution(self):
165
+ with patch(self.orchestrator_patch_target) as client:
166
+ self.patch_orchestrator(client)
167
+ self.payload["data"] = self.solution
168
+ response = self.create_new_row(
169
+ EXECUTION_URL,
170
+ self.model,
171
+ payload=self.payload,
172
+ check_payload=False,
173
+ )
174
+
175
+ def test_new_execution_with_solution_bad(self):
176
+ with patch(self.orchestrator_patch_target) as client:
177
+ self.patch_orchestrator(client)
178
+ self.payload["data"] = {"message": "THIS IS NOT A VALID SOLUTION"}
179
+ response = self.create_new_row(
180
+ EXECUTION_URL,
181
+ self.model,
182
+ payload=self.payload,
183
+ check_payload=False,
184
+ expected_status=400,
185
+ )
186
+ self.assertIn("error", response)
187
+ self.assertIn("jsonschema_errors", response)
188
+
189
+ def test_new_execution_no_instance(self):
190
+ with patch(self.orchestrator_patch_target) as client:
191
+ self.patch_orchestrator(client)
192
+ payload = dict(self.payload)
193
+ payload["instance_id"] = "bad_id"
194
+ response = self.client.post(
195
+ self.url,
196
+ data=json.dumps(payload),
197
+ follow_redirects=True,
198
+ headers=self.get_header_with_auth(self.token),
199
+ )
200
+ self.assertEqual(404, response.status_code)
201
+ self.assertTrue("error" in response.json)
202
+
203
+ def test_get_executions(self):
204
+ with patch(self.orchestrator_patch_target) as client:
205
+ self.patch_orchestrator(client)
206
+ self.get_rows(self.url, self.payloads, keys_to_check=self.keys_to_check)
207
+
208
+ def test_get_no_executions(self):
209
+ with patch(self.orchestrator_patch_target) as client:
210
+ self.patch_orchestrator(client)
211
+ self.get_no_rows(self.url)
212
+
213
+ def test_get_executions_superadmin(self):
214
+ with patch(self.orchestrator_patch_target) as client:
215
+ self.patch_orchestrator(client)
216
+ self.get_rows(self.url, self.payloads, keys_to_check=self.keys_to_check)
217
+ token = self.create_service_user()
218
+ rows = self.client.get(
219
+ self.url,
220
+ follow_redirects=True,
221
+ headers=self.get_header_with_auth(token),
222
+ )
223
+ self.assertEqual(len(rows.json), len(self.payloads))
224
+
225
+
226
+ class BaseExecutionRelaunch(CustomTestCase, ABC):
227
+
228
+ @property
229
+ @abstractmethod
230
+ def orchestrator_patch_target(self):
231
+ """Must be implemented by concrete classes"""
232
+ pass
233
+
234
+ @property
235
+ @abstractmethod
236
+ def orchestrator_patch_fn(self):
237
+ """Must be implemented by concrete classes"""
238
+ pass
239
+
240
+ def setUp(self):
241
+ super().setUp()
242
+
243
+ with open(INSTANCE_PATH) as f:
244
+ payload = json.load(f)
245
+ fk_id = self.create_new_row(INSTANCE_URL, InstanceModel, payload)
246
+ self.url = EXECUTION_URL_NORUN
247
+ self.model = ExecutionModel
248
+
249
+ def load_file_fk(_file):
250
+ with open(_file) as f:
251
+ temp = json.load(f)
252
+ temp["instance_id"] = fk_id
253
+ return temp
254
+
255
+ self.payload = load_file_fk(EXECUTION_PATH)
256
+
257
+ def patch_orchestrator(self, client_class):
258
+ if self.orchestrator_patch_fn:
259
+ self.orchestrator_patch_fn(client_class)
260
+
261
+ def test_relaunch_execution(self):
262
+ with patch(self.orchestrator_patch_target) as client:
263
+ self.patch_orchestrator(client)
264
+ idx = self.create_new_row(self.url, self.model, payload=self.payload)
265
+
266
+ # Add solution checks to see if they are deleted correctly
267
+ token = self.create_service_user()
268
+ self.update_row(
269
+ url=DAG_URL + idx + "/",
270
+ payload_to_check=dict(),
271
+ change=dict(solution_schema="_data_checks", checks=dict(check_1=[])),
272
+ token=token,
273
+ check_payload=False,
274
+ )
275
+
276
+ url = EXECUTION_URL + idx + "/relaunch/?run=0"
277
+ self.payload["config"]["warmStart"] = False
278
+ response = self.client.post(
279
+ url,
280
+ data=json.dumps({"config": self.payload["config"]}),
281
+ follow_redirects=True,
282
+ headers=self.get_header_with_auth(self.token),
283
+ )
284
+ self.assertEqual(201, response.status_code)
285
+
286
+ url = EXECUTION_URL + idx + "/data"
287
+ row = self.client.get(
288
+ url,
289
+ follow_redirects=True,
290
+ headers=self.get_header_with_auth(self.token),
291
+ ).json
292
+
293
+ self.assertEqual(row["config"], self.payload["config"])
294
+ self.assertIsNone(row["checks"])
295
+
296
+ def test_relaunch_execution_run(self):
297
+ with patch(self.orchestrator_patch_target) as client:
298
+ self.patch_orchestrator(client)
299
+ idx = self.create_new_row(self.url, self.model, payload=self.payload)
300
+
301
+ # Add solution checks to see if they are deleted correctly
302
+ token = self.create_service_user()
303
+ self.update_row(
304
+ url=DAG_URL + idx + "/",
305
+ payload_to_check=dict(),
306
+ change=dict(solution_schema="_data_checks", checks=dict(check_1=[])),
307
+ token=token,
308
+ check_payload=False,
309
+ )
310
+
311
+ url = EXECUTION_URL + idx + "/relaunch/"
312
+ self.payload["config"]["warmStart"] = False
313
+ response = self.client.post(
314
+ url,
315
+ data=json.dumps({"config": self.payload["config"]}),
316
+ follow_redirects=True,
317
+ headers=self.get_header_with_auth(self.token),
318
+ )
319
+ self.assertEqual(201, response.status_code)
320
+
321
+ url = EXECUTION_URL + idx + "/data"
322
+ row = self.client.get(
323
+ url,
324
+ follow_redirects=True,
325
+ headers=self.get_header_with_auth(self.token),
326
+ ).json
327
+
328
+ self.assertEqual(row["config"], self.payload["config"])
329
+ self.assertIsNone(row["checks"])
330
+
331
+ def test_relaunch_invalid_execution(self):
332
+ with patch(self.orchestrator_patch_target) as client:
333
+ self.patch_orchestrator(client)
334
+ idx = "thisIsAnInvalidExecutionId"
335
+ url = EXECUTION_URL + idx + "/relaunch/?run=0"
336
+ self.payload["config"]["warmStart"] = False
337
+ response = self.client.post(
338
+ url,
339
+ data=json.dumps({"config": self.payload["config"]}),
340
+ follow_redirects=True,
341
+ headers=self.get_header_with_auth(self.token),
342
+ )
343
+ self.assertEqual(404, response.status_code)
344
+
345
+
346
+ class BaseExecutionDetail(BaseTestCases.DetailEndpoint, ABC):
347
+
348
+ @property
349
+ @abstractmethod
350
+ def orchestrator_patch_target(self):
351
+ """Must be implemented by concrete classes"""
352
+ pass
353
+
354
+ @property
355
+ @abstractmethod
356
+ def orchestrator_patch_fn(self):
357
+ """Must be implemented by concrete classes"""
358
+ pass
359
+
360
+ def setUp(self):
361
+ super().setUp()
362
+ with open(INSTANCE_PATH) as f:
363
+ payload = json.load(f)
364
+ fk_id = self.create_new_row(INSTANCE_URL, InstanceModel, payload)
365
+ self.instance_payload = payload
366
+ self.model = ExecutionModel
367
+ self.response_items = {
368
+ "id",
369
+ "name",
370
+ "description",
371
+ "created_at",
372
+ "instance_id",
373
+ "data_hash",
374
+ "message",
375
+ "state",
376
+ "config",
377
+ "schema",
378
+ "user_id",
379
+ "indicators",
380
+ "username",
381
+ "first_name",
382
+ "last_name",
383
+ "updated_at",
384
+ }
385
+ # we only check the following because this endpoint does not return data
386
+ self.items_to_check = ["name", "description"]
387
+ self.url = EXECUTION_URL
388
+ with open(EXECUTION_PATH) as f:
389
+ self.payload = json.load(f)
390
+ self.payload["instance_id"] = fk_id
391
+ self.query_arguments = {"run": 0}
392
+
393
+ def patch_orchestrator(self, client_class):
394
+ if self.orchestrator_patch_fn:
395
+ self.orchestrator_patch_fn(client_class)
396
+
397
+ def test_incomplete_payload2(self):
398
+ payload = {"description": "arg", "instance_id": self.payload["instance_id"]}
399
+ response = self.create_new_row(
400
+ self.url + "?run=0",
401
+ self.model,
402
+ payload,
403
+ expected_status=400,
404
+ check_payload=False,
405
+ )
406
+
407
+ def test_create_delete_instance_load(self):
408
+ idx = self.create_new_row(self.url + "?run=0", self.model, self.payload)
409
+ keys_to_check = [
410
+ "message",
411
+ "id",
412
+ "schema",
413
+ "data_hash",
414
+ "config",
415
+ "instance_id",
416
+ "user_id",
417
+ "indicators",
418
+ "description",
419
+ "name",
420
+ "created_at",
421
+ "state",
422
+ "username",
423
+ "first_name",
424
+ "last_name",
425
+ "updated_at",
426
+ ]
427
+ execution = self.get_one_row(
428
+ self.url + idx,
429
+ payload={**self.payload, **dict(id=idx)},
430
+ keys_to_check=keys_to_check,
431
+ )
432
+ self.delete_row(self.url + idx + "/")
433
+ keys_to_check = [
434
+ "id",
435
+ "schema",
436
+ "description",
437
+ "name",
438
+ "user_id",
439
+ "executions",
440
+ "created_at",
441
+ "data_hash",
442
+ ]
443
+ instance = self.get_one_row(
444
+ INSTANCE_URL + execution["instance_id"] + "/",
445
+ payload={},
446
+ expected_status=200,
447
+ check_payload=False,
448
+ keys_to_check=keys_to_check,
449
+ )
450
+ executions = [execution["id"] for execution in instance["executions"]]
451
+ self.assertFalse(idx in executions)
452
+
453
+ def test_delete_instance_deletes_execution(self):
454
+ # this test should be agnostic of the orchestrator
455
+ with patch(self.orchestrator_patch_target) as client:
456
+ self.patch_orchestrator(client)
457
+ # we create a new instance
458
+ with open(INSTANCE_PATH) as f:
459
+ payload = json.load(f)
460
+ fk_id = self.create_new_row(INSTANCE_URL, InstanceModel, payload)
461
+ payload = {**self.payload, **dict(instance_id=fk_id)}
462
+ # we create an execution for that instance
463
+ idx = self.create_new_row(self.url + "?run=0", self.model, payload)
464
+ self.get_one_row(self.url + idx, payload={**self.payload, **dict(id=idx)})
465
+ # we delete the new instance
466
+ self.delete_row(INSTANCE_URL + fk_id + "/")
467
+ # we check the execution does not exist
468
+ self.get_one_row(
469
+ self.url + idx, payload={}, expected_status=404, check_payload=False
470
+ )
471
+
472
+ def test_update_one_row_data(self):
473
+ with patch(self.orchestrator_patch_target) as client:
474
+ self.patch_orchestrator(client)
475
+ idx = self.create_new_row(
476
+ self.url_with_query_arguments(), self.model, self.payload
477
+ )
478
+ with open(INSTANCE_PATH) as f:
479
+ payload = json.load(f)
480
+ payload["data"]["parameters"]["name"] = "NewName"
481
+
482
+ url = self.url + str(idx) + "/"
483
+ payload = {
484
+ **self.payload,
485
+ **dict(id=idx, name="new_name", data=payload["data"]),
486
+ }
487
+ self.update_row(
488
+ url,
489
+ dict(name="new_name", data=payload["data"]),
490
+ payload,
491
+ )
492
+
493
+ url += "data/"
494
+ row = self.client.get(
495
+ url,
496
+ follow_redirects=True,
497
+ headers=self.get_header_with_auth(self.token),
498
+ )
499
+
500
+ self.assertEqual(row.json["checks"], None)
501
+
502
+ def test_stop_execution(self):
503
+ #! Feature to be implemented for databricks
504
+ with patch(self.orchestrator_patch_target) as client:
505
+ self.patch_orchestrator(client)
506
+ # We only execute this test for airflow
507
+ if (
508
+ self.orchestrator_patch_target
509
+ == "cornflow.endpoints.execution.Databricks"
510
+ ):
511
+ self.skipTest("This feature is not implemented for databricks")
512
+
513
+ idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
514
+
515
+ response = self.client.post(
516
+ self.url + str(idx) + "/",
517
+ follow_redirects=True,
518
+ headers=self.get_header_with_auth(self.token),
519
+ )
520
+
521
+ self.assertEqual(200, response.status_code)
522
+ self.assertEqual(response.json["message"], "The execution has been stopped")
523
+
524
+ def test_edit_execution(self):
525
+ with patch(self.orchestrator_patch_target) as client:
526
+ self.patch_orchestrator(client)
527
+ id_new_instance = self.create_new_row(
528
+ INSTANCE_URL, InstanceModel, self.instance_payload
529
+ )
530
+ idx = self.create_new_row(
531
+ self.url_with_query_arguments(), self.model, self.payload
532
+ )
533
+
534
+ # Extract the data from data/edit_execution_solution.json
535
+ with open(EDIT_EXECUTION_SOLUTION) as f:
536
+ data = json.load(f)
537
+
538
+ data = {
539
+ "name": "new_name",
540
+ "description": "Updated description",
541
+ "data": data,
542
+ "instance_id": id_new_instance,
543
+ }
544
+ payload_to_check = {
545
+ "id": idx,
546
+ "name": "new_name",
547
+ "description": "Updated description",
548
+ "data_hash": "74234e98afe7498fb5daf1f36ac2d78acc339464f950703b8c019892f982b90b",
549
+ "instance_id": "805bad3280c95e45384dc6bd91a41317f9a7858c",
550
+ }
551
+ self.update_row(
552
+ self.url + str(idx) + "/",
553
+ data,
554
+ payload_to_check,
555
+ )
556
+
557
+ def test_get_one_status(self):
558
+ with patch(self.orchestrator_patch_target) as client:
559
+ self.patch_orchestrator(client)
560
+ idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
561
+ payload = dict(self.payload)
562
+ payload["id"] = idx
563
+ keys_to_check = ["state", "message", "id", "data_hash"]
564
+ data = self.get_one_row(
565
+ EXECUTION_URL + idx + "/status/",
566
+ payload,
567
+ check_payload=False,
568
+ keys_to_check=keys_to_check,
569
+ )
570
+ # In the patch we assign success as the state
571
+ self.assertEqual(data["state"], 1)
572
+
573
+ def test_put_one_status(self):
574
+ with patch(self.orchestrator_patch_target) as client:
575
+ self.patch_orchestrator(client)
576
+ idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
577
+ payload = dict(self.payload)
578
+ payload["id"] = idx
579
+ response = self.client.put(
580
+ EXECUTION_URL + idx + "/status/",
581
+ data=json.dumps({"status": 0}),
582
+ follow_redirects=True,
583
+ headers=self.get_header_with_auth(self.token),
584
+ )
585
+
586
+ self.assertEqual(200, response.status_code)
587
+ self.assertEqual(
588
+ f"execution {idx} updated correctly", response.json["message"]
589
+ )
590
+
591
+
592
+ class BaseExecutionData(TestExecutionsDetailEndpointMock, ABC):
593
+ # e.g. "cornflow.endpoints.execution.Airflow"
594
+ orchestrator_patch_target = None
595
+ # e.g. patch_af_client
596
+ orchestrator_patch_fn = None
597
+
598
+ def setUp(self):
599
+ super().setUp()
600
+ self.response_items = {
601
+ "id",
602
+ "name",
603
+ "description",
604
+ "created_at",
605
+ "user_id",
606
+ "data_hash",
607
+ "schema",
608
+ "config",
609
+ "instance_id",
610
+ "state",
611
+ "message",
612
+ "indicators",
613
+ "updated_at",
614
+ "username",
615
+ }
616
+ self.items_to_check = ["name"]
617
+ self.keys_to_check = [
618
+ "created_at",
619
+ "checks",
620
+ "instance_id",
621
+ "schema",
622
+ "data",
623
+ "user_id",
624
+ "message",
625
+ "data_hash",
626
+ "log",
627
+ "config",
628
+ "description",
629
+ "state",
630
+ "name",
631
+ "id",
632
+ ]
633
+
634
+ def patch_orchestrator(self, client_class):
635
+ """Patch the orchestrator client for testing"""
636
+ if self.orchestrator_patch_fn:
637
+ self.orchestrator_patch_fn(client_class)
638
+
639
+ def test_get_one_execution(self):
640
+ with patch(self.orchestrator_patch_target) as client:
641
+ self.patch_orchestrator(client)
642
+ idx = self.create_new_row(EXECUTION_URL_NORUN, self.model, self.payload)
643
+ self.url = EXECUTION_URL + idx + "/data/"
644
+ payload = dict(self.payload)
645
+ payload["id"] = idx
646
+ self.get_one_row(self.url, payload, keys_to_check=self.keys_to_check)
647
+
648
+ def test_get_one_execution_superadmin(self):
649
+ with patch(self.orchestrator_patch_target) as client:
650
+ self.patch_orchestrator(client)
651
+ idx = self.create_new_row(EXECUTION_URL_NORUN, self.model, self.payload)
652
+ payload = dict(self.payload)
653
+ payload["id"] = idx
654
+ token = self.create_service_user()
655
+ self.get_one_row(
656
+ EXECUTION_URL + idx + "/data/",
657
+ payload,
658
+ token=token,
659
+ keys_to_check=self.keys_to_check,
660
+ )
661
+
662
+
663
+ class BaseExecutionLog(BaseExecutionDetail, ABC):
664
+ # e.g. "cornflow.endpoints.execution.Airflow"
665
+ orchestrator_patch_target = None
666
+ # e.g. patch_af_client
667
+ orchestrator_patch_fn = None
668
+
669
+ def setUp(self):
670
+ super().setUp()
671
+ # response_items for the log endpoint specifically
672
+ self.log_response_items = {
673
+ "id",
674
+ "name",
675
+ "description",
676
+ "created_at",
677
+ "user_id",
678
+ "data_hash",
679
+ "schema",
680
+ "config",
681
+ "instance_id",
682
+ "state",
683
+ "message",
684
+ "indicators",
685
+ "updated_at",
686
+ "username",
687
+ "first_name",
688
+ "last_name",
689
+ "log",
690
+ "log_text",
691
+ }
692
+ self.items_to_check = ["name"]
693
+ self.keys_to_check = [
694
+ "created_at",
695
+ "id",
696
+ "log_text",
697
+ "instance_id",
698
+ "state",
699
+ "message",
700
+ "description",
701
+ "data_hash",
702
+ "name",
703
+ "log",
704
+ "schema",
705
+ "user_id",
706
+ "config",
707
+ "indicators",
708
+ "username",
709
+ "first_name",
710
+ "last_name",
711
+ "updated_at",
712
+ ]
713
+
714
+ def test_get_one_execution(self):
715
+ with patch(self.orchestrator_patch_target) as client:
716
+ self.patch_orchestrator(client)
717
+ idx = self.create_new_row(EXECUTION_URL_NORUN, self.model, self.payload)
718
+ payload = dict(self.payload)
719
+ payload["id"] = idx
720
+ self.get_one_row(
721
+ EXECUTION_URL + idx + "/log/", payload, keys_to_check=self.keys_to_check
722
+ )
723
+
724
+ def test_get_one_execution_superadmin(self):
725
+ with patch(self.orchestrator_patch_target) as client:
726
+ self.patch_orchestrator(client)
727
+ idx = self.create_new_row(EXECUTION_URL_NORUN, self.model, self.payload)
728
+ payload = dict(self.payload)
729
+ payload["id"] = idx
730
+ token = self.create_service_user()
731
+ self.get_one_row(
732
+ EXECUTION_URL + idx + "/log/",
733
+ payload,
734
+ token=token,
735
+ keys_to_check=self.keys_to_check,
736
+ )
737
+
738
+
739
+ class BaseExecutionModel(BaseExecutionDetail, ABC):
740
+ # e.g. "cornflow.endpoints.execution.Airflow"
741
+ orchestrator_patch_target = None
742
+ # e.g. patch_af_client
743
+ orchestrator_patch_fn = None
744
+
745
+ def test_repr_method(self):
746
+ with patch(self.orchestrator_patch_target) as client:
747
+ self.patch_orchestrator(client)
748
+ idx = self.create_new_row(self.url + "?run=0", self.model, self.payload)
749
+ self.repr_method(idx, f"<Execution {idx}>")
750
+
751
+ def test_str_method(self):
752
+ with patch(self.orchestrator_patch_target) as client:
753
+ self.patch_orchestrator(client)
754
+ idx = self.create_new_row(self.url + "?run=0", self.model, self.payload)
755
+ self.str_method(idx, f"<Execution {idx}>")
756
+
757
+
758
+ class BaseExecutionStatus(BaseExecutionDetail, ABC):
759
+ # e.g. "cornflow.endpoints.execution.Airflow"
760
+ orchestrator_patch_target = None
761
+ # e.g. patch_af_client
762
+ orchestrator_patch_fn = None
763
+
764
+ def setUp(self):
765
+ super().setUp()
766
+
767
+ def test_get_one_status(self):
768
+ with patch(self.orchestrator_patch_target) as client:
769
+ self.patch_orchestrator(client)
770
+ idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
771
+ payload = dict(self.payload)
772
+ payload["id"] = idx
773
+ keys_to_check = ["state", "message", "id", "data_hash"]
774
+ data = self.get_one_row(
775
+ EXECUTION_URL + idx + "/status/",
776
+ payload,
777
+ check_payload=False,
778
+ keys_to_check=keys_to_check,
779
+ )
780
+ self.assertEqual(data["state"], 1)
781
+
782
+ def test_put_one_status(self):
783
+ with patch(self.orchestrator_patch_target) as client:
784
+ self.patch_orchestrator(client)
785
+ idx = self.create_new_row(EXECUTION_URL, self.model, payload=self.payload)
786
+ payload = dict(self.payload)
787
+ payload["id"] = idx
788
+ response = self.client.put(
789
+ EXECUTION_URL + idx + "/status/",
790
+ data=json.dumps({"status": 0}),
791
+ follow_redirects=True,
792
+ headers=self.get_header_with_auth(self.token),
793
+ )
794
+
795
+ self.assertEqual(200, response.status_code)
796
+ self.assertEqual(
797
+ f"execution {idx} updated correctly", response.json["message"]
798
+ )