arkindex-base-worker 0.4.0b1__py3-none-any.whl → 0.4.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/METADATA +1 -1
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/RECORD +19 -19
- arkindex_worker/image.py +2 -1
- arkindex_worker/utils.py +81 -0
- arkindex_worker/worker/__init__.py +3 -2
- arkindex_worker/worker/classification.py +31 -15
- arkindex_worker/worker/element.py +71 -10
- arkindex_worker/worker/entity.py +25 -11
- arkindex_worker/worker/metadata.py +18 -8
- arkindex_worker/worker/transcription.py +38 -17
- tests/test_elements_worker/test_classifications.py +107 -60
- tests/test_elements_worker/test_elements.py +318 -49
- tests/test_elements_worker/test_entities.py +102 -33
- tests/test_elements_worker/test_metadata.py +223 -98
- tests/test_elements_worker/test_transcriptions.py +293 -143
- tests/test_utils.py +28 -0
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/LICENSE +0 -0
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/WHEEL +0 -0
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/top_level.txt +0 -0
|
@@ -11,6 +11,7 @@ from peewee import IntegrityError
|
|
|
11
11
|
from arkindex_worker import logger
|
|
12
12
|
from arkindex_worker.cache import CachedElement, CachedTranscription
|
|
13
13
|
from arkindex_worker.models import Element
|
|
14
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class TextOrientation(Enum):
|
|
@@ -109,9 +110,11 @@ class TranscriptionMixin:
|
|
|
109
110
|
|
|
110
111
|
return created
|
|
111
112
|
|
|
113
|
+
@batch_publication
|
|
112
114
|
def create_transcriptions(
|
|
113
115
|
self,
|
|
114
116
|
transcriptions: list[dict[str, str | float | TextOrientation | None]],
|
|
117
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
115
118
|
) -> list[dict[str, str | float]]:
|
|
116
119
|
"""
|
|
117
120
|
Create multiple transcriptions at once on existing elements through the API,
|
|
@@ -128,6 +131,8 @@ class TranscriptionMixin:
|
|
|
128
131
|
orientation (TextOrientation)
|
|
129
132
|
Optional. Orientation of the transcription's text.
|
|
130
133
|
|
|
134
|
+
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
|
|
135
|
+
|
|
131
136
|
:returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint.
|
|
132
137
|
"""
|
|
133
138
|
|
|
@@ -171,13 +176,19 @@ class TranscriptionMixin:
|
|
|
171
176
|
)
|
|
172
177
|
return
|
|
173
178
|
|
|
174
|
-
created_trs =
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
"
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
179
|
+
created_trs = [
|
|
180
|
+
created_tr
|
|
181
|
+
for batch in make_batches(
|
|
182
|
+
transcriptions_payload, "transcription", batch_size
|
|
183
|
+
)
|
|
184
|
+
for created_tr in self.api_client.request(
|
|
185
|
+
"CreateTranscriptions",
|
|
186
|
+
body={
|
|
187
|
+
"worker_run_id": self.worker_run_id,
|
|
188
|
+
"transcriptions": batch,
|
|
189
|
+
},
|
|
190
|
+
)["transcriptions"]
|
|
191
|
+
]
|
|
181
192
|
|
|
182
193
|
if self.use_cache:
|
|
183
194
|
# Store transcriptions in local cache
|
|
@@ -201,11 +212,13 @@ class TranscriptionMixin:
|
|
|
201
212
|
|
|
202
213
|
return created_trs
|
|
203
214
|
|
|
215
|
+
@batch_publication
|
|
204
216
|
def create_element_transcriptions(
|
|
205
217
|
self,
|
|
206
218
|
element: Element | CachedElement,
|
|
207
219
|
sub_element_type: str,
|
|
208
220
|
transcriptions: list[dict[str, str | float]],
|
|
221
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
209
222
|
) -> dict[str, str | bool]:
|
|
210
223
|
"""
|
|
211
224
|
Create multiple elements and transcriptions at once on a single parent element through the API.
|
|
@@ -225,6 +238,8 @@ class TranscriptionMixin:
|
|
|
225
238
|
element_confidence (float)
|
|
226
239
|
Optional. Confidence score of the element between 0 and 1.
|
|
227
240
|
|
|
241
|
+
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
|
|
242
|
+
|
|
228
243
|
:returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint.
|
|
229
244
|
"""
|
|
230
245
|
assert element and isinstance(
|
|
@@ -291,16 +306,22 @@ class TranscriptionMixin:
|
|
|
291
306
|
)
|
|
292
307
|
return
|
|
293
308
|
|
|
294
|
-
annotations =
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
"
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
309
|
+
annotations = [
|
|
310
|
+
annotation
|
|
311
|
+
for batch in make_batches(
|
|
312
|
+
transcriptions_payload, "transcription", batch_size
|
|
313
|
+
)
|
|
314
|
+
for annotation in self.api_client.request(
|
|
315
|
+
"CreateElementTranscriptions",
|
|
316
|
+
id=element.id,
|
|
317
|
+
body={
|
|
318
|
+
"element_type": sub_element_type,
|
|
319
|
+
"worker_run_id": self.worker_run_id,
|
|
320
|
+
"transcriptions": batch,
|
|
321
|
+
"return_elements": True,
|
|
322
|
+
},
|
|
323
|
+
)
|
|
324
|
+
]
|
|
304
325
|
|
|
305
326
|
for annotation in annotations:
|
|
306
327
|
if annotation["created"]:
|
|
@@ -7,6 +7,7 @@ from apistar.exceptions import ErrorResponse
|
|
|
7
7
|
|
|
8
8
|
from arkindex_worker.cache import CachedClassification, CachedElement
|
|
9
9
|
from arkindex_worker.models import Element
|
|
10
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
10
11
|
from tests import CORPUS_ID
|
|
11
12
|
|
|
12
13
|
from . import BASE_API_CALLS
|
|
@@ -692,7 +693,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
692
693
|
}
|
|
693
694
|
|
|
694
695
|
|
|
695
|
-
|
|
696
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
697
|
+
def test_create_classifications(batch_size, responses, mock_elements_worker):
|
|
696
698
|
mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
|
|
697
699
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
698
700
|
responses.add(
|
|
@@ -716,62 +718,98 @@ def test_create_classifications(responses, mock_elements_worker):
|
|
|
716
718
|
"high_confidence": False,
|
|
717
719
|
},
|
|
718
720
|
],
|
|
721
|
+
batch_size=batch_size,
|
|
719
722
|
)
|
|
720
723
|
|
|
721
|
-
|
|
724
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
725
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
726
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
727
|
+
|
|
728
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
722
729
|
assert [
|
|
723
730
|
(call.request.method, call.request.url) for call in responses.calls
|
|
724
|
-
] == BASE_API_CALLS +
|
|
725
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
726
|
-
]
|
|
731
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
727
732
|
|
|
728
|
-
|
|
733
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
734
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
735
|
+
empty_payload = {
|
|
729
736
|
"parent": str(elt.id),
|
|
730
737
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
731
|
-
"classifications": [
|
|
732
|
-
{
|
|
733
|
-
"confidence": 0.75,
|
|
734
|
-
"high_confidence": False,
|
|
735
|
-
"ml_class": "0000",
|
|
736
|
-
},
|
|
737
|
-
{
|
|
738
|
-
"confidence": 0.25,
|
|
739
|
-
"high_confidence": False,
|
|
740
|
-
"ml_class": "1111",
|
|
741
|
-
},
|
|
742
|
-
],
|
|
738
|
+
"classifications": [],
|
|
743
739
|
}
|
|
744
740
|
|
|
741
|
+
bodies = []
|
|
742
|
+
first_call_idx = None
|
|
743
|
+
if batch_size > 1:
|
|
744
|
+
first_call_idx = -1
|
|
745
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
746
|
+
else:
|
|
747
|
+
first_call_idx = -2
|
|
748
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
749
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
750
|
+
|
|
751
|
+
assert [
|
|
752
|
+
json.loads(bulk_call.request.body)
|
|
753
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
754
|
+
] == bodies
|
|
755
|
+
|
|
745
756
|
|
|
746
|
-
|
|
757
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
758
|
+
def test_create_classifications_with_cache(
|
|
759
|
+
batch_size, responses, mock_elements_worker_with_cache
|
|
760
|
+
):
|
|
747
761
|
mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
|
|
748
762
|
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
|
|
749
763
|
|
|
750
|
-
|
|
751
|
-
responses.
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
764
|
+
if batch_size > 1:
|
|
765
|
+
responses.add(
|
|
766
|
+
responses.POST,
|
|
767
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
768
|
+
status=200,
|
|
769
|
+
json={
|
|
770
|
+
"parent": str(elt.id),
|
|
771
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
772
|
+
"classifications": [
|
|
773
|
+
{
|
|
774
|
+
"id": "00000000-0000-0000-0000-000000000000",
|
|
775
|
+
"ml_class": "0000",
|
|
776
|
+
"confidence": 0.75,
|
|
777
|
+
"high_confidence": False,
|
|
778
|
+
"state": "pending",
|
|
779
|
+
},
|
|
780
|
+
{
|
|
781
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
782
|
+
"ml_class": "1111",
|
|
783
|
+
"confidence": 0.25,
|
|
784
|
+
"high_confidence": False,
|
|
785
|
+
"state": "pending",
|
|
786
|
+
},
|
|
787
|
+
],
|
|
788
|
+
},
|
|
789
|
+
)
|
|
790
|
+
else:
|
|
791
|
+
for cl_id, cl_class, cl_conf in [
|
|
792
|
+
("00000000-0000-0000-0000-000000000000", "0000", 0.75),
|
|
793
|
+
("11111111-1111-1111-1111-111111111111", "1111", 0.25),
|
|
794
|
+
]:
|
|
795
|
+
responses.add(
|
|
796
|
+
responses.POST,
|
|
797
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
798
|
+
status=200,
|
|
799
|
+
json={
|
|
800
|
+
"parent": str(elt.id),
|
|
801
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
802
|
+
"classifications": [
|
|
803
|
+
{
|
|
804
|
+
"id": cl_id,
|
|
805
|
+
"ml_class": cl_class,
|
|
806
|
+
"confidence": cl_conf,
|
|
807
|
+
"high_confidence": False,
|
|
808
|
+
"state": "pending",
|
|
809
|
+
},
|
|
810
|
+
],
|
|
771
811
|
},
|
|
772
|
-
|
|
773
|
-
},
|
|
774
|
-
)
|
|
812
|
+
)
|
|
775
813
|
|
|
776
814
|
mock_elements_worker_with_cache.create_classifications(
|
|
777
815
|
element=elt,
|
|
@@ -787,32 +825,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
|
|
|
787
825
|
"high_confidence": False,
|
|
788
826
|
},
|
|
789
827
|
],
|
|
828
|
+
batch_size=batch_size,
|
|
790
829
|
)
|
|
791
830
|
|
|
792
|
-
|
|
831
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
832
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
833
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
834
|
+
|
|
835
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
793
836
|
assert [
|
|
794
837
|
(call.request.method, call.request.url) for call in responses.calls
|
|
795
|
-
] == BASE_API_CALLS +
|
|
796
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
797
|
-
]
|
|
838
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
798
839
|
|
|
799
|
-
|
|
840
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
841
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
842
|
+
empty_payload = {
|
|
800
843
|
"parent": str(elt.id),
|
|
801
844
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
802
|
-
"classifications": [
|
|
803
|
-
{
|
|
804
|
-
"confidence": 0.75,
|
|
805
|
-
"high_confidence": False,
|
|
806
|
-
"ml_class": "0000",
|
|
807
|
-
},
|
|
808
|
-
{
|
|
809
|
-
"confidence": 0.25,
|
|
810
|
-
"high_confidence": False,
|
|
811
|
-
"ml_class": "1111",
|
|
812
|
-
},
|
|
813
|
-
],
|
|
845
|
+
"classifications": [],
|
|
814
846
|
}
|
|
815
847
|
|
|
848
|
+
bodies = []
|
|
849
|
+
first_call_idx = None
|
|
850
|
+
if batch_size > 1:
|
|
851
|
+
first_call_idx = -1
|
|
852
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
853
|
+
else:
|
|
854
|
+
first_call_idx = -2
|
|
855
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
856
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
857
|
+
|
|
858
|
+
assert [
|
|
859
|
+
json.loads(bulk_call.request.body)
|
|
860
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
861
|
+
] == bodies
|
|
862
|
+
|
|
816
863
|
# Check that created classifications were properly stored in SQLite cache
|
|
817
864
|
assert list(CachedClassification.select()) == [
|
|
818
865
|
CachedClassification(
|