TransferQueue 0.0.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +307 -0
- recipe/simple_use_case/sync_demo.py +223 -0
- tests/test_client.py +390 -0
- tests/test_controller.py +268 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +479 -0
- transfer_queue/__init__.py +42 -0
- transfer_queue/client.py +663 -0
- transfer_queue/controller.py +772 -0
- transfer_queue/metadata.py +603 -0
- transfer_queue/storage.py +515 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +98 -0
- transfer_queue/utils/zmq_utils.py +175 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.0.1.dev0.dist-info/METADATA +15 -0
- transferqueue-0.0.1.dev0.dist-info/RECORD +21 -0
- transferqueue-0.0.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.0.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.0.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,772 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import math
|
|
17
|
+
import os
|
|
18
|
+
import threading
|
|
19
|
+
import time
|
|
20
|
+
from threading import Thread
|
|
21
|
+
from typing import Any, Optional
|
|
22
|
+
from uuid import uuid4
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import ray
|
|
26
|
+
import torch
|
|
27
|
+
import zmq
|
|
28
|
+
from ray.util import get_node_ip_address
|
|
29
|
+
|
|
30
|
+
from transfer_queue.metadata import (
|
|
31
|
+
BatchMeta,
|
|
32
|
+
FieldMeta,
|
|
33
|
+
SampleMeta,
|
|
34
|
+
)
|
|
35
|
+
from transfer_queue.utils.utils import (
|
|
36
|
+
ProductionStatus,
|
|
37
|
+
TransferQueueRole,
|
|
38
|
+
sequential_sampler,
|
|
39
|
+
)
|
|
40
|
+
from transfer_queue.utils.zmq_utils import (
|
|
41
|
+
ZMQMessage,
|
|
42
|
+
ZMQRequestType,
|
|
43
|
+
ZMQServerInfo,
|
|
44
|
+
create_zmq_socket,
|
|
45
|
+
get_free_port,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
50
|
+
|
|
51
|
+
TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 300))
|
|
52
|
+
TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 1))
|
|
53
|
+
TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@ray.remote(num_cpus=1)
|
|
57
|
+
class TransferQueueController:
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
num_storage_units: int,
|
|
61
|
+
global_batch_size: int,
|
|
62
|
+
num_global_batch: int = 1,
|
|
63
|
+
num_n_samples: int = 1,
|
|
64
|
+
) -> None:
|
|
65
|
+
"""Initialize the TransferQueueController.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
num_storage_units: Number of storage units in the system
|
|
69
|
+
global_batch_size: Size of each global batch
|
|
70
|
+
num_global_batch: Number of global batches to maintain in storage
|
|
71
|
+
num_n_samples: For each prompt, sample n responses
|
|
72
|
+
"""
|
|
73
|
+
self.controller_id = f"TQ_CONTROLLER_{uuid4()}"
|
|
74
|
+
|
|
75
|
+
self._init_zmq_socket() # Initialize ZMQ sockets for data communication
|
|
76
|
+
|
|
77
|
+
self.num_storage_units = num_storage_units
|
|
78
|
+
self.global_batch_size = (
|
|
79
|
+
global_batch_size # Used as offset for global index to identify corresponding global step
|
|
80
|
+
)
|
|
81
|
+
self.num_global_batch = num_global_batch
|
|
82
|
+
self.num_n_samples = num_n_samples
|
|
83
|
+
self.total_storage_size = self.global_batch_size * self.num_global_batch * self.num_n_samples
|
|
84
|
+
|
|
85
|
+
self.data_production_status = torch.zeros(
|
|
86
|
+
self.total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8
|
|
87
|
+
) # Initialize with default number of fields, dynamically extensible
|
|
88
|
+
# task_name -> consumption_status mapping
|
|
89
|
+
self.data_consumption_status: dict[str, torch.Tensor] = {}
|
|
90
|
+
self.field_name_mapping: dict[
|
|
91
|
+
str, int
|
|
92
|
+
] = {} # Mapping table from field_name to the column indices in self.data_production_status tables
|
|
93
|
+
# Per-field dtype and shape storage: {global_index: {field_name: {'dtype': dtype, 'shape': shape}}}
|
|
94
|
+
self.per_tensor_dtype_mapping: dict[int, dict[str, Any]] = {}
|
|
95
|
+
self.per_tensor_shape_mapping: dict[int, dict[str, Any]] = {}
|
|
96
|
+
|
|
97
|
+
self._build_index_storage_mapping()
|
|
98
|
+
|
|
99
|
+
self._start_process_handshake()
|
|
100
|
+
self._start_process_update_data_status()
|
|
101
|
+
self._start_process_request()
|
|
102
|
+
|
|
103
|
+
def _get_consumption_status(self, task_name: str) -> torch.Tensor:
|
|
104
|
+
"""
|
|
105
|
+
Get or create the consumption status tensor for a specific task.
|
|
106
|
+
The consumption status is a binary, 1D tensor that records whether the corresponding sample has been consumed
|
|
107
|
+
by the task.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
task_name: Name of the consumer task
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Consumption status tensor for the specified task
|
|
114
|
+
"""
|
|
115
|
+
# Retrieve or create the consumption state tensor for a specified consumer
|
|
116
|
+
if task_name not in self.data_consumption_status:
|
|
117
|
+
# Initialize state for a new consumer
|
|
118
|
+
self.data_consumption_status[task_name] = torch.zeros(self.total_storage_size, dtype=torch.int8)
|
|
119
|
+
return self.data_consumption_status[task_name]
|
|
120
|
+
|
|
121
|
+
def _get_per_field_dtype(self, global_index: int, field_name: str) -> Optional[torch.dtype]:
|
|
122
|
+
"""Get dtype for a specific sample and field.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
global_index: Global index of the sample
|
|
126
|
+
field_name: Name of the field
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
dtype of the specified field for the sample, or None if not found
|
|
130
|
+
"""
|
|
131
|
+
return self.per_tensor_dtype_mapping.get(global_index, {}).get(field_name)
|
|
132
|
+
|
|
133
|
+
def _get_per_field_shape(self, global_index: int, field_name: str) -> Optional[torch.Size]:
|
|
134
|
+
"""Get shape for a specific sample and field.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
global_index: Global index of the sample
|
|
138
|
+
field_name: Name of the field
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Shape of the specified field for the sample, or None if not found
|
|
142
|
+
"""
|
|
143
|
+
return self.per_tensor_shape_mapping.get(global_index, {}).get(field_name)
|
|
144
|
+
|
|
145
|
+
def _step_to_global_index_range(self, global_step: int) -> tuple[int, int]:
|
|
146
|
+
"""Convert global step to corresponding global index range.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
global_step: The global step to convert
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Tuple of (start_index, end_index) for the given global step
|
|
153
|
+
"""
|
|
154
|
+
start_idx = (global_step % self.num_global_batch) * self.global_batch_size * self.num_n_samples
|
|
155
|
+
end_idx = start_idx + self.global_batch_size * self.num_n_samples
|
|
156
|
+
|
|
157
|
+
return start_idx, end_idx
|
|
158
|
+
|
|
159
|
+
def generate_data_status_mask(
|
|
160
|
+
self, data_fields: list[str], global_step: int, task_name: str
|
|
161
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
162
|
+
"""
|
|
163
|
+
Generate mask matrix for filtering data based on field availability and consumption status.
|
|
164
|
+
|
|
165
|
+
This function is called within _get_meta and generates a mask matrix based on
|
|
166
|
+
user-specified fields and the current step. The mask matrix selects the required
|
|
167
|
+
rows and columns from self.data_production_status while inversely selecting from
|
|
168
|
+
self.data_consumption_status to support automated vectorization.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
data_fields: List of field names to include in the mask
|
|
172
|
+
global_step: Current global step for row selection
|
|
173
|
+
task_name: Name of the consumer task for consumption status
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Tuple of (row_mask, col_mask) tensors for filtering data status matrices
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
# Check if all requested fields are registered
|
|
180
|
+
for col in data_fields:
|
|
181
|
+
if col not in self.field_name_mapping:
|
|
182
|
+
# Return empty mask indicating no available data for unregistered columns
|
|
183
|
+
empty_row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool)
|
|
184
|
+
empty_col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool)
|
|
185
|
+
return empty_row_mask, empty_col_mask
|
|
186
|
+
|
|
187
|
+
# Map steps to global indices
|
|
188
|
+
start_idx, end_idx = self._step_to_global_index_range(global_step)
|
|
189
|
+
row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool)
|
|
190
|
+
row_mask[start_idx:end_idx] = True
|
|
191
|
+
|
|
192
|
+
# Invert selection based on consumption status
|
|
193
|
+
consumer_status = self._get_consumption_status(task_name)
|
|
194
|
+
unconsumed_mask = consumer_status == 0
|
|
195
|
+
row_mask &= unconsumed_mask
|
|
196
|
+
|
|
197
|
+
# Select the specified fields
|
|
198
|
+
col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool)
|
|
199
|
+
valid_fields = [self.field_name_mapping[col] for col in data_fields]
|
|
200
|
+
if valid_fields:
|
|
201
|
+
col_mask[valid_fields] = True
|
|
202
|
+
|
|
203
|
+
return row_mask, col_mask
|
|
204
|
+
|
|
205
|
+
def _build_index_storage_mapping(self):
|
|
206
|
+
"""
|
|
207
|
+
Build mappings between global indices and storage locations.
|
|
208
|
+
|
|
209
|
+
Distributes samples across storage units based on total storage space and
|
|
210
|
+
maintains mappings between global index and local index within each storage.
|
|
211
|
+
"""
|
|
212
|
+
# Assign each sample to a storage node. Here we scatter the samples in each GBS to different storage nodes
|
|
213
|
+
# Samples are arranged sequentially, similar to generate_data_status_mask
|
|
214
|
+
real_global_batch_size = self.global_batch_size * self.num_n_samples
|
|
215
|
+
global_batch_per_storage_unit = math.ceil(real_global_batch_size / self.num_storage_units)
|
|
216
|
+
|
|
217
|
+
# Build mapping between global index and storage unit for locating each data sample
|
|
218
|
+
batch_storage_indices = np.repeat(np.arange(self.num_storage_units), global_batch_per_storage_unit)[
|
|
219
|
+
:real_global_batch_size
|
|
220
|
+
]
|
|
221
|
+
self._global_index_storage_rank_mapping = np.tile(batch_storage_indices, self.num_global_batch)
|
|
222
|
+
|
|
223
|
+
# Build mapping between global index and local index within each storage unit
|
|
224
|
+
indices = np.arange(self.total_storage_size)
|
|
225
|
+
pos_in_batch = indices % real_global_batch_size
|
|
226
|
+
g = indices // real_global_batch_size
|
|
227
|
+
pos_in_block = pos_in_batch % global_batch_per_storage_unit
|
|
228
|
+
self.global_index_local_index_mapping = g * global_batch_per_storage_unit + pos_in_block
|
|
229
|
+
|
|
230
|
+
def get_data_production_status(self) -> torch.Tensor:
|
|
231
|
+
"""
|
|
232
|
+
Get the current data production status matrix. The data production status is a 2D matrix that records whether
|
|
233
|
+
the corresponding data is ready for each field of each sample.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Tensor representing production status of all data fields
|
|
237
|
+
"""
|
|
238
|
+
return self.data_production_status
|
|
239
|
+
|
|
240
|
+
def get_field_name_mapping(self) -> dict[str, Any]:
|
|
241
|
+
"""Get the field name to column index mapping.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Dictionary mapping field names to their column indices
|
|
245
|
+
"""
|
|
246
|
+
return self.field_name_mapping
|
|
247
|
+
|
|
248
|
+
def get_data_consumption_status(self) -> dict[str, torch.Tensor]:
|
|
249
|
+
"""Get consumption status for all tasks.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Dictionary mapping task names to their consumption status tensors
|
|
253
|
+
"""
|
|
254
|
+
return self.data_consumption_status
|
|
255
|
+
|
|
256
|
+
def get_global_index_mapping(self):
|
|
257
|
+
"""Get global index to storage mapping information.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple containing storage rank mapping and local index mapping
|
|
261
|
+
"""
|
|
262
|
+
return self._global_index_storage_rank_mapping, self.global_index_local_index_mapping
|
|
263
|
+
|
|
264
|
+
def _get_metadata(
|
|
265
|
+
self,
|
|
266
|
+
data_fields: list[str],
|
|
267
|
+
batch_size: int,
|
|
268
|
+
global_step: int,
|
|
269
|
+
mode: str = "fetch",
|
|
270
|
+
task_name: str | None = None,
|
|
271
|
+
get_n_samples=False,
|
|
272
|
+
*args,
|
|
273
|
+
**kwargs,
|
|
274
|
+
) -> BatchMeta:
|
|
275
|
+
"""
|
|
276
|
+
Retrieve metadata with support for three modes.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
data_fields: List of field names to include in metadata
|
|
280
|
+
batch_size: Number of samples to retrieve
|
|
281
|
+
global_step: Global step for which to retrieve metadata
|
|
282
|
+
mode: Operation mode - 'insert', 'fetch', or 'force_fetch'
|
|
283
|
+
- mode="insert": Insert metadata for new rows (without checking data status)
|
|
284
|
+
- mode="fetch": Retrieve metadata for ready data (check data status and sample)
|
|
285
|
+
- mode="force_fetch": Directly return metadata (without checking data status)
|
|
286
|
+
task_name: Name of the consumer task (required for fetch modes)
|
|
287
|
+
get_n_samples: Whether to retrieve n_samples as groups
|
|
288
|
+
*args: Additional positional arguments
|
|
289
|
+
**kwargs: Additional keyword arguments
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
BatchMeta object containing the requested metadata
|
|
293
|
+
|
|
294
|
+
Raises:
|
|
295
|
+
TimeoutError: If waiting for sufficient data times out in fetch mode
|
|
296
|
+
"""
|
|
297
|
+
if mode == "insert":
|
|
298
|
+
# TODO: Currently we only supports put the entire GBS data in one time
|
|
299
|
+
assert batch_size == self.global_batch_size * self.num_n_samples, (
|
|
300
|
+
f"batch_size {batch_size} must equal "
|
|
301
|
+
f"global_batch_size * num_n_samples {self.global_batch_size * self.num_n_samples}"
|
|
302
|
+
)
|
|
303
|
+
start_idx, end_idx = self._step_to_global_index_range(global_step)
|
|
304
|
+
batch_global_indexes = list(range(start_idx, end_idx))
|
|
305
|
+
return self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode)
|
|
306
|
+
|
|
307
|
+
assert task_name is not None
|
|
308
|
+
if mode == "fetch":
|
|
309
|
+
# Find consumable samples within current batch and package into BatchMeta when reading
|
|
310
|
+
|
|
311
|
+
start_time = time.time()
|
|
312
|
+
while True:
|
|
313
|
+
ready_for_consume_idx = self._scan_data_status(data_fields, global_step, task_name, get_n_samples)
|
|
314
|
+
|
|
315
|
+
if len(ready_for_consume_idx) >= batch_size:
|
|
316
|
+
break
|
|
317
|
+
|
|
318
|
+
if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
|
|
319
|
+
raise TimeoutError(
|
|
320
|
+
f"Timeout while waiting for sufficient data. "
|
|
321
|
+
f"Required: {batch_size}, Available: {len(ready_for_consume_idx)}"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
logger.warning(
|
|
325
|
+
f"Insufficient data available. Required: {batch_size}, "
|
|
326
|
+
f"Available: {len(ready_for_consume_idx)}. Retrying in "
|
|
327
|
+
f"{TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
|
|
328
|
+
)
|
|
329
|
+
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
|
|
330
|
+
logger.debug(f"ready for consume idx: {ready_for_consume_idx}")
|
|
331
|
+
|
|
332
|
+
batch_global_indexes = sequential_sampler(
|
|
333
|
+
ready_for_consume_idx, batch_size, get_n_samples, self.num_n_samples
|
|
334
|
+
)
|
|
335
|
+
elif mode == "force_fetch":
|
|
336
|
+
start_idx, end_idx = self._step_to_global_index_range(global_step)
|
|
337
|
+
consumer_status = self._get_consumption_status(task_name)
|
|
338
|
+
not_consumed_idx = [i for i in range(start_idx, end_idx) if consumer_status[i] == 0]
|
|
339
|
+
batch_global_indexes = sequential_sampler(not_consumed_idx, batch_size, get_n_samples, self.num_n_samples)
|
|
340
|
+
|
|
341
|
+
# Mark this batch of data as consumed
|
|
342
|
+
consumer_status = self._get_consumption_status(task_name)
|
|
343
|
+
consumer_status[batch_global_indexes] = 1
|
|
344
|
+
# Package into metadata
|
|
345
|
+
metadata = self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode)
|
|
346
|
+
logger.debug(f"_get_metadata: {metadata}")
|
|
347
|
+
|
|
348
|
+
return metadata
|
|
349
|
+
|
|
350
|
+
def _scan_data_status(
|
|
351
|
+
self, data_fields: list[str], global_step: int, task_name: str, get_n_samples: bool
|
|
352
|
+
) -> list[int]:
|
|
353
|
+
"""
|
|
354
|
+
Scan data status to find samples ready for consumption.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
data_fields: List of field names to check
|
|
358
|
+
global_step: Global step to scan
|
|
359
|
+
task_name: Name of the consumer task
|
|
360
|
+
get_n_samples: Whether to return n_samples as groups
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
List of global indices that are ready for consumption
|
|
364
|
+
"""
|
|
365
|
+
# Get row and column masks
|
|
366
|
+
row_mask, col_mask = self.generate_data_status_mask(data_fields, global_step, task_name)
|
|
367
|
+
logger.debug(f"row_mask, col_mask: {row_mask, col_mask}")
|
|
368
|
+
|
|
369
|
+
if not row_mask.any() or not col_mask.any():
|
|
370
|
+
return []
|
|
371
|
+
|
|
372
|
+
# Extract subset of data status for relevant fields
|
|
373
|
+
logger.debug(f"self.data_production_status: {self.data_production_status}")
|
|
374
|
+
data_status_of_interest = self.data_production_status[:, col_mask]
|
|
375
|
+
logger.debug(f"data_status_of_interest: {data_status_of_interest}")
|
|
376
|
+
|
|
377
|
+
# Use torch.all for vectorized check instead of sum comparison
|
|
378
|
+
all_fields_ready = torch.all(data_status_of_interest, dim=1)
|
|
379
|
+
|
|
380
|
+
# Filter samples that meet criteria combined with row mask
|
|
381
|
+
ready_mask = all_fields_ready & row_mask
|
|
382
|
+
|
|
383
|
+
if get_n_samples and self.num_n_samples > 1:
|
|
384
|
+
# Reshape to group view and check group completeness
|
|
385
|
+
group_all_ready = torch.all(ready_mask.view(-1, self.num_n_samples), dim=1)
|
|
386
|
+
|
|
387
|
+
# Get indices of fully ready groups
|
|
388
|
+
ready_group_indices = group_all_ready.nonzero(as_tuple=False).flatten()
|
|
389
|
+
|
|
390
|
+
# Calculate all sample indices
|
|
391
|
+
sample_offset = torch.arange(self.num_n_samples)
|
|
392
|
+
ready_for_consume_idx = (
|
|
393
|
+
(ready_group_indices.unsqueeze(1) * self.num_n_samples + sample_offset).flatten().tolist()
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
return ready_for_consume_idx
|
|
397
|
+
else:
|
|
398
|
+
ready_for_consume_idx = torch.nonzero(ready_mask, as_tuple=False).flatten().tolist()
|
|
399
|
+
logger.debug(f"ready_for_consume_idx: {ready_for_consume_idx}")
|
|
400
|
+
|
|
401
|
+
return ready_for_consume_idx
|
|
402
|
+
|
|
403
|
+
def _generate_batch_meta(
|
|
404
|
+
self, global_step: int, global_indexes: list[int], data_fields: list[str], mode: str
|
|
405
|
+
) -> BatchMeta:
|
|
406
|
+
"""
|
|
407
|
+
Generate BatchMeta by resolving storage locations for given global indexes.
|
|
408
|
+
|
|
409
|
+
For each global index, looks up the corresponding storage node address using:
|
|
410
|
+
- global_index_local_index_mapping: Maps to local index within storage
|
|
411
|
+
- _global_index_storage_id_mapping: Maps to storage node identifier
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
global_step: Current global step
|
|
415
|
+
global_indexes: List of global indexes to process
|
|
416
|
+
data_fields: List of data field names
|
|
417
|
+
mode: Operation mode ('fetch', 'insert', or 'force_fetch')
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
BatchMeta object containing sample metadata with resolved storage locations
|
|
421
|
+
"""
|
|
422
|
+
global_arr = np.array(global_indexes)
|
|
423
|
+
storage_ids = self.global_index_storage_id_mapping[global_arr]
|
|
424
|
+
local_indexes = self.global_index_local_index_mapping[global_arr]
|
|
425
|
+
|
|
426
|
+
samples = []
|
|
427
|
+
|
|
428
|
+
# Create samples from the flattened BatchMeta data
|
|
429
|
+
# TODO: Optimize this
|
|
430
|
+
for i, global_index in enumerate(global_indexes):
|
|
431
|
+
local_index = local_indexes[i]
|
|
432
|
+
storage_id = storage_ids[i]
|
|
433
|
+
|
|
434
|
+
# Create FieldMeta objects for each field
|
|
435
|
+
fields = []
|
|
436
|
+
for field_name in data_fields:
|
|
437
|
+
if mode == "fetch":
|
|
438
|
+
production_status = ProductionStatus.READY_FOR_CONSUME # Since we filtered by ready status
|
|
439
|
+
# Get per-field dtype and shape for this specific global_index and field
|
|
440
|
+
dtype = self._get_per_field_dtype(global_index, field_name)
|
|
441
|
+
shape = self._get_per_field_shape(global_index, field_name)
|
|
442
|
+
elif mode == "insert":
|
|
443
|
+
production_status = ProductionStatus.NOT_PRODUCED # FIXME: not real-time
|
|
444
|
+
dtype = None
|
|
445
|
+
shape = None
|
|
446
|
+
elif mode == "force_fetch":
|
|
447
|
+
col_index = self.field_name_mapping.get(field_name)
|
|
448
|
+
if col_index is not None and self.data_production_status[global_index, col_index] == 1:
|
|
449
|
+
production_status = ProductionStatus.READY_FOR_CONSUME
|
|
450
|
+
dtype = self._get_per_field_dtype(global_index, field_name)
|
|
451
|
+
shape = self._get_per_field_shape(global_index, field_name)
|
|
452
|
+
else:
|
|
453
|
+
production_status = ProductionStatus.NOT_PRODUCED
|
|
454
|
+
dtype = None
|
|
455
|
+
shape = None
|
|
456
|
+
field_meta = FieldMeta(
|
|
457
|
+
name=field_name,
|
|
458
|
+
dtype=dtype,
|
|
459
|
+
shape=shape,
|
|
460
|
+
production_status=production_status,
|
|
461
|
+
)
|
|
462
|
+
fields.append(field_meta)
|
|
463
|
+
|
|
464
|
+
sample = SampleMeta(
|
|
465
|
+
global_step=global_step,
|
|
466
|
+
global_index=global_index,
|
|
467
|
+
storage_id=storage_id,
|
|
468
|
+
local_index=local_index,
|
|
469
|
+
fields={field.name: field for field in fields},
|
|
470
|
+
)
|
|
471
|
+
samples.append(sample)
|
|
472
|
+
|
|
473
|
+
return BatchMeta(samples=samples)
|
|
474
|
+
|
|
475
|
+
def _update_production_status(self, indexes: list[int], fields: list[str]) -> None:
|
|
476
|
+
"""
|
|
477
|
+
Update production status for specified indexes and fields.
|
|
478
|
+
|
|
479
|
+
Args:
|
|
480
|
+
indexes: List of global indexes to update
|
|
481
|
+
fields: List of field names to update
|
|
482
|
+
"""
|
|
483
|
+
# TODO: Replace self.data_production_status == 0 or ==1 operations with ProductionStatus enum
|
|
484
|
+
# Update data production status matrix
|
|
485
|
+
new_fields = [field for field in fields if field not in self.field_name_mapping]
|
|
486
|
+
if new_fields:
|
|
487
|
+
needed_fields = len(new_fields)
|
|
488
|
+
current_fields = self.data_production_status.shape[1]
|
|
489
|
+
# Expand data status matrix if needed
|
|
490
|
+
if len(self.field_name_mapping) + needed_fields > current_fields:
|
|
491
|
+
add_fields = max(TQ_INIT_FIELD_NUM, needed_fields + 1)
|
|
492
|
+
new_matrix = torch.zeros((self.total_storage_size, add_fields), dtype=torch.int8)
|
|
493
|
+
self.data_production_status = torch.cat([self.data_production_status, new_matrix], dim=1)
|
|
494
|
+
|
|
495
|
+
for field in fields:
|
|
496
|
+
if field not in self.field_name_mapping.keys():
|
|
497
|
+
self.field_name_mapping[field] = len(self.field_name_mapping)
|
|
498
|
+
self.data_production_status[
|
|
499
|
+
torch.tensor(indexes)[:, None], torch.tensor([self.field_name_mapping.get(field) for field in fields])
|
|
500
|
+
] = 1
|
|
501
|
+
|
|
502
|
+
def _update_field_info(
|
|
503
|
+
self,
|
|
504
|
+
fields: list[str],
|
|
505
|
+
per_tensor_dtypes: dict[int, dict[str, Any]],
|
|
506
|
+
per_tensor_shapes: dict[int, dict[str, Any]],
|
|
507
|
+
global_indexes: list[int],
|
|
508
|
+
) -> None:
|
|
509
|
+
"""
|
|
510
|
+
Store per-field dtype and shape information.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
fields: List of field names
|
|
514
|
+
per_tensor_dtypes: Dict mapping global_index to field dtypes {global_index: {field: dtype}}
|
|
515
|
+
per_tensor_shapes: Dict mapping global_index to field shapes {global_index: {field: shape}}
|
|
516
|
+
global_indexes: List of global indexes corresponding to the samples
|
|
517
|
+
"""
|
|
518
|
+
for global_idx in global_indexes:
|
|
519
|
+
if global_idx not in self.per_tensor_dtype_mapping:
|
|
520
|
+
self.per_tensor_dtype_mapping[global_idx] = {}
|
|
521
|
+
if global_idx not in self.per_tensor_shape_mapping:
|
|
522
|
+
self.per_tensor_shape_mapping[global_idx] = {}
|
|
523
|
+
|
|
524
|
+
for field in fields:
|
|
525
|
+
if global_idx in per_tensor_dtypes and field in per_tensor_dtypes[global_idx]:
|
|
526
|
+
self.per_tensor_dtype_mapping[global_idx][field] = per_tensor_dtypes[global_idx][field]
|
|
527
|
+
if global_idx in per_tensor_shapes and field in per_tensor_shapes[global_idx]:
|
|
528
|
+
self.per_tensor_shape_mapping[global_idx][field] = per_tensor_shapes[global_idx][field]
|
|
529
|
+
|
|
530
|
+
def _init_zmq_socket(self):
|
|
531
|
+
"""
|
|
532
|
+
Initialize ZMQ sockets for communication.
|
|
533
|
+
|
|
534
|
+
Sets up three ZMQ service ports for:
|
|
535
|
+
1. Receiving handshake requests from storage
|
|
536
|
+
2. Handling client data read/write requests
|
|
537
|
+
3. Receiving status update signals from storage
|
|
538
|
+
"""
|
|
539
|
+
self.zmq_context = zmq.Context()
|
|
540
|
+
|
|
541
|
+
self._node_ip = get_node_ip_address()
|
|
542
|
+
self._handshake_socket_port = get_free_port()
|
|
543
|
+
self._request_handle_socket_port = get_free_port()
|
|
544
|
+
self._data_status_update_socket_port = get_free_port()
|
|
545
|
+
|
|
546
|
+
self.handshake_socket = create_zmq_socket(
|
|
547
|
+
ctx=self.zmq_context,
|
|
548
|
+
socket_type=zmq.ROUTER,
|
|
549
|
+
)
|
|
550
|
+
self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}")
|
|
551
|
+
|
|
552
|
+
self.request_handle_socket = create_zmq_socket(
|
|
553
|
+
ctx=self.zmq_context,
|
|
554
|
+
socket_type=zmq.ROUTER,
|
|
555
|
+
)
|
|
556
|
+
self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}")
|
|
557
|
+
|
|
558
|
+
self.data_status_update_socket = create_zmq_socket(
|
|
559
|
+
ctx=self.zmq_context,
|
|
560
|
+
socket_type=zmq.ROUTER,
|
|
561
|
+
)
|
|
562
|
+
self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}")
|
|
563
|
+
|
|
564
|
+
self.zmq_server_info = ZMQServerInfo.create(
|
|
565
|
+
role=TransferQueueRole.CONTROLLER,
|
|
566
|
+
id=self.controller_id,
|
|
567
|
+
ip=self._node_ip,
|
|
568
|
+
ports={
|
|
569
|
+
"handshake_socket": self._handshake_socket_port,
|
|
570
|
+
"request_handle_socket": self._request_handle_socket_port,
|
|
571
|
+
"data_status_update_socket": self._data_status_update_socket_port,
|
|
572
|
+
},
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
def _wait_connection(self):
|
|
576
|
+
"""Wait for all storage instances to complete handshake.
|
|
577
|
+
|
|
578
|
+
Clients don't need handshake to support dynamic scaling. Continuously
|
|
579
|
+
listens for handshake messages until all expected storage units connect.
|
|
580
|
+
"""
|
|
581
|
+
# TODO(zjj): Consider if retransmission is needed (assuming cases where Storage doesn't receive ACK)
|
|
582
|
+
connected_storage_units = set()
|
|
583
|
+
while len(connected_storage_units) < self.num_storage_units:
|
|
584
|
+
identity, serialized_msg = self.handshake_socket.recv_multipart()
|
|
585
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
586
|
+
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
|
|
587
|
+
connected_storage_units.add(request_msg.sender_id)
|
|
588
|
+
response_msg = ZMQMessage.create(
|
|
589
|
+
request_type=ZMQRequestType.HANDSHAKE_ACK,
|
|
590
|
+
sender_id=self.controller_id,
|
|
591
|
+
body={},
|
|
592
|
+
).serialize()
|
|
593
|
+
self.handshake_socket.send_multipart([identity, response_msg])
|
|
594
|
+
logger.info("Controller sent handshake ack successfully!")
|
|
595
|
+
self.global_index_storage_id_mapping = np.array(sorted(list(connected_storage_units)))[
|
|
596
|
+
self._global_index_storage_rank_mapping
|
|
597
|
+
]
|
|
598
|
+
self.handshake_done.set()
|
|
599
|
+
|
|
600
|
+
def _start_process_handshake(self):
|
|
601
|
+
"""Start the handshake process thread."""
|
|
602
|
+
self.handshake_done = threading.Event()
|
|
603
|
+
self.wait_connection_thread = Thread(
|
|
604
|
+
target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True
|
|
605
|
+
)
|
|
606
|
+
self.wait_connection_thread.start()
|
|
607
|
+
|
|
608
|
+
def _start_process_update_data_status(self):
|
|
609
|
+
"""Start the data status update processing thread."""
|
|
610
|
+
self.process_update_data_status_thread = Thread(
|
|
611
|
+
target=self._update_data_status, name="TransferQueueControllerProcessUpdateDataStatusThread", daemon=True
|
|
612
|
+
)
|
|
613
|
+
self.process_update_data_status_thread.start()
|
|
614
|
+
|
|
615
|
+
def _start_process_request(self):
|
|
616
|
+
"""Start the request processing thread."""
|
|
617
|
+
self.process_request_thread = Thread(
|
|
618
|
+
target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True
|
|
619
|
+
)
|
|
620
|
+
self.process_request_thread.start()
|
|
621
|
+
|
|
622
|
+
def _process_request(self):
|
|
623
|
+
"""Main request processing loop.
|
|
624
|
+
|
|
625
|
+
Handles various request types including metadata retrieval,
|
|
626
|
+
consumption status checks, and clear operations.
|
|
627
|
+
"""
|
|
628
|
+
self.handshake_done.wait()
|
|
629
|
+
while True:
|
|
630
|
+
# ROUTER socket receives multi-part messages
|
|
631
|
+
identity, serialized_msg = self.request_handle_socket.recv_multipart()
|
|
632
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
633
|
+
|
|
634
|
+
if request_msg.request_type == ZMQRequestType.GET_META:
|
|
635
|
+
params = request_msg.body
|
|
636
|
+
logger.info("Controller preparing to get metadata...")
|
|
637
|
+
metadata = self._get_metadata(
|
|
638
|
+
data_fields=params["data_fields"],
|
|
639
|
+
batch_size=params["batch_size"],
|
|
640
|
+
global_step=params["global_step"],
|
|
641
|
+
mode=params.get("mode", "fetch"),
|
|
642
|
+
task_name=params.get("task_name", None),
|
|
643
|
+
get_n_samples=params.get("get_n_samples", False),
|
|
644
|
+
)
|
|
645
|
+
response_msg = ZMQMessage.create(
|
|
646
|
+
request_type=ZMQRequestType.GET_META_RESPONSE,
|
|
647
|
+
sender_id=self.controller_id,
|
|
648
|
+
receiver_id=request_msg.sender_id,
|
|
649
|
+
body={"metadata": metadata},
|
|
650
|
+
)
|
|
651
|
+
elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
|
|
652
|
+
params = request_msg.body
|
|
653
|
+
metadata = self._get_metadata(
|
|
654
|
+
data_fields=[],
|
|
655
|
+
batch_size=self.global_batch_size * self.num_n_samples,
|
|
656
|
+
global_step=params["global_step"],
|
|
657
|
+
mode="insert",
|
|
658
|
+
)
|
|
659
|
+
response_msg = ZMQMessage.create(
|
|
660
|
+
request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
|
|
661
|
+
sender_id=self.controller_id,
|
|
662
|
+
receiver_id=request_msg.sender_id,
|
|
663
|
+
body={"metadata": metadata},
|
|
664
|
+
)
|
|
665
|
+
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
|
|
666
|
+
params = request_msg.body
|
|
667
|
+
self.clear(global_step=params["global_step"])
|
|
668
|
+
response_msg = ZMQMessage.create(
|
|
669
|
+
request_type=ZMQRequestType.CLEAR_META_RESPONSE,
|
|
670
|
+
sender_id=self.controller_id,
|
|
671
|
+
receiver_id=request_msg.sender_id,
|
|
672
|
+
body={"message": f"Clear operation completed by controller {self.controller_id}"},
|
|
673
|
+
)
|
|
674
|
+
elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
|
|
675
|
+
# Check consumption status
|
|
676
|
+
params = request_msg.body
|
|
677
|
+
global_step = params["global_step"]
|
|
678
|
+
|
|
679
|
+
consumer_status = self._get_consumption_status(params["task_name"])
|
|
680
|
+
start_idx, end_idx = self._step_to_global_index_range(global_step)
|
|
681
|
+
batch_status = consumer_status[start_idx:end_idx]
|
|
682
|
+
consumed = torch.all(batch_status == 1).item()
|
|
683
|
+
|
|
684
|
+
# Build response message
|
|
685
|
+
response_msg = ZMQMessage.create(
|
|
686
|
+
request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
|
|
687
|
+
sender_id=self.controller_id,
|
|
688
|
+
receiver_id=request_msg.sender_id,
|
|
689
|
+
body={
|
|
690
|
+
"global_step": global_step,
|
|
691
|
+
"consumed": consumed,
|
|
692
|
+
},
|
|
693
|
+
)
|
|
694
|
+
self.request_handle_socket.send_multipart([identity, response_msg.serialize()])
|
|
695
|
+
logger.debug("Controller request_handle_socket sent multipart successfully!")
|
|
696
|
+
|
|
697
|
+
def _update_data_status(self):
|
|
698
|
+
"""Process data status update messages from storage units.
|
|
699
|
+
|
|
700
|
+
Continuously listens for data update notifications and updates
|
|
701
|
+
internal production status and field information accordingly.
|
|
702
|
+
"""
|
|
703
|
+
# Receive data status update information from storage
|
|
704
|
+
while True:
|
|
705
|
+
logger.debug("Preparing _update_data_status...")
|
|
706
|
+
identity, serialized_msg = self.data_status_update_socket.recv_multipart()
|
|
707
|
+
logger.debug("Controller received update_data_status request!")
|
|
708
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
709
|
+
logger.debug(f"[{self.controller_id}]: Controller received update_data_status request_msg: {request_msg}")
|
|
710
|
+
|
|
711
|
+
if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
|
|
712
|
+
message_data = request_msg.body
|
|
713
|
+
|
|
714
|
+
fields = message_data.get("fields", [])
|
|
715
|
+
global_indexes = message_data.get("global_indexes", [])
|
|
716
|
+
per_tensor_dtypes = message_data.get("dtypes", {}) # Now a dict of lists
|
|
717
|
+
per_tensor_shapes = message_data.get("shapes", {}) # Now a dict of lists
|
|
718
|
+
# Update data production status
|
|
719
|
+
logger.debug(f"global_indexes, fields: {global_indexes, fields}")
|
|
720
|
+
self._update_production_status(global_indexes, fields)
|
|
721
|
+
self._update_field_info(fields, per_tensor_dtypes, per_tensor_shapes, global_indexes)
|
|
722
|
+
logger.info("Controller updated production status successfully!")
|
|
723
|
+
|
|
724
|
+
# Send acknowledgment response
|
|
725
|
+
response_msg = ZMQMessage.create(
|
|
726
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
|
|
727
|
+
sender_id=self.controller_id,
|
|
728
|
+
body={
|
|
729
|
+
"controller_id": self.controller_id,
|
|
730
|
+
"message": f"Data update acknowledged from controller {self.controller_id}",
|
|
731
|
+
},
|
|
732
|
+
)
|
|
733
|
+
self.data_status_update_socket.send_multipart([identity, response_msg.serialize()])
|
|
734
|
+
logger.info("Controller sent DATA_UPDATE_ACK successfully!")
|
|
735
|
+
elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR:
|
|
736
|
+
# Handle data update errors
|
|
737
|
+
error_msg = request_msg.body.get("message", "Unknown error")
|
|
738
|
+
logger.error(f"Data update error from storage: {error_msg}")
|
|
739
|
+
|
|
740
|
+
# Send error acknowledgment response
|
|
741
|
+
response_msg = ZMQMessage.create(
|
|
742
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
|
|
743
|
+
sender_id=self.controller_id,
|
|
744
|
+
body={
|
|
745
|
+
"controller_id": self.controller_id,
|
|
746
|
+
"message": f"Error notification acknowledged from controller {self.controller_id}",
|
|
747
|
+
},
|
|
748
|
+
)
|
|
749
|
+
self.data_status_update_socket.send_multipart([identity, response_msg.serialize()])
|
|
750
|
+
|
|
751
|
+
def get_zmq_server_info(self) -> ZMQServerInfo:
|
|
752
|
+
"""Get ZMQ server connection information.
|
|
753
|
+
|
|
754
|
+
Returns:
|
|
755
|
+
ZMQServerInfo object containing connection details
|
|
756
|
+
"""
|
|
757
|
+
return self.zmq_server_info
|
|
758
|
+
|
|
759
|
+
def clear(self, global_step: int):
|
|
760
|
+
"""Clear data for a specific global batch.
|
|
761
|
+
|
|
762
|
+
Resets production and consumption status for all data in the specified
|
|
763
|
+
global step. Currently only supports clearing single GBS at a time.
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
global_step: The global step to clear data for
|
|
767
|
+
"""
|
|
768
|
+
start_idx, end_idx = self._step_to_global_index_range(global_step)
|
|
769
|
+
|
|
770
|
+
self.data_production_status[start_idx:end_idx, :] = 0
|
|
771
|
+
for task_name in self.data_consumption_status:
|
|
772
|
+
self.data_consumption_status[task_name][start_idx:end_idx] = 0
|