datago 2025.6.5__tar.gz → 2025.8.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {datago-2025.6.5 → datago-2025.8.1}/Cargo.lock +1 -1
  2. {datago-2025.6.5 → datago-2025.8.1}/Cargo.toml +2 -1
  3. {datago-2025.6.5 → datago-2025.8.1}/PKG-INFO +2 -1
  4. {datago-2025.6.5 → datago-2025.8.1}/pyproject.toml +8 -6
  5. {datago-2025.6.5 → datago-2025.8.1}/python/test_datago_edge_cases.py +3 -4
  6. {datago-2025.6.5 → datago-2025.8.1}/src/client.rs +73 -9
  7. {datago-2025.6.5 → datago-2025.8.1}/src/generator_files.rs +105 -64
  8. {datago-2025.6.5 → datago-2025.8.1}/src/generator_http.rs +8 -9
  9. {datago-2025.6.5 → datago-2025.8.1}/src/generator_wds.rs +20 -29
  10. {datago-2025.6.5 → datago-2025.8.1}/src/image_processing.rs +8 -16
  11. {datago-2025.6.5 → datago-2025.8.1}/src/main.rs +5 -8
  12. {datago-2025.6.5 → datago-2025.8.1}/src/structs.rs +1 -0
  13. {datago-2025.6.5 → datago-2025.8.1}/src/worker_files.rs +7 -8
  14. {datago-2025.6.5 → datago-2025.8.1}/src/worker_http.rs +12 -16
  15. {datago-2025.6.5 → datago-2025.8.1}/src/worker_wds.rs +7 -10
  16. {datago-2025.6.5 → datago-2025.8.1}/.github/workflows/ci-cd.yml +0 -0
  17. {datago-2025.6.5 → datago-2025.8.1}/.github/workflows/rust.yml +0 -0
  18. {datago-2025.6.5 → datago-2025.8.1}/.gitignore +0 -0
  19. {datago-2025.6.5 → datago-2025.8.1}/.pre-commit-config.yaml +0 -0
  20. {datago-2025.6.5 → datago-2025.8.1}/LICENSE +0 -0
  21. {datago-2025.6.5 → datago-2025.8.1}/README.md +0 -0
  22. {datago-2025.6.5 → datago-2025.8.1}/assets/447175851-2277afcb-8abf-4d17-b2db-dae27c6056d0.png +0 -0
  23. {datago-2025.6.5 → datago-2025.8.1}/python/benchmark_db.py +0 -0
  24. {datago-2025.6.5 → datago-2025.8.1}/python/benchmark_filesystem.py +0 -0
  25. {datago-2025.6.5 → datago-2025.8.1}/python/benchmark_webdataset.py +0 -0
  26. {datago-2025.6.5 → datago-2025.8.1}/python/dataset.py +0 -0
  27. {datago-2025.6.5 → datago-2025.8.1}/python/raw_types.py +0 -0
  28. {datago-2025.6.5 → datago-2025.8.1}/python/test_datago_client.py +0 -0
  29. {datago-2025.6.5 → datago-2025.8.1}/python/test_datago_db.py +0 -0
  30. {datago-2025.6.5 → datago-2025.8.1}/python/test_datago_filesystem.py +0 -0
  31. {datago-2025.6.5 → datago-2025.8.1}/requirements-tests.txt +0 -0
  32. {datago-2025.6.5 → datago-2025.8.1}/requirements.txt +0 -0
  33. {datago-2025.6.5 → datago-2025.8.1}/src/lib.rs +0 -0
@@ -613,7 +613,7 @@ dependencies = [
613
613
 
614
614
  [[package]]
615
615
  name = "datago"
616
- version = "2025.6.5"
616
+ version = "2025.8.1"
617
617
  dependencies = [
618
618
  "async-compression",
619
619
  "async-tar",
@@ -1,7 +1,8 @@
1
1
  [package]
2
2
  name = "datago"
3
3
  edition = "2021"
4
- version = "2025.6.5"
4
+ version = "2025.8.1"
5
+ readme = "README.md"
5
6
 
6
7
  [lib]
7
8
  # exposed by pyo3
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datago
3
- Version: 2025.6.5
3
+ Version: 2025.8.1
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -8,6 +8,7 @@ Classifier: Programming Language :: Python :: 3
8
8
  Classifier: License :: OSI Approved :: MIT License
9
9
  License-File: LICENSE
10
10
  Summary: A high performance dataloader for Python, written in Rust
11
+ Author: Benjamin Lefaudeux
11
12
  Author-email: Photoroom <team@photoroom.com>
12
13
  Requires-Python: >=3.8
13
14
  Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
@@ -1,17 +1,19 @@
1
1
  [project]
2
2
  name = "datago"
3
+ dynamic = ["version"]
3
4
  authors = [
4
- { name="Photoroom", email="team@photoroom.com" },
5
+ { name = "Benjamin Lefaudeux" },
6
+ { name = "Photoroom", email = "team@photoroom.com" }
5
7
  ]
6
8
  description = "A high performance dataloader for Python, written in Rust"
7
9
  readme = "README.md"
8
10
  requires-python = ">=3.8"
9
11
  classifiers = [
10
- "Programming Language :: Rust",
11
- "Programming Language :: Python :: Implementation :: CPython",
12
- "Programming Language :: Python :: Implementation :: PyPy",
13
- "Programming Language :: Python :: 3",
14
- "License :: OSI Approved :: MIT License",
12
+ "Programming Language :: Rust",
13
+ "Programming Language :: Python :: Implementation :: CPython",
14
+ "Programming Language :: Python :: Implementation :: PyPy",
15
+ "Programming Language :: Python :: 3",
16
+ "License :: OSI Approved :: MIT License",
15
17
  ]
16
18
  dependencies = []
17
19
 
@@ -238,11 +238,10 @@ class TestDatagoEdgeCases:
238
238
  "samples_buffer_size": 10,
239
239
  }
240
240
 
241
+ # Should flag that the config is not correct
241
242
  client = DatagoClient(json.dumps(config))
242
- _sample = client.get_sample()
243
-
244
- # Should handle gracefully (might return None or work with adjusted parameters)
245
- # The exact behavior depends on implementation
243
+ sample = client.get_sample()
244
+ assert sample is None
246
245
 
247
246
  def test_very_large_buffer_sizes(self):
248
247
  """Test with very large buffer sizes."""
@@ -26,14 +26,58 @@ pub struct DatagoClient {
26
26
 
27
27
  // Holds all the variables related to a running engine
28
28
  engine: Option<DatagoEngine>,
29
+
30
+ is_valid: bool,
31
+ }
32
+
33
+ fn check_config(str_config: &str) -> Option<DatagoClientConfig> {
34
+ match serde_json::from_str::<DatagoClientConfig>(str_config) {
35
+ Ok(config) => {
36
+ if config.samples_buffer_size == 0 {
37
+ error!("Samples buffer size must be greater than 0");
38
+ return None;
39
+ }
40
+
41
+ if config.limit == 0 {
42
+ error!("Limit must be greater than 0");
43
+ return None;
44
+ }
45
+
46
+ // Check that a distributed config is valid, and error out early if not
47
+ let world_size = config
48
+ .source_config
49
+ .get("world_size")
50
+ .and_then(|v| v.as_u64())
51
+ .unwrap_or(1) as usize;
52
+ let rank = config
53
+ .source_config
54
+ .get("rank")
55
+ .and_then(|v| v.as_u64())
56
+ .unwrap_or(0) as usize;
57
+ if world_size == 0 {
58
+ error!("World size must be greater than 0");
59
+ return None;
60
+ }
61
+
62
+ if rank >= world_size {
63
+ error!("Rank must be less than world size");
64
+ return None;
65
+ }
66
+ Some(config)
67
+ }
68
+ Err(e) => {
69
+ error!("Failed to parse config: {e}");
70
+ None
71
+ }
72
+ }
29
73
  }
30
74
 
31
75
  #[pymethods]
32
76
  impl DatagoClient {
33
77
  #[new]
34
78
  pub fn new(str_config: String) -> Self {
35
- match serde_json::from_str::<DatagoClientConfig>(&str_config) {
36
- Ok(config) => {
79
+ match check_config(&str_config) {
80
+ Some(config) => {
37
81
  let mut image_transform: Option<ARAwareTransform> = None;
38
82
  let mut encode_images = false;
39
83
  let mut image_to_rgb8 = false;
@@ -45,8 +89,6 @@ impl DatagoClient {
45
89
  image_to_rgb8 = image_config.image_to_rgb8;
46
90
  }
47
91
 
48
- assert!(config.limit > 0, "Limit must be greater than 0");
49
-
50
92
  DatagoClient {
51
93
  is_started: false,
52
94
  source_type: config.source_type,
@@ -58,10 +100,24 @@ impl DatagoClient {
58
100
  encode_images,
59
101
  image_to_rgb8,
60
102
  engine: None,
103
+ is_valid: true,
61
104
  }
62
105
  }
63
- Err(e) => {
64
- panic!("Failed to parse config: {}", e);
106
+ None => {
107
+ error!("Failed to parse config");
108
+ DatagoClient {
109
+ is_started: false,
110
+ source_type: SourceType::Invalid,
111
+ source_config: serde_json::Value::Null,
112
+ samples_buffer: 0,
113
+ limit: 0,
114
+ max_connections: 0,
115
+ image_transform: None,
116
+ encode_images: false,
117
+ image_to_rgb8: false,
118
+ engine: None,
119
+ is_valid: false,
120
+ }
65
121
  }
66
122
  }
67
123
  }
@@ -87,12 +143,20 @@ impl DatagoClient {
87
143
  warn!("WebDataset source type is new and experimental, use with caution!\nPlease report any issues you encounter to https://github.com/Photoroom/datago/issues.");
88
144
  self.engine = Some(generator_wds::orchestrate(self));
89
145
  }
146
+ SourceType::Invalid => {
147
+ error!("Client ill-defined, probably a config error. Cannot start");
148
+ return;
149
+ }
90
150
  }
91
151
 
92
152
  self.is_started = true;
93
153
  }
94
154
 
95
155
  pub fn get_sample(&mut self) -> Option<Sample> {
156
+ if !self.is_valid {
157
+ return None;
158
+ }
159
+
96
160
  if !self.is_started {
97
161
  self.start();
98
162
  }
@@ -120,7 +184,7 @@ impl DatagoClient {
120
184
  }
121
185
  },
122
186
  Err(e) => {
123
- warn!("Timeout waiting for sample, stopping the client. {}", e);
187
+ warn!("Timeout waiting for sample, stopping the client. {e}");
124
188
  self.stop();
125
189
  None
126
190
  }
@@ -514,7 +578,7 @@ mod tests {
514
578
  let mut config = get_test_config();
515
579
  let tag1 = "photo";
516
580
  let tag2 = "graphic";
517
- config["source_config"]["tags_ne_all"] = format!("{},{}", tag1, tag2).into();
581
+ config["source_config"]["tags_ne_all"] = format!("{tag1},{tag2}").into();
518
582
  let mut client = DatagoClient::new(config.to_string());
519
583
 
520
584
  let sample = client.get_sample();
@@ -599,7 +663,7 @@ mod tests {
599
663
  config["source_config"]["sources_ne"] = "LAION_ART".into();
600
664
  config["limit"] = json!(limit);
601
665
 
602
- debug!("{}", config);
666
+ debug!("{config}");
603
667
  let mut client = DatagoClient::new(config.to_string());
604
668
 
605
669
  for _ in 0..limit {
@@ -5,8 +5,6 @@ use kanal::bounded;
5
5
  use log::{debug, info};
6
6
  use rand::seq::SliceRandom;
7
7
  use serde::{Deserialize, Serialize};
8
- use std::collections::hash_map::DefaultHasher;
9
- use std::hash::{Hash, Hasher};
10
8
  use std::thread;
11
9
 
12
10
  #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -23,17 +21,24 @@ pub struct SourceFileConfig {
23
21
  pub world_size: usize,
24
22
  }
25
23
 
26
- // Hash function to be able to dispatch the samples to the correct rank
24
+ fn get_data_slice_multirank(quorum: usize, rank: usize, world_size: usize) -> (usize, usize) {
25
+ assert!(rank < world_size, "Rank must be less than world size");
27
26
 
28
- // The seed ensures consistent hashing across different runs,
29
- // essentially acting as a deterministic salt
30
- const HASH_SEED: u64 = 0x51_73_b3_c3_7f_d9_2e_a1;
27
+ let chunk_size = quorum / world_size; // This floors by default
28
+ let remainder = quorum % world_size;
31
29
 
32
- fn hash<T: Hash>(t: &T) -> u64 {
33
- let mut hasher = DefaultHasher::new();
34
- HASH_SEED.hash(&mut hasher); // Add seed first
35
- t.hash(&mut hasher); // Then hash the actual data
36
- hasher.finish()
30
+ let start = if rank < remainder {
31
+ rank * (chunk_size + 1)
32
+ } else {
33
+ remainder * (chunk_size + 1) + (rank - remainder) * chunk_size
34
+ };
35
+
36
+ let end = if (rank + 1) <= remainder {
37
+ (rank + 1) * (chunk_size + 1)
38
+ } else {
39
+ remainder * (chunk_size + 1) + (rank + 1 - remainder) * chunk_size
40
+ };
41
+ (start, end)
37
42
  }
38
43
 
39
44
  fn enumerate_files(
@@ -49,6 +54,7 @@ fn enumerate_files(
49
54
  .into_iter()
50
55
  .filter_map(|e| e.ok());
51
56
 
57
+ // We need to materialize the file list to be able to shuffle it
52
58
  let mut files_list: Vec<walkdir::DirEntry> = files
53
59
  .filter_map(|entry| {
54
60
  let path = entry.path();
@@ -65,13 +71,18 @@ fn enumerate_files(
65
71
  .collect();
66
72
 
67
73
  // If shuffle is set, shuffle the files
68
- let files_iter = if source_config.random_sampling {
74
+ if source_config.random_sampling {
69
75
  let mut rng = rand::rng(); // Get a random number generator, thread local. We don´t seed, so typically won't be reproducible
70
- files_list.shuffle(&mut rng);
71
- files_list.into_iter()
72
- } else {
73
- files_list.into_iter()
74
- };
76
+ files_list.shuffle(&mut rng); // This happens in place
77
+ }
78
+
79
+ // If world_size > 1, we need to split the files list into chunks and only process the chunk corresponding to the rank
80
+ if source_config.world_size > 1 {
81
+ let quorum = files_list.len();
82
+ let (start, end) =
83
+ get_data_slice_multirank(quorum, source_config.rank, source_config.world_size);
84
+ files_list = files_list[start..end].to_vec();
85
+ }
75
86
 
76
87
  // Iterate over the files and send the paths as they come
77
88
  let mut count = 0;
@@ -81,17 +92,8 @@ fn enumerate_files(
81
92
  let max_submitted_samples = (1.1 * (limit as f64)).ceil() as usize;
82
93
 
83
94
  // Build a page from the files iterator
84
- for entry in files_iter {
85
- let file_name = entry.path().to_str().unwrap().to_string();
86
-
87
- // If world_size is not 0, we need to dispatch the samples to the correct rank
88
- if source_config.world_size > 1 {
89
- let hash = hash(&file_name);
90
- let target_rank = (hash % source_config.world_size as u64) as usize;
91
- if target_rank != source_config.rank {
92
- continue;
93
- }
94
- }
95
+ for entry in files_list.iter() {
96
+ let file_name: String = entry.path().to_str().unwrap().to_string();
95
97
 
96
98
  if samples_metadata_tx
97
99
  .send(serde_json::Value::String(file_name))
@@ -110,10 +112,7 @@ fn enumerate_files(
110
112
  }
111
113
 
112
114
  // Either we don't have any more samples or we have reached the limit
113
- debug!(
114
- "ping_pages: total samples requested: {}. page samples served {}",
115
- limit, count
116
- );
115
+ debug!("ping_pages: total samples requested: {limit}. page samples served {count}");
117
116
 
118
117
  // Send an empty value to signal the end of the stream
119
118
  match samples_metadata_tx.send(serde_json::Value::Null) {
@@ -183,19 +182,55 @@ mod tests {
183
182
  use tempfile::TempDir;
184
183
 
185
184
  #[test]
186
- fn test_hash_function() {
187
- let str1 = "test_string1";
188
- let str2 = "test_string2";
189
- let str3 = "test_string1"; // Same as str1
190
-
191
- let hash1 = hash(&str1);
192
- let hash2 = hash(&str2);
193
- let hash3 = hash(&str3);
194
-
195
- // Same input should produce same hash
196
- assert_eq!(hash1, hash3);
197
- // Different inputs should likely produce different hashes
198
- assert_ne!(hash1, hash2);
185
+ fn test_get_data_slice_multirank() {
186
+ // Test case 1: Equal distribution with no remainder
187
+ let (start, end) = get_data_slice_multirank(10, 0, 2);
188
+ assert_eq!(start, 0);
189
+ assert_eq!(end, 5);
190
+
191
+ let (start, end) = get_data_slice_multirank(10, 1, 2);
192
+ assert_eq!(start, 5);
193
+ assert_eq!(end, 10);
194
+
195
+ // Test case 2: Unequal distribution with remainder
196
+ let (start, end) = get_data_slice_multirank(11, 0, 2);
197
+ assert_eq!(start, 0);
198
+ assert_eq!(end, 6);
199
+
200
+ let (start, end) = get_data_slice_multirank(11, 1, 2);
201
+ assert_eq!(start, 6);
202
+ assert_eq!(end, 11);
203
+
204
+ // Test case 3: Multiple ranks with remainder
205
+ let (start, end) = get_data_slice_multirank(13, 0, 3);
206
+ assert_eq!(start, 0);
207
+ assert_eq!(end, 5);
208
+
209
+ let (start, end) = get_data_slice_multirank(13, 1, 3);
210
+ assert_eq!(start, 5);
211
+ assert_eq!(end, 9);
212
+
213
+ let (start, end) = get_data_slice_multirank(13, 2, 3);
214
+ assert_eq!(start, 9);
215
+ assert_eq!(end, 13);
216
+
217
+ // Test case 4: Single rank
218
+ let (start, end) = get_data_slice_multirank(10, 0, 1);
219
+ assert_eq!(start, 0);
220
+ assert_eq!(end, 10);
221
+
222
+ // Test case 5: Edge case with zero quorum
223
+ let (start, end) = get_data_slice_multirank(0, 0, 1);
224
+ assert_eq!(start, 0);
225
+ assert_eq!(end, 0);
226
+
227
+ // Test case 6: Edge case with zero world size (should panic or handle gracefully)
228
+ // Note: This test assumes the function should panic or handle the zero division gracefully
229
+ // You may need to adjust the test based on your actual error handling
230
+ let result = std::panic::catch_unwind(|| {
231
+ get_data_slice_multirank(10, 0, 0);
232
+ });
233
+ assert!(result.is_err());
199
234
  }
200
235
 
201
236
  #[test]
@@ -227,14 +262,19 @@ mod tests {
227
262
  assert_eq!(config.world_size, 4);
228
263
  }
229
264
 
230
- fn create_test_images(dir: &Path) -> Vec<String> {
265
+ fn create_test_images(dir: &Path, min_num_files: usize) -> Vec<String> {
231
266
  let extensions = ["jpg", "png", "bmp", "gif", "JPEG"];
232
267
  let mut files = Vec::new();
233
- for (i, ext) in extensions.iter().enumerate() {
234
- let filename = format!("test_image_{}.{}", i, ext);
235
- let filepath = dir.join(&filename);
236
- fs::write(&filepath, "fake_image_data").unwrap();
237
- files.push(filepath.to_string_lossy().to_string());
268
+ let mut n_files = 0;
269
+
270
+ while n_files < min_num_files {
271
+ for (i, ext) in extensions.iter().enumerate() {
272
+ let filename = format!("test_image_{n_files}_{i}.{ext}");
273
+ let filepath = dir.join(&filename);
274
+ fs::write(&filepath, "fake_image_data").unwrap();
275
+ files.push(filepath.to_string_lossy().to_string());
276
+ n_files += 1;
277
+ }
238
278
  }
239
279
 
240
280
  // Create a non-image file that should be ignored
@@ -248,8 +288,8 @@ mod tests {
248
288
  fn test_enumerate_files_basic() {
249
289
  let temp_dir = TempDir::new().unwrap();
250
290
  let temp_path = temp_dir.path();
251
-
252
- let created_files = create_test_images(temp_path);
291
+ let limit = 10;
292
+ let created_files = create_test_images(temp_path, limit);
253
293
 
254
294
  let (tx, rx) = kanal::bounded(100);
255
295
  let config = SourceFileConfig {
@@ -284,8 +324,8 @@ mod tests {
284
324
  fn test_enumerate_files_with_limit() {
285
325
  let temp_dir = TempDir::new().unwrap();
286
326
  let temp_path = temp_dir.path();
287
-
288
- create_test_images(temp_path);
327
+ let limit = 10;
328
+ create_test_images(temp_path, limit);
289
329
 
290
330
  let (tx, rx) = kanal::bounded(100);
291
331
  let config = SourceFileConfig {
@@ -295,7 +335,6 @@ mod tests {
295
335
  world_size: 1,
296
336
  };
297
337
 
298
- let limit = 2;
299
338
  std::thread::spawn(move || {
300
339
  enumerate_files(tx, config, limit);
301
340
  });
@@ -318,8 +357,9 @@ mod tests {
318
357
  fn test_enumerate_files_with_world_size() {
319
358
  let temp_dir = TempDir::new().unwrap();
320
359
  let temp_path = temp_dir.path();
360
+ let limit = 10;
321
361
 
322
- create_test_images(temp_path);
362
+ create_test_images(temp_path, limit * 2); // We'll check that each rank has "limit" files l416
323
363
 
324
364
  // Test rank 0 of world_size 2
325
365
  let (tx1, rx1) = kanal::bounded(100);
@@ -340,11 +380,11 @@ mod tests {
340
380
  };
341
381
 
342
382
  std::thread::spawn(move || {
343
- enumerate_files(tx1, config1, 10);
383
+ enumerate_files(tx1, config1, limit);
344
384
  });
345
385
 
346
386
  std::thread::spawn(move || {
347
- enumerate_files(tx2, config2, 10);
387
+ enumerate_files(tx2, config2, limit);
348
388
  });
349
389
 
350
390
  let mut files_rank0 = Vec::new();
@@ -373,16 +413,17 @@ mod tests {
373
413
  }
374
414
 
375
415
  // Both ranks should have some files
376
- assert!(!files_rank0.is_empty());
377
- assert!(!files_rank1.is_empty());
416
+ assert!(files_rank0.len() >= limit);
417
+ assert!(files_rank1.len() >= limit);
378
418
  }
379
419
 
380
420
  #[test]
381
421
  fn test_enumerate_files_random_sampling() {
382
422
  let temp_dir = TempDir::new().unwrap();
383
423
  let temp_path = temp_dir.path();
424
+ let limit = 10;
384
425
 
385
- create_test_images(temp_path);
426
+ create_test_images(temp_path, limit);
386
427
 
387
428
  // Run twice with random sampling to see if order changes
388
429
  let (tx1, rx1) = kanal::bounded(100);
@@ -402,11 +443,11 @@ mod tests {
402
443
  };
403
444
 
404
445
  std::thread::spawn(move || {
405
- enumerate_files(tx1, config1, 10);
446
+ enumerate_files(tx1, config1, limit);
406
447
  });
407
448
 
408
449
  std::thread::spawn(move || {
409
- enumerate_files(tx2, config2, 10);
450
+ enumerate_files(tx2, config2, limit);
410
451
  });
411
452
 
412
453
  let mut files1 = Vec::new();
@@ -130,9 +130,9 @@ struct DbRequest {
130
130
  impl DbRequest {
131
131
  async fn get_http_request(&self, api_url: &str, api_key: &str) -> reqwest::Request {
132
132
  let mut url = if self.random_sampling {
133
- Url::parse(&format!("{}images/random/", api_url))
133
+ Url::parse(&format!("{api_url}images/random/"))
134
134
  } else {
135
- Url::parse(&format!("{}images/", api_url))
135
+ Url::parse(&format!("{api_url}images/"))
136
136
  }
137
137
  .unwrap(); // Cannot survive without the URL, that's a panic
138
138
 
@@ -183,7 +183,7 @@ impl DbRequest {
183
183
  let mut req = reqwest::Request::new(reqwest::Method::GET, url);
184
184
  req.headers_mut().append(
185
185
  AUTHORIZATION,
186
- HeaderValue::from_str(&format!("Token {}", api_key))
186
+ HeaderValue::from_str(&format!("Token {api_key}"))
187
187
  .expect("Couldn't parse the provided API key"),
188
188
  );
189
189
 
@@ -276,7 +276,7 @@ fn build_request(source_config: SourceDBConfig) -> DbRequest {
276
276
  "Rank cannot be greater than or equal to world size"
277
277
  );
278
278
 
279
- debug!("Fields: {}", fields);
279
+ debug!("Fields: {fields}");
280
280
  debug!(
281
281
  "Rank: {}, World size: {}",
282
282
  source_config.rank, source_config.world_size
@@ -363,7 +363,7 @@ async fn async_pull_and_dispatch_pages(
363
363
  let mut headers = HeaderMap::new();
364
364
  headers.insert(
365
365
  AUTHORIZATION,
366
- HeaderValue::from_str(&format!("Token {}", api_key)).unwrap(),
366
+ HeaderValue::from_str(&format!("Token {api_key}")).unwrap(),
367
367
  );
368
368
 
369
369
  let db_request = build_request(source_config.clone());
@@ -380,7 +380,7 @@ async fn async_pull_and_dispatch_pages(
380
380
  if let Some(next) = response_json.get("next") {
381
381
  next_url = next;
382
382
  } else {
383
- debug!("No next URL in the response {:?}", response_json);
383
+ debug!("No next URL in the response {response_json:?}");
384
384
  }
385
385
  }
386
386
  Err(e) => {
@@ -420,7 +420,7 @@ async fn async_pull_and_dispatch_pages(
420
420
  }
421
421
  }
422
422
  None => {
423
- debug!("No results in the response: {:?}", response_json);
423
+ debug!("No results in the response: {response_json:?}");
424
424
  }
425
425
  }
426
426
 
@@ -455,8 +455,7 @@ async fn async_pull_and_dispatch_pages(
455
455
 
456
456
  // Either we don't have any more samples or we have reached the limit
457
457
  debug!(
458
- "pull_and_dispatch_pages: total samples requested: {}. page samples served {}",
459
- limit, count
458
+ "pull_and_dispatch_pages: total samples requested: {limit}. page samples served {count}"
460
459
  );
461
460
 
462
461
  // Send an empty value to signal the end of the stream
@@ -95,9 +95,7 @@ async fn pull_tarballs(
95
95
  // Convert the byte stream to an AsyncRead
96
96
  let byte_stream = response.bytes_stream();
97
97
  let stream_reader =
98
- StreamReader::new(byte_stream.map(|res_bytes| {
99
- res_bytes.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
100
- }));
98
+ StreamReader::new(byte_stream.map(|res_bytes| res_bytes.map_err(std::io::Error::other)));
101
99
 
102
100
  // Wrap in BufReader for the async Tar reader
103
101
  let buf_reader = BufReader::new(stream_reader);
@@ -107,18 +105,18 @@ async fn pull_tarballs(
107
105
 
108
106
  let mut entries = archive
109
107
  .entries()
110
- .map_err(|e| format!("Failed to fetch TarballSample: {}", e))?; // This returns a stream
108
+ .map_err(|e| format!("Failed to fetch TarballSample: {e}"))?; // This returns a stream
111
109
 
112
110
  let mut current_sample_key: Option<String> = None;
113
111
  let mut current_files_for_sample = TarballSample::new(url.to_string());
114
112
 
115
113
  while let Some(entry_result) = entries.next().await {
116
114
  let mut entry =
117
- entry_result.map_err(|e| format!("Failed to read TarballSample entry: {}", e))?;
115
+ entry_result.map_err(|e| format!("Failed to read TarballSample entry: {e}"))?;
118
116
 
119
117
  let header_path = entry
120
118
  .path()
121
- .map_err(|e| format!("Error considering TarballSample content {}", e))?
119
+ .map_err(|e| format!("Error considering TarballSample content {e}"))?
122
120
  .into_owned();
123
121
  let filename = header_path.to_string_lossy().into_owned();
124
122
 
@@ -172,7 +170,7 @@ async fn pull_tarballs(
172
170
  entry
173
171
  .read_to_end(&mut buffer)
174
172
  .await
175
- .map_err(|e| format!("Failed to read TarballSample {}", e))?; // Read the content of the current file
173
+ .map_err(|e| format!("Failed to read TarballSample {e}"))?; // Read the content of the current file
176
174
 
177
175
  current_files_for_sample.add(BinaryFile { filename, buffer });
178
176
  debug!(
@@ -191,10 +189,7 @@ async fn pull_tarballs(
191
189
  return Err("Channel closed".into());
192
190
  }
193
191
 
194
- debug!(
195
- "dispatch_shards (streaming): finished processing TarballSample {}",
196
- url
197
- );
192
+ debug!("dispatch_shards (streaming): finished processing TarballSample {url}");
198
193
  Ok(())
199
194
  }
200
195
 
@@ -221,10 +216,7 @@ async fn pull_tarballs_task(
221
216
  }
222
217
  Err(e) => {
223
218
  attempt += 1;
224
- debug!(
225
- "Error pulling TarballSample: {}. Attempt {}/{}",
226
- e, attempt, retries
227
- );
219
+ debug!("Error pulling TarballSample: {e}. Attempt {attempt}/{retries}");
228
220
  if samples_metadata_tx.is_closed() {
229
221
  debug!(
230
222
  "dispatch_shards: samples_metadata_tx channel closed, stopping retries."
@@ -235,8 +227,7 @@ async fn pull_tarballs_task(
235
227
  }
236
228
  }
237
229
  Err(format!(
238
- "Failed to pull TarballSample after {} attempts",
239
- retries
230
+ "Failed to pull TarballSample after {retries} attempts"
240
231
  ))
241
232
  }
242
233
 
@@ -266,19 +257,19 @@ async fn get_url_list(
266
257
  // Given the url, list all the available webdataset files
267
258
  let request = reqwest::Request::new(
268
259
  reqwest::Method::GET,
269
- Url::parse(&config.url).map_err(|e| format!("Failed parsing url: {}", e))?,
260
+ Url::parse(&config.url).map_err(|e| format!("Failed parsing url: {e}"))?,
270
261
  );
271
262
 
272
263
  let response = shared_client
273
264
  .client
274
265
  .execute(request)
275
266
  .await
276
- .map_err(|e| format!("Failed parsing reply: {}", e))?;
267
+ .map_err(|e| format!("Failed parsing reply: {e}"))?;
277
268
 
278
269
  let response_text = response
279
270
  .text()
280
271
  .await
281
- .map_err(|e| format!("Failed parsing reply: {}", e))?;
272
+ .map_err(|e| format!("Failed parsing reply: {e}"))?;
282
273
  let response_json: serde_json::Value =
283
274
  serde_json::from_str(&response_text).unwrap_or(serde_json::Value::Null);
284
275
 
@@ -336,7 +327,7 @@ async fn tasks_from_shards(
336
327
  }
337
328
  Err(e) => {
338
329
  // Logging as debug, could be that channels are closed
339
- debug!("dispatch_shards: task returned error: {:?}", e);
330
+ debug!("dispatch_shards: task returned error: {e:?}");
340
331
  join_error = Some(e);
341
332
  break;
342
333
  }
@@ -361,7 +352,7 @@ async fn tasks_from_shards(
361
352
  count += 1;
362
353
  }
363
354
  Err(e) => {
364
- debug!("dispatch_shards: task returned error: {:?}", e);
355
+ debug!("dispatch_shards: task returned error: {e}");
365
356
  // Note that we only keep the first error, which is probably the most relevant
366
357
  if join_error.is_none() {
367
358
  join_error = Some(e);
@@ -372,7 +363,7 @@ async fn tasks_from_shards(
372
363
 
373
364
  if join_error.is_some() {
374
365
  // If we had an error, we log it and return an error
375
- warn!("dispatch_shards: one of the tasks failed: {:?}", join_error);
366
+ warn!("dispatch_shards: one of the tasks failed: {join_error:?}");
376
367
  return Err(join_error.unwrap().to_string());
377
368
  }
378
369
 
@@ -380,7 +371,7 @@ async fn tasks_from_shards(
380
371
  if count == 0 {
381
372
  warn!("No items found in the response");
382
373
  }
383
- debug!("Served {} items from the bucket", count);
374
+ debug!("Served {count} items from the bucket");
384
375
 
385
376
  // Send an empty value to signal the end of the stream
386
377
  if samples_metadata_tx
@@ -393,7 +384,7 @@ async fn tasks_from_shards(
393
384
  Ok(response_json)
394
385
  }
395
386
  Err(e) => {
396
- warn!("Failed to get URL list: {}", e);
387
+ warn!("Failed to get URL list: {e}");
397
388
  Err(e) // Return a JoinError with the error message
398
389
  }
399
390
  }
@@ -421,7 +412,7 @@ fn query_shards_and_dispatch(
421
412
  debug!("query_shards_and_dispatch: finished processing all shards");
422
413
  }
423
414
  Err(e) => {
424
- debug!("query_shards_and_dispatch: ended with : {:?}", e);
415
+ debug!("query_shards_and_dispatch: ended with : {e:?}");
425
416
  }
426
417
  }
427
418
  });
@@ -542,7 +533,7 @@ mod tests {
542
533
  }
543
534
  }
544
535
 
545
- debug!("Received {} items", count);
536
+ debug!("Received {count} items");
546
537
  let _ = samples_meta_rx.close();
547
538
  feeder.join().expect("Feeder thread panicked");
548
539
 
@@ -593,7 +584,7 @@ mod tests {
593
584
  break;
594
585
  }
595
586
  }
596
- info!("Received {} items", count);
587
+ info!("Received {count} items");
597
588
  assert!(count >= limit, "Not enough items found in the bucket");
598
589
  client.stop();
599
590
 
@@ -650,7 +641,7 @@ mod tests {
650
641
  break;
651
642
  }
652
643
  }
653
- info!("Received {} items", count);
644
+ info!("Received {count} items");
654
645
  client.stop();
655
646
 
656
647
  samples
@@ -59,10 +59,7 @@ impl ImageTransformConfig {
59
59
  self.max_aspect_ratio,
60
60
  );
61
61
 
62
- debug!(
63
- "Cropping and resizing images. Target image sizes:\n{:?}\n",
64
- target_image_sizes
65
- );
62
+ debug!("Cropping and resizing images. Target image sizes:\n{target_image_sizes:?}\n");
66
63
 
67
64
  let mut aspect_ratio_to_size = std::collections::HashMap::new();
68
65
  for img_size in &target_image_sizes {
@@ -325,12 +322,7 @@ pub async fn image_to_payload(
325
322
  image.height(),
326
323
  image.color().into(),
327
324
  )
328
- .map_err(|e| {
329
- image::ImageError::IoError(std::io::Error::new(
330
- std::io::ErrorKind::Other,
331
- e.to_string(),
332
- ))
333
- })?;
325
+ .map_err(std::io::Error::other)?;
334
326
 
335
327
  channels = -1; // Signal the fact that the image is encoded
336
328
  } else {
@@ -687,8 +679,8 @@ mod tests {
687
679
 
688
680
  // Fill with some test data
689
681
  let buffer = img.buffer_mut();
690
- for i in 0..buffer.len() {
691
- buffer[i] = (i % 256) as u8;
682
+ for (i, item) in buffer.iter_mut().enumerate() {
683
+ *item = (i % 256) as u8;
692
684
  }
693
685
 
694
686
  let dyn_img = image_to_dyn_image(&img);
@@ -704,8 +696,8 @@ mod tests {
704
696
  let mut img = Image::new(width, height, fr::PixelType::U8x4);
705
697
 
706
698
  let buffer = img.buffer_mut();
707
- for i in 0..buffer.len() {
708
- buffer[i] = ((i * 63) % 256) as u8;
699
+ for (i, item) in buffer.iter_mut().enumerate() {
700
+ *item = ((i * 63) % 256) as u8;
709
701
  }
710
702
 
711
703
  let dyn_img = image_to_dyn_image(&img);
@@ -721,8 +713,8 @@ mod tests {
721
713
  let mut img = Image::new(width, height, fr::PixelType::U8);
722
714
 
723
715
  let buffer = img.buffer_mut();
724
- for i in 0..buffer.len() {
725
- buffer[i] = (i % 256) as u8;
716
+ for (i, item) in buffer.iter_mut().enumerate() {
717
+ *item = (i % 256) as u8;
726
718
  }
727
719
 
728
720
  let dyn_img = image_to_dyn_image(&img);
@@ -119,7 +119,7 @@ fn main() {
119
119
  "samples_buffer_size": samples_buffer_size
120
120
  });
121
121
 
122
- info!("{}", config);
122
+ info!("{config}");
123
123
 
124
124
  let mut client = client::DatagoClient::new(config.to_string());
125
125
 
@@ -141,7 +141,7 @@ fn main() {
141
141
  }
142
142
  if save_samples {
143
143
  let img = image::load_from_memory(&sample.image.data).unwrap();
144
- let filename = format!("sample_{:?}.jpg", num_samples_received);
144
+ let filename = format!("sample_{num_samples_received:?}.jpg");
145
145
  img.save(filename).unwrap();
146
146
  }
147
147
  num_samples_received += 1;
@@ -160,18 +160,15 @@ fn main() {
160
160
  }
161
161
  }
162
162
  client.stop();
163
- info!(
164
- "All samples processed. Got {:?} samples\n",
165
- num_samples_received
166
- );
163
+ info!("All samples processed. Got {num_samples_received:?} samples\n");
167
164
 
168
165
  // Report the per-bucket occupancy, good sanity check
169
166
  if crop_and_resize {
170
167
  let mut size_buckets_str = String::from("Size buckets:\n");
171
168
  for (size, count) in size_buckets.iter() {
172
- size_buckets_str.push_str(&format!("{}: {}\n", size, count));
169
+ size_buckets_str.push_str(&format!("{size}: {count}\n"));
173
170
  }
174
- info!("{}", size_buckets_str);
171
+ info!("{size_buckets_str}");
175
172
  }
176
173
 
177
174
  let elapsed_secs = start_time.elapsed().as_secs_f64();
@@ -13,6 +13,7 @@ pub enum SourceType {
13
13
  Db,
14
14
  File,
15
15
  WebDataset,
16
+ Invalid,
16
17
  }
17
18
 
18
19
  fn default_source_type() -> SourceType {
@@ -6,9 +6,8 @@ use std::collections::HashMap;
6
6
  use std::sync::Arc;
7
7
 
8
8
  async fn image_from_path(path: &str) -> Result<image::DynamicImage, image::ImageError> {
9
- let bytes = std::fs::read(path).map_err(|e| {
10
- image::ImageError::IoError(std::io::Error::new(std::io::ErrorKind::Other, e))
11
- })?;
9
+ let bytes =
10
+ std::fs::read(path).map_err(|e| image::ImageError::IoError(std::io::Error::other(e)))?;
12
11
 
13
12
  image::load_from_memory(&bytes)
14
13
  }
@@ -70,7 +69,7 @@ async fn pull_sample(
70
69
  Ok(())
71
70
  }
72
71
  Err(e) => {
73
- error!("Failed to load image from path {} {}", sample_json, e);
72
+ error!("Failed to load image from path {sample_json} {e}");
74
73
  Err(())
75
74
  }
76
75
  }
@@ -125,7 +124,7 @@ async fn async_pull_samples(
125
124
  debug!("file_worker: task failed or was cancelled");
126
125
  }
127
126
  });
128
- debug!("file_worker: total samples sent: {}\n", count);
127
+ debug!("file_worker: total samples sent: {count}\n");
129
128
 
130
129
  // Signal the end of the stream
131
130
  if samples_tx.send(None).is_ok() {};
@@ -349,7 +348,7 @@ mod tests {
349
348
  // Create multiple test images
350
349
  let mut image_paths = Vec::new();
351
350
  for i in 0..3 {
352
- let image_path = temp_dir.path().join(format!("test_{}.png", i));
351
+ let image_path = temp_dir.path().join(format!("test_{i}.png"));
353
352
  create_test_image(&image_path);
354
353
  image_paths.push(image_path.to_str().unwrap().to_string());
355
354
  }
@@ -391,7 +390,7 @@ mod tests {
391
390
 
392
391
  // Create more images than the limit
393
392
  for i in 0..10 {
394
- let image_path = temp_dir.path().join(format!("test_{}.png", i));
393
+ let image_path = temp_dir.path().join(format!("test_{i}.png"));
395
394
  create_test_image(&image_path);
396
395
  }
397
396
 
@@ -400,7 +399,7 @@ mod tests {
400
399
 
401
400
  // Send more paths than the limit
402
401
  for i in 0..10 {
403
- let path = temp_dir.path().join(format!("test_{}.png", i));
402
+ let path = temp_dir.path().join(format!("test_{i}.png"));
404
403
  metadata_tx
405
404
  .send(serde_json::Value::String(
406
405
  path.to_str().unwrap().to_string(),
@@ -64,14 +64,13 @@ async fn image_from_url(
64
64
  match image::load_from_memory(&bytes) {
65
65
  Ok(image) => return Ok(image),
66
66
  Err(e) => {
67
- warn!("Failed to decode image from URL: {}. Retrying", url);
68
- warn!("Error: {:?}", e);
67
+ warn!("Failed to decode image from URL: {url}. Retrying");
68
+ warn!("Error: {e:?}");
69
69
  }
70
70
  }
71
71
  }
72
72
  }
73
- Err(image::ImageError::IoError(std::io::Error::new(
74
- std::io::ErrorKind::Other,
73
+ Err(image::ImageError::IoError(std::io::Error::other(
75
74
  "Failed to fetch image bytes",
76
75
  )))
77
76
  }
@@ -88,14 +87,11 @@ async fn payload_from_url(
88
87
  return Ok(bytes);
89
88
  }
90
89
  None => {
91
- warn!("Failed to get bytes from URL: {}. Retrying", url);
90
+ warn!("Failed to get bytes from URL: {url}. Retrying");
92
91
  }
93
92
  }
94
93
  }
95
- Err(std::io::Error::new(
96
- std::io::ErrorKind::Other,
97
- "Failed to fetch bytes buffer",
98
- ))
94
+ Err(std::io::Error::other("Failed to fetch bytes buffer"))
99
95
  }
100
96
 
101
97
  async fn image_payload_from_url(
@@ -157,8 +153,8 @@ async fn pull_sample(
157
153
  Some(payload)
158
154
  }
159
155
  Err(e) => {
160
- error!("Failed to get image from URL: {}\n {:?}", image_url, e);
161
- error!("Error: {:?}", e);
156
+ error!("Failed to get image from URL: {image_url}\n {e:?}");
157
+ error!("Error: {e:?}");
162
158
  return Err(());
163
159
  }
164
160
  };
@@ -282,7 +278,7 @@ async fn async_pull_samples(
282
278
  // We use async-await here, to better use IO stalls
283
279
  // We'll keep a pool of N async tasks in parallel
284
280
  let max_tasks = min(num_cpus::get(), limit);
285
- debug!("Using {} tasks in the async threadpool", max_tasks);
281
+ debug!("Using {max_tasks} tasks in the async threadpool");
286
282
  let mut tasks = tokio::task::JoinSet::new();
287
283
  let mut count = 0;
288
284
  let shareable_channel_tx: Arc<kanal::Sender<Option<Sample>>> = Arc::new(samples_tx);
@@ -314,7 +310,7 @@ async fn async_pull_samples(
314
310
  }
315
311
  Some(Err(e)) => {
316
312
  // Task failed, log the error
317
- error!("file_worker: task failed with error: {:?}", e);
313
+ error!("file_worker: task failed with error: {e}");
318
314
  join_error = Some(e);
319
315
  break;
320
316
  }
@@ -337,7 +333,7 @@ async fn async_pull_samples(
337
333
  count += 1;
338
334
  }
339
335
  Err(e) => {
340
- error!("dispatch_shards: task failed with error: {:?}", e);
336
+ error!("dispatch_shards: task failed with error: {e}");
341
337
  if join_error.is_none() {
342
338
  join_error = Some(e);
343
339
  }
@@ -345,7 +341,7 @@ async fn async_pull_samples(
345
341
  }
346
342
  }
347
343
 
348
- debug!("http_worker: total samples sent: {}\n", count);
344
+ debug!("http_worker: total samples sent: {count}\n");
349
345
 
350
346
  // Signal the end of the stream
351
347
  let _ = shareable_channel_tx.send(None); // Channel could have been closed by a .stop() call
@@ -387,7 +383,7 @@ pub fn pull_samples(
387
383
  debug!("http_worker: all samples pulled successfully");
388
384
  }
389
385
  Err(e) => {
390
- error!("http_worker: error pulling samples: {:?}", e);
386
+ error!("http_worker: error pulling samples: {e}");
391
387
  }
392
388
  }
393
389
  });
@@ -102,7 +102,7 @@ async fn process_sample(
102
102
  debug!("wds_worker: unpacked {}", item.filename);
103
103
  }
104
104
  Err(e) => {
105
- debug!("wds_worker: error loading image: {}", e);
105
+ debug!("wds_worker: error loading image: {e}");
106
106
  continue;
107
107
  }
108
108
  }
@@ -142,7 +142,7 @@ async fn async_deserialize_samples(
142
142
  // We use async-await here, to better use IO stalls
143
143
  // We'll keep a pool of N async tasks in parallel
144
144
  let max_tasks = min(num_cpus::get(), limit);
145
- info!("Using {} tasks in the async threadpool", max_tasks);
145
+ info!("Using {max_tasks} tasks in the async threadpool");
146
146
  let mut tasks = tokio::task::JoinSet::new();
147
147
  let mut count = 0;
148
148
  let shareable_channel_tx: Arc<kanal::Sender<Option<Sample>>> = Arc::new(samples_tx);
@@ -172,7 +172,7 @@ async fn async_deserialize_samples(
172
172
  match result {
173
173
  Ok(_) => count += 1,
174
174
  Err(e) => {
175
- join_error = Some(format!("Task failed: {}", e));
175
+ join_error = Some(format!("Task failed: {e}"));
176
176
  break;
177
177
  }
178
178
  }
@@ -192,7 +192,7 @@ async fn async_deserialize_samples(
192
192
  count += 1;
193
193
  }
194
194
  Err(e) => {
195
- error!("dispatch_shards: task failed with error: {:?}", e);
195
+ error!("dispatch_shards: task failed with error: {e}");
196
196
  if join_error.is_none() {
197
197
  join_error = Some(e.to_string());
198
198
  }
@@ -200,16 +200,13 @@ async fn async_deserialize_samples(
200
200
  }
201
201
  }
202
202
 
203
- info!("wds_worker: total samples sent: {}\n", count);
203
+ info!("wds_worker: total samples sent: {count}\n");
204
204
 
205
205
  // Signal the end of the stream
206
206
  let _ = shareable_channel_tx.send(None); // Channel could have been closed by a .stop() call
207
207
 
208
208
  if let Some(error) = join_error {
209
- error!(
210
- "wds_worker: encountered an error while processing samples: {}",
211
- error
212
- );
209
+ error!("wds_worker: encountered an error while processing samples: {error}");
213
210
  return Err(error);
214
211
  }
215
212
  Ok(())
@@ -242,7 +239,7 @@ pub fn deserialize_samples(
242
239
  .await
243
240
  {
244
241
  Ok(_) => debug!("wds_worker: all samples processed successfully"),
245
- Err(e) => error!("wds_worker: error processing samples : {:?}", e),
242
+ Err(e) => error!("wds_worker: error processing samples : {e}"),
246
243
  }
247
244
  });
248
245
  }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes