arkindex-base-worker 0.3.7rc4__py3-none-any.whl → 0.5.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
- arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
- arkindex_worker/cache.py +1 -1
- arkindex_worker/image.py +167 -2
- arkindex_worker/models.py +18 -0
- arkindex_worker/utils.py +98 -4
- arkindex_worker/worker/__init__.py +117 -218
- arkindex_worker/worker/base.py +39 -46
- arkindex_worker/worker/classification.py +45 -29
- arkindex_worker/worker/corpus.py +86 -0
- arkindex_worker/worker/dataset.py +89 -26
- arkindex_worker/worker/element.py +352 -91
- arkindex_worker/worker/entity.py +13 -11
- arkindex_worker/worker/image.py +21 -0
- arkindex_worker/worker/metadata.py +26 -16
- arkindex_worker/worker/process.py +92 -0
- arkindex_worker/worker/task.py +5 -4
- arkindex_worker/worker/training.py +25 -10
- arkindex_worker/worker/transcription.py +89 -68
- arkindex_worker/worker/version.py +3 -1
- hooks/pre_gen_project.py +3 -0
- tests/__init__.py +8 -0
- tests/conftest.py +47 -58
- tests/test_base_worker.py +212 -12
- tests/test_dataset_worker.py +294 -437
- tests/test_elements_worker/{test_classifications.py → test_classification.py} +313 -200
- tests/test_elements_worker/test_cli.py +3 -11
- tests/test_elements_worker/test_corpus.py +168 -0
- tests/test_elements_worker/test_dataset.py +106 -157
- tests/test_elements_worker/test_element.py +427 -0
- tests/test_elements_worker/test_element_create_multiple.py +715 -0
- tests/test_elements_worker/test_element_create_single.py +528 -0
- tests/test_elements_worker/test_element_list_children.py +969 -0
- tests/test_elements_worker/test_element_list_parents.py +530 -0
- tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
- tests/test_elements_worker/test_entity_list_and_check.py +160 -0
- tests/test_elements_worker/test_image.py +66 -0
- tests/test_elements_worker/test_metadata.py +252 -161
- tests/test_elements_worker/test_process.py +89 -0
- tests/test_elements_worker/test_task.py +8 -18
- tests/test_elements_worker/test_training.py +17 -8
- tests/test_elements_worker/test_transcription_create.py +873 -0
- tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
- tests/test_elements_worker/test_transcription_list.py +450 -0
- tests/test_elements_worker/test_version.py +60 -0
- tests/test_elements_worker/test_worker.py +578 -293
- tests/test_image.py +542 -209
- tests/test_merge.py +1 -2
- tests/test_utils.py +89 -4
- worker-demo/tests/__init__.py +0 -0
- worker-demo/tests/conftest.py +32 -0
- worker-demo/tests/test_worker.py +12 -0
- worker-demo/worker_demo/__init__.py +6 -0
- worker-demo/worker_demo/worker.py +19 -0
- arkindex_base_worker-0.3.7rc4.dist-info/RECORD +0 -41
- tests/test_elements_worker/test_elements.py +0 -2713
- tests/test_elements_worker/test_transcriptions.py +0 -2119
- {arkindex_base_worker-0.3.7rc4.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
-
from uuid import UUID
|
|
3
|
+
from uuid import UUID
|
|
4
4
|
|
|
5
5
|
import pytest
|
|
6
|
-
from apistar.exceptions import ErrorResponse
|
|
7
6
|
|
|
7
|
+
from arkindex.exceptions import ErrorResponse
|
|
8
8
|
from arkindex_worker.cache import CachedClassification, CachedElement
|
|
9
9
|
from arkindex_worker.models import Element
|
|
10
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
11
|
+
from tests import CORPUS_ID
|
|
10
12
|
|
|
11
13
|
from . import BASE_API_CALLS
|
|
12
14
|
|
|
@@ -15,11 +17,96 @@ from . import BASE_API_CALLS
|
|
|
15
17
|
DELETE_PARAMETER = "DELETE_PARAMETER"
|
|
16
18
|
|
|
17
19
|
|
|
20
|
+
def test_load_corpus_classes_api_error(responses, mock_elements_worker):
|
|
21
|
+
responses.add(
|
|
22
|
+
responses.GET,
|
|
23
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
24
|
+
status=418,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
assert not mock_elements_worker.classes
|
|
28
|
+
with pytest.raises(
|
|
29
|
+
Exception, match="Stopping pagination as data will be incomplete"
|
|
30
|
+
):
|
|
31
|
+
mock_elements_worker.load_corpus_classes()
|
|
32
|
+
|
|
33
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 5
|
|
34
|
+
assert [
|
|
35
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
36
|
+
] == BASE_API_CALLS + [
|
|
37
|
+
# We do 5 retries
|
|
38
|
+
(
|
|
39
|
+
"GET",
|
|
40
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
41
|
+
),
|
|
42
|
+
(
|
|
43
|
+
"GET",
|
|
44
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
45
|
+
),
|
|
46
|
+
(
|
|
47
|
+
"GET",
|
|
48
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
49
|
+
),
|
|
50
|
+
(
|
|
51
|
+
"GET",
|
|
52
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
53
|
+
),
|
|
54
|
+
(
|
|
55
|
+
"GET",
|
|
56
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
57
|
+
),
|
|
58
|
+
]
|
|
59
|
+
assert not mock_elements_worker.classes
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_load_corpus_classes(responses, mock_elements_worker):
|
|
63
|
+
responses.add(
|
|
64
|
+
responses.GET,
|
|
65
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
66
|
+
status=200,
|
|
67
|
+
json={
|
|
68
|
+
"count": 3,
|
|
69
|
+
"next": None,
|
|
70
|
+
"results": [
|
|
71
|
+
{
|
|
72
|
+
"id": "0000",
|
|
73
|
+
"name": "good",
|
|
74
|
+
},
|
|
75
|
+
{
|
|
76
|
+
"id": "1111",
|
|
77
|
+
"name": "average",
|
|
78
|
+
},
|
|
79
|
+
{
|
|
80
|
+
"id": "2222",
|
|
81
|
+
"name": "bad",
|
|
82
|
+
},
|
|
83
|
+
],
|
|
84
|
+
},
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
assert not mock_elements_worker.classes
|
|
88
|
+
mock_elements_worker.load_corpus_classes()
|
|
89
|
+
|
|
90
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
91
|
+
assert [
|
|
92
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
93
|
+
] == BASE_API_CALLS + [
|
|
94
|
+
(
|
|
95
|
+
"GET",
|
|
96
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
97
|
+
),
|
|
98
|
+
]
|
|
99
|
+
assert mock_elements_worker.classes == {
|
|
100
|
+
"good": "0000",
|
|
101
|
+
"average": "1111",
|
|
102
|
+
"bad": "2222",
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
|
|
18
106
|
def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
19
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
20
107
|
responses.add(
|
|
21
108
|
responses.GET,
|
|
22
|
-
f"http://testserver/api/v1/corpus/{
|
|
109
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
23
110
|
status=200,
|
|
24
111
|
json={
|
|
25
112
|
"count": 1,
|
|
@@ -42,7 +129,7 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
|
42
129
|
] == BASE_API_CALLS + [
|
|
43
130
|
(
|
|
44
131
|
"GET",
|
|
45
|
-
f"http://testserver/api/v1/corpus/{
|
|
132
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
46
133
|
),
|
|
47
134
|
]
|
|
48
135
|
assert mock_elements_worker.classes == {"good": "0000"}
|
|
@@ -51,12 +138,11 @@ def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
|
|
|
51
138
|
|
|
52
139
|
def test_get_ml_class_id_inexistant_class(mock_elements_worker, responses):
|
|
53
140
|
# A missing class is now created automatically
|
|
54
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
55
141
|
mock_elements_worker.classes = {"good": "0000"}
|
|
56
142
|
|
|
57
143
|
responses.add(
|
|
58
144
|
responses.POST,
|
|
59
|
-
f"http://testserver/api/v1/corpus/{
|
|
145
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
60
146
|
status=201,
|
|
61
147
|
json={"id": "new-ml-class-1234"},
|
|
62
148
|
)
|
|
@@ -82,12 +168,10 @@ def test_get_ml_class_id(mock_elements_worker):
|
|
|
82
168
|
|
|
83
169
|
|
|
84
170
|
def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
85
|
-
corpus_id = "11111111-1111-1111-1111-111111111111"
|
|
86
|
-
|
|
87
171
|
# Add some initial classes
|
|
88
172
|
responses.add(
|
|
89
173
|
responses.GET,
|
|
90
|
-
f"http://testserver/api/v1/corpus/{
|
|
174
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
91
175
|
json={
|
|
92
176
|
"count": 1,
|
|
93
177
|
"next": None,
|
|
@@ -103,7 +187,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
103
187
|
# Invalid response when trying to create class2
|
|
104
188
|
responses.add(
|
|
105
189
|
responses.POST,
|
|
106
|
-
f"http://testserver/api/v1/corpus/{
|
|
190
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
107
191
|
status=400,
|
|
108
192
|
json={"non_field_errors": "Already exists"},
|
|
109
193
|
)
|
|
@@ -111,7 +195,7 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
111
195
|
# Add both classes (class2 is created by another process)
|
|
112
196
|
responses.add(
|
|
113
197
|
responses.GET,
|
|
114
|
-
f"http://testserver/api/v1/corpus/{
|
|
198
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
115
199
|
json={
|
|
116
200
|
"count": 2,
|
|
117
201
|
"next": None,
|
|
@@ -141,15 +225,15 @@ def test_get_ml_class_reload(responses, mock_elements_worker):
|
|
|
141
225
|
] == BASE_API_CALLS + [
|
|
142
226
|
(
|
|
143
227
|
"GET",
|
|
144
|
-
f"http://testserver/api/v1/corpus/{
|
|
228
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
145
229
|
),
|
|
146
230
|
(
|
|
147
231
|
"POST",
|
|
148
|
-
f"http://testserver/api/v1/corpus/{
|
|
232
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
149
233
|
),
|
|
150
234
|
(
|
|
151
235
|
"GET",
|
|
152
|
-
f"http://testserver/api/v1/corpus/{
|
|
236
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
153
237
|
),
|
|
154
238
|
]
|
|
155
239
|
|
|
@@ -169,7 +253,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
|
|
|
169
253
|
"""
|
|
170
254
|
responses.add(
|
|
171
255
|
responses.GET,
|
|
172
|
-
f"http://testserver/api/v1/corpus/{
|
|
256
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
173
257
|
status=200,
|
|
174
258
|
json={
|
|
175
259
|
"count": 1,
|
|
@@ -189,7 +273,7 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
|
|
|
189
273
|
] == BASE_API_CALLS + [
|
|
190
274
|
(
|
|
191
275
|
"GET",
|
|
192
|
-
f"http://testserver/api/v1/corpus/{
|
|
276
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
193
277
|
),
|
|
194
278
|
]
|
|
195
279
|
|
|
@@ -276,7 +360,7 @@ def test_create_classification_api_error(responses, mock_elements_worker):
|
|
|
276
360
|
responses.add(
|
|
277
361
|
responses.POST,
|
|
278
362
|
"http://testserver/api/v1/classifications/",
|
|
279
|
-
status=
|
|
363
|
+
status=418,
|
|
280
364
|
)
|
|
281
365
|
|
|
282
366
|
with pytest.raises(ErrorResponse):
|
|
@@ -287,17 +371,10 @@ def test_create_classification_api_error(responses, mock_elements_worker):
|
|
|
287
371
|
high_confidence=True,
|
|
288
372
|
)
|
|
289
373
|
|
|
290
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
374
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
291
375
|
assert [
|
|
292
376
|
(call.request.method, call.request.url) for call in responses.calls
|
|
293
|
-
] == BASE_API_CALLS + [
|
|
294
|
-
# We retry 5 times the API call
|
|
295
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
296
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
297
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
298
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
299
|
-
("POST", "http://testserver/api/v1/classifications/"),
|
|
300
|
-
]
|
|
377
|
+
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classifications/")]
|
|
301
378
|
|
|
302
379
|
|
|
303
380
|
def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
@@ -306,7 +383,7 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
|
306
383
|
# Automatically create a missing class!
|
|
307
384
|
responses.add(
|
|
308
385
|
responses.POST,
|
|
309
|
-
"http://testserver/api/v1/corpus/
|
|
386
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
310
387
|
status=201,
|
|
311
388
|
json={"id": "new-ml-class-1234"},
|
|
312
389
|
)
|
|
@@ -325,15 +402,12 @@ def test_create_classification_create_ml_class(mock_elements_worker, responses):
|
|
|
325
402
|
)
|
|
326
403
|
|
|
327
404
|
# Check a class & classification has been created
|
|
328
|
-
for call in responses.calls:
|
|
329
|
-
print(call.request.url, call.request.body)
|
|
330
|
-
|
|
331
405
|
assert [
|
|
332
406
|
(call.request.url, json.loads(call.request.body))
|
|
333
407
|
for call in responses.calls[-2:]
|
|
334
408
|
] == [
|
|
335
409
|
(
|
|
336
|
-
"http://testserver/api/v1/corpus/
|
|
410
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
337
411
|
{"name": "a_class"},
|
|
338
412
|
),
|
|
339
413
|
(
|
|
@@ -506,12 +580,12 @@ def test_create_classifications_wrong_data(
|
|
|
506
580
|
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
|
|
507
581
|
"classifications": [
|
|
508
582
|
{
|
|
509
|
-
"
|
|
583
|
+
"ml_class": "cat",
|
|
510
584
|
"confidence": 0.75,
|
|
511
585
|
"high_confidence": False,
|
|
512
586
|
},
|
|
513
587
|
{
|
|
514
|
-
"
|
|
588
|
+
"ml_class": "dog",
|
|
515
589
|
"confidence": 0.25,
|
|
516
590
|
"high_confidence": False,
|
|
517
591
|
},
|
|
@@ -523,86 +597,71 @@ def test_create_classifications_wrong_data(
|
|
|
523
597
|
|
|
524
598
|
|
|
525
599
|
@pytest.mark.parametrize(
|
|
526
|
-
("arg_name", "data", "error_message"
|
|
600
|
+
("arg_name", "data", "error_message"),
|
|
527
601
|
[
|
|
528
|
-
# Wrong classifications >
|
|
602
|
+
# Wrong classifications > ml_class
|
|
529
603
|
(
|
|
530
|
-
"
|
|
604
|
+
"ml_class",
|
|
531
605
|
DELETE_PARAMETER,
|
|
532
|
-
"
|
|
533
|
-
|
|
534
|
-
), # Updated
|
|
606
|
+
"ml_class shouldn't be null and should be of type str",
|
|
607
|
+
),
|
|
535
608
|
(
|
|
536
|
-
"
|
|
609
|
+
"ml_class",
|
|
537
610
|
None,
|
|
538
|
-
"
|
|
539
|
-
AssertionError,
|
|
611
|
+
"ml_class shouldn't be null and should be of type str",
|
|
540
612
|
),
|
|
541
613
|
(
|
|
542
|
-
"
|
|
614
|
+
"ml_class",
|
|
543
615
|
1234,
|
|
544
|
-
"
|
|
545
|
-
AssertionError,
|
|
546
|
-
),
|
|
547
|
-
(
|
|
548
|
-
"ml_class_id",
|
|
549
|
-
"not_an_uuid",
|
|
550
|
-
"ml_class_id is not a valid uuid.",
|
|
551
|
-
ValueError,
|
|
616
|
+
"ml_class shouldn't be null and should be of type str",
|
|
552
617
|
),
|
|
553
618
|
# Wrong classifications > confidence
|
|
554
619
|
(
|
|
555
620
|
"confidence",
|
|
556
621
|
DELETE_PARAMETER,
|
|
557
622
|
"confidence shouldn't be null and should be a float in [0..1] range",
|
|
558
|
-
AssertionError,
|
|
559
623
|
),
|
|
560
624
|
(
|
|
561
625
|
"confidence",
|
|
562
626
|
None,
|
|
563
627
|
"confidence shouldn't be null and should be a float in [0..1] range",
|
|
564
|
-
AssertionError,
|
|
565
628
|
),
|
|
566
629
|
(
|
|
567
630
|
"confidence",
|
|
568
631
|
"wrong confidence",
|
|
569
632
|
"confidence shouldn't be null and should be a float in [0..1] range",
|
|
570
|
-
AssertionError,
|
|
571
633
|
),
|
|
572
634
|
(
|
|
573
635
|
"confidence",
|
|
574
636
|
0,
|
|
575
637
|
"confidence shouldn't be null and should be a float in [0..1] range",
|
|
576
|
-
AssertionError,
|
|
577
638
|
),
|
|
578
639
|
(
|
|
579
640
|
"confidence",
|
|
580
641
|
2.00,
|
|
581
642
|
"confidence shouldn't be null and should be a float in [0..1] range",
|
|
582
|
-
AssertionError,
|
|
583
643
|
),
|
|
584
644
|
# Wrong classifications > high_confidence
|
|
585
645
|
(
|
|
586
646
|
"high_confidence",
|
|
587
647
|
"wrong high_confidence",
|
|
588
648
|
"high_confidence should be of type bool",
|
|
589
|
-
AssertionError,
|
|
590
649
|
),
|
|
591
650
|
],
|
|
592
651
|
)
|
|
593
652
|
def test_create_classifications_wrong_classifications_data(
|
|
594
|
-
arg_name, data, error_message,
|
|
653
|
+
arg_name, data, error_message, mock_elements_worker
|
|
595
654
|
):
|
|
596
655
|
all_data = {
|
|
597
656
|
"element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
|
|
598
657
|
"classifications": [
|
|
599
658
|
{
|
|
600
|
-
"
|
|
659
|
+
"ml_class": "cat",
|
|
601
660
|
"confidence": 0.75,
|
|
602
661
|
"high_confidence": False,
|
|
603
662
|
},
|
|
604
663
|
{
|
|
605
|
-
"
|
|
664
|
+
"ml_class": "dog",
|
|
606
665
|
"confidence": 0.25,
|
|
607
666
|
"high_confidence": False,
|
|
608
667
|
# Overwrite with wrong data
|
|
@@ -614,7 +673,7 @@ def test_create_classifications_wrong_classifications_data(
|
|
|
614
673
|
del all_data["classifications"][1][arg_name]
|
|
615
674
|
|
|
616
675
|
with pytest.raises(
|
|
617
|
-
|
|
676
|
+
AssertionError,
|
|
618
677
|
match=re.escape(
|
|
619
678
|
f"Classification at index 1 in classifications: {error_message}"
|
|
620
679
|
),
|
|
@@ -623,20 +682,21 @@ def test_create_classifications_wrong_classifications_data(
|
|
|
623
682
|
|
|
624
683
|
|
|
625
684
|
def test_create_classifications_api_error(responses, mock_elements_worker):
|
|
685
|
+
mock_elements_worker.classes = {"cat": "0000", "dog": "1111"}
|
|
626
686
|
responses.add(
|
|
627
687
|
responses.POST,
|
|
628
688
|
"http://testserver/api/v1/classification/bulk/",
|
|
629
|
-
status=
|
|
689
|
+
status=418,
|
|
630
690
|
)
|
|
631
691
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
632
692
|
classes = [
|
|
633
693
|
{
|
|
634
|
-
"
|
|
694
|
+
"ml_class": "cat",
|
|
635
695
|
"confidence": 0.75,
|
|
636
696
|
"high_confidence": False,
|
|
637
697
|
},
|
|
638
698
|
{
|
|
639
|
-
"
|
|
699
|
+
"ml_class": "dog",
|
|
640
700
|
"confidence": 0.25,
|
|
641
701
|
"high_confidence": False,
|
|
642
702
|
},
|
|
@@ -647,192 +707,245 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
|
|
|
647
707
|
element=elt, classifications=classes
|
|
648
708
|
)
|
|
649
709
|
|
|
650
|
-
assert len(responses.calls) == len(BASE_API_CALLS) +
|
|
710
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 1
|
|
651
711
|
assert [
|
|
652
712
|
(call.request.method, call.request.url) for call in responses.calls
|
|
653
|
-
] == BASE_API_CALLS + [
|
|
654
|
-
# We retry 5 times the API call
|
|
655
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
656
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
657
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
658
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
659
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
660
|
-
]
|
|
661
|
-
|
|
713
|
+
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
662
714
|
|
|
663
|
-
def test_create_classifications(responses, mock_elements_worker_with_cache):
|
|
664
|
-
# Set MLClass in cache
|
|
665
|
-
portrait_uuid = str(uuid4())
|
|
666
|
-
landscape_uuid = str(uuid4())
|
|
667
|
-
mock_elements_worker_with_cache.classes = {
|
|
668
|
-
"portrait": portrait_uuid,
|
|
669
|
-
"landscape": landscape_uuid,
|
|
670
|
-
}
|
|
671
715
|
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
{
|
|
675
|
-
"ml_class_id": portrait_uuid,
|
|
676
|
-
"confidence": 0.75,
|
|
677
|
-
"high_confidence": False,
|
|
678
|
-
},
|
|
679
|
-
{
|
|
680
|
-
"ml_class_id": landscape_uuid,
|
|
681
|
-
"confidence": 0.25,
|
|
682
|
-
"high_confidence": False,
|
|
683
|
-
},
|
|
684
|
-
]
|
|
716
|
+
def test_create_classifications_create_ml_class(mock_elements_worker, responses):
|
|
717
|
+
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
685
718
|
|
|
719
|
+
# Automatically create a missing class!
|
|
720
|
+
responses.add(
|
|
721
|
+
responses.POST,
|
|
722
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
723
|
+
status=201,
|
|
724
|
+
json={"id": "new-ml-class-1234"},
|
|
725
|
+
)
|
|
686
726
|
responses.add(
|
|
687
727
|
responses.POST,
|
|
688
728
|
"http://testserver/api/v1/classification/bulk/",
|
|
689
|
-
status=
|
|
729
|
+
status=201,
|
|
690
730
|
json={
|
|
691
731
|
"parent": str(elt.id),
|
|
692
732
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
693
733
|
"classifications": [
|
|
694
734
|
{
|
|
695
735
|
"id": "00000000-0000-0000-0000-000000000000",
|
|
696
|
-
"ml_class":
|
|
736
|
+
"ml_class": "new-ml-class-1234",
|
|
697
737
|
"confidence": 0.75,
|
|
698
738
|
"high_confidence": False,
|
|
699
739
|
"state": "pending",
|
|
700
740
|
},
|
|
701
|
-
{
|
|
702
|
-
"id": "11111111-1111-1111-1111-111111111111",
|
|
703
|
-
"ml_class": landscape_uuid,
|
|
704
|
-
"confidence": 0.25,
|
|
705
|
-
"high_confidence": False,
|
|
706
|
-
"state": "pending",
|
|
707
|
-
},
|
|
708
741
|
],
|
|
709
742
|
},
|
|
710
743
|
)
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
element=elt,
|
|
744
|
+
mock_elements_worker.classes = {"another_class": "0000"}
|
|
745
|
+
mock_elements_worker.create_classifications(
|
|
746
|
+
element=elt,
|
|
747
|
+
classifications=[
|
|
748
|
+
{
|
|
749
|
+
"ml_class": "a_class",
|
|
750
|
+
"confidence": 0.75,
|
|
751
|
+
"high_confidence": False,
|
|
752
|
+
}
|
|
753
|
+
],
|
|
714
754
|
)
|
|
715
755
|
|
|
716
|
-
|
|
756
|
+
# Check a class & classification has been created
|
|
757
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + 2
|
|
717
758
|
assert [
|
|
718
759
|
(call.request.method, call.request.url) for call in responses.calls
|
|
719
760
|
] == BASE_API_CALLS + [
|
|
761
|
+
(
|
|
762
|
+
"POST",
|
|
763
|
+
f"http://testserver/api/v1/corpus/{CORPUS_ID}/classes/",
|
|
764
|
+
),
|
|
720
765
|
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
721
766
|
]
|
|
722
767
|
|
|
768
|
+
assert json.loads(responses.calls[-2].request.body) == {"name": "a_class"}
|
|
723
769
|
assert json.loads(responses.calls[-1].request.body) == {
|
|
724
|
-
"parent":
|
|
770
|
+
"parent": "12341234-1234-1234-1234-123412341234",
|
|
725
771
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
726
|
-
"classifications":
|
|
772
|
+
"classifications": [
|
|
773
|
+
{
|
|
774
|
+
"ml_class": "new-ml-class-1234",
|
|
775
|
+
"confidence": 0.75,
|
|
776
|
+
"high_confidence": False,
|
|
777
|
+
}
|
|
778
|
+
],
|
|
727
779
|
}
|
|
728
780
|
|
|
729
|
-
# Check that created classifications were properly stored in SQLite cache
|
|
730
|
-
assert list(CachedClassification.select()) == [
|
|
731
|
-
CachedClassification(
|
|
732
|
-
id=UUID("00000000-0000-0000-0000-000000000000"),
|
|
733
|
-
element_id=UUID(elt.id),
|
|
734
|
-
class_name="portrait",
|
|
735
|
-
confidence=0.75,
|
|
736
|
-
state="pending",
|
|
737
|
-
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
|
|
738
|
-
),
|
|
739
|
-
CachedClassification(
|
|
740
|
-
id=UUID("11111111-1111-1111-1111-111111111111"),
|
|
741
|
-
element_id=UUID(elt.id),
|
|
742
|
-
class_name="landscape",
|
|
743
|
-
confidence=0.25,
|
|
744
|
-
state="pending",
|
|
745
|
-
worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
|
|
746
|
-
),
|
|
747
|
-
]
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
def test_create_classifications_not_in_cache(
|
|
751
|
-
responses, mock_elements_worker_with_cache
|
|
752
|
-
):
|
|
753
|
-
"""
|
|
754
|
-
CreateClassifications using ID that are not in `.classes` attribute.
|
|
755
|
-
Will load corpus MLClass to insert the corresponding name in Cache.
|
|
756
|
-
"""
|
|
757
|
-
portrait_uuid = str(uuid4())
|
|
758
|
-
landscape_uuid = str(uuid4())
|
|
759
|
-
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
|
|
760
|
-
classes = [
|
|
761
|
-
{
|
|
762
|
-
"ml_class_id": portrait_uuid,
|
|
763
|
-
"confidence": 0.75,
|
|
764
|
-
"high_confidence": False,
|
|
765
|
-
},
|
|
766
|
-
{
|
|
767
|
-
"ml_class_id": landscape_uuid,
|
|
768
|
-
"confidence": 0.25,
|
|
769
|
-
"high_confidence": False,
|
|
770
|
-
},
|
|
771
|
-
]
|
|
772
781
|
|
|
782
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
783
|
+
def test_create_classifications(batch_size, responses, mock_elements_worker):
|
|
784
|
+
mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
|
|
785
|
+
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
773
786
|
responses.add(
|
|
774
787
|
responses.POST,
|
|
775
788
|
"http://testserver/api/v1/classification/bulk/",
|
|
776
789
|
status=200,
|
|
777
|
-
json={
|
|
778
|
-
"parent": str(elt.id),
|
|
779
|
-
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
780
|
-
"classifications": [
|
|
781
|
-
{
|
|
782
|
-
"id": "00000000-0000-0000-0000-000000000000",
|
|
783
|
-
"ml_class": portrait_uuid,
|
|
784
|
-
"confidence": 0.75,
|
|
785
|
-
"high_confidence": False,
|
|
786
|
-
"state": "pending",
|
|
787
|
-
},
|
|
788
|
-
{
|
|
789
|
-
"id": "11111111-1111-1111-1111-111111111111",
|
|
790
|
-
"ml_class": landscape_uuid,
|
|
791
|
-
"confidence": 0.25,
|
|
792
|
-
"high_confidence": False,
|
|
793
|
-
"state": "pending",
|
|
794
|
-
},
|
|
795
|
-
],
|
|
796
|
-
},
|
|
790
|
+
json={"classifications": []},
|
|
797
791
|
)
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
792
|
+
|
|
793
|
+
mock_elements_worker.create_classifications(
|
|
794
|
+
element=elt,
|
|
795
|
+
classifications=[
|
|
796
|
+
{
|
|
797
|
+
"ml_class": "portrait",
|
|
798
|
+
"confidence": 0.75,
|
|
799
|
+
"high_confidence": False,
|
|
800
|
+
},
|
|
801
|
+
{
|
|
802
|
+
"ml_class": "landscape",
|
|
803
|
+
"confidence": 0.25,
|
|
804
|
+
"high_confidence": False,
|
|
805
|
+
},
|
|
806
|
+
],
|
|
807
|
+
batch_size=batch_size,
|
|
813
808
|
)
|
|
814
809
|
|
|
810
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
811
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
812
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
813
|
+
|
|
814
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
815
|
+
assert [
|
|
816
|
+
(call.request.method, call.request.url) for call in responses.calls
|
|
817
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
818
|
+
|
|
819
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
820
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
821
|
+
empty_payload = {
|
|
822
|
+
"parent": str(elt.id),
|
|
823
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
824
|
+
"classifications": [],
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
bodies = []
|
|
828
|
+
first_call_idx = None
|
|
829
|
+
if batch_size > 1:
|
|
830
|
+
first_call_idx = -1
|
|
831
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
832
|
+
else:
|
|
833
|
+
first_call_idx = -2
|
|
834
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
835
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
836
|
+
|
|
837
|
+
assert [
|
|
838
|
+
json.loads(bulk_call.request.body)
|
|
839
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
840
|
+
] == bodies
|
|
841
|
+
|
|
842
|
+
|
|
843
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
844
|
+
def test_create_classifications_with_cache(
|
|
845
|
+
batch_size, responses, mock_elements_worker_with_cache
|
|
846
|
+
):
|
|
847
|
+
mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
|
|
848
|
+
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
|
|
849
|
+
|
|
850
|
+
if batch_size > 1:
|
|
851
|
+
responses.add(
|
|
852
|
+
responses.POST,
|
|
853
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
854
|
+
status=200,
|
|
855
|
+
json={
|
|
856
|
+
"parent": str(elt.id),
|
|
857
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
858
|
+
"classifications": [
|
|
859
|
+
{
|
|
860
|
+
"id": "00000000-0000-0000-0000-000000000000",
|
|
861
|
+
"ml_class": "0000",
|
|
862
|
+
"confidence": 0.75,
|
|
863
|
+
"high_confidence": False,
|
|
864
|
+
"state": "pending",
|
|
865
|
+
},
|
|
866
|
+
{
|
|
867
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
868
|
+
"ml_class": "1111",
|
|
869
|
+
"confidence": 0.25,
|
|
870
|
+
"high_confidence": False,
|
|
871
|
+
"state": "pending",
|
|
872
|
+
},
|
|
873
|
+
],
|
|
874
|
+
},
|
|
875
|
+
)
|
|
876
|
+
else:
|
|
877
|
+
for cl_id, cl_class, cl_conf in [
|
|
878
|
+
("00000000-0000-0000-0000-000000000000", "0000", 0.75),
|
|
879
|
+
("11111111-1111-1111-1111-111111111111", "1111", 0.25),
|
|
880
|
+
]:
|
|
881
|
+
responses.add(
|
|
882
|
+
responses.POST,
|
|
883
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
884
|
+
status=200,
|
|
885
|
+
json={
|
|
886
|
+
"parent": str(elt.id),
|
|
887
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
888
|
+
"classifications": [
|
|
889
|
+
{
|
|
890
|
+
"id": cl_id,
|
|
891
|
+
"ml_class": cl_class,
|
|
892
|
+
"confidence": cl_conf,
|
|
893
|
+
"high_confidence": False,
|
|
894
|
+
"state": "pending",
|
|
895
|
+
},
|
|
896
|
+
],
|
|
897
|
+
},
|
|
898
|
+
)
|
|
899
|
+
|
|
815
900
|
mock_elements_worker_with_cache.create_classifications(
|
|
816
|
-
element=elt,
|
|
901
|
+
element=elt,
|
|
902
|
+
classifications=[
|
|
903
|
+
{
|
|
904
|
+
"ml_class": "portrait",
|
|
905
|
+
"confidence": 0.75,
|
|
906
|
+
"high_confidence": False,
|
|
907
|
+
},
|
|
908
|
+
{
|
|
909
|
+
"ml_class": "landscape",
|
|
910
|
+
"confidence": 0.25,
|
|
911
|
+
"high_confidence": False,
|
|
912
|
+
},
|
|
913
|
+
],
|
|
914
|
+
batch_size=batch_size,
|
|
817
915
|
)
|
|
818
916
|
|
|
819
|
-
|
|
917
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
918
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
919
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
920
|
+
|
|
921
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
820
922
|
assert [
|
|
821
923
|
(call.request.method, call.request.url) for call in responses.calls
|
|
822
|
-
] == BASE_API_CALLS +
|
|
823
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
824
|
-
(
|
|
825
|
-
"GET",
|
|
826
|
-
f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
|
|
827
|
-
),
|
|
828
|
-
]
|
|
924
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
829
925
|
|
|
830
|
-
|
|
926
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
927
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
928
|
+
empty_payload = {
|
|
831
929
|
"parent": str(elt.id),
|
|
832
930
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
833
|
-
"classifications":
|
|
931
|
+
"classifications": [],
|
|
834
932
|
}
|
|
835
933
|
|
|
934
|
+
bodies = []
|
|
935
|
+
first_call_idx = None
|
|
936
|
+
if batch_size > 1:
|
|
937
|
+
first_call_idx = -1
|
|
938
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
939
|
+
else:
|
|
940
|
+
first_call_idx = -2
|
|
941
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
942
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
943
|
+
|
|
944
|
+
assert [
|
|
945
|
+
json.loads(bulk_call.request.body)
|
|
946
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
947
|
+
] == bodies
|
|
948
|
+
|
|
836
949
|
# Check that created classifications were properly stored in SQLite cache
|
|
837
950
|
assert list(CachedClassification.select()) == [
|
|
838
951
|
CachedClassification(
|