TransferQueue 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,772 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ import os
18
+ import threading
19
+ import time
20
+ from threading import Thread
21
+ from typing import Any, Optional
22
+ from uuid import uuid4
23
+
24
+ import numpy as np
25
+ import ray
26
+ import torch
27
+ import zmq
28
+ from ray.util import get_node_ip_address
29
+
30
+ from transfer_queue.metadata import (
31
+ BatchMeta,
32
+ FieldMeta,
33
+ SampleMeta,
34
+ )
35
+ from transfer_queue.utils.utils import (
36
+ ProductionStatus,
37
+ TransferQueueRole,
38
+ sequential_sampler,
39
+ )
40
+ from transfer_queue.utils.zmq_utils import (
41
+ ZMQMessage,
42
+ ZMQRequestType,
43
+ ZMQServerInfo,
44
+ create_zmq_socket,
45
+ get_free_port,
46
+ )
47
+
48
+ logger = logging.getLogger(__name__)
49
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
50
+
51
+ TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 300))
52
+ TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 1))
53
+ TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10))
54
+
55
+
56
+ @ray.remote(num_cpus=1)
57
+ class TransferQueueController:
58
+ def __init__(
59
+ self,
60
+ num_storage_units: int,
61
+ global_batch_size: int,
62
+ num_global_batch: int = 1,
63
+ num_n_samples: int = 1,
64
+ ) -> None:
65
+ """Initialize the TransferQueueController.
66
+
67
+ Args:
68
+ num_storage_units: Number of storage units in the system
69
+ global_batch_size: Size of each global batch
70
+ num_global_batch: Number of global batches to maintain in storage
71
+ num_n_samples: For each prompt, sample n responses
72
+ """
73
+ self.controller_id = f"TQ_CONTROLLER_{uuid4()}"
74
+
75
+ self._init_zmq_socket() # Initialize ZMQ sockets for data communication
76
+
77
+ self.num_storage_units = num_storage_units
78
+ self.global_batch_size = (
79
+ global_batch_size # Used as offset for global index to identify corresponding global step
80
+ )
81
+ self.num_global_batch = num_global_batch
82
+ self.num_n_samples = num_n_samples
83
+ self.total_storage_size = self.global_batch_size * self.num_global_batch * self.num_n_samples
84
+
85
+ self.data_production_status = torch.zeros(
86
+ self.total_storage_size, TQ_INIT_FIELD_NUM, dtype=torch.int8
87
+ ) # Initialize with default number of fields, dynamically extensible
88
+ # task_name -> consumption_status mapping
89
+ self.data_consumption_status: dict[str, torch.Tensor] = {}
90
+ self.field_name_mapping: dict[
91
+ str, int
92
+ ] = {} # Mapping table from field_name to the column indices in self.data_production_status tables
93
+ # Per-field dtype and shape storage: {global_index: {field_name: {'dtype': dtype, 'shape': shape}}}
94
+ self.per_tensor_dtype_mapping: dict[int, dict[str, Any]] = {}
95
+ self.per_tensor_shape_mapping: dict[int, dict[str, Any]] = {}
96
+
97
+ self._build_index_storage_mapping()
98
+
99
+ self._start_process_handshake()
100
+ self._start_process_update_data_status()
101
+ self._start_process_request()
102
+
103
+ def _get_consumption_status(self, task_name: str) -> torch.Tensor:
104
+ """
105
+ Get or create the consumption status tensor for a specific task.
106
+ The consumption status is a binary, 1D tensor that records whether the corresponding sample has been consumed
107
+ by the task.
108
+
109
+ Args:
110
+ task_name: Name of the consumer task
111
+
112
+ Returns:
113
+ Consumption status tensor for the specified task
114
+ """
115
+ # Retrieve or create the consumption state tensor for a specified consumer
116
+ if task_name not in self.data_consumption_status:
117
+ # Initialize state for a new consumer
118
+ self.data_consumption_status[task_name] = torch.zeros(self.total_storage_size, dtype=torch.int8)
119
+ return self.data_consumption_status[task_name]
120
+
121
+ def _get_per_field_dtype(self, global_index: int, field_name: str) -> Optional[torch.dtype]:
122
+ """Get dtype for a specific sample and field.
123
+
124
+ Args:
125
+ global_index: Global index of the sample
126
+ field_name: Name of the field
127
+
128
+ Returns:
129
+ dtype of the specified field for the sample, or None if not found
130
+ """
131
+ return self.per_tensor_dtype_mapping.get(global_index, {}).get(field_name)
132
+
133
+ def _get_per_field_shape(self, global_index: int, field_name: str) -> Optional[torch.Size]:
134
+ """Get shape for a specific sample and field.
135
+
136
+ Args:
137
+ global_index: Global index of the sample
138
+ field_name: Name of the field
139
+
140
+ Returns:
141
+ Shape of the specified field for the sample, or None if not found
142
+ """
143
+ return self.per_tensor_shape_mapping.get(global_index, {}).get(field_name)
144
+
145
+ def _step_to_global_index_range(self, global_step: int) -> tuple[int, int]:
146
+ """Convert global step to corresponding global index range.
147
+
148
+ Args:
149
+ global_step: The global step to convert
150
+
151
+ Returns:
152
+ Tuple of (start_index, end_index) for the given global step
153
+ """
154
+ start_idx = (global_step % self.num_global_batch) * self.global_batch_size * self.num_n_samples
155
+ end_idx = start_idx + self.global_batch_size * self.num_n_samples
156
+
157
+ return start_idx, end_idx
158
+
159
+ def generate_data_status_mask(
160
+ self, data_fields: list[str], global_step: int, task_name: str
161
+ ) -> tuple[torch.Tensor, torch.Tensor]:
162
+ """
163
+ Generate mask matrix for filtering data based on field availability and consumption status.
164
+
165
+ This function is called within _get_meta and generates a mask matrix based on
166
+ user-specified fields and the current step. The mask matrix selects the required
167
+ rows and columns from self.data_production_status while inversely selecting from
168
+ self.data_consumption_status to support automated vectorization.
169
+
170
+ Args:
171
+ data_fields: List of field names to include in the mask
172
+ global_step: Current global step for row selection
173
+ task_name: Name of the consumer task for consumption status
174
+
175
+ Returns:
176
+ Tuple of (row_mask, col_mask) tensors for filtering data status matrices
177
+ """
178
+
179
+ # Check if all requested fields are registered
180
+ for col in data_fields:
181
+ if col not in self.field_name_mapping:
182
+ # Return empty mask indicating no available data for unregistered columns
183
+ empty_row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool)
184
+ empty_col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool)
185
+ return empty_row_mask, empty_col_mask
186
+
187
+ # Map steps to global indices
188
+ start_idx, end_idx = self._step_to_global_index_range(global_step)
189
+ row_mask = torch.zeros(self.data_production_status.shape[0], dtype=torch.bool)
190
+ row_mask[start_idx:end_idx] = True
191
+
192
+ # Invert selection based on consumption status
193
+ consumer_status = self._get_consumption_status(task_name)
194
+ unconsumed_mask = consumer_status == 0
195
+ row_mask &= unconsumed_mask
196
+
197
+ # Select the specified fields
198
+ col_mask = torch.zeros(self.data_production_status.shape[1], dtype=torch.bool)
199
+ valid_fields = [self.field_name_mapping[col] for col in data_fields]
200
+ if valid_fields:
201
+ col_mask[valid_fields] = True
202
+
203
+ return row_mask, col_mask
204
+
205
+ def _build_index_storage_mapping(self):
206
+ """
207
+ Build mappings between global indices and storage locations.
208
+
209
+ Distributes samples across storage units based on total storage space and
210
+ maintains mappings between global index and local index within each storage.
211
+ """
212
+ # Assign each sample to a storage node. Here we scatter the samples in each GBS to different storage nodes
213
+ # Samples are arranged sequentially, similar to generate_data_status_mask
214
+ real_global_batch_size = self.global_batch_size * self.num_n_samples
215
+ global_batch_per_storage_unit = math.ceil(real_global_batch_size / self.num_storage_units)
216
+
217
+ # Build mapping between global index and storage unit for locating each data sample
218
+ batch_storage_indices = np.repeat(np.arange(self.num_storage_units), global_batch_per_storage_unit)[
219
+ :real_global_batch_size
220
+ ]
221
+ self._global_index_storage_rank_mapping = np.tile(batch_storage_indices, self.num_global_batch)
222
+
223
+ # Build mapping between global index and local index within each storage unit
224
+ indices = np.arange(self.total_storage_size)
225
+ pos_in_batch = indices % real_global_batch_size
226
+ g = indices // real_global_batch_size
227
+ pos_in_block = pos_in_batch % global_batch_per_storage_unit
228
+ self.global_index_local_index_mapping = g * global_batch_per_storage_unit + pos_in_block
229
+
230
+ def get_data_production_status(self) -> torch.Tensor:
231
+ """
232
+ Get the current data production status matrix. The data production status is a 2D matrix that records whether
233
+ the corresponding data is ready for each field of each sample.
234
+
235
+ Returns:
236
+ Tensor representing production status of all data fields
237
+ """
238
+ return self.data_production_status
239
+
240
+ def get_field_name_mapping(self) -> dict[str, Any]:
241
+ """Get the field name to column index mapping.
242
+
243
+ Returns:
244
+ Dictionary mapping field names to their column indices
245
+ """
246
+ return self.field_name_mapping
247
+
248
+ def get_data_consumption_status(self) -> dict[str, torch.Tensor]:
249
+ """Get consumption status for all tasks.
250
+
251
+ Returns:
252
+ Dictionary mapping task names to their consumption status tensors
253
+ """
254
+ return self.data_consumption_status
255
+
256
+ def get_global_index_mapping(self):
257
+ """Get global index to storage mapping information.
258
+
259
+ Returns:
260
+ Tuple containing storage rank mapping and local index mapping
261
+ """
262
+ return self._global_index_storage_rank_mapping, self.global_index_local_index_mapping
263
+
264
+ def _get_metadata(
265
+ self,
266
+ data_fields: list[str],
267
+ batch_size: int,
268
+ global_step: int,
269
+ mode: str = "fetch",
270
+ task_name: str | None = None,
271
+ get_n_samples=False,
272
+ *args,
273
+ **kwargs,
274
+ ) -> BatchMeta:
275
+ """
276
+ Retrieve metadata with support for three modes.
277
+
278
+ Args:
279
+ data_fields: List of field names to include in metadata
280
+ batch_size: Number of samples to retrieve
281
+ global_step: Global step for which to retrieve metadata
282
+ mode: Operation mode - 'insert', 'fetch', or 'force_fetch'
283
+ - mode="insert": Insert metadata for new rows (without checking data status)
284
+ - mode="fetch": Retrieve metadata for ready data (check data status and sample)
285
+ - mode="force_fetch": Directly return metadata (without checking data status)
286
+ task_name: Name of the consumer task (required for fetch modes)
287
+ get_n_samples: Whether to retrieve n_samples as groups
288
+ *args: Additional positional arguments
289
+ **kwargs: Additional keyword arguments
290
+
291
+ Returns:
292
+ BatchMeta object containing the requested metadata
293
+
294
+ Raises:
295
+ TimeoutError: If waiting for sufficient data times out in fetch mode
296
+ """
297
+ if mode == "insert":
298
+ # TODO: Currently we only supports put the entire GBS data in one time
299
+ assert batch_size == self.global_batch_size * self.num_n_samples, (
300
+ f"batch_size {batch_size} must equal "
301
+ f"global_batch_size * num_n_samples {self.global_batch_size * self.num_n_samples}"
302
+ )
303
+ start_idx, end_idx = self._step_to_global_index_range(global_step)
304
+ batch_global_indexes = list(range(start_idx, end_idx))
305
+ return self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode)
306
+
307
+ assert task_name is not None
308
+ if mode == "fetch":
309
+ # Find consumable samples within current batch and package into BatchMeta when reading
310
+
311
+ start_time = time.time()
312
+ while True:
313
+ ready_for_consume_idx = self._scan_data_status(data_fields, global_step, task_name, get_n_samples)
314
+
315
+ if len(ready_for_consume_idx) >= batch_size:
316
+ break
317
+
318
+ if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
319
+ raise TimeoutError(
320
+ f"Timeout while waiting for sufficient data. "
321
+ f"Required: {batch_size}, Available: {len(ready_for_consume_idx)}"
322
+ )
323
+
324
+ logger.warning(
325
+ f"Insufficient data available. Required: {batch_size}, "
326
+ f"Available: {len(ready_for_consume_idx)}. Retrying in "
327
+ f"{TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
328
+ )
329
+ time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
330
+ logger.debug(f"ready for consume idx: {ready_for_consume_idx}")
331
+
332
+ batch_global_indexes = sequential_sampler(
333
+ ready_for_consume_idx, batch_size, get_n_samples, self.num_n_samples
334
+ )
335
+ elif mode == "force_fetch":
336
+ start_idx, end_idx = self._step_to_global_index_range(global_step)
337
+ consumer_status = self._get_consumption_status(task_name)
338
+ not_consumed_idx = [i for i in range(start_idx, end_idx) if consumer_status[i] == 0]
339
+ batch_global_indexes = sequential_sampler(not_consumed_idx, batch_size, get_n_samples, self.num_n_samples)
340
+
341
+ # Mark this batch of data as consumed
342
+ consumer_status = self._get_consumption_status(task_name)
343
+ consumer_status[batch_global_indexes] = 1
344
+ # Package into metadata
345
+ metadata = self._generate_batch_meta(global_step, batch_global_indexes, data_fields, mode)
346
+ logger.debug(f"_get_metadata: {metadata}")
347
+
348
+ return metadata
349
+
350
+ def _scan_data_status(
351
+ self, data_fields: list[str], global_step: int, task_name: str, get_n_samples: bool
352
+ ) -> list[int]:
353
+ """
354
+ Scan data status to find samples ready for consumption.
355
+
356
+ Args:
357
+ data_fields: List of field names to check
358
+ global_step: Global step to scan
359
+ task_name: Name of the consumer task
360
+ get_n_samples: Whether to return n_samples as groups
361
+
362
+ Returns:
363
+ List of global indices that are ready for consumption
364
+ """
365
+ # Get row and column masks
366
+ row_mask, col_mask = self.generate_data_status_mask(data_fields, global_step, task_name)
367
+ logger.debug(f"row_mask, col_mask: {row_mask, col_mask}")
368
+
369
+ if not row_mask.any() or not col_mask.any():
370
+ return []
371
+
372
+ # Extract subset of data status for relevant fields
373
+ logger.debug(f"self.data_production_status: {self.data_production_status}")
374
+ data_status_of_interest = self.data_production_status[:, col_mask]
375
+ logger.debug(f"data_status_of_interest: {data_status_of_interest}")
376
+
377
+ # Use torch.all for vectorized check instead of sum comparison
378
+ all_fields_ready = torch.all(data_status_of_interest, dim=1)
379
+
380
+ # Filter samples that meet criteria combined with row mask
381
+ ready_mask = all_fields_ready & row_mask
382
+
383
+ if get_n_samples and self.num_n_samples > 1:
384
+ # Reshape to group view and check group completeness
385
+ group_all_ready = torch.all(ready_mask.view(-1, self.num_n_samples), dim=1)
386
+
387
+ # Get indices of fully ready groups
388
+ ready_group_indices = group_all_ready.nonzero(as_tuple=False).flatten()
389
+
390
+ # Calculate all sample indices
391
+ sample_offset = torch.arange(self.num_n_samples)
392
+ ready_for_consume_idx = (
393
+ (ready_group_indices.unsqueeze(1) * self.num_n_samples + sample_offset).flatten().tolist()
394
+ )
395
+
396
+ return ready_for_consume_idx
397
+ else:
398
+ ready_for_consume_idx = torch.nonzero(ready_mask, as_tuple=False).flatten().tolist()
399
+ logger.debug(f"ready_for_consume_idx: {ready_for_consume_idx}")
400
+
401
+ return ready_for_consume_idx
402
+
403
+ def _generate_batch_meta(
404
+ self, global_step: int, global_indexes: list[int], data_fields: list[str], mode: str
405
+ ) -> BatchMeta:
406
+ """
407
+ Generate BatchMeta by resolving storage locations for given global indexes.
408
+
409
+ For each global index, looks up the corresponding storage node address using:
410
+ - global_index_local_index_mapping: Maps to local index within storage
411
+ - _global_index_storage_id_mapping: Maps to storage node identifier
412
+
413
+ Args:
414
+ global_step: Current global step
415
+ global_indexes: List of global indexes to process
416
+ data_fields: List of data field names
417
+ mode: Operation mode ('fetch', 'insert', or 'force_fetch')
418
+
419
+ Returns:
420
+ BatchMeta object containing sample metadata with resolved storage locations
421
+ """
422
+ global_arr = np.array(global_indexes)
423
+ storage_ids = self.global_index_storage_id_mapping[global_arr]
424
+ local_indexes = self.global_index_local_index_mapping[global_arr]
425
+
426
+ samples = []
427
+
428
+ # Create samples from the flattened BatchMeta data
429
+ # TODO: Optimize this
430
+ for i, global_index in enumerate(global_indexes):
431
+ local_index = local_indexes[i]
432
+ storage_id = storage_ids[i]
433
+
434
+ # Create FieldMeta objects for each field
435
+ fields = []
436
+ for field_name in data_fields:
437
+ if mode == "fetch":
438
+ production_status = ProductionStatus.READY_FOR_CONSUME # Since we filtered by ready status
439
+ # Get per-field dtype and shape for this specific global_index and field
440
+ dtype = self._get_per_field_dtype(global_index, field_name)
441
+ shape = self._get_per_field_shape(global_index, field_name)
442
+ elif mode == "insert":
443
+ production_status = ProductionStatus.NOT_PRODUCED # FIXME: not real-time
444
+ dtype = None
445
+ shape = None
446
+ elif mode == "force_fetch":
447
+ col_index = self.field_name_mapping.get(field_name)
448
+ if col_index is not None and self.data_production_status[global_index, col_index] == 1:
449
+ production_status = ProductionStatus.READY_FOR_CONSUME
450
+ dtype = self._get_per_field_dtype(global_index, field_name)
451
+ shape = self._get_per_field_shape(global_index, field_name)
452
+ else:
453
+ production_status = ProductionStatus.NOT_PRODUCED
454
+ dtype = None
455
+ shape = None
456
+ field_meta = FieldMeta(
457
+ name=field_name,
458
+ dtype=dtype,
459
+ shape=shape,
460
+ production_status=production_status,
461
+ )
462
+ fields.append(field_meta)
463
+
464
+ sample = SampleMeta(
465
+ global_step=global_step,
466
+ global_index=global_index,
467
+ storage_id=storage_id,
468
+ local_index=local_index,
469
+ fields={field.name: field for field in fields},
470
+ )
471
+ samples.append(sample)
472
+
473
+ return BatchMeta(samples=samples)
474
+
475
+ def _update_production_status(self, indexes: list[int], fields: list[str]) -> None:
476
+ """
477
+ Update production status for specified indexes and fields.
478
+
479
+ Args:
480
+ indexes: List of global indexes to update
481
+ fields: List of field names to update
482
+ """
483
+ # TODO: Replace self.data_production_status == 0 or ==1 operations with ProductionStatus enum
484
+ # Update data production status matrix
485
+ new_fields = [field for field in fields if field not in self.field_name_mapping]
486
+ if new_fields:
487
+ needed_fields = len(new_fields)
488
+ current_fields = self.data_production_status.shape[1]
489
+ # Expand data status matrix if needed
490
+ if len(self.field_name_mapping) + needed_fields > current_fields:
491
+ add_fields = max(TQ_INIT_FIELD_NUM, needed_fields + 1)
492
+ new_matrix = torch.zeros((self.total_storage_size, add_fields), dtype=torch.int8)
493
+ self.data_production_status = torch.cat([self.data_production_status, new_matrix], dim=1)
494
+
495
+ for field in fields:
496
+ if field not in self.field_name_mapping.keys():
497
+ self.field_name_mapping[field] = len(self.field_name_mapping)
498
+ self.data_production_status[
499
+ torch.tensor(indexes)[:, None], torch.tensor([self.field_name_mapping.get(field) for field in fields])
500
+ ] = 1
501
+
502
+ def _update_field_info(
503
+ self,
504
+ fields: list[str],
505
+ per_tensor_dtypes: dict[int, dict[str, Any]],
506
+ per_tensor_shapes: dict[int, dict[str, Any]],
507
+ global_indexes: list[int],
508
+ ) -> None:
509
+ """
510
+ Store per-field dtype and shape information.
511
+
512
+ Args:
513
+ fields: List of field names
514
+ per_tensor_dtypes: Dict mapping global_index to field dtypes {global_index: {field: dtype}}
515
+ per_tensor_shapes: Dict mapping global_index to field shapes {global_index: {field: shape}}
516
+ global_indexes: List of global indexes corresponding to the samples
517
+ """
518
+ for global_idx in global_indexes:
519
+ if global_idx not in self.per_tensor_dtype_mapping:
520
+ self.per_tensor_dtype_mapping[global_idx] = {}
521
+ if global_idx not in self.per_tensor_shape_mapping:
522
+ self.per_tensor_shape_mapping[global_idx] = {}
523
+
524
+ for field in fields:
525
+ if global_idx in per_tensor_dtypes and field in per_tensor_dtypes[global_idx]:
526
+ self.per_tensor_dtype_mapping[global_idx][field] = per_tensor_dtypes[global_idx][field]
527
+ if global_idx in per_tensor_shapes and field in per_tensor_shapes[global_idx]:
528
+ self.per_tensor_shape_mapping[global_idx][field] = per_tensor_shapes[global_idx][field]
529
+
530
+ def _init_zmq_socket(self):
531
+ """
532
+ Initialize ZMQ sockets for communication.
533
+
534
+ Sets up three ZMQ service ports for:
535
+ 1. Receiving handshake requests from storage
536
+ 2. Handling client data read/write requests
537
+ 3. Receiving status update signals from storage
538
+ """
539
+ self.zmq_context = zmq.Context()
540
+
541
+ self._node_ip = get_node_ip_address()
542
+ self._handshake_socket_port = get_free_port()
543
+ self._request_handle_socket_port = get_free_port()
544
+ self._data_status_update_socket_port = get_free_port()
545
+
546
+ self.handshake_socket = create_zmq_socket(
547
+ ctx=self.zmq_context,
548
+ socket_type=zmq.ROUTER,
549
+ )
550
+ self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}")
551
+
552
+ self.request_handle_socket = create_zmq_socket(
553
+ ctx=self.zmq_context,
554
+ socket_type=zmq.ROUTER,
555
+ )
556
+ self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}")
557
+
558
+ self.data_status_update_socket = create_zmq_socket(
559
+ ctx=self.zmq_context,
560
+ socket_type=zmq.ROUTER,
561
+ )
562
+ self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}")
563
+
564
+ self.zmq_server_info = ZMQServerInfo.create(
565
+ role=TransferQueueRole.CONTROLLER,
566
+ id=self.controller_id,
567
+ ip=self._node_ip,
568
+ ports={
569
+ "handshake_socket": self._handshake_socket_port,
570
+ "request_handle_socket": self._request_handle_socket_port,
571
+ "data_status_update_socket": self._data_status_update_socket_port,
572
+ },
573
+ )
574
+
575
+ def _wait_connection(self):
576
+ """Wait for all storage instances to complete handshake.
577
+
578
+ Clients don't need handshake to support dynamic scaling. Continuously
579
+ listens for handshake messages until all expected storage units connect.
580
+ """
581
+ # TODO(zjj): Consider if retransmission is needed (assuming cases where Storage doesn't receive ACK)
582
+ connected_storage_units = set()
583
+ while len(connected_storage_units) < self.num_storage_units:
584
+ identity, serialized_msg = self.handshake_socket.recv_multipart()
585
+ request_msg = ZMQMessage.deserialize(serialized_msg)
586
+ if request_msg.request_type == ZMQRequestType.HANDSHAKE:
587
+ connected_storage_units.add(request_msg.sender_id)
588
+ response_msg = ZMQMessage.create(
589
+ request_type=ZMQRequestType.HANDSHAKE_ACK,
590
+ sender_id=self.controller_id,
591
+ body={},
592
+ ).serialize()
593
+ self.handshake_socket.send_multipart([identity, response_msg])
594
+ logger.info("Controller sent handshake ack successfully!")
595
+ self.global_index_storage_id_mapping = np.array(sorted(list(connected_storage_units)))[
596
+ self._global_index_storage_rank_mapping
597
+ ]
598
+ self.handshake_done.set()
599
+
600
+ def _start_process_handshake(self):
601
+ """Start the handshake process thread."""
602
+ self.handshake_done = threading.Event()
603
+ self.wait_connection_thread = Thread(
604
+ target=self._wait_connection, name="TransferQueueControllerWaitConnectionThread", daemon=True
605
+ )
606
+ self.wait_connection_thread.start()
607
+
608
+ def _start_process_update_data_status(self):
609
+ """Start the data status update processing thread."""
610
+ self.process_update_data_status_thread = Thread(
611
+ target=self._update_data_status, name="TransferQueueControllerProcessUpdateDataStatusThread", daemon=True
612
+ )
613
+ self.process_update_data_status_thread.start()
614
+
615
+ def _start_process_request(self):
616
+ """Start the request processing thread."""
617
+ self.process_request_thread = Thread(
618
+ target=self._process_request, name="TransferQueueControllerProcessRequestThread", daemon=True
619
+ )
620
+ self.process_request_thread.start()
621
+
622
+ def _process_request(self):
623
+ """Main request processing loop.
624
+
625
+ Handles various request types including metadata retrieval,
626
+ consumption status checks, and clear operations.
627
+ """
628
+ self.handshake_done.wait()
629
+ while True:
630
+ # ROUTER socket receives multi-part messages
631
+ identity, serialized_msg = self.request_handle_socket.recv_multipart()
632
+ request_msg = ZMQMessage.deserialize(serialized_msg)
633
+
634
+ if request_msg.request_type == ZMQRequestType.GET_META:
635
+ params = request_msg.body
636
+ logger.info("Controller preparing to get metadata...")
637
+ metadata = self._get_metadata(
638
+ data_fields=params["data_fields"],
639
+ batch_size=params["batch_size"],
640
+ global_step=params["global_step"],
641
+ mode=params.get("mode", "fetch"),
642
+ task_name=params.get("task_name", None),
643
+ get_n_samples=params.get("get_n_samples", False),
644
+ )
645
+ response_msg = ZMQMessage.create(
646
+ request_type=ZMQRequestType.GET_META_RESPONSE,
647
+ sender_id=self.controller_id,
648
+ receiver_id=request_msg.sender_id,
649
+ body={"metadata": metadata},
650
+ )
651
+ elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
652
+ params = request_msg.body
653
+ metadata = self._get_metadata(
654
+ data_fields=[],
655
+ batch_size=self.global_batch_size * self.num_n_samples,
656
+ global_step=params["global_step"],
657
+ mode="insert",
658
+ )
659
+ response_msg = ZMQMessage.create(
660
+ request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
661
+ sender_id=self.controller_id,
662
+ receiver_id=request_msg.sender_id,
663
+ body={"metadata": metadata},
664
+ )
665
+ elif request_msg.request_type == ZMQRequestType.CLEAR_META:
666
+ params = request_msg.body
667
+ self.clear(global_step=params["global_step"])
668
+ response_msg = ZMQMessage.create(
669
+ request_type=ZMQRequestType.CLEAR_META_RESPONSE,
670
+ sender_id=self.controller_id,
671
+ receiver_id=request_msg.sender_id,
672
+ body={"message": f"Clear operation completed by controller {self.controller_id}"},
673
+ )
674
+ elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
675
+ # Check consumption status
676
+ params = request_msg.body
677
+ global_step = params["global_step"]
678
+
679
+ consumer_status = self._get_consumption_status(params["task_name"])
680
+ start_idx, end_idx = self._step_to_global_index_range(global_step)
681
+ batch_status = consumer_status[start_idx:end_idx]
682
+ consumed = torch.all(batch_status == 1).item()
683
+
684
+ # Build response message
685
+ response_msg = ZMQMessage.create(
686
+ request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
687
+ sender_id=self.controller_id,
688
+ receiver_id=request_msg.sender_id,
689
+ body={
690
+ "global_step": global_step,
691
+ "consumed": consumed,
692
+ },
693
+ )
694
+ self.request_handle_socket.send_multipart([identity, response_msg.serialize()])
695
+ logger.debug("Controller request_handle_socket sent multipart successfully!")
696
+
697
+ def _update_data_status(self):
698
+ """Process data status update messages from storage units.
699
+
700
+ Continuously listens for data update notifications and updates
701
+ internal production status and field information accordingly.
702
+ """
703
+ # Receive data status update information from storage
704
+ while True:
705
+ logger.debug("Preparing _update_data_status...")
706
+ identity, serialized_msg = self.data_status_update_socket.recv_multipart()
707
+ logger.debug("Controller received update_data_status request!")
708
+ request_msg = ZMQMessage.deserialize(serialized_msg)
709
+ logger.debug(f"[{self.controller_id}]: Controller received update_data_status request_msg: {request_msg}")
710
+
711
+ if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
712
+ message_data = request_msg.body
713
+
714
+ fields = message_data.get("fields", [])
715
+ global_indexes = message_data.get("global_indexes", [])
716
+ per_tensor_dtypes = message_data.get("dtypes", {}) # Now a dict of lists
717
+ per_tensor_shapes = message_data.get("shapes", {}) # Now a dict of lists
718
+ # Update data production status
719
+ logger.debug(f"global_indexes, fields: {global_indexes, fields}")
720
+ self._update_production_status(global_indexes, fields)
721
+ self._update_field_info(fields, per_tensor_dtypes, per_tensor_shapes, global_indexes)
722
+ logger.info("Controller updated production status successfully!")
723
+
724
+ # Send acknowledgment response
725
+ response_msg = ZMQMessage.create(
726
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
727
+ sender_id=self.controller_id,
728
+ body={
729
+ "controller_id": self.controller_id,
730
+ "message": f"Data update acknowledged from controller {self.controller_id}",
731
+ },
732
+ )
733
+ self.data_status_update_socket.send_multipart([identity, response_msg.serialize()])
734
+ logger.info("Controller sent DATA_UPDATE_ACK successfully!")
735
+ elif request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE_ERROR:
736
+ # Handle data update errors
737
+ error_msg = request_msg.body.get("message", "Unknown error")
738
+ logger.error(f"Data update error from storage: {error_msg}")
739
+
740
+ # Send error acknowledgment response
741
+ response_msg = ZMQMessage.create(
742
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
743
+ sender_id=self.controller_id,
744
+ body={
745
+ "controller_id": self.controller_id,
746
+ "message": f"Error notification acknowledged from controller {self.controller_id}",
747
+ },
748
+ )
749
+ self.data_status_update_socket.send_multipart([identity, response_msg.serialize()])
750
+
751
+ def get_zmq_server_info(self) -> ZMQServerInfo:
752
+ """Get ZMQ server connection information.
753
+
754
+ Returns:
755
+ ZMQServerInfo object containing connection details
756
+ """
757
+ return self.zmq_server_info
758
+
759
+ def clear(self, global_step: int):
760
+ """Clear data for a specific global batch.
761
+
762
+ Resets production and consumption status for all data in the specified
763
+ global step. Currently only supports clearing single GBS at a time.
764
+
765
+ Args:
766
+ global_step: The global step to clear data for
767
+ """
768
+ start_idx, end_idx = self._step_to_global_index_range(global_step)
769
+
770
+ self.data_production_status[start_idx:end_idx, :] = 0
771
+ for task_name in self.data_consumption_status:
772
+ self.data_consumption_status[task_name][start_idx:end_idx] = 0