TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,460 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import dataclasses
|
|
16
|
+
import itertools
|
|
17
|
+
from collections import ChainMap
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Any, Optional
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
from tensordict import TensorDict
|
|
23
|
+
|
|
24
|
+
from transfer_queue.utils.utils import ProductionStatus
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class FieldMeta:
|
|
29
|
+
"""Records the metadata of a single data field (name, dtype, shape, etc.)."""
|
|
30
|
+
|
|
31
|
+
name: str
|
|
32
|
+
dtype: Optional[Any] # Data type (e.g., torch.float32, numpy.float32)
|
|
33
|
+
shape: Optional[Any] # Data shape (e.g., torch.Size([3, 224, 224]), (3, 224, 224))
|
|
34
|
+
production_status: ProductionStatus = ProductionStatus.NOT_PRODUCED
|
|
35
|
+
|
|
36
|
+
def __str__(self) -> str:
|
|
37
|
+
return (
|
|
38
|
+
f"FieldMeta(name='{self.name}', dtype={self.dtype}, "
|
|
39
|
+
f"shape={self.shape}, production_status={self.production_status})"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def is_ready(self) -> bool:
|
|
44
|
+
"""Check if this field is ready for consumption"""
|
|
45
|
+
return self.production_status == ProductionStatus.READY_FOR_CONSUME
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class SampleMeta:
|
|
50
|
+
"""Records the metadata of a single data sample (stored as a row in the data system)."""
|
|
51
|
+
|
|
52
|
+
partition_id: str # Partition id, used for data versioning
|
|
53
|
+
global_index: int # Global row index, uniquely identifies a data sample
|
|
54
|
+
fields: dict[str, FieldMeta] # Fields of interest for this sample
|
|
55
|
+
|
|
56
|
+
def __post_init__(self):
|
|
57
|
+
"""Initialize is_ready property based on field readiness"""
|
|
58
|
+
# Check if all fields are ready and update is_ready property
|
|
59
|
+
object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
|
|
60
|
+
|
|
61
|
+
def __str__(self) -> str:
|
|
62
|
+
return f"SampleMeta(partition_id={self.partition_id}, global_index={self.global_index})"
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def field_names(self) -> list[str]:
|
|
66
|
+
"""Get list of field names for this sample"""
|
|
67
|
+
return list(self.fields.keys())
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def batch_index(self) -> int:
|
|
71
|
+
"""Get the batch index of this sample (to be set by BatchMeta)"""
|
|
72
|
+
return getattr(self, "_batch_index", -1)
|
|
73
|
+
|
|
74
|
+
def get_field_by_name(self, name: str) -> Optional[FieldMeta]:
|
|
75
|
+
"""Get FieldMeta by field name"""
|
|
76
|
+
return self.fields.get(name)
|
|
77
|
+
|
|
78
|
+
def has_field(self, name: str) -> bool:
|
|
79
|
+
"""Check if this sample has a specific field"""
|
|
80
|
+
return name in self.fields
|
|
81
|
+
|
|
82
|
+
def is_field_ready(self, field_name: str) -> bool:
|
|
83
|
+
"""Check if a specific field is ready for consumption"""
|
|
84
|
+
field = self.fields.get(field_name)
|
|
85
|
+
return field.is_ready if field else False
|
|
86
|
+
|
|
87
|
+
def add_fields(self, fields: dict[str, FieldMeta]) -> "SampleMeta":
|
|
88
|
+
"""
|
|
89
|
+
Add new fields to this sample. New fields will be initialized with given dtype, shape
|
|
90
|
+
and production_status (if provided). If not provided, default values (None, None, READY_FOR_CONSUME)
|
|
91
|
+
will be used. This modifies the sample in-place to include the new fields.
|
|
92
|
+
"""
|
|
93
|
+
self.fields = _union_fields(self.fields, fields)
|
|
94
|
+
# Update is_ready property
|
|
95
|
+
object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
def union(self, other: "SampleMeta", validate: bool = True) -> "SampleMeta":
|
|
99
|
+
"""
|
|
100
|
+
Create a union of this sample's fields with another sample's fields.
|
|
101
|
+
Assume both samples have the same global index. If fields overlap, the
|
|
102
|
+
fields in this sample will be replaced by the other sample's fields.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
other: Another SampleMeta to union with
|
|
106
|
+
validate: Whether to validate union conditions
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
New SampleMeta with unioned fields (None if validation fails)
|
|
110
|
+
"""
|
|
111
|
+
if validate:
|
|
112
|
+
if self.global_index != other.global_index:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Error: Global indexes ({self.global_index} and {other.global_index}) do not match for union."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Merge fields
|
|
118
|
+
self.fields = _union_fields(self.fields, other.fields)
|
|
119
|
+
|
|
120
|
+
# Update is_ready property
|
|
121
|
+
object.__setattr__(self, "_is_ready", all(field.is_ready for field in self.fields.values()))
|
|
122
|
+
return self
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def is_ready(self) -> bool:
|
|
126
|
+
"""Check if all fields in this sample are ready for consumption"""
|
|
127
|
+
return getattr(self, "_is_ready", False)
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def production_status(self) -> dict[str, ProductionStatus]:
|
|
131
|
+
"""Get production status for all fields (backward compatibility)"""
|
|
132
|
+
return {name: field.production_status for name, field in self.fields.items()}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclass
|
|
136
|
+
class BatchMeta:
|
|
137
|
+
"""Records the metadata of a batch of data samples."""
|
|
138
|
+
|
|
139
|
+
samples: list[SampleMeta]
|
|
140
|
+
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
141
|
+
|
|
142
|
+
def __post_init__(self):
|
|
143
|
+
"""Initialize all computed properties during initialization"""
|
|
144
|
+
# Basic properties
|
|
145
|
+
object.__setattr__(self, "_size", len(self.samples))
|
|
146
|
+
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
|
|
147
|
+
|
|
148
|
+
# Pre-compute all list properties for better performance
|
|
149
|
+
if self.samples:
|
|
150
|
+
for idx, sample in enumerate(self.samples):
|
|
151
|
+
object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
|
|
152
|
+
|
|
153
|
+
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
|
|
154
|
+
|
|
155
|
+
# assume all samples have the same fields.
|
|
156
|
+
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
|
|
157
|
+
else:
|
|
158
|
+
object.__setattr__(self, "_global_indexes", [])
|
|
159
|
+
object.__setattr__(self, "_field_names", [])
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def size(self) -> int:
|
|
163
|
+
"""Return the number of samples in this batch"""
|
|
164
|
+
return getattr(self, "_size", 0)
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def global_indexes(self) -> list[int]:
|
|
168
|
+
"""Get all global indexes in this batch"""
|
|
169
|
+
return getattr(self, "_global_indexes", [])
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def field_names(self) -> list[str]:
|
|
173
|
+
"""Get all unique field names in this batch"""
|
|
174
|
+
return getattr(self, "_field_names", [])
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def is_ready(self) -> bool:
|
|
178
|
+
"""Check if all samples in this batch are ready for consumption"""
|
|
179
|
+
# TODO: get ready status from controller realtime
|
|
180
|
+
return getattr(self, "_is_ready", False)
|
|
181
|
+
|
|
182
|
+
# Extra info interface methods
|
|
183
|
+
def get_extra_info(self, key: str, default: Any = None) -> Any:
|
|
184
|
+
"""Get extra info by key"""
|
|
185
|
+
return self.extra_info.get(key, default)
|
|
186
|
+
|
|
187
|
+
def set_extra_info(self, key: str, value: Any) -> None:
|
|
188
|
+
"""Set extra info by key"""
|
|
189
|
+
self.extra_info[key] = value
|
|
190
|
+
|
|
191
|
+
def update_extra_info(self, info_dict: dict[str, Any]) -> None:
|
|
192
|
+
"""Update extra info with multiple key-value pairs"""
|
|
193
|
+
self.extra_info.update(info_dict)
|
|
194
|
+
|
|
195
|
+
def remove_extra_info(self, key: str) -> Any:
|
|
196
|
+
"""Remove extra info by key and return its value"""
|
|
197
|
+
return self.extra_info.pop(key, None)
|
|
198
|
+
|
|
199
|
+
def clear_extra_info(self) -> None:
|
|
200
|
+
"""Clear all extra info"""
|
|
201
|
+
self.extra_info.clear()
|
|
202
|
+
|
|
203
|
+
def has_extra_info(self, key: str) -> bool:
|
|
204
|
+
"""Check if extra info contains a specific key"""
|
|
205
|
+
return key in self.extra_info
|
|
206
|
+
|
|
207
|
+
def get_all_extra_info(self) -> dict[str, Any]:
|
|
208
|
+
"""Get all extra info as a dictionary"""
|
|
209
|
+
return self.extra_info.copy()
|
|
210
|
+
|
|
211
|
+
def add_fields(self, tensor_dict: TensorDict, set_all_ready: bool = True) -> "BatchMeta":
|
|
212
|
+
"""
|
|
213
|
+
Add new fields from a TensorDict to all samples in this batch.
|
|
214
|
+
This modifies each sample in-place to include the new fields.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
tensor_dict (TensorDict): The input TensorDict containing new fields.
|
|
218
|
+
set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME. Default is True.
|
|
219
|
+
"""
|
|
220
|
+
fields = _extract_field_metas(tensor_dict, set_all_ready)
|
|
221
|
+
|
|
222
|
+
if fields:
|
|
223
|
+
if len(self.samples) != len(fields):
|
|
224
|
+
raise ValueError(f"add_fields length mismatch: samples={len(self.samples)} vs fields={len(fields)}")
|
|
225
|
+
for idx, sample in enumerate(self.samples):
|
|
226
|
+
sample.add_fields(fields=fields[idx])
|
|
227
|
+
|
|
228
|
+
# Update batch-level fields cache
|
|
229
|
+
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
|
|
230
|
+
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
|
|
231
|
+
return self
|
|
232
|
+
|
|
233
|
+
def __len__(self) -> int:
|
|
234
|
+
"""Return the number of samples in this batch."""
|
|
235
|
+
return len(self.samples)
|
|
236
|
+
|
|
237
|
+
def __getitem__(self, item):
|
|
238
|
+
if isinstance(item, int | np.integer):
|
|
239
|
+
sample_meta = self.samples[item] if self.samples else []
|
|
240
|
+
return BatchMeta(samples=[sample_meta], extra_info=self.extra_info)
|
|
241
|
+
else:
|
|
242
|
+
raise TypeError(f"Indexing with {type(item)} is not supported now!")
|
|
243
|
+
|
|
244
|
+
def chunk(self, chunks: int) -> list["BatchMeta"]:
|
|
245
|
+
"""
|
|
246
|
+
Split this batch into smaller chunks.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
chunks: number of chunks
|
|
250
|
+
|
|
251
|
+
Return:
|
|
252
|
+
List of smaller BatchMeta chunks
|
|
253
|
+
"""
|
|
254
|
+
chunk_list = []
|
|
255
|
+
n = len(self.samples)
|
|
256
|
+
|
|
257
|
+
# Calculate the base size and remainder of each chunk
|
|
258
|
+
base_size = n // chunks
|
|
259
|
+
remainder = n % chunks
|
|
260
|
+
|
|
261
|
+
start = 0
|
|
262
|
+
for i in range(chunks):
|
|
263
|
+
# Calculate the size of the current chunk(the first remainder chunk is 1 more than the base size)
|
|
264
|
+
current_chunk_size = base_size + 1 if i < remainder else base_size
|
|
265
|
+
end = start + current_chunk_size
|
|
266
|
+
chunk_samples = self.samples[start:end]
|
|
267
|
+
chunk = BatchMeta(samples=chunk_samples, extra_info=self.extra_info.copy())
|
|
268
|
+
chunk_list.append(chunk)
|
|
269
|
+
start = end
|
|
270
|
+
return chunk_list
|
|
271
|
+
|
|
272
|
+
@classmethod
|
|
273
|
+
def concat(cls, data: list["BatchMeta"], validate: bool = True) -> Optional["BatchMeta"]:
|
|
274
|
+
"""
|
|
275
|
+
Concatenate multiple BatchMeta chunks into one large batch.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
data: List of BatchMeta chunks to concatenate
|
|
279
|
+
validate: Whether to validate concatenation conditions
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Concatenated BatchMeta
|
|
283
|
+
|
|
284
|
+
Raises:
|
|
285
|
+
ValueError: If validation fails (e.g., field names do not match)
|
|
286
|
+
"""
|
|
287
|
+
if not data:
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
if validate:
|
|
291
|
+
base_fields = data[0].field_names
|
|
292
|
+
|
|
293
|
+
for chunk in data:
|
|
294
|
+
if chunk.field_names != base_fields:
|
|
295
|
+
raise ValueError("Error: Field names do not match for concatenation.")
|
|
296
|
+
|
|
297
|
+
# Combine all samples
|
|
298
|
+
all_samples = list(itertools.chain.from_iterable(chunk.samples for chunk in data))
|
|
299
|
+
# Merge all extra_info dictionaries from the chunks
|
|
300
|
+
merged_extra_info = dict(ChainMap(*(chunk.extra_info for chunk in data)))
|
|
301
|
+
|
|
302
|
+
return BatchMeta(samples=all_samples, extra_info=merged_extra_info)
|
|
303
|
+
|
|
304
|
+
def union(self, other: "BatchMeta", validate: bool = True) -> Optional["BatchMeta"]:
|
|
305
|
+
"""
|
|
306
|
+
Create a union of this batch's fields with another batch's fields.
|
|
307
|
+
Assume both batches have the same global indices. If fields overlap, the
|
|
308
|
+
fields in this batch will be replaced by the other batch's fields.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
other: Another BatchMeta to union with
|
|
312
|
+
validate: Whether to validate union conditions
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
New BatchMeta with unioned fields
|
|
316
|
+
|
|
317
|
+
Raises:
|
|
318
|
+
ValueError: If validation fails (e.g., batch sizes or global indexes do not match)
|
|
319
|
+
"""
|
|
320
|
+
if validate:
|
|
321
|
+
if self.size != other.size:
|
|
322
|
+
raise ValueError("Error: Batch sizes do not match for union.")
|
|
323
|
+
|
|
324
|
+
self_global_indexes = sorted(self.global_indexes)
|
|
325
|
+
other_global_indexes = sorted(other.global_indexes)
|
|
326
|
+
if self_global_indexes != other_global_indexes:
|
|
327
|
+
raise ValueError("Error: Global indexes do not match for union.")
|
|
328
|
+
|
|
329
|
+
# Create a mapping from global_index to SampleMeta in the other batch
|
|
330
|
+
other_sample_map = {sample.global_index: sample for sample in other.samples}
|
|
331
|
+
|
|
332
|
+
# Merge samples
|
|
333
|
+
merged_samples = []
|
|
334
|
+
for sample in self.samples:
|
|
335
|
+
if sample.global_index in other_sample_map:
|
|
336
|
+
other_sample = other_sample_map[sample.global_index]
|
|
337
|
+
merged_sample = sample.union(other_sample, validate=validate)
|
|
338
|
+
merged_samples.append(merged_sample)
|
|
339
|
+
else:
|
|
340
|
+
merged_samples.append(sample)
|
|
341
|
+
|
|
342
|
+
# Merge extra info dictionaries
|
|
343
|
+
merged_extra_info = {**self.extra_info, **other.extra_info}
|
|
344
|
+
return BatchMeta(samples=merged_samples, extra_info=merged_extra_info)
|
|
345
|
+
|
|
346
|
+
def reorder(self, indices: list[int]):
|
|
347
|
+
"""
|
|
348
|
+
Reorder the SampleMeta in the BatchMeta according to the given indices.
|
|
349
|
+
|
|
350
|
+
The operation is performed in-place, modifying the current BatchMeta's SampleMeta order.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
indices : list[int]
|
|
354
|
+
A list of integers specifying the new order of SampleMeta. Each integer
|
|
355
|
+
represents the current index of the SampleMeta in the BatchMeta.
|
|
356
|
+
"""
|
|
357
|
+
# Reorder the samples
|
|
358
|
+
reordered_samples = [self.samples[i] for i in indices]
|
|
359
|
+
object.__setattr__(self, "samples", reordered_samples)
|
|
360
|
+
|
|
361
|
+
# Update necessary attributes
|
|
362
|
+
self._update_after_reorder()
|
|
363
|
+
|
|
364
|
+
def _update_after_reorder(self) -> None:
|
|
365
|
+
"""Update related attributes specifically for the reorder operation"""
|
|
366
|
+
# Update batch_index for each sample
|
|
367
|
+
for idx, sample in enumerate(self.samples):
|
|
368
|
+
object.__setattr__(sample, "_batch_index", idx)
|
|
369
|
+
|
|
370
|
+
# Update cached index lists
|
|
371
|
+
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
|
|
372
|
+
|
|
373
|
+
# Note: No need to update _size, _field_names, _is_ready, etc., as these remain unchanged after reorder
|
|
374
|
+
|
|
375
|
+
@classmethod
|
|
376
|
+
def from_samples(
|
|
377
|
+
cls, samples: SampleMeta | list[SampleMeta], extra_info: Optional[dict[str, Any]] = None
|
|
378
|
+
) -> "BatchMeta":
|
|
379
|
+
"""
|
|
380
|
+
Create a BatchMeta from a single SampleMeta or a list of SampleMeta objects.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
samples: A single SampleMeta or a list of SampleMeta objects
|
|
384
|
+
extra_info: Optional additional information to store with the batch
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
BatchMeta instance containing the provided sample(s)
|
|
388
|
+
|
|
389
|
+
Example:
|
|
390
|
+
>>> sample_meta = SampleMeta(...)
|
|
391
|
+
>>> batch_meta = BatchMeta.from_samples(sample_meta)
|
|
392
|
+
|
|
393
|
+
>>> sample_metas = [sample1, sample2, sample3]
|
|
394
|
+
>>> batch_meta = BatchMeta.from_samples(sample_metas, extra_info={"source": "training"})
|
|
395
|
+
"""
|
|
396
|
+
if extra_info is None:
|
|
397
|
+
extra_info = {}
|
|
398
|
+
|
|
399
|
+
if isinstance(samples, SampleMeta):
|
|
400
|
+
samples = [samples]
|
|
401
|
+
|
|
402
|
+
return cls(samples=samples, extra_info=extra_info)
|
|
403
|
+
|
|
404
|
+
@classmethod
|
|
405
|
+
def empty(cls, extra_info: Optional[dict[str, Any]] = None) -> "BatchMeta":
|
|
406
|
+
"""
|
|
407
|
+
Create an empty BatchMeta with no samples.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
extra_info: Optional additional information to store with the batch
|
|
411
|
+
|
|
412
|
+
Returns:
|
|
413
|
+
Empty BatchMeta instance
|
|
414
|
+
|
|
415
|
+
Example:
|
|
416
|
+
>>> empty_batch = BatchMeta.empty()
|
|
417
|
+
"""
|
|
418
|
+
if extra_info is None:
|
|
419
|
+
extra_info = {}
|
|
420
|
+
return cls(samples=[], extra_info=extra_info)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _union_fields(fields1: dict[str, FieldMeta], fields2: dict[str, FieldMeta]) -> dict[str, FieldMeta]:
|
|
424
|
+
"""Union two sample's fields. If fields overlap, the fields in fields1 will be replaced by fields2."""
|
|
425
|
+
for name in fields2.keys():
|
|
426
|
+
fields1[name] = fields2[name]
|
|
427
|
+
return fields1
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _extract_field_metas(tensor_dict: TensorDict, set_all_ready: bool = True) -> list[dict[str, FieldMeta]]:
|
|
431
|
+
"""
|
|
432
|
+
Extract field metas from a TensorDict. If data in tensor_dict does not have dtype or shape attribute,
|
|
433
|
+
the corresponding dtype or shape will be set to None.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
tensor_dict (TensorDict): The input TensorDict.
|
|
437
|
+
set_all_ready (bool): If True, set all production_status to READY_FOR_CONSUME.
|
|
438
|
+
Otherwise, set to NOT_PRODUCED. Default is True.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
all_fields (list[dict[str, FieldMeta]]): A list of dictionaries containing field metadata.
|
|
442
|
+
"""
|
|
443
|
+
batch_size = tensor_dict.batch_size[0]
|
|
444
|
+
|
|
445
|
+
production_status = ProductionStatus.READY_FOR_CONSUME if set_all_ready else ProductionStatus.NOT_PRODUCED
|
|
446
|
+
|
|
447
|
+
all_fields = [
|
|
448
|
+
{
|
|
449
|
+
name: FieldMeta(
|
|
450
|
+
name=name,
|
|
451
|
+
dtype=getattr(value, "dtype", None),
|
|
452
|
+
shape=getattr(value, "shape", None),
|
|
453
|
+
production_status=production_status,
|
|
454
|
+
)
|
|
455
|
+
for name, value in tensor_dict[idx].items()
|
|
456
|
+
}
|
|
457
|
+
for idx in range(batch_size)
|
|
458
|
+
]
|
|
459
|
+
|
|
460
|
+
return all_fields
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .base import BaseSampler
|
|
16
|
+
from .grpo_group_n_sampler import GRPOGroupNSampler
|
|
17
|
+
from .sequential_sampler import SequentialSampler
|
|
18
|
+
|
|
19
|
+
__all__ = ["BaseSampler", "SequentialSampler", "GRPOGroupNSampler"]
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseSampler(ABC):
|
|
20
|
+
"""Base class for samplers that control how data is consumed from TransferQueue.
|
|
21
|
+
|
|
22
|
+
A sampler defines the logic for selecting which samples to retrieve from the
|
|
23
|
+
available samples, and which should be labeled as consumed (will never be retrieved in the future).
|
|
24
|
+
Based on this abstraction, users can implement various data consumption strategies
|
|
25
|
+
for different training scenarios, such as sequential sampling, grouped sampling for
|
|
26
|
+
reinforcement learning, or custom sampling patterns.
|
|
27
|
+
|
|
28
|
+
The sampler interface provides a clean separation between data production status
|
|
29
|
+
(handled by TransferQueueController) and data consumption strategy (implemented by samplers).
|
|
30
|
+
This allows users to customize data consumption behavior without modifying the TransferQueue codes.
|
|
31
|
+
|
|
32
|
+
Available Samplers:
|
|
33
|
+
- **SequentialSampler**: Default sampler, selects samples sequentially without replacement
|
|
34
|
+
- **GRPOGroupNSampler**: A sampler that performs sampling on continuous N samples only when all of them are ready.
|
|
35
|
+
It assumes the N samples associated with the same prompt are stored contiguously
|
|
36
|
+
- **RankAwareSampler**: Rank-aware sampling for distributed scenarios (TODO)
|
|
37
|
+
|
|
38
|
+
NOTE: Always return both sampled and consumed indexes (may be identical).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self):
|
|
42
|
+
self._states: dict[str, Any] = {}
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def sample(
|
|
46
|
+
self,
|
|
47
|
+
ready_indexes: list[int],
|
|
48
|
+
batch_size: int,
|
|
49
|
+
*args: Any,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> tuple[list[int], list[int]]:
|
|
52
|
+
"""Sample a batch of indices from the ready indices.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
ready_indexes: List of global indices for which all required fields of the
|
|
56
|
+
corresponding samples have been produced, and the samples are not labeled as
|
|
57
|
+
consumed in the corresponding task.
|
|
58
|
+
batch_size: Number of samples to select
|
|
59
|
+
*args: Additional positional arguments for specific sampler implementations
|
|
60
|
+
**kwargs: Additional keyword arguments for specific sampler implementations
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List of sampled global indices of length batch_size
|
|
64
|
+
|
|
65
|
+
List of global indices of length batch_size that should be labeled as consumed
|
|
66
|
+
(will never be retrieved in the future)
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If batch_size is invalid or ready_indexes is insufficient
|
|
70
|
+
"""
|
|
71
|
+
raise NotImplementedError("Subclasses must implement sample")
|
|
72
|
+
|
|
73
|
+
def __call__(self, *args: Any, **kwargs: Any) -> tuple[list[int], list[int]]:
|
|
74
|
+
return self.sample(*args, **kwargs)
|