arkindex-base-worker 0.4.0__py3-none-any.whl → 0.4.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/METADATA +13 -15
  2. arkindex_base_worker-0.4.0a2.dist-info/RECORD +51 -0
  3. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/WHEEL +1 -1
  4. arkindex_worker/cache.py +1 -1
  5. arkindex_worker/image.py +1 -120
  6. arkindex_worker/utils.py +0 -82
  7. arkindex_worker/worker/__init__.py +161 -46
  8. arkindex_worker/worker/base.py +11 -36
  9. arkindex_worker/worker/classification.py +18 -34
  10. arkindex_worker/worker/corpus.py +4 -21
  11. arkindex_worker/worker/dataset.py +1 -71
  12. arkindex_worker/worker/element.py +91 -352
  13. arkindex_worker/worker/entity.py +11 -11
  14. arkindex_worker/worker/metadata.py +9 -19
  15. arkindex_worker/worker/task.py +4 -5
  16. arkindex_worker/worker/training.py +6 -6
  17. arkindex_worker/worker/transcription.py +68 -89
  18. arkindex_worker/worker/version.py +1 -3
  19. tests/__init__.py +1 -1
  20. tests/conftest.py +45 -33
  21. tests/test_base_worker.py +3 -204
  22. tests/test_dataset_worker.py +4 -7
  23. tests/test_elements_worker/{test_classification.py → test_classifications.py} +61 -194
  24. tests/test_elements_worker/test_corpus.py +1 -32
  25. tests/test_elements_worker/test_dataset.py +1 -1
  26. tests/test_elements_worker/test_elements.py +2734 -0
  27. tests/test_elements_worker/{test_entity_create.py → test_entities.py} +160 -26
  28. tests/test_elements_worker/test_image.py +1 -2
  29. tests/test_elements_worker/test_metadata.py +99 -224
  30. tests/test_elements_worker/test_task.py +1 -1
  31. tests/test_elements_worker/test_training.py +2 -2
  32. tests/test_elements_worker/test_transcriptions.py +2102 -0
  33. tests/test_elements_worker/test_worker.py +280 -563
  34. tests/test_image.py +204 -429
  35. tests/test_merge.py +2 -1
  36. tests/test_utils.py +3 -66
  37. arkindex_base_worker-0.4.0.dist-info/RECORD +0 -61
  38. arkindex_worker/worker/process.py +0 -92
  39. tests/test_elements_worker/test_element.py +0 -427
  40. tests/test_elements_worker/test_element_create_multiple.py +0 -715
  41. tests/test_elements_worker/test_element_create_single.py +0 -528
  42. tests/test_elements_worker/test_element_list_children.py +0 -969
  43. tests/test_elements_worker/test_element_list_parents.py +0 -530
  44. tests/test_elements_worker/test_entity_list_and_check.py +0 -160
  45. tests/test_elements_worker/test_process.py +0 -89
  46. tests/test_elements_worker/test_transcription_create.py +0 -873
  47. tests/test_elements_worker/test_transcription_create_with_elements.py +0 -951
  48. tests/test_elements_worker/test_transcription_list.py +0 -450
  49. tests/test_elements_worker/test_version.py +0 -60
  50. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/LICENSE +0 -0
  51. {arkindex_base_worker-0.4.0.dist-info → arkindex_base_worker-0.4.0a2.dist-info}/top_level.txt +0 -0
@@ -1,572 +1,90 @@
1
1
  import json
2
2
  import sys
3
- from argparse import Namespace
4
- from uuid import UUID
5
3
 
6
4
  import pytest
5
+ from apistar.exceptions import ErrorResponse
7
6
 
8
- from arkindex.exceptions import ErrorResponse
9
- from arkindex_worker.cache import (
10
- SQL_VERSION,
11
- CachedElement,
12
- create_version_table,
13
- init_cache_db,
14
- )
15
- from arkindex_worker.models import Element
7
+ from arkindex_worker.cache import CachedElement
16
8
  from arkindex_worker.worker import ActivityState, ElementsWorker
17
- from arkindex_worker.worker.dataset import DatasetState
18
- from arkindex_worker.worker.process import ProcessMode
19
- from tests import PROCESS_ID
9
+ from tests import CORPUS_ID
20
10
 
21
11
  from . import BASE_API_CALLS
22
12
 
13
+ TEST_VERSION_ID = "test_123"
14
+ TEST_SLUG = "some_slug"
23
15
 
24
- def test_database_arg(mocker, mock_elements_worker, tmp_path):
25
- database_path = tmp_path / "my_database.sqlite"
26
- init_cache_db(database_path)
27
- create_version_table()
28
-
29
- mocker.patch(
30
- "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
31
- return_value=Namespace(
32
- element=["volumeid", "pageid"],
33
- verbose=False,
34
- elements_list=None,
35
- database=database_path,
36
- dev=False,
37
- set=[],
38
- ),
39
- )
40
-
41
- worker = ElementsWorker(support_cache=True)
42
- worker.configure()
43
-
44
- assert worker.use_cache is True
45
- assert worker.cache_path == database_path
46
-
47
-
48
- def test_database_arg_cache_missing_version_table(
49
- mocker, mock_elements_worker, tmp_path
50
- ):
51
- database_path = tmp_path / "my_database.sqlite"
52
- database_path.touch()
53
-
54
- mocker.patch(
55
- "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
56
- return_value=Namespace(
57
- element=["volumeid", "pageid"],
58
- verbose=False,
59
- elements_list=None,
60
- database=database_path,
61
- dev=False,
62
- set=[],
63
- ),
64
- )
65
-
66
- worker = ElementsWorker(support_cache=True)
67
- with pytest.raises(
68
- AssertionError,
69
- match=f"The SQLite database {database_path} does not have the correct cache version, it should be {SQL_VERSION}",
70
- ):
71
- worker.configure()
72
-
73
-
74
- def test_readonly(responses, mock_elements_worker):
75
- """Test readonly worker does not trigger any API calls"""
76
-
77
- # Setup the worker as read-only
78
- mock_elements_worker.worker_run_id = None
79
- assert mock_elements_worker.is_read_only is True
80
-
81
- out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
82
-
83
- # update_activity returns False in very specific cases
84
- assert out is True
85
- assert len(responses.calls) == len(BASE_API_CALLS)
86
- assert [
87
- (call.request.method, call.request.url) for call in responses.calls
88
- ] == BASE_API_CALLS
89
-
90
-
91
- def test_get_elements_elements_list_arg_wrong_type(
92
- monkeypatch, tmp_path, mock_elements_worker
93
- ):
94
- elements_path = tmp_path / "elements.json"
95
- elements_path.write_text("{}")
96
-
97
- monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
98
- worker = ElementsWorker()
99
- worker.configure()
100
-
101
- with pytest.raises(AssertionError, match="Elements list must be a list"):
102
- worker.get_elements()
103
16
 
17
+ def test_get_worker_version(fake_dummy_worker):
18
+ api_client = fake_dummy_worker.api_client
104
19
 
105
- def test_get_elements_elements_list_arg_empty_list(
106
- monkeypatch, tmp_path, mock_elements_worker
107
- ):
108
- elements_path = tmp_path / "elements.json"
109
- elements_path.write_text("[]")
20
+ response = {"worker": {"slug": TEST_SLUG}}
110
21
 
111
- monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
112
- worker = ElementsWorker()
113
- worker.configure()
22
+ api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
114
23
 
115
- with pytest.raises(AssertionError, match="No elements in elements list"):
116
- worker.get_elements()
24
+ with pytest.deprecated_call(match="WorkerVersion usage is deprecated."):
25
+ res = fake_dummy_worker.get_worker_version(TEST_VERSION_ID)
117
26
 
27
+ assert res == response
28
+ assert fake_dummy_worker._worker_version_cache[TEST_VERSION_ID] == response
118
29
 
119
- def test_get_elements_elements_list_arg_missing_id(
120
- monkeypatch, tmp_path, mock_elements_worker
121
- ):
122
- elements_path = tmp_path / "elements.json"
123
- elements_path.write_text(json.dumps([{"type": "volume"}]))
124
-
125
- monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
126
- worker = ElementsWorker()
127
- worker.configure()
128
30
 
129
- elt_list = worker.get_elements()
31
+ def test_get_worker_version__uses_cache(fake_dummy_worker):
32
+ api_client = fake_dummy_worker.api_client
130
33
 
131
- assert elt_list == []
34
+ response = {"worker": {"slug": TEST_SLUG}}
132
35
 
36
+ api_client.add_response("RetrieveWorkerVersion", response, id=TEST_VERSION_ID)
133
37
 
134
- def test_get_elements_elements_list_arg_not_uuid(
135
- monkeypatch, tmp_path, mock_elements_worker
136
- ):
137
- elements_path = tmp_path / "elements.json"
138
- elements_path.write_text(
139
- json.dumps(
140
- [
141
- {"id": "volumeid", "type": "volume"},
142
- {"id": "pageid", "type": "page"},
143
- {"id": "actid", "type": "act"},
144
- {"id": "surfaceid", "type": "surface"},
145
- ]
146
- )
147
- )
38
+ with pytest.deprecated_call(match="WorkerVersion usage is deprecated."):
39
+ response_1 = fake_dummy_worker.get_worker_version(TEST_VERSION_ID)
148
40
 
149
- monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
150
- worker = ElementsWorker()
151
- worker.configure()
41
+ with pytest.deprecated_call(match="WorkerVersion usage is deprecated."):
42
+ response_2 = fake_dummy_worker.get_worker_version(TEST_VERSION_ID)
152
43
 
153
- with pytest.raises(
154
- Exception,
155
- match="These element IDs are invalid: volumeid, pageid, actid, surfaceid",
156
- ):
157
- worker.get_elements()
158
-
159
-
160
- def test_get_elements_elements_list_arg(monkeypatch, tmp_path, mock_elements_worker):
161
- elements_path = tmp_path / "elements.json"
162
- elements_path.write_text(
163
- json.dumps(
164
- [
165
- {"id": "11111111-1111-1111-1111-111111111111", "type": "volume"},
166
- {"id": "22222222-2222-2222-2222-222222222222", "type": "page"},
167
- {"id": "33333333-3333-3333-3333-333333333333", "type": "act"},
168
- ]
169
- )
170
- )
44
+ assert response_1 == response
45
+ assert response_1 == response_2
171
46
 
172
- monkeypatch.setenv("TASK_ELEMENTS", str(elements_path))
173
- worker = ElementsWorker()
174
- worker.configure()
47
+ # assert that only one call to the API
48
+ assert len(api_client.history) == 1
49
+ assert not api_client.responses
175
50
 
176
- elt_list = worker.get_elements()
177
51
 
178
- assert elt_list == [
179
- "11111111-1111-1111-1111-111111111111",
180
- "22222222-2222-2222-2222-222222222222",
181
- "33333333-3333-3333-3333-333333333333",
182
- ]
183
-
184
-
185
- def test_get_elements_element_arg_not_uuid(mocker, mock_elements_worker):
186
- mocker.patch(
187
- "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
188
- return_value=Namespace(
189
- element=["volumeid", "pageid"],
190
- config={},
191
- verbose=False,
192
- elements_list=None,
193
- database=None,
194
- dev=True,
195
- set=[],
196
- ),
197
- )
198
-
199
- worker = ElementsWorker()
200
- worker.configure()
201
-
202
- with pytest.raises(
203
- Exception, match="These element IDs are invalid: volumeid, pageid"
204
- ):
205
- worker.get_elements()
206
-
207
-
208
- def test_get_elements_element_arg(mocker, mock_elements_worker):
209
- mocker.patch(
210
- "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
211
- return_value=Namespace(
212
- element=[
213
- "11111111-1111-1111-1111-111111111111",
214
- "22222222-2222-2222-2222-222222222222",
215
- ],
216
- config={},
217
- verbose=False,
218
- elements_list=None,
219
- database=None,
220
- dev=True,
221
- set=[],
222
- ),
223
- )
224
-
225
- worker = ElementsWorker()
226
- worker.configure()
227
-
228
- elt_list = worker.get_elements()
229
-
230
- assert elt_list == [
231
- "11111111-1111-1111-1111-111111111111",
232
- "22222222-2222-2222-2222-222222222222",
233
- ]
234
-
235
-
236
- def test_get_elements_dataset_set_arg(responses, mocker, mock_elements_worker):
237
- mocker.patch(
238
- "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
239
- return_value=Namespace(
240
- element=[],
241
- config={},
242
- verbose=False,
243
- elements_list=None,
244
- database=None,
245
- dev=True,
246
- set=[(UUID("11111111-1111-1111-1111-111111111111"), "train")],
247
- ),
248
- )
249
-
250
- # Mock RetrieveDataset call
251
- responses.add(
252
- responses.GET,
253
- "http://testserver/api/v1/datasets/11111111-1111-1111-1111-111111111111/",
254
- status=200,
255
- json={
256
- "id": "11111111-1111-1111-1111-111111111111",
257
- "name": "My dataset",
258
- "description": "A dataset about cats.",
259
- "sets": ["train", "dev", "test"],
260
- "state": DatasetState.Complete.value,
261
- },
262
- content_type="application/json",
263
- )
264
-
265
- # Mock ListSetElements call
266
- element = {
267
- "id": "22222222-2222-2222-2222-222222222222",
268
- "type": "page",
269
- "name": "1",
270
- "corpus": {
271
- "id": "11111111-1111-1111-1111-111111111111",
272
- },
273
- "thumbnail_url": "http://example.com",
274
- "zone": {
275
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
276
- "polygon": [[0, 0], [0, 0], [0, 0]],
277
- "image": {
278
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
279
- "path": "string",
280
- "width": 0,
281
- "height": 0,
282
- "url": "http://example.com",
283
- "s3_url": "string",
284
- "status": "checked",
285
- "server": {
286
- "display_name": "string",
287
- "url": "http://example.com",
288
- "max_width": 2147483647,
289
- "max_height": 2147483647,
290
- },
291
- },
292
- "url": "http://example.com",
293
- },
294
- "rotation_angle": 0,
295
- "mirrored": False,
296
- "created": "2019-08-24T14:15:22Z",
297
- "classes": [
298
- {
299
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
300
- "ml_class": {
301
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
302
- "name": "string",
303
- },
304
- "state": "pending",
305
- "confidence": 0,
306
- "high_confidence": True,
307
- "worker_run": {
308
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
309
- "summary": "string",
310
- },
311
- }
312
- ],
313
- "metadata": [
314
- {
315
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
316
- "type": "text",
317
- "name": "string",
318
- "value": "string",
319
- "dates": [{"type": "exact", "year": 0, "month": 1, "day": 1}],
320
- }
321
- ],
322
- "transcriptions": [
323
- {
324
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
325
- "text": "string",
326
- "confidence": 0,
327
- "orientation": "horizontal-lr",
328
- "worker_run": {
329
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
330
- "summary": "string",
331
- },
332
- }
333
- ],
334
- "has_children": True,
335
- "worker_run": {
336
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
337
- "summary": "string",
338
- },
339
- "confidence": 1,
52
+ def test_get_worker_version_slug(mocker, fake_dummy_worker):
53
+ fake_dummy_worker.get_worker_version = mocker.MagicMock()
54
+ fake_dummy_worker.get_worker_version.return_value = {
55
+ "id": TEST_VERSION_ID,
56
+ "worker": {"slug": "mock_slug"},
340
57
  }
341
- responses.add(
342
- responses.GET,
343
- "http://testserver/api/v1/datasets/11111111-1111-1111-1111-111111111111/elements/?set=train&with_count=true",
344
- status=200,
345
- json={
346
- "next": None,
347
- "previous": None,
348
- "results": [
349
- {
350
- "set": "train",
351
- "element": element,
352
- }
353
- ],
354
- "count": 1,
355
- },
356
- content_type="application/json",
357
- )
358
-
359
- worker = ElementsWorker()
360
- worker.configure()
361
58
 
362
- elt_list = worker.get_elements()
363
-
364
- assert elt_list == [
365
- Element(**element),
366
- ]
59
+ with pytest.deprecated_call(match="WorkerVersion usage is deprecated."):
60
+ slug = fake_dummy_worker.get_worker_version_slug(TEST_VERSION_ID)
61
+ assert slug == "mock_slug"
367
62
 
368
63
 
369
- def test_get_elements_dataset_set_api(responses, mocker, mock_elements_worker):
370
- # Mock ListProcessSets call
371
- responses.add(
372
- responses.GET,
373
- "http://testserver/api/v1/process/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/sets/",
374
- status=200,
375
- json={
376
- "next": None,
377
- "previous": None,
378
- "results": [
379
- {
380
- "id": "33333333-3333-3333-3333-333333333333",
381
- "dataset": {"id": "11111111-1111-1111-1111-111111111111"},
382
- "set_name": "train",
383
- }
384
- ],
385
- "count": 1,
386
- },
387
- content_type="application/json",
388
- )
389
-
390
- # Mock ListSetElements call
391
- element = {
392
- "id": "22222222-2222-2222-2222-222222222222",
393
- "type": "page",
394
- "name": "1",
395
- "corpus": {
396
- "id": "11111111-1111-1111-1111-111111111111",
397
- },
398
- "thumbnail_url": "http://example.com",
399
- "zone": {
400
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
401
- "polygon": [[0, 0], [0, 0], [0, 0]],
402
- "image": {
403
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
404
- "path": "string",
405
- "width": 0,
406
- "height": 0,
407
- "url": "http://example.com",
408
- "s3_url": "string",
409
- "status": "checked",
410
- "server": {
411
- "display_name": "string",
412
- "url": "http://example.com",
413
- "max_width": 2147483647,
414
- "max_height": 2147483647,
415
- },
416
- },
417
- "url": "http://example.com",
418
- },
419
- "rotation_angle": 0,
420
- "mirrored": False,
421
- "created": "2019-08-24T14:15:22Z",
422
- "classes": [
423
- {
424
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
425
- "ml_class": {
426
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
427
- "name": "string",
428
- },
429
- "state": "pending",
430
- "confidence": 0,
431
- "high_confidence": True,
432
- "worker_run": {
433
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
434
- "summary": "string",
435
- },
436
- }
437
- ],
438
- "metadata": [
439
- {
440
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
441
- "type": "text",
442
- "name": "string",
443
- "value": "string",
444
- "dates": [{"type": "exact", "year": 0, "month": 1, "day": 1}],
445
- }
446
- ],
447
- "transcriptions": [
448
- {
449
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
450
- "text": "string",
451
- "confidence": 0,
452
- "orientation": "horizontal-lr",
453
- "worker_run": {
454
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
455
- "summary": "string",
456
- },
457
- }
458
- ],
459
- "has_children": True,
460
- "worker_run": {
461
- "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
462
- "summary": "string",
463
- },
464
- "confidence": 1,
465
- }
466
- responses.add(
467
- responses.GET,
468
- "http://testserver/api/v1/datasets/11111111-1111-1111-1111-111111111111/elements/?set=train&with_count=true",
469
- status=200,
470
- json={
471
- "next": None,
472
- "previous": None,
473
- "results": [
474
- {
475
- "set": "train",
476
- "element": element,
477
- }
478
- ],
479
- "count": 1,
480
- },
481
- content_type="application/json",
482
- )
483
-
484
- # Update ProcessMode to Dataset
485
- mock_elements_worker.process_information["mode"] = ProcessMode.Dataset
486
-
487
- elt_list = mock_elements_worker.get_elements()
488
-
489
- assert elt_list == [
490
- Element(**element),
491
- ]
492
-
64
+ def test_get_worker_version_slug_none(fake_dummy_worker):
65
+ # WARNING: pytest.deprecated_call must be placed BEFORE pytest.raises, otherwise `match` argument won't be checked
66
+ with (
67
+ pytest.deprecated_call(match="WorkerVersion usage is deprecated."),
68
+ pytest.raises(ValueError, match="No worker version ID"),
69
+ ):
70
+ fake_dummy_worker.get_worker_version_slug(None)
493
71
 
494
- def test_get_elements_both_args_error(mocker, mock_elements_worker, tmp_path):
495
- elements_path = tmp_path / "elements.json"
496
- elements_path.write_text(
497
- json.dumps(
498
- [
499
- {"id": "volumeid", "type": "volume"},
500
- {"id": "pageid", "type": "page"},
501
- {"id": "actid", "type": "act"},
502
- {"id": "surfaceid", "type": "surface"},
503
- ]
504
- )
505
- )
506
- mocker.patch(
507
- "arkindex_worker.worker.base.argparse.ArgumentParser.parse_args",
508
- return_value=Namespace(
509
- element=["anotherid", "againanotherid"],
510
- verbose=False,
511
- elements_list=elements_path.open(),
512
- database=None,
513
- dev=False,
514
- set=[],
515
- ),
516
- )
517
72
 
518
- worker = ElementsWorker()
519
- worker.configure()
73
+ def test_readonly(responses, mock_elements_worker):
74
+ """Test readonly worker does not trigger any API calls"""
520
75
 
521
- with pytest.raises(
522
- AssertionError, match="elements-list and element CLI args shouldn't be both set"
523
- ):
524
- worker.get_elements()
76
+ # Setup the worker as read-only
77
+ mock_elements_worker.worker_run_id = None
78
+ assert mock_elements_worker.is_read_only is True
525
79
 
80
+ out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
526
81
 
527
- def test_get_elements_export_process(mock_elements_worker, responses):
528
- responses.add(
529
- responses.GET,
530
- f"http://testserver/api/v1/process/{PROCESS_ID}/elements/?page_size=500&with_count=true&with_image=False",
531
- status=200,
532
- json={
533
- "count": 2,
534
- "next": None,
535
- "results": [
536
- {
537
- "id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
538
- "type_id": "baaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
539
- "name": "element 1",
540
- "confidence": 1,
541
- "image_id": None,
542
- "image_width": None,
543
- "image_height": None,
544
- "image_url": None,
545
- "polygon": None,
546
- "rotation_angle": 0,
547
- "mirrored": False,
548
- },
549
- {
550
- "id": "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaa0",
551
- "type_id": "baaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
552
- "name": "element 2",
553
- "confidence": 1,
554
- "image_id": None,
555
- "image_width": None,
556
- "image_height": None,
557
- "image_url": None,
558
- "polygon": None,
559
- "rotation_angle": 0,
560
- "mirrored": False,
561
- },
562
- ],
563
- },
564
- )
565
- mock_elements_worker.process_information["mode"] = "export"
566
- assert set(mock_elements_worker.get_elements()) == {
567
- "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
568
- "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaa0",
569
- }
82
+ # update_activity returns False in very specific cases
83
+ assert out is True
84
+ assert len(responses.calls) == len(BASE_API_CALLS)
85
+ assert [
86
+ (call.request.method, call.request.url) for call in responses.calls
87
+ ] == BASE_API_CALLS
570
88
 
571
89
 
572
90
  @pytest.mark.usefixtures("_mock_worker_run_api")
@@ -597,6 +115,43 @@ def test_activities_dev_mode(mocker):
597
115
  assert worker.store_activity is False
598
116
 
599
117
 
118
+ @pytest.mark.usefixtures("_mock_worker_run_api")
119
+ def test_update_call(responses, mock_elements_worker):
120
+ """Test an update call with feature enabled triggers an API call"""
121
+ responses.add(
122
+ responses.PUT,
123
+ "http://testserver/api/v1/workers/versions/56785678-5678-5678-5678-567856785678/activity/",
124
+ status=200,
125
+ json={
126
+ "element_id": "1234-deadbeef",
127
+ "process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
128
+ "state": "processed",
129
+ },
130
+ )
131
+
132
+ out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
133
+
134
+ # Check the response received by worker
135
+ assert out is True
136
+
137
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
138
+ assert [
139
+ (call.request.method, call.request.url) for call in responses.calls
140
+ ] == BASE_API_CALLS + [
141
+ (
142
+ "PUT",
143
+ "http://testserver/api/v1/workers/versions/56785678-5678-5678-5678-567856785678/activity/",
144
+ ),
145
+ ]
146
+
147
+ # Check the request sent by worker
148
+ assert json.loads(responses.calls[-1].request.body) == {
149
+ "element_id": "1234-deadbeef",
150
+ "process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
151
+ "state": "processed",
152
+ }
153
+
154
+
600
155
  @pytest.mark.usefixtures("_mock_activity_calls")
601
156
  @pytest.mark.parametrize(
602
157
  ("process_exception", "final_state"),
@@ -760,38 +315,200 @@ def test_start_activity_error(
760
315
  ]
761
316
 
762
317
 
763
- @pytest.mark.usefixtures("_mock_worker_run_api")
764
- def test_update_activity(responses, mock_elements_worker):
765
- """Test an update call with feature enabled triggers an API call"""
318
+ @pytest.mark.parametrize(
319
+ (
320
+ "wk_version_config",
321
+ "wk_version_user_config",
322
+ "frontend_user_config",
323
+ "model_config",
324
+ "expected_config",
325
+ ),
326
+ [
327
+ ({}, {}, {}, {}, {}),
328
+ # Keep parameters from worker version configuration
329
+ ({"parameter": 0}, {}, {}, {}, {"parameter": 0}),
330
+ # Keep parameters from worker version configuration + user_config defaults
331
+ (
332
+ {"parameter": 0},
333
+ {
334
+ "parameter2": {
335
+ "type": "int",
336
+ "title": "Lambda",
337
+ "default": 0,
338
+ "required": False,
339
+ }
340
+ },
341
+ {},
342
+ {},
343
+ {"parameter": 0, "parameter2": 0},
344
+ ),
345
+ # Keep parameters from worker version configuration + user_config no defaults
346
+ (
347
+ {"parameter": 0},
348
+ {
349
+ "parameter2": {
350
+ "type": "int",
351
+ "title": "Lambda",
352
+ "required": False,
353
+ }
354
+ },
355
+ {},
356
+ {},
357
+ {"parameter": 0},
358
+ ),
359
+ # Keep parameters from worker version configuration but user_config defaults overrides
360
+ (
361
+ {"parameter": 0},
362
+ {
363
+ "parameter": {
364
+ "type": "int",
365
+ "title": "Lambda",
366
+ "default": 1,
367
+ "required": False,
368
+ }
369
+ },
370
+ {},
371
+ {},
372
+ {"parameter": 1},
373
+ ),
374
+ # Keep parameters from worker version configuration + frontend config
375
+ (
376
+ {"parameter": 0},
377
+ {},
378
+ {"parameter2": 0},
379
+ {},
380
+ {"parameter": 0, "parameter2": 0},
381
+ ),
382
+ # Keep parameters from worker version configuration + frontend config overrides
383
+ ({"parameter": 0}, {}, {"parameter": 1}, {}, {"parameter": 1}),
384
+ # Keep parameters from worker version configuration + model config
385
+ (
386
+ {"parameter": 0},
387
+ {},
388
+ {},
389
+ {"parameter2": 0},
390
+ {"parameter": 0, "parameter2": 0},
391
+ ),
392
+ # Keep parameters from worker version configuration + model config overrides
393
+ ({"parameter": 0}, {}, {}, {"parameter": 1}, {"parameter": 1}),
394
+ # Keep parameters from worker version configuration + user_config default + model config overrides
395
+ (
396
+ {"parameter": 0},
397
+ {
398
+ "parameter": {
399
+ "type": "int",
400
+ "title": "Lambda",
401
+ "default": 1,
402
+ "required": False,
403
+ }
404
+ },
405
+ {},
406
+ {"parameter": 2},
407
+ {"parameter": 2},
408
+ ),
409
+ # Keep parameters from worker version configuration + model config + frontend config overrides
410
+ ({"parameter": 0}, {}, {"parameter": 2}, {"parameter": 1}, {"parameter": 2}),
411
+ # Keep parameters from worker version configuration + user_config default + model config + frontend config overrides all
412
+ (
413
+ {"parameter": 0},
414
+ {
415
+ "parameter": {
416
+ "type": "int",
417
+ "title": "Lambda",
418
+ "default": 1,
419
+ "required": False,
420
+ }
421
+ },
422
+ {"parameter": 3},
423
+ {"parameter": 2},
424
+ {"parameter": 3},
425
+ ),
426
+ ],
427
+ )
428
+ def test_worker_config_multiple_source(
429
+ monkeypatch,
430
+ responses,
431
+ wk_version_config,
432
+ wk_version_user_config,
433
+ frontend_user_config,
434
+ model_config,
435
+ expected_config,
436
+ ):
437
+ # Compute WorkerRun info
438
+ payload = {
439
+ "id": "56785678-5678-5678-5678-567856785678",
440
+ "parents": [],
441
+ "worker_version": {
442
+ "id": "12341234-1234-1234-1234-123412341234",
443
+ "configuration": {
444
+ "docker": {"image": "python:3"},
445
+ "configuration": wk_version_config,
446
+ "secrets": [],
447
+ "user_configuration": wk_version_user_config,
448
+ },
449
+ "revision": {
450
+ "hash": "deadbeef1234",
451
+ "name": "some git revision",
452
+ },
453
+ "docker_image": "python:3",
454
+ "docker_image_name": "python:3",
455
+ "state": "created",
456
+ "worker": {
457
+ "id": "deadbeef-1234-5678-1234-worker",
458
+ "name": "Fake worker",
459
+ "slug": "fake_worker",
460
+ "type": "classifier",
461
+ },
462
+ },
463
+ "configuration": {
464
+ "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
465
+ "name": "Configuration entered by user",
466
+ "configuration": frontend_user_config,
467
+ },
468
+ "model_version": {
469
+ "id": "12341234-1234-1234-1234-123412341234",
470
+ "name": "Model version 1337",
471
+ "configuration": model_config,
472
+ "model": {
473
+ "id": "hahahaha-haha-haha-haha-hahahahahaha",
474
+ "name": "My model",
475
+ },
476
+ },
477
+ "process": {
478
+ "name": None,
479
+ "id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
480
+ "state": "running",
481
+ "mode": "workers",
482
+ "corpus": CORPUS_ID,
483
+ "use_cache": False,
484
+ "activity_state": "ready",
485
+ "model_id": None,
486
+ "train_folder_id": None,
487
+ "validation_folder_id": None,
488
+ "test_folder_id": None,
489
+ },
490
+ "summary": "Worker Fake worker @ 123412",
491
+ }
492
+
766
493
  responses.add(
767
- responses.PUT,
768
- "http://testserver/api/v1/workers/versions/56785678-5678-5678-5678-567856785678/activity/",
494
+ responses.GET,
495
+ "http://testserver/api/v1/process/workers/56785678-5678-5678-5678-567856785678/",
769
496
  status=200,
770
- json={
771
- "element_id": "1234-deadbeef",
772
- "process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
773
- "state": "processed",
774
- },
497
+ body=json.dumps(payload),
498
+ content_type="application/json",
775
499
  )
776
500
 
777
- out = mock_elements_worker.update_activity("1234-deadbeef", ActivityState.Processed)
501
+ # Create and configure a worker
502
+ monkeypatch.setattr(sys, "argv", ["worker"])
503
+ worker = ElementsWorker()
504
+ worker.configure()
778
505
 
779
- # Check the response received by worker
780
- assert out is True
506
+ # Do what people do with a model configuration
507
+ if worker.model_configuration:
508
+ worker.config.update(worker.model_configuration)
781
509
 
782
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
783
- assert [
784
- (call.request.method, call.request.url) for call in responses.calls
785
- ] == BASE_API_CALLS + [
786
- (
787
- "PUT",
788
- "http://testserver/api/v1/workers/versions/56785678-5678-5678-5678-567856785678/activity/",
789
- ),
790
- ]
510
+ if worker.user_configuration:
511
+ worker.config.update(worker.user_configuration)
791
512
 
792
- # Check the request sent by worker
793
- assert json.loads(responses.calls[-1].request.body) == {
794
- "element_id": "1234-deadbeef",
795
- "process_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff",
796
- "state": "processed",
797
- }
513
+ # Check final config
514
+ assert worker.config == expected_config