arkindex-base-worker 0.3.6rc5__py3-none-any.whl → 0.3.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/METADATA +14 -16
  2. arkindex_base_worker-0.3.7.post1.dist-info/RECORD +47 -0
  3. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +14 -0
  6. arkindex_worker/image.py +29 -19
  7. arkindex_worker/models.py +14 -2
  8. arkindex_worker/utils.py +17 -3
  9. arkindex_worker/worker/__init__.py +122 -125
  10. arkindex_worker/worker/base.py +25 -45
  11. arkindex_worker/worker/classification.py +18 -25
  12. arkindex_worker/worker/dataset.py +24 -18
  13. arkindex_worker/worker/element.py +45 -6
  14. arkindex_worker/worker/entity.py +35 -4
  15. arkindex_worker/worker/metadata.py +21 -11
  16. arkindex_worker/worker/training.py +16 -0
  17. arkindex_worker/worker/transcription.py +45 -5
  18. arkindex_worker/worker/version.py +22 -0
  19. hooks/pre_gen_project.py +3 -0
  20. tests/conftest.py +15 -7
  21. tests/test_base_worker.py +0 -6
  22. tests/test_dataset_worker.py +292 -410
  23. tests/test_elements_worker/test_classifications.py +365 -539
  24. tests/test_elements_worker/test_cli.py +1 -1
  25. tests/test_elements_worker/test_dataset.py +97 -116
  26. tests/test_elements_worker/test_elements.py +227 -61
  27. tests/test_elements_worker/test_entities.py +22 -2
  28. tests/test_elements_worker/test_metadata.py +53 -27
  29. tests/test_elements_worker/test_training.py +35 -0
  30. tests/test_elements_worker/test_transcriptions.py +149 -16
  31. tests/test_elements_worker/test_worker.py +19 -6
  32. tests/test_image.py +37 -0
  33. tests/test_utils.py +23 -1
  34. worker-demo/tests/__init__.py +0 -0
  35. worker-demo/tests/conftest.py +32 -0
  36. worker-demo/tests/test_worker.py +12 -0
  37. worker-demo/worker_demo/__init__.py +6 -0
  38. worker-demo/worker_demo/worker.py +19 -0
  39. arkindex_base_worker-0.3.6rc5.dist-info/RECORD +0 -41
  40. {arkindex_base_worker-0.3.6rc5.dist-info → arkindex_base_worker-0.3.7.post1.dist-info}/LICENSE +0 -0
@@ -4,6 +4,7 @@ ElementsWorker methods for transcriptions.
4
4
 
5
5
  from collections.abc import Iterable
6
6
  from enum import Enum
7
+ from warnings import warn
7
8
 
8
9
  from peewee import IntegrityError
9
10
 
@@ -366,14 +367,22 @@ class TranscriptionMixin:
366
367
  element_type: str | None = None,
367
368
  recursive: bool | None = None,
368
369
  worker_version: str | bool | None = None,
370
+ worker_run: str | bool | None = None,
369
371
  ) -> Iterable[dict] | Iterable[CachedTranscription]:
370
372
  """
371
373
  List transcriptions on an element.
372
374
 
375
+ Warns:
376
+ ----
377
+ The following parameters are **deprecated**:
378
+
379
+ - `worker_version` in favor of `worker_run`
380
+
373
381
  :param element: The element to list transcriptions on.
374
382
  :param element_type: Restrict to transcriptions whose elements have an element type with this slug.
375
383
  :param recursive: Include transcriptions of any descendant of this element, recursively.
376
- :param worker_version: Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
384
+ :param worker_version: **Deprecated** Restrict to transcriptions created by a worker version with this UUID. Set to False to look for manually created transcriptions.
385
+ :param worker_run: Restrict to transcriptions created by a worker run with this UUID. Set to False to look for manually created transcriptions.
377
386
  :returns: An iterable of dicts representing each transcription,
378
387
  or an iterable of CachedTranscription when cache support is enabled.
379
388
  """
@@ -388,6 +397,11 @@ class TranscriptionMixin:
388
397
  assert isinstance(recursive, bool), "recursive should be of type bool"
389
398
  query_params["recursive"] = recursive
390
399
  if worker_version is not None:
400
+ warn(
401
+ "`worker_version` usage is deprecated. Consider using `worker_run` instead.",
402
+ DeprecationWarning,
403
+ stacklevel=1,
404
+ )
391
405
  assert isinstance(
392
406
  worker_version, str | bool
393
407
  ), "worker_version should be of type str or bool"
@@ -396,6 +410,15 @@ class TranscriptionMixin:
396
410
  worker_version is False
397
411
  ), "if of type bool, worker_version can only be set to False"
398
412
  query_params["worker_version"] = worker_version
413
+ if worker_run is not None:
414
+ assert isinstance(
415
+ worker_run, str | bool
416
+ ), "worker_run should be of type str or bool"
417
+ if isinstance(worker_run, bool):
418
+ assert (
419
+ worker_run is False
420
+ ), "if of type bool, worker_run can only be set to False"
421
+ query_params["worker_run"] = worker_run
399
422
 
400
423
  if self.use_cache:
401
424
  if not recursive:
@@ -427,10 +450,27 @@ class TranscriptionMixin:
427
450
 
428
451
  if worker_version is not None:
429
452
  # If worker_version=False, filter by manual worker_version e.g. None
430
- worker_version_id = worker_version if worker_version else None
431
- transcriptions = transcriptions.where(
432
- CachedTranscription.worker_version_id == worker_version_id
433
- )
453
+ worker_version_id = worker_version or None
454
+ if worker_version_id:
455
+ transcriptions = transcriptions.where(
456
+ CachedTranscription.worker_version_id == worker_version_id
457
+ )
458
+ else:
459
+ transcriptions = transcriptions.where(
460
+ CachedTranscription.worker_version_id.is_null()
461
+ )
462
+
463
+ if worker_run is not None:
464
+ # If worker_run=False, filter by manual worker_run e.g. None
465
+ worker_run_id = worker_run or None
466
+ if worker_run_id:
467
+ transcriptions = transcriptions.where(
468
+ CachedTranscription.worker_run_id == worker_run_id
469
+ )
470
+ else:
471
+ transcriptions = transcriptions.where(
472
+ CachedTranscription.worker_run_id.is_null()
473
+ )
434
474
  else:
435
475
  transcriptions = self.api_client.paginate(
436
476
  "ListTranscriptions", id=element.id, **query_params
@@ -2,10 +2,27 @@
2
2
  ElementsWorker methods for worker versions.
3
3
  """
4
4
 
5
+ import functools
6
+ from warnings import warn
7
+
8
+
9
+ def worker_version_deprecation(func):
10
+ @functools.wraps(func)
11
+ def wrapper(self, *args, **kwargs):
12
+ warn("WorkerVersion usage is deprecated.", DeprecationWarning, stacklevel=2)
13
+ return func(self, *args, **kwargs)
14
+
15
+ return wrapper
16
+
5
17
 
6
18
  class WorkerVersionMixin:
19
+ @worker_version_deprecation
7
20
  def get_worker_version(self, worker_version_id: str) -> dict:
8
21
  """
22
+ Warns:
23
+ ----
24
+ This method is **deprecated**.
25
+
9
26
  Retrieve a worker version, using the [ElementsWorker][arkindex_worker.worker.ElementsWorker]'s internal cache when possible.
10
27
 
11
28
  :param worker_version_id: ID of the worker version to retrieve.
@@ -22,8 +39,13 @@ class WorkerVersionMixin:
22
39
 
23
40
  return worker_version
24
41
 
42
+ @worker_version_deprecation
25
43
  def get_worker_version_slug(self, worker_version_id: str) -> str:
26
44
  """
45
+ Warns:
46
+ ----
47
+ This method is **deprecated**.
48
+
27
49
  Retrieve the slug of the worker of a worker version, from a worker version UUID.
28
50
  Uses a worker version from the internal cache if possible, otherwise makes an API request.
29
51
 
@@ -0,0 +1,3 @@
1
+ # Normalize the slug to generate __package and __module private variables
2
+ {{cookiecutter.update({"__package": cookiecutter.slug.lower().replace("_", "-")})}} # noqa: F821
3
+ {{cookiecutter.update({"__module": cookiecutter.slug.lower().replace("-", "_")})}} # noqa: F821
tests/conftest.py CHANGED
@@ -22,7 +22,7 @@ from arkindex_worker.cache import (
22
22
  create_version_table,
23
23
  init_cache_db,
24
24
  )
25
- from arkindex_worker.models import Artifact, Dataset
25
+ from arkindex_worker.models import Artifact, Dataset, Set
26
26
  from arkindex_worker.worker import BaseWorker, DatasetWorker, ElementsWorker
27
27
  from arkindex_worker.worker.dataset import DatasetState
28
28
  from arkindex_worker.worker.transcription import TextOrientation
@@ -93,7 +93,7 @@ def _setup_api(responses, monkeypatch, _cache_yaml):
93
93
 
94
94
  # Fallback to prod environment
95
95
  if schema_url is None:
96
- schema_url = "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json"
96
+ schema_url = "https://arkindex.teklia.com/api/v1/openapi/?format=json"
97
97
  monkeypatch.setenv("ARKINDEX_API_SCHEMA_URL", schema_url)
98
98
 
99
99
  # Allow accessing remote API schemas
@@ -466,6 +466,7 @@ def _mock_cached_transcriptions(mock_cache_db):
466
466
  confidence=0.42,
467
467
  orientation=TextOrientation.HorizontalLeftToRight,
468
468
  worker_version_id=UUID("56785678-5678-5678-5678-567856785678"),
469
+ worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
469
470
  )
470
471
  CachedTranscription.create(
471
472
  id=UUID("22222222-2222-2222-2222-222222222222"),
@@ -506,6 +507,7 @@ def _mock_cached_transcriptions(mock_cache_db):
506
507
  confidence=0.42,
507
508
  orientation=TextOrientation.HorizontalLeftToRight,
508
509
  worker_version_id=None,
510
+ worker_run_id=None,
509
511
  )
510
512
 
511
513
 
@@ -595,11 +597,11 @@ def mock_databases(tmp_path):
595
597
  @pytest.fixture()
596
598
  def default_dataset():
597
599
  return Dataset(
598
- **{
600
+ {
599
601
  "id": "dataset_id",
600
602
  "name": "My dataset",
601
603
  "description": "A super dataset built by me",
602
- "sets": ["set_1", "set_2", "set_3"],
604
+ "sets": ["set_1", "set_2", "set_3", "set_4"],
603
605
  "state": DatasetState.Open.value,
604
606
  "corpus_id": "corpus_id",
605
607
  "creator": "creator@teklia.com",
@@ -610,6 +612,11 @@ def default_dataset():
610
612
  )
611
613
 
612
614
 
615
+ @pytest.fixture()
616
+ def default_train_set(default_dataset):
617
+ return Set(name="train", dataset=default_dataset)
618
+
619
+
613
620
  @pytest.fixture()
614
621
  def mock_dataset_worker(monkeypatch, mocker, _mock_worker_run_api):
615
622
  monkeypatch.setenv("PONOS_TASK", "my_task")
@@ -632,9 +639,10 @@ def mock_dev_dataset_worker(mocker):
632
639
  [
633
640
  "worker",
634
641
  "--dev",
635
- "--dataset",
636
- "11111111-1111-1111-1111-111111111111",
637
- "22222222-2222-2222-2222-222222222222",
642
+ "--set",
643
+ "11111111-1111-1111-1111-111111111111:train",
644
+ "11111111-1111-1111-1111-111111111111:val",
645
+ "22222222-2222-2222-2222-222222222222:my_set",
638
646
  ],
639
647
  )
640
648
 
tests/test_base_worker.py CHANGED
@@ -86,7 +86,6 @@ def test_cli_default(mocker):
86
86
  assert logger.level == logging.NOTSET
87
87
  assert worker.api_client
88
88
  assert worker.config == {"someKey": "someValue"} # from API
89
- assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
90
89
 
91
90
  logger.setLevel(logging.NOTSET)
92
91
 
@@ -106,7 +105,6 @@ def test_cli_arg_verbose_given(mocker):
106
105
  assert logger.level == logging.DEBUG
107
106
  assert worker.api_client
108
107
  assert worker.config == {"someKey": "someValue"} # from API
109
- assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
110
108
 
111
109
  logger.setLevel(logging.NOTSET)
112
110
 
@@ -126,7 +124,6 @@ def test_cli_envvar_debug_given(mocker, monkeypatch):
126
124
  assert logger.level == logging.DEBUG
127
125
  assert worker.api_client
128
126
  assert worker.config == {"someKey": "someValue"} # from API
129
- assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
130
127
 
131
128
  logger.setLevel(logging.NOTSET)
132
129
 
@@ -142,7 +139,6 @@ def test_configure_dev_mode(mocker):
142
139
 
143
140
  assert worker.args.dev is True
144
141
  assert worker.process_information is None
145
- assert worker.worker_run_id == "56785678-5678-5678-5678-567856785678"
146
142
  assert worker.is_read_only is True
147
143
  assert worker.user_configuration == {}
148
144
 
@@ -212,7 +208,6 @@ def test_configure_worker_run(mocker, responses, caplog):
212
208
  ("arkindex_worker", logging.INFO, "Loaded user configuration from WorkerRun"),
213
209
  ]
214
210
 
215
- assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
216
211
  assert worker.user_configuration == {"a": "b"}
217
212
 
218
213
 
@@ -482,7 +477,6 @@ def test_configure_load_model_configuration(mocker, responses):
482
477
 
483
478
  worker.configure()
484
479
 
485
- assert worker.worker_version_id == "12341234-1234-1234-1234-123412341234"
486
480
  assert worker.model_configuration == {
487
481
  "param1": "value1",
488
482
  "param2": 2,