arkindex-base-worker 0.4.0a2__py3-none-any.whl → 0.4.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@ from arkindex_worker.cache import (
15
15
  init_cache_db,
16
16
  )
17
17
  from arkindex_worker.models import Element
18
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
18
19
  from arkindex_worker.worker import ElementsWorker
19
20
  from arkindex_worker.worker.element import MissingTypeError
20
21
  from tests import CORPUS_ID
@@ -22,6 +23,24 @@ from tests import CORPUS_ID
22
23
  from . import BASE_API_CALLS
23
24
 
24
25
 
26
+ def test_list_corpus_types(responses, mock_elements_worker):
27
+ responses.add(
28
+ responses.GET,
29
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
30
+ json={
31
+ "id": CORPUS_ID,
32
+ "types": [{"slug": "folder"}, {"slug": "page"}],
33
+ },
34
+ )
35
+
36
+ mock_elements_worker.list_corpus_types()
37
+
38
+ assert mock_elements_worker.corpus_types == {
39
+ "folder": {"slug": "folder"},
40
+ "page": {"slug": "page"},
41
+ }
42
+
43
+
25
44
  def test_check_required_types_argument_types(mock_elements_worker):
26
45
  with pytest.raises(
27
46
  AssertionError, match="At least one element type slug is required."
@@ -32,17 +51,11 @@ def test_check_required_types_argument_types(mock_elements_worker):
32
51
  mock_elements_worker.check_required_types("lol", 42)
33
52
 
34
53
 
35
- def test_check_required_types(responses, mock_elements_worker):
36
- responses.add(
37
- responses.GET,
38
- f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
39
- json={
40
- "id": CORPUS_ID,
41
- "name": "Some Corpus",
42
- "types": [{"slug": "folder"}, {"slug": "page"}],
43
- },
44
- )
45
- mock_elements_worker.setup_api_client()
54
+ def test_check_required_types(mock_elements_worker):
55
+ mock_elements_worker.corpus_types = {
56
+ "folder": {"slug": "folder"},
57
+ "page": {"slug": "page"},
58
+ }
46
59
 
47
60
  assert mock_elements_worker.check_required_types("page")
48
61
  assert mock_elements_worker.check_required_types("page", "folder")
@@ -50,22 +63,18 @@ def test_check_required_types(responses, mock_elements_worker):
50
63
  with pytest.raises(
51
64
  MissingTypeError,
52
65
  match=re.escape(
53
- "Element type(s) act, text_line were not found in the Some Corpus corpus (11111111-1111-1111-1111-111111111111)."
66
+ "Element types act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
54
67
  ),
55
68
  ):
56
69
  assert mock_elements_worker.check_required_types("page", "text_line", "act")
57
70
 
58
71
 
59
72
  def test_create_missing_types(responses, mock_elements_worker):
60
- responses.add(
61
- responses.GET,
62
- f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
63
- json={
64
- "id": CORPUS_ID,
65
- "name": "Some Corpus",
66
- "types": [{"slug": "folder"}, {"slug": "page"}],
67
- },
68
- )
73
+ mock_elements_worker.corpus_types = {
74
+ "folder": {"slug": "folder"},
75
+ "page": {"slug": "page"},
76
+ }
77
+
69
78
  responses.add(
70
79
  responses.POST,
71
80
  "http://testserver/api/v1/elements/type/",
@@ -94,7 +103,6 @@ def test_create_missing_types(responses, mock_elements_worker):
94
103
  )
95
104
  ],
96
105
  )
97
- mock_elements_worker.setup_api_client()
98
106
 
99
107
  assert mock_elements_worker.check_required_types(
100
108
  "page", "text_line", "act", create_missing=True
@@ -1003,7 +1011,10 @@ def test_create_elements_api_error(responses, mock_elements_worker):
1003
1011
  ]
1004
1012
 
1005
1013
 
1006
- def test_create_elements_cached_element(responses, mock_elements_worker_with_cache):
1014
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
1015
+ def test_create_elements_cached_element(
1016
+ batch_size, responses, mock_elements_worker_with_cache
1017
+ ):
1007
1018
  image = CachedImage.create(
1008
1019
  id=UUID("c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe"),
1009
1020
  width=42,
@@ -1016,12 +1027,28 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1016
1027
  image_id=image.id,
1017
1028
  polygon="[[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]]",
1018
1029
  )
1019
- responses.add(
1020
- responses.POST,
1021
- "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1022
- status=200,
1023
- json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
1024
- )
1030
+
1031
+ if batch_size > 1:
1032
+ responses.add(
1033
+ responses.POST,
1034
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1035
+ status=200,
1036
+ json=[
1037
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1038
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1039
+ ],
1040
+ )
1041
+ else:
1042
+ for elt_id in [
1043
+ "497f6eca-6276-4993-bfeb-53cbbbba6f08",
1044
+ "5468c358-b9c4-499d-8b92-d6349c58e88d",
1045
+ ]:
1046
+ responses.add(
1047
+ responses.POST,
1048
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1049
+ status=200,
1050
+ json=[{"id": elt_id}],
1051
+ )
1025
1052
 
1026
1053
  created_ids = mock_elements_worker_with_cache.create_elements(
1027
1054
  parent=elt,
@@ -1030,30 +1057,69 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1030
1057
  "name": "0",
1031
1058
  "type": "something",
1032
1059
  "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1033
- }
1060
+ },
1061
+ {
1062
+ "name": "1",
1063
+ "type": "something",
1064
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1065
+ },
1034
1066
  ],
1067
+ batch_size=batch_size,
1035
1068
  )
1036
1069
 
1037
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1038
- assert [
1039
- (call.request.method, call.request.url) for call in responses.calls
1040
- ] == BASE_API_CALLS + [
1070
+ bulk_api_calls = [
1041
1071
  (
1042
1072
  "POST",
1043
1073
  "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1044
- ),
1074
+ )
1045
1075
  ]
1046
- assert json.loads(responses.calls[-1].request.body) == {
1047
- "elements": [
1048
- {
1049
- "name": "0",
1050
- "type": "something",
1051
- "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1052
- }
1053
- ],
1076
+ if batch_size != DEFAULT_BATCH_SIZE:
1077
+ bulk_api_calls.append(
1078
+ (
1079
+ "POST",
1080
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1081
+ )
1082
+ )
1083
+
1084
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1085
+ assert [
1086
+ (call.request.method, call.request.url) for call in responses.calls
1087
+ ] == BASE_API_CALLS + bulk_api_calls
1088
+
1089
+ first_elt = {
1090
+ "name": "0",
1091
+ "type": "something",
1092
+ "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1093
+ }
1094
+ second_elt = {
1095
+ "name": "1",
1096
+ "type": "something",
1097
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1098
+ }
1099
+ empty_payload = {
1100
+ "elements": [],
1054
1101
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1055
1102
  }
1056
- assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
1103
+
1104
+ bodies = []
1105
+ first_call_idx = None
1106
+ if batch_size > 1:
1107
+ first_call_idx = -1
1108
+ bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
1109
+ else:
1110
+ first_call_idx = -2
1111
+ bodies.append({**empty_payload, "elements": [first_elt]})
1112
+ bodies.append({**empty_payload, "elements": [second_elt]})
1113
+
1114
+ assert [
1115
+ json.loads(bulk_call.request.body)
1116
+ for bulk_call in responses.calls[first_call_idx:]
1117
+ ] == bodies
1118
+
1119
+ assert created_ids == [
1120
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1121
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1122
+ ]
1057
1123
 
1058
1124
  # Check that created elements were properly stored in SQLite cache
1059
1125
  assert list(CachedElement.select().order_by(CachedElement.id)) == [
@@ -1065,11 +1131,24 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1065
1131
  image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1066
1132
  polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
1067
1133
  worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1134
+ confidence=None,
1135
+ ),
1136
+ CachedElement(
1137
+ id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
1138
+ parent_id=elt.id,
1139
+ type="something",
1140
+ image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1141
+ polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
1142
+ worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1143
+ confidence=None,
1068
1144
  ),
1069
1145
  ]
1070
1146
 
1071
1147
 
1072
- def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1148
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
1149
+ def test_create_elements(
1150
+ batch_size, responses, mock_elements_worker_with_cache, tmp_path
1151
+ ):
1073
1152
  elt = Element(
1074
1153
  {
1075
1154
  "id": "12341234-1234-1234-1234-123412341234",
@@ -1083,12 +1162,28 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1083
1162
  },
1084
1163
  }
1085
1164
  )
1086
- responses.add(
1087
- responses.POST,
1088
- "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1089
- status=200,
1090
- json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
1091
- )
1165
+
1166
+ if batch_size > 1:
1167
+ responses.add(
1168
+ responses.POST,
1169
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1170
+ status=200,
1171
+ json=[
1172
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1173
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1174
+ ],
1175
+ )
1176
+ else:
1177
+ for elt_id in [
1178
+ "497f6eca-6276-4993-bfeb-53cbbbba6f08",
1179
+ "5468c358-b9c4-499d-8b92-d6349c58e88d",
1180
+ ]:
1181
+ responses.add(
1182
+ responses.POST,
1183
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1184
+ status=200,
1185
+ json=[{"id": elt_id}],
1186
+ )
1092
1187
 
1093
1188
  created_ids = mock_elements_worker_with_cache.create_elements(
1094
1189
  parent=elt,
@@ -1097,30 +1192,69 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1097
1192
  "name": "0",
1098
1193
  "type": "something",
1099
1194
  "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1100
- }
1195
+ },
1196
+ {
1197
+ "name": "1",
1198
+ "type": "something",
1199
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1200
+ },
1101
1201
  ],
1202
+ batch_size=batch_size,
1102
1203
  )
1103
1204
 
1104
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1105
- assert [
1106
- (call.request.method, call.request.url) for call in responses.calls
1107
- ] == BASE_API_CALLS + [
1205
+ bulk_api_calls = [
1108
1206
  (
1109
1207
  "POST",
1110
1208
  "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1111
- ),
1209
+ )
1112
1210
  ]
1113
- assert json.loads(responses.calls[-1].request.body) == {
1114
- "elements": [
1115
- {
1116
- "name": "0",
1117
- "type": "something",
1118
- "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1119
- }
1120
- ],
1211
+ if batch_size != DEFAULT_BATCH_SIZE:
1212
+ bulk_api_calls.append(
1213
+ (
1214
+ "POST",
1215
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1216
+ )
1217
+ )
1218
+
1219
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1220
+ assert [
1221
+ (call.request.method, call.request.url) for call in responses.calls
1222
+ ] == BASE_API_CALLS + bulk_api_calls
1223
+
1224
+ first_elt = {
1225
+ "name": "0",
1226
+ "type": "something",
1227
+ "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1228
+ }
1229
+ second_elt = {
1230
+ "name": "1",
1231
+ "type": "something",
1232
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1233
+ }
1234
+ empty_payload = {
1235
+ "elements": [],
1121
1236
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1122
1237
  }
1123
- assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
1238
+
1239
+ bodies = []
1240
+ first_call_idx = None
1241
+ if batch_size > 1:
1242
+ first_call_idx = -1
1243
+ bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
1244
+ else:
1245
+ first_call_idx = -2
1246
+ bodies.append({**empty_payload, "elements": [first_elt]})
1247
+ bodies.append({**empty_payload, "elements": [second_elt]})
1248
+
1249
+ assert [
1250
+ json.loads(bulk_call.request.body)
1251
+ for bulk_call in responses.calls[first_call_idx:]
1252
+ ] == bodies
1253
+
1254
+ assert created_ids == [
1255
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1256
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1257
+ ]
1124
1258
 
1125
1259
  # Check that created elements were properly stored in SQLite cache
1126
1260
  assert (tmp_path / "db.sqlite").is_file()
@@ -1134,7 +1268,16 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1134
1268
  polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
1135
1269
  worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1136
1270
  confidence=None,
1137
- )
1271
+ ),
1272
+ CachedElement(
1273
+ id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
1274
+ parent_id=UUID("12341234-1234-1234-1234-123412341234"),
1275
+ type="something",
1276
+ image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1277
+ polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
1278
+ worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1279
+ confidence=None,
1280
+ ),
1138
1281
  ]
1139
1282
 
1140
1283
 
@@ -1261,9 +1404,9 @@ def test_create_elements_integrity_error(
1261
1404
  {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1262
1405
  ]
1263
1406
 
1264
- assert len(caplog.records) == 1
1265
- assert caplog.records[0].levelname == "WARNING"
1266
- assert caplog.records[0].message.startswith(
1407
+ assert len(caplog.records) == 3
1408
+ assert caplog.records[-1].levelname == "WARNING"
1409
+ assert caplog.records[-1].message.startswith(
1267
1410
  "Couldn't save created elements in local cache:"
1268
1411
  )
1269
1412
 
@@ -13,6 +13,7 @@ from arkindex_worker.cache import (
13
13
  CachedTranscriptionEntity,
14
14
  )
15
15
  from arkindex_worker.models import Transcription
16
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
16
17
  from arkindex_worker.worker.entity import MissingEntityType
17
18
  from arkindex_worker.worker.transcription import TextOrientation
18
19
  from tests import CORPUS_ID
@@ -988,38 +989,89 @@ def test_create_transcription_entities_wrong_entity(
988
989
  )
989
990
 
990
991
 
991
- def test_create_transcription_entities(responses, mock_elements_worker):
992
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
993
+ def test_create_transcription_entities(batch_size, responses, mock_elements_worker):
992
994
  transcription = Transcription(id="transcription-id")
995
+
993
996
  # Call to Transcription entities creation in bulk
994
- responses.add(
995
- responses.POST,
996
- "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
997
- status=201,
998
- match=[
999
- matchers.json_params_matcher(
1000
- {
1001
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
997
+ if batch_size > 1:
998
+ responses.add(
999
+ responses.POST,
1000
+ "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1001
+ status=201,
1002
+ match=[
1003
+ matchers.json_params_matcher(
1004
+ {
1005
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
1006
+ "entities": [
1007
+ {
1008
+ "name": "Teklia",
1009
+ "type_id": "22222222-2222-2222-2222-222222222222",
1010
+ "offset": 0,
1011
+ "length": 6,
1012
+ "confidence": 1.0,
1013
+ },
1014
+ {
1015
+ "name": "Team Rocket",
1016
+ "type_id": "22222222-2222-2222-2222-222222222222",
1017
+ "offset": 7,
1018
+ "length": 11,
1019
+ "confidence": 1.0,
1020
+ },
1021
+ ],
1022
+ }
1023
+ )
1024
+ ],
1025
+ json={
1026
+ "entities": [
1027
+ {
1028
+ "transcription_entity_id": "transc-entity-id",
1029
+ "entity_id": "entity-id1",
1030
+ },
1031
+ {
1032
+ "transcription_entity_id": "transc-entity-id",
1033
+ "entity_id": "entity-id2",
1034
+ },
1035
+ ]
1036
+ },
1037
+ )
1038
+ else:
1039
+ for idx, (name, offset, length) in enumerate(
1040
+ [
1041
+ ("Teklia", 0, 6),
1042
+ ("Team Rocket", 7, 11),
1043
+ ],
1044
+ start=1,
1045
+ ):
1046
+ responses.add(
1047
+ responses.POST,
1048
+ "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1049
+ status=201,
1050
+ match=[
1051
+ matchers.json_params_matcher(
1052
+ {
1053
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
1054
+ "entities": [
1055
+ {
1056
+ "name": name,
1057
+ "type_id": "22222222-2222-2222-2222-222222222222",
1058
+ "offset": offset,
1059
+ "length": length,
1060
+ "confidence": 1.0,
1061
+ }
1062
+ ],
1063
+ }
1064
+ )
1065
+ ],
1066
+ json={
1002
1067
  "entities": [
1003
1068
  {
1004
- "name": "Teklia",
1005
- "type_id": "22222222-2222-2222-2222-222222222222",
1006
- "offset": 0,
1007
- "length": 6,
1008
- "confidence": 1.0,
1069
+ "transcription_entity_id": "transc-entity-id",
1070
+ "entity_id": f"entity-id{idx}",
1009
1071
  }
1010
- ],
1011
- }
1072
+ ]
1073
+ },
1012
1074
  )
1013
- ],
1014
- json={
1015
- "entities": [
1016
- {
1017
- "transcription_entity_id": "transc-entity-id",
1018
- "entity_id": "entity-id",
1019
- }
1020
- ]
1021
- },
1022
- )
1023
1075
 
1024
1076
  # Store entity type/slug correspondence on the worker
1025
1077
  mock_elements_worker.entity_types = {
@@ -1034,18 +1086,35 @@ def test_create_transcription_entities(responses, mock_elements_worker):
1034
1086
  "offset": 0,
1035
1087
  "length": 6,
1036
1088
  "confidence": 1.0,
1037
- }
1089
+ },
1090
+ {
1091
+ "name": "Team Rocket",
1092
+ "type_id": "22222222-2222-2222-2222-222222222222",
1093
+ "offset": 7,
1094
+ "length": 11,
1095
+ "confidence": 1.0,
1096
+ },
1038
1097
  ],
1098
+ batch_size=batch_size,
1039
1099
  )
1040
1100
 
1041
- assert len(created_objects) == 1
1101
+ assert len(created_objects) == 2
1042
1102
 
1043
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1044
- assert [
1045
- (call.request.method, call.request.url) for call in responses.calls
1046
- ] == BASE_API_CALLS + [
1103
+ bulk_api_calls = [
1047
1104
  (
1048
1105
  "POST",
1049
1106
  "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1050
- ),
1107
+ )
1051
1108
  ]
1109
+ if batch_size != DEFAULT_BATCH_SIZE:
1110
+ bulk_api_calls.append(
1111
+ (
1112
+ "POST",
1113
+ "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1114
+ )
1115
+ )
1116
+
1117
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1118
+ assert [
1119
+ (call.request.method, call.request.url) for call in responses.calls
1120
+ ] == BASE_API_CALLS + bulk_api_calls