caption-flow 0.2.4__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,22 +1,24 @@
1
- """HuggingFace Datasets processor implementation."""
1
+ """HuggingFace Datasets processor implementation - Memory Optimized Version."""
2
2
 
3
3
  import logging
4
4
  import threading
5
5
  import re
6
+ import queue
6
7
  import requests
8
+ import json
9
+ import io
10
+ import os
11
+ import gc
12
+ import psutil
13
+ from concurrent.futures import ThreadPoolExecutor, Future
7
14
  from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
8
15
  from collections import deque, defaultdict
9
16
  from pathlib import Path
10
- import json
11
- import io
12
17
  from datetime import datetime
13
18
  from PIL import Image
14
- from datasets import (
15
- Dataset,
16
- get_dataset_config_names,
17
- get_dataset_split_names,
18
- load_dataset_builder,
19
- )
19
+ import pyarrow as pa
20
+ import pyarrow.parquet as pq
21
+ from datasets import get_dataset_config_names, get_dataset_split_names
20
22
  from huggingface_hub import hf_hub_download, get_token
21
23
  from caption_flow.storage import StorageManager
22
24
 
@@ -28,11 +30,82 @@ logger = logging.getLogger(__name__)
28
30
  logger.setLevel(logging.DEBUG)
29
31
 
30
32
 
33
+ def log_memory(location: str):
34
+ """Log memory usage at specific location."""
35
+ process = psutil.Process(os.getpid())
36
+ mem_info = process.memory_info()
37
+ logger.info(
38
+ f"Memory at {location}: RSS={mem_info.rss/1024/1024:.1f}MB, VMS={mem_info.vms/1024/1024:.1f}MB"
39
+ )
40
+ # Force garbage collection
41
+ gc.collect()
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(logging.DEBUG)
46
+
47
+
48
+ class NonBlockingQueueHandler:
49
+ """Handles non-blocking retrieval from queues using concurrent futures."""
50
+
51
+ def __init__(self, max_workers: int = 1):
52
+ self.executor = ThreadPoolExecutor(max_workers=max_workers)
53
+ self.pending_futures: Dict[int, Future] = {} # queue_id -> Future
54
+
55
+ def get_from_queue_async(self, response_queue: queue.Queue, timeout: float = None) -> Future:
56
+ """Start an async queue retrieval."""
57
+ queue_id = id(response_queue)
58
+
59
+ # Check if we already have a pending future for this queue
60
+ if queue_id in self.pending_futures and not self.pending_futures[queue_id].done():
61
+ return self.pending_futures[queue_id]
62
+
63
+ # Start new async retrieval
64
+ future = self.executor.submit(response_queue.get, timeout=timeout)
65
+ self.pending_futures[queue_id] = future
66
+ return future
67
+
68
+ def check_response(self, response_queue: queue.Queue, timeout: float = None) -> Optional[Any]:
69
+ """Non-blocking check for queue response."""
70
+ queue_id = id(response_queue)
71
+
72
+ # Start async retrieval if needed
73
+ future = self.get_from_queue_async(response_queue, timeout)
74
+
75
+ # Check if result is ready (non-blocking)
76
+ if future.done():
77
+ try:
78
+ result = future.result(timeout=0)
79
+ # Clear future for next retrieval
80
+ if queue_id in self.pending_futures:
81
+ del self.pending_futures[queue_id]
82
+ return result
83
+ except queue.Empty:
84
+ # Queue was empty, clear future
85
+ if queue_id in self.pending_futures:
86
+ del self.pending_futures[queue_id]
87
+ return None
88
+ except Exception as e:
89
+ logger.error(f"Error retrieving from queue: {e}")
90
+ if queue_id in self.pending_futures:
91
+ del self.pending_futures[queue_id]
92
+ return None
93
+
94
+ # Result not ready yet
95
+ return None
96
+
97
+ def shutdown(self):
98
+ """Shutdown the executor."""
99
+ self.executor.shutdown(wait=True)
100
+
101
+
31
102
  class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
32
- """Orchestrator processor for HuggingFace datasets."""
103
+ """Memory-optimized orchestrator processor for HuggingFace datasets with non-blocking operations."""
33
104
 
34
105
  def __init__(self):
35
- logger.debug("Initializing HuggingFaceDatasetOrchestratorProcessor")
106
+ logger.debug(
107
+ "Initializing HuggingFaceDatasetOrchestratorProcessor (Optimized + Non-blocking)"
108
+ )
36
109
  self.dataset_name: Optional[str] = None
37
110
  self.config: Optional[str] = None
38
111
  self.split: Optional[str] = None
@@ -44,19 +117,33 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
44
117
  self.shard_info: Dict[int, Dict[str, Any]] = {}
45
118
  self.total_items: int = 0
46
119
 
47
- # Work unit management
48
- self.work_units: Dict[str, WorkUnit] = {}
120
+ # Work unit management - only store active units
49
121
  self.pending_units: Deque[str] = deque()
50
- self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
122
+ self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
51
123
  self.lock = threading.Lock()
52
124
 
125
+ # Track current chunk index for on-demand creation
126
+ self.current_chunk_index = 0
127
+
128
+ # Cache data files info instead of loading builder repeatedly
129
+ self.data_files: List[str] = []
130
+
53
131
  # Background thread for creating work units
54
132
  self.unit_creation_thread: Optional[threading.Thread] = None
55
133
  self.stop_creation = threading.Event()
56
134
 
135
+ # Non-blocking queue handler
136
+ self.queue_handler = NonBlockingQueueHandler()
137
+
138
+ # Response processing state
139
+ self.last_maintenance_time = datetime.now()
140
+ self.maintenance_interval = 30 # seconds
141
+
57
142
  def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
58
143
  """Initialize HuggingFace dataset processor."""
59
144
  logger.debug("Initializing orchestrator with config: %s", config.config)
145
+ log_memory("start of initialize")
146
+
60
147
  cfg = config.config
61
148
 
62
149
  # Dataset configuration
@@ -83,12 +170,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
83
170
  self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
84
171
 
85
172
  # Initialize chunk tracking
86
- checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
87
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
88
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
173
+ self.checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
174
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
175
+ self.chunk_tracker = ChunkTracker(self.checkpoint_dir / "chunks.json")
89
176
 
90
- # Discover shards
91
- self._discover_shards()
177
+ # Discover shards (optimized)
178
+ self._discover_shards_optimized()
92
179
 
93
180
  # Restore existing state
94
181
  self._restore_state(storage=storage)
@@ -98,7 +185,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
98
185
  target=self._create_units_background, daemon=True
99
186
  )
100
187
  self.unit_creation_thread.start()
101
- logger.debug("Unit creation thread started")
188
+
189
+ log_memory("end of initialize")
102
190
 
103
191
  def _detect_config(self, provided_config: Optional[str]) -> str:
104
192
  """Auto-detect config if not provided."""
@@ -110,14 +198,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
110
198
  if not configs:
111
199
  return "default"
112
200
 
113
- # Prefer common config names
114
201
  preferred = ["default", "en", "train", "main"]
115
202
  for pref in preferred:
116
203
  if pref in configs:
117
204
  logger.info(f"Auto-selected config: {pref}")
118
205
  return pref
119
206
 
120
- # Otherwise use first available
121
207
  logger.info(f"Auto-selected first available config: {configs[0]}")
122
208
  return configs[0]
123
209
  except Exception as e:
@@ -134,17 +220,14 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
134
220
  self.dataset_name, config_name=self.config, token=self.token
135
221
  )
136
222
  if not splits:
137
- logger.warning("No splits found, using 'train'")
138
223
  return "train"
139
224
 
140
- # Prefer training splits
141
225
  preferred = ["train", "training", "test", "validation", "dev"]
142
226
  for pref in preferred:
143
227
  if pref in splits:
144
228
  logger.info(f"Auto-selected split: {pref}")
145
229
  return pref
146
230
 
147
- # Otherwise use first available
148
231
  logger.info(f"Auto-selected first available split: {splits[0]}")
149
232
  return splits[0]
150
233
  except Exception as e:
@@ -153,18 +236,16 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
153
236
 
154
237
  def _extract_filename_from_url(self, url: str) -> str:
155
238
  """Extract filename from HF URL format."""
156
- # Format: hf://datasets/user/dataset@hash/filename
157
239
  match = re.search(r"@[a-f0-9]+/(.+)$", url)
158
240
  if match:
159
241
  return match.group(1)
160
- # Fallback: just get last part
161
242
  return url.split("/")[-1]
162
243
 
163
- def _discover_shards(self):
164
- """Discover all shards and their sizes."""
165
- logger.info("Discovering shards...")
244
+ def _get_data_files_from_builder(self) -> List[str]:
245
+ """Get data files using dataset builder with minimal memory usage."""
246
+ # Load builder to get correct file structure
247
+ from datasets import load_dataset_builder
166
248
 
167
- # Load dataset builder to get file info
168
249
  builder = load_dataset_builder(self.dataset_name, self.config)
169
250
 
170
251
  # Get data files for our split
@@ -176,81 +257,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
176
257
  files = [files]
177
258
  data_files = files
178
259
 
179
- if not data_files:
260
+ # Explicitly delete builder to free memory
261
+ del builder
262
+ gc.collect()
263
+
264
+ return data_files
265
+
266
+ def _discover_shards_optimized(self):
267
+ """Discover all shards using dataset builder but release memory immediately."""
268
+ logger.info("Discovering shards...")
269
+
270
+ # Try to load cached shard info first
271
+ shard_info_cache_path = (
272
+ self.checkpoint_dir / f"{self.dataset_name}_{self.config}_{self.split}_shard_info.json"
273
+ )
274
+
275
+ if shard_info_cache_path.exists():
276
+ try:
277
+ with open(shard_info_cache_path, "r") as f:
278
+ cached_info = json.load(f)
279
+ if (
280
+ cached_info.get("dataset") == self.dataset_name
281
+ and cached_info.get("config") == self.config
282
+ and cached_info.get("split") == self.split
283
+ ):
284
+ self.shard_info = {int(k): v for k, v in cached_info["shards"].items()}
285
+ self.total_items = cached_info["total_items"]
286
+ self.data_files = cached_info.get("data_files", [])
287
+ logger.info(
288
+ f"Loaded cached shard info: {len(self.shard_info)} shards, {self.total_items} total items"
289
+ )
290
+ return
291
+ except Exception as e:
292
+ logger.warning(f"Failed to load cached shard info: {e}")
293
+
294
+ # Get data files using dataset builder
295
+ self.data_files = self._get_data_files_from_builder()
296
+
297
+ if not self.data_files:
180
298
  raise ValueError(f"No data files found for split '{self.split}'")
181
299
 
182
- logger.info(f"Found {len(data_files)} data files")
300
+ logger.info(f"Found {len(self.data_files)} data files")
183
301
 
184
- # Get info about each shard
302
+ # Get metadata for each shard
185
303
  cumulative_offset = 0
186
- for i, file_url in enumerate(data_files):
304
+ for i, file_url in enumerate(self.data_files):
187
305
  filename = self._extract_filename_from_url(file_url)
188
306
  logger.info(f"Discovering shard {i}: {filename}")
189
307
 
190
- # We don't download shards here - workers will do that
191
- # For now, store the info we have
192
- self.shard_info[i] = {
193
- "shard_id": i,
194
- "file_url": file_url,
195
- "filename": filename,
196
- "start_offset": cumulative_offset,
197
- # Size will be determined when first worker needs it
198
- "size": None,
199
- "end_offset": None,
200
- }
201
-
202
- # Try to get size from builder info if available
203
- if hasattr(builder.info, "splits") and self.split in builder.info.splits:
204
- split_info = builder.info.splits[self.split]
205
- if split_info.num_examples and len(data_files) == 1:
206
- # Single shard case
207
- self.shard_info[i]["size"] = split_info.num_examples
208
- self.shard_info[i]["end_offset"] = (
209
- cumulative_offset + split_info.num_examples - 1
210
- )
211
- cumulative_offset += split_info.num_examples
212
-
213
- # If we couldn't get sizes, we'll need to load shards on demand
214
- if self.shard_info[0]["size"] is None:
215
- logger.warning("Shard sizes not available from metadata, will load on demand")
216
- else:
217
- self.total_items = cumulative_offset
218
- logger.info(f"Total items across all shards: {self.total_items}")
219
-
220
- def _get_shard_size(self, shard_id: int) -> int:
221
- """Get size of a shard, loading it if necessary."""
222
- if self.shard_info[shard_id]["size"] is not None:
223
- return self.shard_info[shard_id]["size"]
308
+ try:
309
+ # Download file to get metadata
310
+ local_path = hf_hub_download(
311
+ repo_id=self.dataset_name,
312
+ filename=filename,
313
+ repo_type="dataset",
314
+ token=self.token,
315
+ )
224
316
 
225
- # Need to load the shard to get its size
226
- logger.info(f"Loading shard {shard_id} to determine size...")
227
- filename = self.shard_info[shard_id]["filename"]
317
+ # Read only metadata
318
+ metadata = pq.read_metadata(local_path)
319
+ size = metadata.num_rows
320
+
321
+ self.shard_info[i] = {
322
+ "shard_id": i,
323
+ "file_url": file_url,
324
+ "filename": filename,
325
+ "start_offset": cumulative_offset,
326
+ "size": size,
327
+ "end_offset": cumulative_offset + size - 1,
328
+ }
228
329
 
229
- local_path = hf_hub_download(
230
- repo_id=self.dataset_name, filename=filename, repo_type="dataset", token=self.token
231
- )
330
+ cumulative_offset += size
331
+ logger.info(f"Shard {i} ({filename}): {size} rows")
232
332
 
233
- # Load just to get size
234
- dataset = Dataset.from_parquet(local_path)
235
- size = len(dataset)
333
+ except Exception as e:
334
+ logger.error(f"Failed to discover shard {i}: {e}")
335
+ # Skip this shard
336
+ continue
236
337
 
237
- # Update shard info
238
- self.shard_info[shard_id]["size"] = size
338
+ self.total_items = cumulative_offset
339
+ logger.info(f"Total items across all shards: {self.total_items}")
239
340
 
240
- # Update offsets for this and subsequent shards
241
- for sid in range(shard_id, len(self.shard_info)):
242
- if sid > shard_id:
243
- self.shard_info[sid]["start_offset"] = self.shard_info[sid - 1]["end_offset"] + 1
244
- self.shard_info[sid]["end_offset"] = (
245
- self.shard_info[sid]["start_offset"] + self.shard_info[sid]["size"] - 1
246
- )
341
+ # Cache shard info
342
+ try:
343
+ cache_data = {
344
+ "dataset": self.dataset_name,
345
+ "config": self.config,
346
+ "split": self.split,
347
+ "shards": self.shard_info,
348
+ "total_items": self.total_items,
349
+ "data_files": self.data_files,
350
+ }
351
+ with open(shard_info_cache_path, "w") as f:
352
+ json.dump(cache_data, f)
353
+ logger.info(f"Cached shard info to {shard_info_cache_path}")
354
+ except Exception as e:
355
+ logger.warning(f"Failed to cache shard info: {e}")
247
356
 
248
- # Update total items
249
- if all(s["size"] is not None for s in self.shard_info.values()):
250
- self.total_items = sum(s["size"] for s in self.shard_info.values())
251
- logger.info(f"Total items: {self.total_items}")
357
+ # Force garbage collection
358
+ gc.collect()
359
+ log_memory("after discovering shards")
252
360
 
253
- return size
361
+ def _get_shard_for_index(self, global_index: int) -> Tuple[int, int]:
362
+ """Get shard ID and local index for a global index."""
363
+ for shard_id, sinfo in self.shard_info.items():
364
+ if sinfo["start_offset"] <= global_index <= sinfo["end_offset"]:
365
+ local_index = global_index - sinfo["start_offset"]
366
+ return shard_id, local_index
367
+ raise ValueError(f"Global index {global_index} not found in any shard")
254
368
 
255
369
  def _restore_state(self, storage: StorageManager) -> None:
256
370
  """Restore state from chunk tracker."""
@@ -258,73 +372,84 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
258
372
  if not self.chunk_tracker:
259
373
  return
260
374
 
261
- all_processed_jobs = storage.get_all_processed_job_ids()
262
-
263
375
  with self.lock:
376
+ max_chunk_index = -1
377
+
264
378
  for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
265
- # Calculate actual unprocessed ranges
266
- chunk_range = (
267
- chunk_state.start_index,
268
- chunk_state.start_index + chunk_state.chunk_size - 1,
269
- )
379
+ chunk_index = chunk_state.start_index // self.chunk_size
380
+ max_chunk_index = max(max_chunk_index, chunk_index)
381
+
382
+ # Only add incomplete chunks to pending
383
+ if chunk_state.status != "completed":
384
+ self.pending_units.append(chunk_id)
385
+ elif chunk_state.status == "completed" and chunk_state.processed_ranges:
386
+ logger.warning(
387
+ f"Chunk {chunk_id} has processed_ranges stored in the checkpoint."
388
+ )
389
+
390
+ self.current_chunk_index = max_chunk_index + 1
391
+ logger.info(f"Resuming from chunk index {self.current_chunk_index}")
270
392
 
271
- # Get processed indices for this chunk
272
- processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
273
- chunk_id, all_processed_jobs
393
+ def _create_work_unit(self, chunk_index: int) -> Optional[WorkUnit]:
394
+ """Create a single work unit for a chunk index."""
395
+ current_index = chunk_index * self.chunk_size
396
+
397
+ if current_index >= self.total_items:
398
+ return None
399
+
400
+ chunk_size = min(self.chunk_size, self.total_items - current_index)
401
+
402
+ # Find shard for this chunk
403
+ shard_id, _ = self._get_shard_for_index(current_index)
404
+ shard_name = Path(self.shard_info[shard_id]["filename"]).stem
405
+
406
+ job_id_obj = JobId(shard_id=shard_name, chunk_id=chunk_index, sample_id=current_index)
407
+ unit_id = job_id_obj.get_chunk_str()
408
+
409
+ # Calculate unprocessed ranges based on existing chunk state
410
+ unprocessed_ranges = [(current_index, current_index + chunk_size - 1)]
411
+
412
+ if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
413
+ chunk_state = self.chunk_tracker.chunks[unit_id]
414
+ if chunk_state.processed_ranges:
415
+ # Subtract processed ranges from total range
416
+ unprocessed_ranges = self._subtract_ranges(
417
+ [(current_index, current_index + chunk_size - 1)], chunk_state.processed_ranges
274
418
  )
275
419
 
276
- # Calculate unprocessed ranges
277
- unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
278
-
279
- if unprocessed_ranges:
280
- # Find which shard(s) this chunk belongs to
281
- shard_ids = []
282
- for sid, sinfo in self.shard_info.items():
283
- # Need size to check
284
- if sinfo["size"] is None:
285
- self._get_shard_size(sid)
286
-
287
- if (
288
- sinfo["start_offset"]
289
- <= chunk_state.start_index + chunk_state.chunk_size - 1
290
- and sinfo["end_offset"] >= chunk_state.start_index
291
- ):
292
- shard_ids.append(sid)
293
- logger.info(f"Found shard {sid} for chunk {chunk_id}: {sinfo}")
294
-
295
- chunk_index = chunk_state.start_index // self.chunk_size
296
- shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
297
- unit = WorkUnit(
298
- unit_id=chunk_id,
299
- chunk_id=chunk_id,
300
- source_id=shard_name,
301
- data={
302
- "dataset_name": self.dataset_name,
303
- "config": self.config,
304
- "split": self.split,
305
- "start_index": chunk_state.start_index,
306
- "chunk_size": chunk_state.chunk_size,
307
- "unprocessed_ranges": unprocessed_ranges,
308
- "shard_ids": shard_ids,
309
- },
310
- metadata={
311
- "dataset": self.dataset_name,
312
- "shard_name": shard_name,
313
- "chunk_index": chunk_index,
314
- },
315
- )
420
+ # If all ranges are processed, return None (shouldn't happen if status tracking is correct)
421
+ if not unprocessed_ranges:
422
+ return None
423
+
424
+ unit = WorkUnit(
425
+ unit_id=unit_id,
426
+ chunk_id=unit_id,
427
+ source_id=shard_name,
428
+ unit_size=chunk_size,
429
+ data={
430
+ "dataset_name": self.dataset_name,
431
+ "config": self.config,
432
+ "split": self.split,
433
+ "start_index": current_index,
434
+ "chunk_size": chunk_size,
435
+ "unprocessed_ranges": unprocessed_ranges, # Use calculated ranges
436
+ "shard_ids": [shard_id],
437
+ "data_files": self.data_files,
438
+ },
439
+ metadata={
440
+ "dataset": self.dataset_name,
441
+ "shard_name": shard_name,
442
+ "chunk_index": chunk_index,
443
+ },
444
+ )
316
445
 
317
- self.work_units[unit.unit_id] = unit
318
- self.pending_units.append(unit.unit_id)
446
+ return unit
319
447
 
320
448
  def _create_units_background(self) -> None:
321
449
  """Background thread to create work units on demand."""
322
450
  logger.info("Starting work unit creation thread")
323
451
 
324
- current_index = 0
325
-
326
452
  while not self.stop_creation.is_set():
327
- # Check if we need more units
328
453
  with self.lock:
329
454
  pending_count = len(self.pending_units)
330
455
  assigned_count = sum(len(units) for units in self.assigned_units.values())
@@ -337,127 +462,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
337
462
  threading.Event().wait(5)
338
463
  continue
339
464
 
340
- # Make sure we know total items
341
- if self.total_items == 0:
342
- # Load all shard sizes
343
- for sid in range(len(self.shard_info)):
344
- self._get_shard_size(sid)
345
-
346
465
  # Create units as needed
347
466
  units_created = 0
348
467
 
349
- while units_created < units_needed and current_index < self.total_items:
350
- chunk_size = min(self.chunk_size, self.total_items - current_index)
351
- chunk_id = current_index // self.chunk_size
352
-
353
- with self.lock:
354
- shard_ids = []
355
- for sid, sinfo in self.shard_info.items():
356
- if (
357
- sinfo["start_offset"] <= current_index + chunk_size - 1
358
- and sinfo["end_offset"] >= current_index
359
- ):
360
- shard_ids.append(sid)
361
- shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
468
+ while units_created < units_needed:
469
+ logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
470
+ if self.current_chunk_index * self.chunk_size >= self.total_items:
471
+ threading.Event().wait(30)
472
+ break
473
+ # Get shard info for proper unit_id
474
+ current_index = self.current_chunk_index * self.chunk_size
475
+ if current_index < self.total_items:
476
+ shard_id, _ = self._get_shard_for_index(current_index)
477
+ shard_name = Path(self.shard_info[shard_id]["filename"]).stem
362
478
 
363
479
  job_id_obj = JobId(
364
- shard_id=shard_name, chunk_id=chunk_id, sample_id=current_index
480
+ shard_id=shard_name,
481
+ chunk_id=self.current_chunk_index,
482
+ sample_id=current_index,
365
483
  )
366
- unit_id = (
367
- job_id_obj.get_chunk_str()
368
- ) # just the chunk part, eg pixel-images:chunk:0
369
- if unit_id in self.work_units:
370
- current_index += self.chunk_size
371
- continue
372
-
373
- # Check if chunk is already completed
374
- if self.chunk_tracker:
375
- chunk_state = self.chunk_tracker.chunks.get(unit_id)
376
- if chunk_state and chunk_state.status == "completed":
377
- current_index += self.chunk_size
378
- continue
484
+ unit_id = job_id_obj.get_chunk_str()
379
485
 
380
- # Find which shard(s) this chunk belongs to
381
-
382
- unit = WorkUnit(
383
- unit_id=unit_id,
384
- chunk_id=unit_id,
385
- source_id=shard_name,
386
- data={
387
- "dataset_name": self.dataset_name,
388
- "config": self.config,
389
- "split": self.split,
390
- "start_index": current_index,
391
- "chunk_size": chunk_size,
392
- "unprocessed_ranges": [(current_index, current_index + chunk_size - 1)],
393
- "shard_ids": shard_ids,
394
- },
395
- metadata={
396
- "dataset": self.dataset_name,
397
- "shard_name": shard_name,
398
- "chunk_index": chunk_id,
399
- },
400
- )
401
- logger.debug(f"Created WorkUnit: {unit}")
486
+ with self.lock:
487
+ # Check if already tracked
488
+ if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
489
+ chunk_state = self.chunk_tracker.chunks[unit_id]
490
+ if chunk_state.status == "completed":
491
+ self.current_chunk_index += 1
492
+ continue
402
493
 
403
- self.work_units[unit_id] = unit
494
+ # Add to pending
404
495
  self.pending_units.append(unit_id)
405
496
 
497
+ # Track in chunk tracker
406
498
  if self.chunk_tracker:
499
+ start_index = self.current_chunk_index * self.chunk_size
500
+ chunk_size = min(self.chunk_size, self.total_items - start_index)
407
501
  self.chunk_tracker.add_chunk(
408
502
  unit_id,
409
503
  self.dataset_name,
410
- "", # No shard URL
411
- current_index,
504
+ "",
505
+ start_index,
412
506
  chunk_size,
413
507
  )
414
508
 
415
509
  units_created += 1
416
-
417
- current_index += self.chunk_size
510
+ self.current_chunk_index += 1
418
511
 
419
512
  if units_created > 0:
420
- logger.debug(f"Created {units_created} work units")
421
-
422
- def _subtract_ranges(
423
- self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
424
- ) -> List[Tuple[int, int]]:
425
- """Subtract processed ranges from total ranges."""
426
- if not processed_ranges:
427
- return total_ranges
428
-
429
- # Create a set of all processed indices
430
- processed_indices = set()
431
- for start, end in processed_ranges:
432
- processed_indices.update(range(start, end + 1))
433
-
434
- # Find unprocessed ranges
435
- unprocessed_ranges = []
436
- for start, end in total_ranges:
437
- current_start = None
438
- for i in range(start, end + 1):
439
- if i not in processed_indices:
440
- if current_start is None:
441
- current_start = i
442
- else:
443
- if current_start is not None:
444
- unprocessed_ranges.append((current_start, i - 1))
445
- current_start = None
513
+ logger.debug(f"Created {units_created} work unit IDs")
514
+
515
+ logger.info("Thread for creating units has completed. Exiting thread.")
516
+
517
+ def process_responses_non_blocking(self, response_queue: queue.Queue) -> Optional[WorkResult]:
518
+ """
519
+ Non-blocking method to process responses from workers.
520
+ Returns a WorkResult if one is available, None otherwise.
521
+ """
522
+ # Check for response without blocking
523
+ response = self.queue_handler.check_response(response_queue, timeout=0.1)
524
+
525
+ if response is not None:
526
+ # Process the response
527
+ if isinstance(response, WorkResult):
528
+ logger.debug(f"Processing response for unit {response.unit_id}")
529
+ return response
530
+ else:
531
+ logger.warning(f"Unexpected response type: {type(response)}")
532
+
533
+ # Perform periodic maintenance tasks
534
+ now = datetime.now()
535
+ if (now - self.last_maintenance_time).total_seconds() > self.maintenance_interval:
536
+ self._perform_maintenance()
537
+ self.last_maintenance_time = now
538
+
539
+ return None
540
+
541
+ def _perform_maintenance(self):
542
+ """Perform periodic maintenance tasks."""
543
+ with self.lock:
544
+ # Log current state
545
+ pending_count = len(self.pending_units)
546
+ assigned_count = sum(len(units) for units in self.assigned_units.values())
547
+ logger.debug(f"Maintenance: {pending_count} pending, {assigned_count} assigned units")
446
548
 
447
- if current_start is not None:
448
- unprocessed_ranges.append((current_start, end))
549
+ # Check for stale assignments (workers that might have disconnected)
550
+ # This would be implemented based on your worker heartbeat mechanism
449
551
 
450
- return unprocessed_ranges
552
+ # Force checkpoint save if needed
553
+ if self.chunk_tracker:
554
+ self.chunk_tracker.save_checkpoint()
451
555
 
452
556
  def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
453
557
  """Get available work units for a worker."""
454
- logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
455
- assigned = []
456
558
 
559
+ logger.debug(
560
+ "get_work_units called: count=%d worker_id=%s, pending: %d",
561
+ count,
562
+ worker_id,
563
+ len(self.pending_units),
564
+ )
565
+ assigned = []
457
566
  with self.lock:
458
567
  while len(assigned) < count and self.pending_units:
459
568
  unit_id = self.pending_units.popleft()
460
- unit = self.work_units.get(unit_id)
569
+
570
+ # Create work unit on demand
571
+ chunk_index = int(unit_id.split(":")[-1])
572
+ unit = self._create_work_unit(chunk_index)
461
573
 
462
574
  if unit:
463
575
  self.assigned_units[worker_id].add(unit_id)
@@ -474,22 +586,26 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
474
586
  """Mark a work unit as completed."""
475
587
  logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
476
588
  with self.lock:
477
- if unit_id in self.work_units:
478
- self.assigned_units[worker_id].discard(unit_id)
589
+ self.assigned_units[worker_id].discard(unit_id)
590
+
591
+ if self.chunk_tracker:
592
+ self.chunk_tracker.mark_completed(unit_id)
479
593
 
480
- if self.chunk_tracker:
481
- self.chunk_tracker.mark_completed(unit_id)
594
+ # remove from pending deque if it's there.
595
+ try:
596
+ self.pending_units.remove(unit_id)
597
+ except:
598
+ pass
482
599
 
483
600
  def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
484
601
  """Mark a work unit as failed."""
485
- logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
602
+ logger.error("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
486
603
  with self.lock:
487
- if unit_id in self.work_units:
488
- self.assigned_units[worker_id].discard(unit_id)
489
- self.pending_units.append(unit_id)
604
+ self.assigned_units[worker_id].discard(unit_id)
605
+ self.pending_units.append(unit_id)
490
606
 
491
- if self.chunk_tracker:
492
- self.chunk_tracker.mark_failed(unit_id)
607
+ if self.chunk_tracker:
608
+ self.chunk_tracker.mark_failed(unit_id)
493
609
 
494
610
  def release_assignments(self, worker_id: str) -> None:
495
611
  """Release all assignments for a disconnected worker."""
@@ -498,8 +614,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
498
614
  unit_ids = list(self.assigned_units.get(worker_id, []))
499
615
 
500
616
  for unit_id in unit_ids:
501
- if unit_id in self.work_units:
502
- self.pending_units.append(unit_id)
617
+ logger.debug(f"Adding {unit_id} to pending queue")
618
+ self.pending_units.append(unit_id)
503
619
 
504
620
  if worker_id in self.assigned_units:
505
621
  del self.assigned_units[worker_id]
@@ -509,57 +625,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
509
625
 
510
626
  def update_from_storage(self, processed_job_ids: Set[str]) -> None:
511
627
  """Update work units based on what's been processed."""
512
- logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
513
-
514
- with self.lock:
515
- for unit_id, unit in self.work_units.items():
516
- # Extract chunk info from unit
517
- logger.debug(f"Checking unit {unit_id} for updates")
518
- logger.debug(f"Unit data: {unit.data}")
519
- logger.debug(f"Unit metadata: {unit.metadata}")
520
- start_index = unit.data["start_index"]
521
- chunk_size = unit.data["chunk_size"]
522
- shard_name = unit.metadata["shard_name"]
523
- chunk_index = unit.metadata["chunk_index"]
524
-
525
- # Find processed indices for this chunk
526
- processed_indices = []
527
- for job_id in processed_job_ids:
528
- # Parse job_id format: "data-0000:chunk:0:idx:42"
529
- job_id = JobId.from_str(job_id=job_id)
530
- if job_id.shard_id == shard_name and int(job_id.chunk_id) == chunk_index:
531
- idx = int(job_id.sample_id)
532
- if start_index <= idx < start_index + chunk_size:
533
- processed_indices.append(idx)
534
-
535
- if processed_indices:
536
- # Convert to ranges
537
- processed_indices.sort()
538
- processed_ranges = []
539
- start = processed_indices[0]
540
- end = processed_indices[0]
541
-
542
- for idx in processed_indices[1:]:
543
- if idx == end + 1:
544
- end = idx
545
- else:
546
- processed_ranges.append((start, end))
547
- start = idx
548
- end = idx
549
-
550
- processed_ranges.append((start, end))
551
-
552
- # Calculate unprocessed ranges
553
- total_range = [(start_index, start_index + chunk_size - 1)]
554
- unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
555
-
556
- # Update unit
557
- unit.data["unprocessed_ranges"] = unprocessed_ranges
558
-
559
- logger.debug(
560
- f"Updated unit {unit_id}: {len(processed_indices)} processed, "
561
- f"unprocessed ranges: {unprocessed_ranges}"
562
- )
628
+ logger.info(f"Updating from storage with {len(processed_job_ids)} processed jobs")
629
+ # No need to update in-memory work units since we create on demand
563
630
 
564
631
  def get_stats(self) -> Dict[str, Any]:
565
632
  """Get processor statistics."""
@@ -568,12 +635,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
568
635
  "dataset": self.dataset_name,
569
636
  "config": self.config,
570
637
  "split": self.split,
571
- "total_units": len(self.work_units),
572
638
  "pending_units": len(self.pending_units),
573
639
  "assigned_units": sum(len(units) for units in self.assigned_units.values()),
574
640
  "total_shards": len(self.shard_info),
575
641
  "total_items": self.total_items,
576
642
  "workers": len(self.assigned_units),
643
+ "current_chunk_index": self.current_chunk_index,
577
644
  }
578
645
  return stats
579
646
 
@@ -581,71 +648,111 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
581
648
  """Handle result processing."""
582
649
  base_result = super().handle_result(result)
583
650
 
584
- # Track processed items
585
651
  if self.chunk_tracker:
586
- if "item_indices" not in result.metadata:
587
- result.metadata["item_indices"] = [result.metadata.get("_item_index")]
588
- indices = result.metadata["item_indices"]
652
+ if "item_indices" in result.metadata:
653
+ indices = result.metadata["item_indices"]
654
+ if indices:
655
+ # Convert to ranges for efficient tracking
656
+ indices.sort()
657
+ ranges = []
658
+ start = indices[0]
659
+ end = indices[0]
660
+
661
+ for i in range(1, len(indices)):
662
+ if indices[i] == end + 1:
663
+ end = indices[i]
664
+ else:
665
+ ranges.append((start, end))
666
+ start = indices[i]
667
+ end = indices[i]
589
668
 
590
- if indices:
591
- indices.sort()
592
- ranges = []
593
- start = indices[0]
594
- end = indices[0]
669
+ ranges.append((start, end))
595
670
 
596
- for i in range(1, len(indices)):
597
- if indices[i] == end + 1:
598
- end = indices[i]
599
- else:
600
- ranges.append((start, end))
601
- start = indices[i]
602
- end = indices[i]
671
+ for start_idx, end_idx in ranges:
672
+ self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
603
673
 
604
- ranges.append((start, end))
674
+ return base_result
605
675
 
606
- for start_idx, end_idx in ranges:
607
- self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
676
+ def _subtract_ranges(
677
+ self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
678
+ ) -> List[Tuple[int, int]]:
679
+ """Subtract processed ranges from total ranges."""
680
+ if not processed_ranges:
681
+ return total_ranges
608
682
 
609
- return base_result
683
+ # Create a set of all processed indices
684
+ processed_indices = set()
685
+ for start, end in processed_ranges:
686
+ processed_indices.update(range(start, end + 1))
687
+
688
+ # Find unprocessed ranges
689
+ unprocessed_ranges = []
690
+ for start, end in total_ranges:
691
+ current_start = None
692
+ for i in range(start, end + 1):
693
+ if i not in processed_indices:
694
+ if current_start is None:
695
+ current_start = i
696
+ else:
697
+ if current_start is not None:
698
+ unprocessed_ranges.append((current_start, i - 1))
699
+ current_start = None
700
+
701
+ if current_start is not None:
702
+ unprocessed_ranges.append((current_start, end))
703
+
704
+ return unprocessed_ranges
705
+
706
+ def cleanup(self):
707
+ """Clean up resources."""
708
+ logger.info("Cleaning up orchestrator resources")
709
+
710
+ # Stop background threads
711
+ self.stop_creation.set()
712
+ if self.unit_creation_thread:
713
+ self.unit_creation_thread.join(timeout=5)
714
+
715
+ # Shutdown queue handler
716
+ self.queue_handler.shutdown()
717
+
718
+ # Save final state
719
+ if self.chunk_tracker:
720
+ self.chunk_tracker.save_checkpoint()
610
721
 
611
722
 
612
723
  class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
613
- """Worker processor for HuggingFace datasets."""
724
+ """Memory-optimized worker processor for HuggingFace datasets."""
614
725
 
615
726
  def __init__(self):
616
- logger.debug("Initializing HuggingFaceDatasetWorkerProcessor")
727
+ logger.debug("Initializing HuggingFaceDatasetWorkerProcessor (Optimized)")
617
728
  self.dataset_config: Dict[str, Any] = {}
618
729
  self.token = get_token()
619
- self.shard_cache: Dict[int, Dataset] = {} # Cache loaded shards
620
730
  self.image_column: Optional[str] = None
621
731
  self.url_column: Optional[str] = None
622
732
 
733
+ # Thread-local storage for shard info to avoid repeated builder loading
734
+ self._thread_local = threading.local()
735
+
623
736
  def initialize(self, config: ProcessorConfig) -> None:
624
737
  """Initialize processor."""
625
738
  logger.debug("Initializing worker with config: %s", config.config)
626
739
  self.dataset_config = config.config.get("dataset", {})
627
740
 
628
- # Determine if this is an image URL dataset or binary image dataset
629
741
  self.image_column = self.dataset_config.get("dataset_image_column", "image")
630
742
  self.url_column = self.dataset_config.get("dataset_url_column", "image_url")
631
743
  self.dataset_path = self.dataset_config.get("dataset_path", None)
632
744
 
633
- def _load_shard(self, dataset_name: str, shard_filename: str, shard_id: int) -> Dataset:
634
- """Load a shard if not already cached."""
635
- if shard_id in self.shard_cache:
636
- return self.shard_cache[shard_id]
637
-
638
- logger.info(f"Loading shard {shard_id}: {shard_filename}")
745
+ # Add mock results flag
746
+ self.mock_results = self.dataset_config.get("mock_results", False)
747
+ if self.mock_results:
748
+ logger.info("Mock results mode enabled - will generate dummy images")
639
749
 
640
- local_path = hf_hub_download(
750
+ def _get_shard_path(self, dataset_name: str, shard_filename: str) -> str:
751
+ """Get local path for a shard, downloading if needed."""
752
+ return hf_hub_download(
641
753
  repo_id=dataset_name, filename=shard_filename, repo_type="dataset", token=self.token
642
754
  )
643
755
 
644
- dataset = Dataset.from_parquet(local_path)
645
- self.shard_cache[shard_id] = dataset
646
-
647
- return dataset
648
-
649
756
  def _extract_filename_from_url(self, url: str) -> str:
650
757
  """Extract filename from HF URL format."""
651
758
  match = re.search(r"@[a-f0-9]+/(.+)$", url)
@@ -653,161 +760,227 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
653
760
  return match.group(1)
654
761
  return url.split("/")[-1]
655
762
 
763
+ def _create_dummy_image(self, index: int, metadata: Dict[str, Any]) -> Image.Image:
764
+ """Create a dummy image"""
765
+ color = (0, 0, 0)
766
+ width, height = 128, 128
767
+ image = Image.new("RGB", (width, height), color=color)
768
+
769
+ return image
770
+
656
771
  def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
657
772
  """Process a work unit, yielding items to be captioned."""
658
- logger.debug("Processing unit: %s", unit.unit_id)
773
+ logger.debug("Processing unit: %s (mock_results=%s)", unit.unit_id, self.mock_results)
774
+ log_memory(f"start processing unit {unit.unit_id}")
659
775
 
660
776
  dataset_name = unit.data["dataset_name"]
661
- config = unit.data["config"]
662
- split = unit.data["split"]
663
777
  start_index = unit.data["start_index"]
664
778
  chunk_size = unit.data["chunk_size"]
665
779
  unprocessed_ranges = unit.data.get(
666
780
  "unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
667
781
  )
668
782
  shard_ids = unit.data.get("shard_ids", [])
783
+ data_files = unit.data.get("data_files", [])
669
784
 
670
785
  logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
671
786
 
672
- # Need to get shard info - should be passed in unit data
673
- # For now, we'll need to load dataset builder to get file info
674
- from datasets import load_dataset_builder
675
-
676
- builder = load_dataset_builder(dataset_name, config)
677
-
678
- data_files = []
679
- if hasattr(builder.config, "data_files"):
680
- if isinstance(builder.config.data_files, dict):
681
- files = builder.config.data_files.get(split, [])
682
- if isinstance(files, str):
683
- files = [files]
684
- data_files = files
685
-
686
- # Build shard info
787
+ # Build shard info from provided data files (no dataset builder needed)
687
788
  shard_info = {}
688
- cumulative_offset = 0
689
-
690
- for i, file_url in enumerate(data_files):
691
- if i not in shard_ids:
692
- # Skip loading this shard, but we need its size for offsets
693
- # This is inefficient - in real implementation, orchestrator should pass this info
694
- filename = self._extract_filename_from_url(file_url)
695
- dataset = self._load_shard(dataset_name, filename, i)
696
- size = len(dataset)
697
- cumulative_offset += size
698
- continue
699
789
 
700
- filename = self._extract_filename_from_url(file_url)
701
- dataset = self._load_shard(dataset_name, filename, i)
790
+ if data_files:
791
+ # Use provided data files
792
+ for i, file_url in enumerate(data_files):
793
+ if i in shard_ids:
794
+ filename = self._extract_filename_from_url(file_url)
795
+ shard_path = self._get_shard_path(dataset_name, filename)
796
+
797
+ # Get size from metadata
798
+ metadata = pq.read_metadata(shard_path)
799
+ size = metadata.num_rows
800
+
801
+ shard_info[i] = {
802
+ "path": shard_path,
803
+ "start_offset": 0, # Will be set below
804
+ "end_offset": 0, # Will be set below
805
+ "size": size,
806
+ "metadata": metadata,
807
+ }
702
808
 
703
- shard_info[i] = {
704
- "dataset": dataset,
705
- "start_offset": cumulative_offset,
706
- "end_offset": cumulative_offset + len(dataset) - 1,
707
- "columns": dataset.column_names,
708
- }
709
- cumulative_offset += len(dataset)
809
+ # Calculate offsets
810
+ cumulative_offset = 0
811
+ for i in range(max(shard_info.keys()) + 1):
812
+ if i in shard_info:
813
+ shard_info[i]["start_offset"] = cumulative_offset
814
+ shard_info[i]["end_offset"] = cumulative_offset + shard_info[i]["size"] - 1
815
+ cumulative_offset += shard_info[i]["size"]
816
+ else:
817
+ # Need to get size for offset calculation
818
+ filename = self._extract_filename_from_url(data_files[i])
819
+ shard_path = self._get_shard_path(dataset_name, filename)
820
+ metadata = pq.read_metadata(shard_path)
821
+ cumulative_offset += metadata.num_rows
822
+ else:
823
+ # This should never happen with the new orchestrator
824
+ raise ValueError("No data files provided in work unit")
710
825
 
711
826
  # Create set of indices to process
712
827
  indices_to_process = set()
713
828
  for start, end in unprocessed_ranges:
714
829
  indices_to_process.update(range(start, end + 1))
715
830
 
716
- processed_indices = []
717
-
718
- # Process items
719
- for global_idx in sorted(indices_to_process):
720
- # Find which shard contains this index
721
- shard_id = None
722
- local_idx = None
723
-
724
- for sid, sinfo in shard_info.items():
831
+ # Group indices by shard
832
+ indices_by_shard = defaultdict(list)
833
+ for global_idx in indices_to_process:
834
+ for shard_id, sinfo in shard_info.items():
725
835
  if sinfo["start_offset"] <= global_idx <= sinfo["end_offset"]:
726
- shard_id = sid
727
836
  local_idx = global_idx - sinfo["start_offset"]
837
+ indices_by_shard[shard_id].append((global_idx, local_idx))
728
838
  break
729
839
 
730
- if shard_id is None:
731
- logger.warning(f"Could not find shard for global index {global_idx}")
732
- continue
733
-
734
- try:
735
- # Get item from shard
736
- item = shard_info[shard_id]["dataset"][local_idx]
737
-
738
- # Check if this is a URL dataset or binary image dataset
739
- image = None
740
- image_url = None
741
-
742
- # Try URL column first
743
- if self.url_column and self.url_column in item:
744
- image_url = item[self.url_column]
745
- # Download image from URL
746
- try:
747
- response = requests.get(image_url, timeout=30)
748
- response.raise_for_status()
749
- image = Image.open(io.BytesIO(response.content))
750
- except Exception as e:
751
- logger.error(f"Error downloading image from {image_url}: {e}")
752
- continue
753
-
754
- # Try binary image column
755
- elif self.image_column and self.image_column in item:
756
- image_data = item[self.image_column]
757
- if isinstance(image_data, Image.Image):
758
- image = image_data
759
- elif isinstance(image_data, dict) and "bytes" in image_data:
760
- # Handle datasets Image feature
761
- image = Image.open(io.BytesIO(image_data["bytes"]))
762
- elif isinstance(image_data, bytes):
763
- image = Image.open(io.BytesIO(image_data))
764
-
765
- if image is None:
766
- logger.warning(f"No image found for item at index {global_idx}")
767
- continue
768
-
769
- # Build job ID
770
- chunk_index = unit.metadata["chunk_index"]
771
- shard_name = unit.metadata["shard_name"]
772
- job_id_obj = JobId(
773
- shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(global_idx)
774
- )
775
- job_id = job_id_obj.get_sample_str()
776
-
777
- # Clean metadata
778
- clean_metadata = {
779
- k: v
780
- for k, v in item.items()
781
- if k not in [self.image_column, self.url_column] and not k.startswith("_")
782
- }
783
-
784
- clean_metadata.update(
785
- {
786
- "_item_index": global_idx,
787
- "_chunk_relative_index": global_idx - start_index,
788
- "_job_id": job_id,
789
- "_shard_id": shard_id,
790
- "_local_index": local_idx,
791
- "_url": image_url,
792
- }
793
- )
794
-
795
- yield {
796
- "image": image,
797
- "item_key": str(global_idx),
798
- "item_index": global_idx,
799
- "metadata": clean_metadata,
800
- "job_id": job_id,
801
- }
802
-
803
- processed_indices.append(global_idx)
840
+ processed_indices = []
804
841
 
805
- except Exception as e:
806
- logger.error(f"Error processing item at index {global_idx}: {e}")
842
+ # Process items shard by shard
843
+ for shard_id, idx_pairs in indices_by_shard.items():
844
+ shard_path = shard_info[shard_id]["path"]
845
+
846
+ # Process in batches to avoid loading entire table
847
+ batch_size = 100
848
+ for batch_start in range(0, len(idx_pairs), batch_size):
849
+ batch_pairs = idx_pairs[batch_start : batch_start + batch_size]
850
+ local_indices = [local_idx for _, local_idx in batch_pairs]
851
+
852
+ # Read only specific rows using PyArrow
853
+ try:
854
+ # Create row group filters based on metadata
855
+ metadata = shard_info[shard_id]["metadata"]
856
+ row_groups_to_read = set()
857
+
858
+ # Find which row groups contain our indices
859
+ current_row = 0
860
+ for rg_idx in range(metadata.num_row_groups):
861
+ rg_metadata = metadata.row_group(rg_idx)
862
+ rg_num_rows = rg_metadata.num_rows
863
+
864
+ # Check if any of our indices are in this row group
865
+ for local_idx in local_indices:
866
+ if current_row <= local_idx < current_row + rg_num_rows:
867
+ row_groups_to_read.add(rg_idx)
868
+
869
+ current_row += rg_num_rows
870
+
871
+ # Read only necessary row groups
872
+ parquet_file = pq.ParquetFile(shard_path)
873
+ table = parquet_file.read_row_groups(list(row_groups_to_read))
874
+
875
+ # Process items
876
+ for global_idx, local_idx in batch_pairs:
877
+ try:
878
+ # Get item as dictionary (efficient row extraction)
879
+ row_dict = table.slice(local_idx, 1).to_pydict()
880
+ item = {k: v[0] for k, v in row_dict.items()}
881
+
882
+ # Process image
883
+ image = None
884
+ image_url = None
885
+
886
+ if self.mock_results:
887
+ # In mock mode, create a dummy image
888
+ logger.debug(f"Creating mock image for index {global_idx}")
889
+
890
+ # Still extract URL if available for metadata
891
+ if self.url_column and self.url_column in item:
892
+ image_url = item[self.url_column]
893
+
894
+ # Create dummy image with metadata context
895
+ image = self._create_dummy_image(
896
+ global_idx,
897
+ {
898
+ "_shard_id": shard_id,
899
+ "_local_index": local_idx,
900
+ },
901
+ )
902
+ else:
903
+ # Normal processing - load real images
904
+ if self.url_column and self.url_column in item:
905
+ image_url = item[self.url_column]
906
+ try:
907
+ response = requests.get(image_url, timeout=30)
908
+ response.raise_for_status()
909
+ image = Image.open(io.BytesIO(response.content))
910
+ except Exception as e:
911
+ logger.error(
912
+ f"Error downloading image from {image_url}: {e}"
913
+ )
914
+ continue
915
+
916
+ elif self.image_column and self.image_column in item:
917
+ image_data = item[self.image_column]
918
+ if isinstance(image_data, dict) and "bytes" in image_data:
919
+ image = Image.open(io.BytesIO(image_data["bytes"]))
920
+ elif isinstance(image_data, bytes):
921
+ image = Image.open(io.BytesIO(image_data))
922
+
923
+ if image is None:
924
+ logger.warning(f"No image found for item at index {global_idx}")
925
+ continue
926
+
927
+ # Build job ID
928
+ chunk_index = unit.metadata["chunk_index"]
929
+ shard_name = unit.metadata["shard_name"]
930
+ job_id_obj = JobId(
931
+ shard_id=shard_name,
932
+ chunk_id=str(chunk_index),
933
+ sample_id=str(global_idx),
934
+ )
935
+ job_id = job_id_obj.get_sample_str()
936
+
937
+ # Clean metadata
938
+ clean_metadata = {
939
+ k: v
940
+ for k, v in item.items()
941
+ if k not in [self.image_column, self.url_column]
942
+ and not k.startswith("_")
943
+ }
944
+
945
+ clean_metadata.update(
946
+ {
947
+ "_item_index": global_idx,
948
+ "_chunk_relative_index": global_idx - start_index,
949
+ "_job_id": job_id,
950
+ "_shard_id": shard_id,
951
+ "_local_index": local_idx,
952
+ "_url": image_url,
953
+ "_mock": self.mock_results, # Add flag to indicate mock data
954
+ }
955
+ )
956
+
957
+ yield {
958
+ "image": image,
959
+ "item_key": str(global_idx),
960
+ "item_index": global_idx,
961
+ "metadata": clean_metadata,
962
+ "job_id": job_id,
963
+ "_processed_indices": processed_indices,
964
+ }
965
+
966
+ processed_indices.append(global_idx)
967
+
968
+ except Exception as e:
969
+ logger.error(f"Error processing item at index {global_idx}: {e}")
970
+
971
+ # Explicitly delete table to free memory
972
+ del table
973
+ gc.collect()
974
+
975
+ except Exception as e:
976
+ logger.error(f"Error reading batch from shard {shard_id}: {e}")
807
977
 
808
978
  # Store processed indices in context
809
979
  context["_processed_indices"] = processed_indices
810
- logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
980
+ logger.debug(
981
+ f"Processed {len(processed_indices)} indices for unit {unit.unit_id}: {processed_indices}, {context}"
982
+ )
983
+ log_memory(f"end processing unit {unit.unit_id}")
811
984
 
812
985
  def prepare_result(
813
986
  self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float