arkindex-base-worker 0.4.0b1__py3-none-any.whl → 0.4.0b2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ from playhouse.shortcuts import model_to_dict
8
8
 
9
9
  from arkindex_worker.cache import CachedElement, CachedTranscription
10
10
  from arkindex_worker.models import Element
11
+ from arkindex_worker.utils import DEFAULT_BATCH_SIZE
11
12
  from arkindex_worker.worker.transcription import TextOrientation
12
13
 
13
14
  from . import BASE_API_CALLS
@@ -639,9 +640,10 @@ def test_create_transcriptions_api_error(responses, mock_elements_worker):
639
640
  ] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/transcription/bulk/")]
640
641
 
641
642
 
642
- def test_create_transcriptions(responses, mock_elements_worker_with_cache):
643
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
644
+ def test_create_transcriptions(batch_size, responses, mock_elements_worker_with_cache):
643
645
  CachedElement.create(id="11111111-1111-1111-1111-111111111111", type="thing")
644
- trans = [
646
+ transcriptions = [
645
647
  {
646
648
  "element_id": "11111111-1111-1111-1111-111111111111",
647
649
  "text": "The",
@@ -654,60 +656,110 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
654
656
  },
655
657
  ]
656
658
 
657
- responses.add(
658
- responses.POST,
659
- "http://testserver/api/v1/transcription/bulk/",
660
- status=200,
661
- json={
662
- "worker_run_id": "56785678-5678-5678-5678-567856785678",
663
- "transcriptions": [
664
- {
665
- "id": "00000000-0000-0000-0000-000000000000",
666
- "element_id": "11111111-1111-1111-1111-111111111111",
667
- "text": "The",
668
- "orientation": "horizontal-lr",
669
- "confidence": 0.75,
670
- },
671
- {
672
- "id": "11111111-1111-1111-1111-111111111111",
673
- "element_id": "11111111-1111-1111-1111-111111111111",
674
- "text": "word",
675
- "orientation": "horizontal-lr",
676
- "confidence": 0.42,
677
- },
659
+ if batch_size > 1:
660
+ responses.add(
661
+ responses.POST,
662
+ "http://testserver/api/v1/transcription/bulk/",
663
+ status=200,
664
+ json={
665
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
666
+ "transcriptions": [
667
+ {
668
+ "id": "00000000-0000-0000-0000-000000000000",
669
+ "element_id": "11111111-1111-1111-1111-111111111111",
670
+ "text": "The",
671
+ "orientation": "horizontal-lr",
672
+ "confidence": 0.75,
673
+ },
674
+ {
675
+ "id": "11111111-1111-1111-1111-111111111111",
676
+ "element_id": "11111111-1111-1111-1111-111111111111",
677
+ "text": "word",
678
+ "orientation": "horizontal-lr",
679
+ "confidence": 0.42,
680
+ },
681
+ ],
682
+ },
683
+ )
684
+ else:
685
+ for tr, tr_id in zip(
686
+ transcriptions,
687
+ [
688
+ "00000000-0000-0000-0000-000000000000",
689
+ "11111111-1111-1111-1111-111111111111",
678
690
  ],
679
- },
680
- )
691
+ strict=False,
692
+ ):
693
+ responses.add(
694
+ responses.POST,
695
+ "http://testserver/api/v1/transcription/bulk/",
696
+ status=200,
697
+ json={
698
+ "worker_run_id": "56785678-5678-5678-5678-567856785678",
699
+ "transcriptions": [
700
+ {
701
+ "id": tr_id,
702
+ "element_id": tr["element_id"],
703
+ "text": tr["text"],
704
+ "orientation": "horizontal-lr",
705
+ "confidence": tr["confidence"],
706
+ }
707
+ ],
708
+ },
709
+ )
681
710
 
682
711
  mock_elements_worker_with_cache.create_transcriptions(
683
- transcriptions=trans,
712
+ transcriptions=transcriptions,
713
+ batch_size=batch_size,
684
714
  )
685
715
 
686
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
716
+ bulk_api_calls = [
717
+ (
718
+ "POST",
719
+ "http://testserver/api/v1/transcription/bulk/",
720
+ )
721
+ ]
722
+ if batch_size != DEFAULT_BATCH_SIZE:
723
+ bulk_api_calls.append(
724
+ (
725
+ "POST",
726
+ "http://testserver/api/v1/transcription/bulk/",
727
+ )
728
+ )
729
+
730
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
687
731
  assert [
688
732
  (call.request.method, call.request.url) for call in responses.calls
689
- ] == BASE_API_CALLS + [
690
- ("POST", "http://testserver/api/v1/transcription/bulk/"),
691
- ]
733
+ ] == BASE_API_CALLS + bulk_api_calls
692
734
 
693
- assert json.loads(responses.calls[-1].request.body) == {
735
+ first_tr = {
736
+ **transcriptions[0],
737
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
738
+ }
739
+ second_tr = {
740
+ **transcriptions[1],
741
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
742
+ }
743
+ empty_payload = {
744
+ "transcriptions": [],
694
745
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
695
- "transcriptions": [
696
- {
697
- "element_id": "11111111-1111-1111-1111-111111111111",
698
- "text": "The",
699
- "confidence": 0.75,
700
- "orientation": TextOrientation.HorizontalLeftToRight.value,
701
- },
702
- {
703
- "element_id": "11111111-1111-1111-1111-111111111111",
704
- "text": "word",
705
- "confidence": 0.42,
706
- "orientation": TextOrientation.HorizontalLeftToRight.value,
707
- },
708
- ],
709
746
  }
710
747
 
748
+ bodies = []
749
+ first_call_idx = None
750
+ if batch_size > 1:
751
+ first_call_idx = -1
752
+ bodies.append({**empty_payload, "transcriptions": [first_tr, second_tr]})
753
+ else:
754
+ first_call_idx = -2
755
+ bodies.append({**empty_payload, "transcriptions": [first_tr]})
756
+ bodies.append({**empty_payload, "transcriptions": [second_tr]})
757
+
758
+ assert [
759
+ json.loads(bulk_call.request.body)
760
+ for bulk_call in responses.calls[first_call_idx:]
761
+ ] == bodies
762
+
711
763
  # Check that created transcriptions were properly stored in SQLite cache
712
764
  assert list(CachedTranscription.select()) == [
713
765
  CachedTranscription(
@@ -1281,70 +1333,119 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
1281
1333
  ]
1282
1334
 
1283
1335
 
1284
- def test_create_element_transcriptions(responses, mock_elements_worker):
1336
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 2])
1337
+ def test_create_element_transcriptions(batch_size, responses, mock_elements_worker):
1285
1338
  elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
1286
- responses.add(
1287
- responses.POST,
1288
- f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1289
- status=200,
1290
- json=[
1291
- {
1292
- "id": "56785678-5678-5678-5678-567856785678",
1293
- "element_id": "11111111-1111-1111-1111-111111111111",
1294
- "created": True,
1295
- },
1296
- {
1297
- "id": "67896789-6789-6789-6789-678967896789",
1298
- "element_id": "22222222-2222-2222-2222-222222222222",
1299
- "created": False,
1300
- },
1301
- {
1302
- "id": "78907890-7890-7890-7890-789078907890",
1303
- "element_id": "11111111-1111-1111-1111-111111111111",
1304
- "created": True,
1305
- },
1306
- ],
1307
- )
1339
+
1340
+ if batch_size > 2:
1341
+ responses.add(
1342
+ responses.POST,
1343
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1344
+ status=200,
1345
+ json=[
1346
+ {
1347
+ "id": "56785678-5678-5678-5678-567856785678",
1348
+ "element_id": "11111111-1111-1111-1111-111111111111",
1349
+ "created": True,
1350
+ },
1351
+ {
1352
+ "id": "67896789-6789-6789-6789-678967896789",
1353
+ "element_id": "22222222-2222-2222-2222-222222222222",
1354
+ "created": False,
1355
+ },
1356
+ {
1357
+ "id": "78907890-7890-7890-7890-789078907890",
1358
+ "element_id": "11111111-1111-1111-1111-111111111111",
1359
+ "created": True,
1360
+ },
1361
+ ],
1362
+ )
1363
+ else:
1364
+ for transcriptions in [
1365
+ [
1366
+ ("56785678-5678-5678-5678-567856785678", True),
1367
+ ("67896789-6789-6789-6789-678967896789", False),
1368
+ ],
1369
+ [("78907890-7890-7890-7890-789078907890", True)],
1370
+ ]:
1371
+ responses.add(
1372
+ responses.POST,
1373
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1374
+ status=200,
1375
+ json=[
1376
+ {
1377
+ "id": tr_id,
1378
+ "element_id": "11111111-1111-1111-1111-111111111111"
1379
+ if created
1380
+ else "22222222-2222-2222-2222-222222222222",
1381
+ "created": created,
1382
+ }
1383
+ for tr_id, created in transcriptions
1384
+ ],
1385
+ )
1308
1386
 
1309
1387
  annotations = mock_elements_worker.create_element_transcriptions(
1310
1388
  element=elt,
1311
1389
  sub_element_type="page",
1312
1390
  transcriptions=TRANSCRIPTIONS_SAMPLE,
1391
+ batch_size=batch_size,
1313
1392
  )
1314
1393
 
1315
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1394
+ bulk_api_calls = [
1395
+ (
1396
+ "POST",
1397
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1398
+ )
1399
+ ]
1400
+ if batch_size != DEFAULT_BATCH_SIZE:
1401
+ bulk_api_calls.append(
1402
+ (
1403
+ "POST",
1404
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1405
+ )
1406
+ )
1407
+
1408
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1316
1409
  assert [
1317
1410
  (call.request.method, call.request.url) for call in responses.calls
1318
- ] == BASE_API_CALLS + [
1319
- ("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
1320
- ]
1411
+ ] == BASE_API_CALLS + bulk_api_calls
1321
1412
 
1322
- assert json.loads(responses.calls[-1].request.body) == {
1413
+ first_tr = {
1414
+ **TRANSCRIPTIONS_SAMPLE[0],
1415
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
1416
+ }
1417
+ second_tr = {
1418
+ **TRANSCRIPTIONS_SAMPLE[1],
1419
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
1420
+ }
1421
+ third_tr = {
1422
+ **TRANSCRIPTIONS_SAMPLE[2],
1423
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
1424
+ }
1425
+ empty_payload = {
1323
1426
  "element_type": "page",
1324
1427
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1325
- "transcriptions": [
1326
- {
1327
- "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
1328
- "confidence": 0.5,
1329
- "text": "The",
1330
- "orientation": TextOrientation.HorizontalLeftToRight.value,
1331
- },
1332
- {
1333
- "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
1334
- "confidence": 0.75,
1335
- "text": "first",
1336
- "orientation": TextOrientation.HorizontalLeftToRight.value,
1337
- "element_confidence": 0.75,
1338
- },
1339
- {
1340
- "polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
1341
- "confidence": 0.9,
1342
- "text": "line",
1343
- "orientation": TextOrientation.HorizontalLeftToRight.value,
1344
- },
1345
- ],
1428
+ "transcriptions": [],
1346
1429
  "return_elements": True,
1347
1430
  }
1431
+
1432
+ bodies = []
1433
+ first_call_idx = None
1434
+ if batch_size > 2:
1435
+ first_call_idx = -1
1436
+ bodies.append(
1437
+ {**empty_payload, "transcriptions": [first_tr, second_tr, third_tr]}
1438
+ )
1439
+ else:
1440
+ first_call_idx = -2
1441
+ bodies.append({**empty_payload, "transcriptions": [first_tr, second_tr]})
1442
+ bodies.append({**empty_payload, "transcriptions": [third_tr]})
1443
+
1444
+ assert [
1445
+ json.loads(bulk_call.request.body)
1446
+ for bulk_call in responses.calls[first_call_idx:]
1447
+ ] == bodies
1448
+
1348
1449
  assert annotations == [
1349
1450
  {
1350
1451
  "id": "56785678-5678-5678-5678-567856785678",
@@ -1364,73 +1465,122 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
1364
1465
  ]
1365
1466
 
1366
1467
 
1468
+ @pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 2])
1367
1469
  def test_create_element_transcriptions_with_cache(
1368
- responses, mock_elements_worker_with_cache
1470
+ batch_size, responses, mock_elements_worker_with_cache
1369
1471
  ):
1370
1472
  elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
1371
1473
 
1372
- responses.add(
1373
- responses.POST,
1374
- f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1375
- status=200,
1376
- json=[
1377
- {
1378
- "id": "56785678-5678-5678-5678-567856785678",
1379
- "element_id": "11111111-1111-1111-1111-111111111111",
1380
- "created": True,
1381
- },
1382
- {
1383
- "id": "67896789-6789-6789-6789-678967896789",
1384
- "element_id": "22222222-2222-2222-2222-222222222222",
1385
- "created": False,
1386
- },
1387
- {
1388
- "id": "78907890-7890-7890-7890-789078907890",
1389
- "element_id": "11111111-1111-1111-1111-111111111111",
1390
- "created": True,
1391
- },
1392
- ],
1393
- )
1474
+ if batch_size > 2:
1475
+ responses.add(
1476
+ responses.POST,
1477
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1478
+ status=200,
1479
+ json=[
1480
+ {
1481
+ "id": "56785678-5678-5678-5678-567856785678",
1482
+ "element_id": "11111111-1111-1111-1111-111111111111",
1483
+ "created": True,
1484
+ },
1485
+ {
1486
+ "id": "67896789-6789-6789-6789-678967896789",
1487
+ "element_id": "22222222-2222-2222-2222-222222222222",
1488
+ "created": False,
1489
+ },
1490
+ {
1491
+ "id": "78907890-7890-7890-7890-789078907890",
1492
+ "element_id": "11111111-1111-1111-1111-111111111111",
1493
+ "created": True,
1494
+ },
1495
+ ],
1496
+ )
1497
+ else:
1498
+ for transcriptions in [
1499
+ [
1500
+ ("56785678-5678-5678-5678-567856785678", True),
1501
+ ("67896789-6789-6789-6789-678967896789", False),
1502
+ ],
1503
+ [("78907890-7890-7890-7890-789078907890", True)],
1504
+ ]:
1505
+ responses.add(
1506
+ responses.POST,
1507
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1508
+ status=200,
1509
+ json=[
1510
+ {
1511
+ "id": tr_id,
1512
+ "element_id": "11111111-1111-1111-1111-111111111111"
1513
+ if created
1514
+ else "22222222-2222-2222-2222-222222222222",
1515
+ "created": created,
1516
+ }
1517
+ for tr_id, created in transcriptions
1518
+ ],
1519
+ )
1394
1520
 
1395
1521
  annotations = mock_elements_worker_with_cache.create_element_transcriptions(
1396
1522
  element=elt,
1397
1523
  sub_element_type="page",
1398
1524
  transcriptions=TRANSCRIPTIONS_SAMPLE,
1525
+ batch_size=batch_size,
1399
1526
  )
1400
1527
 
1401
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
1528
+ bulk_api_calls = [
1529
+ (
1530
+ "POST",
1531
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1532
+ )
1533
+ ]
1534
+ if batch_size != DEFAULT_BATCH_SIZE:
1535
+ bulk_api_calls.append(
1536
+ (
1537
+ "POST",
1538
+ f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
1539
+ )
1540
+ )
1541
+
1542
+ assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
1402
1543
  assert [
1403
1544
  (call.request.method, call.request.url) for call in responses.calls
1404
- ] == BASE_API_CALLS + [
1405
- ("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
1406
- ]
1545
+ ] == BASE_API_CALLS + bulk_api_calls
1407
1546
 
1408
- assert json.loads(responses.calls[-1].request.body) == {
1547
+ first_tr = {
1548
+ **TRANSCRIPTIONS_SAMPLE[0],
1549
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
1550
+ }
1551
+ second_tr = {
1552
+ **TRANSCRIPTIONS_SAMPLE[1],
1553
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
1554
+ "element_confidence": 0.75,
1555
+ }
1556
+ third_tr = {
1557
+ **TRANSCRIPTIONS_SAMPLE[2],
1558
+ "orientation": TextOrientation.HorizontalLeftToRight.value,
1559
+ }
1560
+ empty_payload = {
1409
1561
  "element_type": "page",
1410
1562
  "worker_run_id": "56785678-5678-5678-5678-567856785678",
1411
- "transcriptions": [
1412
- {
1413
- "polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
1414
- "confidence": 0.5,
1415
- "text": "The",
1416
- "orientation": TextOrientation.HorizontalLeftToRight.value,
1417
- },
1418
- {
1419
- "polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
1420
- "confidence": 0.75,
1421
- "text": "first",
1422
- "orientation": TextOrientation.HorizontalLeftToRight.value,
1423
- "element_confidence": 0.75,
1424
- },
1425
- {
1426
- "polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
1427
- "confidence": 0.9,
1428
- "text": "line",
1429
- "orientation": TextOrientation.HorizontalLeftToRight.value,
1430
- },
1431
- ],
1563
+ "transcriptions": [],
1432
1564
  "return_elements": True,
1433
1565
  }
1566
+
1567
+ bodies = []
1568
+ first_call_idx = None
1569
+ if batch_size > 2:
1570
+ first_call_idx = -1
1571
+ bodies.append(
1572
+ {**empty_payload, "transcriptions": [first_tr, second_tr, third_tr]}
1573
+ )
1574
+ else:
1575
+ first_call_idx = -2
1576
+ bodies.append({**empty_payload, "transcriptions": [first_tr, second_tr]})
1577
+ bodies.append({**empty_payload, "transcriptions": [third_tr]})
1578
+
1579
+ assert [
1580
+ json.loads(bulk_call.request.body)
1581
+ for bulk_call in responses.calls[first_call_idx:]
1582
+ ] == bodies
1583
+
1434
1584
  assert annotations == [
1435
1585
  {
1436
1586
  "id": "56785678-5678-5678-5678-567856785678",
tests/test_utils.py CHANGED
@@ -3,6 +3,8 @@ from pathlib import Path
3
3
  import pytest
4
4
 
5
5
  from arkindex_worker.utils import (
6
+ DEFAULT_BATCH_SIZE,
7
+ batch_publication,
6
8
  close_delete_file,
7
9
  extract_tar_zst_archive,
8
10
  parse_source_id,
@@ -55,3 +57,29 @@ def test_close_delete_file(tmp_path):
55
57
  close_delete_file(archive_fd, archive_path)
56
58
 
57
59
  assert not archive_path.exists()
60
+
61
+
62
+ class TestMixin:
63
+ @batch_publication
64
+ def custom_publication_in_batches(self, batch_size: int = DEFAULT_BATCH_SIZE):
65
+ return batch_size
66
+
67
+
68
+ def test_batch_publication_decorator_no_parameter():
69
+ assert TestMixin().custom_publication_in_batches() == DEFAULT_BATCH_SIZE
70
+
71
+
72
+ @pytest.mark.parametrize("wrong_batch_size", [None, "not an int", 0])
73
+ def test_batch_publication_decorator_wrong_parameter(wrong_batch_size):
74
+ with pytest.raises(
75
+ AssertionError,
76
+ match="batch_size shouldn't be null and should be a strictly positive integer",
77
+ ):
78
+ TestMixin().custom_publication_in_batches(batch_size=wrong_batch_size)
79
+
80
+
81
+ @pytest.mark.parametrize("batch_size", [1, 10, DEFAULT_BATCH_SIZE])
82
+ def test_batch_publication_decorator_right_parameter(batch_size):
83
+ assert (
84
+ TestMixin().custom_publication_in_batches(batch_size=batch_size) == batch_size
85
+ )