arkindex-base-worker 0.3.7rc5__py3-none-any.whl → 0.5.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/METADATA +18 -19
  2. arkindex_base_worker-0.5.0a1.dist-info/RECORD +61 -0
  3. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/WHEEL +1 -1
  4. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/top_level.txt +2 -0
  5. arkindex_worker/cache.py +1 -1
  6. arkindex_worker/image.py +167 -2
  7. arkindex_worker/models.py +18 -0
  8. arkindex_worker/utils.py +98 -4
  9. arkindex_worker/worker/__init__.py +117 -218
  10. arkindex_worker/worker/base.py +39 -46
  11. arkindex_worker/worker/classification.py +34 -18
  12. arkindex_worker/worker/corpus.py +86 -0
  13. arkindex_worker/worker/dataset.py +89 -26
  14. arkindex_worker/worker/element.py +352 -91
  15. arkindex_worker/worker/entity.py +13 -11
  16. arkindex_worker/worker/image.py +21 -0
  17. arkindex_worker/worker/metadata.py +26 -16
  18. arkindex_worker/worker/process.py +92 -0
  19. arkindex_worker/worker/task.py +5 -4
  20. arkindex_worker/worker/training.py +25 -10
  21. arkindex_worker/worker/transcription.py +89 -68
  22. arkindex_worker/worker/version.py +3 -1
  23. hooks/pre_gen_project.py +3 -0
  24. tests/__init__.py +8 -0
  25. tests/conftest.py +47 -58
  26. tests/test_base_worker.py +212 -12
  27. tests/test_dataset_worker.py +294 -437
  28. tests/test_elements_worker/{test_classifications.py → test_classification.py} +216 -100
  29. tests/test_elements_worker/test_cli.py +3 -11
  30. tests/test_elements_worker/test_corpus.py +168 -0
  31. tests/test_elements_worker/test_dataset.py +106 -157
  32. tests/test_elements_worker/test_element.py +427 -0
  33. tests/test_elements_worker/test_element_create_multiple.py +715 -0
  34. tests/test_elements_worker/test_element_create_single.py +528 -0
  35. tests/test_elements_worker/test_element_list_children.py +969 -0
  36. tests/test_elements_worker/test_element_list_parents.py +530 -0
  37. tests/test_elements_worker/{test_entities.py → test_entity_create.py} +37 -195
  38. tests/test_elements_worker/test_entity_list_and_check.py +160 -0
  39. tests/test_elements_worker/test_image.py +66 -0
  40. tests/test_elements_worker/test_metadata.py +252 -161
  41. tests/test_elements_worker/test_process.py +89 -0
  42. tests/test_elements_worker/test_task.py +8 -18
  43. tests/test_elements_worker/test_training.py +17 -8
  44. tests/test_elements_worker/test_transcription_create.py +873 -0
  45. tests/test_elements_worker/test_transcription_create_with_elements.py +951 -0
  46. tests/test_elements_worker/test_transcription_list.py +450 -0
  47. tests/test_elements_worker/test_version.py +60 -0
  48. tests/test_elements_worker/test_worker.py +578 -293
  49. tests/test_image.py +542 -209
  50. tests/test_merge.py +1 -2
  51. tests/test_utils.py +89 -4
  52. worker-demo/tests/__init__.py +0 -0
  53. worker-demo/tests/conftest.py +32 -0
  54. worker-demo/tests/test_worker.py +12 -0
  55. worker-demo/worker_demo/__init__.py +6 -0
  56. worker-demo/worker_demo/worker.py +19 -0
  57. arkindex_base_worker-0.3.7rc5.dist-info/RECORD +0 -41
  58. tests/test_elements_worker/test_elements.py +0 -2713
  59. tests/test_elements_worker/test_transcriptions.py +0 -2119
  60. {arkindex_base_worker-0.3.7rc5.dist-info → arkindex_base_worker-0.5.0a1.dist-info}/LICENSE +0 -0
@@ -0,0 +1,168 @@
1
+ import re
2
+ import uuid
3
+
4
+ import pytest
5
+
6
+ from arkindex.exceptions import ErrorResponse
7
+ from arkindex_worker.worker.corpus import CorpusExportState
8
+ from tests import CORPUS_ID
9
+ from tests.test_elements_worker import BASE_API_CALLS
10
+
11
+
12
+ def test_download_export_not_a_uuid(responses, mock_elements_worker):
13
+ with pytest.raises(ValueError, match="export_id is not a valid uuid."):
14
+ mock_elements_worker.download_export("mon export")
15
+
16
+
17
+ def test_download_export(responses, mock_elements_worker):
18
+ responses.add(
19
+ responses.GET,
20
+ "http://testserver/api/v1/export/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
21
+ status=302,
22
+ body=b"some SQLite export",
23
+ content_type="application/x-sqlite3",
24
+ stream=True,
25
+ )
26
+
27
+ export = mock_elements_worker.download_export(
28
+ "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"
29
+ )
30
+ assert export.name == "/tmp/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff"
31
+
32
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
33
+ assert [
34
+ (call.request.method, call.request.url) for call in responses.calls
35
+ ] == BASE_API_CALLS + [
36
+ (
37
+ "GET",
38
+ "http://testserver/api/v1/export/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeffff/",
39
+ ),
40
+ ]
41
+
42
+
43
+ def mock_list_exports_call(responses, export_id):
44
+ responses.add(
45
+ responses.GET,
46
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/",
47
+ status=200,
48
+ json={
49
+ "count": len(CorpusExportState),
50
+ "next": None,
51
+ "results": [
52
+ {
53
+ "id": str(uuid.uuid4())
54
+ if state != CorpusExportState.Done
55
+ else export_id,
56
+ "created": "2019-08-24T14:15:22Z",
57
+ "updated": "2019-08-24T14:15:22Z",
58
+ "corpus_id": CORPUS_ID,
59
+ "user": {
60
+ "id": 0,
61
+ "email": "user@example.com",
62
+ "display_name": "User",
63
+ },
64
+ "state": state.value,
65
+ "source": "default",
66
+ }
67
+ for state in CorpusExportState
68
+ ],
69
+ },
70
+ )
71
+
72
+
73
+ def test_download_latest_export_list_error(responses, mock_elements_worker):
74
+ responses.add(
75
+ responses.GET,
76
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/",
77
+ status=418,
78
+ )
79
+
80
+ with pytest.raises(
81
+ Exception, match="Stopping pagination as data will be incomplete"
82
+ ):
83
+ mock_elements_worker.download_latest_export()
84
+
85
+ assert len(responses.calls) == len(BASE_API_CALLS) + 5
86
+ assert [
87
+ (call.request.method, call.request.url) for call in responses.calls
88
+ ] == BASE_API_CALLS + [
89
+ # The API call is retried 5 times
90
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
91
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
92
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
93
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
94
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
95
+ ]
96
+
97
+
98
+ def test_download_latest_export_no_available_exports(responses, mock_elements_worker):
99
+ responses.add(
100
+ responses.GET,
101
+ f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/",
102
+ status=200,
103
+ json={
104
+ "count": 0,
105
+ "next": None,
106
+ "results": [],
107
+ },
108
+ )
109
+
110
+ with pytest.raises(
111
+ AssertionError,
112
+ match=re.escape(
113
+ f'No available exports found for the corpus ({CORPUS_ID}) with state "Done".'
114
+ ),
115
+ ):
116
+ mock_elements_worker.download_latest_export()
117
+
118
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
119
+ assert [
120
+ (call.request.method, call.request.url) for call in responses.calls
121
+ ] == BASE_API_CALLS + [
122
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
123
+ ]
124
+
125
+
126
+ def test_download_latest_export_download_error(responses, mock_elements_worker):
127
+ export_id = str(uuid.uuid4())
128
+ mock_list_exports_call(responses, export_id)
129
+ responses.add(
130
+ responses.GET,
131
+ f"http://testserver/api/v1/export/{export_id}/",
132
+ status=418,
133
+ )
134
+
135
+ with pytest.raises(ErrorResponse):
136
+ mock_elements_worker.download_latest_export()
137
+
138
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
139
+ assert [
140
+ (call.request.method, call.request.url) for call in responses.calls
141
+ ] == BASE_API_CALLS + [
142
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
143
+ ("GET", f"http://testserver/api/v1/export/{export_id}/"),
144
+ ]
145
+
146
+
147
+ def test_download_latest_export(responses, mock_elements_worker):
148
+ export_id = str(uuid.uuid4())
149
+ mock_list_exports_call(responses, export_id)
150
+ responses.add(
151
+ responses.GET,
152
+ f"http://testserver/api/v1/export/{export_id}/",
153
+ status=302,
154
+ body=b"some SQLite export",
155
+ content_type="application/x-sqlite3",
156
+ stream=True,
157
+ )
158
+
159
+ export = mock_elements_worker.download_latest_export()
160
+ assert export.name == f"/tmp/{export_id}"
161
+
162
+ assert len(responses.calls) == len(BASE_API_CALLS) + 2
163
+ assert [
164
+ (call.request.method, call.request.url) for call in responses.calls
165
+ ] == BASE_API_CALLS + [
166
+ ("GET", f"http://testserver/api/v1/corpus/{CORPUS_ID}/export/"),
167
+ ("GET", f"http://testserver/api/v1/export/{export_id}/"),
168
+ ]
@@ -2,15 +2,15 @@ import json
2
2
  import logging
3
3
 
4
4
  import pytest
5
- from apistar.exceptions import ErrorResponse
6
5
 
7
- from arkindex_worker.models import Dataset
6
+ from arkindex.exceptions import ErrorResponse
7
+ from arkindex_worker.models import Dataset, Element, Set
8
8
  from arkindex_worker.worker.dataset import DatasetState
9
- from tests.conftest import PROCESS_ID
9
+ from tests import PROCESS_ID
10
10
  from tests.test_elements_worker import BASE_API_CALLS
11
11
 
12
12
 
13
- def test_list_process_datasets_readonly_error(mock_dataset_worker):
13
+ def test_list_process_sets_readonly_error(mock_dataset_worker):
14
14
  # Set worker in read_only mode
15
15
  mock_dataset_worker.worker_run_id = None
16
16
  assert mock_dataset_worker.is_read_only
@@ -18,85 +18,91 @@ def test_list_process_datasets_readonly_error(mock_dataset_worker):
18
18
  with pytest.raises(
19
19
  AssertionError, match="This helper is not available in read-only mode."
20
20
  ):
21
- mock_dataset_worker.list_process_datasets()
21
+ mock_dataset_worker.list_process_sets()
22
22
 
23
23
 
24
- def test_list_process_datasets_api_error(responses, mock_dataset_worker):
24
+ def test_list_process_sets_api_error(responses, mock_dataset_worker):
25
25
  responses.add(
26
26
  responses.GET,
27
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
28
- status=500,
27
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
28
+ status=418,
29
29
  )
30
30
 
31
31
  with pytest.raises(
32
32
  Exception, match="Stopping pagination as data will be incomplete"
33
33
  ):
34
- next(mock_dataset_worker.list_process_datasets())
34
+ next(mock_dataset_worker.list_process_sets())
35
35
 
36
36
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
37
37
  assert [
38
38
  (call.request.method, call.request.url) for call in responses.calls
39
39
  ] == BASE_API_CALLS + [
40
40
  # The API call is retried 5 times
41
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
42
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
43
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
44
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
45
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
41
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
42
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
43
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
44
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
45
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
46
46
  ]
47
47
 
48
48
 
49
- def test_list_process_datasets(
49
+ def test_list_process_sets(
50
50
  responses,
51
51
  mock_dataset_worker,
52
52
  ):
53
53
  expected_results = [
54
54
  {
55
- "id": "process_dataset_1",
55
+ "id": "set_1",
56
56
  "dataset": {
57
57
  "id": "dataset_1",
58
58
  "name": "Dataset 1",
59
59
  "description": "My first great dataset",
60
- "sets": ["train", "val", "test"],
60
+ "sets": [
61
+ {"id": "set_1", "name": "train"},
62
+ {"id": "set_2", "name": "val"},
63
+ ],
61
64
  "state": "open",
62
65
  "corpus_id": "corpus_id",
63
66
  "creator": "test@teklia.com",
64
67
  "task_id": "task_id_1",
65
68
  },
66
- "sets": ["test"],
69
+ "set_name": "train",
67
70
  },
68
71
  {
69
- "id": "process_dataset_2",
72
+ "id": "set_2",
70
73
  "dataset": {
71
- "id": "dataset_2",
72
- "name": "Dataset 2",
73
- "description": "My second great dataset",
74
- "sets": ["train", "val"],
75
- "state": "complete",
74
+ "id": "dataset_1",
75
+ "name": "Dataset 1",
76
+ "description": "My first great dataset",
77
+ "sets": [
78
+ {"id": "set_1", "name": "train"},
79
+ {"id": "set_2", "name": "val"},
80
+ ],
81
+ "state": "open",
76
82
  "corpus_id": "corpus_id",
77
83
  "creator": "test@teklia.com",
78
- "task_id": "task_id_2",
84
+ "task_id": "task_id_1",
79
85
  },
80
- "sets": ["train", "val"],
86
+ "set_name": "val",
81
87
  },
82
88
  {
83
- "id": "process_dataset_3",
89
+ "id": "set_3",
84
90
  "dataset": {
85
- "id": "dataset_3",
86
- "name": "Dataset 3 (TRASHME)",
87
- "description": "My third dataset, in error",
88
- "sets": ["nonsense", "random set"],
89
- "state": "error",
91
+ "id": "dataset_2",
92
+ "name": "Dataset 2",
93
+ "description": "My second great dataset",
94
+ "sets": [{"id": "set_3", "name": "my_set"}],
95
+ "state": "complete",
90
96
  "corpus_id": "corpus_id",
91
97
  "creator": "test@teklia.com",
92
- "task_id": "task_id_3",
98
+ "task_id": "task_id_2",
93
99
  },
94
- "sets": ["random set"],
100
+ "set_name": "my_set",
95
101
  },
96
102
  ]
97
103
  responses.add(
98
104
  responses.GET,
99
- f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/",
105
+ f"http://testserver/api/v1/process/{PROCESS_ID}/sets/",
100
106
  status=200,
101
107
  json={
102
108
  "count": 3,
@@ -105,67 +111,54 @@ def test_list_process_datasets(
105
111
  },
106
112
  )
107
113
 
108
- for idx, dataset in enumerate(mock_dataset_worker.list_process_datasets()):
109
- assert isinstance(dataset, Dataset)
110
- assert dataset == {
111
- **expected_results[idx]["dataset"],
112
- "selected_sets": expected_results[idx]["sets"],
113
- }
114
+ for idx, dataset_set in enumerate(mock_dataset_worker.list_process_sets()):
115
+ assert isinstance(dataset_set, Set)
116
+ assert dataset_set.name == expected_results[idx]["set_name"]
117
+
118
+ assert isinstance(dataset_set.dataset, Dataset)
119
+ assert dataset_set.dataset == expected_results[idx]["dataset"]
114
120
 
115
121
  assert len(responses.calls) == len(BASE_API_CALLS) + 1
116
122
  assert [
117
123
  (call.request.method, call.request.url) for call in responses.calls
118
124
  ] == BASE_API_CALLS + [
119
- ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/datasets/"),
125
+ ("GET", f"http://testserver/api/v1/process/{PROCESS_ID}/sets/"),
120
126
  ]
121
127
 
122
128
 
123
129
  @pytest.mark.parametrize(
124
130
  ("payload", "error"),
125
131
  [
126
- # Dataset
132
+ # Set
127
133
  (
128
- {"dataset": None},
129
- "dataset shouldn't be null and should be a Dataset",
134
+ {"dataset_set": None},
135
+ "dataset_set shouldn't be null and should be a Set",
130
136
  ),
131
137
  (
132
- {"dataset": "not Dataset type"},
133
- "dataset shouldn't be null and should be a Dataset",
138
+ {"dataset_set": "not Set type"},
139
+ "dataset_set shouldn't be null and should be a Set",
134
140
  ),
135
141
  ],
136
142
  )
137
- def test_list_dataset_elements_wrong_param_dataset(mock_dataset_worker, payload, error):
143
+ def test_list_set_elements_wrong_param_dataset_set(mock_dataset_worker, payload, error):
138
144
  with pytest.raises(AssertionError, match=error):
139
- mock_dataset_worker.list_dataset_elements(**payload)
145
+ mock_dataset_worker.list_set_elements(**payload)
140
146
 
141
147
 
142
- @pytest.mark.parametrize(
143
- "sets",
144
- [
145
- ["set_1"],
146
- ["set_1", "set_2", "set_3"],
147
- ["set_1", "set_2", "set_3", "set_4"],
148
- ],
149
- )
150
- def test_list_dataset_elements_api_error(
151
- responses, mock_dataset_worker, sets, default_dataset
148
+ def test_list_set_elements_api_error(
149
+ responses, mock_dataset_worker, default_dataset, default_train_set
152
150
  ):
153
- default_dataset.selected_sets = sets
154
- query_params = (
155
- "?with_count=true"
156
- if sets == default_dataset.sets
157
- else "?set=set_1&with_count=true"
158
- )
151
+ query_params = f"?set={default_train_set.name}&with_count=true"
159
152
  responses.add(
160
153
  responses.GET,
161
154
  f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
162
- status=500,
155
+ status=418,
163
156
  )
164
157
 
165
158
  with pytest.raises(
166
159
  Exception, match="Stopping pagination as data will be incomplete"
167
160
  ):
168
- next(mock_dataset_worker.list_dataset_elements(dataset=default_dataset))
161
+ next(mock_dataset_worker.list_set_elements(dataset_set=default_train_set))
169
162
 
170
163
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
171
164
  assert [
@@ -195,99 +188,60 @@ def test_list_dataset_elements_api_error(
195
188
  ]
196
189
 
197
190
 
198
- @pytest.mark.parametrize(
199
- "sets",
200
- [
201
- ["set_1"],
202
- ["set_1", "set_2", "set_3"],
203
- ["set_1", "set_2", "set_3", "set_4"],
204
- ],
205
- )
206
- def test_list_dataset_elements(
191
+ def test_list_set_elements(
207
192
  responses,
208
193
  mock_dataset_worker,
209
- sets,
210
194
  default_dataset,
195
+ default_train_set,
211
196
  ):
212
- default_dataset.selected_sets = sets
213
-
214
- dataset_elements = []
215
- for split in default_dataset.sets:
216
- index = split[-1]
217
- dataset_elements.append(
218
- {
219
- "set": split,
220
- "element": {
221
- "id": str(index) * 4,
222
- "type": "page",
223
- "name": f"Test {index}",
224
- "corpus": {},
225
- "thumbnail_url": None,
226
- "zone": {},
227
- "best_classes": None,
228
- "has_children": None,
229
- "worker_version_id": None,
230
- "worker_run_id": None,
231
- },
232
- }
233
- )
234
- if split == "set_1":
235
- dataset_elements.append({**dataset_elements[-1]})
236
- dataset_elements[-1]["element"]["name"] = f"Test {index} (bis)"
237
-
238
- # All sets are selected, we call the unfiltered endpoint once
239
- if default_dataset.sets == default_dataset.selected_sets:
240
- expected_results = dataset_elements
241
- responses.add(
242
- responses.GET,
243
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
244
- status=200,
245
- json={
246
- "count": len(expected_results),
247
- "next": None,
248
- "results": expected_results,
197
+ expected_results = [
198
+ {
199
+ "set": "train",
200
+ "element": {
201
+ "id": "element_1",
202
+ "type": "page",
203
+ "name": "1",
204
+ "corpus": {},
205
+ "thumbnail_url": None,
206
+ "zone": {},
207
+ "best_classes": None,
208
+ "has_children": None,
209
+ "worker_version_id": None,
210
+ "worker_run_id": None,
249
211
  },
250
- )
251
- expected_calls = [
252
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true"
253
- ]
254
-
255
- # Not all sets are selected, we call the filtered endpoint multiple times, once per set
256
- else:
257
- expected_results, expected_calls = [], []
258
- for selected_set in default_dataset.selected_sets:
259
- partial_results = [
260
- element
261
- for element in dataset_elements
262
- if element["set"] == selected_set
263
- ]
264
- expected_results += partial_results
265
- responses.add(
266
- responses.GET,
267
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
268
- status=200,
269
- json={
270
- "count": len(partial_results),
271
- "next": None,
272
- "results": partial_results,
273
- },
274
- )
275
- expected_calls += [
276
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true"
277
- ]
212
+ }
213
+ ]
214
+ expected_results.append({**expected_results[-1]})
215
+ expected_results[-1]["element"]["id"] = "element_2"
216
+ expected_results[-1]["element"]["name"] = "2"
217
+
218
+ query_params = f"?set={default_train_set.name}&with_count=true"
219
+ responses.add(
220
+ responses.GET,
221
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
222
+ status=200,
223
+ json={
224
+ "count": 2,
225
+ "next": None,
226
+ "results": expected_results,
227
+ },
228
+ )
278
229
 
279
230
  for idx, element in enumerate(
280
- mock_dataset_worker.list_dataset_elements(dataset=default_dataset)
231
+ mock_dataset_worker.list_set_elements(dataset_set=default_train_set)
281
232
  ):
282
- assert element == (
283
- expected_results[idx]["set"],
284
- expected_results[idx]["element"],
285
- )
233
+ assert isinstance(element, Element)
234
+ assert element == expected_results[idx]["element"]
286
235
 
287
- assert len(responses.calls) == len(BASE_API_CALLS) + len(expected_calls)
236
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
288
237
  assert [
289
238
  (call.request.method, call.request.url) for call in responses.calls
290
- ] == BASE_API_CALLS + [("GET", expected_call) for expected_call in expected_calls]
239
+ ] == BASE_API_CALLS + [
240
+ (
241
+ "GET",
242
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/{query_params}",
243
+ )
244
+ ]
291
245
 
292
246
 
293
247
  @pytest.mark.parametrize(
@@ -367,7 +321,7 @@ def test_update_dataset_state_api_error(
367
321
  responses.add(
368
322
  responses.PATCH,
369
323
  f"http://testserver/api/v1/datasets/{default_dataset.id}/",
370
- status=500,
324
+ status=418,
371
325
  )
372
326
 
373
327
  with pytest.raises(ErrorResponse):
@@ -376,16 +330,11 @@ def test_update_dataset_state_api_error(
376
330
  state=DatasetState.Building,
377
331
  )
378
332
 
379
- assert len(responses.calls) == len(BASE_API_CALLS) + 5
333
+ assert len(responses.calls) == len(BASE_API_CALLS) + 1
380
334
  assert [
381
335
  (call.request.method, call.request.url) for call in responses.calls
382
336
  ] == BASE_API_CALLS + [
383
- # We retry 5 times the API call
384
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
385
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
386
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
387
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
388
- ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/"),
337
+ ("PATCH", f"http://testserver/api/v1/datasets/{default_dataset.id}/")
389
338
  ]
390
339
 
391
340