caption-flow 0.2.4__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/base.py +3 -0
- caption_flow/processors/huggingface.py +637 -464
- caption_flow/processors/local_filesystem.py +2 -0
- caption_flow/processors/webdataset.py +438 -538
- caption_flow/storage/manager.py +328 -305
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +197 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/workers/caption.py +191 -138
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/METADATA +2 -1
- caption_flow-0.3.2.dist-info/RECORD +33 -0
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- caption_flow-0.2.4.dist-info/RECORD +0 -38
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,24 @@
|
|
1
|
-
"""HuggingFace Datasets processor implementation."""
|
1
|
+
"""HuggingFace Datasets processor implementation - Memory Optimized Version."""
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import threading
|
5
5
|
import re
|
6
|
+
import queue
|
6
7
|
import requests
|
8
|
+
import json
|
9
|
+
import io
|
10
|
+
import os
|
11
|
+
import gc
|
12
|
+
import psutil
|
13
|
+
from concurrent.futures import ThreadPoolExecutor, Future
|
7
14
|
from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
|
8
15
|
from collections import deque, defaultdict
|
9
16
|
from pathlib import Path
|
10
|
-
import json
|
11
|
-
import io
|
12
17
|
from datetime import datetime
|
13
18
|
from PIL import Image
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
get_dataset_split_names,
|
18
|
-
load_dataset_builder,
|
19
|
-
)
|
19
|
+
import pyarrow as pa
|
20
|
+
import pyarrow.parquet as pq
|
21
|
+
from datasets import get_dataset_config_names, get_dataset_split_names
|
20
22
|
from huggingface_hub import hf_hub_download, get_token
|
21
23
|
from caption_flow.storage import StorageManager
|
22
24
|
|
@@ -28,11 +30,82 @@ logger = logging.getLogger(__name__)
|
|
28
30
|
logger.setLevel(logging.DEBUG)
|
29
31
|
|
30
32
|
|
33
|
+
def log_memory(location: str):
|
34
|
+
"""Log memory usage at specific location."""
|
35
|
+
process = psutil.Process(os.getpid())
|
36
|
+
mem_info = process.memory_info()
|
37
|
+
logger.info(
|
38
|
+
f"Memory at {location}: RSS={mem_info.rss/1024/1024:.1f}MB, VMS={mem_info.vms/1024/1024:.1f}MB"
|
39
|
+
)
|
40
|
+
# Force garbage collection
|
41
|
+
gc.collect()
|
42
|
+
|
43
|
+
|
44
|
+
logger = logging.getLogger(__name__)
|
45
|
+
logger.setLevel(logging.DEBUG)
|
46
|
+
|
47
|
+
|
48
|
+
class NonBlockingQueueHandler:
|
49
|
+
"""Handles non-blocking retrieval from queues using concurrent futures."""
|
50
|
+
|
51
|
+
def __init__(self, max_workers: int = 1):
|
52
|
+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
53
|
+
self.pending_futures: Dict[int, Future] = {} # queue_id -> Future
|
54
|
+
|
55
|
+
def get_from_queue_async(self, response_queue: queue.Queue, timeout: float = None) -> Future:
|
56
|
+
"""Start an async queue retrieval."""
|
57
|
+
queue_id = id(response_queue)
|
58
|
+
|
59
|
+
# Check if we already have a pending future for this queue
|
60
|
+
if queue_id in self.pending_futures and not self.pending_futures[queue_id].done():
|
61
|
+
return self.pending_futures[queue_id]
|
62
|
+
|
63
|
+
# Start new async retrieval
|
64
|
+
future = self.executor.submit(response_queue.get, timeout=timeout)
|
65
|
+
self.pending_futures[queue_id] = future
|
66
|
+
return future
|
67
|
+
|
68
|
+
def check_response(self, response_queue: queue.Queue, timeout: float = None) -> Optional[Any]:
|
69
|
+
"""Non-blocking check for queue response."""
|
70
|
+
queue_id = id(response_queue)
|
71
|
+
|
72
|
+
# Start async retrieval if needed
|
73
|
+
future = self.get_from_queue_async(response_queue, timeout)
|
74
|
+
|
75
|
+
# Check if result is ready (non-blocking)
|
76
|
+
if future.done():
|
77
|
+
try:
|
78
|
+
result = future.result(timeout=0)
|
79
|
+
# Clear future for next retrieval
|
80
|
+
if queue_id in self.pending_futures:
|
81
|
+
del self.pending_futures[queue_id]
|
82
|
+
return result
|
83
|
+
except queue.Empty:
|
84
|
+
# Queue was empty, clear future
|
85
|
+
if queue_id in self.pending_futures:
|
86
|
+
del self.pending_futures[queue_id]
|
87
|
+
return None
|
88
|
+
except Exception as e:
|
89
|
+
logger.error(f"Error retrieving from queue: {e}")
|
90
|
+
if queue_id in self.pending_futures:
|
91
|
+
del self.pending_futures[queue_id]
|
92
|
+
return None
|
93
|
+
|
94
|
+
# Result not ready yet
|
95
|
+
return None
|
96
|
+
|
97
|
+
def shutdown(self):
|
98
|
+
"""Shutdown the executor."""
|
99
|
+
self.executor.shutdown(wait=True)
|
100
|
+
|
101
|
+
|
31
102
|
class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
32
|
-
"""
|
103
|
+
"""Memory-optimized orchestrator processor for HuggingFace datasets with non-blocking operations."""
|
33
104
|
|
34
105
|
def __init__(self):
|
35
|
-
logger.debug(
|
106
|
+
logger.debug(
|
107
|
+
"Initializing HuggingFaceDatasetOrchestratorProcessor (Optimized + Non-blocking)"
|
108
|
+
)
|
36
109
|
self.dataset_name: Optional[str] = None
|
37
110
|
self.config: Optional[str] = None
|
38
111
|
self.split: Optional[str] = None
|
@@ -44,19 +117,33 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
44
117
|
self.shard_info: Dict[int, Dict[str, Any]] = {}
|
45
118
|
self.total_items: int = 0
|
46
119
|
|
47
|
-
# Work unit management
|
48
|
-
self.work_units: Dict[str, WorkUnit] = {}
|
120
|
+
# Work unit management - only store active units
|
49
121
|
self.pending_units: Deque[str] = deque()
|
50
|
-
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
122
|
+
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
51
123
|
self.lock = threading.Lock()
|
52
124
|
|
125
|
+
# Track current chunk index for on-demand creation
|
126
|
+
self.current_chunk_index = 0
|
127
|
+
|
128
|
+
# Cache data files info instead of loading builder repeatedly
|
129
|
+
self.data_files: List[str] = []
|
130
|
+
|
53
131
|
# Background thread for creating work units
|
54
132
|
self.unit_creation_thread: Optional[threading.Thread] = None
|
55
133
|
self.stop_creation = threading.Event()
|
56
134
|
|
135
|
+
# Non-blocking queue handler
|
136
|
+
self.queue_handler = NonBlockingQueueHandler()
|
137
|
+
|
138
|
+
# Response processing state
|
139
|
+
self.last_maintenance_time = datetime.now()
|
140
|
+
self.maintenance_interval = 30 # seconds
|
141
|
+
|
57
142
|
def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
|
58
143
|
"""Initialize HuggingFace dataset processor."""
|
59
144
|
logger.debug("Initializing orchestrator with config: %s", config.config)
|
145
|
+
log_memory("start of initialize")
|
146
|
+
|
60
147
|
cfg = config.config
|
61
148
|
|
62
149
|
# Dataset configuration
|
@@ -83,12 +170,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
83
170
|
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
|
84
171
|
|
85
172
|
# Initialize chunk tracking
|
86
|
-
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
87
|
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
88
|
-
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
173
|
+
self.checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
174
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
175
|
+
self.chunk_tracker = ChunkTracker(self.checkpoint_dir / "chunks.json")
|
89
176
|
|
90
|
-
# Discover shards
|
91
|
-
self.
|
177
|
+
# Discover shards (optimized)
|
178
|
+
self._discover_shards_optimized()
|
92
179
|
|
93
180
|
# Restore existing state
|
94
181
|
self._restore_state(storage=storage)
|
@@ -98,7 +185,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
98
185
|
target=self._create_units_background, daemon=True
|
99
186
|
)
|
100
187
|
self.unit_creation_thread.start()
|
101
|
-
|
188
|
+
|
189
|
+
log_memory("end of initialize")
|
102
190
|
|
103
191
|
def _detect_config(self, provided_config: Optional[str]) -> str:
|
104
192
|
"""Auto-detect config if not provided."""
|
@@ -110,14 +198,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
110
198
|
if not configs:
|
111
199
|
return "default"
|
112
200
|
|
113
|
-
# Prefer common config names
|
114
201
|
preferred = ["default", "en", "train", "main"]
|
115
202
|
for pref in preferred:
|
116
203
|
if pref in configs:
|
117
204
|
logger.info(f"Auto-selected config: {pref}")
|
118
205
|
return pref
|
119
206
|
|
120
|
-
# Otherwise use first available
|
121
207
|
logger.info(f"Auto-selected first available config: {configs[0]}")
|
122
208
|
return configs[0]
|
123
209
|
except Exception as e:
|
@@ -134,17 +220,14 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
134
220
|
self.dataset_name, config_name=self.config, token=self.token
|
135
221
|
)
|
136
222
|
if not splits:
|
137
|
-
logger.warning("No splits found, using 'train'")
|
138
223
|
return "train"
|
139
224
|
|
140
|
-
# Prefer training splits
|
141
225
|
preferred = ["train", "training", "test", "validation", "dev"]
|
142
226
|
for pref in preferred:
|
143
227
|
if pref in splits:
|
144
228
|
logger.info(f"Auto-selected split: {pref}")
|
145
229
|
return pref
|
146
230
|
|
147
|
-
# Otherwise use first available
|
148
231
|
logger.info(f"Auto-selected first available split: {splits[0]}")
|
149
232
|
return splits[0]
|
150
233
|
except Exception as e:
|
@@ -153,18 +236,16 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
153
236
|
|
154
237
|
def _extract_filename_from_url(self, url: str) -> str:
|
155
238
|
"""Extract filename from HF URL format."""
|
156
|
-
# Format: hf://datasets/user/dataset@hash/filename
|
157
239
|
match = re.search(r"@[a-f0-9]+/(.+)$", url)
|
158
240
|
if match:
|
159
241
|
return match.group(1)
|
160
|
-
# Fallback: just get last part
|
161
242
|
return url.split("/")[-1]
|
162
243
|
|
163
|
-
def
|
164
|
-
"""
|
165
|
-
|
244
|
+
def _get_data_files_from_builder(self) -> List[str]:
|
245
|
+
"""Get data files using dataset builder with minimal memory usage."""
|
246
|
+
# Load builder to get correct file structure
|
247
|
+
from datasets import load_dataset_builder
|
166
248
|
|
167
|
-
# Load dataset builder to get file info
|
168
249
|
builder = load_dataset_builder(self.dataset_name, self.config)
|
169
250
|
|
170
251
|
# Get data files for our split
|
@@ -176,81 +257,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
176
257
|
files = [files]
|
177
258
|
data_files = files
|
178
259
|
|
179
|
-
|
260
|
+
# Explicitly delete builder to free memory
|
261
|
+
del builder
|
262
|
+
gc.collect()
|
263
|
+
|
264
|
+
return data_files
|
265
|
+
|
266
|
+
def _discover_shards_optimized(self):
|
267
|
+
"""Discover all shards using dataset builder but release memory immediately."""
|
268
|
+
logger.info("Discovering shards...")
|
269
|
+
|
270
|
+
# Try to load cached shard info first
|
271
|
+
shard_info_cache_path = (
|
272
|
+
self.checkpoint_dir / f"{self.dataset_name}_{self.config}_{self.split}_shard_info.json"
|
273
|
+
)
|
274
|
+
|
275
|
+
if shard_info_cache_path.exists():
|
276
|
+
try:
|
277
|
+
with open(shard_info_cache_path, "r") as f:
|
278
|
+
cached_info = json.load(f)
|
279
|
+
if (
|
280
|
+
cached_info.get("dataset") == self.dataset_name
|
281
|
+
and cached_info.get("config") == self.config
|
282
|
+
and cached_info.get("split") == self.split
|
283
|
+
):
|
284
|
+
self.shard_info = {int(k): v for k, v in cached_info["shards"].items()}
|
285
|
+
self.total_items = cached_info["total_items"]
|
286
|
+
self.data_files = cached_info.get("data_files", [])
|
287
|
+
logger.info(
|
288
|
+
f"Loaded cached shard info: {len(self.shard_info)} shards, {self.total_items} total items"
|
289
|
+
)
|
290
|
+
return
|
291
|
+
except Exception as e:
|
292
|
+
logger.warning(f"Failed to load cached shard info: {e}")
|
293
|
+
|
294
|
+
# Get data files using dataset builder
|
295
|
+
self.data_files = self._get_data_files_from_builder()
|
296
|
+
|
297
|
+
if not self.data_files:
|
180
298
|
raise ValueError(f"No data files found for split '{self.split}'")
|
181
299
|
|
182
|
-
logger.info(f"Found {len(data_files)} data files")
|
300
|
+
logger.info(f"Found {len(self.data_files)} data files")
|
183
301
|
|
184
|
-
# Get
|
302
|
+
# Get metadata for each shard
|
185
303
|
cumulative_offset = 0
|
186
|
-
for i, file_url in enumerate(data_files):
|
304
|
+
for i, file_url in enumerate(self.data_files):
|
187
305
|
filename = self._extract_filename_from_url(file_url)
|
188
306
|
logger.info(f"Discovering shard {i}: {filename}")
|
189
307
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
"size": None,
|
199
|
-
"end_offset": None,
|
200
|
-
}
|
201
|
-
|
202
|
-
# Try to get size from builder info if available
|
203
|
-
if hasattr(builder.info, "splits") and self.split in builder.info.splits:
|
204
|
-
split_info = builder.info.splits[self.split]
|
205
|
-
if split_info.num_examples and len(data_files) == 1:
|
206
|
-
# Single shard case
|
207
|
-
self.shard_info[i]["size"] = split_info.num_examples
|
208
|
-
self.shard_info[i]["end_offset"] = (
|
209
|
-
cumulative_offset + split_info.num_examples - 1
|
210
|
-
)
|
211
|
-
cumulative_offset += split_info.num_examples
|
212
|
-
|
213
|
-
# If we couldn't get sizes, we'll need to load shards on demand
|
214
|
-
if self.shard_info[0]["size"] is None:
|
215
|
-
logger.warning("Shard sizes not available from metadata, will load on demand")
|
216
|
-
else:
|
217
|
-
self.total_items = cumulative_offset
|
218
|
-
logger.info(f"Total items across all shards: {self.total_items}")
|
219
|
-
|
220
|
-
def _get_shard_size(self, shard_id: int) -> int:
|
221
|
-
"""Get size of a shard, loading it if necessary."""
|
222
|
-
if self.shard_info[shard_id]["size"] is not None:
|
223
|
-
return self.shard_info[shard_id]["size"]
|
308
|
+
try:
|
309
|
+
# Download file to get metadata
|
310
|
+
local_path = hf_hub_download(
|
311
|
+
repo_id=self.dataset_name,
|
312
|
+
filename=filename,
|
313
|
+
repo_type="dataset",
|
314
|
+
token=self.token,
|
315
|
+
)
|
224
316
|
|
225
|
-
|
226
|
-
|
227
|
-
|
317
|
+
# Read only metadata
|
318
|
+
metadata = pq.read_metadata(local_path)
|
319
|
+
size = metadata.num_rows
|
320
|
+
|
321
|
+
self.shard_info[i] = {
|
322
|
+
"shard_id": i,
|
323
|
+
"file_url": file_url,
|
324
|
+
"filename": filename,
|
325
|
+
"start_offset": cumulative_offset,
|
326
|
+
"size": size,
|
327
|
+
"end_offset": cumulative_offset + size - 1,
|
328
|
+
}
|
228
329
|
|
229
|
-
|
230
|
-
|
231
|
-
)
|
330
|
+
cumulative_offset += size
|
331
|
+
logger.info(f"Shard {i} ({filename}): {size} rows")
|
232
332
|
|
233
|
-
|
234
|
-
|
235
|
-
|
333
|
+
except Exception as e:
|
334
|
+
logger.error(f"Failed to discover shard {i}: {e}")
|
335
|
+
# Skip this shard
|
336
|
+
continue
|
236
337
|
|
237
|
-
|
238
|
-
|
338
|
+
self.total_items = cumulative_offset
|
339
|
+
logger.info(f"Total items across all shards: {self.total_items}")
|
239
340
|
|
240
|
-
#
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
341
|
+
# Cache shard info
|
342
|
+
try:
|
343
|
+
cache_data = {
|
344
|
+
"dataset": self.dataset_name,
|
345
|
+
"config": self.config,
|
346
|
+
"split": self.split,
|
347
|
+
"shards": self.shard_info,
|
348
|
+
"total_items": self.total_items,
|
349
|
+
"data_files": self.data_files,
|
350
|
+
}
|
351
|
+
with open(shard_info_cache_path, "w") as f:
|
352
|
+
json.dump(cache_data, f)
|
353
|
+
logger.info(f"Cached shard info to {shard_info_cache_path}")
|
354
|
+
except Exception as e:
|
355
|
+
logger.warning(f"Failed to cache shard info: {e}")
|
247
356
|
|
248
|
-
#
|
249
|
-
|
250
|
-
|
251
|
-
logger.info(f"Total items: {self.total_items}")
|
357
|
+
# Force garbage collection
|
358
|
+
gc.collect()
|
359
|
+
log_memory("after discovering shards")
|
252
360
|
|
253
|
-
|
361
|
+
def _get_shard_for_index(self, global_index: int) -> Tuple[int, int]:
|
362
|
+
"""Get shard ID and local index for a global index."""
|
363
|
+
for shard_id, sinfo in self.shard_info.items():
|
364
|
+
if sinfo["start_offset"] <= global_index <= sinfo["end_offset"]:
|
365
|
+
local_index = global_index - sinfo["start_offset"]
|
366
|
+
return shard_id, local_index
|
367
|
+
raise ValueError(f"Global index {global_index} not found in any shard")
|
254
368
|
|
255
369
|
def _restore_state(self, storage: StorageManager) -> None:
|
256
370
|
"""Restore state from chunk tracker."""
|
@@ -258,73 +372,84 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
258
372
|
if not self.chunk_tracker:
|
259
373
|
return
|
260
374
|
|
261
|
-
all_processed_jobs = storage.get_all_processed_job_ids()
|
262
|
-
|
263
375
|
with self.lock:
|
376
|
+
max_chunk_index = -1
|
377
|
+
|
264
378
|
for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
379
|
+
chunk_index = chunk_state.start_index // self.chunk_size
|
380
|
+
max_chunk_index = max(max_chunk_index, chunk_index)
|
381
|
+
|
382
|
+
# Only add incomplete chunks to pending
|
383
|
+
if chunk_state.status != "completed":
|
384
|
+
self.pending_units.append(chunk_id)
|
385
|
+
elif chunk_state.status == "completed" and chunk_state.processed_ranges:
|
386
|
+
logger.warning(
|
387
|
+
f"Chunk {chunk_id} has processed_ranges stored in the checkpoint."
|
388
|
+
)
|
389
|
+
|
390
|
+
self.current_chunk_index = max_chunk_index + 1
|
391
|
+
logger.info(f"Resuming from chunk index {self.current_chunk_index}")
|
270
392
|
|
271
|
-
|
272
|
-
|
273
|
-
|
393
|
+
def _create_work_unit(self, chunk_index: int) -> Optional[WorkUnit]:
|
394
|
+
"""Create a single work unit for a chunk index."""
|
395
|
+
current_index = chunk_index * self.chunk_size
|
396
|
+
|
397
|
+
if current_index >= self.total_items:
|
398
|
+
return None
|
399
|
+
|
400
|
+
chunk_size = min(self.chunk_size, self.total_items - current_index)
|
401
|
+
|
402
|
+
# Find shard for this chunk
|
403
|
+
shard_id, _ = self._get_shard_for_index(current_index)
|
404
|
+
shard_name = Path(self.shard_info[shard_id]["filename"]).stem
|
405
|
+
|
406
|
+
job_id_obj = JobId(shard_id=shard_name, chunk_id=chunk_index, sample_id=current_index)
|
407
|
+
unit_id = job_id_obj.get_chunk_str()
|
408
|
+
|
409
|
+
# Calculate unprocessed ranges based on existing chunk state
|
410
|
+
unprocessed_ranges = [(current_index, current_index + chunk_size - 1)]
|
411
|
+
|
412
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
413
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
414
|
+
if chunk_state.processed_ranges:
|
415
|
+
# Subtract processed ranges from total range
|
416
|
+
unprocessed_ranges = self._subtract_ranges(
|
417
|
+
[(current_index, current_index + chunk_size - 1)], chunk_state.processed_ranges
|
274
418
|
)
|
275
419
|
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
data={
|
302
|
-
"dataset_name": self.dataset_name,
|
303
|
-
"config": self.config,
|
304
|
-
"split": self.split,
|
305
|
-
"start_index": chunk_state.start_index,
|
306
|
-
"chunk_size": chunk_state.chunk_size,
|
307
|
-
"unprocessed_ranges": unprocessed_ranges,
|
308
|
-
"shard_ids": shard_ids,
|
309
|
-
},
|
310
|
-
metadata={
|
311
|
-
"dataset": self.dataset_name,
|
312
|
-
"shard_name": shard_name,
|
313
|
-
"chunk_index": chunk_index,
|
314
|
-
},
|
315
|
-
)
|
420
|
+
# If all ranges are processed, return None (shouldn't happen if status tracking is correct)
|
421
|
+
if not unprocessed_ranges:
|
422
|
+
return None
|
423
|
+
|
424
|
+
unit = WorkUnit(
|
425
|
+
unit_id=unit_id,
|
426
|
+
chunk_id=unit_id,
|
427
|
+
source_id=shard_name,
|
428
|
+
unit_size=chunk_size,
|
429
|
+
data={
|
430
|
+
"dataset_name": self.dataset_name,
|
431
|
+
"config": self.config,
|
432
|
+
"split": self.split,
|
433
|
+
"start_index": current_index,
|
434
|
+
"chunk_size": chunk_size,
|
435
|
+
"unprocessed_ranges": unprocessed_ranges, # Use calculated ranges
|
436
|
+
"shard_ids": [shard_id],
|
437
|
+
"data_files": self.data_files,
|
438
|
+
},
|
439
|
+
metadata={
|
440
|
+
"dataset": self.dataset_name,
|
441
|
+
"shard_name": shard_name,
|
442
|
+
"chunk_index": chunk_index,
|
443
|
+
},
|
444
|
+
)
|
316
445
|
|
317
|
-
|
318
|
-
self.pending_units.append(unit.unit_id)
|
446
|
+
return unit
|
319
447
|
|
320
448
|
def _create_units_background(self) -> None:
|
321
449
|
"""Background thread to create work units on demand."""
|
322
450
|
logger.info("Starting work unit creation thread")
|
323
451
|
|
324
|
-
current_index = 0
|
325
|
-
|
326
452
|
while not self.stop_creation.is_set():
|
327
|
-
# Check if we need more units
|
328
453
|
with self.lock:
|
329
454
|
pending_count = len(self.pending_units)
|
330
455
|
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
@@ -337,127 +462,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
337
462
|
threading.Event().wait(5)
|
338
463
|
continue
|
339
464
|
|
340
|
-
# Make sure we know total items
|
341
|
-
if self.total_items == 0:
|
342
|
-
# Load all shard sizes
|
343
|
-
for sid in range(len(self.shard_info)):
|
344
|
-
self._get_shard_size(sid)
|
345
|
-
|
346
465
|
# Create units as needed
|
347
466
|
units_created = 0
|
348
467
|
|
349
|
-
while units_created < units_needed
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
):
|
360
|
-
shard_ids.append(sid)
|
361
|
-
shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
|
468
|
+
while units_created < units_needed:
|
469
|
+
logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
|
470
|
+
if self.current_chunk_index * self.chunk_size >= self.total_items:
|
471
|
+
threading.Event().wait(30)
|
472
|
+
break
|
473
|
+
# Get shard info for proper unit_id
|
474
|
+
current_index = self.current_chunk_index * self.chunk_size
|
475
|
+
if current_index < self.total_items:
|
476
|
+
shard_id, _ = self._get_shard_for_index(current_index)
|
477
|
+
shard_name = Path(self.shard_info[shard_id]["filename"]).stem
|
362
478
|
|
363
479
|
job_id_obj = JobId(
|
364
|
-
shard_id=shard_name,
|
480
|
+
shard_id=shard_name,
|
481
|
+
chunk_id=self.current_chunk_index,
|
482
|
+
sample_id=current_index,
|
365
483
|
)
|
366
|
-
unit_id = (
|
367
|
-
job_id_obj.get_chunk_str()
|
368
|
-
) # just the chunk part, eg pixel-images:chunk:0
|
369
|
-
if unit_id in self.work_units:
|
370
|
-
current_index += self.chunk_size
|
371
|
-
continue
|
372
|
-
|
373
|
-
# Check if chunk is already completed
|
374
|
-
if self.chunk_tracker:
|
375
|
-
chunk_state = self.chunk_tracker.chunks.get(unit_id)
|
376
|
-
if chunk_state and chunk_state.status == "completed":
|
377
|
-
current_index += self.chunk_size
|
378
|
-
continue
|
484
|
+
unit_id = job_id_obj.get_chunk_str()
|
379
485
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
"dataset_name": self.dataset_name,
|
388
|
-
"config": self.config,
|
389
|
-
"split": self.split,
|
390
|
-
"start_index": current_index,
|
391
|
-
"chunk_size": chunk_size,
|
392
|
-
"unprocessed_ranges": [(current_index, current_index + chunk_size - 1)],
|
393
|
-
"shard_ids": shard_ids,
|
394
|
-
},
|
395
|
-
metadata={
|
396
|
-
"dataset": self.dataset_name,
|
397
|
-
"shard_name": shard_name,
|
398
|
-
"chunk_index": chunk_id,
|
399
|
-
},
|
400
|
-
)
|
401
|
-
logger.debug(f"Created WorkUnit: {unit}")
|
486
|
+
with self.lock:
|
487
|
+
# Check if already tracked
|
488
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
489
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
490
|
+
if chunk_state.status == "completed":
|
491
|
+
self.current_chunk_index += 1
|
492
|
+
continue
|
402
493
|
|
403
|
-
|
494
|
+
# Add to pending
|
404
495
|
self.pending_units.append(unit_id)
|
405
496
|
|
497
|
+
# Track in chunk tracker
|
406
498
|
if self.chunk_tracker:
|
499
|
+
start_index = self.current_chunk_index * self.chunk_size
|
500
|
+
chunk_size = min(self.chunk_size, self.total_items - start_index)
|
407
501
|
self.chunk_tracker.add_chunk(
|
408
502
|
unit_id,
|
409
503
|
self.dataset_name,
|
410
|
-
"",
|
411
|
-
|
504
|
+
"",
|
505
|
+
start_index,
|
412
506
|
chunk_size,
|
413
507
|
)
|
414
508
|
|
415
509
|
units_created += 1
|
416
|
-
|
417
|
-
current_index += self.chunk_size
|
510
|
+
self.current_chunk_index += 1
|
418
511
|
|
419
512
|
if units_created > 0:
|
420
|
-
logger.debug(f"Created {units_created} work
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
) ->
|
425
|
-
"""
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
#
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
513
|
+
logger.debug(f"Created {units_created} work unit IDs")
|
514
|
+
|
515
|
+
logger.info("Thread for creating units has completed. Exiting thread.")
|
516
|
+
|
517
|
+
def process_responses_non_blocking(self, response_queue: queue.Queue) -> Optional[WorkResult]:
|
518
|
+
"""
|
519
|
+
Non-blocking method to process responses from workers.
|
520
|
+
Returns a WorkResult if one is available, None otherwise.
|
521
|
+
"""
|
522
|
+
# Check for response without blocking
|
523
|
+
response = self.queue_handler.check_response(response_queue, timeout=0.1)
|
524
|
+
|
525
|
+
if response is not None:
|
526
|
+
# Process the response
|
527
|
+
if isinstance(response, WorkResult):
|
528
|
+
logger.debug(f"Processing response for unit {response.unit_id}")
|
529
|
+
return response
|
530
|
+
else:
|
531
|
+
logger.warning(f"Unexpected response type: {type(response)}")
|
532
|
+
|
533
|
+
# Perform periodic maintenance tasks
|
534
|
+
now = datetime.now()
|
535
|
+
if (now - self.last_maintenance_time).total_seconds() > self.maintenance_interval:
|
536
|
+
self._perform_maintenance()
|
537
|
+
self.last_maintenance_time = now
|
538
|
+
|
539
|
+
return None
|
540
|
+
|
541
|
+
def _perform_maintenance(self):
|
542
|
+
"""Perform periodic maintenance tasks."""
|
543
|
+
with self.lock:
|
544
|
+
# Log current state
|
545
|
+
pending_count = len(self.pending_units)
|
546
|
+
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
547
|
+
logger.debug(f"Maintenance: {pending_count} pending, {assigned_count} assigned units")
|
446
548
|
|
447
|
-
|
448
|
-
|
549
|
+
# Check for stale assignments (workers that might have disconnected)
|
550
|
+
# This would be implemented based on your worker heartbeat mechanism
|
449
551
|
|
450
|
-
|
552
|
+
# Force checkpoint save if needed
|
553
|
+
if self.chunk_tracker:
|
554
|
+
self.chunk_tracker.save_checkpoint()
|
451
555
|
|
452
556
|
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
453
557
|
"""Get available work units for a worker."""
|
454
|
-
logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
|
455
|
-
assigned = []
|
456
558
|
|
559
|
+
logger.debug(
|
560
|
+
"get_work_units called: count=%d worker_id=%s, pending: %d",
|
561
|
+
count,
|
562
|
+
worker_id,
|
563
|
+
len(self.pending_units),
|
564
|
+
)
|
565
|
+
assigned = []
|
457
566
|
with self.lock:
|
458
567
|
while len(assigned) < count and self.pending_units:
|
459
568
|
unit_id = self.pending_units.popleft()
|
460
|
-
|
569
|
+
|
570
|
+
# Create work unit on demand
|
571
|
+
chunk_index = int(unit_id.split(":")[-1])
|
572
|
+
unit = self._create_work_unit(chunk_index)
|
461
573
|
|
462
574
|
if unit:
|
463
575
|
self.assigned_units[worker_id].add(unit_id)
|
@@ -474,22 +586,26 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
474
586
|
"""Mark a work unit as completed."""
|
475
587
|
logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
|
476
588
|
with self.lock:
|
477
|
-
|
478
|
-
|
589
|
+
self.assigned_units[worker_id].discard(unit_id)
|
590
|
+
|
591
|
+
if self.chunk_tracker:
|
592
|
+
self.chunk_tracker.mark_completed(unit_id)
|
479
593
|
|
480
|
-
|
481
|
-
|
594
|
+
# remove from pending deque if it's there.
|
595
|
+
try:
|
596
|
+
self.pending_units.remove(unit_id)
|
597
|
+
except:
|
598
|
+
pass
|
482
599
|
|
483
600
|
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
484
601
|
"""Mark a work unit as failed."""
|
485
|
-
logger.
|
602
|
+
logger.error("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
|
486
603
|
with self.lock:
|
487
|
-
|
488
|
-
|
489
|
-
self.pending_units.append(unit_id)
|
604
|
+
self.assigned_units[worker_id].discard(unit_id)
|
605
|
+
self.pending_units.append(unit_id)
|
490
606
|
|
491
|
-
|
492
|
-
|
607
|
+
if self.chunk_tracker:
|
608
|
+
self.chunk_tracker.mark_failed(unit_id)
|
493
609
|
|
494
610
|
def release_assignments(self, worker_id: str) -> None:
|
495
611
|
"""Release all assignments for a disconnected worker."""
|
@@ -498,8 +614,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
498
614
|
unit_ids = list(self.assigned_units.get(worker_id, []))
|
499
615
|
|
500
616
|
for unit_id in unit_ids:
|
501
|
-
|
502
|
-
|
617
|
+
logger.debug(f"Adding {unit_id} to pending queue")
|
618
|
+
self.pending_units.append(unit_id)
|
503
619
|
|
504
620
|
if worker_id in self.assigned_units:
|
505
621
|
del self.assigned_units[worker_id]
|
@@ -509,57 +625,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
509
625
|
|
510
626
|
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
511
627
|
"""Update work units based on what's been processed."""
|
512
|
-
logger.info(f"Updating
|
513
|
-
|
514
|
-
with self.lock:
|
515
|
-
for unit_id, unit in self.work_units.items():
|
516
|
-
# Extract chunk info from unit
|
517
|
-
logger.debug(f"Checking unit {unit_id} for updates")
|
518
|
-
logger.debug(f"Unit data: {unit.data}")
|
519
|
-
logger.debug(f"Unit metadata: {unit.metadata}")
|
520
|
-
start_index = unit.data["start_index"]
|
521
|
-
chunk_size = unit.data["chunk_size"]
|
522
|
-
shard_name = unit.metadata["shard_name"]
|
523
|
-
chunk_index = unit.metadata["chunk_index"]
|
524
|
-
|
525
|
-
# Find processed indices for this chunk
|
526
|
-
processed_indices = []
|
527
|
-
for job_id in processed_job_ids:
|
528
|
-
# Parse job_id format: "data-0000:chunk:0:idx:42"
|
529
|
-
job_id = JobId.from_str(job_id=job_id)
|
530
|
-
if job_id.shard_id == shard_name and int(job_id.chunk_id) == chunk_index:
|
531
|
-
idx = int(job_id.sample_id)
|
532
|
-
if start_index <= idx < start_index + chunk_size:
|
533
|
-
processed_indices.append(idx)
|
534
|
-
|
535
|
-
if processed_indices:
|
536
|
-
# Convert to ranges
|
537
|
-
processed_indices.sort()
|
538
|
-
processed_ranges = []
|
539
|
-
start = processed_indices[0]
|
540
|
-
end = processed_indices[0]
|
541
|
-
|
542
|
-
for idx in processed_indices[1:]:
|
543
|
-
if idx == end + 1:
|
544
|
-
end = idx
|
545
|
-
else:
|
546
|
-
processed_ranges.append((start, end))
|
547
|
-
start = idx
|
548
|
-
end = idx
|
549
|
-
|
550
|
-
processed_ranges.append((start, end))
|
551
|
-
|
552
|
-
# Calculate unprocessed ranges
|
553
|
-
total_range = [(start_index, start_index + chunk_size - 1)]
|
554
|
-
unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
|
555
|
-
|
556
|
-
# Update unit
|
557
|
-
unit.data["unprocessed_ranges"] = unprocessed_ranges
|
558
|
-
|
559
|
-
logger.debug(
|
560
|
-
f"Updated unit {unit_id}: {len(processed_indices)} processed, "
|
561
|
-
f"unprocessed ranges: {unprocessed_ranges}"
|
562
|
-
)
|
628
|
+
logger.info(f"Updating from storage with {len(processed_job_ids)} processed jobs")
|
629
|
+
# No need to update in-memory work units since we create on demand
|
563
630
|
|
564
631
|
def get_stats(self) -> Dict[str, Any]:
|
565
632
|
"""Get processor statistics."""
|
@@ -568,12 +635,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
568
635
|
"dataset": self.dataset_name,
|
569
636
|
"config": self.config,
|
570
637
|
"split": self.split,
|
571
|
-
"total_units": len(self.work_units),
|
572
638
|
"pending_units": len(self.pending_units),
|
573
639
|
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
574
640
|
"total_shards": len(self.shard_info),
|
575
641
|
"total_items": self.total_items,
|
576
642
|
"workers": len(self.assigned_units),
|
643
|
+
"current_chunk_index": self.current_chunk_index,
|
577
644
|
}
|
578
645
|
return stats
|
579
646
|
|
@@ -581,71 +648,111 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
581
648
|
"""Handle result processing."""
|
582
649
|
base_result = super().handle_result(result)
|
583
650
|
|
584
|
-
# Track processed items
|
585
651
|
if self.chunk_tracker:
|
586
|
-
if "item_indices"
|
587
|
-
|
588
|
-
|
652
|
+
if "item_indices" in result.metadata:
|
653
|
+
indices = result.metadata["item_indices"]
|
654
|
+
if indices:
|
655
|
+
# Convert to ranges for efficient tracking
|
656
|
+
indices.sort()
|
657
|
+
ranges = []
|
658
|
+
start = indices[0]
|
659
|
+
end = indices[0]
|
660
|
+
|
661
|
+
for i in range(1, len(indices)):
|
662
|
+
if indices[i] == end + 1:
|
663
|
+
end = indices[i]
|
664
|
+
else:
|
665
|
+
ranges.append((start, end))
|
666
|
+
start = indices[i]
|
667
|
+
end = indices[i]
|
589
668
|
|
590
|
-
|
591
|
-
indices.sort()
|
592
|
-
ranges = []
|
593
|
-
start = indices[0]
|
594
|
-
end = indices[0]
|
669
|
+
ranges.append((start, end))
|
595
670
|
|
596
|
-
|
597
|
-
|
598
|
-
end = indices[i]
|
599
|
-
else:
|
600
|
-
ranges.append((start, end))
|
601
|
-
start = indices[i]
|
602
|
-
end = indices[i]
|
671
|
+
for start_idx, end_idx in ranges:
|
672
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
603
673
|
|
604
|
-
|
674
|
+
return base_result
|
605
675
|
|
606
|
-
|
607
|
-
|
676
|
+
def _subtract_ranges(
|
677
|
+
self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
|
678
|
+
) -> List[Tuple[int, int]]:
|
679
|
+
"""Subtract processed ranges from total ranges."""
|
680
|
+
if not processed_ranges:
|
681
|
+
return total_ranges
|
608
682
|
|
609
|
-
|
683
|
+
# Create a set of all processed indices
|
684
|
+
processed_indices = set()
|
685
|
+
for start, end in processed_ranges:
|
686
|
+
processed_indices.update(range(start, end + 1))
|
687
|
+
|
688
|
+
# Find unprocessed ranges
|
689
|
+
unprocessed_ranges = []
|
690
|
+
for start, end in total_ranges:
|
691
|
+
current_start = None
|
692
|
+
for i in range(start, end + 1):
|
693
|
+
if i not in processed_indices:
|
694
|
+
if current_start is None:
|
695
|
+
current_start = i
|
696
|
+
else:
|
697
|
+
if current_start is not None:
|
698
|
+
unprocessed_ranges.append((current_start, i - 1))
|
699
|
+
current_start = None
|
700
|
+
|
701
|
+
if current_start is not None:
|
702
|
+
unprocessed_ranges.append((current_start, end))
|
703
|
+
|
704
|
+
return unprocessed_ranges
|
705
|
+
|
706
|
+
def cleanup(self):
|
707
|
+
"""Clean up resources."""
|
708
|
+
logger.info("Cleaning up orchestrator resources")
|
709
|
+
|
710
|
+
# Stop background threads
|
711
|
+
self.stop_creation.set()
|
712
|
+
if self.unit_creation_thread:
|
713
|
+
self.unit_creation_thread.join(timeout=5)
|
714
|
+
|
715
|
+
# Shutdown queue handler
|
716
|
+
self.queue_handler.shutdown()
|
717
|
+
|
718
|
+
# Save final state
|
719
|
+
if self.chunk_tracker:
|
720
|
+
self.chunk_tracker.save_checkpoint()
|
610
721
|
|
611
722
|
|
612
723
|
class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
613
|
-
"""
|
724
|
+
"""Memory-optimized worker processor for HuggingFace datasets."""
|
614
725
|
|
615
726
|
def __init__(self):
|
616
|
-
logger.debug("Initializing HuggingFaceDatasetWorkerProcessor")
|
727
|
+
logger.debug("Initializing HuggingFaceDatasetWorkerProcessor (Optimized)")
|
617
728
|
self.dataset_config: Dict[str, Any] = {}
|
618
729
|
self.token = get_token()
|
619
|
-
self.shard_cache: Dict[int, Dataset] = {} # Cache loaded shards
|
620
730
|
self.image_column: Optional[str] = None
|
621
731
|
self.url_column: Optional[str] = None
|
622
732
|
|
733
|
+
# Thread-local storage for shard info to avoid repeated builder loading
|
734
|
+
self._thread_local = threading.local()
|
735
|
+
|
623
736
|
def initialize(self, config: ProcessorConfig) -> None:
|
624
737
|
"""Initialize processor."""
|
625
738
|
logger.debug("Initializing worker with config: %s", config.config)
|
626
739
|
self.dataset_config = config.config.get("dataset", {})
|
627
740
|
|
628
|
-
# Determine if this is an image URL dataset or binary image dataset
|
629
741
|
self.image_column = self.dataset_config.get("dataset_image_column", "image")
|
630
742
|
self.url_column = self.dataset_config.get("dataset_url_column", "image_url")
|
631
743
|
self.dataset_path = self.dataset_config.get("dataset_path", None)
|
632
744
|
|
633
|
-
|
634
|
-
|
635
|
-
if
|
636
|
-
|
637
|
-
|
638
|
-
logger.info(f"Loading shard {shard_id}: {shard_filename}")
|
745
|
+
# Add mock results flag
|
746
|
+
self.mock_results = self.dataset_config.get("mock_results", False)
|
747
|
+
if self.mock_results:
|
748
|
+
logger.info("Mock results mode enabled - will generate dummy images")
|
639
749
|
|
640
|
-
|
750
|
+
def _get_shard_path(self, dataset_name: str, shard_filename: str) -> str:
|
751
|
+
"""Get local path for a shard, downloading if needed."""
|
752
|
+
return hf_hub_download(
|
641
753
|
repo_id=dataset_name, filename=shard_filename, repo_type="dataset", token=self.token
|
642
754
|
)
|
643
755
|
|
644
|
-
dataset = Dataset.from_parquet(local_path)
|
645
|
-
self.shard_cache[shard_id] = dataset
|
646
|
-
|
647
|
-
return dataset
|
648
|
-
|
649
756
|
def _extract_filename_from_url(self, url: str) -> str:
|
650
757
|
"""Extract filename from HF URL format."""
|
651
758
|
match = re.search(r"@[a-f0-9]+/(.+)$", url)
|
@@ -653,161 +760,227 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
|
653
760
|
return match.group(1)
|
654
761
|
return url.split("/")[-1]
|
655
762
|
|
763
|
+
def _create_dummy_image(self, index: int, metadata: Dict[str, Any]) -> Image.Image:
|
764
|
+
"""Create a dummy image"""
|
765
|
+
color = (0, 0, 0)
|
766
|
+
width, height = 128, 128
|
767
|
+
image = Image.new("RGB", (width, height), color=color)
|
768
|
+
|
769
|
+
return image
|
770
|
+
|
656
771
|
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
657
772
|
"""Process a work unit, yielding items to be captioned."""
|
658
|
-
logger.debug("Processing unit: %s", unit.unit_id)
|
773
|
+
logger.debug("Processing unit: %s (mock_results=%s)", unit.unit_id, self.mock_results)
|
774
|
+
log_memory(f"start processing unit {unit.unit_id}")
|
659
775
|
|
660
776
|
dataset_name = unit.data["dataset_name"]
|
661
|
-
config = unit.data["config"]
|
662
|
-
split = unit.data["split"]
|
663
777
|
start_index = unit.data["start_index"]
|
664
778
|
chunk_size = unit.data["chunk_size"]
|
665
779
|
unprocessed_ranges = unit.data.get(
|
666
780
|
"unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
|
667
781
|
)
|
668
782
|
shard_ids = unit.data.get("shard_ids", [])
|
783
|
+
data_files = unit.data.get("data_files", [])
|
669
784
|
|
670
785
|
logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
|
671
786
|
|
672
|
-
#
|
673
|
-
# For now, we'll need to load dataset builder to get file info
|
674
|
-
from datasets import load_dataset_builder
|
675
|
-
|
676
|
-
builder = load_dataset_builder(dataset_name, config)
|
677
|
-
|
678
|
-
data_files = []
|
679
|
-
if hasattr(builder.config, "data_files"):
|
680
|
-
if isinstance(builder.config.data_files, dict):
|
681
|
-
files = builder.config.data_files.get(split, [])
|
682
|
-
if isinstance(files, str):
|
683
|
-
files = [files]
|
684
|
-
data_files = files
|
685
|
-
|
686
|
-
# Build shard info
|
787
|
+
# Build shard info from provided data files (no dataset builder needed)
|
687
788
|
shard_info = {}
|
688
|
-
cumulative_offset = 0
|
689
|
-
|
690
|
-
for i, file_url in enumerate(data_files):
|
691
|
-
if i not in shard_ids:
|
692
|
-
# Skip loading this shard, but we need its size for offsets
|
693
|
-
# This is inefficient - in real implementation, orchestrator should pass this info
|
694
|
-
filename = self._extract_filename_from_url(file_url)
|
695
|
-
dataset = self._load_shard(dataset_name, filename, i)
|
696
|
-
size = len(dataset)
|
697
|
-
cumulative_offset += size
|
698
|
-
continue
|
699
789
|
|
700
|
-
|
701
|
-
|
790
|
+
if data_files:
|
791
|
+
# Use provided data files
|
792
|
+
for i, file_url in enumerate(data_files):
|
793
|
+
if i in shard_ids:
|
794
|
+
filename = self._extract_filename_from_url(file_url)
|
795
|
+
shard_path = self._get_shard_path(dataset_name, filename)
|
796
|
+
|
797
|
+
# Get size from metadata
|
798
|
+
metadata = pq.read_metadata(shard_path)
|
799
|
+
size = metadata.num_rows
|
800
|
+
|
801
|
+
shard_info[i] = {
|
802
|
+
"path": shard_path,
|
803
|
+
"start_offset": 0, # Will be set below
|
804
|
+
"end_offset": 0, # Will be set below
|
805
|
+
"size": size,
|
806
|
+
"metadata": metadata,
|
807
|
+
}
|
702
808
|
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
809
|
+
# Calculate offsets
|
810
|
+
cumulative_offset = 0
|
811
|
+
for i in range(max(shard_info.keys()) + 1):
|
812
|
+
if i in shard_info:
|
813
|
+
shard_info[i]["start_offset"] = cumulative_offset
|
814
|
+
shard_info[i]["end_offset"] = cumulative_offset + shard_info[i]["size"] - 1
|
815
|
+
cumulative_offset += shard_info[i]["size"]
|
816
|
+
else:
|
817
|
+
# Need to get size for offset calculation
|
818
|
+
filename = self._extract_filename_from_url(data_files[i])
|
819
|
+
shard_path = self._get_shard_path(dataset_name, filename)
|
820
|
+
metadata = pq.read_metadata(shard_path)
|
821
|
+
cumulative_offset += metadata.num_rows
|
822
|
+
else:
|
823
|
+
# This should never happen with the new orchestrator
|
824
|
+
raise ValueError("No data files provided in work unit")
|
710
825
|
|
711
826
|
# Create set of indices to process
|
712
827
|
indices_to_process = set()
|
713
828
|
for start, end in unprocessed_ranges:
|
714
829
|
indices_to_process.update(range(start, end + 1))
|
715
830
|
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
# Find which shard contains this index
|
721
|
-
shard_id = None
|
722
|
-
local_idx = None
|
723
|
-
|
724
|
-
for sid, sinfo in shard_info.items():
|
831
|
+
# Group indices by shard
|
832
|
+
indices_by_shard = defaultdict(list)
|
833
|
+
for global_idx in indices_to_process:
|
834
|
+
for shard_id, sinfo in shard_info.items():
|
725
835
|
if sinfo["start_offset"] <= global_idx <= sinfo["end_offset"]:
|
726
|
-
shard_id = sid
|
727
836
|
local_idx = global_idx - sinfo["start_offset"]
|
837
|
+
indices_by_shard[shard_id].append((global_idx, local_idx))
|
728
838
|
break
|
729
839
|
|
730
|
-
|
731
|
-
logger.warning(f"Could not find shard for global index {global_idx}")
|
732
|
-
continue
|
733
|
-
|
734
|
-
try:
|
735
|
-
# Get item from shard
|
736
|
-
item = shard_info[shard_id]["dataset"][local_idx]
|
737
|
-
|
738
|
-
# Check if this is a URL dataset or binary image dataset
|
739
|
-
image = None
|
740
|
-
image_url = None
|
741
|
-
|
742
|
-
# Try URL column first
|
743
|
-
if self.url_column and self.url_column in item:
|
744
|
-
image_url = item[self.url_column]
|
745
|
-
# Download image from URL
|
746
|
-
try:
|
747
|
-
response = requests.get(image_url, timeout=30)
|
748
|
-
response.raise_for_status()
|
749
|
-
image = Image.open(io.BytesIO(response.content))
|
750
|
-
except Exception as e:
|
751
|
-
logger.error(f"Error downloading image from {image_url}: {e}")
|
752
|
-
continue
|
753
|
-
|
754
|
-
# Try binary image column
|
755
|
-
elif self.image_column and self.image_column in item:
|
756
|
-
image_data = item[self.image_column]
|
757
|
-
if isinstance(image_data, Image.Image):
|
758
|
-
image = image_data
|
759
|
-
elif isinstance(image_data, dict) and "bytes" in image_data:
|
760
|
-
# Handle datasets Image feature
|
761
|
-
image = Image.open(io.BytesIO(image_data["bytes"]))
|
762
|
-
elif isinstance(image_data, bytes):
|
763
|
-
image = Image.open(io.BytesIO(image_data))
|
764
|
-
|
765
|
-
if image is None:
|
766
|
-
logger.warning(f"No image found for item at index {global_idx}")
|
767
|
-
continue
|
768
|
-
|
769
|
-
# Build job ID
|
770
|
-
chunk_index = unit.metadata["chunk_index"]
|
771
|
-
shard_name = unit.metadata["shard_name"]
|
772
|
-
job_id_obj = JobId(
|
773
|
-
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(global_idx)
|
774
|
-
)
|
775
|
-
job_id = job_id_obj.get_sample_str()
|
776
|
-
|
777
|
-
# Clean metadata
|
778
|
-
clean_metadata = {
|
779
|
-
k: v
|
780
|
-
for k, v in item.items()
|
781
|
-
if k not in [self.image_column, self.url_column] and not k.startswith("_")
|
782
|
-
}
|
783
|
-
|
784
|
-
clean_metadata.update(
|
785
|
-
{
|
786
|
-
"_item_index": global_idx,
|
787
|
-
"_chunk_relative_index": global_idx - start_index,
|
788
|
-
"_job_id": job_id,
|
789
|
-
"_shard_id": shard_id,
|
790
|
-
"_local_index": local_idx,
|
791
|
-
"_url": image_url,
|
792
|
-
}
|
793
|
-
)
|
794
|
-
|
795
|
-
yield {
|
796
|
-
"image": image,
|
797
|
-
"item_key": str(global_idx),
|
798
|
-
"item_index": global_idx,
|
799
|
-
"metadata": clean_metadata,
|
800
|
-
"job_id": job_id,
|
801
|
-
}
|
802
|
-
|
803
|
-
processed_indices.append(global_idx)
|
840
|
+
processed_indices = []
|
804
841
|
|
805
|
-
|
806
|
-
|
842
|
+
# Process items shard by shard
|
843
|
+
for shard_id, idx_pairs in indices_by_shard.items():
|
844
|
+
shard_path = shard_info[shard_id]["path"]
|
845
|
+
|
846
|
+
# Process in batches to avoid loading entire table
|
847
|
+
batch_size = 100
|
848
|
+
for batch_start in range(0, len(idx_pairs), batch_size):
|
849
|
+
batch_pairs = idx_pairs[batch_start : batch_start + batch_size]
|
850
|
+
local_indices = [local_idx for _, local_idx in batch_pairs]
|
851
|
+
|
852
|
+
# Read only specific rows using PyArrow
|
853
|
+
try:
|
854
|
+
# Create row group filters based on metadata
|
855
|
+
metadata = shard_info[shard_id]["metadata"]
|
856
|
+
row_groups_to_read = set()
|
857
|
+
|
858
|
+
# Find which row groups contain our indices
|
859
|
+
current_row = 0
|
860
|
+
for rg_idx in range(metadata.num_row_groups):
|
861
|
+
rg_metadata = metadata.row_group(rg_idx)
|
862
|
+
rg_num_rows = rg_metadata.num_rows
|
863
|
+
|
864
|
+
# Check if any of our indices are in this row group
|
865
|
+
for local_idx in local_indices:
|
866
|
+
if current_row <= local_idx < current_row + rg_num_rows:
|
867
|
+
row_groups_to_read.add(rg_idx)
|
868
|
+
|
869
|
+
current_row += rg_num_rows
|
870
|
+
|
871
|
+
# Read only necessary row groups
|
872
|
+
parquet_file = pq.ParquetFile(shard_path)
|
873
|
+
table = parquet_file.read_row_groups(list(row_groups_to_read))
|
874
|
+
|
875
|
+
# Process items
|
876
|
+
for global_idx, local_idx in batch_pairs:
|
877
|
+
try:
|
878
|
+
# Get item as dictionary (efficient row extraction)
|
879
|
+
row_dict = table.slice(local_idx, 1).to_pydict()
|
880
|
+
item = {k: v[0] for k, v in row_dict.items()}
|
881
|
+
|
882
|
+
# Process image
|
883
|
+
image = None
|
884
|
+
image_url = None
|
885
|
+
|
886
|
+
if self.mock_results:
|
887
|
+
# In mock mode, create a dummy image
|
888
|
+
logger.debug(f"Creating mock image for index {global_idx}")
|
889
|
+
|
890
|
+
# Still extract URL if available for metadata
|
891
|
+
if self.url_column and self.url_column in item:
|
892
|
+
image_url = item[self.url_column]
|
893
|
+
|
894
|
+
# Create dummy image with metadata context
|
895
|
+
image = self._create_dummy_image(
|
896
|
+
global_idx,
|
897
|
+
{
|
898
|
+
"_shard_id": shard_id,
|
899
|
+
"_local_index": local_idx,
|
900
|
+
},
|
901
|
+
)
|
902
|
+
else:
|
903
|
+
# Normal processing - load real images
|
904
|
+
if self.url_column and self.url_column in item:
|
905
|
+
image_url = item[self.url_column]
|
906
|
+
try:
|
907
|
+
response = requests.get(image_url, timeout=30)
|
908
|
+
response.raise_for_status()
|
909
|
+
image = Image.open(io.BytesIO(response.content))
|
910
|
+
except Exception as e:
|
911
|
+
logger.error(
|
912
|
+
f"Error downloading image from {image_url}: {e}"
|
913
|
+
)
|
914
|
+
continue
|
915
|
+
|
916
|
+
elif self.image_column and self.image_column in item:
|
917
|
+
image_data = item[self.image_column]
|
918
|
+
if isinstance(image_data, dict) and "bytes" in image_data:
|
919
|
+
image = Image.open(io.BytesIO(image_data["bytes"]))
|
920
|
+
elif isinstance(image_data, bytes):
|
921
|
+
image = Image.open(io.BytesIO(image_data))
|
922
|
+
|
923
|
+
if image is None:
|
924
|
+
logger.warning(f"No image found for item at index {global_idx}")
|
925
|
+
continue
|
926
|
+
|
927
|
+
# Build job ID
|
928
|
+
chunk_index = unit.metadata["chunk_index"]
|
929
|
+
shard_name = unit.metadata["shard_name"]
|
930
|
+
job_id_obj = JobId(
|
931
|
+
shard_id=shard_name,
|
932
|
+
chunk_id=str(chunk_index),
|
933
|
+
sample_id=str(global_idx),
|
934
|
+
)
|
935
|
+
job_id = job_id_obj.get_sample_str()
|
936
|
+
|
937
|
+
# Clean metadata
|
938
|
+
clean_metadata = {
|
939
|
+
k: v
|
940
|
+
for k, v in item.items()
|
941
|
+
if k not in [self.image_column, self.url_column]
|
942
|
+
and not k.startswith("_")
|
943
|
+
}
|
944
|
+
|
945
|
+
clean_metadata.update(
|
946
|
+
{
|
947
|
+
"_item_index": global_idx,
|
948
|
+
"_chunk_relative_index": global_idx - start_index,
|
949
|
+
"_job_id": job_id,
|
950
|
+
"_shard_id": shard_id,
|
951
|
+
"_local_index": local_idx,
|
952
|
+
"_url": image_url,
|
953
|
+
"_mock": self.mock_results, # Add flag to indicate mock data
|
954
|
+
}
|
955
|
+
)
|
956
|
+
|
957
|
+
yield {
|
958
|
+
"image": image,
|
959
|
+
"item_key": str(global_idx),
|
960
|
+
"item_index": global_idx,
|
961
|
+
"metadata": clean_metadata,
|
962
|
+
"job_id": job_id,
|
963
|
+
"_processed_indices": processed_indices,
|
964
|
+
}
|
965
|
+
|
966
|
+
processed_indices.append(global_idx)
|
967
|
+
|
968
|
+
except Exception as e:
|
969
|
+
logger.error(f"Error processing item at index {global_idx}: {e}")
|
970
|
+
|
971
|
+
# Explicitly delete table to free memory
|
972
|
+
del table
|
973
|
+
gc.collect()
|
974
|
+
|
975
|
+
except Exception as e:
|
976
|
+
logger.error(f"Error reading batch from shard {shard_id}: {e}")
|
807
977
|
|
808
978
|
# Store processed indices in context
|
809
979
|
context["_processed_indices"] = processed_indices
|
810
|
-
logger.debug(
|
980
|
+
logger.debug(
|
981
|
+
f"Processed {len(processed_indices)} indices for unit {unit.unit_id}: {processed_indices}, {context}"
|
982
|
+
)
|
983
|
+
log_memory(f"end processing unit {unit.unit_id}")
|
811
984
|
|
812
985
|
def prepare_result(
|
813
986
|
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|