arkindex-base-worker 0.4.0b1__py3-none-any.whl → 0.4.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@ from arkindex_worker.cache import (
15
15
  init_cache_db,
16
16
  )
17
17
  from arkindex_worker.models import Element
18
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
18
19
  from arkindex_worker.worker import ElementsWorker
19
20
  from arkindex_worker.worker.element import MissingTypeError
20
21
  from tests import CORPUS_ID
@@ -62,7 +63,7 @@ def test_check_required_types(mock_elements_worker):
62
63
  with pytest.raises(
63
64
  MissingTypeError,
64
65
  match=re.escape(
65
- "Element type(s) act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
66
+ "Element types act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
66
67
  ),
67
68
  ):
68
69
  assert mock_elements_worker.check_required_types("page", "text_line", "act")
@@ -1010,7 +1011,10 @@ def test_create_elements_api_error(responses, mock_elements_worker):
1010
1011
  ]
1011
1012
 
1012
1013
 
1013
- def test_create_elements_cached_element(responses, mock_elements_worker_with_cache):
1014
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
1015
+ def test_create_elements_cached_element(
1016
+ batch_size, responses, mock_elements_worker_with_cache
1017
+ ):
1014
1018
  image = CachedImage.create(
1015
1019
  id=UUID("c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe"),
1016
1020
  width=42,
@@ -1023,12 +1027,28 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1023
1027
  image_id=image.id,
1024
1028
  polygon="[[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]]",
1025
1029
  )
1026
- responses.add(
1027
- responses.POST,
1028
- "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1029
- status=200,
1030
- json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
1031
- )
1030
+
1031
+ if batch_size > 1:
1032
+ responses.add(
1033
+ responses.POST,
1034
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1035
+ status=200,
1036
+ json=[
1037
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1038
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1039
+ ],
1040
+ )
1041
+ else:
1042
+ for elt_id in [
1043
+ "497f6eca-6276-4993-bfeb-53cbbbba6f08",
1044
+ "5468c358-b9c4-499d-8b92-d6349c58e88d",
1045
+ ]:
1046
+ responses.add(
1047
+ responses.POST,
1048
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1049
+ status=200,
1050
+ json=[{"id": elt_id}],
1051
+ )
1032
1052
 
1033
1053
  created_ids = mock_elements_worker_with_cache.create_elements(
1034
1054
  parent=elt,
@@ -1037,30 +1057,69 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1037
1057
  "name": "0",
1038
1058
  "type": "something",
1039
1059
  "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1040
- }
1060
+ },
1061
+ {
1062
+ "name": "1",
1063
+ "type": "something",
1064
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1065
+ },
1041
1066
  ],
1067
+ batch_size=batch_size,
1042
1068
  )
1043
1069
 
1044
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1045
- assert [
1046
- (call.request.method, call.request.url) for call in responses.calls
1047
- ] == BASE_API_CALLS + [
1070
+ bulk_api_calls = [
1048
1071
  (
1049
1072
  "POST",
1050
1073
  "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1051
- ),
1074
+ )
1052
1075
  ]
1053
- assert json.loads(responses.calls[-1].request.body) == {
1054
- "elements": [
1055
- {
1056
- "name": "0",
1057
- "type": "something",
1058
- "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1059
- }
1060
- ],
1076
+ if batch_size != DEFAULT_BATCH_SIZE:
1077
+ bulk_api_calls.append(
1078
+ (
1079
+ "POST",
1080
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1081
+ )
1082
+ )
1083
+
1084
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1085
+ assert [
1086
+ (call.request.method, call.request.url) for call in responses.calls
1087
+ ] == BASE_API_CALLS + bulk_api_calls
1088
+
1089
+ first_elt = {
1090
+ "name": "0",
1091
+ "type": "something",
1092
+ "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1093
+ }
1094
+ second_elt = {
1095
+ "name": "1",
1096
+ "type": "something",
1097
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1098
+ }
1099
+ empty_payload = {
1100
+ "elements": [],
1061
1101
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1062
1102
  }
1063
- assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
1103
+
1104
+ bodies = []
1105
+ first_call_idx = None
1106
+ if batch_size > 1:
1107
+ first_call_idx = -1
1108
+ bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
1109
+ else:
1110
+ first_call_idx = -2
1111
+ bodies.append({**empty_payload, "elements": [first_elt]})
1112
+ bodies.append({**empty_payload, "elements": [second_elt]})
1113
+
1114
+ assert [
1115
+ json.loads(bulk_call.request.body)
1116
+ for bulk_call in responses.calls[first_call_idx:]
1117
+ ] == bodies
1118
+
1119
+ assert created_ids == [
1120
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1121
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1122
+ ]
1064
1123
 
1065
1124
  # Check that created elements were properly stored in SQLite cache
1066
1125
  assert list(CachedElement.select().order_by(CachedElement.id)) == [
@@ -1072,11 +1131,24 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1072
1131
  image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1073
1132
  polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
1074
1133
  worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1134
+ confidence=None,
1135
+ ),
1136
+ CachedElement(
1137
+ id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
1138
+ parent_id=elt.id,
1139
+ type="something",
1140
+ image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1141
+ polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
1142
+ worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1143
+ confidence=None,
1075
1144
  ),
1076
1145
  ]
1077
1146
 
1078
1147
 
1079
- def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1148
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
1149
+ def test_create_elements(
1150
+ batch_size, responses, mock_elements_worker_with_cache, tmp_path
1151
+ ):
1080
1152
  elt = Element(
1081
1153
  {
1082
1154
  "id": "12341234-1234-1234-1234-123412341234",
@@ -1090,12 +1162,28 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1090
1162
  },
1091
1163
  }
1092
1164
  )
1093
- responses.add(
1094
- responses.POST,
1095
- "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1096
- status=200,
1097
- json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
1098
- )
1165
+
1166
+ if batch_size > 1:
1167
+ responses.add(
1168
+ responses.POST,
1169
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1170
+ status=200,
1171
+ json=[
1172
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1173
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1174
+ ],
1175
+ )
1176
+ else:
1177
+ for elt_id in [
1178
+ "497f6eca-6276-4993-bfeb-53cbbbba6f08",
1179
+ "5468c358-b9c4-499d-8b92-d6349c58e88d",
1180
+ ]:
1181
+ responses.add(
1182
+ responses.POST,
1183
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1184
+ status=200,
1185
+ json=[{"id": elt_id}],
1186
+ )
1099
1187
 
1100
1188
  created_ids = mock_elements_worker_with_cache.create_elements(
1101
1189
  parent=elt,
@@ -1104,30 +1192,69 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1104
1192
  "name": "0",
1105
1193
  "type": "something",
1106
1194
  "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1107
- }
1195
+ },
1196
+ {
1197
+ "name": "1",
1198
+ "type": "something",
1199
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1200
+ },
1108
1201
  ],
1202
+ batch_size=batch_size,
1109
1203
  )
1110
1204
 
1111
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1112
- assert [
1113
- (call.request.method, call.request.url) for call in responses.calls
1114
- ] == BASE_API_CALLS + [
1205
+ bulk_api_calls = [
1115
1206
  (
1116
1207
  "POST",
1117
1208
  "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1118
- ),
1209
+ )
1119
1210
  ]
1120
- assert json.loads(responses.calls[-1].request.body) == {
1121
- "elements": [
1122
- {
1123
- "name": "0",
1124
- "type": "something",
1125
- "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1126
- }
1127
- ],
1211
+ if batch_size != DEFAULT_BATCH_SIZE:
1212
+ bulk_api_calls.append(
1213
+ (
1214
+ "POST",
1215
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1216
+ )
1217
+ )
1218
+
1219
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1220
+ assert [
1221
+ (call.request.method, call.request.url) for call in responses.calls
1222
+ ] == BASE_API_CALLS + bulk_api_calls
1223
+
1224
+ first_elt = {
1225
+ "name": "0",
1226
+ "type": "something",
1227
+ "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1228
+ }
1229
+ second_elt = {
1230
+ "name": "1",
1231
+ "type": "something",
1232
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1233
+ }
1234
+ empty_payload = {
1235
+ "elements": [],
1128
1236
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1129
1237
  }
1130
- assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
1238
+
1239
+ bodies = []
1240
+ first_call_idx = None
1241
+ if batch_size > 1:
1242
+ first_call_idx = -1
1243
+ bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
1244
+ else:
1245
+ first_call_idx = -2
1246
+ bodies.append({**empty_payload, "elements": [first_elt]})
1247
+ bodies.append({**empty_payload, "elements": [second_elt]})
1248
+
1249
+ assert [
1250
+ json.loads(bulk_call.request.body)
1251
+ for bulk_call in responses.calls[first_call_idx:]
1252
+ ] == bodies
1253
+
1254
+ assert created_ids == [
1255
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1256
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1257
+ ]
1131
1258
 
1132
1259
  # Check that created elements were properly stored in SQLite cache
1133
1260
  assert (tmp_path / "db.sqlite").is_file()
@@ -1141,7 +1268,16 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1141
1268
  polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
1142
1269
  worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1143
1270
  confidence=None,
1144
- )
1271
+ ),
1272
+ CachedElement(
1273
+ id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
1274
+ parent_id=UUID("12341234-1234-1234-1234-123412341234"),
1275
+ type="something",
1276
+ image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1277
+ polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
1278
+ worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1279
+ confidence=None,
1280
+ ),
1145
1281
  ]
1146
1282
 
1147
1283
 
@@ -1268,9 +1404,9 @@ def test_create_elements_integrity_error(
1268
1404
  {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1269
1405
  ]
1270
1406
 
1271
- assert len(caplog.records) == 1
1272
- assert caplog.records[0].levelname == "WARNING"
1273
- assert caplog.records[0].message.startswith(
1407
+ assert len(caplog.records) == 3
1408
+ assert caplog.records[-1].levelname == "WARNING"
1409
+ assert caplog.records[-1].message.startswith(
1274
1410
  "Couldn't save created elements in local cache:"
1275
1411
  )
1276
1412
 
@@ -1364,6 +1500,139 @@ def test_create_element_parent(responses, mock_elements_worker):
1364
1500
  }
1365
1501
 
1366
1502
 
1503
+ @pytest.mark.parametrize(
1504
+ ("arg_name", "data", "error_message"),
1505
+ [
1506
+ (
1507
+ "parent",
1508
+ None,
1509
+ "parent shouldn't be null and should be of type Element",
1510
+ ),
1511
+ (
1512
+ "parent",
1513
+ "not element type",
1514
+ "parent shouldn't be null and should be of type Element",
1515
+ ),
1516
+ (
1517
+ "children",
1518
+ None,
1519
+ "children shouldn't be null and should be of type list",
1520
+ ),
1521
+ (
1522
+ "children",
1523
+ "not a list",
1524
+ "children shouldn't be null and should be of type list",
1525
+ ),
1526
+ (
1527
+ "children",
1528
+ [
1529
+ Element({"id": "11111111-1111-1111-1111-111111111111"}),
1530
+ "not element type",
1531
+ ],
1532
+ "Child at index 1 in children: Should be of type Element",
1533
+ ),
1534
+ ],
1535
+ )
1536
+ def test_create_element_children_wrong_params(
1537
+ arg_name, data, error_message, mock_elements_worker
1538
+ ):
1539
+ with pytest.raises(AssertionError, match=error_message):
1540
+ mock_elements_worker.create_element_children(
1541
+ **{
1542
+ "parent": Element({"id": "12341234-1234-1234-1234-123412341234"}),
1543
+ "children": [
1544
+ Element({"id": "11111111-1111-1111-1111-111111111111"}),
1545
+ Element({"id": "22222222-2222-2222-2222-222222222222"}),
1546
+ ],
1547
+ # Overwrite with wrong data
1548
+ arg_name: data,
1549
+ },
1550
+ )
1551
+
1552
+
1553
+ def test_create_element_children_api_error(responses, mock_elements_worker):
1554
+ parent = Element({"id": "12341234-1234-1234-1234-123412341234"})
1555
+ responses.add(
1556
+ responses.POST,
1557
+ f"http://testserver/api/v1/element/parent/{parent.id}/",
1558
+ status=418,
1559
+ )
1560
+
1561
+ with pytest.raises(ErrorResponse):
1562
+ mock_elements_worker.create_element_children(
1563
+ parent=parent,
1564
+ children=[
1565
+ Element({"id": "11111111-1111-1111-1111-111111111111"}),
1566
+ Element({"id": "22222222-2222-2222-2222-222222222222"}),
1567
+ ],
1568
+ )
1569
+
1570
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
1571
+ assert [
1572
+ (call.request.method, call.request.url) for call in responses.calls
1573
+ ] == BASE_API_CALLS + [
1574
+ (
1575
+ "POST",
1576
+ f"http://testserver/api/v1/element/parent/{parent.id}/",
1577
+ )
1578
+ ]
1579
+
1580
+
1581
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
1582
+ def test_create_element_children(batch_size, responses, mock_elements_worker):
1583
+ parent = Element({"id": "12341234-1234-1234-1234-123412341234"})
1584
+
1585
+ first_child = Element({"id": "11111111-1111-1111-1111-111111111111"})
1586
+ second_child = Element({"id": "22222222-2222-2222-2222-222222222222"})
1587
+
1588
+ responses.add(
1589
+ responses.POST,
1590
+ f"http://testserver/api/v1/element/parent/{parent.id}/",
1591
+ status=200,
1592
+ json={"children": []},
1593
+ )
1594
+
1595
+ mock_elements_worker.create_element_children(
1596
+ parent=parent,
1597
+ children=[first_child, second_child],
1598
+ batch_size=batch_size,
1599
+ )
1600
+
1601
+ bulk_api_calls = [
1602
+ (
1603
+ "POST",
1604
+ f"http://testserver/api/v1/element/parent/{parent.id}/",
1605
+ )
1606
+ ]
1607
+ if batch_size != DEFAULT_BATCH_SIZE:
1608
+ bulk_api_calls.append(
1609
+ (
1610
+ "POST",
1611
+ f"http://testserver/api/v1/element/parent/{parent.id}/",
1612
+ )
1613
+ )
1614
+
1615
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1616
+ assert [
1617
+ (call.request.method, call.request.url) for call in responses.calls
1618
+ ] == BASE_API_CALLS + bulk_api_calls
1619
+
1620
+ bodies = []
1621
+ first_call_idx = None
1622
+ if batch_size > 1:
1623
+ first_call_idx = -1
1624
+ bodies.append({"children": [first_child.id, second_child.id]})
1625
+ else:
1626
+ first_call_idx = -2
1627
+ bodies.append({"children": [first_child.id]})
1628
+ bodies.append({"children": [second_child.id]})
1629
+
1630
+ assert [
1631
+ json.loads(bulk_call.request.body)
1632
+ for bulk_call in responses.calls[first_call_idx:]
1633
+ ] == bodies
1634
+
1635
+
1367
1636
  @pytest.mark.parametrize(
1368
1637
  ("payload", "error"),
1369
1638
  [
@@ -13,6 +13,7 @@ from arkindex_worker.cache import (
13
13
  CachedTranscriptionEntity,
14
14
  )
15
15
  from arkindex_worker.models import Transcription
16
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
16
17
  from arkindex_worker.worker.entity import MissingEntityType
17
18
  from arkindex_worker.worker.transcription import TextOrientation
18
19
  from tests import CORPUS_ID
@@ -988,38 +989,89 @@ def test_create_transcription_entities_wrong_entity(
988
989
  )
989
990
 
990
991
 
991
- def test_create_transcription_entities(responses, mock_elements_worker):
992
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
993
+ def test_create_transcription_entities(batch_size, responses, mock_elements_worker):
992
994
  transcription = Transcription(id="transcription-id")
995
+
993
996
  # Call to Transcription entities creation in bulk
994
- responses.add(
995
- responses.POST,
996
- "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
997
- status=201,
998
- match=[
999
- matchers.json_params_matcher(
1000
- {
1001
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
997
+ if batch_size > 1:
998
+ responses.add(
999
+ responses.POST,
1000
+ "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1001
+ status=201,
1002
+ match=[
1003
+ matchers.json_params_matcher(
1004
+ {
1005
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
1006
+ "entities": [
1007
+ {
1008
+ "name": "Teklia",
1009
+ "type_id": "22222222-2222-2222-2222-222222222222",
1010
+ "offset": 0,
1011
+ "length": 6,
1012
+ "confidence": 1.0,
1013
+ },
1014
+ {
1015
+ "name": "Team Rocket",
1016
+ "type_id": "22222222-2222-2222-2222-222222222222",
1017
+ "offset": 7,
1018
+ "length": 11,
1019
+ "confidence": 1.0,
1020
+ },
1021
+ ],
1022
+ }
1023
+ )
1024
+ ],
1025
+ json={
1026
+ "entities": [
1027
+ {
1028
+ "transcription_entity_id": "transc-entity-id",
1029
+ "entity_id": "entity-id1",
1030
+ },
1031
+ {
1032
+ "transcription_entity_id": "transc-entity-id",
1033
+ "entity_id": "entity-id2",
1034
+ },
1035
+ ]
1036
+ },
1037
+ )
1038
+ else:
1039
+ for idx, (name, offset, length) in enumerate(
1040
+ [
1041
+ ("Teklia", 0, 6),
1042
+ ("Team Rocket", 7, 11),
1043
+ ],
1044
+ start=1,
1045
+ ):
1046
+ responses.add(
1047
+ responses.POST,
1048
+ "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1049
+ status=201,
1050
+ match=[
1051
+ matchers.json_params_matcher(
1052
+ {
1053
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
1054
+ "entities": [
1055
+ {
1056
+ "name": name,
1057
+ "type_id": "22222222-2222-2222-2222-222222222222",
1058
+ "offset": offset,
1059
+ "length": length,
1060
+ "confidence": 1.0,
1061
+ }
1062
+ ],
1063
+ }
1064
+ )
1065
+ ],
1066
+ json={
1002
1067
  "entities": [
1003
1068
  {
1004
- "name": "Teklia",
1005
- "type_id": "22222222-2222-2222-2222-222222222222",
1006
- "offset": 0,
1007
- "length": 6,
1008
- "confidence": 1.0,
1069
+ "transcription_entity_id": "transc-entity-id",
1070
+ "entity_id": f"entity-id{idx}",
1009
1071
  }
1010
- ],
1011
- }
1072
+ ]
1073
+ },
1012
1074
  )
1013
- ],
1014
- json={
1015
- "entities": [
1016
- {
1017
- "transcription_entity_id": "transc-entity-id",
1018
- "entity_id": "entity-id",
1019
- }
1020
- ]
1021
- },
1022
- )
1023
1075
 
1024
1076
  # Store entity type/slug correspondence on the worker
1025
1077
  mock_elements_worker.entity_types = {
@@ -1034,18 +1086,35 @@ def test_create_transcription_entities(responses, mock_elements_worker):
1034
1086
  "offset": 0,
1035
1087
  "length": 6,
1036
1088
  "confidence": 1.0,
1037
- }
1089
+ },
1090
+ {
1091
+ "name": "Team Rocket",
1092
+ "type_id": "22222222-2222-2222-2222-222222222222",
1093
+ "offset": 7,
1094
+ "length": 11,
1095
+ "confidence": 1.0,
1096
+ },
1038
1097
  ],
1098
+ batch_size=batch_size,
1039
1099
  )
1040
1100
 
1041
- assert len(created_objects) == 1
1101
+ assert len(created_objects) == 2
1042
1102
 
1043
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1044
- assert [
1045
- (call.request.method, call.request.url) for call in responses.calls
1046
- ] == BASE_API_CALLS + [
1103
+ bulk_api_calls = [
1047
1104
  (
1048
1105
  "POST",
1049
1106
  "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1050
- ),
1107
+ )
1051
1108
  ]
1109
+ if batch_size != DEFAULT_BATCH_SIZE:
1110
+ bulk_api_calls.append(
1111
+ (
1112
+ "POST",
1113
+ "http://testserver/api/v1/transcription/transcription-id/entities/bulk/",
1114
+ )
1115
+ )
1116
+
1117
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1118
+ assert [
1119
+ (call.request.method, call.request.url) for call in responses.calls
1120
+ ] == BASE_API_CALLS + bulk_api_calls