arkindex-base-worker 0.4.0b1__py3-none-any.whl → 0.4.0b3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/METADATA +1 -1
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/RECORD +19 -19
- arkindex_worker/image.py +2 -1
- arkindex_worker/utils.py +81 -0
- arkindex_worker/worker/__init__.py +3 -2
- arkindex_worker/worker/classification.py +31 -15
- arkindex_worker/worker/element.py +71 -10
- arkindex_worker/worker/entity.py +25 -11
- arkindex_worker/worker/metadata.py +18 -8
- arkindex_worker/worker/transcription.py +38 -17
- tests/test_elements_worker/test_classifications.py +107 -60
- tests/test_elements_worker/test_elements.py +318 -49
- tests/test_elements_worker/test_entities.py +102 -33
- tests/test_elements_worker/test_metadata.py +223 -98
- tests/test_elements_worker/test_transcriptions.py +293 -143
- tests/test_utils.py +28 -0
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/LICENSE +0 -0
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/WHEEL +0 -0
- {arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/top_level.txt +0 -0
|
@@ -8,6 +8,7 @@ from playhouse.shortcuts import model_to_dict
|
|
|
8
8
|
|
|
9
9
|
from arkindex_worker.cache import CachedElement, CachedTranscription
|
|
10
10
|
from arkindex_worker.models import Element
|
|
11
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
11
12
|
from arkindex_worker.worker.transcription import TextOrientation
|
|
12
13
|
|
|
13
14
|
from . import BASE_API_CALLS
|
|
@@ -639,9 +640,10 @@ def test_create_transcriptions_api_error(responses, mock_elements_worker):
|
|
|
639
640
|
] == BASE_API_CALLS + [("POST", "http://testserver/api/v1/transcription/bulk/")]
|
|
640
641
|
|
|
641
642
|
|
|
642
|
-
|
|
643
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
644
|
+
def test_create_transcriptions(batch_size, responses, mock_elements_worker_with_cache):
|
|
643
645
|
CachedElement.create(id="11111111-1111-1111-1111-111111111111", type="thing")
|
|
644
|
-
|
|
646
|
+
transcriptions = [
|
|
645
647
|
{
|
|
646
648
|
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
647
649
|
"text": "The",
|
|
@@ -654,60 +656,110 @@ def test_create_transcriptions(responses, mock_elements_worker_with_cache):
|
|
|
654
656
|
},
|
|
655
657
|
]
|
|
656
658
|
|
|
657
|
-
|
|
658
|
-
responses.
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
659
|
+
if batch_size > 1:
|
|
660
|
+
responses.add(
|
|
661
|
+
responses.POST,
|
|
662
|
+
"http://testserver/api/v1/transcription/bulk/",
|
|
663
|
+
status=200,
|
|
664
|
+
json={
|
|
665
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
666
|
+
"transcriptions": [
|
|
667
|
+
{
|
|
668
|
+
"id": "00000000-0000-0000-0000-000000000000",
|
|
669
|
+
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
670
|
+
"text": "The",
|
|
671
|
+
"orientation": "horizontal-lr",
|
|
672
|
+
"confidence": 0.75,
|
|
673
|
+
},
|
|
674
|
+
{
|
|
675
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
676
|
+
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
677
|
+
"text": "word",
|
|
678
|
+
"orientation": "horizontal-lr",
|
|
679
|
+
"confidence": 0.42,
|
|
680
|
+
},
|
|
681
|
+
],
|
|
682
|
+
},
|
|
683
|
+
)
|
|
684
|
+
else:
|
|
685
|
+
for tr, tr_id in zip(
|
|
686
|
+
transcriptions,
|
|
687
|
+
[
|
|
688
|
+
"00000000-0000-0000-0000-000000000000",
|
|
689
|
+
"11111111-1111-1111-1111-111111111111",
|
|
678
690
|
],
|
|
679
|
-
|
|
680
|
-
|
|
691
|
+
strict=False,
|
|
692
|
+
):
|
|
693
|
+
responses.add(
|
|
694
|
+
responses.POST,
|
|
695
|
+
"http://testserver/api/v1/transcription/bulk/",
|
|
696
|
+
status=200,
|
|
697
|
+
json={
|
|
698
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
699
|
+
"transcriptions": [
|
|
700
|
+
{
|
|
701
|
+
"id": tr_id,
|
|
702
|
+
"element_id": tr["element_id"],
|
|
703
|
+
"text": tr["text"],
|
|
704
|
+
"orientation": "horizontal-lr",
|
|
705
|
+
"confidence": tr["confidence"],
|
|
706
|
+
}
|
|
707
|
+
],
|
|
708
|
+
},
|
|
709
|
+
)
|
|
681
710
|
|
|
682
711
|
mock_elements_worker_with_cache.create_transcriptions(
|
|
683
|
-
transcriptions=
|
|
712
|
+
transcriptions=transcriptions,
|
|
713
|
+
batch_size=batch_size,
|
|
684
714
|
)
|
|
685
715
|
|
|
686
|
-
|
|
716
|
+
bulk_api_calls = [
|
|
717
|
+
(
|
|
718
|
+
"POST",
|
|
719
|
+
"http://testserver/api/v1/transcription/bulk/",
|
|
720
|
+
)
|
|
721
|
+
]
|
|
722
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
723
|
+
bulk_api_calls.append(
|
|
724
|
+
(
|
|
725
|
+
"POST",
|
|
726
|
+
"http://testserver/api/v1/transcription/bulk/",
|
|
727
|
+
)
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
687
731
|
assert [
|
|
688
732
|
(call.request.method, call.request.url) for call in responses.calls
|
|
689
|
-
] == BASE_API_CALLS +
|
|
690
|
-
("POST", "http://testserver/api/v1/transcription/bulk/"),
|
|
691
|
-
]
|
|
733
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
692
734
|
|
|
693
|
-
|
|
735
|
+
first_tr = {
|
|
736
|
+
**transcriptions[0],
|
|
737
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
738
|
+
}
|
|
739
|
+
second_tr = {
|
|
740
|
+
**transcriptions[1],
|
|
741
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
742
|
+
}
|
|
743
|
+
empty_payload = {
|
|
744
|
+
"transcriptions": [],
|
|
694
745
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
695
|
-
"transcriptions": [
|
|
696
|
-
{
|
|
697
|
-
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
698
|
-
"text": "The",
|
|
699
|
-
"confidence": 0.75,
|
|
700
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
701
|
-
},
|
|
702
|
-
{
|
|
703
|
-
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
704
|
-
"text": "word",
|
|
705
|
-
"confidence": 0.42,
|
|
706
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
707
|
-
},
|
|
708
|
-
],
|
|
709
746
|
}
|
|
710
747
|
|
|
748
|
+
bodies = []
|
|
749
|
+
first_call_idx = None
|
|
750
|
+
if batch_size > 1:
|
|
751
|
+
first_call_idx = -1
|
|
752
|
+
bodies.append({**empty_payload, "transcriptions": [first_tr, second_tr]})
|
|
753
|
+
else:
|
|
754
|
+
first_call_idx = -2
|
|
755
|
+
bodies.append({**empty_payload, "transcriptions": [first_tr]})
|
|
756
|
+
bodies.append({**empty_payload, "transcriptions": [second_tr]})
|
|
757
|
+
|
|
758
|
+
assert [
|
|
759
|
+
json.loads(bulk_call.request.body)
|
|
760
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
761
|
+
] == bodies
|
|
762
|
+
|
|
711
763
|
# Check that created transcriptions were properly stored in SQLite cache
|
|
712
764
|
assert list(CachedTranscription.select()) == [
|
|
713
765
|
CachedTranscription(
|
|
@@ -1281,70 +1333,119 @@ def test_create_element_transcriptions_api_error(responses, mock_elements_worker
|
|
|
1281
1333
|
]
|
|
1282
1334
|
|
|
1283
1335
|
|
|
1284
|
-
|
|
1336
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 2])
|
|
1337
|
+
def test_create_element_transcriptions(batch_size, responses, mock_elements_worker):
|
|
1285
1338
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1339
|
+
|
|
1340
|
+
if batch_size > 2:
|
|
1341
|
+
responses.add(
|
|
1342
|
+
responses.POST,
|
|
1343
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1344
|
+
status=200,
|
|
1345
|
+
json=[
|
|
1346
|
+
{
|
|
1347
|
+
"id": "56785678-5678-5678-5678-567856785678",
|
|
1348
|
+
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
1349
|
+
"created": True,
|
|
1350
|
+
},
|
|
1351
|
+
{
|
|
1352
|
+
"id": "67896789-6789-6789-6789-678967896789",
|
|
1353
|
+
"element_id": "22222222-2222-2222-2222-222222222222",
|
|
1354
|
+
"created": False,
|
|
1355
|
+
},
|
|
1356
|
+
{
|
|
1357
|
+
"id": "78907890-7890-7890-7890-789078907890",
|
|
1358
|
+
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
1359
|
+
"created": True,
|
|
1360
|
+
},
|
|
1361
|
+
],
|
|
1362
|
+
)
|
|
1363
|
+
else:
|
|
1364
|
+
for transcriptions in [
|
|
1365
|
+
[
|
|
1366
|
+
("56785678-5678-5678-5678-567856785678", True),
|
|
1367
|
+
("67896789-6789-6789-6789-678967896789", False),
|
|
1368
|
+
],
|
|
1369
|
+
[("78907890-7890-7890-7890-789078907890", True)],
|
|
1370
|
+
]:
|
|
1371
|
+
responses.add(
|
|
1372
|
+
responses.POST,
|
|
1373
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1374
|
+
status=200,
|
|
1375
|
+
json=[
|
|
1376
|
+
{
|
|
1377
|
+
"id": tr_id,
|
|
1378
|
+
"element_id": "11111111-1111-1111-1111-111111111111"
|
|
1379
|
+
if created
|
|
1380
|
+
else "22222222-2222-2222-2222-222222222222",
|
|
1381
|
+
"created": created,
|
|
1382
|
+
}
|
|
1383
|
+
for tr_id, created in transcriptions
|
|
1384
|
+
],
|
|
1385
|
+
)
|
|
1308
1386
|
|
|
1309
1387
|
annotations = mock_elements_worker.create_element_transcriptions(
|
|
1310
1388
|
element=elt,
|
|
1311
1389
|
sub_element_type="page",
|
|
1312
1390
|
transcriptions=TRANSCRIPTIONS_SAMPLE,
|
|
1391
|
+
batch_size=batch_size,
|
|
1313
1392
|
)
|
|
1314
1393
|
|
|
1315
|
-
|
|
1394
|
+
bulk_api_calls = [
|
|
1395
|
+
(
|
|
1396
|
+
"POST",
|
|
1397
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1398
|
+
)
|
|
1399
|
+
]
|
|
1400
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
1401
|
+
bulk_api_calls.append(
|
|
1402
|
+
(
|
|
1403
|
+
"POST",
|
|
1404
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1405
|
+
)
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
1316
1409
|
assert [
|
|
1317
1410
|
(call.request.method, call.request.url) for call in responses.calls
|
|
1318
|
-
] == BASE_API_CALLS +
|
|
1319
|
-
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
|
|
1320
|
-
]
|
|
1411
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
1321
1412
|
|
|
1322
|
-
|
|
1413
|
+
first_tr = {
|
|
1414
|
+
**TRANSCRIPTIONS_SAMPLE[0],
|
|
1415
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1416
|
+
}
|
|
1417
|
+
second_tr = {
|
|
1418
|
+
**TRANSCRIPTIONS_SAMPLE[1],
|
|
1419
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1420
|
+
}
|
|
1421
|
+
third_tr = {
|
|
1422
|
+
**TRANSCRIPTIONS_SAMPLE[2],
|
|
1423
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1424
|
+
}
|
|
1425
|
+
empty_payload = {
|
|
1323
1426
|
"element_type": "page",
|
|
1324
1427
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
1325
|
-
"transcriptions": [
|
|
1326
|
-
{
|
|
1327
|
-
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
|
|
1328
|
-
"confidence": 0.5,
|
|
1329
|
-
"text": "The",
|
|
1330
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1331
|
-
},
|
|
1332
|
-
{
|
|
1333
|
-
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
|
|
1334
|
-
"confidence": 0.75,
|
|
1335
|
-
"text": "first",
|
|
1336
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1337
|
-
"element_confidence": 0.75,
|
|
1338
|
-
},
|
|
1339
|
-
{
|
|
1340
|
-
"polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
|
|
1341
|
-
"confidence": 0.9,
|
|
1342
|
-
"text": "line",
|
|
1343
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1344
|
-
},
|
|
1345
|
-
],
|
|
1428
|
+
"transcriptions": [],
|
|
1346
1429
|
"return_elements": True,
|
|
1347
1430
|
}
|
|
1431
|
+
|
|
1432
|
+
bodies = []
|
|
1433
|
+
first_call_idx = None
|
|
1434
|
+
if batch_size > 2:
|
|
1435
|
+
first_call_idx = -1
|
|
1436
|
+
bodies.append(
|
|
1437
|
+
{**empty_payload, "transcriptions": [first_tr, second_tr, third_tr]}
|
|
1438
|
+
)
|
|
1439
|
+
else:
|
|
1440
|
+
first_call_idx = -2
|
|
1441
|
+
bodies.append({**empty_payload, "transcriptions": [first_tr, second_tr]})
|
|
1442
|
+
bodies.append({**empty_payload, "transcriptions": [third_tr]})
|
|
1443
|
+
|
|
1444
|
+
assert [
|
|
1445
|
+
json.loads(bulk_call.request.body)
|
|
1446
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
1447
|
+
] == bodies
|
|
1448
|
+
|
|
1348
1449
|
assert annotations == [
|
|
1349
1450
|
{
|
|
1350
1451
|
"id": "56785678-5678-5678-5678-567856785678",
|
|
@@ -1364,73 +1465,122 @@ def test_create_element_transcriptions(responses, mock_elements_worker):
|
|
|
1364
1465
|
]
|
|
1365
1466
|
|
|
1366
1467
|
|
|
1468
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 2])
|
|
1367
1469
|
def test_create_element_transcriptions_with_cache(
|
|
1368
|
-
responses, mock_elements_worker_with_cache
|
|
1470
|
+
batch_size, responses, mock_elements_worker_with_cache
|
|
1369
1471
|
):
|
|
1370
1472
|
elt = CachedElement(id="12341234-1234-1234-1234-123412341234", type="thing")
|
|
1371
1473
|
|
|
1372
|
-
|
|
1373
|
-
responses.
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1474
|
+
if batch_size > 2:
|
|
1475
|
+
responses.add(
|
|
1476
|
+
responses.POST,
|
|
1477
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1478
|
+
status=200,
|
|
1479
|
+
json=[
|
|
1480
|
+
{
|
|
1481
|
+
"id": "56785678-5678-5678-5678-567856785678",
|
|
1482
|
+
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
1483
|
+
"created": True,
|
|
1484
|
+
},
|
|
1485
|
+
{
|
|
1486
|
+
"id": "67896789-6789-6789-6789-678967896789",
|
|
1487
|
+
"element_id": "22222222-2222-2222-2222-222222222222",
|
|
1488
|
+
"created": False,
|
|
1489
|
+
},
|
|
1490
|
+
{
|
|
1491
|
+
"id": "78907890-7890-7890-7890-789078907890",
|
|
1492
|
+
"element_id": "11111111-1111-1111-1111-111111111111",
|
|
1493
|
+
"created": True,
|
|
1494
|
+
},
|
|
1495
|
+
],
|
|
1496
|
+
)
|
|
1497
|
+
else:
|
|
1498
|
+
for transcriptions in [
|
|
1499
|
+
[
|
|
1500
|
+
("56785678-5678-5678-5678-567856785678", True),
|
|
1501
|
+
("67896789-6789-6789-6789-678967896789", False),
|
|
1502
|
+
],
|
|
1503
|
+
[("78907890-7890-7890-7890-789078907890", True)],
|
|
1504
|
+
]:
|
|
1505
|
+
responses.add(
|
|
1506
|
+
responses.POST,
|
|
1507
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1508
|
+
status=200,
|
|
1509
|
+
json=[
|
|
1510
|
+
{
|
|
1511
|
+
"id": tr_id,
|
|
1512
|
+
"element_id": "11111111-1111-1111-1111-111111111111"
|
|
1513
|
+
if created
|
|
1514
|
+
else "22222222-2222-2222-2222-222222222222",
|
|
1515
|
+
"created": created,
|
|
1516
|
+
}
|
|
1517
|
+
for tr_id, created in transcriptions
|
|
1518
|
+
],
|
|
1519
|
+
)
|
|
1394
1520
|
|
|
1395
1521
|
annotations = mock_elements_worker_with_cache.create_element_transcriptions(
|
|
1396
1522
|
element=elt,
|
|
1397
1523
|
sub_element_type="page",
|
|
1398
1524
|
transcriptions=TRANSCRIPTIONS_SAMPLE,
|
|
1525
|
+
batch_size=batch_size,
|
|
1399
1526
|
)
|
|
1400
1527
|
|
|
1401
|
-
|
|
1528
|
+
bulk_api_calls = [
|
|
1529
|
+
(
|
|
1530
|
+
"POST",
|
|
1531
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1532
|
+
)
|
|
1533
|
+
]
|
|
1534
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
1535
|
+
bulk_api_calls.append(
|
|
1536
|
+
(
|
|
1537
|
+
"POST",
|
|
1538
|
+
f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/",
|
|
1539
|
+
)
|
|
1540
|
+
)
|
|
1541
|
+
|
|
1542
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
1402
1543
|
assert [
|
|
1403
1544
|
(call.request.method, call.request.url) for call in responses.calls
|
|
1404
|
-
] == BASE_API_CALLS +
|
|
1405
|
-
("POST", f"http://testserver/api/v1/element/{elt.id}/transcriptions/bulk/"),
|
|
1406
|
-
]
|
|
1545
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
1407
1546
|
|
|
1408
|
-
|
|
1547
|
+
first_tr = {
|
|
1548
|
+
**TRANSCRIPTIONS_SAMPLE[0],
|
|
1549
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1550
|
+
}
|
|
1551
|
+
second_tr = {
|
|
1552
|
+
**TRANSCRIPTIONS_SAMPLE[1],
|
|
1553
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1554
|
+
"element_confidence": 0.75,
|
|
1555
|
+
}
|
|
1556
|
+
third_tr = {
|
|
1557
|
+
**TRANSCRIPTIONS_SAMPLE[2],
|
|
1558
|
+
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1559
|
+
}
|
|
1560
|
+
empty_payload = {
|
|
1409
1561
|
"element_type": "page",
|
|
1410
1562
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
1411
|
-
"transcriptions": [
|
|
1412
|
-
{
|
|
1413
|
-
"polygon": [[100, 150], [700, 150], [700, 200], [100, 200]],
|
|
1414
|
-
"confidence": 0.5,
|
|
1415
|
-
"text": "The",
|
|
1416
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1417
|
-
},
|
|
1418
|
-
{
|
|
1419
|
-
"polygon": [[0, 0], [2000, 0], [2000, 3000], [0, 3000]],
|
|
1420
|
-
"confidence": 0.75,
|
|
1421
|
-
"text": "first",
|
|
1422
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1423
|
-
"element_confidence": 0.75,
|
|
1424
|
-
},
|
|
1425
|
-
{
|
|
1426
|
-
"polygon": [[1000, 300], [1200, 300], [1200, 500], [1000, 500]],
|
|
1427
|
-
"confidence": 0.9,
|
|
1428
|
-
"text": "line",
|
|
1429
|
-
"orientation": TextOrientation.HorizontalLeftToRight.value,
|
|
1430
|
-
},
|
|
1431
|
-
],
|
|
1563
|
+
"transcriptions": [],
|
|
1432
1564
|
"return_elements": True,
|
|
1433
1565
|
}
|
|
1566
|
+
|
|
1567
|
+
bodies = []
|
|
1568
|
+
first_call_idx = None
|
|
1569
|
+
if batch_size > 2:
|
|
1570
|
+
first_call_idx = -1
|
|
1571
|
+
bodies.append(
|
|
1572
|
+
{**empty_payload, "transcriptions": [first_tr, second_tr, third_tr]}
|
|
1573
|
+
)
|
|
1574
|
+
else:
|
|
1575
|
+
first_call_idx = -2
|
|
1576
|
+
bodies.append({**empty_payload, "transcriptions": [first_tr, second_tr]})
|
|
1577
|
+
bodies.append({**empty_payload, "transcriptions": [third_tr]})
|
|
1578
|
+
|
|
1579
|
+
assert [
|
|
1580
|
+
json.loads(bulk_call.request.body)
|
|
1581
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
1582
|
+
] == bodies
|
|
1583
|
+
|
|
1434
1584
|
assert annotations == [
|
|
1435
1585
|
{
|
|
1436
1586
|
"id": "56785678-5678-5678-5678-567856785678",
|
tests/test_utils.py
CHANGED
|
@@ -3,6 +3,8 @@ from pathlib import Path
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
from arkindex_worker.utils import (
|
|
6
|
+
DEFAULT_BATCH_SIZE,
|
|
7
|
+
batch_publication,
|
|
6
8
|
close_delete_file,
|
|
7
9
|
extract_tar_zst_archive,
|
|
8
10
|
parse_source_id,
|
|
@@ -55,3 +57,29 @@ def test_close_delete_file(tmp_path):
|
|
|
55
57
|
close_delete_file(archive_fd, archive_path)
|
|
56
58
|
|
|
57
59
|
assert not archive_path.exists()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TestMixin:
|
|
63
|
+
@batch_publication
|
|
64
|
+
def custom_publication_in_batches(self, batch_size: int = DEFAULT_BATCH_SIZE):
|
|
65
|
+
return batch_size
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_batch_publication_decorator_no_parameter():
|
|
69
|
+
assert TestMixin().custom_publication_in_batches() == DEFAULT_BATCH_SIZE
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@pytest.mark.parametrize("wrong_batch_size", [None, "not an int", 0])
|
|
73
|
+
def test_batch_publication_decorator_wrong_parameter(wrong_batch_size):
|
|
74
|
+
with pytest.raises(
|
|
75
|
+
AssertionError,
|
|
76
|
+
match="batch_size shouldn't be null and should be a strictly positive integer",
|
|
77
|
+
):
|
|
78
|
+
TestMixin().custom_publication_in_batches(batch_size=wrong_batch_size)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.mark.parametrize("batch_size", [1, 10, DEFAULT_BATCH_SIZE])
|
|
82
|
+
def test_batch_publication_decorator_right_parameter(batch_size):
|
|
83
|
+
assert (
|
|
84
|
+
TestMixin().custom_publication_in_batches(batch_size=batch_size) == batch_size
|
|
85
|
+
)
|
|
File without changes
|
|
File without changes
|
{arkindex_base_worker-0.4.0b1.dist-info → arkindex_base_worker-0.4.0b3.dist-info}/top_level.txt
RENAMED
|
File without changes
|