arkindex-base-worker 0.3.6rc4__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. arkindex_base_worker-0.3.7.dist-info/LICENSE +21 -0
  2. arkindex_base_worker-0.3.7.dist-info/METADATA +77 -0
  3. arkindex_base_worker-0.3.7.dist-info/RECORD +47 -0
  4. {arkindex_base_worker-0.3.6rc4.dist-info → arkindex_base_worker-0.3.7.dist-info}/WHEEL +1 -1
  5. {arkindex_base_worker-0.3.6rc4.dist-info → arkindex_base_worker-0.3.7.dist-info}/top_level.txt +2 -0
  6. arkindex_worker/cache.py +14 -0
  7. arkindex_worker/image.py +29 -19
  8. arkindex_worker/models.py +14 -2
  9. arkindex_worker/utils.py +17 -3
  10. arkindex_worker/worker/__init__.py +122 -125
  11. arkindex_worker/worker/base.py +24 -24
  12. arkindex_worker/worker/classification.py +18 -25
  13. arkindex_worker/worker/dataset.py +24 -18
  14. arkindex_worker/worker/element.py +100 -19
  15. arkindex_worker/worker/entity.py +35 -4
  16. arkindex_worker/worker/metadata.py +21 -11
  17. arkindex_worker/worker/training.py +13 -0
  18. arkindex_worker/worker/transcription.py +45 -5
  19. arkindex_worker/worker/version.py +22 -0
  20. hooks/pre_gen_project.py +3 -0
  21. tests/conftest.py +16 -8
  22. tests/test_base_worker.py +0 -6
  23. tests/test_dataset_worker.py +291 -409
  24. tests/test_elements_worker/test_classifications.py +365 -539
  25. tests/test_elements_worker/test_cli.py +1 -1
  26. tests/test_elements_worker/test_dataset.py +97 -116
  27. tests/test_elements_worker/test_elements.py +354 -76
  28. tests/test_elements_worker/test_entities.py +22 -2
  29. tests/test_elements_worker/test_metadata.py +53 -27
  30. tests/test_elements_worker/test_training.py +35 -0
  31. tests/test_elements_worker/test_transcriptions.py +149 -16
  32. tests/test_elements_worker/test_worker.py +19 -6
  33. tests/test_image.py +37 -0
  34. tests/test_utils.py +23 -1
  35. worker-demo/tests/__init__.py +0 -0
  36. worker-demo/tests/conftest.py +32 -0
  37. worker-demo/tests/test_worker.py +12 -0
  38. worker-demo/worker_demo/__init__.py +6 -0
  39. worker-demo/worker_demo/worker.py +19 -0
  40. arkindex_base_worker-0.3.6rc4.dist-info/METADATA +0 -47
  41. arkindex_base_worker-0.3.6rc4.dist-info/RECORD +0 -40
@@ -1,31 +1,31 @@
1
1
  """
2
2
  Base classes to implement Arkindex workers.
3
3
  """
4
+
4
5
  import contextlib
5
6
  import json
6
7
  import os
7
8
  import sys
8
9
  import uuid
10
+ from argparse import ArgumentTypeError
9
11
  from collections.abc import Iterable, Iterator
10
12
  from enum import Enum
11
- from itertools import groupby
12
- from operator import itemgetter
13
13
  from pathlib import Path
14
14
 
15
15
  from apistar.exceptions import ErrorResponse
16
16
 
17
17
  from arkindex_worker import logger
18
18
  from arkindex_worker.cache import CachedElement
19
- from arkindex_worker.models import Dataset, Element
19
+ from arkindex_worker.models import Dataset, Element, Set
20
20
  from arkindex_worker.worker.base import BaseWorker
21
21
  from arkindex_worker.worker.classification import ClassificationMixin
22
22
  from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
23
23
  from arkindex_worker.worker.element import ElementMixin
24
- from arkindex_worker.worker.entity import EntityMixin # noqa: F401
24
+ from arkindex_worker.worker.entity import EntityMixin
25
25
  from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
26
26
  from arkindex_worker.worker.task import TaskMixin
27
27
  from arkindex_worker.worker.transcription import TranscriptionMixin
28
- from arkindex_worker.worker.version import WorkerVersionMixin # noqa: F401
28
+ from arkindex_worker.worker.version import WorkerVersionMixin
29
29
 
30
30
 
31
31
  class ActivityState(Enum):
@@ -159,6 +159,16 @@ class ElementsWorker(
159
159
  super().configure()
160
160
  super().configure_cache()
161
161
 
162
+ # Retrieve the model configuration
163
+ if self.model_configuration:
164
+ self.config.update(self.model_configuration)
165
+ logger.info("Model version configuration retrieved")
166
+
167
+ # Retrieve the user configuration
168
+ if self.user_configuration:
169
+ self.config.update(self.user_configuration)
170
+ logger.info("User configuration retrieved")
171
+
162
172
  def run(self):
163
173
  """
164
174
  Implements an Arkindex worker that goes through each element returned by
@@ -229,12 +239,13 @@ class ElementsWorker(
229
239
  with contextlib.suppress(Exception):
230
240
  self.update_activity(element.id, ActivityState.Error)
231
241
 
242
+ message = f'Ran on {count} element{"s"[:count>1]}: {count - failed} completed, {failed} failed'
232
243
  if failed:
233
- logger.error(
234
- f"Ran on {count} elements: {count - failed} completed, {failed} failed"
235
- )
244
+ logger.error(message)
236
245
  if failed >= count: # Everything failed!
237
246
  sys.exit(1)
247
+ else:
248
+ logger.info(message)
238
249
 
239
250
  def process_element(self, element: Element | CachedElement):
240
251
  """
@@ -299,6 +310,21 @@ class ElementsWorker(
299
310
  return True
300
311
 
301
312
 
313
+ def check_dataset_set(value: str) -> tuple[uuid.UUID, str]:
314
+ values = value.split(":")
315
+ if len(values) != 2:
316
+ raise ArgumentTypeError(
317
+ f"'{value}' is not in the correct format `<dataset_id>:<set_name>`"
318
+ )
319
+
320
+ dataset_id, set_name = values
321
+ try:
322
+ dataset_id = uuid.UUID(dataset_id)
323
+ return (dataset_id, set_name)
324
+ except (TypeError, ValueError) as e:
325
+ raise ArgumentTypeError(f"'{dataset_id}' should be a valid UUID") from e
326
+
327
+
302
328
  class MissingDatasetArchive(Exception):
303
329
  """
304
330
  Exception raised when the compressed archive associated to
@@ -308,7 +334,7 @@ class MissingDatasetArchive(Exception):
308
334
 
309
335
  class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
310
336
  """
311
- Base class for ML workers that operate on Arkindex datasets.
337
+ Base class for ML workers that operate on Arkindex dataset sets.
312
338
 
313
339
  This class inherits from numerous mixin classes found in other modules of
314
340
  ``arkindex.worker``, which provide helpers to read and write to the Arkindex API.
@@ -318,24 +344,28 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
318
344
  self,
319
345
  description: str = "Arkindex Dataset Worker",
320
346
  support_cache: bool = False,
321
- generator: bool = False,
322
347
  ):
323
348
  """
324
349
  :param description: The worker's description.
325
350
  :param support_cache: Whether the worker supports cache.
326
- :param generator: Whether the worker generates the dataset archive artifact.
327
351
  """
328
352
  super().__init__(description, support_cache)
329
353
 
354
+ # Path to the dataset compressed archive (containing images and a SQLite database)
355
+ # Set as an instance variable as dataset workers might use it to easily extract its content
356
+ self.downloaded_dataset_artifact: Path | None = None
357
+
330
358
  self.parser.add_argument(
331
- "--dataset",
332
- type=uuid.UUID,
359
+ "--set",
360
+ type=check_dataset_set,
333
361
  nargs="+",
334
- help="One or more Arkindex dataset ID",
362
+ help="""
363
+ One or more Arkindex dataset sets, format is <dataset_uuid>:<set_name>
364
+ (e.g.: "12341234-1234-1234-1234-123412341234:train")
365
+ """,
366
+ default=[],
335
367
  )
336
368
 
337
- self.generator = generator
338
-
339
369
  def configure(self):
340
370
  """
341
371
  Setup the worker using CLI arguments and environment variables.
@@ -349,163 +379,130 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
349
379
  super().configure()
350
380
  super().configure_cache()
351
381
 
352
- def download_dataset_artifact(self, dataset: Dataset) -> Path:
382
+ # Retrieve the model configuration
383
+ if self.model_configuration:
384
+ self.config.update(self.model_configuration)
385
+ logger.info("Model version configuration retrieved")
386
+
387
+ # Retrieve the user configuration
388
+ if self.user_configuration:
389
+ self.config.update(self.user_configuration)
390
+ logger.info("User configuration retrieved")
391
+
392
+ def cleanup_downloaded_artifact(self) -> None:
393
+ """
394
+ Cleanup the downloaded dataset artifact if any
395
+ """
396
+ if not self.downloaded_dataset_artifact:
397
+ return
398
+
399
+ self.downloaded_dataset_artifact.unlink(missing_ok=True)
400
+
401
+ def download_dataset_artifact(self, dataset: Dataset) -> None:
353
402
  """
354
403
  Find and download the compressed archive artifact describing a dataset using
355
404
  the [list_artifacts][arkindex_worker.worker.task.TaskMixin.list_artifacts] and
356
405
  [download_artifact][arkindex_worker.worker.task.TaskMixin.download_artifact] methods.
357
406
 
358
407
  :param dataset: The dataset to retrieve the compressed archive artifact for.
359
- :returns: A path to the downloaded artifact.
360
408
  :raises MissingDatasetArchive: When the dataset artifact is not found.
361
409
  """
410
+ extra_dir = self.find_extras_directory()
411
+ archive = extra_dir / dataset.filepath
412
+ if archive.exists():
413
+ return
362
414
 
363
- task_id = uuid.UUID(dataset.task_id)
415
+ # Cleanup the dataset artifact that was downloaded previously
416
+ self.cleanup_downloaded_artifact()
364
417
 
418
+ logger.info(f"Downloading artifact for {dataset}")
419
+ task_id = uuid.UUID(dataset.task_id)
365
420
  for artifact in self.list_artifacts(task_id):
366
421
  if artifact.path != dataset.filepath:
367
422
  continue
368
423
 
369
- extra_dir = self.find_extras_directory()
370
- archive = extra_dir / dataset.filepath
371
424
  archive.write_bytes(self.download_artifact(task_id, artifact).read())
372
- return archive
425
+ self.downloaded_dataset_artifact = archive
426
+ return
373
427
 
374
428
  raise MissingDatasetArchive(
375
429
  "The dataset compressed archive artifact was not found."
376
430
  )
377
431
 
378
- def list_dataset_elements_per_split(
379
- self, dataset: Dataset
380
- ) -> Iterator[tuple[str, list[Element]]]:
381
- """
382
- List the elements in the dataset, grouped by split, using the
383
- [list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method.
384
-
385
- :param dataset: The dataset to retrieve elements from.
386
- :returns: An iterator of tuples containing the split name and the list of its elements.
387
- """
388
-
389
- def format_split(
390
- split: tuple[str, Iterator[tuple[str, Element]]],
391
- ) -> tuple[str, list[Element]]:
392
- return (split[0], list(map(itemgetter(1), list(split[1]))))
393
-
394
- return map(
395
- format_split,
396
- groupby(
397
- sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
398
- key=itemgetter(0),
399
- ),
400
- )
401
-
402
- def process_dataset(self, dataset: Dataset):
432
+ def process_set(self, set: Set):
403
433
  """
404
- Override this method to implement your worker and process a single Arkindex dataset at once.
434
+ Override this method to implement your worker and process a single Arkindex dataset set at once.
405
435
 
406
- :param dataset: The dataset to process.
436
+ :param set: The set to process.
407
437
  """
408
438
 
409
- def list_datasets(self) -> Iterator[Dataset] | Iterator[str]:
439
+ def list_sets(self) -> Iterator[Set]:
410
440
  """
411
- List the datasets to be processed, either from the CLI arguments or using the
412
- [list_process_datasets][arkindex_worker.worker.dataset.DatasetMixin.list_process_datasets] method.
441
+ List the sets to be processed, either from the CLI arguments or using the
442
+ [list_process_sets][arkindex_worker.worker.dataset.DatasetMixin.list_process_sets] method.
413
443
 
414
- :returns: An iterator of strings if the worker is in read-only mode,
415
- else an iterator of ``Dataset`` objects.
444
+ :returns: An iterator of ``Set`` objects.
416
445
  """
417
- if self.is_read_only:
418
- return map(str, self.args.dataset)
446
+ if not self.is_read_only:
447
+ yield from self.list_process_sets()
448
+
449
+ datasets: dict[uuid.UUID, Dataset] = {}
450
+ for dataset_id, set_name in self.args.set:
451
+ # Retrieving dataset information is not already cached
452
+ if dataset_id not in datasets:
453
+ datasets[dataset_id] = Dataset(
454
+ **self.request("RetrieveDataset", id=dataset_id)
455
+ )
419
456
 
420
- return self.list_process_datasets()
457
+ yield Set(name=set_name, dataset=datasets[dataset_id])
421
458
 
422
459
  def run(self):
423
460
  """
424
- Implements an Arkindex worker that goes through each dataset returned by
425
- [list_datasets][arkindex_worker.worker.DatasetWorker.list_datasets].
461
+ Implements an Arkindex worker that goes through each dataset set returned by
462
+ [list_sets][arkindex_worker.worker.DatasetWorker.list_sets].
426
463
 
427
- It calls [process_dataset][arkindex_worker.worker.DatasetWorker.process_dataset],
428
- catching exceptions, and handles updating the [DatasetState][arkindex_worker.worker.dataset.DatasetState]
429
- when the worker is a generator.
464
+ It calls [process_set][arkindex_worker.worker.DatasetWorker.process_set],
465
+ catching exceptions.
430
466
  """
431
467
  self.configure()
432
468
 
433
- datasets: list[Dataset] | list[str] = list(self.list_datasets())
434
- if not datasets:
435
- logger.warning("No datasets to process, stopping.")
469
+ dataset_sets: list[Set] = list(self.list_sets())
470
+ if not dataset_sets:
471
+ logger.warning("No sets to process, stopping.")
436
472
  sys.exit(1)
437
473
 
438
- # Process every dataset
439
- count = len(datasets)
474
+ # Process every set
475
+ count = len(dataset_sets)
440
476
  failed = 0
441
- for i, item in enumerate(datasets, start=1):
442
- dataset = None
443
- dataset_artifact = None
444
-
477
+ for i, dataset_set in enumerate(dataset_sets, start=1):
445
478
  try:
446
- if not self.is_read_only:
447
- # Just use the result of list_datasets as the dataset
448
- dataset = item
449
- else:
450
- # Load dataset using the Arkindex API
451
- dataset = Dataset(**self.request("RetrieveDataset", id=item))
452
-
453
- if self.generator:
454
- assert (
455
- dataset.state == DatasetState.Open.value
456
- ), "When generating a new dataset, its state should be Open."
457
- else:
458
- assert (
459
- dataset.state == DatasetState.Complete.value
460
- ), "When processing an existing dataset, its state should be Complete."
461
-
462
- logger.info(f"Processing {dataset} ({i}/{count})")
463
-
464
- if self.generator:
465
- # Update the dataset state to Building
466
- logger.info(f"Building {dataset} ({i}/{count})")
467
- self.update_dataset_state(dataset, DatasetState.Building)
468
- else:
469
- logger.info(f"Downloading data for {dataset} ({i}/{count})")
470
- dataset_artifact = self.download_dataset_artifact(dataset)
479
+ assert (
480
+ dataset_set.dataset.state == DatasetState.Complete.value
481
+ ), "When processing a set, its dataset state should be Complete."
471
482
 
472
- # Process the dataset
473
- self.process_dataset(dataset)
483
+ logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
484
+ self.download_dataset_artifact(dataset_set.dataset)
474
485
 
475
- if self.generator:
476
- # Update the dataset state to Complete
477
- logger.info(f"Completed {dataset} ({i}/{count})")
478
- self.update_dataset_state(dataset, DatasetState.Complete)
486
+ logger.info(f"Processing {dataset_set} ({i}/{count})")
487
+ self.process_set(dataset_set)
479
488
  except Exception as e:
480
- # Handle errors occurring while retrieving, processing or patching the state for this dataset.
489
+ # Handle errors occurring while retrieving or processing this dataset set
481
490
  failed += 1
482
491
 
483
- # Handle the case where we failed retrieving the dataset
484
- dataset_id = dataset.id if dataset else item
485
-
486
492
  if isinstance(e, ErrorResponse):
487
- message = f"An API error occurred while processing dataset {dataset_id}: {e.title} - {e.content}"
493
+ message = f"An API error occurred while processing {dataset_set}: {e.title} - {e.content}"
488
494
  else:
489
- message = (
490
- f"Failed running worker on dataset {dataset_id}: {repr(e)}"
491
- )
495
+ message = f"Failed running worker on {dataset_set}: {repr(e)}"
492
496
 
493
- logger.warning(
494
- message,
495
- exc_info=e if self.args.verbose else None,
496
- )
497
- if dataset and self.generator:
498
- # Try to update the state to Error regardless of the response
499
- with contextlib.suppress(Exception):
500
- self.update_dataset_state(dataset, DatasetState.Error)
501
- finally:
502
- # Cleanup the dataset artifact if it was downloaded, no matter what
503
- if dataset_artifact:
504
- dataset_artifact.unlink(missing_ok=True)
497
+ logger.warning(message, exc_info=e if self.args.verbose else None)
498
+
499
+ # Cleanup the latest downloaded dataset artifact
500
+ self.cleanup_downloaded_artifact()
505
501
 
502
+ message = f'Ran on {count} set{"s"[:count>1]}: {count - failed} completed, {failed} failed'
506
503
  if failed:
507
- logger.error(
508
- f"Ran on {count} datasets: {count - failed} completed, {failed} failed"
509
- )
504
+ logger.error(message)
510
505
  if failed >= count: # Everything failed!
511
506
  sys.exit(1)
507
+ else:
508
+ logger.info(message)
@@ -1,6 +1,7 @@
1
1
  """
2
2
  The base class for all Arkindex workers.
3
3
  """
4
+
4
5
  import argparse
5
6
  import json
6
7
  import logging
@@ -20,7 +21,6 @@ from tenacity import (
20
21
  wait_exponential,
21
22
  )
22
23
 
23
- from arkindex import ArkindexClient, options_from_env
24
24
  from arkindex_worker import logger
25
25
  from arkindex_worker.cache import (
26
26
  check_version,
@@ -30,18 +30,7 @@ from arkindex_worker.cache import (
30
30
  merge_parents_cache,
31
31
  )
32
32
  from arkindex_worker.utils import close_delete_file, extract_tar_zst_archive
33
-
34
-
35
- def _is_500_error(exc: Exception) -> bool:
36
- """
37
- Check if an Arkindex API error has a HTTP 5xx error code.
38
- Used to retry most API calls in [BaseWorker][arkindex_worker.worker.base.BaseWorker].
39
- :param exc: Exception to check
40
- """
41
- if not isinstance(exc, ErrorResponse):
42
- return False
43
-
44
- return 500 <= exc.status_code < 600
33
+ from teklia_toolbox.requests import _get_arkindex_client, _is_500_error
45
34
 
46
35
 
47
36
  class ExtrasDirNotFoundError(Exception):
@@ -72,7 +61,7 @@ class BaseWorker:
72
61
  self.parser.add_argument(
73
62
  "-c",
74
63
  "--config",
75
- help="Alternative configuration file when running without a Worker Version ID",
64
+ help="Alternative configuration file when running without a Worker Run ID",
76
65
  type=open,
77
66
  )
78
67
  self.parser.add_argument(
@@ -94,7 +83,7 @@ class BaseWorker:
94
83
  "--dev",
95
84
  help=(
96
85
  "Run worker in developer mode. "
97
- "Worker will be in read-only state even if a worker_version is supplied. "
86
+ "Worker will be in read-only state even if a worker run is supplied. "
98
87
  ),
99
88
  action="store_true",
100
89
  default=False,
@@ -148,6 +137,13 @@ class BaseWorker:
148
137
  # there is at least one available sqlite database either given or in the parent tasks
149
138
  self.use_cache = False
150
139
 
140
+ # model_version_id will be updated in configure() using the worker_run's model version
141
+ # or in configure_for_developers() from the environment
142
+ self.model_version_id = None
143
+ # model_details will be updated in configure() using the worker_run's model version
144
+ # or in configure_for_developers() from the environment
145
+ self.model_details = {}
146
+
151
147
  # task_parents will be updated in configure_cache() if the cache is supported,
152
148
  # if the task ID is set and if no database is passed as argument
153
149
  self.task_parents = []
@@ -176,12 +172,20 @@ class BaseWorker:
176
172
  """
177
173
  return self.args.dev or self.worker_run_id is None
178
174
 
175
+ @property
176
+ def worker_version_id(self):
177
+ """Deprecated property previously used to retrieve the current WorkerVersion ID.
178
+
179
+ :raises DeprecationWarning: Whenever `worker_version_id` is used.
180
+ """
181
+ raise DeprecationWarning("`worker_version_id` usage is deprecated")
182
+
179
183
  def setup_api_client(self):
180
184
  """
181
185
  Create an ArkindexClient to make API requests towards Arkindex instances.
182
186
  """
183
187
  # Build Arkindex API client from environment variables
184
- self.api_client = ArkindexClient(**options_from_env())
188
+ self.api_client = _get_arkindex_client()
185
189
  logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
186
190
 
187
191
  def configure_for_developers(self):
@@ -243,25 +247,21 @@ class BaseWorker:
243
247
 
244
248
  # Load worker version information
245
249
  worker_version = worker_run["worker_version"]
246
-
247
- # Store worker version id
248
- self.worker_version_id = worker_version["id"]
249
-
250
250
  self.worker_details = worker_version["worker"]
251
251
 
252
252
  logger.info(f"Loaded {worker_run['summary']} from API")
253
253
 
254
254
  # Load model version configuration when available
255
255
  model_version = worker_run.get("model_version")
256
- if model_version and model_version.get("configuration"):
256
+ if model_version:
257
257
  logger.info("Loaded model version configuration from WorkerRun")
258
- self.model_configuration.update(model_version.get("configuration"))
258
+ self.model_configuration.update(model_version["configuration"])
259
259
 
260
260
  # Set model_version ID as worker attribute
261
- self.model_version_id = model_version.get("id")
261
+ self.model_version_id = model_version["id"]
262
262
 
263
263
  # Set model details as worker attribute
264
- self.model_details = model_version.get("model")
264
+ self.model_details = model_version["model"]
265
265
 
266
266
  # Retrieve initial configuration from API
267
267
  self.config = worker_version["configuration"].get("configuration", {})
@@ -2,8 +2,6 @@
2
2
  ElementsWorker methods for classifications and ML classes.
3
3
  """
4
4
 
5
- from uuid import UUID
6
-
7
5
  from apistar.exceptions import ErrorResponse
8
6
  from peewee import IntegrityError
9
7
 
@@ -154,13 +152,6 @@ class ClassificationMixin:
154
152
  # Detect already existing classification
155
153
  if e.status_code == 400 and "non_field_errors" in e.content:
156
154
  if (
157
- "The fields element, worker_version, ml_class must make a unique set."
158
- in e.content["non_field_errors"]
159
- ):
160
- logger.warning(
161
- f"This worker version has already set {ml_class} on element {element.id}"
162
- )
163
- elif (
164
155
  "The fields element, worker_run, ml_class must make a unique set."
165
156
  in e.content["non_field_errors"]
166
157
  ):
@@ -185,10 +176,14 @@ class ClassificationMixin:
185
176
  Create multiple classifications at once on the given element through the API.
186
177
 
187
178
  :param element: The element to create classifications on.
188
- :param classifications: The classifications to create, a list of dicts. Each of them contains
189
- a **ml_class_id** (str), the ID of the MLClass for this classification;
190
- a **confidence** (float), the confidence score, between 0 and 1;
191
- a **high_confidence** (bool), the high confidence state of the classification.
179
+ :param classifications: A list of dicts representing a classification each, with the following keys:
180
+
181
+ ml_class (str)
182
+ Required. Name of the MLClass to use.
183
+ confidence (float)
184
+ Required. Confidence score for the classification. Must be between 0 and 1.
185
+ high_confidence (bool)
186
+ Optional. Whether or not the classification is of high confidence.
192
187
 
193
188
  :returns: List of created classifications, as returned in the ``classifications`` field by
194
189
  the ``CreateClassifications`` API endpoint.
@@ -201,18 +196,10 @@ class ClassificationMixin:
201
196
  ), "classifications shouldn't be null and should be of type list"
202
197
 
203
198
  for index, classification in enumerate(classifications):
204
- ml_class_id = classification.get("ml_class_id")
199
+ ml_class = classification.get("ml_class")
205
200
  assert (
206
- ml_class_id and isinstance(ml_class_id, str)
207
- ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
208
-
209
- # Make sure it's a valid UUID
210
- try:
211
- UUID(ml_class_id)
212
- except ValueError as e:
213
- raise ValueError(
214
- f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
215
- ) from e
201
+ ml_class and isinstance(ml_class, str)
202
+ ), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str"
216
203
 
217
204
  confidence = classification.get("confidence")
218
205
  assert (
@@ -238,7 +225,13 @@ class ClassificationMixin:
238
225
  body={
239
226
  "parent": str(element.id),
240
227
  "worker_run_id": self.worker_run_id,
241
- "classifications": classifications,
228
+ "classifications": [
229
+ {
230
+ **classification,
231
+ "ml_class": self.get_ml_class_id(classification["ml_class"]),
232
+ }
233
+ for classification in classifications
234
+ ],
242
235
  },
243
236
  )["classifications"]
244
237
 
@@ -6,7 +6,8 @@ from collections.abc import Iterator
6
6
  from enum import Enum
7
7
 
8
8
  from arkindex_worker import logger
9
- from arkindex_worker.models import Dataset, Element
9
+ from arkindex_worker.cache import unsupported_cache
10
+ from arkindex_worker.models import Dataset, Element, Set
10
11
 
11
12
 
12
13
  class DatasetState(Enum):
@@ -36,38 +37,43 @@ class DatasetState(Enum):
36
37
 
37
38
 
38
39
  class DatasetMixin:
39
- def list_process_datasets(self) -> Iterator[Dataset]:
40
+ def list_process_sets(self) -> Iterator[Set]:
40
41
  """
41
- List datasets associated to the worker's process. This helper is not available in developer mode.
42
+ List dataset sets associated to the worker's process. This helper is not available in developer mode.
42
43
 
43
- :returns: An iterator of ``Dataset`` objects built from the ``ListProcessDatasets`` API endpoint.
44
+ :returns: An iterator of ``Set`` objects built from the ``ListProcessSets`` API endpoint.
44
45
  """
45
46
  assert not self.is_read_only, "This helper is not available in read-only mode."
46
47
 
47
48
  results = self.api_client.paginate(
48
- "ListProcessDatasets", id=self.process_information["id"]
49
+ "ListProcessSets", id=self.process_information["id"]
49
50
  )
50
51
 
51
- return map(Dataset, list(results))
52
+ return map(
53
+ lambda result: Set(
54
+ name=result["set_name"], dataset=Dataset(**result["dataset"])
55
+ ),
56
+ results,
57
+ )
52
58
 
53
- def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
59
+ def list_set_elements(self, dataset_set: Set) -> Iterator[Element]:
54
60
  """
55
- List elements in a dataset.
61
+ List elements in a dataset set.
56
62
 
57
- :param dataset: Dataset to find elements in.
58
- :returns: An iterator of tuples built from the ``ListDatasetElements`` API endpoint.
63
+ :param dataset_set: Set to find elements in.
64
+ :returns: An iterator of Element built from the ``ListDatasetElements`` API endpoint.
59
65
  """
60
- assert dataset and isinstance(
61
- dataset, Dataset
62
- ), "dataset shouldn't be null and should be a Dataset"
66
+ assert dataset_set and isinstance(
67
+ dataset_set, Set
68
+ ), "dataset_set shouldn't be null and should be a Set"
63
69
 
64
- results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
65
-
66
- def format_result(result):
67
- return (result["set"], Element(**result["element"]))
70
+ results = self.api_client.paginate(
71
+ "ListDatasetElements", id=dataset_set.dataset.id, set=dataset_set.name
72
+ )
68
73
 
69
- return map(format_result, list(results))
74
+ return map(lambda result: Element(**result["element"]), results)
70
75
 
76
+ @unsupported_cache
71
77
  def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
72
78
  """
73
79
  Partially updates a dataset state through the API.