arkindex-base-worker 0.4.0b2__py3-none-any.whl → 0.4.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ ElementsWorker methods for elements and element types.
3
3
  """
4
4
 
5
5
  from collections.abc import Iterable
6
+ from operator import attrgetter
6
7
  from typing import NamedTuple
7
8
  from uuid import UUID
8
9
  from warnings import warn
@@ -346,6 +347,52 @@ class ElementMixin:
346
347
  child=child.id,
347
348
  )
348
349
 
350
+ @unsupported_cache
351
+ @batch_publication
352
+ def create_element_children(
353
+ self,
354
+ parent: Element,
355
+ children: list[Element],
356
+ batch_size: int = DEFAULT_BATCH_SIZE,
357
+ ) -> list[str]:
358
+ """
359
+ Link multiple elements to a single parent through the API.
360
+
361
+ :param parent: Parent element.
362
+ :param children: A list of child elements.
363
+ :param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
364
+
365
+ :returns: A list containing the string UUID of each child linked to the parent.
366
+ """
367
+ assert parent and isinstance(
368
+ parent, Element
369
+ ), "parent shouldn't be null and should be of type Element"
370
+
371
+ assert children and isinstance(
372
+ children, list
373
+ ), "children shouldn't be null and should be of type list"
374
+
375
+ for index, child in enumerate(children):
376
+ assert isinstance(
377
+ child, Element
378
+ ), f"Child at index {index} in children: Should be of type Element"
379
+
380
+ if self.is_read_only:
381
+ logger.warning("Cannot link elements as this worker is in read-only mode")
382
+ return
383
+
384
+ return [
385
+ child_id
386
+ for batch in make_batches(children, "child", batch_size)
387
+ for child_id in self.api_client.request(
388
+ "CreateElementChildren",
389
+ id=parent.id,
390
+ body={
391
+ "children": list(map(attrgetter("id"), batch)),
392
+ },
393
+ )["children"]
394
+ ]
395
+
349
396
  def partial_update_element(
350
397
  self, element: Element | CachedElement, **kwargs
351
398
  ) -> dict:
@@ -436,6 +483,178 @@ class ElementMixin:
436
483
 
437
484
  return updated_element
438
485
 
486
+ def list_elements(
487
+ self,
488
+ folder: bool | None = None,
489
+ name: str | None = None,
490
+ top_level: bool | None = None,
491
+ transcription_worker_version: str | bool | None = None,
492
+ transcription_worker_run: str | bool | None = None,
493
+ type: str | None = None,
494
+ with_classes: bool | None = None,
495
+ with_corpus: bool | None = None,
496
+ with_metadata: bool | None = None,
497
+ with_has_children: bool | None = None,
498
+ with_zone: bool | None = None,
499
+ worker_version: str | bool | None = None,
500
+ worker_run: str | bool | None = None,
501
+ ) -> Iterable[dict] | Iterable[CachedElement]:
502
+ """
503
+ List element in a corpus.
504
+
505
+ Warns:
506
+ ----
507
+ The following parameters are **deprecated**:
508
+
509
+ - `transcription_worker_version` in favor of `transcription_worker_run`
510
+ - `worker_version` in favor of `worker_run`
511
+
512
+ :param folder: Restrict to or exclude elements with folder types.
513
+ This parameter is not supported when caching is enabled.
514
+ :param name: Restrict to elements whose name contain a substring (case-insensitive).
515
+ This parameter is not supported when caching is enabled.
516
+ :param top_level: Restrict to or exclude folder elements without parent elements (top-level elements).
517
+ This parameter is not supported when caching is enabled.
518
+ :param transcription_worker_version: **Deprecated** Restrict to elements that have a transcription created by a worker version with this UUID. Set to False to look for elements that have a manual transcription.
519
+ This parameter is not supported when caching is enabled.
520
+ :param transcription_worker_run: Restrict to elements that have a transcription created by a worker run with this UUID. Set to False to look for elements that have a manual transcription.
521
+ This parameter is not supported when caching is enabled.
522
+ :param type: Restrict to elements with a specific type slug
523
+ This parameter is not supported when caching is enabled.
524
+ :param with_classes: Include each element's classifications in the response.
525
+ This parameter is not supported when caching is enabled.
526
+ :param with_corpus: Include each element's corpus in the response.
527
+ This parameter is not supported when caching is enabled.
528
+ :param with_has_children: Include the ``has_children`` attribute in the response,
529
+ indicating if this element has child elements of its own.
530
+ This parameter is not supported when caching is enabled.
531
+ :param with_metadata: Include each element's metadata in the response.
532
+ This parameter is not supported when caching is enabled.
533
+ :param with_zone: Include the ``zone`` attribute in the response,
534
+ holding the element's image and polygon.
535
+ This parameter is not supported when caching is enabled.
536
+ :param worker_version: **Deprecated** Restrict to elements created by a worker version with this UUID.
537
+ :param worker_run: Restrict to elements created by a worker run with this UUID.
538
+ :return: An iterable of dicts from the ``ListElementChildren`` API endpoint,
539
+ or an iterable of [CachedElement][arkindex_worker.cache.CachedElement] when caching is enabled.
540
+ """
541
+ query_params = {}
542
+ if folder is not None:
543
+ assert isinstance(folder, bool), "folder should be of type bool"
544
+ query_params["folder"] = folder
545
+ if name:
546
+ assert isinstance(name, str), "name should be of type str"
547
+ query_params["name"] = name
548
+ if top_level is not None:
549
+ assert isinstance(top_level, bool), "top_level should be of type bool"
550
+ query_params["top_level"] = top_level
551
+ if transcription_worker_version is not None:
552
+ warn(
553
+ "`transcription_worker_version` usage is deprecated. Consider using `transcription_worker_run` instead.",
554
+ DeprecationWarning,
555
+ stacklevel=1,
556
+ )
557
+ assert isinstance(
558
+ transcription_worker_version, str | bool
559
+ ), "transcription_worker_version should be of type str or bool"
560
+ if isinstance(transcription_worker_version, bool):
561
+ assert (
562
+ transcription_worker_version is False
563
+ ), "if of type bool, transcription_worker_version can only be set to False"
564
+ query_params["transcription_worker_version"] = transcription_worker_version
565
+ if transcription_worker_run is not None:
566
+ assert isinstance(
567
+ transcription_worker_run, str | bool
568
+ ), "transcription_worker_run should be of type str or bool"
569
+ if isinstance(transcription_worker_run, bool):
570
+ assert (
571
+ transcription_worker_run is False
572
+ ), "if of type bool, transcription_worker_run can only be set to False"
573
+ query_params["transcription_worker_run"] = transcription_worker_run
574
+ if type:
575
+ assert isinstance(type, str), "type should be of type str"
576
+ query_params["type"] = type
577
+ if with_classes is not None:
578
+ assert isinstance(with_classes, bool), "with_classes should be of type bool"
579
+ query_params["with_classes"] = with_classes
580
+ if with_corpus is not None:
581
+ assert isinstance(with_corpus, bool), "with_corpus should be of type bool"
582
+ query_params["with_corpus"] = with_corpus
583
+ if with_has_children is not None:
584
+ assert isinstance(
585
+ with_has_children, bool
586
+ ), "with_has_children should be of type bool"
587
+ query_params["with_has_children"] = with_has_children
588
+ if with_metadata is not None:
589
+ assert isinstance(
590
+ with_metadata, bool
591
+ ), "with_metadata should be of type bool"
592
+ query_params["with_metadata"] = with_metadata
593
+ if with_zone is not None:
594
+ assert isinstance(with_zone, bool), "with_zone should be of type bool"
595
+ query_params["with_zone"] = with_zone
596
+ if worker_version is not None:
597
+ warn(
598
+ "`worker_version` usage is deprecated. Consider using `worker_run` instead.",
599
+ DeprecationWarning,
600
+ stacklevel=1,
601
+ )
602
+ assert isinstance(
603
+ worker_version, str | bool
604
+ ), "worker_version should be of type str or bool"
605
+ if isinstance(worker_version, bool):
606
+ assert (
607
+ worker_version is False
608
+ ), "if of type bool, worker_version can only be set to False"
609
+ query_params["worker_version"] = worker_version
610
+ if worker_run is not None:
611
+ assert isinstance(
612
+ worker_run, str | bool
613
+ ), "worker_run should be of type str or bool"
614
+ if isinstance(worker_run, bool):
615
+ assert (
616
+ worker_run is False
617
+ ), "if of type bool, worker_run can only be set to False"
618
+ query_params["worker_run"] = worker_run
619
+
620
+ if not self.use_cache:
621
+ return self.api_client.paginate(
622
+ "ListElements", corpus=self.corpus_id, **query_params
623
+ )
624
+
625
+ # Checking that we only received query_params handled by the cache
626
+ assert (
627
+ set(query_params.keys())
628
+ <= {
629
+ "type",
630
+ "worker_version",
631
+ "worker_run",
632
+ }
633
+ ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
634
+
635
+ query = CachedElement.select()
636
+ if type:
637
+ query = query.where(CachedElement.type == type)
638
+ if worker_version is not None:
639
+ # If worker_version=False, filter by manual worker_version e.g. None
640
+ worker_version_id = worker_version or None
641
+ if worker_version_id:
642
+ query = query.where(
643
+ CachedElement.worker_version_id == worker_version_id
644
+ )
645
+ else:
646
+ query = query.where(CachedElement.worker_version_id.is_null())
647
+
648
+ if worker_run is not None:
649
+ # If worker_run=False, filter by manual worker_run e.g. None
650
+ worker_run_id = worker_run or None
651
+ if worker_run_id:
652
+ query = query.where(CachedElement.worker_run_id == worker_run_id)
653
+ else:
654
+ query = query.where(CachedElement.worker_run_id.is_null())
655
+
656
+ return query
657
+
439
658
  def list_element_children(
440
659
  self,
441
660
  element: Element | CachedElement,
@@ -575,45 +794,43 @@ class ElementMixin:
575
794
  ), "if of type bool, worker_run can only be set to False"
576
795
  query_params["worker_run"] = worker_run
577
796
 
578
- if self.use_cache:
579
- # Checking that we only received query_params handled by the cache
580
- assert (
581
- set(query_params.keys())
582
- <= {
583
- "type",
584
- "worker_version",
585
- "worker_run",
586
- }
587
- ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
588
-
589
- query = CachedElement.select().where(CachedElement.parent_id == element.id)
590
- if type:
591
- query = query.where(CachedElement.type == type)
592
- if worker_version is not None:
593
- # If worker_version=False, filter by manual worker_version e.g. None
594
- worker_version_id = worker_version or None
595
- if worker_version_id:
596
- query = query.where(
597
- CachedElement.worker_version_id == worker_version_id
598
- )
599
- else:
600
- query = query.where(CachedElement.worker_version_id.is_null())
601
-
602
- if worker_run is not None:
603
- # If worker_run=False, filter by manual worker_run e.g. None
604
- worker_run_id = worker_run or None
605
- if worker_run_id:
606
- query = query.where(CachedElement.worker_run_id == worker_run_id)
607
- else:
608
- query = query.where(CachedElement.worker_run_id.is_null())
609
-
610
- return query
611
- else:
612
- children = self.api_client.paginate(
797
+ if not self.use_cache:
798
+ return self.api_client.paginate(
613
799
  "ListElementChildren", id=element.id, **query_params
614
800
  )
615
801
 
616
- return children
802
+ # Checking that we only received query_params handled by the cache
803
+ assert (
804
+ set(query_params.keys())
805
+ <= {
806
+ "type",
807
+ "worker_version",
808
+ "worker_run",
809
+ }
810
+ ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
811
+
812
+ query = CachedElement.select().where(CachedElement.parent_id == element.id)
813
+ if type:
814
+ query = query.where(CachedElement.type == type)
815
+ if worker_version is not None:
816
+ # If worker_version=False, filter by manual worker_version e.g. None
817
+ worker_version_id = worker_version or None
818
+ if worker_version_id:
819
+ query = query.where(
820
+ CachedElement.worker_version_id == worker_version_id
821
+ )
822
+ else:
823
+ query = query.where(CachedElement.worker_version_id.is_null())
824
+
825
+ if worker_run is not None:
826
+ # If worker_run=False, filter by manual worker_run e.g. None
827
+ worker_run_id = worker_run or None
828
+ if worker_run_id:
829
+ query = query.where(CachedElement.worker_run_id == worker_run_id)
830
+ else:
831
+ query = query.where(CachedElement.worker_run_id.is_null())
832
+
833
+ return query
617
834
 
618
835
  def list_element_parents(
619
836
  self,
@@ -754,45 +971,43 @@ class ElementMixin:
754
971
  ), "if of type bool, worker_run can only be set to False"
755
972
  query_params["worker_run"] = worker_run
756
973
 
757
- if self.use_cache:
758
- # Checking that we only received query_params handled by the cache
759
- assert (
760
- set(query_params.keys())
761
- <= {
762
- "type",
763
- "worker_version",
764
- "worker_run",
765
- }
766
- ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
767
-
768
- parent_ids = CachedElement.select(CachedElement.parent_id).where(
769
- CachedElement.id == element.id
770
- )
771
- query = CachedElement.select().where(CachedElement.id.in_(parent_ids))
772
- if type:
773
- query = query.where(CachedElement.type == type)
774
- if worker_version is not None:
775
- # If worker_version=False, filter by manual worker_version e.g. None
776
- worker_version_id = worker_version or None
777
- if worker_version_id:
778
- query = query.where(
779
- CachedElement.worker_version_id == worker_version_id
780
- )
781
- else:
782
- query = query.where(CachedElement.worker_version_id.is_null())
783
-
784
- if worker_run is not None:
785
- # If worker_run=False, filter by manual worker_run e.g. None
786
- worker_run_id = worker_run or None
787
- if worker_run_id:
788
- query = query.where(CachedElement.worker_run_id == worker_run_id)
789
- else:
790
- query = query.where(CachedElement.worker_run_id.is_null())
791
-
792
- return query
793
- else:
794
- parents = self.api_client.paginate(
974
+ if not self.use_cache:
975
+ return self.api_client.paginate(
795
976
  "ListElementParents", id=element.id, **query_params
796
977
  )
797
978
 
798
- return parents
979
+ # Checking that we only received query_params handled by the cache
980
+ assert (
981
+ set(query_params.keys())
982
+ <= {
983
+ "type",
984
+ "worker_version",
985
+ "worker_run",
986
+ }
987
+ ), "When using the local cache, you can only filter by 'type' and/or 'worker_version' and/or 'worker_run'"
988
+
989
+ parent_ids = CachedElement.select(CachedElement.parent_id).where(
990
+ CachedElement.id == element.id
991
+ )
992
+ query = CachedElement.select().where(CachedElement.id.in_(parent_ids))
993
+ if type:
994
+ query = query.where(CachedElement.type == type)
995
+ if worker_version is not None:
996
+ # If worker_version=False, filter by manual worker_version e.g. None
997
+ worker_version_id = worker_version or None
998
+ if worker_version_id:
999
+ query = query.where(
1000
+ CachedElement.worker_version_id == worker_version_id
1001
+ )
1002
+ else:
1003
+ query = query.where(CachedElement.worker_version_id.is_null())
1004
+
1005
+ if worker_run is not None:
1006
+ # If worker_run=False, filter by manual worker_run e.g. None
1007
+ worker_run_id = worker_run or None
1008
+ if worker_run_id:
1009
+ query = query.where(CachedElement.worker_run_id == worker_run_id)
1010
+ else:
1011
+ query = query.where(CachedElement.worker_run_id.is_null())
1012
+
1013
+ return query
@@ -441,60 +441,60 @@ class TranscriptionMixin:
441
441
  ), "if of type bool, worker_run can only be set to False"
442
442
  query_params["worker_run"] = worker_run
443
443
 
444
- if self.use_cache:
445
- if not recursive:
446
- # In this case we don't have to return anything, it's easier to use an
447
- # impossible condition (False) rather than filtering by type for nothing
448
- if element_type and element_type != element.type:
449
- return CachedTranscription.select().where(False)
450
- transcriptions = CachedTranscription.select().where(
451
- CachedTranscription.element_id == element.id
444
+ if not self.use_cache:
445
+ return self.api_client.paginate(
446
+ "ListTranscriptions", id=element.id, **query_params
447
+ )
448
+
449
+ if not recursive:
450
+ # In this case we don't have to return anything, it's easier to use an
451
+ # impossible condition (False) rather than filtering by type for nothing
452
+ if element_type and element_type != element.type:
453
+ return CachedTranscription.select().where(False)
454
+ transcriptions = CachedTranscription.select().where(
455
+ CachedTranscription.element_id == element.id
456
+ )
457
+ else:
458
+ base_case = (
459
+ CachedElement.select()
460
+ .where(CachedElement.id == element.id)
461
+ .cte("base", recursive=True)
462
+ )
463
+ recursive = CachedElement.select().join(
464
+ base_case, on=(CachedElement.parent_id == base_case.c.id)
465
+ )
466
+ cte = base_case.union_all(recursive)
467
+ transcriptions = (
468
+ CachedTranscription.select()
469
+ .join(cte, on=(CachedTranscription.element_id == cte.c.id))
470
+ .with_cte(cte)
471
+ )
472
+
473
+ if element_type:
474
+ transcriptions = transcriptions.where(cte.c.type == element_type)
475
+
476
+ if worker_version is not None:
477
+ # If worker_version=False, filter by manual worker_version e.g. None
478
+ worker_version_id = worker_version or None
479
+ if worker_version_id:
480
+ transcriptions = transcriptions.where(
481
+ CachedTranscription.worker_version_id == worker_version_id
452
482
  )
453
483
  else:
454
- base_case = (
455
- CachedElement.select()
456
- .where(CachedElement.id == element.id)
457
- .cte("base", recursive=True)
484
+ transcriptions = transcriptions.where(
485
+ CachedTranscription.worker_version_id.is_null()
458
486
  )
459
- recursive = CachedElement.select().join(
460
- base_case, on=(CachedElement.parent_id == base_case.c.id)
487
+
488
+ if worker_run is not None:
489
+ # If worker_run=False, filter by manual worker_run e.g. None
490
+ worker_run_id = worker_run or None
491
+ if worker_run_id:
492
+ transcriptions = transcriptions.where(
493
+ CachedTranscription.worker_run_id == worker_run_id
461
494
  )
462
- cte = base_case.union_all(recursive)
463
- transcriptions = (
464
- CachedTranscription.select()
465
- .join(cte, on=(CachedTranscription.element_id == cte.c.id))
466
- .with_cte(cte)
495
+ else:
496
+ transcriptions = transcriptions.where(
497
+ CachedTranscription.worker_run_id.is_null()
467
498
  )
468
499
 
469
- if element_type:
470
- transcriptions = transcriptions.where(cte.c.type == element_type)
471
-
472
- if worker_version is not None:
473
- # If worker_version=False, filter by manual worker_version e.g. None
474
- worker_version_id = worker_version or None
475
- if worker_version_id:
476
- transcriptions = transcriptions.where(
477
- CachedTranscription.worker_version_id == worker_version_id
478
- )
479
- else:
480
- transcriptions = transcriptions.where(
481
- CachedTranscription.worker_version_id.is_null()
482
- )
483
-
484
- if worker_run is not None:
485
- # If worker_run=False, filter by manual worker_run e.g. None
486
- worker_run_id = worker_run or None
487
- if worker_run_id:
488
- transcriptions = transcriptions.where(
489
- CachedTranscription.worker_run_id == worker_run_id
490
- )
491
- else:
492
- transcriptions = transcriptions.where(
493
- CachedTranscription.worker_run_id.is_null()
494
- )
495
- else:
496
- transcriptions = self.api_client.paginate(
497
- "ListTranscriptions", id=element.id, **query_params
498
- )
499
-
500
500
  return transcriptions
tests/conftest.py CHANGED
@@ -26,7 +26,7 @@ from arkindex_worker.models import Artifact, Dataset, Set
26
26
  from arkindex_worker.worker import BaseWorker, DatasetWorker, ElementsWorker
27
27
  from arkindex_worker.worker.dataset import DatasetState
28
28
  from arkindex_worker.worker.transcription import TextOrientation
29
- from tests import CORPUS_ID, FIXTURES_DIR, PROCESS_ID, SAMPLES_DIR
29
+ from tests import CORPUS_ID, PROCESS_ID, SAMPLES_DIR
30
30
 
31
31
  __yaml_cache = {}
32
32
 
@@ -277,9 +277,7 @@ def mock_elements_worker_with_list(monkeypatch, responses, mock_elements_worker)
277
277
  """
278
278
  Mock a worker instance to list and retrieve a single element
279
279
  """
280
- monkeypatch.setattr(
281
- mock_elements_worker, "list_elements", lambda: ["1234-deadbeef"]
282
- )
280
+ monkeypatch.setattr(mock_elements_worker, "get_elements", lambda: ["1234-deadbeef"])
283
281
  responses.add(
284
282
  responses.GET,
285
283
  "http://testserver/api/v1/element/1234-deadbeef/",
@@ -326,23 +324,6 @@ def mock_elements_worker_with_cache(monkeypatch, mock_cache_db, _mock_worker_run
326
324
  return worker
327
325
 
328
326
 
329
- @pytest.fixture()
330
- def fake_page_element():
331
- return json.loads((FIXTURES_DIR / "page_element.json").read_text())
332
-
333
-
334
- @pytest.fixture()
335
- def fake_ufcn_worker_version():
336
- return json.loads(
337
- (FIXTURES_DIR / "ufcn_line_historical_worker_version.json").read_text()
338
- )
339
-
340
-
341
- @pytest.fixture()
342
- def fake_transcriptions_small():
343
- return json.loads((FIXTURES_DIR / "line_transcriptions_small.json").read_text())
344
-
345
-
346
327
  @pytest.fixture()
347
328
  def model_file_dir():
348
329
  return SAMPLES_DIR / "model_files"