arkindex-base-worker 0.4.0a1__py3-none-any.whl → 0.4.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arkindex-base-worker
3
- Version: 0.4.0a1
3
+ Version: 0.4.0b1
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -41,17 +41,17 @@ Requires-Python: >=3.10
41
41
  Description-Content-Type: text/markdown
42
42
  License-File: LICENSE
43
43
  Requires-Dist: peewee ~=3.17
44
- Requires-Dist: Pillow ==10.3.0
44
+ Requires-Dist: Pillow ==10.4.0
45
45
  Requires-Dist: python-gnupg ==0.5.2
46
- Requires-Dist: shapely ==2.0.3
46
+ Requires-Dist: shapely ==2.0.5
47
47
  Requires-Dist: teklia-toolbox ==0.1.5
48
48
  Requires-Dist: zstandard ==0.22.0
49
49
  Provides-Extra: docs
50
- Requires-Dist: black ==24.4.0 ; extra == 'docs'
51
- Requires-Dist: mkdocs-material ==9.5.17 ; extra == 'docs'
52
- Requires-Dist: mkdocstrings-python ==1.9.2 ; extra == 'docs'
50
+ Requires-Dist: black ==24.4.2 ; extra == 'docs'
51
+ Requires-Dist: mkdocs-material ==9.5.31 ; extra == 'docs'
52
+ Requires-Dist: mkdocstrings-python ==1.10.7 ; extra == 'docs'
53
53
  Provides-Extra: tests
54
- Requires-Dist: pytest ==8.1.1 ; extra == 'tests'
54
+ Requires-Dist: pytest ==8.3.2 ; extra == 'tests'
55
55
  Requires-Dist: pytest-mock ==3.14.0 ; extra == 'tests'
56
56
  Requires-Dist: pytest-responses ==0.5.1 ; extra == 'tests'
57
57
 
@@ -3,40 +3,40 @@ arkindex_worker/cache.py,sha256=FTlB0coXofn5zTNRTcVIvh709mcw4a1bPGqkwWjKs3w,1124
3
3
  arkindex_worker/image.py,sha256=5ymIGaTm2D7Sp2YYQkbuheuGnx5VJo0_AzYAEIvNGhs,14267
4
4
  arkindex_worker/models.py,sha256=bPQzGZNs5a6z6DEcygsa8T33VOqPlMUbwKzHqlKzwbw,9923
5
5
  arkindex_worker/utils.py,sha256=KXWIACda7D3IpdToaAplLoAgnCK8bKWw7aWUyq-IWUA,7211
6
- arkindex_worker/worker/__init__.py,sha256=3sJ_EPB7yG-kPfgunbm2B7B7DzoeOi5ZNpQwC_3QuZ0,19429
7
- arkindex_worker/worker/base.py,sha256=c9u37W1BNHt5RoQV2ZrYUYv6tBs-CjiSgUAAg7p7GA0,18876
8
- arkindex_worker/worker/classification.py,sha256=JVz-6YEeuavOy7zGfQi4nE_wpj9hwMUZDXTem-hXQY8,10328
9
- arkindex_worker/worker/corpus.py,sha256=ZHAAYE4PRPXqqaZm71wjrsxYETFqU6TAz-3VYgIXzac,1794
10
- arkindex_worker/worker/dataset.py,sha256=roX2IMMNA-icteTtRADiFSZiZSRPClqS62ZPJm9s2JI,2923
11
- arkindex_worker/worker/element.py,sha256=AWK3YJSHWy3j4ajntJloi_2X4zxsgXZ6c6dzphgq3OI,33848
12
- arkindex_worker/worker/entity.py,sha256=suhycfikC9oTPEWmX48_cnvFEw-Wu5zBA8n_00K4KUk,14714
6
+ arkindex_worker/worker/__init__.py,sha256=belqRtbs0raTdFJoQJoGBoDJkUOrEE3wyXv90f85bTs,19760
7
+ arkindex_worker/worker/base.py,sha256=JStHpwSP3bis9LLvV2C2n6GTWtLUVIDA9JPgPJEt17o,18717
8
+ arkindex_worker/worker/classification.py,sha256=4YAY4weF6kMSMsoYiz6oia3SN21PzRR1bAdhMJCGBbw,10361
9
+ arkindex_worker/worker/corpus.py,sha256=s9bCxOszJMwRq1WWAmKjWq888mjDfbaJ18Wo7h-rNOw,1827
10
+ arkindex_worker/worker/dataset.py,sha256=UXElhhARca9m7Himp-yxD5dAqWbdxDKWOUJUGgeCZXI,2934
11
+ arkindex_worker/worker/element.py,sha256=kMaJNXEfZbFBK4YYc3XLqyGvPyNvJs7mJG2T_a1c7D0,34294
12
+ arkindex_worker/worker/entity.py,sha256=BbQp56kxTPmOQI482TUFZ8KOXISj7KtQAyHRT0CmedM,14744
13
13
  arkindex_worker/worker/image.py,sha256=t_Az6IGnj0EZyvcA4XxfPikOUjn_pztgsyxTkFZhaXU,621
14
- arkindex_worker/worker/metadata.py,sha256=Bouuc_JaXogKykVXOTKDVP3tX--OUQeHoazxIGrGrJI,6702
15
- arkindex_worker/worker/task.py,sha256=cz3wJNPgogZv1lm_3lm7WScitQtYQtL6H6I7Xokq208,1475
16
- arkindex_worker/worker/training.py,sha256=hkwCBjVE4bByXzHUmCZF73Bl5JxARdXWjYgFE6ydAT0,10749
17
- arkindex_worker/worker/transcription.py,sha256=6R7ofcGnNqX4rjT0kRKIE-G9FHq2TJ1tfztNM5sTqYE,20464
18
- arkindex_worker/worker/version.py,sha256=cs2pdlDxpKRO2Oldvcu54w-D_DQhf1cdeEt4tKX_QYs,1927
14
+ arkindex_worker/worker/metadata.py,sha256=PFO0oJc8N91HIpj4yHLscwGW5UFRXtuyQYfEXW27-WQ,6724
15
+ arkindex_worker/worker/task.py,sha256=1O9zrWXxe3na3TOcoHX5Pxn1875v7EU08BSsCPnb62g,1519
16
+ arkindex_worker/worker/training.py,sha256=qnBFEk11JOWWPLTbjF-lZ9iFBdTPpQzZAzQ9a03J1j4,10874
17
+ arkindex_worker/worker/transcription.py,sha256=9TC3E6zu_CnQKWsaTAzI83TrSfMuzh3KSMOCLdbEG18,20497
18
+ arkindex_worker/worker/version.py,sha256=JIT7OI3Mo7RPkNrjOB9hfqrsG-FYygz_zi4l8PbkuAo,1960
19
19
  hooks/pre_gen_project.py,sha256=xQJERv3vv9VzIqcBHI281eeWLWREXUF4mMw7PvJHHXM,269
20
20
  tests/__init__.py,sha256=6aeTMHf4q_dKY4jIZWg1KT70VKaLvVlzCxh-Uu_cWiQ,241
21
21
  tests/conftest.py,sha256=-ZQTV4rg7TgW84-5Ioqndqv8byNILfDOpyUt8wecEiI,21967
22
- tests/test_base_worker.py,sha256=qG45O3nPbASXN5a5RadXU1BAXn3EIaTK6Hvjj3s4Ozs,24292
22
+ tests/test_base_worker.py,sha256=LdFV0LFdNU2IOyEKlX59MB1kuyxHCuhy4Tm7eE_iPiU,24281
23
23
  tests/test_cache.py,sha256=ii0gyr0DrG7ChEs7pmT8hMdSguAOAcCze4bRMiFQxuk,10640
24
24
  tests/test_dataset_worker.py,sha256=d9HG36qnO5HXu9vQ0UTBvdTSRR21FVq1FNoXM-vZbPk,22105
25
25
  tests/test_element.py,sha256=2G9M15TLxQRmvrWM9Kw2ucnElh4kSv_oF_5FYwwAxTY,13181
26
26
  tests/test_image.py,sha256=Fs9vKYgQ7mEFylbzI4YIO_JyOLeAcs-WxUXpzewxCd8,16188
27
- tests/test_merge.py,sha256=Q4zCbtZbe0wBfqE56gvAD06c6pDuhqnjKaioFqIgAQw,8331
27
+ tests/test_merge.py,sha256=FMdpsm_ncHNmIvOrJ1vcwlyn8o9-SPcpFTcbAsXwK-w,8320
28
28
  tests/test_utils.py,sha256=vpeHMeL7bJQonv5ZEbJmlJikqVKn5VWlVEbvmYFzDYA,1650
29
29
  tests/test_elements_worker/__init__.py,sha256=Fh4nkbbyJSMv_VtjQxnWrOqTnxXaaWI8S9WU0VrzCHs,179
30
30
  tests/test_elements_worker/test_classifications.py,sha256=DYRKhPpplFp144GCXKyFG1hz4Ra9vk5FiAN6dhfMP6k,25511
31
31
  tests/test_elements_worker/test_cli.py,sha256=a23i1pUDbXi23MUtbWwGEcLLrmc_YlrbDgOG3h66wLM,2620
32
32
  tests/test_elements_worker/test_corpus.py,sha256=c_LUHvkJIYgk_wXF06VQPNOoWfiZ06XpjOXrJ7MRiBc,4479
33
33
  tests/test_elements_worker/test_dataset.py,sha256=lSXqubhg1EEq2Y2goE8Y2RYaqIpM9Iejq6fGNW2BczU,11411
34
- tests/test_elements_worker/test_elements.py,sha256=2_kdeo99biCH3Uez6HB8ltS_iIizZ7ir5uOkFjIXfjM,84812
34
+ tests/test_elements_worker/test_elements.py,sha256=HH8jUU4xHp5gXcrGJLQlo4kLFh7oYfMxO3QQEYo2itg,84885
35
35
  tests/test_elements_worker/test_entities.py,sha256=jirb_IKAMqMhwxeDgjO-rsr1fTP9GdXwuyhncUjCJFM,33494
36
36
  tests/test_elements_worker/test_image.py,sha256=_E3UGdDOwTo1MW5KMS81PrdeSPBPWinWYoQPNy2F9Ro,2077
37
37
  tests/test_elements_worker/test_metadata.py,sha256=-cZhlVAh4o2uRnHz8fPf_thfavRnJrtJYN_p4BmHISU,17566
38
38
  tests/test_elements_worker/test_task.py,sha256=7Sr3fbjdgWUXJUhJEiC9CwnbhQIQX3rCInmHMIrmA38,5573
39
- tests/test_elements_worker/test_training.py,sha256=wVYWdMdeSA6T2XyhH5AJJNGemYq3LOViiZvj0dblACA,9468
39
+ tests/test_elements_worker/test_training.py,sha256=Qxi9EzGr_uKcn2Fh5aE6jNrq1K8QKLiOiSew4upASPs,8721
40
40
  tests/test_elements_worker/test_transcriptions.py,sha256=7HDkIW8IDK7pKAfpSdAPB7YOyKyeBJTn2_alvVK46SA,72411
41
41
  tests/test_elements_worker/test_worker.py,sha256=AwdP8uSXNQ_SJavXxJV2s3_J3OiCafShVjMV1dgt4xo,17162
42
42
  worker-demo/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -44,8 +44,8 @@ worker-demo/tests/conftest.py,sha256=XzNMNeg6pmABUAH8jN6eZTlZSFGLYjS3-DTXjiRN6Yc
44
44
  worker-demo/tests/test_worker.py,sha256=3DLd4NRK4bfyatG5P_PK4k9P9tJHx9XQq5_ryFEEFVg,304
45
45
  worker-demo/worker_demo/__init__.py,sha256=2BPomV8ZMNf3YXJgloatKeHQCE6QOkwmsHGkO6MkQuM,125
46
46
  worker-demo/worker_demo/worker.py,sha256=Rt-DjWa5iBP08k58NDZMfeyPuFbtNcbX6nc5jFX7GNo,440
47
- arkindex_base_worker-0.4.0a1.dist-info/LICENSE,sha256=NVshRi1efwVezMfW7xXYLrdDr2Li1AfwfGOd5WuH1kQ,1063
48
- arkindex_base_worker-0.4.0a1.dist-info/METADATA,sha256=PBTlbhWTCvvkkcGqQew6yvJIdncf9mKZ71yI_QSX2iM,3269
49
- arkindex_base_worker-0.4.0a1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
50
- arkindex_base_worker-0.4.0a1.dist-info/top_level.txt,sha256=58NuslgxQC2vT4DiqZEgO4JqJRrYa2yeNI9QvkbfGQU,40
51
- arkindex_base_worker-0.4.0a1.dist-info/RECORD,,
47
+ arkindex_base_worker-0.4.0b1.dist-info/LICENSE,sha256=NVshRi1efwVezMfW7xXYLrdDr2Li1AfwfGOd5WuH1kQ,1063
48
+ arkindex_base_worker-0.4.0b1.dist-info/METADATA,sha256=02rPRlcFlghY1Trb-_trpdCCMME1A9FmPzrY8wzzLDg,3270
49
+ arkindex_base_worker-0.4.0b1.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
50
+ arkindex_base_worker-0.4.0b1.dist-info/top_level.txt,sha256=58NuslgxQC2vT4DiqZEgO4JqJRrYa2yeNI9QvkbfGQU,40
51
+ arkindex_base_worker-0.4.0b1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (72.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -83,7 +83,20 @@ class ElementsWorker(
83
83
  """
84
84
  super().__init__(description, support_cache)
85
85
 
86
- # Add mandatory argument to process elements
86
+ self.classes = {}
87
+
88
+ self.entity_types = {}
89
+ """Known and available entity types in processed corpus
90
+ """
91
+
92
+ self.corpus_types = {}
93
+ """Known and available element types in processed corpus
94
+ """
95
+
96
+ self._worker_version_cache = {}
97
+
98
+ def add_arguments(self):
99
+ """Define specific ``argparse`` arguments for this worker"""
87
100
  self.parser.add_argument(
88
101
  "--elements-list",
89
102
  help="JSON elements list to use",
@@ -97,14 +110,6 @@ class ElementsWorker(
97
110
  help="One or more Arkindex element ID",
98
111
  )
99
112
 
100
- self.classes = {}
101
-
102
- self.entity_types = {}
103
- """Known and available entity types in processed corpus
104
- """
105
-
106
- self._worker_version_cache = {}
107
-
108
113
  def list_elements(self) -> Iterable[CachedElement] | list[str]:
109
114
  """
110
115
  List the elements to be processed, either from the CLI arguments or
@@ -222,7 +227,9 @@ class ElementsWorker(
222
227
  element = item
223
228
  else:
224
229
  # Load element using the Arkindex API
225
- element = Element(**self.request("RetrieveElement", id=item))
230
+ element = Element(
231
+ **self.api_client.request("RetrieveElement", id=item)
232
+ )
226
233
 
227
234
  logger.info(f"Processing {element} ({i}/{count})")
228
235
 
@@ -301,7 +308,7 @@ class ElementsWorker(
301
308
  assert isinstance(state, ActivityState), "state should be an ActivityState"
302
309
 
303
310
  try:
304
- self.request(
311
+ self.api_client.request(
305
312
  "UpdateWorkerActivity",
306
313
  id=self.worker_run_id,
307
314
  body={
@@ -376,6 +383,8 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
376
383
  # Set as an instance variable as dataset workers might use it to easily extract its content
377
384
  self.downloaded_dataset_artifact: Path | None = None
378
385
 
386
+ def add_arguments(self):
387
+ """Define specific ``argparse`` arguments for this worker"""
379
388
  self.parser.add_argument(
380
389
  "--set",
381
390
  type=check_dataset_set,
@@ -472,7 +481,7 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
472
481
  # Retrieving dataset information is not already cached
473
482
  if dataset_id not in datasets:
474
483
  datasets[dataset_id] = Dataset(
475
- **self.request("RetrieveDataset", id=dataset_id)
484
+ **self.api_client.request("RetrieveDataset", id=dataset_id)
476
485
  )
477
486
 
478
487
  yield Set(name=set_name, dataset=datasets[dataset_id])
@@ -231,7 +231,7 @@ class BaseWorker:
231
231
  logger.debug("Debug output enabled")
232
232
 
233
233
  # Load worker run information
234
- worker_run = self.request("RetrieveWorkerRun", id=self.worker_run_id)
234
+ worker_run = self.api_client.request("RetrieveWorkerRun", id=self.worker_run_id)
235
235
 
236
236
  # Load process information
237
237
  self.process_information = worker_run["process"]
@@ -290,7 +290,7 @@ class BaseWorker:
290
290
  if self.support_cache and self.args.database is not None:
291
291
  self.use_cache = True
292
292
  elif self.support_cache and self.task_id:
293
- task = self.request("RetrieveTaskFromAgent", id=self.task_id)
293
+ task = self.api_client.request("RetrieveTask", id=self.task_id)
294
294
  self.task_parents = task["parents"]
295
295
  paths = self.find_parents_file_paths(Path("db.sqlite"))
296
296
  self.use_cache = len(paths) > 0
@@ -331,7 +331,7 @@ class BaseWorker:
331
331
 
332
332
  # Load from the backend
333
333
  try:
334
- resp = self.request("RetrieveSecret", name=str(name))
334
+ resp = self.api_client.request("RetrieveSecret", name=str(name))
335
335
  secret = resp["content"]
336
336
  logging.info(f"Loaded API secret {name}")
337
337
  except ErrorResponse as e:
@@ -471,12 +471,6 @@ class BaseWorker:
471
471
  # Clean up
472
472
  shutil.rmtree(base_extracted_path)
473
473
 
474
- def request(self, *args, **kwargs):
475
- """
476
- Wrapper around the ``ArkindexClient.request`` method.
477
- """
478
- return self.api_client.request(*args, **kwargs)
479
-
480
474
  def add_arguments(self):
481
475
  """Override this method to add ``argparse`` arguments to this worker"""
482
476
 
@@ -39,7 +39,7 @@ class ClassificationMixin:
39
39
  if ml_class_id is None:
40
40
  logger.info(f"Creating ML class {ml_class} on corpus {self.corpus_id}")
41
41
  try:
42
- response = self.request(
42
+ response = self.api_client.request(
43
43
  "CreateMLClass", id=self.corpus_id, body={"name": ml_class}
44
44
  )
45
45
  ml_class_id = self.classes[ml_class] = response["id"]
@@ -119,7 +119,7 @@ class ClassificationMixin:
119
119
  )
120
120
  return
121
121
  try:
122
- created = self.request(
122
+ created = self.api_client.request(
123
123
  "CreateClassification",
124
124
  body={
125
125
  "element": str(element.id),
@@ -220,7 +220,7 @@ class ClassificationMixin:
220
220
  )
221
221
  return
222
222
 
223
- created_cls = self.request(
223
+ created_cls = self.api_client.request(
224
224
  "CreateClassifications",
225
225
  body={
226
226
  "parent": str(element.id),
@@ -63,7 +63,9 @@ class CorpusMixin:
63
63
  # Download latest export
64
64
  export_id: str = exports[0]["id"]
65
65
  logger.info(f"Downloading export ({export_id})...")
66
- export: _TemporaryFileWrapper = self.request("DownloadExport", id=export_id)
66
+ export: _TemporaryFileWrapper = self.api_client.request(
67
+ "DownloadExport", id=export_id
68
+ )
67
69
  logger.info(f"Downloaded export ({export_id}) @ `{export.name}`")
68
70
 
69
71
  return export
@@ -93,7 +93,7 @@ class DatasetMixin:
93
93
  logger.warning("Cannot update dataset as this worker is in read-only mode")
94
94
  return
95
95
 
96
- updated_dataset = self.request(
96
+ updated_dataset = self.api_client.request(
97
97
  "PartialUpdateDataset",
98
98
  id=dataset.id,
99
99
  body={"state": state.value},
@@ -31,6 +31,21 @@ class MissingTypeError(Exception):
31
31
 
32
32
 
33
33
  class ElementMixin:
34
+ def list_corpus_types(self):
35
+ """
36
+ Loads available element types in corpus.
37
+ """
38
+ self.corpus_types = {
39
+ element_type["slug"]: element_type
40
+ for element_type in self.api_client.request(
41
+ "RetrieveCorpus", id=self.corpus_id
42
+ )["types"]
43
+ }
44
+ count = len(self.corpus_types)
45
+ logger.info(
46
+ f'Loaded {count} element type{"s"[:count>1]} in corpus ({self.corpus_id}).'
47
+ )
48
+
34
49
  @unsupported_cache
35
50
  def create_required_types(self, element_types: list[ElementType]):
36
51
  """Creates given element types in the corpus.
@@ -38,7 +53,7 @@ class ElementMixin:
38
53
  :param element_types: The missing element types to create.
39
54
  """
40
55
  for element_type in element_types:
41
- self.request(
56
+ self.api_client.request(
42
57
  "CreateElementType",
43
58
  body={
44
59
  "slug": element_type.slug,
@@ -66,10 +81,10 @@ class ElementMixin:
66
81
  isinstance(slug, str) for slug in type_slugs
67
82
  ), "Element type slugs must be strings."
68
83
 
69
- corpus = self.request("RetrieveCorpus", id=self.corpus_id)
70
- available_slugs = {element_type["slug"] for element_type in corpus["types"]}
71
- missing_slugs = set(type_slugs) - available_slugs
84
+ if not self.corpus_types:
85
+ self.list_corpus_types()
72
86
 
87
+ missing_slugs = set(type_slugs) - set(self.corpus_types)
73
88
  if missing_slugs:
74
89
  if create_missing:
75
90
  self.create_required_types(
@@ -79,7 +94,7 @@ class ElementMixin:
79
94
  )
80
95
  else:
81
96
  raise MissingTypeError(
82
- f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in the {corpus["name"]} corpus ({corpus["id"]}).'
97
+ f'Element type(s) {", ".join(sorted(missing_slugs))} were not found in corpus ({self.corpus_id}).'
83
98
  )
84
99
 
85
100
  return True
@@ -145,7 +160,7 @@ class ElementMixin:
145
160
  logger.warning("Cannot create element as this worker is in read-only mode")
146
161
  return
147
162
 
148
- sub_element = self.request(
163
+ sub_element = self.api_client.request(
149
164
  "CreateElement",
150
165
  body={
151
166
  "type": type,
@@ -243,7 +258,7 @@ class ElementMixin:
243
258
  logger.warning("Cannot create elements as this worker is in read-only mode")
244
259
  return
245
260
 
246
- created_ids = self.request(
261
+ created_ids = self.api_client.request(
247
262
  "CreateElements",
248
263
  id=parent.id,
249
264
  body={
@@ -311,7 +326,7 @@ class ElementMixin:
311
326
  logger.warning("Cannot link elements as this worker is in read-only mode")
312
327
  return
313
328
 
314
- return self.request(
329
+ return self.api_client.request(
315
330
  "CreateElementParent",
316
331
  parent=parent.id,
317
332
  child=child.id,
@@ -383,7 +398,7 @@ class ElementMixin:
383
398
  logger.warning("Cannot update element as this worker is in read-only mode")
384
399
  return
385
400
 
386
- updated_element = self.request(
401
+ updated_element = self.api_client.request(
387
402
  "PartialUpdateElement",
388
403
  id=element.id,
389
404
  body=kwargs,
@@ -48,6 +48,7 @@ class EntityMixin:
48
48
  if not self.entity_types:
49
49
  # Load entity_types of corpus
50
50
  self.list_corpus_entity_types()
51
+
51
52
  for entity_type in entity_types:
52
53
  # Do nothing if type already exists
53
54
  if entity_type in self.entity_types:
@@ -60,7 +61,7 @@ class EntityMixin:
60
61
  )
61
62
 
62
63
  # Create type if non-existent
63
- self.entity_types[entity_type] = self.request(
64
+ self.entity_types[entity_type] = self.api_client.request(
64
65
  "CreateEntityType",
65
66
  body={
66
67
  "name": entity_type,
@@ -106,7 +107,7 @@ class EntityMixin:
106
107
  entity_type_id = self.entity_types.get(type)
107
108
  assert entity_type_id, f"Entity type `{type}` not found in the corpus."
108
109
 
109
- entity = self.request(
110
+ entity = self.api_client.request(
110
111
  "CreateEntity",
111
112
  body={
112
113
  "name": name,
@@ -188,7 +189,7 @@ class EntityMixin:
188
189
  if confidence is not None:
189
190
  body["confidence"] = confidence
190
191
 
191
- transcription_ent = self.request(
192
+ transcription_ent = self.api_client.request(
192
193
  "CreateTranscriptionEntity",
193
194
  id=transcription.id,
194
195
  body=body,
@@ -289,7 +290,7 @@ class EntityMixin:
289
290
  )
290
291
  return
291
292
 
292
- created_ids = self.request(
293
+ created_ids = self.api_client.request(
293
294
  "CreateTranscriptionEntities",
294
295
  id=transcription.id,
295
296
  body={
@@ -385,9 +386,7 @@ class EntityMixin:
385
386
  f'Loaded {count} entit{"ies" if count > 1 else "y"} in corpus ({self.corpus_id})'
386
387
  )
387
388
 
388
- def list_corpus_entity_types(
389
- self,
390
- ):
389
+ def list_corpus_entity_types(self):
391
390
  """
392
391
  Loads available entity types in corpus.
393
392
  """
@@ -93,7 +93,7 @@ class MetaDataMixin:
93
93
  logger.warning("Cannot create metadata as this worker is in read-only mode")
94
94
  return
95
95
 
96
- metadata = self.request(
96
+ metadata = self.api_client.request(
97
97
  "CreateMetaData",
98
98
  id=element.id,
99
99
  body={
@@ -168,7 +168,7 @@ class MetaDataMixin:
168
168
  logger.warning("Cannot create metadata as this worker is in read-only mode")
169
169
  return
170
170
 
171
- created_metadata_list = self.request(
171
+ created_metadata_list = self.api_client.request(
172
172
  "CreateMetaDataBulk",
173
173
  id=element.id,
174
174
  body={
@@ -22,7 +22,7 @@ class TaskMixin:
22
22
  task_id, uuid.UUID
23
23
  ), "task_id shouldn't be null and should be an UUID"
24
24
 
25
- results = self.request("ListArtifacts", id=task_id)
25
+ results = self.api_client.request("ListArtifacts", id=task_id)
26
26
 
27
27
  return map(Artifact, results)
28
28
 
@@ -43,4 +43,6 @@ class TaskMixin:
43
43
  artifact, Artifact
44
44
  ), "artifact shouldn't be null and should be an Artifact"
45
45
 
46
- return self.request("DownloadArtifact", id=task_id, path=artifact.path)
46
+ return self.api_client.request(
47
+ "DownloadArtifact", id=task_id, path=artifact.path
48
+ )
@@ -185,7 +185,7 @@ class TrainingMixin:
185
185
  assert not self.model_version, "A model version has already been created."
186
186
 
187
187
  configuration = configuration or {}
188
- self.model_version = self.request(
188
+ self.model_version = self.api_client.request(
189
189
  "CreateModelVersion",
190
190
  id=model_id,
191
191
  body=build_clean_payload(
@@ -217,7 +217,7 @@ class TrainingMixin:
217
217
  :param parent: ID of the parent model version
218
218
  """
219
219
  assert self.model_version, "No model version has been created yet."
220
- self.model_version = self.request(
220
+ self.model_version = self.api_client.request(
221
221
  "UpdateModelVersion",
222
222
  id=self.model_version["id"],
223
223
  body=build_clean_payload(
@@ -273,41 +273,44 @@ class TrainingMixin:
273
273
  """
274
274
  assert self.model_version, "You must create the model version and upload its archive before validating it."
275
275
  try:
276
- self.model_version = self.request(
277
- "ValidateModelVersion",
276
+ self.model_version = self.api_client.request(
277
+ "PartialUpdateModelVersion",
278
278
  id=self.model_version["id"],
279
279
  body={
280
+ "state": "available",
280
281
  "size": size,
281
282
  "hash": hash,
282
283
  "archive_hash": archive_hash,
283
284
  },
284
285
  )
285
286
  except ErrorResponse as e:
286
- # Temporary fix while waiting for `ValidateModelVersion` refactoring as it can
287
- # return errors even when the model version is properly validated
288
- if e.status_code in [403, 500]:
289
- logger.warning(
290
- f'An error occurred while validating model version {self.model_version["id"]}, please check its status.'
291
- )
292
- return
293
-
294
- if e.status_code != 409:
287
+ model_version = e.content
288
+ if not model_version or "id" not in model_version:
295
289
  raise e
296
290
 
297
291
  logger.warning(
298
292
  f"An available model version exists with hash {hash}, using it instead of the pending version."
299
293
  )
300
294
  pending_version_id = self.model_version["id"]
301
- self.model_version = getattr(e, "content", None)
302
- assert self.model_version is not None, "An unexpected error occurred."
303
-
304
295
  logger.warning("Removing the pending model version.")
305
296
  try:
306
- self.request("DestroyModelVersion", id=pending_version_id)
297
+ self.api_client.request("DestroyModelVersion", id=pending_version_id)
307
298
  except ErrorResponse as e:
308
299
  msg = getattr(e, "content", str(e))
309
300
  logger.error(
310
301
  f"An error occurred removing the pending version {pending_version_id}: {msg}."
311
302
  )
312
303
 
304
+ logger.info("Retrieving the existing model version.")
305
+ existing_version_id = model_version["id"].pop()
306
+ try:
307
+ self.model_version = self.api_client.request(
308
+ "RetrieveModelVersion", id=existing_version_id
309
+ )
310
+ except ErrorResponse as e:
311
+ logger.error(
312
+ f"An error occurred retrieving the existing version {existing_version_id}: {e.status_code} - {e.content}."
313
+ )
314
+ raise
315
+
313
316
  logger.info(f"Model version {self.model_version['id']} is now available.")
@@ -77,7 +77,7 @@ class TranscriptionMixin:
77
77
  )
78
78
  return
79
79
 
80
- created = self.request(
80
+ created = self.api_client.request(
81
81
  "CreateTranscription",
82
82
  id=element.id,
83
83
  body={
@@ -171,7 +171,7 @@ class TranscriptionMixin:
171
171
  )
172
172
  return
173
173
 
174
- created_trs = self.request(
174
+ created_trs = self.api_client.request(
175
175
  "CreateTranscriptions",
176
176
  body={
177
177
  "worker_run_id": self.worker_run_id,
@@ -291,7 +291,7 @@ class TranscriptionMixin:
291
291
  )
292
292
  return
293
293
 
294
- annotations = self.request(
294
+ annotations = self.api_client.request(
295
295
  "CreateElementTranscriptions",
296
296
  id=element.id,
297
297
  body={
@@ -34,7 +34,9 @@ class WorkerVersionMixin:
34
34
  if worker_version_id in self._worker_version_cache:
35
35
  return self._worker_version_cache[worker_version_id]
36
36
 
37
- worker_version = self.request("RetrieveWorkerVersion", id=worker_version_id)
37
+ worker_version = self.api_client.request(
38
+ "RetrieveWorkerVersion", id=worker_version_id
39
+ )
38
40
  self._worker_version_cache[worker_version_id] = worker_version
39
41
 
40
42
  return worker_version
tests/test_base_worker.py CHANGED
@@ -658,7 +658,7 @@ def test_find_extras_directory_not_found(monkeypatch, extras_path, exists, error
658
658
  def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_path):
659
659
  responses.add(
660
660
  responses.GET,
661
- "http://testserver/api/v1/task/my_task/from-agent/",
661
+ "http://testserver/api/v1/task/my_task/",
662
662
  status=200,
663
663
  json={"parents": ["first", "second", "third"]},
664
664
  )
@@ -22,6 +22,24 @@ from tests import CORPUS_ID
22
22
  from . import BASE_API_CALLS
23
23
 
24
24
 
25
+ def test_list_corpus_types(responses, mock_elements_worker):
26
+ responses.add(
27
+ responses.GET,
28
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
29
+ json={
30
+ "id": CORPUS_ID,
31
+ "types": [{"slug": "folder"}, {"slug": "page"}],
32
+ },
33
+ )
34
+
35
+ mock_elements_worker.list_corpus_types()
36
+
37
+ assert mock_elements_worker.corpus_types == {
38
+ "folder": {"slug": "folder"},
39
+ "page": {"slug": "page"},
40
+ }
41
+
42
+
25
43
  def test_check_required_types_argument_types(mock_elements_worker):
26
44
  with pytest.raises(
27
45
  AssertionError, match="At least one element type slug is required."
@@ -32,17 +50,11 @@ def test_check_required_types_argument_types(mock_elements_worker):
32
50
  mock_elements_worker.check_required_types("lol", 42)
33
51
 
34
52
 
35
- def test_check_required_types(responses, mock_elements_worker):
36
- responses.add(
37
- responses.GET,
38
- f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
39
- json={
40
- "id": CORPUS_ID,
41
- "name": "Some Corpus",
42
- "types": [{"slug": "folder"}, {"slug": "page"}],
43
- },
44
- )
45
- mock_elements_worker.setup_api_client()
53
+ def test_check_required_types(mock_elements_worker):
54
+ mock_elements_worker.corpus_types = {
55
+ "folder": {"slug": "folder"},
56
+ "page": {"slug": "page"},
57
+ }
46
58
 
47
59
  assert mock_elements_worker.check_required_types("page")
48
60
  assert mock_elements_worker.check_required_types("page", "folder")
@@ -50,22 +62,18 @@ def test_check_required_types(responses, mock_elements_worker):
50
62
  with pytest.raises(
51
63
  MissingTypeError,
52
64
  match=re.escape(
53
- "Element type(s) act, text_line were not found in the Some Corpus corpus (11111111-1111-1111-1111-111111111111)."
65
+ "Element type(s) act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
54
66
  ),
55
67
  ):
56
68
  assert mock_elements_worker.check_required_types("page", "text_line", "act")
57
69
 
58
70
 
59
71
  def test_create_missing_types(responses, mock_elements_worker):
60
- responses.add(
61
- responses.GET,
62
- f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
63
- json={
64
- "id": CORPUS_ID,
65
- "name": "Some Corpus",
66
- "types": [{"slug": "folder"}, {"slug": "page"}],
67
- },
68
- )
72
+ mock_elements_worker.corpus_types = {
73
+ "folder": {"slug": "folder"},
74
+ "page": {"slug": "page"},
75
+ }
76
+
69
77
  responses.add(
70
78
  responses.POST,
71
79
  "http://testserver/api/v1/elements/type/",
@@ -94,7 +102,6 @@ def test_create_missing_types(responses, mock_elements_worker):
94
102
  )
95
103
  ],
96
104
  )
97
- mock_elements_worker.setup_api_client()
98
105
 
99
106
  assert mock_elements_worker.check_required_types(
100
107
  "page", "text_line", "act", create_missing=True
@@ -179,44 +179,12 @@ def test_validate_model_version_not_created(mock_training_worker):
179
179
  mock_training_worker.validate_model_version(hash="a", size=1, archive_hash="b")
180
180
 
181
181
 
182
- @pytest.mark.parametrize("status_code", [403, 500])
183
- def test_validate_model_version_catch_errors(
184
- mocker, mock_training_worker, caplog, status_code
185
- ):
186
- mocker.patch(
187
- "arkindex.client.ArkindexClient.request.retry.retry", return_value=False
188
- )
189
-
190
- mock_training_worker.model_version = {"id": "model_version_id"}
191
- args = {
192
- "hash": "hash",
193
- "archive_hash": "archive_hash",
194
- "size": 30,
195
- }
196
- mock_training_worker.api_client.add_error_response(
197
- "ValidateModelVersion",
198
- id="model_version_id",
199
- status_code=status_code,
200
- body=args,
201
- )
202
-
203
- mock_training_worker.validate_model_version(**args)
204
- assert mock_training_worker.model_version == {"id": "model_version_id"}
205
- assert [
206
- (level, message)
207
- for module, level, message in caplog.record_tuples
208
- if module == "arkindex_worker"
209
- ] == [
210
- (
211
- logging.WARNING,
212
- "An error occurred while validating model version model_version_id, please check its status.",
213
- ),
214
- ]
215
-
216
-
217
182
  @pytest.mark.parametrize("deletion_failed", [True, False])
218
183
  def test_validate_model_version_hash_conflict(
219
- mock_training_worker, default_model_version, caplog, deletion_failed
184
+ mock_training_worker,
185
+ default_model_version,
186
+ caplog,
187
+ deletion_failed,
220
188
  ):
221
189
  mock_training_worker.model_version = {"id": "another_id"}
222
190
  args = {
@@ -225,11 +193,11 @@ def test_validate_model_version_hash_conflict(
225
193
  "size": 30,
226
194
  }
227
195
  mock_training_worker.api_client.add_error_response(
228
- "ValidateModelVersion",
196
+ "PartialUpdateModelVersion",
229
197
  id="another_id",
230
198
  status_code=409,
231
- body=args,
232
- content=default_model_version,
199
+ body={"state": "available", **args},
200
+ content={"id": ["model_version_id"]},
233
201
  )
234
202
  if deletion_failed:
235
203
  mock_training_worker.api_client.add_error_response(
@@ -244,6 +212,11 @@ def test_validate_model_version_hash_conflict(
244
212
  id="another_id",
245
213
  response="No content",
246
214
  )
215
+ mock_training_worker.api_client.add_response(
216
+ "RetrieveModelVersion",
217
+ id="model_version_id",
218
+ response=default_model_version,
219
+ )
247
220
 
248
221
  mock_training_worker.validate_model_version(**args)
249
222
  assert mock_training_worker.model_version == default_model_version
@@ -266,6 +239,7 @@ def test_validate_model_version_hash_conflict(
266
239
  ),
267
240
  (logging.WARNING, "Removing the pending model version."),
268
241
  *error_msg,
242
+ (logging.INFO, "Retrieving the existing model version."),
269
243
  (logging.INFO, "Model version model_version_id is now available."),
270
244
  ]
271
245
 
@@ -278,9 +252,9 @@ def test_validate_model_version(mock_training_worker, default_model_version, cap
278
252
  "size": 30,
279
253
  }
280
254
  mock_training_worker.api_client.add_response(
281
- "ValidateModelVersion",
255
+ "PartialUpdateModelVersion",
282
256
  id="model_version_id",
283
- body=args,
257
+ body={"state": "available", **args},
284
258
  response=default_model_version,
285
259
  )
286
260
 
tests/test_merge.py CHANGED
@@ -161,7 +161,7 @@ def test_merge_from_worker(
161
161
  """
162
162
  responses.add(
163
163
  responses.GET,
164
- "http://testserver/api/v1/task/my_task/from-agent/",
164
+ "http://testserver/api/v1/task/my_task/",
165
165
  status=200,
166
166
  json={"parents": ["first", "second"]},
167
167
  )