TransferQueue 0.1.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. recipe/simple_use_case/async_demo.py +331 -0
  2. recipe/simple_use_case/sync_demo.py +220 -0
  3. tests/test_async_simple_storage_manager.py +339 -0
  4. tests/test_client.py +423 -0
  5. tests/test_controller.py +274 -0
  6. tests/test_controller_data_partitions.py +513 -0
  7. tests/test_kv_storage_manager.py +92 -0
  8. tests/test_put.py +327 -0
  9. tests/test_samplers.py +492 -0
  10. tests/test_serial_utils_on_cpu.py +202 -0
  11. tests/test_simple_storage_unit.py +443 -0
  12. tests/test_storage_client_factory.py +45 -0
  13. transfer_queue/__init__.py +48 -0
  14. transfer_queue/client.py +611 -0
  15. transfer_queue/controller.py +1187 -0
  16. transfer_queue/metadata.py +460 -0
  17. transfer_queue/sampler/__init__.py +19 -0
  18. transfer_queue/sampler/base.py +74 -0
  19. transfer_queue/sampler/grpo_group_n_sampler.py +157 -0
  20. transfer_queue/sampler/sequential_sampler.py +75 -0
  21. transfer_queue/storage/__init__.py +25 -0
  22. transfer_queue/storage/clients/__init__.py +24 -0
  23. transfer_queue/storage/clients/base.py +22 -0
  24. transfer_queue/storage/clients/factory.py +55 -0
  25. transfer_queue/storage/clients/yuanrong_client.py +118 -0
  26. transfer_queue/storage/managers/__init__.py +23 -0
  27. transfer_queue/storage/managers/base.py +460 -0
  28. transfer_queue/storage/managers/factory.py +43 -0
  29. transfer_queue/storage/managers/simple_backend_manager.py +611 -0
  30. transfer_queue/storage/managers/yuanrong_manager.py +18 -0
  31. transfer_queue/storage/simple_backend.py +451 -0
  32. transfer_queue/utils/__init__.py +13 -0
  33. transfer_queue/utils/serial_utils.py +240 -0
  34. transfer_queue/utils/utils.py +132 -0
  35. transfer_queue/utils/zmq_utils.py +170 -0
  36. transfer_queue/version/version +1 -0
  37. transferqueue-0.1.1.dev0.dist-info/METADATA +327 -0
  38. transferqueue-0.1.1.dev0.dist-info/RECORD +41 -0
  39. transferqueue-0.1.1.dev0.dist-info/WHEEL +5 -0
  40. transferqueue-0.1.1.dev0.dist-info/licenses/LICENSE +202 -0
  41. transferqueue-0.1.1.dev0.dist-info/top_level.txt +4 -0
@@ -0,0 +1,1187 @@
1
+ # Copyright 2025 The TransferQueue Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import time
18
+ from collections import defaultdict
19
+ from dataclasses import dataclass, field
20
+ from threading import Thread
21
+ from typing import Any, Optional
22
+ from uuid import uuid4
23
+
24
+ import ray
25
+ import torch
26
+ import zmq
27
+ from ray.util import get_node_ip_address
28
+
29
+ from transfer_queue.metadata import (
30
+ BatchMeta,
31
+ FieldMeta,
32
+ SampleMeta,
33
+ )
34
+ from transfer_queue.sampler import BaseSampler, SequentialSampler
35
+ from transfer_queue.utils.utils import (
36
+ ProductionStatus,
37
+ TransferQueueRole,
38
+ )
39
+ from transfer_queue.utils.zmq_utils import (
40
+ ZMQMessage,
41
+ ZMQRequestType,
42
+ ZMQServerInfo,
43
+ create_zmq_socket,
44
+ get_free_port,
45
+ )
46
+
47
+ logger = logging.getLogger(__name__)
48
+ logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
49
+
50
+ TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 300))
51
+ TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 1))
52
+ TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL", 2))
53
+
54
+ TQ_INIT_SAMPLE_NUM = int(os.environ.get("TQ_INIT_SAMPLE_NUM", 10)) # Initial number of samples
55
+ TQ_INIT_FIELD_NUM = int(os.environ.get("TQ_INIT_FIELD_NUM", 10))
56
+
57
+ # Expansion configuration - Unified approach using minimum expansion sizes
58
+ TQ_SAMPLE_MIN_EXPANSION_SIZE = int(
59
+ os.environ.get("TQ_SAMPLE_MIN_EXPANSION_SIZE", 10)
60
+ ) # Minimum expansion size for samples (rows)
61
+ TQ_FIELD_MIN_EXPANSION_SIZE = int(
62
+ os.environ.get("TQ_FIELD_MIN_EXPANSION_SIZE", 5)
63
+ ) # Minimum expansion size for fields (columns)
64
+
65
+
66
+ class PartitionIndexManager:
67
+ """
68
+ Manages the mapping relationship between partitions and global indexes,
69
+ responsible for index allocation and reuse.
70
+ """
71
+
72
+ def __init__(self):
73
+ # Records the set of global_indexes used by each partition
74
+ self.partition_to_indexes = defaultdict(set)
75
+
76
+ # Reusable global_index pool - stored using list
77
+ self.reusable_indexes = []
78
+
79
+ # Global index counter for allocating new indexes
80
+ self.global_index_counter = 0
81
+
82
+ # Track all active indexes
83
+ self.allocated_indexes = set()
84
+
85
+ def allocate_indexes(self, partition_id, count=1) -> list:
86
+ """
87
+ Allocate global_indexes for the specified partition.
88
+ Prioritizes obtaining from reusable pool, allocates new indexes when insufficient.
89
+
90
+ Args:
91
+ partition_id: Partition ID
92
+ count: Number of indexes needed
93
+
94
+ Returns:
95
+ list: List of allocated global_indexes
96
+ """
97
+ if count <= 0:
98
+ raise ValueError(f"Number of indexes needed must larger than 0, but got {count}")
99
+ indexes = []
100
+
101
+ # Get indexes from reusable pool
102
+ if self.reusable_indexes:
103
+ # Calculate number of indexes needed from reusable pool
104
+ num_reuse = min(count, len(self.reusable_indexes))
105
+
106
+ # Use slice operation to get multiple elements at once (FIFO principle)
107
+ indexes.extend(self.reusable_indexes[:num_reuse])
108
+ del self.reusable_indexes[:num_reuse]
109
+
110
+ # If reusable pool doesn't have enough indexes, allocate new ones
111
+ if len(indexes) < count:
112
+ # Ensure newly allocated indexes don't conflict with existing ones
113
+ needed = count - len(indexes)
114
+ # Batch allocate consecutive index ranges
115
+ start_index = self.global_index_counter
116
+ end_index = start_index + needed
117
+
118
+ # Directly generate consecutive index list
119
+ new_indexes = list(range(start_index, end_index))
120
+
121
+ # Batch update status
122
+ self.allocated_indexes.update(new_indexes)
123
+ self.global_index_counter = end_index
124
+
125
+ indexes.extend(new_indexes)
126
+
127
+ # Record partition-index relationship
128
+ self.partition_to_indexes[partition_id].update(indexes)
129
+
130
+ return indexes
131
+
132
+ def release_indexes(self, partition_id):
133
+ """
134
+ Release all global_indexes of the specified partition, adding them to reusable pool.
135
+
136
+ Args:
137
+ partition_id: Partition ID
138
+
139
+ Returns:
140
+ list: List of released global_indexes
141
+ """
142
+ if partition_id in self.partition_to_indexes:
143
+ indexes = self.partition_to_indexes.pop(partition_id)
144
+
145
+ # Add released indexes to reusable pool
146
+ self.reusable_indexes.extend(indexes)
147
+
148
+ # Remove these indexes from allocated_indexes
149
+ for idx in indexes:
150
+ self.allocated_indexes.discard(idx)
151
+
152
+ return indexes
153
+ return []
154
+
155
+ def get_indexes_for_partition(self, partition_id):
156
+ """
157
+ Get all global_indexes for the specified partition.
158
+
159
+ Args:
160
+ partition_id: Partition ID
161
+
162
+ Returns:
163
+ set: Set of global_indexes for this partition
164
+ """
165
+ return self.partition_to_indexes.get(partition_id, set()).copy()
166
+
167
+
168
+ @dataclass
169
+ class DataPartitionStatus:
170
+ """
171
+ Robust status information for a data partition with dynamic expansion support.
172
+
173
+ This class tracks the production and consumption status of data within a specific
174
+ partition (e.g., "train@global_batch_0", "inference@kv_cache_1") with full support
175
+ for dynamic row and column expansion.
176
+ """
177
+
178
+ partition_id: str
179
+ created_at: float = field(default_factory=time.time)
180
+
181
+ # Production status tensor - dynamically expandable
182
+ # Values: 0 = not produced, 1 = ready for consumption
183
+ production_status: Optional[torch.Tensor] = torch.zeros(TQ_INIT_SAMPLE_NUM, TQ_INIT_FIELD_NUM, dtype=torch.int8)
184
+
185
+ # Consumption status per task - task_name -> consumption_tensor
186
+ # Each tensor tracks which samples have been consumed by that task
187
+ consumption_status: dict[str, torch.Tensor] = field(default_factory=dict)
188
+
189
+ # Field metadata
190
+ field_name_mapping: dict[str, int] = field(default_factory=dict) # field_name -> column_index
191
+ field_dtypes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: dtype}
192
+ field_shapes: dict[int, dict[str, Any]] = field(default_factory=dict) # global_idx -> {field: shape}
193
+
194
+ # Dynamic configuration - these are computed from the current state
195
+ @property
196
+ def total_samples_num(self) -> int:
197
+ """Current number of samples (rows) in the partition."""
198
+ return self.production_status.shape[0] if self.production_status is not None else 0
199
+
200
+ @property
201
+ def total_fields_num(self) -> int:
202
+ """Current number of fields (columns) in the partition."""
203
+ return len(self.field_name_mapping)
204
+
205
+ @property
206
+ def allocated_fields_num(self) -> int:
207
+ """Current number of allocated columns in the tensor."""
208
+ return self.production_status.shape[1] if self.production_status is not None else 0
209
+
210
+ @property
211
+ def allocated_samples_num(self) -> int:
212
+ """Current number of allocated rows in the tensor."""
213
+ return self.production_status.shape[0] if self.production_status is not None else 0
214
+
215
+ # ==================== Dynamic Expansion Methods ====================
216
+
217
+ def ensure_samples_capacity(self, required_samples: int) -> bool:
218
+ """
219
+ Ensure the production status tensor has enough rows for the required samples.
220
+ Dynamically expands if needed using unified minimum expansion size.
221
+
222
+ Args:
223
+ required_samples: Minimum number of samples needed
224
+ """
225
+ current_samples = self.production_status.shape[0]
226
+ if required_samples > current_samples:
227
+ # Expand rows using minimum expansion size for predictable memory usage
228
+ expansion_needed = required_samples - current_samples
229
+ min_expansion = max(TQ_SAMPLE_MIN_EXPANSION_SIZE, expansion_needed)
230
+ new_samples = current_samples + min_expansion
231
+ new_fields = self.production_status.shape[1]
232
+
233
+ expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8)
234
+ expanded_tensor[:current_samples, :] = self.production_status
235
+ self.production_status = expanded_tensor
236
+
237
+ # Update consumption tensors for all tasks
238
+ for task_name, consumption_tensor in self.consumption_status.items():
239
+ expanded_consumption = torch.zeros(new_samples, dtype=torch.int8)
240
+ expanded_consumption[:current_samples] = consumption_tensor
241
+ self.consumption_status[task_name] = expanded_consumption
242
+
243
+ logger.debug(
244
+ f"Expanded partition {self.partition_id} from {current_samples} to {new_samples} samples "
245
+ f"(added {min_expansion} samples)"
246
+ )
247
+
248
+ def ensure_fields_capacity(self, required_fields: int) -> bool:
249
+ """
250
+ Ensure the production status tensor has enough columns for the required fields.
251
+ Dynamically expands if needed using unified minimum expansion size.
252
+
253
+ Args:
254
+ required_fields: Minimum number of fields needed
255
+ """
256
+ if self.production_status is None:
257
+ # Will be initialized when samples are added
258
+ return
259
+
260
+ current_fields = self.production_status.shape[1]
261
+ if required_fields > current_fields:
262
+ # Expand columns using minimum expansion size for predictable memory usage
263
+ expansion_needed = required_fields - current_fields
264
+ min_expansion = max(TQ_FIELD_MIN_EXPANSION_SIZE, expansion_needed)
265
+ new_fields = current_fields + min_expansion
266
+ new_samples = self.production_status.shape[0]
267
+
268
+ expanded_tensor = torch.zeros(new_samples, new_fields, dtype=torch.int8)
269
+ expanded_tensor[:, :current_fields] = self.production_status
270
+ self.production_status = expanded_tensor
271
+
272
+ logger.debug(
273
+ f"Expanded partition {self.partition_id} from {current_fields} to {new_fields} fields "
274
+ f"(added {min_expansion} fields)"
275
+ )
276
+
277
+ # ==================== Production Status Interface ====================
278
+
279
+ def update_production_status(
280
+ self,
281
+ global_indices: list[int],
282
+ field_names: list[str],
283
+ dtypes: Optional[dict[int, dict[str, Any]]] = None,
284
+ shapes: Optional[dict[int, dict[str, Any]]] = None,
285
+ ) -> bool:
286
+ """
287
+ Update production status for specific samples and fields.
288
+ Handles dynamic expansion of both samples and fields.
289
+
290
+ Args:
291
+ global_indices: List of sample indices to update
292
+ field_names: List of field names to mark as produced
293
+ dtypes: Optional per-sample field dtype information
294
+ shapes: Optional per-sample field shape information
295
+
296
+ Returns:
297
+ True if update was successful, False on error
298
+ """
299
+ try:
300
+ # Determine required capacity
301
+ max_sample_idx = max(global_indices) if global_indices else -1
302
+ required_samples = max_sample_idx + 1
303
+
304
+ # Ensure we have enough rows
305
+ self.ensure_samples_capacity(required_samples)
306
+
307
+ # Register new fields if needed
308
+ new_fields = [field for field in field_names if field not in self.field_name_mapping]
309
+ if new_fields:
310
+ # Add new fields to mapping
311
+ for field in new_fields:
312
+ self.field_name_mapping[field] = len(self.field_name_mapping)
313
+
314
+ required_fields = len(self.field_name_mapping)
315
+ self.ensure_fields_capacity(required_fields)
316
+
317
+ # Update production status
318
+ if self.production_status is not None and global_indices and field_names:
319
+ field_indices = [self.field_name_mapping.get(field) for field in field_names]
320
+ self.production_status[torch.tensor(global_indices)[:, None], torch.tensor(field_indices)] = 1
321
+
322
+ # Update field metadata
323
+ self._update_field_metadata(global_indices, field_names, dtypes, shapes)
324
+
325
+ return True
326
+
327
+ except Exception as e:
328
+ logger.error(f"Error updating production status for partition {self.partition_id}: {e}")
329
+ return False
330
+
331
+ def _update_field_metadata(
332
+ self,
333
+ global_indices: list[int],
334
+ field_names: list[str],
335
+ dtypes: Optional[dict[int, dict[str, Any]]] = None,
336
+ shapes: Optional[dict[int, dict[str, Any]]] = None,
337
+ ):
338
+ """Update field dtype and shape metadata."""
339
+ for global_idx in global_indices:
340
+ if global_idx not in self.field_dtypes:
341
+ self.field_dtypes[global_idx] = {}
342
+ if global_idx not in self.field_shapes:
343
+ self.field_shapes[global_idx] = {}
344
+
345
+ for field_name in field_names:
346
+ if dtypes and global_idx in dtypes and field_name in dtypes[global_idx]:
347
+ self.field_dtypes[global_idx][field_name] = dtypes[global_idx][field_name]
348
+ if shapes and global_idx in shapes and field_name in shapes[global_idx]:
349
+ self.field_shapes[global_idx][field_name] = shapes[global_idx][field_name]
350
+
351
+ # ==================== Consumption Status Interface ====================
352
+
353
+ def get_consumption_status(self, task_name: str) -> torch.Tensor:
354
+ """
355
+ Get or create consumption status for a specific task.
356
+ Handles dynamic expansion when new samples are added.
357
+
358
+ Args:
359
+ task_name: Name of the consumer task
360
+
361
+ Returns:
362
+ Consumption status tensor for the specified task
363
+ """
364
+ if task_name not in self.consumption_status:
365
+ if self.production_status is not None:
366
+ self.consumption_status[task_name] = torch.zeros(self.total_samples_num, dtype=torch.int8)
367
+ else:
368
+ self.consumption_status[task_name] = torch.zeros(0, dtype=torch.int8)
369
+
370
+ return self.consumption_status[task_name]
371
+
372
+ # TODO: No need return, just raise error. Same With other function
373
+ def mark_consumed(self, task_name: str, global_indices: list[int]) -> bool:
374
+ """
375
+ Mark specific samples as consumed by a task.
376
+
377
+ Args:
378
+ task_name: Name of the consumer task
379
+ global_indices: List of sample indices to mark as consumed
380
+
381
+ Returns:
382
+ True if successful, False on error
383
+ """
384
+ try:
385
+ consumption_status = self.get_consumption_status(task_name)
386
+ if consumption_status.numel() > 0 and global_indices:
387
+ consumption_status[global_indices] = 1
388
+ return True
389
+ except Exception as e:
390
+ logger.error(f"Error marking samples consumed for partition {self.partition_id}, task {task_name}: {e}")
391
+ return False
392
+
393
+ # ==================== Data Scanning and Query Methods ====================
394
+
395
+ def scan_data_status(self, field_names: list[str], task_name: str) -> list[int]:
396
+ """
397
+ Scan data status to find samples ready for consumption.
398
+ This replaces the original _scan_data_status functionality.
399
+
400
+ Args:
401
+ field_names: List of required field names
402
+ task_name: Name of the consumer task
403
+
404
+ Returns:
405
+ List of sample indices that are ready for consumption
406
+ """
407
+ if self.production_status is None:
408
+ return []
409
+
410
+ # Check if all requested fields are registered
411
+ for field_name in field_names:
412
+ if field_name not in self.field_name_mapping:
413
+ return []
414
+
415
+ row_mask = torch.ones(self.total_samples_num, dtype=torch.bool)
416
+
417
+ # Apply consumption filter (exclude already consumed samples)
418
+ consumption_status = self.get_consumption_status(task_name)
419
+ if consumption_status is not None:
420
+ unconsumed_mask = consumption_status == 0
421
+ row_mask &= unconsumed_mask
422
+
423
+ # Create column mask for requested fields
424
+ col_mask = torch.zeros(self.allocated_fields_num, dtype=torch.bool)
425
+ field_indices = [self.field_name_mapping[field] for field in field_names]
426
+ if field_indices:
427
+ col_mask[field_indices] = True
428
+
429
+ # Filter production status by masks
430
+ relevant_status = self.production_status[row_mask][:, col_mask]
431
+
432
+ # Check if all required fields are ready for each sample
433
+ all_fields_ready = torch.all(relevant_status, dim=1)
434
+ ready_indices_in_filtered = torch.nonzero(all_fields_ready, as_tuple=False).flatten()
435
+
436
+ # Map back to original sample indices
437
+ all_indices = torch.where(row_mask)[0]
438
+ ready_sample_indices = all_indices[ready_indices_in_filtered].tolist()
439
+
440
+ return ready_sample_indices
441
+
442
+ # ==================== Field Metadata Methods ====================
443
+
444
+ def get_field_dtype(self, global_index: int, field_name: str) -> Optional[Any]:
445
+ """Get dtype for a specific sample and field."""
446
+ return self.field_dtypes.get(global_index, {}).get(field_name)
447
+
448
+ def get_field_shape(self, global_index: int, field_name: str) -> Optional[Any]:
449
+ """Get shape for a specific sample and field."""
450
+ return self.field_shapes.get(global_index, {}).get(field_name)
451
+
452
+ # ==================== Statistics and Monitoring ====================
453
+
454
+ def get_statistics(self) -> dict[str, Any]:
455
+ """Get detailed statistics for this partition."""
456
+ stats = {
457
+ "partition_id": self.partition_id,
458
+ "created_at": self.created_at,
459
+ "total_samples_num": self.total_samples_num,
460
+ "total_fields_num": self.total_fields_num,
461
+ "allocated_fields_num": self.allocated_fields_num,
462
+ "registered_tasks": list(self.consumption_status.keys()),
463
+ }
464
+
465
+ if self.production_status is not None:
466
+ produced_samples = torch.any(self.production_status == 1, dim=1).sum().item()
467
+ stats["produced_samples"] = produced_samples
468
+ stats["production_progress"] = (
469
+ produced_samples / self.total_samples_num if self.total_samples_num > 0 else 0
470
+ )
471
+
472
+ # Field-wise production statistics
473
+ field_stats = {}
474
+ for field_name, field_idx in self.field_name_mapping.items():
475
+ field_produced = (self.production_status[:, field_idx] == 1).sum().item()
476
+ field_stats[field_name] = {
477
+ "produced_samples": field_produced,
478
+ "production_progress": field_produced / self.total_samples_num if self.total_samples_num > 0 else 0,
479
+ }
480
+ stats["field_statistics"] = field_stats
481
+
482
+ # Consumption statistics per task
483
+ consumption_stats = {}
484
+ for task_name, consumption_tensor in self.consumption_status.items():
485
+ consumed_samples = (consumption_tensor == 1).sum().item()
486
+ consumption_stats[task_name] = {
487
+ "consumed_samples": consumed_samples,
488
+ "consumption_progress": consumed_samples / self.total_samples_num if self.total_samples_num > 0 else 0,
489
+ }
490
+ stats["consumption_statistics"] = consumption_stats
491
+
492
+ return stats
493
+
494
+ def clear_data(self, global_indexes_range: list[int], clear_consumption: bool = True) -> bool:
495
+ """Clear all production and optionally consumption data."""
496
+ try:
497
+ if self.production_status is not None:
498
+ self.production_status[global_indexes_range, :] = 0
499
+
500
+ if clear_consumption:
501
+ for consumption_tensor in self.consumption_status.values():
502
+ consumption_tensor[global_indexes_range] = 0
503
+
504
+ return True
505
+ except Exception as e:
506
+ logger.error(f"Error clearing data for partition {self.partition_id}: {e}")
507
+ return False
508
+
509
+
510
+ @ray.remote(num_cpus=1)
511
+ class TransferQueueController:
512
+ """
513
+ TransferQueue Controller with partition-based data management.
514
+
515
+ This refactored controller manages data through dynamic partitions instead of
516
+ fixed global batches. Each partition represents a logical data container
517
+ (e.g., "train@global_batch_0", "inference@kv_cache_1") that can be created
518
+ on-demand and managed independently.
519
+
520
+ Key improvements:
521
+ - Dynamic partition creation on-demand
522
+ - No dependency on training-specific parameters (global_batch_size, etc.)
523
+ - Support for diverse use cases (KV cache migration, model resharding, etc.)
524
+ - Flexible data organization through partition-based addressing
525
+ """
526
+
527
+ def __init__(self, sampler: BaseSampler | type[BaseSampler] = SequentialSampler) -> None:
528
+ """Initialize the TransferQueue Controller.
529
+
530
+ Args:
531
+ sampler: Sampler instance or sampler class to use for data sampling.
532
+ - If a BaseSampler instance is provided, it will be used directly
533
+ - If a BaseSampler subclass is provided, it will be instantiated
534
+ - Defaults to SequentialSampler for simple sequential sampling
535
+ - Example: sampler=GRPOGroupNSampler() (instance)
536
+ - Example: sampler=GRPOGroupNSampler (class)
537
+ """
538
+ if isinstance(sampler, BaseSampler):
539
+ self.sampler = sampler
540
+ elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
541
+ self.sampler = sampler()
542
+ else:
543
+ raise TypeError(
544
+ f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler"
545
+ )
546
+
547
+ self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}"
548
+
549
+ # Initialize ZMQ sockets for communication
550
+ self._init_zmq_socket()
551
+
552
+ # Partition management
553
+ self.partitions: dict[str, DataPartitionStatus] = {} # partition_id -> DataPartitionStatus
554
+
555
+ # Partition-GlobalIndex management
556
+ self.index_manager = PartitionIndexManager() # partition_id -> global_indexes
557
+
558
+ # Connected storage managers tracking
559
+ self._connected_storage_managers: set[str] = set()
560
+
561
+ # Start background processing threads
562
+ self._start_process_handshake()
563
+ self._start_process_update_data_status()
564
+ self._start_process_request()
565
+
566
+ logger.info(f"TransferQueue Controller {self.controller_id} initialized")
567
+
568
+ # ==================== Partition Management API ====================
569
+
570
+ def create_partition(self, partition_id: str) -> bool:
571
+ """
572
+ Create a new data partition.
573
+
574
+ Note: Partitions now dynamically expand as needed, so initial capacity is not required.
575
+
576
+ Args:
577
+ partition_id: Unique identifier for the partition (e.g., "train@global_batch_0")
578
+
579
+ Returns:
580
+ True if partition was created successfully, False if it already exists
581
+ """
582
+ if partition_id in self.partitions:
583
+ logger.warning(f"Partition {partition_id} already exists")
584
+ return False
585
+
586
+ self.partitions[partition_id] = DataPartitionStatus(partition_id=partition_id)
587
+
588
+ logger.info(f"Created partition {partition_id} with dynamic capacity")
589
+ return True
590
+
591
+ def get_partition(self, partition_id: str) -> Optional[DataPartitionStatus]:
592
+ """
593
+ Get partition status information.
594
+
595
+ Args:
596
+ partition_id: ID of the partition to retrieve
597
+
598
+ Returns:
599
+ DataPartitionStatus object if partition exists, None otherwise
600
+ """
601
+ return self.partitions.get(partition_id)
602
+
603
+ def list_partitions(self) -> list[str]:
604
+ """
605
+ List all available partition IDs.
606
+
607
+ Returns:
608
+ List of partition IDs
609
+ """
610
+ return list(self.partitions.keys())
611
+
612
+ def delete_partition(self, partition_id: str) -> bool:
613
+ """
614
+ Delete a partition and all its data.
615
+
616
+ Args:
617
+ partition_id: ID of the partition to delete
618
+
619
+ Returns:
620
+ True if partition was deleted, False if it didn't exist
621
+ """
622
+ if partition_id in self.partitions:
623
+ del self.partitions[partition_id]
624
+ logger.info(f"Deleted partition {partition_id}")
625
+ return True
626
+ return False
627
+
628
+ # ==================== Partition Index Management API ====================
629
+
630
+ def get_partition_index_range(self, partition: DataPartitionStatus) -> set:
631
+ """
632
+ Get all indexes for a specific partition.
633
+
634
+ Args:
635
+ partition: Partition identifier
636
+
637
+ Returns:
638
+ Set of indexes allocated to the partition
639
+ """
640
+ return self.index_manager.get_indexes_for_partition(partition)
641
+
642
+ # ==================== Data Production API ====================
643
+
644
+ # TODO: Modify dtypes & shapes to be required
645
+ def update_production_status(
646
+ self,
647
+ partition_id: str,
648
+ global_indexes: list[int],
649
+ field_names: list[str],
650
+ dtypes: Optional[dict[int, dict[str, Any]]] = None,
651
+ shapes: Optional[dict[int, dict[str, Any]]] = None,
652
+ ) -> bool:
653
+ """
654
+ Update production status for specific samples and fields in a partition.
655
+ Delegates to the partition's own update_production_status method.
656
+
657
+ Args:
658
+ partition_id: ID of the partition
659
+ global_indexes: List of sample indices to update
660
+ field_names: List of field names to mark as produced
661
+ dtypes: Optional per-sample field dtype information
662
+ shapes: Optional per-sample field shape information
663
+
664
+ Returns:
665
+ True if update was successful, False otherwise
666
+ """
667
+ partition = self.get_partition(partition_id)
668
+ if not partition:
669
+ logger.error(f"Partition {partition_id} not found")
670
+ return False
671
+
672
+ success = partition.update_production_status(global_indexes, field_names, dtypes, shapes)
673
+ if success:
674
+ logger.debug(
675
+ f"Updated production status for partition {partition_id}: samples={global_indexes}, "
676
+ f"fields={field_names}"
677
+ )
678
+ return success
679
+
680
+ # ==================== Data Consumption API ====================
681
+
682
+ def get_consumption_status(self, partition_id: str, task_name: str) -> Optional[torch.Tensor]:
683
+ """
684
+ Get or create consumption status for a specific task and partition.
685
+ Delegates to the partition's own method.
686
+
687
+ Args:
688
+ partition_id: ID of the partition
689
+ task_name: Name of the consumer task
690
+
691
+ Returns:
692
+ Consumption status tensor if partition exists, None otherwise
693
+ """
694
+ partition = self.get_partition(partition_id)
695
+ if not partition:
696
+ return None
697
+
698
+ return partition.get_consumption_status(task_name)
699
+
700
+ def get_metadata(
701
+ self,
702
+ data_fields: list[str],
703
+ partition_id: str,
704
+ mode: str = "fetch",
705
+ task_name: str | None = None,
706
+ batch_size: int | None = None,
707
+ sampling_config: Optional[dict[str, Any]] = None,
708
+ *args,
709
+ **kwargs,
710
+ ) -> BatchMeta:
711
+ """
712
+ Retrieve metadata with support for three modes.
713
+
714
+ Args:
715
+ data_fields: List of field names to include in metadata
716
+ partition_id: Partition id for which to retrieve metadata
717
+ mode: Operation mode - 'insert', 'fetch', or 'force_fetch'
718
+ - mode="insert": Create metadata for new samples (for data insertion)
719
+ - mode="fetch": Get metadata from ready samples using the configured sampler
720
+ - mode="force_fetch": Get metadata for unconsumed samples without sampling
721
+ (excludes already consumed samples)
722
+ task_name: Name of the consumer task (required for fetch modes)
723
+ batch_size: Number of samples to retrieve
724
+ *args: Additional positional arguments
725
+ **kwargs: Additional keyword arguments
726
+
727
+ Returns:
728
+ BatchMeta object containing the requested metadata
729
+
730
+ Raises:
731
+ TimeoutError: If waiting for sufficient data times out in fetch mode
732
+ """
733
+ if partition_id not in self.partitions:
734
+ self.create_partition(partition_id)
735
+
736
+ if mode == "insert":
737
+ if data_fields:
738
+ # First put_data call, get_metadata in insert mode
739
+ batch_global_indexes = self.index_manager.allocate_indexes(partition_id, count=batch_size)
740
+ else:
741
+ # clear metadata call passes empty data_fields
742
+ batch_global_indexes = self.index_manager.get_indexes_for_partition(partition_id)
743
+ return self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
744
+
745
+ assert task_name is not None
746
+ if mode == "fetch":
747
+ # Find ready samples within current data partition and package into BatchMeta when reading
748
+
749
+ start_time = time.time()
750
+ while True:
751
+ # ready_for_consume_indexes: samples where all required fields are produced
752
+ # (production status is ready) and not yet consumed
753
+ ready_for_consume_indexes = self.scan_data_status(partition_id, data_fields, task_name, batch_size)
754
+
755
+ if len(ready_for_consume_indexes) < batch_size:
756
+ continue
757
+
758
+ # Try sampling - if it returns empty lists, retry
759
+ batch_global_indexes, consumed_indexes = self.sampler(
760
+ ready_for_consume_indexes,
761
+ batch_size,
762
+ **(sampling_config or {}),
763
+ )
764
+
765
+ # Check if we got valid results from the sampler
766
+ if len(batch_global_indexes) == batch_size:
767
+ break
768
+
769
+ if time.time() - start_time > TQ_CONTROLLER_GET_METADATA_TIMEOUT:
770
+ raise TimeoutError(
771
+ f"Timeout while waiting for sufficient data. "
772
+ f"Required: {batch_size}, Available: {len(ready_for_consume_indexes)}, "
773
+ f"Sampled: {len(batch_global_indexes)}"
774
+ )
775
+
776
+ logger.warning(
777
+ f"Insufficient complete groups available. Required: {batch_size}, "
778
+ f"Available: {len(ready_for_consume_indexes)}, "
779
+ f"Sampled: {len(batch_global_indexes)}. Retrying in "
780
+ f"{TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
781
+ )
782
+ time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
783
+ logger.debug(f"ready for consume idx: {ready_for_consume_indexes}")
784
+ logger.debug(f"sampled idx: {batch_global_indexes}")
785
+ elif mode == "force_fetch":
786
+ global_indexes_range = self.index_manager.get_indexes_for_partition(partition_id)
787
+ consumer_status = self.get_consumption_status(partition_id, task_name)
788
+ not_consumed_idx = [i for i in global_indexes_range if consumer_status[i] == 0]
789
+ batch_global_indexes = not_consumed_idx
790
+ consumed_indexes = []
791
+
792
+ # Package into metadata
793
+ metadata = self.generate_batch_meta(partition_id, batch_global_indexes, data_fields, mode)
794
+
795
+ # Mark samples as consumed if in fetch mode
796
+ if mode == "fetch" and consumed_indexes:
797
+ partition = self.partitions[partition_id]
798
+ partition.mark_consumed(task_name, consumed_indexes)
799
+
800
+ logger.debug(f"get_metadata: {metadata}")
801
+
802
+ return metadata
803
+
804
+ def scan_data_status(
805
+ self,
806
+ partition_id: str,
807
+ data_fields: list[str],
808
+ task_name: str,
809
+ batch_size: int,
810
+ sample_filter: Optional[list[int]] = None,
811
+ timeout: float = TQ_CONTROLLER_GET_METADATA_TIMEOUT,
812
+ ) -> list[int]:
813
+ """
814
+ Find samples that are ready for consumption in a specific partition.
815
+ Delegates scanning functionality to the partition's own method.
816
+
817
+ Args:
818
+ partition_id: ID of the partition
819
+ data_fields: List of required field names
820
+ task_name: Name of the consumer task
821
+ batch_size: Number of samples needed
822
+ sample_filter: Optional list of specific sample indices to consider
823
+ timeout: Maximum time to wait for sufficient data
824
+
825
+ Returns:
826
+ List of sample indices that are ready for consumption
827
+
828
+ Raises:
829
+ TimeoutError: If sufficient data is not available within timeout
830
+ """
831
+ start_time = time.time()
832
+
833
+ while True:
834
+ partition = self.get_partition(partition_id)
835
+ if not partition:
836
+ if time.time() - start_time > timeout:
837
+ raise TimeoutError(f"Partition {partition_id} not found")
838
+ time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
839
+ continue
840
+
841
+ # Use partition's own scanning method
842
+ ready_sample_indices = partition.scan_data_status(data_fields, task_name)
843
+
844
+ if len(ready_sample_indices) >= batch_size:
845
+ return ready_sample_indices[:batch_size]
846
+
847
+ if time.time() - start_time > timeout:
848
+ raise TimeoutError(
849
+ f"Timeout waiting for sufficient data in partition {partition_id}. "
850
+ f"Required: {batch_size}, Available: {len(ready_sample_indices)}"
851
+ )
852
+
853
+ logger.warning(
854
+ f"Insufficient data in partition {partition_id}. Required: {batch_size}, "
855
+ f"Available: {len(ready_sample_indices)}. Retrying in {TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL}s..."
856
+ )
857
+ time.sleep(TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL)
858
+
859
+ # ==================== Metadata Generation API ====================
860
+
861
+ def generate_batch_meta(
862
+ self,
863
+ partition_id: str,
864
+ batch_global_indexes: list[int],
865
+ data_fields: list[str],
866
+ mode: str = "fetch",
867
+ ) -> BatchMeta:
868
+ """
869
+ Generate BatchMeta for specific samples in a partition.
870
+
871
+ This function is responsible only for metadata generation and does not
872
+ modify consumption state. State management is handled by the calling function.
873
+
874
+ Args:
875
+ partition_id: ID of the partition
876
+ batch_global_indexes: List of sample indices to include in the batch
877
+ data_fields: List of field names to include
878
+ mode: Operation mode - 'fetch', 'insert', or 'force_fetch'
879
+
880
+ Returns:
881
+ BatchMeta object containing sample metadata
882
+
883
+ Raises:
884
+ ValueError: If partition doesn't exist or invalid mode
885
+ """
886
+ partition = self.get_partition(partition_id)
887
+ if not partition:
888
+ raise ValueError(f"Partition {partition_id} not found")
889
+
890
+ if mode not in ["fetch", "insert", "force_fetch"]:
891
+ raise ValueError(f"Invalid mode: {mode}")
892
+
893
+ # Generate sample metadata
894
+ samples = []
895
+ for global_index in batch_global_indexes:
896
+ fields = {}
897
+ for field_name in data_fields:
898
+ # Determine production status
899
+ if mode == "fetch":
900
+ production_status = ProductionStatus.READY_FOR_CONSUME
901
+ dtype = partition.get_field_dtype(global_index, field_name)
902
+ shape = partition.get_field_shape(global_index, field_name)
903
+ elif mode == "insert":
904
+ production_status = ProductionStatus.NOT_PRODUCED
905
+ dtype = None
906
+ shape = None
907
+ elif mode == "force_fetch":
908
+ field_index = partition.field_name_mapping.get(field_name)
909
+ if (
910
+ field_index is not None
911
+ and partition.production_status is not None
912
+ and partition.production_status[global_index, field_index] == 1
913
+ ):
914
+ production_status = ProductionStatus.NOT_PRODUCED
915
+ dtype = partition.get_field_dtype(global_index, field_name)
916
+ shape = partition.get_field_shape(global_index, field_name)
917
+ else:
918
+ production_status = ProductionStatus.NOT_PRODUCED
919
+ dtype = None
920
+ shape = None
921
+
922
+ fields[field_name] = FieldMeta(
923
+ name=field_name,
924
+ dtype=dtype,
925
+ shape=shape,
926
+ production_status=production_status,
927
+ )
928
+
929
+ sample = SampleMeta(
930
+ partition_id=partition_id,
931
+ global_index=global_index,
932
+ fields=fields,
933
+ )
934
+ samples.append(sample)
935
+
936
+ return BatchMeta(samples=samples)
937
+
938
+ def clear(self, partition_id: str, clear_consumption: bool = True) -> bool:
939
+ """
940
+ Clear data for a specific partition.
941
+
942
+ Args:
943
+ partition_id: ID of the partition to clear
944
+ clear_consumption: Whether to also clear consumption status
945
+
946
+ Returns:
947
+ True if cleared successfully, False otherwise
948
+ """
949
+ partition = self.get_partition(partition_id)
950
+ if not partition:
951
+ raise ValueError(f"Partition {partition_id} not found")
952
+
953
+ global_indexes_range = list(self.index_manager.get_indexes_for_partition(partition_id))
954
+ success = partition.clear_data(global_indexes_range, clear_consumption)
955
+ self.index_manager.release_indexes(partition_id)
956
+ if success:
957
+ logger.info(f"Cleared data for partition {partition_id}")
958
+ return success
959
+
960
+ def _init_zmq_socket(self):
961
+ """Initialize ZMQ sockets for communication."""
962
+ self.zmq_context = zmq.Context()
963
+ self._node_ip = get_node_ip_address()
964
+ self._handshake_socket_port = get_free_port()
965
+ self._request_handle_socket_port = get_free_port()
966
+ self._data_status_update_socket_port = get_free_port()
967
+
968
+ self.handshake_socket = create_zmq_socket(
969
+ ctx=self.zmq_context,
970
+ socket_type=zmq.ROUTER,
971
+ )
972
+ self.handshake_socket.bind(f"tcp://{self._node_ip}:{self._handshake_socket_port}")
973
+
974
+ self.request_handle_socket = create_zmq_socket(
975
+ ctx=self.zmq_context,
976
+ socket_type=zmq.ROUTER,
977
+ )
978
+ self.request_handle_socket.bind(f"tcp://{self._node_ip}:{self._request_handle_socket_port}")
979
+
980
+ self.data_status_update_socket = create_zmq_socket(
981
+ ctx=self.zmq_context,
982
+ socket_type=zmq.ROUTER,
983
+ )
984
+ self.data_status_update_socket.bind(f"tcp://{self._node_ip}:{self._data_status_update_socket_port}")
985
+
986
+ self.zmq_server_info = ZMQServerInfo(
987
+ role=TransferQueueRole.CONTROLLER,
988
+ id=self.controller_id,
989
+ ip=self._node_ip,
990
+ ports={
991
+ "handshake_socket": self._handshake_socket_port,
992
+ "request_handle_socket": self._request_handle_socket_port,
993
+ "data_status_update_socket": self._data_status_update_socket_port,
994
+ },
995
+ )
996
+
997
+ def _wait_connection(self):
998
+ """Wait for storage instances to complete handshake with retransmission support."""
999
+ poller = zmq.Poller()
1000
+ poller.register(self.handshake_socket, zmq.POLLIN)
1001
+
1002
+ logger.info(f"Dynamic Controller {self.controller_id} started waiting for storage connections...")
1003
+
1004
+ while True:
1005
+ socks = dict(poller.poll(TQ_CONTROLLER_CONNECTION_CHECK_INTERVAL * 1000))
1006
+
1007
+ if self.handshake_socket in socks:
1008
+ try:
1009
+ identity, serialized_msg = self.handshake_socket.recv_multipart()
1010
+ request_msg = ZMQMessage.deserialize(serialized_msg)
1011
+
1012
+ if request_msg.request_type == ZMQRequestType.HANDSHAKE:
1013
+ storage_manager_id = request_msg.sender_id
1014
+
1015
+ # Always send ACK for HANDSHAKE
1016
+ response_msg = ZMQMessage.create(
1017
+ request_type=ZMQRequestType.HANDSHAKE_ACK,
1018
+ sender_id=self.controller_id,
1019
+ body={},
1020
+ ).serialize()
1021
+ self.handshake_socket.send_multipart([identity, response_msg])
1022
+
1023
+ # Track new connections
1024
+ if storage_manager_id not in self._connected_storage_managers:
1025
+ self._connected_storage_managers.add(storage_manager_id)
1026
+ storage_manager_type = request_msg.body.get("storage_manager_type", "Unknown")
1027
+ logger.info(
1028
+ f"Dynamic Controller {self.controller_id} received handshake from "
1029
+ f"storage manager {storage_manager_id} (type: {storage_manager_type}). "
1030
+ f"Total connected: {len(self._connected_storage_managers)}"
1031
+ )
1032
+ else:
1033
+ logger.debug(
1034
+ f"Dynamic Controller {self.controller_id} received duplicate handshake from "
1035
+ f"storage manager {storage_manager_id}. Resending ACK."
1036
+ )
1037
+
1038
+ except Exception as e:
1039
+ logger.error(f"Dynamic Controller {self.controller_id} error processing handshake: {e}")
1040
+
1041
+ def _start_process_handshake(self):
1042
+ """Start the handshake process thread."""
1043
+ self.wait_connection_thread = Thread(
1044
+ target=self._wait_connection, name="DynamicTransferQueueControllerWaitConnectionThread", daemon=True
1045
+ )
1046
+ self.wait_connection_thread.start()
1047
+
1048
+ def _start_process_update_data_status(self):
1049
+ """Start the data status update processing thread."""
1050
+ self.process_update_data_status_thread = Thread(
1051
+ target=self._update_data_status,
1052
+ name="DynamicTransferQueueControllerProcessUpdateDataStatusThread",
1053
+ daemon=True,
1054
+ )
1055
+ self.process_update_data_status_thread.start()
1056
+
1057
+ def _start_process_request(self):
1058
+ """Start the request processing thread."""
1059
+ self.process_request_thread = Thread(
1060
+ target=self._process_request, name="DynamicTransferQueueControllerProcessRequestThread", daemon=True
1061
+ )
1062
+ self.process_request_thread.start()
1063
+
1064
+ def _process_request(self):
1065
+ """Main request processing loop - adapted for partition-based operations."""
1066
+ while True:
1067
+ identity, serialized_msg = self.request_handle_socket.recv_multipart()
1068
+ request_msg = ZMQMessage.deserialize(serialized_msg)
1069
+
1070
+ if request_msg.request_type == ZMQRequestType.GET_META:
1071
+ params = request_msg.body
1072
+
1073
+ metadata = self.get_metadata(
1074
+ data_fields=params["data_fields"],
1075
+ batch_size=params["batch_size"],
1076
+ partition_id=params["partition_id"],
1077
+ mode=params.get("mode", "fetch"),
1078
+ task_name=params.get("task_name"),
1079
+ sampling_config=params.get("sampling_config"),
1080
+ )
1081
+
1082
+ response_msg = ZMQMessage.create(
1083
+ request_type=ZMQRequestType.GET_META_RESPONSE,
1084
+ sender_id=self.controller_id,
1085
+ receiver_id=request_msg.sender_id,
1086
+ body={"metadata": metadata},
1087
+ )
1088
+
1089
+ elif request_msg.request_type == ZMQRequestType.GET_CLEAR_META:
1090
+ params = request_msg.body
1091
+ partition_id = params["partition_id"]
1092
+
1093
+ metadata = self.get_metadata(
1094
+ data_fields=[],
1095
+ partition_id=partition_id,
1096
+ mode="insert",
1097
+ )
1098
+ response_msg = ZMQMessage.create(
1099
+ request_type=ZMQRequestType.GET_CLEAR_META_RESPONSE,
1100
+ sender_id=self.controller_id,
1101
+ receiver_id=request_msg.sender_id,
1102
+ body={"metadata": metadata},
1103
+ )
1104
+ elif request_msg.request_type == ZMQRequestType.CLEAR_META:
1105
+ params = request_msg.body
1106
+ partition_id = params["partition_id"]
1107
+
1108
+ clear_success = self.clear(partition_id)
1109
+ if clear_success:
1110
+ response_msg = ZMQMessage.create(
1111
+ request_type=ZMQRequestType.CLEAR_META_RESPONSE,
1112
+ sender_id=self.controller_id,
1113
+ receiver_id=request_msg.sender_id,
1114
+ body={"message": f"Clear operation completed by controller {self.controller_id}"},
1115
+ )
1116
+ else:
1117
+ response_msg = ZMQMessage.create(
1118
+ request_type=ZMQRequestType.CLEAR_META_RESPONSE,
1119
+ sender_id=self.controller_id,
1120
+ receiver_id=request_msg.sender_id,
1121
+ body={"error": f"Clear operation failed for partition {partition_id}"},
1122
+ )
1123
+
1124
+ elif request_msg.request_type == ZMQRequestType.CHECK_CONSUMPTION:
1125
+ # Handle consumption status checks
1126
+ params = request_msg.body
1127
+
1128
+ consumption_status = self.get_consumption_status(params["partition_id"], params["task_name"])
1129
+ sample_filter = params.get("sample_filter")
1130
+
1131
+ if consumption_status is not None and sample_filter:
1132
+ batch_status = consumption_status[sample_filter]
1133
+ consumed = torch.all(batch_status == 1).item()
1134
+ elif consumption_status is not None:
1135
+ batch_status = consumption_status
1136
+ consumed = torch.all(batch_status == 1).item()
1137
+ else:
1138
+ consumed = False
1139
+
1140
+ response_msg = ZMQMessage.create(
1141
+ request_type=ZMQRequestType.CONSUMPTION_RESPONSE,
1142
+ sender_id=self.controller_id,
1143
+ receiver_id=request_msg.sender_id,
1144
+ body={
1145
+ "partition_id": params["partition_id"],
1146
+ "consumed": consumed,
1147
+ },
1148
+ )
1149
+ self.request_handle_socket.send_multipart([identity, response_msg.serialize()])
1150
+
1151
+ def _update_data_status(self):
1152
+ """Process data status update messages from storage units - adapted for partitions."""
1153
+ while True:
1154
+ identity, serialized_msg = self.data_status_update_socket.recv_multipart()
1155
+ request_msg = ZMQMessage.deserialize(serialized_msg)
1156
+
1157
+ if request_msg.request_type == ZMQRequestType.NOTIFY_DATA_UPDATE:
1158
+ message_data = request_msg.body
1159
+ partition_id = message_data.get("partition_id")
1160
+
1161
+ # Update production status
1162
+ success = self.update_production_status(
1163
+ partition_id=partition_id,
1164
+ global_indexes=message_data.get("global_indexes", []),
1165
+ field_names=message_data.get("fields", []),
1166
+ dtypes=message_data.get("dtypes", {}),
1167
+ shapes=message_data.get("shapes", {}),
1168
+ )
1169
+
1170
+ if success:
1171
+ logger.info(f"Updated production status for partition {partition_id}")
1172
+
1173
+ # Send acknowledgment
1174
+ response_msg = ZMQMessage.create(
1175
+ request_type=ZMQRequestType.NOTIFY_DATA_UPDATE_ACK,
1176
+ sender_id=self.controller_id,
1177
+ body={
1178
+ "controller_id": self.controller_id,
1179
+ "partition_id": partition_id,
1180
+ "success": success,
1181
+ },
1182
+ )
1183
+ self.data_status_update_socket.send_multipart([identity, response_msg.serialize()])
1184
+
1185
+ def get_zmq_server_info(self) -> ZMQServerInfo:
1186
+ """Get ZMQ server connection information."""
1187
+ return self.zmq_server_info