arkindex-base-worker 0.3.7rc6__py3-none-any.whl → 0.3.7rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,59 @@
1
1
  import logging
2
+ import uuid
3
+ from argparse import ArgumentTypeError
2
4
 
3
5
  import pytest
4
6
  from apistar.exceptions import ErrorResponse
5
7
 
6
- from arkindex_worker.worker import MissingDatasetArchive
8
+ from arkindex_worker.models import Dataset, Set
9
+ from arkindex_worker.worker import MissingDatasetArchive, check_dataset_set
7
10
  from arkindex_worker.worker.dataset import DatasetState
8
11
  from tests.conftest import FIXTURES_DIR, PROCESS_ID
9
12
  from tests.test_elements_worker import BASE_API_CALLS
10
13
 
14
+ RANDOM_UUID = uuid.uuid4()
15
+
16
+
17
+ @pytest.fixture()
18
+ def tmp_archive(tmp_path):
19
+ archive = tmp_path / "test_archive.tar.zst"
20
+ archive.touch()
21
+
22
+ yield archive
23
+
24
+ archive.unlink(missing_ok=True)
25
+
26
+
27
+ @pytest.mark.parametrize(
28
+ ("value", "error"),
29
+ [("train", ""), (f"{RANDOM_UUID}:train:val", ""), ("not_uuid:train", "")],
30
+ )
31
+ def test_check_dataset_set_errors(value, error):
32
+ with pytest.raises(ArgumentTypeError, match=error):
33
+ check_dataset_set(value)
34
+
35
+
36
+ def test_check_dataset_set():
37
+ assert check_dataset_set(f"{RANDOM_UUID}:train") == (RANDOM_UUID, "train")
38
+
39
+
40
+ def test_cleanup_downloaded_artifact_no_download(mock_dataset_worker):
41
+ assert not mock_dataset_worker.downloaded_artifact
42
+ # Do nothing
43
+ mock_dataset_worker.cleanup_downloaded_artifact()
44
+
45
+
46
+ def test_cleanup_downloaded_artifact(mock_dataset_worker, tmp_archive):
47
+ mock_dataset_worker.downloaded_artifact = tmp_archive
48
+
49
+ assert mock_dataset_worker.downloaded_artifact.exists()
50
+ # Unlink the downloaded archive
51
+ mock_dataset_worker.cleanup_downloaded_artifact()
52
+ assert not mock_dataset_worker.downloaded_artifact.exists()
53
+
54
+ # Unlinking again does not raise an error even if the archive no longer exists
55
+ mock_dataset_worker.cleanup_downloaded_artifact()
56
+
11
57
 
12
58
  def test_download_dataset_artifact_list_api_error(
13
59
  responses, mock_dataset_worker, default_dataset
@@ -127,8 +173,15 @@ def test_download_dataset_artifact_no_archive(
127
173
  ]
128
174
 
129
175
 
176
+ @pytest.mark.parametrize("downloaded_cache", [False, True])
130
177
  def test_download_dataset_artifact(
131
- mocker, tmp_path, responses, mock_dataset_worker, default_dataset
178
+ mocker,
179
+ tmp_path,
180
+ responses,
181
+ mock_dataset_worker,
182
+ default_dataset,
183
+ downloaded_cache,
184
+ tmp_archive,
132
185
  ):
133
186
  task_id = default_dataset.task_id
134
187
  archive_path = (
@@ -176,10 +229,22 @@ def test_download_dataset_artifact(
176
229
  content_type="application/zstd",
177
230
  )
178
231
 
179
- archive = mock_dataset_worker.download_dataset_artifact(default_dataset)
180
- assert archive == tmp_path / "dataset_id.tar.zst"
181
- assert archive.read_bytes() == archive_path.read_bytes()
182
- archive.unlink()
232
+ if downloaded_cache:
233
+ mock_dataset_worker.downloaded_artifact = tmp_archive
234
+ previous_artifact = mock_dataset_worker.downloaded_artifact
235
+
236
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
237
+
238
+ # We removed the artifact that was downloaded previously
239
+ if previous_artifact:
240
+ assert not previous_artifact.exists()
241
+
242
+ assert mock_dataset_worker.downloaded_artifact == tmp_path / "dataset_id.tar.zst"
243
+ assert (
244
+ mock_dataset_worker.downloaded_artifact.read_bytes()
245
+ == archive_path.read_bytes()
246
+ )
247
+ mock_dataset_worker.downloaded_artifact.unlink()
183
248
 
184
249
  assert len(responses.calls) == len(BASE_API_CALLS) + 2
185
250
  assert [
@@ -190,189 +255,107 @@ def test_download_dataset_artifact(
190
255
  ]
191
256
 
192
257
 
193
- def test_list_dataset_elements_per_split_api_error(
194
- responses, mock_dataset_worker, default_dataset
258
+ def test_download_dataset_artifact_already_exists(
259
+ mocker, tmp_path, responses, mock_dataset_worker, default_dataset
195
260
  ):
196
- responses.add(
197
- responses.GET,
198
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
199
- status=500,
261
+ mocker.patch(
262
+ "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
263
+ return_value=tmp_path,
200
264
  )
265
+ already_downloaded = tmp_path / "dataset_id.tar.zst"
266
+ already_downloaded.write_bytes(b"Some content")
267
+ mock_dataset_worker.downloaded_artifact = already_downloaded
201
268
 
202
- with pytest.raises(
203
- Exception, match="Stopping pagination as data will be incomplete"
204
- ):
205
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
269
+ mock_dataset_worker.download_dataset_artifact(default_dataset)
206
270
 
207
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
208
- assert [
209
- (call.request.method, call.request.url) for call in responses.calls
210
- ] == BASE_API_CALLS + [
211
- # The API call is retried 5 times
212
- (
213
- "GET",
214
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
215
- ),
216
- (
217
- "GET",
218
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
219
- ),
220
- (
221
- "GET",
222
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
223
- ),
224
- (
225
- "GET",
226
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
227
- ),
228
- (
229
- "GET",
230
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
231
- ),
232
- ]
271
+ assert mock_dataset_worker.downloaded_artifact == already_downloaded
272
+ already_downloaded.unlink()
233
273
 
234
-
235
- def test_list_dataset_elements_per_split(
236
- responses, mock_dataset_worker, default_dataset
237
- ):
238
- expected_results = []
239
- for selected_set in default_dataset.selected_sets:
240
- index = selected_set[-1]
241
- expected_results.append(
242
- {
243
- "set": selected_set,
244
- "element": {
245
- "id": str(index) * 4,
246
- "type": "page",
247
- "name": f"Test {index}",
248
- "corpus": {},
249
- "thumbnail_url": None,
250
- "zone": {},
251
- "best_classes": None,
252
- "has_children": None,
253
- "worker_version_id": None,
254
- "worker_run_id": None,
255
- },
256
- }
257
- )
258
- responses.add(
259
- responses.GET,
260
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
261
- status=200,
262
- json={
263
- "count": 1,
264
- "next": None,
265
- "results": [expected_results[-1]],
266
- },
267
- )
268
-
269
- assert list(
270
- mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
271
- ) == [
272
- ("set_1", [expected_results[0]["element"]]),
273
- ("set_2", [expected_results[1]["element"]]),
274
- ("set_3", [expected_results[2]["element"]]),
275
- ]
276
-
277
- assert len(responses.calls) == len(BASE_API_CALLS) + 3
274
+ assert len(responses.calls) == len(BASE_API_CALLS)
278
275
  assert [
279
276
  (call.request.method, call.request.url) for call in responses.calls
280
- ] == BASE_API_CALLS + [
281
- (
282
- "GET",
283
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
284
- ),
285
- (
286
- "GET",
287
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_2&with_count=true",
288
- ),
289
- (
290
- "GET",
291
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_3&with_count=true",
292
- ),
293
- ]
277
+ ] == BASE_API_CALLS
294
278
 
295
279
 
296
- def test_list_datasets_read_only(mock_dev_dataset_worker):
297
- assert list(mock_dev_dataset_worker.list_datasets()) == [
298
- "11111111-1111-1111-1111-111111111111",
299
- "22222222-2222-2222-2222-222222222222",
300
- ]
301
-
302
-
303
- def test_list_datasets_api_error(responses, mock_dataset_worker):
280
+ def test_list_sets_api_error(responses, mock_dataset_worker):
304
281
  responses.add(
305
282
  responses.GET,
306
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
283
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
307
284
  status=500,
308
285
  )
309
286
 
310
287
  with pytest.raises(
311
288
  Exception, match="Stopping pagination as data will be incomplete"
312
289
  ):
313
- next(mock_dataset_worker.list_datasets())
290
+ next(mock_dataset_worker.list_sets())
314
291
 
315
292
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
316
293
  assert [
317
294
  (call.request.method, call.request.url) for call in responses.calls
318
295
  ] == BASE_API_CALLS + [
319
296
  # The API call is retried 5 times
320
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
321
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
322
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
323
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
324
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
297
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
298
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
299
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
300
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
301
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
325
302
  ]
326
303
 
327
304
 
328
- def test_list_datasets(responses, mock_dataset_worker):
305
+ def test_list_sets(responses, mock_dataset_worker):
329
306
  expected_results = [
330
307
  {
331
- "id": "process_dataset_1",
308
+ "id": "set_1",
332
309
  "dataset": {
333
310
  "id": "dataset_1",
334
311
  "name": "Dataset 1",
335
312
  "description": "My first great dataset",
336
- "sets": ["train", "val", "test"],
313
+ "sets": [
314
+ {"id": "set_1", "name": "train"},
315
+ {"id": "set_2", "name": "val"},
316
+ ],
337
317
  "state": "open",
338
318
  "corpus_id": "corpus_id",
339
319
  "creator": "test@teklia.com",
340
320
  "task_id": "task_id_1",
341
321
  },
342
- "sets": ["test"],
322
+ "set_name": "train",
343
323
  },
344
324
  {
345
- "id": "process_dataset_2",
325
+ "id": "set_2",
346
326
  "dataset": {
347
- "id": "dataset_2",
348
- "name": "Dataset 2",
349
- "description": "My second great dataset",
350
- "sets": ["train", "val"],
351
- "state": "complete",
327
+ "id": "dataset_1",
328
+ "name": "Dataset 1",
329
+ "description": "My first great dataset",
330
+ "sets": [
331
+ {"id": "set_1", "name": "train"},
332
+ {"id": "set_2", "name": "val"},
333
+ ],
334
+ "state": "open",
352
335
  "corpus_id": "corpus_id",
353
336
  "creator": "test@teklia.com",
354
- "task_id": "task_id_2",
337
+ "task_id": "task_id_1",
355
338
  },
356
- "sets": ["train", "val"],
339
+ "set_name": "val",
357
340
  },
358
341
  {
359
- "id": "process_dataset_3",
342
+ "id": "set_3",
360
343
  "dataset": {
361
- "id": "dataset_3",
362
- "name": "Dataset 3 (TRASHME)",
363
- "description": "My third dataset, in error",
364
- "sets": ["nonsense", "random set"],
365
- "state": "error",
344
+ "id": "dataset_2",
345
+ "name": "Dataset 2",
346
+ "description": "My second great dataset",
347
+ "sets": [{"id": "set_3", "name": "my_set"}],
348
+ "state": "complete",
366
349
  "corpus_id": "corpus_id",
367
350
  "creator": "test@teklia.com",
368
- "task_id": "task_id_3",
351
+ "task_id": "task_id_2",
369
352
  },
370
- "sets": ["random set"],
353
+ "set_name": "my_set",
371
354
  },
372
355
  ]
373
356
  responses.add(
374
357
  responses.GET,
375
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
358
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
376
359
  status=200,
377
360
  json={
378
361
  "count": 3,
@@ -381,154 +364,109 @@ def test_list_datasets(responses, mock_dataset_worker):
381
364
  },
382
365
  )
383
366
 
384
- for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
385
- assert dataset == {
386
- **expected_results[idx]["dataset"],
387
- "selected_sets": expected_results[idx]["sets"],
388
- }
367
+ for idx, dataset_set in enumerate(mock_dataset_worker.list_process_sets()):
368
+ assert isinstance(dataset_set, Set)
369
+ assert dataset_set.name == expected_results[idx]["set_name"]
370
+
371
+ assert isinstance(dataset_set.dataset, Dataset)
372
+ assert dataset_set.dataset == expected_results[idx]["dataset"]
389
373
 
390
374
  assert len(responses.calls) == len(BASE_API_CALLS) + 1
391
375
  assert [
392
376
  (call.request.method, call.request.url) for call in responses.calls
393
377
  ] == BASE_API_CALLS + [
394
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
378
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
395
379
  ]
396
380
 
397
381
 
398
- @pytest.mark.parametrize("generator", [True, False])
399
- def test_run_no_datasets(mocker, caplog, mock_dataset_worker, generator):
400
- mocker.patch("arkindex_worker.worker.DatasetWorker.list_datasets", return_value=[])
401
- mock_dataset_worker.generator = generator
382
+ def test_list_sets_retrieve_dataset_api_error(
383
+ responses, mock_dev_dataset_worker, default_dataset
384
+ ):
385
+ mock_dev_dataset_worker.args.set = [
386
+ (default_dataset.id, "train"),
387
+ (default_dataset.id, "val"),
388
+ ]
402
389
 
403
- with pytest.raises(SystemExit):
404
- mock_dataset_worker.run()
390
+ responses.add(
391
+ responses.GET,
392
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
393
+ status=500,
394
+ )
405
395
 
406
- assert [(level, message) for _, level, message in caplog.record_tuples] == [
407
- (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
408
- (logging.WARNING, "No datasets to process, stopping."),
396
+ with pytest.raises(ErrorResponse):
397
+ next(mock_dev_dataset_worker.list_sets())
398
+
399
+ assert len(responses.calls) == 5
400
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
401
+ # The API call is retried 5 times
402
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
403
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
404
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
405
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
406
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
409
407
  ]
410
408
 
411
409
 
412
- @pytest.mark.parametrize(
413
- ("generator", "error"),
414
- [
415
- (True, "When generating a new dataset, its state should be Open or Error."),
416
- (False, "When processing an existing dataset, its state should be Complete."),
417
- ],
418
- )
419
- def test_run_initial_dataset_state_error(
420
- mocker, responses, caplog, mock_dataset_worker, default_dataset, generator, error
421
- ):
422
- default_dataset.state = DatasetState.Building.value
423
- mocker.patch(
424
- "arkindex_worker.worker.DatasetWorker.list_datasets",
425
- return_value=[default_dataset],
410
+ def test_list_sets_read_only(responses, mock_dev_dataset_worker, default_dataset):
411
+ mock_dev_dataset_worker.args.set = [
412
+ (default_dataset.id, "train"),
413
+ (default_dataset.id, "val"),
414
+ ]
415
+
416
+ responses.add(
417
+ responses.GET,
418
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/",
419
+ status=200,
420
+ json=default_dataset,
426
421
  )
427
- mock_dataset_worker.generator = generator
428
-
429
- extra_call = []
430
- if generator:
431
- responses.add(
432
- responses.PATCH,
433
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
434
- status=200,
435
- json={},
436
- )
437
- extra_call = [
438
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
439
- ]
422
+
423
+ assert list(mock_dev_dataset_worker.list_sets()) == [
424
+ Set(name="train", dataset=default_dataset),
425
+ Set(name="val", dataset=default_dataset),
426
+ ]
427
+
428
+ assert len(responses.calls) == 1
429
+ assert [(call.request.method, call.request.url) for call in responses.calls] == [
430
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
431
+ ]
432
+
433
+
434
+ def test_run_no_sets(mocker, caplog, mock_dataset_worker):
435
+ mocker.patch("arkindex_worker.worker.DatasetWorker.list_sets", return_value=[])
440
436
 
441
437
  with pytest.raises(SystemExit):
442
438
  mock_dataset_worker.run()
443
439
 
444
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_call)
445
- assert [
446
- (call.request.method, call.request.url) for call in responses.calls
447
- ] == BASE_API_CALLS * 2 + extra_call
448
-
449
440
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
450
441
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
451
- (
452
- logging.WARNING,
453
- f"Failed running worker on dataset dataset_id: AssertionError('{error}')",
454
- ),
455
- ] + (
456
- [
457
- (
458
- logging.WARNING,
459
- "This API helper `update_dataset_state` did not update the cache database",
460
- )
461
- ]
462
- if generator
463
- else []
464
- ) + [
465
- (logging.ERROR, "Ran on 1 dataset: 0 completed, 1 failed"),
442
+ (logging.WARNING, "No sets to process, stopping."),
466
443
  ]
467
444
 
468
445
 
469
- def test_run_update_dataset_state_api_error(
446
+ def test_run_initial_dataset_state_error(
470
447
  mocker, responses, caplog, mock_dataset_worker, default_dataset
471
448
  ):
449
+ default_dataset.state = DatasetState.Building.value
472
450
  mocker.patch(
473
- "arkindex_worker.worker.DatasetWorker.list_datasets",
474
- return_value=[default_dataset],
475
- )
476
- mock_dataset_worker.generator = True
477
-
478
- responses.add(
479
- responses.PATCH,
480
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
481
- status=500,
451
+ "arkindex_worker.worker.DatasetWorker.list_sets",
452
+ return_value=[Set(name="train", dataset=default_dataset)],
482
453
  )
483
454
 
484
455
  with pytest.raises(SystemExit):
485
456
  mock_dataset_worker.run()
486
457
 
487
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 10
458
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2
488
459
  assert [
489
460
  (call.request.method, call.request.url) for call in responses.calls
490
- ] == BASE_API_CALLS * 2 + [
491
- # We retry 5 times the API call to update the Dataset as Building
492
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
493
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
494
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
495
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
496
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
497
- # We retry 5 times the API call to update the Dataset as in Error
498
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
499
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
500
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
501
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
502
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
503
- ]
461
+ ] == BASE_API_CALLS * 2
504
462
 
505
- retries = [3.0, 4.0, 8.0, 16.0]
506
463
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
507
464
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
508
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
509
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
510
- *[
511
- (
512
- logging.INFO,
513
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
514
- )
515
- for retry in retries
516
- ],
517
465
  (
518
466
  logging.WARNING,
519
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
520
- ),
521
- *[
522
- (
523
- logging.INFO,
524
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
525
- )
526
- for retry in retries
527
- ],
528
- (
529
- logging.ERROR,
530
- "Ran on 1 dataset: 0 completed, 1 failed",
467
+ "Failed running worker on Set (train) from Dataset (dataset_id): AssertionError('When processing a set, its dataset state should be Complete.')",
531
468
  ),
469
+ (logging.ERROR, "Ran on 1 set: 0 completed, 1 failed"),
532
470
  ]
533
471
 
534
472
 
@@ -541,10 +479,9 @@ def test_run_download_dataset_artifact_api_error(
541
479
  default_dataset,
542
480
  ):
543
481
  default_dataset.state = DatasetState.Complete.value
544
-
545
482
  mocker.patch(
546
- "arkindex_worker.worker.DatasetWorker.list_datasets",
547
- return_value=[default_dataset],
483
+ "arkindex_worker.worker.DatasetWorker.list_sets",
484
+ return_value=[Set(name="train", dataset=default_dataset)],
548
485
  )
549
486
  mocker.patch(
550
487
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -574,8 +511,11 @@ def test_run_download_dataset_artifact_api_error(
574
511
 
575
512
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
576
513
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
577
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
578
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
514
+ (
515
+ logging.INFO,
516
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
517
+ ),
518
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
579
519
  *[
580
520
  (
581
521
  logging.INFO,
@@ -585,11 +525,11 @@ def test_run_download_dataset_artifact_api_error(
585
525
  ],
586
526
  (
587
527
  logging.WARNING,
588
- "An API error occurred while processing dataset dataset_id: 500 Internal Server Error - None",
528
+ "An API error occurred while processing Set (train) from Dataset (dataset_id): 500 Internal Server Error - None",
589
529
  ),
590
530
  (
591
531
  logging.ERROR,
592
- "Ran on 1 dataset: 0 completed, 1 failed",
532
+ "Ran on 1 set: 0 completed, 1 failed",
593
533
  ),
594
534
  ]
595
535
 
@@ -603,10 +543,9 @@ def test_run_no_downloaded_artifact_error(
603
543
  default_dataset,
604
544
  ):
605
545
  default_dataset.state = DatasetState.Complete.value
606
-
607
546
  mocker.patch(
608
- "arkindex_worker.worker.DatasetWorker.list_datasets",
609
- return_value=[default_dataset],
547
+ "arkindex_worker.worker.DatasetWorker.list_sets",
548
+ return_value=[Set(name="train", dataset=default_dataset)],
610
549
  )
611
550
  mocker.patch(
612
551
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
@@ -632,22 +571,22 @@ def test_run_no_downloaded_artifact_error(
632
571
 
633
572
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
634
573
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
635
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
636
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
574
+ (
575
+ logging.INFO,
576
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
577
+ ),
578
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
637
579
  (
638
580
  logging.WARNING,
639
- "Failed running worker on dataset dataset_id: MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
581
+ "Failed running worker on Set (train) from Dataset (dataset_id): MissingDatasetArchive('The dataset compressed archive artifact was not found.')",
640
582
  ),
641
583
  (
642
584
  logging.ERROR,
643
- "Ran on 1 dataset: 0 completed, 1 failed",
585
+ "Ran on 1 set: 0 completed, 1 failed",
644
586
  ),
645
587
  ]
646
588
 
647
589
 
648
- @pytest.mark.parametrize(
649
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
650
- )
651
590
  def test_run(
652
591
  mocker,
653
592
  tmp_path,
@@ -656,100 +595,68 @@ def test_run(
656
595
  mock_dataset_worker,
657
596
  default_dataset,
658
597
  default_artifact,
659
- generator,
660
- state,
661
598
  ):
662
- mock_dataset_worker.generator = generator
663
- default_dataset.state = state.value
664
-
599
+ default_dataset.state = DatasetState.Complete.value
665
600
  mocker.patch(
666
- "arkindex_worker.worker.DatasetWorker.list_datasets",
667
- return_value=[default_dataset],
601
+ "arkindex_worker.worker.DatasetWorker.list_sets",
602
+ return_value=[Set(name="train", dataset=default_dataset)],
668
603
  )
669
604
  mocker.patch(
670
605
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
671
606
  return_value=tmp_path,
672
607
  )
673
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
674
-
675
- extra_calls = []
676
- extra_logs = []
677
- if generator:
678
- responses.add(
679
- responses.PATCH,
680
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
681
- status=200,
682
- json={},
683
- )
684
- extra_calls += [
685
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
686
- ] * 2
687
- extra_logs += [
688
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
689
- (
690
- logging.WARNING,
691
- "This API helper `update_dataset_state` did not update the cache database",
692
- ),
693
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
694
- (
695
- logging.WARNING,
696
- "This API helper `update_dataset_state` did not update the cache database",
697
- ),
698
- ]
699
- else:
700
- archive_path = (
701
- FIXTURES_DIR
702
- / "extract_parent_archives"
703
- / "first_parent"
704
- / "arkindex_data.tar.zst"
705
- )
706
- responses.add(
707
- responses.GET,
708
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
709
- status=200,
710
- json=[default_artifact],
711
- )
712
- responses.add(
713
- responses.GET,
714
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
715
- status=200,
716
- body=archive_path.read_bytes(),
717
- content_type="application/zstd",
718
- )
719
- extra_calls += [
720
- (
721
- "GET",
722
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
723
- ),
724
- (
725
- "GET",
726
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
727
- ),
728
- ]
729
- extra_logs += [
730
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
731
- ]
608
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
609
+
610
+ archive_path = (
611
+ FIXTURES_DIR
612
+ / "extract_parent_archives"
613
+ / "first_parent"
614
+ / "arkindex_data.tar.zst"
615
+ )
616
+ responses.add(
617
+ responses.GET,
618
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
619
+ status=200,
620
+ json=[default_artifact],
621
+ )
622
+ responses.add(
623
+ responses.GET,
624
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
625
+ status=200,
626
+ body=archive_path.read_bytes(),
627
+ content_type="application/zstd",
628
+ )
732
629
 
733
630
  mock_dataset_worker.run()
734
631
 
735
632
  assert mock_process.call_count == 1
736
633
 
737
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + len(extra_calls)
634
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 2
738
635
  assert [
739
636
  (call.request.method, call.request.url) for call in responses.calls
740
- ] == BASE_API_CALLS * 2 + extra_calls
637
+ ] == BASE_API_CALLS * 2 + [
638
+ (
639
+ "GET",
640
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
641
+ ),
642
+ (
643
+ "GET",
644
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
645
+ ),
646
+ ]
741
647
 
742
648
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
743
649
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
744
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
745
- *extra_logs,
746
- (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
650
+ (
651
+ logging.INFO,
652
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
653
+ ),
654
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
655
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
656
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
747
657
  ]
748
658
 
749
659
 
750
- @pytest.mark.parametrize(
751
- ("generator", "state"), [(True, DatasetState.Open), (False, DatasetState.Complete)]
752
- )
753
660
  def test_run_read_only(
754
661
  mocker,
755
662
  tmp_path,
@@ -758,90 +665,61 @@ def test_run_read_only(
758
665
  mock_dev_dataset_worker,
759
666
  default_dataset,
760
667
  default_artifact,
761
- generator,
762
- state,
763
668
  ):
764
- mock_dev_dataset_worker.generator = generator
765
- default_dataset.state = state.value
766
-
669
+ default_dataset.state = DatasetState.Complete.value
767
670
  mocker.patch(
768
- "arkindex_worker.worker.DatasetWorker.list_datasets",
769
- return_value=[default_dataset.id],
671
+ "arkindex_worker.worker.DatasetWorker.list_sets",
672
+ return_value=[Set(name="train", dataset=default_dataset)],
770
673
  )
771
674
  mocker.patch(
772
675
  "arkindex_worker.worker.base.BaseWorker.find_extras_directory",
773
676
  return_value=tmp_path,
774
677
  )
775
- mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_dataset")
678
+ mock_process = mocker.patch("arkindex_worker.worker.DatasetWorker.process_set")
776
679
 
680
+ archive_path = (
681
+ FIXTURES_DIR
682
+ / "extract_parent_archives"
683
+ / "first_parent"
684
+ / "arkindex_data.tar.zst"
685
+ )
777
686
  responses.add(
778
687
  responses.GET,
779
- f"http://testserver/api/v1/datasets/{default_dataset.id}/",
688
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
780
689
  status=200,
781
- json=default_dataset,
690
+ json=[default_artifact],
691
+ )
692
+ responses.add(
693
+ responses.GET,
694
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
695
+ status=200,
696
+ body=archive_path.read_bytes(),
697
+ content_type="application/zstd",
782
698
  )
783
-
784
- extra_calls = []
785
- extra_logs = []
786
- if generator:
787
- extra_logs += [
788
- (logging.INFO, "Building Dataset (dataset_id) (1/1)"),
789
- (
790
- logging.WARNING,
791
- "Cannot update dataset as this worker is in read-only mode",
792
- ),
793
- (logging.INFO, "Completed Dataset (dataset_id) (1/1)"),
794
- (
795
- logging.WARNING,
796
- "Cannot update dataset as this worker is in read-only mode",
797
- ),
798
- ]
799
- else:
800
- archive_path = (
801
- FIXTURES_DIR
802
- / "extract_parent_archives"
803
- / "first_parent"
804
- / "arkindex_data.tar.zst"
805
- )
806
- responses.add(
807
- responses.GET,
808
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
809
- status=200,
810
- json=[default_artifact],
811
- )
812
- responses.add(
813
- responses.GET,
814
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
815
- status=200,
816
- body=archive_path.read_bytes(),
817
- content_type="application/zstd",
818
- )
819
- extra_calls += [
820
- (
821
- "GET",
822
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
823
- ),
824
- (
825
- "GET",
826
- f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
827
- ),
828
- ]
829
- extra_logs += [
830
- (logging.INFO, "Downloading data for Dataset (dataset_id) (1/1)"),
831
- ]
832
699
 
833
700
  mock_dev_dataset_worker.run()
834
701
 
835
702
  assert mock_process.call_count == 1
836
703
 
837
- assert len(responses.calls) == 1 + len(extra_calls)
704
+ assert len(responses.calls) == 2
838
705
  assert [(call.request.method, call.request.url) for call in responses.calls] == [
839
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
840
- ] + extra_calls
706
+ (
707
+ "GET",
708
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
709
+ ),
710
+ (
711
+ "GET",
712
+ f"http://testserver/api/v1/task/{default_dataset.task_id}/artifact/dataset_id.tar.zst",
713
+ ),
714
+ ]
841
715
 
842
716
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
843
717
  (logging.WARNING, "Running without any extra configuration"),
844
- (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
845
- *extra_logs,
846
- (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
718
+ (
719
+ logging.INFO,
720
+ "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
721
+ ),
722
+ (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
723
+ (logging.INFO, "Processing Set (train) from Dataset (dataset_id) (1/1)"),
724
+ (logging.INFO, "Ran on 1 set: 1 completed, 0 failed"),
847
725
  ]