caption-flow 0.2.4__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,22 +1,24 @@
1
- """HuggingFace Datasets processor implementation."""
1
+ """HuggingFace Datasets processor implementation - Memory Optimized Version."""
2
2
 
3
3
  import logging
4
4
  import threading
5
5
  import re
6
+ import queue
6
7
  import requests
8
+ import json
9
+ import io
10
+ import os
11
+ import gc
12
+ import psutil
13
+ from concurrent.futures import ThreadPoolExecutor, Future
7
14
  from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
8
15
  from collections import deque, defaultdict
9
16
  from pathlib import Path
10
- import json
11
- import io
12
17
  from datetime import datetime
13
18
  from PIL import Image
14
- from datasets import (
15
- Dataset,
16
- get_dataset_config_names,
17
- get_dataset_split_names,
18
- load_dataset_builder,
19
- )
19
+ import pyarrow as pa
20
+ import pyarrow.parquet as pq
21
+ from datasets import get_dataset_config_names, get_dataset_split_names
20
22
  from huggingface_hub import hf_hub_download, get_token
21
23
  from caption_flow.storage import StorageManager
22
24
 
@@ -28,11 +30,82 @@ logger = logging.getLogger(__name__)
28
30
  logger.setLevel(logging.DEBUG)
29
31
 
30
32
 
33
+ def log_memory(location: str):
34
+ """Log memory usage at specific location."""
35
+ process = psutil.Process(os.getpid())
36
+ mem_info = process.memory_info()
37
+ logger.info(
38
+ f"Memory at {location}: RSS={mem_info.rss/1024/1024:.1f}MB, VMS={mem_info.vms/1024/1024:.1f}MB"
39
+ )
40
+ # Force garbage collection
41
+ gc.collect()
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(logging.DEBUG)
46
+
47
+
48
+ class NonBlockingQueueHandler:
49
+ """Handles non-blocking retrieval from queues using concurrent futures."""
50
+
51
+ def __init__(self, max_workers: int = 1):
52
+ self.executor = ThreadPoolExecutor(max_workers=max_workers)
53
+ self.pending_futures: Dict[int, Future] = {} # queue_id -> Future
54
+
55
+ def get_from_queue_async(self, response_queue: queue.Queue, timeout: float = None) -> Future:
56
+ """Start an async queue retrieval."""
57
+ queue_id = id(response_queue)
58
+
59
+ # Check if we already have a pending future for this queue
60
+ if queue_id in self.pending_futures and not self.pending_futures[queue_id].done():
61
+ return self.pending_futures[queue_id]
62
+
63
+ # Start new async retrieval
64
+ future = self.executor.submit(response_queue.get, timeout=timeout)
65
+ self.pending_futures[queue_id] = future
66
+ return future
67
+
68
+ def check_response(self, response_queue: queue.Queue, timeout: float = None) -> Optional[Any]:
69
+ """Non-blocking check for queue response."""
70
+ queue_id = id(response_queue)
71
+
72
+ # Start async retrieval if needed
73
+ future = self.get_from_queue_async(response_queue, timeout)
74
+
75
+ # Check if result is ready (non-blocking)
76
+ if future.done():
77
+ try:
78
+ result = future.result(timeout=0)
79
+ # Clear future for next retrieval
80
+ if queue_id in self.pending_futures:
81
+ del self.pending_futures[queue_id]
82
+ return result
83
+ except queue.Empty:
84
+ # Queue was empty, clear future
85
+ if queue_id in self.pending_futures:
86
+ del self.pending_futures[queue_id]
87
+ return None
88
+ except Exception as e:
89
+ logger.error(f"Error retrieving from queue: {e}")
90
+ if queue_id in self.pending_futures:
91
+ del self.pending_futures[queue_id]
92
+ return None
93
+
94
+ # Result not ready yet
95
+ return None
96
+
97
+ def shutdown(self):
98
+ """Shutdown the executor."""
99
+ self.executor.shutdown(wait=True)
100
+
101
+
31
102
  class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
32
- """Orchestrator processor for HuggingFace datasets."""
103
+ """Memory-optimized orchestrator processor for HuggingFace datasets with non-blocking operations."""
33
104
 
34
105
  def __init__(self):
35
- logger.debug("Initializing HuggingFaceDatasetOrchestratorProcessor")
106
+ logger.debug(
107
+ "Initializing HuggingFaceDatasetOrchestratorProcessor (Optimized + Non-blocking)"
108
+ )
36
109
  self.dataset_name: Optional[str] = None
37
110
  self.config: Optional[str] = None
38
111
  self.split: Optional[str] = None
@@ -44,19 +117,33 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
44
117
  self.shard_info: Dict[int, Dict[str, Any]] = {}
45
118
  self.total_items: int = 0
46
119
 
47
- # Work unit management
48
- self.work_units: Dict[str, WorkUnit] = {}
120
+ # Work unit management - only store active units
49
121
  self.pending_units: Deque[str] = deque()
50
- self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
122
+ self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
51
123
  self.lock = threading.Lock()
52
124
 
125
+ # Track current chunk index for on-demand creation
126
+ self.current_chunk_index = 0
127
+
128
+ # Cache data files info instead of loading builder repeatedly
129
+ self.data_files: List[str] = []
130
+
53
131
  # Background thread for creating work units
54
132
  self.unit_creation_thread: Optional[threading.Thread] = None
55
133
  self.stop_creation = threading.Event()
56
134
 
135
+ # Non-blocking queue handler
136
+ self.queue_handler = NonBlockingQueueHandler()
137
+
138
+ # Response processing state
139
+ self.last_maintenance_time = datetime.now()
140
+ self.maintenance_interval = 30 # seconds
141
+
57
142
  def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
58
143
  """Initialize HuggingFace dataset processor."""
59
144
  logger.debug("Initializing orchestrator with config: %s", config.config)
145
+ log_memory("start of initialize")
146
+
60
147
  cfg = config.config
61
148
 
62
149
  # Dataset configuration
@@ -83,12 +170,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
83
170
  self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
84
171
 
85
172
  # Initialize chunk tracking
86
- checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
87
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
88
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
173
+ self.checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
174
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
175
+ self.chunk_tracker = ChunkTracker(self.checkpoint_dir / "chunks.json")
89
176
 
90
- # Discover shards
91
- self._discover_shards()
177
+ # Discover shards (optimized)
178
+ self._discover_shards_optimized()
92
179
 
93
180
  # Restore existing state
94
181
  self._restore_state(storage=storage)
@@ -98,7 +185,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
98
185
  target=self._create_units_background, daemon=True
99
186
  )
100
187
  self.unit_creation_thread.start()
101
- logger.debug("Unit creation thread started")
188
+
189
+ log_memory("end of initialize")
102
190
 
103
191
  def _detect_config(self, provided_config: Optional[str]) -> str:
104
192
  """Auto-detect config if not provided."""
@@ -110,14 +198,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
110
198
  if not configs:
111
199
  return "default"
112
200
 
113
- # Prefer common config names
114
201
  preferred = ["default", "en", "train", "main"]
115
202
  for pref in preferred:
116
203
  if pref in configs:
117
204
  logger.info(f"Auto-selected config: {pref}")
118
205
  return pref
119
206
 
120
- # Otherwise use first available
121
207
  logger.info(f"Auto-selected first available config: {configs[0]}")
122
208
  return configs[0]
123
209
  except Exception as e:
@@ -134,17 +220,14 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
134
220
  self.dataset_name, config_name=self.config, token=self.token
135
221
  )
136
222
  if not splits:
137
- logger.warning("No splits found, using 'train'")
138
223
  return "train"
139
224
 
140
- # Prefer training splits
141
225
  preferred = ["train", "training", "test", "validation", "dev"]
142
226
  for pref in preferred:
143
227
  if pref in splits:
144
228
  logger.info(f"Auto-selected split: {pref}")
145
229
  return pref
146
230
 
147
- # Otherwise use first available
148
231
  logger.info(f"Auto-selected first available split: {splits[0]}")
149
232
  return splits[0]
150
233
  except Exception as e:
@@ -153,18 +236,16 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
153
236
 
154
237
  def _extract_filename_from_url(self, url: str) -> str:
155
238
  """Extract filename from HF URL format."""
156
- # Format: hf://datasets/user/dataset@hash/filename
157
239
  match = re.search(r"@[a-f0-9]+/(.+)$", url)
158
240
  if match:
159
241
  return match.group(1)
160
- # Fallback: just get last part
161
242
  return url.split("/")[-1]
162
243
 
163
- def _discover_shards(self):
164
- """Discover all shards and their sizes."""
165
- logger.info("Discovering shards...")
244
+ def _get_data_files_from_builder(self) -> List[str]:
245
+ """Get data files using dataset builder with minimal memory usage."""
246
+ # Load builder to get correct file structure
247
+ from datasets import load_dataset_builder
166
248
 
167
- # Load dataset builder to get file info
168
249
  builder = load_dataset_builder(self.dataset_name, self.config)
169
250
 
170
251
  # Get data files for our split
@@ -176,81 +257,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
176
257
  files = [files]
177
258
  data_files = files
178
259
 
179
- if not data_files:
260
+ # Explicitly delete builder to free memory
261
+ del builder
262
+ gc.collect()
263
+
264
+ return data_files
265
+
266
+ def _discover_shards_optimized(self):
267
+ """Discover all shards using dataset builder but release memory immediately."""
268
+ logger.info("Discovering shards...")
269
+
270
+ # Try to load cached shard info first
271
+ shard_info_cache_path = (
272
+ self.checkpoint_dir / f"{self.dataset_name}_{self.config}_{self.split}_shard_info.json"
273
+ )
274
+
275
+ if shard_info_cache_path.exists():
276
+ try:
277
+ with open(shard_info_cache_path, "r") as f:
278
+ cached_info = json.load(f)
279
+ if (
280
+ cached_info.get("dataset") == self.dataset_name
281
+ and cached_info.get("config") == self.config
282
+ and cached_info.get("split") == self.split
283
+ ):
284
+ self.shard_info = {int(k): v for k, v in cached_info["shards"].items()}
285
+ self.total_items = cached_info["total_items"]
286
+ self.data_files = cached_info.get("data_files", [])
287
+ logger.info(
288
+ f"Loaded cached shard info: {len(self.shard_info)} shards, {self.total_items} total items"
289
+ )
290
+ return
291
+ except Exception as e:
292
+ logger.warning(f"Failed to load cached shard info: {e}")
293
+
294
+ # Get data files using dataset builder
295
+ self.data_files = self._get_data_files_from_builder()
296
+
297
+ if not self.data_files:
180
298
  raise ValueError(f"No data files found for split '{self.split}'")
181
299
 
182
- logger.info(f"Found {len(data_files)} data files")
300
+ logger.info(f"Found {len(self.data_files)} data files")
183
301
 
184
- # Get info about each shard
302
+ # Get metadata for each shard
185
303
  cumulative_offset = 0
186
- for i, file_url in enumerate(data_files):
304
+ for i, file_url in enumerate(self.data_files):
187
305
  filename = self._extract_filename_from_url(file_url)
188
306
  logger.info(f"Discovering shard {i}: {filename}")
189
307
 
190
- # We don't download shards here - workers will do that
191
- # For now, store the info we have
192
- self.shard_info[i] = {
193
- "shard_id": i,
194
- "file_url": file_url,
195
- "filename": filename,
196
- "start_offset": cumulative_offset,
197
- # Size will be determined when first worker needs it
198
- "size": None,
199
- "end_offset": None,
200
- }
201
-
202
- # Try to get size from builder info if available
203
- if hasattr(builder.info, "splits") and self.split in builder.info.splits:
204
- split_info = builder.info.splits[self.split]
205
- if split_info.num_examples and len(data_files) == 1:
206
- # Single shard case
207
- self.shard_info[i]["size"] = split_info.num_examples
208
- self.shard_info[i]["end_offset"] = (
209
- cumulative_offset + split_info.num_examples - 1
210
- )
211
- cumulative_offset += split_info.num_examples
212
-
213
- # If we couldn't get sizes, we'll need to load shards on demand
214
- if self.shard_info[0]["size"] is None:
215
- logger.warning("Shard sizes not available from metadata, will load on demand")
216
- else:
217
- self.total_items = cumulative_offset
218
- logger.info(f"Total items across all shards: {self.total_items}")
219
-
220
- def _get_shard_size(self, shard_id: int) -> int:
221
- """Get size of a shard, loading it if necessary."""
222
- if self.shard_info[shard_id]["size"] is not None:
223
- return self.shard_info[shard_id]["size"]
308
+ try:
309
+ # Download file to get metadata
310
+ local_path = hf_hub_download(
311
+ repo_id=self.dataset_name,
312
+ filename=filename,
313
+ repo_type="dataset",
314
+ token=self.token,
315
+ )
224
316
 
225
- # Need to load the shard to get its size
226
- logger.info(f"Loading shard {shard_id} to determine size...")
227
- filename = self.shard_info[shard_id]["filename"]
317
+ # Read only metadata
318
+ metadata = pq.read_metadata(local_path)
319
+ size = metadata.num_rows
320
+
321
+ self.shard_info[i] = {
322
+ "shard_id": i,
323
+ "file_url": file_url,
324
+ "filename": filename,
325
+ "start_offset": cumulative_offset,
326
+ "size": size,
327
+ "end_offset": cumulative_offset + size - 1,
328
+ }
228
329
 
229
- local_path = hf_hub_download(
230
- repo_id=self.dataset_name, filename=filename, repo_type="dataset", token=self.token
231
- )
330
+ cumulative_offset += size
331
+ logger.info(f"Shard {i} ({filename}): {size} rows")
232
332
 
233
- # Load just to get size
234
- dataset = Dataset.from_parquet(local_path)
235
- size = len(dataset)
333
+ except Exception as e:
334
+ logger.error(f"Failed to discover shard {i}: {e}")
335
+ # Skip this shard
336
+ continue
236
337
 
237
- # Update shard info
238
- self.shard_info[shard_id]["size"] = size
338
+ self.total_items = cumulative_offset
339
+ logger.info(f"Total items across all shards: {self.total_items}")
239
340
 
240
- # Update offsets for this and subsequent shards
241
- for sid in range(shard_id, len(self.shard_info)):
242
- if sid > shard_id:
243
- self.shard_info[sid]["start_offset"] = self.shard_info[sid - 1]["end_offset"] + 1
244
- self.shard_info[sid]["end_offset"] = (
245
- self.shard_info[sid]["start_offset"] + self.shard_info[sid]["size"] - 1
246
- )
341
+ # Cache shard info
342
+ try:
343
+ cache_data = {
344
+ "dataset": self.dataset_name,
345
+ "config": self.config,
346
+ "split": self.split,
347
+ "shards": self.shard_info,
348
+ "total_items": self.total_items,
349
+ "data_files": self.data_files,
350
+ }
351
+ with open(shard_info_cache_path, "w") as f:
352
+ json.dump(cache_data, f)
353
+ logger.info(f"Cached shard info to {shard_info_cache_path}")
354
+ except Exception as e:
355
+ logger.warning(f"Failed to cache shard info: {e}")
247
356
 
248
- # Update total items
249
- if all(s["size"] is not None for s in self.shard_info.values()):
250
- self.total_items = sum(s["size"] for s in self.shard_info.values())
251
- logger.info(f"Total items: {self.total_items}")
357
+ # Force garbage collection
358
+ gc.collect()
359
+ log_memory("after discovering shards")
252
360
 
253
- return size
361
+ def _get_shard_for_index(self, global_index: int) -> Tuple[int, int]:
362
+ """Get shard ID and local index for a global index."""
363
+ for shard_id, sinfo in self.shard_info.items():
364
+ if sinfo["start_offset"] <= global_index <= sinfo["end_offset"]:
365
+ local_index = global_index - sinfo["start_offset"]
366
+ return shard_id, local_index
367
+ raise ValueError(f"Global index {global_index} not found in any shard")
254
368
 
255
369
  def _restore_state(self, storage: StorageManager) -> None:
256
370
  """Restore state from chunk tracker."""
@@ -258,73 +372,83 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
258
372
  if not self.chunk_tracker:
259
373
  return
260
374
 
261
- all_processed_jobs = storage.get_all_processed_job_ids()
262
-
263
375
  with self.lock:
376
+ max_chunk_index = -1
377
+
264
378
  for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
265
- # Calculate actual unprocessed ranges
266
- chunk_range = (
267
- chunk_state.start_index,
268
- chunk_state.start_index + chunk_state.chunk_size - 1,
269
- )
379
+ chunk_index = chunk_state.start_index // self.chunk_size
380
+ max_chunk_index = max(max_chunk_index, chunk_index)
381
+
382
+ # Only add incomplete chunks to pending
383
+ if chunk_state.status != "completed":
384
+ self.pending_units.append(chunk_id)
385
+ elif chunk_state.status == "completed" and chunk_state.processed_ranges:
386
+ logger.warning(
387
+ f"Chunk {chunk_id} has processed_ranges stored in the checkpoint."
388
+ )
389
+
390
+ self.current_chunk_index = max_chunk_index + 1
391
+ logger.info(f"Resuming from chunk index {self.current_chunk_index}")
270
392
 
271
- # Get processed indices for this chunk
272
- processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
273
- chunk_id, all_processed_jobs
393
+ def _create_work_unit(self, chunk_index: int) -> Optional[WorkUnit]:
394
+ """Create a single work unit for a chunk index."""
395
+ current_index = chunk_index * self.chunk_size
396
+
397
+ if current_index >= self.total_items:
398
+ return None
399
+
400
+ chunk_size = min(self.chunk_size, self.total_items - current_index)
401
+
402
+ # Find shard for this chunk
403
+ shard_id, _ = self._get_shard_for_index(current_index)
404
+ shard_name = Path(self.shard_info[shard_id]["filename"]).stem
405
+
406
+ job_id_obj = JobId(shard_id=shard_name, chunk_id=chunk_index, sample_id=current_index)
407
+ unit_id = job_id_obj.get_chunk_str()
408
+
409
+ # Calculate unprocessed ranges based on existing chunk state
410
+ unprocessed_ranges = [(current_index, current_index + chunk_size - 1)]
411
+
412
+ if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
413
+ chunk_state = self.chunk_tracker.chunks[unit_id]
414
+ if chunk_state.processed_ranges:
415
+ # Subtract processed ranges from total range
416
+ unprocessed_ranges = self._subtract_ranges(
417
+ [(current_index, current_index + chunk_size - 1)], chunk_state.processed_ranges
274
418
  )
275
419
 
276
- # Calculate unprocessed ranges
277
- unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
278
-
279
- if unprocessed_ranges:
280
- # Find which shard(s) this chunk belongs to
281
- shard_ids = []
282
- for sid, sinfo in self.shard_info.items():
283
- # Need size to check
284
- if sinfo["size"] is None:
285
- self._get_shard_size(sid)
286
-
287
- if (
288
- sinfo["start_offset"]
289
- <= chunk_state.start_index + chunk_state.chunk_size - 1
290
- and sinfo["end_offset"] >= chunk_state.start_index
291
- ):
292
- shard_ids.append(sid)
293
- logger.info(f"Found shard {sid} for chunk {chunk_id}: {sinfo}")
294
-
295
- chunk_index = chunk_state.start_index // self.chunk_size
296
- shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
297
- unit = WorkUnit(
298
- unit_id=chunk_id,
299
- chunk_id=chunk_id,
300
- source_id=shard_name,
301
- data={
302
- "dataset_name": self.dataset_name,
303
- "config": self.config,
304
- "split": self.split,
305
- "start_index": chunk_state.start_index,
306
- "chunk_size": chunk_state.chunk_size,
307
- "unprocessed_ranges": unprocessed_ranges,
308
- "shard_ids": shard_ids,
309
- },
310
- metadata={
311
- "dataset": self.dataset_name,
312
- "shard_name": shard_name,
313
- "chunk_index": chunk_index,
314
- },
315
- )
420
+ # If all ranges are processed, return None (shouldn't happen if status tracking is correct)
421
+ if not unprocessed_ranges:
422
+ return None
316
423
 
317
- self.work_units[unit.unit_id] = unit
318
- self.pending_units.append(unit.unit_id)
424
+ unit = WorkUnit(
425
+ unit_id=unit_id,
426
+ chunk_id=unit_id,
427
+ source_id=shard_name,
428
+ data={
429
+ "dataset_name": self.dataset_name,
430
+ "config": self.config,
431
+ "split": self.split,
432
+ "start_index": current_index,
433
+ "chunk_size": chunk_size,
434
+ "unprocessed_ranges": unprocessed_ranges, # Use calculated ranges
435
+ "shard_ids": [shard_id],
436
+ "data_files": self.data_files,
437
+ },
438
+ metadata={
439
+ "dataset": self.dataset_name,
440
+ "shard_name": shard_name,
441
+ "chunk_index": chunk_index,
442
+ },
443
+ )
444
+
445
+ return unit
319
446
 
320
447
  def _create_units_background(self) -> None:
321
448
  """Background thread to create work units on demand."""
322
449
  logger.info("Starting work unit creation thread")
323
450
 
324
- current_index = 0
325
-
326
451
  while not self.stop_creation.is_set():
327
- # Check if we need more units
328
452
  with self.lock:
329
453
  pending_count = len(self.pending_units)
330
454
  assigned_count = sum(len(units) for units in self.assigned_units.values())
@@ -337,127 +461,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
337
461
  threading.Event().wait(5)
338
462
  continue
339
463
 
340
- # Make sure we know total items
341
- if self.total_items == 0:
342
- # Load all shard sizes
343
- for sid in range(len(self.shard_info)):
344
- self._get_shard_size(sid)
345
-
346
464
  # Create units as needed
347
465
  units_created = 0
348
466
 
349
- while units_created < units_needed and current_index < self.total_items:
350
- chunk_size = min(self.chunk_size, self.total_items - current_index)
351
- chunk_id = current_index // self.chunk_size
352
-
353
- with self.lock:
354
- shard_ids = []
355
- for sid, sinfo in self.shard_info.items():
356
- if (
357
- sinfo["start_offset"] <= current_index + chunk_size - 1
358
- and sinfo["end_offset"] >= current_index
359
- ):
360
- shard_ids.append(sid)
361
- shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
467
+ while units_created < units_needed:
468
+ logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
469
+ if self.current_chunk_index * self.chunk_size >= self.total_items:
470
+ threading.Event().wait(30)
471
+ break
472
+ # Get shard info for proper unit_id
473
+ current_index = self.current_chunk_index * self.chunk_size
474
+ if current_index < self.total_items:
475
+ shard_id, _ = self._get_shard_for_index(current_index)
476
+ shard_name = Path(self.shard_info[shard_id]["filename"]).stem
362
477
 
363
478
  job_id_obj = JobId(
364
- shard_id=shard_name, chunk_id=chunk_id, sample_id=current_index
479
+ shard_id=shard_name,
480
+ chunk_id=self.current_chunk_index,
481
+ sample_id=current_index,
365
482
  )
366
- unit_id = (
367
- job_id_obj.get_chunk_str()
368
- ) # just the chunk part, eg pixel-images:chunk:0
369
- if unit_id in self.work_units:
370
- current_index += self.chunk_size
371
- continue
372
-
373
- # Check if chunk is already completed
374
- if self.chunk_tracker:
375
- chunk_state = self.chunk_tracker.chunks.get(unit_id)
376
- if chunk_state and chunk_state.status == "completed":
377
- current_index += self.chunk_size
378
- continue
483
+ unit_id = job_id_obj.get_chunk_str()
379
484
 
380
- # Find which shard(s) this chunk belongs to
381
-
382
- unit = WorkUnit(
383
- unit_id=unit_id,
384
- chunk_id=unit_id,
385
- source_id=shard_name,
386
- data={
387
- "dataset_name": self.dataset_name,
388
- "config": self.config,
389
- "split": self.split,
390
- "start_index": current_index,
391
- "chunk_size": chunk_size,
392
- "unprocessed_ranges": [(current_index, current_index + chunk_size - 1)],
393
- "shard_ids": shard_ids,
394
- },
395
- metadata={
396
- "dataset": self.dataset_name,
397
- "shard_name": shard_name,
398
- "chunk_index": chunk_id,
399
- },
400
- )
401
- logger.debug(f"Created WorkUnit: {unit}")
485
+ with self.lock:
486
+ # Check if already tracked
487
+ if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
488
+ chunk_state = self.chunk_tracker.chunks[unit_id]
489
+ if chunk_state.status == "completed":
490
+ self.current_chunk_index += 1
491
+ continue
402
492
 
403
- self.work_units[unit_id] = unit
493
+ # Add to pending
404
494
  self.pending_units.append(unit_id)
405
495
 
496
+ # Track in chunk tracker
406
497
  if self.chunk_tracker:
498
+ start_index = self.current_chunk_index * self.chunk_size
499
+ chunk_size = min(self.chunk_size, self.total_items - start_index)
407
500
  self.chunk_tracker.add_chunk(
408
501
  unit_id,
409
502
  self.dataset_name,
410
- "", # No shard URL
411
- current_index,
503
+ "",
504
+ start_index,
412
505
  chunk_size,
413
506
  )
414
507
 
415
508
  units_created += 1
416
-
417
- current_index += self.chunk_size
509
+ self.current_chunk_index += 1
418
510
 
419
511
  if units_created > 0:
420
- logger.debug(f"Created {units_created} work units")
421
-
422
- def _subtract_ranges(
423
- self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
424
- ) -> List[Tuple[int, int]]:
425
- """Subtract processed ranges from total ranges."""
426
- if not processed_ranges:
427
- return total_ranges
428
-
429
- # Create a set of all processed indices
430
- processed_indices = set()
431
- for start, end in processed_ranges:
432
- processed_indices.update(range(start, end + 1))
433
-
434
- # Find unprocessed ranges
435
- unprocessed_ranges = []
436
- for start, end in total_ranges:
437
- current_start = None
438
- for i in range(start, end + 1):
439
- if i not in processed_indices:
440
- if current_start is None:
441
- current_start = i
442
- else:
443
- if current_start is not None:
444
- unprocessed_ranges.append((current_start, i - 1))
445
- current_start = None
512
+ logger.debug(f"Created {units_created} work unit IDs")
513
+
514
+ logger.info("Thread for creating units has completed. Exiting thread.")
515
+
516
+ def process_responses_non_blocking(self, response_queue: queue.Queue) -> Optional[WorkResult]:
517
+ """
518
+ Non-blocking method to process responses from workers.
519
+ Returns a WorkResult if one is available, None otherwise.
520
+ """
521
+ # Check for response without blocking
522
+ response = self.queue_handler.check_response(response_queue, timeout=0.1)
523
+
524
+ if response is not None:
525
+ # Process the response
526
+ if isinstance(response, WorkResult):
527
+ logger.debug(f"Processing response for unit {response.unit_id}")
528
+ return response
529
+ else:
530
+ logger.warning(f"Unexpected response type: {type(response)}")
531
+
532
+ # Perform periodic maintenance tasks
533
+ now = datetime.now()
534
+ if (now - self.last_maintenance_time).total_seconds() > self.maintenance_interval:
535
+ self._perform_maintenance()
536
+ self.last_maintenance_time = now
537
+
538
+ return None
539
+
540
+ def _perform_maintenance(self):
541
+ """Perform periodic maintenance tasks."""
542
+ with self.lock:
543
+ # Log current state
544
+ pending_count = len(self.pending_units)
545
+ assigned_count = sum(len(units) for units in self.assigned_units.values())
546
+ logger.debug(f"Maintenance: {pending_count} pending, {assigned_count} assigned units")
446
547
 
447
- if current_start is not None:
448
- unprocessed_ranges.append((current_start, end))
548
+ # Check for stale assignments (workers that might have disconnected)
549
+ # This would be implemented based on your worker heartbeat mechanism
449
550
 
450
- return unprocessed_ranges
551
+ # Force checkpoint save if needed
552
+ if self.chunk_tracker:
553
+ self.chunk_tracker.save_checkpoint()
451
554
 
452
555
  def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
453
556
  """Get available work units for a worker."""
454
- logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
455
- assigned = []
456
557
 
558
+ logger.debug(
559
+ "get_work_units called: count=%d worker_id=%s, pending: %d",
560
+ count,
561
+ worker_id,
562
+ len(self.pending_units),
563
+ )
564
+ assigned = []
457
565
  with self.lock:
458
566
  while len(assigned) < count and self.pending_units:
459
567
  unit_id = self.pending_units.popleft()
460
- unit = self.work_units.get(unit_id)
568
+
569
+ # Create work unit on demand
570
+ chunk_index = int(unit_id.split(":")[-1])
571
+ unit = self._create_work_unit(chunk_index)
461
572
 
462
573
  if unit:
463
574
  self.assigned_units[worker_id].add(unit_id)
@@ -474,22 +585,26 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
474
585
  """Mark a work unit as completed."""
475
586
  logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
476
587
  with self.lock:
477
- if unit_id in self.work_units:
478
- self.assigned_units[worker_id].discard(unit_id)
588
+ self.assigned_units[worker_id].discard(unit_id)
479
589
 
480
- if self.chunk_tracker:
481
- self.chunk_tracker.mark_completed(unit_id)
590
+ if self.chunk_tracker:
591
+ self.chunk_tracker.mark_completed(unit_id)
592
+
593
+ # remove from pending deque if it's there.
594
+ try:
595
+ self.pending_units.remove(unit_id)
596
+ except:
597
+ pass
482
598
 
483
599
  def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
484
600
  """Mark a work unit as failed."""
485
- logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
601
+ logger.error("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
486
602
  with self.lock:
487
- if unit_id in self.work_units:
488
- self.assigned_units[worker_id].discard(unit_id)
489
- self.pending_units.append(unit_id)
603
+ self.assigned_units[worker_id].discard(unit_id)
604
+ self.pending_units.append(unit_id)
490
605
 
491
- if self.chunk_tracker:
492
- self.chunk_tracker.mark_failed(unit_id)
606
+ if self.chunk_tracker:
607
+ self.chunk_tracker.mark_failed(unit_id)
493
608
 
494
609
  def release_assignments(self, worker_id: str) -> None:
495
610
  """Release all assignments for a disconnected worker."""
@@ -498,8 +613,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
498
613
  unit_ids = list(self.assigned_units.get(worker_id, []))
499
614
 
500
615
  for unit_id in unit_ids:
501
- if unit_id in self.work_units:
502
- self.pending_units.append(unit_id)
616
+ logger.debug(f"Adding {unit_id} to pending queue")
617
+ self.pending_units.append(unit_id)
503
618
 
504
619
  if worker_id in self.assigned_units:
505
620
  del self.assigned_units[worker_id]
@@ -509,57 +624,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
509
624
 
510
625
  def update_from_storage(self, processed_job_ids: Set[str]) -> None:
511
626
  """Update work units based on what's been processed."""
512
- logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
513
-
514
- with self.lock:
515
- for unit_id, unit in self.work_units.items():
516
- # Extract chunk info from unit
517
- logger.debug(f"Checking unit {unit_id} for updates")
518
- logger.debug(f"Unit data: {unit.data}")
519
- logger.debug(f"Unit metadata: {unit.metadata}")
520
- start_index = unit.data["start_index"]
521
- chunk_size = unit.data["chunk_size"]
522
- shard_name = unit.metadata["shard_name"]
523
- chunk_index = unit.metadata["chunk_index"]
524
-
525
- # Find processed indices for this chunk
526
- processed_indices = []
527
- for job_id in processed_job_ids:
528
- # Parse job_id format: "data-0000:chunk:0:idx:42"
529
- job_id = JobId.from_str(job_id=job_id)
530
- if job_id.shard_id == shard_name and int(job_id.chunk_id) == chunk_index:
531
- idx = int(job_id.sample_id)
532
- if start_index <= idx < start_index + chunk_size:
533
- processed_indices.append(idx)
534
-
535
- if processed_indices:
536
- # Convert to ranges
537
- processed_indices.sort()
538
- processed_ranges = []
539
- start = processed_indices[0]
540
- end = processed_indices[0]
541
-
542
- for idx in processed_indices[1:]:
543
- if idx == end + 1:
544
- end = idx
545
- else:
546
- processed_ranges.append((start, end))
547
- start = idx
548
- end = idx
549
-
550
- processed_ranges.append((start, end))
551
-
552
- # Calculate unprocessed ranges
553
- total_range = [(start_index, start_index + chunk_size - 1)]
554
- unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
555
-
556
- # Update unit
557
- unit.data["unprocessed_ranges"] = unprocessed_ranges
558
-
559
- logger.debug(
560
- f"Updated unit {unit_id}: {len(processed_indices)} processed, "
561
- f"unprocessed ranges: {unprocessed_ranges}"
562
- )
627
+ logger.info(f"Updating from storage with {len(processed_job_ids)} processed jobs")
628
+ # No need to update in-memory work units since we create on demand
563
629
 
564
630
  def get_stats(self) -> Dict[str, Any]:
565
631
  """Get processor statistics."""
@@ -568,12 +634,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
568
634
  "dataset": self.dataset_name,
569
635
  "config": self.config,
570
636
  "split": self.split,
571
- "total_units": len(self.work_units),
572
637
  "pending_units": len(self.pending_units),
573
638
  "assigned_units": sum(len(units) for units in self.assigned_units.values()),
574
639
  "total_shards": len(self.shard_info),
575
640
  "total_items": self.total_items,
576
641
  "workers": len(self.assigned_units),
642
+ "current_chunk_index": self.current_chunk_index,
577
643
  }
578
644
  return stats
579
645
 
@@ -581,71 +647,111 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
581
647
  """Handle result processing."""
582
648
  base_result = super().handle_result(result)
583
649
 
584
- # Track processed items
585
650
  if self.chunk_tracker:
586
- if "item_indices" not in result.metadata:
587
- result.metadata["item_indices"] = [result.metadata.get("_item_index")]
588
- indices = result.metadata["item_indices"]
651
+ if "item_indices" in result.metadata:
652
+ indices = result.metadata["item_indices"]
653
+ if indices:
654
+ # Convert to ranges for efficient tracking
655
+ indices.sort()
656
+ ranges = []
657
+ start = indices[0]
658
+ end = indices[0]
659
+
660
+ for i in range(1, len(indices)):
661
+ if indices[i] == end + 1:
662
+ end = indices[i]
663
+ else:
664
+ ranges.append((start, end))
665
+ start = indices[i]
666
+ end = indices[i]
589
667
 
590
- if indices:
591
- indices.sort()
592
- ranges = []
593
- start = indices[0]
594
- end = indices[0]
668
+ ranges.append((start, end))
595
669
 
596
- for i in range(1, len(indices)):
597
- if indices[i] == end + 1:
598
- end = indices[i]
599
- else:
600
- ranges.append((start, end))
601
- start = indices[i]
602
- end = indices[i]
670
+ for start_idx, end_idx in ranges:
671
+ self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
603
672
 
604
- ranges.append((start, end))
673
+ return base_result
605
674
 
606
- for start_idx, end_idx in ranges:
607
- self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
675
+ def _subtract_ranges(
676
+ self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
677
+ ) -> List[Tuple[int, int]]:
678
+ """Subtract processed ranges from total ranges."""
679
+ if not processed_ranges:
680
+ return total_ranges
608
681
 
609
- return base_result
682
+ # Create a set of all processed indices
683
+ processed_indices = set()
684
+ for start, end in processed_ranges:
685
+ processed_indices.update(range(start, end + 1))
686
+
687
+ # Find unprocessed ranges
688
+ unprocessed_ranges = []
689
+ for start, end in total_ranges:
690
+ current_start = None
691
+ for i in range(start, end + 1):
692
+ if i not in processed_indices:
693
+ if current_start is None:
694
+ current_start = i
695
+ else:
696
+ if current_start is not None:
697
+ unprocessed_ranges.append((current_start, i - 1))
698
+ current_start = None
699
+
700
+ if current_start is not None:
701
+ unprocessed_ranges.append((current_start, end))
702
+
703
+ return unprocessed_ranges
704
+
705
+ def cleanup(self):
706
+ """Clean up resources."""
707
+ logger.info("Cleaning up orchestrator resources")
708
+
709
+ # Stop background threads
710
+ self.stop_creation.set()
711
+ if self.unit_creation_thread:
712
+ self.unit_creation_thread.join(timeout=5)
713
+
714
+ # Shutdown queue handler
715
+ self.queue_handler.shutdown()
716
+
717
+ # Save final state
718
+ if self.chunk_tracker:
719
+ self.chunk_tracker.save_checkpoint()
610
720
 
611
721
 
612
722
  class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
613
- """Worker processor for HuggingFace datasets."""
723
+ """Memory-optimized worker processor for HuggingFace datasets."""
614
724
 
615
725
  def __init__(self):
616
- logger.debug("Initializing HuggingFaceDatasetWorkerProcessor")
726
+ logger.debug("Initializing HuggingFaceDatasetWorkerProcessor (Optimized)")
617
727
  self.dataset_config: Dict[str, Any] = {}
618
728
  self.token = get_token()
619
- self.shard_cache: Dict[int, Dataset] = {} # Cache loaded shards
620
729
  self.image_column: Optional[str] = None
621
730
  self.url_column: Optional[str] = None
622
731
 
732
+ # Thread-local storage for shard info to avoid repeated builder loading
733
+ self._thread_local = threading.local()
734
+
623
735
  def initialize(self, config: ProcessorConfig) -> None:
624
736
  """Initialize processor."""
625
737
  logger.debug("Initializing worker with config: %s", config.config)
626
738
  self.dataset_config = config.config.get("dataset", {})
627
739
 
628
- # Determine if this is an image URL dataset or binary image dataset
629
740
  self.image_column = self.dataset_config.get("dataset_image_column", "image")
630
741
  self.url_column = self.dataset_config.get("dataset_url_column", "image_url")
631
742
  self.dataset_path = self.dataset_config.get("dataset_path", None)
632
743
 
633
- def _load_shard(self, dataset_name: str, shard_filename: str, shard_id: int) -> Dataset:
634
- """Load a shard if not already cached."""
635
- if shard_id in self.shard_cache:
636
- return self.shard_cache[shard_id]
637
-
638
- logger.info(f"Loading shard {shard_id}: {shard_filename}")
744
+ # Add mock results flag
745
+ self.mock_results = self.dataset_config.get("mock_results", False)
746
+ if self.mock_results:
747
+ logger.info("Mock results mode enabled - will generate dummy images")
639
748
 
640
- local_path = hf_hub_download(
749
+ def _get_shard_path(self, dataset_name: str, shard_filename: str) -> str:
750
+ """Get local path for a shard, downloading if needed."""
751
+ return hf_hub_download(
641
752
  repo_id=dataset_name, filename=shard_filename, repo_type="dataset", token=self.token
642
753
  )
643
754
 
644
- dataset = Dataset.from_parquet(local_path)
645
- self.shard_cache[shard_id] = dataset
646
-
647
- return dataset
648
-
649
755
  def _extract_filename_from_url(self, url: str) -> str:
650
756
  """Extract filename from HF URL format."""
651
757
  match = re.search(r"@[a-f0-9]+/(.+)$", url)
@@ -653,161 +759,227 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
653
759
  return match.group(1)
654
760
  return url.split("/")[-1]
655
761
 
762
+ def _create_dummy_image(self, index: int, metadata: Dict[str, Any]) -> Image.Image:
763
+ """Create a dummy image"""
764
+ color = (0, 0, 0)
765
+ width, height = 128, 128
766
+ image = Image.new("RGB", (width, height), color=color)
767
+
768
+ return image
769
+
656
770
  def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
657
771
  """Process a work unit, yielding items to be captioned."""
658
- logger.debug("Processing unit: %s", unit.unit_id)
772
+ logger.debug("Processing unit: %s (mock_results=%s)", unit.unit_id, self.mock_results)
773
+ log_memory(f"start processing unit {unit.unit_id}")
659
774
 
660
775
  dataset_name = unit.data["dataset_name"]
661
- config = unit.data["config"]
662
- split = unit.data["split"]
663
776
  start_index = unit.data["start_index"]
664
777
  chunk_size = unit.data["chunk_size"]
665
778
  unprocessed_ranges = unit.data.get(
666
779
  "unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
667
780
  )
668
781
  shard_ids = unit.data.get("shard_ids", [])
782
+ data_files = unit.data.get("data_files", [])
669
783
 
670
784
  logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
671
785
 
672
- # Need to get shard info - should be passed in unit data
673
- # For now, we'll need to load dataset builder to get file info
674
- from datasets import load_dataset_builder
675
-
676
- builder = load_dataset_builder(dataset_name, config)
677
-
678
- data_files = []
679
- if hasattr(builder.config, "data_files"):
680
- if isinstance(builder.config.data_files, dict):
681
- files = builder.config.data_files.get(split, [])
682
- if isinstance(files, str):
683
- files = [files]
684
- data_files = files
685
-
686
- # Build shard info
786
+ # Build shard info from provided data files (no dataset builder needed)
687
787
  shard_info = {}
688
- cumulative_offset = 0
689
-
690
- for i, file_url in enumerate(data_files):
691
- if i not in shard_ids:
692
- # Skip loading this shard, but we need its size for offsets
693
- # This is inefficient - in real implementation, orchestrator should pass this info
694
- filename = self._extract_filename_from_url(file_url)
695
- dataset = self._load_shard(dataset_name, filename, i)
696
- size = len(dataset)
697
- cumulative_offset += size
698
- continue
699
788
 
700
- filename = self._extract_filename_from_url(file_url)
701
- dataset = self._load_shard(dataset_name, filename, i)
789
+ if data_files:
790
+ # Use provided data files
791
+ for i, file_url in enumerate(data_files):
792
+ if i in shard_ids:
793
+ filename = self._extract_filename_from_url(file_url)
794
+ shard_path = self._get_shard_path(dataset_name, filename)
795
+
796
+ # Get size from metadata
797
+ metadata = pq.read_metadata(shard_path)
798
+ size = metadata.num_rows
799
+
800
+ shard_info[i] = {
801
+ "path": shard_path,
802
+ "start_offset": 0, # Will be set below
803
+ "end_offset": 0, # Will be set below
804
+ "size": size,
805
+ "metadata": metadata,
806
+ }
702
807
 
703
- shard_info[i] = {
704
- "dataset": dataset,
705
- "start_offset": cumulative_offset,
706
- "end_offset": cumulative_offset + len(dataset) - 1,
707
- "columns": dataset.column_names,
708
- }
709
- cumulative_offset += len(dataset)
808
+ # Calculate offsets
809
+ cumulative_offset = 0
810
+ for i in range(max(shard_info.keys()) + 1):
811
+ if i in shard_info:
812
+ shard_info[i]["start_offset"] = cumulative_offset
813
+ shard_info[i]["end_offset"] = cumulative_offset + shard_info[i]["size"] - 1
814
+ cumulative_offset += shard_info[i]["size"]
815
+ else:
816
+ # Need to get size for offset calculation
817
+ filename = self._extract_filename_from_url(data_files[i])
818
+ shard_path = self._get_shard_path(dataset_name, filename)
819
+ metadata = pq.read_metadata(shard_path)
820
+ cumulative_offset += metadata.num_rows
821
+ else:
822
+ # This should never happen with the new orchestrator
823
+ raise ValueError("No data files provided in work unit")
710
824
 
711
825
  # Create set of indices to process
712
826
  indices_to_process = set()
713
827
  for start, end in unprocessed_ranges:
714
828
  indices_to_process.update(range(start, end + 1))
715
829
 
716
- processed_indices = []
717
-
718
- # Process items
719
- for global_idx in sorted(indices_to_process):
720
- # Find which shard contains this index
721
- shard_id = None
722
- local_idx = None
723
-
724
- for sid, sinfo in shard_info.items():
830
+ # Group indices by shard
831
+ indices_by_shard = defaultdict(list)
832
+ for global_idx in indices_to_process:
833
+ for shard_id, sinfo in shard_info.items():
725
834
  if sinfo["start_offset"] <= global_idx <= sinfo["end_offset"]:
726
- shard_id = sid
727
835
  local_idx = global_idx - sinfo["start_offset"]
836
+ indices_by_shard[shard_id].append((global_idx, local_idx))
728
837
  break
729
838
 
730
- if shard_id is None:
731
- logger.warning(f"Could not find shard for global index {global_idx}")
732
- continue
733
-
734
- try:
735
- # Get item from shard
736
- item = shard_info[shard_id]["dataset"][local_idx]
737
-
738
- # Check if this is a URL dataset or binary image dataset
739
- image = None
740
- image_url = None
741
-
742
- # Try URL column first
743
- if self.url_column and self.url_column in item:
744
- image_url = item[self.url_column]
745
- # Download image from URL
746
- try:
747
- response = requests.get(image_url, timeout=30)
748
- response.raise_for_status()
749
- image = Image.open(io.BytesIO(response.content))
750
- except Exception as e:
751
- logger.error(f"Error downloading image from {image_url}: {e}")
752
- continue
753
-
754
- # Try binary image column
755
- elif self.image_column and self.image_column in item:
756
- image_data = item[self.image_column]
757
- if isinstance(image_data, Image.Image):
758
- image = image_data
759
- elif isinstance(image_data, dict) and "bytes" in image_data:
760
- # Handle datasets Image feature
761
- image = Image.open(io.BytesIO(image_data["bytes"]))
762
- elif isinstance(image_data, bytes):
763
- image = Image.open(io.BytesIO(image_data))
764
-
765
- if image is None:
766
- logger.warning(f"No image found for item at index {global_idx}")
767
- continue
768
-
769
- # Build job ID
770
- chunk_index = unit.metadata["chunk_index"]
771
- shard_name = unit.metadata["shard_name"]
772
- job_id_obj = JobId(
773
- shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(global_idx)
774
- )
775
- job_id = job_id_obj.get_sample_str()
776
-
777
- # Clean metadata
778
- clean_metadata = {
779
- k: v
780
- for k, v in item.items()
781
- if k not in [self.image_column, self.url_column] and not k.startswith("_")
782
- }
783
-
784
- clean_metadata.update(
785
- {
786
- "_item_index": global_idx,
787
- "_chunk_relative_index": global_idx - start_index,
788
- "_job_id": job_id,
789
- "_shard_id": shard_id,
790
- "_local_index": local_idx,
791
- "_url": image_url,
792
- }
793
- )
794
-
795
- yield {
796
- "image": image,
797
- "item_key": str(global_idx),
798
- "item_index": global_idx,
799
- "metadata": clean_metadata,
800
- "job_id": job_id,
801
- }
802
-
803
- processed_indices.append(global_idx)
839
+ processed_indices = []
804
840
 
805
- except Exception as e:
806
- logger.error(f"Error processing item at index {global_idx}: {e}")
841
+ # Process items shard by shard
842
+ for shard_id, idx_pairs in indices_by_shard.items():
843
+ shard_path = shard_info[shard_id]["path"]
844
+
845
+ # Process in batches to avoid loading entire table
846
+ batch_size = 100
847
+ for batch_start in range(0, len(idx_pairs), batch_size):
848
+ batch_pairs = idx_pairs[batch_start : batch_start + batch_size]
849
+ local_indices = [local_idx for _, local_idx in batch_pairs]
850
+
851
+ # Read only specific rows using PyArrow
852
+ try:
853
+ # Create row group filters based on metadata
854
+ metadata = shard_info[shard_id]["metadata"]
855
+ row_groups_to_read = set()
856
+
857
+ # Find which row groups contain our indices
858
+ current_row = 0
859
+ for rg_idx in range(metadata.num_row_groups):
860
+ rg_metadata = metadata.row_group(rg_idx)
861
+ rg_num_rows = rg_metadata.num_rows
862
+
863
+ # Check if any of our indices are in this row group
864
+ for local_idx in local_indices:
865
+ if current_row <= local_idx < current_row + rg_num_rows:
866
+ row_groups_to_read.add(rg_idx)
867
+
868
+ current_row += rg_num_rows
869
+
870
+ # Read only necessary row groups
871
+ parquet_file = pq.ParquetFile(shard_path)
872
+ table = parquet_file.read_row_groups(list(row_groups_to_read))
873
+
874
+ # Process items
875
+ for global_idx, local_idx in batch_pairs:
876
+ try:
877
+ # Get item as dictionary (efficient row extraction)
878
+ row_dict = table.slice(local_idx, 1).to_pydict()
879
+ item = {k: v[0] for k, v in row_dict.items()}
880
+
881
+ # Process image
882
+ image = None
883
+ image_url = None
884
+
885
+ if self.mock_results:
886
+ # In mock mode, create a dummy image
887
+ logger.debug(f"Creating mock image for index {global_idx}")
888
+
889
+ # Still extract URL if available for metadata
890
+ if self.url_column and self.url_column in item:
891
+ image_url = item[self.url_column]
892
+
893
+ # Create dummy image with metadata context
894
+ image = self._create_dummy_image(
895
+ global_idx,
896
+ {
897
+ "_shard_id": shard_id,
898
+ "_local_index": local_idx,
899
+ },
900
+ )
901
+ else:
902
+ # Normal processing - load real images
903
+ if self.url_column and self.url_column in item:
904
+ image_url = item[self.url_column]
905
+ try:
906
+ response = requests.get(image_url, timeout=30)
907
+ response.raise_for_status()
908
+ image = Image.open(io.BytesIO(response.content))
909
+ except Exception as e:
910
+ logger.error(
911
+ f"Error downloading image from {image_url}: {e}"
912
+ )
913
+ continue
914
+
915
+ elif self.image_column and self.image_column in item:
916
+ image_data = item[self.image_column]
917
+ if isinstance(image_data, dict) and "bytes" in image_data:
918
+ image = Image.open(io.BytesIO(image_data["bytes"]))
919
+ elif isinstance(image_data, bytes):
920
+ image = Image.open(io.BytesIO(image_data))
921
+
922
+ if image is None:
923
+ logger.warning(f"No image found for item at index {global_idx}")
924
+ continue
925
+
926
+ # Build job ID
927
+ chunk_index = unit.metadata["chunk_index"]
928
+ shard_name = unit.metadata["shard_name"]
929
+ job_id_obj = JobId(
930
+ shard_id=shard_name,
931
+ chunk_id=str(chunk_index),
932
+ sample_id=str(global_idx),
933
+ )
934
+ job_id = job_id_obj.get_sample_str()
935
+
936
+ # Clean metadata
937
+ clean_metadata = {
938
+ k: v
939
+ for k, v in item.items()
940
+ if k not in [self.image_column, self.url_column]
941
+ and not k.startswith("_")
942
+ }
943
+
944
+ clean_metadata.update(
945
+ {
946
+ "_item_index": global_idx,
947
+ "_chunk_relative_index": global_idx - start_index,
948
+ "_job_id": job_id,
949
+ "_shard_id": shard_id,
950
+ "_local_index": local_idx,
951
+ "_url": image_url,
952
+ "_mock": self.mock_results, # Add flag to indicate mock data
953
+ }
954
+ )
955
+
956
+ yield {
957
+ "image": image,
958
+ "item_key": str(global_idx),
959
+ "item_index": global_idx,
960
+ "metadata": clean_metadata,
961
+ "job_id": job_id,
962
+ "_processed_indices": processed_indices,
963
+ }
964
+
965
+ processed_indices.append(global_idx)
966
+
967
+ except Exception as e:
968
+ logger.error(f"Error processing item at index {global_idx}: {e}")
969
+
970
+ # Explicitly delete table to free memory
971
+ del table
972
+ gc.collect()
973
+
974
+ except Exception as e:
975
+ logger.error(f"Error reading batch from shard {shard_id}: {e}")
807
976
 
808
977
  # Store processed indices in context
809
978
  context["_processed_indices"] = processed_indices
810
- logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
979
+ logger.debug(
980
+ f"Processed {len(processed_indices)} indices for unit {unit.unit_id}: {processed_indices}, {context}"
981
+ )
982
+ log_memory(f"end processing unit {unit.unit_id}")
811
983
 
812
984
  def prepare_result(
813
985
  self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float