TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,157 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ from transfer_queue.sampler import BaseSampler
18
+
19
+
20
+ class GRPOGroupNSampler(BaseSampler):
21
+ """Group-based sampler for reinforcement learning and multi-sample generation workflows.
22
+
23
+ This sampler implements grouped sampling without replacement, specifically designed
24
+ for scenarios where multiple samples need to be generated from the same input prompt
25
+ or where grouped sampling is required. It ensures that all samples belonging to the
26
+ same prompt are either selected together or not at all, maintaining the integrity
27
+ of prompt groups throughout the training process.
28
+
29
+ The sampler is commonly used in GRPO (Group Relative Policy Optimization)
30
+ training scenarios where you need to generate multiple responses from the same
31
+ prompt and train the policy on all of them together.
32
+
33
+ The sampler is configured through TransferQueueController and receives parameters
34
+ via the sampling_config in get_meta calls:
35
+
36
+ ```python
37
+ # Initialize controller with GRPO sampler
38
+ from transfer_queue import TransferQueueController, GRPOGroupNSampler, AsyncTransferQueueClient
39
+
40
+ controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
41
+ controller_info = process_zmq_server_info(controller)
42
+
43
+ client = AsyncTransferQueueClient(
44
+ client_id="rl_client",
45
+ controller_info=controller_info,
46
+ )
47
+
48
+ # Get metadata with grouped sampling configuration
49
+ meta = await client.async_get_meta(
50
+ data_fields=["input_ids", "attention_mask", "generated_text", "reward"],
51
+ batch_size=16, # Total samples requested
52
+ partition_id="train_0",
53
+ task_name="rl_training",
54
+ sampling_config={"n_samples_per_prompt": 4} # 4 samples per prompt
55
+ )
56
+ # This will return 16 samples organized as 4 groups of 4 samples each
57
+ ```
58
+
59
+ Data Organization:
60
+ This sampler assumes the user puts the prompts in consecutive orders, such as
61
+ [prompt1_sample1, prompt1_sample2, prompt2_sample1, prompt2_sample2, ...]
62
+ belong to the same prompt group:
63
+ ```
64
+ ready_indexes = [prompt1_sample1, prompt1_sample2, prompt1_sample3, prompt1_sample4,
65
+ prompt2_sample1, prompt2_sample2, prompt2_sample3, prompt2_sample4, ...]
66
+ ```
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ ):
72
+ """Initialize the GRPOGroupNSampler.
73
+
74
+ The sampler maintains minimal internal state and relies on runtime
75
+ configuration through the sampling_config parameter.
76
+ """
77
+ super().__init__()
78
+
79
+ def sample(
80
+ self,
81
+ ready_indexes: list[int],
82
+ batch_size: int,
83
+ n_samples_per_prompt: int,
84
+ *args: Any,
85
+ **kwargs: Any,
86
+ ) -> tuple[list[int], list[int]]:
87
+ """Sample groups of indices from the ready indices.
88
+
89
+ This method implements group completeness validation and ensures that only complete
90
+ groups are sampled. It returns empty lists if insufficient complete groups are available.
91
+
92
+ Args:
93
+ ready_indexes: List of global indices for which all required fields have been
94
+ produced and samples are not labeled as consumed. These should be organized
95
+ such that consecutive indices belong to the same prompt group.
96
+ batch_size: Total number of samples to select. Must be divisible by n_samples_per_prompt.
97
+ n_samples_per_prompt: Number of samples per prompt group. Must be > 0.
98
+ *args: Additional positional arguments (ignored in current implementation)
99
+ **kwargs: Additional keyword arguments (ignored in current implementation)
100
+
101
+ Returns:
102
+ Tuple of (sampled_indexes, consumed_indexes):
103
+ - sampled_indexes: List of selected global indices, length = batch_size or empty
104
+ - consumed_indexes: List of indices to mark as consumed, identical to sampled_indexes
105
+ (without replacement semantics)
106
+
107
+ Examples:
108
+ >>> sampler = GRPOGroupNSampler()
109
+ >>> ready_indexes = [0, 1, 3, 4, 6, 7] # No complete groups after sorting
110
+ >>> sampled, consumed = sampler.sample(ready_indexes, 6, n_samples_per_prompt=3)
111
+ >>> sampled
112
+ []
113
+ >>> consumed
114
+ []
115
+
116
+ >>> ready_indexes = [0, 1, 3, 4, 5, 6, 7, 9, 10, 11] # Has complete groups after sorting
117
+ >>> sampled, consumed = sampler.sample(ready_indexes, 6, n_samples_per_prompt=3)
118
+ >>> sampled
119
+ [3, 4, 5, 9, 10, 11]
120
+ >>> consumed
121
+ [3, 4, 5, 9, 10, 11]
122
+ """
123
+ # Basic validation
124
+ if n_samples_per_prompt <= 0:
125
+ raise ValueError(f"n_samples_per_prompt must be positive, got {n_samples_per_prompt}")
126
+
127
+ if batch_size % n_samples_per_prompt != 0:
128
+ raise ValueError(
129
+ f"batch_size ({batch_size}) must be a multiple of n_samples_per_prompt ({n_samples_per_prompt})"
130
+ )
131
+
132
+ required_groups = batch_size // n_samples_per_prompt
133
+ sorted_ready_indexes = sorted(ready_indexes)
134
+
135
+ complete_group_indices = []
136
+ found_groups = 0
137
+
138
+ i = 0
139
+ while i <= len(sorted_ready_indexes) - n_samples_per_prompt and found_groups < required_groups:
140
+ potential_group = sorted_ready_indexes[i : i + n_samples_per_prompt]
141
+ # Check if this forms a complete group (consecutive indices)
142
+ is_consecutive = all(
143
+ potential_group[j + 1] - potential_group[j] == 1 for j in range(len(potential_group) - 1)
144
+ )
145
+ if is_consecutive:
146
+ complete_group_indices.extend(potential_group)
147
+ found_groups += 1
148
+ i += n_samples_per_prompt
149
+ else:
150
+ i += 1
151
+
152
+ if found_groups < required_groups:
153
+ return [], []
154
+ sampled_indexes = complete_group_indices
155
+ consumed_indexes = sampled_indexes.copy()
156
+
157
+ return sampled_indexes, consumed_indexes
@@ -0,0 +1,75 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ from transfer_queue.sampler import BaseSampler
18
+
19
+
20
+ class SequentialSampler(BaseSampler):
21
+ """Sequential sampler for basic data consumption patterns.
22
+
23
+ This sampler implements sequential sampling without replacement, selecting samples
24
+ from the beginning of the ready_indexes list in order. It's the default sampling
25
+ strategy for TransferQueueController and provides simple, deterministic data consumption
26
+ with minimal overhead.
27
+
28
+ The sampler is ideal for standard supervised learning scenarios, data preprocessing
29
+ pipelines, and any use case where ordered, predictable data consumption is preferred.
30
+ It ensures each sample is consumed exactly once, maintaining a clean progression through
31
+ the available data.
32
+
33
+ This sampler is typically used as the default sampler in TransferQueueController:
34
+
35
+ ```python
36
+ # Default usage (SequentialSampler is the default)
37
+ controller = TransferQueueController.remote()
38
+ # or explicitly:
39
+ controller = TransferQueueController.remote(sampler=SequentialSampler)
40
+ ```
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ ):
46
+ """Initialize the SequentialSampler.
47
+
48
+ SequentialSampler requires no initialization parameters and maintains
49
+ minimal internal state for optimal performance.
50
+ """
51
+ super().__init__()
52
+
53
+ def sample(
54
+ self,
55
+ ready_indexes: list[int],
56
+ batch_size: int,
57
+ *args: Any,
58
+ **kwargs: Any,
59
+ ) -> tuple[list[int], list[int]]:
60
+ """Select first batch_size elements from ready_indexes.
61
+
62
+ Args:
63
+ ready_indexes: Available sample indices.
64
+ batch_size: Number of samples to select. If larger than available ready samples,
65
+ all available samples will be returned.
66
+ *args: Additional positional arguments (ignored).
67
+ **kwargs: Additional keyword arguments (ignored).
68
+
69
+ Returns:
70
+ Tuple of (sampled_indexes, consumed_indexes), where consumed_indexes = sampled_indexes.
71
+ """
72
+ sampled_indexes = ready_indexes[:batch_size]
73
+ consumed_indexes = sampled_indexes
74
+
75
+ return sampled_indexes, consumed_indexes
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .managers import AsyncSimpleStorageManager, TransferQueueStorageManager, TransferQueueStorageManagerFactory
16
+ from .simple_backend import SimpleStorageUnit, StorageMetaGroup, StorageUnitData
17
+
18
+ __all__ = [
19
+ "SimpleStorageUnit",
20
+ "StorageUnitData",
21
+ "StorageMetaGroup",
22
+ "TransferQueueStorageManager",
23
+ "TransferQueueStorageManagerFactory",
24
+ "AsyncSimpleStorageManager",
25
+ ]
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This module is currently empty but reserved for future client implementations
16
+ from .base import TransferQueueStorageKVClient
17
+ from .factory import StorageClientFactory
18
+ from .yuanrong_client import YRStorageClient
19
+
20
+ __all__ = [
21
+ "TransferQueueStorageKVClient",
22
+ "StorageClientFactory",
23
+ "YRStorageClient",
24
+ ]
@@ -0,0 +1,22 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from torch import Tensor
4
+
5
+
6
+ class TransferQueueStorageKVClient(ABC):
7
+ """
8
+ Abstract base class for storage client.
9
+ Subclasses must implement the core methods: put, get, and clear.
10
+ """
11
+
12
+ @abstractmethod
13
+ def put(self, keys: list[str], values: list[Tensor]) -> None:
14
+ raise NotImplementedError("Subclasses must implement put")
15
+
16
+ @abstractmethod
17
+ def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]:
18
+ raise NotImplementedError("Subclasses must implement get")
19
+
20
+ @abstractmethod
21
+ def clear(self, keys: list[str]) -> None:
22
+ raise NotImplementedError("Subclasses must implement clear")
@@ -0,0 +1,55 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
15
+
16
+
17
+ class StorageClientFactory:
18
+ """
19
+ Factory class for creating storage client instances.
20
+ Uses a decorator-based registration mechanism to map client names to classes.
21
+ """
22
+
23
+ # Class variable: maps client names to their corresponding classes
24
+ _registry: dict[str, TransferQueueStorageKVClient] = {}
25
+
26
+ @classmethod
27
+ def register(cls, client_type: str):
28
+ """
29
+ Decorator to register a concrete client class with the factory.
30
+ Args:
31
+ client_type (str): The name used to identify the client
32
+ Returns:
33
+ Callable: The decorator function that returns the original class
34
+ """
35
+
36
+ def decorator(client_class: TransferQueueStorageKVClient) -> TransferQueueStorageKVClient:
37
+ cls._registry[client_type] = client_class
38
+ return client_class
39
+
40
+ return decorator
41
+
42
+ @classmethod
43
+ def create(cls, client_type: str, config: dict) -> TransferQueueStorageKVClient:
44
+ """
45
+ Create and return an instance of the storage client by name.
46
+ Args:
47
+ client_type (str): The registered name of the client
48
+ Returns:
49
+ StorageClientFactory: An instance of the requested client
50
+ Raises:
51
+ ValueError: If no client is registered with the given name
52
+ """
53
+ if client_type not in cls._registry:
54
+ raise ValueError(f"Unknown StorageClient: {client_type}")
55
+ return cls._registry[client_type](config)
@@ -0,0 +1,118 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
7
+ from transfer_queue.storage.clients.factory import StorageClientFactory
8
+
9
+ YUANRONG_DATASYSTEM_IMPORTED: bool = True
10
+ TORCH_NPU_IMPORTED: bool = True
11
+ try:
12
+ import datasystem
13
+ except ImportError:
14
+ YUANRONG_DATASYSTEM_IMPORTED = False
15
+ try:
16
+ import torch_npu
17
+ except ImportError:
18
+ TORCH_NPU_IMPORTED = False
19
+
20
+
21
+ # TODO: DSTensorClient.dev_mget has wrong behavior: it may require stricter environment to execute
22
+ @StorageClientFactory.register("Yuanrong")
23
+ class YRStorageClient(TransferQueueStorageKVClient):
24
+ """
25
+ Storage client for YuanRong DataSystem.
26
+ Communicates with the remote tensor storage service via DsTensorClient.
27
+ All tensors must reside on NPU device.
28
+ """
29
+
30
+ def __init__(self, config: dict[str, Any]):
31
+ if not YUANRONG_DATASYSTEM_IMPORTED:
32
+ raise ImportError("YuanRong DataSystem not installed.")
33
+ if not TORCH_NPU_IMPORTED:
34
+ raise ImportError("Torch_npu not installed.")
35
+
36
+ self.host = config.get("host")
37
+ self.port = config.get("port")
38
+ self.device_id = config.get("device_id")
39
+ torch_npu.npu.set_device(f"npu:{self.device_id}") # set npu_device
40
+ self._ds_client = datasystem.DsTensorClient(self.host, self.port, self.device_id)
41
+ self._ds_client.init()
42
+
43
+ def _create_empty_tensorlist(self, shapes, dtypes):
44
+ """
45
+ Create a list of empty NPU tensors with given shapes and dtypes.
46
+ Args:
47
+ shapes (list): List of tensor shapes (e.g., [(3,), (2, 4)])
48
+ dtypes (list): List of torch dtypes (e.g., [torch.float32, torch.int64])
49
+ Returns:
50
+ list: List of uninitialized NPU tensors
51
+ """
52
+ if len(dtypes) != len(shapes):
53
+ raise ValueError("Length of dtypes must equal length of shapes")
54
+
55
+ tensors: list[Tensor] = []
56
+ for dtype, shape in zip(dtypes, shapes, strict=False):
57
+ tensor = torch.empty(shape, dtype=dtype).to(f"npu:{self.device_id}")
58
+ tensors.append(tensor)
59
+ return tensors
60
+
61
+ def put(self, keys: list[str], values: list[Tensor]):
62
+ """
63
+ Store tensors to remote storage.
64
+ Args:
65
+ keys (list): List of string keys
66
+ values (list): List of torch.Tensor on NPU
67
+ """
68
+ if not isinstance(keys, list) or not isinstance(values, list):
69
+ raise ValueError("keys and values must be lists")
70
+ if len(keys) != len(values):
71
+ raise ValueError("Number of keys must match number of values")
72
+
73
+ # TODO: Support the situation when the number of keys is greater than 10000
74
+ if len(keys) > 10000:
75
+ raise NotImplementedError("We will support the number of keys greater than 10000 int the future")
76
+
77
+ for value in values:
78
+ if not isinstance(value, torch.Tensor):
79
+ raise ValueError(f"Expected torch.Tensor, got {type(value)}")
80
+ if value.device.type != "npu":
81
+ raise ValueError(f"Tensor is on {value.device}, not on NPU")
82
+
83
+ self._ds_client.dev_mset(keys, values)
84
+
85
+ def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]:
86
+ """
87
+ Retrieve tensors from remote storage.
88
+ Args:
89
+ keys (list): List of keys to fetch
90
+ shapes (list): Expected shapes of returned tensors
91
+ dtypes (list): Expected dtypes of returned tensors
92
+ Returns:
93
+ list: List of retrieved NPU tensors
94
+ """
95
+ if shapes is None:
96
+ raise ValueError("Yuanrong storage client needs Expected shapes of returned tensors")
97
+ if dtypes is None:
98
+ raise ValueError("Yuanrong storage client needs Expected dtypes of returned tensors")
99
+ if len(dtypes) != len(shapes):
100
+ raise ValueError("Length of dtypes must equal length of shapes")
101
+
102
+ values: list[Tensor] = self._create_empty_tensorlist(shapes=shapes, dtypes=dtypes)
103
+
104
+ # TODO: Support the situation when the number of keys is greater than 10000
105
+ if len(keys) > 10000:
106
+ raise NotImplementedError("We will support the number of keys greater than 10000 int the future")
107
+
108
+ # Timeout set to 2000ms
109
+ self._ds_client.dev_mget(keys, values, 2000)
110
+ return values
111
+
112
+ def clear(self, keys: list[str]):
113
+ """
114
+ Delete entries from storage by keys.
115
+ Args:
116
+ keys (list): List of keys to delete
117
+ """
118
+ self._ds_client.dev_delete(keys)
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import TransferQueueStorageManager
16
+ from .factory import TransferQueueStorageManagerFactory
17
+ from .simple_backend_manager import AsyncSimpleStorageManager
18
+
19
+ __all__ = [
20
+ "TransferQueueStorageManager",
21
+ "TransferQueueStorageManagerFactory",
22
+ "AsyncSimpleStorageManager",
23
+ ]