arkindex-base-worker 0.4.0b1__py3-none-any.whl → 0.4.0b3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from peewee import IntegrityError
11
11
  from arkindex_worker import logger
12
12
  from arkindex_worker.cache import CachedElement, CachedTranscription
13
13
  from arkindex_worker.models import Element
14
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
14
15
 
15
16
 
16
17
  class TextOrientation(Enum):
@@ -109,9 +110,11 @@ class TranscriptionMixin:
109
110
 
110
111
  return created
111
112
 
113
+ @batch_publication
112
114
  def create_transcriptions(
113
115
  self,
114
116
  transcriptions: list[dict[str, str | float | TextOrientation | None]],
117
+ batch_size: int = DEFAULT_BATCH_SIZE,
115
118
  ) -> list[dict[str, str | float]]:
116
119
  """
117
120
  Create multiple transcriptions at once on existing elements through the API,
@@ -128,6 +131,8 @@ class TranscriptionMixin:
128
131
  orientation (TextOrientation)
129
132
  Optional. Orientation of the transcription's text.
130
133
 
134
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
135
+
131
136
  :returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint.
132
137
  """
133
138
 
@@ -171,13 +176,19 @@ class TranscriptionMixin:
171
176
  )
172
177
  return
173
178
 
174
- created_trs = self.api_client.request(
175
- "CreateTranscriptions",
176
- body={
177
- "worker_run_id": self.worker_run_id,
178
- "transcriptions": transcriptions_payload,
179
- },
180
- )["transcriptions"]
179
+ created_trs = [
180
+ created_tr
181
+ for batch in make_batches(
182
+ transcriptions_payload, "transcription", batch_size
183
+ )
184
+ for created_tr in self.api_client.request(
185
+ "CreateTranscriptions",
186
+ body={
187
+ "worker_run_id": self.worker_run_id,
188
+ "transcriptions": batch,
189
+ },
190
+ )["transcriptions"]
191
+ ]
181
192
 
182
193
  if self.use_cache:
183
194
  # Store transcriptions in local cache
@@ -201,11 +212,13 @@ class TranscriptionMixin:
201
212
 
202
213
  return created_trs
203
214
 
215
+ @batch_publication
204
216
  def create_element_transcriptions(
205
217
  self,
206
218
  element: Element | CachedElement,
207
219
  sub_element_type: str,
208
220
  transcriptions: list[dict[str, str | float]],
221
+ batch_size: int = DEFAULT_BATCH_SIZE,
209
222
  ) -> dict[str, str | bool]:
210
223
  """
211
224
  Create multiple elements and transcriptions at once on a single parent element through the API.
@@ -225,6 +238,8 @@ class TranscriptionMixin:
225
238
  element_confidence (float)
226
239
  Optional. Confidence score of the element between 0 and 1.
227
240
 
241
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
242
+
228
243
  :returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint.
229
244
  """
230
245
  assert element and isinstance(
@@ -291,16 +306,22 @@ class TranscriptionMixin:
291
306
  )
292
307
  return
293
308
 
294
- annotations = self.api_client.request(
295
- "CreateElementTranscriptions",
296
- id=element.id,
297
- body={
298
- "element_type": sub_element_type,
299
- "worker_run_id": self.worker_run_id,
300
- "transcriptions": transcriptions_payload,
301
- "return_elements": True,
302
- },
303
- )
309
+ annotations = [
310
+ annotation
311
+ for batch in make_batches(
312
+ transcriptions_payload, "transcription", batch_size
313
+ )
314
+ for annotation in self.api_client.request(
315
+ "CreateElementTranscriptions",
316
+ id=element.id,
317
+ body={
318
+ "element_type": sub_element_type,
319
+ "worker_run_id": self.worker_run_id,
320
+ "transcriptions": batch,
321
+ "return_elements": True,
322
+ },
323
+ )
324
+ ]
304
325
 
305
326
  for annotation in annotations:
306
327
  if annotation["created"]:
@@ -7,6 +7,7 @@ from apistar.exceptions import ErrorResponse
7
7
 
8
8
  from arkindex_worker.cache import CachedClassification, CachedElement
9
9
  from arkindex_worker.models import Element
10
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
10
11
  from tests import CORPUS_ID
11
12
 
12
13
  from . import BASE_API_CALLS
@@ -692,7 +693,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
692
693
  }
693
694
 
694
695
 
695
- def test_create_classifications(responses, mock_elements_worker):
696
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
697
+ def test_create_classifications(batch_size, responses, mock_elements_worker):
696
698
  mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
697
699
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
698
700
  responses.add(
@@ -716,62 +718,98 @@ def test_create_classifications(responses, mock_elements_worker):
716
718
  "high_confidence": False,
717
719
  },
718
720
  ],
721
+ batch_size=batch_size,
719
722
  )
720
723
 
721
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
724
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
725
+ if batch_size != DEFAULT_BATCH_SIZE:
726
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
727
+
728
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
722
729
  assert [
723
730
  (call.request.method, call.request.url) for call in responses.calls
724
- ] == BASE_API_CALLS + [
725
- ("POST", "http://testserver/api/v1/classification/bulk/"),
726
- ]
731
+ ] == BASE_API_CALLS + bulk_api_calls
727
732
 
728
- assert json.loads(responses.calls[-1].request.body) == {
733
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
734
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
735
+ empty_payload = {
729
736
  "parent": str(elt.id),
730
737
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
731
- "classifications": [
732
- {
733
- "confidence": 0.75,
734
- "high_confidence": False,
735
- "ml_class": "0000",
736
- },
737
- {
738
- "confidence": 0.25,
739
- "high_confidence": False,
740
- "ml_class": "1111",
741
- },
742
- ],
738
+ "classifications": [],
743
739
  }
744
740
 
741
+ bodies = []
742
+ first_call_idx = None
743
+ if batch_size > 1:
744
+ first_call_idx = -1
745
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
746
+ else:
747
+ first_call_idx = -2
748
+ bodies.append({**empty_payload, "classifications": [first_cl]})
749
+ bodies.append({**empty_payload, "classifications": [second_cl]})
750
+
751
+ assert [
752
+ json.loads(bulk_call.request.body)
753
+ for bulk_call in responses.calls[first_call_idx:]
754
+ ] == bodies
755
+
745
756
 
746
- def test_create_classifications_with_cache(responses, mock_elements_worker_with_cache):
757
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
758
+ def test_create_classifications_with_cache(
759
+ batch_size, responses, mock_elements_worker_with_cache
760
+ ):
747
761
  mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
748
762
  elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
749
763
 
750
- responses.add(
751
- responses.POST,
752
- "http://testserver/api/v1/classification/bulk/",
753
- status=200,
754
- json={
755
- "parent": str(elt.id),
756
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
757
- "classifications": [
758
- {
759
- "id": "00000000-0000-0000-0000-000000000000",
760
- "ml_class": "0000",
761
- "confidence": 0.75,
762
- "high_confidence": False,
763
- "state": "pending",
764
- },
765
- {
766
- "id": "11111111-1111-1111-1111-111111111111",
767
- "ml_class": "1111",
768
- "confidence": 0.25,
769
- "high_confidence": False,
770
- "state": "pending",
764
+ if batch_size > 1:
765
+ responses.add(
766
+ responses.POST,
767
+ "http://testserver/api/v1/classification/bulk/",
768
+ status=200,
769
+ json={
770
+ "parent": str(elt.id),
771
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
772
+ "classifications": [
773
+ {
774
+ "id": "00000000-0000-0000-0000-000000000000",
775
+ "ml_class": "0000",
776
+ "confidence": 0.75,
777
+ "high_confidence": False,
778
+ "state": "pending",
779
+ },
780
+ {
781
+ "id": "11111111-1111-1111-1111-111111111111",
782
+ "ml_class": "1111",
783
+ "confidence": 0.25,
784
+ "high_confidence": False,
785
+ "state": "pending",
786
+ },
787
+ ],
788
+ },
789
+ )
790
+ else:
791
+ for cl_id, cl_class, cl_conf in [
792
+ ("00000000-0000-0000-0000-000000000000", "0000", 0.75),
793
+ ("11111111-1111-1111-1111-111111111111", "1111", 0.25),
794
+ ]:
795
+ responses.add(
796
+ responses.POST,
797
+ "http://testserver/api/v1/classification/bulk/",
798
+ status=200,
799
+ json={
800
+ "parent": str(elt.id),
801
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
802
+ "classifications": [
803
+ {
804
+ "id": cl_id,
805
+ "ml_class": cl_class,
806
+ "confidence": cl_conf,
807
+ "high_confidence": False,
808
+ "state": "pending",
809
+ },
810
+ ],
771
811
  },
772
- ],
773
- },
774
- )
812
+ )
775
813
 
776
814
  mock_elements_worker_with_cache.create_classifications(
777
815
  element=elt,
@@ -787,32 +825,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
787
825
  "high_confidence": False,
788
826
  },
789
827
  ],
828
+ batch_size=batch_size,
790
829
  )
791
830
 
792
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
831
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
832
+ if batch_size != DEFAULT_BATCH_SIZE:
833
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
834
+
835
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
793
836
  assert [
794
837
  (call.request.method, call.request.url) for call in responses.calls
795
- ] == BASE_API_CALLS + [
796
- ("POST", "http://testserver/api/v1/classification/bulk/"),
797
- ]
838
+ ] == BASE_API_CALLS + bulk_api_calls
798
839
 
799
- assert json.loads(responses.calls[-1].request.body) == {
840
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
841
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
842
+ empty_payload = {
800
843
  "parent": str(elt.id),
801
844
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
802
- "classifications": [
803
- {
804
- "confidence": 0.75,
805
- "high_confidence": False,
806
- "ml_class": "0000",
807
- },
808
- {
809
- "confidence": 0.25,
810
- "high_confidence": False,
811
- "ml_class": "1111",
812
- },
813
- ],
845
+ "classifications": [],
814
846
  }
815
847
 
848
+ bodies = []
849
+ first_call_idx = None
850
+ if batch_size > 1:
851
+ first_call_idx = -1
852
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
853
+ else:
854
+ first_call_idx = -2
855
+ bodies.append({**empty_payload, "classifications": [first_cl]})
856
+ bodies.append({**empty_payload, "classifications": [second_cl]})
857
+
858
+ assert [
859
+ json.loads(bulk_call.request.body)
860
+ for bulk_call in responses.calls[first_call_idx:]
861
+ ] == bodies
862
+
816
863
  # Check that created classifications were properly stored in SQLite cache
817
864
  assert list(CachedClassification.select()) == [
818
865
  CachedClassification(