TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from transfer_queue.sampler import BaseSampler
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class GRPOGroupNSampler(BaseSampler):
|
|
21
|
+
"""Group-based sampler for reinforcement learning and multi-sample generation workflows.
|
|
22
|
+
|
|
23
|
+
This sampler implements grouped sampling without replacement, specifically designed
|
|
24
|
+
for scenarios where multiple samples need to be generated from the same input prompt
|
|
25
|
+
or where grouped sampling is required. It ensures that all samples belonging to the
|
|
26
|
+
same prompt are either selected together or not at all, maintaining the integrity
|
|
27
|
+
of prompt groups throughout the training process.
|
|
28
|
+
|
|
29
|
+
The sampler is commonly used in GRPO (Group Relative Policy Optimization)
|
|
30
|
+
training scenarios where you need to generate multiple responses from the same
|
|
31
|
+
prompt and train the policy on all of them together.
|
|
32
|
+
|
|
33
|
+
The sampler is configured through TransferQueueController and receives parameters
|
|
34
|
+
via the sampling_config in get_meta calls:
|
|
35
|
+
|
|
36
|
+
```python
|
|
37
|
+
# Initialize controller with GRPO sampler
|
|
38
|
+
from transfer_queue import TransferQueueController, GRPOGroupNSampler, AsyncTransferQueueClient
|
|
39
|
+
|
|
40
|
+
controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)
|
|
41
|
+
controller_info = process_zmq_server_info(controller)
|
|
42
|
+
|
|
43
|
+
client = AsyncTransferQueueClient(
|
|
44
|
+
client_id="rl_client",
|
|
45
|
+
controller_info=controller_info,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Get metadata with grouped sampling configuration
|
|
49
|
+
meta = await client.async_get_meta(
|
|
50
|
+
data_fields=["input_ids", "attention_mask", "generated_text", "reward"],
|
|
51
|
+
batch_size=16, # Total samples requested
|
|
52
|
+
partition_id="train_0",
|
|
53
|
+
task_name="rl_training",
|
|
54
|
+
sampling_config={"n_samples_per_prompt": 4} # 4 samples per prompt
|
|
55
|
+
)
|
|
56
|
+
# This will return 16 samples organized as 4 groups of 4 samples each
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
Data Organization:
|
|
60
|
+
This sampler assumes the user puts the prompts in consecutive orders, such as
|
|
61
|
+
[prompt1_sample1, prompt1_sample2, prompt2_sample1, prompt2_sample2, ...]
|
|
62
|
+
belong to the same prompt group:
|
|
63
|
+
```
|
|
64
|
+
ready_indexes = [prompt1_sample1, prompt1_sample2, prompt1_sample3, prompt1_sample4,
|
|
65
|
+
prompt2_sample1, prompt2_sample2, prompt2_sample3, prompt2_sample4, ...]
|
|
66
|
+
```
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
):
|
|
72
|
+
"""Initialize the GRPOGroupNSampler.
|
|
73
|
+
|
|
74
|
+
The sampler maintains minimal internal state and relies on runtime
|
|
75
|
+
configuration through the sampling_config parameter.
|
|
76
|
+
"""
|
|
77
|
+
super().__init__()
|
|
78
|
+
|
|
79
|
+
def sample(
|
|
80
|
+
self,
|
|
81
|
+
ready_indexes: list[int],
|
|
82
|
+
batch_size: int,
|
|
83
|
+
n_samples_per_prompt: int,
|
|
84
|
+
*args: Any,
|
|
85
|
+
**kwargs: Any,
|
|
86
|
+
) -> tuple[list[int], list[int]]:
|
|
87
|
+
"""Sample groups of indices from the ready indices.
|
|
88
|
+
|
|
89
|
+
This method implements group completeness validation and ensures that only complete
|
|
90
|
+
groups are sampled. It returns empty lists if insufficient complete groups are available.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
ready_indexes: List of global indices for which all required fields have been
|
|
94
|
+
produced and samples are not labeled as consumed. These should be organized
|
|
95
|
+
such that consecutive indices belong to the same prompt group.
|
|
96
|
+
batch_size: Total number of samples to select. Must be divisible by n_samples_per_prompt.
|
|
97
|
+
n_samples_per_prompt: Number of samples per prompt group. Must be > 0.
|
|
98
|
+
*args: Additional positional arguments (ignored in current implementation)
|
|
99
|
+
**kwargs: Additional keyword arguments (ignored in current implementation)
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Tuple of (sampled_indexes, consumed_indexes):
|
|
103
|
+
- sampled_indexes: List of selected global indices, length = batch_size or empty
|
|
104
|
+
- consumed_indexes: List of indices to mark as consumed, identical to sampled_indexes
|
|
105
|
+
(without replacement semantics)
|
|
106
|
+
|
|
107
|
+
Examples:
|
|
108
|
+
>>> sampler = GRPOGroupNSampler()
|
|
109
|
+
>>> ready_indexes = [0, 1, 3, 4, 6, 7] # No complete groups after sorting
|
|
110
|
+
>>> sampled, consumed = sampler.sample(ready_indexes, 6, n_samples_per_prompt=3)
|
|
111
|
+
>>> sampled
|
|
112
|
+
[]
|
|
113
|
+
>>> consumed
|
|
114
|
+
[]
|
|
115
|
+
|
|
116
|
+
>>> ready_indexes = [0, 1, 3, 4, 5, 6, 7, 9, 10, 11] # Has complete groups after sorting
|
|
117
|
+
>>> sampled, consumed = sampler.sample(ready_indexes, 6, n_samples_per_prompt=3)
|
|
118
|
+
>>> sampled
|
|
119
|
+
[3, 4, 5, 9, 10, 11]
|
|
120
|
+
>>> consumed
|
|
121
|
+
[3, 4, 5, 9, 10, 11]
|
|
122
|
+
"""
|
|
123
|
+
# Basic validation
|
|
124
|
+
if n_samples_per_prompt <= 0:
|
|
125
|
+
raise ValueError(f"n_samples_per_prompt must be positive, got {n_samples_per_prompt}")
|
|
126
|
+
|
|
127
|
+
if batch_size % n_samples_per_prompt != 0:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"batch_size ({batch_size}) must be a multiple of n_samples_per_prompt ({n_samples_per_prompt})"
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
required_groups = batch_size // n_samples_per_prompt
|
|
133
|
+
sorted_ready_indexes = sorted(ready_indexes)
|
|
134
|
+
|
|
135
|
+
complete_group_indices = []
|
|
136
|
+
found_groups = 0
|
|
137
|
+
|
|
138
|
+
i = 0
|
|
139
|
+
while i <= len(sorted_ready_indexes) - n_samples_per_prompt and found_groups < required_groups:
|
|
140
|
+
potential_group = sorted_ready_indexes[i : i + n_samples_per_prompt]
|
|
141
|
+
# Check if this forms a complete group (consecutive indices)
|
|
142
|
+
is_consecutive = all(
|
|
143
|
+
potential_group[j + 1] - potential_group[j] == 1 for j in range(len(potential_group) - 1)
|
|
144
|
+
)
|
|
145
|
+
if is_consecutive:
|
|
146
|
+
complete_group_indices.extend(potential_group)
|
|
147
|
+
found_groups += 1
|
|
148
|
+
i += n_samples_per_prompt
|
|
149
|
+
else:
|
|
150
|
+
i += 1
|
|
151
|
+
|
|
152
|
+
if found_groups < required_groups:
|
|
153
|
+
return [], []
|
|
154
|
+
sampled_indexes = complete_group_indices
|
|
155
|
+
consumed_indexes = sampled_indexes.copy()
|
|
156
|
+
|
|
157
|
+
return sampled_indexes, consumed_indexes
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from transfer_queue.sampler import BaseSampler
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SequentialSampler(BaseSampler):
|
|
21
|
+
"""Sequential sampler for basic data consumption patterns.
|
|
22
|
+
|
|
23
|
+
This sampler implements sequential sampling without replacement, selecting samples
|
|
24
|
+
from the beginning of the ready_indexes list in order. It's the default sampling
|
|
25
|
+
strategy for TransferQueueController and provides simple, deterministic data consumption
|
|
26
|
+
with minimal overhead.
|
|
27
|
+
|
|
28
|
+
The sampler is ideal for standard supervised learning scenarios, data preprocessing
|
|
29
|
+
pipelines, and any use case where ordered, predictable data consumption is preferred.
|
|
30
|
+
It ensures each sample is consumed exactly once, maintaining a clean progression through
|
|
31
|
+
the available data.
|
|
32
|
+
|
|
33
|
+
This sampler is typically used as the default sampler in TransferQueueController:
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
# Default usage (SequentialSampler is the default)
|
|
37
|
+
controller = TransferQueueController.remote()
|
|
38
|
+
# or explicitly:
|
|
39
|
+
controller = TransferQueueController.remote(sampler=SequentialSampler)
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
):
|
|
46
|
+
"""Initialize the SequentialSampler.
|
|
47
|
+
|
|
48
|
+
SequentialSampler requires no initialization parameters and maintains
|
|
49
|
+
minimal internal state for optimal performance.
|
|
50
|
+
"""
|
|
51
|
+
super().__init__()
|
|
52
|
+
|
|
53
|
+
def sample(
|
|
54
|
+
self,
|
|
55
|
+
ready_indexes: list[int],
|
|
56
|
+
batch_size: int,
|
|
57
|
+
*args: Any,
|
|
58
|
+
**kwargs: Any,
|
|
59
|
+
) -> tuple[list[int], list[int]]:
|
|
60
|
+
"""Select first batch_size elements from ready_indexes.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
ready_indexes: Available sample indices.
|
|
64
|
+
batch_size: Number of samples to select. If larger than available ready samples,
|
|
65
|
+
all available samples will be returned.
|
|
66
|
+
*args: Additional positional arguments (ignored).
|
|
67
|
+
**kwargs: Additional keyword arguments (ignored).
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tuple of (sampled_indexes, consumed_indexes), where consumed_indexes = sampled_indexes.
|
|
71
|
+
"""
|
|
72
|
+
sampled_indexes = ready_indexes[:batch_size]
|
|
73
|
+
consumed_indexes = sampled_indexes
|
|
74
|
+
|
|
75
|
+
return sampled_indexes, consumed_indexes
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .managers import AsyncSimpleStorageManager, TransferQueueStorageManager, TransferQueueStorageManagerFactory
|
|
16
|
+
from .simple_backend import SimpleStorageUnit, StorageMetaGroup, StorageUnitData
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"SimpleStorageUnit",
|
|
20
|
+
"StorageUnitData",
|
|
21
|
+
"StorageMetaGroup",
|
|
22
|
+
"TransferQueueStorageManager",
|
|
23
|
+
"TransferQueueStorageManagerFactory",
|
|
24
|
+
"AsyncSimpleStorageManager",
|
|
25
|
+
]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# This module is currently empty but reserved for future client implementations
|
|
16
|
+
from .base import TransferQueueStorageKVClient
|
|
17
|
+
from .factory import StorageClientFactory
|
|
18
|
+
from .yuanrong_client import YRStorageClient
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"TransferQueueStorageKVClient",
|
|
22
|
+
"StorageClientFactory",
|
|
23
|
+
"YRStorageClient",
|
|
24
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TransferQueueStorageKVClient(ABC):
|
|
7
|
+
"""
|
|
8
|
+
Abstract base class for storage client.
|
|
9
|
+
Subclasses must implement the core methods: put, get, and clear.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def put(self, keys: list[str], values: list[Tensor]) -> None:
|
|
14
|
+
raise NotImplementedError("Subclasses must implement put")
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]:
|
|
18
|
+
raise NotImplementedError("Subclasses must implement get")
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def clear(self, keys: list[str]) -> None:
|
|
22
|
+
raise NotImplementedError("Subclasses must implement clear")
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class StorageClientFactory:
|
|
18
|
+
"""
|
|
19
|
+
Factory class for creating storage client instances.
|
|
20
|
+
Uses a decorator-based registration mechanism to map client names to classes.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
# Class variable: maps client names to their corresponding classes
|
|
24
|
+
_registry: dict[str, TransferQueueStorageKVClient] = {}
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def register(cls, client_type: str):
|
|
28
|
+
"""
|
|
29
|
+
Decorator to register a concrete client class with the factory.
|
|
30
|
+
Args:
|
|
31
|
+
client_type (str): The name used to identify the client
|
|
32
|
+
Returns:
|
|
33
|
+
Callable: The decorator function that returns the original class
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def decorator(client_class: TransferQueueStorageKVClient) -> TransferQueueStorageKVClient:
|
|
37
|
+
cls._registry[client_type] = client_class
|
|
38
|
+
return client_class
|
|
39
|
+
|
|
40
|
+
return decorator
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def create(cls, client_type: str, config: dict) -> TransferQueueStorageKVClient:
|
|
44
|
+
"""
|
|
45
|
+
Create and return an instance of the storage client by name.
|
|
46
|
+
Args:
|
|
47
|
+
client_type (str): The registered name of the client
|
|
48
|
+
Returns:
|
|
49
|
+
StorageClientFactory: An instance of the requested client
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If no client is registered with the given name
|
|
52
|
+
"""
|
|
53
|
+
if client_type not in cls._registry:
|
|
54
|
+
raise ValueError(f"Unknown StorageClient: {client_type}")
|
|
55
|
+
return cls._registry[client_type](config)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
|
|
7
|
+
from transfer_queue.storage.clients.factory import StorageClientFactory
|
|
8
|
+
|
|
9
|
+
YUANRONG_DATASYSTEM_IMPORTED: bool = True
|
|
10
|
+
TORCH_NPU_IMPORTED: bool = True
|
|
11
|
+
try:
|
|
12
|
+
import datasystem
|
|
13
|
+
except ImportError:
|
|
14
|
+
YUANRONG_DATASYSTEM_IMPORTED = False
|
|
15
|
+
try:
|
|
16
|
+
import torch_npu
|
|
17
|
+
except ImportError:
|
|
18
|
+
TORCH_NPU_IMPORTED = False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# TODO: DSTensorClient.dev_mget has wrong behavior: it may require stricter environment to execute
|
|
22
|
+
@StorageClientFactory.register("Yuanrong")
|
|
23
|
+
class YRStorageClient(TransferQueueStorageKVClient):
|
|
24
|
+
"""
|
|
25
|
+
Storage client for YuanRong DataSystem.
|
|
26
|
+
Communicates with the remote tensor storage service via DsTensorClient.
|
|
27
|
+
All tensors must reside on NPU device.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, config: dict[str, Any]):
|
|
31
|
+
if not YUANRONG_DATASYSTEM_IMPORTED:
|
|
32
|
+
raise ImportError("YuanRong DataSystem not installed.")
|
|
33
|
+
if not TORCH_NPU_IMPORTED:
|
|
34
|
+
raise ImportError("Torch_npu not installed.")
|
|
35
|
+
|
|
36
|
+
self.host = config.get("host")
|
|
37
|
+
self.port = config.get("port")
|
|
38
|
+
self.device_id = config.get("device_id")
|
|
39
|
+
torch_npu.npu.set_device(f"npu:{self.device_id}") # set npu_device
|
|
40
|
+
self._ds_client = datasystem.DsTensorClient(self.host, self.port, self.device_id)
|
|
41
|
+
self._ds_client.init()
|
|
42
|
+
|
|
43
|
+
def _create_empty_tensorlist(self, shapes, dtypes):
|
|
44
|
+
"""
|
|
45
|
+
Create a list of empty NPU tensors with given shapes and dtypes.
|
|
46
|
+
Args:
|
|
47
|
+
shapes (list): List of tensor shapes (e.g., [(3,), (2, 4)])
|
|
48
|
+
dtypes (list): List of torch dtypes (e.g., [torch.float32, torch.int64])
|
|
49
|
+
Returns:
|
|
50
|
+
list: List of uninitialized NPU tensors
|
|
51
|
+
"""
|
|
52
|
+
if len(dtypes) != len(shapes):
|
|
53
|
+
raise ValueError("Length of dtypes must equal length of shapes")
|
|
54
|
+
|
|
55
|
+
tensors: list[Tensor] = []
|
|
56
|
+
for dtype, shape in zip(dtypes, shapes, strict=False):
|
|
57
|
+
tensor = torch.empty(shape, dtype=dtype).to(f"npu:{self.device_id}")
|
|
58
|
+
tensors.append(tensor)
|
|
59
|
+
return tensors
|
|
60
|
+
|
|
61
|
+
def put(self, keys: list[str], values: list[Tensor]):
|
|
62
|
+
"""
|
|
63
|
+
Store tensors to remote storage.
|
|
64
|
+
Args:
|
|
65
|
+
keys (list): List of string keys
|
|
66
|
+
values (list): List of torch.Tensor on NPU
|
|
67
|
+
"""
|
|
68
|
+
if not isinstance(keys, list) or not isinstance(values, list):
|
|
69
|
+
raise ValueError("keys and values must be lists")
|
|
70
|
+
if len(keys) != len(values):
|
|
71
|
+
raise ValueError("Number of keys must match number of values")
|
|
72
|
+
|
|
73
|
+
# TODO: Support the situation when the number of keys is greater than 10000
|
|
74
|
+
if len(keys) > 10000:
|
|
75
|
+
raise NotImplementedError("We will support the number of keys greater than 10000 int the future")
|
|
76
|
+
|
|
77
|
+
for value in values:
|
|
78
|
+
if not isinstance(value, torch.Tensor):
|
|
79
|
+
raise ValueError(f"Expected torch.Tensor, got {type(value)}")
|
|
80
|
+
if value.device.type != "npu":
|
|
81
|
+
raise ValueError(f"Tensor is on {value.device}, not on NPU")
|
|
82
|
+
|
|
83
|
+
self._ds_client.dev_mset(keys, values)
|
|
84
|
+
|
|
85
|
+
def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]:
|
|
86
|
+
"""
|
|
87
|
+
Retrieve tensors from remote storage.
|
|
88
|
+
Args:
|
|
89
|
+
keys (list): List of keys to fetch
|
|
90
|
+
shapes (list): Expected shapes of returned tensors
|
|
91
|
+
dtypes (list): Expected dtypes of returned tensors
|
|
92
|
+
Returns:
|
|
93
|
+
list: List of retrieved NPU tensors
|
|
94
|
+
"""
|
|
95
|
+
if shapes is None:
|
|
96
|
+
raise ValueError("Yuanrong storage client needs Expected shapes of returned tensors")
|
|
97
|
+
if dtypes is None:
|
|
98
|
+
raise ValueError("Yuanrong storage client needs Expected dtypes of returned tensors")
|
|
99
|
+
if len(dtypes) != len(shapes):
|
|
100
|
+
raise ValueError("Length of dtypes must equal length of shapes")
|
|
101
|
+
|
|
102
|
+
values: list[Tensor] = self._create_empty_tensorlist(shapes=shapes, dtypes=dtypes)
|
|
103
|
+
|
|
104
|
+
# TODO: Support the situation when the number of keys is greater than 10000
|
|
105
|
+
if len(keys) > 10000:
|
|
106
|
+
raise NotImplementedError("We will support the number of keys greater than 10000 int the future")
|
|
107
|
+
|
|
108
|
+
# Timeout set to 2000ms
|
|
109
|
+
self._ds_client.dev_mget(keys, values, 2000)
|
|
110
|
+
return values
|
|
111
|
+
|
|
112
|
+
def clear(self, keys: list[str]):
|
|
113
|
+
"""
|
|
114
|
+
Delete entries from storage by keys.
|
|
115
|
+
Args:
|
|
116
|
+
keys (list): List of keys to delete
|
|
117
|
+
"""
|
|
118
|
+
self._ds_client.dev_delete(keys)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .base import TransferQueueStorageManager
|
|
16
|
+
from .factory import TransferQueueStorageManagerFactory
|
|
17
|
+
from .simple_backend_manager import AsyncSimpleStorageManager
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"TransferQueueStorageManager",
|
|
21
|
+
"TransferQueueStorageManagerFactory",
|
|
22
|
+
"AsyncSimpleStorageManager",
|
|
23
|
+
]
|