arkindex-base-worker 0.3.6rc4__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. arkindex_base_worker-0.3.7.dist-info/LICENSE +21 -0
  2. arkindex_base_worker-0.3.7.dist-info/METADATA +77 -0
  3. arkindex_base_worker-0.3.7.dist-info/RECORD +47 -0
  4. {arkindex_base_worker-0.3.6rc4.dist-info → arkindex_base_worker-0.3.7.dist-info}/WHEEL +1 -1
  5. {arkindex_base_worker-0.3.6rc4.dist-info → arkindex_base_worker-0.3.7.dist-info}/top_level.txt +2 -0
  6. arkindex_worker/cache.py +14 -0
  7. arkindex_worker/image.py +29 -19
  8. arkindex_worker/models.py +14 -2
  9. arkindex_worker/utils.py +17 -3
  10. arkindex_worker/worker/__init__.py +122 -125
  11. arkindex_worker/worker/base.py +24 -24
  12. arkindex_worker/worker/classification.py +18 -25
  13. arkindex_worker/worker/dataset.py +24 -18
  14. arkindex_worker/worker/element.py +100 -19
  15. arkindex_worker/worker/entity.py +35 -4
  16. arkindex_worker/worker/metadata.py +21 -11
  17. arkindex_worker/worker/training.py +13 -0
  18. arkindex_worker/worker/transcription.py +45 -5
  19. arkindex_worker/worker/version.py +22 -0
  20. hooks/pre_gen_project.py +3 -0
  21. tests/conftest.py +16 -8
  22. tests/test_base_worker.py +0 -6
  23. tests/test_dataset_worker.py +291 -409
  24. tests/test_elements_worker/test_classifications.py +365 -539
  25. tests/test_elements_worker/test_cli.py +1 -1
  26. tests/test_elements_worker/test_dataset.py +97 -116
  27. tests/test_elements_worker/test_elements.py +354 -76
  28. tests/test_elements_worker/test_entities.py +22 -2
  29. tests/test_elements_worker/test_metadata.py +53 -27
  30. tests/test_elements_worker/test_training.py +35 -0
  31. tests/test_elements_worker/test_transcriptions.py +149 -16
  32. tests/test_elements_worker/test_worker.py +19 -6
  33. tests/test_image.py +37 -0
  34. tests/test_utils.py +23 -1
  35. worker-demo/tests/__init__.py +0 -0
  36. worker-demo/tests/conftest.py +32 -0
  37. worker-demo/tests/test_worker.py +12 -0
  38. worker-demo/worker_demo/__init__.py +6 -0
  39. worker-demo/worker_demo/worker.py +19 -0
  40. arkindex_base_worker-0.3.6rc4.dist-info/METADATA +0 -47
  41. arkindex_base_worker-0.3.6rc4.dist-info/RECORD +0 -40
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  import re
3
- from uuid import UUID, uuid4
3
+ from uuid import UUID
4
4
 
5
5
  import pytest
6
6
  from apistar.exceptions import ErrorResponse
@@ -10,6 +10,10 @@ from arkindex_worker.models import Element
10
10
 
11
11
  from . import BASE_API_CALLS
12
12
 
13
+ # Special string used to know if the `arg_name` passed in
14
+ # `pytest.mark.parametrize` should be removed from the payload
15
+ DELETE_PARAMETER = "DELETE_PARAMETER"
16
+
13
17
 
14
18
  def test_get_ml_class_id_load_classes(responses, mock_elements_worker):
15
19
  corpus_id = "11111111-1111-1111-1111-111111111111"
@@ -190,54 +194,116 @@ def test_retrieve_ml_class_not_in_cache(responses, mock_elements_worker):
190
194
  ]
191
195
 
192
196
 
193
- def test_create_classification_wrong_element(mock_elements_worker):
194
- with pytest.raises(
195
- AssertionError,
196
- match="element shouldn't be null and should be an Element or CachedElement",
197
- ):
198
- mock_elements_worker.create_classification(
199
- element=None,
200
- ml_class="a_class",
201
- confidence=0.42,
202
- high_confidence=True,
203
- )
204
-
205
- with pytest.raises(
206
- AssertionError,
207
- match="element shouldn't be null and should be an Element or CachedElement",
208
- ):
197
+ @pytest.mark.parametrize(
198
+ ("arg_name", "data", "error_message"),
199
+ [
200
+ # Wrong element
201
+ (
202
+ "element",
203
+ None,
204
+ "element shouldn't be null and should be an Element or CachedElement",
205
+ ),
206
+ (
207
+ "element",
208
+ "not element type",
209
+ "element shouldn't be null and should be an Element or CachedElement",
210
+ ),
211
+ # Wrong ml_class
212
+ (
213
+ "ml_class",
214
+ None,
215
+ "ml_class shouldn't be null and should be of type str",
216
+ ),
217
+ (
218
+ "ml_class",
219
+ 1234,
220
+ "ml_class shouldn't be null and should be of type str",
221
+ ),
222
+ # Wrong confidence
223
+ (
224
+ "confidence",
225
+ None,
226
+ "confidence shouldn't be null and should be a float in [0..1] range",
227
+ ),
228
+ (
229
+ "confidence",
230
+ "wrong confidence",
231
+ "confidence shouldn't be null and should be a float in [0..1] range",
232
+ ),
233
+ (
234
+ "confidence",
235
+ 0,
236
+ "confidence shouldn't be null and should be a float in [0..1] range",
237
+ ),
238
+ (
239
+ "confidence",
240
+ 2.00,
241
+ "confidence shouldn't be null and should be a float in [0..1] range",
242
+ ),
243
+ # Wrong high_confidence
244
+ (
245
+ "high_confidence",
246
+ None,
247
+ "high_confidence shouldn't be null and should be of type bool",
248
+ ),
249
+ (
250
+ "high_confidence",
251
+ "wrong high_confidence",
252
+ "high_confidence shouldn't be null and should be of type bool",
253
+ ),
254
+ ],
255
+ )
256
+ def test_create_classification_wrong_data(
257
+ arg_name, data, error_message, mock_elements_worker
258
+ ):
259
+ mock_elements_worker.classes = {"a_class": "0000"}
260
+ with pytest.raises(AssertionError, match=re.escape(error_message)):
209
261
  mock_elements_worker.create_classification(
210
- element="not element type",
211
- ml_class="a_class",
212
- confidence=0.42,
213
- high_confidence=True,
262
+ **{
263
+ "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
264
+ "ml_class": "a_class",
265
+ "confidence": 0.42,
266
+ "high_confidence": True,
267
+ # Overwrite with wrong data
268
+ arg_name: data,
269
+ }
214
270
  )
215
271
 
216
272
 
217
- def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
273
+ def test_create_classification_api_error(responses, mock_elements_worker):
274
+ mock_elements_worker.classes = {"a_class": "0000"}
218
275
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
276
+ responses.add(
277
+ responses.POST,
278
+ "http://testserver/api/v1/classifications/",
279
+ status=500,
280
+ )
219
281
 
220
- with pytest.raises(
221
- AssertionError, match="ml_class shouldn't be null and should be of type str"
222
- ):
282
+ with pytest.raises(ErrorResponse):
223
283
  mock_elements_worker.create_classification(
224
284
  element=elt,
225
- ml_class=None,
285
+ ml_class="a_class",
226
286
  confidence=0.42,
227
287
  high_confidence=True,
228
288
  )
229
289
 
230
- with pytest.raises(
231
- AssertionError, match="ml_class shouldn't be null and should be of type str"
232
- ):
233
- mock_elements_worker.create_classification(
234
- element=elt,
235
- ml_class=1234,
236
- confidence=0.42,
237
- high_confidence=True,
238
- )
290
+ assert len(responses.calls) == len(BASE_API_CALLS) + 5
291
+ assert [
292
+ (call.request.method, call.request.url) for call in responses.calls
293
+ ] == BASE_API_CALLS + [
294
+ # We retry 5 times the API call
295
+ ("POST", "http://testserver/api/v1/classifications/"),
296
+ ("POST", "http://testserver/api/v1/classifications/"),
297
+ ("POST", "http://testserver/api/v1/classifications/"),
298
+ ("POST", "http://testserver/api/v1/classifications/"),
299
+ ("POST", "http://testserver/api/v1/classifications/"),
300
+ ]
301
+
302
+
303
+ def test_create_classification_create_ml_class(mock_elements_worker, responses):
304
+ elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
239
305
 
240
- # Automatically create a missing class !
306
+ # Automatically create a missing class!
241
307
  responses.add(
242
308
  responses.POST,
243
309
  "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
@@ -259,9 +325,6 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
259
325
  )
260
326
 
261
327
  # Check a class & classification has been created
262
- for call in responses.calls:
263
- print(call.request.url, call.request.body)
264
-
265
328
  assert [
266
329
  (call.request.url, json.loads(call.request.body))
267
330
  for call in responses.calls[-2:]
@@ -283,119 +346,6 @@ def test_create_classification_wrong_ml_class(mock_elements_worker, responses):
283
346
  ]
284
347
 
285
348
 
286
- def test_create_classification_wrong_confidence(mock_elements_worker):
287
- mock_elements_worker.classes = {"a_class": "0000"}
288
- elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
289
- with pytest.raises(
290
- AssertionError,
291
- match=re.escape(
292
- "confidence shouldn't be null and should be a float in [0..1] range"
293
- ),
294
- ):
295
- mock_elements_worker.create_classification(
296
- element=elt,
297
- ml_class="a_class",
298
- confidence=None,
299
- high_confidence=True,
300
- )
301
-
302
- with pytest.raises(
303
- AssertionError,
304
- match=re.escape(
305
- "confidence shouldn't be null and should be a float in [0..1] range"
306
- ),
307
- ):
308
- mock_elements_worker.create_classification(
309
- element=elt,
310
- ml_class="a_class",
311
- confidence="wrong confidence",
312
- high_confidence=True,
313
- )
314
-
315
- with pytest.raises(
316
- AssertionError,
317
- match=re.escape(
318
- "confidence shouldn't be null and should be a float in [0..1] range"
319
- ),
320
- ):
321
- mock_elements_worker.create_classification(
322
- element=elt,
323
- ml_class="a_class",
324
- confidence=0,
325
- high_confidence=True,
326
- )
327
-
328
- with pytest.raises(
329
- AssertionError,
330
- match=re.escape(
331
- "confidence shouldn't be null and should be a float in [0..1] range"
332
- ),
333
- ):
334
- mock_elements_worker.create_classification(
335
- element=elt,
336
- ml_class="a_class",
337
- confidence=2.00,
338
- high_confidence=True,
339
- )
340
-
341
-
342
- def test_create_classification_wrong_high_confidence(mock_elements_worker):
343
- mock_elements_worker.classes = {"a_class": "0000"}
344
- elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
345
-
346
- with pytest.raises(
347
- AssertionError,
348
- match="high_confidence shouldn't be null and should be of type bool",
349
- ):
350
- mock_elements_worker.create_classification(
351
- element=elt,
352
- ml_class="a_class",
353
- confidence=0.42,
354
- high_confidence=None,
355
- )
356
-
357
- with pytest.raises(
358
- AssertionError,
359
- match="high_confidence shouldn't be null and should be of type bool",
360
- ):
361
- mock_elements_worker.create_classification(
362
- element=elt,
363
- ml_class="a_class",
364
- confidence=0.42,
365
- high_confidence="wrong high_confidence",
366
- )
367
-
368
-
369
- def test_create_classification_api_error(responses, mock_elements_worker):
370
- mock_elements_worker.classes = {"a_class": "0000"}
371
- elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
372
- responses.add(
373
- responses.POST,
374
- "http://testserver/api/v1/classifications/",
375
- status=500,
376
- )
377
-
378
- with pytest.raises(ErrorResponse):
379
- mock_elements_worker.create_classification(
380
- element=elt,
381
- ml_class="a_class",
382
- confidence=0.42,
383
- high_confidence=True,
384
- )
385
-
386
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
387
- assert [
388
- (call.request.method, call.request.url) for call in responses.calls
389
- ] == BASE_API_CALLS + [
390
- # We retry 5 times the API call
391
- ("POST", "http://testserver/api/v1/classifications/"),
392
- ("POST", "http://testserver/api/v1/classifications/"),
393
- ("POST", "http://testserver/api/v1/classifications/"),
394
- ("POST", "http://testserver/api/v1/classifications/"),
395
- ("POST", "http://testserver/api/v1/classifications/"),
396
- ]
397
-
398
-
399
349
  def test_create_classification(responses, mock_elements_worker):
400
350
  mock_elements_worker.classes = {"a_class": "0000"}
401
351
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
@@ -519,303 +469,165 @@ def test_create_classification_duplicate_worker_run(responses, mock_elements_wor
519
469
  }
520
470
 
521
471
 
522
- def test_create_classifications_wrong_element(mock_elements_worker):
523
- with pytest.raises(
524
- AssertionError,
525
- match="element shouldn't be null and should be an Element or CachedElement",
526
- ):
472
+ @pytest.mark.parametrize(
473
+ ("arg_name", "data", "error_message"),
474
+ [
475
+ (
476
+ "element",
477
+ None,
478
+ "element shouldn't be null and should be an Element or CachedElement",
479
+ ),
480
+ (
481
+ "element",
482
+ "not element type",
483
+ "element shouldn't be null and should be an Element or CachedElement",
484
+ ),
485
+ (
486
+ "classifications",
487
+ None,
488
+ "classifications shouldn't be null and should be of type list",
489
+ ),
490
+ (
491
+ "classifications",
492
+ 1234,
493
+ "classifications shouldn't be null and should be of type list",
494
+ ),
495
+ ],
496
+ )
497
+ def test_create_classifications_wrong_data(
498
+ arg_name, data, error_message, mock_elements_worker
499
+ ):
500
+ with pytest.raises(AssertionError, match=error_message):
527
501
  mock_elements_worker.create_classifications(
528
- element=None,
529
- classifications=[
530
- {
531
- "ml_class_id": "uuid1",
532
- "confidence": 0.75,
533
- "high_confidence": False,
534
- },
535
- {
536
- "ml_class_id": "uuid2",
537
- "confidence": 0.25,
538
- "high_confidence": False,
539
- },
540
- ],
502
+ **{
503
+ "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
504
+ "classifications": [
505
+ {
506
+ "ml_class": "cat",
507
+ "confidence": 0.75,
508
+ "high_confidence": False,
509
+ },
510
+ {
511
+ "ml_class": "dog",
512
+ "confidence": 0.25,
513
+ "high_confidence": False,
514
+ },
515
+ ],
516
+ # Overwrite with wrong data
517
+ arg_name: data,
518
+ },
541
519
  )
542
520
 
521
+
522
+ @pytest.mark.parametrize(
523
+ ("arg_name", "data", "error_message"),
524
+ [
525
+ # Wrong classifications > ml_class
526
+ (
527
+ "ml_class",
528
+ DELETE_PARAMETER,
529
+ "ml_class shouldn't be null and should be of type str",
530
+ ),
531
+ (
532
+ "ml_class",
533
+ None,
534
+ "ml_class shouldn't be null and should be of type str",
535
+ ),
536
+ (
537
+ "ml_class",
538
+ 1234,
539
+ "ml_class shouldn't be null and should be of type str",
540
+ ),
541
+ # Wrong classifications > confidence
542
+ (
543
+ "confidence",
544
+ DELETE_PARAMETER,
545
+ "confidence shouldn't be null and should be a float in [0..1] range",
546
+ ),
547
+ (
548
+ "confidence",
549
+ None,
550
+ "confidence shouldn't be null and should be a float in [0..1] range",
551
+ ),
552
+ (
553
+ "confidence",
554
+ "wrong confidence",
555
+ "confidence shouldn't be null and should be a float in [0..1] range",
556
+ ),
557
+ (
558
+ "confidence",
559
+ 0,
560
+ "confidence shouldn't be null and should be a float in [0..1] range",
561
+ ),
562
+ (
563
+ "confidence",
564
+ 2.00,
565
+ "confidence shouldn't be null and should be a float in [0..1] range",
566
+ ),
567
+ # Wrong classifications > high_confidence
568
+ (
569
+ "high_confidence",
570
+ "wrong high_confidence",
571
+ "high_confidence should be of type bool",
572
+ ),
573
+ ],
574
+ )
575
+ def test_create_classifications_wrong_classifications_data(
576
+ arg_name, data, error_message, mock_elements_worker
577
+ ):
578
+ all_data = {
579
+ "element": Element({"id": "12341234-1234-1234-1234-123412341234"}),
580
+ "classifications": [
581
+ {
582
+ "ml_class": "cat",
583
+ "confidence": 0.75,
584
+ "high_confidence": False,
585
+ },
586
+ {
587
+ "ml_class": "dog",
588
+ "confidence": 0.25,
589
+ "high_confidence": False,
590
+ # Overwrite with wrong data
591
+ arg_name: data,
592
+ },
593
+ ],
594
+ }
595
+ if data == DELETE_PARAMETER:
596
+ del all_data["classifications"][1][arg_name]
597
+
543
598
  with pytest.raises(
544
599
  AssertionError,
545
- match="element shouldn't be null and should be an Element or CachedElement",
600
+ match=re.escape(
601
+ f"Classification at index 1 in classifications: {error_message}"
602
+ ),
546
603
  ):
547
- mock_elements_worker.create_classifications(
548
- element="not element type",
549
- classifications=[
550
- {
551
- "ml_class_id": "uuid1",
552
- "confidence": 0.75,
553
- "high_confidence": False,
554
- },
555
- {
556
- "ml_class_id": "uuid2",
557
- "confidence": 0.25,
558
- "high_confidence": False,
559
- },
560
- ],
561
- )
604
+ mock_elements_worker.create_classifications(**all_data)
562
605
 
563
606
 
564
- def test_create_classifications_wrong_classifications(mock_elements_worker):
607
+ def test_create_classifications_api_error(responses, mock_elements_worker):
608
+ mock_elements_worker.classes = {"cat": "0000", "dog": "1111"}
609
+ responses.add(
610
+ responses.POST,
611
+ "http://testserver/api/v1/classification/bulk/",
612
+ status=500,
613
+ )
565
614
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
615
+ classes = [
616
+ {
617
+ "ml_class": "cat",
618
+ "confidence": 0.75,
619
+ "high_confidence": False,
620
+ },
621
+ {
622
+ "ml_class": "dog",
623
+ "confidence": 0.25,
624
+ "high_confidence": False,
625
+ },
626
+ ]
566
627
 
567
- with pytest.raises(
568
- AssertionError,
569
- match="classifications shouldn't be null and should be of type list",
570
- ):
628
+ with pytest.raises(ErrorResponse):
571
629
  mock_elements_worker.create_classifications(
572
- element=elt,
573
- classifications=None,
574
- )
575
-
576
- with pytest.raises(
577
- AssertionError,
578
- match="classifications shouldn't be null and should be of type list",
579
- ):
580
- mock_elements_worker.create_classifications(
581
- element=elt,
582
- classifications=1234,
583
- )
584
-
585
- with pytest.raises(
586
- AssertionError,
587
- match="Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str",
588
- ):
589
- mock_elements_worker.create_classifications(
590
- element=elt,
591
- classifications=[
592
- {
593
- "ml_class_id": str(uuid4()),
594
- "confidence": 0.75,
595
- "high_confidence": False,
596
- },
597
- {
598
- "ml_class_id": 0.25,
599
- "high_confidence": False,
600
- },
601
- ],
602
- )
603
-
604
- with pytest.raises(
605
- AssertionError,
606
- match="Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str",
607
- ):
608
- mock_elements_worker.create_classifications(
609
- element=elt,
610
- classifications=[
611
- {
612
- "ml_class_id": str(uuid4()),
613
- "confidence": 0.75,
614
- "high_confidence": False,
615
- },
616
- {
617
- "ml_class_id": None,
618
- "confidence": 0.25,
619
- "high_confidence": False,
620
- },
621
- ],
622
- )
623
-
624
- with pytest.raises(
625
- AssertionError,
626
- match="Classification at index 1 in classifications: ml_class_id shouldn't be null and should be of type str",
627
- ):
628
- mock_elements_worker.create_classifications(
629
- element=elt,
630
- classifications=[
631
- {
632
- "ml_class_id": str(uuid4()),
633
- "confidence": 0.75,
634
- "high_confidence": False,
635
- },
636
- {
637
- "ml_class_id": 1234,
638
- "confidence": 0.25,
639
- "high_confidence": False,
640
- },
641
- ],
642
- )
643
-
644
- with pytest.raises(
645
- ValueError,
646
- match="Classification at index 1 in classifications: ml_class_id is not a valid uuid.",
647
- ):
648
- mock_elements_worker.create_classifications(
649
- element=elt,
650
- classifications=[
651
- {
652
- "ml_class_id": str(uuid4()),
653
- "confidence": 0.75,
654
- "high_confidence": False,
655
- },
656
- {
657
- "ml_class_id": "not_an_uuid",
658
- "confidence": 0.25,
659
- "high_confidence": False,
660
- },
661
- ],
662
- )
663
-
664
- with pytest.raises(
665
- AssertionError,
666
- match=re.escape(
667
- "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
668
- ),
669
- ):
670
- mock_elements_worker.create_classifications(
671
- element=elt,
672
- classifications=[
673
- {
674
- "ml_class_id": str(uuid4()),
675
- "confidence": 0.75,
676
- "high_confidence": False,
677
- },
678
- {
679
- "ml_class_id": str(uuid4()),
680
- "high_confidence": False,
681
- },
682
- ],
683
- )
684
-
685
- with pytest.raises(
686
- AssertionError,
687
- match=re.escape(
688
- "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
689
- ),
690
- ):
691
- mock_elements_worker.create_classifications(
692
- element=elt,
693
- classifications=[
694
- {
695
- "ml_class_id": str(uuid4()),
696
- "confidence": 0.75,
697
- "high_confidence": False,
698
- },
699
- {
700
- "ml_class_id": str(uuid4()),
701
- "confidence": None,
702
- "high_confidence": False,
703
- },
704
- ],
705
- )
706
-
707
- with pytest.raises(
708
- AssertionError,
709
- match=re.escape(
710
- "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
711
- ),
712
- ):
713
- mock_elements_worker.create_classifications(
714
- element=elt,
715
- classifications=[
716
- {
717
- "ml_class_id": str(uuid4()),
718
- "confidence": 0.75,
719
- "high_confidence": False,
720
- },
721
- {
722
- "ml_class_id": str(uuid4()),
723
- "confidence": "wrong confidence",
724
- "high_confidence": False,
725
- },
726
- ],
727
- )
728
-
729
- with pytest.raises(
730
- AssertionError,
731
- match=re.escape(
732
- "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
733
- ),
734
- ):
735
- mock_elements_worker.create_classifications(
736
- element=elt,
737
- classifications=[
738
- {
739
- "ml_class_id": str(uuid4()),
740
- "confidence": 0.75,
741
- "high_confidence": False,
742
- },
743
- {
744
- "ml_class_id": str(uuid4()),
745
- "confidence": 0,
746
- "high_confidence": False,
747
- },
748
- ],
749
- )
750
-
751
- with pytest.raises(
752
- AssertionError,
753
- match=re.escape(
754
- "Classification at index 1 in classifications: confidence shouldn't be null and should be a float in [0..1] range"
755
- ),
756
- ):
757
- mock_elements_worker.create_classifications(
758
- element=elt,
759
- classifications=[
760
- {
761
- "ml_class_id": str(uuid4()),
762
- "confidence": 0.75,
763
- "high_confidence": False,
764
- },
765
- {
766
- "ml_class_id": str(uuid4()),
767
- "confidence": 2.00,
768
- "high_confidence": False,
769
- },
770
- ],
771
- )
772
-
773
- with pytest.raises(
774
- AssertionError,
775
- match=re.escape(
776
- "Classification at index 1 in classifications: high_confidence should be of type bool"
777
- ),
778
- ):
779
- mock_elements_worker.create_classifications(
780
- element=elt,
781
- classifications=[
782
- {
783
- "ml_class_id": str(uuid4()),
784
- "confidence": 0.75,
785
- "high_confidence": False,
786
- },
787
- {
788
- "ml_class_id": str(uuid4()),
789
- "confidence": 0.25,
790
- "high_confidence": "wrong high_confidence",
791
- },
792
- ],
793
- )
794
-
795
-
796
- def test_create_classifications_api_error(responses, mock_elements_worker):
797
- responses.add(
798
- responses.POST,
799
- "http://testserver/api/v1/classification/bulk/",
800
- status=500,
801
- )
802
- elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
803
- classes = [
804
- {
805
- "ml_class_id": str(uuid4()),
806
- "confidence": 0.75,
807
- "high_confidence": False,
808
- },
809
- {
810
- "ml_class_id": str(uuid4()),
811
- "confidence": 0.25,
812
- "high_confidence": False,
813
- },
814
- ]
815
-
816
- with pytest.raises(ErrorResponse):
817
- mock_elements_worker.create_classifications(
818
- element=elt, classifications=classes
630
+ element=elt, classifications=classes
819
631
  )
820
632
 
821
633
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
@@ -831,57 +643,96 @@ def test_create_classifications_api_error(responses, mock_elements_worker):
831
643
  ]
832
644
 
833
645
 
834
- def test_create_classifications(responses, mock_elements_worker_with_cache):
835
- # Set MLClass in cache
836
- portrait_uuid = str(uuid4())
837
- landscape_uuid = str(uuid4())
838
- mock_elements_worker_with_cache.classes = {
839
- "portrait": portrait_uuid,
840
- "landscape": landscape_uuid,
841
- }
842
-
843
- elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
844
- classes = [
845
- {
846
- "ml_class_id": portrait_uuid,
847
- "confidence": 0.75,
848
- "high_confidence": False,
849
- },
850
- {
851
- "ml_class_id": landscape_uuid,
852
- "confidence": 0.25,
853
- "high_confidence": False,
854
- },
855
- ]
646
+ def test_create_classifications_create_ml_class(mock_elements_worker, responses):
647
+ elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
856
648
 
649
+ # Automatically create a missing class!
650
+ responses.add(
651
+ responses.POST,
652
+ "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
653
+ status=201,
654
+ json={"id": "new-ml-class-1234"},
655
+ )
857
656
  responses.add(
858
657
  responses.POST,
859
658
  "http://testserver/api/v1/classification/bulk/",
860
- status=200,
659
+ status=201,
861
660
  json={
862
661
  "parent": str(elt.id),
863
662
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
864
663
  "classifications": [
865
664
  {
866
665
  "id": "00000000-0000-0000-0000-000000000000",
867
- "ml_class": portrait_uuid,
666
+ "ml_class": "new-ml-class-1234",
868
667
  "confidence": 0.75,
869
668
  "high_confidence": False,
870
669
  "state": "pending",
871
670
  },
872
- {
873
- "id": "11111111-1111-1111-1111-111111111111",
874
- "ml_class": landscape_uuid,
875
- "confidence": 0.25,
876
- "high_confidence": False,
877
- "state": "pending",
878
- },
879
671
  ],
880
672
  },
881
673
  )
674
+ mock_elements_worker.classes = {"another_class": "0000"}
675
+ mock_elements_worker.create_classifications(
676
+ element=elt,
677
+ classifications=[
678
+ {
679
+ "ml_class": "a_class",
680
+ "confidence": 0.75,
681
+ "high_confidence": False,
682
+ }
683
+ ],
684
+ )
882
685
 
883
- mock_elements_worker_with_cache.create_classifications(
884
- element=elt, classifications=classes
686
+ # Check a class & classification has been created
687
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
688
+ assert [
689
+ (call.request.method, call.request.url) for call in responses.calls
690
+ ] == BASE_API_CALLS + [
691
+ (
692
+ "POST",
693
+ "http://testserver/api/v1/corpus/11111111-1111-1111-1111-111111111111/classes/",
694
+ ),
695
+ ("POST", "http://testserver/api/v1/classification/bulk/"),
696
+ ]
697
+
698
+ assert json.loads(responses.calls[-2].request.body) == {"name": "a_class"}
699
+ assert json.loads(responses.calls[-1].request.body) == {
700
+ "parent": "12341234-1234-1234-1234-123412341234",
701
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
702
+ "classifications": [
703
+ {
704
+ "ml_class": "new-ml-class-1234",
705
+ "confidence": 0.75,
706
+ "high_confidence": False,
707
+ }
708
+ ],
709
+ }
710
+
711
+
712
+ def test_create_classifications(responses, mock_elements_worker):
713
+ mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
714
+ elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
715
+ responses.add(
716
+ responses.POST,
717
+ "http://testserver/api/v1/classification/bulk/",
718
+ status=200,
719
+ json={"classifications": []},
720
+ )
721
+
722
+ mock_elements_worker.create_classifications(
723
+ element=elt,
724
+ classifications=[
725
+ {
726
+ "ml_class": "portrait",
727
+ "confidence": 0.75,
728
+ "high_confidence": False,
729
+ },
730
+ {
731
+ "ml_class": "landscape",
732
+ "confidence": 0.25,
733
+ "high_confidence": False,
734
+ },
735
+ ],
885
736
  )
886
737
 
887
738
  assert len(responses.calls) == len(BASE_API_CALLS) + 1
@@ -894,52 +745,24 @@ def test_create_classifications(responses, mock_elements_worker_with_cache):
894
745
  assert json.loads(responses.calls[-1].request.body) == {
895
746
  "parent": str(elt.id),
896
747
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
897
- "classifications": classes,
748
+ "classifications": [
749
+ {
750
+ "confidence": 0.75,
751
+ "high_confidence": False,
752
+ "ml_class": "0000",
753
+ },
754
+ {
755
+ "confidence": 0.25,
756
+ "high_confidence": False,
757
+ "ml_class": "1111",
758
+ },
759
+ ],
898
760
  }
899
761
 
900
- # Check that created classifications were properly stored in SQLite cache
901
- assert list(CachedClassification.select()) == [
902
- CachedClassification(
903
- id=UUID("00000000-0000-0000-0000-000000000000"),
904
- element_id=UUID(elt.id),
905
- class_name="portrait",
906
- confidence=0.75,
907
- state="pending",
908
- worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
909
- ),
910
- CachedClassification(
911
- id=UUID("11111111-1111-1111-1111-111111111111"),
912
- element_id=UUID(elt.id),
913
- class_name="landscape",
914
- confidence=0.25,
915
- state="pending",
916
- worker_run_id=UUID("56785678-5678-5678-5678-567856785678"),
917
- ),
918
- ]
919
-
920
762
 
921
- def test_create_classifications_not_in_cache(
922
- responses, mock_elements_worker_with_cache
923
- ):
924
- """
925
- CreateClassifications using ID that are not in `.classes` attribute.
926
- Will load corpus MLClass to insert the corresponding name in Cache.
927
- """
928
- portrait_uuid = str(uuid4())
929
- landscape_uuid = str(uuid4())
763
+ def test_create_classifications_with_cache(responses, mock_elements_worker_with_cache):
764
+ mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
930
765
  elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
931
- classes = [
932
- {
933
- "ml_class_id": portrait_uuid,
934
- "confidence": 0.75,
935
- "high_confidence": False,
936
- },
937
- {
938
- "ml_class_id": landscape_uuid,
939
- "confidence": 0.25,
940
- "high_confidence": False,
941
- },
942
- ]
943
766
 
944
767
  responses.add(
945
768
  responses.POST,
@@ -951,14 +774,14 @@ def test_create_classifications_not_in_cache(
951
774
  "classifications": [
952
775
  {
953
776
  "id": "00000000-0000-0000-0000-000000000000",
954
- "ml_class": portrait_uuid,
777
+ "ml_class": "0000",
955
778
  "confidence": 0.75,
956
779
  "high_confidence": False,
957
780
  "state": "pending",
958
781
  },
959
782
  {
960
783
  "id": "11111111-1111-1111-1111-111111111111",
961
- "ml_class": landscape_uuid,
784
+ "ml_class": "1111",
962
785
  "confidence": 0.25,
963
786
  "high_confidence": False,
964
787
  "state": "pending",
@@ -966,42 +789,45 @@ def test_create_classifications_not_in_cache(
966
789
  ],
967
790
  },
968
791
  )
969
- responses.add(
970
- responses.GET,
971
- f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
972
- status=200,
973
- json={
974
- "count": 2,
975
- "next": None,
976
- "results": [
977
- {
978
- "id": portrait_uuid,
979
- "name": "portrait",
980
- },
981
- {"id": landscape_uuid, "name": "landscape"},
982
- ],
983
- },
984
- )
985
792
 
986
793
  mock_elements_worker_with_cache.create_classifications(
987
- element=elt, classifications=classes
794
+ element=elt,
795
+ classifications=[
796
+ {
797
+ "ml_class": "portrait",
798
+ "confidence": 0.75,
799
+ "high_confidence": False,
800
+ },
801
+ {
802
+ "ml_class": "landscape",
803
+ "confidence": 0.25,
804
+ "high_confidence": False,
805
+ },
806
+ ],
988
807
  )
989
808
 
990
- assert len(responses.calls) == len(BASE_API_CALLS) + 2
809
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
991
810
  assert [
992
811
  (call.request.method, call.request.url) for call in responses.calls
993
812
  ] == BASE_API_CALLS + [
994
813
  ("POST", "http://testserver/api/v1/classification/bulk/"),
995
- (
996
- "GET",
997
- f"http://testserver/api/v1/corpus/{mock_elements_worker_with_cache.corpus_id}/classes/",
998
- ),
999
814
  ]
1000
815
 
1001
- assert json.loads(responses.calls[-2].request.body) == {
816
+ assert json.loads(responses.calls[-1].request.body) == {
1002
817
  "parent": str(elt.id),
1003
818
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1004
- "classifications": classes,
819
+ "classifications": [
820
+ {
821
+ "confidence": 0.75,
822
+ "high_confidence": False,
823
+ "ml_class": "0000",
824
+ },
825
+ {
826
+ "confidence": 0.25,
827
+ "high_confidence": False,
828
+ "ml_class": "1111",
829
+ },
830
+ ],
1005
831
  }
1006
832
 
1007
833
  # Check that created classifications were properly stored in SQLite cache