caption-flow 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +2 -1
- caption_flow/models.py +108 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1595
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage.py +415 -406
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +94 -35
- caption_flow/utils/dataset_loader.py +64 -522
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +4 -200
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/METADATA +29 -27
- caption_flow-0.2.3.dist-info/RECORD +35 -0
- caption_flow-0.2.1.dist-info/RECORD +0 -29
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.1.dist-info → caption_flow-0.2.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,832 @@
|
|
1
|
+
"""HuggingFace Datasets processor implementation."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import threading
|
5
|
+
import re
|
6
|
+
import requests
|
7
|
+
from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
|
8
|
+
from collections import deque, defaultdict
|
9
|
+
from pathlib import Path
|
10
|
+
import json
|
11
|
+
import io
|
12
|
+
from datetime import datetime
|
13
|
+
from PIL import Image
|
14
|
+
from datasets import (
|
15
|
+
Dataset,
|
16
|
+
get_dataset_config_names,
|
17
|
+
get_dataset_split_names,
|
18
|
+
load_dataset_builder,
|
19
|
+
)
|
20
|
+
from huggingface_hub import hf_hub_download, get_token
|
21
|
+
from caption_flow.storage import StorageManager
|
22
|
+
|
23
|
+
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
24
|
+
from ..utils import ChunkTracker
|
25
|
+
from ..models import JobId
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
logger.setLevel(logging.DEBUG)
|
29
|
+
|
30
|
+
|
31
|
+
class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
32
|
+
"""Orchestrator processor for HuggingFace datasets."""
|
33
|
+
|
34
|
+
def __init__(self):
|
35
|
+
logger.debug("Initializing HuggingFaceDatasetOrchestratorProcessor")
|
36
|
+
self.dataset_name: Optional[str] = None
|
37
|
+
self.config: Optional[str] = None
|
38
|
+
self.split: Optional[str] = None
|
39
|
+
self.chunk_tracker: Optional[ChunkTracker] = None
|
40
|
+
self.chunk_size: int = 1000
|
41
|
+
self.token = get_token()
|
42
|
+
|
43
|
+
# Shard information
|
44
|
+
self.shard_info: Dict[int, Dict[str, Any]] = {}
|
45
|
+
self.total_items: int = 0
|
46
|
+
|
47
|
+
# Work unit management
|
48
|
+
self.work_units: Dict[str, WorkUnit] = {}
|
49
|
+
self.pending_units: Deque[str] = deque()
|
50
|
+
self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
|
51
|
+
self.lock = threading.Lock()
|
52
|
+
|
53
|
+
# Background thread for creating work units
|
54
|
+
self.unit_creation_thread: Optional[threading.Thread] = None
|
55
|
+
self.stop_creation = threading.Event()
|
56
|
+
|
57
|
+
def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
|
58
|
+
"""Initialize HuggingFace dataset processor."""
|
59
|
+
logger.debug("Initializing orchestrator with config: %s", config.config)
|
60
|
+
cfg = config.config
|
61
|
+
|
62
|
+
# Dataset configuration
|
63
|
+
dataset_cfg = cfg.get("dataset", {})
|
64
|
+
self.dataset_name = dataset_cfg.get("dataset_path")
|
65
|
+
if not self.dataset_name:
|
66
|
+
raise ValueError("dataset_path is required in config")
|
67
|
+
|
68
|
+
# Auto-detect config if not provided
|
69
|
+
provided_config = dataset_cfg.get("dataset_config")
|
70
|
+
self.config = self._detect_config(provided_config)
|
71
|
+
|
72
|
+
# Auto-detect split if not provided
|
73
|
+
provided_split = dataset_cfg.get("dataset_split")
|
74
|
+
self.split = self._detect_split(provided_split)
|
75
|
+
|
76
|
+
logger.info(
|
77
|
+
f"Using dataset: {self.dataset_name}, config: {self.config}, split: {self.split}"
|
78
|
+
)
|
79
|
+
|
80
|
+
# Chunk settings
|
81
|
+
self.chunk_size = cfg.get("chunk_size", 1000)
|
82
|
+
self.min_buffer = cfg.get("min_chunk_buffer", 10)
|
83
|
+
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
|
84
|
+
|
85
|
+
# Initialize chunk tracking
|
86
|
+
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
87
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
88
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
89
|
+
|
90
|
+
# Discover shards
|
91
|
+
self._discover_shards()
|
92
|
+
|
93
|
+
# Restore existing state
|
94
|
+
self._restore_state(storage=storage)
|
95
|
+
|
96
|
+
# Start background unit creation
|
97
|
+
self.unit_creation_thread = threading.Thread(
|
98
|
+
target=self._create_units_background, daemon=True
|
99
|
+
)
|
100
|
+
self.unit_creation_thread.start()
|
101
|
+
logger.debug("Unit creation thread started")
|
102
|
+
|
103
|
+
def _detect_config(self, provided_config: Optional[str]) -> str:
|
104
|
+
"""Auto-detect config if not provided."""
|
105
|
+
if provided_config:
|
106
|
+
return provided_config
|
107
|
+
|
108
|
+
try:
|
109
|
+
configs = get_dataset_config_names(self.dataset_name, token=self.token)
|
110
|
+
if not configs:
|
111
|
+
return "default"
|
112
|
+
|
113
|
+
# Prefer common config names
|
114
|
+
preferred = ["default", "en", "train", "main"]
|
115
|
+
for pref in preferred:
|
116
|
+
if pref in configs:
|
117
|
+
logger.info(f"Auto-selected config: {pref}")
|
118
|
+
return pref
|
119
|
+
|
120
|
+
# Otherwise use first available
|
121
|
+
logger.info(f"Auto-selected first available config: {configs[0]}")
|
122
|
+
return configs[0]
|
123
|
+
except Exception as e:
|
124
|
+
logger.warning(f"Error detecting config: {e}, using 'default'")
|
125
|
+
return "default"
|
126
|
+
|
127
|
+
def _detect_split(self, provided_split: Optional[str]) -> str:
|
128
|
+
"""Auto-detect split if not provided."""
|
129
|
+
if provided_split:
|
130
|
+
return provided_split
|
131
|
+
|
132
|
+
try:
|
133
|
+
splits = get_dataset_split_names(
|
134
|
+
self.dataset_name, config_name=self.config, token=self.token
|
135
|
+
)
|
136
|
+
if not splits:
|
137
|
+
logger.warning("No splits found, using 'train'")
|
138
|
+
return "train"
|
139
|
+
|
140
|
+
# Prefer training splits
|
141
|
+
preferred = ["train", "training", "test", "validation", "dev"]
|
142
|
+
for pref in preferred:
|
143
|
+
if pref in splits:
|
144
|
+
logger.info(f"Auto-selected split: {pref}")
|
145
|
+
return pref
|
146
|
+
|
147
|
+
# Otherwise use first available
|
148
|
+
logger.info(f"Auto-selected first available split: {splits[0]}")
|
149
|
+
return splits[0]
|
150
|
+
except Exception as e:
|
151
|
+
logger.warning(f"Error detecting split: {e}, using 'train'")
|
152
|
+
return "train"
|
153
|
+
|
154
|
+
def _extract_filename_from_url(self, url: str) -> str:
|
155
|
+
"""Extract filename from HF URL format."""
|
156
|
+
# Format: hf://datasets/user/dataset@hash/filename
|
157
|
+
match = re.search(r"@[a-f0-9]+/(.+)$", url)
|
158
|
+
if match:
|
159
|
+
return match.group(1)
|
160
|
+
# Fallback: just get last part
|
161
|
+
return url.split("/")[-1]
|
162
|
+
|
163
|
+
def _discover_shards(self):
|
164
|
+
"""Discover all shards and their sizes."""
|
165
|
+
logger.info("Discovering shards...")
|
166
|
+
|
167
|
+
# Load dataset builder to get file info
|
168
|
+
builder = load_dataset_builder(self.dataset_name, self.config)
|
169
|
+
|
170
|
+
# Get data files for our split
|
171
|
+
data_files = []
|
172
|
+
if hasattr(builder.config, "data_files"):
|
173
|
+
if isinstance(builder.config.data_files, dict):
|
174
|
+
files = builder.config.data_files.get(self.split, [])
|
175
|
+
if isinstance(files, str):
|
176
|
+
files = [files]
|
177
|
+
data_files = files
|
178
|
+
|
179
|
+
if not data_files:
|
180
|
+
raise ValueError(f"No data files found for split '{self.split}'")
|
181
|
+
|
182
|
+
logger.info(f"Found {len(data_files)} data files")
|
183
|
+
|
184
|
+
# Get info about each shard
|
185
|
+
cumulative_offset = 0
|
186
|
+
for i, file_url in enumerate(data_files):
|
187
|
+
filename = self._extract_filename_from_url(file_url)
|
188
|
+
logger.info(f"Discovering shard {i}: {filename}")
|
189
|
+
|
190
|
+
# We don't download shards here - workers will do that
|
191
|
+
# For now, store the info we have
|
192
|
+
self.shard_info[i] = {
|
193
|
+
"shard_id": i,
|
194
|
+
"file_url": file_url,
|
195
|
+
"filename": filename,
|
196
|
+
"start_offset": cumulative_offset,
|
197
|
+
# Size will be determined when first worker needs it
|
198
|
+
"size": None,
|
199
|
+
"end_offset": None,
|
200
|
+
}
|
201
|
+
|
202
|
+
# Try to get size from builder info if available
|
203
|
+
if hasattr(builder.info, "splits") and self.split in builder.info.splits:
|
204
|
+
split_info = builder.info.splits[self.split]
|
205
|
+
if split_info.num_examples and len(data_files) == 1:
|
206
|
+
# Single shard case
|
207
|
+
self.shard_info[i]["size"] = split_info.num_examples
|
208
|
+
self.shard_info[i]["end_offset"] = (
|
209
|
+
cumulative_offset + split_info.num_examples - 1
|
210
|
+
)
|
211
|
+
cumulative_offset += split_info.num_examples
|
212
|
+
|
213
|
+
# If we couldn't get sizes, we'll need to load shards on demand
|
214
|
+
if self.shard_info[0]["size"] is None:
|
215
|
+
logger.warning("Shard sizes not available from metadata, will load on demand")
|
216
|
+
else:
|
217
|
+
self.total_items = cumulative_offset
|
218
|
+
logger.info(f"Total items across all shards: {self.total_items}")
|
219
|
+
|
220
|
+
def _get_shard_size(self, shard_id: int) -> int:
|
221
|
+
"""Get size of a shard, loading it if necessary."""
|
222
|
+
if self.shard_info[shard_id]["size"] is not None:
|
223
|
+
return self.shard_info[shard_id]["size"]
|
224
|
+
|
225
|
+
# Need to load the shard to get its size
|
226
|
+
logger.info(f"Loading shard {shard_id} to determine size...")
|
227
|
+
filename = self.shard_info[shard_id]["filename"]
|
228
|
+
|
229
|
+
local_path = hf_hub_download(
|
230
|
+
repo_id=self.dataset_name, filename=filename, repo_type="dataset", token=self.token
|
231
|
+
)
|
232
|
+
|
233
|
+
# Load just to get size
|
234
|
+
dataset = Dataset.from_parquet(local_path)
|
235
|
+
size = len(dataset)
|
236
|
+
|
237
|
+
# Update shard info
|
238
|
+
self.shard_info[shard_id]["size"] = size
|
239
|
+
|
240
|
+
# Update offsets for this and subsequent shards
|
241
|
+
for sid in range(shard_id, len(self.shard_info)):
|
242
|
+
if sid > shard_id:
|
243
|
+
self.shard_info[sid]["start_offset"] = self.shard_info[sid - 1]["end_offset"] + 1
|
244
|
+
self.shard_info[sid]["end_offset"] = (
|
245
|
+
self.shard_info[sid]["start_offset"] + self.shard_info[sid]["size"] - 1
|
246
|
+
)
|
247
|
+
|
248
|
+
# Update total items
|
249
|
+
if all(s["size"] is not None for s in self.shard_info.values()):
|
250
|
+
self.total_items = sum(s["size"] for s in self.shard_info.values())
|
251
|
+
logger.info(f"Total items: {self.total_items}")
|
252
|
+
|
253
|
+
return size
|
254
|
+
|
255
|
+
def _restore_state(self, storage: StorageManager) -> None:
|
256
|
+
"""Restore state from chunk tracker."""
|
257
|
+
logger.debug("Restoring state from chunk tracker")
|
258
|
+
if not self.chunk_tracker:
|
259
|
+
return
|
260
|
+
|
261
|
+
all_processed_jobs = storage.get_all_processed_job_ids()
|
262
|
+
|
263
|
+
with self.lock:
|
264
|
+
for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
|
265
|
+
# Calculate actual unprocessed ranges
|
266
|
+
chunk_range = (
|
267
|
+
chunk_state.start_index,
|
268
|
+
chunk_state.start_index + chunk_state.chunk_size - 1,
|
269
|
+
)
|
270
|
+
|
271
|
+
# Get processed indices for this chunk
|
272
|
+
processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
|
273
|
+
chunk_id, all_processed_jobs
|
274
|
+
)
|
275
|
+
|
276
|
+
# Calculate unprocessed ranges
|
277
|
+
unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
|
278
|
+
|
279
|
+
if unprocessed_ranges:
|
280
|
+
# Find which shard(s) this chunk belongs to
|
281
|
+
shard_ids = []
|
282
|
+
for sid, sinfo in self.shard_info.items():
|
283
|
+
# Need size to check
|
284
|
+
if sinfo["size"] is None:
|
285
|
+
self._get_shard_size(sid)
|
286
|
+
|
287
|
+
if (
|
288
|
+
sinfo["start_offset"]
|
289
|
+
<= chunk_state.start_index + chunk_state.chunk_size - 1
|
290
|
+
and sinfo["end_offset"] >= chunk_state.start_index
|
291
|
+
):
|
292
|
+
shard_ids.append(sid)
|
293
|
+
logger.info(f"Found shard {sid} for chunk {chunk_id}: {sinfo}")
|
294
|
+
|
295
|
+
chunk_index = chunk_state.start_index // self.chunk_size
|
296
|
+
shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
|
297
|
+
unit = WorkUnit(
|
298
|
+
unit_id=chunk_id,
|
299
|
+
chunk_id=chunk_id,
|
300
|
+
source_id=shard_name,
|
301
|
+
data={
|
302
|
+
"dataset_name": self.dataset_name,
|
303
|
+
"config": self.config,
|
304
|
+
"split": self.split,
|
305
|
+
"start_index": chunk_state.start_index,
|
306
|
+
"chunk_size": chunk_state.chunk_size,
|
307
|
+
"unprocessed_ranges": unprocessed_ranges,
|
308
|
+
"shard_ids": shard_ids,
|
309
|
+
},
|
310
|
+
metadata={
|
311
|
+
"dataset": self.dataset_name,
|
312
|
+
"shard_name": shard_name,
|
313
|
+
"chunk_index": chunk_index,
|
314
|
+
},
|
315
|
+
)
|
316
|
+
|
317
|
+
self.work_units[unit.unit_id] = unit
|
318
|
+
self.pending_units.append(unit.unit_id)
|
319
|
+
|
320
|
+
def _create_units_background(self) -> None:
|
321
|
+
"""Background thread to create work units on demand."""
|
322
|
+
logger.info("Starting work unit creation thread")
|
323
|
+
|
324
|
+
current_index = 0
|
325
|
+
|
326
|
+
while not self.stop_creation.is_set():
|
327
|
+
# Check if we need more units
|
328
|
+
with self.lock:
|
329
|
+
pending_count = len(self.pending_units)
|
330
|
+
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
331
|
+
worker_count = max(1, len(self.assigned_units))
|
332
|
+
|
333
|
+
target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
|
334
|
+
units_needed = max(0, target_buffer - (pending_count + assigned_count))
|
335
|
+
|
336
|
+
if units_needed == 0:
|
337
|
+
threading.Event().wait(5)
|
338
|
+
continue
|
339
|
+
|
340
|
+
# Make sure we know total items
|
341
|
+
if self.total_items == 0:
|
342
|
+
# Load all shard sizes
|
343
|
+
for sid in range(len(self.shard_info)):
|
344
|
+
self._get_shard_size(sid)
|
345
|
+
|
346
|
+
# Create units as needed
|
347
|
+
units_created = 0
|
348
|
+
|
349
|
+
while units_created < units_needed and current_index < self.total_items:
|
350
|
+
chunk_size = min(self.chunk_size, self.total_items - current_index)
|
351
|
+
chunk_id = current_index // self.chunk_size
|
352
|
+
|
353
|
+
with self.lock:
|
354
|
+
shard_ids = []
|
355
|
+
for sid, sinfo in self.shard_info.items():
|
356
|
+
if (
|
357
|
+
sinfo["start_offset"] <= current_index + chunk_size - 1
|
358
|
+
and sinfo["end_offset"] >= current_index
|
359
|
+
):
|
360
|
+
shard_ids.append(sid)
|
361
|
+
shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
|
362
|
+
|
363
|
+
job_id_obj = JobId(
|
364
|
+
shard_id=shard_name, chunk_id=chunk_id, sample_id=current_index
|
365
|
+
)
|
366
|
+
unit_id = (
|
367
|
+
job_id_obj.get_chunk_str()
|
368
|
+
) # just the chunk part, eg pixel-images:chunk:0
|
369
|
+
if unit_id in self.work_units:
|
370
|
+
current_index += self.chunk_size
|
371
|
+
continue
|
372
|
+
|
373
|
+
# Check if chunk is already completed
|
374
|
+
if self.chunk_tracker:
|
375
|
+
chunk_state = self.chunk_tracker.chunks.get(unit_id)
|
376
|
+
if chunk_state and chunk_state.status == "completed":
|
377
|
+
current_index += self.chunk_size
|
378
|
+
continue
|
379
|
+
|
380
|
+
# Find which shard(s) this chunk belongs to
|
381
|
+
|
382
|
+
unit = WorkUnit(
|
383
|
+
unit_id=unit_id,
|
384
|
+
chunk_id=unit_id,
|
385
|
+
source_id=shard_name,
|
386
|
+
data={
|
387
|
+
"dataset_name": self.dataset_name,
|
388
|
+
"config": self.config,
|
389
|
+
"split": self.split,
|
390
|
+
"start_index": current_index,
|
391
|
+
"chunk_size": chunk_size,
|
392
|
+
"unprocessed_ranges": [(current_index, current_index + chunk_size - 1)],
|
393
|
+
"shard_ids": shard_ids,
|
394
|
+
},
|
395
|
+
metadata={
|
396
|
+
"dataset": self.dataset_name,
|
397
|
+
"shard_name": shard_name,
|
398
|
+
"chunk_index": chunk_id,
|
399
|
+
},
|
400
|
+
)
|
401
|
+
logger.debug(f"Created WorkUnit: {unit}")
|
402
|
+
|
403
|
+
self.work_units[unit_id] = unit
|
404
|
+
self.pending_units.append(unit_id)
|
405
|
+
|
406
|
+
if self.chunk_tracker:
|
407
|
+
self.chunk_tracker.add_chunk(
|
408
|
+
unit_id,
|
409
|
+
self.dataset_name,
|
410
|
+
"", # No shard URL
|
411
|
+
current_index,
|
412
|
+
chunk_size,
|
413
|
+
)
|
414
|
+
|
415
|
+
units_created += 1
|
416
|
+
|
417
|
+
current_index += self.chunk_size
|
418
|
+
|
419
|
+
if units_created > 0:
|
420
|
+
logger.debug(f"Created {units_created} work units")
|
421
|
+
|
422
|
+
def _subtract_ranges(
|
423
|
+
self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
|
424
|
+
) -> List[Tuple[int, int]]:
|
425
|
+
"""Subtract processed ranges from total ranges."""
|
426
|
+
if not processed_ranges:
|
427
|
+
return total_ranges
|
428
|
+
|
429
|
+
# Create a set of all processed indices
|
430
|
+
processed_indices = set()
|
431
|
+
for start, end in processed_ranges:
|
432
|
+
processed_indices.update(range(start, end + 1))
|
433
|
+
|
434
|
+
# Find unprocessed ranges
|
435
|
+
unprocessed_ranges = []
|
436
|
+
for start, end in total_ranges:
|
437
|
+
current_start = None
|
438
|
+
for i in range(start, end + 1):
|
439
|
+
if i not in processed_indices:
|
440
|
+
if current_start is None:
|
441
|
+
current_start = i
|
442
|
+
else:
|
443
|
+
if current_start is not None:
|
444
|
+
unprocessed_ranges.append((current_start, i - 1))
|
445
|
+
current_start = None
|
446
|
+
|
447
|
+
if current_start is not None:
|
448
|
+
unprocessed_ranges.append((current_start, end))
|
449
|
+
|
450
|
+
return unprocessed_ranges
|
451
|
+
|
452
|
+
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
453
|
+
"""Get available work units for a worker."""
|
454
|
+
logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
|
455
|
+
assigned = []
|
456
|
+
|
457
|
+
with self.lock:
|
458
|
+
while len(assigned) < count and self.pending_units:
|
459
|
+
unit_id = self.pending_units.popleft()
|
460
|
+
unit = self.work_units.get(unit_id)
|
461
|
+
|
462
|
+
if unit:
|
463
|
+
self.assigned_units[worker_id].add(unit_id)
|
464
|
+
assigned.append(unit)
|
465
|
+
logger.debug("Assigning unit %s to worker %s", unit_id, worker_id)
|
466
|
+
|
467
|
+
if self.chunk_tracker:
|
468
|
+
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
469
|
+
|
470
|
+
logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
|
471
|
+
return assigned
|
472
|
+
|
473
|
+
def mark_completed(self, unit_id: str, worker_id: str) -> None:
|
474
|
+
"""Mark a work unit as completed."""
|
475
|
+
logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
|
476
|
+
with self.lock:
|
477
|
+
if unit_id in self.work_units:
|
478
|
+
self.assigned_units[worker_id].discard(unit_id)
|
479
|
+
|
480
|
+
if self.chunk_tracker:
|
481
|
+
self.chunk_tracker.mark_completed(unit_id)
|
482
|
+
|
483
|
+
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
484
|
+
"""Mark a work unit as failed."""
|
485
|
+
logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
|
486
|
+
with self.lock:
|
487
|
+
if unit_id in self.work_units:
|
488
|
+
self.assigned_units[worker_id].discard(unit_id)
|
489
|
+
self.pending_units.append(unit_id)
|
490
|
+
|
491
|
+
if self.chunk_tracker:
|
492
|
+
self.chunk_tracker.mark_failed(unit_id)
|
493
|
+
|
494
|
+
def release_assignments(self, worker_id: str) -> None:
|
495
|
+
"""Release all assignments for a disconnected worker."""
|
496
|
+
logger.debug("Releasing assignments for worker %s", worker_id)
|
497
|
+
with self.lock:
|
498
|
+
unit_ids = list(self.assigned_units.get(worker_id, []))
|
499
|
+
|
500
|
+
for unit_id in unit_ids:
|
501
|
+
if unit_id in self.work_units:
|
502
|
+
self.pending_units.append(unit_id)
|
503
|
+
|
504
|
+
if worker_id in self.assigned_units:
|
505
|
+
del self.assigned_units[worker_id]
|
506
|
+
|
507
|
+
if self.chunk_tracker:
|
508
|
+
self.chunk_tracker.release_worker_chunks(worker_id)
|
509
|
+
|
510
|
+
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
511
|
+
"""Update work units based on what's been processed."""
|
512
|
+
logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
|
513
|
+
|
514
|
+
with self.lock:
|
515
|
+
for unit_id, unit in self.work_units.items():
|
516
|
+
# Extract chunk info from unit
|
517
|
+
logger.debug(f"Checking unit {unit_id} for updates")
|
518
|
+
logger.debug(f"Unit data: {unit.data}")
|
519
|
+
logger.debug(f"Unit metadata: {unit.metadata}")
|
520
|
+
start_index = unit.data["start_index"]
|
521
|
+
chunk_size = unit.data["chunk_size"]
|
522
|
+
shard_name = unit.metadata["shard_name"]
|
523
|
+
chunk_index = unit.metadata["chunk_index"]
|
524
|
+
|
525
|
+
# Find processed indices for this chunk
|
526
|
+
processed_indices = []
|
527
|
+
for job_id in processed_job_ids:
|
528
|
+
# Parse job_id format: "data-0000:chunk:0:idx:42"
|
529
|
+
job_id = JobId.from_str(job_id=job_id)
|
530
|
+
if job_id.shard_id == shard_name and int(job_id.chunk_id) == chunk_index:
|
531
|
+
idx = int(job_id.sample_id)
|
532
|
+
if start_index <= idx < start_index + chunk_size:
|
533
|
+
processed_indices.append(idx)
|
534
|
+
|
535
|
+
if processed_indices:
|
536
|
+
# Convert to ranges
|
537
|
+
processed_indices.sort()
|
538
|
+
processed_ranges = []
|
539
|
+
start = processed_indices[0]
|
540
|
+
end = processed_indices[0]
|
541
|
+
|
542
|
+
for idx in processed_indices[1:]:
|
543
|
+
if idx == end + 1:
|
544
|
+
end = idx
|
545
|
+
else:
|
546
|
+
processed_ranges.append((start, end))
|
547
|
+
start = idx
|
548
|
+
end = idx
|
549
|
+
|
550
|
+
processed_ranges.append((start, end))
|
551
|
+
|
552
|
+
# Calculate unprocessed ranges
|
553
|
+
total_range = [(start_index, start_index + chunk_size - 1)]
|
554
|
+
unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
|
555
|
+
|
556
|
+
# Update unit
|
557
|
+
unit.data["unprocessed_ranges"] = unprocessed_ranges
|
558
|
+
|
559
|
+
logger.debug(
|
560
|
+
f"Updated unit {unit_id}: {len(processed_indices)} processed, "
|
561
|
+
f"unprocessed ranges: {unprocessed_ranges}"
|
562
|
+
)
|
563
|
+
|
564
|
+
def get_stats(self) -> Dict[str, Any]:
|
565
|
+
"""Get processor statistics."""
|
566
|
+
with self.lock:
|
567
|
+
stats = {
|
568
|
+
"dataset": self.dataset_name,
|
569
|
+
"config": self.config,
|
570
|
+
"split": self.split,
|
571
|
+
"total_units": len(self.work_units),
|
572
|
+
"pending_units": len(self.pending_units),
|
573
|
+
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
574
|
+
"total_shards": len(self.shard_info),
|
575
|
+
"total_items": self.total_items,
|
576
|
+
"workers": len(self.assigned_units),
|
577
|
+
}
|
578
|
+
return stats
|
579
|
+
|
580
|
+
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
581
|
+
"""Handle result processing."""
|
582
|
+
base_result = super().handle_result(result)
|
583
|
+
|
584
|
+
# Track processed items
|
585
|
+
if self.chunk_tracker:
|
586
|
+
if "item_indices" not in result.metadata:
|
587
|
+
result.metadata["item_indices"] = [result.metadata.get("_item_index")]
|
588
|
+
indices = result.metadata["item_indices"]
|
589
|
+
|
590
|
+
if indices:
|
591
|
+
indices.sort()
|
592
|
+
ranges = []
|
593
|
+
start = indices[0]
|
594
|
+
end = indices[0]
|
595
|
+
|
596
|
+
for i in range(1, len(indices)):
|
597
|
+
if indices[i] == end + 1:
|
598
|
+
end = indices[i]
|
599
|
+
else:
|
600
|
+
ranges.append((start, end))
|
601
|
+
start = indices[i]
|
602
|
+
end = indices[i]
|
603
|
+
|
604
|
+
ranges.append((start, end))
|
605
|
+
|
606
|
+
for start_idx, end_idx in ranges:
|
607
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
608
|
+
|
609
|
+
return base_result
|
610
|
+
|
611
|
+
|
612
|
+
class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
613
|
+
"""Worker processor for HuggingFace datasets."""
|
614
|
+
|
615
|
+
def __init__(self):
|
616
|
+
logger.debug("Initializing HuggingFaceDatasetWorkerProcessor")
|
617
|
+
self.dataset_config: Dict[str, Any] = {}
|
618
|
+
self.token = get_token()
|
619
|
+
self.shard_cache: Dict[int, Dataset] = {} # Cache loaded shards
|
620
|
+
self.image_column: Optional[str] = None
|
621
|
+
self.url_column: Optional[str] = None
|
622
|
+
|
623
|
+
def initialize(self, config: ProcessorConfig) -> None:
|
624
|
+
"""Initialize processor."""
|
625
|
+
logger.debug("Initializing worker with config: %s", config.config)
|
626
|
+
self.dataset_config = config.config.get("dataset", {})
|
627
|
+
|
628
|
+
# Determine if this is an image URL dataset or binary image dataset
|
629
|
+
self.image_column = self.dataset_config.get("dataset_image_column", "image")
|
630
|
+
self.url_column = self.dataset_config.get("dataset_url_column", "image_url")
|
631
|
+
self.dataset_path = self.dataset_config.get("dataset_path", None)
|
632
|
+
|
633
|
+
def _load_shard(self, dataset_name: str, shard_filename: str, shard_id: int) -> Dataset:
|
634
|
+
"""Load a shard if not already cached."""
|
635
|
+
if shard_id in self.shard_cache:
|
636
|
+
return self.shard_cache[shard_id]
|
637
|
+
|
638
|
+
logger.info(f"Loading shard {shard_id}: {shard_filename}")
|
639
|
+
|
640
|
+
local_path = hf_hub_download(
|
641
|
+
repo_id=dataset_name, filename=shard_filename, repo_type="dataset", token=self.token
|
642
|
+
)
|
643
|
+
|
644
|
+
dataset = Dataset.from_parquet(local_path)
|
645
|
+
self.shard_cache[shard_id] = dataset
|
646
|
+
|
647
|
+
return dataset
|
648
|
+
|
649
|
+
def _extract_filename_from_url(self, url: str) -> str:
|
650
|
+
"""Extract filename from HF URL format."""
|
651
|
+
match = re.search(r"@[a-f0-9]+/(.+)$", url)
|
652
|
+
if match:
|
653
|
+
return match.group(1)
|
654
|
+
return url.split("/")[-1]
|
655
|
+
|
656
|
+
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
657
|
+
"""Process a work unit, yielding items to be captioned."""
|
658
|
+
logger.debug("Processing unit: %s", unit.unit_id)
|
659
|
+
|
660
|
+
dataset_name = unit.data["dataset_name"]
|
661
|
+
config = unit.data["config"]
|
662
|
+
split = unit.data["split"]
|
663
|
+
start_index = unit.data["start_index"]
|
664
|
+
chunk_size = unit.data["chunk_size"]
|
665
|
+
unprocessed_ranges = unit.data.get(
|
666
|
+
"unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
|
667
|
+
)
|
668
|
+
shard_ids = unit.data.get("shard_ids", [])
|
669
|
+
|
670
|
+
logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
|
671
|
+
|
672
|
+
# Need to get shard info - should be passed in unit data
|
673
|
+
# For now, we'll need to load dataset builder to get file info
|
674
|
+
from datasets import load_dataset_builder
|
675
|
+
|
676
|
+
builder = load_dataset_builder(dataset_name, config)
|
677
|
+
|
678
|
+
data_files = []
|
679
|
+
if hasattr(builder.config, "data_files"):
|
680
|
+
if isinstance(builder.config.data_files, dict):
|
681
|
+
files = builder.config.data_files.get(split, [])
|
682
|
+
if isinstance(files, str):
|
683
|
+
files = [files]
|
684
|
+
data_files = files
|
685
|
+
|
686
|
+
# Build shard info
|
687
|
+
shard_info = {}
|
688
|
+
cumulative_offset = 0
|
689
|
+
|
690
|
+
for i, file_url in enumerate(data_files):
|
691
|
+
if i not in shard_ids:
|
692
|
+
# Skip loading this shard, but we need its size for offsets
|
693
|
+
# This is inefficient - in real implementation, orchestrator should pass this info
|
694
|
+
filename = self._extract_filename_from_url(file_url)
|
695
|
+
dataset = self._load_shard(dataset_name, filename, i)
|
696
|
+
size = len(dataset)
|
697
|
+
cumulative_offset += size
|
698
|
+
continue
|
699
|
+
|
700
|
+
filename = self._extract_filename_from_url(file_url)
|
701
|
+
dataset = self._load_shard(dataset_name, filename, i)
|
702
|
+
|
703
|
+
shard_info[i] = {
|
704
|
+
"dataset": dataset,
|
705
|
+
"start_offset": cumulative_offset,
|
706
|
+
"end_offset": cumulative_offset + len(dataset) - 1,
|
707
|
+
"columns": dataset.column_names,
|
708
|
+
}
|
709
|
+
cumulative_offset += len(dataset)
|
710
|
+
|
711
|
+
# Create set of indices to process
|
712
|
+
indices_to_process = set()
|
713
|
+
for start, end in unprocessed_ranges:
|
714
|
+
indices_to_process.update(range(start, end + 1))
|
715
|
+
|
716
|
+
processed_indices = []
|
717
|
+
|
718
|
+
# Process items
|
719
|
+
for global_idx in sorted(indices_to_process):
|
720
|
+
# Find which shard contains this index
|
721
|
+
shard_id = None
|
722
|
+
local_idx = None
|
723
|
+
|
724
|
+
for sid, sinfo in shard_info.items():
|
725
|
+
if sinfo["start_offset"] <= global_idx <= sinfo["end_offset"]:
|
726
|
+
shard_id = sid
|
727
|
+
local_idx = global_idx - sinfo["start_offset"]
|
728
|
+
break
|
729
|
+
|
730
|
+
if shard_id is None:
|
731
|
+
logger.warning(f"Could not find shard for global index {global_idx}")
|
732
|
+
continue
|
733
|
+
|
734
|
+
try:
|
735
|
+
# Get item from shard
|
736
|
+
item = shard_info[shard_id]["dataset"][local_idx]
|
737
|
+
|
738
|
+
# Check if this is a URL dataset or binary image dataset
|
739
|
+
image = None
|
740
|
+
image_url = None
|
741
|
+
|
742
|
+
# Try URL column first
|
743
|
+
if self.url_column and self.url_column in item:
|
744
|
+
image_url = item[self.url_column]
|
745
|
+
# Download image from URL
|
746
|
+
try:
|
747
|
+
response = requests.get(image_url, timeout=30)
|
748
|
+
response.raise_for_status()
|
749
|
+
image = Image.open(io.BytesIO(response.content))
|
750
|
+
except Exception as e:
|
751
|
+
logger.error(f"Error downloading image from {image_url}: {e}")
|
752
|
+
continue
|
753
|
+
|
754
|
+
# Try binary image column
|
755
|
+
elif self.image_column and self.image_column in item:
|
756
|
+
image_data = item[self.image_column]
|
757
|
+
if isinstance(image_data, Image.Image):
|
758
|
+
image = image_data
|
759
|
+
elif isinstance(image_data, dict) and "bytes" in image_data:
|
760
|
+
# Handle datasets Image feature
|
761
|
+
image = Image.open(io.BytesIO(image_data["bytes"]))
|
762
|
+
elif isinstance(image_data, bytes):
|
763
|
+
image = Image.open(io.BytesIO(image_data))
|
764
|
+
|
765
|
+
if image is None:
|
766
|
+
logger.warning(f"No image found for item at index {global_idx}")
|
767
|
+
continue
|
768
|
+
|
769
|
+
# Build job ID
|
770
|
+
chunk_index = unit.metadata["chunk_index"]
|
771
|
+
shard_name = unit.metadata["shard_name"]
|
772
|
+
job_id_obj = JobId(
|
773
|
+
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(global_idx)
|
774
|
+
)
|
775
|
+
job_id = job_id_obj.get_sample_str()
|
776
|
+
|
777
|
+
# Clean metadata
|
778
|
+
clean_metadata = {
|
779
|
+
k: v
|
780
|
+
for k, v in item.items()
|
781
|
+
if k not in [self.image_column, self.url_column] and not k.startswith("_")
|
782
|
+
}
|
783
|
+
|
784
|
+
clean_metadata.update(
|
785
|
+
{
|
786
|
+
"_item_index": global_idx,
|
787
|
+
"_chunk_relative_index": global_idx - start_index,
|
788
|
+
"_job_id": job_id,
|
789
|
+
"_shard_id": shard_id,
|
790
|
+
"_local_index": local_idx,
|
791
|
+
"_url": image_url,
|
792
|
+
}
|
793
|
+
)
|
794
|
+
|
795
|
+
yield {
|
796
|
+
"image": image,
|
797
|
+
"item_key": str(global_idx),
|
798
|
+
"item_index": global_idx,
|
799
|
+
"metadata": clean_metadata,
|
800
|
+
"job_id": job_id,
|
801
|
+
}
|
802
|
+
|
803
|
+
processed_indices.append(global_idx)
|
804
|
+
|
805
|
+
except Exception as e:
|
806
|
+
logger.error(f"Error processing item at index {global_idx}: {e}")
|
807
|
+
|
808
|
+
# Store processed indices in context
|
809
|
+
context["_processed_indices"] = processed_indices
|
810
|
+
logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
|
811
|
+
|
812
|
+
def prepare_result(
|
813
|
+
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|
814
|
+
) -> WorkResult:
|
815
|
+
"""Prepare result."""
|
816
|
+
logger.debug("Preparing result for unit %s", unit.unit_id)
|
817
|
+
result = super().prepare_result(unit, outputs, processing_time_ms)
|
818
|
+
|
819
|
+
# Add processed indices to metadata
|
820
|
+
if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
|
821
|
+
result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
|
822
|
+
|
823
|
+
return result
|
824
|
+
|
825
|
+
def get_dataset_info(self) -> Dict[str, Any]:
|
826
|
+
"""Get dataset information."""
|
827
|
+
return {
|
828
|
+
"dataset_path": self.dataset_config.get("dataset_path"),
|
829
|
+
"dataset_type": "huggingface",
|
830
|
+
"config": self.dataset_config.get("dataset_config"),
|
831
|
+
"split": self.dataset_config.get("dataset_split"),
|
832
|
+
}
|