arkindex-base-worker 0.3.6rc1__py3-none-any.whl → 0.3.6rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. arkindex_base_worker-0.3.6rc2.dist-info/METADATA +39 -0
  2. arkindex_base_worker-0.3.6rc2.dist-info/RECORD +40 -0
  3. arkindex_worker/__init__.py +0 -1
  4. arkindex_worker/cache.py +19 -25
  5. arkindex_worker/image.py +16 -17
  6. arkindex_worker/models.py +17 -21
  7. arkindex_worker/utils.py +16 -17
  8. arkindex_worker/worker/__init__.py +14 -23
  9. arkindex_worker/worker/base.py +12 -7
  10. arkindex_worker/worker/classification.py +13 -15
  11. arkindex_worker/worker/dataset.py +3 -4
  12. arkindex_worker/worker/element.py +80 -75
  13. arkindex_worker/worker/entity.py +27 -29
  14. arkindex_worker/worker/metadata.py +19 -25
  15. arkindex_worker/worker/task.py +2 -3
  16. arkindex_worker/worker/training.py +21 -22
  17. arkindex_worker/worker/transcription.py +37 -34
  18. arkindex_worker/worker/version.py +1 -2
  19. tests/conftest.py +55 -75
  20. tests/test_base_worker.py +37 -31
  21. tests/test_cache.py +14 -7
  22. tests/test_dataset_worker.py +4 -4
  23. tests/test_element.py +0 -1
  24. tests/test_elements_worker/__init__.py +0 -1
  25. tests/test_elements_worker/test_classifications.py +0 -1
  26. tests/test_elements_worker/test_cli.py +22 -17
  27. tests/test_elements_worker/test_dataset.py +9 -10
  28. tests/test_elements_worker/test_elements.py +58 -63
  29. tests/test_elements_worker/test_entities.py +10 -20
  30. tests/test_elements_worker/test_metadata.py +72 -96
  31. tests/test_elements_worker/test_task.py +9 -10
  32. tests/test_elements_worker/test_training.py +20 -13
  33. tests/test_elements_worker/test_transcriptions.py +6 -10
  34. tests/test_elements_worker/test_worker.py +16 -14
  35. tests/test_image.py +21 -20
  36. tests/test_merge.py +5 -6
  37. tests/test_utils.py +0 -1
  38. arkindex_base_worker-0.3.6rc1.dist-info/METADATA +0 -27
  39. arkindex_base_worker-0.3.6rc1.dist-info/RECORD +0 -42
  40. arkindex_worker/git.py +0 -392
  41. tests/test_git.py +0 -480
  42. {arkindex_base_worker-0.3.6rc1.dist-info → arkindex_base_worker-0.3.6rc2.dist-info}/WHEEL +0 -0
  43. {arkindex_base_worker-0.3.6rc1.dist-info → arkindex_base_worker-0.3.6rc2.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
1
  """
3
2
  The base class for all Arkindex workers.
4
3
  """
@@ -9,7 +8,6 @@ import os
9
8
  import shutil
10
9
  from pathlib import Path
11
10
  from tempfile import mkdtemp
12
- from typing import List, Optional
13
11
 
14
12
  import gnupg
15
13
  import yaml
@@ -52,15 +50,15 @@ class ExtrasDirNotFoundError(Exception):
52
50
  """
53
51
 
54
52
 
55
- class BaseWorker(object):
53
+ class BaseWorker:
56
54
  """
57
55
  Base class for Arkindex workers.
58
56
  """
59
57
 
60
58
  def __init__(
61
59
  self,
62
- description: Optional[str] = "Arkindex Base Worker",
63
- support_cache: Optional[bool] = False,
60
+ description: str | None = "Arkindex Base Worker",
61
+ support_cache: bool | None = False,
64
62
  ):
65
63
  """
66
64
  Initialize the worker.
@@ -217,6 +215,9 @@ class BaseWorker(object):
217
215
  # Define model_version_id from environment
218
216
  self.model_version_id = os.environ.get("ARKINDEX_MODEL_VERSION_ID")
219
217
 
218
+ # Define model_details from environment
219
+ self.model_details = {"id": os.environ.get("ARKINDEX_MODEL_ID")}
220
+
220
221
  # Load all required secrets
221
222
  self.secrets = {name: self.load_secret(Path(name)) for name in required_secrets}
222
223
 
@@ -259,6 +260,9 @@ class BaseWorker(object):
259
260
  # Set model_version ID as worker attribute
260
261
  self.model_version_id = model_version.get("id")
261
262
 
263
+ # Set model details as worker attribute
264
+ self.model_details = model_version.get("model")
265
+
262
266
  # Retrieve initial configuration from API
263
267
  self.config = worker_version["configuration"].get("configuration", {})
264
268
  if "user_configuration" in worker_version["configuration"]:
@@ -347,7 +351,8 @@ class BaseWorker(object):
347
351
 
348
352
  try:
349
353
  gpg = gnupg.GPG()
350
- decrypted = gpg.decrypt_file(open(path, "rb"))
354
+ with path.open("rb") as gpg_file:
355
+ decrypted = gpg.decrypt_file(gpg_file)
351
356
  assert (
352
357
  decrypted.ok
353
358
  ), f"GPG error: {decrypted.status} - {decrypted.stderr}"
@@ -406,7 +411,7 @@ class BaseWorker(object):
406
411
  )
407
412
  return extras_dir
408
413
 
409
- def find_parents_file_paths(self, filename: Path) -> List[Path]:
414
+ def find_parents_file_paths(self, filename: Path) -> list[Path]:
410
415
  """
411
416
  Find the paths of a specific file from the parent tasks.
412
417
  Only works if the task_parents attributes is updated, so if the cache is supported,
@@ -1,9 +1,7 @@
1
- # -*- coding: utf-8 -*-
2
1
  """
3
2
  ElementsWorker methods for classifications and ML classes.
4
3
  """
5
4
 
6
- from typing import Dict, List, Optional, Union
7
5
  from uuid import UUID
8
6
 
9
7
  from apistar.exceptions import ErrorResponse
@@ -14,7 +12,7 @@ from arkindex_worker.cache import CachedClassification, CachedElement
14
12
  from arkindex_worker.models import Element
15
13
 
16
14
 
17
- class ClassificationMixin(object):
15
+ class ClassificationMixin:
18
16
  def load_corpus_classes(self):
19
17
  """
20
18
  Load all ML classes available in the worker's corpus and store them in the ``self.classes`` cache.
@@ -91,11 +89,11 @@ class ClassificationMixin(object):
91
89
 
92
90
  def create_classification(
93
91
  self,
94
- element: Union[Element, CachedElement],
92
+ element: Element | CachedElement,
95
93
  ml_class: str,
96
94
  confidence: float,
97
- high_confidence: Optional[bool] = False,
98
- ) -> Dict[str, str]:
95
+ high_confidence: bool = False,
96
+ ) -> dict[str, str]:
99
97
  """
100
98
  Create a classification on the given element through the API.
101
99
 
@@ -106,7 +104,7 @@ class ClassificationMixin(object):
106
104
  :returns: The created classification, as returned by the ``CreateClassification`` API endpoint.
107
105
  """
108
106
  assert element and isinstance(
109
- element, (Element, CachedElement)
107
+ element, Element | CachedElement
110
108
  ), "element shouldn't be null and should be an Element or CachedElement"
111
109
  assert ml_class and isinstance(
112
110
  ml_class, str
@@ -180,9 +178,9 @@ class ClassificationMixin(object):
180
178
 
181
179
  def create_classifications(
182
180
  self,
183
- element: Union[Element, CachedElement],
184
- classifications: List[Dict[str, Union[str, float, bool]]],
185
- ) -> List[Dict[str, Union[str, float, bool]]]:
181
+ element: Element | CachedElement,
182
+ classifications: list[dict[str, str | float | bool]],
183
+ ) -> list[dict[str, str | float | bool]]:
186
184
  """
187
185
  Create multiple classifications at once on the given element through the API.
188
186
 
@@ -196,7 +194,7 @@ class ClassificationMixin(object):
196
194
  the ``CreateClassifications`` API endpoint.
197
195
  """
198
196
  assert element and isinstance(
199
- element, (Element, CachedElement)
197
+ element, Element | CachedElement
200
198
  ), "element shouldn't be null and should be an Element or CachedElement"
201
199
  assert classifications and isinstance(
202
200
  classifications, list
@@ -204,17 +202,17 @@ class ClassificationMixin(object):
204
202
 
205
203
  for index, classification in enumerate(classifications):
206
204
  ml_class_id = classification.get("ml_class_id")
207
- assert ml_class_id and isinstance(
208
- ml_class_id, str
205
+ assert (
206
+ ml_class_id and isinstance(ml_class_id, str)
209
207
  ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
210
208
 
211
209
  # Make sure it's a valid UUID
212
210
  try:
213
211
  UUID(ml_class_id)
214
- except ValueError:
212
+ except ValueError as e:
215
213
  raise ValueError(
216
214
  f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
217
- )
215
+ ) from e
218
216
 
219
217
  confidence = classification.get("confidence")
220
218
  assert (
@@ -1,10 +1,9 @@
1
- # -*- coding: utf-8 -*-
2
1
  """
3
2
  BaseWorker methods for datasets.
4
3
  """
5
4
 
5
+ from collections.abc import Iterator
6
6
  from enum import Enum
7
- from typing import Iterator, Tuple
8
7
 
9
8
  from arkindex_worker import logger
10
9
  from arkindex_worker.models import Dataset, Element
@@ -36,7 +35,7 @@ class DatasetState(Enum):
36
35
  """
37
36
 
38
37
 
39
- class DatasetMixin(object):
38
+ class DatasetMixin:
40
39
  def list_process_datasets(self) -> Iterator[Dataset]:
41
40
  """
42
41
  List datasets associated to the worker's process. This helper is not available in developer mode.
@@ -51,7 +50,7 @@ class DatasetMixin(object):
51
50
 
52
51
  return map(Dataset, list(results))
53
52
 
54
- def list_dataset_elements(self, dataset: Dataset) -> Iterator[Tuple[str, Element]]:
53
+ def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
55
54
  """
56
55
  List elements in a dataset.
57
56
 
@@ -1,8 +1,8 @@
1
- # -*- coding: utf-8 -*-
2
1
  """
3
2
  ElementsWorker methods for elements and element types.
4
3
  """
5
- from typing import Dict, Iterable, List, NamedTuple, Optional, Union
4
+ from collections.abc import Iterable
5
+ from typing import NamedTuple
6
6
  from uuid import UUID
7
7
 
8
8
  from peewee import IntegrityError
@@ -28,8 +28,8 @@ class MissingTypeError(Exception):
28
28
  """
29
29
 
30
30
 
31
- class ElementMixin(object):
32
- def create_required_types(self, element_types: List[ElementType]):
31
+ class ElementMixin:
32
+ def create_required_types(self, element_types: list[ElementType]):
33
33
  """Creates given element types in the corpus.
34
34
 
35
35
  :param element_types: The missing element types to create.
@@ -86,9 +86,9 @@ class ElementMixin(object):
86
86
  element: Element,
87
87
  type: str,
88
88
  name: str,
89
- polygon: List[List[Union[int, float]]],
90
- confidence: Optional[float] = None,
91
- slim_output: Optional[bool] = True,
89
+ polygon: list[list[int | float]],
90
+ confidence: float | None = None,
91
+ slim_output: bool = True,
92
92
  ) -> str:
93
93
  """
94
94
  Create a child element on the given element through the API.
@@ -117,7 +117,7 @@ class ElementMixin(object):
117
117
  isinstance(point, list) and len(point) == 2 for point in polygon
118
118
  ), "polygon points should be lists of two items"
119
119
  assert all(
120
- isinstance(coord, (int, float)) for point in polygon for coord in point
120
+ isinstance(coord, int | float) for point in polygon for coord in point
121
121
  ), "polygon points should be lists of two numbers"
122
122
  assert confidence is None or (
123
123
  isinstance(confidence, float) and 0 <= confidence <= 1
@@ -146,11 +146,9 @@ class ElementMixin(object):
146
146
 
147
147
  def create_elements(
148
148
  self,
149
- parent: Union[Element, CachedElement],
150
- elements: List[
151
- Dict[str, Union[str, List[List[Union[int, float]]], float, None]]
152
- ],
153
- ) -> List[Dict[str, str]]:
149
+ parent: Element | CachedElement,
150
+ elements: list[dict[str, str | list[list[int | float]] | float | None]],
151
+ ) -> list[dict[str, str]]:
154
152
  """
155
153
  Create child elements on the given element in a single API request.
156
154
 
@@ -195,18 +193,18 @@ class ElementMixin(object):
195
193
  ), f"Element at index {index} in elements: Should be of type dict"
196
194
 
197
195
  name = element.get("name")
198
- assert name and isinstance(
199
- name, str
196
+ assert (
197
+ name and isinstance(name, str)
200
198
  ), f"Element at index {index} in elements: name shouldn't be null and should be of type str"
201
199
 
202
200
  type = element.get("type")
203
- assert type and isinstance(
204
- type, str
201
+ assert (
202
+ type and isinstance(type, str)
205
203
  ), f"Element at index {index} in elements: type shouldn't be null and should be of type str"
206
204
 
207
205
  polygon = element.get("polygon")
208
- assert polygon and isinstance(
209
- polygon, list
206
+ assert (
207
+ polygon and isinstance(polygon, list)
210
208
  ), f"Element at index {index} in elements: polygon shouldn't be null and should be of type list"
211
209
  assert (
212
210
  len(polygon) >= 3
@@ -215,12 +213,13 @@ class ElementMixin(object):
215
213
  isinstance(point, list) and len(point) == 2 for point in polygon
216
214
  ), f"Element at index {index} in elements: polygon points should be lists of two items"
217
215
  assert all(
218
- isinstance(coord, (int, float)) for point in polygon for coord in point
216
+ isinstance(coord, int | float) for point in polygon for coord in point
219
217
  ), f"Element at index {index} in elements: polygon points should be lists of two numbers"
220
218
 
221
219
  confidence = element.get("confidence")
222
- assert confidence is None or (
223
- isinstance(confidence, float) and 0 <= confidence <= 1
220
+ assert (
221
+ confidence is None
222
+ or (isinstance(confidence, float) and 0 <= confidence <= 1)
224
223
  ), f"Element at index {index} in elements: confidence should be None or a float in [0..1] range"
225
224
 
226
225
  if self.is_read_only:
@@ -272,7 +271,7 @@ class ElementMixin(object):
272
271
  return created_ids
273
272
 
274
273
  def partial_update_element(
275
- self, element: Union[Element, CachedElement], **kwargs
274
+ self, element: Element | CachedElement, **kwargs
276
275
  ) -> dict:
277
276
  """
278
277
  Partially updates an element through the API.
@@ -292,7 +291,7 @@ class ElementMixin(object):
292
291
  :returns: A dict from the ``PartialUpdateElement`` API endpoint,
293
292
  """
294
293
  assert element and isinstance(
295
- element, (Element, CachedElement)
294
+ element, Element | CachedElement
296
295
  ), "element shouldn't be null and should be an Element or CachedElement"
297
296
 
298
297
  if "type" in kwargs:
@@ -309,7 +308,7 @@ class ElementMixin(object):
309
308
  isinstance(point, list) and len(point) == 2 for point in polygon
310
309
  ), "polygon points should be lists of two items"
311
310
  assert all(
312
- isinstance(coord, (int, float)) for point in polygon for coord in point
311
+ isinstance(coord, int | float) for point in polygon for coord in point
313
312
  ), "polygon points should be lists of two numbers"
314
313
 
315
314
  if "confidence" in kwargs:
@@ -363,21 +362,21 @@ class ElementMixin(object):
363
362
 
364
363
  def list_element_children(
365
364
  self,
366
- element: Union[Element, CachedElement],
367
- folder: Optional[bool] = None,
368
- name: Optional[str] = None,
369
- recursive: Optional[bool] = None,
370
- transcription_worker_version: Optional[Union[str, bool]] = None,
371
- transcription_worker_run: Optional[Union[str, bool]] = None,
372
- type: Optional[str] = None,
373
- with_classes: Optional[bool] = None,
374
- with_corpus: Optional[bool] = None,
375
- with_metadata: Optional[bool] = None,
376
- with_has_children: Optional[bool] = None,
377
- with_zone: Optional[bool] = None,
378
- worker_version: Optional[Union[str, bool]] = None,
379
- worker_run: Optional[Union[str, bool]] = None,
380
- ) -> Union[Iterable[dict], Iterable[CachedElement]]:
365
+ element: Element | CachedElement,
366
+ folder: bool | None = None,
367
+ name: str | None = None,
368
+ recursive: bool | None = None,
369
+ transcription_worker_version: str | bool | None = None,
370
+ transcription_worker_run: str | bool | None = None,
371
+ type: str | None = None,
372
+ with_classes: bool | None = None,
373
+ with_corpus: bool | None = None,
374
+ with_metadata: bool | None = None,
375
+ with_has_children: bool | None = None,
376
+ with_zone: bool | None = None,
377
+ worker_version: str | bool | None = None,
378
+ worker_run: str | bool | None = None,
379
+ ) -> Iterable[dict] | Iterable[CachedElement]:
381
380
  """
382
381
  List children of an element.
383
382
 
@@ -412,7 +411,7 @@ class ElementMixin(object):
412
411
  or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
413
412
  """
414
413
  assert element and isinstance(
415
- element, (Element, CachedElement)
414
+ element, Element | CachedElement
416
415
  ), "element shouldn't be null and should be an Element or CachedElement"
417
416
  query_params = {}
418
417
  if folder is not None:
@@ -426,7 +425,7 @@ class ElementMixin(object):
426
425
  query_params["recursive"] = recursive
427
426
  if transcription_worker_version is not None:
428
427
  assert isinstance(
429
- transcription_worker_version, (str, bool)
428
+ transcription_worker_version, str | bool
430
429
  ), "transcription_worker_version should be of type str or bool"
431
430
  if isinstance(transcription_worker_version, bool):
432
431
  assert (
@@ -435,7 +434,7 @@ class ElementMixin(object):
435
434
  query_params["transcription_worker_version"] = transcription_worker_version
436
435
  if transcription_worker_run is not None:
437
436
  assert isinstance(
438
- transcription_worker_run, (str, bool)
437
+ transcription_worker_run, str | bool
439
438
  ), "transcription_worker_run should be of type str or bool"
440
439
  if isinstance(transcription_worker_run, bool):
441
440
  assert (
@@ -466,7 +465,7 @@ class ElementMixin(object):
466
465
  query_params["with_zone"] = with_zone
467
466
  if worker_version is not None:
468
467
  assert isinstance(
469
- worker_version, (str, bool)
468
+ worker_version, str | bool
470
469
  ), "worker_version should be of type str or bool"
471
470
  if isinstance(worker_version, bool):
472
471
  assert (
@@ -475,7 +474,7 @@ class ElementMixin(object):
475
474
  query_params["worker_version"] = worker_version
476
475
  if worker_run is not None:
477
476
  assert isinstance(
478
- worker_run, (str, bool)
477
+ worker_run, str | bool
479
478
  ), "worker_run should be of type str or bool"
480
479
  if isinstance(worker_run, bool):
481
480
  assert (
@@ -485,11 +484,14 @@ class ElementMixin(object):
485
484
 
486
485
  if self.use_cache:
487
486
  # Checking that we only received query_params handled by the cache
488
- assert set(query_params.keys()) <= {
489
- "type",
490
- "worker_version",
491
- "worker_run",
492
- }, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
487
+ assert (
488
+ set(query_params.keys())
489
+ <= {
490
+ "type",
491
+ "worker_version",
492
+ "worker_run",
493
+ }
494
+ ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
493
495
 
494
496
  query = CachedElement.select().where(CachedElement.parent_id == element.id)
495
497
  if type:
@@ -522,21 +524,21 @@ class ElementMixin(object):
522
524
 
523
525
  def list_element_parents(
524
526
  self,
525
- element: Union[Element, CachedElement],
526
- folder: Optional[bool] = None,
527
- name: Optional[str] = None,
528
- recursive: Optional[bool] = None,
529
- transcription_worker_version: Optional[Union[str, bool]] = None,
530
- transcription_worker_run: Optional[Union[str, bool]] = None,
531
- type: Optional[str] = None,
532
- with_classes: Optional[bool] = None,
533
- with_corpus: Optional[bool] = None,
534
- with_metadata: Optional[bool] = None,
535
- with_has_children: Optional[bool] = None,
536
- with_zone: Optional[bool] = None,
537
- worker_version: Optional[Union[str, bool]] = None,
538
- worker_run: Optional[Union[str, bool]] = None,
539
- ) -> Union[Iterable[dict], Iterable[CachedElement]]:
527
+ element: Element | CachedElement,
528
+ folder: bool | None = None,
529
+ name: str | None = None,
530
+ recursive: bool | None = None,
531
+ transcription_worker_version: str | bool | None = None,
532
+ transcription_worker_run: str | bool | None = None,
533
+ type: str | None = None,
534
+ with_classes: bool | None = None,
535
+ with_corpus: bool | None = None,
536
+ with_metadata: bool | None = None,
537
+ with_has_children: bool | None = None,
538
+ with_zone: bool | None = None,
539
+ worker_version: str | bool | None = None,
540
+ worker_run: str | bool | None = None,
541
+ ) -> Iterable[dict] | Iterable[CachedElement]:
540
542
  """
541
543
  List parents of an element.
542
544
 
@@ -571,7 +573,7 @@ class ElementMixin(object):
571
573
  or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
572
574
  """
573
575
  assert element and isinstance(
574
- element, (Element, CachedElement)
576
+ element, Element | CachedElement
575
577
  ), "element shouldn't be null and should be an Element or CachedElement"
576
578
  query_params = {}
577
579
  if folder is not None:
@@ -585,7 +587,7 @@ class ElementMixin(object):
585
587
  query_params["recursive"] = recursive
586
588
  if transcription_worker_version is not None:
587
589
  assert isinstance(
588
- transcription_worker_version, (str, bool)
590
+ transcription_worker_version, str | bool
589
591
  ), "transcription_worker_version should be of type str or bool"
590
592
  if isinstance(transcription_worker_version, bool):
591
593
  assert (
@@ -594,7 +596,7 @@ class ElementMixin(object):
594
596
  query_params["transcription_worker_version"] = transcription_worker_version
595
597
  if transcription_worker_run is not None:
596
598
  assert isinstance(
597
- transcription_worker_run, (str, bool)
599
+ transcription_worker_run, str | bool
598
600
  ), "transcription_worker_run should be of type str or bool"
599
601
  if isinstance(transcription_worker_run, bool):
600
602
  assert (
@@ -625,7 +627,7 @@ class ElementMixin(object):
625
627
  query_params["with_zone"] = with_zone
626
628
  if worker_version is not None:
627
629
  assert isinstance(
628
- worker_version, (str, bool)
630
+ worker_version, str | bool
629
631
  ), "worker_version should be of type str or bool"
630
632
  if isinstance(worker_version, bool):
631
633
  assert (
@@ -634,7 +636,7 @@ class ElementMixin(object):
634
636
  query_params["worker_version"] = worker_version
635
637
  if worker_run is not None:
636
638
  assert isinstance(
637
- worker_run, (str, bool)
639
+ worker_run, str | bool
638
640
  ), "worker_run should be of type str or bool"
639
641
  if isinstance(worker_run, bool):
640
642
  assert (
@@ -644,11 +646,14 @@ class ElementMixin(object):
644
646
 
645
647
  if self.use_cache:
646
648
  # Checking that we only received query_params handled by the cache
647
- assert set(query_params.keys()) <= {
648
- "type",
649
- "worker_version",
650
- "worker_run",
651
- }, "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
649
+ assert (
650
+ set(query_params.keys())
651
+ <= {
652
+ "type",
653
+ "worker_version",
654
+ "worker_run",
655
+ }
656
+ ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
652
657
 
653
658
  parent_ids = CachedElement.select(CachedElement.parent_id).where(
654
659
  CachedElement.id == element.id
@@ -1,10 +1,9 @@
1
- # -*- coding: utf-8 -*-
2
1
  """
3
2
  ElementsWorker methods for entities.
4
3
  """
5
4
 
6
5
  from operator import itemgetter
7
- from typing import Dict, List, Optional, TypedDict, Union
6
+ from typing import TypedDict
8
7
 
9
8
  from peewee import IntegrityError
10
9
 
@@ -12,16 +11,13 @@ from arkindex_worker import logger
12
11
  from arkindex_worker.cache import CachedEntity, CachedTranscriptionEntity
13
12
  from arkindex_worker.models import Element, Transcription
14
13
 
15
- Entity = TypedDict(
16
- "Entity",
17
- {
18
- "name": str,
19
- "type_id": str,
20
- "length": int,
21
- "offset": int,
22
- "confidence": Optional[float],
23
- },
24
- )
14
+
15
+ class Entity(TypedDict):
16
+ name: str
17
+ type_id: str
18
+ length: int
19
+ offset: int
20
+ confidence: float | None
25
21
 
26
22
 
27
23
  class MissingEntityType(Exception):
@@ -31,9 +27,9 @@ class MissingEntityType(Exception):
31
27
  """
32
28
 
33
29
 
34
- class EntityMixin(object):
30
+ class EntityMixin:
35
31
  def check_required_entity_types(
36
- self, entity_types: List[str], create_missing: bool = True
32
+ self, entity_types: list[str], create_missing: bool = True
37
33
  ):
38
34
  """Checks that every entity type needed is available in the corpus.
39
35
  Missing ones may be created automatically if needed.
@@ -71,7 +67,7 @@ class EntityMixin(object):
71
67
  self,
72
68
  name: str,
73
69
  type: str,
74
- metas=dict(),
70
+ metas=None,
75
71
  validated=None,
76
72
  ):
77
73
  """
@@ -87,6 +83,7 @@ class EntityMixin(object):
87
83
  assert type and isinstance(
88
84
  type, str
89
85
  ), "type shouldn't be null and should be of type str"
86
+ metas = metas or {}
90
87
  if metas:
91
88
  assert isinstance(metas, dict), "metas should be of type dict"
92
89
  if validated is not None:
@@ -140,8 +137,8 @@ class EntityMixin(object):
140
137
  entity: str,
141
138
  offset: int,
142
139
  length: int,
143
- confidence: Optional[float] = None,
144
- ) -> Optional[Dict[str, Union[str, int]]]:
140
+ confidence: float | None = None,
141
+ ) -> dict[str, str | int] | None:
145
142
  """
146
143
  Create a link between an existing entity and an existing transcription.
147
144
  If cache support is enabled, a `CachedTranscriptionEntity` will also be created.
@@ -211,8 +208,8 @@ class EntityMixin(object):
211
208
  def create_transcription_entities(
212
209
  self,
213
210
  transcription: Transcription,
214
- entities: List[Entity],
215
- ) -> List[Dict[str, str]]:
211
+ entities: list[Entity],
212
+ ) -> list[dict[str, str]]:
216
213
  """
217
214
  Create multiple entities attached to a transcription in a single API request.
218
215
 
@@ -250,13 +247,13 @@ class EntityMixin(object):
250
247
  ), f"Entity at index {index} in entities: Should be of type dict"
251
248
 
252
249
  name = entity.get("name")
253
- assert name and isinstance(
254
- name, str
250
+ assert (
251
+ name and isinstance(name, str)
255
252
  ), f"Entity at index {index} in entities: name shouldn't be null and should be of type str"
256
253
 
257
254
  type_id = entity.get("type_id")
258
- assert type_id and isinstance(
259
- type_id, str
255
+ assert (
256
+ type_id and isinstance(type_id, str)
260
257
  ), f"Entity at index {index} in entities: type_id shouldn't be null and should be of type str"
261
258
 
262
259
  offset = entity.get("offset")
@@ -270,8 +267,9 @@ class EntityMixin(object):
270
267
  ), f"Entity at index {index} in entities: length shouldn't be null and should be a strictly positive integer"
271
268
 
272
269
  confidence = entity.get("confidence")
273
- assert confidence is None or (
274
- isinstance(confidence, float) and 0 <= confidence <= 1
270
+ assert (
271
+ confidence is None
272
+ or (isinstance(confidence, float) and 0 <= confidence <= 1)
275
273
  ), f"Entity at index {index} in entities: confidence should be None or a float in [0..1] range"
276
274
 
277
275
  assert len(entities) == len(
@@ -298,7 +296,7 @@ class EntityMixin(object):
298
296
  def list_transcription_entities(
299
297
  self,
300
298
  transcription: Transcription,
301
- worker_version: Optional[Union[str, bool]] = None,
299
+ worker_version: str | bool | None = None,
302
300
  ):
303
301
  """
304
302
  List existing entities on a transcription
@@ -314,7 +312,7 @@ class EntityMixin(object):
314
312
 
315
313
  if worker_version is not None:
316
314
  assert isinstance(
317
- worker_version, (str, bool)
315
+ worker_version, str | bool
318
316
  ), "worker_version should be of type str or bool"
319
317
 
320
318
  if isinstance(worker_version, bool):
@@ -329,8 +327,8 @@ class EntityMixin(object):
329
327
 
330
328
  def list_corpus_entities(
331
329
  self,
332
- name: Optional[str] = None,
333
- parent: Optional[Element] = None,
330
+ name: str | None = None,
331
+ parent: Element | None = None,
334
332
  ):
335
333
  """
336
334
  List all entities in the worker's corpus