datago 2025.12.2__tar.gz → 2026.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {datago-2025.12.2 → datago-2026.1.2}/.github/workflows/ci-cd.yml +2 -2
  2. {datago-2025.12.2 → datago-2026.1.2}/.github/workflows/rust.yml +1 -1
  3. {datago-2025.12.2 → datago-2026.1.2}/Cargo.lock +1 -1
  4. {datago-2025.12.2 → datago-2026.1.2}/Cargo.toml +1 -1
  5. {datago-2025.12.2 → datago-2026.1.2}/PKG-INFO +36 -8
  6. {datago-2025.12.2 → datago-2026.1.2}/README.md +35 -7
  7. datago-2026.1.2/assets/epyc_wds_pd12m.png +0 -0
  8. datago-2026.1.2/assets/zen3_wds_fakein.png.png +0 -0
  9. datago-2026.1.2/assets/zen3_wds_pd12m.png +0 -0
  10. datago-2026.1.2/assets/zen3_wds_pd12m_processing.png +0 -0
  11. {datago-2025.12.2 → datago-2026.1.2}/python/benchmark_webdataset.py +101 -20
  12. datago-2026.1.2/python/test_datago_wds.py +250 -0
  13. datago-2026.1.2/requirements-dev.txt +3 -0
  14. {datago-2025.12.2 → datago-2026.1.2}/src/client.rs +6 -4
  15. {datago-2025.12.2 → datago-2026.1.2}/src/generator_wds.rs +42 -19
  16. {datago-2025.12.2 → datago-2026.1.2}/src/worker_wds.rs +69 -33
  17. datago-2025.12.2/assets/epyc_wds.png +0 -0
  18. {datago-2025.12.2 → datago-2026.1.2}/.gitignore +0 -0
  19. {datago-2025.12.2 → datago-2026.1.2}/.pre-commit-config.yaml +0 -0
  20. {datago-2025.12.2 → datago-2026.1.2}/LICENSE +0 -0
  21. {datago-2025.12.2 → datago-2026.1.2}/assets/447175851-2277afcb-8abf-4d17-b2db-dae27c6056d0.png +0 -0
  22. {datago-2025.12.2 → datago-2026.1.2}/assets/epyc_vast.png +0 -0
  23. {datago-2025.12.2 → datago-2026.1.2}/assets/zen3_ssd.png +0 -0
  24. {datago-2025.12.2 → datago-2026.1.2}/pyproject.toml +0 -0
  25. {datago-2025.12.2 → datago-2026.1.2}/python/benchmark_db.py +0 -0
  26. {datago-2025.12.2 → datago-2026.1.2}/python/benchmark_defaults.py +0 -0
  27. {datago-2025.12.2 → datago-2026.1.2}/python/benchmark_filesystem.py +0 -0
  28. {datago-2025.12.2 → datago-2026.1.2}/python/dataset.py +0 -0
  29. {datago-2025.12.2 → datago-2026.1.2}/python/raw_types.py +0 -0
  30. {datago-2025.12.2 → datago-2026.1.2}/python/test_datago_client.py +0 -0
  31. {datago-2025.12.2 → datago-2026.1.2}/python/test_datago_db.py +0 -0
  32. {datago-2025.12.2 → datago-2026.1.2}/python/test_datago_edge_cases.py +0 -0
  33. {datago-2025.12.2 → datago-2026.1.2}/python/test_datago_filesystem.py +0 -0
  34. {datago-2025.12.2 → datago-2026.1.2}/python/test_pil_implicit_conversion.py +0 -0
  35. {datago-2025.12.2 → datago-2026.1.2}/requirements-tests.txt +0 -0
  36. {datago-2025.12.2 → datago-2026.1.2}/requirements.txt +0 -0
  37. {datago-2025.12.2 → datago-2026.1.2}/src/generator_files.rs +0 -0
  38. {datago-2025.12.2 → datago-2026.1.2}/src/generator_http.rs +0 -0
  39. {datago-2025.12.2 → datago-2026.1.2}/src/image_processing.rs +0 -0
  40. {datago-2025.12.2 → datago-2026.1.2}/src/lib.rs +0 -0
  41. {datago-2025.12.2 → datago-2026.1.2}/src/main.rs +0 -0
  42. {datago-2025.12.2 → datago-2026.1.2}/src/structs.rs +0 -0
  43. {datago-2025.12.2 → datago-2026.1.2}/src/worker_files.rs +0 -0
  44. {datago-2025.12.2 → datago-2026.1.2}/src/worker_http.rs +0 -0
@@ -22,7 +22,7 @@ jobs:
22
22
  platform:
23
23
  - runner: ubuntu-latest
24
24
  target: x86_64
25
- python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
25
+ python-version: ['3.10', '3.11', '3.12', '3.13']
26
26
 
27
27
  environment:
28
28
  name: release
@@ -48,7 +48,7 @@ jobs:
48
48
 
49
49
  - name: Build the package
50
50
  run: |
51
- maturin build -i python${{ matrix.python-version }} --release --out dist --target "x86_64-unknown-linux-gnu" --manylinux 2014 --zig
51
+ maturin build -i python${{ matrix.python-version }} --release --out dist --target "x86_64-unknown-linux-gnu" --manylinux 2_31 --zig
52
52
 
53
53
  - name: Test package
54
54
  env:
@@ -51,4 +51,4 @@ jobs:
51
51
  DATAROOM_TEST_SOURCE: ${{ secrets.DATAROOM_TEST_SOURCE }}
52
52
  DATAROOM_API_URL: ${{ secrets.DATAROOM_API_URL }}
53
53
 
54
- run: cargo test --verbose
54
+ run: RUST_BACKTRACE=1 cargo test --verbose
@@ -623,7 +623,7 @@ dependencies = [
623
623
 
624
624
  [[package]]
625
625
  name = "datago"
626
- version = "2025.12.2"
626
+ version = "2026.1.2"
627
627
  dependencies = [
628
628
  "async-compression",
629
629
  "async-tar",
@@ -1,7 +1,7 @@
1
1
  [package]
2
2
  name = "datago"
3
3
  edition = "2021"
4
- version = "2025.12.2"
4
+ version = "2026.1.2"
5
5
  readme = "README.md"
6
6
 
7
7
  [lib]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datago
3
- Version: 2025.12.2
3
+ Version: 2026.1.2
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -133,7 +133,7 @@ client_config = {
133
133
  "source_config": {
134
134
  "url": url,
135
135
  "random_sampling": False,
136
- "max_concurrency": 8, # The number of TarballSamples which should be handled concurrently
136
+ "concurrent_downloads": 8, # The number of TarballSamples which should be handled concurrently
137
137
  "rank": 0,
138
138
  "world_size": 1,
139
139
  },
@@ -267,18 +267,46 @@ Create a new tag and a new release in this repo, a new package will be pushed au
267
267
  <details> <summary><strong>Benchmarks</strong></summary>
268
268
  As usual, benchmarks are a tricky game, and you shouldn't read too much into the following plots but do your own tests. Some python benchmark examples are provided in the [python](./python/) folder.
269
269
 
270
- In general, Datago will be impactful if you want to load a lot of images very fast, but if you consume them as you go at a more leisury pace then it's not really needed. The more CPU work there is with the images and the higher quality they are, the more Datago will shine. The following benchmarks are using ImageNet 1k, which is very low resolution and thus kind of a worst case scenario. Data is served from cache (i.e. the OS cache) and the images are not pre-processed. In this case the receiving python process is typically the bottleneck, and caps at around 3000 images per second.
270
+ In general, Datago will be impactful if you want to load a lot of images very fast, but if you consume them as you go at a more leisury pace then it's not really needed. The more CPU work there is with the images and the higher quality they are, the more Datago will shine.
271
271
 
272
- ### AMD Zen3 laptop - IN1k - disk
272
+ ## From disk: ImageNet
273
+
274
+ The following benchmarks are using ImageNet 1k, which is very low resolution and thus kind of a worst case scenario. Data is served from cache (i.e. the OS cache) and the images are not pre-processed. In this case the receiving python process is typically the bottleneck, and caps at around 3000 images per second.
275
+
276
+ ### AMD Zen3 laptop - IN1k - disk - no processing
273
277
  ![AMD Zen3 laptop & M2 SSD](assets/zen3_ssd.png)
274
278
 
275
- ### AMD EPYC 9454 - IN1k - disk
279
+ ### AMD EPYC 9454 - IN1k - disk - no processing
276
280
  ![AMD EPYC 9454](assets/epyc_vast.png)
277
281
 
278
- This benchmark is using the PD12M dataset, which hosts high resolution images. It's accessed through the webdataset front end, datago is compared with the popular python webdataset library. Note that datago will start streaming the images faster here (almost instantly !), so given enough time the two results would look closer.
282
+ ## Webdataset: FakeIN
283
+
284
+ This benchmark is using low resolution images. It's accessed through the webdataset front end, datago is compared with the popular python webdataset library. Note that datago will start streaming the images faster here (almost instantly !), which emphasizes throughput differences depending on how long you test it for.
285
+
286
+ Of note is also that this can be bottlenecked by your external bandwidth to the remote storage where WDS is hosted, in which case both solution would yield comparable numbers.
287
+
288
+ ### AMD Zen3 laptop - webdataset - no processing
289
+ ![AMD EPYC 9454](assets/zen3_wds_fakein.png)
290
+
291
+
292
+ ## Webdataset: PD12M
293
+
294
+ This benchmark is using high resolution images. It's accessed through the webdataset front end, datago is compared with the popular python webdataset library. Note that datago will start streaming the images faster here (almost instantly !), which emphasizes throughput differences depending on how long you test it for.
295
+
296
+ Of note is also that this can be bottlenecked by your external bandwidth to the remote storage where WDS is hosted, in which case both solution would yield comparable numbers.
297
+
298
+ ### AMD Zen3 laptop - webdataset - no processing
299
+ ![AMD Zen3 laptop](assets/zen3_wds_pd12m.png)
300
+
301
+
302
+ ### AMD EPYC 9454 - pd12m - webdataset - no processing
303
+ ![AMD EPYC 9454](assets/epyc_wds_pd12m.png)
304
+
305
+
306
+ ### AMD Zen3 laptop - webdataset - processing
307
+ Adding image processing (crop and resize to Transformer compatible size buckets) to the equation changes the picture, as the work spread becomes more important. If you're training a diffusion model or an image encoder from a diverse set of images, this is likely to be the most realistic micro-benchmark.
279
308
 
280
- ### AMD EPYC 9454 - pd12m - webdataset
281
- ![AMD EPYC 9454](assets/epyc_wds.png)
309
+ ![AMD Zen3 laptop](assets/zen3_wds_pd12m_processing.png)
282
310
 
283
311
  </details>
284
312
 
@@ -116,7 +116,7 @@ client_config = {
116
116
  "source_config": {
117
117
  "url": url,
118
118
  "random_sampling": False,
119
- "max_concurrency": 8, # The number of TarballSamples which should be handled concurrently
119
+ "concurrent_downloads": 8, # The number of TarballSamples which should be handled concurrently
120
120
  "rank": 0,
121
121
  "world_size": 1,
122
122
  },
@@ -250,18 +250,46 @@ Create a new tag and a new release in this repo, a new package will be pushed au
250
250
  <details> <summary><strong>Benchmarks</strong></summary>
251
251
  As usual, benchmarks are a tricky game, and you shouldn't read too much into the following plots but do your own tests. Some python benchmark examples are provided in the [python](./python/) folder.
252
252
 
253
- In general, Datago will be impactful if you want to load a lot of images very fast, but if you consume them as you go at a more leisury pace then it's not really needed. The more CPU work there is with the images and the higher quality they are, the more Datago will shine. The following benchmarks are using ImageNet 1k, which is very low resolution and thus kind of a worst case scenario. Data is served from cache (i.e. the OS cache) and the images are not pre-processed. In this case the receiving python process is typically the bottleneck, and caps at around 3000 images per second.
253
+ In general, Datago will be impactful if you want to load a lot of images very fast, but if you consume them as you go at a more leisury pace then it's not really needed. The more CPU work there is with the images and the higher quality they are, the more Datago will shine.
254
254
 
255
- ### AMD Zen3 laptop - IN1k - disk
255
+ ## From disk: ImageNet
256
+
257
+ The following benchmarks are using ImageNet 1k, which is very low resolution and thus kind of a worst case scenario. Data is served from cache (i.e. the OS cache) and the images are not pre-processed. In this case the receiving python process is typically the bottleneck, and caps at around 3000 images per second.
258
+
259
+ ### AMD Zen3 laptop - IN1k - disk - no processing
256
260
  ![AMD Zen3 laptop & M2 SSD](assets/zen3_ssd.png)
257
261
 
258
- ### AMD EPYC 9454 - IN1k - disk
262
+ ### AMD EPYC 9454 - IN1k - disk - no processing
259
263
  ![AMD EPYC 9454](assets/epyc_vast.png)
260
264
 
261
- This benchmark is using the PD12M dataset, which hosts high resolution images. It's accessed through the webdataset front end, datago is compared with the popular python webdataset library. Note that datago will start streaming the images faster here (almost instantly !), so given enough time the two results would look closer.
265
+ ## Webdataset: FakeIN
266
+
267
+ This benchmark is using low resolution images. It's accessed through the webdataset front end, datago is compared with the popular python webdataset library. Note that datago will start streaming the images faster here (almost instantly !), which emphasizes throughput differences depending on how long you test it for.
268
+
269
+ Of note is also that this can be bottlenecked by your external bandwidth to the remote storage where WDS is hosted, in which case both solution would yield comparable numbers.
270
+
271
+ ### AMD Zen3 laptop - webdataset - no processing
272
+ ![AMD EPYC 9454](assets/zen3_wds_fakein.png)
273
+
274
+
275
+ ## Webdataset: PD12M
276
+
277
+ This benchmark is using high resolution images. It's accessed through the webdataset front end, datago is compared with the popular python webdataset library. Note that datago will start streaming the images faster here (almost instantly !), which emphasizes throughput differences depending on how long you test it for.
278
+
279
+ Of note is also that this can be bottlenecked by your external bandwidth to the remote storage where WDS is hosted, in which case both solution would yield comparable numbers.
280
+
281
+ ### AMD Zen3 laptop - webdataset - no processing
282
+ ![AMD Zen3 laptop](assets/zen3_wds_pd12m.png)
283
+
284
+
285
+ ### AMD EPYC 9454 - pd12m - webdataset - no processing
286
+ ![AMD EPYC 9454](assets/epyc_wds_pd12m.png)
287
+
288
+
289
+ ### AMD Zen3 laptop - webdataset - processing
290
+ Adding image processing (crop and resize to Transformer compatible size buckets) to the equation changes the picture, as the work spread becomes more important. If you're training a diffusion model or an image encoder from a diverse set of images, this is likely to be the most realistic micro-benchmark.
262
291
 
263
- ### AMD EPYC 9454 - pd12m - webdataset
264
- ![AMD EPYC 9454](assets/epyc_wds.png)
292
+ ![AMD Zen3 laptop](assets/zen3_wds_pd12m_processing.png)
265
293
 
266
294
  </details>
267
295
 
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import os
3
3
  import time
4
+ from typing import Any
4
5
 
5
6
  import typer
6
7
  from benchmark_defaults import IMAGE_CONFIG
@@ -9,45 +10,123 @@ from tqdm import tqdm
9
10
 
10
11
 
11
12
  def benchmark(
12
- limit: int = typer.Option(10, help="The number of samples to test on"),
13
+ limit: int = typer.Option(1000, help="The number of samples to test on"),
13
14
  crop_and_resize: bool = typer.Option(
14
- True, help="Crop and resize the images on the fly"
15
+ False, help="Crop and resize the images on the fly"
15
16
  ),
16
17
  compare_wds: bool = typer.Option(True, help="Compare against torch dataloader"),
18
+ num_downloads: int = typer.Option(
19
+ 32,
20
+ help="Number of concurrent downloads",
21
+ ),
17
22
  num_workers: int = typer.Option(
18
- 16,
19
- help="Number of processes to use",
23
+ 8,
24
+ help="Number of CPU workers",
25
+ ),
26
+ sweep: bool = typer.Option(False, help="Sweep over the number of workers"),
27
+ plot: bool = typer.Option(
28
+ False, help="Whether to save a plot at the end of the run"
20
29
  ),
21
- sweep: bool = typer.Option(False, help="Sweep over the number of processes"),
22
30
  ):
23
- if sweep:
24
- results = {}
25
- for num_workers in range(2, max(64, (os.cpu_count() or 1)), 8):
26
- results[num_workers] = benchmark(limit, crop_and_resize, compare_wds, num_workers, False)
27
-
28
- # Save results to a json file
29
- with open("benchmark_results_wds.json", "w") as f:
30
- json.dump(results, f, indent=2)
31
-
32
- return results
31
+ results: dict[Any, Any] = {}
32
+ if plot and not sweep:
33
+ print("Plot option only makes sense if we sweeped results, will not be used since sweep is False")
34
+ plot = False
33
35
 
34
36
  # URL of the test bucket
35
37
  # bucket = "https://storage.googleapis.com/webdataset/fake-imagenet"
36
38
  # dataset = "/imagenet-train-{000000..001281}.tar"
39
+ # source = "FakeIN"
37
40
 
38
41
  bucket = "https://huggingface.co/datasets/sayakpaul/pd12m-full/resolve/"
39
42
  dataset = "main/{00155..02480}.tar"
43
+ source = "PD12M"
44
+
40
45
  url = bucket + dataset
41
46
 
42
47
  print(
43
- f"Benchmarking Datago WDS path on {url}.\nRunning benchmark for {limit} samples"
48
+ f"Benchmarking Datago WDS path on {url}.\nRunning benchmark for {limit} samples. Source {source}"
44
49
  )
50
+
51
+ if sweep:
52
+ max_cpus = os.cpu_count() or 16
53
+
54
+ num_workers = 1
55
+ while num_workers < max_cpus:
56
+ results[num_workers] = benchmark(
57
+ limit,
58
+ crop_and_resize,
59
+ compare_wds,
60
+ num_downloads,
61
+ num_workers,
62
+ False,
63
+ False,
64
+ )
65
+ num_workers *= 2
66
+
67
+ # Save results to a json file
68
+ with open("benchmark_results_wds.json", "w") as f:
69
+ json.dump(results, f, indent=2)
70
+
71
+ if plot:
72
+ import matplotlib.pyplot as plt
73
+ import pandas as pd
74
+
75
+ # Convert to a DataFrame for plotting
76
+ df = pd.DataFrame(
77
+ {
78
+ "Thread Count": [int(k) for k in results.keys()],
79
+ "Datago FPS": [results[k]["datago"]["fps"] for k in results.keys()],
80
+ "Webdataset FPS": [
81
+ results[k]["webdataset"]["fps"] for k in results.keys()
82
+ ],
83
+ }
84
+ )
85
+
86
+ # Plotting with vertical axis starting at 0
87
+ plt.figure(figsize=(10, 6))
88
+ plt.plot(
89
+ df["Thread Count"],
90
+ df["Datago FPS"],
91
+ marker="o",
92
+ label="Datago",
93
+ )
94
+ plt.plot(
95
+ df["Thread Count"],
96
+ df["Webdataset FPS"],
97
+ marker="o",
98
+ label="Webdataset",
99
+ )
100
+ plt.xlabel("Thread Count")
101
+ plt.ylabel("Frames Per Second (FPS)")
102
+ plt.title(f"Throughput: Datago vs Webdataset. Source: {source}")
103
+ plt.ylim(
104
+ 0,
105
+ max(df["Datago FPS"].max(), df["Webdataset FPS"].max()) + 20,
106
+ )
107
+ plt.legend()
108
+ plt.grid(True)
109
+ plt.xticks(df["Thread Count"])
110
+ plt.tight_layout()
111
+ plt.savefig(
112
+ "bench_datago_webdataset.png",
113
+ format="PNG",
114
+ dpi=200,
115
+ bbox_inches="tight",
116
+ )
117
+ plt.close()
118
+
119
+ return results
120
+
121
+ # This setting is not exposed in the config, but an env variable can be used instead
122
+ os.environ["DATAGO_MAX_TASKS"] = str(num_workers)
123
+
45
124
  client_config = {
46
125
  "source_type": "webdataset",
47
126
  "source_config": {
48
127
  "url": url,
49
128
  "shuffle": True,
50
- "max_concurrency": num_workers, # Number of concurrent TarballSample downloads and dispatch
129
+ "concurrent_downloads": num_downloads, # Number of concurrent TarballSample downloads and dispatch
51
130
  "auth_token": os.environ.get("HF_TOKEN", default=""),
52
131
  },
53
132
  "prefetch_buffer_size": 256,
@@ -98,7 +177,7 @@ def benchmark(
98
177
  ]
99
178
  )
100
179
  if crop_and_resize
101
- else None
180
+ else lambda x: x
102
181
  )
103
182
 
104
183
  def custom_transform(sample):
@@ -117,16 +196,17 @@ def benchmark(
117
196
  # .to_tuple("png", "cls") # Map keys to output tuple
118
197
  )
119
198
 
120
- dataloader = DataLoader(
199
+ dataloader = DataLoader( # type:ignore
121
200
  dataset,
122
201
  batch_size=1,
123
202
  num_workers=num_workers,
124
- prefetch_factor=2,
203
+ prefetch_factor=8, # Didn't sweep on that, but probably not super impactful
125
204
  collate_fn=lambda x: x,
126
205
  )
127
206
 
128
207
  # Iterate over the DataLoader
129
208
  start = time.time()
209
+ n_images = 0
130
210
  for n_images, _ in enumerate(tqdm(dataloader, desc="WDS", dynamic_ncols=True)):
131
211
  if n_images > limit:
132
212
  break
@@ -136,5 +216,6 @@ def benchmark(
136
216
  results["webdataset"] = {"fps": fps, "count": n_images}
137
217
  return results
138
218
 
219
+
139
220
  if __name__ == "__main__":
140
221
  typer.run(benchmark)
@@ -0,0 +1,250 @@
1
+ """
2
+ Test suite for WebDataset (WDS) functionality in Datago.
3
+
4
+ This module tests that Datago correctly serves images and attributes from WebDataset sources.
5
+ """
6
+
7
+ import os
8
+
9
+ from dataset import DatagoIterDataset
10
+ from PIL import Image
11
+
12
+ # Test buckets - using the same ones as benchmark_webdataset.py
13
+ TEST_BUCKETS = {
14
+ "pd12m": {
15
+ "url": "https://huggingface.co/datasets/sayakpaul/pd12m-full/resolve/main/{00155..02480}.tar",
16
+ "source": "PD12M",
17
+ },
18
+ "fakein": {
19
+ "url": "https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-{000000..001281}.tar",
20
+ "source": "FakeIN",
21
+ },
22
+ }
23
+
24
+
25
+ def test_wds_basic_functionality():
26
+ """Test basic WDS functionality - that we can get samples with proper structure."""
27
+ limit = 5 # Small limit for quick testing
28
+
29
+ # Use the PD12M bucket for testing
30
+ bucket_config = TEST_BUCKETS["pd12m"]
31
+
32
+ client_config = {
33
+ "source_type": "webdataset",
34
+ "source_config": {
35
+ "url": bucket_config["url"],
36
+ "shuffle": True,
37
+ "concurrent_downloads": 4, # Reduced for testing
38
+ "auth_token": os.environ.get("HF_TOKEN", default=""),
39
+ },
40
+ "prefetch_buffer_size": 32,
41
+ "samples_buffer_size": 32,
42
+ "limit": limit,
43
+ }
44
+
45
+ # Test with return_python_types=True to get proper Python objects
46
+ dataset = DatagoIterDataset(client_config, return_python_types=True)
47
+
48
+ count = 0
49
+ for sample in dataset:
50
+ count += 1
51
+
52
+ # Basic structure checks
53
+ assert "id" in sample, "Sample should contain 'id' field"
54
+ assert sample["id"] != "", "Sample ID should not be empty"
55
+
56
+ # Check that we have an image
57
+ assert "image" in sample, "Sample should contain 'image' field"
58
+ assert sample["image"] is not None, "Image should not be None"
59
+
60
+ # If it's a PIL Image, check its properties
61
+ if isinstance(sample["image"], Image.Image):
62
+ assert sample["image"].width > 0, "Image should have positive width"
63
+ assert sample["image"].height > 0, "Image should have positive height"
64
+ assert sample["image"].mode in ["RGB", "RGBA", "L"], (
65
+ f"Image should have valid mode, got {sample['image'].mode}"
66
+ )
67
+
68
+ # Check for attributes if present
69
+ if "attributes" in sample:
70
+ assert isinstance(sample["attributes"], dict), "Attributes should be a dictionary"
71
+ # Attributes should be non-empty if present
72
+ if sample["attributes"]:
73
+ assert len(sample["attributes"]) > 0, "Attributes dictionary should not be empty"
74
+
75
+ # We should get at least the basic fields
76
+ assert len(sample) >= 2, "Sample should contain at least id and image"
77
+
78
+ if count >= limit:
79
+ break
80
+
81
+ assert count == limit, f"Expected {limit} samples, got {count}"
82
+
83
+
84
+ def test_wds_image_properties():
85
+ """Test that images from WDS have proper properties and can be processed."""
86
+ limit = 3
87
+
88
+ bucket_config = TEST_BUCKETS["pd12m"]
89
+
90
+ client_config = {
91
+ "source_type": "webdataset",
92
+ "source_config": {
93
+ "url": bucket_config["url"],
94
+ "shuffle": True,
95
+ "concurrent_downloads": 4,
96
+ "auth_token": os.environ.get("HF_TOKEN", default=""),
97
+ },
98
+ "prefetch_buffer_size": 32,
99
+ "samples_buffer_size": 32,
100
+ "limit": limit,
101
+ }
102
+
103
+ dataset = DatagoIterDataset(client_config, return_python_types=True)
104
+
105
+ for sample in dataset:
106
+ if "image" in sample and sample["image"] is not None:
107
+ image = sample["image"]
108
+
109
+ # Test that we can get image properties
110
+ if isinstance(image, Image.Image):
111
+ width, height = image.size
112
+ assert width > 0 and height > 0, "Image should have valid dimensions"
113
+
114
+ # Test that we can convert to different modes
115
+ rgb_image = image.convert("RGB")
116
+ assert rgb_image.mode == "RGB", "Image should convert to RGB mode"
117
+
118
+ # Test that we can get thumbnail
119
+ thumbnail = image.copy()
120
+ thumbnail.thumbnail((100, 100))
121
+ assert thumbnail.size[0] <= 100 and thumbnail.size[1] <= 100, "Thumbnail should be resized"
122
+
123
+ # Test that image data is valid by trying to get pixel data
124
+ pixels = image.get_flattened_data()
125
+ assert len(pixels) > 0, "Image should have pixel data"
126
+
127
+ break # Just test one image
128
+
129
+
130
+ def test_wds_with_image_processing():
131
+ """Test WDS with image processing configuration (crop and resize)."""
132
+ limit = 3
133
+
134
+ bucket_config = TEST_BUCKETS["pd12m"]
135
+
136
+ client_config = {
137
+ "source_type": "webdataset",
138
+ "source_config": {
139
+ "url": bucket_config["url"],
140
+ "shuffle": True,
141
+ "concurrent_downloads": 4,
142
+ "auth_token": os.environ.get("HF_TOKEN", default=""),
143
+ },
144
+ "image_config": {
145
+ "crop_and_resize": True,
146
+ "default_image_size": 256,
147
+ "downsampling_ratio": 16,
148
+ "min_aspect_ratio": 0.5,
149
+ "max_aspect_ratio": 2.0,
150
+ },
151
+ "prefetch_buffer_size": 32,
152
+ "samples_buffer_size": 32,
153
+ "limit": limit,
154
+ }
155
+
156
+ dataset = DatagoIterDataset(client_config, return_python_types=True)
157
+
158
+ for sample in dataset:
159
+ if "image" in sample and sample["image"] is not None:
160
+ image = sample["image"]
161
+
162
+ if isinstance(image, Image.Image):
163
+ # With crop_and_resize=True, images should be processed
164
+ width, height = image.size
165
+ assert width > 0 and height > 0, "Processed image should have valid dimensions"
166
+
167
+ # The processed image should be in RGB mode
168
+ assert image.mode == "RGB", f"Processed image should be RGB, got {image.mode}"
169
+
170
+ break # Just test one image
171
+
172
+
173
+ def test_wds_attributes_structure():
174
+ """Test that WDS attributes are properly structured when present."""
175
+ limit = 5
176
+
177
+ bucket_config = TEST_BUCKETS["pd12m"]
178
+
179
+ client_config = {
180
+ "source_type": "webdataset",
181
+ "source_config": {
182
+ "url": bucket_config["url"],
183
+ "shuffle": True,
184
+ "concurrent_downloads": 4,
185
+ "auth_token": os.environ.get("HF_TOKEN", default=""),
186
+ },
187
+ "prefetch_buffer_size": 32,
188
+ "samples_buffer_size": 32,
189
+ "limit": limit,
190
+ }
191
+
192
+ dataset = DatagoIterDataset(client_config, return_python_types=True)
193
+
194
+ for sample in dataset:
195
+ if "attributes" in sample and sample["attributes"]:
196
+ attributes = sample["attributes"]
197
+
198
+ # Attributes should be a dictionary
199
+ assert isinstance(attributes, dict), "Attributes should be a dictionary"
200
+
201
+ # Check that we can access attribute values
202
+ for key, value in attributes.items():
203
+ # Values should be JSON-serializable types
204
+ assert isinstance(key, str), "Attribute keys should be strings"
205
+ assert isinstance(value, (str, int, float, bool, list, dict)), (
206
+ f"Attribute values should be JSON-serializable, got {type(value)}"
207
+ )
208
+
209
+ break # Just test one sample with attributes
210
+
211
+ # Note: Not all samples may have attributes
212
+
213
+
214
+ def test_wds_sample_consistency():
215
+ """Test that WDS samples have consistent structure across multiple samples."""
216
+ limit = 10
217
+
218
+ bucket_config = TEST_BUCKETS["pd12m"]
219
+
220
+ client_config = {
221
+ "source_type": "webdataset",
222
+ "source_config": {
223
+ "url": bucket_config["url"],
224
+ "shuffle": True,
225
+ "concurrent_downloads": 4,
226
+ "auth_token": os.environ.get("HF_TOKEN", default=""),
227
+ },
228
+ "prefetch_buffer_size": 32,
229
+ "samples_buffer_size": 32,
230
+ "limit": limit,
231
+ }
232
+
233
+ dataset = DatagoIterDataset(client_config, return_python_types=True)
234
+
235
+ first_sample = True
236
+
237
+ for sample in dataset:
238
+ current_keys = set(sample.keys())
239
+
240
+ if first_sample:
241
+ first_sample = False
242
+ else:
243
+ # All samples should have at least the core fields (id, image)
244
+ required_keys = {"id", "image"}
245
+ assert required_keys.issubset(current_keys), \
246
+ f"Sample missing required keys. Expected at least {required_keys}, got {current_keys}"
247
+
248
+ # Check that we don't have any unexpected None values for core fields
249
+ assert sample.get("id") != "", "Sample ID should not be empty"
250
+ assert sample.get("image") is not None, "Sample image should not be None"
@@ -0,0 +1,3 @@
1
+ -r requirements.txt
2
+ matplotlib
3
+ pandas
@@ -217,14 +217,16 @@ impl DatagoClient {
217
217
  debug!("Sample pipe closed...");
218
218
 
219
219
  if let Some(feeder) = engine.feeder.take() {
220
- if feeder.join().is_err() {
221
- error!("Failed to join feeder thread");
220
+ match feeder.join() {
221
+ Ok(_) => debug!("Feeder thread joined successfully"),
222
+ Err(e) => error!("Failed to join feeder thread: {:?}", e),
222
223
  }
223
224
  }
224
225
 
225
226
  if let Some(worker) = engine.worker.take() {
226
- if worker.join().is_err() {
227
- error!("Failed to join worker thread");
227
+ match worker.join() {
228
+ Ok(_) => debug!("Worker thread joined successfully"),
229
+ Err(e) => error!("Failed to join worker thread: {:?}", e),
228
230
  }
229
231
  }
230
232
  self.is_started = false;
@@ -35,7 +35,7 @@ pub struct SourceWebDatasetConfig {
35
35
  pub random_sampling: bool,
36
36
 
37
37
  #[serde(default)]
38
- pub max_concurrency: usize,
38
+ pub concurrent_downloads: usize,
39
39
 
40
40
  #[serde(default)]
41
41
  pub auth_token: String,
@@ -81,13 +81,23 @@ async fn pull_tarballs(
81
81
 
82
82
  // Grab an async byte stream from the request, we'll try to untar the results on the fly
83
83
  let response = request_builder.send().await;
84
- if response.is_err() {
85
- return Err("Failed to send request".into());
86
- }
87
- let response = response.unwrap();
84
+ let response = match response {
85
+ Ok(resp) => resp,
86
+ Err(e) => {
87
+ error!("Failed to send request for {}: {}", url, e);
88
+ return Err(format!("Failed to send request: {}", e));
89
+ }
90
+ };
91
+
88
92
  if !response.status().is_success() {
93
+ error!(
94
+ "Failed to download TarballSample {}: HTTP {}",
95
+ url,
96
+ response.status()
97
+ );
89
98
  return Err(format!(
90
- "Failed to download TarballSample: {}",
99
+ "Failed to download TarballSample {}: HTTP {}",
100
+ url,
91
101
  response.status()
92
102
  ));
93
103
  }
@@ -158,7 +168,7 @@ async fn pull_tarballs(
158
168
  if samples_metadata_tx.send(current_files_for_sample).is_err() {
159
169
  debug!("dispatch_shards (streaming): samples_metadata_tx channel closed.");
160
170
  let _ = samples_metadata_tx.close(); // Make sure that we close on both ends
161
- return Err("Channel closed".into());
171
+ return Ok(());
162
172
  }
163
173
 
164
174
  // Start a new sample
@@ -184,9 +194,9 @@ async fn pull_tarballs(
184
194
  // Send the last collected sample if any
185
195
  if !current_files_for_sample.content.is_empty()
186
196
  && samples_metadata_tx.send(current_files_for_sample).is_err()
197
+ && !samples_metadata_tx.is_closed()
187
198
  {
188
- debug!("dispatch_shards (streaming): samples_metadata_tx channel closed for last sample.");
189
- return Err("Channel closed".into());
199
+ return Err("Failed to send last sample".into());
190
200
  }
191
201
 
192
202
  debug!("dispatch_shards (streaming): finished processing TarballSample {url}");
@@ -308,7 +318,18 @@ async fn tasks_from_shards(
308
318
  let mut count = 0;
309
319
  let mut join_error: Option<String> = None;
310
320
 
321
+ info!("WDS: Using {} download tasks", config.concurrent_downloads);
322
+
311
323
  for url in task_list {
324
+ // Escape out if the channel is closed
325
+ if samples_metadata_tx.is_closed() {
326
+ debug!(
327
+ "dispatch_shards: channel is closed, enough samples probably. Bailing out"
328
+ );
329
+ break;
330
+ }
331
+
332
+ // All good, submit a new async task
312
333
  tasks.spawn(pull_tarballs_task(
313
334
  shared_client.clone(),
314
335
  url,
@@ -319,7 +340,7 @@ async fn tasks_from_shards(
319
340
 
320
341
  // Some bookkeeping, to limit the number of tasks in flight
321
342
  // we'll wait for the first one to finish before adding a new one
322
- if tasks.len() >= config.max_concurrency {
343
+ if tasks.len() >= config.concurrent_downloads {
323
344
  match tasks.join_next().await {
324
345
  Some(res) => {
325
346
  match res.unwrap() {
@@ -334,7 +355,6 @@ async fn tasks_from_shards(
334
355
  break;
335
356
  }
336
357
  }
337
- debug!("dispatch_shards: task completed successfully");
338
358
  }
339
359
  None => {
340
360
  // Task was cancelled or panicked
@@ -407,8 +427,10 @@ fn query_shards_and_dispatch(
407
427
  .and_then(|v| v.parse::<u8>().ok())
408
428
  .unwrap_or(3);
409
429
 
430
+ // Use more threads for the download runtime to handle increased concurrency
431
+ let download_threads = std::cmp::max(4, source_config.concurrent_downloads);
410
432
  tokio::runtime::Builder::new_multi_thread()
411
- .worker_threads(source_config.max_concurrency)
433
+ .worker_threads(download_threads)
412
434
  .enable_all()
413
435
  .build()
414
436
  .unwrap()
@@ -434,7 +456,8 @@ fn query_shards_and_dispatch(
434
456
  // ---- Global orchestration ---------
435
457
  pub fn orchestrate(client: &DatagoClient) -> DatagoEngine {
436
458
  // Allocate all the message passing pipes
437
- let (samples_metadata_tx, samples_metadata_rx) = bounded::<TarballSample>(32);
459
+ let metadata_buffer_size = std::cmp::max(128, client.samples_buffer * 2);
460
+ let (samples_metadata_tx, samples_metadata_rx) = bounded::<TarballSample>(metadata_buffer_size);
438
461
  let (samples_tx, samples_rx) = bounded(client.samples_buffer);
439
462
 
440
463
  info!("Using webdataset as source");
@@ -444,9 +467,9 @@ pub fn orchestrate(client: &DatagoClient) -> DatagoEngine {
444
467
  serde_json::from_value(client.source_config.clone()).unwrap();
445
468
  let extension_reference_image_type: String = source_config.reference_image_type.clone();
446
469
 
447
- if source_config.max_concurrency == 0 {
448
- info!("WDS: Defaulting to 8 max_concurrency");
449
- source_config.max_concurrency = 8;
470
+ if source_config.concurrent_downloads == 0 {
471
+ info!("WDS: Defaulting to 8 concurrent_downloads");
472
+ source_config.concurrent_downloads = 8;
450
473
  }
451
474
 
452
475
  // List the contents of the bucket and feed the workers
@@ -518,7 +541,7 @@ mod tests {
518
541
  auth_token: "".into(),
519
542
  reference_image_type: "jpg".into(),
520
543
  random_sampling: s,
521
- max_concurrency: 2,
544
+ concurrent_downloads: 2,
522
545
  rank: 0,
523
546
  world_size: 1,
524
547
  };
@@ -569,7 +592,7 @@ mod tests {
569
592
  "source_config": {
570
593
  "url": "https://storage.googleapis.com/storage/v1/b/webdataset/o?prefix=fake-imagenet/",
571
594
  "random_sampling": do_random_sampling,
572
- "max_concurrency": 2
595
+ "concurrent_downloads": 2
573
596
  },
574
597
  "limit": n_samples,
575
598
  "num_threads": 1,
@@ -632,7 +655,7 @@ mod tests {
632
655
  "source_config": {
633
656
  "url": "https://storage.googleapis.com/storage/v1/b/webdataset/o?prefix=fake-imagenet/",
634
657
  "random_sampling": false,
635
- "max_concurrency": 2,
658
+ "concurrent_downloads": 2,
636
659
  "rank": rank,
637
660
  "world_size": world_size,
638
661
  },
@@ -1,6 +1,6 @@
1
1
  use crate::image_processing;
2
2
  use crate::structs::{to_python_image_payload, ImagePayload, Sample, TarballSample};
3
- use log::{debug, error, info, warn};
3
+ use log::{debug, error, info};
4
4
  use std::cmp::min;
5
5
  use std::collections::HashMap;
6
6
  use std::path::Path;
@@ -77,30 +77,56 @@ async fn process_sample(
77
77
 
78
78
  if ext == extension_reference_image {
79
79
  // If this is the reference image, we store it in the main image field
80
- final_sample = Some(Sample {
81
- id: String::from(sample_id.to_str().unwrap_or("unknown")),
82
- source: sample.name.clone(),
83
- image,
84
- attributes: attributes.clone(),
85
- coca_embedding: vec![],
86
- tags: vec![],
87
- masks: HashMap::new(),
88
- latents: HashMap::new(),
89
- additional_images: HashMap::new(),
90
- duplicate_state: 0,
91
- });
80
+ if let Some(mut_final_sample) = &mut final_sample {
81
+ mut_final_sample.image = image;
82
+ } else {
83
+ // Init the sample
84
+ final_sample = Some(Sample {
85
+ id: String::from(
86
+ sample_id.to_str().unwrap_or("unknown"),
87
+ ),
88
+ source: sample.name.clone(),
89
+ image,
90
+ attributes: HashMap::new(),
91
+ coca_embedding: vec![],
92
+ tags: vec![],
93
+ masks: HashMap::new(),
94
+ latents: HashMap::new(),
95
+ additional_images: HashMap::new(),
96
+ duplicate_state: 0,
97
+ });
98
+ }
92
99
  } else {
93
100
  // Otherwise, we store it in the additional images
94
- match final_sample {
95
- Some(ref mut final_sample_ref) => {
96
- final_sample_ref
97
- .additional_images
98
- .insert(item.filename.clone(), image.clone());
99
- }
100
- None => {
101
- // If final_sample is not initialized, we create it
102
- panic!( "Final sample should be initialized before adding additional images");
103
- }
101
+ if let Some(mut_final_sample) = &mut final_sample {
102
+ mut_final_sample
103
+ .additional_images
104
+ .insert(item.filename.clone(), image);
105
+ } else {
106
+ // Init the sample
107
+ final_sample = Some(Sample {
108
+ id: String::from(
109
+ sample_id.to_str().unwrap_or("unknown"),
110
+ ),
111
+ source: sample.name.clone(),
112
+ image: to_python_image_payload(ImagePayload {
113
+ data: vec![],
114
+ width: 0,
115
+ height: 0,
116
+ original_height: 0,
117
+ original_width: 0,
118
+ bit_depth: 0,
119
+ channels: 0,
120
+ is_encoded: false,
121
+ }),
122
+ attributes: HashMap::new(),
123
+ coca_embedding: vec![],
124
+ tags: vec![],
125
+ masks: HashMap::new(),
126
+ latents: HashMap::new(),
127
+ additional_images: HashMap::new(),
128
+ duplicate_state: 0,
129
+ });
104
130
  }
105
131
  }
106
132
  debug!("wds_worker: unpacked {}", item.filename);
@@ -114,15 +140,25 @@ async fn process_sample(
114
140
  // Load the file in to a string
115
141
  let class_file = String::from_utf8_lossy(&item.buffer).to_string();
116
142
  attributes.insert(ext.to_string(), serde_json::json!(class_file));
117
- debug!("wds_worker: unpacked {}", item.filename);
143
+ debug!("wds_worker: unpacked {} {}", item.filename, class_file);
118
144
  }
119
145
  }
120
146
 
121
- if samples_tx.send(final_sample).is_err() {
122
- debug!("wds_worker: stream already closed, wrapping up");
123
- return Err(());
147
+ // Make sure that the sample has the attributes we decoded
148
+ if let Some(ref mut final_sample_ref) = final_sample {
149
+ final_sample_ref.attributes = attributes;
150
+ match samples_tx.send(final_sample) {
151
+ Ok(_) => (),
152
+ Err(e) => {
153
+ if !samples_tx.is_closed() {
154
+ debug!("wds_worker: error dispatching sample: {e}");
155
+ return Err(());
156
+ }
157
+ }
158
+ }
159
+ return Ok(());
124
160
  }
125
- return Ok(());
161
+ return Err(());
126
162
  }
127
163
  None => {
128
164
  debug!("wds_worker: unpacking sample with no ID");
@@ -147,10 +183,10 @@ async fn async_deserialize_samples(
147
183
  let default_max_tasks = std::env::var("DATAGO_MAX_TASKS")
148
184
  .unwrap_or_else(|_| "0".to_string())
149
185
  .parse::<usize>()
150
- .unwrap_or(num_cpus::get() * 4);
151
- let max_tasks = min(default_max_tasks, limit);
186
+ .unwrap_or(num_cpus::get());
187
+ let max_tasks = min(num_cpus::get() * 4, default_max_tasks); // Ensure minimum of 8 processing tasks
152
188
 
153
- info!("Using {max_tasks} tasks in the async threadpool");
189
+ info!("WDS: Using {max_tasks} processing tasks in worker threadpool");
154
190
  let mut tasks = tokio::task::JoinSet::new();
155
191
  let mut count = 0;
156
192
  let shareable_channel_tx: Arc<kanal::Sender<Option<Sample>>> = Arc::new(samples_tx);
@@ -159,7 +195,7 @@ async fn async_deserialize_samples(
159
195
 
160
196
  while let Ok(sample) = samples_metadata_rx.recv() {
161
197
  if sample.is_empty() {
162
- warn!("wds_worker: end of stream received, stopping there");
198
+ info!("wds_worker: end of stream received, stopping there");
163
199
  let _ = samples_metadata_rx.close();
164
200
  break;
165
201
  }
@@ -228,7 +264,7 @@ pub fn deserialize_samples(
228
264
  extension_reference_image: String,
229
265
  ) {
230
266
  tokio::runtime::Builder::new_multi_thread()
231
- .worker_threads(num_cpus::get())
267
+ .worker_threads(num_cpus::get()) // Tasks in flight are limited by DATAGO_MAX_TASKS env
232
268
  .enable_all()
233
269
  .build()
234
270
  .unwrap()
Binary file
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes