arkindex-base-worker 0.4.0b1__py3-none-any.whl → 0.4.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,7 @@ from apistar.exceptions import ErrorResponse
7
7
 
8
8
  from arkindex_worker.cache import CachedClassification, CachedElement
9
9
  from arkindex_worker.models import Element
10
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
10
11
  from tests import CORPUS_ID
11
12
 
12
13
  from . import BASE_API_CALLS
@@ -692,7 +693,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
692
693
  }
693
694
 
694
695
 
695
- def test_create_classifications(responses, mock_elements_worker):
696
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
697
+ def test_create_classifications(batch_size, responses, mock_elements_worker):
696
698
  mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
697
699
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
698
700
  responses.add(
@@ -716,62 +718,98 @@ def test_create_classifications(responses, mock_elements_worker):
716
718
  "high_confidence": False,
717
719
  },
718
720
  ],
721
+ batch_size=batch_size,
719
722
  )
720
723
 
721
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
724
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
725
+ if batch_size != DEFAULT_BATCH_SIZE:
726
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
727
+
728
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
722
729
  assert [
723
730
  (call.request.method, call.request.url) for call in responses.calls
724
- ] == BASE_API_CALLS + [
725
- ("POST", "http://testserver/api/v1/classification/bulk/"),
726
- ]
731
+ ] == BASE_API_CALLS + bulk_api_calls
727
732
 
728
- assert json.loads(responses.calls[-1].request.body) == {
733
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
734
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
735
+ empty_payload = {
729
736
  "parent": str(elt.id),
730
737
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
731
- "classifications": [
732
- {
733
- "confidence": 0.75,
734
- "high_confidence": False,
735
- "ml_class": "0000",
736
- },
737
- {
738
- "confidence": 0.25,
739
- "high_confidence": False,
740
- "ml_class": "1111",
741
- },
742
- ],
738
+ "classifications": [],
743
739
  }
744
740
 
741
+ bodies = []
742
+ first_call_idx = None
743
+ if batch_size > 1:
744
+ first_call_idx = -1
745
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
746
+ else:
747
+ first_call_idx = -2
748
+ bodies.append({**empty_payload, "classifications": [first_cl]})
749
+ bodies.append({**empty_payload, "classifications": [second_cl]})
750
+
751
+ assert [
752
+ json.loads(bulk_call.request.body)
753
+ for bulk_call in responses.calls[first_call_idx:]
754
+ ] == bodies
755
+
745
756
 
746
- def test_create_classifications_with_cache(responses, mock_elements_worker_with_cache):
757
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
758
+ def test_create_classifications_with_cache(
759
+ batch_size, responses, mock_elements_worker_with_cache
760
+ ):
747
761
  mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
748
762
  elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
749
763
 
750
- responses.add(
751
- responses.POST,
752
- "http://testserver/api/v1/classification/bulk/",
753
- status=200,
754
- json={
755
- "parent": str(elt.id),
756
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
757
- "classifications": [
758
- {
759
- "id": "00000000-0000-0000-0000-000000000000",
760
- "ml_class": "0000",
761
- "confidence": 0.75,
762
- "high_confidence": False,
763
- "state": "pending",
764
- },
765
- {
766
- "id": "11111111-1111-1111-1111-111111111111",
767
- "ml_class": "1111",
768
- "confidence": 0.25,
769
- "high_confidence": False,
770
- "state": "pending",
764
+ if batch_size > 1:
765
+ responses.add(
766
+ responses.POST,
767
+ "http://testserver/api/v1/classification/bulk/",
768
+ status=200,
769
+ json={
770
+ "parent": str(elt.id),
771
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
772
+ "classifications": [
773
+ {
774
+ "id": "00000000-0000-0000-0000-000000000000",
775
+ "ml_class": "0000",
776
+ "confidence": 0.75,
777
+ "high_confidence": False,
778
+ "state": "pending",
779
+ },
780
+ {
781
+ "id": "11111111-1111-1111-1111-111111111111",
782
+ "ml_class": "1111",
783
+ "confidence": 0.25,
784
+ "high_confidence": False,
785
+ "state": "pending",
786
+ },
787
+ ],
788
+ },
789
+ )
790
+ else:
791
+ for cl_id, cl_class, cl_conf in [
792
+ ("00000000-0000-0000-0000-000000000000", "0000", 0.75),
793
+ ("11111111-1111-1111-1111-111111111111", "1111", 0.25),
794
+ ]:
795
+ responses.add(
796
+ responses.POST,
797
+ "http://testserver/api/v1/classification/bulk/",
798
+ status=200,
799
+ json={
800
+ "parent": str(elt.id),
801
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
802
+ "classifications": [
803
+ {
804
+ "id": cl_id,
805
+ "ml_class": cl_class,
806
+ "confidence": cl_conf,
807
+ "high_confidence": False,
808
+ "state": "pending",
809
+ },
810
+ ],
771
811
  },
772
- ],
773
- },
774
- )
812
+ )
775
813
 
776
814
  mock_elements_worker_with_cache.create_classifications(
777
815
  element=elt,
@@ -787,32 +825,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
787
825
  "high_confidence": False,
788
826
  },
789
827
  ],
828
+ batch_size=batch_size,
790
829
  )
791
830
 
792
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
831
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
832
+ if batch_size != DEFAULT_BATCH_SIZE:
833
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
834
+
835
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
793
836
  assert [
794
837
  (call.request.method, call.request.url) for call in responses.calls
795
- ] == BASE_API_CALLS + [
796
- ("POST", "http://testserver/api/v1/classification/bulk/"),
797
- ]
838
+ ] == BASE_API_CALLS + bulk_api_calls
798
839
 
799
- assert json.loads(responses.calls[-1].request.body) == {
840
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
841
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
842
+ empty_payload = {
800
843
  "parent": str(elt.id),
801
844
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
802
- "classifications": [
803
- {
804
- "confidence": 0.75,
805
- "high_confidence": False,
806
- "ml_class": "0000",
807
- },
808
- {
809
- "confidence": 0.25,
810
- "high_confidence": False,
811
- "ml_class": "1111",
812
- },
813
- ],
845
+ "classifications": [],
814
846
  }
815
847
 
848
+ bodies = []
849
+ first_call_idx = None
850
+ if batch_size > 1:
851
+ first_call_idx = -1
852
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
853
+ else:
854
+ first_call_idx = -2
855
+ bodies.append({**empty_payload, "classifications": [first_cl]})
856
+ bodies.append({**empty_payload, "classifications": [second_cl]})
857
+
858
+ assert [
859
+ json.loads(bulk_call.request.body)
860
+ for bulk_call in responses.calls[first_call_idx:]
861
+ ] == bodies
862
+
816
863
  # Check that created classifications were properly stored in SQLite cache
817
864
  assert list(CachedClassification.select()) == [
818
865
  CachedClassification(
@@ -15,6 +15,7 @@ from arkindex_worker.cache import (
15
15
  init_cache_db,
16
16
  )
17
17
  from arkindex_worker.models import Element
18
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
18
19
  from arkindex_worker.worker import ElementsWorker
19
20
  from arkindex_worker.worker.element import MissingTypeError
20
21
  from tests import CORPUS_ID
@@ -62,7 +63,7 @@ def test_check_required_types(mock_elements_worker):
62
63
  with pytest.raises(
63
64
  MissingTypeError,
64
65
  match=re.escape(
65
- "Element type(s) act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
66
+ "Element types act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
66
67
  ),
67
68
  ):
68
69
  assert mock_elements_worker.check_required_types("page", "text_line", "act")
@@ -1010,7 +1011,10 @@ def test_create_elements_api_error(responses, mock_elements_worker):
1010
1011
  ]
1011
1012
 
1012
1013
 
1013
- def test_create_elements_cached_element(responses, mock_elements_worker_with_cache):
1014
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
1015
+ def test_create_elements_cached_element(
1016
+ batch_size, responses, mock_elements_worker_with_cache
1017
+ ):
1014
1018
  image = CachedImage.create(
1015
1019
  id=UUID("c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe"),
1016
1020
  width=42,
@@ -1023,12 +1027,28 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1023
1027
  image_id=image.id,
1024
1028
  polygon="[[0, 0], [0, 1000], [1000, 1000], [1000, 0], [0, 0]]",
1025
1029
  )
1026
- responses.add(
1027
- responses.POST,
1028
- "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1029
- status=200,
1030
- json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
1031
- )
1030
+
1031
+ if batch_size > 1:
1032
+ responses.add(
1033
+ responses.POST,
1034
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1035
+ status=200,
1036
+ json=[
1037
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1038
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1039
+ ],
1040
+ )
1041
+ else:
1042
+ for elt_id in [
1043
+ "497f6eca-6276-4993-bfeb-53cbbbba6f08",
1044
+ "5468c358-b9c4-499d-8b92-d6349c58e88d",
1045
+ ]:
1046
+ responses.add(
1047
+ responses.POST,
1048
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1049
+ status=200,
1050
+ json=[{"id": elt_id}],
1051
+ )
1032
1052
 
1033
1053
  created_ids = mock_elements_worker_with_cache.create_elements(
1034
1054
  parent=elt,
@@ -1037,30 +1057,69 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1037
1057
  "name": "0",
1038
1058
  "type": "something",
1039
1059
  "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1040
- }
1060
+ },
1061
+ {
1062
+ "name": "1",
1063
+ "type": "something",
1064
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1065
+ },
1041
1066
  ],
1067
+ batch_size=batch_size,
1042
1068
  )
1043
1069
 
1044
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1045
- assert [
1046
- (call.request.method, call.request.url) for call in responses.calls
1047
- ] == BASE_API_CALLS + [
1070
+ bulk_api_calls = [
1048
1071
  (
1049
1072
  "POST",
1050
1073
  "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1051
- ),
1074
+ )
1052
1075
  ]
1053
- assert json.loads(responses.calls[-1].request.body) == {
1054
- "elements": [
1055
- {
1056
- "name": "0",
1057
- "type": "something",
1058
- "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1059
- }
1060
- ],
1076
+ if batch_size != DEFAULT_BATCH_SIZE:
1077
+ bulk_api_calls.append(
1078
+ (
1079
+ "POST",
1080
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1081
+ )
1082
+ )
1083
+
1084
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1085
+ assert [
1086
+ (call.request.method, call.request.url) for call in responses.calls
1087
+ ] == BASE_API_CALLS + bulk_api_calls
1088
+
1089
+ first_elt = {
1090
+ "name": "0",
1091
+ "type": "something",
1092
+ "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1093
+ }
1094
+ second_elt = {
1095
+ "name": "1",
1096
+ "type": "something",
1097
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1098
+ }
1099
+ empty_payload = {
1100
+ "elements": [],
1061
1101
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1062
1102
  }
1063
- assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
1103
+
1104
+ bodies = []
1105
+ first_call_idx = None
1106
+ if batch_size > 1:
1107
+ first_call_idx = -1
1108
+ bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
1109
+ else:
1110
+ first_call_idx = -2
1111
+ bodies.append({**empty_payload, "elements": [first_elt]})
1112
+ bodies.append({**empty_payload, "elements": [second_elt]})
1113
+
1114
+ assert [
1115
+ json.loads(bulk_call.request.body)
1116
+ for bulk_call in responses.calls[first_call_idx:]
1117
+ ] == bodies
1118
+
1119
+ assert created_ids == [
1120
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1121
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1122
+ ]
1064
1123
 
1065
1124
  # Check that created elements were properly stored in SQLite cache
1066
1125
  assert list(CachedElement.select().order_by(CachedElement.id)) == [
@@ -1072,11 +1131,24 @@ def test_create_elements_cached_element(responses, mock_elements_worker_with_cac
1072
1131
  image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1073
1132
  polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
1074
1133
  worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1134
+ confidence=None,
1135
+ ),
1136
+ CachedElement(
1137
+ id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
1138
+ parent_id=elt.id,
1139
+ type="something",
1140
+ image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1141
+ polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
1142
+ worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1143
+ confidence=None,
1075
1144
  ),
1076
1145
  ]
1077
1146
 
1078
1147
 
1079
- def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1148
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
1149
+ def test_create_elements(
1150
+ batch_size, responses, mock_elements_worker_with_cache, tmp_path
1151
+ ):
1080
1152
  elt = Element(
1081
1153
  {
1082
1154
  "id": "12341234-1234-1234-1234-123412341234",
@@ -1090,12 +1162,28 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1090
1162
  },
1091
1163
  }
1092
1164
  )
1093
- responses.add(
1094
- responses.POST,
1095
- "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1096
- status=200,
1097
- json=[{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}],
1098
- )
1165
+
1166
+ if batch_size > 1:
1167
+ responses.add(
1168
+ responses.POST,
1169
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1170
+ status=200,
1171
+ json=[
1172
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1173
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1174
+ ],
1175
+ )
1176
+ else:
1177
+ for elt_id in [
1178
+ "497f6eca-6276-4993-bfeb-53cbbbba6f08",
1179
+ "5468c358-b9c4-499d-8b92-d6349c58e88d",
1180
+ ]:
1181
+ responses.add(
1182
+ responses.POST,
1183
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1184
+ status=200,
1185
+ json=[{"id": elt_id}],
1186
+ )
1099
1187
 
1100
1188
  created_ids = mock_elements_worker_with_cache.create_elements(
1101
1189
  parent=elt,
@@ -1104,30 +1192,69 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1104
1192
  "name": "0",
1105
1193
  "type": "something",
1106
1194
  "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1107
- }
1195
+ },
1196
+ {
1197
+ "name": "1",
1198
+ "type": "something",
1199
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1200
+ },
1108
1201
  ],
1202
+ batch_size=batch_size,
1109
1203
  )
1110
1204
 
1111
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1112
- assert [
1113
- (call.request.method, call.request.url) for call in responses.calls
1114
- ] == BASE_API_CALLS + [
1205
+ bulk_api_calls = [
1115
1206
  (
1116
1207
  "POST",
1117
1208
  "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1118
- ),
1209
+ )
1119
1210
  ]
1120
- assert json.loads(responses.calls[-1].request.body) == {
1121
- "elements": [
1122
- {
1123
- "name": "0",
1124
- "type": "something",
1125
- "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1126
- }
1127
- ],
1211
+ if batch_size != DEFAULT_BATCH_SIZE:
1212
+ bulk_api_calls.append(
1213
+ (
1214
+ "POST",
1215
+ "http://testserver/api/v1/element/12341234-1234-1234-1234-123412341234/children/bulk/",
1216
+ )
1217
+ )
1218
+
1219
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1220
+ assert [
1221
+ (call.request.method, call.request.url) for call in responses.calls
1222
+ ] == BASE_API_CALLS + bulk_api_calls
1223
+
1224
+ first_elt = {
1225
+ "name": "0",
1226
+ "type": "something",
1227
+ "polygon": [[1, 1], [2, 2], [2, 1], [1, 2]],
1228
+ }
1229
+ second_elt = {
1230
+ "name": "1",
1231
+ "type": "something",
1232
+ "polygon": [[4, 4], [5, 5], [5, 4], [4, 5]],
1233
+ }
1234
+ empty_payload = {
1235
+ "elements": [],
1128
1236
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1129
1237
  }
1130
- assert created_ids == [{"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"}]
1238
+
1239
+ bodies = []
1240
+ first_call_idx = None
1241
+ if batch_size > 1:
1242
+ first_call_idx = -1
1243
+ bodies.append({**empty_payload, "elements": [first_elt, second_elt]})
1244
+ else:
1245
+ first_call_idx = -2
1246
+ bodies.append({**empty_payload, "elements": [first_elt]})
1247
+ bodies.append({**empty_payload, "elements": [second_elt]})
1248
+
1249
+ assert [
1250
+ json.loads(bulk_call.request.body)
1251
+ for bulk_call in responses.calls[first_call_idx:]
1252
+ ] == bodies
1253
+
1254
+ assert created_ids == [
1255
+ {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1256
+ {"id": "5468c358-b9c4-499d-8b92-d6349c58e88d"},
1257
+ ]
1131
1258
 
1132
1259
  # Check that created elements were properly stored in SQLite cache
1133
1260
  assert (tmp_path / "db.sqlite").is_file()
@@ -1141,7 +1268,16 @@ def test_create_elements(responses, mock_elements_worker_with_cache, tmp_path):
1141
1268
  polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
1142
1269
  worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1143
1270
  confidence=None,
1144
- )
1271
+ ),
1272
+ CachedElement(
1273
+ id=UUID("5468c358-b9c4-499d-8b92-d6349c58e88d"),
1274
+ parent_id=UUID("12341234-1234-1234-1234-123412341234"),
1275
+ type="something",
1276
+ image_id="c0fec0fe-c0fe-c0fe-c0fe-c0fec0fec0fe",
1277
+ polygon=[[4, 4], [5, 5], [5, 4], [4, 5]],
1278
+ worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
1279
+ confidence=None,
1280
+ ),
1145
1281
  ]
1146
1282
 
1147
1283
 
@@ -1268,9 +1404,9 @@ def test_create_elements_integrity_error(
1268
1404
  {"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08"},
1269
1405
  ]
1270
1406
 
1271
- assert len(caplog.records) == 1
1272
- assert caplog.records[0].levelname == "WARNING"
1273
- assert caplog.records[0].message.startswith(
1407
+ assert len(caplog.records) == 3
1408
+ assert caplog.records[-1].levelname == "WARNING"
1409
+ assert caplog.records[-1].message.startswith(
1274
1410
  "Couldn't save created elements in local cache:"
1275
1411
  )
1276
1412