arkindex-base-worker 0.3.6rc4__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arkindex_base_worker-0.3.7.dist-info/LICENSE +21 -0
- arkindex_base_worker-0.3.7.dist-info/METADATA +77 -0
- arkindex_base_worker-0.3.7.dist-info/RECORD +47 -0
- {arkindex_base_worker-0.3.6rc4.dist-info → arkindex_base_worker-0.3.7.dist-info}/WHEEL +1 -1
- {arkindex_base_worker-0.3.6rc4.dist-info → arkindex_base_worker-0.3.7.dist-info}/top_level.txt +2 -0
- arkindex_worker/cache.py +14 -0
- arkindex_worker/image.py +29 -19
- arkindex_worker/models.py +14 -2
- arkindex_worker/utils.py +17 -3
- arkindex_worker/worker/__init__.py +122 -125
- arkindex_worker/worker/base.py +24 -24
- arkindex_worker/worker/classification.py +18 -25
- arkindex_worker/worker/dataset.py +24 -18
- arkindex_worker/worker/element.py +100 -19
- arkindex_worker/worker/entity.py +35 -4
- arkindex_worker/worker/metadata.py +21 -11
- arkindex_worker/worker/training.py +13 -0
- arkindex_worker/worker/transcription.py +45 -5
- arkindex_worker/worker/version.py +22 -0
- hooks/pre_gen_project.py +3 -0
- tests/conftest.py +16 -8
- tests/test_base_worker.py +0 -6
- tests/test_dataset_worker.py +291 -409
- tests/test_elements_worker/test_classifications.py +365 -539
- tests/test_elements_worker/test_cli.py +1 -1
- tests/test_elements_worker/test_dataset.py +97 -116
- tests/test_elements_worker/test_elements.py +354 -76
- tests/test_elements_worker/test_entities.py +22 -2
- tests/test_elements_worker/test_metadata.py +53 -27
- tests/test_elements_worker/test_training.py +35 -0
- tests/test_elements_worker/test_transcriptions.py +149 -16
- tests/test_elements_worker/test_worker.py +19 -6
- tests/test_image.py +37 -0
- tests/test_utils.py +23 -1
- worker-demo/tests/__init__.py +0 -0
- worker-demo/tests/conftest.py +32 -0
- worker-demo/tests/test_worker.py +12 -0
- worker-demo/worker_demo/__init__.py +6 -0
- worker-demo/worker_demo/worker.py +19 -0
- arkindex_base_worker-0.3.6rc4.dist-info/METADATA +0 -47
- arkindex_base_worker-0.3.6rc4.dist-info/RECORD +0 -40
|
@@ -1,31 +1,31 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Base classes to implement Arkindex workers.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import contextlib
|
|
5
6
|
import json
|
|
6
7
|
import os
|
|
7
8
|
import sys
|
|
8
9
|
import uuid
|
|
10
|
+
from argparse import ArgumentTypeError
|
|
9
11
|
from collections.abc import Iterable, Iterator
|
|
10
12
|
from enum import Enum
|
|
11
|
-
from itertools import groupby
|
|
12
|
-
from operator import itemgetter
|
|
13
13
|
from pathlib import Path
|
|
14
14
|
|
|
15
15
|
from apistar.exceptions import ErrorResponse
|
|
16
16
|
|
|
17
17
|
from arkindex_worker import logger
|
|
18
18
|
from arkindex_worker.cache import CachedElement
|
|
19
|
-
from arkindex_worker.models import Dataset, Element
|
|
19
|
+
from arkindex_worker.models import Dataset, Element, Set
|
|
20
20
|
from arkindex_worker.worker.base import BaseWorker
|
|
21
21
|
from arkindex_worker.worker.classification import ClassificationMixin
|
|
22
22
|
from arkindex_worker.worker.dataset import DatasetMixin, DatasetState
|
|
23
23
|
from arkindex_worker.worker.element import ElementMixin
|
|
24
|
-
from arkindex_worker.worker.entity import EntityMixin
|
|
24
|
+
from arkindex_worker.worker.entity import EntityMixin
|
|
25
25
|
from arkindex_worker.worker.metadata import MetaDataMixin, MetaType # noqa: F401
|
|
26
26
|
from arkindex_worker.worker.task import TaskMixin
|
|
27
27
|
from arkindex_worker.worker.transcription import TranscriptionMixin
|
|
28
|
-
from arkindex_worker.worker.version import WorkerVersionMixin
|
|
28
|
+
from arkindex_worker.worker.version import WorkerVersionMixin
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
class ActivityState(Enum):
|
|
@@ -159,6 +159,16 @@ class ElementsWorker(
|
|
|
159
159
|
super().configure()
|
|
160
160
|
super().configure_cache()
|
|
161
161
|
|
|
162
|
+
# Retrieve the model configuration
|
|
163
|
+
if self.model_configuration:
|
|
164
|
+
self.config.update(self.model_configuration)
|
|
165
|
+
logger.info("Model version configuration retrieved")
|
|
166
|
+
|
|
167
|
+
# Retrieve the user configuration
|
|
168
|
+
if self.user_configuration:
|
|
169
|
+
self.config.update(self.user_configuration)
|
|
170
|
+
logger.info("User configuration retrieved")
|
|
171
|
+
|
|
162
172
|
def run(self):
|
|
163
173
|
"""
|
|
164
174
|
Implements an Arkindex worker that goes through each element returned by
|
|
@@ -229,12 +239,13 @@ class ElementsWorker(
|
|
|
229
239
|
with contextlib.suppress(Exception):
|
|
230
240
|
self.update_activity(element.id, ActivityState.Error)
|
|
231
241
|
|
|
242
|
+
message = f'Ran on {count} element{"s"[:count>1]}: {count - failed} completed, {failed} failed'
|
|
232
243
|
if failed:
|
|
233
|
-
logger.error(
|
|
234
|
-
f"Ran on {count} elements: {count - failed} completed, {failed} failed"
|
|
235
|
-
)
|
|
244
|
+
logger.error(message)
|
|
236
245
|
if failed >= count: # Everything failed!
|
|
237
246
|
sys.exit(1)
|
|
247
|
+
else:
|
|
248
|
+
logger.info(message)
|
|
238
249
|
|
|
239
250
|
def process_element(self, element: Element | CachedElement):
|
|
240
251
|
"""
|
|
@@ -299,6 +310,21 @@ class ElementsWorker(
|
|
|
299
310
|
return True
|
|
300
311
|
|
|
301
312
|
|
|
313
|
+
def check_dataset_set(value: str) -> tuple[uuid.UUID, str]:
|
|
314
|
+
values = value.split(":")
|
|
315
|
+
if len(values) != 2:
|
|
316
|
+
raise ArgumentTypeError(
|
|
317
|
+
f"'{value}' is not in the correct format `<dataset_id>:<set_name>`"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
dataset_id, set_name = values
|
|
321
|
+
try:
|
|
322
|
+
dataset_id = uuid.UUID(dataset_id)
|
|
323
|
+
return (dataset_id, set_name)
|
|
324
|
+
except (TypeError, ValueError) as e:
|
|
325
|
+
raise ArgumentTypeError(f"'{dataset_id}' should be a valid UUID") from e
|
|
326
|
+
|
|
327
|
+
|
|
302
328
|
class MissingDatasetArchive(Exception):
|
|
303
329
|
"""
|
|
304
330
|
Exception raised when the compressed archive associated to
|
|
@@ -308,7 +334,7 @@ class MissingDatasetArchive(Exception):
|
|
|
308
334
|
|
|
309
335
|
class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
|
|
310
336
|
"""
|
|
311
|
-
Base class for ML workers that operate on Arkindex
|
|
337
|
+
Base class for ML workers that operate on Arkindex dataset sets.
|
|
312
338
|
|
|
313
339
|
This class inherits from numerous mixin classes found in other modules of
|
|
314
340
|
``arkindex.worker``, which provide helpers to read and write to the Arkindex API.
|
|
@@ -318,24 +344,28 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
|
|
|
318
344
|
self,
|
|
319
345
|
description: str = "Arkindex Dataset Worker",
|
|
320
346
|
support_cache: bool = False,
|
|
321
|
-
generator: bool = False,
|
|
322
347
|
):
|
|
323
348
|
"""
|
|
324
349
|
:param description: The worker's description.
|
|
325
350
|
:param support_cache: Whether the worker supports cache.
|
|
326
|
-
:param generator: Whether the worker generates the dataset archive artifact.
|
|
327
351
|
"""
|
|
328
352
|
super().__init__(description, support_cache)
|
|
329
353
|
|
|
354
|
+
# Path to the dataset compressed archive (containing images and a SQLite database)
|
|
355
|
+
# Set as an instance variable as dataset workers might use it to easily extract its content
|
|
356
|
+
self.downloaded_dataset_artifact: Path | None = None
|
|
357
|
+
|
|
330
358
|
self.parser.add_argument(
|
|
331
|
-
"--
|
|
332
|
-
type=
|
|
359
|
+
"--set",
|
|
360
|
+
type=check_dataset_set,
|
|
333
361
|
nargs="+",
|
|
334
|
-
help="
|
|
362
|
+
help="""
|
|
363
|
+
One or more Arkindex dataset sets, format is <dataset_uuid>:<set_name>
|
|
364
|
+
(e.g.: "12341234-1234-1234-1234-123412341234:train")
|
|
365
|
+
""",
|
|
366
|
+
default=[],
|
|
335
367
|
)
|
|
336
368
|
|
|
337
|
-
self.generator = generator
|
|
338
|
-
|
|
339
369
|
def configure(self):
|
|
340
370
|
"""
|
|
341
371
|
Setup the worker using CLI arguments and environment variables.
|
|
@@ -349,163 +379,130 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
|
|
|
349
379
|
super().configure()
|
|
350
380
|
super().configure_cache()
|
|
351
381
|
|
|
352
|
-
|
|
382
|
+
# Retrieve the model configuration
|
|
383
|
+
if self.model_configuration:
|
|
384
|
+
self.config.update(self.model_configuration)
|
|
385
|
+
logger.info("Model version configuration retrieved")
|
|
386
|
+
|
|
387
|
+
# Retrieve the user configuration
|
|
388
|
+
if self.user_configuration:
|
|
389
|
+
self.config.update(self.user_configuration)
|
|
390
|
+
logger.info("User configuration retrieved")
|
|
391
|
+
|
|
392
|
+
def cleanup_downloaded_artifact(self) -> None:
|
|
393
|
+
"""
|
|
394
|
+
Cleanup the downloaded dataset artifact if any
|
|
395
|
+
"""
|
|
396
|
+
if not self.downloaded_dataset_artifact:
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
self.downloaded_dataset_artifact.unlink(missing_ok=True)
|
|
400
|
+
|
|
401
|
+
def download_dataset_artifact(self, dataset: Dataset) -> None:
|
|
353
402
|
"""
|
|
354
403
|
Find and download the compressed archive artifact describing a dataset using
|
|
355
404
|
the [list_artifacts][arkindex_worker.worker.task.TaskMixin.list_artifacts] and
|
|
356
405
|
[download_artifact][arkindex_worker.worker.task.TaskMixin.download_artifact] methods.
|
|
357
406
|
|
|
358
407
|
:param dataset: The dataset to retrieve the compressed archive artifact for.
|
|
359
|
-
:returns: A path to the downloaded artifact.
|
|
360
408
|
:raises MissingDatasetArchive: When the dataset artifact is not found.
|
|
361
409
|
"""
|
|
410
|
+
extra_dir = self.find_extras_directory()
|
|
411
|
+
archive = extra_dir / dataset.filepath
|
|
412
|
+
if archive.exists():
|
|
413
|
+
return
|
|
362
414
|
|
|
363
|
-
|
|
415
|
+
# Cleanup the dataset artifact that was downloaded previously
|
|
416
|
+
self.cleanup_downloaded_artifact()
|
|
364
417
|
|
|
418
|
+
logger.info(f"Downloading artifact for {dataset}")
|
|
419
|
+
task_id = uuid.UUID(dataset.task_id)
|
|
365
420
|
for artifact in self.list_artifacts(task_id):
|
|
366
421
|
if artifact.path != dataset.filepath:
|
|
367
422
|
continue
|
|
368
423
|
|
|
369
|
-
extra_dir = self.find_extras_directory()
|
|
370
|
-
archive = extra_dir / dataset.filepath
|
|
371
424
|
archive.write_bytes(self.download_artifact(task_id, artifact).read())
|
|
372
|
-
|
|
425
|
+
self.downloaded_dataset_artifact = archive
|
|
426
|
+
return
|
|
373
427
|
|
|
374
428
|
raise MissingDatasetArchive(
|
|
375
429
|
"The dataset compressed archive artifact was not found."
|
|
376
430
|
)
|
|
377
431
|
|
|
378
|
-
def
|
|
379
|
-
self, dataset: Dataset
|
|
380
|
-
) -> Iterator[tuple[str, list[Element]]]:
|
|
381
|
-
"""
|
|
382
|
-
List the elements in the dataset, grouped by split, using the
|
|
383
|
-
[list_dataset_elements][arkindex_worker.worker.dataset.DatasetMixin.list_dataset_elements] method.
|
|
384
|
-
|
|
385
|
-
:param dataset: The dataset to retrieve elements from.
|
|
386
|
-
:returns: An iterator of tuples containing the split name and the list of its elements.
|
|
387
|
-
"""
|
|
388
|
-
|
|
389
|
-
def format_split(
|
|
390
|
-
split: tuple[str, Iterator[tuple[str, Element]]],
|
|
391
|
-
) -> tuple[str, list[Element]]:
|
|
392
|
-
return (split[0], list(map(itemgetter(1), list(split[1]))))
|
|
393
|
-
|
|
394
|
-
return map(
|
|
395
|
-
format_split,
|
|
396
|
-
groupby(
|
|
397
|
-
sorted(self.list_dataset_elements(dataset), key=itemgetter(0)),
|
|
398
|
-
key=itemgetter(0),
|
|
399
|
-
),
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
def process_dataset(self, dataset: Dataset):
|
|
432
|
+
def process_set(self, set: Set):
|
|
403
433
|
"""
|
|
404
|
-
Override this method to implement your worker and process a single Arkindex dataset at once.
|
|
434
|
+
Override this method to implement your worker and process a single Arkindex dataset set at once.
|
|
405
435
|
|
|
406
|
-
:param
|
|
436
|
+
:param set: The set to process.
|
|
407
437
|
"""
|
|
408
438
|
|
|
409
|
-
def
|
|
439
|
+
def list_sets(self) -> Iterator[Set]:
|
|
410
440
|
"""
|
|
411
|
-
List the
|
|
412
|
-
[
|
|
441
|
+
List the sets to be processed, either from the CLI arguments or using the
|
|
442
|
+
[list_process_sets][arkindex_worker.worker.dataset.DatasetMixin.list_process_sets] method.
|
|
413
443
|
|
|
414
|
-
:returns: An iterator of
|
|
415
|
-
else an iterator of ``Dataset`` objects.
|
|
444
|
+
:returns: An iterator of ``Set`` objects.
|
|
416
445
|
"""
|
|
417
|
-
if self.is_read_only:
|
|
418
|
-
|
|
446
|
+
if not self.is_read_only:
|
|
447
|
+
yield from self.list_process_sets()
|
|
448
|
+
|
|
449
|
+
datasets: dict[uuid.UUID, Dataset] = {}
|
|
450
|
+
for dataset_id, set_name in self.args.set:
|
|
451
|
+
# Retrieving dataset information is not already cached
|
|
452
|
+
if dataset_id not in datasets:
|
|
453
|
+
datasets[dataset_id] = Dataset(
|
|
454
|
+
**self.request("RetrieveDataset", id=dataset_id)
|
|
455
|
+
)
|
|
419
456
|
|
|
420
|
-
|
|
457
|
+
yield Set(name=set_name, dataset=datasets[dataset_id])
|
|
421
458
|
|
|
422
459
|
def run(self):
|
|
423
460
|
"""
|
|
424
|
-
Implements an Arkindex worker that goes through each dataset returned by
|
|
425
|
-
[
|
|
461
|
+
Implements an Arkindex worker that goes through each dataset set returned by
|
|
462
|
+
[list_sets][arkindex_worker.worker.DatasetWorker.list_sets].
|
|
426
463
|
|
|
427
|
-
It calls [
|
|
428
|
-
catching exceptions
|
|
429
|
-
when the worker is a generator.
|
|
464
|
+
It calls [process_set][arkindex_worker.worker.DatasetWorker.process_set],
|
|
465
|
+
catching exceptions.
|
|
430
466
|
"""
|
|
431
467
|
self.configure()
|
|
432
468
|
|
|
433
|
-
|
|
434
|
-
if not
|
|
435
|
-
logger.warning("No
|
|
469
|
+
dataset_sets: list[Set] = list(self.list_sets())
|
|
470
|
+
if not dataset_sets:
|
|
471
|
+
logger.warning("No sets to process, stopping.")
|
|
436
472
|
sys.exit(1)
|
|
437
473
|
|
|
438
|
-
# Process every
|
|
439
|
-
count = len(
|
|
474
|
+
# Process every set
|
|
475
|
+
count = len(dataset_sets)
|
|
440
476
|
failed = 0
|
|
441
|
-
for i,
|
|
442
|
-
dataset = None
|
|
443
|
-
dataset_artifact = None
|
|
444
|
-
|
|
477
|
+
for i, dataset_set in enumerate(dataset_sets, start=1):
|
|
445
478
|
try:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
else:
|
|
450
|
-
# Load dataset using the Arkindex API
|
|
451
|
-
dataset = Dataset(**self.request("RetrieveDataset", id=item))
|
|
452
|
-
|
|
453
|
-
if self.generator:
|
|
454
|
-
assert (
|
|
455
|
-
dataset.state == DatasetState.Open.value
|
|
456
|
-
), "When generating a new dataset, its state should be Open."
|
|
457
|
-
else:
|
|
458
|
-
assert (
|
|
459
|
-
dataset.state == DatasetState.Complete.value
|
|
460
|
-
), "When processing an existing dataset, its state should be Complete."
|
|
461
|
-
|
|
462
|
-
logger.info(f"Processing {dataset} ({i}/{count})")
|
|
463
|
-
|
|
464
|
-
if self.generator:
|
|
465
|
-
# Update the dataset state to Building
|
|
466
|
-
logger.info(f"Building {dataset} ({i}/{count})")
|
|
467
|
-
self.update_dataset_state(dataset, DatasetState.Building)
|
|
468
|
-
else:
|
|
469
|
-
logger.info(f"Downloading data for {dataset} ({i}/{count})")
|
|
470
|
-
dataset_artifact = self.download_dataset_artifact(dataset)
|
|
479
|
+
assert (
|
|
480
|
+
dataset_set.dataset.state == DatasetState.Complete.value
|
|
481
|
+
), "When processing a set, its dataset state should be Complete."
|
|
471
482
|
|
|
472
|
-
|
|
473
|
-
self.
|
|
483
|
+
logger.info(f"Retrieving data for {dataset_set} ({i}/{count})")
|
|
484
|
+
self.download_dataset_artifact(dataset_set.dataset)
|
|
474
485
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
logger.info(f"Completed {dataset} ({i}/{count})")
|
|
478
|
-
self.update_dataset_state(dataset, DatasetState.Complete)
|
|
486
|
+
logger.info(f"Processing {dataset_set} ({i}/{count})")
|
|
487
|
+
self.process_set(dataset_set)
|
|
479
488
|
except Exception as e:
|
|
480
|
-
# Handle errors occurring while retrieving
|
|
489
|
+
# Handle errors occurring while retrieving or processing this dataset set
|
|
481
490
|
failed += 1
|
|
482
491
|
|
|
483
|
-
# Handle the case where we failed retrieving the dataset
|
|
484
|
-
dataset_id = dataset.id if dataset else item
|
|
485
|
-
|
|
486
492
|
if isinstance(e, ErrorResponse):
|
|
487
|
-
message = f"An API error occurred while processing
|
|
493
|
+
message = f"An API error occurred while processing {dataset_set}: {e.title} - {e.content}"
|
|
488
494
|
else:
|
|
489
|
-
message = (
|
|
490
|
-
f"Failed running worker on dataset {dataset_id}: {repr(e)}"
|
|
491
|
-
)
|
|
495
|
+
message = f"Failed running worker on {dataset_set}: {repr(e)}"
|
|
492
496
|
|
|
493
|
-
logger.warning(
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
if dataset and self.generator:
|
|
498
|
-
# Try to update the state to Error regardless of the response
|
|
499
|
-
with contextlib.suppress(Exception):
|
|
500
|
-
self.update_dataset_state(dataset, DatasetState.Error)
|
|
501
|
-
finally:
|
|
502
|
-
# Cleanup the dataset artifact if it was downloaded, no matter what
|
|
503
|
-
if dataset_artifact:
|
|
504
|
-
dataset_artifact.unlink(missing_ok=True)
|
|
497
|
+
logger.warning(message, exc_info=e if self.args.verbose else None)
|
|
498
|
+
|
|
499
|
+
# Cleanup the latest downloaded dataset artifact
|
|
500
|
+
self.cleanup_downloaded_artifact()
|
|
505
501
|
|
|
502
|
+
message = f'Ran on {count} set{"s"[:count>1]}: {count - failed} completed, {failed} failed'
|
|
506
503
|
if failed:
|
|
507
|
-
logger.error(
|
|
508
|
-
f"Ran on {count} datasets: {count - failed} completed, {failed} failed"
|
|
509
|
-
)
|
|
504
|
+
logger.error(message)
|
|
510
505
|
if failed >= count: # Everything failed!
|
|
511
506
|
sys.exit(1)
|
|
507
|
+
else:
|
|
508
|
+
logger.info(message)
|
arkindex_worker/worker/base.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
The base class for all Arkindex workers.
|
|
3
3
|
"""
|
|
4
|
+
|
|
4
5
|
import argparse
|
|
5
6
|
import json
|
|
6
7
|
import logging
|
|
@@ -20,7 +21,6 @@ from tenacity import (
|
|
|
20
21
|
wait_exponential,
|
|
21
22
|
)
|
|
22
23
|
|
|
23
|
-
from arkindex import ArkindexClient, options_from_env
|
|
24
24
|
from arkindex_worker import logger
|
|
25
25
|
from arkindex_worker.cache import (
|
|
26
26
|
check_version,
|
|
@@ -30,18 +30,7 @@ from arkindex_worker.cache import (
|
|
|
30
30
|
merge_parents_cache,
|
|
31
31
|
)
|
|
32
32
|
from arkindex_worker.utils import close_delete_file, extract_tar_zst_archive
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def _is_500_error(exc: Exception) -> bool:
|
|
36
|
-
"""
|
|
37
|
-
Check if an Arkindex API error has a HTTP 5xx error code.
|
|
38
|
-
Used to retry most API calls in [BaseWorker][arkindex_worker.worker.base.BaseWorker].
|
|
39
|
-
:param exc: Exception to check
|
|
40
|
-
"""
|
|
41
|
-
if not isinstance(exc, ErrorResponse):
|
|
42
|
-
return False
|
|
43
|
-
|
|
44
|
-
return 500 <= exc.status_code < 600
|
|
33
|
+
from teklia_toolbox.requests import _get_arkindex_client, _is_500_error
|
|
45
34
|
|
|
46
35
|
|
|
47
36
|
class ExtrasDirNotFoundError(Exception):
|
|
@@ -72,7 +61,7 @@ class BaseWorker:
|
|
|
72
61
|
self.parser.add_argument(
|
|
73
62
|
"-c",
|
|
74
63
|
"--config",
|
|
75
|
-
help="Alternative configuration file when running without a Worker
|
|
64
|
+
help="Alternative configuration file when running without a Worker Run ID",
|
|
76
65
|
type=open,
|
|
77
66
|
)
|
|
78
67
|
self.parser.add_argument(
|
|
@@ -94,7 +83,7 @@ class BaseWorker:
|
|
|
94
83
|
"--dev",
|
|
95
84
|
help=(
|
|
96
85
|
"Run worker in developer mode. "
|
|
97
|
-
"Worker will be in read-only state even if a
|
|
86
|
+
"Worker will be in read-only state even if a worker run is supplied. "
|
|
98
87
|
),
|
|
99
88
|
action="store_true",
|
|
100
89
|
default=False,
|
|
@@ -148,6 +137,13 @@ class BaseWorker:
|
|
|
148
137
|
# there is at least one available sqlite database either given or in the parent tasks
|
|
149
138
|
self.use_cache = False
|
|
150
139
|
|
|
140
|
+
# model_version_id will be updated in configure() using the worker_run's model version
|
|
141
|
+
# or in configure_for_developers() from the environment
|
|
142
|
+
self.model_version_id = None
|
|
143
|
+
# model_details will be updated in configure() using the worker_run's model version
|
|
144
|
+
# or in configure_for_developers() from the environment
|
|
145
|
+
self.model_details = {}
|
|
146
|
+
|
|
151
147
|
# task_parents will be updated in configure_cache() if the cache is supported,
|
|
152
148
|
# if the task ID is set and if no database is passed as argument
|
|
153
149
|
self.task_parents = []
|
|
@@ -176,12 +172,20 @@ class BaseWorker:
|
|
|
176
172
|
"""
|
|
177
173
|
return self.args.dev or self.worker_run_id is None
|
|
178
174
|
|
|
175
|
+
@property
|
|
176
|
+
def worker_version_id(self):
|
|
177
|
+
"""Deprecated property previously used to retrieve the current WorkerVersion ID.
|
|
178
|
+
|
|
179
|
+
:raises DeprecationWarning: Whenever `worker_version_id` is used.
|
|
180
|
+
"""
|
|
181
|
+
raise DeprecationWarning("`worker_version_id` usage is deprecated")
|
|
182
|
+
|
|
179
183
|
def setup_api_client(self):
|
|
180
184
|
"""
|
|
181
185
|
Create an ArkindexClient to make API requests towards Arkindex instances.
|
|
182
186
|
"""
|
|
183
187
|
# Build Arkindex API client from environment variables
|
|
184
|
-
self.api_client =
|
|
188
|
+
self.api_client = _get_arkindex_client()
|
|
185
189
|
logger.debug(f"Setup Arkindex API client on {self.api_client.document.url}")
|
|
186
190
|
|
|
187
191
|
def configure_for_developers(self):
|
|
@@ -243,25 +247,21 @@ class BaseWorker:
|
|
|
243
247
|
|
|
244
248
|
# Load worker version information
|
|
245
249
|
worker_version = worker_run["worker_version"]
|
|
246
|
-
|
|
247
|
-
# Store worker version id
|
|
248
|
-
self.worker_version_id = worker_version["id"]
|
|
249
|
-
|
|
250
250
|
self.worker_details = worker_version["worker"]
|
|
251
251
|
|
|
252
252
|
logger.info(f"Loaded {worker_run['summary']} from API")
|
|
253
253
|
|
|
254
254
|
# Load model version configuration when available
|
|
255
255
|
model_version = worker_run.get("model_version")
|
|
256
|
-
if model_version
|
|
256
|
+
if model_version:
|
|
257
257
|
logger.info("Loaded model version configuration from WorkerRun")
|
|
258
|
-
self.model_configuration.update(model_version
|
|
258
|
+
self.model_configuration.update(model_version["configuration"])
|
|
259
259
|
|
|
260
260
|
# Set model_version ID as worker attribute
|
|
261
|
-
self.model_version_id = model_version
|
|
261
|
+
self.model_version_id = model_version["id"]
|
|
262
262
|
|
|
263
263
|
# Set model details as worker attribute
|
|
264
|
-
self.model_details = model_version
|
|
264
|
+
self.model_details = model_version["model"]
|
|
265
265
|
|
|
266
266
|
# Retrieve initial configuration from API
|
|
267
267
|
self.config = worker_version["configuration"].get("configuration", {})
|
|
@@ -2,8 +2,6 @@
|
|
|
2
2
|
ElementsWorker methods for classifications and ML classes.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from uuid import UUID
|
|
6
|
-
|
|
7
5
|
from apistar.exceptions import ErrorResponse
|
|
8
6
|
from peewee import IntegrityError
|
|
9
7
|
|
|
@@ -154,13 +152,6 @@ class ClassificationMixin:
|
|
|
154
152
|
# Detect already existing classification
|
|
155
153
|
if e.status_code == 400 and "non_field_errors" in e.content:
|
|
156
154
|
if (
|
|
157
|
-
"The fields element, worker_version, ml_class must make a unique set."
|
|
158
|
-
in e.content["non_field_errors"]
|
|
159
|
-
):
|
|
160
|
-
logger.warning(
|
|
161
|
-
f"This worker version has already set {ml_class} on element {element.id}"
|
|
162
|
-
)
|
|
163
|
-
elif (
|
|
164
155
|
"The fields element, worker_run, ml_class must make a unique set."
|
|
165
156
|
in e.content["non_field_errors"]
|
|
166
157
|
):
|
|
@@ -185,10 +176,14 @@ class ClassificationMixin:
|
|
|
185
176
|
Create multiple classifications at once on the given element through the API.
|
|
186
177
|
|
|
187
178
|
:param element: The element to create classifications on.
|
|
188
|
-
:param classifications:
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
179
|
+
:param classifications: A list of dicts representing a classification each, with the following keys:
|
|
180
|
+
|
|
181
|
+
ml_class (str)
|
|
182
|
+
Required. Name of the MLClass to use.
|
|
183
|
+
confidence (float)
|
|
184
|
+
Required. Confidence score for the classification. Must be between 0 and 1.
|
|
185
|
+
high_confidence (bool)
|
|
186
|
+
Optional. Whether or not the classification is of high confidence.
|
|
192
187
|
|
|
193
188
|
:returns: List of created classifications, as returned in the ``classifications`` field by
|
|
194
189
|
the ``CreateClassifications`` API endpoint.
|
|
@@ -201,18 +196,10 @@ class ClassificationMixin:
|
|
|
201
196
|
), "classifications shouldn't be null and should be of type list"
|
|
202
197
|
|
|
203
198
|
for index, classification in enumerate(classifications):
|
|
204
|
-
|
|
199
|
+
ml_class = classification.get("ml_class")
|
|
205
200
|
assert (
|
|
206
|
-
|
|
207
|
-
), f"Classification at index {index} in classifications:
|
|
208
|
-
|
|
209
|
-
# Make sure it's a valid UUID
|
|
210
|
-
try:
|
|
211
|
-
UUID(ml_class_id)
|
|
212
|
-
except ValueError as e:
|
|
213
|
-
raise ValueError(
|
|
214
|
-
f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
|
|
215
|
-
) from e
|
|
201
|
+
ml_class and isinstance(ml_class, str)
|
|
202
|
+
), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str"
|
|
216
203
|
|
|
217
204
|
confidence = classification.get("confidence")
|
|
218
205
|
assert (
|
|
@@ -238,7 +225,13 @@ class ClassificationMixin:
|
|
|
238
225
|
body={
|
|
239
226
|
"parent": str(element.id),
|
|
240
227
|
"worker_run_id": self.worker_run_id,
|
|
241
|
-
"classifications":
|
|
228
|
+
"classifications": [
|
|
229
|
+
{
|
|
230
|
+
**classification,
|
|
231
|
+
"ml_class": self.get_ml_class_id(classification["ml_class"]),
|
|
232
|
+
}
|
|
233
|
+
for classification in classifications
|
|
234
|
+
],
|
|
242
235
|
},
|
|
243
236
|
)["classifications"]
|
|
244
237
|
|
|
@@ -6,7 +6,8 @@ from collections.abc import Iterator
|
|
|
6
6
|
from enum import Enum
|
|
7
7
|
|
|
8
8
|
from arkindex_worker import logger
|
|
9
|
-
from arkindex_worker.
|
|
9
|
+
from arkindex_worker.cache import unsupported_cache
|
|
10
|
+
from arkindex_worker.models import Dataset, Element, Set
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class DatasetState(Enum):
|
|
@@ -36,38 +37,43 @@ class DatasetState(Enum):
|
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class DatasetMixin:
|
|
39
|
-
def
|
|
40
|
+
def list_process_sets(self) -> Iterator[Set]:
|
|
40
41
|
"""
|
|
41
|
-
List
|
|
42
|
+
List dataset sets associated to the worker's process. This helper is not available in developer mode.
|
|
42
43
|
|
|
43
|
-
:returns: An iterator of ``
|
|
44
|
+
:returns: An iterator of ``Set`` objects built from the ``ListProcessSets`` API endpoint.
|
|
44
45
|
"""
|
|
45
46
|
assert not self.is_read_only, "This helper is not available in read-only mode."
|
|
46
47
|
|
|
47
48
|
results = self.api_client.paginate(
|
|
48
|
-
"
|
|
49
|
+
"ListProcessSets", id=self.process_information["id"]
|
|
49
50
|
)
|
|
50
51
|
|
|
51
|
-
return map(
|
|
52
|
+
return map(
|
|
53
|
+
lambda result: Set(
|
|
54
|
+
name=result["set_name"], dataset=Dataset(**result["dataset"])
|
|
55
|
+
),
|
|
56
|
+
results,
|
|
57
|
+
)
|
|
52
58
|
|
|
53
|
-
def
|
|
59
|
+
def list_set_elements(self, dataset_set: Set) -> Iterator[Element]:
|
|
54
60
|
"""
|
|
55
|
-
List elements in a dataset.
|
|
61
|
+
List elements in a dataset set.
|
|
56
62
|
|
|
57
|
-
:param
|
|
58
|
-
:returns: An iterator of
|
|
63
|
+
:param dataset_set: Set to find elements in.
|
|
64
|
+
:returns: An iterator of Element built from the ``ListDatasetElements`` API endpoint.
|
|
59
65
|
"""
|
|
60
|
-
assert
|
|
61
|
-
|
|
62
|
-
), "
|
|
66
|
+
assert dataset_set and isinstance(
|
|
67
|
+
dataset_set, Set
|
|
68
|
+
), "dataset_set shouldn't be null and should be a Set"
|
|
63
69
|
|
|
64
|
-
results = self.api_client.paginate(
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
return (result["set"], Element(**result["element"]))
|
|
70
|
+
results = self.api_client.paginate(
|
|
71
|
+
"ListDatasetElements", id=dataset_set.dataset.id, set=dataset_set.name
|
|
72
|
+
)
|
|
68
73
|
|
|
69
|
-
return map(
|
|
74
|
+
return map(lambda result: Element(**result["element"]), results)
|
|
70
75
|
|
|
76
|
+
@unsupported_cache
|
|
71
77
|
def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
|
|
72
78
|
"""
|
|
73
79
|
Partially updates a dataset state through the API.
|