TransferQueue 0.1.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recipe/simple_use_case/async_demo.py +331 -0
- recipe/simple_use_case/sync_demo.py +220 -0
- tests/test_async_simple_storage_manager.py +339 -0
- tests/test_client.py +423 -0
- tests/test_controller.py +274 -0
- tests/test_controller_data_partitions.py +513 -0
- tests/test_kv_storage_manager.py +92 -0
- tests/test_put.py +327 -0
- tests/test_samplers.py +492 -0
- tests/test_serial_utils_on_cpu.py +202 -0
- tests/test_simple_storage_unit.py +443 -0
- tests/test_storage_client_factory.py +45 -0
- transfer_queue/__init__.py +48 -0
- transfer_queue/client.py +611 -0
- transfer_queue/controller.py +1187 -0
- transfer_queue/metadata.py +460 -0
- transfer_queue/sampler/__init__.py +19 -0
- transfer_queue/sampler/base.py +74 -0
- transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
- transfer_queue/sampler/sequential_sampler.py +75 -0
- transfer_queue/storage/__init__.py +25 -0
- transfer_queue/storage/clients/__init__.py +24 -0
- transfer_queue/storage/clients/base.py +22 -0
- transfer_queue/storage/clients/factory.py +55 -0
- transfer_queue/storage/clients/yuanrong_client.py +118 -0
- transfer_queue/storage/managers/__init__.py +23 -0
- transfer_queue/storage/managers/base.py +460 -0
- transfer_queue/storage/managers/factory.py +43 -0
- transfer_queue/storage/managers/simple_backend_manager.py +611 -0
- transfer_queue/storage/managers/yuanrong_manager.py +18 -0
- transfer_queue/storage/simple_backend.py +451 -0
- transfer_queue/utils/__init__.py +13 -0
- transfer_queue/utils/serial_utils.py +240 -0
- transfer_queue/utils/utils.py +132 -0
- transfer_queue/utils/zmq_utils.py +170 -0
- transfer_queue/version/version +1 -0
- transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
- transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
- transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
- transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
- transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,1187 @@
|
|
|
1
|
+
# Copyright 2025 The TransferQueue Team
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import time
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from threading import Thread
|
|
21
|
+
from typing import Any, Optional
|
|
22
|
+
from uuid import uuid4
|
|
23
|
+
|
|
24
|
+
import ray
|
|
25
|
+
import torch
|
|
26
|
+
import zmq
|
|
27
|
+
from ray.util import get_node_ip_address
|
|
28
|
+
|
|
29
|
+
from transfer_queue.metadata import (
|
|
30
|
+
BatchMeta,
|
|
31
|
+
FieldMeta,
|
|
32
|
+
SampleMeta,
|
|
33
|
+
)
|
|
34
|
+
from transfer_queue.sampler import BaseSampler, SequentialSampler
|
|
35
|
+
from transfer_queue.utils.utils import (
|
|
36
|
+
ProductionStatus,
|
|
37
|
+
TransferQueueRole,
|
|
38
|
+
)
|
|
39
|
+
from transfer_queue.utils.zmq_utils import (
|
|
40
|
+
ZMQMessage,
|
|
41
|
+
ZMQRequestType,
|
|
42
|
+
ZMQServerInfo,
|
|
43
|
+
create_zmq_socket,
|
|
44
|
+
get_free_port,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
|
|
49
|
+
|
|
50
|
+
TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 300))
|
|
51
|
+
TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 1))
|
|
52
|
+
TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL", 2))
|
|
53
|
+
|
|
54
|
+
TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM", 10)) # Initial number of samples
|
|
55
|
+
TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10))
|
|
56
|
+
|
|
57
|
+
# Expansion configuration - Unified approach using minimum expansion sizes
|
|
58
|
+
TQ_SAMPLE_MIN_EXPANSION_SIZE = int(
|
|
59
|
+
os.environ.get("TQ_SAMPLE_MIN_EXPANSION_SIZE", 10)
|
|
60
|
+
) # Minimum expansion size for samples (rows)
|
|
61
|
+
TQ_FIELD_MIN_EXPANSION_SIZE = int(
|
|
62
|
+
os.environ.get("TQ_FIELD_MIN_EXPANSION_SIZE", 5)
|
|
63
|
+
) # Minimum expansion size for fields (columns)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class PartitionIndexManager:
|
|
67
|
+
"""
|
|
68
|
+
Manages the mapping relationship between partitions and global indexes,
|
|
69
|
+
responsible for index allocation and reuse.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self):
|
|
73
|
+
# Records the set of global_indexes used by each partition
|
|
74
|
+
self.partition_to_indexes = defaultdict(set)
|
|
75
|
+
|
|
76
|
+
# Reusable global_index pool - stored using list
|
|
77
|
+
self.reusable_indexes = []
|
|
78
|
+
|
|
79
|
+
# Global index counter for allocating new indexes
|
|
80
|
+
self.global_index_counter = 0
|
|
81
|
+
|
|
82
|
+
# Track all active indexes
|
|
83
|
+
self.allocated_indexes = set()
|
|
84
|
+
|
|
85
|
+
def allocate_indexes(self, partition_id, count=1) -> list:
|
|
86
|
+
"""
|
|
87
|
+
Allocate global_indexes for the specified partition.
|
|
88
|
+
Prioritizes obtaining from reusable pool, allocates new indexes when insufficient.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
partition_id: Partition ID
|
|
92
|
+
count: Number of indexes needed
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
list: List of allocated global_indexes
|
|
96
|
+
"""
|
|
97
|
+
if count <= 0:
|
|
98
|
+
raise ValueError(f"Number of indexes needed must larger than 0, but got {count}")
|
|
99
|
+
indexes = []
|
|
100
|
+
|
|
101
|
+
# Get indexes from reusable pool
|
|
102
|
+
if self.reusable_indexes:
|
|
103
|
+
# Calculate number of indexes needed from reusable pool
|
|
104
|
+
num_reuse = min(count, len(self.reusable_indexes))
|
|
105
|
+
|
|
106
|
+
# Use slice operation to get multiple elements at once (FIFO principle)
|
|
107
|
+
indexes.extend(self.reusable_indexes[:num_reuse])
|
|
108
|
+
del self.reusable_indexes[:num_reuse]
|
|
109
|
+
|
|
110
|
+
# If reusable pool doesn't have enough indexes, allocate new ones
|
|
111
|
+
if len(indexes) < count:
|
|
112
|
+
# Ensure newly allocated indexes don't conflict with existing ones
|
|
113
|
+
needed = count - len(indexes)
|
|
114
|
+
# Batch allocate consecutive index ranges
|
|
115
|
+
start_index = self.global_index_counter
|
|
116
|
+
end_index = start_index + needed
|
|
117
|
+
|
|
118
|
+
# Directly generate consecutive index list
|
|
119
|
+
new_indexes = list(range(start_index, end_index))
|
|
120
|
+
|
|
121
|
+
# Batch update status
|
|
122
|
+
self.allocated_indexes.update(new_indexes)
|
|
123
|
+
self.global_index_counter = end_index
|
|
124
|
+
|
|
125
|
+
indexes.extend(new_indexes)
|
|
126
|
+
|
|
127
|
+
# Record partition-index relationship
|
|
128
|
+
self.partition_to_indexes[partition_id].update(indexes)
|
|
129
|
+
|
|
130
|
+
return indexes
|
|
131
|
+
|
|
132
|
+
def release_indexes(self, partition_id):
|
|
133
|
+
"""
|
|
134
|
+
Release all global_indexes of the specified partition, adding them to reusable pool.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
partition_id: Partition ID
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
list: List of released global_indexes
|
|
141
|
+
"""
|
|
142
|
+
if partition_id in self.partition_to_indexes:
|
|
143
|
+
indexes = self.partition_to_indexes.pop(partition_id)
|
|
144
|
+
|
|
145
|
+
# Add released indexes to reusable pool
|
|
146
|
+
self.reusable_indexes.extend(indexes)
|
|
147
|
+
|
|
148
|
+
# Remove these indexes from allocated_indexes
|
|
149
|
+
for idx in indexes:
|
|
150
|
+
self.allocated_indexes.discard(idx)
|
|
151
|
+
|
|
152
|
+
return indexes
|
|
153
|
+
return []
|
|
154
|
+
|
|
155
|
+
def get_indexes_for_partition(self, partition_id):
|
|
156
|
+
"""
|
|
157
|
+
Get all global_indexes for the specified partition.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
partition_id: Partition ID
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
set: Set of global_indexes for this partition
|
|
164
|
+
"""
|
|
165
|
+
return self.partition_to_indexes.get(partition_id, set()).copy()
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dataclass
|
|
169
|
+
class DataPartitionStatus:
|
|
170
|
+
"""
|
|
171
|
+
Robust status information for a data partition with dynamic expansion support.
|
|
172
|
+
|
|
173
|
+
This class tracks the production and consumption status of data within a specific
|
|
174
|
+
partition (e.g., "train@global_batch_0", "inference@kv_cache_1") with full support
|
|
175
|
+
for dynamic row and column expansion.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
partition_id: str
|
|
179
|
+
created_at: float = field(default_factory=time.time)
|
|
180
|
+
|
|
181
|
+
# Production status tensor - dynamically expandable
|
|
182
|
+
# Values: 0 = not produced, 1 = ready for consumption
|
|
183
|
+
production_status: Optional[torch.Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
|
|
184
|
+
|
|
185
|
+
# Consumption status per task - task_name -> consumption_tensor
|
|
186
|
+
# Each tensor tracks which samples have been consumed by that task
|
|
187
|
+
consumption_status: dict[str, torch.Tensor] = field(default_factory=dict)
|
|
188
|
+
|
|
189
|
+
# Field metadata
|
|
190
|
+
field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index
|
|
191
|
+
field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype}
|
|
192
|
+
field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape}
|
|
193
|
+
|
|
194
|
+
# Dynamic configuration - these are computed from the current state
|
|
195
|
+
@property
|
|
196
|
+
def total_samples_num(self) -> int:
|
|
197
|
+
"""Current number of samples (rows) in the partition."""
|
|
198
|
+
return self.production_status.shape[0] if self.production_status is not None else 0
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def total_fields_num(self) -> int:
|
|
202
|
+
"""Current number of fields (columns) in the partition."""
|
|
203
|
+
return len(self.field_name_mapping)
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def allocated_fields_num(self) -> int:
|
|
207
|
+
"""Current number of allocated columns in the tensor."""
|
|
208
|
+
return self.production_status.shape[1] if self.production_status is not None else 0
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def allocated_samples_num(self) -> int:
|
|
212
|
+
"""Current number of allocated rows in the tensor."""
|
|
213
|
+
return self.production_status.shape[0] if self.production_status is not None else 0
|
|
214
|
+
|
|
215
|
+
# ==================== Dynamic Expansion Methods ====================
|
|
216
|
+
|
|
217
|
+
def ensure_samples_capacity(self, required_samples: int) -> bool:
|
|
218
|
+
"""
|
|
219
|
+
Ensure the production status tensor has enough rows for the required samples.
|
|
220
|
+
Dynamically expands if needed using unified minimum expansion size.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
required_samples: Minimum number of samples needed
|
|
224
|
+
"""
|
|
225
|
+
current_samples = self.production_status.shape[0]
|
|
226
|
+
if required_samples > current_samples:
|
|
227
|
+
# Expand rows using minimum expansion size for predictable memory usage
|
|
228
|
+
expansion_needed = required_samples - current_samples
|
|
229
|
+
min_expansion = max(TQ_SAMPLE_MIN_EXPANSION_SIZE, expansion_needed)
|
|
230
|
+
new_samples = current_samples + min_expansion
|
|
231
|
+
new_fields = self.production_status.shape[1]
|
|
232
|
+
|
|
233
|
+
expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8)
|
|
234
|
+
expanded_tensor[:current_samples, :] = self.production_status
|
|
235
|
+
self.production_status = expanded_tensor
|
|
236
|
+
|
|
237
|
+
# Update consumption tensors for all tasks
|
|
238
|
+
for task_name, consumption_tensor in self.consumption_status.items():
|
|
239
|
+
expanded_consumption = torch.zeros(new_samples, dtype=torch.int8)
|
|
240
|
+
expanded_consumption[:current_samples] = consumption_tensor
|
|
241
|
+
self.consumption_status[task_name] = expanded_consumption
|
|
242
|
+
|
|
243
|
+
logger.debug(
|
|
244
|
+
f"Expanded partition {self.partition_id} from {current_samples} to {new_samples} samples "
|
|
245
|
+
f"(added {min_expansion} samples)"
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def ensure_fields_capacity(self, required_fields: int) -> bool:
|
|
249
|
+
"""
|
|
250
|
+
Ensure the production status tensor has enough columns for the required fields.
|
|
251
|
+
Dynamically expands if needed using unified minimum expansion size.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
required_fields: Minimum number of fields needed
|
|
255
|
+
"""
|
|
256
|
+
if self.production_status is None:
|
|
257
|
+
# Will be initialized when samples are added
|
|
258
|
+
return
|
|
259
|
+
|
|
260
|
+
current_fields = self.production_status.shape[1]
|
|
261
|
+
if required_fields > current_fields:
|
|
262
|
+
# Expand columns using minimum expansion size for predictable memory usage
|
|
263
|
+
expansion_needed = required_fields - current_fields
|
|
264
|
+
min_expansion = max(TQ_FIELD_MIN_EXPANSION_SIZE, expansion_needed)
|
|
265
|
+
new_fields = current_fields + min_expansion
|
|
266
|
+
new_samples = self.production_status.shape[0]
|
|
267
|
+
|
|
268
|
+
expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8)
|
|
269
|
+
expanded_tensor[:, :current_fields] = self.production_status
|
|
270
|
+
self.production_status = expanded_tensor
|
|
271
|
+
|
|
272
|
+
logger.debug(
|
|
273
|
+
f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields "
|
|
274
|
+
f"(added {min_expansion} fields)"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# ==================== Production Status Interface ====================
|
|
278
|
+
|
|
279
|
+
def update_production_status(
|
|
280
|
+
self,
|
|
281
|
+
global_indices: list[int],
|
|
282
|
+
field_names: list[str],
|
|
283
|
+
dtypes: Optional[dict[int, dict[str, Any]]] = None,
|
|
284
|
+
shapes: Optional[dict[int, dict[str, Any]]] = None,
|
|
285
|
+
) -> bool:
|
|
286
|
+
"""
|
|
287
|
+
Update production status for specific samples and fields.
|
|
288
|
+
Handles dynamic expansion of both samples and fields.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
global_indices: List of sample indices to update
|
|
292
|
+
field_names: List of field names to mark as produced
|
|
293
|
+
dtypes: Optional per-sample field dtype information
|
|
294
|
+
shapes: Optional per-sample field shape information
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
True if update was successful, False on error
|
|
298
|
+
"""
|
|
299
|
+
try:
|
|
300
|
+
# Determine required capacity
|
|
301
|
+
max_sample_idx = max(global_indices) if global_indices else -1
|
|
302
|
+
required_samples = max_sample_idx + 1
|
|
303
|
+
|
|
304
|
+
# Ensure we have enough rows
|
|
305
|
+
self.ensure_samples_capacity(required_samples)
|
|
306
|
+
|
|
307
|
+
# Register new fields if needed
|
|
308
|
+
new_fields = [field for field in field_names if field not in self.field_name_mapping]
|
|
309
|
+
if new_fields:
|
|
310
|
+
# Add new fields to mapping
|
|
311
|
+
for field in new_fields:
|
|
312
|
+
self.field_name_mapping[field] = len(self.field_name_mapping)
|
|
313
|
+
|
|
314
|
+
required_fields = len(self.field_name_mapping)
|
|
315
|
+
self.ensure_fields_capacity(required_fields)
|
|
316
|
+
|
|
317
|
+
# Update production status
|
|
318
|
+
if self.production_status is not None and global_indices and field_names:
|
|
319
|
+
field_indices = [self.field_name_mapping.get(field) for field in field_names]
|
|
320
|
+
self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1
|
|
321
|
+
|
|
322
|
+
# Update field metadata
|
|
323
|
+
self._update_field_metadata(global_indices, field_names, dtypes, shapes)
|
|
324
|
+
|
|
325
|
+
return True
|
|
326
|
+
|
|
327
|
+
except Exception as e:
|
|
328
|
+
logger.error(f"Error updating production status for partition {self.partition_id}: {e}")
|
|
329
|
+
return False
|
|
330
|
+
|
|
331
|
+
def _update_field_metadata(
|
|
332
|
+
self,
|
|
333
|
+
global_indices: list[int],
|
|
334
|
+
field_names: list[str],
|
|
335
|
+
dtypes: Optional[dict[int, dict[str, Any]]] = None,
|
|
336
|
+
shapes: Optional[dict[int, dict[str, Any]]] = None,
|
|
337
|
+
):
|
|
338
|
+
"""Update field dtype and shape metadata."""
|
|
339
|
+
for global_idx in global_indices:
|
|
340
|
+
if global_idx not in self.field_dtypes:
|
|
341
|
+
self.field_dtypes[global_idx] = {}
|
|
342
|
+
if global_idx not in self.field_shapes:
|
|
343
|
+
self.field_shapes[global_idx] = {}
|
|
344
|
+
|
|
345
|
+
for field_name in field_names:
|
|
346
|
+
if dtypes and global_idx in dtypes and field_name in dtypes[global_idx]:
|
|
347
|
+
self.field_dtypes[global_idx][field_name] = dtypes[global_idx][field_name]
|
|
348
|
+
if shapes and global_idx in shapes and field_name in shapes[global_idx]:
|
|
349
|
+
self.field_shapes[global_idx][field_name] = shapes[global_idx][field_name]
|
|
350
|
+
|
|
351
|
+
# ==================== Consumption Status Interface ====================
|
|
352
|
+
|
|
353
|
+
def get_consumption_status(self, task_name: str) -> torch.Tensor:
|
|
354
|
+
"""
|
|
355
|
+
Get or create consumption status for a specific task.
|
|
356
|
+
Handles dynamic expansion when new samples are added.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
task_name: Name of the consumer task
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Consumption status tensor for the specified task
|
|
363
|
+
"""
|
|
364
|
+
if task_name not in self.consumption_status:
|
|
365
|
+
if self.production_status is not None:
|
|
366
|
+
self.consumption_status[task_name] = torch.zeros(self.total_samples_num, dtype=torch.int8)
|
|
367
|
+
else:
|
|
368
|
+
self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8)
|
|
369
|
+
|
|
370
|
+
return self.consumption_status[task_name]
|
|
371
|
+
|
|
372
|
+
# TODO: No need return, just raise error. Same With other function
|
|
373
|
+
def mark_consumed(self, task_name: str, global_indices: list[int]) -> bool:
|
|
374
|
+
"""
|
|
375
|
+
Mark specific samples as consumed by a task.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
task_name: Name of the consumer task
|
|
379
|
+
global_indices: List of sample indices to mark as consumed
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
True if successful, False on error
|
|
383
|
+
"""
|
|
384
|
+
try:
|
|
385
|
+
consumption_status = self.get_consumption_status(task_name)
|
|
386
|
+
if consumption_status.numel() > 0 and global_indices:
|
|
387
|
+
consumption_status[global_indices] = 1
|
|
388
|
+
return True
|
|
389
|
+
except Exception as e:
|
|
390
|
+
logger.error(f"Error marking samples consumed for partition {self.partition_id}, task {task_name}: {e}")
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
# ==================== Data Scanning and Query Methods ====================
|
|
394
|
+
|
|
395
|
+
def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
|
|
396
|
+
"""
|
|
397
|
+
Scan data status to find samples ready for consumption.
|
|
398
|
+
This replaces the original _scan_data_status functionality.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
field_names: List of required field names
|
|
402
|
+
task_name: Name of the consumer task
|
|
403
|
+
|
|
404
|
+
Returns:
|
|
405
|
+
List of sample indices that are ready for consumption
|
|
406
|
+
"""
|
|
407
|
+
if self.production_status is None:
|
|
408
|
+
return []
|
|
409
|
+
|
|
410
|
+
# Check if all requested fields are registered
|
|
411
|
+
for field_name in field_names:
|
|
412
|
+
if field_name not in self.field_name_mapping:
|
|
413
|
+
return []
|
|
414
|
+
|
|
415
|
+
row_mask = torch.ones(self.total_samples_num, dtype=torch.bool)
|
|
416
|
+
|
|
417
|
+
# Apply consumption filter (exclude already consumed samples)
|
|
418
|
+
consumption_status = self.get_consumption_status(task_name)
|
|
419
|
+
if consumption_status is not None:
|
|
420
|
+
unconsumed_mask = consumption_status == 0
|
|
421
|
+
row_mask &= unconsumed_mask
|
|
422
|
+
|
|
423
|
+
# Create column mask for requested fields
|
|
424
|
+
col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
|
|
425
|
+
field_indices = [self.field_name_mapping[field] for field in field_names]
|
|
426
|
+
if field_indices:
|
|
427
|
+
col_mask[field_indices] = True
|
|
428
|
+
|
|
429
|
+
# Filter production status by masks
|
|
430
|
+
relevant_status = self.production_status[row_mask][:, col_mask]
|
|
431
|
+
|
|
432
|
+
# Check if all required fields are ready for each sample
|
|
433
|
+
all_fields_ready = torch.all(relevant_status, dim=1)
|
|
434
|
+
ready_indices_in_filtered = torch.nonzero(all_fields_ready, as_tuple=False).flatten()
|
|
435
|
+
|
|
436
|
+
# Map back to original sample indices
|
|
437
|
+
all_indices = torch.where(row_mask)[0]
|
|
438
|
+
ready_sample_indices = all_indices[ready_indices_in_filtered].tolist()
|
|
439
|
+
|
|
440
|
+
return ready_sample_indices
|
|
441
|
+
|
|
442
|
+
# ==================== Field Metadata Methods ====================
|
|
443
|
+
|
|
444
|
+
def get_field_dtype(self, global_index: int, field_name: str) -> Optional[Any]:
|
|
445
|
+
"""Get dtype for a specific sample and field."""
|
|
446
|
+
return self.field_dtypes.get(global_index, {}).get(field_name)
|
|
447
|
+
|
|
448
|
+
def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]:
|
|
449
|
+
"""Get shape for a specific sample and field."""
|
|
450
|
+
return self.field_shapes.get(global_index, {}).get(field_name)
|
|
451
|
+
|
|
452
|
+
# ==================== Statistics and Monitoring ====================
|
|
453
|
+
|
|
454
|
+
def get_statistics(self) -> dict[str, Any]:
|
|
455
|
+
"""Get detailed statistics for this partition."""
|
|
456
|
+
stats = {
|
|
457
|
+
"partition_id": self.partition_id,
|
|
458
|
+
"created_at": self.created_at,
|
|
459
|
+
"total_samples_num": self.total_samples_num,
|
|
460
|
+
"total_fields_num": self.total_fields_num,
|
|
461
|
+
"allocated_fields_num": self.allocated_fields_num,
|
|
462
|
+
"registered_tasks": list(self.consumption_status.keys()),
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
if self.production_status is not None:
|
|
466
|
+
produced_samples = torch.any(self.production_status == 1, dim=1).sum().item()
|
|
467
|
+
stats["produced_samples"] = produced_samples
|
|
468
|
+
stats["production_progress"] = (
|
|
469
|
+
produced_samples / self.total_samples_num if self.total_samples_num > 0 else 0
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Field-wise production statistics
|
|
473
|
+
field_stats = {}
|
|
474
|
+
for field_name, field_idx in self.field_name_mapping.items():
|
|
475
|
+
field_produced = (self.production_status[:, field_idx] == 1).sum().item()
|
|
476
|
+
field_stats[field_name] = {
|
|
477
|
+
"produced_samples": field_produced,
|
|
478
|
+
"production_progress": field_produced / self.total_samples_num if self.total_samples_num > 0 else 0,
|
|
479
|
+
}
|
|
480
|
+
stats["field_statistics"] = field_stats
|
|
481
|
+
|
|
482
|
+
# Consumption statistics per task
|
|
483
|
+
consumption_stats = {}
|
|
484
|
+
for task_name, consumption_tensor in self.consumption_status.items():
|
|
485
|
+
consumed_samples = (consumption_tensor == 1).sum().item()
|
|
486
|
+
consumption_stats[task_name] = {
|
|
487
|
+
"consumed_samples": consumed_samples,
|
|
488
|
+
"consumption_progress": consumed_samples / self.total_samples_num if self.total_samples_num > 0 else 0,
|
|
489
|
+
}
|
|
490
|
+
stats["consumption_statistics"] = consumption_stats
|
|
491
|
+
|
|
492
|
+
return stats
|
|
493
|
+
|
|
494
|
+
def clear_data(self, global_indexes_range: list[int], clear_consumption: bool = True) -> bool:
|
|
495
|
+
"""Clear all production and optionally consumption data."""
|
|
496
|
+
try:
|
|
497
|
+
if self.production_status is not None:
|
|
498
|
+
self.production_status[global_indexes_range, :] = 0
|
|
499
|
+
|
|
500
|
+
if clear_consumption:
|
|
501
|
+
for consumption_tensor in self.consumption_status.values():
|
|
502
|
+
consumption_tensor[global_indexes_range] = 0
|
|
503
|
+
|
|
504
|
+
return True
|
|
505
|
+
except Exception as e:
|
|
506
|
+
logger.error(f"Error clearing data for partition {self.partition_id}: {e}")
|
|
507
|
+
return False
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
@ray.remote(num_cpus=1)
|
|
511
|
+
class TransferQueueController:
|
|
512
|
+
"""
|
|
513
|
+
TransferQueue Controller with partition-based data management.
|
|
514
|
+
|
|
515
|
+
This refactored controller manages data through dynamic partitions instead of
|
|
516
|
+
fixed global batches. Each partition represents a logical data container
|
|
517
|
+
(e.g., "train@global_batch_0", "inference@kv_cache_1") that can be created
|
|
518
|
+
on-demand and managed independently.
|
|
519
|
+
|
|
520
|
+
Key improvements:
|
|
521
|
+
- Dynamic partition creation on-demand
|
|
522
|
+
- No dependency on training-specific parameters (global_batch_size, etc.)
|
|
523
|
+
- Support for diverse use cases (KV cache migration, model resharding, etc.)
|
|
524
|
+
- Flexible data organization through partition-based addressing
|
|
525
|
+
"""
|
|
526
|
+
|
|
527
|
+
def __init__(self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler) -> None:
|
|
528
|
+
"""Initialize the TransferQueue Controller.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
sampler: Sampler instance or sampler class to use for data sampling.
|
|
532
|
+
- If a BaseSampler instance is provided, it will be used directly
|
|
533
|
+
- If a BaseSampler subclass is provided, it will be instantiated
|
|
534
|
+
- Defaults to SequentialSampler for simple sequential sampling
|
|
535
|
+
- Example: sampler=GRPOGroupNSampler() (instance)
|
|
536
|
+
- Example: sampler=GRPOGroupNSampler (class)
|
|
537
|
+
"""
|
|
538
|
+
if isinstance(sampler, BaseSampler):
|
|
539
|
+
self.sampler = sampler
|
|
540
|
+
elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
|
|
541
|
+
self.sampler = sampler()
|
|
542
|
+
else:
|
|
543
|
+
raise TypeError(
|
|
544
|
+
f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}"
|
|
548
|
+
|
|
549
|
+
# Initialize ZMQ sockets for communication
|
|
550
|
+
self._init_zmq_socket()
|
|
551
|
+
|
|
552
|
+
# Partition management
|
|
553
|
+
self.partitions: dict[str, DataPartitionStatus] = {} # partition_id -> DataPartitionStatus
|
|
554
|
+
|
|
555
|
+
# Partition-GlobalIndex management
|
|
556
|
+
self.index_manager = PartitionIndexManager() # partition_id -> global_indexes
|
|
557
|
+
|
|
558
|
+
# Connected storage managers tracking
|
|
559
|
+
self._connected_storage_managers: set[str] = set()
|
|
560
|
+
|
|
561
|
+
# Start background processing threads
|
|
562
|
+
self._start_process_handshake()
|
|
563
|
+
self._start_process_update_data_status()
|
|
564
|
+
self._start_process_request()
|
|
565
|
+
|
|
566
|
+
logger.info(f"TransferQueue Controller {self.controller_id} initialized")
|
|
567
|
+
|
|
568
|
+
# ==================== Partition Management API ====================
|
|
569
|
+
|
|
570
|
+
def create_partition(self, partition_id: str) -> bool:
|
|
571
|
+
"""
|
|
572
|
+
Create a new data partition.
|
|
573
|
+
|
|
574
|
+
Note: Partitions now dynamically expand as needed, so initial capacity is not required.
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
partition_id: Unique identifier for the partition (e.g., "train@global_batch_0")
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
True if partition was created successfully, False if it already exists
|
|
581
|
+
"""
|
|
582
|
+
if partition_id in self.partitions:
|
|
583
|
+
logger.warning(f"Partition {partition_id} already exists")
|
|
584
|
+
return False
|
|
585
|
+
|
|
586
|
+
self.partitions[partition_id] = DataPartitionStatus(partition_id=partition_id)
|
|
587
|
+
|
|
588
|
+
logger.info(f"Created partition {partition_id} with dynamic capacity")
|
|
589
|
+
return True
|
|
590
|
+
|
|
591
|
+
def get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]:
|
|
592
|
+
"""
|
|
593
|
+
Get partition status information.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
partition_id: ID of the partition to retrieve
|
|
597
|
+
|
|
598
|
+
Returns:
|
|
599
|
+
DataPartitionStatus object if partition exists, None otherwise
|
|
600
|
+
"""
|
|
601
|
+
return self.partitions.get(partition_id)
|
|
602
|
+
|
|
603
|
+
def list_partitions(self) -> list[str]:
|
|
604
|
+
"""
|
|
605
|
+
List all available partition IDs.
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
List of partition IDs
|
|
609
|
+
"""
|
|
610
|
+
return list(self.partitions.keys())
|
|
611
|
+
|
|
612
|
+
def delete_partition(self, partition_id: str) -> bool:
|
|
613
|
+
"""
|
|
614
|
+
Delete a partition and all its data.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
partition_id: ID of the partition to delete
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
True if partition was deleted, False if it didn't exist
|
|
621
|
+
"""
|
|
622
|
+
if partition_id in self.partitions:
|
|
623
|
+
del self.partitions[partition_id]
|
|
624
|
+
logger.info(f"Deleted partition {partition_id}")
|
|
625
|
+
return True
|
|
626
|
+
return False
|
|
627
|
+
|
|
628
|
+
# ==================== Partition Index Management API ====================
|
|
629
|
+
|
|
630
|
+
def get_partition_index_range(self, partition: DataPartitionStatus) -> set:
|
|
631
|
+
"""
|
|
632
|
+
Get all indexes for a specific partition.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
partition: Partition identifier
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
Set of indexes allocated to the partition
|
|
639
|
+
"""
|
|
640
|
+
return self.index_manager.get_indexes_for_partition(partition)
|
|
641
|
+
|
|
642
|
+
# ==================== Data Production API ====================
|
|
643
|
+
|
|
644
|
+
# TODO: Modify dtypes & shapes to be required
|
|
645
|
+
def update_production_status(
|
|
646
|
+
self,
|
|
647
|
+
partition_id: str,
|
|
648
|
+
global_indexes: list[int],
|
|
649
|
+
field_names: list[str],
|
|
650
|
+
dtypes: Optional[dict[int, dict[str, Any]]] = None,
|
|
651
|
+
shapes: Optional[dict[int, dict[str, Any]]] = None,
|
|
652
|
+
) -> bool:
|
|
653
|
+
"""
|
|
654
|
+
Update production status for specific samples and fields in a partition.
|
|
655
|
+
Delegates to the partition's own update_production_status method.
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
partition_id: ID of the partition
|
|
659
|
+
global_indexes: List of sample indices to update
|
|
660
|
+
field_names: List of field names to mark as produced
|
|
661
|
+
dtypes: Optional per-sample field dtype information
|
|
662
|
+
shapes: Optional per-sample field shape information
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
True if update was successful, False otherwise
|
|
666
|
+
"""
|
|
667
|
+
partition = self.get_partition(partition_id)
|
|
668
|
+
if not partition:
|
|
669
|
+
logger.error(f"Partition {partition_id} not found")
|
|
670
|
+
return False
|
|
671
|
+
|
|
672
|
+
success = partition.update_production_status(global_indexes, field_names, dtypes, shapes)
|
|
673
|
+
if success:
|
|
674
|
+
logger.debug(
|
|
675
|
+
f"Updated production status for partition {partition_id}: samples={global_indexes}, "
|
|
676
|
+
f"fields={field_names}"
|
|
677
|
+
)
|
|
678
|
+
return success
|
|
679
|
+
|
|
680
|
+
# ==================== Data Consumption API ====================
|
|
681
|
+
|
|
682
|
+
def get_consumption_status(self, partition_id: str, task_name: str) -> Optional[torch.Tensor]:
|
|
683
|
+
"""
|
|
684
|
+
Get or create consumption status for a specific task and partition.
|
|
685
|
+
Delegates to the partition's own method.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
partition_id: ID of the partition
|
|
689
|
+
task_name: Name of the consumer task
|
|
690
|
+
|
|
691
|
+
Returns:
|
|
692
|
+
Consumption status tensor if partition exists, None otherwise
|
|
693
|
+
"""
|
|
694
|
+
partition = self.get_partition(partition_id)
|
|
695
|
+
if not partition:
|
|
696
|
+
return None
|
|
697
|
+
|
|
698
|
+
return partition.get_consumption_status(task_name)
|
|
699
|
+
|
|
700
|
+
def get_metadata(
|
|
701
|
+
self,
|
|
702
|
+
data_fields: list[str],
|
|
703
|
+
partition_id: str,
|
|
704
|
+
mode: str = "fetch",
|
|
705
|
+
task_name: str | None = None,
|
|
706
|
+
batch_size: int | None = None,
|
|
707
|
+
sampling_config: Optional[dict[str, Any]] = None,
|
|
708
|
+
*args,
|
|
709
|
+
**kwargs,
|
|
710
|
+
) -> BatchMeta:
|
|
711
|
+
"""
|
|
712
|
+
Retrieve metadata with support for three modes.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
data_fields: List of field names to include in metadata
|
|
716
|
+
partition_id: Partition id for which to retrieve metadata
|
|
717
|
+
mode: Operation mode - 'insert', 'fetch', or 'force_fetch'
|
|
718
|
+
- mode="insert": Create metadata for new samples (for data insertion)
|
|
719
|
+
- mode="fetch": Get metadata from ready samples using the configured sampler
|
|
720
|
+
- mode="force_fetch": Get metadata for unconsumed samples without sampling
|
|
721
|
+
(excludes already consumed samples)
|
|
722
|
+
task_name: Name of the consumer task (required for fetch modes)
|
|
723
|
+
batch_size: Number of samples to retrieve
|
|
724
|
+
*args: Additional positional arguments
|
|
725
|
+
**kwargs: Additional keyword arguments
|
|
726
|
+
|
|
727
|
+
Returns:
|
|
728
|
+
BatchMeta object containing the requested metadata
|
|
729
|
+
|
|
730
|
+
Raises:
|
|
731
|
+
TimeoutError: If waiting for sufficient data times out in fetch mode
|
|
732
|
+
"""
|
|
733
|
+
if partition_id not in self.partitions:
|
|
734
|
+
self.create_partition(partition_id)
|
|
735
|
+
|
|
736
|
+
if mode == "insert":
|
|
737
|
+
if data_fields:
|
|
738
|
+
# First put_data call, get_metadata in insert mode
|
|
739
|
+
batch_global_indexes = self.index_manager.allocate_indexes(partition_id, count=batch_size)
|
|
740
|
+
else:
|
|
741
|
+
# clear metadata call passes empty data_fields
|
|
742
|
+
batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
|
|
743
|
+
return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
|
|
744
|
+
|
|
745
|
+
assert task_name is not None
|
|
746
|
+
if mode == "fetch":
|
|
747
|
+
# Find ready samples within current data partition and package into BatchMeta when reading
|
|
748
|
+
|
|
749
|
+
start_time = time.time()
|
|
750
|
+
while True:
|
|
751
|
+
# ready_for_consume_indexes: samples where all required fields are produced
|
|
752
|
+
# (production status is ready) and not yet consumed
|
|
753
|
+
ready_for_consume_indexes = self.scan_data_status(partition_id, data_fields, task_name, batch_size)
|
|
754
|
+
|
|
755
|
+
if len(ready_for_consume_indexes) < batch_size:
|
|
756
|
+
continue
|
|
757
|
+
|
|
758
|
+
# Try sampling - if it returns empty lists, retry
|
|
759
|
+
batch_global_indexes, consumed_indexes = self.sampler(
|
|
760
|
+
ready_for_consume_indexes,
|
|
761
|
+
batch_size,
|
|
762
|
+
**(sampling_config or {}),
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
# Check if we got valid results from the sampler
|
|
766
|
+
if len(batch_global_indexes) == batch_size:
|
|
767
|
+
break
|
|
768
|
+
|
|
769
|
+
if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
|
|
770
|
+
raise TimeoutError(
|
|
771
|
+
f"Timeout while waiting for sufficient data. "
|
|
772
|
+
f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}, "
|
|
773
|
+
f"Sampled: {len(batch_global_indexes)}"
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
logger.warning(
|
|
777
|
+
f"Insufficient complete groups available. Required: {batch_size}, "
|
|
778
|
+
f"Available: {len(ready_for_consume_indexes)}, "
|
|
779
|
+
f"Sampled: {len(batch_global_indexes)}. Retrying in "
|
|
780
|
+
f"{TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
|
|
781
|
+
)
|
|
782
|
+
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
|
|
783
|
+
logger.debug(f"ready for consume idx: {ready_for_consume_indexes}")
|
|
784
|
+
logger.debug(f"sampled idx: {batch_global_indexes}")
|
|
785
|
+
elif mode == "force_fetch":
|
|
786
|
+
global_indexes_range = self.index_manager.get_indexes_for_partition(partition_id)
|
|
787
|
+
consumer_status = self.get_consumption_status(partition_id, task_name)
|
|
788
|
+
not_consumed_idx = [i for i in global_indexes_range if consumer_status[i] == 0]
|
|
789
|
+
batch_global_indexes = not_consumed_idx
|
|
790
|
+
consumed_indexes = []
|
|
791
|
+
|
|
792
|
+
# Package into metadata
|
|
793
|
+
metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
|
|
794
|
+
|
|
795
|
+
# Mark samples as consumed if in fetch mode
|
|
796
|
+
if mode == "fetch" and consumed_indexes:
|
|
797
|
+
partition = self.partitions[partition_id]
|
|
798
|
+
partition.mark_consumed(task_name, consumed_indexes)
|
|
799
|
+
|
|
800
|
+
logger.debug(f"get_metadata: {metadata}")
|
|
801
|
+
|
|
802
|
+
return metadata
|
|
803
|
+
|
|
804
|
+
def scan_data_status(
|
|
805
|
+
self,
|
|
806
|
+
partition_id: str,
|
|
807
|
+
data_fields: list[str],
|
|
808
|
+
task_name: str,
|
|
809
|
+
batch_size: int,
|
|
810
|
+
sample_filter: Optional[list[int]] = None,
|
|
811
|
+
timeout: float = TQ_CONTROLLER_GET_METADATA_TIMEOUT,
|
|
812
|
+
) -> list[int]:
|
|
813
|
+
"""
|
|
814
|
+
Find samples that are ready for consumption in a specific partition.
|
|
815
|
+
Delegates scanning functionality to the partition's own method.
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
partition_id: ID of the partition
|
|
819
|
+
data_fields: List of required field names
|
|
820
|
+
task_name: Name of the consumer task
|
|
821
|
+
batch_size: Number of samples needed
|
|
822
|
+
sample_filter: Optional list of specific sample indices to consider
|
|
823
|
+
timeout: Maximum time to wait for sufficient data
|
|
824
|
+
|
|
825
|
+
Returns:
|
|
826
|
+
List of sample indices that are ready for consumption
|
|
827
|
+
|
|
828
|
+
Raises:
|
|
829
|
+
TimeoutError: If sufficient data is not available within timeout
|
|
830
|
+
"""
|
|
831
|
+
start_time = time.time()
|
|
832
|
+
|
|
833
|
+
while True:
|
|
834
|
+
partition = self.get_partition(partition_id)
|
|
835
|
+
if not partition:
|
|
836
|
+
if time.time() - start_time > timeout:
|
|
837
|
+
raise TimeoutError(f"Partition {partition_id} not found")
|
|
838
|
+
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
# Use partition's own scanning method
|
|
842
|
+
ready_sample_indices = partition.scan_data_status(data_fields, task_name)
|
|
843
|
+
|
|
844
|
+
if len(ready_sample_indices) >= batch_size:
|
|
845
|
+
return ready_sample_indices[:batch_size]
|
|
846
|
+
|
|
847
|
+
if time.time() - start_time > timeout:
|
|
848
|
+
raise TimeoutError(
|
|
849
|
+
f"Timeout waiting for sufficient data in partition {partition_id}. "
|
|
850
|
+
f"Required: {batch_size}, Available: {len(ready_sample_indices)}"
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
logger.warning(
|
|
854
|
+
f"Insufficient data in partition {partition_id}. Required: {batch_size}, "
|
|
855
|
+
f"Available: {len(ready_sample_indices)}. Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
|
|
856
|
+
)
|
|
857
|
+
time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
|
|
858
|
+
|
|
859
|
+
# ==================== Metadata Generation API ====================
|
|
860
|
+
|
|
861
|
+
def generate_batch_meta(
|
|
862
|
+
self,
|
|
863
|
+
partition_id: str,
|
|
864
|
+
batch_global_indexes: list[int],
|
|
865
|
+
data_fields: list[str],
|
|
866
|
+
mode: str = "fetch",
|
|
867
|
+
) -> BatchMeta:
|
|
868
|
+
"""
|
|
869
|
+
Generate BatchMeta for specific samples in a partition.
|
|
870
|
+
|
|
871
|
+
This function is responsible only for metadata generation and does not
|
|
872
|
+
modify consumption state. State management is handled by the calling function.
|
|
873
|
+
|
|
874
|
+
Args:
|
|
875
|
+
partition_id: ID of the partition
|
|
876
|
+
batch_global_indexes: List of sample indices to include in the batch
|
|
877
|
+
data_fields: List of field names to include
|
|
878
|
+
mode: Operation mode - 'fetch', 'insert', or 'force_fetch'
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
BatchMeta object containing sample metadata
|
|
882
|
+
|
|
883
|
+
Raises:
|
|
884
|
+
ValueError: If partition doesn't exist or invalid mode
|
|
885
|
+
"""
|
|
886
|
+
partition = self.get_partition(partition_id)
|
|
887
|
+
if not partition:
|
|
888
|
+
raise ValueError(f"Partition {partition_id} not found")
|
|
889
|
+
|
|
890
|
+
if mode not in ["fetch", "insert", "force_fetch"]:
|
|
891
|
+
raise ValueError(f"Invalid mode: {mode}")
|
|
892
|
+
|
|
893
|
+
# Generate sample metadata
|
|
894
|
+
samples = []
|
|
895
|
+
for global_index in batch_global_indexes:
|
|
896
|
+
fields = {}
|
|
897
|
+
for field_name in data_fields:
|
|
898
|
+
# Determine production status
|
|
899
|
+
if mode == "fetch":
|
|
900
|
+
production_status = ProductionStatus.READY_FOR_CONSUME
|
|
901
|
+
dtype = partition.get_field_dtype(global_index, field_name)
|
|
902
|
+
shape = partition.get_field_shape(global_index, field_name)
|
|
903
|
+
elif mode == "insert":
|
|
904
|
+
production_status = ProductionStatus.NOT_PRODUCED
|
|
905
|
+
dtype = None
|
|
906
|
+
shape = None
|
|
907
|
+
elif mode == "force_fetch":
|
|
908
|
+
field_index = partition.field_name_mapping.get(field_name)
|
|
909
|
+
if (
|
|
910
|
+
field_index is not None
|
|
911
|
+
and partition.production_status is not None
|
|
912
|
+
and partition.production_status[global_index, field_index] == 1
|
|
913
|
+
):
|
|
914
|
+
production_status = ProductionStatus.NOT_PRODUCED
|
|
915
|
+
dtype = partition.get_field_dtype(global_index, field_name)
|
|
916
|
+
shape = partition.get_field_shape(global_index, field_name)
|
|
917
|
+
else:
|
|
918
|
+
production_status = ProductionStatus.NOT_PRODUCED
|
|
919
|
+
dtype = None
|
|
920
|
+
shape = None
|
|
921
|
+
|
|
922
|
+
fields[field_name] = FieldMeta(
|
|
923
|
+
name=field_name,
|
|
924
|
+
dtype=dtype,
|
|
925
|
+
shape=shape,
|
|
926
|
+
production_status=production_status,
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
sample = SampleMeta(
|
|
930
|
+
partition_id=partition_id,
|
|
931
|
+
global_index=global_index,
|
|
932
|
+
fields=fields,
|
|
933
|
+
)
|
|
934
|
+
samples.append(sample)
|
|
935
|
+
|
|
936
|
+
return BatchMeta(samples=samples)
|
|
937
|
+
|
|
938
|
+
def clear(self, partition_id: str, clear_consumption: bool = True) -> bool:
|
|
939
|
+
"""
|
|
940
|
+
Clear data for a specific partition.
|
|
941
|
+
|
|
942
|
+
Args:
|
|
943
|
+
partition_id: ID of the partition to clear
|
|
944
|
+
clear_consumption: Whether to also clear consumption status
|
|
945
|
+
|
|
946
|
+
Returns:
|
|
947
|
+
True if cleared successfully, False otherwise
|
|
948
|
+
"""
|
|
949
|
+
partition = self.get_partition(partition_id)
|
|
950
|
+
if not partition:
|
|
951
|
+
raise ValueError(f"Partition {partition_id} not found")
|
|
952
|
+
|
|
953
|
+
global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id))
|
|
954
|
+
success = partition.clear_data(global_indexes_range, clear_consumption)
|
|
955
|
+
self.index_manager.release_indexes(partition_id)
|
|
956
|
+
if success:
|
|
957
|
+
logger.info(f"Cleared data for partition {partition_id}")
|
|
958
|
+
return success
|
|
959
|
+
|
|
960
|
+
def _init_zmq_socket(self):
|
|
961
|
+
"""Initialize ZMQ sockets for communication."""
|
|
962
|
+
self.zmq_context = zmq.Context()
|
|
963
|
+
self._node_ip = get_node_ip_address()
|
|
964
|
+
self._handshake_socket_port = get_free_port()
|
|
965
|
+
self._request_handle_socket_port = get_free_port()
|
|
966
|
+
self._data_status_update_socket_port = get_free_port()
|
|
967
|
+
|
|
968
|
+
self.handshake_socket = create_zmq_socket(
|
|
969
|
+
ctx=self.zmq_context,
|
|
970
|
+
socket_type=zmq.ROUTER,
|
|
971
|
+
)
|
|
972
|
+
self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}")
|
|
973
|
+
|
|
974
|
+
self.request_handle_socket = create_zmq_socket(
|
|
975
|
+
ctx=self.zmq_context,
|
|
976
|
+
socket_type=zmq.ROUTER,
|
|
977
|
+
)
|
|
978
|
+
self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}")
|
|
979
|
+
|
|
980
|
+
self.data_status_update_socket = create_zmq_socket(
|
|
981
|
+
ctx=self.zmq_context,
|
|
982
|
+
socket_type=zmq.ROUTER,
|
|
983
|
+
)
|
|
984
|
+
self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}")
|
|
985
|
+
|
|
986
|
+
self.zmq_server_info = ZMQServerInfo(
|
|
987
|
+
role=TransferQueueRole.CONTROLLER,
|
|
988
|
+
id=self.controller_id,
|
|
989
|
+
ip=self._node_ip,
|
|
990
|
+
ports={
|
|
991
|
+
"handshake_socket": self._handshake_socket_port,
|
|
992
|
+
"request_handle_socket": self._request_handle_socket_port,
|
|
993
|
+
"data_status_update_socket": self._data_status_update_socket_port,
|
|
994
|
+
},
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
def _wait_connection(self):
|
|
998
|
+
"""Wait for storage instances to complete handshake with retransmission support."""
|
|
999
|
+
poller = zmq.Poller()
|
|
1000
|
+
poller.register(self.handshake_socket, zmq.POLLIN)
|
|
1001
|
+
|
|
1002
|
+
logger.info(f"Dynamic Controller {self.controller_id} started waiting for storage connections...")
|
|
1003
|
+
|
|
1004
|
+
while True:
|
|
1005
|
+
socks = dict(poller.poll(TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL * 1000))
|
|
1006
|
+
|
|
1007
|
+
if self.handshake_socket in socks:
|
|
1008
|
+
try:
|
|
1009
|
+
identity, serialized_msg = self.handshake_socket.recv_multipart()
|
|
1010
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
1011
|
+
|
|
1012
|
+
if request_msg.request_type == ZMQRequestType.HANDSHAKE:
|
|
1013
|
+
storage_manager_id = request_msg.sender_id
|
|
1014
|
+
|
|
1015
|
+
# Always send ACK for HANDSHAKE
|
|
1016
|
+
response_msg = ZMQMessage.create(
|
|
1017
|
+
request_type=ZMQRequestType.HANDSHAKE_ACK,
|
|
1018
|
+
sender_id=self.controller_id,
|
|
1019
|
+
body={},
|
|
1020
|
+
).serialize()
|
|
1021
|
+
self.handshake_socket.send_multipart([identity, response_msg])
|
|
1022
|
+
|
|
1023
|
+
# Track new connections
|
|
1024
|
+
if storage_manager_id not in self._connected_storage_managers:
|
|
1025
|
+
self._connected_storage_managers.add(storage_manager_id)
|
|
1026
|
+
storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
|
|
1027
|
+
logger.info(
|
|
1028
|
+
f"Dynamic Controller {self.controller_id} received handshake from "
|
|
1029
|
+
f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
|
|
1030
|
+
f"Total connected: {len(self._connected_storage_managers)}"
|
|
1031
|
+
)
|
|
1032
|
+
else:
|
|
1033
|
+
logger.debug(
|
|
1034
|
+
f"Dynamic Controller {self.controller_id} received duplicate handshake from "
|
|
1035
|
+
f"storage manager {storage_manager_id}. Resending ACK."
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
except Exception as e:
|
|
1039
|
+
logger.error(f"Dynamic Controller {self.controller_id} error processing handshake: {e}")
|
|
1040
|
+
|
|
1041
|
+
def _start_process_handshake(self):
|
|
1042
|
+
"""Start the handshake process thread."""
|
|
1043
|
+
self.wait_connection_thread = Thread(
|
|
1044
|
+
target=self._wait_connection, name="DynamicTransferQueueControllerWaitConnectionThread", daemon=True
|
|
1045
|
+
)
|
|
1046
|
+
self.wait_connection_thread.start()
|
|
1047
|
+
|
|
1048
|
+
def _start_process_update_data_status(self):
|
|
1049
|
+
"""Start the data status update processing thread."""
|
|
1050
|
+
self.process_update_data_status_thread = Thread(
|
|
1051
|
+
target=self._update_data_status,
|
|
1052
|
+
name="DynamicTransferQueueControllerProcessUpdateDataStatusThread",
|
|
1053
|
+
daemon=True,
|
|
1054
|
+
)
|
|
1055
|
+
self.process_update_data_status_thread.start()
|
|
1056
|
+
|
|
1057
|
+
def _start_process_request(self):
|
|
1058
|
+
"""Start the request processing thread."""
|
|
1059
|
+
self.process_request_thread = Thread(
|
|
1060
|
+
target=self._process_request, name="DynamicTransferQueueControllerProcessRequestThread", daemon=True
|
|
1061
|
+
)
|
|
1062
|
+
self.process_request_thread.start()
|
|
1063
|
+
|
|
1064
|
+
def _process_request(self):
|
|
1065
|
+
"""Main request processing loop - adapted for partition-based operations."""
|
|
1066
|
+
while True:
|
|
1067
|
+
identity, serialized_msg = self.request_handle_socket.recv_multipart()
|
|
1068
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
1069
|
+
|
|
1070
|
+
if request_msg.request_type == ZMQRequestType.GET_META:
|
|
1071
|
+
params = request_msg.body
|
|
1072
|
+
|
|
1073
|
+
metadata = self.get_metadata(
|
|
1074
|
+
data_fields=params["data_fields"],
|
|
1075
|
+
batch_size=params["batch_size"],
|
|
1076
|
+
partition_id=params["partition_id"],
|
|
1077
|
+
mode=params.get("mode", "fetch"),
|
|
1078
|
+
task_name=params.get("task_name"),
|
|
1079
|
+
sampling_config=params.get("sampling_config"),
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
response_msg = ZMQMessage.create(
|
|
1083
|
+
request_type=ZMQRequestType.GET_META_RESPONSE,
|
|
1084
|
+
sender_id=self.controller_id,
|
|
1085
|
+
receiver_id=request_msg.sender_id,
|
|
1086
|
+
body={"metadata": metadata},
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
|
|
1090
|
+
params = request_msg.body
|
|
1091
|
+
partition_id = params["partition_id"]
|
|
1092
|
+
|
|
1093
|
+
metadata = self.get_metadata(
|
|
1094
|
+
data_fields=[],
|
|
1095
|
+
partition_id=partition_id,
|
|
1096
|
+
mode="insert",
|
|
1097
|
+
)
|
|
1098
|
+
response_msg = ZMQMessage.create(
|
|
1099
|
+
request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
|
|
1100
|
+
sender_id=self.controller_id,
|
|
1101
|
+
receiver_id=request_msg.sender_id,
|
|
1102
|
+
body={"metadata": metadata},
|
|
1103
|
+
)
|
|
1104
|
+
elif request_msg.request_type == ZMQRequestType.CLEAR_META:
|
|
1105
|
+
params = request_msg.body
|
|
1106
|
+
partition_id = params["partition_id"]
|
|
1107
|
+
|
|
1108
|
+
clear_success = self.clear(partition_id)
|
|
1109
|
+
if clear_success:
|
|
1110
|
+
response_msg = ZMQMessage.create(
|
|
1111
|
+
request_type=ZMQRequestType.CLEAR_META_RESPONSE,
|
|
1112
|
+
sender_id=self.controller_id,
|
|
1113
|
+
receiver_id=request_msg.sender_id,
|
|
1114
|
+
body={"message": f"Clear operation completed by controller {self.controller_id}"},
|
|
1115
|
+
)
|
|
1116
|
+
else:
|
|
1117
|
+
response_msg = ZMQMessage.create(
|
|
1118
|
+
request_type=ZMQRequestType.CLEAR_META_RESPONSE,
|
|
1119
|
+
sender_id=self.controller_id,
|
|
1120
|
+
receiver_id=request_msg.sender_id,
|
|
1121
|
+
body={"error": f"Clear operation failed for partition {partition_id}"},
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
|
|
1125
|
+
# Handle consumption status checks
|
|
1126
|
+
params = request_msg.body
|
|
1127
|
+
|
|
1128
|
+
consumption_status = self.get_consumption_status(params["partition_id"], params["task_name"])
|
|
1129
|
+
sample_filter = params.get("sample_filter")
|
|
1130
|
+
|
|
1131
|
+
if consumption_status is not None and sample_filter:
|
|
1132
|
+
batch_status = consumption_status[sample_filter]
|
|
1133
|
+
consumed = torch.all(batch_status == 1).item()
|
|
1134
|
+
elif consumption_status is not None:
|
|
1135
|
+
batch_status = consumption_status
|
|
1136
|
+
consumed = torch.all(batch_status == 1).item()
|
|
1137
|
+
else:
|
|
1138
|
+
consumed = False
|
|
1139
|
+
|
|
1140
|
+
response_msg = ZMQMessage.create(
|
|
1141
|
+
request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
|
|
1142
|
+
sender_id=self.controller_id,
|
|
1143
|
+
receiver_id=request_msg.sender_id,
|
|
1144
|
+
body={
|
|
1145
|
+
"partition_id": params["partition_id"],
|
|
1146
|
+
"consumed": consumed,
|
|
1147
|
+
},
|
|
1148
|
+
)
|
|
1149
|
+
self.request_handle_socket.send_multipart([identity, response_msg.serialize()])
|
|
1150
|
+
|
|
1151
|
+
def _update_data_status(self):
|
|
1152
|
+
"""Process data status update messages from storage units - adapted for partitions."""
|
|
1153
|
+
while True:
|
|
1154
|
+
identity, serialized_msg = self.data_status_update_socket.recv_multipart()
|
|
1155
|
+
request_msg = ZMQMessage.deserialize(serialized_msg)
|
|
1156
|
+
|
|
1157
|
+
if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
|
|
1158
|
+
message_data = request_msg.body
|
|
1159
|
+
partition_id = message_data.get("partition_id")
|
|
1160
|
+
|
|
1161
|
+
# Update production status
|
|
1162
|
+
success = self.update_production_status(
|
|
1163
|
+
partition_id=partition_id,
|
|
1164
|
+
global_indexes=message_data.get("global_indexes", []),
|
|
1165
|
+
field_names=message_data.get("fields", []),
|
|
1166
|
+
dtypes=message_data.get("dtypes", {}),
|
|
1167
|
+
shapes=message_data.get("shapes", {}),
|
|
1168
|
+
)
|
|
1169
|
+
|
|
1170
|
+
if success:
|
|
1171
|
+
logger.info(f"Updated production status for partition {partition_id}")
|
|
1172
|
+
|
|
1173
|
+
# Send acknowledgment
|
|
1174
|
+
response_msg = ZMQMessage.create(
|
|
1175
|
+
request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
|
|
1176
|
+
sender_id=self.controller_id,
|
|
1177
|
+
body={
|
|
1178
|
+
"controller_id": self.controller_id,
|
|
1179
|
+
"partition_id": partition_id,
|
|
1180
|
+
"success": success,
|
|
1181
|
+
},
|
|
1182
|
+
)
|
|
1183
|
+
self.data_status_update_socket.send_multipart([identity, response_msg.serialize()])
|
|
1184
|
+
|
|
1185
|
+
def get_zmq_server_info(self) -> ZMQServerInfo:
|
|
1186
|
+
"""Get ZMQ server connection information."""
|
|
1187
|
+
return self.zmq_server_info
|