arkindex-base-worker 0.3.6rc5__py3-none-any.whl → 0.3.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/METADATA +14 -16
  2. arkindex_base_worker-0.3.7.post1.dist-info/RECORD +47 -0
  3. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +14 -0
  6. arkindex_worker/image.py +29 -19
  7. arkindex_worker/models.py +14 -2
  8. arkindex_worker/utils.py +17 -3
  9. arkindex_worker/worker/__init__.py +122 -125
  10. arkindex_worker/worker/base.py +25 -45
  11. arkindex_worker/worker/classification.py +18 -25
  12. arkindex_worker/worker/dataset.py +24 -18
  13. arkindex_worker/worker/element.py +45 -6
  14. arkindex_worker/worker/entity.py +35 -4
  15. arkindex_worker/worker/metadata.py +21 -11
  16. arkindex_worker/worker/training.py +16 -0
  17. arkindex_worker/worker/transcription.py +45 -5
  18. arkindex_worker/worker/version.py +22 -0
  19. hooks/pre_gen_project.py +3 -0
  20. tests/conftest.py +15 -7
  21. tests/test_base_worker.py +0 -6
  22. tests/test_dataset_worker.py +292 -410
  23. tests/test_elements_worker/test_classifications.py +365 -539
  24. tests/test_elements_worker/test_cli.py +1 -1
  25. tests/test_elements_worker/test_dataset.py +97 -116
  26. tests/test_elements_worker/test_elements.py +227 -61
  27. tests/test_elements_worker/test_entities.py +22 -2
  28. tests/test_elements_worker/test_metadata.py +53 -27
  29. tests/test_elements_worker/test_training.py +35 -0
  30. tests/test_elements_worker/test_transcriptions.py +149 -16
  31. tests/test_elements_worker/test_worker.py +19 -6
  32. tests/test_image.py +37 -0
  33. tests/test_utils.py +23 -1
  34. worker-demo/tests/__init__.py +0 -0
  35. worker-demo/tests/conftest.py +32 -0
  36. worker-demo/tests/test_worker.py +12 -0
  37. worker-demo/worker_demo/__init__.py +6 -0
  38. worker-demo/worker_demo/worker.py +19 -0
  39. arkindex_base_worker-0.3.6rc5.dist-info/RECORD +0 -41
  40. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/LICENSE +0 -0
@@ -1,13 +1,59 @@
1
1
  import logging
2
+ import uuid
3
+ from argparse import ArgumentTypeError
2
4
 
3
5
  import pytest
4
6
  from apistar.exceptions import ErrorResponse
5
7
 
6
- from arkindex_worker.worker import MissingDatasetArchive
8
+ from arkindex_worker.models import Dataset, Set
9
+ from arkindex_worker.worker import MissingDatasetArchive, check_dataset_set
7
10
  from arkindex_worker.worker.dataset import DatasetState
8
11
  from tests.conftest import FIXTURES_DIR, PROCESS_ID
9
12
  from tests.test_elements_worker import BASE_API_CALLS
10
13
 
14
+ RANDOM_UUID = uuid.uuid4()
15
+
16
+
17
+ @pytest.fixture()
18
+ def tmp_archive(tmp_path):
19
+ archive = tmp_path / "test_archive.tar.zst"
20
+ archive.touch()
21
+
22
+ yield archive
23
+
24
+ archive.unlink(missing_ok=True)
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ ("value", "error"),
29
+ [("train", ""), (f"{RANDOM_UUID}:train:val", ""), ("not_uuid:train", "")],
30
+ )
31
+ def test_check_dataset_set_errors(value, error):
32
+ with pytest.raises(ArgumentTypeError, match=error):
33
+ check_dataset_set(value)
34
+
35
+
36
+ def test_check_dataset_set():
37
+ assert check_dataset_set(f"{RANDOM_UUID}:train") == (RANDOM_UUID, "train")
38
+
39
+
40
+ def test_cleanup_downloaded_artifact_no_download(mock_dataset_worker):
41
+ assert not mock_dataset_worker.downloaded_dataset_artifact
42
+ # Do nothing
43
+ mock_dataset_worker.cleanup_downloaded_artifact()
44
+
45
+
46
+ def test_cleanup_downloaded_artifact(mock_dataset_worker, tmp_archive):
47
+ mock_dataset_worker.downloaded_dataset_artifact = tmp_archive
48
+
49
+ assert mock_dataset_worker.downloaded_dataset_artifact.exists()
50
+ # Unlink the downloaded archive
51
+ mock_dataset_worker.cleanup_downloaded_artifact()
52
+ assert not mock_dataset_worker.downloaded_dataset_artifact.exists()
53
+
54
+ # Unlinking again does not raise an error even if the archive no longer exists
55
+ mock_dataset_worker.cleanup_downloaded_artifact()
56
+
11
57
 
12
58
  def test_download_dataset_artifact_list_api_error(
13
59
  responses, mock_dataset_worker, default_dataset
@@ -127,8 +173,15 @@ def test_download_dataset_artifact_no_archive(
127
173
  ]
128
174
 
129
175
 
176
+ @pytest.mark.parametrize("downloaded_cache", [False, True])
130
177
  def test_download_dataset_artifact(
131
- mocker, tmp_path, responses, mock_dataset_worker, default_dataset
178
+ mocker,
179
+ tmp_path,
180
+ responses,
181
+ mock_dataset_worker,
182
+ default_dataset,
183
+ downloaded_cache,
184
+ tmp_archive,
132
185
  ):
133
186
  task_id = default_dataset.task_id
134
187
  archive_path = (
@@ -176,10 +229,25 @@ def test_download_dataset_artifact(
176
229
  content_type="application/zstd",
177
230
  )
178
231
 
179
- archive = mock_dataset_worker.download_dataset_artifact(default_dataset)
180
- assert archive == tmp_path / "dataset_id.tar.zst"
181
- assert archive.read_bytes() == archive_path.read_bytes()
182
- archive.unlink()
232
+ if downloaded_cache:
233
+ mock_dataset_worker.downloaded_dataset_artifact = tmp_archive
234
+ previous_artifact = mock_dataset_worker.downloaded_dataset_artifact
235
+
236
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
237
+
238
+ # We removed the artifact that was downloaded previously
239
+ if previous_artifact:
240
+ assert not previous_artifact.exists()
241
+
242
+ assert (
243
+ mock_dataset_worker.downloaded_dataset_artifact
244
+ == tmp_path / "dataset_id.tar.zst"
245
+ )
246
+ assert (
247
+ mock_dataset_worker.downloaded_dataset_artifact.read_bytes()
248
+ == archive_path.read_bytes()
249
+ )
250
+ mock_dataset_worker.downloaded_dataset_artifact.unlink()
183
251
 
184
252
  assert len(responses.calls) == len(BASE_API_CALLS) + 2
185
253
  assert [
@@ -190,356 +258,218 @@ def test_download_dataset_artifact(
190
258
  ]
191
259
 
192
260
 
193
- def test_list_dataset_elements_per_split_api_error(
194
- responses, mock_dataset_worker, default_dataset
261
+ def test_download_dataset_artifact_already_exists(
262
+ mocker, tmp_path, responses, mock_dataset_worker, default_dataset
195
263
  ):
264
+ mocker.patch(
265
+ "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
266
+ return_value=tmp_path,
267
+ )
268
+ already_downloaded = tmp_path / "dataset_id.tar.zst"
269
+ already_downloaded.write_bytes(b"Some content")
270
+ mock_dataset_worker.downloaded_dataset_artifact = already_downloaded
271
+
272
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
273
+
274
+ assert mock_dataset_worker.downloaded_dataset_artifact == already_downloaded
275
+ already_downloaded.unlink()
276
+
277
+ assert len(responses.calls) == len(BASE_API_CALLS)
278
+ assert [
279
+ (call.request.method, call.request.url) for call in responses.calls
280
+ ] == BASE_API_CALLS
281
+
282
+
283
+ def test_list_sets_api_error(responses, mock_dataset_worker):
196
284
  responses.add(
197
285
  responses.GET,
198
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
286
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
199
287
  status=500,
200
288
  )
201
289
 
202
290
  with pytest.raises(
203
291
  Exception, match="Stopping pagination as data will be incomplete"
204
292
  ):
205
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
293
+ next(mock_dataset_worker.list_sets())
206
294
 
207
295
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
208
296
  assert [
209
297
  (call.request.method, call.request.url) for call in responses.calls
210
298
  ] == BASE_API_CALLS + [
211
299
  # The API call is retried 5 times
212
- (
213
- "GET",
214
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
215
- ),
216
- (
217
- "GET",
218
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
219
- ),
220
- (
221
- "GET",
222
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
223
- ),
224
- (
225
- "GET",
226
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
227
- ),
228
- (
229
- "GET",
230
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
231
- ),
300
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
301
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
302
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
303
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
304
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
232
305
  ]
233
306
 
234
307
 
235
- def test_list_dataset_elements_per_split(
236
- responses, mock_dataset_worker, default_dataset
237
- ):
308
+ def test_list_sets(responses, mock_dataset_worker):
238
309
  expected_results = [
239
310
  {
240
- "set": "set_1",
241
- "element": {
242
- "id": "0000",
243
- "type": "page",
244
- "name": "Test",
245
- "corpus": {},
246
- "thumbnail_url": None,
247
- "zone": {},
248
- "best_classes": None,
249
- "has_children": None,
250
- "worker_version_id": None,
251
- "worker_run_id": None,
252
- },
253
- },
254
- {
255
- "set": "set_1",
256
- "element": {
257
- "id": "1111",
258
- "type": "page",
259
- "name": "Test 2",
260
- "corpus": {},
261
- "thumbnail_url": None,
262
- "zone": {},
263
- "best_classes": None,
264
- "has_children": None,
265
- "worker_version_id": None,
266
- "worker_run_id": None,
311
+ "id": "set_1",
312
+ "dataset": {
313
+ "id": "dataset_1",
314
+ "name": "Dataset 1",
315
+ "description": "My first great dataset",
316
+ "sets": [
317
+ {"id": "set_1", "name": "train"},
318
+ {"id": "set_2", "name": "val"},
319
+ ],
320
+ "state": "open",
321
+ "corpus_id": "corpus_id",
322
+ "creator": "test@teklia.com",
323
+ "task_id": "task_id_1",
267
324
  },
325
+ "set_name": "train",
268
326
  },
269
327
  {
270
- "set": "set_2",
271
- "element": {
272
- "id": "2222",
273
- "type": "page",
274
- "name": "Test 3",
275
- "corpus": {},
276
- "thumbnail_url": None,
277
- "zone": {},
278
- "best_classes": None,
279
- "has_children": None,
280
- "worker_version_id": None,
281
- "worker_run_id": None,
328
+ "id": "set_2",
329
+ "dataset": {
330
+ "id": "dataset_1",
331
+ "name": "Dataset 1",
332
+ "description": "My first great dataset",
333
+ "sets": [
334
+ {"id": "set_1", "name": "train"},
335
+ {"id": "set_2", "name": "val"},
336
+ ],
337
+ "state": "open",
338
+ "corpus_id": "corpus_id",
339
+ "creator": "test@teklia.com",
340
+ "task_id": "task_id_1",
282
341
  },
342
+ "set_name": "val",
283
343
  },
284
344
  {
285
- "set": "set_3",
286
- "element": {
287
- "id": "3333",
288
- "type": "page",
289
- "name": "Test 4",
290
- "corpus": {},
291
- "thumbnail_url": None,
292
- "zone": {},
293
- "best_classes": None,
294
- "has_children": None,
295
- "worker_version_id": None,
296
- "worker_run_id": None,
345
+ "id": "set_3",
346
+ "dataset": {
347
+ "id": "dataset_2",
348
+ "name": "Dataset 2",
349
+ "description": "My second great dataset",
350
+ "sets": [{"id": "set_3", "name": "my_set"}],
351
+ "state": "complete",
352
+ "corpus_id": "corpus_id",
353
+ "creator": "test@teklia.com",
354
+ "task_id": "task_id_2",
297
355
  },
356
+ "set_name": "my_set",
298
357
  },
299
358
  ]
300
359
  responses.add(
301
360
  responses.GET,
302
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
361
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
303
362
  status=200,
304
363
  json={
305
- "count": 4,
364
+ "count": 3,
306
365
  "next": None,
307
366
  "results": expected_results,
308
367
  },
309
368
  )
310
369
 
311
- assert list(
312
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
313
- ) == [
314
- ("set_1", [expected_results[0]["element"], expected_results[1]["element"]]),
315
- ("set_2", [expected_results[2]["element"]]),
316
- ("set_3", [expected_results[3]["element"]]),
317
- ]
370
+ for idx, dataset_set in enumerate(mock_dataset_worker.list_process_sets()):
371
+ assert isinstance(dataset_set, Set)
372
+ assert dataset_set.name == expected_results[idx]["set_name"]
373
+
374
+ assert isinstance(dataset_set.dataset, Dataset)
375
+ assert dataset_set.dataset == expected_results[idx]["dataset"]
318
376
 
319
377
  assert len(responses.calls) == len(BASE_API_CALLS) + 1
320
378
  assert [
321
379
  (call.request.method, call.request.url) for call in responses.calls
322
380
  ] == BASE_API_CALLS + [
323
- (
324
- "GET",
325
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
326
- ),
381
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
327
382
  ]
328
383
 
329
384
 
330
- def test_list_datasets_read_only(mock_dev_dataset_worker):
331
- assert list(mock_dev_dataset_worker.list_datasets()) == [
332
- "11111111-1111-1111-1111-111111111111",
333
- "22222222-2222-2222-2222-222222222222",
385
+ def test_list_sets_retrieve_dataset_api_error(
386
+ responses, mock_dev_dataset_worker, default_dataset
387
+ ):
388
+ mock_dev_dataset_worker.args.set = [
389
+ (default_dataset.id, "train"),
390
+ (default_dataset.id, "val"),
334
391
  ]
335
392
 
336
-
337
- def test_list_datasets_api_error(responses, mock_dataset_worker):
338
393
  responses.add(
339
394
  responses.GET,
340
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
395
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
341
396
  status=500,
342
397
  )
343
398
 
344
- with pytest.raises(
345
- Exception, match="Stopping pagination as data will be incomplete"
346
- ):
347
- mock_dataset_worker.list_datasets()
399
+ with pytest.raises(ErrorResponse):
400
+ next(mock_dev_dataset_worker.list_sets())
348
401
 
349
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
350
- assert [
351
- (call.request.method, call.request.url) for call in responses.calls
352
- ] == BASE_API_CALLS + [
402
+ assert len(responses.calls) == 5
403
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
353
404
  # The API call is retried 5 times
354
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
355
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
356
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
357
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
358
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
405
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
406
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
407
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
408
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
409
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
359
410
  ]
360
411
 
361
412
 
362
- def test_list_datasets(responses, mock_dataset_worker):
363
- expected_results = [
364
- {
365
- "id": "dataset_1",
366
- "name": "Dataset 1",
367
- "description": "My first great dataset",
368
- "sets": ["train", "val", "test"],
369
- "state": "open",
370
- "corpus_id": "corpus_id",
371
- "creator": "test@teklia.com",
372
- "task_id": "task_id_1",
373
- },
374
- {
375
- "id": "dataset_2",
376
- "name": "Dataset 2",
377
- "description": "My second great dataset",
378
- "sets": ["train", "val"],
379
- "state": "complete",
380
- "corpus_id": "corpus_id",
381
- "creator": "test@teklia.com",
382
- "task_id": "task_id_2",
383
- },
384
- {
385
- "id": "dataset_3",
386
- "name": "Dataset 3 (TRASHME)",
387
- "description": "My third dataset, in error",
388
- "sets": ["nonsense", "random set"],
389
- "state": "error",
390
- "corpus_id": "corpus_id",
391
- "creator": "test@teklia.com",
392
- "task_id": "task_id_3",
393
- },
413
+ def test_list_sets_read_only(responses, mock_dev_dataset_worker, default_dataset):
414
+ mock_dev_dataset_worker.args.set = [
415
+ (default_dataset.id, "train"),
416
+ (default_dataset.id, "val"),
394
417
  ]
418
+
395
419
  responses.add(
396
420
  responses.GET,
397
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
421
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
398
422
  status=200,
399
- json={
400
- "count": 3,
401
- "next": None,
402
- "results": expected_results,
403
- },
423
+ json=default_dataset,
404
424
  )
405
425
 
406
- assert list(mock_dataset_worker.list_datasets()) == expected_results
407
-
408
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
409
- assert [
410
- (call.request.method, call.request.url) for call in responses.calls
411
- ] == BASE_API_CALLS + [
412
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
426
+ assert list(mock_dev_dataset_worker.list_sets()) == [
427
+ Set(name="train", dataset=default_dataset),
428
+ Set(name="val", dataset=default_dataset),
413
429
  ]
414
430
 
415
-
416
- @pytest.mark.parametrize("generator", [True, False])
417
- def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator):
418
- mocker.patch("arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[])
419
- mock_dataset_worker.generator = generator
420
-
421
- with pytest.raises(SystemExit):
422
- mock_dataset_worker.run()
423
-
424
- assert [(level, message) for _, level, message in caplog.record_tuples] == [
425
- (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
426
- (logging.WARNING, "No datasets to process, stopping."),
431
+ assert len(responses.calls) == 1
432
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
433
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
427
434
  ]
428
435
 
429
436
 
430
- @pytest.mark.parametrize(
431
- ("generator", "error"),
432
- [
433
- (True, "When generating a new dataset, its state should be Open."),
434
- (False, "When processing an existing dataset, its state should be Complete."),
435
- ],
436
- )
437
- def test_run_initial_dataset_state_error(
438
- mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, error
439
- ):
440
- default_dataset.state = DatasetState.Building.value
441
- mocker.patch(
442
- "arkindex_worker.worker.DatasetWorker.list_datasets",
443
- return_value=[default_dataset],
444
- )
445
- mock_dataset_worker.generator = generator
446
-
447
- extra_call = []
448
- if generator:
449
- responses.add(
450
- responses.PATCH,
451
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
452
- status=200,
453
- json={},
454
- )
455
- extra_call = [
456
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
457
- ]
437
+ def test_run_no_sets(mocker, caplog, mock_dataset_worker):
438
+ mocker.patch("arkindex_worker.worker.DatasetWorker.list_sets", return_value=[])
458
439
 
459
440
  with pytest.raises(SystemExit):
460
441
  mock_dataset_worker.run()
461
442
 
462
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_call)
463
- assert [
464
- (call.request.method, call.request.url) for call in responses.calls
465
- ] == BASE_API_CALLS * 2 + extra_call
466
-
467
443
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
468
444
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
469
- (
470
- logging.WARNING,
471
- f"Failed running worker on dataset dataset_id: AssertionError('{error}')",
472
- ),
473
- (
474
- logging.ERROR,
475
- "Ran on 1 datasets: 0 completed, 1 failed",
476
- ),
445
+ (logging.WARNING, "No sets to process, stopping."),
477
446
  ]
478
447
 
479
448
 
480
- def test_run_update_dataset_state_api_error(
449
+ def test_run_initial_dataset_state_error(
481
450
  mocker, responses, caplog, mock_dataset_worker, default_dataset
482
451
  ):
452
+ default_dataset.state = DatasetState.Building.value
483
453
  mocker.patch(
484
- "arkindex_worker.worker.DatasetWorker.list_datasets",
485
- return_value=[default_dataset],
486
- )
487
- mock_dataset_worker.generator = True
488
-
489
- responses.add(
490
- responses.PATCH,
491
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
492
- status=500,
454
+ "arkindex_worker.worker.DatasetWorker.list_sets",
455
+ return_value=[Set(name="train", dataset=default_dataset)],
493
456
  )
494
457
 
495
458
  with pytest.raises(SystemExit):
496
459
  mock_dataset_worker.run()
497
460
 
498
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 10
461
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2
499
462
  assert [
500
463
  (call.request.method, call.request.url) for call in responses.calls
501
- ] == BASE_API_CALLS * 2 + [
502
- # We retry 5 times the API call to update the Dataset as Building
503
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
504
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
505
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
506
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
507
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
508
- # We retry 5 times the API call to update the Dataset as in Error
509
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
510
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
511
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
512
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
513
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
514
- ]
464
+ ] == BASE_API_CALLS * 2
515
465
 
516
- retries = [3.0, 4.0, 8.0, 16.0]
517
466
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
518
467
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
519
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
520
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
521
- *[
522
- (
523
- logging.INFO,
524
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
525
- )
526
- for retry in retries
527
- ],
528
468
  (
529
469
  logging.WARNING,
530
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
531
- ),
532
- *[
533
- (
534
- logging.INFO,
535
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
536
- )
537
- for retry in retries
538
- ],
539
- (
540
- logging.ERROR,
541
- "Ran on 1 datasets: 0 completed, 1 failed",
470
+ "Failed running worker on Set (train) from Dataset (dataset_id): AssertionError('When processing a set, its dataset state should be Complete.')",
542
471
  ),
472
+ (logging.ERROR, "Ran on 1 set: 0 completed, 1 failed"),
543
473
  ]
544
474
 
545
475
 
@@ -552,10 +482,9 @@ def test_run_download_dataset_artifact_api_error(
552
482
  default_dataset,
553
483
  ):
554
484
  default_dataset.state = DatasetState.Complete.value
555
-
556
485
  mocker.patch(
557
- "arkindex_worker.worker.DatasetWorker.list_datasets",
558
- return_value=[default_dataset],
486
+ "arkindex_worker.worker.DatasetWorker.list_sets",
487
+ return_value=[Set(name="train", dataset=default_dataset)],
559
488
  )
560
489
  mocker.patch(
561
490
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -585,27 +514,30 @@ def test_run_download_dataset_artifact_api_error(
585
514
 
586
515
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
587
516
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
588
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
589
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
517
+ (
518
+ logging.INFO,
519
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
520
+ ),
521
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
590
522
  *[
591
523
  (
592
524
  logging.INFO,
593
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
525
+ f"Retrying arkindex.client.ArkindexClient.request in {retry} seconds as it raised ErrorResponse: .",
594
526
  )
595
527
  for retry in [3.0, 4.0, 8.0, 16.0]
596
528
  ],
597
529
  (
598
530
  logging.WARNING,
599
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
531
+ "An API error occurred while processing Set (train) from Dataset (dataset_id): 500 Internal Server Error - None",
600
532
  ),
601
533
  (
602
534
  logging.ERROR,
603
- "Ran on 1 datasets: 0 completed, 1 failed",
535
+ "Ran on 1 set: 0 completed, 1 failed",
604
536
  ),
605
537
  ]
606
538
 
607
539
 
608
- def test_run_no_downloaded_artifact_error(
540
+ def test_run_no_downloaded_dataset_artifact_error(
609
541
  mocker,
610
542
  tmp_path,
611
543
  responses,
@@ -614,10 +546,9 @@ def test_run_no_downloaded_artifact_error(
614
546
  default_dataset,
615
547
  ):
616
548
  default_dataset.state = DatasetState.Complete.value
617
-
618
549
  mocker.patch(
619
- "arkindex_worker.worker.DatasetWorker.list_datasets",
620
- return_value=[default_dataset],
550
+ "arkindex_worker.worker.DatasetWorker.list_sets",
551
+ return_value=[Set(name="train", dataset=default_dataset)],
621
552
  )
622
553
  mocker.patch(
623
554
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -643,22 +574,22 @@ def test_run_no_downloaded_artifact_error(
643
574
 
644
575
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
645
576
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
646
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
647
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
577
+ (
578
+ logging.INFO,
579
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
580
+ ),
581
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
648
582
  (
649
583
  logging.WARNING,
650
- "Failed running worker on dataset dataset_id: MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
584
+ "Failed running worker on Set (train) from Dataset (dataset_id): MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
651
585
  ),
652
586
  (
653
587
  logging.ERROR,
654
- "Ran on 1 datasets: 0 completed, 1 failed",
588
+ "Ran on 1 set: 0 completed, 1 failed",
655
589
  ),
656
590
  ]
657
591
 
658
592
 
659
- @pytest.mark.parametrize(
660
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
661
- )
662
593
  def test_run(
663
594
  mocker,
664
595
  tmp_path,
@@ -667,90 +598,68 @@ def test_run(
667
598
  mock_dataset_worker,
668
599
  default_dataset,
669
600
  default_artifact,
670
- generator,
671
- state,
672
601
  ):
673
- mock_dataset_worker.generator = generator
674
- default_dataset.state = state.value
675
-
602
+ default_dataset.state = DatasetState.Complete.value
676
603
  mocker.patch(
677
- "arkindex_worker.worker.DatasetWorker.list_datasets",
678
- return_value=[default_dataset],
604
+ "arkindex_worker.worker.DatasetWorker.list_sets",
605
+ return_value=[Set(name="train", dataset=default_dataset)],
679
606
  )
680
607
  mocker.patch(
681
608
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
682
609
  return_value=tmp_path,
683
610
  )
684
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
685
-
686
- extra_calls = []
687
- extra_logs = []
688
- if generator:
689
- responses.add(
690
- responses.PATCH,
691
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
692
- status=200,
693
- json={},
694
- )
695
- extra_calls += [
696
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
697
- ] * 2
698
- extra_logs += [
699
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
700
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
701
- ]
702
- else:
703
- archive_path = (
704
- FIXTURES_DIR
705
- / "extract_parent_archives"
706
- / "first_parent"
707
- / "arkindex_data.tar.zst"
708
- )
709
- responses.add(
710
- responses.GET,
711
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
712
- status=200,
713
- json=[default_artifact],
714
- )
715
- responses.add(
716
- responses.GET,
717
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
718
- status=200,
719
- body=archive_path.read_bytes(),
720
- content_type="application/zstd",
721
- )
722
- extra_calls += [
723
- (
724
- "GET",
725
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
726
- ),
727
- (
728
- "GET",
729
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
730
- ),
731
- ]
732
- extra_logs += [
733
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
734
- ]
611
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
612
+
613
+ archive_path = (
614
+ FIXTURES_DIR
615
+ / "extract_parent_archives"
616
+ / "first_parent"
617
+ / "arkindex_data.tar.zst"
618
+ )
619
+ responses.add(
620
+ responses.GET,
621
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
622
+ status=200,
623
+ json=[default_artifact],
624
+ )
625
+ responses.add(
626
+ responses.GET,
627
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
628
+ status=200,
629
+ body=archive_path.read_bytes(),
630
+ content_type="application/zstd",
631
+ )
735
632
 
736
633
  mock_dataset_worker.run()
737
634
 
738
635
  assert mock_process.call_count == 1
739
636
 
740
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_calls)
637
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 2
741
638
  assert [
742
639
  (call.request.method, call.request.url) for call in responses.calls
743
- ] == BASE_API_CALLS * 2 + extra_calls
640
+ ] == BASE_API_CALLS * 2 + [
641
+ (
642
+ "GET",
643
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
644
+ ),
645
+ (
646
+ "GET",
647
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
648
+ ),
649
+ ]
744
650
 
745
651
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
746
652
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
747
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
748
- ] + extra_logs
653
+ (
654
+ logging.INFO,
655
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
656
+ ),
657
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
658
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
659
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
660
+ ]
749
661
 
750
662
 
751
- @pytest.mark.parametrize(
752
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
753
- )
754
663
  def test_run_read_only(
755
664
  mocker,
756
665
  tmp_path,
@@ -759,88 +668,61 @@ def test_run_read_only(
759
668
  mock_dev_dataset_worker,
760
669
  default_dataset,
761
670
  default_artifact,
762
- generator,
763
- state,
764
671
  ):
765
- mock_dev_dataset_worker.generator = generator
766
- default_dataset.state = state.value
767
-
672
+ default_dataset.state = DatasetState.Complete.value
768
673
  mocker.patch(
769
- "arkindex_worker.worker.DatasetWorker.list_datasets",
770
- return_value=[default_dataset.id],
674
+ "arkindex_worker.worker.DatasetWorker.list_sets",
675
+ return_value=[Set(name="train", dataset=default_dataset)],
771
676
  )
772
677
  mocker.patch(
773
678
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
774
679
  return_value=tmp_path,
775
680
  )
776
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
681
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
777
682
 
683
+ archive_path = (
684
+ FIXTURES_DIR
685
+ / "extract_parent_archives"
686
+ / "first_parent"
687
+ / "arkindex_data.tar.zst"
688
+ )
778
689
  responses.add(
779
690
  responses.GET,
780
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
691
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
781
692
  status=200,
782
- json=default_dataset,
693
+ json=[default_artifact],
694
+ )
695
+ responses.add(
696
+ responses.GET,
697
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
698
+ status=200,
699
+ body=archive_path.read_bytes(),
700
+ content_type="application/zstd",
783
701
  )
784
-
785
- extra_calls = []
786
- extra_logs = []
787
- if generator:
788
- extra_logs += [
789
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
790
- (
791
- logging.WARNING,
792
- "Cannot update dataset as this worker is in read-only mode",
793
- ),
794
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
795
- (
796
- logging.WARNING,
797
- "Cannot update dataset as this worker is in read-only mode",
798
- ),
799
- ]
800
- else:
801
- archive_path = (
802
- FIXTURES_DIR
803
- / "extract_parent_archives"
804
- / "first_parent"
805
- / "arkindex_data.tar.zst"
806
- )
807
- responses.add(
808
- responses.GET,
809
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
810
- status=200,
811
- json=[default_artifact],
812
- )
813
- responses.add(
814
- responses.GET,
815
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
816
- status=200,
817
- body=archive_path.read_bytes(),
818
- content_type="application/zstd",
819
- )
820
- extra_calls += [
821
- (
822
- "GET",
823
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
824
- ),
825
- (
826
- "GET",
827
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
828
- ),
829
- ]
830
- extra_logs += [
831
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
832
- ]
833
702
 
834
703
  mock_dev_dataset_worker.run()
835
704
 
836
705
  assert mock_process.call_count == 1
837
706
 
838
- assert len(responses.calls) == 1 + len(extra_calls)
707
+ assert len(responses.calls) == 2
839
708
  assert [(call.request.method, call.request.url) for call in responses.calls] == [
840
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
841
- ] + extra_calls
709
+ (
710
+ "GET",
711
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
712
+ ),
713
+ (
714
+ "GET",
715
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
716
+ ),
717
+ ]
842
718
 
843
719
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
844
720
  (logging.WARNING, "Running without any extra configuration"),
845
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
846
- ] + extra_logs
721
+ (
722
+ logging.INFO,
723
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
724
+ ),
725
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
726
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
727
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
728
+ ]