arkindex-base-worker 0.3.7rc5__py3-none-any.whl → 0.5.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
  2. arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +1 -1
  6. arkindex_worker/image.py +167 -2
  7. arkindex_worker/models.py +18 -0
  8. arkindex_worker/utils.py +98 -4
  9. arkindex_worker/worker/__init__.py +117 -218
  10. arkindex_worker/worker/base.py +39 -46
  11. arkindex_worker/worker/classification.py +34 -18
  12. arkindex_worker/worker/corpus.py +86 -0
  13. arkindex_worker/worker/dataset.py +89 -26
  14. arkindex_worker/worker/element.py +352 -91
  15. arkindex_worker/worker/entity.py +13 -11
  16. arkindex_worker/worker/image.py +21 -0
  17. arkindex_worker/worker/metadata.py +26 -16
  18. arkindex_worker/worker/process.py +92 -0
  19. arkindex_worker/worker/task.py +5 -4
  20. arkindex_worker/worker/training.py +25 -10
  21. arkindex_worker/worker/transcription.py +89 -68
  22. arkindex_worker/worker/version.py +3 -1
  23. hooks/pre_gen_project.py +3 -0
  24. tests/__init__.py +8 -0
  25. tests/conftest.py +47 -58
  26. tests/test_base_worker.py +212 -12
  27. tests/test_dataset_worker.py +294 -437
  28. tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
  29. tests/test_elements_worker/test_cli.py +3 -11
  30. tests/test_elements_worker/test_corpus.py +168 -0
  31. tests/test_elements_worker/test_dataset.py +106 -157
  32. tests/test_elements_worker/test_element.py +427 -0
  33. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  34. tests/test_elements_worker/test_element_create_single.py +528 -0
  35. tests/test_elements_worker/test_element_list_children.py +969 -0
  36. tests/test_elements_worker/test_element_list_parents.py +530 -0
  37. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  38. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  39. tests/test_elements_worker/test_image.py +66 -0
  40. tests/test_elements_worker/test_metadata.py +252 -161
  41. tests/test_elements_worker/test_process.py +89 -0
  42. tests/test_elements_worker/test_task.py +8 -18
  43. tests/test_elements_worker/test_training.py +17 -8
  44. tests/test_elements_worker/test_transcription_create.py +873 -0
  45. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  46. tests/test_elements_worker/test_transcription_list.py +450 -0
  47. tests/test_elements_worker/test_version.py +60 -0
  48. tests/test_elements_worker/test_worker.py +578 -293
  49. tests/test_image.py +542 -209
  50. tests/test_merge.py +1 -2
  51. tests/test_utils.py +89 -4
  52. worker-demo/tests/__init__.py +0 -0
  53. worker-demo/tests/conftest.py +32 -0
  54. worker-demo/tests/test_worker.py +12 -0
  55. worker-demo/worker_demo/__init__.py +6 -0
  56. worker-demo/worker_demo/worker.py +19 -0
  57. arkindex_base_worker-0.3.7rc5.dist-info/RECORD +0 -41
  58. tests/test_elements_worker/test_elements.py +0 -2713
  59. tests/test_elements_worker/test_transcriptions.py +0 -2119
  60. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
@@ -1,13 +1,62 @@
1
1
  import logging
2
+ import uuid
3
+ from argparse import ArgumentTypeError
2
4
 
3
5
  import pytest
4
- from apistar.exceptions import ErrorResponse
5
6
 
6
- from arkindex_worker.worker import MissingDatasetArchive
7
- from arkindex_worker.worker.dataset import DatasetState
8
- from tests.conftest import FIXTURES_DIR, PROCESS_ID
7
+ from arkindex.exceptions import ErrorResponse
8
+ from arkindex_worker.models import Dataset, Set
9
+ from arkindex_worker.worker.dataset import (
10
+ DatasetState,
11
+ MissingDatasetArchive,
12
+ check_dataset_set,
13
+ )
14
+ from tests import FIXTURES_DIR, PROCESS_ID
9
15
  from tests.test_elements_worker import BASE_API_CALLS
10
16
 
17
+ RANDOM_UUID = uuid.uuid4()
18
+
19
+
20
+ @pytest.fixture
21
+ def tmp_archive(tmp_path):
22
+ archive = tmp_path / "test_archive.tar.zst"
23
+ archive.touch()
24
+
25
+ yield archive
26
+
27
+ archive.unlink(missing_ok=True)
28
+
29
+
30
+ @pytest.mark.parametrize(
31
+ ("value", "error"),
32
+ [("train", ""), (f"{RANDOM_UUID}:train:val", ""), ("not_uuid:train", "")],
33
+ )
34
+ def test_check_dataset_set_errors(value, error):
35
+ with pytest.raises(ArgumentTypeError, match=error):
36
+ check_dataset_set(value)
37
+
38
+
39
+ def test_check_dataset_set():
40
+ assert check_dataset_set(f"{RANDOM_UUID}:train") == (RANDOM_UUID, "train")
41
+
42
+
43
+ def test_cleanup_downloaded_artifact_no_download(mock_dataset_worker):
44
+ assert not mock_dataset_worker.downloaded_dataset_artifact
45
+ # Do nothing
46
+ mock_dataset_worker.cleanup_downloaded_artifact()
47
+
48
+
49
+ def test_cleanup_downloaded_artifact(mock_dataset_worker, tmp_archive):
50
+ mock_dataset_worker.downloaded_dataset_artifact = tmp_archive
51
+
52
+ assert mock_dataset_worker.downloaded_dataset_artifact.exists()
53
+ # Unlink the downloaded archive
54
+ mock_dataset_worker.cleanup_downloaded_artifact()
55
+ assert not mock_dataset_worker.downloaded_dataset_artifact.exists()
56
+
57
+ # Unlinking again does not raise an error even if the archive no longer exists
58
+ mock_dataset_worker.cleanup_downloaded_artifact()
59
+
11
60
 
12
61
  def test_download_dataset_artifact_list_api_error(
13
62
  responses, mock_dataset_worker, default_dataset
@@ -17,22 +66,17 @@ def test_download_dataset_artifact_list_api_error(
17
66
  responses.add(
18
67
  responses.GET,
19
68
  f"http://testserver/api/v1/task/{task_id}/artifacts/",
20
- status=500,
69
+ status=418,
21
70
  )
22
71
 
23
72
  with pytest.raises(ErrorResponse):
24
73
  mock_dataset_worker.download_dataset_artifact(default_dataset)
25
74
 
26
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
75
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
27
76
  assert [
28
77
  (call.request.method, call.request.url) for call in responses.calls
29
78
  ] == BASE_API_CALLS + [
30
- # The API call is retried 5 times
31
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
32
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
33
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
34
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
35
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
79
+ ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/")
36
80
  ]
37
81
 
38
82
 
@@ -70,22 +114,17 @@ def test_download_dataset_artifact_download_api_error(
70
114
  responses.add(
71
115
  responses.GET,
72
116
  f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst",
73
- status=500,
117
+ status=418,
74
118
  )
75
119
 
76
120
  with pytest.raises(ErrorResponse):
77
121
  mock_dataset_worker.download_dataset_artifact(default_dataset)
78
122
 
79
- assert len(responses.calls) == len(BASE_API_CALLS) + 6
123
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
80
124
  assert [
81
125
  (call.request.method, call.request.url) for call in responses.calls
82
126
  ] == BASE_API_CALLS + [
83
127
  ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
84
- # The API call is retried 5 times
85
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
86
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
87
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
88
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
89
128
  ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
90
129
  ]
91
130
 
@@ -127,8 +166,15 @@ def test_download_dataset_artifact_no_archive(
127
166
  ]
128
167
 
129
168
 
169
+ @pytest.mark.parametrize("downloaded_cache", [False, True])
130
170
  def test_download_dataset_artifact(
131
- mocker, tmp_path, responses, mock_dataset_worker, default_dataset
171
+ mocker,
172
+ tmp_path,
173
+ responses,
174
+ mock_dataset_worker,
175
+ default_dataset,
176
+ downloaded_cache,
177
+ tmp_archive,
132
178
  ):
133
179
  task_id = default_dataset.task_id
134
180
  archive_path = (
@@ -176,10 +222,25 @@ def test_download_dataset_artifact(
176
222
  content_type="application/zstd",
177
223
  )
178
224
 
179
- archive = mock_dataset_worker.download_dataset_artifact(default_dataset)
180
- assert archive == tmp_path / "dataset_id.tar.zst"
181
- assert archive.read_bytes() == archive_path.read_bytes()
182
- archive.unlink()
225
+ if downloaded_cache:
226
+ mock_dataset_worker.downloaded_dataset_artifact = tmp_archive
227
+ previous_artifact = mock_dataset_worker.downloaded_dataset_artifact
228
+
229
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
230
+
231
+ # We removed the artifact that was downloaded previously
232
+ if previous_artifact:
233
+ assert not previous_artifact.exists()
234
+
235
+ assert (
236
+ mock_dataset_worker.downloaded_dataset_artifact
237
+ == tmp_path / "dataset_id.tar.zst"
238
+ )
239
+ assert (
240
+ mock_dataset_worker.downloaded_dataset_artifact.read_bytes()
241
+ == archive_path.read_bytes()
242
+ )
243
+ mock_dataset_worker.downloaded_dataset_artifact.unlink()
183
244
 
184
245
  assert len(responses.calls) == len(BASE_API_CALLS) + 2
185
246
  assert [
@@ -190,189 +251,107 @@ def test_download_dataset_artifact(
190
251
  ]
191
252
 
192
253
 
193
- def test_list_dataset_elements_per_split_api_error(
194
- responses, mock_dataset_worker, default_dataset
254
+ def test_download_dataset_artifact_already_exists(
255
+ mocker, tmp_path, responses, mock_dataset_worker, default_dataset
195
256
  ):
196
- responses.add(
197
- responses.GET,
198
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
199
- status=500,
257
+ mocker.patch(
258
+ "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
259
+ return_value=tmp_path,
200
260
  )
261
+ already_downloaded = tmp_path / "dataset_id.tar.zst"
262
+ already_downloaded.write_bytes(b"Some content")
263
+ mock_dataset_worker.downloaded_dataset_artifact = already_downloaded
201
264
 
202
- with pytest.raises(
203
- Exception, match="Stopping pagination as data will be incomplete"
204
- ):
205
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
206
-
207
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
208
- assert [
209
- (call.request.method, call.request.url) for call in responses.calls
210
- ] == BASE_API_CALLS + [
211
- # The API call is retried 5 times
212
- (
213
- "GET",
214
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
215
- ),
216
- (
217
- "GET",
218
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
219
- ),
220
- (
221
- "GET",
222
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
223
- ),
224
- (
225
- "GET",
226
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
227
- ),
228
- (
229
- "GET",
230
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
231
- ),
232
- ]
233
-
265
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
234
266
 
235
- def test_list_dataset_elements_per_split(
236
- responses, mock_dataset_worker, default_dataset
237
- ):
238
- expected_results = []
239
- for selected_set in default_dataset.selected_sets:
240
- index = selected_set[-1]
241
- expected_results.append(
242
- {
243
- "set": selected_set,
244
- "element": {
245
- "id": str(index) * 4,
246
- "type": "page",
247
- "name": f"Test {index}",
248
- "corpus": {},
249
- "thumbnail_url": None,
250
- "zone": {},
251
- "best_classes": None,
252
- "has_children": None,
253
- "worker_version_id": None,
254
- "worker_run_id": None,
255
- },
256
- }
257
- )
258
- responses.add(
259
- responses.GET,
260
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
261
- status=200,
262
- json={
263
- "count": 1,
264
- "next": None,
265
- "results": [expected_results[-1]],
266
- },
267
- )
268
-
269
- assert list(
270
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
271
- ) == [
272
- ("set_1", [expected_results[0]["element"]]),
273
- ("set_2", [expected_results[1]["element"]]),
274
- ("set_3", [expected_results[2]["element"]]),
275
- ]
267
+ assert mock_dataset_worker.downloaded_dataset_artifact == already_downloaded
268
+ already_downloaded.unlink()
276
269
 
277
- assert len(responses.calls) == len(BASE_API_CALLS) + 3
270
+ assert len(responses.calls) == len(BASE_API_CALLS)
278
271
  assert [
279
272
  (call.request.method, call.request.url) for call in responses.calls
280
- ] == BASE_API_CALLS + [
281
- (
282
- "GET",
283
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
284
- ),
285
- (
286
- "GET",
287
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_2&with_count=true",
288
- ),
289
- (
290
- "GET",
291
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_3&with_count=true",
292
- ),
293
- ]
273
+ ] == BASE_API_CALLS
294
274
 
295
275
 
296
- def test_list_datasets_read_only(mock_dev_dataset_worker):
297
- assert list(mock_dev_dataset_worker.list_datasets()) == [
298
- "11111111-1111-1111-1111-111111111111",
299
- "22222222-2222-2222-2222-222222222222",
300
- ]
301
-
302
-
303
- def test_list_datasets_api_error(responses, mock_dataset_worker):
276
+ def test_list_sets_api_error(responses, mock_dataset_worker):
304
277
  responses.add(
305
278
  responses.GET,
306
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
307
- status=500,
279
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
280
+ status=418,
308
281
  )
309
282
 
310
283
  with pytest.raises(
311
284
  Exception, match="Stopping pagination as data will be incomplete"
312
285
  ):
313
- next(mock_dataset_worker.list_datasets())
286
+ next(mock_dataset_worker.list_sets())
314
287
 
315
288
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
316
289
  assert [
317
290
  (call.request.method, call.request.url) for call in responses.calls
318
291
  ] == BASE_API_CALLS + [
319
292
  # The API call is retried 5 times
320
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
321
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
322
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
323
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
324
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
293
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
294
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
295
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
296
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
297
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
325
298
  ]
326
299
 
327
300
 
328
- def test_list_datasets(responses, mock_dataset_worker):
301
+ def test_list_sets(responses, mock_dataset_worker):
329
302
  expected_results = [
330
303
  {
331
- "id": "process_dataset_1",
304
+ "id": "set_1",
332
305
  "dataset": {
333
306
  "id": "dataset_1",
334
307
  "name": "Dataset 1",
335
308
  "description": "My first great dataset",
336
- "sets": ["train", "val", "test"],
309
+ "sets": [
310
+ {"id": "set_1", "name": "train"},
311
+ {"id": "set_2", "name": "val"},
312
+ ],
337
313
  "state": "open",
338
314
  "corpus_id": "corpus_id",
339
315
  "creator": "test@teklia.com",
340
316
  "task_id": "task_id_1",
341
317
  },
342
- "sets": ["test"],
318
+ "set_name": "train",
343
319
  },
344
320
  {
345
- "id": "process_dataset_2",
321
+ "id": "set_2",
346
322
  "dataset": {
347
- "id": "dataset_2",
348
- "name": "Dataset 2",
349
- "description": "My second great dataset",
350
- "sets": ["train", "val"],
351
- "state": "complete",
323
+ "id": "dataset_1",
324
+ "name": "Dataset 1",
325
+ "description": "My first great dataset",
326
+ "sets": [
327
+ {"id": "set_1", "name": "train"},
328
+ {"id": "set_2", "name": "val"},
329
+ ],
330
+ "state": "open",
352
331
  "corpus_id": "corpus_id",
353
332
  "creator": "test@teklia.com",
354
- "task_id": "task_id_2",
333
+ "task_id": "task_id_1",
355
334
  },
356
- "sets": ["train", "val"],
335
+ "set_name": "val",
357
336
  },
358
337
  {
359
- "id": "process_dataset_3",
338
+ "id": "set_3",
360
339
  "dataset": {
361
- "id": "dataset_3",
362
- "name": "Dataset 3 (TRASHME)",
363
- "description": "My third dataset, in error",
364
- "sets": ["nonsense", "random set"],
365
- "state": "error",
340
+ "id": "dataset_2",
341
+ "name": "Dataset 2",
342
+ "description": "My second great dataset",
343
+ "sets": [{"id": "set_3", "name": "my_set"}],
344
+ "state": "complete",
366
345
  "corpus_id": "corpus_id",
367
346
  "creator": "test@teklia.com",
368
- "task_id": "task_id_3",
347
+ "task_id": "task_id_2",
369
348
  },
370
- "sets": ["random set"],
349
+ "set_name": "my_set",
371
350
  },
372
351
  ]
373
352
  responses.add(
374
353
  responses.GET,
375
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
354
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
376
355
  status=200,
377
356
  json={
378
357
  "count": 3,
@@ -381,154 +360,104 @@ def test_list_datasets(responses, mock_dataset_worker):
381
360
  },
382
361
  )
383
362
 
384
- for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
385
- assert dataset == {
386
- **expected_results[idx]["dataset"],
387
- "selected_sets": expected_results[idx]["sets"],
388
- }
363
+ for idx, dataset_set in enumerate(mock_dataset_worker.list_process_sets()):
364
+ assert isinstance(dataset_set, Set)
365
+ assert dataset_set.name == expected_results[idx]["set_name"]
366
+
367
+ assert isinstance(dataset_set.dataset, Dataset)
368
+ assert dataset_set.dataset == expected_results[idx]["dataset"]
389
369
 
390
370
  assert len(responses.calls) == len(BASE_API_CALLS) + 1
391
371
  assert [
392
372
  (call.request.method, call.request.url) for call in responses.calls
393
373
  ] == BASE_API_CALLS + [
394
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
374
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
395
375
  ]
396
376
 
397
377
 
398
- @pytest.mark.parametrize("generator", [True, False])
399
- def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator):
400
- mocker.patch("arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[])
401
- mock_dataset_worker.generator = generator
378
+ def test_list_sets_retrieve_dataset_api_error(
379
+ responses, mock_dev_dataset_worker, default_dataset
380
+ ):
381
+ mock_dev_dataset_worker.args.set = [
382
+ (default_dataset.id, "train"),
383
+ (default_dataset.id, "val"),
384
+ ]
402
385
 
403
- with pytest.raises(SystemExit):
404
- mock_dataset_worker.run()
386
+ responses.add(
387
+ responses.GET,
388
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
389
+ status=418,
390
+ )
405
391
 
406
- assert [(level, message) for _, level, message in caplog.record_tuples] == [
407
- (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
408
- (logging.WARNING, "No datasets to process, stopping."),
392
+ with pytest.raises(ErrorResponse):
393
+ next(mock_dev_dataset_worker.list_sets())
394
+
395
+ assert len(responses.calls) == 1
396
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
397
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
409
398
  ]
410
399
 
411
400
 
412
- @pytest.mark.parametrize(
413
- ("generator", "error"),
414
- [
415
- (True, "When generating a new dataset, its state should be Open or Error."),
416
- (False, "When processing an existing dataset, its state should be Complete."),
417
- ],
418
- )
419
- def test_run_initial_dataset_state_error(
420
- mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, error
421
- ):
422
- default_dataset.state = DatasetState.Building.value
423
- mocker.patch(
424
- "arkindex_worker.worker.DatasetWorker.list_datasets",
425
- return_value=[default_dataset],
426
- )
427
- mock_dataset_worker.generator = generator
428
-
429
- extra_call = []
430
- if generator:
431
- responses.add(
432
- responses.PATCH,
433
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
434
- status=200,
435
- json={},
436
- )
437
- extra_call = [
438
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
439
- ]
401
+ def test_list_sets_read_only(responses, mock_dev_dataset_worker, default_dataset):
402
+ mock_dev_dataset_worker.args.set = [
403
+ (default_dataset.id, "train"),
404
+ (default_dataset.id, "val"),
405
+ ]
406
+
407
+ responses.add(
408
+ responses.GET,
409
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
410
+ status=200,
411
+ json=default_dataset,
412
+ )
413
+
414
+ assert list(mock_dev_dataset_worker.list_sets()) == [
415
+ Set(name="train", dataset=default_dataset),
416
+ Set(name="val", dataset=default_dataset),
417
+ ]
418
+
419
+ assert len(responses.calls) == 1
420
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
421
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
422
+ ]
423
+
424
+
425
+ def test_run_no_sets(mocker, caplog, mock_dataset_worker):
426
+ mocker.patch("arkindex_worker.worker.DatasetWorker.list_sets", return_value=[])
440
427
 
441
428
  with pytest.raises(SystemExit):
442
429
  mock_dataset_worker.run()
443
430
 
444
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_call)
445
- assert [
446
- (call.request.method, call.request.url) for call in responses.calls
447
- ] == BASE_API_CALLS * 2 + extra_call
448
-
449
431
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
450
432
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
451
- (
452
- logging.WARNING,
453
- f"Failed running worker on dataset dataset_id: AssertionError('{error}')",
454
- ),
455
- ] + (
456
- [
457
- (
458
- logging.WARNING,
459
- "This API helper `update_dataset_state` did not update the cache database",
460
- )
461
- ]
462
- if generator
463
- else []
464
- ) + [
465
- (logging.ERROR, "Ran on 1 dataset: 0 completed, 1 failed"),
433
+ (logging.WARNING, "No sets to process, stopping."),
466
434
  ]
467
435
 
468
436
 
469
- def test_run_update_dataset_state_api_error(
437
+ def test_run_initial_dataset_state_error(
470
438
  mocker, responses, caplog, mock_dataset_worker, default_dataset
471
439
  ):
440
+ default_dataset.state = DatasetState.Building.value
472
441
  mocker.patch(
473
- "arkindex_worker.worker.DatasetWorker.list_datasets",
474
- return_value=[default_dataset],
475
- )
476
- mock_dataset_worker.generator = True
477
-
478
- responses.add(
479
- responses.PATCH,
480
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
481
- status=500,
442
+ "arkindex_worker.worker.DatasetWorker.list_sets",
443
+ return_value=[Set(name="train", dataset=default_dataset)],
482
444
  )
483
445
 
484
446
  with pytest.raises(SystemExit):
485
447
  mock_dataset_worker.run()
486
448
 
487
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 10
449
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2
488
450
  assert [
489
451
  (call.request.method, call.request.url) for call in responses.calls
490
- ] == BASE_API_CALLS * 2 + [
491
- # We retry 5 times the API call to update the Dataset as Building
492
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
493
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
494
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
495
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
496
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
497
- # We retry 5 times the API call to update the Dataset as in Error
498
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
499
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
500
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
501
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
502
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
503
- ]
452
+ ] == BASE_API_CALLS * 2
504
453
 
505
- retries = [3.0, 4.0, 8.0, 16.0]
506
454
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
507
455
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
508
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
509
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
510
- *[
511
- (
512
- logging.INFO,
513
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
514
- )
515
- for retry in retries
516
- ],
517
456
  (
518
457
  logging.WARNING,
519
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
520
- ),
521
- *[
522
- (
523
- logging.INFO,
524
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
525
- )
526
- for retry in retries
527
- ],
528
- (
529
- logging.ERROR,
530
- "Ran on 1 dataset: 0 completed, 1 failed",
458
+ "Failed running worker on Set (train) from Dataset (dataset_id): AssertionError('When processing a set, its dataset state should be Complete.')",
531
459
  ),
460
+ (logging.ERROR, "Ran on 1 set: 0 completed, 1 failed"),
532
461
  ]
533
462
 
534
463
 
@@ -541,10 +470,9 @@ def test_run_download_dataset_artifact_api_error(
541
470
  default_dataset,
542
471
  ):
543
472
  default_dataset.state = DatasetState.Complete.value
544
-
545
473
  mocker.patch(
546
- "arkindex_worker.worker.DatasetWorker.list_datasets",
547
- return_value=[default_dataset],
474
+ "arkindex_worker.worker.DatasetWorker.list_sets",
475
+ return_value=[Set(name="train", dataset=default_dataset)],
548
476
  )
549
477
  mocker.patch(
550
478
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -554,47 +482,38 @@ def test_run_download_dataset_artifact_api_error(
554
482
  responses.add(
555
483
  responses.GET,
556
484
  f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
557
- status=500,
485
+ status=418,
558
486
  )
559
487
 
560
488
  with pytest.raises(SystemExit):
561
489
  mock_dataset_worker.run()
562
490
 
563
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 5
491
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 1
564
492
  assert [
565
493
  (call.request.method, call.request.url) for call in responses.calls
566
494
  ] == BASE_API_CALLS * 2 + [
567
- # We retry 5 times the API call
568
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
569
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
570
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
571
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
572
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
495
+ ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/")
573
496
  ]
574
497
 
575
498
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
576
499
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
577
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
578
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
579
- *[
580
- (
581
- logging.INFO,
582
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
583
- )
584
- for retry in [3.0, 4.0, 8.0, 16.0]
585
- ],
500
+ (
501
+ logging.INFO,
502
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
503
+ ),
504
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
586
505
  (
587
506
  logging.WARNING,
588
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
507
+ "An API error occurred while processing Set (train) from Dataset (dataset_id): 418 I'm a Teapot - None",
589
508
  ),
590
509
  (
591
510
  logging.ERROR,
592
- "Ran on 1 dataset: 0 completed, 1 failed",
511
+ "Ran on 1 set: 0 completed, 1 failed",
593
512
  ),
594
513
  ]
595
514
 
596
515
 
597
- def test_run_no_downloaded_artifact_error(
516
+ def test_run_no_downloaded_dataset_artifact_error(
598
517
  mocker,
599
518
  tmp_path,
600
519
  responses,
@@ -603,10 +522,9 @@ def test_run_no_downloaded_artifact_error(
603
522
  default_dataset,
604
523
  ):
605
524
  default_dataset.state = DatasetState.Complete.value
606
-
607
525
  mocker.patch(
608
- "arkindex_worker.worker.DatasetWorker.list_datasets",
609
- return_value=[default_dataset],
526
+ "arkindex_worker.worker.DatasetWorker.list_sets",
527
+ return_value=[Set(name="train", dataset=default_dataset)],
610
528
  )
611
529
  mocker.patch(
612
530
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -632,22 +550,22 @@ def test_run_no_downloaded_artifact_error(
632
550
 
633
551
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
634
552
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
635
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
636
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
553
+ (
554
+ logging.INFO,
555
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
556
+ ),
557
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
637
558
  (
638
559
  logging.WARNING,
639
- "Failed running worker on dataset dataset_id: MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
560
+ "Failed running worker on Set (train) from Dataset (dataset_id): MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
640
561
  ),
641
562
  (
642
563
  logging.ERROR,
643
- "Ran on 1 dataset: 0 completed, 1 failed",
564
+ "Ran on 1 set: 0 completed, 1 failed",
644
565
  ),
645
566
  ]
646
567
 
647
568
 
648
- @pytest.mark.parametrize(
649
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
650
- )
651
569
  def test_run(
652
570
  mocker,
653
571
  tmp_path,
@@ -656,100 +574,68 @@ def test_run(
656
574
  mock_dataset_worker,
657
575
  default_dataset,
658
576
  default_artifact,
659
- generator,
660
- state,
661
577
  ):
662
- mock_dataset_worker.generator = generator
663
- default_dataset.state = state.value
664
-
578
+ default_dataset.state = DatasetState.Complete.value
665
579
  mocker.patch(
666
- "arkindex_worker.worker.DatasetWorker.list_datasets",
667
- return_value=[default_dataset],
580
+ "arkindex_worker.worker.DatasetWorker.list_sets",
581
+ return_value=[Set(name="train", dataset=default_dataset)],
668
582
  )
669
583
  mocker.patch(
670
584
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
671
585
  return_value=tmp_path,
672
586
  )
673
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
674
-
675
- extra_calls = []
676
- extra_logs = []
677
- if generator:
678
- responses.add(
679
- responses.PATCH,
680
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
681
- status=200,
682
- json={},
683
- )
684
- extra_calls += [
685
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
686
- ] * 2
687
- extra_logs += [
688
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
689
- (
690
- logging.WARNING,
691
- "This API helper `update_dataset_state` did not update the cache database",
692
- ),
693
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
694
- (
695
- logging.WARNING,
696
- "This API helper `update_dataset_state` did not update the cache database",
697
- ),
698
- ]
699
- else:
700
- archive_path = (
701
- FIXTURES_DIR
702
- / "extract_parent_archives"
703
- / "first_parent"
704
- / "arkindex_data.tar.zst"
705
- )
706
- responses.add(
707
- responses.GET,
708
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
709
- status=200,
710
- json=[default_artifact],
711
- )
712
- responses.add(
713
- responses.GET,
714
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
715
- status=200,
716
- body=archive_path.read_bytes(),
717
- content_type="application/zstd",
718
- )
719
- extra_calls += [
720
- (
721
- "GET",
722
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
723
- ),
724
- (
725
- "GET",
726
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
727
- ),
728
- ]
729
- extra_logs += [
730
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
731
- ]
587
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
588
+
589
+ archive_path = (
590
+ FIXTURES_DIR
591
+ / "extract_parent_archives"
592
+ / "first_parent"
593
+ / "arkindex_data.tar.zst"
594
+ )
595
+ responses.add(
596
+ responses.GET,
597
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
598
+ status=200,
599
+ json=[default_artifact],
600
+ )
601
+ responses.add(
602
+ responses.GET,
603
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
604
+ status=200,
605
+ body=archive_path.read_bytes(),
606
+ content_type="application/zstd",
607
+ )
732
608
 
733
609
  mock_dataset_worker.run()
734
610
 
735
611
  assert mock_process.call_count == 1
736
612
 
737
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_calls)
613
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 2
738
614
  assert [
739
615
  (call.request.method, call.request.url) for call in responses.calls
740
- ] == BASE_API_CALLS * 2 + extra_calls
616
+ ] == BASE_API_CALLS * 2 + [
617
+ (
618
+ "GET",
619
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
620
+ ),
621
+ (
622
+ "GET",
623
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
624
+ ),
625
+ ]
741
626
 
742
627
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
743
628
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
744
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
745
- *extra_logs,
746
- (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
629
+ (
630
+ logging.INFO,
631
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
632
+ ),
633
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
634
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
635
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
747
636
  ]
748
637
 
749
638
 
750
- @pytest.mark.parametrize(
751
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
752
- )
753
639
  def test_run_read_only(
754
640
  mocker,
755
641
  tmp_path,
@@ -758,90 +644,61 @@ def test_run_read_only(
758
644
  mock_dev_dataset_worker,
759
645
  default_dataset,
760
646
  default_artifact,
761
- generator,
762
- state,
763
647
  ):
764
- mock_dev_dataset_worker.generator = generator
765
- default_dataset.state = state.value
766
-
648
+ default_dataset.state = DatasetState.Complete.value
767
649
  mocker.patch(
768
- "arkindex_worker.worker.DatasetWorker.list_datasets",
769
- return_value=[default_dataset.id],
650
+ "arkindex_worker.worker.DatasetWorker.list_sets",
651
+ return_value=[Set(name="train", dataset=default_dataset)],
770
652
  )
771
653
  mocker.patch(
772
654
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
773
655
  return_value=tmp_path,
774
656
  )
775
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
657
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
776
658
 
659
+ archive_path = (
660
+ FIXTURES_DIR
661
+ / "extract_parent_archives"
662
+ / "first_parent"
663
+ / "arkindex_data.tar.zst"
664
+ )
777
665
  responses.add(
778
666
  responses.GET,
779
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
667
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
780
668
  status=200,
781
- json=default_dataset,
669
+ json=[default_artifact],
670
+ )
671
+ responses.add(
672
+ responses.GET,
673
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
674
+ status=200,
675
+ body=archive_path.read_bytes(),
676
+ content_type="application/zstd",
782
677
  )
783
-
784
- extra_calls = []
785
- extra_logs = []
786
- if generator:
787
- extra_logs += [
788
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
789
- (
790
- logging.WARNING,
791
- "Cannot update dataset as this worker is in read-only mode",
792
- ),
793
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
794
- (
795
- logging.WARNING,
796
- "Cannot update dataset as this worker is in read-only mode",
797
- ),
798
- ]
799
- else:
800
- archive_path = (
801
- FIXTURES_DIR
802
- / "extract_parent_archives"
803
- / "first_parent"
804
- / "arkindex_data.tar.zst"
805
- )
806
- responses.add(
807
- responses.GET,
808
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
809
- status=200,
810
- json=[default_artifact],
811
- )
812
- responses.add(
813
- responses.GET,
814
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
815
- status=200,
816
- body=archive_path.read_bytes(),
817
- content_type="application/zstd",
818
- )
819
- extra_calls += [
820
- (
821
- "GET",
822
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
823
- ),
824
- (
825
- "GET",
826
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
827
- ),
828
- ]
829
- extra_logs += [
830
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
831
- ]
832
678
 
833
679
  mock_dev_dataset_worker.run()
834
680
 
835
681
  assert mock_process.call_count == 1
836
682
 
837
- assert len(responses.calls) == 1 + len(extra_calls)
683
+ assert len(responses.calls) == 2
838
684
  assert [(call.request.method, call.request.url) for call in responses.calls] == [
839
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
840
- ] + extra_calls
685
+ (
686
+ "GET",
687
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
688
+ ),
689
+ (
690
+ "GET",
691
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
692
+ ),
693
+ ]
841
694
 
842
695
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
843
696
  (logging.WARNING, "Running without any extra configuration"),
844
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
845
- *extra_logs,
846
- (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
697
+ (
698
+ logging.INFO,
699
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
700
+ ),
701
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
702
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
703
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
847
704
  ]