arkindex-base-worker 0.3.7rc9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/METADATA +16 -20
  2. arkindex_base_worker-0.4.0.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/WHEEL +1 -1
  4. arkindex_worker/cache.py +1 -1
  5. arkindex_worker/image.py +120 -1
  6. arkindex_worker/models.py +6 -0
  7. arkindex_worker/utils.py +85 -4
  8. arkindex_worker/worker/__init__.py +68 -162
  9. arkindex_worker/worker/base.py +39 -34
  10. arkindex_worker/worker/classification.py +34 -18
  11. arkindex_worker/worker/corpus.py +86 -0
  12. arkindex_worker/worker/dataset.py +71 -1
  13. arkindex_worker/worker/element.py +352 -91
  14. arkindex_worker/worker/entity.py +11 -11
  15. arkindex_worker/worker/image.py +21 -0
  16. arkindex_worker/worker/metadata.py +19 -9
  17. arkindex_worker/worker/process.py +92 -0
  18. arkindex_worker/worker/task.py +5 -4
  19. arkindex_worker/worker/training.py +25 -10
  20. arkindex_worker/worker/transcription.py +89 -68
  21. arkindex_worker/worker/version.py +3 -1
  22. tests/__init__.py +8 -0
  23. tests/conftest.py +36 -52
  24. tests/test_base_worker.py +212 -12
  25. tests/test_dataset_worker.py +21 -45
  26. tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
  27. tests/test_elements_worker/test_cli.py +3 -11
  28. tests/test_elements_worker/test_corpus.py +168 -0
  29. tests/test_elements_worker/test_dataset.py +7 -12
  30. tests/test_elements_worker/test_element.py +427 -0
  31. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  32. tests/test_elements_worker/test_element_create_single.py +528 -0
  33. tests/test_elements_worker/test_element_list_children.py +969 -0
  34. tests/test_elements_worker/test_element_list_parents.py +530 -0
  35. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  36. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  37. tests/test_elements_worker/test_image.py +66 -0
  38. tests/test_elements_worker/test_metadata.py +230 -139
  39. tests/test_elements_worker/test_process.py +89 -0
  40. tests/test_elements_worker/test_task.py +8 -18
  41. tests/test_elements_worker/test_training.py +17 -8
  42. tests/test_elements_worker/test_transcription_create.py +873 -0
  43. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  44. tests/test_elements_worker/test_transcription_list.py +450 -0
  45. tests/test_elements_worker/test_version.py +60 -0
  46. tests/test_elements_worker/test_worker.py +563 -279
  47. tests/test_image.py +432 -209
  48. tests/test_merge.py +1 -2
  49. tests/test_utils.py +66 -3
  50. arkindex_base_worker-0.3.7rc9.dist-info/RECORD +0 -47
  51. tests/test_elements_worker/test_elements.py +0 -2713
  52. tests/test_elements_worker/test_transcriptions.py +0 -2119
  53. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/LICENSE +0 -0
  54. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,18 +3,21 @@ import uuid
3
3
  from argparse import ArgumentTypeError
4
4
 
5
5
  import pytest
6
- from apistar.exceptions import ErrorResponse
7
6
 
7
+ from arkindex.exceptions import ErrorResponse
8
8
  from arkindex_worker.models import Dataset, Set
9
- from arkindex_worker.worker import MissingDatasetArchive, check_dataset_set
10
- from arkindex_worker.worker.dataset import DatasetState
11
- from tests.conftest import FIXTURES_DIR, PROCESS_ID
9
+ from arkindex_worker.worker.dataset import (
10
+ DatasetState,
11
+ MissingDatasetArchive,
12
+ check_dataset_set,
13
+ )
14
+ from tests import FIXTURES_DIR, PROCESS_ID
12
15
  from tests.test_elements_worker import BASE_API_CALLS
13
16
 
14
17
  RANDOM_UUID = uuid.uuid4()
15
18
 
16
19
 
17
- @pytest.fixture()
20
+ @pytest.fixture
18
21
  def tmp_archive(tmp_path):
19
22
  archive = tmp_path / "test_archive.tar.zst"
20
23
  archive.touch()
@@ -63,22 +66,17 @@ def test_download_dataset_artifact_list_api_error(
63
66
  responses.add(
64
67
  responses.GET,
65
68
  f"http://testserver/api/v1/task/{task_id}/artifacts/",
66
- status=500,
69
+ status=418,
67
70
  )
68
71
 
69
72
  with pytest.raises(ErrorResponse):
70
73
  mock_dataset_worker.download_dataset_artifact(default_dataset)
71
74
 
72
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
75
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
73
76
  assert [
74
77
  (call.request.method, call.request.url) for call in responses.calls
75
78
  ] == BASE_API_CALLS + [
76
- # The API call is retried 5 times
77
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
78
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
79
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
80
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
81
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
79
+ ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/")
82
80
  ]
83
81
 
84
82
 
@@ -116,22 +114,17 @@ def test_download_dataset_artifact_download_api_error(
116
114
  responses.add(
117
115
  responses.GET,
118
116
  f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst",
119
- status=500,
117
+ status=418,
120
118
  )
121
119
 
122
120
  with pytest.raises(ErrorResponse):
123
121
  mock_dataset_worker.download_dataset_artifact(default_dataset)
124
122
 
125
- assert len(responses.calls) == len(BASE_API_CALLS) + 6
123
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
126
124
  assert [
127
125
  (call.request.method, call.request.url) for call in responses.calls
128
126
  ] == BASE_API_CALLS + [
129
127
  ("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
130
- # The API call is retried 5 times
131
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
132
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
133
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
134
- ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
135
128
  ("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
136
129
  ]
137
130
 
@@ -284,7 +277,7 @@ def test_list_sets_api_error(responses, mock_dataset_worker):
284
277
  responses.add(
285
278
  responses.GET,
286
279
  f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
287
- status=500,
280
+ status=418,
288
281
  )
289
282
 
290
283
  with pytest.raises(
@@ -393,20 +386,15 @@ def test_list_sets_retrieve_dataset_api_error(
393
386
  responses.add(
394
387
  responses.GET,
395
388
  f"http://testserver/api/v1/datasets/{default_dataset.id}/",
396
- status=500,
389
+ status=418,
397
390
  )
398
391
 
399
392
  with pytest.raises(ErrorResponse):
400
393
  next(mock_dev_dataset_worker.list_sets())
401
394
 
402
- assert len(responses.calls) == 5
395
+ assert len(responses.calls) == 1
403
396
  assert [(call.request.method, call.request.url) for call in responses.calls] == [
404
- # The API call is retried 5 times
405
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
406
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
407
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
408
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
409
- ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
397
+ ("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
410
398
  ]
411
399
 
412
400
 
@@ -494,22 +482,17 @@ def test_run_download_dataset_artifact_api_error(
494
482
  responses.add(
495
483
  responses.GET,
496
484
  f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
497
- status=500,
485
+ status=418,
498
486
  )
499
487
 
500
488
  with pytest.raises(SystemExit):
501
489
  mock_dataset_worker.run()
502
490
 
503
- assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 5
491
+ assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 1
504
492
  assert [
505
493
  (call.request.method, call.request.url) for call in responses.calls
506
494
  ] == BASE_API_CALLS * 2 + [
507
- # We retry 5 times the API call
508
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
509
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
510
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
511
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
512
- ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
495
+ ("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/")
513
496
  ]
514
497
 
515
498
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
@@ -519,16 +502,9 @@ def test_run_download_dataset_artifact_api_error(
519
502
  "Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
520
503
  ),
521
504
  (logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
522
- *[
523
- (
524
- logging.INFO,
525
- f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
526
- )
527
- for retry in [3.0, 4.0, 8.0, 16.0]
528
- ],
529
505
  (
530
506
  logging.WARNING,
531
- "An API error occurred while processing Set (train) from Dataset (dataset_id): 500 Internal Server Error - None",
507
+ "An API error occurred while processing Set (train) from Dataset (dataset_id): 418 I'm a Teapot - None",
532
508
  ),
533
509
  (
534
510
  logging.ERROR,
@@ -3,10 +3,12 @@ import re
3
3
  from uuid import UUID
4
4
 
5
5
  import pytest
6
- from apistar.exceptions import ErrorResponse
7
6
 
7
+ from arkindex.exceptions import ErrorResponse
8
8
  from arkindex_worker.cache import CachedClassification, CachedElement
9
9
  from arkindex_worker.models import Element
10
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
11
+ from tests import CORPUS_ID
10
12
 
11
13
  from . import BASE_API_CALLS
12
14
 
@@ -15,11 +17,96 @@ from . import BASE_API_CALLS
15
17
  DELETE_PARAMETER = "DELETE_PARAMETER"
16
18
 
17
19
 
20
+ def test_load_corpus_classes_api_error(responses, mock_elements_worker):
21
+ responses.add(
22
+ responses.GET,
23
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
24
+ status=418,
25
+ )
26
+
27
+ assert not mock_elements_worker.classes
28
+ with pytest.raises(
29
+ Exception, match="Stopping pagination as data will be incomplete"
30
+ ):
31
+ mock_elements_worker.load_corpus_classes()
32
+
33
+ assert len(responses.calls) == len(BASE_API_CALLS) + 5
34
+ assert [
35
+ (call.request.method, call.request.url) for call in responses.calls
36
+ ] == BASE_API_CALLS + [
37
+ # We do 5 retries
38
+ (
39
+ "GET",
40
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
41
+ ),
42
+ (
43
+ "GET",
44
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
45
+ ),
46
+ (
47
+ "GET",
48
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
49
+ ),
50
+ (
51
+ "GET",
52
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
53
+ ),
54
+ (
55
+ "GET",
56
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
57
+ ),
58
+ ]
59
+ assert not mock_elements_worker.classes
60
+
61
+
62
+ def test_load_corpus_classes(responses, mock_elements_worker):
63
+ responses.add(
64
+ responses.GET,
65
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
66
+ status=200,
67
+ json={
68
+ "count": 3,
69
+ "next": None,
70
+ "results": [
71
+ {
72
+ "id": "0000",
73
+ "name": "good",
74
+ },
75
+ {
76
+ "id": "1111",
77
+ "name": "average",
78
+ },
79
+ {
80
+ "id": "2222",
81
+ "name": "bad",
82
+ },
83
+ ],
84
+ },
85
+ )
86
+
87
+ assert not mock_elements_worker.classes
88
+ mock_elements_worker.load_corpus_classes()
89
+
90
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
91
+ assert [
92
+ (call.request.method, call.request.url) for call in responses.calls
93
+ ] == BASE_API_CALLS + [
94
+ (
95
+ "GET",
96
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
97
+ ),
98
+ ]
99
+ assert mock_elements_worker.classes == {
100
+ "good": "0000",
101
+ "average": "1111",
102
+ "bad": "2222",
103
+ }
104
+
105
+
18
106
  def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
19
- corpus_id = "11111111-1111-1111-1111-111111111111"
20
107
  responses.add(
21
108
  responses.GET,
22
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
109
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
23
110
  status=200,
24
111
  json={
25
112
  "count": 1,
@@ -42,7 +129,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
42
129
  ] == BASE_API_CALLS + [
43
130
  (
44
131
  "GET",
45
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
132
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
46
133
  ),
47
134
  ]
48
135
  assert mock_elements_worker.classes == {"good": "0000"}
@@ -51,12 +138,11 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
51
138
 
52
139
  def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
53
140
  # A missing class is now created automatically
54
- corpus_id = "11111111-1111-1111-1111-111111111111"
55
141
  mock_elements_worker.classes = {"good": "0000"}
56
142
 
57
143
  responses.add(
58
144
  responses.POST,
59
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
145
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
60
146
  status=201,
61
147
  json={"id": "new-ml-class-1234"},
62
148
  )
@@ -82,12 +168,10 @@ def test_get_ml_class_id(mock_elements_worker):
82
168
 
83
169
 
84
170
  def test_get_ml_class_reload(responses, mock_elements_worker):
85
- corpus_id = "11111111-1111-1111-1111-111111111111"
86
-
87
171
  # Add some initial classes
88
172
  responses.add(
89
173
  responses.GET,
90
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
174
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
91
175
  json={
92
176
  "count": 1,
93
177
  "next": None,
@@ -103,7 +187,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
103
187
  # Invalid response when trying to create class2
104
188
  responses.add(
105
189
  responses.POST,
106
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
190
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
107
191
  status=400,
108
192
  json={"non_field_errors": "Already exists"},
109
193
  )
@@ -111,7 +195,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
111
195
  # Add both classes (class2 is created by another process)
112
196
  responses.add(
113
197
  responses.GET,
114
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
198
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
115
199
  json={
116
200
  "count": 2,
117
201
  "next": None,
@@ -141,15 +225,15 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
141
225
  ] == BASE_API_CALLS + [
142
226
  (
143
227
  "GET",
144
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
228
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
145
229
  ),
146
230
  (
147
231
  "POST",
148
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
232
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
149
233
  ),
150
234
  (
151
235
  "GET",
152
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
236
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
153
237
  ),
154
238
  ]
155
239
 
@@ -169,7 +253,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
169
253
  """
170
254
  responses.add(
171
255
  responses.GET,
172
- f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
256
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
173
257
  status=200,
174
258
  json={
175
259
  "count": 1,
@@ -189,7 +273,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
189
273
  ] == BASE_API_CALLS + [
190
274
  (
191
275
  "GET",
192
- f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
276
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
193
277
  ),
194
278
  ]
195
279
 
@@ -276,7 +360,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
276
360
  responses.add(
277
361
  responses.POST,
278
362
  "http://testserver/api/v1/classifications/",
279
- status=500,
363
+ status=418,
280
364
  )
281
365
 
282
366
  with pytest.raises(ErrorResponse):
@@ -287,17 +371,10 @@ def test_create_classification_api_error(responses, mock_elements_worker):
287
371
  high_confidence=True,
288
372
  )
289
373
 
290
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
374
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
291
375
  assert [
292
376
  (call.request.method, call.request.url) for call in responses.calls
293
- ] == BASE_API_CALLS + [
294
- # We retry 5 times the API call
295
- ("POST", "http://testserver/api/v1/classifications/"),
296
- ("POST", "http://testserver/api/v1/classifications/"),
297
- ("POST", "http://testserver/api/v1/classifications/"),
298
- ("POST", "http://testserver/api/v1/classifications/"),
299
- ("POST", "http://testserver/api/v1/classifications/"),
300
- ]
377
+ ] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classifications/")]
301
378
 
302
379
 
303
380
  def test_create_classification_create_ml_class(mock_elements_worker, responses):
@@ -306,7 +383,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
306
383
  # Automatically create a missing class!
307
384
  responses.add(
308
385
  responses.POST,
309
- "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
386
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
310
387
  status=201,
311
388
  json={"id": "new-ml-class-1234"},
312
389
  )
@@ -330,7 +407,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
330
407
  for call in responses.calls[-2:]
331
408
  ] == [
332
409
  (
333
- "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
410
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
334
411
  {"name": "a_class"},
335
412
  ),
336
413
  (
@@ -609,7 +686,7 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
609
686
  responses.add(
610
687
  responses.POST,
611
688
  "http://testserver/api/v1/classification/bulk/",
612
- status=500,
689
+ status=418,
613
690
  )
614
691
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
615
692
  classes = [
@@ -630,17 +707,10 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
630
707
  element=elt, classifications=classes
631
708
  )
632
709
 
633
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
710
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
634
711
  assert [
635
712
  (call.request.method, call.request.url) for call in responses.calls
636
- ] == BASE_API_CALLS + [
637
- # We retry 5 times the API call
638
- ("POST", "http://testserver/api/v1/classification/bulk/"),
639
- ("POST", "http://testserver/api/v1/classification/bulk/"),
640
- ("POST", "http://testserver/api/v1/classification/bulk/"),
641
- ("POST", "http://testserver/api/v1/classification/bulk/"),
642
- ("POST", "http://testserver/api/v1/classification/bulk/"),
643
- ]
713
+ ] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classification/bulk/")]
644
714
 
645
715
 
646
716
  def test_create_classifications_create_ml_class(mock_elements_worker, responses):
@@ -649,7 +719,7 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
649
719
  # Automatically create a missing class!
650
720
  responses.add(
651
721
  responses.POST,
652
- "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
722
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
653
723
  status=201,
654
724
  json={"id": "new-ml-class-1234"},
655
725
  )
@@ -690,7 +760,7 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
690
760
  ] == BASE_API_CALLS + [
691
761
  (
692
762
  "POST",
693
- "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
763
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
694
764
  ),
695
765
  ("POST", "http://testserver/api/v1/classification/bulk/"),
696
766
  ]
@@ -709,7 +779,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
709
779
  }
710
780
 
711
781
 
712
- def test_create_classifications(responses, mock_elements_worker):
782
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
783
+ def test_create_classifications(batch_size, responses, mock_elements_worker):
713
784
  mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
714
785
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
715
786
  responses.add(
@@ -733,62 +804,98 @@ def test_create_classifications(responses, mock_elements_worker):
733
804
  "high_confidence": False,
734
805
  },
735
806
  ],
807
+ batch_size=batch_size,
736
808
  )
737
809
 
738
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
810
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
811
+ if batch_size != DEFAULT_BATCH_SIZE:
812
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
813
+
814
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
739
815
  assert [
740
816
  (call.request.method, call.request.url) for call in responses.calls
741
- ] == BASE_API_CALLS + [
742
- ("POST", "http://testserver/api/v1/classification/bulk/"),
743
- ]
817
+ ] == BASE_API_CALLS + bulk_api_calls
744
818
 
745
- assert json.loads(responses.calls[-1].request.body) == {
819
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
820
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
821
+ empty_payload = {
746
822
  "parent": str(elt.id),
747
823
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
748
- "classifications": [
749
- {
750
- "confidence": 0.75,
751
- "high_confidence": False,
752
- "ml_class": "0000",
753
- },
754
- {
755
- "confidence": 0.25,
756
- "high_confidence": False,
757
- "ml_class": "1111",
758
- },
759
- ],
824
+ "classifications": [],
760
825
  }
761
826
 
827
+ bodies = []
828
+ first_call_idx = None
829
+ if batch_size > 1:
830
+ first_call_idx = -1
831
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
832
+ else:
833
+ first_call_idx = -2
834
+ bodies.append({**empty_payload, "classifications": [first_cl]})
835
+ bodies.append({**empty_payload, "classifications": [second_cl]})
836
+
837
+ assert [
838
+ json.loads(bulk_call.request.body)
839
+ for bulk_call in responses.calls[first_call_idx:]
840
+ ] == bodies
762
841
 
763
- def test_create_classifications_with_cache(responses, mock_elements_worker_with_cache):
842
+
843
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
844
+ def test_create_classifications_with_cache(
845
+ batch_size, responses, mock_elements_worker_with_cache
846
+ ):
764
847
  mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
765
848
  elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
766
849
 
767
- responses.add(
768
- responses.POST,
769
- "http://testserver/api/v1/classification/bulk/",
770
- status=200,
771
- json={
772
- "parent": str(elt.id),
773
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
774
- "classifications": [
775
- {
776
- "id": "00000000-0000-0000-0000-000000000000",
777
- "ml_class": "0000",
778
- "confidence": 0.75,
779
- "high_confidence": False,
780
- "state": "pending",
781
- },
782
- {
783
- "id": "11111111-1111-1111-1111-111111111111",
784
- "ml_class": "1111",
785
- "confidence": 0.25,
786
- "high_confidence": False,
787
- "state": "pending",
850
+ if batch_size > 1:
851
+ responses.add(
852
+ responses.POST,
853
+ "http://testserver/api/v1/classification/bulk/",
854
+ status=200,
855
+ json={
856
+ "parent": str(elt.id),
857
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
858
+ "classifications": [
859
+ {
860
+ "id": "00000000-0000-0000-0000-000000000000",
861
+ "ml_class": "0000",
862
+ "confidence": 0.75,
863
+ "high_confidence": False,
864
+ "state": "pending",
865
+ },
866
+ {
867
+ "id": "11111111-1111-1111-1111-111111111111",
868
+ "ml_class": "1111",
869
+ "confidence": 0.25,
870
+ "high_confidence": False,
871
+ "state": "pending",
872
+ },
873
+ ],
874
+ },
875
+ )
876
+ else:
877
+ for cl_id, cl_class, cl_conf in [
878
+ ("00000000-0000-0000-0000-000000000000", "0000", 0.75),
879
+ ("11111111-1111-1111-1111-111111111111", "1111", 0.25),
880
+ ]:
881
+ responses.add(
882
+ responses.POST,
883
+ "http://testserver/api/v1/classification/bulk/",
884
+ status=200,
885
+ json={
886
+ "parent": str(elt.id),
887
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
888
+ "classifications": [
889
+ {
890
+ "id": cl_id,
891
+ "ml_class": cl_class,
892
+ "confidence": cl_conf,
893
+ "high_confidence": False,
894
+ "state": "pending",
895
+ },
896
+ ],
788
897
  },
789
- ],
790
- },
791
- )
898
+ )
792
899
 
793
900
  mock_elements_worker_with_cache.create_classifications(
794
901
  element=elt,
@@ -804,32 +911,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
804
911
  "high_confidence": False,
805
912
  },
806
913
  ],
914
+ batch_size=batch_size,
807
915
  )
808
916
 
809
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
917
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
918
+ if batch_size != DEFAULT_BATCH_SIZE:
919
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
920
+
921
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
810
922
  assert [
811
923
  (call.request.method, call.request.url) for call in responses.calls
812
- ] == BASE_API_CALLS + [
813
- ("POST", "http://testserver/api/v1/classification/bulk/"),
814
- ]
924
+ ] == BASE_API_CALLS + bulk_api_calls
815
925
 
816
- assert json.loads(responses.calls[-1].request.body) == {
926
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
927
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
928
+ empty_payload = {
817
929
  "parent": str(elt.id),
818
930
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
819
- "classifications": [
820
- {
821
- "confidence": 0.75,
822
- "high_confidence": False,
823
- "ml_class": "0000",
824
- },
825
- {
826
- "confidence": 0.25,
827
- "high_confidence": False,
828
- "ml_class": "1111",
829
- },
830
- ],
931
+ "classifications": [],
831
932
  }
832
933
 
934
+ bodies = []
935
+ first_call_idx = None
936
+ if batch_size > 1:
937
+ first_call_idx = -1
938
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
939
+ else:
940
+ first_call_idx = -2
941
+ bodies.append({**empty_payload, "classifications": [first_cl]})
942
+ bodies.append({**empty_payload, "classifications": [second_cl]})
943
+
944
+ assert [
945
+ json.loads(bulk_call.request.body)
946
+ for bulk_call in responses.calls[first_call_idx:]
947
+ ] == bodies
948
+
833
949
  # Check that created classifications were properly stored in SQLite cache
834
950
  assert list(CachedClassification.select()) == [
835
951
  CachedClassification(