datago 2025.10.2__tar.gz → 2025.12.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {datago-2025.10.2 → datago-2025.12.1}/Cargo.lock +73 -19
  2. {datago-2025.10.2 → datago-2025.12.1}/Cargo.toml +3 -2
  3. {datago-2025.10.2 → datago-2025.12.1}/PKG-INFO +2 -2
  4. datago-2025.12.1/assets/zen3_ssd.png +0 -0
  5. {datago-2025.10.2 → datago-2025.12.1}/pyproject.toml +1 -1
  6. {datago-2025.10.2 → datago-2025.12.1}/python/benchmark_db.py +2 -16
  7. {datago-2025.10.2 → datago-2025.12.1}/python/benchmark_filesystem.py +29 -10
  8. {datago-2025.10.2 → datago-2025.12.1}/python/dataset.py +5 -3
  9. datago-2025.12.1/python/raw_types.py +40 -0
  10. {datago-2025.10.2 → datago-2025.12.1}/python/test_datago_client.py +9 -4
  11. {datago-2025.10.2 → datago-2025.12.1}/python/test_datago_db.py +26 -65
  12. {datago-2025.10.2 → datago-2025.12.1}/python/test_datago_edge_cases.py +6 -4
  13. {datago-2025.10.2 → datago-2025.12.1}/python/test_datago_filesystem.py +7 -5
  14. datago-2025.12.1/python/test_pil_implicit_conversion.py +80 -0
  15. {datago-2025.10.2 → datago-2025.12.1}/src/client.rs +21 -16
  16. {datago-2025.10.2 → datago-2025.12.1}/src/generator_http.rs +2 -2
  17. {datago-2025.10.2 → datago-2025.12.1}/src/generator_wds.rs +13 -12
  18. {datago-2025.10.2 → datago-2025.12.1}/src/image_processing.rs +4 -4
  19. {datago-2025.10.2 → datago-2025.12.1}/src/lib.rs +3 -1
  20. {datago-2025.10.2 → datago-2025.12.1}/src/main.rs +4 -2
  21. {datago-2025.10.2 → datago-2025.12.1}/src/structs.rs +182 -10
  22. {datago-2025.10.2 → datago-2025.12.1}/src/worker_files.rs +11 -8
  23. {datago-2025.10.2 → datago-2025.12.1}/src/worker_http.rs +116 -6
  24. {datago-2025.10.2 → datago-2025.12.1}/src/worker_wds.rs +18 -12
  25. datago-2025.10.2/assets/zen3_ssd.png +0 -0
  26. datago-2025.10.2/python/raw_types.py +0 -87
  27. {datago-2025.10.2 → datago-2025.12.1}/.github/workflows/ci-cd.yml +0 -0
  28. {datago-2025.10.2 → datago-2025.12.1}/.github/workflows/rust.yml +0 -0
  29. {datago-2025.10.2 → datago-2025.12.1}/.gitignore +0 -0
  30. {datago-2025.10.2 → datago-2025.12.1}/.pre-commit-config.yaml +0 -0
  31. {datago-2025.10.2 → datago-2025.12.1}/LICENSE +0 -0
  32. {datago-2025.10.2 → datago-2025.12.1}/README.md +0 -0
  33. {datago-2025.10.2 → datago-2025.12.1}/assets/447175851-2277afcb-8abf-4d17-b2db-dae27c6056d0.png +0 -0
  34. {datago-2025.10.2 → datago-2025.12.1}/assets/epyc_vast.png +0 -0
  35. {datago-2025.10.2 → datago-2025.12.1}/assets/epyc_wds.png +0 -0
  36. {datago-2025.10.2 → datago-2025.12.1}/python/benchmark_defaults.py +0 -0
  37. {datago-2025.10.2 → datago-2025.12.1}/python/benchmark_webdataset.py +0 -0
  38. {datago-2025.10.2 → datago-2025.12.1}/requirements-tests.txt +0 -0
  39. {datago-2025.10.2 → datago-2025.12.1}/requirements.txt +0 -0
  40. {datago-2025.10.2 → datago-2025.12.1}/src/generator_files.rs +0 -0
@@ -111,6 +111,16 @@ version = "0.7.6"
111
111
  source = "registry+https://github.com/rust-lang/crates.io-index"
112
112
  checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
113
113
 
114
+ [[package]]
115
+ name = "assert-json-diff"
116
+ version = "2.0.2"
117
+ source = "registry+https://github.com/rust-lang/crates.io-index"
118
+ checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
119
+ dependencies = [
120
+ "serde",
121
+ "serde_json",
122
+ ]
123
+
114
124
  [[package]]
115
125
  name = "async-channel"
116
126
  version = "1.9.0"
@@ -464,7 +474,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
464
474
  checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02"
465
475
  dependencies = [
466
476
  "smallvec",
467
- "target-lexicon 0.12.16",
477
+ "target-lexicon",
468
478
  ]
469
479
 
470
480
  [[package]]
@@ -613,7 +623,7 @@ dependencies = [
613
623
 
614
624
  [[package]]
615
625
  name = "datago"
616
- version = "2025.10.2"
626
+ version = "2025.12.1"
617
627
  dependencies = [
618
628
  "async-compression",
619
629
  "async-tar",
@@ -644,8 +654,27 @@ dependencies = [
644
654
  "tokio-util",
645
655
  "url",
646
656
  "walkdir",
657
+ "wiremock",
658
+ ]
659
+
660
+ [[package]]
661
+ name = "deadpool"
662
+ version = "0.10.0"
663
+ source = "registry+https://github.com/rust-lang/crates.io-index"
664
+ checksum = "fb84100978c1c7b37f09ed3ce3e5f843af02c2a2c431bae5b19230dad2c1b490"
665
+ dependencies = [
666
+ "async-trait",
667
+ "deadpool-runtime",
668
+ "num_cpus",
669
+ "tokio",
647
670
  ]
648
671
 
672
+ [[package]]
673
+ name = "deadpool-runtime"
674
+ version = "0.1.4"
675
+ source = "registry+https://github.com/rust-lang/crates.io-index"
676
+ checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b"
677
+
649
678
  [[package]]
650
679
  name = "dirs-next"
651
680
  version = "2.0.0"
@@ -1118,6 +1147,12 @@ version = "1.10.1"
1118
1147
  source = "registry+https://github.com/rust-lang/crates.io-index"
1119
1148
  checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
1120
1149
 
1150
+ [[package]]
1151
+ name = "httpdate"
1152
+ version = "1.0.3"
1153
+ source = "registry+https://github.com/rust-lang/crates.io-index"
1154
+ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
1155
+
1121
1156
  [[package]]
1122
1157
  name = "hyper"
1123
1158
  version = "1.6.0"
@@ -1131,6 +1166,7 @@ dependencies = [
1131
1166
  "http",
1132
1167
  "http-body",
1133
1168
  "httparse",
1169
+ "httpdate",
1134
1170
  "itoa",
1135
1171
  "pin-project-lite",
1136
1172
  "smallvec",
@@ -2026,9 +2062,9 @@ dependencies = [
2026
2062
 
2027
2063
  [[package]]
2028
2064
  name = "pyo3"
2029
- version = "0.24.1"
2065
+ version = "0.22.6"
2030
2066
  source = "registry+https://github.com/rust-lang/crates.io-index"
2031
- checksum = "17da310086b068fbdcefbba30aeb3721d5bb9af8db4987d6735b2183ca567229"
2067
+ checksum = "f402062616ab18202ae8319da13fa4279883a2b8a9d9f83f20dbade813ce1884"
2032
2068
  dependencies = [
2033
2069
  "cfg-if",
2034
2070
  "indoc",
@@ -2044,19 +2080,19 @@ dependencies = [
2044
2080
 
2045
2081
  [[package]]
2046
2082
  name = "pyo3-build-config"
2047
- version = "0.24.1"
2083
+ version = "0.22.6"
2048
2084
  source = "registry+https://github.com/rust-lang/crates.io-index"
2049
- checksum = "e27165889bd793000a098bb966adc4300c312497ea25cf7a690a9f0ac5aa5fc1"
2085
+ checksum = "b14b5775b5ff446dd1056212d778012cbe8a0fbffd368029fd9e25b514479c38"
2050
2086
  dependencies = [
2051
2087
  "once_cell",
2052
- "target-lexicon 0.13.2",
2088
+ "target-lexicon",
2053
2089
  ]
2054
2090
 
2055
2091
  [[package]]
2056
2092
  name = "pyo3-ffi"
2057
- version = "0.24.1"
2093
+ version = "0.22.6"
2058
2094
  source = "registry+https://github.com/rust-lang/crates.io-index"
2059
- checksum = "05280526e1dbf6b420062f3ef228b78c0c54ba94e157f5cb724a609d0f2faabc"
2095
+ checksum = "9ab5bcf04a2cdcbb50c7d6105de943f543f9ed92af55818fd17b660390fc8636"
2060
2096
  dependencies = [
2061
2097
  "libc",
2062
2098
  "pyo3-build-config",
@@ -2064,9 +2100,9 @@ dependencies = [
2064
2100
 
2065
2101
  [[package]]
2066
2102
  name = "pyo3-macros"
2067
- version = "0.24.1"
2103
+ version = "0.22.6"
2068
2104
  source = "registry+https://github.com/rust-lang/crates.io-index"
2069
- checksum = "5c3ce5686aa4d3f63359a5100c62a127c9f15e8398e5fdeb5deef1fed5cd5f44"
2105
+ checksum = "0fd24d897903a9e6d80b968368a34e1525aeb719d568dba8b3d4bfa5dc67d453"
2070
2106
  dependencies = [
2071
2107
  "proc-macro2",
2072
2108
  "pyo3-macros-backend",
@@ -2076,9 +2112,9 @@ dependencies = [
2076
2112
 
2077
2113
  [[package]]
2078
2114
  name = "pyo3-macros-backend"
2079
- version = "0.24.1"
2115
+ version = "0.22.6"
2080
2116
  source = "registry+https://github.com/rust-lang/crates.io-index"
2081
- checksum = "f4cf6faa0cbfb0ed08e89beb8103ae9724eb4750e3a78084ba4017cbe94f3855"
2117
+ checksum = "36c011a03ba1e50152b4b394b479826cad97e7a21eb52df179cd91ac411cbfbe"
2082
2118
  dependencies = [
2083
2119
  "heck",
2084
2120
  "proc-macro2",
@@ -2753,12 +2789,6 @@ version = "0.12.16"
2753
2789
  source = "registry+https://github.com/rust-lang/crates.io-index"
2754
2790
  checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
2755
2791
 
2756
- [[package]]
2757
- name = "target-lexicon"
2758
- version = "0.13.2"
2759
- source = "registry+https://github.com/rust-lang/crates.io-index"
2760
- checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a"
2761
-
2762
2792
  [[package]]
2763
2793
  name = "tempfile"
2764
2794
  version = "3.20.0"
@@ -3454,6 +3484,30 @@ dependencies = [
3454
3484
  "memchr",
3455
3485
  ]
3456
3486
 
3487
+ [[package]]
3488
+ name = "wiremock"
3489
+ version = "0.6.4"
3490
+ source = "registry+https://github.com/rust-lang/crates.io-index"
3491
+ checksum = "a2b8b99d4cdbf36b239a9532e31fe4fb8acc38d1897c1761e161550a7dc78e6a"
3492
+ dependencies = [
3493
+ "assert-json-diff",
3494
+ "async-trait",
3495
+ "base64",
3496
+ "deadpool",
3497
+ "futures",
3498
+ "http",
3499
+ "http-body-util",
3500
+ "hyper",
3501
+ "hyper-util",
3502
+ "log",
3503
+ "once_cell",
3504
+ "regex",
3505
+ "serde",
3506
+ "serde_json",
3507
+ "tokio",
3508
+ "url",
3509
+ ]
3510
+
3457
3511
  [[package]]
3458
3512
  name = "wit-bindgen-rt"
3459
3513
  version = "0.33.0"
@@ -1,7 +1,7 @@
1
1
  [package]
2
2
  name = "datago"
3
3
  edition = "2021"
4
- version = "2025.10.2"
4
+ version = "2025.12.1"
5
5
  readme = "README.md"
6
6
 
7
7
  [lib]
@@ -24,7 +24,7 @@ kanal = "0.1"
24
24
  clap = { version = "4.5.27", features = ["derive"] }
25
25
  tokio = { version = "1.43.1", features = ["rt-multi-thread", "macros"] }
26
26
  prettytable-rs = "0.10.0"
27
- pyo3 = { version = "0.24.1", features = ["extension-module"] }
27
+ pyo3 = { version = "0.22", features = ["extension-module"] }
28
28
  threadpool = "1.8.1"
29
29
  openssl = { version = "0.10", features = ["vendored"] }
30
30
  walkdir = "2.5.0"
@@ -46,6 +46,7 @@ fast_image_resize = { version ="5.1.3", features=["image"]}
46
46
 
47
47
  [dev-dependencies]
48
48
  tempfile = "3.13.0"
49
+ wiremock = "0.6.0"
49
50
 
50
51
  [profile.release]
51
52
  opt-level = 3
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datago
3
- Version: 2025.10.2
3
+ Version: 2025.12.1
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -8,7 +8,7 @@ Classifier: Programming Language :: Python :: 3
8
8
  Classifier: License :: OSI Approved :: MIT License
9
9
  License-File: LICENSE
10
10
  Summary: A high performance dataloader for Python, written in Rust
11
- Author: Benjamin Lefaudeux
11
+ Author: Benjamin Lefaudeux, Roman Frigg
12
12
  Author-email: Photoroom <team@photoroom.com>
13
13
  Requires-Python: >=3.8
14
14
  Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
Binary file
@@ -2,7 +2,7 @@
2
2
  name = "datago"
3
3
  dynamic = ["version"]
4
4
  authors = [
5
- { name = "Benjamin Lefaudeux" },
5
+ { name = "Benjamin Lefaudeux, Roman Frigg" },
6
6
  { name = "Photoroom", email = "team@photoroom.com" }
7
7
  ]
8
8
  description = "A high performance dataloader for Python, written in Rust"
@@ -6,7 +6,7 @@ import typer
6
6
  from benchmark_defaults import IMAGE_CONFIG
7
7
  from datago import DatagoClient # type: ignore
8
8
  from PIL import Image
9
- from raw_types import raw_array_to_numpy, raw_array_to_pil_image
9
+ from raw_types import raw_array_to_numpy
10
10
  from tqdm import tqdm
11
11
 
12
12
 
@@ -58,21 +58,7 @@ def benchmark(
58
58
  for _ in tqdm(range(limit), dynamic_ncols=True):
59
59
  sample = client.get_sample()
60
60
  if sample.id:
61
- # Bring the masks and image to PIL
62
- if hasattr(sample, "image"):
63
- img = raw_array_to_pil_image(sample.image)
64
-
65
- if hasattr(sample, "masks"):
66
- for _, mask_buffer in sample.masks.items():
67
- mask = raw_array_to_pil_image(mask_buffer)
68
-
69
- if (
70
- hasattr(sample, "additional_images")
71
- and "masked_image" in sample.additional_images
72
- ):
73
- masked_image = raw_array_to_pil_image(
74
- sample.AdditionalImages["masked_image"]
75
- )
61
+ # Images are already PIL by default
76
62
 
77
63
  # Bring the latents to numpy
78
64
  if hasattr(sample, "latents"):
@@ -9,26 +9,34 @@ from tqdm import tqdm
9
9
 
10
10
 
11
11
  def benchmark(
12
- root_path: str = typer.Option(os.getenv("DATAGO_TEST_FILESYSTEM", ""), help="The source to test out"),
12
+ root_path: str = typer.Option(
13
+ os.getenv("DATAGO_TEST_FILESYSTEM", ""), help="The source to test out"
14
+ ),
13
15
  limit: int = typer.Option(2000, help="The number of samples to test on"),
14
- crop_and_resize: bool = typer.Option(False, help="Crop and resize the images on the fly"),
16
+ crop_and_resize: bool = typer.Option(
17
+ False, help="Crop and resize the images on the fly"
18
+ ),
15
19
  compare_torch: bool = typer.Option(True, help="Compare against torch dataloader"),
16
20
  num_workers: int = typer.Option(os.cpu_count(), help="Number of workers to use"),
17
21
  sweep: bool = typer.Option(False, help="Sweep over the number of workers"),
18
22
  ):
19
23
  if sweep:
20
- results = {}
21
- for num_workers in range(2, (os.cpu_count() or 2), 16):
22
- results[num_workers] = benchmark(root_path, limit, crop_and_resize, compare_torch, num_workers, False)
24
+ results_sweep = {}
25
+ for num_workers in range(2, (os.cpu_count() * 2 or 2), 2):
26
+ results_sweep[num_workers] = benchmark(
27
+ root_path, limit, crop_and_resize, compare_torch, num_workers, False
28
+ )
23
29
 
24
30
  # Save results to a json file
25
31
 
26
32
  with open("benchmark_results_filesystem.json", "w") as f:
27
- json.dump(results, f, indent=2)
33
+ json.dump(results_sweep, f, indent=2)
28
34
 
29
- return results
35
+ return results_sweep
30
36
 
31
- print(f"Running benchmark for {root_path} - {limit} samples - {num_workers} workers")
37
+ print(
38
+ f"Running benchmark for {root_path} - {limit} samples - {num_workers} workers"
39
+ )
32
40
 
33
41
  # This setting is not exposed in the config, but an env variable can be used instead
34
42
  os.environ["DATAGO_MAX_TASKS"] = str(num_workers)
@@ -59,6 +67,11 @@ def benchmark(
59
67
  for sample in tqdm(datago_dataset, desc="Datago", dynamic_ncols=True):
60
68
  assert sample["id"] != ""
61
69
  img = sample["image"]
70
+
71
+ if count < limit - 1:
72
+ del img
73
+ img = None # Help with memory pressure
74
+
62
75
  count += 1
63
76
 
64
77
  assert count == limit, f"Expected {limit} samples, got {count}"
@@ -80,7 +93,9 @@ def benchmark(
80
93
  transform = (
81
94
  transforms.Compose(
82
95
  [
83
- transforms.Resize((1024, 1024), interpolation=transforms.InterpolationMode.LANCZOS),
96
+ transforms.Resize(
97
+ (1024, 1024), interpolation=transforms.InterpolationMode.LANCZOS
98
+ ),
84
99
  ]
85
100
  )
86
101
  if crop_and_resize
@@ -88,7 +103,9 @@ def benchmark(
88
103
  )
89
104
 
90
105
  # Create the ImageFolder dataset
91
- dataset = datasets.ImageFolder(root=root_path, transform=transform, allow_empty=True)
106
+ dataset = datasets.ImageFolder(
107
+ root=root_path, transform=transform, allow_empty=True
108
+ )
92
109
 
93
110
  # Create a DataLoader to allow for multiple workers
94
111
  # Use available CPU count for num_workers
@@ -107,6 +124,8 @@ def benchmark(
107
124
  n_images += len(batch)
108
125
  if n_images > limit:
109
126
  break
127
+
128
+ del batch # Help with memory pressure, same as above
110
129
  fps = n_images / (time.time() - start)
111
130
  results["torch"] = {"fps": fps, "count": n_images}
112
131
  print(f"Torch - FPS {fps:.2f} - workers {num_workers}")
@@ -1,7 +1,7 @@
1
1
  from datago import DatagoClient, initialize_logging
2
2
  import json
3
3
  from typing import Dict, Any
4
- from raw_types import raw_array_to_pil_image, raw_array_to_numpy
4
+ from raw_types import raw_array_to_numpy
5
5
 
6
6
 
7
7
  class DatagoIterDataset:
@@ -29,11 +29,13 @@ class DatagoIterDataset:
29
29
  return json.loads(item)
30
30
 
31
31
  if isinstance(item, dict):
32
- # recurvisely convert the dictionary
32
+ # recursively convert the dictionary
33
33
  return {k: DatagoIterDataset.to_python_types(v, k) for k, v in item.items()}
34
34
 
35
35
  elif "image" in key:
36
- return raw_array_to_pil_image(item)
36
+ # The Rust-side returns PythonImagePayload objects that are callable
37
+ # Call them to get the actual PIL image
38
+ return item()
37
39
  elif "latent" in key:
38
40
  return raw_array_to_numpy(item)
39
41
 
@@ -0,0 +1,40 @@
1
+ from PIL import Image
2
+ from typing import Optional, Union
3
+ import numpy as np
4
+ from datago import ImagePayload
5
+
6
+
7
+ def raw_array_to_numpy(raw_array: ImagePayload) -> Optional[np.ndarray]:
8
+ if len(raw_array.data) == 0:
9
+ return None
10
+
11
+ # Generic numpy-serialized array
12
+ try:
13
+ return np.load(raw_array.data, allow_pickle=False)
14
+ except ValueError:
15
+ # Do not try to handle these, return None and we'll handle it in the caller
16
+ print("Could not deserialize numpy array")
17
+ return None
18
+
19
+
20
+ def decode_image_payload(payload: ImagePayload) -> Image.Image:
21
+ """
22
+ Decode an ImagePayload (encoded image) into a PIL Image.
23
+ This is the proper way to decode encoded images for API users.
24
+ """
25
+ import io
26
+
27
+ return Image.open(io.BytesIO(payload.data))
28
+
29
+
30
+ def get_image_mode(image_or_payload: Union[ImagePayload, Image.Image]) -> str:
31
+ """
32
+ Helper function to get the mode of an image, whether it's a PIL Image or ImagePayload.
33
+ For ImagePayload objects (encoded images), we need to decode them first.
34
+ """
35
+ if hasattr(image_or_payload, "mode"):
36
+ # It's a PIL Image
37
+ return image_or_payload.mode
38
+ else:
39
+ # It's an ImagePayload (encoded image), decode it first
40
+ return decode_image_payload(image_or_payload).mode
@@ -115,7 +115,13 @@ class TestDatagoClient:
115
115
  assert sample.source == "filesystem"
116
116
  assert sample.image.width > 0
117
117
  assert sample.image.height > 0
118
- assert len(sample.image.data) > 0
118
+
119
+ # Check the payload path
120
+ payload = sample.image.get_payload()
121
+ assert len(payload.data) > 0
122
+ assert payload.width == sample.image.width
123
+ assert payload.height == sample.image.height
124
+ assert payload.channels == 3
119
125
 
120
126
  def test_client_with_image_transformations(self):
121
127
  """Test client with image transformation configuration."""
@@ -149,7 +155,7 @@ class TestDatagoClient:
149
155
  assert sample is not None
150
156
  assert sample.image.width <= 64
151
157
  assert sample.image.height <= 64
152
- assert sample.image.channels == 3 # RGB8
158
+ assert sample.image.mode == "RGB"
153
159
 
154
160
  def test_client_with_image_encoding(self):
155
161
  """Test client with image encoding enabled."""
@@ -177,8 +183,7 @@ class TestDatagoClient:
177
183
  sample = client.get_sample()
178
184
 
179
185
  assert sample is not None
180
- assert sample.image.channels == -1 # Encoded images have channels = -1
181
- assert len(sample.image.data) > 0
186
+ assert sample.image.mode == "RGB"
182
187
 
183
188
  def test_random_sampling(self):
184
189
  """Test that random sampling produces different results."""
@@ -3,7 +3,9 @@ import pytest
3
3
  import os
4
4
  import json
5
5
 
6
- from raw_types import raw_array_to_pil_image, raw_array_to_numpy, get_image_mode, decode_image_payload
6
+ from raw_types import (
7
+ decode_image_payload,
8
+ )
7
9
  from dataset import DatagoIterDataset
8
10
 
9
11
 
@@ -62,9 +64,11 @@ def test_caption_and_image():
62
64
  assert img.height > 0
63
65
  assert img.width > 0
64
66
 
65
- assert img.height <= img.original_height
66
- assert img.width <= img.original_width
67
- assert img.channels == channels
67
+ payload = img.get_payload()
68
+ assert img.height <= payload.original_height
69
+ assert img.width <= payload.original_width
70
+ assert img.mode == "RGB" if channels == 3 else "L"
71
+ assert payload.channels == channels
68
72
 
69
73
  for i, sample in enumerate(dataset):
70
74
  assert sample.source != ""
@@ -80,14 +84,11 @@ def test_caption_and_image():
80
84
  check_image(sample.masks["segmentation_mask"], 1)
81
85
 
82
86
  # Check the image decoding
83
- assert raw_array_to_pil_image(sample.image).mode == "RGB", "Image should be RGB"
84
- assert (
85
- raw_array_to_pil_image(sample.additional_images["masked_image"]).mode
86
- == "RGB"
87
- ), "Image should be RGB"
88
- assert raw_array_to_pil_image(sample.masks["segmentation_mask"]).mode == "L", (
89
- "Mask should be L"
87
+ assert sample.image.mode == "RGB", "Image should be RGB"
88
+ assert sample.additional_images["masked_image"].mode == "RGB", (
89
+ "Image should be RGB"
90
90
  )
91
+ assert sample.masks["segmentation_mask"].mode == "L", "Mask should be L"
91
92
 
92
93
  if i > N_SAMPLES:
93
94
  break
@@ -146,36 +147,13 @@ def test_jpeg_compression():
146
147
 
147
148
  sample = next(iter(dataset))
148
149
 
149
- # When images are encoded, channels is set to -1 to signal encoded format
150
- assert sample.image.channels == -1, "Image should be encoded (channels == -1)"
151
- assert (
152
- sample.additional_images["masked_image"].channels == -1
153
- ), "Additional image should be encoded"
154
- assert (
155
- sample.masks["segmentation_mask"].channels == -1
156
- ), "Mask should be encoded"
157
-
158
- # Test that raw_array_to_pil_image returns ImagePayload for encoded images
159
- image_result = raw_array_to_pil_image(sample.image)
160
- assert not hasattr(image_result, 'mode'), "Should return ImagePayload, not PIL Image"
161
- assert hasattr(image_result, 'data'), "Should have data attribute"
162
- assert hasattr(image_result, 'channels'), "Should have channels attribute"
163
- assert image_result.channels == -1, "Should be encoded ImagePayload"
164
-
165
- # Test proper decoding using decode_image_payload
166
- decoded_image = decode_image_payload(image_result)
167
- assert hasattr(decoded_image, 'mode'), "Decoded image should be PIL Image"
168
- assert decoded_image.mode == "RGB", "Image should decode to RGB"
169
- assert decoded_image.size == (sample.image.width, sample.image.height), "Size should match"
170
-
171
- # Test additional images and masks
172
- additional_result = raw_array_to_pil_image(sample.additional_images["masked_image"])
173
- decoded_additional = decode_image_payload(additional_result)
174
- assert decoded_additional.mode == "RGB", "Additional image should decode to RGB"
175
-
176
- mask_result = raw_array_to_pil_image(sample.masks["segmentation_mask"])
177
- decoded_mask = decode_image_payload(mask_result)
178
- assert decoded_mask.mode == "L", "Mask should decode to L"
150
+ # Check that the image is properly accessible through PIL, but that it is encoded
151
+ assert sample.image.mode == "RGB", "Image should be RGB"
152
+ assert sample.additional_images["masked_image"].mode == "RGB", "Image should be RGB"
153
+ assert sample.masks["segmentation_mask"].mode == "L", "Mask should be L"
154
+
155
+ # Check that the image is encoded, as JPG PIL
156
+ # TODO: @blefaudeux
179
157
 
180
158
 
181
159
  def test_png_compression():
@@ -185,23 +163,11 @@ def test_png_compression():
185
163
  # Don't specify encode_format - should default to PNG
186
164
  dataset = DatagoIterDataset(client_config, return_python_types=False)
187
165
 
188
- sample = next(iter(dataset))
189
-
190
- # When images are encoded, channels is set to -1 to signal encoded format
191
- assert sample.image.channels == -1, "Image should be encoded (channels == -1)"
192
-
193
- # Test that raw_array_to_pil_image returns ImagePayload for encoded images
194
- image_result = raw_array_to_pil_image(sample.image)
195
- assert not hasattr(image_result, 'mode'), "Should return ImagePayload, not PIL Image"
196
- assert hasattr(image_result, 'data'), "Should have data attribute"
197
- assert hasattr(image_result, 'channels'), "Should have channels attribute"
198
- assert image_result.channels == -1, "Should be encoded ImagePayload"
166
+ _sample = next(iter(dataset))
199
167
 
200
- # Test proper decoding using decode_image_payload
201
- decoded_image = decode_image_payload(image_result)
202
- assert hasattr(decoded_image, 'mode'), "Decoded image should be PIL Image"
203
- assert decoded_image.mode == "RGB", "Image should decode to RGB"
204
- assert decoded_image.size == (sample.image.width, sample.image.height), "Size should match"
168
+ # Check that the image is encoded, as JPG PIL
169
+ # TODO: @blefaudeux
170
+ # same as above
205
171
 
206
172
 
207
173
  def test_original_image():
@@ -212,14 +178,9 @@ def test_original_image():
212
178
  dataset = DatagoIterDataset(client_config, return_python_types=False)
213
179
 
214
180
  sample = next(iter(dataset))
215
-
216
- assert raw_array_to_pil_image(sample.image).mode == "RGB", "Image should be RGB"
217
- assert (
218
- raw_array_to_pil_image(sample.additional_images["masked_image"]).mode == "RGB"
219
- ), "Image should be RGB"
220
- assert raw_array_to_pil_image(sample.masks["segmentation_mask"]).mode == "L", (
221
- "Mask should be L"
222
- )
181
+ payload = sample.image.get_payload()
182
+ assert payload.original_height == payload.height == sample.image.height
183
+ assert payload.original_width == payload.width == sample.image.width
223
184
 
224
185
 
225
186
  def test_duplicate_state():
@@ -90,11 +90,13 @@ class TestDatagoEdgeCases:
90
90
  sample = client.get_sample()
91
91
 
92
92
  assert sample is not None
93
- assert sample.image.original_width == 2000
94
- assert sample.image.original_height == 2000
93
+ image_payload = sample.image.get_payload()
94
+ assert image_payload is not None
95
+ assert image_payload.original_width == 2000
96
+ assert image_payload.original_height == 2000
95
97
  # Should be resized
96
- assert sample.image.width <= 512
97
- assert sample.image.height <= 512
98
+ assert image_payload.width <= 512
99
+ assert image_payload.height <= 512
98
100
 
99
101
  def test_very_small_images(self):
100
102
  """Test handling of very small images."""
@@ -54,21 +54,23 @@ def test_get_sample_filesystem(pre_encode_images: bool, rgb16: bool, rgba: bool)
54
54
 
55
55
  client = DatagoClient(json.dumps(client_config))
56
56
  count = 0
57
- for i in range(limit):
57
+ for _ in range(limit):
58
58
  data = client.get_sample()
59
59
  if not data:
60
60
  break
61
61
  count += 1
62
62
  assert data.id != ""
63
- assert data.image.width == 100
64
- assert data.image.height == 100
63
+
64
+ image_payload = data.image.get_payload()
65
+ assert image_payload.width == 100
66
+ assert image_payload.height == 100
65
67
 
66
68
  if rgb16:
67
- assert data.image.bit_depth == 8
69
+ assert image_payload.bit_depth == 8
68
70
 
69
71
  # Open the image in python scope and check properties
70
72
  if pre_encode_images:
71
- test_image = Image.open(BytesIO(data.image.data))
73
+ test_image = Image.open(BytesIO(bytes(image_payload.data)))
72
74
  assert test_image.width == 100
73
75
  assert test_image.height == 100
74
76
  assert test_image.mode == "RGB"