arkindex-base-worker 0.4.0a2__py3-none-any.whl → 0.4.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,12 @@ from arkindex_worker.cache import (
15
15
  unsupported_cache,
16
16
  )
17
17
  from arkindex_worker.models import Element, Transcription
18
+ from arkindex_worker.utils import (
19
+ DEFAULT_BATCH_SIZE,
20
+ batch_publication,
21
+ make_batches,
22
+ pluralize,
23
+ )
18
24
 
19
25
 
20
26
  class Entity(TypedDict):
@@ -48,6 +54,7 @@ class EntityMixin:
48
54
  if not self.entity_types:
49
55
  # Load entity_types of corpus
50
56
  self.list_corpus_entity_types()
57
+
51
58
  for entity_type in entity_types:
52
59
  # Do nothing if type already exists
53
60
  if entity_type in self.entity_types:
@@ -60,7 +67,7 @@ class EntityMixin:
60
67
  )
61
68
 
62
69
  # Create type if non-existent
63
- self.entity_types[entity_type] = self.request(
70
+ self.entity_types[entity_type] = self.api_client.request(
64
71
  "CreateEntityType",
65
72
  body={
66
73
  "name": entity_type,
@@ -106,7 +113,7 @@ class EntityMixin:
106
113
  entity_type_id = self.entity_types.get(type)
107
114
  assert entity_type_id, f"Entity type `{type}` not found in the corpus."
108
115
 
109
- entity = self.request(
116
+ entity = self.api_client.request(
110
117
  "CreateEntity",
111
118
  body={
112
119
  "name": name,
@@ -188,7 +195,7 @@ class EntityMixin:
188
195
  if confidence is not None:
189
196
  body["confidence"] = confidence
190
197
 
191
- transcription_ent = self.request(
198
+ transcription_ent = self.api_client.request(
192
199
  "CreateTranscriptionEntity",
193
200
  id=transcription.id,
194
201
  body=body,
@@ -212,10 +219,12 @@ class EntityMixin:
212
219
  return transcription_ent
213
220
 
214
221
  @unsupported_cache
222
+ @batch_publication
215
223
  def create_transcription_entities(
216
224
  self,
217
225
  transcription: Transcription,
218
226
  entities: list[Entity],
227
+ batch_size: int = DEFAULT_BATCH_SIZE,
219
228
  ) -> list[dict[str, str]]:
220
229
  """
221
230
  Create multiple entities attached to a transcription in a single API request.
@@ -238,6 +247,8 @@ class EntityMixin:
238
247
  confidence (float or None)
239
248
  Optional confidence score, between 0.0 and 1.0.
240
249
 
250
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
251
+
241
252
  :return: List of dicts, with each dict having a two keys, `transcription_entity_id` and `entity_id`, holding the UUID of each created object.
242
253
  """
243
254
  assert transcription and isinstance(
@@ -289,16 +300,20 @@ class EntityMixin:
289
300
  )
290
301
  return
291
302
 
292
- created_ids = self.request(
293
- "CreateTranscriptionEntities",
294
- id=transcription.id,
295
- body={
296
- "worker_run_id": self.worker_run_id,
297
- "entities": entities,
298
- },
299
- )
303
+ created_entities = [
304
+ created_entity
305
+ for batch in make_batches(entities, "entities", batch_size)
306
+ for created_entity in self.api_client.request(
307
+ "CreateTranscriptionEntities",
308
+ id=transcription.id,
309
+ body={
310
+ "worker_run_id": self.worker_run_id,
311
+ "entities": batch,
312
+ },
313
+ )["entities"]
314
+ ]
300
315
 
301
- return created_ids["entities"]
316
+ return created_entities
302
317
 
303
318
  def list_transcription_entities(
304
319
  self,
@@ -382,12 +397,10 @@ class EntityMixin:
382
397
  }
383
398
  count = len(self.entities)
384
399
  logger.info(
385
- f'Loaded {count} entit{"ies" if count > 1 else "y"} in corpus ({self.corpus_id})'
400
+ f'Loaded {count} {pluralize("entity", count)} in corpus ({self.corpus_id})'
386
401
  )
387
402
 
388
- def list_corpus_entity_types(
389
- self,
390
- ):
403
+ def list_corpus_entity_types(self):
391
404
  """
392
405
  Loads available entity types in corpus.
393
406
  """
@@ -399,5 +412,5 @@ class EntityMixin:
399
412
  }
400
413
  count = len(self.entity_types)
401
414
  logger.info(
402
- f'Loaded {count} entity type{"s"[:count>1]} in corpus ({self.corpus_id}).'
415
+ f'Loaded {count} entity {pluralize("type", count)} in corpus ({self.corpus_id}).'
403
416
  )
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from arkindex_worker import logger
8
8
  from arkindex_worker.cache import CachedElement, unsupported_cache
9
9
  from arkindex_worker.models import Element
10
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
10
11
 
11
12
 
12
13
  class MetaType(Enum):
@@ -93,7 +94,7 @@ class MetaDataMixin:
93
94
  logger.warning("Cannot create metadata as this worker is in read-only mode")
94
95
  return
95
96
 
96
- metadata = self.request(
97
+ metadata = self.api_client.request(
97
98
  "CreateMetaData",
98
99
  id=element.id,
99
100
  body={
@@ -108,10 +109,12 @@ class MetaDataMixin:
108
109
  return metadata["id"]
109
110
 
110
111
  @unsupported_cache
112
+ @batch_publication
111
113
  def create_metadata_bulk(
112
114
  self,
113
115
  element: Element | CachedElement,
114
116
  metadata_list: list[dict[str, MetaType | str | int | float | None]],
117
+ batch_size: int = DEFAULT_BATCH_SIZE,
115
118
  ) -> list[dict[str, str]]:
116
119
  """
117
120
  Create multiple metadata on an existing element.
@@ -123,6 +126,9 @@ class MetaDataMixin:
123
126
  - name: str
124
127
  - value: str | int | float
125
128
  - entity_id: str | None
129
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
130
+
131
+ :returns: A list of dicts as returned in the ``metadata_list`` field by the ``CreateMetaDataBulk`` API endpoint.
126
132
  """
127
133
  assert element and isinstance(
128
134
  element, Element | CachedElement
@@ -168,14 +174,18 @@ class MetaDataMixin:
168
174
  logger.warning("Cannot create metadata as this worker is in read-only mode")
169
175
  return
170
176
 
171
- created_metadata_list = self.request(
172
- "CreateMetaDataBulk",
173
- id=element.id,
174
- body={
175
- "worker_run_id": self.worker_run_id,
176
- "metadata_list": metas,
177
- },
178
- )["metadata_list"]
177
+ created_metadata_list = [
178
+ created_metadata
179
+ for batch in make_batches(metas, "metadata", batch_size)
180
+ for created_metadata in self.api_client.request(
181
+ "CreateMetaDataBulk",
182
+ id=element.id,
183
+ body={
184
+ "worker_run_id": self.worker_run_id,
185
+ "metadata_list": batch,
186
+ },
187
+ )["metadata_list"]
188
+ ]
179
189
 
180
190
  return created_metadata_list
181
191
 
@@ -22,7 +22,7 @@ class TaskMixin:
22
22
  task_id, uuid.UUID
23
23
  ), "task_id shouldn't be null and should be an UUID"
24
24
 
25
- results = self.request("ListArtifacts", id=task_id)
25
+ results = self.api_client.request("ListArtifacts", id=task_id)
26
26
 
27
27
  return map(Artifact, results)
28
28
 
@@ -43,4 +43,6 @@ class TaskMixin:
43
43
  artifact, Artifact
44
44
  ), "artifact shouldn't be null and should be an Artifact"
45
45
 
46
- return self.request("DownloadArtifact", id=task_id, path=artifact.path)
46
+ return self.api_client.request(
47
+ "DownloadArtifact", id=task_id, path=artifact.path
48
+ )
@@ -185,7 +185,7 @@ class TrainingMixin:
185
185
  assert not self.model_version, "A model version has already been created."
186
186
 
187
187
  configuration = configuration or {}
188
- self.model_version = self.request(
188
+ self.model_version = self.api_client.request(
189
189
  "CreateModelVersion",
190
190
  id=model_id,
191
191
  body=build_clean_payload(
@@ -217,7 +217,7 @@ class TrainingMixin:
217
217
  :param parent: ID of the parent model version
218
218
  """
219
219
  assert self.model_version, "No model version has been created yet."
220
- self.model_version = self.request(
220
+ self.model_version = self.api_client.request(
221
221
  "UpdateModelVersion",
222
222
  id=self.model_version["id"],
223
223
  body=build_clean_payload(
@@ -273,7 +273,7 @@ class TrainingMixin:
273
273
  """
274
274
  assert self.model_version, "You must create the model version and upload its archive before validating it."
275
275
  try:
276
- self.model_version = self.request(
276
+ self.model_version = self.api_client.request(
277
277
  "PartialUpdateModelVersion",
278
278
  id=self.model_version["id"],
279
279
  body={
@@ -294,7 +294,7 @@ class TrainingMixin:
294
294
  pending_version_id = self.model_version["id"]
295
295
  logger.warning("Removing the pending model version.")
296
296
  try:
297
- self.request("DestroyModelVersion", id=pending_version_id)
297
+ self.api_client.request("DestroyModelVersion", id=pending_version_id)
298
298
  except ErrorResponse as e:
299
299
  msg = getattr(e, "content", str(e))
300
300
  logger.error(
@@ -304,7 +304,7 @@ class TrainingMixin:
304
304
  logger.info("Retrieving the existing model version.")
305
305
  existing_version_id = model_version["id"].pop()
306
306
  try:
307
- self.model_version = self.request(
307
+ self.model_version = self.api_client.request(
308
308
  "RetrieveModelVersion", id=existing_version_id
309
309
  )
310
310
  except ErrorResponse as e:
@@ -11,6 +11,7 @@ from peewee import IntegrityError
11
11
  from arkindex_worker import logger
12
12
  from arkindex_worker.cache import CachedElement, CachedTranscription
13
13
  from arkindex_worker.models import Element
14
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
14
15
 
15
16
 
16
17
  class TextOrientation(Enum):
@@ -77,7 +78,7 @@ class TranscriptionMixin:
77
78
  )
78
79
  return
79
80
 
80
- created = self.request(
81
+ created = self.api_client.request(
81
82
  "CreateTranscription",
82
83
  id=element.id,
83
84
  body={
@@ -109,9 +110,11 @@ class TranscriptionMixin:
109
110
 
110
111
  return created
111
112
 
113
+ @batch_publication
112
114
  def create_transcriptions(
113
115
  self,
114
116
  transcriptions: list[dict[str, str | float | TextOrientation | None]],
117
+ batch_size: int = DEFAULT_BATCH_SIZE,
115
118
  ) -> list[dict[str, str | float]]:
116
119
  """
117
120
  Create multiple transcriptions at once on existing elements through the API,
@@ -128,6 +131,8 @@ class TranscriptionMixin:
128
131
  orientation (TextOrientation)
129
132
  Optional. Orientation of the transcription's text.
130
133
 
134
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
135
+
131
136
  :returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint.
132
137
  """
133
138
 
@@ -171,13 +176,19 @@ class TranscriptionMixin:
171
176
  )
172
177
  return
173
178
 
174
- created_trs = self.request(
175
- "CreateTranscriptions",
176
- body={
177
- "worker_run_id": self.worker_run_id,
178
- "transcriptions": transcriptions_payload,
179
- },
180
- )["transcriptions"]
179
+ created_trs = [
180
+ created_tr
181
+ for batch in make_batches(
182
+ transcriptions_payload, "transcription", batch_size
183
+ )
184
+ for created_tr in self.api_client.request(
185
+ "CreateTranscriptions",
186
+ body={
187
+ "worker_run_id": self.worker_run_id,
188
+ "transcriptions": batch,
189
+ },
190
+ )["transcriptions"]
191
+ ]
181
192
 
182
193
  if self.use_cache:
183
194
  # Store transcriptions in local cache
@@ -201,11 +212,13 @@ class TranscriptionMixin:
201
212
 
202
213
  return created_trs
203
214
 
215
+ @batch_publication
204
216
  def create_element_transcriptions(
205
217
  self,
206
218
  element: Element | CachedElement,
207
219
  sub_element_type: str,
208
220
  transcriptions: list[dict[str, str | float]],
221
+ batch_size: int = DEFAULT_BATCH_SIZE,
209
222
  ) -> dict[str, str | bool]:
210
223
  """
211
224
  Create multiple elements and transcriptions at once on a single parent element through the API.
@@ -225,6 +238,8 @@ class TranscriptionMixin:
225
238
  element_confidence (float)
226
239
  Optional. Confidence score of the element between 0 and 1.
227
240
 
241
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
242
+
228
243
  :returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint.
229
244
  """
230
245
  assert element and isinstance(
@@ -291,16 +306,22 @@ class TranscriptionMixin:
291
306
  )
292
307
  return
293
308
 
294
- annotations = self.request(
295
- "CreateElementTranscriptions",
296
- id=element.id,
297
- body={
298
- "element_type": sub_element_type,
299
- "worker_run_id": self.worker_run_id,
300
- "transcriptions": transcriptions_payload,
301
- "return_elements": True,
302
- },
303
- )
309
+ annotations = [
310
+ annotation
311
+ for batch in make_batches(
312
+ transcriptions_payload, "transcription", batch_size
313
+ )
314
+ for annotation in self.api_client.request(
315
+ "CreateElementTranscriptions",
316
+ id=element.id,
317
+ body={
318
+ "element_type": sub_element_type,
319
+ "worker_run_id": self.worker_run_id,
320
+ "transcriptions": batch,
321
+ "return_elements": True,
322
+ },
323
+ )
324
+ ]
304
325
 
305
326
  for annotation in annotations:
306
327
  if annotation["created"]:
@@ -34,7 +34,9 @@ class WorkerVersionMixin:
34
34
  if worker_version_id in self._worker_version_cache:
35
35
  return self._worker_version_cache[worker_version_id]
36
36
 
37
- worker_version = self.request("RetrieveWorkerVersion", id=worker_version_id)
37
+ worker_version = self.api_client.request(
38
+ "RetrieveWorkerVersion", id=worker_version_id
39
+ )
38
40
  self._worker_version_cache[worker_version_id] = worker_version
39
41
 
40
42
  return worker_version
tests/test_base_worker.py CHANGED
@@ -658,7 +658,7 @@ def test_find_extras_directory_not_found(monkeypatch, extras_path, exists, error
658
658
  def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_path):
659
659
  responses.add(
660
660
  responses.GET,
661
- "http://testserver/api/v1/task/my_task/from-agent/",
661
+ "http://testserver/api/v1/task/my_task/",
662
662
  status=200,
663
663
  json={"parents": ["first", "second", "third"]},
664
664
  )
@@ -7,6 +7,7 @@ from apistar.exceptions import ErrorResponse
7
7
 
8
8
  from arkindex_worker.cache import CachedClassification, CachedElement
9
9
  from arkindex_worker.models import Element
10
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
10
11
  from tests import CORPUS_ID
11
12
 
12
13
  from . import BASE_API_CALLS
@@ -692,7 +693,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
692
693
  }
693
694
 
694
695
 
695
- def test_create_classifications(responses, mock_elements_worker):
696
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
697
+ def test_create_classifications(batch_size, responses, mock_elements_worker):
696
698
  mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
697
699
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
698
700
  responses.add(
@@ -716,62 +718,98 @@ def test_create_classifications(responses, mock_elements_worker):
716
718
  "high_confidence": False,
717
719
  },
718
720
  ],
721
+ batch_size=batch_size,
719
722
  )
720
723
 
721
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
724
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
725
+ if batch_size != DEFAULT_BATCH_SIZE:
726
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
727
+
728
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
722
729
  assert [
723
730
  (call.request.method, call.request.url) for call in responses.calls
724
- ] == BASE_API_CALLS + [
725
- ("POST", "http://testserver/api/v1/classification/bulk/"),
726
- ]
731
+ ] == BASE_API_CALLS + bulk_api_calls
727
732
 
728
- assert json.loads(responses.calls[-1].request.body) == {
733
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
734
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
735
+ empty_payload = {
729
736
  "parent": str(elt.id),
730
737
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
731
- "classifications": [
732
- {
733
- "confidence": 0.75,
734
- "high_confidence": False,
735
- "ml_class": "0000",
736
- },
737
- {
738
- "confidence": 0.25,
739
- "high_confidence": False,
740
- "ml_class": "1111",
741
- },
742
- ],
738
+ "classifications": [],
743
739
  }
744
740
 
741
+ bodies = []
742
+ first_call_idx = None
743
+ if batch_size > 1:
744
+ first_call_idx = -1
745
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
746
+ else:
747
+ first_call_idx = -2
748
+ bodies.append({**empty_payload, "classifications": [first_cl]})
749
+ bodies.append({**empty_payload, "classifications": [second_cl]})
750
+
751
+ assert [
752
+ json.loads(bulk_call.request.body)
753
+ for bulk_call in responses.calls[first_call_idx:]
754
+ ] == bodies
755
+
745
756
 
746
- def test_create_classifications_with_cache(responses, mock_elements_worker_with_cache):
757
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
758
+ def test_create_classifications_with_cache(
759
+ batch_size, responses, mock_elements_worker_with_cache
760
+ ):
747
761
  mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
748
762
  elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
749
763
 
750
- responses.add(
751
- responses.POST,
752
- "http://testserver/api/v1/classification/bulk/",
753
- status=200,
754
- json={
755
- "parent": str(elt.id),
756
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
757
- "classifications": [
758
- {
759
- "id": "00000000-0000-0000-0000-000000000000",
760
- "ml_class": "0000",
761
- "confidence": 0.75,
762
- "high_confidence": False,
763
- "state": "pending",
764
- },
765
- {
766
- "id": "11111111-1111-1111-1111-111111111111",
767
- "ml_class": "1111",
768
- "confidence": 0.25,
769
- "high_confidence": False,
770
- "state": "pending",
764
+ if batch_size > 1:
765
+ responses.add(
766
+ responses.POST,
767
+ "http://testserver/api/v1/classification/bulk/",
768
+ status=200,
769
+ json={
770
+ "parent": str(elt.id),
771
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
772
+ "classifications": [
773
+ {
774
+ "id": "00000000-0000-0000-0000-000000000000",
775
+ "ml_class": "0000",
776
+ "confidence": 0.75,
777
+ "high_confidence": False,
778
+ "state": "pending",
779
+ },
780
+ {
781
+ "id": "11111111-1111-1111-1111-111111111111",
782
+ "ml_class": "1111",
783
+ "confidence": 0.25,
784
+ "high_confidence": False,
785
+ "state": "pending",
786
+ },
787
+ ],
788
+ },
789
+ )
790
+ else:
791
+ for cl_id, cl_class, cl_conf in [
792
+ ("00000000-0000-0000-0000-000000000000", "0000", 0.75),
793
+ ("11111111-1111-1111-1111-111111111111", "1111", 0.25),
794
+ ]:
795
+ responses.add(
796
+ responses.POST,
797
+ "http://testserver/api/v1/classification/bulk/",
798
+ status=200,
799
+ json={
800
+ "parent": str(elt.id),
801
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
802
+ "classifications": [
803
+ {
804
+ "id": cl_id,
805
+ "ml_class": cl_class,
806
+ "confidence": cl_conf,
807
+ "high_confidence": False,
808
+ "state": "pending",
809
+ },
810
+ ],
771
811
  },
772
- ],
773
- },
774
- )
812
+ )
775
813
 
776
814
  mock_elements_worker_with_cache.create_classifications(
777
815
  element=elt,
@@ -787,32 +825,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
787
825
  "high_confidence": False,
788
826
  },
789
827
  ],
828
+ batch_size=batch_size,
790
829
  )
791
830
 
792
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
831
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
832
+ if batch_size != DEFAULT_BATCH_SIZE:
833
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
834
+
835
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
793
836
  assert [
794
837
  (call.request.method, call.request.url) for call in responses.calls
795
- ] == BASE_API_CALLS + [
796
- ("POST", "http://testserver/api/v1/classification/bulk/"),
797
- ]
838
+ ] == BASE_API_CALLS + bulk_api_calls
798
839
 
799
- assert json.loads(responses.calls[-1].request.body) == {
840
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
841
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
842
+ empty_payload = {
800
843
  "parent": str(elt.id),
801
844
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
802
- "classifications": [
803
- {
804
- "confidence": 0.75,
805
- "high_confidence": False,
806
- "ml_class": "0000",
807
- },
808
- {
809
- "confidence": 0.25,
810
- "high_confidence": False,
811
- "ml_class": "1111",
812
- },
813
- ],
845
+ "classifications": [],
814
846
  }
815
847
 
848
+ bodies = []
849
+ first_call_idx = None
850
+ if batch_size > 1:
851
+ first_call_idx = -1
852
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
853
+ else:
854
+ first_call_idx = -2
855
+ bodies.append({**empty_payload, "classifications": [first_cl]})
856
+ bodies.append({**empty_payload, "classifications": [second_cl]})
857
+
858
+ assert [
859
+ json.loads(bulk_call.request.body)
860
+ for bulk_call in responses.calls[first_call_idx:]
861
+ ] == bodies
862
+
816
863
  # Check that created classifications were properly stored in SQLite cache
817
864
  assert list(CachedClassification.select()) == [
818
865
  CachedClassification(