arkindex-base-worker 0.4.0a2__py3-none-any.whl → 0.4.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/METADATA +7 -7
- arkindex_base_worker-0.4.0b2.dist-info/RECORD +51 -0
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/WHEEL +1 -1
- arkindex_worker/image.py +2 -1
- arkindex_worker/utils.py +76 -0
- arkindex_worker/worker/__init__.py +24 -14
- arkindex_worker/worker/base.py +3 -9
- arkindex_worker/worker/classification.py +33 -17
- arkindex_worker/worker/corpus.py +3 -1
- arkindex_worker/worker/dataset.py +1 -1
- arkindex_worker/worker/element.py +45 -16
- arkindex_worker/worker/entity.py +30 -17
- arkindex_worker/worker/metadata.py +19 -9
- arkindex_worker/worker/task.py +4 -2
- arkindex_worker/worker/training.py +5 -5
- arkindex_worker/worker/transcription.py +39 -18
- arkindex_worker/worker/version.py +3 -1
- tests/test_base_worker.py +1 -1
- tests/test_elements_worker/test_classifications.py +107 -60
- tests/test_elements_worker/test_elements.py +213 -70
- tests/test_elements_worker/test_entities.py +102 -33
- tests/test_elements_worker/test_metadata.py +223 -98
- tests/test_elements_worker/test_transcriptions.py +293 -143
- tests/test_merge.py +1 -1
- tests/test_utils.py +28 -0
- arkindex_base_worker-0.4.0a2.dist-info/RECORD +0 -51
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/LICENSE +0 -0
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@ from arkindex_worker.cache import (
|
|
|
15
15
|
init_cache_db,
|
|
16
16
|
)
|
|
17
17
|
from arkindex_worker.models import Element
|
|
18
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
18
19
|
from arkindex_worker.worker import ElementsWorker
|
|
19
20
|
from arkindex_worker.worker.element import MissingTypeError
|
|
20
21
|
from tests import CORPUS_ID
|
|
@@ -22,6 +23,24 @@ from tests import CORPUS_ID
|
|
|
22
23
|
from . import BASE_API_CALLS
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
def test_list_corpus_types(responses, mock_elements_worker):
|
|
27
|
+
responses.add(
|
|
28
|
+
responses.GET,
|
|
29
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
|
|
30
|
+
json={
|
|
31
|
+
"id": CORPUS_ID,
|
|
32
|
+
"types": [{"slug": "folder"}, {"slug": "page"}],
|
|
33
|
+
},
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
mock_elements_worker.list_corpus_types()
|
|
37
|
+
|
|
38
|
+
assert mock_elements_worker.corpus_types == {
|
|
39
|
+
"folder": {"slug": "folder"},
|
|
40
|
+
"page": {"slug": "page"},
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
25
44
|
def test_check_required_types_argument_types(mock_elements_worker):
|
|
26
45
|
with pytest.raises(
|
|
27
46
|
AssertionError, match="At least one element type slug is required."
|
|
@@ -32,17 +51,11 @@ def test_check_required_types_argument_types(mock_elements_worker):
|
|
|
32
51
|
mock_elements_worker.check_required_types("lol", 42)
|
|
33
52
|
|
|
34
53
|
|
|
35
|
-
def test_check_required_types(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
"id": CORPUS_ID,
|
|
41
|
-
"name": "Some Corpus",
|
|
42
|
-
"types": [{"slug": "folder"}, {"slug": "page"}],
|
|
43
|
-
},
|
|
44
|
-
)
|
|
45
|
-
mock_elements_worker.setup_api_client()
|
|
54
|
+
def test_check_required_types(mock_elements_worker):
|
|
55
|
+
mock_elements_worker.corpus_types = {
|
|
56
|
+
"folder": {"slug": "folder"},
|
|
57
|
+
"page": {"slug": "page"},
|
|
58
|
+
}
|
|
46
59
|
|
|
47
60
|
assert mock_elements_worker.check_required_types("page")
|
|
48
61
|
assert mock_elements_worker.check_required_types("page", "folder")
|
|
@@ -50,22 +63,18 @@ def test_check_required_types(responses, mock_elements_worker):
|
|
|
50
63
|
with pytest.raises(
|
|
51
64
|
MissingTypeError,
|
|
52
65
|
match=re.escape(
|
|
53
|
-
"Element
|
|
66
|
+
"Element types act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
|
|
54
67
|
),
|
|
55
68
|
):
|
|
56
69
|
assert mock_elements_worker.check_required_types("page", "text_line", "act")
|
|
57
70
|
|
|
58
71
|
|
|
59
72
|
def test_create_missing_types(responses, mock_elements_worker):
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
"name": "Some Corpus",
|
|
66
|
-
"types": [{"slug": "folder"}, {"slug": "page"}],
|
|
67
|
-
},
|
|
68
|
-
)
|
|
73
|
+
mock_elements_worker.corpus_types = {
|
|
74
|
+
"folder": {"slug": "folder"},
|
|
75
|
+
"page": {"slug": "page"},
|
|
76
|
+
}
|
|
77
|
+
|
|
69
78
|
responses.add(
|
|
70
79
|
responses.POST,
|
|
71
80
|
"http://testserver/api/v1/elements/type/",
|
|
@@ -94,7 +103,6 @@ def test_create_missing_types(responses, mock_elements_worker):
|
|
|
94
103
|
)
|
|
95
104
|
],
|
|
96
105
|
)
|
|
97
|
-
mock_elements_worker.setup_api_client()
|
|
98
106
|
|
|
99
107
|
assert mock_elements_worker.check_required_types(
|
|
100
108
|
"page", "text_line", "act", create_missing=True
|
|
@@ -1003,7 +1011,10 @@ def test_create_elements_api_error(responses, mock_elements_worker):
|
|
|
1003
1011
|
]
|
|
1004
1012
|
|
|
1005
1013
|
|
|
1006
|
-
|
|
1014
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
1015
|
+
def test_create_elements_cached_element(
|
|
1016
|
+
batch_size, responses, mock_elements_worker_with_cache
|
|
1017
|
+
):
|
|
1007
1018
|
image = CachedImage.create(
|
|
1008
1019
|
id=UUID("c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe"),
|
|
1009
1020
|
width=42,
|
|
@@ -1016,12 +1027,28 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
|
|
|
1016
1027
|
image_id=image.id,
|
|
1017
1028
|
polygon="[[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]]",
|
|
1018
1029
|
)
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1030
|
+
|
|
1031
|
+
if batch_size > 1:
|
|
1032
|
+
responses.add(
|
|
1033
|
+
responses.POST,
|
|
1034
|
+
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1035
|
+
status=200,
|
|
1036
|
+
json=[
|
|
1037
|
+
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
|
|
1038
|
+
{"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
|
|
1039
|
+
],
|
|
1040
|
+
)
|
|
1041
|
+
else:
|
|
1042
|
+
for elt_id in [
|
|
1043
|
+
"497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
|
1044
|
+
"5468c358-b9c4-499d-8b92-d6349c58e88d",
|
|
1045
|
+
]:
|
|
1046
|
+
responses.add(
|
|
1047
|
+
responses.POST,
|
|
1048
|
+
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1049
|
+
status=200,
|
|
1050
|
+
json=[{"id": elt_id}],
|
|
1051
|
+
)
|
|
1025
1052
|
|
|
1026
1053
|
created_ids = mock_elements_worker_with_cache.create_elements(
|
|
1027
1054
|
parent=elt,
|
|
@@ -1030,30 +1057,69 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
|
|
|
1030
1057
|
"name": "0",
|
|
1031
1058
|
"type": "something",
|
|
1032
1059
|
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
|
|
1033
|
-
}
|
|
1060
|
+
},
|
|
1061
|
+
{
|
|
1062
|
+
"name": "1",
|
|
1063
|
+
"type": "something",
|
|
1064
|
+
"polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
|
|
1065
|
+
},
|
|
1034
1066
|
],
|
|
1067
|
+
batch_size=batch_size,
|
|
1035
1068
|
)
|
|
1036
1069
|
|
|
1037
|
-
|
|
1038
|
-
assert [
|
|
1039
|
-
(call.request.method, call.request.url) for call in responses.calls
|
|
1040
|
-
] == BASE_API_CALLS + [
|
|
1070
|
+
bulk_api_calls = [
|
|
1041
1071
|
(
|
|
1042
1072
|
"POST",
|
|
1043
1073
|
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1044
|
-
)
|
|
1074
|
+
)
|
|
1045
1075
|
]
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
"
|
|
1050
|
-
"
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1076
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
1077
|
+
bulk_api_calls.append(
|
|
1078
|
+
(
|
|
1079
|
+
"POST",
|
|
1080
|
+
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1081
|
+
)
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
1085
|
+
assert [
|
|
1086
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
1087
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
1088
|
+
|
|
1089
|
+
first_elt = {
|
|
1090
|
+
"name": "0",
|
|
1091
|
+
"type": "something",
|
|
1092
|
+
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
|
|
1093
|
+
}
|
|
1094
|
+
second_elt = {
|
|
1095
|
+
"name": "1",
|
|
1096
|
+
"type": "something",
|
|
1097
|
+
"polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
|
|
1098
|
+
}
|
|
1099
|
+
empty_payload = {
|
|
1100
|
+
"elements": [],
|
|
1054
1101
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
1055
1102
|
}
|
|
1056
|
-
|
|
1103
|
+
|
|
1104
|
+
bodies = []
|
|
1105
|
+
first_call_idx = None
|
|
1106
|
+
if batch_size > 1:
|
|
1107
|
+
first_call_idx = -1
|
|
1108
|
+
bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
|
|
1109
|
+
else:
|
|
1110
|
+
first_call_idx = -2
|
|
1111
|
+
bodies.append({**empty_payload, "elements": [first_elt]})
|
|
1112
|
+
bodies.append({**empty_payload, "elements": [second_elt]})
|
|
1113
|
+
|
|
1114
|
+
assert [
|
|
1115
|
+
json.loads(bulk_call.request.body)
|
|
1116
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
1117
|
+
] == bodies
|
|
1118
|
+
|
|
1119
|
+
assert created_ids == [
|
|
1120
|
+
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
|
|
1121
|
+
{"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
|
|
1122
|
+
]
|
|
1057
1123
|
|
|
1058
1124
|
# Check that created elements were properly stored in SQLite cache
|
|
1059
1125
|
assert list(CachedElement.select().order_by(CachedElement.id)) == [
|
|
@@ -1065,11 +1131,24 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
|
|
|
1065
1131
|
image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
|
|
1066
1132
|
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
|
|
1067
1133
|
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
|
|
1134
|
+
confidence=None,
|
|
1135
|
+
),
|
|
1136
|
+
CachedElement(
|
|
1137
|
+
id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
|
|
1138
|
+
parent_id=elt.id,
|
|
1139
|
+
type="something",
|
|
1140
|
+
image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
|
|
1141
|
+
polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
|
|
1142
|
+
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
|
|
1143
|
+
confidence=None,
|
|
1068
1144
|
),
|
|
1069
1145
|
]
|
|
1070
1146
|
|
|
1071
1147
|
|
|
1072
|
-
|
|
1148
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
1149
|
+
def test_create_elements(
|
|
1150
|
+
batch_size, responses, mock_elements_worker_with_cache, tmp_path
|
|
1151
|
+
):
|
|
1073
1152
|
elt = Element(
|
|
1074
1153
|
{
|
|
1075
1154
|
"id": "12341234-1234-1234-1234-123412341234",
|
|
@@ -1083,12 +1162,28 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
|
|
|
1083
1162
|
},
|
|
1084
1163
|
}
|
|
1085
1164
|
)
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1165
|
+
|
|
1166
|
+
if batch_size > 1:
|
|
1167
|
+
responses.add(
|
|
1168
|
+
responses.POST,
|
|
1169
|
+
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1170
|
+
status=200,
|
|
1171
|
+
json=[
|
|
1172
|
+
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
|
|
1173
|
+
{"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
|
|
1174
|
+
],
|
|
1175
|
+
)
|
|
1176
|
+
else:
|
|
1177
|
+
for elt_id in [
|
|
1178
|
+
"497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
|
1179
|
+
"5468c358-b9c4-499d-8b92-d6349c58e88d",
|
|
1180
|
+
]:
|
|
1181
|
+
responses.add(
|
|
1182
|
+
responses.POST,
|
|
1183
|
+
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1184
|
+
status=200,
|
|
1185
|
+
json=[{"id": elt_id}],
|
|
1186
|
+
)
|
|
1092
1187
|
|
|
1093
1188
|
created_ids = mock_elements_worker_with_cache.create_elements(
|
|
1094
1189
|
parent=elt,
|
|
@@ -1097,30 +1192,69 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
|
|
|
1097
1192
|
"name": "0",
|
|
1098
1193
|
"type": "something",
|
|
1099
1194
|
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
|
|
1100
|
-
}
|
|
1195
|
+
},
|
|
1196
|
+
{
|
|
1197
|
+
"name": "1",
|
|
1198
|
+
"type": "something",
|
|
1199
|
+
"polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
|
|
1200
|
+
},
|
|
1101
1201
|
],
|
|
1202
|
+
batch_size=batch_size,
|
|
1102
1203
|
)
|
|
1103
1204
|
|
|
1104
|
-
|
|
1105
|
-
assert [
|
|
1106
|
-
(call.request.method, call.request.url) for call in responses.calls
|
|
1107
|
-
] == BASE_API_CALLS + [
|
|
1205
|
+
bulk_api_calls = [
|
|
1108
1206
|
(
|
|
1109
1207
|
"POST",
|
|
1110
1208
|
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1111
|
-
)
|
|
1209
|
+
)
|
|
1112
1210
|
]
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
"
|
|
1117
|
-
"
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1211
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
1212
|
+
bulk_api_calls.append(
|
|
1213
|
+
(
|
|
1214
|
+
"POST",
|
|
1215
|
+
"http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
|
|
1216
|
+
)
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
1220
|
+
assert [
|
|
1221
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
1222
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
1223
|
+
|
|
1224
|
+
first_elt = {
|
|
1225
|
+
"name": "0",
|
|
1226
|
+
"type": "something",
|
|
1227
|
+
"polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
|
|
1228
|
+
}
|
|
1229
|
+
second_elt = {
|
|
1230
|
+
"name": "1",
|
|
1231
|
+
"type": "something",
|
|
1232
|
+
"polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
|
|
1233
|
+
}
|
|
1234
|
+
empty_payload = {
|
|
1235
|
+
"elements": [],
|
|
1121
1236
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
1122
1237
|
}
|
|
1123
|
-
|
|
1238
|
+
|
|
1239
|
+
bodies = []
|
|
1240
|
+
first_call_idx = None
|
|
1241
|
+
if batch_size > 1:
|
|
1242
|
+
first_call_idx = -1
|
|
1243
|
+
bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
|
|
1244
|
+
else:
|
|
1245
|
+
first_call_idx = -2
|
|
1246
|
+
bodies.append({**empty_payload, "elements": [first_elt]})
|
|
1247
|
+
bodies.append({**empty_payload, "elements": [second_elt]})
|
|
1248
|
+
|
|
1249
|
+
assert [
|
|
1250
|
+
json.loads(bulk_call.request.body)
|
|
1251
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
1252
|
+
] == bodies
|
|
1253
|
+
|
|
1254
|
+
assert created_ids == [
|
|
1255
|
+
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
|
|
1256
|
+
{"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
|
|
1257
|
+
]
|
|
1124
1258
|
|
|
1125
1259
|
# Check that created elements were properly stored in SQLite cache
|
|
1126
1260
|
assert (tmp_path / "db.sqlite").is_file()
|
|
@@ -1134,7 +1268,16 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
|
|
|
1134
1268
|
polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
|
|
1135
1269
|
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
|
|
1136
1270
|
confidence=None,
|
|
1137
|
-
)
|
|
1271
|
+
),
|
|
1272
|
+
CachedElement(
|
|
1273
|
+
id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
|
|
1274
|
+
parent_id=UUID("12341234-1234-1234-1234-123412341234"),
|
|
1275
|
+
type="something",
|
|
1276
|
+
image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
|
|
1277
|
+
polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
|
|
1278
|
+
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
|
|
1279
|
+
confidence=None,
|
|
1280
|
+
),
|
|
1138
1281
|
]
|
|
1139
1282
|
|
|
1140
1283
|
|
|
@@ -1261,9 +1404,9 @@ def test_create_elements_integrity_error(
|
|
|
1261
1404
|
{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
|
|
1262
1405
|
]
|
|
1263
1406
|
|
|
1264
|
-
assert len(caplog.records) ==
|
|
1265
|
-
assert caplog.records[
|
|
1266
|
-
assert caplog.records[
|
|
1407
|
+
assert len(caplog.records) == 3
|
|
1408
|
+
assert caplog.records[-1].levelname == "WARNING"
|
|
1409
|
+
assert caplog.records[-1].message.startswith(
|
|
1267
1410
|
"Couldn't save created elements in local cache:"
|
|
1268
1411
|
)
|
|
1269
1412
|
|
|
@@ -13,6 +13,7 @@ from arkindex_worker.cache import (
|
|
|
13
13
|
CachedTranscriptionEntity,
|
|
14
14
|
)
|
|
15
15
|
from arkindex_worker.models import Transcription
|
|
16
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
16
17
|
from arkindex_worker.worker.entity import MissingEntityType
|
|
17
18
|
from arkindex_worker.worker.transcription import TextOrientation
|
|
18
19
|
from tests import CORPUS_ID
|
|
@@ -988,38 +989,89 @@ def test_create_transcription_entities_wrong_entity(
|
|
|
988
989
|
)
|
|
989
990
|
|
|
990
991
|
|
|
991
|
-
|
|
992
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
993
|
+
def test_create_transcription_entities(batch_size, responses, mock_elements_worker):
|
|
992
994
|
transcription = Transcription(id="transcription-id")
|
|
995
|
+
|
|
993
996
|
# Call to Transcription entities creation in bulk
|
|
994
|
-
|
|
995
|
-
responses.
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
997
|
+
if batch_size > 1:
|
|
998
|
+
responses.add(
|
|
999
|
+
responses.POST,
|
|
1000
|
+
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
|
|
1001
|
+
status=201,
|
|
1002
|
+
match=[
|
|
1003
|
+
matchers.json_params_matcher(
|
|
1004
|
+
{
|
|
1005
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
1006
|
+
"entities": [
|
|
1007
|
+
{
|
|
1008
|
+
"name": "Teklia",
|
|
1009
|
+
"type_id": "22222222-2222-2222-2222-222222222222",
|
|
1010
|
+
"offset": 0,
|
|
1011
|
+
"length": 6,
|
|
1012
|
+
"confidence": 1.0,
|
|
1013
|
+
},
|
|
1014
|
+
{
|
|
1015
|
+
"name": "Team Rocket",
|
|
1016
|
+
"type_id": "22222222-2222-2222-2222-222222222222",
|
|
1017
|
+
"offset": 7,
|
|
1018
|
+
"length": 11,
|
|
1019
|
+
"confidence": 1.0,
|
|
1020
|
+
},
|
|
1021
|
+
],
|
|
1022
|
+
}
|
|
1023
|
+
)
|
|
1024
|
+
],
|
|
1025
|
+
json={
|
|
1026
|
+
"entities": [
|
|
1027
|
+
{
|
|
1028
|
+
"transcription_entity_id": "transc-entity-id",
|
|
1029
|
+
"entity_id": "entity-id1",
|
|
1030
|
+
},
|
|
1031
|
+
{
|
|
1032
|
+
"transcription_entity_id": "transc-entity-id",
|
|
1033
|
+
"entity_id": "entity-id2",
|
|
1034
|
+
},
|
|
1035
|
+
]
|
|
1036
|
+
},
|
|
1037
|
+
)
|
|
1038
|
+
else:
|
|
1039
|
+
for idx, (name, offset, length) in enumerate(
|
|
1040
|
+
[
|
|
1041
|
+
("Teklia", 0, 6),
|
|
1042
|
+
("Team Rocket", 7, 11),
|
|
1043
|
+
],
|
|
1044
|
+
start=1,
|
|
1045
|
+
):
|
|
1046
|
+
responses.add(
|
|
1047
|
+
responses.POST,
|
|
1048
|
+
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
|
|
1049
|
+
status=201,
|
|
1050
|
+
match=[
|
|
1051
|
+
matchers.json_params_matcher(
|
|
1052
|
+
{
|
|
1053
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
1054
|
+
"entities": [
|
|
1055
|
+
{
|
|
1056
|
+
"name": name,
|
|
1057
|
+
"type_id": "22222222-2222-2222-2222-222222222222",
|
|
1058
|
+
"offset": offset,
|
|
1059
|
+
"length": length,
|
|
1060
|
+
"confidence": 1.0,
|
|
1061
|
+
}
|
|
1062
|
+
],
|
|
1063
|
+
}
|
|
1064
|
+
)
|
|
1065
|
+
],
|
|
1066
|
+
json={
|
|
1002
1067
|
"entities": [
|
|
1003
1068
|
{
|
|
1004
|
-
"
|
|
1005
|
-
"
|
|
1006
|
-
"offset": 0,
|
|
1007
|
-
"length": 6,
|
|
1008
|
-
"confidence": 1.0,
|
|
1069
|
+
"transcription_entity_id": "transc-entity-id",
|
|
1070
|
+
"entity_id": f"entity-id{idx}",
|
|
1009
1071
|
}
|
|
1010
|
-
]
|
|
1011
|
-
}
|
|
1072
|
+
]
|
|
1073
|
+
},
|
|
1012
1074
|
)
|
|
1013
|
-
],
|
|
1014
|
-
json={
|
|
1015
|
-
"entities": [
|
|
1016
|
-
{
|
|
1017
|
-
"transcription_entity_id": "transc-entity-id",
|
|
1018
|
-
"entity_id": "entity-id",
|
|
1019
|
-
}
|
|
1020
|
-
]
|
|
1021
|
-
},
|
|
1022
|
-
)
|
|
1023
1075
|
|
|
1024
1076
|
# Store entity type/slug correspondence on the worker
|
|
1025
1077
|
mock_elements_worker.entity_types = {
|
|
@@ -1034,18 +1086,35 @@ def test_create_transcription_entities(responses, mock_elements_worker):
|
|
|
1034
1086
|
"offset": 0,
|
|
1035
1087
|
"length": 6,
|
|
1036
1088
|
"confidence": 1.0,
|
|
1037
|
-
}
|
|
1089
|
+
},
|
|
1090
|
+
{
|
|
1091
|
+
"name": "Team Rocket",
|
|
1092
|
+
"type_id": "22222222-2222-2222-2222-222222222222",
|
|
1093
|
+
"offset": 7,
|
|
1094
|
+
"length": 11,
|
|
1095
|
+
"confidence": 1.0,
|
|
1096
|
+
},
|
|
1038
1097
|
],
|
|
1098
|
+
batch_size=batch_size,
|
|
1039
1099
|
)
|
|
1040
1100
|
|
|
1041
|
-
assert len(created_objects) ==
|
|
1101
|
+
assert len(created_objects) == 2
|
|
1042
1102
|
|
|
1043
|
-
|
|
1044
|
-
assert [
|
|
1045
|
-
(call.request.method, call.request.url) for call in responses.calls
|
|
1046
|
-
] == BASE_API_CALLS + [
|
|
1103
|
+
bulk_api_calls = [
|
|
1047
1104
|
(
|
|
1048
1105
|
"POST",
|
|
1049
1106
|
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
|
|
1050
|
-
)
|
|
1107
|
+
)
|
|
1051
1108
|
]
|
|
1109
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
1110
|
+
bulk_api_calls.append(
|
|
1111
|
+
(
|
|
1112
|
+
"POST",
|
|
1113
|
+
"http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
|
|
1114
|
+
)
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
1118
|
+
assert [
|
|
1119
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
1120
|
+
] == BASE_API_CALLS + bulk_api_calls
|