arkindex-base-worker 0.3.7rc4__py3-none-any.whl → 0.5.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
  2. arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +1 -1
  6. arkindex_worker/image.py +167 -2
  7. arkindex_worker/models.py +18 -0
  8. arkindex_worker/utils.py +98 -4
  9. arkindex_worker/worker/__init__.py +117 -218
  10. arkindex_worker/worker/base.py +39 -46
  11. arkindex_worker/worker/classification.py +45 -29
  12. arkindex_worker/worker/corpus.py +86 -0
  13. arkindex_worker/worker/dataset.py +89 -26
  14. arkindex_worker/worker/element.py +352 -91
  15. arkindex_worker/worker/entity.py +13 -11
  16. arkindex_worker/worker/image.py +21 -0
  17. arkindex_worker/worker/metadata.py +26 -16
  18. arkindex_worker/worker/process.py +92 -0
  19. arkindex_worker/worker/task.py +5 -4
  20. arkindex_worker/worker/training.py +25 -10
  21. arkindex_worker/worker/transcription.py +89 -68
  22. arkindex_worker/worker/version.py +3 -1
  23. hooks/pre_gen_project.py +3 -0
  24. tests/__init__.py +8 -0
  25. tests/conftest.py +47 -58
  26. tests/test_base_worker.py +212 -12
  27. tests/test_dataset_worker.py +294 -437
  28. tests/test_elements_worker/{test_classifications.py → test_classification.py} +313 -200
  29. tests/test_elements_worker/test_cli.py +3 -11
  30. tests/test_elements_worker/test_corpus.py +168 -0
  31. tests/test_elements_worker/test_dataset.py +106 -157
  32. tests/test_elements_worker/test_element.py +427 -0
  33. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  34. tests/test_elements_worker/test_element_create_single.py +528 -0
  35. tests/test_elements_worker/test_element_list_children.py +969 -0
  36. tests/test_elements_worker/test_element_list_parents.py +530 -0
  37. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  38. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  39. tests/test_elements_worker/test_image.py +66 -0
  40. tests/test_elements_worker/test_metadata.py +252 -161
  41. tests/test_elements_worker/test_process.py +89 -0
  42. tests/test_elements_worker/test_task.py +8 -18
  43. tests/test_elements_worker/test_training.py +17 -8
  44. tests/test_elements_worker/test_transcription_create.py +873 -0
  45. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  46. tests/test_elements_worker/test_transcription_list.py +450 -0
  47. tests/test_elements_worker/test_version.py +60 -0
  48. tests/test_elements_worker/test_worker.py +578 -293
  49. tests/test_image.py +542 -209
  50. tests/test_merge.py +1 -2
  51. tests/test_utils.py +89 -4
  52. worker-demo/tests/__init__.py +0 -0
  53. worker-demo/tests/conftest.py +32 -0
  54. worker-demo/tests/test_worker.py +12 -0
  55. worker-demo/worker_demo/__init__.py +6 -0
  56. worker-demo/worker_demo/worker.py +19 -0
  57. arkindex_base_worker-0.3.7rc4.dist-info/RECORD +0 -41
  58. tests/test_elements_worker/test_elements.py +0 -2713
  59. tests/test_elements_worker/test_transcriptions.py +0 -2119
  60. {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
@@ -1,12 +1,14 @@
1
1
  import json
2
2
  import re
3
- from uuid import UUID, uuid4
3
+ from uuid import UUID
4
4
 
5
5
  import pytest
6
- from apistar.exceptions import ErrorResponse
7
6
 
7
+ from arkindex.exceptions import ErrorResponse
8
8
  from arkindex_worker.cache import CachedClassification, CachedElement
9
9
  from arkindex_worker.models import Element
10
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
11
+ from tests import CORPUS_ID
10
12
 
11
13
  from . import BASE_API_CALLS
12
14
 
@@ -15,11 +17,96 @@ from . import BASE_API_CALLS
15
17
  DELETE_PARAMETER = "DELETE_PARAMETER"
16
18
 
17
19
 
20
+ def test_load_corpus_classes_api_error(responses, mock_elements_worker):
21
+ responses.add(
22
+ responses.GET,
23
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
24
+ status=418,
25
+ )
26
+
27
+ assert not mock_elements_worker.classes
28
+ with pytest.raises(
29
+ Exception, match="Stopping pagination as data will be incomplete"
30
+ ):
31
+ mock_elements_worker.load_corpus_classes()
32
+
33
+ assert len(responses.calls) == len(BASE_API_CALLS) + 5
34
+ assert [
35
+ (call.request.method, call.request.url) for call in responses.calls
36
+ ] == BASE_API_CALLS + [
37
+ # We do 5 retries
38
+ (
39
+ "GET",
40
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
41
+ ),
42
+ (
43
+ "GET",
44
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
45
+ ),
46
+ (
47
+ "GET",
48
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
49
+ ),
50
+ (
51
+ "GET",
52
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
53
+ ),
54
+ (
55
+ "GET",
56
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
57
+ ),
58
+ ]
59
+ assert not mock_elements_worker.classes
60
+
61
+
62
+ def test_load_corpus_classes(responses, mock_elements_worker):
63
+ responses.add(
64
+ responses.GET,
65
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
66
+ status=200,
67
+ json={
68
+ "count": 3,
69
+ "next": None,
70
+ "results": [
71
+ {
72
+ "id": "0000",
73
+ "name": "good",
74
+ },
75
+ {
76
+ "id": "1111",
77
+ "name": "average",
78
+ },
79
+ {
80
+ "id": "2222",
81
+ "name": "bad",
82
+ },
83
+ ],
84
+ },
85
+ )
86
+
87
+ assert not mock_elements_worker.classes
88
+ mock_elements_worker.load_corpus_classes()
89
+
90
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
91
+ assert [
92
+ (call.request.method, call.request.url) for call in responses.calls
93
+ ] == BASE_API_CALLS + [
94
+ (
95
+ "GET",
96
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
97
+ ),
98
+ ]
99
+ assert mock_elements_worker.classes == {
100
+ "good": "0000",
101
+ "average": "1111",
102
+ "bad": "2222",
103
+ }
104
+
105
+
18
106
  def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
19
- corpus_id = "11111111-1111-1111-1111-111111111111"
20
107
  responses.add(
21
108
  responses.GET,
22
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
109
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
23
110
  status=200,
24
111
  json={
25
112
  "count": 1,
@@ -42,7 +129,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
42
129
  ] == BASE_API_CALLS + [
43
130
  (
44
131
  "GET",
45
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
132
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
46
133
  ),
47
134
  ]
48
135
  assert mock_elements_worker.classes == {"good": "0000"}
@@ -51,12 +138,11 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
51
138
 
52
139
  def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
53
140
  # A missing class is now created automatically
54
- corpus_id = "11111111-1111-1111-1111-111111111111"
55
141
  mock_elements_worker.classes = {"good": "0000"}
56
142
 
57
143
  responses.add(
58
144
  responses.POST,
59
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
145
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
60
146
  status=201,
61
147
  json={"id": "new-ml-class-1234"},
62
148
  )
@@ -82,12 +168,10 @@ def test_get_ml_class_id(mock_elements_worker):
82
168
 
83
169
 
84
170
  def test_get_ml_class_reload(responses, mock_elements_worker):
85
- corpus_id = "11111111-1111-1111-1111-111111111111"
86
-
87
171
  # Add some initial classes
88
172
  responses.add(
89
173
  responses.GET,
90
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
174
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
91
175
  json={
92
176
  "count": 1,
93
177
  "next": None,
@@ -103,7 +187,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
103
187
  # Invalid response when trying to create class2
104
188
  responses.add(
105
189
  responses.POST,
106
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
190
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
107
191
  status=400,
108
192
  json={"non_field_errors": "Already exists"},
109
193
  )
@@ -111,7 +195,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
111
195
  # Add both classes (class2 is created by another process)
112
196
  responses.add(
113
197
  responses.GET,
114
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
198
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
115
199
  json={
116
200
  "count": 2,
117
201
  "next": None,
@@ -141,15 +225,15 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
141
225
  ] == BASE_API_CALLS + [
142
226
  (
143
227
  "GET",
144
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
228
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
145
229
  ),
146
230
  (
147
231
  "POST",
148
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
232
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
149
233
  ),
150
234
  (
151
235
  "GET",
152
- f"http://testserver/api/v1/corpus/{corpus_id}/classes/",
236
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
153
237
  ),
154
238
  ]
155
239
 
@@ -169,7 +253,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
169
253
  """
170
254
  responses.add(
171
255
  responses.GET,
172
- f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
256
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
173
257
  status=200,
174
258
  json={
175
259
  "count": 1,
@@ -189,7 +273,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
189
273
  ] == BASE_API_CALLS + [
190
274
  (
191
275
  "GET",
192
- f"http://testserver/api/v1/corpus/{mock_elements_worker.corpus_id}/classes/",
276
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
193
277
  ),
194
278
  ]
195
279
 
@@ -276,7 +360,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
276
360
  responses.add(
277
361
  responses.POST,
278
362
  "http://testserver/api/v1/classifications/",
279
- status=500,
363
+ status=418,
280
364
  )
281
365
 
282
366
  with pytest.raises(ErrorResponse):
@@ -287,17 +371,10 @@ def test_create_classification_api_error(responses, mock_elements_worker):
287
371
  high_confidence=True,
288
372
  )
289
373
 
290
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
374
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
291
375
  assert [
292
376
  (call.request.method, call.request.url) for call in responses.calls
293
- ] == BASE_API_CALLS + [
294
- # We retry 5 times the API call
295
- ("POST", "http://testserver/api/v1/classifications/"),
296
- ("POST", "http://testserver/api/v1/classifications/"),
297
- ("POST", "http://testserver/api/v1/classifications/"),
298
- ("POST", "http://testserver/api/v1/classifications/"),
299
- ("POST", "http://testserver/api/v1/classifications/"),
300
- ]
377
+ ] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classifications/")]
301
378
 
302
379
 
303
380
  def test_create_classification_create_ml_class(mock_elements_worker, responses):
@@ -306,7 +383,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
306
383
  # Automatically create a missing class!
307
384
  responses.add(
308
385
  responses.POST,
309
- "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
386
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
310
387
  status=201,
311
388
  json={"id": "new-ml-class-1234"},
312
389
  )
@@ -325,15 +402,12 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
325
402
  )
326
403
 
327
404
  # Check a class & classification has been created
328
- for call in responses.calls:
329
- print(call.request.url, call.request.body)
330
-
331
405
  assert [
332
406
  (call.request.url, json.loads(call.request.body))
333
407
  for call in responses.calls[-2:]
334
408
  ] == [
335
409
  (
336
- "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
410
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
337
411
  {"name": "a_class"},
338
412
  ),
339
413
  (
@@ -506,12 +580,12 @@ def test_create_classifications_wrong_data(
506
580
  "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
507
581
  "classifications": [
508
582
  {
509
- "ml_class_id": "uuid1",
583
+ "ml_class": "cat",
510
584
  "confidence": 0.75,
511
585
  "high_confidence": False,
512
586
  },
513
587
  {
514
- "ml_class_id": "uuid2",
588
+ "ml_class": "dog",
515
589
  "confidence": 0.25,
516
590
  "high_confidence": False,
517
591
  },
@@ -523,86 +597,71 @@ def test_create_classifications_wrong_data(
523
597
 
524
598
 
525
599
  @pytest.mark.parametrize(
526
- ("arg_name", "data", "error_message", "error_type"),
600
+ ("arg_name", "data", "error_message"),
527
601
  [
528
- # Wrong classifications > ml_class_id
602
+ # Wrong classifications > ml_class
529
603
  (
530
- "ml_class_id",
604
+ "ml_class",
531
605
  DELETE_PARAMETER,
532
- "ml_class_id shouldn't be null and should be of type str",
533
- AssertionError,
534
- ), # Updated
606
+ "ml_class shouldn't be null and should be of type str",
607
+ ),
535
608
  (
536
- "ml_class_id",
609
+ "ml_class",
537
610
  None,
538
- "ml_class_id shouldn't be null and should be of type str",
539
- AssertionError,
611
+ "ml_class shouldn't be null and should be of type str",
540
612
  ),
541
613
  (
542
- "ml_class_id",
614
+ "ml_class",
543
615
  1234,
544
- "ml_class_id shouldn't be null and should be of type str",
545
- AssertionError,
546
- ),
547
- (
548
- "ml_class_id",
549
- "not_an_uuid",
550
- "ml_class_id is not a valid uuid.",
551
- ValueError,
616
+ "ml_class shouldn't be null and should be of type str",
552
617
  ),
553
618
  # Wrong classifications > confidence
554
619
  (
555
620
  "confidence",
556
621
  DELETE_PARAMETER,
557
622
  "confidence shouldn't be null and should be a float in [0..1] range",
558
- AssertionError,
559
623
  ),
560
624
  (
561
625
  "confidence",
562
626
  None,
563
627
  "confidence shouldn't be null and should be a float in [0..1] range",
564
- AssertionError,
565
628
  ),
566
629
  (
567
630
  "confidence",
568
631
  "wrong confidence",
569
632
  "confidence shouldn't be null and should be a float in [0..1] range",
570
- AssertionError,
571
633
  ),
572
634
  (
573
635
  "confidence",
574
636
  0,
575
637
  "confidence shouldn't be null and should be a float in [0..1] range",
576
- AssertionError,
577
638
  ),
578
639
  (
579
640
  "confidence",
580
641
  2.00,
581
642
  "confidence shouldn't be null and should be a float in [0..1] range",
582
- AssertionError,
583
643
  ),
584
644
  # Wrong classifications > high_confidence
585
645
  (
586
646
  "high_confidence",
587
647
  "wrong high_confidence",
588
648
  "high_confidence should be of type bool",
589
- AssertionError,
590
649
  ),
591
650
  ],
592
651
  )
593
652
  def test_create_classifications_wrong_classifications_data(
594
- arg_name, data, error_message, error_type, mock_elements_worker
653
+ arg_name, data, error_message, mock_elements_worker
595
654
  ):
596
655
  all_data = {
597
656
  "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
598
657
  "classifications": [
599
658
  {
600
- "ml_class_id": str(uuid4()),
659
+ "ml_class": "cat",
601
660
  "confidence": 0.75,
602
661
  "high_confidence": False,
603
662
  },
604
663
  {
605
- "ml_class_id": str(uuid4()),
664
+ "ml_class": "dog",
606
665
  "confidence": 0.25,
607
666
  "high_confidence": False,
608
667
  # Overwrite with wrong data
@@ -614,7 +673,7 @@ def test_create_classifications_wrong_classifications_data(
614
673
  del all_data["classifications"][1][arg_name]
615
674
 
616
675
  with pytest.raises(
617
- error_type,
676
+ AssertionError,
618
677
  match=re.escape(
619
678
  f"Classification at index 1 in classifications: {error_message}"
620
679
  ),
@@ -623,20 +682,21 @@ def test_create_classifications_wrong_classifications_data(
623
682
 
624
683
 
625
684
  def test_create_classifications_api_error(responses, mock_elements_worker):
685
+ mock_elements_worker.classes = {"cat": "0000", "dog": "1111"}
626
686
  responses.add(
627
687
  responses.POST,
628
688
  "http://testserver/api/v1/classification/bulk/",
629
- status=500,
689
+ status=418,
630
690
  )
631
691
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
632
692
  classes = [
633
693
  {
634
- "ml_class_id": str(uuid4()),
694
+ "ml_class": "cat",
635
695
  "confidence": 0.75,
636
696
  "high_confidence": False,
637
697
  },
638
698
  {
639
- "ml_class_id": str(uuid4()),
699
+ "ml_class": "dog",
640
700
  "confidence": 0.25,
641
701
  "high_confidence": False,
642
702
  },
@@ -647,192 +707,245 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
647
707
  element=elt, classifications=classes
648
708
  )
649
709
 
650
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
710
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
651
711
  assert [
652
712
  (call.request.method, call.request.url) for call in responses.calls
653
- ] == BASE_API_CALLS + [
654
- # We retry 5 times the API call
655
- ("POST", "http://testserver/api/v1/classification/bulk/"),
656
- ("POST", "http://testserver/api/v1/classification/bulk/"),
657
- ("POST", "http://testserver/api/v1/classification/bulk/"),
658
- ("POST", "http://testserver/api/v1/classification/bulk/"),
659
- ("POST", "http://testserver/api/v1/classification/bulk/"),
660
- ]
661
-
713
+ ] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classification/bulk/")]
662
714
 
663
- def test_create_classifications(responses, mock_elements_worker_with_cache):
664
- # Set MLClass in cache
665
- portrait_uuid = str(uuid4())
666
- landscape_uuid = str(uuid4())
667
- mock_elements_worker_with_cache.classes = {
668
- "portrait": portrait_uuid,
669
- "landscape": landscape_uuid,
670
- }
671
715
 
672
- elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
673
- classes = [
674
- {
675
- "ml_class_id": portrait_uuid,
676
- "confidence": 0.75,
677
- "high_confidence": False,
678
- },
679
- {
680
- "ml_class_id": landscape_uuid,
681
- "confidence": 0.25,
682
- "high_confidence": False,
683
- },
684
- ]
716
+ def test_create_classifications_create_ml_class(mock_elements_worker, responses):
717
+ elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
685
718
 
719
+ # Automatically create a missing class!
720
+ responses.add(
721
+ responses.POST,
722
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
723
+ status=201,
724
+ json={"id": "new-ml-class-1234"},
725
+ )
686
726
  responses.add(
687
727
  responses.POST,
688
728
  "http://testserver/api/v1/classification/bulk/",
689
- status=200,
729
+ status=201,
690
730
  json={
691
731
  "parent": str(elt.id),
692
732
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
693
733
  "classifications": [
694
734
  {
695
735
  "id": "00000000-0000-0000-0000-000000000000",
696
- "ml_class": portrait_uuid,
736
+ "ml_class": "new-ml-class-1234",
697
737
  "confidence": 0.75,
698
738
  "high_confidence": False,
699
739
  "state": "pending",
700
740
  },
701
- {
702
- "id": "11111111-1111-1111-1111-111111111111",
703
- "ml_class": landscape_uuid,
704
- "confidence": 0.25,
705
- "high_confidence": False,
706
- "state": "pending",
707
- },
708
741
  ],
709
742
  },
710
743
  )
711
-
712
- mock_elements_worker_with_cache.create_classifications(
713
- element=elt, classifications=classes
744
+ mock_elements_worker.classes = {"another_class": "0000"}
745
+ mock_elements_worker.create_classifications(
746
+ element=elt,
747
+ classifications=[
748
+ {
749
+ "ml_class": "a_class",
750
+ "confidence": 0.75,
751
+ "high_confidence": False,
752
+ }
753
+ ],
714
754
  )
715
755
 
716
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
756
+ # Check a class & classification has been created
757
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
717
758
  assert [
718
759
  (call.request.method, call.request.url) for call in responses.calls
719
760
  ] == BASE_API_CALLS + [
761
+ (
762
+ "POST",
763
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
764
+ ),
720
765
  ("POST", "http://testserver/api/v1/classification/bulk/"),
721
766
  ]
722
767
 
768
+ assert json.loads(responses.calls[-2].request.body) == {"name": "a_class"}
723
769
  assert json.loads(responses.calls[-1].request.body) == {
724
- "parent": str(elt.id),
770
+ "parent": "12341234-1234-1234-1234-123412341234",
725
771
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
726
- "classifications": classes,
772
+ "classifications": [
773
+ {
774
+ "ml_class": "new-ml-class-1234",
775
+ "confidence": 0.75,
776
+ "high_confidence": False,
777
+ }
778
+ ],
727
779
  }
728
780
 
729
- # Check that created classifications were properly stored in SQLite cache
730
- assert list(CachedClassification.select()) == [
731
- CachedClassification(
732
- id=UUID("00000000-0000-0000-0000-000000000000"),
733
- element_id=UUID(elt.id),
734
- class_name="portrait",
735
- confidence=0.75,
736
- state="pending",
737
- worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
738
- ),
739
- CachedClassification(
740
- id=UUID("11111111-1111-1111-1111-111111111111"),
741
- element_id=UUID(elt.id),
742
- class_name="landscape",
743
- confidence=0.25,
744
- state="pending",
745
- worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
746
- ),
747
- ]
748
-
749
-
750
- def test_create_classifications_not_in_cache(
751
- responses, mock_elements_worker_with_cache
752
- ):
753
- """
754
- CreateClassifications using ID that are not in `.classes` attribute.
755
- Will load corpus MLClass to insert the corresponding name in Cache.
756
- """
757
- portrait_uuid = str(uuid4())
758
- landscape_uuid = str(uuid4())
759
- elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
760
- classes = [
761
- {
762
- "ml_class_id": portrait_uuid,
763
- "confidence": 0.75,
764
- "high_confidence": False,
765
- },
766
- {
767
- "ml_class_id": landscape_uuid,
768
- "confidence": 0.25,
769
- "high_confidence": False,
770
- },
771
- ]
772
781
 
782
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
783
+ def test_create_classifications(batch_size, responses, mock_elements_worker):
784
+ mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
785
+ elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
773
786
  responses.add(
774
787
  responses.POST,
775
788
  "http://testserver/api/v1/classification/bulk/",
776
789
  status=200,
777
- json={
778
- "parent": str(elt.id),
779
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
780
- "classifications": [
781
- {
782
- "id": "00000000-0000-0000-0000-000000000000",
783
- "ml_class": portrait_uuid,
784
- "confidence": 0.75,
785
- "high_confidence": False,
786
- "state": "pending",
787
- },
788
- {
789
- "id": "11111111-1111-1111-1111-111111111111",
790
- "ml_class": landscape_uuid,
791
- "confidence": 0.25,
792
- "high_confidence": False,
793
- "state": "pending",
794
- },
795
- ],
796
- },
790
+ json={"classifications": []},
797
791
  )
798
- responses.add(
799
- responses.GET,
800
- f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
801
- status=200,
802
- json={
803
- "count": 2,
804
- "next": None,
805
- "results": [
806
- {
807
- "id": portrait_uuid,
808
- "name": "portrait",
809
- },
810
- {"id": landscape_uuid, "name": "landscape"},
811
- ],
812
- },
792
+
793
+ mock_elements_worker.create_classifications(
794
+ element=elt,
795
+ classifications=[
796
+ {
797
+ "ml_class": "portrait",
798
+ "confidence": 0.75,
799
+ "high_confidence": False,
800
+ },
801
+ {
802
+ "ml_class": "landscape",
803
+ "confidence": 0.25,
804
+ "high_confidence": False,
805
+ },
806
+ ],
807
+ batch_size=batch_size,
813
808
  )
814
809
 
810
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
811
+ if batch_size != DEFAULT_BATCH_SIZE:
812
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
813
+
814
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
815
+ assert [
816
+ (call.request.method, call.request.url) for call in responses.calls
817
+ ] == BASE_API_CALLS + bulk_api_calls
818
+
819
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
820
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
821
+ empty_payload = {
822
+ "parent": str(elt.id),
823
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
824
+ "classifications": [],
825
+ }
826
+
827
+ bodies = []
828
+ first_call_idx = None
829
+ if batch_size > 1:
830
+ first_call_idx = -1
831
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
832
+ else:
833
+ first_call_idx = -2
834
+ bodies.append({**empty_payload, "classifications": [first_cl]})
835
+ bodies.append({**empty_payload, "classifications": [second_cl]})
836
+
837
+ assert [
838
+ json.loads(bulk_call.request.body)
839
+ for bulk_call in responses.calls[first_call_idx:]
840
+ ] == bodies
841
+
842
+
843
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
844
+ def test_create_classifications_with_cache(
845
+ batch_size, responses, mock_elements_worker_with_cache
846
+ ):
847
+ mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
848
+ elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
849
+
850
+ if batch_size > 1:
851
+ responses.add(
852
+ responses.POST,
853
+ "http://testserver/api/v1/classification/bulk/",
854
+ status=200,
855
+ json={
856
+ "parent": str(elt.id),
857
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
858
+ "classifications": [
859
+ {
860
+ "id": "00000000-0000-0000-0000-000000000000",
861
+ "ml_class": "0000",
862
+ "confidence": 0.75,
863
+ "high_confidence": False,
864
+ "state": "pending",
865
+ },
866
+ {
867
+ "id": "11111111-1111-1111-1111-111111111111",
868
+ "ml_class": "1111",
869
+ "confidence": 0.25,
870
+ "high_confidence": False,
871
+ "state": "pending",
872
+ },
873
+ ],
874
+ },
875
+ )
876
+ else:
877
+ for cl_id, cl_class, cl_conf in [
878
+ ("00000000-0000-0000-0000-000000000000", "0000", 0.75),
879
+ ("11111111-1111-1111-1111-111111111111", "1111", 0.25),
880
+ ]:
881
+ responses.add(
882
+ responses.POST,
883
+ "http://testserver/api/v1/classification/bulk/",
884
+ status=200,
885
+ json={
886
+ "parent": str(elt.id),
887
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
888
+ "classifications": [
889
+ {
890
+ "id": cl_id,
891
+ "ml_class": cl_class,
892
+ "confidence": cl_conf,
893
+ "high_confidence": False,
894
+ "state": "pending",
895
+ },
896
+ ],
897
+ },
898
+ )
899
+
815
900
  mock_elements_worker_with_cache.create_classifications(
816
- element=elt, classifications=classes
901
+ element=elt,
902
+ classifications=[
903
+ {
904
+ "ml_class": "portrait",
905
+ "confidence": 0.75,
906
+ "high_confidence": False,
907
+ },
908
+ {
909
+ "ml_class": "landscape",
910
+ "confidence": 0.25,
911
+ "high_confidence": False,
912
+ },
913
+ ],
914
+ batch_size=batch_size,
817
915
  )
818
916
 
819
- assert len(responses.calls) == len(BASE_API_CALLS) + 2
917
+ bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
918
+ if batch_size != DEFAULT_BATCH_SIZE:
919
+ bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
920
+
921
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
820
922
  assert [
821
923
  (call.request.method, call.request.url) for call in responses.calls
822
- ] == BASE_API_CALLS + [
823
- ("POST", "http://testserver/api/v1/classification/bulk/"),
824
- (
825
- "GET",
826
- f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
827
- ),
828
- ]
924
+ ] == BASE_API_CALLS + bulk_api_calls
829
925
 
830
- assert json.loads(responses.calls[-2].request.body) == {
926
+ first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
927
+ second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
928
+ empty_payload = {
831
929
  "parent": str(elt.id),
832
930
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
833
- "classifications": classes,
931
+ "classifications": [],
834
932
  }
835
933
 
934
+ bodies = []
935
+ first_call_idx = None
936
+ if batch_size > 1:
937
+ first_call_idx = -1
938
+ bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
939
+ else:
940
+ first_call_idx = -2
941
+ bodies.append({**empty_payload, "classifications": [first_cl]})
942
+ bodies.append({**empty_payload, "classifications": [second_cl]})
943
+
944
+ assert [
945
+ json.loads(bulk_call.request.body)
946
+ for bulk_call in responses.calls[first_call_idx:]
947
+ ] == bodies
948
+
836
949
  # Check that created classifications were properly stored in SQLite cache
837
950
  assert list(CachedClassification.select()) == [
838
951
  CachedClassification(