arkindex-base-worker 0.3.7rc5__py3-none-any.whl → 0.3.7rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,13 +4,13 @@ import logging
4
4
  import pytest
5
5
  from apistar.exceptions import ErrorResponse
6
6
 
7
- from arkindex_worker.models import Dataset
7
+ from arkindex_worker.models import Dataset, Element, Set
8
8
  from arkindex_worker.worker.dataset import DatasetState
9
9
  from tests.conftest import PROCESS_ID
10
10
  from tests.test_elements_worker import BASE_API_CALLS
11
11
 
12
12
 
13
- def test_list_process_datasets_readonly_error(mock_dataset_worker):
13
+ def test_list_process_sets_readonly_error(mock_dataset_worker):
14
14
  # Set worker in read_only mode
15
15
  mock_dataset_worker.worker_run_id = None
16
16
  assert mock_dataset_worker.is_read_only
@@ -18,85 +18,91 @@ def test_list_process_datasets_readonly_error(mock_dataset_worker):
18
18
  with pytest.raises(
19
19
  AssertionError, match="This helper is not available in read-only mode."
20
20
  ):
21
- mock_dataset_worker.list_process_datasets()
21
+ mock_dataset_worker.list_process_sets()
22
22
 
23
23
 
24
- def test_list_process_datasets_api_error(responses, mock_dataset_worker):
24
+ def test_list_process_sets_api_error(responses, mock_dataset_worker):
25
25
  responses.add(
26
26
  responses.GET,
27
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
27
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
28
28
  status=500,
29
29
  )
30
30
 
31
31
  with pytest.raises(
32
32
  Exception, match="Stopping pagination as data will be incomplete"
33
33
  ):
34
- next(mock_dataset_worker.list_process_datasets())
34
+ next(mock_dataset_worker.list_process_sets())
35
35
 
36
36
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
37
37
  assert [
38
38
  (call.request.method, call.request.url) for call in responses.calls
39
39
  ] == BASE_API_CALLS + [
40
40
  # The API call is retried 5 times
41
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
42
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
43
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
44
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
45
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
41
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
42
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
43
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
44
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
45
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
46
46
  ]
47
47
 
48
48
 
49
- def test_list_process_datasets(
49
+ def test_list_process_sets(
50
50
  responses,
51
51
  mock_dataset_worker,
52
52
  ):
53
53
  expected_results = [
54
54
  {
55
- "id": "process_dataset_1",
55
+ "id": "set_1",
56
56
  "dataset": {
57
57
  "id": "dataset_1",
58
58
  "name": "Dataset 1",
59
59
  "description": "My first great dataset",
60
- "sets": ["train", "val", "test"],
60
+ "sets": [
61
+ {"id": "set_1", "name": "train"},
62
+ {"id": "set_2", "name": "val"},
63
+ ],
61
64
  "state": "open",
62
65
  "corpus_id": "corpus_id",
63
66
  "creator": "test@teklia.com",
64
67
  "task_id": "task_id_1",
65
68
  },
66
- "sets": ["test"],
69
+ "set_name": "train",
67
70
  },
68
71
  {
69
- "id": "process_dataset_2",
72
+ "id": "set_2",
70
73
  "dataset": {
71
- "id": "dataset_2",
72
- "name": "Dataset 2",
73
- "description": "My second great dataset",
74
- "sets": ["train", "val"],
75
- "state": "complete",
74
+ "id": "dataset_1",
75
+ "name": "Dataset 1",
76
+ "description": "My first great dataset",
77
+ "sets": [
78
+ {"id": "set_1", "name": "train"},
79
+ {"id": "set_2", "name": "val"},
80
+ ],
81
+ "state": "open",
76
82
  "corpus_id": "corpus_id",
77
83
  "creator": "test@teklia.com",
78
- "task_id": "task_id_2",
84
+ "task_id": "task_id_1",
79
85
  },
80
- "sets": ["train", "val"],
86
+ "set_name": "val",
81
87
  },
82
88
  {
83
- "id": "process_dataset_3",
89
+ "id": "set_3",
84
90
  "dataset": {
85
- "id": "dataset_3",
86
- "name": "Dataset 3 (TRASHME)",
87
- "description": "My third dataset, in error",
88
- "sets": ["nonsense", "random set"],
89
- "state": "error",
91
+ "id": "dataset_2",
92
+ "name": "Dataset 2",
93
+ "description": "My second great dataset",
94
+ "sets": [{"id": "set_3", "name": "my_set"}],
95
+ "state": "complete",
90
96
  "corpus_id": "corpus_id",
91
97
  "creator": "test@teklia.com",
92
- "task_id": "task_id_3",
98
+ "task_id": "task_id_2",
93
99
  },
94
- "sets": ["random set"],
100
+ "set_name": "my_set",
95
101
  },
96
102
  ]
97
103
  responses.add(
98
104
  responses.GET,
99
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
105
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
100
106
  status=200,
101
107
  json={
102
108
  "count": 3,
@@ -105,57 +111,44 @@ def test_list_process_datasets(
105
111
  },
106
112
  )
107
113
 
108
- for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
109
- assert isinstance(dataset, Dataset)
110
- assert dataset == {
111
- **expected_results[idx]["dataset"],
112
- "selected_sets": expected_results[idx]["sets"],
113
- }
114
+ for idx, dataset_set in enumerate(mock_dataset_worker.list_process_sets()):
115
+ assert isinstance(dataset_set, Set)
116
+ assert dataset_set.name == expected_results[idx]["set_name"]
117
+
118
+ assert isinstance(dataset_set.dataset, Dataset)
119
+ assert dataset_set.dataset == expected_results[idx]["dataset"]
114
120
 
115
121
  assert len(responses.calls) == len(BASE_API_CALLS) + 1
116
122
  assert [
117
123
  (call.request.method, call.request.url) for call in responses.calls
118
124
  ] == BASE_API_CALLS + [
119
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
125
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
120
126
  ]
121
127
 
122
128
 
123
129
  @pytest.mark.parametrize(
124
130
  ("payload", "error"),
125
131
  [
126
- # Dataset
132
+ # Set
127
133
  (
128
- {"dataset": None},
129
- "dataset shouldn't be null and should be a Dataset",
134
+ {"dataset_set": None},
135
+ "dataset_set shouldn't be null and should be a Set",
130
136
  ),
131
137
  (
132
- {"dataset": "not Dataset type"},
133
- "dataset shouldn't be null and should be a Dataset",
138
+ {"dataset_set": "not Set type"},
139
+ "dataset_set shouldn't be null and should be a Set",
134
140
  ),
135
141
  ],
136
142
  )
137
- def test_list_dataset_elements_wrong_param_dataset(mock_dataset_worker, payload, error):
143
+ def test_list_set_elements_wrong_param_dataset_set(mock_dataset_worker, payload, error):
138
144
  with pytest.raises(AssertionError, match=error):
139
- mock_dataset_worker.list_dataset_elements(**payload)
145
+ mock_dataset_worker.list_set_elements(**payload)
140
146
 
141
147
 
142
- @pytest.mark.parametrize(
143
- "sets",
144
- [
145
- ["set_1"],
146
- ["set_1", "set_2", "set_3"],
147
- ["set_1", "set_2", "set_3", "set_4"],
148
- ],
149
- )
150
- def test_list_dataset_elements_api_error(
151
- responses, mock_dataset_worker, sets, default_dataset
148
+ def test_list_set_elements_api_error(
149
+ responses, mock_dataset_worker, default_dataset, default_train_set
152
150
  ):
153
- default_dataset.selected_sets = sets
154
- query_params = (
155
- "?with_count=true"
156
- if sets == default_dataset.sets
157
- else "?set=set_1&with_count=true"
158
- )
151
+ query_params = f"?set={default_train_set.name}&with_count=true"
159
152
  responses.add(
160
153
  responses.GET,
161
154
  f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
@@ -165,7 +158,7 @@ def test_list_dataset_elements_api_error(
165
158
  with pytest.raises(
166
159
  Exception, match="Stopping pagination as data will be incomplete"
167
160
  ):
168
- next(mock_dataset_worker.list_dataset_elements(dataset=default_dataset))
161
+ next(mock_dataset_worker.list_set_elements(dataset_set=default_train_set))
169
162
 
170
163
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
171
164
  assert [
@@ -195,99 +188,60 @@ def test_list_dataset_elements_api_error(
195
188
  ]
196
189
 
197
190
 
198
- @pytest.mark.parametrize(
199
- "sets",
200
- [
201
- ["set_1"],
202
- ["set_1", "set_2", "set_3"],
203
- ["set_1", "set_2", "set_3", "set_4"],
204
- ],
205
- )
206
- def test_list_dataset_elements(
191
+ def test_list_set_elements(
207
192
  responses,
208
193
  mock_dataset_worker,
209
- sets,
210
194
  default_dataset,
195
+ default_train_set,
211
196
  ):
212
- default_dataset.selected_sets = sets
213
-
214
- dataset_elements = []
215
- for split in default_dataset.sets:
216
- index = split[-1]
217
- dataset_elements.append(
218
- {
219
- "set": split,
220
- "element": {
221
- "id": str(index) * 4,
222
- "type": "page",
223
- "name": f"Test {index}",
224
- "corpus": {},
225
- "thumbnail_url": None,
226
- "zone": {},
227
- "best_classes": None,
228
- "has_children": None,
229
- "worker_version_id": None,
230
- "worker_run_id": None,
231
- },
232
- }
233
- )
234
- if split == "set_1":
235
- dataset_elements.append({**dataset_elements[-1]})
236
- dataset_elements[-1]["element"]["name"] = f"Test {index} (bis)"
237
-
238
- # All sets are selected, we call the unfiltered endpoint once
239
- if default_dataset.sets == default_dataset.selected_sets:
240
- expected_results = dataset_elements
241
- responses.add(
242
- responses.GET,
243
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
244
- status=200,
245
- json={
246
- "count": len(expected_results),
247
- "next": None,
248
- "results": expected_results,
197
+ expected_results = [
198
+ {
199
+ "set": "train",
200
+ "element": {
201
+ "id": "element_1",
202
+ "type": "page",
203
+ "name": "1",
204
+ "corpus": {},
205
+ "thumbnail_url": None,
206
+ "zone": {},
207
+ "best_classes": None,
208
+ "has_children": None,
209
+ "worker_version_id": None,
210
+ "worker_run_id": None,
249
211
  },
250
- )
251
- expected_calls = [
252
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true"
253
- ]
254
-
255
- # Not all sets are selected, we call the filtered endpoint multiple times, once per set
256
- else:
257
- expected_results, expected_calls = [], []
258
- for selected_set in default_dataset.selected_sets:
259
- partial_results = [
260
- element
261
- for element in dataset_elements
262
- if element["set"] == selected_set
263
- ]
264
- expected_results += partial_results
265
- responses.add(
266
- responses.GET,
267
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
268
- status=200,
269
- json={
270
- "count": len(partial_results),
271
- "next": None,
272
- "results": partial_results,
273
- },
274
- )
275
- expected_calls += [
276
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true"
277
- ]
212
+ }
213
+ ]
214
+ expected_results.append({**expected_results[-1]})
215
+ expected_results[-1]["element"]["id"] = "element_2"
216
+ expected_results[-1]["element"]["name"] = "2"
217
+
218
+ query_params = f"?set={default_train_set.name}&with_count=true"
219
+ responses.add(
220
+ responses.GET,
221
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
222
+ status=200,
223
+ json={
224
+ "count": 2,
225
+ "next": None,
226
+ "results": expected_results,
227
+ },
228
+ )
278
229
 
279
230
  for idx, element in enumerate(
280
- mock_dataset_worker.list_dataset_elements(dataset=default_dataset)
231
+ mock_dataset_worker.list_set_elements(dataset_set=default_train_set)
281
232
  ):
282
- assert element == (
283
- expected_results[idx]["set"],
284
- expected_results[idx]["element"],
285
- )
233
+ assert isinstance(element, Element)
234
+ assert element == expected_results[idx]["element"]
286
235
 
287
- assert len(responses.calls) == len(BASE_API_CALLS) + len(expected_calls)
236
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
288
237
  assert [
289
238
  (call.request.method, call.request.url) for call in responses.calls
290
- ] == BASE_API_CALLS + [("GET", expected_call) for expected_call in expected_calls]
239
+ ] == BASE_API_CALLS + [
240
+ (
241
+ "GET",
242
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
243
+ )
244
+ ]
291
245
 
292
246
 
293
247
  @pytest.mark.parametrize(
@@ -259,7 +259,7 @@ def test_create_metadata_cached_element(responses, mock_elements_worker_with_cac
259
259
  ],
260
260
  ],
261
261
  )
262
- def test_create_metadatas(responses, mock_elements_worker, metadata_list):
262
+ def test_create_metadata_bulk(responses, mock_elements_worker, metadata_list):
263
263
  element = Element({"id": "12341234-1234-1234-1234-123412341234"})
264
264
  responses.add(
265
265
  responses.POST,
@@ -280,7 +280,7 @@ def test_create_metadatas(responses, mock_elements_worker, metadata_list):
280
280
  },
281
281
  )
282
282
 
283
- created_metadata_list = mock_elements_worker.create_metadatas(
283
+ created_metadata_list = mock_elements_worker.create_metadata_bulk(
284
284
  element, metadata_list
285
285
  )
286
286
 
@@ -327,7 +327,7 @@ def test_create_metadatas(responses, mock_elements_worker, metadata_list):
327
327
  ],
328
328
  ],
329
329
  )
330
- def test_create_metadatas_cached_element(
330
+ def test_create_metadata_bulk_cached_element(
331
331
  responses, mock_elements_worker_with_cache, metadata_list
332
332
  ):
333
333
  element = CachedElement.create(
@@ -352,7 +352,7 @@ def test_create_metadatas_cached_element(
352
352
  },
353
353
  )
354
354
 
355
- created_metadata_list = mock_elements_worker_with_cache.create_metadatas(
355
+ created_metadata_list = mock_elements_worker_with_cache.create_metadata_bulk(
356
356
  element, metadata_list
357
357
  )
358
358
 
@@ -386,7 +386,7 @@ def test_create_metadatas_cached_element(
386
386
 
387
387
 
388
388
  @pytest.mark.parametrize("wrong_element", [None, "not_element_type", 1234, 12.5])
389
- def test_create_metadatas_wrong_element(mock_elements_worker, wrong_element):
389
+ def test_create_metadata_bulk_wrong_element(mock_elements_worker, wrong_element):
390
390
  wrong_metadata_list = [
391
391
  {"type": MetaType.Text, "name": "fake_name", "value": "fake_value"}
392
392
  ]
@@ -394,13 +394,13 @@ def test_create_metadatas_wrong_element(mock_elements_worker, wrong_element):
394
394
  AssertionError,
395
395
  match="element shouldn't be null and should be of type Element or CachedElement",
396
396
  ):
397
- mock_elements_worker.create_metadatas(
398
- element=wrong_element, metadatas=wrong_metadata_list
397
+ mock_elements_worker.create_metadata_bulk(
398
+ element=wrong_element, metadata_list=wrong_metadata_list
399
399
  )
400
400
 
401
401
 
402
402
  @pytest.mark.parametrize("wrong_type", [None, "not_metadata_type", 1234, 12.5])
403
- def test_create_metadatas_wrong_type(mock_elements_worker, wrong_type):
403
+ def test_create_metadata_bulk_wrong_type(mock_elements_worker, wrong_type):
404
404
  element = Element({"id": "12341234-1234-1234-1234-123412341234"})
405
405
  wrong_metadata_list = [
406
406
  {"type": wrong_type, "name": "fake_name", "value": "fake_value"}
@@ -408,13 +408,13 @@ def test_create_metadatas_wrong_type(mock_elements_worker, wrong_type):
408
408
  with pytest.raises(
409
409
  AssertionError, match="type shouldn't be null and should be of type MetaType"
410
410
  ):
411
- mock_elements_worker.create_metadatas(
412
- element=element, metadatas=wrong_metadata_list
411
+ mock_elements_worker.create_metadata_bulk(
412
+ element=element, metadata_list=wrong_metadata_list
413
413
  )
414
414
 
415
415
 
416
416
  @pytest.mark.parametrize("wrong_name", [None, 1234, 12.5, [1, 2, 3, 4]])
417
- def test_create_metadatas_wrong_name(mock_elements_worker, wrong_name):
417
+ def test_create_metadata_bulk_wrong_name(mock_elements_worker, wrong_name):
418
418
  element = Element({"id": "fake_element_id"})
419
419
  wrong_metadata_list = [
420
420
  {"type": MetaType.Text, "name": wrong_name, "value": "fake_value"}
@@ -422,13 +422,13 @@ def test_create_metadatas_wrong_name(mock_elements_worker, wrong_name):
422
422
  with pytest.raises(
423
423
  AssertionError, match="name shouldn't be null and should be of type str"
424
424
  ):
425
- mock_elements_worker.create_metadatas(
426
- element=element, metadatas=wrong_metadata_list
425
+ mock_elements_worker.create_metadata_bulk(
426
+ element=element, metadata_list=wrong_metadata_list
427
427
  )
428
428
 
429
429
 
430
430
  @pytest.mark.parametrize("wrong_value", [None, [1, 2, 3, 4]])
431
- def test_create_metadatas_wrong_value(mock_elements_worker, wrong_value):
431
+ def test_create_metadata_bulk_wrong_value(mock_elements_worker, wrong_value):
432
432
  element = Element({"id": "fake_element_id"})
433
433
  wrong_metadata_list = [
434
434
  {"type": MetaType.Text, "name": "fake_name", "value": wrong_value}
@@ -439,13 +439,13 @@ def test_create_metadatas_wrong_value(mock_elements_worker, wrong_value):
439
439
  "value shouldn't be null and should be of type (str or float or int)"
440
440
  ),
441
441
  ):
442
- mock_elements_worker.create_metadatas(
443
- element=element, metadatas=wrong_metadata_list
442
+ mock_elements_worker.create_metadata_bulk(
443
+ element=element, metadata_list=wrong_metadata_list
444
444
  )
445
445
 
446
446
 
447
447
  @pytest.mark.parametrize("wrong_entity", [[1, 2, 3, 4], 1234, 12.5])
448
- def test_create_metadatas_wrong_entity(mock_elements_worker, wrong_entity):
448
+ def test_create_metadata_bulk_wrong_entity(mock_elements_worker, wrong_entity):
449
449
  element = Element({"id": "fake_element_id"})
450
450
  wrong_metadata_list = [
451
451
  {
@@ -456,12 +456,12 @@ def test_create_metadatas_wrong_entity(mock_elements_worker, wrong_entity):
456
456
  }
457
457
  ]
458
458
  with pytest.raises(AssertionError, match="entity_id should be None or a str"):
459
- mock_elements_worker.create_metadatas(
460
- element=element, metadatas=wrong_metadata_list
459
+ mock_elements_worker.create_metadata_bulk(
460
+ element=element, metadata_list=wrong_metadata_list
461
461
  )
462
462
 
463
463
 
464
- def test_create_metadatas_api_error(responses, mock_elements_worker):
464
+ def test_create_metadata_bulk_api_error(responses, mock_elements_worker):
465
465
  element = Element({"id": "12341234-1234-1234-1234-123412341234"})
466
466
  metadata_list = [
467
467
  {
@@ -478,7 +478,7 @@ def test_create_metadatas_api_error(responses, mock_elements_worker):
478
478
  )
479
479
 
480
480
  with pytest.raises(ErrorResponse):
481
- mock_elements_worker.create_metadatas(element, metadata_list)
481
+ mock_elements_worker.create_metadata_bulk(element, metadata_list)
482
482
 
483
483
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
484
484
  assert [
@@ -0,0 +1,32 @@
1
+ import os
2
+
3
+ import pytest
4
+
5
+ from arkindex.mock import MockApiClient
6
+ from arkindex_worker.worker.base import BaseWorker
7
+
8
+
9
+ @pytest.fixture(autouse=True)
10
+ def _setup_environment(responses, monkeypatch) -> None:
11
+ """Setup needed environment variables"""
12
+
13
+ # Allow accessing remote API schemas
14
+ # defaulting to the prod environment
15
+ schema_url = os.environ.get(
16
+ "ARKINDEX_API_SCHEMA_URL",
17
+ "https://demo.arkindex.org/api/v1/openapi/?format=openapi-json",
18
+ )
19
+ responses.add_passthru(schema_url)
20
+
21
+ # Set schema url in environment
22
+ os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
23
+ # Setup a fake worker run ID
24
+ os.environ["ARKINDEX_WORKER_RUN_ID"] = "1234-demo"
25
+ # Setup a fake corpus ID
26
+ os.environ["ARKINDEX_CORPUS_ID"] = "1234-corpus-id"
27
+
28
+ # Setup a mock api client instead of using a real one
29
+ def mock_setup_api_client(self):
30
+ self.api_client = MockApiClient()
31
+
32
+ monkeypatch.setattr(BaseWorker, "setup_api_client", mock_setup_api_client)
@@ -0,0 +1,12 @@
1
+ import importlib
2
+
3
+
4
+ def test_dummy():
5
+ assert True
6
+
7
+
8
+ def test_import():
9
+ """Import our newly created module, through importlib to avoid parsing issues"""
10
+ worker = importlib.import_module("worker_demo.worker")
11
+ assert hasattr(worker, "Demo")
12
+ assert hasattr(worker.Demo, "process_element")
@@ -0,0 +1,6 @@
1
+ import logging
2
+
3
+ logging.basicConfig(
4
+ level=logging.INFO,
5
+ format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
6
+ )
@@ -0,0 +1,19 @@
1
+ from logging import Logger, getLogger
2
+
3
+ from arkindex_worker.models import Element
4
+ from arkindex_worker.worker import ElementsWorker
5
+
6
+ logger: Logger = getLogger(__name__)
7
+
8
+
9
+ class Demo(ElementsWorker):
10
+ def process_element(self, element: Element) -> None:
11
+ logger.info(f"Demo processing element ({element.id})")
12
+
13
+
14
+ def main() -> None:
15
+ Demo(description="Demo ML worker for Arkindex").run()
16
+
17
+
18
+ if __name__ == "__main__":
19
+ main()