arkindex-base-worker 0.3.7rc6__py3-none-any.whl → 0.3.7rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,59 @@
1
1
  import logging
2
+ import uuid
3
+ from argparse import ArgumentTypeError
2
4
 
3
5
  import pytest
4
6
  from apistar.exceptions import ErrorResponse
5
7
 
6
- from arkindex_worker.worker import MissingDatasetArchive
8
+ from arkindex_worker.models import Dataset, Set
9
+ from arkindex_worker.worker import MissingDatasetArchive, check_dataset_set
7
10
  from arkindex_worker.worker.dataset import DatasetState
8
11
  from tests.conftest import FIXTURES_DIR, PROCESS_ID
9
12
  from tests.test_elements_worker import BASE_API_CALLS
10
13
 
14
+ RANDOM_UUID = uuid.uuid4()
15
+
16
+
17
+ @pytest.fixture()
18
+ def tmp_archive(tmp_path):
19
+ archive = tmp_path / "test_archive.tar.zst"
20
+ archive.touch()
21
+
22
+ yield archive
23
+
24
+ archive.unlink(missing_ok=True)
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ ("value", "error"),
29
+ [("train", ""), (f"{RANDOM_UUID}:train:val", ""), ("not_uuid:train", "")],
30
+ )
31
+ def test_check_dataset_set_errors(value, error):
32
+ with pytest.raises(ArgumentTypeError, match=error):
33
+ check_dataset_set(value)
34
+
35
+
36
+ def test_check_dataset_set():
37
+ assert check_dataset_set(f"{RANDOM_UUID}:train") == (RANDOM_UUID, "train")
38
+
39
+
40
+ def test_cleanup_downloaded_artifact_no_download(mock_dataset_worker):
41
+ assert not mock_dataset_worker.downloaded_dataset_artifact
42
+ # Do nothing
43
+ mock_dataset_worker.cleanup_downloaded_artifact()
44
+
45
+
46
+ def test_cleanup_downloaded_artifact(mock_dataset_worker, tmp_archive):
47
+ mock_dataset_worker.downloaded_dataset_artifact = tmp_archive
48
+
49
+ assert mock_dataset_worker.downloaded_dataset_artifact.exists()
50
+ # Unlink the downloaded archive
51
+ mock_dataset_worker.cleanup_downloaded_artifact()
52
+ assert not mock_dataset_worker.downloaded_dataset_artifact.exists()
53
+
54
+ # Unlinking again does not raise an error even if the archive no longer exists
55
+ mock_dataset_worker.cleanup_downloaded_artifact()
56
+
11
57
 
12
58
  def test_download_dataset_artifact_list_api_error(
13
59
  responses, mock_dataset_worker, default_dataset
@@ -127,8 +173,15 @@ def test_download_dataset_artifact_no_archive(
127
173
  ]
128
174
 
129
175
 
176
+ @pytest.mark.parametrize("downloaded_cache", [False, True])
130
177
  def test_download_dataset_artifact(
131
- mocker, tmp_path, responses, mock_dataset_worker, default_dataset
178
+ mocker,
179
+ tmp_path,
180
+ responses,
181
+ mock_dataset_worker,
182
+ default_dataset,
183
+ downloaded_cache,
184
+ tmp_archive,
132
185
  ):
133
186
  task_id = default_dataset.task_id
134
187
  archive_path = (
@@ -176,10 +229,25 @@ def test_download_dataset_artifact(
176
229
  content_type="application/zstd",
177
230
  )
178
231
 
179
- archive = mock_dataset_worker.download_dataset_artifact(default_dataset)
180
- assert archive == tmp_path / "dataset_id.tar.zst"
181
- assert archive.read_bytes() == archive_path.read_bytes()
182
- archive.unlink()
232
+ if downloaded_cache:
233
+ mock_dataset_worker.downloaded_dataset_artifact = tmp_archive
234
+ previous_artifact = mock_dataset_worker.downloaded_dataset_artifact
235
+
236
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
237
+
238
+ # We removed the artifact that was downloaded previously
239
+ if previous_artifact:
240
+ assert not previous_artifact.exists()
241
+
242
+ assert (
243
+ mock_dataset_worker.downloaded_dataset_artifact
244
+ == tmp_path / "dataset_id.tar.zst"
245
+ )
246
+ assert (
247
+ mock_dataset_worker.downloaded_dataset_artifact.read_bytes()
248
+ == archive_path.read_bytes()
249
+ )
250
+ mock_dataset_worker.downloaded_dataset_artifact.unlink()
183
251
 
184
252
  assert len(responses.calls) == len(BASE_API_CALLS) + 2
185
253
  assert [
@@ -190,189 +258,107 @@ def test_download_dataset_artifact(
190
258
  ]
191
259
 
192
260
 
193
- def test_list_dataset_elements_per_split_api_error(
194
- responses, mock_dataset_worker, default_dataset
261
+ def test_download_dataset_artifact_already_exists(
262
+ mocker, tmp_path, responses, mock_dataset_worker, default_dataset
195
263
  ):
196
- responses.add(
197
- responses.GET,
198
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
199
- status=500,
264
+ mocker.patch(
265
+ "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
266
+ return_value=tmp_path,
200
267
  )
268
+ already_downloaded = tmp_path / "dataset_id.tar.zst"
269
+ already_downloaded.write_bytes(b"Some content")
270
+ mock_dataset_worker.downloaded_dataset_artifact = already_downloaded
201
271
 
202
- with pytest.raises(
203
- Exception, match="Stopping pagination as data will be incomplete"
204
- ):
205
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
206
-
207
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
208
- assert [
209
- (call.request.method, call.request.url) for call in responses.calls
210
- ] == BASE_API_CALLS + [
211
- # The API call is retried 5 times
212
- (
213
- "GET",
214
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
215
- ),
216
- (
217
- "GET",
218
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
219
- ),
220
- (
221
- "GET",
222
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
223
- ),
224
- (
225
- "GET",
226
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
227
- ),
228
- (
229
- "GET",
230
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
231
- ),
232
- ]
233
-
272
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
234
273
 
235
- def test_list_dataset_elements_per_split(
236
- responses, mock_dataset_worker, default_dataset
237
- ):
238
- expected_results = []
239
- for selected_set in default_dataset.selected_sets:
240
- index = selected_set[-1]
241
- expected_results.append(
242
- {
243
- "set": selected_set,
244
- "element": {
245
- "id": str(index) * 4,
246
- "type": "page",
247
- "name": f"Test {index}",
248
- "corpus": {},
249
- "thumbnail_url": None,
250
- "zone": {},
251
- "best_classes": None,
252
- "has_children": None,
253
- "worker_version_id": None,
254
- "worker_run_id": None,
255
- },
256
- }
257
- )
258
- responses.add(
259
- responses.GET,
260
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
261
- status=200,
262
- json={
263
- "count": 1,
264
- "next": None,
265
- "results": [expected_results[-1]],
266
- },
267
- )
268
-
269
- assert list(
270
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
271
- ) == [
272
- ("set_1", [expected_results[0]["element"]]),
273
- ("set_2", [expected_results[1]["element"]]),
274
- ("set_3", [expected_results[2]["element"]]),
275
- ]
274
+ assert mock_dataset_worker.downloaded_dataset_artifact == already_downloaded
275
+ already_downloaded.unlink()
276
276
 
277
- assert len(responses.calls) == len(BASE_API_CALLS) + 3
277
+ assert len(responses.calls) == len(BASE_API_CALLS)
278
278
  assert [
279
279
  (call.request.method, call.request.url) for call in responses.calls
280
- ] == BASE_API_CALLS + [
281
- (
282
- "GET",
283
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
284
- ),
285
- (
286
- "GET",
287
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_2&with_count=true",
288
- ),
289
- (
290
- "GET",
291
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_3&with_count=true",
292
- ),
293
- ]
280
+ ] == BASE_API_CALLS
294
281
 
295
282
 
296
- def test_list_datasets_read_only(mock_dev_dataset_worker):
297
- assert list(mock_dev_dataset_worker.list_datasets()) == [
298
- "11111111-1111-1111-1111-111111111111",
299
- "22222222-2222-2222-2222-222222222222",
300
- ]
301
-
302
-
303
- def test_list_datasets_api_error(responses, mock_dataset_worker):
283
+ def test_list_sets_api_error(responses, mock_dataset_worker):
304
284
  responses.add(
305
285
  responses.GET,
306
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
286
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
307
287
  status=500,
308
288
  )
309
289
 
310
290
  with pytest.raises(
311
291
  Exception, match="Stopping pagination as data will be incomplete"
312
292
  ):
313
- next(mock_dataset_worker.list_datasets())
293
+ next(mock_dataset_worker.list_sets())
314
294
 
315
295
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
316
296
  assert [
317
297
  (call.request.method, call.request.url) for call in responses.calls
318
298
  ] == BASE_API_CALLS + [
319
299
  # The API call is retried 5 times
320
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
321
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
322
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
323
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
324
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
300
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
301
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
302
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
303
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
304
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
325
305
  ]
326
306
 
327
307
 
328
- def test_list_datasets(responses, mock_dataset_worker):
308
+ def test_list_sets(responses, mock_dataset_worker):
329
309
  expected_results = [
330
310
  {
331
- "id": "process_dataset_1",
311
+ "id": "set_1",
332
312
  "dataset": {
333
313
  "id": "dataset_1",
334
314
  "name": "Dataset 1",
335
315
  "description": "My first great dataset",
336
- "sets": ["train", "val", "test"],
316
+ "sets": [
317
+ {"id": "set_1", "name": "train"},
318
+ {"id": "set_2", "name": "val"},
319
+ ],
337
320
  "state": "open",
338
321
  "corpus_id": "corpus_id",
339
322
  "creator": "test@teklia.com",
340
323
  "task_id": "task_id_1",
341
324
  },
342
- "sets": ["test"],
325
+ "set_name": "train",
343
326
  },
344
327
  {
345
- "id": "process_dataset_2",
328
+ "id": "set_2",
346
329
  "dataset": {
347
- "id": "dataset_2",
348
- "name": "Dataset 2",
349
- "description": "My second great dataset",
350
- "sets": ["train", "val"],
351
- "state": "complete",
330
+ "id": "dataset_1",
331
+ "name": "Dataset 1",
332
+ "description": "My first great dataset",
333
+ "sets": [
334
+ {"id": "set_1", "name": "train"},
335
+ {"id": "set_2", "name": "val"},
336
+ ],
337
+ "state": "open",
352
338
  "corpus_id": "corpus_id",
353
339
  "creator": "test@teklia.com",
354
- "task_id": "task_id_2",
340
+ "task_id": "task_id_1",
355
341
  },
356
- "sets": ["train", "val"],
342
+ "set_name": "val",
357
343
  },
358
344
  {
359
- "id": "process_dataset_3",
345
+ "id": "set_3",
360
346
  "dataset": {
361
- "id": "dataset_3",
362
- "name": "Dataset 3 (TRASHME)",
363
- "description": "My third dataset, in error",
364
- "sets": ["nonsense", "random set"],
365
- "state": "error",
347
+ "id": "dataset_2",
348
+ "name": "Dataset 2",
349
+ "description": "My second great dataset",
350
+ "sets": [{"id": "set_3", "name": "my_set"}],
351
+ "state": "complete",
366
352
  "corpus_id": "corpus_id",
367
353
  "creator": "test@teklia.com",
368
- "task_id": "task_id_3",
354
+ "task_id": "task_id_2",
369
355
  },
370
- "sets": ["random set"],
356
+ "set_name": "my_set",
371
357
  },
372
358
  ]
373
359
  responses.add(
374
360
  responses.GET,
375
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
361
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
376
362
  status=200,
377
363
  json={
378
364
  "count": 3,
@@ -381,154 +367,109 @@ def test_list_datasets(responses, mock_dataset_worker):
381
367
  },
382
368
  )
383
369
 
384
- for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
385
- assert dataset == {
386
- **expected_results[idx]["dataset"],
387
- "selected_sets": expected_results[idx]["sets"],
388
- }
370
+ for idx, dataset_set in enumerate(mock_dataset_worker.list_process_sets()):
371
+ assert isinstance(dataset_set, Set)
372
+ assert dataset_set.name == expected_results[idx]["set_name"]
373
+
374
+ assert isinstance(dataset_set.dataset, Dataset)
375
+ assert dataset_set.dataset == expected_results[idx]["dataset"]
389
376
 
390
377
  assert len(responses.calls) == len(BASE_API_CALLS) + 1
391
378
  assert [
392
379
  (call.request.method, call.request.url) for call in responses.calls
393
380
  ] == BASE_API_CALLS + [
394
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
381
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
395
382
  ]
396
383
 
397
384
 
398
- @pytest.mark.parametrize("generator", [True, False])
399
- def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator):
400
- mocker.patch("arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[])
401
- mock_dataset_worker.generator = generator
385
+ def test_list_sets_retrieve_dataset_api_error(
386
+ responses, mock_dev_dataset_worker, default_dataset
387
+ ):
388
+ mock_dev_dataset_worker.args.set = [
389
+ (default_dataset.id, "train"),
390
+ (default_dataset.id, "val"),
391
+ ]
402
392
 
403
- with pytest.raises(SystemExit):
404
- mock_dataset_worker.run()
393
+ responses.add(
394
+ responses.GET,
395
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
396
+ status=500,
397
+ )
405
398
 
406
- assert [(level, message) for _, level, message in caplog.record_tuples] == [
407
- (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
408
- (logging.WARNING, "No datasets to process, stopping."),
399
+ with pytest.raises(ErrorResponse):
400
+ next(mock_dev_dataset_worker.list_sets())
401
+
402
+ assert len(responses.calls) == 5
403
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
404
+ # The API call is retried 5 times
405
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
406
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
407
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
408
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
409
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
409
410
  ]
410
411
 
411
412
 
412
- @pytest.mark.parametrize(
413
- ("generator", "error"),
414
- [
415
- (True, "When generating a new dataset, its state should be Open or Error."),
416
- (False, "When processing an existing dataset, its state should be Complete."),
417
- ],
418
- )
419
- def test_run_initial_dataset_state_error(
420
- mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, error
421
- ):
422
- default_dataset.state = DatasetState.Building.value
423
- mocker.patch(
424
- "arkindex_worker.worker.DatasetWorker.list_datasets",
425
- return_value=[default_dataset],
426
- )
427
- mock_dataset_worker.generator = generator
428
-
429
- extra_call = []
430
- if generator:
431
- responses.add(
432
- responses.PATCH,
433
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
434
- status=200,
435
- json={},
436
- )
437
- extra_call = [
438
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
439
- ]
413
+ def test_list_sets_read_only(responses, mock_dev_dataset_worker, default_dataset):
414
+ mock_dev_dataset_worker.args.set = [
415
+ (default_dataset.id, "train"),
416
+ (default_dataset.id, "val"),
417
+ ]
418
+
419
+ responses.add(
420
+ responses.GET,
421
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
422
+ status=200,
423
+ json=default_dataset,
424
+ )
425
+
426
+ assert list(mock_dev_dataset_worker.list_sets()) == [
427
+ Set(name="train", dataset=default_dataset),
428
+ Set(name="val", dataset=default_dataset),
429
+ ]
430
+
431
+ assert len(responses.calls) == 1
432
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
433
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
434
+ ]
435
+
436
+
437
+ def test_run_no_sets(mocker, caplog, mock_dataset_worker):
438
+ mocker.patch("arkindex_worker.worker.DatasetWorker.list_sets", return_value=[])
440
439
 
441
440
  with pytest.raises(SystemExit):
442
441
  mock_dataset_worker.run()
443
442
 
444
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_call)
445
- assert [
446
- (call.request.method, call.request.url) for call in responses.calls
447
- ] == BASE_API_CALLS * 2 + extra_call
448
-
449
443
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
450
444
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
451
- (
452
- logging.WARNING,
453
- f"Failed running worker on dataset dataset_id: AssertionError('{error}')",
454
- ),
455
- ] + (
456
- [
457
- (
458
- logging.WARNING,
459
- "This API helper `update_dataset_state` did not update the cache database",
460
- )
461
- ]
462
- if generator
463
- else []
464
- ) + [
465
- (logging.ERROR, "Ran on 1 dataset: 0 completed, 1 failed"),
445
+ (logging.WARNING, "No sets to process, stopping."),
466
446
  ]
467
447
 
468
448
 
469
- def test_run_update_dataset_state_api_error(
449
+ def test_run_initial_dataset_state_error(
470
450
  mocker, responses, caplog, mock_dataset_worker, default_dataset
471
451
  ):
452
+ default_dataset.state = DatasetState.Building.value
472
453
  mocker.patch(
473
- "arkindex_worker.worker.DatasetWorker.list_datasets",
474
- return_value=[default_dataset],
475
- )
476
- mock_dataset_worker.generator = True
477
-
478
- responses.add(
479
- responses.PATCH,
480
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
481
- status=500,
454
+ "arkindex_worker.worker.DatasetWorker.list_sets",
455
+ return_value=[Set(name="train", dataset=default_dataset)],
482
456
  )
483
457
 
484
458
  with pytest.raises(SystemExit):
485
459
  mock_dataset_worker.run()
486
460
 
487
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 10
461
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2
488
462
  assert [
489
463
  (call.request.method, call.request.url) for call in responses.calls
490
- ] == BASE_API_CALLS * 2 + [
491
- # We retry 5 times the API call to update the Dataset as Building
492
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
493
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
494
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
495
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
496
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
497
- # We retry 5 times the API call to update the Dataset as in Error
498
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
499
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
500
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
501
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
502
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
503
- ]
464
+ ] == BASE_API_CALLS * 2
504
465
 
505
- retries = [3.0, 4.0, 8.0, 16.0]
506
466
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
507
467
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
508
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
509
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
510
- *[
511
- (
512
- logging.INFO,
513
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
514
- )
515
- for retry in retries
516
- ],
517
468
  (
518
469
  logging.WARNING,
519
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
520
- ),
521
- *[
522
- (
523
- logging.INFO,
524
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
525
- )
526
- for retry in retries
527
- ],
528
- (
529
- logging.ERROR,
530
- "Ran on 1 dataset: 0 completed, 1 failed",
470
+ "Failed running worker on Set (train) from Dataset (dataset_id): AssertionError('When processing a set, its dataset state should be Complete.')",
531
471
  ),
472
+ (logging.ERROR, "Ran on 1 set: 0 completed, 1 failed"),
532
473
  ]
533
474
 
534
475
 
@@ -541,10 +482,9 @@ def test_run_download_dataset_artifact_api_error(
541
482
  default_dataset,
542
483
  ):
543
484
  default_dataset.state = DatasetState.Complete.value
544
-
545
485
  mocker.patch(
546
- "arkindex_worker.worker.DatasetWorker.list_datasets",
547
- return_value=[default_dataset],
486
+ "arkindex_worker.worker.DatasetWorker.list_sets",
487
+ return_value=[Set(name="train", dataset=default_dataset)],
548
488
  )
549
489
  mocker.patch(
550
490
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -574,8 +514,11 @@ def test_run_download_dataset_artifact_api_error(
574
514
 
575
515
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
576
516
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
577
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
578
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
517
+ (
518
+ logging.INFO,
519
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
520
+ ),
521
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
579
522
  *[
580
523
  (
581
524
  logging.INFO,
@@ -585,16 +528,16 @@ def test_run_download_dataset_artifact_api_error(
585
528
  ],
586
529
  (
587
530
  logging.WARNING,
588
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
531
+ "An API error occurred while processing Set (train) from Dataset (dataset_id): 500 Internal Server Error - None",
589
532
  ),
590
533
  (
591
534
  logging.ERROR,
592
- "Ran on 1 dataset: 0 completed, 1 failed",
535
+ "Ran on 1 set: 0 completed, 1 failed",
593
536
  ),
594
537
  ]
595
538
 
596
539
 
597
- def test_run_no_downloaded_artifact_error(
540
+ def test_run_no_downloaded_dataset_artifact_error(
598
541
  mocker,
599
542
  tmp_path,
600
543
  responses,
@@ -603,10 +546,9 @@ def test_run_no_downloaded_artifact_error(
603
546
  default_dataset,
604
547
  ):
605
548
  default_dataset.state = DatasetState.Complete.value
606
-
607
549
  mocker.patch(
608
- "arkindex_worker.worker.DatasetWorker.list_datasets",
609
- return_value=[default_dataset],
550
+ "arkindex_worker.worker.DatasetWorker.list_sets",
551
+ return_value=[Set(name="train", dataset=default_dataset)],
610
552
  )
611
553
  mocker.patch(
612
554
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -632,22 +574,22 @@ def test_run_no_downloaded_artifact_error(
632
574
 
633
575
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
634
576
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
635
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
636
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
577
+ (
578
+ logging.INFO,
579
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
580
+ ),
581
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
637
582
  (
638
583
  logging.WARNING,
639
- "Failed running worker on dataset dataset_id: MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
584
+ "Failed running worker on Set (train) from Dataset (dataset_id): MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
640
585
  ),
641
586
  (
642
587
  logging.ERROR,
643
- "Ran on 1 dataset: 0 completed, 1 failed",
588
+ "Ran on 1 set: 0 completed, 1 failed",
644
589
  ),
645
590
  ]
646
591
 
647
592
 
648
- @pytest.mark.parametrize(
649
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
650
- )
651
593
  def test_run(
652
594
  mocker,
653
595
  tmp_path,
@@ -656,100 +598,68 @@ def test_run(
656
598
  mock_dataset_worker,
657
599
  default_dataset,
658
600
  default_artifact,
659
- generator,
660
- state,
661
601
  ):
662
- mock_dataset_worker.generator = generator
663
- default_dataset.state = state.value
664
-
602
+ default_dataset.state = DatasetState.Complete.value
665
603
  mocker.patch(
666
- "arkindex_worker.worker.DatasetWorker.list_datasets",
667
- return_value=[default_dataset],
604
+ "arkindex_worker.worker.DatasetWorker.list_sets",
605
+ return_value=[Set(name="train", dataset=default_dataset)],
668
606
  )
669
607
  mocker.patch(
670
608
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
671
609
  return_value=tmp_path,
672
610
  )
673
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
674
-
675
- extra_calls = []
676
- extra_logs = []
677
- if generator:
678
- responses.add(
679
- responses.PATCH,
680
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
681
- status=200,
682
- json={},
683
- )
684
- extra_calls += [
685
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
686
- ] * 2
687
- extra_logs += [
688
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
689
- (
690
- logging.WARNING,
691
- "This API helper `update_dataset_state` did not update the cache database",
692
- ),
693
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
694
- (
695
- logging.WARNING,
696
- "This API helper `update_dataset_state` did not update the cache database",
697
- ),
698
- ]
699
- else:
700
- archive_path = (
701
- FIXTURES_DIR
702
- / "extract_parent_archives"
703
- / "first_parent"
704
- / "arkindex_data.tar.zst"
705
- )
706
- responses.add(
707
- responses.GET,
708
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
709
- status=200,
710
- json=[default_artifact],
711
- )
712
- responses.add(
713
- responses.GET,
714
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
715
- status=200,
716
- body=archive_path.read_bytes(),
717
- content_type="application/zstd",
718
- )
719
- extra_calls += [
720
- (
721
- "GET",
722
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
723
- ),
724
- (
725
- "GET",
726
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
727
- ),
728
- ]
729
- extra_logs += [
730
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
731
- ]
611
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
612
+
613
+ archive_path = (
614
+ FIXTURES_DIR
615
+ / "extract_parent_archives"
616
+ / "first_parent"
617
+ / "arkindex_data.tar.zst"
618
+ )
619
+ responses.add(
620
+ responses.GET,
621
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
622
+ status=200,
623
+ json=[default_artifact],
624
+ )
625
+ responses.add(
626
+ responses.GET,
627
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
628
+ status=200,
629
+ body=archive_path.read_bytes(),
630
+ content_type="application/zstd",
631
+ )
732
632
 
733
633
  mock_dataset_worker.run()
734
634
 
735
635
  assert mock_process.call_count == 1
736
636
 
737
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_calls)
637
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 2
738
638
  assert [
739
639
  (call.request.method, call.request.url) for call in responses.calls
740
- ] == BASE_API_CALLS * 2 + extra_calls
640
+ ] == BASE_API_CALLS * 2 + [
641
+ (
642
+ "GET",
643
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
644
+ ),
645
+ (
646
+ "GET",
647
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
648
+ ),
649
+ ]
741
650
 
742
651
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
743
652
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
744
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
745
- *extra_logs,
746
- (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
653
+ (
654
+ logging.INFO,
655
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
656
+ ),
657
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
658
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
659
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
747
660
  ]
748
661
 
749
662
 
750
- @pytest.mark.parametrize(
751
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
752
- )
753
663
  def test_run_read_only(
754
664
  mocker,
755
665
  tmp_path,
@@ -758,90 +668,61 @@ def test_run_read_only(
758
668
  mock_dev_dataset_worker,
759
669
  default_dataset,
760
670
  default_artifact,
761
- generator,
762
- state,
763
671
  ):
764
- mock_dev_dataset_worker.generator = generator
765
- default_dataset.state = state.value
766
-
672
+ default_dataset.state = DatasetState.Complete.value
767
673
  mocker.patch(
768
- "arkindex_worker.worker.DatasetWorker.list_datasets",
769
- return_value=[default_dataset.id],
674
+ "arkindex_worker.worker.DatasetWorker.list_sets",
675
+ return_value=[Set(name="train", dataset=default_dataset)],
770
676
  )
771
677
  mocker.patch(
772
678
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
773
679
  return_value=tmp_path,
774
680
  )
775
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
681
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
776
682
 
683
+ archive_path = (
684
+ FIXTURES_DIR
685
+ / "extract_parent_archives"
686
+ / "first_parent"
687
+ / "arkindex_data.tar.zst"
688
+ )
777
689
  responses.add(
778
690
  responses.GET,
779
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
691
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
780
692
  status=200,
781
- json=default_dataset,
693
+ json=[default_artifact],
694
+ )
695
+ responses.add(
696
+ responses.GET,
697
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
698
+ status=200,
699
+ body=archive_path.read_bytes(),
700
+ content_type="application/zstd",
782
701
  )
783
-
784
- extra_calls = []
785
- extra_logs = []
786
- if generator:
787
- extra_logs += [
788
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
789
- (
790
- logging.WARNING,
791
- "Cannot update dataset as this worker is in read-only mode",
792
- ),
793
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
794
- (
795
- logging.WARNING,
796
- "Cannot update dataset as this worker is in read-only mode",
797
- ),
798
- ]
799
- else:
800
- archive_path = (
801
- FIXTURES_DIR
802
- / "extract_parent_archives"
803
- / "first_parent"
804
- / "arkindex_data.tar.zst"
805
- )
806
- responses.add(
807
- responses.GET,
808
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
809
- status=200,
810
- json=[default_artifact],
811
- )
812
- responses.add(
813
- responses.GET,
814
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
815
- status=200,
816
- body=archive_path.read_bytes(),
817
- content_type="application/zstd",
818
- )
819
- extra_calls += [
820
- (
821
- "GET",
822
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
823
- ),
824
- (
825
- "GET",
826
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
827
- ),
828
- ]
829
- extra_logs += [
830
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
831
- ]
832
702
 
833
703
  mock_dev_dataset_worker.run()
834
704
 
835
705
  assert mock_process.call_count == 1
836
706
 
837
- assert len(responses.calls) == 1 + len(extra_calls)
707
+ assert len(responses.calls) == 2
838
708
  assert [(call.request.method, call.request.url) for call in responses.calls] == [
839
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
840
- ] + extra_calls
709
+ (
710
+ "GET",
711
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
712
+ ),
713
+ (
714
+ "GET",
715
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
716
+ ),
717
+ ]
841
718
 
842
719
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
843
720
  (logging.WARNING, "Running without any extra configuration"),
844
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
845
- *extra_logs,
846
- (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
721
+ (
722
+ logging.INFO,
723
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
724
+ ),
725
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
726
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
727
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
847
728
  ]