arkindex-base-worker 0.4.0a2__py3-none-any.whl → 0.4.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/METADATA +7 -7
- arkindex_base_worker-0.4.0b2.dist-info/RECORD +51 -0
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/WHEEL +1 -1
- arkindex_worker/image.py +2 -1
- arkindex_worker/utils.py +76 -0
- arkindex_worker/worker/__init__.py +24 -14
- arkindex_worker/worker/base.py +3 -9
- arkindex_worker/worker/classification.py +33 -17
- arkindex_worker/worker/corpus.py +3 -1
- arkindex_worker/worker/dataset.py +1 -1
- arkindex_worker/worker/element.py +45 -16
- arkindex_worker/worker/entity.py +30 -17
- arkindex_worker/worker/metadata.py +19 -9
- arkindex_worker/worker/task.py +4 -2
- arkindex_worker/worker/training.py +5 -5
- arkindex_worker/worker/transcription.py +39 -18
- arkindex_worker/worker/version.py +3 -1
- tests/test_base_worker.py +1 -1
- tests/test_elements_worker/test_classifications.py +107 -60
- tests/test_elements_worker/test_elements.py +213 -70
- tests/test_elements_worker/test_entities.py +102 -33
- tests/test_elements_worker/test_metadata.py +223 -98
- tests/test_elements_worker/test_transcriptions.py +293 -143
- tests/test_merge.py +1 -1
- tests/test_utils.py +28 -0
- arkindex_base_worker-0.4.0a2.dist-info/RECORD +0 -51
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/LICENSE +0 -0
- {arkindex_base_worker-0.4.0a2.dist-info → arkindex_base_worker-0.4.0b2.dist-info}/top_level.txt +0 -0
arkindex_worker/worker/entity.py
CHANGED
|
@@ -15,6 +15,12 @@ from arkindex_worker.cache import (
|
|
|
15
15
|
unsupported_cache,
|
|
16
16
|
)
|
|
17
17
|
from arkindex_worker.models import Element, Transcription
|
|
18
|
+
from arkindex_worker.utils import (
|
|
19
|
+
DEFAULT_BATCH_SIZE,
|
|
20
|
+
batch_publication,
|
|
21
|
+
make_batches,
|
|
22
|
+
pluralize,
|
|
23
|
+
)
|
|
18
24
|
|
|
19
25
|
|
|
20
26
|
class Entity(TypedDict):
|
|
@@ -48,6 +54,7 @@ class EntityMixin:
|
|
|
48
54
|
if not self.entity_types:
|
|
49
55
|
# Load entity_types of corpus
|
|
50
56
|
self.list_corpus_entity_types()
|
|
57
|
+
|
|
51
58
|
for entity_type in entity_types:
|
|
52
59
|
# Do nothing if type already exists
|
|
53
60
|
if entity_type in self.entity_types:
|
|
@@ -60,7 +67,7 @@ class EntityMixin:
|
|
|
60
67
|
)
|
|
61
68
|
|
|
62
69
|
# Create type if non-existent
|
|
63
|
-
self.entity_types[entity_type] = self.request(
|
|
70
|
+
self.entity_types[entity_type] = self.api_client.request(
|
|
64
71
|
"CreateEntityType",
|
|
65
72
|
body={
|
|
66
73
|
"name": entity_type,
|
|
@@ -106,7 +113,7 @@ class EntityMixin:
|
|
|
106
113
|
entity_type_id = self.entity_types.get(type)
|
|
107
114
|
assert entity_type_id, f"Entity type `{type}` not found in the corpus."
|
|
108
115
|
|
|
109
|
-
entity = self.request(
|
|
116
|
+
entity = self.api_client.request(
|
|
110
117
|
"CreateEntity",
|
|
111
118
|
body={
|
|
112
119
|
"name": name,
|
|
@@ -188,7 +195,7 @@ class EntityMixin:
|
|
|
188
195
|
if confidence is not None:
|
|
189
196
|
body["confidence"] = confidence
|
|
190
197
|
|
|
191
|
-
transcription_ent = self.request(
|
|
198
|
+
transcription_ent = self.api_client.request(
|
|
192
199
|
"CreateTranscriptionEntity",
|
|
193
200
|
id=transcription.id,
|
|
194
201
|
body=body,
|
|
@@ -212,10 +219,12 @@ class EntityMixin:
|
|
|
212
219
|
return transcription_ent
|
|
213
220
|
|
|
214
221
|
@unsupported_cache
|
|
222
|
+
@batch_publication
|
|
215
223
|
def create_transcription_entities(
|
|
216
224
|
self,
|
|
217
225
|
transcription: Transcription,
|
|
218
226
|
entities: list[Entity],
|
|
227
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
219
228
|
) -> list[dict[str, str]]:
|
|
220
229
|
"""
|
|
221
230
|
Create multiple entities attached to a transcription in a single API request.
|
|
@@ -238,6 +247,8 @@ class EntityMixin:
|
|
|
238
247
|
confidence (float or None)
|
|
239
248
|
Optional confidence score, between 0.0 and 1.0.
|
|
240
249
|
|
|
250
|
+
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
|
|
251
|
+
|
|
241
252
|
:return: List of dicts, with each dict having a two keys, `transcription_entity_id` and `entity_id`, holding the UUID of each created object.
|
|
242
253
|
"""
|
|
243
254
|
assert transcription and isinstance(
|
|
@@ -289,16 +300,20 @@ class EntityMixin:
|
|
|
289
300
|
)
|
|
290
301
|
return
|
|
291
302
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
"
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
303
|
+
created_entities = [
|
|
304
|
+
created_entity
|
|
305
|
+
for batch in make_batches(entities, "entities", batch_size)
|
|
306
|
+
for created_entity in self.api_client.request(
|
|
307
|
+
"CreateTranscriptionEntities",
|
|
308
|
+
id=transcription.id,
|
|
309
|
+
body={
|
|
310
|
+
"worker_run_id": self.worker_run_id,
|
|
311
|
+
"entities": batch,
|
|
312
|
+
},
|
|
313
|
+
)["entities"]
|
|
314
|
+
]
|
|
300
315
|
|
|
301
|
-
return
|
|
316
|
+
return created_entities
|
|
302
317
|
|
|
303
318
|
def list_transcription_entities(
|
|
304
319
|
self,
|
|
@@ -382,12 +397,10 @@ class EntityMixin:
|
|
|
382
397
|
}
|
|
383
398
|
count = len(self.entities)
|
|
384
399
|
logger.info(
|
|
385
|
-
f'Loaded {count}
|
|
400
|
+
f'Loaded {count} {pluralize("entity", count)} in corpus ({self.corpus_id})'
|
|
386
401
|
)
|
|
387
402
|
|
|
388
|
-
def list_corpus_entity_types(
|
|
389
|
-
self,
|
|
390
|
-
):
|
|
403
|
+
def list_corpus_entity_types(self):
|
|
391
404
|
"""
|
|
392
405
|
Loads available entity types in corpus.
|
|
393
406
|
"""
|
|
@@ -399,5 +412,5 @@ class EntityMixin:
|
|
|
399
412
|
}
|
|
400
413
|
count = len(self.entity_types)
|
|
401
414
|
logger.info(
|
|
402
|
-
f'Loaded {count} entity
|
|
415
|
+
f'Loaded {count} entity {pluralize("type", count)} in corpus ({self.corpus_id}).'
|
|
403
416
|
)
|
|
@@ -7,6 +7,7 @@ from enum import Enum
|
|
|
7
7
|
from arkindex_worker import logger
|
|
8
8
|
from arkindex_worker.cache import CachedElement, unsupported_cache
|
|
9
9
|
from arkindex_worker.models import Element
|
|
10
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class MetaType(Enum):
|
|
@@ -93,7 +94,7 @@ class MetaDataMixin:
|
|
|
93
94
|
logger.warning("Cannot create metadata as this worker is in read-only mode")
|
|
94
95
|
return
|
|
95
96
|
|
|
96
|
-
metadata = self.request(
|
|
97
|
+
metadata = self.api_client.request(
|
|
97
98
|
"CreateMetaData",
|
|
98
99
|
id=element.id,
|
|
99
100
|
body={
|
|
@@ -108,10 +109,12 @@ class MetaDataMixin:
|
|
|
108
109
|
return metadata["id"]
|
|
109
110
|
|
|
110
111
|
@unsupported_cache
|
|
112
|
+
@batch_publication
|
|
111
113
|
def create_metadata_bulk(
|
|
112
114
|
self,
|
|
113
115
|
element: Element | CachedElement,
|
|
114
116
|
metadata_list: list[dict[str, MetaType | str | int | float | None]],
|
|
117
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
115
118
|
) -> list[dict[str, str]]:
|
|
116
119
|
"""
|
|
117
120
|
Create multiple metadata on an existing element.
|
|
@@ -123,6 +126,9 @@ class MetaDataMixin:
|
|
|
123
126
|
- name: str
|
|
124
127
|
- value: str | int | float
|
|
125
128
|
- entity_id: str | None
|
|
129
|
+
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
|
|
130
|
+
|
|
131
|
+
:returns: A list of dicts as returned in the ``metadata_list`` field by the ``CreateMetaDataBulk`` API endpoint.
|
|
126
132
|
"""
|
|
127
133
|
assert element and isinstance(
|
|
128
134
|
element, Element | CachedElement
|
|
@@ -168,14 +174,18 @@ class MetaDataMixin:
|
|
|
168
174
|
logger.warning("Cannot create metadata as this worker is in read-only mode")
|
|
169
175
|
return
|
|
170
176
|
|
|
171
|
-
created_metadata_list =
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
"
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
177
|
+
created_metadata_list = [
|
|
178
|
+
created_metadata
|
|
179
|
+
for batch in make_batches(metas, "metadata", batch_size)
|
|
180
|
+
for created_metadata in self.api_client.request(
|
|
181
|
+
"CreateMetaDataBulk",
|
|
182
|
+
id=element.id,
|
|
183
|
+
body={
|
|
184
|
+
"worker_run_id": self.worker_run_id,
|
|
185
|
+
"metadata_list": batch,
|
|
186
|
+
},
|
|
187
|
+
)["metadata_list"]
|
|
188
|
+
]
|
|
179
189
|
|
|
180
190
|
return created_metadata_list
|
|
181
191
|
|
arkindex_worker/worker/task.py
CHANGED
|
@@ -22,7 +22,7 @@ class TaskMixin:
|
|
|
22
22
|
task_id, uuid.UUID
|
|
23
23
|
), "task_id shouldn't be null and should be an UUID"
|
|
24
24
|
|
|
25
|
-
results = self.request("ListArtifacts", id=task_id)
|
|
25
|
+
results = self.api_client.request("ListArtifacts", id=task_id)
|
|
26
26
|
|
|
27
27
|
return map(Artifact, results)
|
|
28
28
|
|
|
@@ -43,4 +43,6 @@ class TaskMixin:
|
|
|
43
43
|
artifact, Artifact
|
|
44
44
|
), "artifact shouldn't be null and should be an Artifact"
|
|
45
45
|
|
|
46
|
-
return self.request(
|
|
46
|
+
return self.api_client.request(
|
|
47
|
+
"DownloadArtifact", id=task_id, path=artifact.path
|
|
48
|
+
)
|
|
@@ -185,7 +185,7 @@ class TrainingMixin:
|
|
|
185
185
|
assert not self.model_version, "A model version has already been created."
|
|
186
186
|
|
|
187
187
|
configuration = configuration or {}
|
|
188
|
-
self.model_version = self.request(
|
|
188
|
+
self.model_version = self.api_client.request(
|
|
189
189
|
"CreateModelVersion",
|
|
190
190
|
id=model_id,
|
|
191
191
|
body=build_clean_payload(
|
|
@@ -217,7 +217,7 @@ class TrainingMixin:
|
|
|
217
217
|
:param parent: ID of the parent model version
|
|
218
218
|
"""
|
|
219
219
|
assert self.model_version, "No model version has been created yet."
|
|
220
|
-
self.model_version = self.request(
|
|
220
|
+
self.model_version = self.api_client.request(
|
|
221
221
|
"UpdateModelVersion",
|
|
222
222
|
id=self.model_version["id"],
|
|
223
223
|
body=build_clean_payload(
|
|
@@ -273,7 +273,7 @@ class TrainingMixin:
|
|
|
273
273
|
"""
|
|
274
274
|
assert self.model_version, "You must create the model version and upload its archive before validating it."
|
|
275
275
|
try:
|
|
276
|
-
self.model_version = self.request(
|
|
276
|
+
self.model_version = self.api_client.request(
|
|
277
277
|
"PartialUpdateModelVersion",
|
|
278
278
|
id=self.model_version["id"],
|
|
279
279
|
body={
|
|
@@ -294,7 +294,7 @@ class TrainingMixin:
|
|
|
294
294
|
pending_version_id = self.model_version["id"]
|
|
295
295
|
logger.warning("Removing the pending model version.")
|
|
296
296
|
try:
|
|
297
|
-
self.request("DestroyModelVersion", id=pending_version_id)
|
|
297
|
+
self.api_client.request("DestroyModelVersion", id=pending_version_id)
|
|
298
298
|
except ErrorResponse as e:
|
|
299
299
|
msg = getattr(e, "content", str(e))
|
|
300
300
|
logger.error(
|
|
@@ -304,7 +304,7 @@ class TrainingMixin:
|
|
|
304
304
|
logger.info("Retrieving the existing model version.")
|
|
305
305
|
existing_version_id = model_version["id"].pop()
|
|
306
306
|
try:
|
|
307
|
-
self.model_version = self.request(
|
|
307
|
+
self.model_version = self.api_client.request(
|
|
308
308
|
"RetrieveModelVersion", id=existing_version_id
|
|
309
309
|
)
|
|
310
310
|
except ErrorResponse as e:
|
|
@@ -11,6 +11,7 @@ from peewee import IntegrityError
|
|
|
11
11
|
from arkindex_worker import logger
|
|
12
12
|
from arkindex_worker.cache import CachedElement, CachedTranscription
|
|
13
13
|
from arkindex_worker.models import Element
|
|
14
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE, batch_publication, make_batches
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class TextOrientation(Enum):
|
|
@@ -77,7 +78,7 @@ class TranscriptionMixin:
|
|
|
77
78
|
)
|
|
78
79
|
return
|
|
79
80
|
|
|
80
|
-
created = self.request(
|
|
81
|
+
created = self.api_client.request(
|
|
81
82
|
"CreateTranscription",
|
|
82
83
|
id=element.id,
|
|
83
84
|
body={
|
|
@@ -109,9 +110,11 @@ class TranscriptionMixin:
|
|
|
109
110
|
|
|
110
111
|
return created
|
|
111
112
|
|
|
113
|
+
@batch_publication
|
|
112
114
|
def create_transcriptions(
|
|
113
115
|
self,
|
|
114
116
|
transcriptions: list[dict[str, str | float | TextOrientation | None]],
|
|
117
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
115
118
|
) -> list[dict[str, str | float]]:
|
|
116
119
|
"""
|
|
117
120
|
Create multiple transcriptions at once on existing elements through the API,
|
|
@@ -128,6 +131,8 @@ class TranscriptionMixin:
|
|
|
128
131
|
orientation (TextOrientation)
|
|
129
132
|
Optional. Orientation of the transcription's text.
|
|
130
133
|
|
|
134
|
+
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
|
|
135
|
+
|
|
131
136
|
:returns: A list of dicts as returned in the ``transcriptions`` field by the ``CreateTranscriptions`` API endpoint.
|
|
132
137
|
"""
|
|
133
138
|
|
|
@@ -171,13 +176,19 @@ class TranscriptionMixin:
|
|
|
171
176
|
)
|
|
172
177
|
return
|
|
173
178
|
|
|
174
|
-
created_trs =
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
"
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
179
|
+
created_trs = [
|
|
180
|
+
created_tr
|
|
181
|
+
for batch in make_batches(
|
|
182
|
+
transcriptions_payload, "transcription", batch_size
|
|
183
|
+
)
|
|
184
|
+
for created_tr in self.api_client.request(
|
|
185
|
+
"CreateTranscriptions",
|
|
186
|
+
body={
|
|
187
|
+
"worker_run_id": self.worker_run_id,
|
|
188
|
+
"transcriptions": batch,
|
|
189
|
+
},
|
|
190
|
+
)["transcriptions"]
|
|
191
|
+
]
|
|
181
192
|
|
|
182
193
|
if self.use_cache:
|
|
183
194
|
# Store transcriptions in local cache
|
|
@@ -201,11 +212,13 @@ class TranscriptionMixin:
|
|
|
201
212
|
|
|
202
213
|
return created_trs
|
|
203
214
|
|
|
215
|
+
@batch_publication
|
|
204
216
|
def create_element_transcriptions(
|
|
205
217
|
self,
|
|
206
218
|
element: Element | CachedElement,
|
|
207
219
|
sub_element_type: str,
|
|
208
220
|
transcriptions: list[dict[str, str | float]],
|
|
221
|
+
batch_size: int = DEFAULT_BATCH_SIZE,
|
|
209
222
|
) -> dict[str, str | bool]:
|
|
210
223
|
"""
|
|
211
224
|
Create multiple elements and transcriptions at once on a single parent element through the API.
|
|
@@ -225,6 +238,8 @@ class TranscriptionMixin:
|
|
|
225
238
|
element_confidence (float)
|
|
226
239
|
Optional. Confidence score of the element between 0 and 1.
|
|
227
240
|
|
|
241
|
+
:param batch_size: The size of each batch, which will be used to split the publication to avoid API errors.
|
|
242
|
+
|
|
228
243
|
:returns: A list of dicts as returned by the ``CreateElementTranscriptions`` API endpoint.
|
|
229
244
|
"""
|
|
230
245
|
assert element and isinstance(
|
|
@@ -291,16 +306,22 @@ class TranscriptionMixin:
|
|
|
291
306
|
)
|
|
292
307
|
return
|
|
293
308
|
|
|
294
|
-
annotations =
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
"
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
309
|
+
annotations = [
|
|
310
|
+
annotation
|
|
311
|
+
for batch in make_batches(
|
|
312
|
+
transcriptions_payload, "transcription", batch_size
|
|
313
|
+
)
|
|
314
|
+
for annotation in self.api_client.request(
|
|
315
|
+
"CreateElementTranscriptions",
|
|
316
|
+
id=element.id,
|
|
317
|
+
body={
|
|
318
|
+
"element_type": sub_element_type,
|
|
319
|
+
"worker_run_id": self.worker_run_id,
|
|
320
|
+
"transcriptions": batch,
|
|
321
|
+
"return_elements": True,
|
|
322
|
+
},
|
|
323
|
+
)
|
|
324
|
+
]
|
|
304
325
|
|
|
305
326
|
for annotation in annotations:
|
|
306
327
|
if annotation["created"]:
|
|
@@ -34,7 +34,9 @@ class WorkerVersionMixin:
|
|
|
34
34
|
if worker_version_id in self._worker_version_cache:
|
|
35
35
|
return self._worker_version_cache[worker_version_id]
|
|
36
36
|
|
|
37
|
-
worker_version = self.request(
|
|
37
|
+
worker_version = self.api_client.request(
|
|
38
|
+
"RetrieveWorkerVersion", id=worker_version_id
|
|
39
|
+
)
|
|
38
40
|
self._worker_version_cache[worker_version_id] = worker_version
|
|
39
41
|
|
|
40
42
|
return worker_version
|
tests/test_base_worker.py
CHANGED
|
@@ -658,7 +658,7 @@ def test_find_extras_directory_not_found(monkeypatch, extras_path, exists, error
|
|
|
658
658
|
def test_find_parents_file_paths(responses, mock_base_worker_with_cache, tmp_path):
|
|
659
659
|
responses.add(
|
|
660
660
|
responses.GET,
|
|
661
|
-
"http://testserver/api/v1/task/my_task/
|
|
661
|
+
"http://testserver/api/v1/task/my_task/",
|
|
662
662
|
status=200,
|
|
663
663
|
json={"parents": ["first", "second", "third"]},
|
|
664
664
|
)
|
|
@@ -7,6 +7,7 @@ from apistar.exceptions import ErrorResponse
|
|
|
7
7
|
|
|
8
8
|
from arkindex_worker.cache import CachedClassification, CachedElement
|
|
9
9
|
from arkindex_worker.models import Element
|
|
10
|
+
from arkindex_worker.utils import DEFAULT_BATCH_SIZE
|
|
10
11
|
from tests import CORPUS_ID
|
|
11
12
|
|
|
12
13
|
from . import BASE_API_CALLS
|
|
@@ -692,7 +693,8 @@ def test_create_classifications_create_ml_class(mock_elements_worker, responses)
|
|
|
692
693
|
}
|
|
693
694
|
|
|
694
695
|
|
|
695
|
-
|
|
696
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
697
|
+
def test_create_classifications(batch_size, responses, mock_elements_worker):
|
|
696
698
|
mock_elements_worker.classes = {"portrait": "0000", "landscape": "1111"}
|
|
697
699
|
elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
|
|
698
700
|
responses.add(
|
|
@@ -716,62 +718,98 @@ def test_create_classifications(responses, mock_elements_worker):
|
|
|
716
718
|
"high_confidence": False,
|
|
717
719
|
},
|
|
718
720
|
],
|
|
721
|
+
batch_size=batch_size,
|
|
719
722
|
)
|
|
720
723
|
|
|
721
|
-
|
|
724
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
725
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
726
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
727
|
+
|
|
728
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
722
729
|
assert [
|
|
723
730
|
(call.request.method, call.request.url) for call in responses.calls
|
|
724
|
-
] == BASE_API_CALLS +
|
|
725
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
726
|
-
]
|
|
731
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
727
732
|
|
|
728
|
-
|
|
733
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
734
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
735
|
+
empty_payload = {
|
|
729
736
|
"parent": str(elt.id),
|
|
730
737
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
731
|
-
"classifications": [
|
|
732
|
-
{
|
|
733
|
-
"confidence": 0.75,
|
|
734
|
-
"high_confidence": False,
|
|
735
|
-
"ml_class": "0000",
|
|
736
|
-
},
|
|
737
|
-
{
|
|
738
|
-
"confidence": 0.25,
|
|
739
|
-
"high_confidence": False,
|
|
740
|
-
"ml_class": "1111",
|
|
741
|
-
},
|
|
742
|
-
],
|
|
738
|
+
"classifications": [],
|
|
743
739
|
}
|
|
744
740
|
|
|
741
|
+
bodies = []
|
|
742
|
+
first_call_idx = None
|
|
743
|
+
if batch_size > 1:
|
|
744
|
+
first_call_idx = -1
|
|
745
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
746
|
+
else:
|
|
747
|
+
first_call_idx = -2
|
|
748
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
749
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
750
|
+
|
|
751
|
+
assert [
|
|
752
|
+
json.loads(bulk_call.request.body)
|
|
753
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
754
|
+
] == bodies
|
|
755
|
+
|
|
745
756
|
|
|
746
|
-
|
|
757
|
+
@pytest.mark.parametrize("batch_size", [DEFAULT_BATCH_SIZE, 1])
|
|
758
|
+
def test_create_classifications_with_cache(
|
|
759
|
+
batch_size, responses, mock_elements_worker_with_cache
|
|
760
|
+
):
|
|
747
761
|
mock_elements_worker_with_cache.classes = {"portrait": "0000", "landscape": "1111"}
|
|
748
762
|
elt = CachedElement.create(id="12341234-1234-1234-1234-123412341234", type="thing")
|
|
749
763
|
|
|
750
|
-
|
|
751
|
-
responses.
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
764
|
+
if batch_size > 1:
|
|
765
|
+
responses.add(
|
|
766
|
+
responses.POST,
|
|
767
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
768
|
+
status=200,
|
|
769
|
+
json={
|
|
770
|
+
"parent": str(elt.id),
|
|
771
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
772
|
+
"classifications": [
|
|
773
|
+
{
|
|
774
|
+
"id": "00000000-0000-0000-0000-000000000000",
|
|
775
|
+
"ml_class": "0000",
|
|
776
|
+
"confidence": 0.75,
|
|
777
|
+
"high_confidence": False,
|
|
778
|
+
"state": "pending",
|
|
779
|
+
},
|
|
780
|
+
{
|
|
781
|
+
"id": "11111111-1111-1111-1111-111111111111",
|
|
782
|
+
"ml_class": "1111",
|
|
783
|
+
"confidence": 0.25,
|
|
784
|
+
"high_confidence": False,
|
|
785
|
+
"state": "pending",
|
|
786
|
+
},
|
|
787
|
+
],
|
|
788
|
+
},
|
|
789
|
+
)
|
|
790
|
+
else:
|
|
791
|
+
for cl_id, cl_class, cl_conf in [
|
|
792
|
+
("00000000-0000-0000-0000-000000000000", "0000", 0.75),
|
|
793
|
+
("11111111-1111-1111-1111-111111111111", "1111", 0.25),
|
|
794
|
+
]:
|
|
795
|
+
responses.add(
|
|
796
|
+
responses.POST,
|
|
797
|
+
"http://testserver/api/v1/classification/bulk/",
|
|
798
|
+
status=200,
|
|
799
|
+
json={
|
|
800
|
+
"parent": str(elt.id),
|
|
801
|
+
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
802
|
+
"classifications": [
|
|
803
|
+
{
|
|
804
|
+
"id": cl_id,
|
|
805
|
+
"ml_class": cl_class,
|
|
806
|
+
"confidence": cl_conf,
|
|
807
|
+
"high_confidence": False,
|
|
808
|
+
"state": "pending",
|
|
809
|
+
},
|
|
810
|
+
],
|
|
771
811
|
},
|
|
772
|
-
|
|
773
|
-
},
|
|
774
|
-
)
|
|
812
|
+
)
|
|
775
813
|
|
|
776
814
|
mock_elements_worker_with_cache.create_classifications(
|
|
777
815
|
element=elt,
|
|
@@ -787,32 +825,41 @@ def test_create_classifications_with_cache(responses, mock_elements_worker_with_
|
|
|
787
825
|
"high_confidence": False,
|
|
788
826
|
},
|
|
789
827
|
],
|
|
828
|
+
batch_size=batch_size,
|
|
790
829
|
)
|
|
791
830
|
|
|
792
|
-
|
|
831
|
+
bulk_api_calls = [("POST", "http://testserver/api/v1/classification/bulk/")]
|
|
832
|
+
if batch_size != DEFAULT_BATCH_SIZE:
|
|
833
|
+
bulk_api_calls.append(("POST", "http://testserver/api/v1/classification/bulk/"))
|
|
834
|
+
|
|
835
|
+
assert len(responses.calls) == len(BASE_API_CALLS) + len(bulk_api_calls)
|
|
793
836
|
assert [
|
|
794
837
|
(call.request.method, call.request.url) for call in responses.calls
|
|
795
|
-
] == BASE_API_CALLS +
|
|
796
|
-
("POST", "http://testserver/api/v1/classification/bulk/"),
|
|
797
|
-
]
|
|
838
|
+
] == BASE_API_CALLS + bulk_api_calls
|
|
798
839
|
|
|
799
|
-
|
|
840
|
+
first_cl = {"confidence": 0.75, "high_confidence": False, "ml_class": "0000"}
|
|
841
|
+
second_cl = {"confidence": 0.25, "high_confidence": False, "ml_class": "1111"}
|
|
842
|
+
empty_payload = {
|
|
800
843
|
"parent": str(elt.id),
|
|
801
844
|
"worker_run_id": "56785678-5678-5678-5678-567856785678",
|
|
802
|
-
"classifications": [
|
|
803
|
-
{
|
|
804
|
-
"confidence": 0.75,
|
|
805
|
-
"high_confidence": False,
|
|
806
|
-
"ml_class": "0000",
|
|
807
|
-
},
|
|
808
|
-
{
|
|
809
|
-
"confidence": 0.25,
|
|
810
|
-
"high_confidence": False,
|
|
811
|
-
"ml_class": "1111",
|
|
812
|
-
},
|
|
813
|
-
],
|
|
845
|
+
"classifications": [],
|
|
814
846
|
}
|
|
815
847
|
|
|
848
|
+
bodies = []
|
|
849
|
+
first_call_idx = None
|
|
850
|
+
if batch_size > 1:
|
|
851
|
+
first_call_idx = -1
|
|
852
|
+
bodies.append({**empty_payload, "classifications": [first_cl, second_cl]})
|
|
853
|
+
else:
|
|
854
|
+
first_call_idx = -2
|
|
855
|
+
bodies.append({**empty_payload, "classifications": [first_cl]})
|
|
856
|
+
bodies.append({**empty_payload, "classifications": [second_cl]})
|
|
857
|
+
|
|
858
|
+
assert [
|
|
859
|
+
json.loads(bulk_call.request.body)
|
|
860
|
+
for bulk_call in responses.calls[first_call_idx:]
|
|
861
|
+
] == bodies
|
|
862
|
+
|
|
816
863
|
# Check that created classifications were properly stored in SQLite cache
|
|
817
864
|
assert list(CachedClassification.select()) == [
|
|
818
865
|
CachedClassification(
|