arkindex-base-worker 0.4.0__py3-none-any.whl → 0.4.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/METADATA +13 -15
  2. arkindex_base_worker-0.4.0a2.dist-info/RECORD +51 -0
  3. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/WHEEL +1 -1
  4. arkindex_worker/cache.py +1 -1
  5. arkindex_worker/image.py +1 -120
  6. arkindex_worker/utils.py +0 -82
  7. arkindex_worker/worker/__init__.py +161 -46
  8. arkindex_worker/worker/base.py +11 -36
  9. arkindex_worker/worker/classification.py +18 -34
  10. arkindex_worker/worker/corpus.py +4 -21
  11. arkindex_worker/worker/dataset.py +1 -71
  12. arkindex_worker/worker/element.py +91 -352
  13. arkindex_worker/worker/entity.py +11 -11
  14. arkindex_worker/worker/metadata.py +9 -19
  15. arkindex_worker/worker/task.py +4 -5
  16. arkindex_worker/worker/training.py +6 -6
  17. arkindex_worker/worker/transcription.py +68 -89
  18. arkindex_worker/worker/version.py +1 -3
  19. tests/__init__.py +1 -1
  20. tests/conftest.py +45 -33
  21. tests/test_base_worker.py +3 -204
  22. tests/test_dataset_worker.py +4 -7
  23. tests/test_elements_worker/{test_classification.py → test_classifications.py} +61 -194
  24. tests/test_elements_worker/test_corpus.py +1 -32
  25. tests/test_elements_worker/test_dataset.py +1 -1
  26. tests/test_elements_worker/test_elements.py +2734 -0
  27. tests/test_elements_worker/{test_entity_create.py → test_entities.py} +160 -26
  28. tests/test_elements_worker/test_image.py +1 -2
  29. tests/test_elements_worker/test_metadata.py +99 -224
  30. tests/test_elements_worker/test_task.py +1 -1
  31. tests/test_elements_worker/test_training.py +2 -2
  32. tests/test_elements_worker/test_transcriptions.py +2102 -0
  33. tests/test_elements_worker/test_worker.py +280 -563
  34. tests/test_image.py +204 -429
  35. tests/test_merge.py +2 -1
  36. tests/test_utils.py +3 -66
  37. arkindex_base_worker-0.4.0.dist-info/RECORD +0 -61
  38. arkindex_worker/worker/process.py +0 -92
  39. tests/test_elements_worker/test_element.py +0 -427
  40. tests/test_elements_worker/test_element_create_multiple.py +0 -715
  41. tests/test_elements_worker/test_element_create_single.py +0 -528
  42. tests/test_elements_worker/test_element_list_children.py +0 -969
  43. tests/test_elements_worker/test_element_list_parents.py +0 -530
  44. tests/test_elements_worker/test_entity_list_and_check.py +0 -160
  45. tests/test_elements_worker/test_process.py +0 -89
  46. tests/test_elements_worker/test_transcription_create.py +0 -873
  47. tests/test_elements_worker/test_transcription_create_with_elements.py +0 -951
  48. tests/test_elements_worker/test_transcription_list.py +0 -450
  49. tests/test_elements_worker/test_version.py +0 -60
  50. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/LICENSE +0 -0
  51. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@ from arkindex_worker.cache import (
15
15
  unsupported_cache,
16
16
  )
17
17
  from arkindex_worker.models import Element, Transcription
18
- from arkindex_worker.utils import pluralize
19
18
 
20
19
 
21
20
  class Entity(TypedDict):
@@ -49,7 +48,6 @@ class EntityMixin:
49
48
  if not self.entity_types:
50
49
  # Load entity_types of corpus
51
50
  self.list_corpus_entity_types()
52
-
53
51
  for entity_type in entity_types:
54
52
  # Do nothing if type already exists
55
53
  if entity_type in self.entity_types:
@@ -62,7 +60,7 @@ class EntityMixin:
62
60
  )
63
61
 
64
62
  # Create type if non-existent
65
- self.entity_types[entity_type] = self.api_client.request(
63
+ self.entity_types[entity_type] = self.request(
66
64
  "CreateEntityType",
67
65
  body={
68
66
  "name": entity_type,
@@ -108,7 +106,7 @@ class EntityMixin:
108
106
  entity_type_id = self.entity_types.get(type)
109
107
  assert entity_type_id, f"Entity type `{type}` not found in the corpus."
110
108
 
111
- entity = self.api_client.request(
109
+ entity = self.request(
112
110
  "CreateEntity",
113
111
  body={
114
112
  "name": name,
@@ -190,7 +188,7 @@ class EntityMixin:
190
188
  if confidence is not None:
191
189
  body["confidence"] = confidence
192
190
 
193
- transcription_ent = self.api_client.request(
191
+ transcription_ent = self.request(
194
192
  "CreateTranscriptionEntity",
195
193
  id=transcription.id,
196
194
  body=body,
@@ -291,16 +289,16 @@ class EntityMixin:
291
289
  )
292
290
  return
293
291
 
294
- created_entities = self.api_client.request(
292
+ created_ids = self.request(
295
293
  "CreateTranscriptionEntities",
296
294
  id=transcription.id,
297
295
  body={
298
296
  "worker_run_id": self.worker_run_id,
299
297
  "entities": entities,
300
298
  },
301
- )["entities"]
299
+ )
302
300
 
303
- return created_entities
301
+ return created_ids["entities"]
304
302
 
305
303
  def list_transcription_entities(
306
304
  self,
@@ -384,10 +382,12 @@ class EntityMixin:
384
382
  }
385
383
  count = len(self.entities)
386
384
  logger.info(
387
- f'Loaded {count} {pluralize("entity", count)} in corpus ({self.corpus_id})'
385
+ f'Loaded {count} entit{"ies" if count > 1 else "y"} in corpus ({self.corpus_id})'
388
386
  )
389
387
 
390
- def list_corpus_entity_types(self):
388
+ def list_corpus_entity_types(
389
+ self,
390
+ ):
391
391
  """
392
392
  Loads available entity types in corpus.
393
393
  """
@@ -399,5 +399,5 @@ class EntityMixin:
399
399
  }
400
400
  count = len(self.entity_types)
401
401
  logger.info(
402
- f'Loaded {count} entity {pluralize("type", count)} in corpus ({self.corpus_id}).'
402
+ f'Loaded {count} entity type{"s"[:count>1]} in corpus ({self.corpus_id}).'
403
403
  )
@@ -7,7 +7,6 @@ from enum import Enum
7
7
  from arkindex_worker import logger
8
8
  from arkindex_worker.cache import CachedElement, unsupported_cache
9
9
  from arkindex_worker.models import Element
10
- from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
11
10
 
12
11
 
13
12
  class MetaType(Enum):
@@ -94,7 +93,7 @@ class MetaDataMixin:
94
93
  logger.warning("Cannot create metadata as this worker is in read-only mode")
95
94
  return
96
95
 
97
- metadata = self.api_client.request(
96
+ metadata = self.request(
98
97
  "CreateMetaData",
99
98
  id=element.id,
100
99
  body={
@@ -109,12 +108,10 @@ class MetaDataMixin:
109
108
  return metadata["id"]
110
109
 
111
110
  @unsupported_cache
112
- @batch_publication
113
111
  def create_metadata_bulk(
114
112
  self,
115
113
  element: Element | CachedElement,
116
114
  metadata_list: list[dict[str, MetaType | str | int | float | None]],
117
- batch_size: int = DEFAULT_BATCH_SIZE,
118
115
  ) -> list[dict[str, str]]:
119
116
  """
120
117
  Create multiple metadata on an existing element.
@@ -126,9 +123,6 @@ class MetaDataMixin:
126
123
  - name: str
127
124
  - value: str | int | float
128
125
  - entity_id: str | None
129
- :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
130
-
131
- :returns: A list of dicts as returned in the ``metadata_list`` field by the ``CreateMetaDataBulk`` API endpoint.
132
126
  """
133
127
  assert element and isinstance(
134
128
  element, Element | CachedElement
@@ -174,18 +168,14 @@ class MetaDataMixin:
174
168
  logger.warning("Cannot create metadata as this worker is in read-only mode")
175
169
  return
176
170
 
177
- created_metadata_list = [
178
- created_metadata
179
- for batch in make_batches(metas, "metadata", batch_size)
180
- for created_metadata in self.api_client.request(
181
- "CreateMetaDataBulk",
182
- id=element.id,
183
- body={
184
- "worker_run_id": self.worker_run_id,
185
- "metadata_list": batch,
186
- },
187
- )["metadata_list"]
188
- ]
171
+ created_metadata_list = self.request(
172
+ "CreateMetaDataBulk",
173
+ id=element.id,
174
+ body={
175
+ "worker_run_id": self.worker_run_id,
176
+ "metadata_list": metas,
177
+ },
178
+ )["metadata_list"]
189
179
 
190
180
  return created_metadata_list
191
181
 
@@ -5,7 +5,8 @@ BaseWorker methods for tasks.
5
5
  import uuid
6
6
  from collections.abc import Iterator
7
7
 
8
- from arkindex.compat import DownloadedFile
8
+ from apistar.compat import DownloadedFile
9
+
9
10
  from arkindex_worker.models import Artifact
10
11
 
11
12
 
@@ -21,7 +22,7 @@ class TaskMixin:
21
22
  task_id, uuid.UUID
22
23
  ), "task_id shouldn't be null and should be an UUID"
23
24
 
24
- results = self.api_client.request("ListArtifacts", id=task_id)
25
+ results = self.request("ListArtifacts", id=task_id)
25
26
 
26
27
  return map(Artifact, results)
27
28
 
@@ -42,6 +43,4 @@ class TaskMixin:
42
43
  artifact, Artifact
43
44
  ), "artifact shouldn't be null and should be an Artifact"
44
45
 
45
- return self.api_client.request(
46
- "DownloadArtifact", id=task_id, path=artifact.path
47
- )
46
+ return self.request("DownloadArtifact", id=task_id, path=artifact.path)
@@ -9,8 +9,8 @@ from typing import NewType
9
9
  from uuid import UUID
10
10
 
11
11
  import requests
12
+ from apistar.exceptions import ErrorResponse
12
13
 
13
- from arkindex.exceptions import ErrorResponse
14
14
  from arkindex_worker import logger
15
15
  from arkindex_worker.utils import close_delete_file, create_tar_zst_archive
16
16
 
@@ -185,7 +185,7 @@ class TrainingMixin:
185
185
  assert not self.model_version, "A model version has already been created."
186
186
 
187
187
  configuration = configuration or {}
188
- self.model_version = self.api_client.request(
188
+ self.model_version = self.request(
189
189
  "CreateModelVersion",
190
190
  id=model_id,
191
191
  body=build_clean_payload(
@@ -217,7 +217,7 @@ class TrainingMixin:
217
217
  :param parent: ID of the parent model version
218
218
  """
219
219
  assert self.model_version, "No model version has been created yet."
220
- self.model_version = self.api_client.request(
220
+ self.model_version = self.request(
221
221
  "UpdateModelVersion",
222
222
  id=self.model_version["id"],
223
223
  body=build_clean_payload(
@@ -273,7 +273,7 @@ class TrainingMixin:
273
273
  """
274
274
  assert self.model_version, "You must create the model version and upload its archive before validating it."
275
275
  try:
276
- self.model_version = self.api_client.request(
276
+ self.model_version = self.request(
277
277
  "PartialUpdateModelVersion",
278
278
  id=self.model_version["id"],
279
279
  body={
@@ -294,7 +294,7 @@ class TrainingMixin:
294
294
  pending_version_id = self.model_version["id"]
295
295
  logger.warning("Removing the pending model version.")
296
296
  try:
297
- self.api_client.request("DestroyModelVersion", id=pending_version_id)
297
+ self.request("DestroyModelVersion", id=pending_version_id)
298
298
  except ErrorResponse as e:
299
299
  msg = getattr(e, "content", str(e))
300
300
  logger.error(
@@ -304,7 +304,7 @@ class TrainingMixin:
304
304
  logger.info("Retrieving the existing model version.")
305
305
  existing_version_id = model_version["id"].pop()
306
306
  try:
307
- self.model_version = self.api_client.request(
307
+ self.model_version = self.request(
308
308
  "RetrieveModelVersion", id=existing_version_id
309
309
  )
310
310
  except ErrorResponse as e:
@@ -11,7 +11,6 @@ from peewee import IntegrityError
11
11
  from arkindex_worker import logger
12
12
  from arkindex_worker.cache import CachedElement, CachedTranscription
13
13
  from arkindex_worker.models import Element
14
- from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
15
14
 
16
15
 
17
16
  class TextOrientation(Enum):
@@ -78,7 +77,7 @@ class TranscriptionMixin:
78
77
  )
79
78
  return
80
79
 
81
- created = self.api_client.request(
80
+ created = self.request(
82
81
  "CreateTranscription",
83
82
  id=element.id,
84
83
  body={
@@ -110,11 +109,9 @@ class TranscriptionMixin:
110
109
 
111
110
  return created
112
111
 
113
- @batch_publication
114
112
  def create_transcriptions(
115
113
  self,
116
114
  transcriptions: list[dict[str, str | float | TextOrientation | None]],
117
- batch_size: int = DEFAULT_BATCH_SIZE,
118
115
  ) -> list[dict[str, str | float]]:
119
116
  """
120
117
  Create multiple transcriptions at once on existing elements through the API,
@@ -131,8 +128,6 @@ class TranscriptionMixin:
131
128
  orientation (TextOrientation)
132
129
  Optional. Orientation of the transcription's text.
133
130
 
134
- :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
135
-
136
131
  :returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint.
137
132
  """
138
133
 
@@ -176,19 +171,13 @@ class TranscriptionMixin:
176
171
  )
177
172
  return
178
173
 
179
- created_trs = [
180
- created_tr
181
- for batch in make_batches(
182
- transcriptions_payload, "transcription", batch_size
183
- )
184
- for created_tr in self.api_client.request(
185
- "CreateTranscriptions",
186
- body={
187
- "worker_run_id": self.worker_run_id,
188
- "transcriptions": batch,
189
- },
190
- )["transcriptions"]
191
- ]
174
+ created_trs = self.request(
175
+ "CreateTranscriptions",
176
+ body={
177
+ "worker_run_id": self.worker_run_id,
178
+ "transcriptions": transcriptions_payload,
179
+ },
180
+ )["transcriptions"]
192
181
 
193
182
  if self.use_cache:
194
183
  # Store transcriptions in local cache
@@ -212,13 +201,11 @@ class TranscriptionMixin:
212
201
 
213
202
  return created_trs
214
203
 
215
- @batch_publication
216
204
  def create_element_transcriptions(
217
205
  self,
218
206
  element: Element | CachedElement,
219
207
  sub_element_type: str,
220
208
  transcriptions: list[dict[str, str | float]],
221
- batch_size: int = DEFAULT_BATCH_SIZE,
222
209
  ) -> dict[str, str | bool]:
223
210
  """
224
211
  Create multiple elements and transcriptions at once on a single parent element through the API.
@@ -238,8 +225,6 @@ class TranscriptionMixin:
238
225
  element_confidence (float)
239
226
  Optional. Confidence score of the element between 0 and 1.
240
227
 
241
- :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
242
-
243
228
  :returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint.
244
229
  """
245
230
  assert element and isinstance(
@@ -306,22 +291,16 @@ class TranscriptionMixin:
306
291
  )
307
292
  return
308
293
 
309
- annotations = [
310
- annotation
311
- for batch in make_batches(
312
- transcriptions_payload, "transcription", batch_size
313
- )
314
- for annotation in self.api_client.request(
315
- "CreateElementTranscriptions",
316
- id=element.id,
317
- body={
318
- "element_type": sub_element_type,
319
- "worker_run_id": self.worker_run_id,
320
- "transcriptions": batch,
321
- "return_elements": True,
322
- },
323
- )
324
- ]
294
+ annotations = self.request(
295
+ "CreateElementTranscriptions",
296
+ id=element.id,
297
+ body={
298
+ "element_type": sub_element_type,
299
+ "worker_run_id": self.worker_run_id,
300
+ "transcriptions": transcriptions_payload,
301
+ "return_elements": True,
302
+ },
303
+ )
325
304
 
326
305
  for annotation in annotations:
327
306
  if annotation["created"]:
@@ -441,60 +420,60 @@ class TranscriptionMixin:
441
420
  ), "if of type bool, worker_run can only be set to False"
442
421
  query_params["worker_run"] = worker_run
443
422
 
444
- if not self.use_cache:
445
- return self.api_client.paginate(
446
- "ListTranscriptions", id=element.id, **query_params
447
- )
448
-
449
- if not recursive:
450
- # In this case we don't have to return anything, it's easier to use an
451
- # impossible condition (False) rather than filtering by type for nothing
452
- if element_type and element_type != element.type:
453
- return CachedTranscription.select().where(False)
454
- transcriptions = CachedTranscription.select().where(
455
- CachedTranscription.element_id == element.id
456
- )
457
- else:
458
- base_case = (
459
- CachedElement.select()
460
- .where(CachedElement.id == element.id)
461
- .cte("base", recursive=True)
462
- )
463
- recursive = CachedElement.select().join(
464
- base_case, on=(CachedElement.parent_id == base_case.c.id)
465
- )
466
- cte = base_case.union_all(recursive)
467
- transcriptions = (
468
- CachedTranscription.select()
469
- .join(cte, on=(CachedTranscription.element_id == cte.c.id))
470
- .with_cte(cte)
471
- )
472
-
473
- if element_type:
474
- transcriptions = transcriptions.where(cte.c.type == element_type)
475
-
476
- if worker_version is not None:
477
- # If worker_version=False, filter by manual worker_version e.g. None
478
- worker_version_id = worker_version or None
479
- if worker_version_id:
480
- transcriptions = transcriptions.where(
481
- CachedTranscription.worker_version_id == worker_version_id
423
+ if self.use_cache:
424
+ if not recursive:
425
+ # In this case we don't have to return anything, it's easier to use an
426
+ # impossible condition (False) rather than filtering by type for nothing
427
+ if element_type and element_type != element.type:
428
+ return CachedTranscription.select().where(False)
429
+ transcriptions = CachedTranscription.select().where(
430
+ CachedTranscription.element_id == element.id
482
431
  )
483
432
  else:
484
- transcriptions = transcriptions.where(
485
- CachedTranscription.worker_version_id.is_null()
433
+ base_case = (
434
+ CachedElement.select()
435
+ .where(CachedElement.id == element.id)
436
+ .cte("base", recursive=True)
486
437
  )
487
-
488
- if worker_run is not None:
489
- # If worker_run=False, filter by manual worker_run e.g. None
490
- worker_run_id = worker_run or None
491
- if worker_run_id:
492
- transcriptions = transcriptions.where(
493
- CachedTranscription.worker_run_id == worker_run_id
438
+ recursive = CachedElement.select().join(
439
+ base_case, on=(CachedElement.parent_id == base_case.c.id)
494
440
  )
495
- else:
496
- transcriptions = transcriptions.where(
497
- CachedTranscription.worker_run_id.is_null()
441
+ cte = base_case.union_all(recursive)
442
+ transcriptions = (
443
+ CachedTranscription.select()
444
+ .join(cte, on=(CachedTranscription.element_id == cte.c.id))
445
+ .with_cte(cte)
498
446
  )
499
447
 
448
+ if element_type:
449
+ transcriptions = transcriptions.where(cte.c.type == element_type)
450
+
451
+ if worker_version is not None:
452
+ # If worker_version=False, filter by manual worker_version e.g. None
453
+ worker_version_id = worker_version or None
454
+ if worker_version_id:
455
+ transcriptions = transcriptions.where(
456
+ CachedTranscription.worker_version_id == worker_version_id
457
+ )
458
+ else:
459
+ transcriptions = transcriptions.where(
460
+ CachedTranscription.worker_version_id.is_null()
461
+ )
462
+
463
+ if worker_run is not None:
464
+ # If worker_run=False, filter by manual worker_run e.g. None
465
+ worker_run_id = worker_run or None
466
+ if worker_run_id:
467
+ transcriptions = transcriptions.where(
468
+ CachedTranscription.worker_run_id == worker_run_id
469
+ )
470
+ else:
471
+ transcriptions = transcriptions.where(
472
+ CachedTranscription.worker_run_id.is_null()
473
+ )
474
+ else:
475
+ transcriptions = self.api_client.paginate(
476
+ "ListTranscriptions", id=element.id, **query_params
477
+ )
478
+
500
479
  return transcriptions
@@ -34,9 +34,7 @@ class WorkerVersionMixin:
34
34
  if worker_version_id in self._worker_version_cache:
35
35
  return self._worker_version_cache[worker_version_id]
36
36
 
37
- worker_version = self.api_client.request(
38
- "RetrieveWorkerVersion", id=worker_version_id
39
- )
37
+ worker_version = self.request("RetrieveWorkerVersion", id=worker_version_id)
40
38
  self._worker_version_cache[worker_version_id] = worker_version
41
39
 
42
40
  return worker_version
tests/__init__.py CHANGED
@@ -5,4 +5,4 @@ FIXTURES_DIR = BASE_DIR / "data"
5
5
  SAMPLES_DIR = BASE_DIR / "samples"
6
6
 
7
7
  CORPUS_ID = "11111111-1111-1111-1111-111111111111"
8
- PROCESS_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"
8
+ PROCESS_ID = "cafecafe-cafe-cafe-cafe-cafecafecafe"