arkindex-base-worker 0.3.7rc9__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/METADATA +16 -20
- arkindex_base_worker-0.4.0.dist-info/RECORD +61 -0
- {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/WHEEL +1 -1
- arkindex_worker/cache.py +1 -1
- arkindex_worker/image.py +120 -1
- arkindex_worker/models.py +6 -0
- arkindex_worker/utils.py +85 -4
- arkindex_worker/worker/__init__.py +68 -162
- arkindex_worker/worker/base.py +39 -34
- arkindex_worker/worker/classification.py +34 -18
- arkindex_worker/worker/corpus.py +86 -0
- arkindex_worker/worker/dataset.py +71 -1
- arkindex_worker/worker/element.py +352 -91
- arkindex_worker/worker/entity.py +11 -11
- arkindex_worker/worker/image.py +21 -0
- arkindex_worker/worker/metadata.py +19 -9
- arkindex_worker/worker/process.py +92 -0
- arkindex_worker/worker/task.py +5 -4
- arkindex_worker/worker/training.py +25 -10
- arkindex_worker/worker/transcription.py +89 -68
- arkindex_worker/worker/version.py +3 -1
- tests/__init__.py +8 -0
- tests/conftest.py +36 -52
- tests/test_base_worker.py +212 -12
- tests/test_dataset_worker.py +21 -45
- tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
- tests/test_elements_worker/test_cli.py +3 -11
- tests/test_elements_worker/test_corpus.py +168 -0
- tests/test_elements_worker/test_dataset.py +7 -12
- tests/test_elements_worker/test_element.py +427 -0
- tests/test_elements_worker/test_element_create_multiple.py +715 -0
- tests/test_elements_worker/test_element_create_single.py +528 -0
- tests/test_elements_worker/test_element_list_children.py +969 -0
- tests/test_elements_worker/test_element_list_parents.py +530 -0
- tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
- tests/test_elements_worker/test_entity_list_and_check.py +160 -0
- tests/test_elements_worker/test_image.py +66 -0
- tests/test_elements_worker/test_metadata.py +230 -139
- tests/test_elements_worker/test_process.py +89 -0
- tests/test_elements_worker/test_task.py +8 -18
- tests/test_elements_worker/test_training.py +17 -8
- tests/test_elements_worker/test_transcription_create.py +873 -0
- tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
- tests/test_elements_worker/test_transcription_list.py +450 -0
- tests/test_elements_worker/test_version.py +60 -0
- tests/test_elements_worker/test_worker.py +563 -279
- tests/test_image.py +432 -209
- tests/test_merge.py +1 -2
- tests/test_utils.py +66 -3
- arkindex_base_worker-0.3.7rc9.dist-info/RECORD +0 -47
- tests/test_elements_worker/test_elements.py +0 -2713
- tests/test_elements_worker/test_transcriptions.py +0 -2119
- {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/LICENSE +0 -0
- {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/top_level.txt +0 -0
tests/test_dataset_worker.py
CHANGED
|
@@ -3,18 +3,21 @@ import uuid
|
|
|
3
3
|
from argparse import ArgumentTypeError
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
-
from apistar.exceptions import ErrorResponse
|
|
7
6
|
|
|
7
|
+
from arkindex.exceptions import ErrorResponse
|
|
8
8
|
from arkindex_worker.models import Dataset, Set
|
|
9
|
-
from arkindex_worker.worker import
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
from arkindex_worker.worker.dataset import (
|
|
10
|
+
DatasetState,
|
|
11
|
+
MissingDatasetArchive,
|
|
12
|
+
check_dataset_set,
|
|
13
|
+
)
|
|
14
|
+
from tests import FIXTURES_DIR, PROCESS_ID
|
|
12
15
|
from tests.test_elements_worker import BASE_API_CALLS
|
|
13
16
|
|
|
14
17
|
RANDOM_UUID = uuid.uuid4()
|
|
15
18
|
|
|
16
19
|
|
|
17
|
-
@pytest.fixture
|
|
20
|
+
@pytest.fixture
|
|
18
21
|
def tmp_archive(tmp_path):
|
|
19
22
|
archive = tmp_path / "test_archive.tar.zst"
|
|
20
23
|
archive.touch()
|
|
@@ -63,22 +66,17 @@ def test_download_dataset_artifact_list_api_error(
|
|
|
63
66
|
responses.add(
|
|
64
67
|
responses.GET,
|
|
65
68
|
f"http://testserver/api/v1/task/{task_id}/artifacts/",
|
|
66
|
-
status=
|
|
69
|
+
status=418,
|
|
67
70
|
)
|
|
68
71
|
|
|
69
72
|
with pytest.raises(ErrorResponse):
|
|
70
73
|
mock_dataset_worker.download_dataset_artifact(default_dataset)
|
|
71
74
|
|
|
72
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
75
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
73
76
|
assert [
|
|
74
77
|
(call.request.method, call.request.url) for call in responses.calls
|
|
75
78
|
] == BASE_API_CALLS + [
|
|
76
|
-
|
|
77
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
|
|
78
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
|
|
79
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
|
|
80
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
|
|
81
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
|
|
79
|
+
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/")
|
|
82
80
|
]
|
|
83
81
|
|
|
84
82
|
|
|
@@ -116,22 +114,17 @@ def test_download_dataset_artifact_download_api_error(
|
|
|
116
114
|
responses.add(
|
|
117
115
|
responses.GET,
|
|
118
116
|
f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst",
|
|
119
|
-
status=
|
|
117
|
+
status=418,
|
|
120
118
|
)
|
|
121
119
|
|
|
122
120
|
with pytest.raises(ErrorResponse):
|
|
123
121
|
mock_dataset_worker.download_dataset_artifact(default_dataset)
|
|
124
122
|
|
|
125
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
123
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 2
|
|
126
124
|
assert [
|
|
127
125
|
(call.request.method, call.request.url) for call in responses.calls
|
|
128
126
|
] == BASE_API_CALLS + [
|
|
129
127
|
("GET", f"http://testserver/api/v1/task/{task_id}/artifacts/"),
|
|
130
|
-
# The API call is retried 5 times
|
|
131
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
|
|
132
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
|
|
133
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
|
|
134
|
-
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
|
|
135
128
|
("GET", f"http://testserver/api/v1/task/{task_id}/artifact/dataset_id.tar.zst"),
|
|
136
129
|
]
|
|
137
130
|
|
|
@@ -284,7 +277,7 @@ def test_list_sets_api_error(responses, mock_dataset_worker):
|
|
|
284
277
|
responses.add(
|
|
285
278
|
responses.GET,
|
|
286
279
|
f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
|
|
287
|
-
status=
|
|
280
|
+
status=418,
|
|
288
281
|
)
|
|
289
282
|
|
|
290
283
|
with pytest.raises(
|
|
@@ -393,20 +386,15 @@ def test_list_sets_retrieve_dataset_api_error(
|
|
|
393
386
|
responses.add(
|
|
394
387
|
responses.GET,
|
|
395
388
|
f"http://testserver/api/v1/datasets/{default_dataset.id}/",
|
|
396
|
-
status=
|
|
389
|
+
status=418,
|
|
397
390
|
)
|
|
398
391
|
|
|
399
392
|
with pytest.raises(ErrorResponse):
|
|
400
393
|
next(mock_dev_dataset_worker.list_sets())
|
|
401
394
|
|
|
402
|
-
assert len(responses.calls) ==
|
|
395
|
+
assert len(responses.calls) == 1
|
|
403
396
|
assert [(call.request.method, call.request.url) for call in responses.calls] == [
|
|
404
|
-
|
|
405
|
-
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
|
|
406
|
-
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
|
|
407
|
-
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
|
|
408
|
-
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
|
|
409
|
-
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
|
|
397
|
+
("GET", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
|
|
410
398
|
]
|
|
411
399
|
|
|
412
400
|
|
|
@@ -494,22 +482,17 @@ def test_run_download_dataset_artifact_api_error(
|
|
|
494
482
|
responses.add(
|
|
495
483
|
responses.GET,
|
|
496
484
|
f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/",
|
|
497
|
-
status=
|
|
485
|
+
status=418,
|
|
498
486
|
)
|
|
499
487
|
|
|
500
488
|
with pytest.raises(SystemExit):
|
|
501
489
|
mock_dataset_worker.run()
|
|
502
490
|
|
|
503
|
-
assert len(responses.calls) == len(BASE_API_CALLS) * 2 +
|
|
491
|
+
assert len(responses.calls) == len(BASE_API_CALLS) * 2 + 1
|
|
504
492
|
assert [
|
|
505
493
|
(call.request.method, call.request.url) for call in responses.calls
|
|
506
494
|
] == BASE_API_CALLS * 2 + [
|
|
507
|
-
|
|
508
|
-
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
|
|
509
|
-
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
|
|
510
|
-
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
|
|
511
|
-
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
|
|
512
|
-
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/"),
|
|
495
|
+
("GET", f"http://testserver/api/v1/task/{default_dataset.task_id}/artifacts/")
|
|
513
496
|
]
|
|
514
497
|
|
|
515
498
|
assert [(level, message) for _, level, message in caplog.record_tuples] == [
|
|
@@ -519,16 +502,9 @@ def test_run_download_dataset_artifact_api_error(
|
|
|
519
502
|
"Retrieving data for Set (train) from Dataset (dataset_id) (1/1)",
|
|
520
503
|
),
|
|
521
504
|
(logging.INFO, "Downloading artifact for Dataset (dataset_id)"),
|
|
522
|
-
*[
|
|
523
|
-
(
|
|
524
|
-
logging.INFO,
|
|
525
|
-
f"Retrying arkindex_worker.worker.base.BaseWorker.request in {retry} seconds as it raised ErrorResponse: .",
|
|
526
|
-
)
|
|
527
|
-
for retry in [3.0, 4.0, 8.0, 16.0]
|
|
528
|
-
],
|
|
529
505
|
(
|
|
530
506
|
logging.WARNING,
|
|
531
|
-
"An API error occurred while processing Set (train) from Dataset (dataset_id):
|
|
507
|
+
"An API error occurred while processing Set (train) from Dataset (dataset_id): 418 I'm a Teapot - None",
|
|
532
508
|
),
|
|
533
509
|
(
|
|
534
510
|
logging.ERROR,
|
|
@@ -3,10 +3,12 @@ import re
|
|
|
3
3
|
from uuid import UUID
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
-
from apistar.exceptions import ErrorResponse
|
|
7
6
|
|
|
7
|
+
from arkindex.exceptions import ErrorResponse
|
|
8
8
|
from arkindex_worker.cache import CachedClassification, CachedElement
|
|
9
9
|
from arkindex_worker.models import Element
|
|
10
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
11
|
+
from tests import CORPUS_ID
|
|
10
12
|
|
|
11
13
|
from . import BASE_API_CALLS
|
|
12
14
|
|
|
@@ -15,11 +17,96 @@ from . import BASE_API_CALLS
|
|
|
15
17
|
DELETE_PARAMETER = "DELETE_PARAMETER"
|
|
16
18
|
|
|
17
19
|
|
|
20
|
+
def test_load_corpus_classes_api_error(responses, mock_elements_worker):
|
|
21
|
+
responses.add(
|
|
22
|
+
responses.GET,
|
|
23
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
24
|
+
status=418,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
assert not mock_elements_worker.classes
|
|
28
|
+
with pytest.raises(
|
|
29
|
+
Exception, match="Stopping pagination as data will be incomplete"
|
|
30
|
+
):
|
|
31
|
+
mock_elements_worker.load_corpus_classes()
|
|
32
|
+
|
|
33
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 5
|
|
34
|
+
assert [
|
|
35
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
36
|
+
] == BASE_API_CALLS + [
|
|
37
|
+
# We do 5 retries
|
|
38
|
+
(
|
|
39
|
+
"GET",
|
|
40
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
41
|
+
),
|
|
42
|
+
(
|
|
43
|
+
"GET",
|
|
44
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
45
|
+
),
|
|
46
|
+
(
|
|
47
|
+
"GET",
|
|
48
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
49
|
+
),
|
|
50
|
+
(
|
|
51
|
+
"GET",
|
|
52
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
53
|
+
),
|
|
54
|
+
(
|
|
55
|
+
"GET",
|
|
56
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
57
|
+
),
|
|
58
|
+
]
|
|
59
|
+
assert not mock_elements_worker.classes
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_load_corpus_classes(responses, mock_elements_worker):
|
|
63
|
+
responses.add(
|
|
64
|
+
responses.GET,
|
|
65
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
66
|
+
status=200,
|
|
67
|
+
json={
|
|
68
|
+
"count": 3,
|
|
69
|
+
"next": None,
|
|
70
|
+
"results": [
|
|
71
|
+
{
|
|
72
|
+
"id": "0000",
|
|
73
|
+
"name": "good",
|
|
74
|
+
},
|
|
75
|
+
{
|
|
76
|
+
"id": "1111",
|
|
77
|
+
"name": "average",
|
|
78
|
+
},
|
|
79
|
+
{
|
|
80
|
+
"id": "2222",
|
|
81
|
+
"name": "bad",
|
|
82
|
+
},
|
|
83
|
+
],
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
assert not mock_elements_worker.classes
|
|
88
|
+
mock_elements_worker.load_corpus_classes()
|
|
89
|
+
|
|
90
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
91
|
+
assert [
|
|
92
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
93
|
+
] == BASE_API_CALLS + [
|
|
94
|
+
(
|
|
95
|
+
"GET",
|
|
96
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
97
|
+
),
|
|
98
|
+
]
|
|
99
|
+
assert mock_elements_worker.classes == {
|
|
100
|
+
"good": "0000",
|
|
101
|
+
"average": "1111",
|
|
102
|
+
"bad": "2222",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
18
106
|
def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
19
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
20
107
|
responses.add(
|
|
21
108
|
responses.GET,
|
|
22
|
-
f"http://testserver/api/v1/corpus/{
|
|
109
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
23
110
|
status=200,
|
|
24
111
|
json={
|
|
25
112
|
"count": 1,
|
|
@@ -42,7 +129,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
|
42
129
|
] == BASE_API_CALLS + [
|
|
43
130
|
(
|
|
44
131
|
"GET",
|
|
45
|
-
f"http://testserver/api/v1/corpus/{
|
|
132
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
46
133
|
),
|
|
47
134
|
]
|
|
48
135
|
assert mock_elements_worker.classes == {"good": "0000"}
|
|
@@ -51,12 +138,11 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
|
51
138
|
|
|
52
139
|
def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
|
|
53
140
|
# A missing class is now created automatically
|
|
54
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
55
141
|
mock_elements_worker.classes = {"good": "0000"}
|
|
56
142
|
|
|
57
143
|
responses.add(
|
|
58
144
|
responses.POST,
|
|
59
|
-
f"http://testserver/api/v1/corpus/{
|
|
145
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
60
146
|
status=201,
|
|
61
147
|
json={"id": "new-ml-class-1234"},
|
|
62
148
|
)
|
|
@@ -82,12 +168,10 @@ def test_get_ml_class_id(mock_elements_worker):
|
|
|
82
168
|
|
|
83
169
|
|
|
84
170
|
def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
85
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
86
|
-
|
|
87
171
|
# Add some initial classes
|
|
88
172
|
responses.add(
|
|
89
173
|
responses.GET,
|
|
90
|
-
f"http://testserver/api/v1/corpus/{
|
|
174
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
91
175
|
json={
|
|
92
176
|
"count": 1,
|
|
93
177
|
"next": None,
|
|
@@ -103,7 +187,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
103
187
|
# Invalid response when trying to create class2
|
|
104
188
|
responses.add(
|
|
105
189
|
responses.POST,
|
|
106
|
-
f"http://testserver/api/v1/corpus/{
|
|
190
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
107
191
|
status=400,
|
|
108
192
|
json={"non_field_errors": "Already exists"},
|
|
109
193
|
)
|
|
@@ -111,7 +195,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
111
195
|
# Add both classes (class2 is created by another process)
|
|
112
196
|
responses.add(
|
|
113
197
|
responses.GET,
|
|
114
|
-
f"http://testserver/api/v1/corpus/{
|
|
198
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
115
199
|
json={
|
|
116
200
|
"count": 2,
|
|
117
201
|
"next": None,
|
|
@@ -141,15 +225,15 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
141
225
|
] == BASE_API_CALLS + [
|
|
142
226
|
(
|
|
143
227
|
"GET",
|
|
144
|
-
f"http://testserver/api/v1/corpus/{
|
|
228
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
145
229
|
),
|
|
146
230
|
(
|
|
147
231
|
"POST",
|
|
148
|
-
f"http://testserver/api/v1/corpus/{
|
|
232
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
149
233
|
),
|
|
150
234
|
(
|
|
151
235
|
"GET",
|
|
152
|
-
f"http://testserver/api/v1/corpus/{
|
|
236
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
153
237
|
),
|
|
154
238
|
]
|
|
155
239
|
|
|
@@ -169,7 +253,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
|
|
|
169
253
|
"""
|
|
170
254
|
responses.add(
|
|
171
255
|
responses.GET,
|
|
172
|
-
f"http://testserver/api/v1/corpus/{
|
|
256
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
173
257
|
status=200,
|
|
174
258
|
json={
|
|
175
259
|
"count": 1,
|
|
@@ -189,7 +273,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
|
|
|
189
273
|
] == BASE_API_CALLS + [
|
|
190
274
|
(
|
|
191
275
|
"GET",
|
|
192
|
-
f"http://testserver/api/v1/corpus/{
|
|
276
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
193
277
|
),
|
|
194
278
|
]
|
|
195
279
|
|
|
@@ -276,7 +360,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
|
|
|
276
360
|
responses.add(
|
|
277
361
|
responses.POST,
|
|
278
362
|
"http://testserver/api/v1/classifications/",
|
|
279
|
-
status=
|
|
363
|
+
status=418,
|
|
280
364
|
)
|
|
281
365
|
|
|
282
366
|
with pytest.raises(ErrorResponse):
|
|
@@ -287,17 +371,10 @@ def test_create_classification_api_error(responses, mock_elements_worker):
|
|
|
287
371
|
high_confidence=True,
|
|
288
372
|
)
|
|
289
373
|
|
|
290
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
374
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
291
375
|
assert [
|
|
292
376
|
(call.request.method, call.request.url) for call in responses.calls
|
|
293
|
-
] == BASE_API_CALLS + [
|
|
294
|
-
# We retry 5 times the API call
|
|
295
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
296
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
297
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
298
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
299
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
300
|
-
]
|
|
377
|
+
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classifications/")]
|
|
301
378
|
|
|
302
379
|
|
|
303
380
|
def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
@@ -306,7 +383,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
|
306
383
|
# Automatically create a missing class!
|
|
307
384
|
responses.add(
|
|
308
385
|
responses.POST,
|
|
309
|
-
"http://testserver/api/v1/corpus/
|
|
386
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
310
387
|
status=201,
|
|
311
388
|
json={"id": "new-ml-class-1234"},
|
|
312
389
|
)
|
|
@@ -330,7 +407,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
|
330
407
|
for call in responses.calls[-2:]
|
|
331
408
|
] == [
|
|
332
409
|
(
|
|
333
|
-
"http://testserver/api/v1/corpus/
|
|
410
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
334
411
|
{"name": "a_class"},
|
|
335
412
|
),
|
|
336
413
|
(
|
|
@@ -609,7 +686,7 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
|
|
|
609
686
|
responses.add(
|
|
610
687
|
responses.POST,
|
|
611
688
|
"http://testserver/api/v1/classification/bulk/",
|
|
612
|
-
status=
|
|
689
|
+
status=418,
|
|
613
690
|
)
|
|
614
691
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
615
692
|
classes = [
|
|
@@ -630,17 +707,10 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
|
|
|
630
707
|
element=elt, classifications=classes
|
|
631
708
|
)
|
|
632
709
|
|
|
633
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
710
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
634
711
|
assert [
|
|
635
712
|
(call.request.method, call.request.url) for call in responses.calls
|
|
636
|
-
] == BASE_API_CALLS + [
|
|
637
|
-
# We retry 5 times the API call
|
|
638
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
639
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
640
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
641
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
642
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
643
|
-
]
|
|
713
|
+
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
644
714
|
|
|
645
715
|
|
|
646
716
|
def test_create_classifications_create_ml_class(mock_elements_worker, responses):
|
|
@@ -649,7 +719,7 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
649
719
|
# Automatically create a missing class!
|
|
650
720
|
responses.add(
|
|
651
721
|
responses.POST,
|
|
652
|
-
"http://testserver/api/v1/corpus/
|
|
722
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
653
723
|
status=201,
|
|
654
724
|
json={"id": "new-ml-class-1234"},
|
|
655
725
|
)
|
|
@@ -690,7 +760,7 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
690
760
|
] == BASE_API_CALLS + [
|
|
691
761
|
(
|
|
692
762
|
"POST",
|
|
693
|
-
"http://testserver/api/v1/corpus/
|
|
763
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
694
764
|
),
|
|
695
765
|
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
696
766
|
]
|
|
@@ -709,7 +779,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
709
779
|
}
|
|
710
780
|
|
|
711
781
|
|
|
712
|
-
|
|
782
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
783
|
+
def test_create_classifications(batch_size, responses, mock_elements_worker):
|
|
713
784
|
mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
|
|
714
785
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
715
786
|
responses.add(
|
|
@@ -733,62 +804,98 @@ def test_create_classifications(responses, mock_elements_worker):
|
|
|
733
804
|
"high_confidence": False,
|
|
734
805
|
},
|
|
735
806
|
],
|
|
807
|
+
batch_size=batch_size,
|
|
736
808
|
)
|
|
737
809
|
|
|
738
|
-
|
|
810
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
811
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
812
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
813
|
+
|
|
814
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
739
815
|
assert [
|
|
740
816
|
(call.request.method, call.request.url) for call in responses.calls
|
|
741
|
-
] == BASE_API_CALLS +
|
|
742
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
743
|
-
]
|
|
817
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
744
818
|
|
|
745
|
-
|
|
819
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
820
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
821
|
+
empty_payload = {
|
|
746
822
|
"parent": str(elt.id),
|
|
747
823
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
748
|
-
"classifications": [
|
|
749
|
-
{
|
|
750
|
-
"confidence": 0.75,
|
|
751
|
-
"high_confidence": False,
|
|
752
|
-
"ml_class": "0000",
|
|
753
|
-
},
|
|
754
|
-
{
|
|
755
|
-
"confidence": 0.25,
|
|
756
|
-
"high_confidence": False,
|
|
757
|
-
"ml_class": "1111",
|
|
758
|
-
},
|
|
759
|
-
],
|
|
824
|
+
"classifications": [],
|
|
760
825
|
}
|
|
761
826
|
|
|
827
|
+
bodies = []
|
|
828
|
+
first_call_idx = None
|
|
829
|
+
if batch_size > 1:
|
|
830
|
+
first_call_idx = -1
|
|
831
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
832
|
+
else:
|
|
833
|
+
first_call_idx = -2
|
|
834
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
835
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
836
|
+
|
|
837
|
+
assert [
|
|
838
|
+
json.loads(bulk_call.request.body)
|
|
839
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
840
|
+
] == bodies
|
|
762
841
|
|
|
763
|
-
|
|
842
|
+
|
|
843
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
844
|
+
def test_create_classifications_with_cache(
|
|
845
|
+
batch_size, responses, mock_elements_worker_with_cache
|
|
846
|
+
):
|
|
764
847
|
mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
|
|
765
848
|
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
|
|
766
849
|
|
|
767
|
-
|
|
768
|
-
responses.
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
850
|
+
if batch_size > 1:
|
|
851
|
+
responses.add(
|
|
852
|
+
responses.POST,
|
|
853
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
854
|
+
status=200,
|
|
855
|
+
json={
|
|
856
|
+
"parent": str(elt.id),
|
|
857
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
858
|
+
"classifications": [
|
|
859
|
+
{
|
|
860
|
+
"id": "00000000-0000-0000-0000-000000000000",
|
|
861
|
+
"ml_class": "0000",
|
|
862
|
+
"confidence": 0.75,
|
|
863
|
+
"high_confidence": False,
|
|
864
|
+
"state": "pending",
|
|
865
|
+
},
|
|
866
|
+
{
|
|
867
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
868
|
+
"ml_class": "1111",
|
|
869
|
+
"confidence": 0.25,
|
|
870
|
+
"high_confidence": False,
|
|
871
|
+
"state": "pending",
|
|
872
|
+
},
|
|
873
|
+
],
|
|
874
|
+
},
|
|
875
|
+
)
|
|
876
|
+
else:
|
|
877
|
+
for cl_id, cl_class, cl_conf in [
|
|
878
|
+
("00000000-0000-0000-0000-000000000000", "0000", 0.75),
|
|
879
|
+
("11111111-1111-1111-1111-111111111111", "1111", 0.25),
|
|
880
|
+
]:
|
|
881
|
+
responses.add(
|
|
882
|
+
responses.POST,
|
|
883
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
884
|
+
status=200,
|
|
885
|
+
json={
|
|
886
|
+
"parent": str(elt.id),
|
|
887
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
888
|
+
"classifications": [
|
|
889
|
+
{
|
|
890
|
+
"id": cl_id,
|
|
891
|
+
"ml_class": cl_class,
|
|
892
|
+
"confidence": cl_conf,
|
|
893
|
+
"high_confidence": False,
|
|
894
|
+
"state": "pending",
|
|
895
|
+
},
|
|
896
|
+
],
|
|
788
897
|
},
|
|
789
|
-
|
|
790
|
-
},
|
|
791
|
-
)
|
|
898
|
+
)
|
|
792
899
|
|
|
793
900
|
mock_elements_worker_with_cache.create_classifications(
|
|
794
901
|
element=elt,
|
|
@@ -804,32 +911,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
|
|
|
804
911
|
"high_confidence": False,
|
|
805
912
|
},
|
|
806
913
|
],
|
|
914
|
+
batch_size=batch_size,
|
|
807
915
|
)
|
|
808
916
|
|
|
809
|
-
|
|
917
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
918
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
919
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
920
|
+
|
|
921
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
810
922
|
assert [
|
|
811
923
|
(call.request.method, call.request.url) for call in responses.calls
|
|
812
|
-
] == BASE_API_CALLS +
|
|
813
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
814
|
-
]
|
|
924
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
815
925
|
|
|
816
|
-
|
|
926
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
927
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
928
|
+
empty_payload = {
|
|
817
929
|
"parent": str(elt.id),
|
|
818
930
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
819
|
-
"classifications": [
|
|
820
|
-
{
|
|
821
|
-
"confidence": 0.75,
|
|
822
|
-
"high_confidence": False,
|
|
823
|
-
"ml_class": "0000",
|
|
824
|
-
},
|
|
825
|
-
{
|
|
826
|
-
"confidence": 0.25,
|
|
827
|
-
"high_confidence": False,
|
|
828
|
-
"ml_class": "1111",
|
|
829
|
-
},
|
|
830
|
-
],
|
|
931
|
+
"classifications": [],
|
|
831
932
|
}
|
|
832
933
|
|
|
934
|
+
bodies = []
|
|
935
|
+
first_call_idx = None
|
|
936
|
+
if batch_size > 1:
|
|
937
|
+
first_call_idx = -1
|
|
938
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
939
|
+
else:
|
|
940
|
+
first_call_idx = -2
|
|
941
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
942
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
943
|
+
|
|
944
|
+
assert [
|
|
945
|
+
json.loads(bulk_call.request.body)
|
|
946
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
947
|
+
] == bodies
|
|
948
|
+
|
|
833
949
|
# Check that created classifications were properly stored in SQLite cache
|
|
834
950
|
assert list(CachedClassification.select()) == [
|
|
835
951
|
CachedClassification(
|