arkindex-base-worker 0.3.7rc5__py3-none-any.whl → 0.5.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
- arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
- {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
- {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
- arkindex_worker/cache.py +1 -1
- arkindex_worker/image.py +167 -2
- arkindex_worker/models.py +18 -0
- arkindex_worker/utils.py +98 -4
- arkindex_worker/worker/__init__.py +117 -218
- arkindex_worker/worker/base.py +39 -46
- arkindex_worker/worker/classification.py +34 -18
- arkindex_worker/worker/corpus.py +86 -0
- arkindex_worker/worker/dataset.py +89 -26
- arkindex_worker/worker/element.py +352 -91
- arkindex_worker/worker/entity.py +13 -11
- arkindex_worker/worker/image.py +21 -0
- arkindex_worker/worker/metadata.py +26 -16
- arkindex_worker/worker/process.py +92 -0
- arkindex_worker/worker/task.py +5 -4
- arkindex_worker/worker/training.py +25 -10
- arkindex_worker/worker/transcription.py +89 -68
- arkindex_worker/worker/version.py +3 -1
- hooks/pre_gen_project.py +3 -0
- tests/__init__.py +8 -0
- tests/conftest.py +47 -58
- tests/test_base_worker.py +212 -12
- tests/test_dataset_worker.py +294 -437
- tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
- tests/test_elements_worker/test_cli.py +3 -11
- tests/test_elements_worker/test_corpus.py +168 -0
- tests/test_elements_worker/test_dataset.py +106 -157
- tests/test_elements_worker/test_element.py +427 -0
- tests/test_elements_worker/test_element_create_multiple.py +715 -0
- tests/test_elements_worker/test_element_create_single.py +528 -0
- tests/test_elements_worker/test_element_list_children.py +969 -0
- tests/test_elements_worker/test_element_list_parents.py +530 -0
- tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
- tests/test_elements_worker/test_entity_list_and_check.py +160 -0
- tests/test_elements_worker/test_image.py +66 -0
- tests/test_elements_worker/test_metadata.py +252 -161
- tests/test_elements_worker/test_process.py +89 -0
- tests/test_elements_worker/test_task.py +8 -18
- tests/test_elements_worker/test_training.py +17 -8
- tests/test_elements_worker/test_transcription_create.py +873 -0
- tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
- tests/test_elements_worker/test_transcription_list.py +450 -0
- tests/test_elements_worker/test_version.py +60 -0
- tests/test_elements_worker/test_worker.py +578 -293
- tests/test_image.py +542 -209
- tests/test_merge.py +1 -2
- tests/test_utils.py +89 -4
- worker-demo/tests/__init__.py +0 -0
- worker-demo/tests/conftest.py +32 -0
- worker-demo/tests/test_worker.py +12 -0
- worker-demo/worker_demo/__init__.py +6 -0
- worker-demo/worker_demo/worker.py +19 -0
- arkindex_base_worker-0.3.7rc5.dist-info/RECORD +0 -41
- tests/test_elements_worker/test_elements.py +0 -2713
- tests/test_elements_worker/test_transcriptions.py +0 -2119
- {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
|
@@ -3,10 +3,12 @@ import re
|
|
|
3
3
|
from uuid import UUID
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
-
from apistar.exceptions import ErrorResponse
|
|
7
6
|
|
|
7
|
+
from arkindex.exceptions import ErrorResponse
|
|
8
8
|
from arkindex_worker.cache import CachedClassification, CachedElement
|
|
9
9
|
from arkindex_worker.models import Element
|
|
10
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
11
|
+
from tests import CORPUS_ID
|
|
10
12
|
|
|
11
13
|
from . import BASE_API_CALLS
|
|
12
14
|
|
|
@@ -15,11 +17,96 @@ from . import BASE_API_CALLS
|
|
|
15
17
|
DELETE_PARAMETER = "DELETE_PARAMETER"
|
|
16
18
|
|
|
17
19
|
|
|
20
|
+
def test_load_corpus_classes_api_error(responses, mock_elements_worker):
|
|
21
|
+
responses.add(
|
|
22
|
+
responses.GET,
|
|
23
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
24
|
+
status=418,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
assert not mock_elements_worker.classes
|
|
28
|
+
with pytest.raises(
|
|
29
|
+
Exception, match="Stopping pagination as data will be incomplete"
|
|
30
|
+
):
|
|
31
|
+
mock_elements_worker.load_corpus_classes()
|
|
32
|
+
|
|
33
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 5
|
|
34
|
+
assert [
|
|
35
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
36
|
+
] == BASE_API_CALLS + [
|
|
37
|
+
# We do 5 retries
|
|
38
|
+
(
|
|
39
|
+
"GET",
|
|
40
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
41
|
+
),
|
|
42
|
+
(
|
|
43
|
+
"GET",
|
|
44
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
45
|
+
),
|
|
46
|
+
(
|
|
47
|
+
"GET",
|
|
48
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
49
|
+
),
|
|
50
|
+
(
|
|
51
|
+
"GET",
|
|
52
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
53
|
+
),
|
|
54
|
+
(
|
|
55
|
+
"GET",
|
|
56
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
57
|
+
),
|
|
58
|
+
]
|
|
59
|
+
assert not mock_elements_worker.classes
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_load_corpus_classes(responses, mock_elements_worker):
|
|
63
|
+
responses.add(
|
|
64
|
+
responses.GET,
|
|
65
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
66
|
+
status=200,
|
|
67
|
+
json={
|
|
68
|
+
"count": 3,
|
|
69
|
+
"next": None,
|
|
70
|
+
"results": [
|
|
71
|
+
{
|
|
72
|
+
"id": "0000",
|
|
73
|
+
"name": "good",
|
|
74
|
+
},
|
|
75
|
+
{
|
|
76
|
+
"id": "1111",
|
|
77
|
+
"name": "average",
|
|
78
|
+
},
|
|
79
|
+
{
|
|
80
|
+
"id": "2222",
|
|
81
|
+
"name": "bad",
|
|
82
|
+
},
|
|
83
|
+
],
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
assert not mock_elements_worker.classes
|
|
88
|
+
mock_elements_worker.load_corpus_classes()
|
|
89
|
+
|
|
90
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
91
|
+
assert [
|
|
92
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
93
|
+
] == BASE_API_CALLS + [
|
|
94
|
+
(
|
|
95
|
+
"GET",
|
|
96
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
97
|
+
),
|
|
98
|
+
]
|
|
99
|
+
assert mock_elements_worker.classes == {
|
|
100
|
+
"good": "0000",
|
|
101
|
+
"average": "1111",
|
|
102
|
+
"bad": "2222",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
18
106
|
def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
19
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
20
107
|
responses.add(
|
|
21
108
|
responses.GET,
|
|
22
|
-
f"http://testserver/api/v1/corpus/{
|
|
109
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
23
110
|
status=200,
|
|
24
111
|
json={
|
|
25
112
|
"count": 1,
|
|
@@ -42,7 +129,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
|
42
129
|
] == BASE_API_CALLS + [
|
|
43
130
|
(
|
|
44
131
|
"GET",
|
|
45
|
-
f"http://testserver/api/v1/corpus/{
|
|
132
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
46
133
|
),
|
|
47
134
|
]
|
|
48
135
|
assert mock_elements_worker.classes == {"good": "0000"}
|
|
@@ -51,12 +138,11 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
|
51
138
|
|
|
52
139
|
def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
|
|
53
140
|
# A missing class is now created automatically
|
|
54
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
55
141
|
mock_elements_worker.classes = {"good": "0000"}
|
|
56
142
|
|
|
57
143
|
responses.add(
|
|
58
144
|
responses.POST,
|
|
59
|
-
f"http://testserver/api/v1/corpus/{
|
|
145
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
60
146
|
status=201,
|
|
61
147
|
json={"id": "new-ml-class-1234"},
|
|
62
148
|
)
|
|
@@ -82,12 +168,10 @@ def test_get_ml_class_id(mock_elements_worker):
|
|
|
82
168
|
|
|
83
169
|
|
|
84
170
|
def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
85
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
86
|
-
|
|
87
171
|
# Add some initial classes
|
|
88
172
|
responses.add(
|
|
89
173
|
responses.GET,
|
|
90
|
-
f"http://testserver/api/v1/corpus/{
|
|
174
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
91
175
|
json={
|
|
92
176
|
"count": 1,
|
|
93
177
|
"next": None,
|
|
@@ -103,7 +187,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
103
187
|
# Invalid response when trying to create class2
|
|
104
188
|
responses.add(
|
|
105
189
|
responses.POST,
|
|
106
|
-
f"http://testserver/api/v1/corpus/{
|
|
190
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
107
191
|
status=400,
|
|
108
192
|
json={"non_field_errors": "Already exists"},
|
|
109
193
|
)
|
|
@@ -111,7 +195,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
111
195
|
# Add both classes (class2 is created by another process)
|
|
112
196
|
responses.add(
|
|
113
197
|
responses.GET,
|
|
114
|
-
f"http://testserver/api/v1/corpus/{
|
|
198
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
115
199
|
json={
|
|
116
200
|
"count": 2,
|
|
117
201
|
"next": None,
|
|
@@ -141,15 +225,15 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
141
225
|
] == BASE_API_CALLS + [
|
|
142
226
|
(
|
|
143
227
|
"GET",
|
|
144
|
-
f"http://testserver/api/v1/corpus/{
|
|
228
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
145
229
|
),
|
|
146
230
|
(
|
|
147
231
|
"POST",
|
|
148
|
-
f"http://testserver/api/v1/corpus/{
|
|
232
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
149
233
|
),
|
|
150
234
|
(
|
|
151
235
|
"GET",
|
|
152
|
-
f"http://testserver/api/v1/corpus/{
|
|
236
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
153
237
|
),
|
|
154
238
|
]
|
|
155
239
|
|
|
@@ -169,7 +253,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
|
|
|
169
253
|
"""
|
|
170
254
|
responses.add(
|
|
171
255
|
responses.GET,
|
|
172
|
-
f"http://testserver/api/v1/corpus/{
|
|
256
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
173
257
|
status=200,
|
|
174
258
|
json={
|
|
175
259
|
"count": 1,
|
|
@@ -189,7 +273,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
|
|
|
189
273
|
] == BASE_API_CALLS + [
|
|
190
274
|
(
|
|
191
275
|
"GET",
|
|
192
|
-
f"http://testserver/api/v1/corpus/{
|
|
276
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
193
277
|
),
|
|
194
278
|
]
|
|
195
279
|
|
|
@@ -276,7 +360,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
|
|
|
276
360
|
responses.add(
|
|
277
361
|
responses.POST,
|
|
278
362
|
"http://testserver/api/v1/classifications/",
|
|
279
|
-
status=
|
|
363
|
+
status=418,
|
|
280
364
|
)
|
|
281
365
|
|
|
282
366
|
with pytest.raises(ErrorResponse):
|
|
@@ -287,17 +371,10 @@ def test_create_classification_api_error(responses, mock_elements_worker):
|
|
|
287
371
|
high_confidence=True,
|
|
288
372
|
)
|
|
289
373
|
|
|
290
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
374
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
291
375
|
assert [
|
|
292
376
|
(call.request.method, call.request.url) for call in responses.calls
|
|
293
|
-
] == BASE_API_CALLS + [
|
|
294
|
-
# We retry 5 times the API call
|
|
295
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
296
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
297
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
298
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
299
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
300
|
-
]
|
|
377
|
+
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classifications/")]
|
|
301
378
|
|
|
302
379
|
|
|
303
380
|
def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
@@ -306,7 +383,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
|
306
383
|
# Automatically create a missing class!
|
|
307
384
|
responses.add(
|
|
308
385
|
responses.POST,
|
|
309
|
-
"http://testserver/api/v1/corpus/
|
|
386
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
310
387
|
status=201,
|
|
311
388
|
json={"id": "new-ml-class-1234"},
|
|
312
389
|
)
|
|
@@ -330,7 +407,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
|
330
407
|
for call in responses.calls[-2:]
|
|
331
408
|
] == [
|
|
332
409
|
(
|
|
333
|
-
"http://testserver/api/v1/corpus/
|
|
410
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
334
411
|
{"name": "a_class"},
|
|
335
412
|
),
|
|
336
413
|
(
|
|
@@ -609,7 +686,7 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
|
|
|
609
686
|
responses.add(
|
|
610
687
|
responses.POST,
|
|
611
688
|
"http://testserver/api/v1/classification/bulk/",
|
|
612
|
-
status=
|
|
689
|
+
status=418,
|
|
613
690
|
)
|
|
614
691
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
615
692
|
classes = [
|
|
@@ -630,17 +707,10 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
|
|
|
630
707
|
element=elt, classifications=classes
|
|
631
708
|
)
|
|
632
709
|
|
|
633
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
710
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
634
711
|
assert [
|
|
635
712
|
(call.request.method, call.request.url) for call in responses.calls
|
|
636
|
-
] == BASE_API_CALLS + [
|
|
637
|
-
# We retry 5 times the API call
|
|
638
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
639
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
640
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
641
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
642
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
643
|
-
]
|
|
713
|
+
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
644
714
|
|
|
645
715
|
|
|
646
716
|
def test_create_classifications_create_ml_class(mock_elements_worker, responses):
|
|
@@ -649,7 +719,7 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
649
719
|
# Automatically create a missing class!
|
|
650
720
|
responses.add(
|
|
651
721
|
responses.POST,
|
|
652
|
-
"http://testserver/api/v1/corpus/
|
|
722
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
653
723
|
status=201,
|
|
654
724
|
json={"id": "new-ml-class-1234"},
|
|
655
725
|
)
|
|
@@ -690,7 +760,7 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
690
760
|
] == BASE_API_CALLS + [
|
|
691
761
|
(
|
|
692
762
|
"POST",
|
|
693
|
-
"http://testserver/api/v1/corpus/
|
|
763
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
694
764
|
),
|
|
695
765
|
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
696
766
|
]
|
|
@@ -709,7 +779,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
709
779
|
}
|
|
710
780
|
|
|
711
781
|
|
|
712
|
-
|
|
782
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
783
|
+
def test_create_classifications(batch_size, responses, mock_elements_worker):
|
|
713
784
|
mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
|
|
714
785
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
715
786
|
responses.add(
|
|
@@ -733,62 +804,98 @@ def test_create_classifications(responses, mock_elements_worker):
|
|
|
733
804
|
"high_confidence": False,
|
|
734
805
|
},
|
|
735
806
|
],
|
|
807
|
+
batch_size=batch_size,
|
|
736
808
|
)
|
|
737
809
|
|
|
738
|
-
|
|
810
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
811
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
812
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
813
|
+
|
|
814
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
739
815
|
assert [
|
|
740
816
|
(call.request.method, call.request.url) for call in responses.calls
|
|
741
|
-
] == BASE_API_CALLS +
|
|
742
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
743
|
-
]
|
|
817
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
744
818
|
|
|
745
|
-
|
|
819
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
820
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
821
|
+
empty_payload = {
|
|
746
822
|
"parent": str(elt.id),
|
|
747
823
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
748
|
-
"classifications": [
|
|
749
|
-
{
|
|
750
|
-
"confidence": 0.75,
|
|
751
|
-
"high_confidence": False,
|
|
752
|
-
"ml_class": "0000",
|
|
753
|
-
},
|
|
754
|
-
{
|
|
755
|
-
"confidence": 0.25,
|
|
756
|
-
"high_confidence": False,
|
|
757
|
-
"ml_class": "1111",
|
|
758
|
-
},
|
|
759
|
-
],
|
|
824
|
+
"classifications": [],
|
|
760
825
|
}
|
|
761
826
|
|
|
827
|
+
bodies = []
|
|
828
|
+
first_call_idx = None
|
|
829
|
+
if batch_size > 1:
|
|
830
|
+
first_call_idx = -1
|
|
831
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
832
|
+
else:
|
|
833
|
+
first_call_idx = -2
|
|
834
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
835
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
836
|
+
|
|
837
|
+
assert [
|
|
838
|
+
json.loads(bulk_call.request.body)
|
|
839
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
840
|
+
] == bodies
|
|
762
841
|
|
|
763
|
-
|
|
842
|
+
|
|
843
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
844
|
+
def test_create_classifications_with_cache(
|
|
845
|
+
batch_size, responses, mock_elements_worker_with_cache
|
|
846
|
+
):
|
|
764
847
|
mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
|
|
765
848
|
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
|
|
766
849
|
|
|
767
|
-
|
|
768
|
-
responses.
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
850
|
+
if batch_size > 1:
|
|
851
|
+
responses.add(
|
|
852
|
+
responses.POST,
|
|
853
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
854
|
+
status=200,
|
|
855
|
+
json={
|
|
856
|
+
"parent": str(elt.id),
|
|
857
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
858
|
+
"classifications": [
|
|
859
|
+
{
|
|
860
|
+
"id": "00000000-0000-0000-0000-000000000000",
|
|
861
|
+
"ml_class": "0000",
|
|
862
|
+
"confidence": 0.75,
|
|
863
|
+
"high_confidence": False,
|
|
864
|
+
"state": "pending",
|
|
865
|
+
},
|
|
866
|
+
{
|
|
867
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
868
|
+
"ml_class": "1111",
|
|
869
|
+
"confidence": 0.25,
|
|
870
|
+
"high_confidence": False,
|
|
871
|
+
"state": "pending",
|
|
872
|
+
},
|
|
873
|
+
],
|
|
874
|
+
},
|
|
875
|
+
)
|
|
876
|
+
else:
|
|
877
|
+
for cl_id, cl_class, cl_conf in [
|
|
878
|
+
("00000000-0000-0000-0000-000000000000", "0000", 0.75),
|
|
879
|
+
("11111111-1111-1111-1111-111111111111", "1111", 0.25),
|
|
880
|
+
]:
|
|
881
|
+
responses.add(
|
|
882
|
+
responses.POST,
|
|
883
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
884
|
+
status=200,
|
|
885
|
+
json={
|
|
886
|
+
"parent": str(elt.id),
|
|
887
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
888
|
+
"classifications": [
|
|
889
|
+
{
|
|
890
|
+
"id": cl_id,
|
|
891
|
+
"ml_class": cl_class,
|
|
892
|
+
"confidence": cl_conf,
|
|
893
|
+
"high_confidence": False,
|
|
894
|
+
"state": "pending",
|
|
895
|
+
},
|
|
896
|
+
],
|
|
788
897
|
},
|
|
789
|
-
|
|
790
|
-
},
|
|
791
|
-
)
|
|
898
|
+
)
|
|
792
899
|
|
|
793
900
|
mock_elements_worker_with_cache.create_classifications(
|
|
794
901
|
element=elt,
|
|
@@ -804,32 +911,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
|
|
|
804
911
|
"high_confidence": False,
|
|
805
912
|
},
|
|
806
913
|
],
|
|
914
|
+
batch_size=batch_size,
|
|
807
915
|
)
|
|
808
916
|
|
|
809
|
-
|
|
917
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
918
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
919
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
920
|
+
|
|
921
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
810
922
|
assert [
|
|
811
923
|
(call.request.method, call.request.url) for call in responses.calls
|
|
812
|
-
] == BASE_API_CALLS +
|
|
813
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
814
|
-
]
|
|
924
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
815
925
|
|
|
816
|
-
|
|
926
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
927
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
928
|
+
empty_payload = {
|
|
817
929
|
"parent": str(elt.id),
|
|
818
930
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
819
|
-
"classifications": [
|
|
820
|
-
{
|
|
821
|
-
"confidence": 0.75,
|
|
822
|
-
"high_confidence": False,
|
|
823
|
-
"ml_class": "0000",
|
|
824
|
-
},
|
|
825
|
-
{
|
|
826
|
-
"confidence": 0.25,
|
|
827
|
-
"high_confidence": False,
|
|
828
|
-
"ml_class": "1111",
|
|
829
|
-
},
|
|
830
|
-
],
|
|
931
|
+
"classifications": [],
|
|
831
932
|
}
|
|
832
933
|
|
|
934
|
+
bodies = []
|
|
935
|
+
first_call_idx = None
|
|
936
|
+
if batch_size > 1:
|
|
937
|
+
first_call_idx = -1
|
|
938
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
939
|
+
else:
|
|
940
|
+
first_call_idx = -2
|
|
941
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
942
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
943
|
+
|
|
944
|
+
assert [
|
|
945
|
+
json.loads(bulk_call.request.body)
|
|
946
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
947
|
+
] == bodies
|
|
948
|
+
|
|
833
949
|
# Check that created classifications were properly stored in SQLite cache
|
|
834
950
|
assert list(CachedClassification.select()) == [
|
|
835
951
|
CachedClassification(
|
|
@@ -2,7 +2,6 @@ import json
|
|
|
2
2
|
import sys
|
|
3
3
|
import tempfile
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from uuid import UUID
|
|
6
5
|
|
|
7
6
|
import pytest
|
|
8
7
|
|
|
@@ -58,13 +57,6 @@ def test_cli_arg_elements_list_given(mocker):
|
|
|
58
57
|
path.unlink()
|
|
59
58
|
|
|
60
59
|
|
|
61
|
-
def test_cli_arg_element_one_given_not_uuid(mocker):
|
|
62
|
-
mocker.patch.object(sys, "argv", ["worker", "--element", "1234"])
|
|
63
|
-
worker = ElementsWorker()
|
|
64
|
-
with pytest.raises(SystemExit):
|
|
65
|
-
worker.configure()
|
|
66
|
-
|
|
67
|
-
|
|
68
60
|
@pytest.mark.usefixtures("_mock_worker_run_api")
|
|
69
61
|
def test_cli_arg_element_one_given(mocker):
|
|
70
62
|
mocker.patch.object(
|
|
@@ -73,7 +65,7 @@ def test_cli_arg_element_one_given(mocker):
|
|
|
73
65
|
worker = ElementsWorker()
|
|
74
66
|
worker.configure()
|
|
75
67
|
|
|
76
|
-
assert worker.args.element == [
|
|
68
|
+
assert worker.args.element == ["12341234-1234-1234-1234-123412341234"]
|
|
77
69
|
# elements_list is None because TASK_ELEMENTS environment variable isn't set
|
|
78
70
|
assert not worker.args.elements_list
|
|
79
71
|
|
|
@@ -94,8 +86,8 @@ def test_cli_arg_element_many_given(mocker):
|
|
|
94
86
|
worker.configure()
|
|
95
87
|
|
|
96
88
|
assert worker.args.element == [
|
|
97
|
-
|
|
98
|
-
|
|
89
|
+
"12341234-1234-1234-1234-123412341234",
|
|
90
|
+
"43214321-4321-4321-4321-432143214321",
|
|
99
91
|
]
|
|
100
92
|
# elements_list is None because TASK_ELEMENTS environment variable isn't set
|
|
101
93
|
assert not worker.args.elements_list
|