TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,460 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import itertools
17
+ from collections import ChainMap
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional
20
+
21
+ import numpy as np
22
+ from tensordict import TensorDict
23
+
24
+ from transfer_queue.utils.utils import ProductionStatus
25
+
26
+
27
+ @dataclass
28
+ class FieldMeta:
29
+ """Records the metadata of a single data field (name, dtype, shape, etc.)."""
30
+
31
+ name: str
32
+ dtype: Optional[Any] # Data type (e.g., torch.float32, numpy.float32)
33
+ shape: Optional[Any] # Data shape (e.g., torch.Size([3, 224, 224]), (3, 224, 224))
34
+ production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED
35
+
36
+ def __str__(self) -> str:
37
+ return (
38
+ f"FieldMeta(name='{self.name}', dtype={self.dtype}, "
39
+ f"shape={self.shape}, production_status={self.production_status})"
40
+ )
41
+
42
+ @property
43
+ def is_ready(self) -> bool:
44
+ """Check if this field is ready for consumption"""
45
+ return self.production_status == ProductionStatus.READY_FOR_CONSUME
46
+
47
+
48
+ @dataclass
49
+ class SampleMeta:
50
+ """Records the metadata of a single data sample (stored as a row in the data system)."""
51
+
52
+ partition_id: str # Partition id, used for data versioning
53
+ global_index: int # Global row index, uniquely identifies a data sample
54
+ fields: dict[str, FieldMeta] # Fields of interest for this sample
55
+
56
+ def __post_init__(self):
57
+ """Initialize is_ready property based on field readiness"""
58
+ # Check if all fields are ready and update is_ready property
59
+ object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
60
+
61
+ def __str__(self) -> str:
62
+ return f"SampleMeta(partition_id={self.partition_id}, global_index={self.global_index})"
63
+
64
+ @property
65
+ def field_names(self) -> list[str]:
66
+ """Get list of field names for this sample"""
67
+ return list(self.fields.keys())
68
+
69
+ @property
70
+ def batch_index(self) -> int:
71
+ """Get the batch index of this sample (to be set by BatchMeta)"""
72
+ return getattr(self, "_batch_index", -1)
73
+
74
+ def get_field_by_name(self, name: str) -> Optional[FieldMeta]:
75
+ """Get FieldMeta by field name"""
76
+ return self.fields.get(name)
77
+
78
+ def has_field(self, name: str) -> bool:
79
+ """Check if this sample has a specific field"""
80
+ return name in self.fields
81
+
82
+ def is_field_ready(self, field_name: str) -> bool:
83
+ """Check if a specific field is ready for consumption"""
84
+ field = self.fields.get(field_name)
85
+ return field.is_ready if field else False
86
+
87
+ def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta":
88
+ """
89
+ Add new fields to this sample. New fields will be initialized with given dtype, shape
90
+ and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME)
91
+ will be used. This modifies the sample in-place to include the new fields.
92
+ """
93
+ self.fields = _union_fields(self.fields, fields)
94
+ # Update is_ready property
95
+ object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
96
+ return self
97
+
98
+ def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta":
99
+ """
100
+ Create a union of this sample's fields with another sample's fields.
101
+ Assume both samples have the same global index. If fields overlap, the
102
+ fields in this sample will be replaced by the other sample's fields.
103
+
104
+ Args:
105
+ other: Another SampleMeta to union with
106
+ validate: Whether to validate union conditions
107
+
108
+ Returns:
109
+ New SampleMeta with unioned fields (None if validation fails)
110
+ """
111
+ if validate:
112
+ if self.global_index != other.global_index:
113
+ raise ValueError(
114
+ f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union."
115
+ )
116
+
117
+ # Merge fields
118
+ self.fields = _union_fields(self.fields, other.fields)
119
+
120
+ # Update is_ready property
121
+ object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
122
+ return self
123
+
124
+ @property
125
+ def is_ready(self) -> bool:
126
+ """Check if all fields in this sample are ready for consumption"""
127
+ return getattr(self, "_is_ready", False)
128
+
129
+ @property
130
+ def production_status(self) -> dict[str, ProductionStatus]:
131
+ """Get production status for all fields (backward compatibility)"""
132
+ return {name: field.production_status for name, field in self.fields.items()}
133
+
134
+
135
+ @dataclass
136
+ class BatchMeta:
137
+ """Records the metadata of a batch of data samples."""
138
+
139
+ samples: list[SampleMeta]
140
+ extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
141
+
142
+ def __post_init__(self):
143
+ """Initialize all computed properties during initialization"""
144
+ # Basic properties
145
+ object.__setattr__(self, "_size", len(self.samples))
146
+ object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
147
+
148
+ # Pre-compute all list properties for better performance
149
+ if self.samples:
150
+ for idx, sample in enumerate(self.samples):
151
+ object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
152
+
153
+ object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
154
+
155
+ # assume all samples have the same fields.
156
+ object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
157
+ else:
158
+ object.__setattr__(self, "_global_indexes", [])
159
+ object.__setattr__(self, "_field_names", [])
160
+
161
+ @property
162
+ def size(self) -> int:
163
+ """Return the number of samples in this batch"""
164
+ return getattr(self, "_size", 0)
165
+
166
+ @property
167
+ def global_indexes(self) -> list[int]:
168
+ """Get all global indexes in this batch"""
169
+ return getattr(self, "_global_indexes", [])
170
+
171
+ @property
172
+ def field_names(self) -> list[str]:
173
+ """Get all unique field names in this batch"""
174
+ return getattr(self, "_field_names", [])
175
+
176
+ @property
177
+ def is_ready(self) -> bool:
178
+ """Check if all samples in this batch are ready for consumption"""
179
+ # TODO: get ready status from controller realtime
180
+ return getattr(self, "_is_ready", False)
181
+
182
+ # Extra info interface methods
183
+ def get_extra_info(self, key: str, default: Any = None) -> Any:
184
+ """Get extra info by key"""
185
+ return self.extra_info.get(key, default)
186
+
187
+ def set_extra_info(self, key: str, value: Any) -> None:
188
+ """Set extra info by key"""
189
+ self.extra_info[key] = value
190
+
191
+ def update_extra_info(self, info_dict: dict[str, Any]) -> None:
192
+ """Update extra info with multiple key-value pairs"""
193
+ self.extra_info.update(info_dict)
194
+
195
+ def remove_extra_info(self, key: str) -> Any:
196
+ """Remove extra info by key and return its value"""
197
+ return self.extra_info.pop(key, None)
198
+
199
+ def clear_extra_info(self) -> None:
200
+ """Clear all extra info"""
201
+ self.extra_info.clear()
202
+
203
+ def has_extra_info(self, key: str) -> bool:
204
+ """Check if extra info contains a specific key"""
205
+ return key in self.extra_info
206
+
207
+ def get_all_extra_info(self) -> dict[str, Any]:
208
+ """Get all extra info as a dictionary"""
209
+ return self.extra_info.copy()
210
+
211
+ def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta":
212
+ """
213
+ Add new fields from a TensorDict to all samples in this batch.
214
+ This modifies each sample in-place to include the new fields.
215
+
216
+ Args:
217
+ tensor_dict (TensorDict): The input TensorDict containing new fields.
218
+ set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True.
219
+ """
220
+ fields = _extract_field_metas(tensor_dict, set_all_ready)
221
+
222
+ if fields:
223
+ if len(self.samples) != len(fields):
224
+ raise ValueError(f"add_fields length mismatch: samples={len(self.samples)} vs fields={len(fields)}")
225
+ for idx, sample in enumerate(self.samples):
226
+ sample.add_fields(fields=fields[idx])
227
+
228
+ # Update batch-level fields cache
229
+ object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
230
+ object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
231
+ return self
232
+
233
+ def __len__(self) -> int:
234
+ """Return the number of samples in this batch."""
235
+ return len(self.samples)
236
+
237
+ def __getitem__(self, item):
238
+ if isinstance(item, int | np.integer):
239
+ sample_meta = self.samples[item] if self.samples else []
240
+ return BatchMeta(samples=[sample_meta], extra_info=self.extra_info)
241
+ else:
242
+ raise TypeError(f"Indexing with {type(item)} is not supported now!")
243
+
244
+ def chunk(self, chunks: int) -> list["BatchMeta"]:
245
+ """
246
+ Split this batch into smaller chunks.
247
+
248
+ Args:
249
+ chunks: number of chunks
250
+
251
+ Return:
252
+ List of smaller BatchMeta chunks
253
+ """
254
+ chunk_list = []
255
+ n = len(self.samples)
256
+
257
+ # Calculate the base size and remainder of each chunk
258
+ base_size = n // chunks
259
+ remainder = n % chunks
260
+
261
+ start = 0
262
+ for i in range(chunks):
263
+ # Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size)
264
+ current_chunk_size = base_size + 1 if i < remainder else base_size
265
+ end = start + current_chunk_size
266
+ chunk_samples = self.samples[start:end]
267
+ chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info.copy())
268
+ chunk_list.append(chunk)
269
+ start = end
270
+ return chunk_list
271
+
272
+ @classmethod
273
+ def concat(cls, data: list["BatchMeta"], validate: bool = True) -> Optional["BatchMeta"]:
274
+ """
275
+ Concatenate multiple BatchMeta chunks into one large batch.
276
+
277
+ Args:
278
+ data: List of BatchMeta chunks to concatenate
279
+ validate: Whether to validate concatenation conditions
280
+
281
+ Returns:
282
+ Concatenated BatchMeta
283
+
284
+ Raises:
285
+ ValueError: If validation fails (e.g., field names do not match)
286
+ """
287
+ if not data:
288
+ return None
289
+
290
+ if validate:
291
+ base_fields = data[0].field_names
292
+
293
+ for chunk in data:
294
+ if chunk.field_names != base_fields:
295
+ raise ValueError("Error: Field names do not match for concatenation.")
296
+
297
+ # Combine all samples
298
+ all_samples = list(itertools.chain.from_iterable(chunk.samples for chunk in data))
299
+ # Merge all extra_info dictionaries from the chunks
300
+ merged_extra_info = dict(ChainMap(*(chunk.extra_info for chunk in data)))
301
+
302
+ return BatchMeta(samples=all_samples, extra_info=merged_extra_info)
303
+
304
+ def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]:
305
+ """
306
+ Create a union of this batch's fields with another batch's fields.
307
+ Assume both batches have the same global indices. If fields overlap, the
308
+ fields in this batch will be replaced by the other batch's fields.
309
+
310
+ Args:
311
+ other: Another BatchMeta to union with
312
+ validate: Whether to validate union conditions
313
+
314
+ Returns:
315
+ New BatchMeta with unioned fields
316
+
317
+ Raises:
318
+ ValueError: If validation fails (e.g., batch sizes or global indexes do not match)
319
+ """
320
+ if validate:
321
+ if self.size != other.size:
322
+ raise ValueError("Error: Batch sizes do not match for union.")
323
+
324
+ self_global_indexes = sorted(self.global_indexes)
325
+ other_global_indexes = sorted(other.global_indexes)
326
+ if self_global_indexes != other_global_indexes:
327
+ raise ValueError("Error: Global indexes do not match for union.")
328
+
329
+ # Create a mapping from global_index to SampleMeta in the other batch
330
+ other_sample_map = {sample.global_index: sample for sample in other.samples}
331
+
332
+ # Merge samples
333
+ merged_samples = []
334
+ for sample in self.samples:
335
+ if sample.global_index in other_sample_map:
336
+ other_sample = other_sample_map[sample.global_index]
337
+ merged_sample = sample.union(other_sample, validate=validate)
338
+ merged_samples.append(merged_sample)
339
+ else:
340
+ merged_samples.append(sample)
341
+
342
+ # Merge extra info dictionaries
343
+ merged_extra_info = {**self.extra_info, **other.extra_info}
344
+ return BatchMeta(samples=merged_samples, extra_info=merged_extra_info)
345
+
346
+ def reorder(self, indices: list[int]):
347
+ """
348
+ Reorder the SampleMeta in the BatchMeta according to the given indices.
349
+
350
+ The operation is performed in-place, modifying the current BatchMeta's SampleMeta order.
351
+
352
+ Args:
353
+ indices : list[int]
354
+ A list of integers specifying the new order of SampleMeta. Each integer
355
+ represents the current index of the SampleMeta in the BatchMeta.
356
+ """
357
+ # Reorder the samples
358
+ reordered_samples = [self.samples[i] for i in indices]
359
+ object.__setattr__(self, "samples", reordered_samples)
360
+
361
+ # Update necessary attributes
362
+ self._update_after_reorder()
363
+
364
+ def _update_after_reorder(self) -> None:
365
+ """Update related attributes specifically for the reorder operation"""
366
+ # Update batch_index for each sample
367
+ for idx, sample in enumerate(self.samples):
368
+ object.__setattr__(sample, "_batch_index", idx)
369
+
370
+ # Update cached index lists
371
+ object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
372
+
373
+ # Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder
374
+
375
+ @classmethod
376
+ def from_samples(
377
+ cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None
378
+ ) -> "BatchMeta":
379
+ """
380
+ Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects.
381
+
382
+ Args:
383
+ samples: A single SampleMeta or a list of SampleMeta objects
384
+ extra_info: Optional additional information to store with the batch
385
+
386
+ Returns:
387
+ BatchMeta instance containing the provided sample(s)
388
+
389
+ Example:
390
+ >>> sample_meta = SampleMeta(...)
391
+ >>> batch_meta = BatchMeta.from_samples(sample_meta)
392
+
393
+ >>> sample_metas = [sample1, sample2, sample3]
394
+ >>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"})
395
+ """
396
+ if extra_info is None:
397
+ extra_info = {}
398
+
399
+ if isinstance(samples, SampleMeta):
400
+ samples = [samples]
401
+
402
+ return cls(samples=samples, extra_info=extra_info)
403
+
404
+ @classmethod
405
+ def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta":
406
+ """
407
+ Create an empty BatchMeta with no samples.
408
+
409
+ Args:
410
+ extra_info: Optional additional information to store with the batch
411
+
412
+ Returns:
413
+ Empty BatchMeta instance
414
+
415
+ Example:
416
+ >>> empty_batch = BatchMeta.empty()
417
+ """
418
+ if extra_info is None:
419
+ extra_info = {}
420
+ return cls(samples=[], extra_info=extra_info)
421
+
422
+
423
+ def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]:
424
+ """Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2."""
425
+ for name in fields2.keys():
426
+ fields1[name] = fields2[name]
427
+ return fields1
428
+
429
+
430
+ def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]:
431
+ """
432
+ Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute,
433
+ the corresponding dtype or shape will be set to None.
434
+
435
+ Args:
436
+ tensor_dict (TensorDict): The input TensorDict.
437
+ set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME.
438
+ Otherwise, set to NOT_PRODUCED. Default is True.
439
+
440
+ Returns:
441
+ all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata.
442
+ """
443
+ batch_size = tensor_dict.batch_size[0]
444
+
445
+ production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED
446
+
447
+ all_fields = [
448
+ {
449
+ name: FieldMeta(
450
+ name=name,
451
+ dtype=getattr(value, "dtype", None),
452
+ shape=getattr(value, "shape", None),
453
+ production_status=production_status,
454
+ )
455
+ for name, value in tensor_dict[idx].items()
456
+ }
457
+ for idx in range(batch_size)
458
+ ]
459
+
460
+ return all_fields
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import BaseSampler
16
+ from .grpo_group_n_sampler import GRPOGroupNSampler
17
+ from .sequential_sampler import SequentialSampler
18
+
19
+ __all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler"]
@@ -0,0 +1,74 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import Any
17
+
18
+
19
+ class BaseSampler(ABC):
20
+ """Base class for samplers that control how data is consumed from TransferQueue.
21
+
22
+ A sampler defines the logic for selecting which samples to retrieve from the
23
+ available samples, and which should be labeled as consumed (will never be retrieved in the future).
24
+ Based on this abstraction, users can implement various data consumption strategies
25
+ for different training scenarios, such as sequential sampling, grouped sampling for
26
+ reinforcement learning, or custom sampling patterns.
27
+
28
+ The sampler interface provides a clean separation between data production status
29
+ (handled by TransferQueueController) and data consumption strategy (implemented by samplers).
30
+ This allows users to customize data consumption behavior without modifying the TransferQueue codes.
31
+
32
+ Available Samplers:
33
+ - **SequentialSampler**: Default sampler, selects samples sequentially without replacement
34
+ - **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready.
35
+ It assumes the N samples associated with the same prompt are stored contiguously
36
+ - **RankAwareSampler**: Rank-aware sampling for distributed scenarios (TODO)
37
+
38
+ NOTE: Always return both sampled and consumed indexes (may be identical).
39
+ """
40
+
41
+ def __init__(self):
42
+ self._states: dict[str, Any] = {}
43
+
44
+ @abstractmethod
45
+ def sample(
46
+ self,
47
+ ready_indexes: list[int],
48
+ batch_size: int,
49
+ *args: Any,
50
+ **kwargs: Any,
51
+ ) -> tuple[list[int], list[int]]:
52
+ """Sample a batch of indices from the ready indices.
53
+
54
+ Args:
55
+ ready_indexes: List of global indices for which all required fields of the
56
+ corresponding samples have been produced, and the samples are not labeled as
57
+ consumed in the corresponding task.
58
+ batch_size: Number of samples to select
59
+ *args: Additional positional arguments for specific sampler implementations
60
+ **kwargs: Additional keyword arguments for specific sampler implementations
61
+
62
+ Returns:
63
+ List of sampled global indices of length batch_size
64
+
65
+ List of global indices of length batch_size that should be labeled as consumed
66
+ (will never be retrieved in the future)
67
+
68
+ Raises:
69
+ ValueError: If batch_size is invalid or ready_indexes is insufficient
70
+ """
71
+ raise NotImplementedError("Subclasses must implement sample")
72
+
73
+ def __call__(self, *args: Any, **kwargs: Any) -> tuple[list[int], list[int]]:
74
+ return self.sample(*args, **kwargs)