datago 2025.3.12__tar.gz → 2025.4.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {datago-2025.3.12 → datago-2025.4.3}/Cargo.lock +25 -19
  2. {datago-2025.3.12 → datago-2025.4.3}/Cargo.toml +3 -3
  3. {datago-2025.3.12 → datago-2025.4.3}/PKG-INFO +18 -1
  4. {datago-2025.3.12 → datago-2025.4.3}/README.md +17 -0
  5. {datago-2025.3.12 → datago-2025.4.3}/src/client.rs +4 -1
  6. {datago-2025.3.12 → datago-2025.4.3}/src/generator_http.rs +71 -1
  7. {datago-2025.3.12 → datago-2025.4.3}/src/main.rs +2 -2
  8. {datago-2025.3.12 → datago-2025.4.3}/src/structs.rs +5 -4
  9. {datago-2025.3.12 → datago-2025.4.3}/src/worker_files.rs +3 -2
  10. {datago-2025.3.12 → datago-2025.4.3}/src/worker_http.rs +31 -11
  11. {datago-2025.3.12 → datago-2025.4.3}/tests/client_test.rs +114 -0
  12. {datago-2025.3.12 → datago-2025.4.3}/.github/workflows/ci-cd.yml +0 -0
  13. {datago-2025.3.12 → datago-2025.4.3}/.github/workflows/rust.yml +0 -0
  14. {datago-2025.3.12 → datago-2025.4.3}/.gitignore +0 -0
  15. {datago-2025.3.12 → datago-2025.4.3}/.pre-commit-config.yaml +0 -0
  16. {datago-2025.3.12 → datago-2025.4.3}/LICENSE +0 -0
  17. {datago-2025.3.12 → datago-2025.4.3}/pyproject.toml +0 -0
  18. {datago-2025.3.12 → datago-2025.4.3}/python/benchmark_db.py +0 -0
  19. {datago-2025.3.12 → datago-2025.4.3}/python/benchmark_filesystem.py +0 -0
  20. {datago-2025.3.12 → datago-2025.4.3}/python/dataset.py +0 -0
  21. {datago-2025.3.12 → datago-2025.4.3}/python/raw_types.py +0 -0
  22. {datago-2025.3.12 → datago-2025.4.3}/python/test_datago_db.py +0 -0
  23. {datago-2025.3.12 → datago-2025.4.3}/python/test_datago_filesystem.py +0 -0
  24. {datago-2025.3.12 → datago-2025.4.3}/requirements-tests.txt +0 -0
  25. {datago-2025.3.12 → datago-2025.4.3}/requirements.txt +0 -0
  26. {datago-2025.3.12 → datago-2025.4.3}/src/generator_files.rs +0 -0
  27. {datago-2025.3.12 → datago-2025.4.3}/src/image_processing.rs +0 -0
  28. {datago-2025.3.12 → datago-2025.4.3}/src/lib.rs +0 -0
@@ -247,7 +247,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
247
247
  checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02"
248
248
  dependencies = [
249
249
  "smallvec",
250
- "target-lexicon",
250
+ "target-lexicon 0.12.16",
251
251
  ]
252
252
 
253
253
  [[package]]
@@ -387,7 +387,7 @@ dependencies = [
387
387
 
388
388
  [[package]]
389
389
  name = "datago"
390
- version = "2025.3.12"
390
+ version = "2025.4.3"
391
391
  dependencies = [
392
392
  "clap",
393
393
  "image",
@@ -1381,9 +1381,9 @@ checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
1381
1381
 
1382
1382
  [[package]]
1383
1383
  name = "openssl"
1384
- version = "0.10.71"
1384
+ version = "0.10.72"
1385
1385
  source = "registry+https://github.com/rust-lang/crates.io-index"
1386
- checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd"
1386
+ checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da"
1387
1387
  dependencies = [
1388
1388
  "bitflags 2.9.0",
1389
1389
  "cfg-if",
@@ -1422,9 +1422,9 @@ dependencies = [
1422
1422
 
1423
1423
  [[package]]
1424
1424
  name = "openssl-sys"
1425
- version = "0.9.106"
1425
+ version = "0.9.107"
1426
1426
  source = "registry+https://github.com/rust-lang/crates.io-index"
1427
- checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd"
1427
+ checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07"
1428
1428
  dependencies = [
1429
1429
  "cc",
1430
1430
  "libc",
@@ -1560,9 +1560,9 @@ dependencies = [
1560
1560
 
1561
1561
  [[package]]
1562
1562
  name = "pyo3"
1563
- version = "0.23.5"
1563
+ version = "0.24.1"
1564
1564
  source = "registry+https://github.com/rust-lang/crates.io-index"
1565
- checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
1565
+ checksum = "17da310086b068fbdcefbba30aeb3721d5bb9af8db4987d6735b2183ca567229"
1566
1566
  dependencies = [
1567
1567
  "cfg-if",
1568
1568
  "indoc",
@@ -1578,19 +1578,19 @@ dependencies = [
1578
1578
 
1579
1579
  [[package]]
1580
1580
  name = "pyo3-build-config"
1581
- version = "0.23.5"
1581
+ version = "0.24.1"
1582
1582
  source = "registry+https://github.com/rust-lang/crates.io-index"
1583
- checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
1583
+ checksum = "e27165889bd793000a098bb966adc4300c312497ea25cf7a690a9f0ac5aa5fc1"
1584
1584
  dependencies = [
1585
1585
  "once_cell",
1586
- "target-lexicon",
1586
+ "target-lexicon 0.13.2",
1587
1587
  ]
1588
1588
 
1589
1589
  [[package]]
1590
1590
  name = "pyo3-ffi"
1591
- version = "0.23.5"
1591
+ version = "0.24.1"
1592
1592
  source = "registry+https://github.com/rust-lang/crates.io-index"
1593
- checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
1593
+ checksum = "05280526e1dbf6b420062f3ef228b78c0c54ba94e157f5cb724a609d0f2faabc"
1594
1594
  dependencies = [
1595
1595
  "libc",
1596
1596
  "pyo3-build-config",
@@ -1598,9 +1598,9 @@ dependencies = [
1598
1598
 
1599
1599
  [[package]]
1600
1600
  name = "pyo3-macros"
1601
- version = "0.23.5"
1601
+ version = "0.24.1"
1602
1602
  source = "registry+https://github.com/rust-lang/crates.io-index"
1603
- checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
1603
+ checksum = "5c3ce5686aa4d3f63359a5100c62a127c9f15e8398e5fdeb5deef1fed5cd5f44"
1604
1604
  dependencies = [
1605
1605
  "proc-macro2",
1606
1606
  "pyo3-macros-backend",
@@ -1610,9 +1610,9 @@ dependencies = [
1610
1610
 
1611
1611
  [[package]]
1612
1612
  name = "pyo3-macros-backend"
1613
- version = "0.23.5"
1613
+ version = "0.24.1"
1614
1614
  source = "registry+https://github.com/rust-lang/crates.io-index"
1615
- checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
1615
+ checksum = "f4cf6faa0cbfb0ed08e89beb8103ae9724eb4750e3a78084ba4017cbe94f3855"
1616
1616
  dependencies = [
1617
1617
  "heck",
1618
1618
  "proc-macro2",
@@ -2211,6 +2211,12 @@ version = "0.12.16"
2211
2211
  source = "registry+https://github.com/rust-lang/crates.io-index"
2212
2212
  checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
2213
2213
 
2214
+ [[package]]
2215
+ name = "target-lexicon"
2216
+ version = "0.13.2"
2217
+ source = "registry+https://github.com/rust-lang/crates.io-index"
2218
+ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a"
2219
+
2214
2220
  [[package]]
2215
2221
  name = "tempfile"
2216
2222
  version = "3.17.1"
@@ -2288,9 +2294,9 @@ dependencies = [
2288
2294
 
2289
2295
  [[package]]
2290
2296
  name = "tokio"
2291
- version = "1.43.0"
2297
+ version = "1.43.1"
2292
2298
  source = "registry+https://github.com/rust-lang/crates.io-index"
2293
- checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
2299
+ checksum = "492a604e2fd7f814268a378409e6c92b5525d747d10db9a229723f55a417958c"
2294
2300
  dependencies = [
2295
2301
  "backtrace",
2296
2302
  "bytes",
@@ -1,7 +1,7 @@
1
1
  [package]
2
2
  name = "datago"
3
3
  edition = "2021"
4
- version = "2025.3.12"
4
+ version = "2025.4.3"
5
5
 
6
6
  [lib]
7
7
  # exposed by pyo3
@@ -21,9 +21,9 @@ serde_json = "1.0"
21
21
  url = "2.5.4"
22
22
  kanal = "0.1"
23
23
  clap = { version = "4.5.27", features = ["derive"] }
24
- tokio = { version = "1.43.0", features = ["rt-multi-thread", "macros"] }
24
+ tokio = { version = "1.43.1", features = ["rt-multi-thread", "macros"] }
25
25
  prettytable-rs = "0.10.0"
26
- pyo3 = { version = "0.23.4", features = ["extension-module"] }
26
+ pyo3 = { version = "0.24.1", features = ["extension-module"] }
27
27
  threadpool = "1.8.1"
28
28
  openssl = { version = "0.10", features = ["vendored"] }
29
29
  walkdir = "2.5.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: datago
3
- Version: 2025.3.12
3
+ Version: 2025.4.3
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: Implementation :: CPython
6
6
  Classifier: Programming Language :: Python :: Implementation :: PyPy
@@ -102,6 +102,23 @@ See helper functions provided in `raw_types.py`, should be self explanatory. Che
102
102
 
103
103
  Just install the rust toolchain via rustup
104
104
 
105
+ ## [Apple Silicon MacOS only]
106
+
107
+ If you are using an Apple Silicon Mac OS machine, create a `.cargo/config` file and paste the following:
108
+ ```
109
+ [target.x86_64-apple-darwin]
110
+ rustflags = [
111
+ "-C", "link-arg=-undefined",
112
+ "-C", "link-arg=dynamic_lookup",
113
+ ]
114
+
115
+ [target.aarch64-apple-darwin]
116
+ rustflags = [
117
+ "-C", "link-arg=-undefined",
118
+ "-C", "link-arg=dynamic_lookup",
119
+ ]
120
+ ```
121
+
105
122
  ## Build a benchmark CLI
106
123
  `cargo run --release -- -h` to get all the information, should be fairly straightforward
107
124
 
@@ -86,6 +86,23 @@ See helper functions provided in `raw_types.py`, should be self explanatory. Che
86
86
 
87
87
  Just install the rust toolchain via rustup
88
88
 
89
+ ## [Apple Silicon MacOS only]
90
+
91
+ If you are using an Apple Silicon Mac OS machine, create a `.cargo/config` file and paste the following:
92
+ ```
93
+ [target.x86_64-apple-darwin]
94
+ rustflags = [
95
+ "-C", "link-arg=-undefined",
96
+ "-C", "link-arg=dynamic_lookup",
97
+ ]
98
+
99
+ [target.aarch64-apple-darwin]
100
+ rustflags = [
101
+ "-C", "link-arg=-undefined",
102
+ "-C", "link-arg=dynamic_lookup",
103
+ ]
104
+ ```
105
+
89
106
  ## Build a benchmark CLI
90
107
  `cargo run --release -- -h` to get all the information, should be fairly straightforward
91
108
 
@@ -198,7 +198,6 @@ impl DatagoClient {
198
198
  if !self.is_started {
199
199
  self.start();
200
200
  }
201
- const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120);
202
201
 
203
202
  // If no more samples and workers are closed, then wrap it up
204
203
  if self.samples_rx.is_closed() {
@@ -208,6 +207,10 @@ impl DatagoClient {
208
207
  }
209
208
 
210
209
  // Try to fetch a new sample from the queue
210
+ // The client will timeout if zero sample is received in 5 minutes
211
+ // At this point it will stop and wrap everything up
212
+ const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
213
+
211
214
  match self.samples_rx.recv_timeout(TIMEOUT) {
212
215
  Ok(sample) => match sample {
213
216
  Some(sample) => Some(sample),
@@ -24,9 +24,18 @@ pub struct SourceDBConfig {
24
24
  #[serde(default)]
25
25
  pub tags: String,
26
26
 
27
+ #[serde(default)]
28
+ pub tags_all: String,
29
+
27
30
  #[serde(default)]
28
31
  pub tags_ne: String,
29
32
 
33
+ #[serde(default)]
34
+ pub tags_ne_all: String,
35
+
36
+ #[serde(default)]
37
+ pub tags_empty: String,
38
+
30
39
  #[serde(default)]
31
40
  pub has_attributes: String,
32
41
 
@@ -60,6 +69,9 @@ pub struct SourceDBConfig {
60
69
  #[serde(default)]
61
70
  pub duplicate_state: i32,
62
71
 
72
+ #[serde(default)]
73
+ pub attributes: String,
74
+
63
75
  #[serde(default)]
64
76
  pub random_sampling: bool,
65
77
  }
@@ -73,7 +85,10 @@ struct DbRequest {
73
85
  pub page_size: String,
74
86
 
75
87
  pub tags: String,
88
+ pub tags_all: String,
76
89
  pub tags_ne: String,
90
+ pub tags_ne_all: String,
91
+ pub tags_empty: String,
77
92
 
78
93
  pub has_attributes: String,
79
94
  pub lacks_attributes: String,
@@ -92,6 +107,7 @@ struct DbRequest {
92
107
  pub max_pixel_count: String,
93
108
 
94
109
  pub duplicate_state: String,
110
+ pub attributes: String,
95
111
  pub random_sampling: bool,
96
112
 
97
113
  pub partitions_count: String,
@@ -131,7 +147,10 @@ impl DbRequest {
131
147
  maybe_add_field("page_size", &self.page_size);
132
148
 
133
149
  maybe_add_field("tags", &self.tags);
150
+ maybe_add_field("tags__all", &self.tags_all);
134
151
  maybe_add_field("tags__ne", &self.tags_ne);
152
+ maybe_add_field("tags__ne_all", &self.tags_ne_all);
153
+ maybe_add_field("tags__empty", &self.tags_empty);
135
154
  maybe_add_field("has_attributes", &self.has_attributes);
136
155
  maybe_add_field("lacks_attributes", &self.lacks_attributes);
137
156
  maybe_add_field("has_masks", &self.has_masks);
@@ -144,6 +163,7 @@ impl DbRequest {
144
163
  maybe_add_field("pixel_count__gte", &self.min_pixel_count);
145
164
  maybe_add_field("pixel_count__lte", &self.max_pixel_count);
146
165
  maybe_add_field("duplicate_state", &self.duplicate_state);
166
+ maybe_add_field("attributes", &self.attributes);
147
167
  maybe_add_field("partitions_count", &self.partitions_count);
148
168
  maybe_add_field("partition", &self.partition);
149
169
  }
@@ -180,7 +200,53 @@ fn build_request(source_config: SourceDBConfig, rank: usize, world_size: usize)
180
200
 
181
201
  if !source_config.tags.is_empty() {
182
202
  fields.push_str(",tags");
183
- println!("Including some tags: {}", source_config.tags);
203
+ println!(
204
+ "Including some tags, must have any of: {}",
205
+ source_config.tags
206
+ );
207
+ }
208
+
209
+ if !source_config.tags_all.is_empty() {
210
+ fields.push_str(",tags");
211
+ println!(
212
+ "Including tags, must have all of: {}",
213
+ source_config.tags_all
214
+ );
215
+ }
216
+
217
+ if !source_config.tags_ne.is_empty() {
218
+ fields.push_str(",tags");
219
+ println!(
220
+ "Including tags, must not have any of: {}",
221
+ source_config.tags_ne
222
+ );
223
+ }
224
+
225
+ if !source_config.tags_empty.is_empty() {
226
+ fields.push_str(",tags");
227
+ println!(
228
+ "Using filter: Tags must{} be empty",
229
+ if source_config.tags_empty == "true" {
230
+ " not"
231
+ } else {
232
+ ""
233
+ }
234
+ );
235
+ if !source_config.tags_all.is_empty()
236
+ || !source_config.tags.is_empty()
237
+ || !source_config.tags_ne.is_empty()
238
+ || !source_config.tags_ne_all.is_empty()
239
+ {
240
+ println!("WARNING: you've set `tags_empty` in addition to `tags`, `tags_all`, `tags_ne` or `tags_ne_all`. The combination might be incompatible or redundant.");
241
+ }
242
+ }
243
+
244
+ if !source_config.tags_ne_all.is_empty() {
245
+ fields.push_str(",tags");
246
+ println!(
247
+ "Including tags, must not have all of: {}",
248
+ source_config.tags_ne_all
249
+ );
184
250
  }
185
251
 
186
252
  if source_config.require_embeddings {
@@ -213,7 +279,10 @@ fn build_request(source_config: SourceDBConfig, rank: usize, world_size: usize)
213
279
  sources_ne: source_config.sources_ne,
214
280
  page_size: source_config.page_size.to_string(),
215
281
  tags: source_config.tags,
282
+ tags_all: source_config.tags_all,
216
283
  tags_ne: source_config.tags_ne,
284
+ tags_ne_all: source_config.tags_ne_all,
285
+ tags_empty: source_config.tags_empty,
217
286
  has_attributes: source_config.has_attributes,
218
287
  lacks_attributes: source_config.lacks_attributes,
219
288
  has_masks: source_config.has_masks,
@@ -226,6 +295,7 @@ fn build_request(source_config: SourceDBConfig, rank: usize, world_size: usize)
226
295
  min_pixel_count: maybe_add_int(source_config.min_pixel_count),
227
296
  max_pixel_count: maybe_add_int(source_config.max_pixel_count),
228
297
  duplicate_state: maybe_add_int(source_config.duplicate_state),
298
+ attributes: source_config.attributes,
229
299
  random_sampling: source_config.random_sampling,
230
300
  partition: if world_size > 1 {
231
301
  format!("{}", rank)
@@ -1,7 +1,7 @@
1
1
  use clap::{Arg, Command};
2
2
  use prettytable::{row, Table};
3
3
  use serde_json::json;
4
-
4
+ use std::collections::HashMap;
5
5
  mod client;
6
6
  mod generator_files;
7
7
  mod generator_http;
@@ -118,7 +118,7 @@ fn main() {
118
118
  let mut client = client::DatagoClient::new(config.to_string());
119
119
 
120
120
  // -----------------------------------------------------------------
121
- let mut size_buckets: std::collections::HashMap<String, i32> = std::collections::HashMap::new();
121
+ let mut size_buckets: HashMap<String, i32> = HashMap::new();
122
122
  let start_time = std::time::Instant::now();
123
123
  let mut rolling_time = std::time::Instant::now();
124
124
 
@@ -1,6 +1,7 @@
1
1
  use crate::image_processing::ImageTransformConfig;
2
2
  use pyo3::prelude::*;
3
3
  use serde::{Deserialize, Serialize};
4
+ use std::collections::HashMap;
4
5
 
5
6
  #[derive(Deserialize)]
6
7
  #[serde(rename_all = "lowercase")]
@@ -64,7 +65,7 @@ pub struct Sample {
64
65
  pub source: String,
65
66
 
66
67
  #[doc(hidden)]
67
- pub attributes: std::collections::HashMap<String, serde_json::Value>,
68
+ pub attributes: HashMap<String, serde_json::Value>,
68
69
 
69
70
  #[pyo3(get, set)]
70
71
  pub duplicate_state: i32,
@@ -73,13 +74,13 @@ pub struct Sample {
73
74
  pub image: ImagePayload,
74
75
 
75
76
  #[pyo3(get, set)]
76
- pub masks: std::collections::HashMap<String, ImagePayload>,
77
+ pub masks: HashMap<String, ImagePayload>,
77
78
 
78
79
  #[pyo3(get, set)]
79
- pub additional_images: std::collections::HashMap<String, ImagePayload>,
80
+ pub additional_images: HashMap<String, ImagePayload>,
80
81
 
81
82
  #[pyo3(get, set)]
82
- pub latents: std::collections::HashMap<String, LatentPayload>,
83
+ pub latents: HashMap<String, LatentPayload>,
83
84
 
84
85
  #[pyo3(get, set)]
85
86
  pub coca_embedding: Vec<f32>,
@@ -2,6 +2,7 @@ use crate::image_processing;
2
2
  use crate::structs::{ImagePayload, Sample};
3
3
  use std::cmp::min;
4
4
  use std::collections::HashMap;
5
+ use std::collections::VecDeque;
5
6
  use std::sync::Arc;
6
7
 
7
8
  async fn image_from_path(path: &str) -> Result<image::DynamicImage, image::ImageError> {
@@ -78,7 +79,7 @@ async fn pull_sample(
78
79
  }
79
80
 
80
81
  pub async fn consume_oldest_task(
81
- tasks: &mut std::collections::VecDeque<tokio::task::JoinHandle<Result<(), ()>>>,
82
+ tasks: &mut VecDeque<tokio::task::JoinHandle<Result<(), ()>>>,
82
83
  ) -> Result<(), ()> {
83
84
  match tasks.pop_front().unwrap().await {
84
85
  Ok(_) => Ok(()),
@@ -100,7 +101,7 @@ async fn async_pull_samples(
100
101
  // We use async-await here, to better use IO stalls
101
102
  // We'll issue N async tasks in parallel, and wait for them to finish
102
103
  let max_tasks = min(num_cpus::get() * 2, limit);
103
- let mut tasks = std::collections::VecDeque::new();
104
+ let mut tasks = VecDeque::new();
104
105
  let mut count = 0;
105
106
  let shareable_img_tfm = Arc::new(image_transform);
106
107
 
@@ -5,6 +5,7 @@ use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
5
5
  use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
6
6
  use serde::{Deserialize, Serialize};
7
7
  use std::cmp::min;
8
+ use std::collections::HashMap;
8
9
  use std::sync::Arc;
9
10
 
10
11
  // We'll share a single connection pool across all worker threads
@@ -37,7 +38,7 @@ pub fn new_shared_client(max_connections: usize) -> SharedClient {
37
38
  struct SampleMetadata {
38
39
  id: String,
39
40
  source: String,
40
- attributes: std::collections::HashMap<String, serde_json::Value>,
41
+ attributes: HashMap<String, serde_json::Value>,
41
42
  duplicate_state: Option<i32>,
42
43
  image_direct_url: Option<String>,
43
44
  latents: Option<Vec<UrlLatent>>,
@@ -76,6 +77,28 @@ async fn image_from_url(
76
77
  )))
77
78
  }
78
79
 
80
+ async fn payload_from_url(
81
+ client: &SharedClient,
82
+ url: &str,
83
+ retries: i32,
84
+ ) -> Result<Vec<u8>, std::io::Error> {
85
+ // Retry on the fetch and decode a few times, could happen that we get a broken packet
86
+ for _ in 0..retries {
87
+ match bytes_from_url(client, url).await {
88
+ Some(bytes) => {
89
+ return Ok(bytes);
90
+ }
91
+ None => {
92
+ println!("Failed to get bytes from URL: {}. Retrying", url);
93
+ }
94
+ }
95
+ }
96
+ Err(std::io::Error::new(
97
+ std::io::ErrorKind::Other,
98
+ "Failed to fetch bytes buffer",
99
+ ))
100
+ }
101
+
79
102
  async fn image_payload_from_url(
80
103
  client: &SharedClient,
81
104
  url: &str,
@@ -143,12 +166,9 @@ async fn pull_sample(
143
166
  }
144
167
 
145
168
  // Same for the latents, mask and masked images, if they exist
146
- let mut masks: std::collections::HashMap<String, ImagePayload> =
147
- std::collections::HashMap::new();
148
- let mut additional_images: std::collections::HashMap<String, ImagePayload> =
149
- std::collections::HashMap::new();
150
- let mut latents: std::collections::HashMap<String, LatentPayload> =
151
- std::collections::HashMap::new();
169
+ let mut masks: HashMap<String, ImagePayload> = HashMap::new();
170
+ let mut additional_images: HashMap<String, ImagePayload> = HashMap::new();
171
+ let mut latents: HashMap<String, LatentPayload> = HashMap::new();
152
172
 
153
173
  if let Some(exposed_latents) = &sample.latents {
154
174
  for latent in exposed_latents {
@@ -203,8 +223,8 @@ async fn pull_sample(
203
223
  }
204
224
  } else {
205
225
  // Vanilla latents, pure binary payloads
206
- match bytes_from_url(&client, &latent.file_direct_url).await {
207
- Some(latent_payload) => {
226
+ match payload_from_url(&client, &latent.file_direct_url, 5).await {
227
+ Ok(latent_payload) => {
208
228
  latents.insert(
209
229
  latent.latent_type.clone(),
210
230
  LatentPayload {
@@ -213,8 +233,8 @@ async fn pull_sample(
213
233
  },
214
234
  );
215
235
  }
216
- None => {
217
- println!("Error fetching latent: {}", latent.file_direct_url);
236
+ Err(e) => {
237
+ println!("Error fetching latent: {} {}", latent.file_direct_url, e);
218
238
  return Err(());
219
239
  }
220
240
  }
@@ -19,6 +19,9 @@ fn get_test_config() -> serde_json::Value {
19
19
  "require_embeddings": false,
20
20
  "tags": "",
21
21
  "tags_ne": "",
22
+ "tags_all": "",
23
+ "tags_ne_all": "",
24
+ "tags_empty": "",
22
25
  "has_attributes": "",
23
26
  "lacks_attributes": "",
24
27
  "has_masks": "",
@@ -30,6 +33,7 @@ fn get_test_config() -> serde_json::Value {
30
33
  "min_pixel_count": -1,
31
34
  "max_pixel_count": -1,
32
35
  "duplicate_state": -1,
36
+ "attributes": "",
33
37
  "random_sampling": false,
34
38
  "page_size": 10,
35
39
  },
@@ -253,6 +257,116 @@ fn test_tags() {
253
257
  client.stop();
254
258
  }
255
259
 
260
+ #[test]
261
+ fn test_tags_all() {
262
+ let mut config = get_test_config();
263
+ let tags = "v4_trainset_hq,photo";
264
+ config["source_config"]["tags_all"] = tags.into();
265
+ let mut client = DatagoClient::new(config.to_string());
266
+
267
+ let sample = client.get_sample();
268
+ assert!(sample.is_some());
269
+
270
+ let sample = sample.unwrap();
271
+ assert!(!sample.id.is_empty());
272
+ // Check that sample.tags contains all the tags in the tags string
273
+ for tag in tags.split(',') {
274
+ assert!(sample.tags.contains(&tag.to_string()));
275
+ }
276
+ client.stop();
277
+ }
278
+
279
+ #[test]
280
+ fn test_tags_ne() {
281
+ let mut config = get_test_config();
282
+ let tags = "v4_trainset_hq,photo";
283
+ config["source_config"]["tags_ne"] = tags.into();
284
+ let mut client = DatagoClient::new(config.to_string());
285
+
286
+ let sample = client.get_sample();
287
+ assert!(sample.is_some());
288
+
289
+ let sample = sample.unwrap();
290
+ assert!(!sample.id.is_empty());
291
+ // Check that sample.tags does not contain any of the tags in the tags string
292
+ println!("{:?}", sample.tags);
293
+ for tag in tags.split(',') {
294
+ assert!(!sample.tags.contains(&tag.to_string()));
295
+ }
296
+ client.stop();
297
+ }
298
+
299
+ #[test]
300
+ fn test_tags_empty() {
301
+ let mut config = get_test_config();
302
+ config["source_config"]["tags_empty"] = "true".into();
303
+ let mut client = DatagoClient::new(config.to_string());
304
+
305
+ let sample = client.get_sample();
306
+ assert!(sample.is_some());
307
+
308
+ let sample = sample.unwrap();
309
+ assert!(sample.tags.is_empty());
310
+ client.stop();
311
+ }
312
+
313
+ #[test]
314
+ fn test_tags_ne_all() {
315
+ let mut config = get_test_config();
316
+ let tag1 = "photo";
317
+ let tag2 = "graphic";
318
+ config["source_config"]["tags_ne_all"] = format!("{},{}", tag1, tag2).into();
319
+ let mut client = DatagoClient::new(config.to_string());
320
+
321
+ let sample = client.get_sample();
322
+ assert!(sample.is_some());
323
+
324
+ let sample = sample.unwrap();
325
+ assert!(!sample.id.is_empty());
326
+ // Assert that the sample does not contain both tags at the same time
327
+ let has_first = sample.tags.contains(&tag1.to_string());
328
+ let has_second = sample.tags.contains(&tag2.to_string());
329
+ assert!(
330
+ !(has_first && has_second),
331
+ "Sample should not contain both tags at the same time"
332
+ );
333
+ client.stop();
334
+ }
335
+
336
+ #[test]
337
+ fn test_attributes_filter() {
338
+ let mut config = get_test_config();
339
+ config["source_config"]["attributes"] = "aesthetic_score__gte:0.5".into();
340
+ let mut client = DatagoClient::new(config.to_string());
341
+
342
+ let sample = client.get_sample();
343
+ assert!(sample.is_some());
344
+
345
+ let sample = sample.unwrap();
346
+ assert!(!sample.id.is_empty());
347
+ assert!(sample.attributes.contains_key("aesthetic_score"));
348
+ assert!(sample.attributes["aesthetic_score"].as_f64().unwrap() >= 0.5);
349
+ client.stop();
350
+ }
351
+
352
+ #[test]
353
+ fn test_pixel_count_filter() {
354
+ let mut config = get_test_config();
355
+ config["source_config"]["min_pixel_count"] = 1000000.into();
356
+ config["source_config"]["max_pixel_count"] = 2000000.into();
357
+ config["source_config"]["require_images"] = json!(true);
358
+ let mut client = DatagoClient::new(config.to_string());
359
+
360
+ let sample = client.get_sample();
361
+ assert!(sample.is_some());
362
+
363
+ let sample = sample.unwrap();
364
+ assert!(!sample.id.is_empty());
365
+ assert!(sample.image.width * sample.image.height >= 1000000);
366
+ assert!(sample.image.width * sample.image.height <= 2000000);
367
+ client.stop();
368
+ }
369
+
256
370
  #[test]
257
371
  fn test_multiple_sources() {
258
372
  let limit = 10;
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes