arkindex-base-worker 0.3.7rc3__tar.gz → 0.3.7rc5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/PKG-INFO +1 -1
  2. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_base_worker.egg-info/PKG-INFO +1 -1
  3. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/image.py +26 -19
  4. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/models.py +2 -2
  5. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/utils.py +4 -3
  6. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/__init__.py +9 -6
  7. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/base.py +1 -0
  8. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/classification.py +18 -18
  9. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/dataset.py +14 -8
  10. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/element.py +1 -0
  11. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/metadata.py +1 -1
  12. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/version.py +1 -0
  13. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/pyproject.toml +5 -3
  14. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_dataset_worker.py +59 -105
  15. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_classifications.py +365 -539
  16. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_dataset.py +97 -103
  17. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_elements.py +26 -14
  18. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_transcriptions.py +15 -8
  19. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_worker.py +5 -4
  20. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_image.py +37 -0
  21. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/LICENSE +0 -0
  22. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/README.md +0 -0
  23. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_base_worker.egg-info/SOURCES.txt +0 -0
  24. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_base_worker.egg-info/dependency_links.txt +0 -0
  25. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_base_worker.egg-info/requires.txt +0 -0
  26. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_base_worker.egg-info/top_level.txt +0 -0
  27. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/__init__.py +0 -0
  28. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/cache.py +0 -0
  29. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/entity.py +0 -0
  30. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/task.py +0 -0
  31. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/training.py +0 -0
  32. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/arkindex_worker/worker/transcription.py +0 -0
  33. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/docs-requirements.txt +0 -0
  34. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/requirements.txt +0 -0
  35. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/setup.cfg +0 -0
  36. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/setup.py +0 -0
  37. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/__init__.py +0 -0
  38. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/conftest.py +0 -0
  39. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_base_worker.py +0 -0
  40. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_cache.py +0 -0
  41. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_element.py +0 -0
  42. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/__init__.py +0 -0
  43. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_cli.py +0 -0
  44. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_entities.py +0 -0
  45. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_metadata.py +0 -0
  46. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_task.py +0 -0
  47. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_elements_worker/test_training.py +0 -0
  48. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_merge.py +0 -0
  49. {arkindex-base-worker-0.3.7rc3 → arkindex-base-worker-0.3.7rc5}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arkindex-base-worker
3
- Version: 0.3.7rc3
3
+ Version: 0.3.7rc5
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arkindex-base-worker
3
- Version: 0.3.7rc3
3
+ Version: 0.3.7rc5
4
4
  Summary: Base Worker to easily build Arkindex ML workflows
5
5
  Author-email: Teklia <contact@teklia.com>
6
6
  Maintainer-email: Teklia <contact@teklia.com>
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Helper methods to download and open IIIF images, and manage polygons.
3
3
  """
4
+
4
5
  import re
5
6
  from collections import namedtuple
6
7
  from io import BytesIO
@@ -114,32 +115,38 @@ def download_image(url: str) -> Image:
114
115
  )
115
116
  else:
116
117
  raise e
117
- except requests.exceptions.SSLError:
118
- logger.warning(
119
- "An SSLError occurred during image download, retrying with a weaker and unsafe SSL configuration"
120
- )
121
-
122
- # Saving current ciphers
123
- previous_ciphers = requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS
124
-
125
- # Downgrading ciphers to download the image
126
- requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL:@SECLEVEL=1"
127
- resp = _retried_request(url)
128
-
129
- # Restoring previous ciphers
130
- requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = previous_ciphers
131
118
 
132
119
  # Preprocess the image and prepare it for classification
133
120
  image = Image.open(BytesIO(resp.content))
134
121
  logger.info(
135
- "Downloaded image {} - size={}x{} in {}".format(
136
- url, image.size[0], image.size[1], resp.elapsed
137
- )
122
+ f"Downloaded image {url} - size={image.size[0]}x{image.size[1]} in {resp.elapsed}"
138
123
  )
139
124
 
140
125
  return image
141
126
 
142
127
 
128
+ def upload_image(image: Image, url: str) -> requests.Response:
129
+ """
130
+ Upload a Pillow image to a URL.
131
+
132
+ :param image: Pillow image to upload.
133
+ :param url: Destination URL.
134
+ :returns: The upload response.
135
+ """
136
+ assert url.startswith("http"), "Destination URL for the image must be HTTP(S)"
137
+
138
+ # Retrieve a binarized version of the image
139
+ image_bytes = BytesIO()
140
+ image.save(image_bytes, format="jpeg")
141
+ image_bytes.seek(0)
142
+
143
+ # Upload the image
144
+ resp = _retried_request(url, method=requests.put, data=image_bytes)
145
+ logger.info(f"Uploaded image to {url} in {resp.elapsed}")
146
+
147
+ return resp
148
+
149
+
143
150
  def polygon_bounding_box(polygon: list[list[int | float]]) -> BoundingBox:
144
151
  """
145
152
  Compute the rectangle bounding box of a polygon.
@@ -167,8 +174,8 @@ def _retry_log(retry_state, *args, **kwargs):
167
174
  before_sleep=_retry_log,
168
175
  reraise=True,
169
176
  )
170
- def _retried_request(url):
171
- resp = requests.get(url, timeout=DOWNLOAD_TIMEOUT)
177
+ def _retried_request(url, *args, method=requests.get, **kwargs):
178
+ resp = method(url, *args, timeout=DOWNLOAD_TIMEOUT, **kwargs)
172
179
  resp.raise_for_status()
173
180
  return resp
174
181
 
@@ -75,10 +75,10 @@ class Element(MagicDict):
75
75
 
76
76
  def image_url(self, size: str = "full") -> str | None:
77
77
  """
78
- Build an URL to access the image.
78
+ Build a URL to access the image.
79
79
  When possible, will return the S3 URL for images, so an ML worker can bypass IIIF servers.
80
80
  :param size: Subresolution of the image, following the syntax of the IIIF resize parameter.
81
- :returns: An URL to the image, or None if the element does not have an image.
81
+ :returns: A URL to the image, or None if the element does not have an image.
82
82
  """
83
83
  if not self.get("zone"):
84
84
  return
@@ -31,9 +31,10 @@ def decompress_zst_archive(compressed_archive: Path) -> tuple[int, Path]:
31
31
 
32
32
  logger.debug(f"Uncompressing file to {archive_path}")
33
33
  try:
34
- with compressed_archive.open("rb") as compressed, archive_path.open(
35
- "wb"
36
- ) as decompressed:
34
+ with (
35
+ compressed_archive.open("rb") as compressed,
36
+ archive_path.open("wb") as decompressed,
37
+ ):
37
38
  dctx.copy_stream(compressed, decompressed)
38
39
  logger.debug(f"Successfully uncompressed archive {compressed_archive}")
39
40
  except zstandard.ZstdError as e:
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Base classes to implement Arkindex workers.
3
3
  """
4
+
4
5
  import contextlib
5
6
  import json
6
7
  import os
@@ -229,12 +230,13 @@ class ElementsWorker(
229
230
  with contextlib.suppress(Exception):
230
231
  self.update_activity(element.id, ActivityState.Error)
231
232
 
233
+ message = f'Ran on {count} element{"s"[:count>1]}: {count - failed} completed, {failed} failed'
232
234
  if failed:
233
- logger.error(
234
- f"Ran on {count} elements: {count - failed} completed, {failed} failed"
235
- )
235
+ logger.error(message)
236
236
  if failed >= count: # Everything failed!
237
237
  sys.exit(1)
238
+ else:
239
+ logger.info(message)
238
240
 
239
241
  def process_element(self, element: Element | CachedElement):
240
242
  """
@@ -504,9 +506,10 @@ class DatasetWorker(BaseWorker, DatasetMixin, TaskMixin):
504
506
  if dataset_artifact:
505
507
  dataset_artifact.unlink(missing_ok=True)
506
508
 
509
+ message = f'Ran on {count} dataset{"s"[:count>1]}: {count - failed} completed, {failed} failed'
507
510
  if failed:
508
- logger.error(
509
- f"Ran on {count} datasets: {count - failed} completed, {failed} failed"
510
- )
511
+ logger.error(message)
511
512
  if failed >= count: # Everything failed!
512
513
  sys.exit(1)
514
+ else:
515
+ logger.info(message)
@@ -1,6 +1,7 @@
1
1
  """
2
2
  The base class for all Arkindex workers.
3
3
  """
4
+
4
5
  import argparse
5
6
  import json
6
7
  import logging
@@ -2,8 +2,6 @@
2
2
  ElementsWorker methods for classifications and ML classes.
3
3
  """
4
4
 
5
- from uuid import UUID
6
-
7
5
  from apistar.exceptions import ErrorResponse
8
6
  from peewee import IntegrityError
9
7
 
@@ -178,10 +176,14 @@ class ClassificationMixin:
178
176
  Create multiple classifications at once on the given element through the API.
179
177
 
180
178
  :param element: The element to create classifications on.
181
- :param classifications: The classifications to create, a list of dicts. Each of them contains
182
- a **ml_class_id** (str), the ID of the MLClass for this classification;
183
- a **confidence** (float), the confidence score, between 0 and 1;
184
- a **high_confidence** (bool), the high confidence state of the classification.
179
+ :param classifications: A list of dicts representing a classification each, with the following keys:
180
+
181
+ ml_class (str)
182
+ Required. Name of the MLClass to use.
183
+ confidence (float)
184
+ Required. Confidence score for the classification. Must be between 0 and 1.
185
+ high_confidence (bool)
186
+ Optional. Whether or not the classification is of high confidence.
185
187
 
186
188
  :returns: List of created classifications, as returned in the ``classifications`` field by
187
189
  the ``CreateClassifications`` API endpoint.
@@ -194,18 +196,10 @@ class ClassificationMixin:
194
196
  ), "classifications shouldn't be null and should be of type list"
195
197
 
196
198
  for index, classification in enumerate(classifications):
197
- ml_class_id = classification.get("ml_class_id")
199
+ ml_class = classification.get("ml_class")
198
200
  assert (
199
- ml_class_id and isinstance(ml_class_id, str)
200
- ), f"Classification at index {index} in classifications: ml_class_id shouldn't be null and should be of type str"
201
-
202
- # Make sure it's a valid UUID
203
- try:
204
- UUID(ml_class_id)
205
- except ValueError as e:
206
- raise ValueError(
207
- f"Classification at index {index} in classifications: ml_class_id is not a valid uuid."
208
- ) from e
201
+ ml_class and isinstance(ml_class, str)
202
+ ), f"Classification at index {index} in classifications: ml_class shouldn't be null and should be of type str"
209
203
 
210
204
  confidence = classification.get("confidence")
211
205
  assert (
@@ -231,7 +225,13 @@ class ClassificationMixin:
231
225
  body={
232
226
  "parent": str(element.id),
233
227
  "worker_run_id": self.worker_run_id,
234
- "classifications": classifications,
228
+ "classifications": [
229
+ {
230
+ **classification,
231
+ "ml_class": self.get_ml_class_id(classification["ml_class"]),
232
+ }
233
+ for classification in classifications
234
+ ],
235
235
  },
236
236
  )["classifications"]
237
237
 
@@ -51,7 +51,7 @@ class DatasetMixin:
51
51
 
52
52
  return map(
53
53
  lambda result: Dataset(**result["dataset"], selected_sets=result["sets"]),
54
- list(results),
54
+ results,
55
55
  )
56
56
 
57
57
  def list_dataset_elements(self, dataset: Dataset) -> Iterator[tuple[str, Element]]:
@@ -65,14 +65,20 @@ class DatasetMixin:
65
65
  dataset, Dataset
66
66
  ), "dataset shouldn't be null and should be a Dataset"
67
67
 
68
- results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
68
+ if dataset.sets == dataset.selected_sets:
69
+ results = self.api_client.paginate("ListDatasetElements", id=dataset.id)
70
+ else:
71
+ results = iter(
72
+ element
73
+ for selected_set in dataset.selected_sets
74
+ for element in self.api_client.paginate(
75
+ "ListDatasetElements", id=dataset.id, set=selected_set
76
+ )
77
+ )
69
78
 
70
- def format_result(result):
71
- if result["set"] not in dataset.selected_sets:
72
- return
73
- return (result["set"], Element(**result["element"]))
74
-
75
- return filter(None, map(format_result, list(results)))
79
+ return map(
80
+ lambda result: (result["set"], Element(**result["element"])), results
81
+ )
76
82
 
77
83
  @unsupported_cache
78
84
  def update_dataset_state(self, dataset: Dataset, state: DatasetState) -> Dataset:
@@ -1,6 +1,7 @@
1
1
  """
2
2
  ElementsWorker methods for elements and element types.
3
3
  """
4
+
4
5
  from collections.abc import Iterable
5
6
  from typing import NamedTuple
6
7
  from uuid import UUID
@@ -50,7 +50,7 @@ class MetaType(Enum):
50
50
 
51
51
  URL = "url"
52
52
  """
53
- A metadata with a string value that should be interpreted as an URL.
53
+ A metadata with a string value that should be interpreted as a URL.
54
54
  Only the ``http`` and ``https`` schemes are allowed.
55
55
  """
56
56
 
@@ -1,6 +1,7 @@
1
1
  """
2
2
  ElementsWorker methods for worker versions.
3
3
  """
4
+
4
5
  import functools
5
6
  from warnings import warn
6
7
 
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "arkindex-base-worker"
7
- version = "0.3.7rc3"
7
+ version = "0.3.7rc5"
8
8
  description = "Base Worker to easily build Arkindex ML workflows"
9
9
  license = { file = "LICENSE" }
10
10
  dynamic = ["dependencies", "optional-dependencies"]
@@ -41,6 +41,8 @@ optional-dependencies = { docs = { file = ["docs-requirements.txt"] } }
41
41
 
42
42
  [tool.ruff]
43
43
  exclude = [".git", "__pycache__"]
44
+
45
+ [tool.ruff.lint]
44
46
  ignore = ["E501"]
45
47
  select = [
46
48
  # pycodestyle
@@ -68,11 +70,11 @@ select = [
68
70
  "PTH",
69
71
  ]
70
72
 
71
- [tool.ruff.per-file-ignores]
73
+ [tool.ruff.lint.per-file-ignores]
72
74
  # Ignore `pytest-composite-assertion` rules of `flake8-pytest-style` linter for non-test files
73
75
  "arkindex_worker/**/*.py" = ["PT018"]
74
76
 
75
- [tool.ruff.isort]
77
+ [tool.ruff.lint.isort]
76
78
  known-first-party = ["arkindex", "arkindex_common", "arkindex_worker"]
77
79
  known-third-party = [
78
80
  "PIL",
@@ -195,7 +195,7 @@ def test_list_dataset_elements_per_split_api_error(
195
195
  ):
196
196
  responses.add(
197
197
  responses.GET,
198
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
198
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
199
199
  status=500,
200
200
  )
201
201
 
@@ -211,23 +211,23 @@ def test_list_dataset_elements_per_split_api_error(
211
211
  # The API call is retried 5 times
212
212
  (
213
213
  "GET",
214
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
214
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
215
215
  ),
216
216
  (
217
217
  "GET",
218
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
218
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
219
219
  ),
220
220
  (
221
221
  "GET",
222
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
222
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
223
223
  ),
224
224
  (
225
225
  "GET",
226
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
226
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
227
227
  ),
228
228
  (
229
229
  "GET",
230
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
230
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
231
231
  ),
232
232
  ]
233
233
 
@@ -235,110 +235,60 @@ def test_list_dataset_elements_per_split_api_error(
235
235
  def test_list_dataset_elements_per_split(
236
236
  responses, mock_dataset_worker, default_dataset
237
237
  ):
238
- expected_results = [
239
- {
240
- "set": "set_1",
241
- "element": {
242
- "id": "0000",
243
- "type": "page",
244
- "name": "Test",
245
- "corpus": {},
246
- "thumbnail_url": None,
247
- "zone": {},
248
- "best_classes": None,
249
- "has_children": None,
250
- "worker_version_id": None,
251
- "worker_run_id": None,
252
- },
253
- },
254
- {
255
- "set": "set_1",
256
- "element": {
257
- "id": "1111",
258
- "type": "page",
259
- "name": "Test 2",
260
- "corpus": {},
261
- "thumbnail_url": None,
262
- "zone": {},
263
- "best_classes": None,
264
- "has_children": None,
265
- "worker_version_id": None,
266
- "worker_run_id": None,
267
- },
268
- },
269
- {
270
- "set": "set_2",
271
- "element": {
272
- "id": "2222",
273
- "type": "page",
274
- "name": "Test 3",
275
- "corpus": {},
276
- "thumbnail_url": None,
277
- "zone": {},
278
- "best_classes": None,
279
- "has_children": None,
280
- "worker_version_id": None,
281
- "worker_run_id": None,
282
- },
283
- },
284
- {
285
- "set": "set_3",
286
- "element": {
287
- "id": "3333",
288
- "type": "page",
289
- "name": "Test 4",
290
- "corpus": {},
291
- "thumbnail_url": None,
292
- "zone": {},
293
- "best_classes": None,
294
- "has_children": None,
295
- "worker_version_id": None,
296
- "worker_run_id": None,
297
- },
298
- },
299
- # `set_4` is not in `default_dataset.selected_sets`
300
- {
301
- "set": "set_4",
302
- "element": {
303
- "id": "4444",
304
- "type": "page",
305
- "name": "Test 5",
306
- "corpus": {},
307
- "thumbnail_url": None,
308
- "zone": {},
309
- "best_classes": None,
310
- "has_children": None,
311
- "worker_version_id": None,
312
- "worker_run_id": None,
238
+ expected_results = []
239
+ for selected_set in default_dataset.selected_sets:
240
+ index = selected_set[-1]
241
+ expected_results.append(
242
+ {
243
+ "set": selected_set,
244
+ "element": {
245
+ "id": str(index) * 4,
246
+ "type": "page",
247
+ "name": f"Test {index}",
248
+ "corpus": {},
249
+ "thumbnail_url": None,
250
+ "zone": {},
251
+ "best_classes": None,
252
+ "has_children": None,
253
+ "worker_version_id": None,
254
+ "worker_run_id": None,
255
+ },
256
+ }
257
+ )
258
+ responses.add(
259
+ responses.GET,
260
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set={selected_set}&with_count=true",
261
+ status=200,
262
+ json={
263
+ "count": 1,
264
+ "next": None,
265
+ "results": [expected_results[-1]],
313
266
  },
314
- },
315
- ]
316
- responses.add(
317
- responses.GET,
318
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/",
319
- status=200,
320
- json={
321
- "count": 4,
322
- "next": None,
323
- "results": expected_results,
324
- },
325
- )
267
+ )
326
268
 
327
269
  assert list(
328
270
  mock_dataset_worker.list_dataset_elements_per_split(default_dataset)
329
271
  ) == [
330
- ("set_1", [expected_results[0]["element"], expected_results[1]["element"]]),
331
- ("set_2", [expected_results[2]["element"]]),
332
- ("set_3", [expected_results[3]["element"]]),
272
+ ("set_1", [expected_results[0]["element"]]),
273
+ ("set_2", [expected_results[1]["element"]]),
274
+ ("set_3", [expected_results[2]["element"]]),
333
275
  ]
334
276
 
335
- assert len(responses.calls) == len(BASE_API_CALLS) + 1
277
+ assert len(responses.calls) == len(BASE_API_CALLS) + 3
336
278
  assert [
337
279
  (call.request.method, call.request.url) for call in responses.calls
338
280
  ] == BASE_API_CALLS + [
339
281
  (
340
282
  "GET",
341
- f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?with_count=true",
283
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_1&with_count=true",
284
+ ),
285
+ (
286
+ "GET",
287
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_2&with_count=true",
288
+ ),
289
+ (
290
+ "GET",
291
+ f"http://testserver/api/v1/datasets/{default_dataset.id}/elements/?set=set_3&with_count=true",
342
292
  ),
343
293
  ]
344
294
 
@@ -360,7 +310,7 @@ def test_list_datasets_api_error(responses, mock_dataset_worker):
360
310
  with pytest.raises(
361
311
  Exception, match="Stopping pagination as data will be incomplete"
362
312
  ):
363
- mock_dataset_worker.list_datasets()
313
+ next(mock_dataset_worker.list_datasets())
364
314
 
365
315
  assert len(responses.calls) == len(BASE_API_CALLS) + 5
366
316
  assert [
@@ -512,7 +462,7 @@ def test_run_initial_dataset_state_error(
512
462
  if generator
513
463
  else []
514
464
  ) + [
515
- (logging.ERROR, "Ran on 1 datasets: 0 completed, 1 failed"),
465
+ (logging.ERROR, "Ran on 1 dataset: 0 completed, 1 failed"),
516
466
  ]
517
467
 
518
468
 
@@ -577,7 +527,7 @@ def test_run_update_dataset_state_api_error(
577
527
  ],
578
528
  (
579
529
  logging.ERROR,
580
- "Ran on 1 datasets: 0 completed, 1 failed",
530
+ "Ran on 1 dataset: 0 completed, 1 failed",
581
531
  ),
582
532
  ]
583
533
 
@@ -639,7 +589,7 @@ def test_run_download_dataset_artifact_api_error(
639
589
  ),
640
590
  (
641
591
  logging.ERROR,
642
- "Ran on 1 datasets: 0 completed, 1 failed",
592
+ "Ran on 1 dataset: 0 completed, 1 failed",
643
593
  ),
644
594
  ]
645
595
 
@@ -690,7 +640,7 @@ def test_run_no_downloaded_artifact_error(
690
640
  ),
691
641
  (
692
642
  logging.ERROR,
693
- "Ran on 1 datasets: 0 completed, 1 failed",
643
+ "Ran on 1 dataset: 0 completed, 1 failed",
694
644
  ),
695
645
  ]
696
646
 
@@ -792,7 +742,9 @@ def test_run(
792
742
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
793
743
  (logging.INFO, "Loaded Worker Fake worker @ 123412 from API"),
794
744
  (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
795
- ] + extra_logs
745
+ *extra_logs,
746
+ (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
747
+ ]
796
748
 
797
749
 
798
750
  @pytest.mark.parametrize(
@@ -890,4 +842,6 @@ def test_run_read_only(
890
842
  assert [(level, message) for _, level, message in caplog.record_tuples] == [
891
843
  (logging.WARNING, "Running without any extra configuration"),
892
844
  (logging.INFO, "Processing Dataset (dataset_id) (1/1)"),
893
- ] + extra_logs
845
+ *extra_logs,
846
+ (logging.INFO, "Ran on 1 dataset: 1 completed, 0 failed"),
847
+ ]