arkindex-base-worker 0.3.7rc9__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/METADATA +16 -20
  2. arkindex_base_worker-0.4.0.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/WHEEL +1 -1
  4. arkindex_worker/cache.py +1 -1
  5. arkindex_worker/image.py +120 -1
  6. arkindex_worker/models.py +6 -0
  7. arkindex_worker/utils.py +85 -4
  8. arkindex_worker/worker/__init__.py +68 -162
  9. arkindex_worker/worker/base.py +39 -34
  10. arkindex_worker/worker/classification.py +34 -18
  11. arkindex_worker/worker/corpus.py +86 -0
  12. arkindex_worker/worker/dataset.py +71 -1
  13. arkindex_worker/worker/element.py +352 -91
  14. arkindex_worker/worker/entity.py +11 -11
  15. arkindex_worker/worker/image.py +21 -0
  16. arkindex_worker/worker/metadata.py +19 -9
  17. arkindex_worker/worker/process.py +92 -0
  18. arkindex_worker/worker/task.py +5 -4
  19. arkindex_worker/worker/training.py +25 -10
  20. arkindex_worker/worker/transcription.py +89 -68
  21. arkindex_worker/worker/version.py +3 -1
  22. tests/__init__.py +8 -0
  23. tests/conftest.py +36 -52
  24. tests/test_base_worker.py +212 -12
  25. tests/test_dataset_worker.py +21 -45
  26. tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
  27. tests/test_elements_worker/test_cli.py +3 -11
  28. tests/test_elements_worker/test_corpus.py +168 -0
  29. tests/test_elements_worker/test_dataset.py +7 -12
  30. tests/test_elements_worker/test_element.py +427 -0
  31. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  32. tests/test_elements_worker/test_element_create_single.py +528 -0
  33. tests/test_elements_worker/test_element_list_children.py +969 -0
  34. tests/test_elements_worker/test_element_list_parents.py +530 -0
  35. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  36. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  37. tests/test_elements_worker/test_image.py +66 -0
  38. tests/test_elements_worker/test_metadata.py +230 -139
  39. tests/test_elements_worker/test_process.py +89 -0
  40. tests/test_elements_worker/test_task.py +8 -18
  41. tests/test_elements_worker/test_training.py +17 -8
  42. tests/test_elements_worker/test_transcription_create.py +873 -0
  43. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  44. tests/test_elements_worker/test_transcription_list.py +450 -0
  45. tests/test_elements_worker/test_version.py +60 -0
  46. tests/test_elements_worker/test_worker.py +563 -279
  47. tests/test_image.py +432 -209
  48. tests/test_merge.py +1 -2
  49. tests/test_utils.py +66 -3
  50. arkindex_base_worker-0.3.7rc9.dist-info/RECORD +0 -47
  51. tests/test_elements_worker/test_elements.py +0 -2713
  52. tests/test_elements_worker/test_transcriptions.py +0 -2119
  53. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/LICENSE +0 -0
  54. {arkindex_base_worker-0.3.7rc9.dist-info → arkindex_base_worker-0.4.0.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@ import json
2
2
  import sys
3
3
  import tempfile
4
4
  from pathlib import Path
5
- from uuid import UUID
6
5
 
7
6
  import pytest
8
7
 
@@ -58,13 +57,6 @@ def test_cli_arg_elements_list_given(mocker):
58
57
  path.unlink()
59
58
 
60
59
 
61
- def test_cli_arg_element_one_given_not_uuid(mocker):
62
- mocker.patch.object(sys, "argv", ["worker", "--element", "1234"])
63
- worker = ElementsWorker()
64
- with pytest.raises(SystemExit):
65
- worker.configure()
66
-
67
-
68
60
  @pytest.mark.usefixtures("_mock_worker_run_api")
69
61
  def test_cli_arg_element_one_given(mocker):
70
62
  mocker.patch.object(
@@ -73,7 +65,7 @@ def test_cli_arg_element_one_given(mocker):
73
65
  worker = ElementsWorker()
74
66
  worker.configure()
75
67
 
76
- assert worker.args.element == [UUID("12341234-1234-1234-1234-123412341234")]
68
+ assert worker.args.element == ["12341234-1234-1234-1234-123412341234"]
77
69
  # elements_list is None because TASK_ELEMENTS environment variable isn't set
78
70
  assert not worker.args.elements_list
79
71
 
@@ -94,8 +86,8 @@ def test_cli_arg_element_many_given(mocker):
94
86
  worker.configure()
95
87
 
96
88
  assert worker.args.element == [
97
- UUID("12341234-1234-1234-1234-123412341234"),
98
- UUID("43214321-4321-4321-4321-432143214321"),
89
+ "12341234-1234-1234-1234-123412341234",
90
+ "43214321-4321-4321-4321-432143214321",
99
91
  ]
100
92
  # elements_list is None because TASK_ELEMENTS environment variable isn't set
101
93
  assert not worker.args.elements_list
@@ -0,0 +1,168 @@
1
+ import re
2
+ import uuid
3
+
4
+ import pytest
5
+
6
+ from arkindex.exceptions import ErrorResponse
7
+ from arkindex_worker.worker.corpus import CorpusExportState
8
+ from tests import CORPUS_ID
9
+ from tests.test_elements_worker import BASE_API_CALLS
10
+
11
+
12
+ def test_download_export_not_a_uuid(responses, mock_elements_worker):
13
+ with pytest.raises(ValueError, match="export_id is not a valid uuid."):
14
+ mock_elements_worker.download_export("mon export")
15
+
16
+
17
+ def test_download_export(responses, mock_elements_worker):
18
+ responses.add(
19
+ responses.GET,
20
+ "http://testserver/api/v1/export/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
21
+ status=302,
22
+ body=b"some SQLite export",
23
+ content_type="application/x-sqlite3",
24
+ stream=True,
25
+ )
26
+
27
+ export = mock_elements_worker.download_export(
28
+ "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"
29
+ )
30
+ assert export.name == "/tmp/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"
31
+
32
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
33
+ assert [
34
+ (call.request.method, call.request.url) for call in responses.calls
35
+ ] == BASE_API_CALLS + [
36
+ (
37
+ "GET",
38
+ "http://testserver/api/v1/export/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
39
+ ),
40
+ ]
41
+
42
+
43
+ def mock_list_exports_call(responses, export_id):
44
+ responses.add(
45
+ responses.GET,
46
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/",
47
+ status=200,
48
+ json={
49
+ "count": len(CorpusExportState),
50
+ "next": None,
51
+ "results": [
52
+ {
53
+ "id": str(uuid.uuid4())
54
+ if state != CorpusExportState.Done
55
+ else export_id,
56
+ "created": "2019-08-24T14:15:22Z",
57
+ "updated": "2019-08-24T14:15:22Z",
58
+ "corpus_id": CORPUS_ID,
59
+ "user": {
60
+ "id": 0,
61
+ "email": "user@example.com",
62
+ "display_name": "User",
63
+ },
64
+ "state": state.value,
65
+ "source": "default",
66
+ }
67
+ for state in CorpusExportState
68
+ ],
69
+ },
70
+ )
71
+
72
+
73
+ def test_download_latest_export_list_error(responses, mock_elements_worker):
74
+ responses.add(
75
+ responses.GET,
76
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/",
77
+ status=418,
78
+ )
79
+
80
+ with pytest.raises(
81
+ Exception, match="Stopping pagination as data will be incomplete"
82
+ ):
83
+ mock_elements_worker.download_latest_export()
84
+
85
+ assert len(responses.calls) == len(BASE_API_CALLS) + 5
86
+ assert [
87
+ (call.request.method, call.request.url) for call in responses.calls
88
+ ] == BASE_API_CALLS + [
89
+ # The API call is retried 5 times
90
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
91
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
92
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
93
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
94
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
95
+ ]
96
+
97
+
98
+ def test_download_latest_export_no_available_exports(responses, mock_elements_worker):
99
+ responses.add(
100
+ responses.GET,
101
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/",
102
+ status=200,
103
+ json={
104
+ "count": 0,
105
+ "next": None,
106
+ "results": [],
107
+ },
108
+ )
109
+
110
+ with pytest.raises(
111
+ AssertionError,
112
+ match=re.escape(
113
+ f'No available exports found for the corpus ({CORPUS_ID}) with state "Done".'
114
+ ),
115
+ ):
116
+ mock_elements_worker.download_latest_export()
117
+
118
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
119
+ assert [
120
+ (call.request.method, call.request.url) for call in responses.calls
121
+ ] == BASE_API_CALLS + [
122
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
123
+ ]
124
+
125
+
126
+ def test_download_latest_export_download_error(responses, mock_elements_worker):
127
+ export_id = str(uuid.uuid4())
128
+ mock_list_exports_call(responses, export_id)
129
+ responses.add(
130
+ responses.GET,
131
+ f"http://testserver/api/v1/export/{export_id}/",
132
+ status=418,
133
+ )
134
+
135
+ with pytest.raises(ErrorResponse):
136
+ mock_elements_worker.download_latest_export()
137
+
138
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
139
+ assert [
140
+ (call.request.method, call.request.url) for call in responses.calls
141
+ ] == BASE_API_CALLS + [
142
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
143
+ ("GET", f"http://testserver/api/v1/export/{export_id}/"),
144
+ ]
145
+
146
+
147
+ def test_download_latest_export(responses, mock_elements_worker):
148
+ export_id = str(uuid.uuid4())
149
+ mock_list_exports_call(responses, export_id)
150
+ responses.add(
151
+ responses.GET,
152
+ f"http://testserver/api/v1/export/{export_id}/",
153
+ status=302,
154
+ body=b"some SQLite export",
155
+ content_type="application/x-sqlite3",
156
+ stream=True,
157
+ )
158
+
159
+ export = mock_elements_worker.download_latest_export()
160
+ assert export.name == f"/tmp/{export_id}"
161
+
162
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
163
+ assert [
164
+ (call.request.method, call.request.url) for call in responses.calls
165
+ ] == BASE_API_CALLS + [
166
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
167
+ ("GET", f"http://testserver/api/v1/export/{export_id}/"),
168
+ ]
@@ -2,11 +2,11 @@ import json
2
2
  import logging
3
3
 
4
4
  import pytest
5
- from apistar.exceptions import ErrorResponse
6
5
 
6
+ from arkindex.exceptions import ErrorResponse
7
7
  from arkindex_worker.models import Dataset, Element, Set
8
8
  from arkindex_worker.worker.dataset import DatasetState
9
- from tests.conftest import PROCESS_ID
9
+ from tests import PROCESS_ID
10
10
  from tests.test_elements_worker import BASE_API_CALLS
11
11
 
12
12
 
@@ -25,7 +25,7 @@ def test_list_process_sets_api_error(responses, mock_dataset_worker):
25
25
  responses.add(
26
26
  responses.GET,
27
27
  f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
28
- status=500,
28
+ status=418,
29
29
  )
30
30
 
31
31
  with pytest.raises(
@@ -152,7 +152,7 @@ def test_list_set_elements_api_error(
152
152
  responses.add(
153
153
  responses.GET,
154
154
  f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
155
- status=500,
155
+ status=418,
156
156
  )
157
157
 
158
158
  with pytest.raises(
@@ -321,7 +321,7 @@ def test_update_dataset_state_api_error(
321
321
  responses.add(
322
322
  responses.PATCH,
323
323
  f"http://testserver/api/v1/datasets/{default_dataset.id}/",
324
- status=500,
324
+ status=418,
325
325
  )
326
326
 
327
327
  with pytest.raises(ErrorResponse):
@@ -330,16 +330,11 @@ def test_update_dataset_state_api_error(
330
330
  state=DatasetState.Building,
331
331
  )
332
332
 
333
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
333
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
334
334
  assert [
335
335
  (call.request.method, call.request.url) for call in responses.calls
336
336
  ] == BASE_API_CALLS + [
337
- # We retry 5 times the API call
338
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
339
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
340
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
341
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
342
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
337
+ ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
343
338
  ]
344
339
 
345
340
 
@@ -0,0 +1,427 @@
1
+ import json
2
+ import re
3
+
4
+ import pytest
5
+ from responses import matchers
6
+
7
+ from arkindex.exceptions import ErrorResponse
8
+ from arkindex_worker.cache import (
9
+ CachedElement,
10
+ CachedImage,
11
+ )
12
+ from arkindex_worker.models import Element
13
+ from arkindex_worker.worker.element import MissingTypeError
14
+ from tests import CORPUS_ID
15
+
16
+ from . import BASE_API_CALLS
17
+
18
+
19
+ def test_list_corpus_types(responses, mock_elements_worker):
20
+ responses.add(
21
+ responses.GET,
22
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/",
23
+ json={
24
+ "id": CORPUS_ID,
25
+ "types": [{"slug": "folder"}, {"slug": "page"}],
26
+ },
27
+ )
28
+
29
+ mock_elements_worker.list_corpus_types()
30
+
31
+ assert mock_elements_worker.corpus_types == {
32
+ "folder": {"slug": "folder"},
33
+ "page": {"slug": "page"},
34
+ }
35
+
36
+
37
+ def test_check_required_types_argument_types(mock_elements_worker):
38
+ with pytest.raises(
39
+ AssertionError, match="At least one element type slug is required."
40
+ ):
41
+ mock_elements_worker.check_required_types()
42
+
43
+ with pytest.raises(AssertionError, match="Element type slugs must be strings."):
44
+ mock_elements_worker.check_required_types("lol", 42)
45
+
46
+
47
+ def test_check_required_types(mock_elements_worker):
48
+ mock_elements_worker.corpus_types = {
49
+ "folder": {"slug": "folder"},
50
+ "page": {"slug": "page"},
51
+ }
52
+
53
+ assert mock_elements_worker.check_required_types("page")
54
+ assert mock_elements_worker.check_required_types("page", "folder")
55
+
56
+ with pytest.raises(
57
+ MissingTypeError,
58
+ match=re.escape(
59
+ "Element types act, text_line were not found in corpus (11111111-1111-1111-1111-111111111111)."
60
+ ),
61
+ ):
62
+ assert mock_elements_worker.check_required_types("page", "text_line", "act")
63
+
64
+
65
+ def test_check_required_types_create_missing(responses, mock_elements_worker):
66
+ mock_elements_worker.corpus_types = {
67
+ "folder": {"slug": "folder"},
68
+ "page": {"slug": "page"},
69
+ }
70
+
71
+ responses.add(
72
+ responses.POST,
73
+ "http://testserver/api/v1/elements/type/",
74
+ match=[
75
+ matchers.json_params_matcher(
76
+ {
77
+ "slug": "text_line",
78
+ "display_name": "text_line",
79
+ "folder": False,
80
+ "corpus": CORPUS_ID,
81
+ }
82
+ )
83
+ ],
84
+ )
85
+ responses.add(
86
+ responses.POST,
87
+ "http://testserver/api/v1/elements/type/",
88
+ match=[
89
+ matchers.json_params_matcher(
90
+ {
91
+ "slug": "act",
92
+ "display_name": "act",
93
+ "folder": False,
94
+ "corpus": CORPUS_ID,
95
+ }
96
+ )
97
+ ],
98
+ )
99
+
100
+ assert mock_elements_worker.check_required_types(
101
+ "page", "text_line", "act", create_missing=True
102
+ )
103
+
104
+
105
+ @pytest.mark.parametrize(
106
+ ("payload", "error"),
107
+ [
108
+ # Element
109
+ (
110
+ {"element": None},
111
+ "element shouldn't be null and should be an Element or CachedElement",
112
+ ),
113
+ (
114
+ {"element": "not element type"},
115
+ "element shouldn't be null and should be an Element or CachedElement",
116
+ ),
117
+ ],
118
+ )
119
+ def test_partial_update_element_wrong_param_element(
120
+ mock_elements_worker, payload, error
121
+ ):
122
+ api_payload = {
123
+ "element": Element({"zone": None}),
124
+ **payload,
125
+ }
126
+
127
+ with pytest.raises(AssertionError, match=error):
128
+ mock_elements_worker.partial_update_element(
129
+ **api_payload,
130
+ )
131
+
132
+
133
+ @pytest.mark.parametrize(
134
+ ("payload", "error"),
135
+ [
136
+ # Type
137
+ ({"type": 1234}, "type should be a str"),
138
+ ({"type": None}, "type should be a str"),
139
+ ],
140
+ )
141
+ def test_partial_update_element_wrong_param_type(mock_elements_worker, payload, error):
142
+ api_payload = {
143
+ "element": Element({"zone": None}),
144
+ **payload,
145
+ }
146
+
147
+ with pytest.raises(AssertionError, match=error):
148
+ mock_elements_worker.partial_update_element(
149
+ **api_payload,
150
+ )
151
+
152
+
153
+ @pytest.mark.parametrize(
154
+ ("payload", "error"),
155
+ [
156
+ # Name
157
+ ({"name": 1234}, "name should be a str"),
158
+ ({"name": None}, "name should be a str"),
159
+ ],
160
+ )
161
+ def test_partial_update_element_wrong_param_name(mock_elements_worker, payload, error):
162
+ api_payload = {
163
+ "element": Element({"zone": None}),
164
+ **payload,
165
+ }
166
+
167
+ with pytest.raises(AssertionError, match=error):
168
+ mock_elements_worker.partial_update_element(
169
+ **api_payload,
170
+ )
171
+
172
+
173
+ @pytest.mark.parametrize(
174
+ ("payload", "error"),
175
+ [
176
+ # Polygon
177
+ ({"polygon": "not a polygon"}, "polygon should be a list"),
178
+ ({"polygon": None}, "polygon should be a list"),
179
+ ({"polygon": [[1, 1], [2, 2]]}, "polygon should have at least three points"),
180
+ (
181
+ {"polygon": [[1, 1, 1], [2, 2, 1], [2, 1, 1], [1, 2, 1]]},
182
+ "polygon points should be lists of two items",
183
+ ),
184
+ (
185
+ {"polygon": [[1], [2], [2], [1]]},
186
+ "polygon points should be lists of two items",
187
+ ),
188
+ (
189
+ {"polygon": [["not a coord", 1], [2, 2], [2, 1], [1, 2]]},
190
+ "polygon points should be lists of two numbers",
191
+ ),
192
+ ],
193
+ )
194
+ def test_partial_update_element_wrong_param_polygon(
195
+ mock_elements_worker, payload, error
196
+ ):
197
+ api_payload = {
198
+ "element": Element({"zone": None}),
199
+ **payload,
200
+ }
201
+
202
+ with pytest.raises(AssertionError, match=error):
203
+ mock_elements_worker.partial_update_element(
204
+ **api_payload,
205
+ )
206
+
207
+
208
+ @pytest.mark.parametrize(
209
+ ("payload", "error"),
210
+ [
211
+ # Confidence
212
+ ({"confidence": "lol"}, "confidence should be None or a float in [0..1] range"),
213
+ ({"confidence": "0.2"}, "confidence should be None or a float in [0..1] range"),
214
+ ({"confidence": -1.0}, "confidence should be None or a float in [0..1] range"),
215
+ ({"confidence": 1.42}, "confidence should be None or a float in [0..1] range"),
216
+ (
217
+ {"confidence": float("inf")},
218
+ "confidence should be None or a float in [0..1] range",
219
+ ),
220
+ ],
221
+ )
222
+ def test_partial_update_element_wrong_param_conf(mock_elements_worker, payload, error):
223
+ api_payload = {
224
+ "element": Element({"zone": None}),
225
+ **payload,
226
+ }
227
+
228
+ with pytest.raises(AssertionError, match=re.escape(error)):
229
+ mock_elements_worker.partial_update_element(
230
+ **api_payload,
231
+ )
232
+
233
+
234
+ @pytest.mark.parametrize(
235
+ ("payload", "error"),
236
+ [
237
+ # Rotation angle
238
+ ({"rotation_angle": "lol"}, "rotation_angle should be a positive integer"),
239
+ ({"rotation_angle": -1}, "rotation_angle should be a positive integer"),
240
+ ({"rotation_angle": 0.5}, "rotation_angle should be a positive integer"),
241
+ ({"rotation_angle": None}, "rotation_angle should be a positive integer"),
242
+ ],
243
+ )
244
+ def test_partial_update_element_wrong_param_rota(mock_elements_worker, payload, error):
245
+ api_payload = {
246
+ "element": Element({"zone": None}),
247
+ **payload,
248
+ }
249
+
250
+ with pytest.raises(AssertionError, match=error):
251
+ mock_elements_worker.partial_update_element(
252
+ **api_payload,
253
+ )
254
+
255
+
256
+ @pytest.mark.parametrize(
257
+ ("payload", "error"),
258
+ [
259
+ # Mirrored
260
+ ({"mirrored": "lol"}, "mirrored should be a boolean"),
261
+ ({"mirrored": 1234}, "mirrored should be a boolean"),
262
+ ({"mirrored": None}, "mirrored should be a boolean"),
263
+ ],
264
+ )
265
+ def test_partial_update_element_wrong_param_mir(mock_elements_worker, payload, error):
266
+ api_payload = {
267
+ "element": Element({"zone": None}),
268
+ **payload,
269
+ }
270
+
271
+ with pytest.raises(AssertionError, match=error):
272
+ mock_elements_worker.partial_update_element(
273
+ **api_payload,
274
+ )
275
+
276
+
277
+ @pytest.mark.parametrize(
278
+ ("payload", "error"),
279
+ [
280
+ # Image
281
+ ({"image": "lol"}, "image should be a UUID"),
282
+ ({"image": 1234}, "image should be a UUID"),
283
+ ({"image": None}, "image should be a UUID"),
284
+ ],
285
+ )
286
+ def test_partial_update_element_wrong_param_image(mock_elements_worker, payload, error):
287
+ api_payload = {
288
+ "element": Element({"zone": None}),
289
+ **payload,
290
+ }
291
+
292
+ with pytest.raises(AssertionError, match=error):
293
+ mock_elements_worker.partial_update_element(
294
+ **api_payload,
295
+ )
296
+
297
+
298
+ def test_partial_update_element_api_error(responses, mock_elements_worker):
299
+ elt = Element({"id": "12341234-1234-1234-1234-123412341234"})
300
+ responses.add(
301
+ responses.PATCH,
302
+ f"http://testserver/api/v1/element/{elt.id}/",
303
+ status=418,
304
+ )
305
+
306
+ with pytest.raises(ErrorResponse):
307
+ mock_elements_worker.partial_update_element(
308
+ element=elt,
309
+ type="something",
310
+ name="0",
311
+ polygon=[[1, 1], [2, 2], [2, 1], [1, 2]],
312
+ )
313
+
314
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
315
+ assert [
316
+ (call.request.method, call.request.url) for call in responses.calls
317
+ ] == BASE_API_CALLS + [("PATCH", f"http://testserver/api/v1/element/{elt.id}/")]
318
+
319
+
320
+ @pytest.mark.usefixtures("_mock_cached_elements", "_mock_cached_images")
321
+ @pytest.mark.parametrize(
322
+ "payload",
323
+ [
324
+ (
325
+ {
326
+ "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]],
327
+ "confidence": None,
328
+ }
329
+ ),
330
+ (
331
+ {
332
+ "rotation_angle": 45,
333
+ "mirrored": False,
334
+ }
335
+ ),
336
+ (
337
+ {
338
+ "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]],
339
+ "confidence": None,
340
+ "rotation_angle": 45,
341
+ "mirrored": False,
342
+ }
343
+ ),
344
+ ],
345
+ )
346
+ def test_partial_update_element(responses, mock_elements_worker_with_cache, payload):
347
+ elt = CachedElement.select().first()
348
+ new_image = CachedImage.select().first()
349
+
350
+ elt_response = {
351
+ "image": str(new_image.id),
352
+ **payload,
353
+ }
354
+ responses.add(
355
+ responses.PATCH,
356
+ f"http://testserver/api/v1/element/{elt.id}/",
357
+ status=200,
358
+ # UUID not allowed in JSON
359
+ json=elt_response,
360
+ )
361
+
362
+ element_update_response = mock_elements_worker_with_cache.partial_update_element(
363
+ element=elt,
364
+ **{**elt_response, "image": new_image.id},
365
+ )
366
+
367
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
368
+ assert [
369
+ (call.request.method, call.request.url) for call in responses.calls
370
+ ] == BASE_API_CALLS + [
371
+ (
372
+ "PATCH",
373
+ f"http://testserver/api/v1/element/{elt.id}/",
374
+ ),
375
+ ]
376
+ assert json.loads(responses.calls[-1].request.body) == elt_response
377
+ assert element_update_response == elt_response
378
+
379
+ cached_element = CachedElement.get(CachedElement.id == elt.id)
380
+ # Always present in payload
381
+ assert str(cached_element.image_id) == elt_response["image"]
382
+ # Optional params
383
+ if "polygon" in payload:
384
+ # Cast to string as this is the only difference compared to model
385
+ elt_response["polygon"] = str(elt_response["polygon"])
386
+
387
+ for param in payload:
388
+ assert getattr(cached_element, param) == elt_response[param]
389
+
390
+
391
+ @pytest.mark.usefixtures("_mock_cached_elements")
392
+ @pytest.mark.parametrize("confidence", [None, 0.42])
393
+ def test_partial_update_element_confidence(
394
+ responses, mock_elements_worker_with_cache, confidence
395
+ ):
396
+ elt = CachedElement.select().first()
397
+ elt_response = {
398
+ "polygon": [[10, 10], [20, 20], [20, 10], [10, 20]],
399
+ "confidence": confidence,
400
+ }
401
+ responses.add(
402
+ responses.PATCH,
403
+ f"http://testserver/api/v1/element/{elt.id}/",
404
+ status=200,
405
+ json=elt_response,
406
+ )
407
+
408
+ element_update_response = mock_elements_worker_with_cache.partial_update_element(
409
+ element=elt,
410
+ **elt_response,
411
+ )
412
+
413
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
414
+ assert [
415
+ (call.request.method, call.request.url) for call in responses.calls
416
+ ] == BASE_API_CALLS + [
417
+ (
418
+ "PATCH",
419
+ f"http://testserver/api/v1/element/{elt.id}/",
420
+ ),
421
+ ]
422
+ assert json.loads(responses.calls[-1].request.body) == elt_response
423
+ assert element_update_response == elt_response
424
+
425
+ cached_element = CachedElement.get(CachedElement.id == elt.id)
426
+ assert cached_element.polygon == str(elt_response["polygon"])
427
+ assert cached_element.confidence == confidence