caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/cli.py +307 -0
- caption_flow/models.py +26 -0
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/huggingface.py +636 -464
- caption_flow/processors/webdataset.py +379 -534
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +410 -303
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +196 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/viewer.py +594 -0
- caption_flow/workers/caption.py +164 -129
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/METADATA +45 -177
- caption_flow-0.3.1.dist-info/RECORD +33 -0
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- caption_flow-0.2.3.dist-info/RECORD +0 -35
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,24 @@
|
|
1
|
-
"""HuggingFace Datasets processor implementation."""
|
1
|
+
"""HuggingFace Datasets processor implementation - Memory Optimized Version."""
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import threading
|
5
5
|
import re
|
6
|
+
import queue
|
6
7
|
import requests
|
8
|
+
import json
|
9
|
+
import io
|
10
|
+
import os
|
11
|
+
import gc
|
12
|
+
import psutil
|
13
|
+
from concurrent.futures import ThreadPoolExecutor, Future
|
7
14
|
from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
|
8
15
|
from collections import deque, defaultdict
|
9
16
|
from pathlib import Path
|
10
|
-
import json
|
11
|
-
import io
|
12
17
|
from datetime import datetime
|
13
18
|
from PIL import Image
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
get_dataset_split_names,
|
18
|
-
load_dataset_builder,
|
19
|
-
)
|
19
|
+
import pyarrow as pa
|
20
|
+
import pyarrow.parquet as pq
|
21
|
+
from datasets import get_dataset_config_names, get_dataset_split_names
|
20
22
|
from huggingface_hub import hf_hub_download, get_token
|
21
23
|
from caption_flow.storage import StorageManager
|
22
24
|
|
@@ -28,11 +30,82 @@ logger = logging.getLogger(__name__)
|
|
28
30
|
logger.setLevel(logging.DEBUG)
|
29
31
|
|
30
32
|
|
33
|
+
def log_memory(location: str):
|
34
|
+
"""Log memory usage at specific location."""
|
35
|
+
process = psutil.Process(os.getpid())
|
36
|
+
mem_info = process.memory_info()
|
37
|
+
logger.info(
|
38
|
+
f"Memory at {location}: RSS={mem_info.rss/1024/1024:.1f}MB, VMS={mem_info.vms/1024/1024:.1f}MB"
|
39
|
+
)
|
40
|
+
# Force garbage collection
|
41
|
+
gc.collect()
|
42
|
+
|
43
|
+
|
44
|
+
logger = logging.getLogger(__name__)
|
45
|
+
logger.setLevel(logging.DEBUG)
|
46
|
+
|
47
|
+
|
48
|
+
class NonBlockingQueueHandler:
|
49
|
+
"""Handles non-blocking retrieval from queues using concurrent futures."""
|
50
|
+
|
51
|
+
def __init__(self, max_workers: int = 1):
|
52
|
+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
53
|
+
self.pending_futures: Dict[int, Future] = {} # queue_id -> Future
|
54
|
+
|
55
|
+
def get_from_queue_async(self, response_queue: queue.Queue, timeout: float = None) -> Future:
|
56
|
+
"""Start an async queue retrieval."""
|
57
|
+
queue_id = id(response_queue)
|
58
|
+
|
59
|
+
# Check if we already have a pending future for this queue
|
60
|
+
if queue_id in self.pending_futures and not self.pending_futures[queue_id].done():
|
61
|
+
return self.pending_futures[queue_id]
|
62
|
+
|
63
|
+
# Start new async retrieval
|
64
|
+
future = self.executor.submit(response_queue.get, timeout=timeout)
|
65
|
+
self.pending_futures[queue_id] = future
|
66
|
+
return future
|
67
|
+
|
68
|
+
def check_response(self, response_queue: queue.Queue, timeout: float = None) -> Optional[Any]:
|
69
|
+
"""Non-blocking check for queue response."""
|
70
|
+
queue_id = id(response_queue)
|
71
|
+
|
72
|
+
# Start async retrieval if needed
|
73
|
+
future = self.get_from_queue_async(response_queue, timeout)
|
74
|
+
|
75
|
+
# Check if result is ready (non-blocking)
|
76
|
+
if future.done():
|
77
|
+
try:
|
78
|
+
result = future.result(timeout=0)
|
79
|
+
# Clear future for next retrieval
|
80
|
+
if queue_id in self.pending_futures:
|
81
|
+
del self.pending_futures[queue_id]
|
82
|
+
return result
|
83
|
+
except queue.Empty:
|
84
|
+
# Queue was empty, clear future
|
85
|
+
if queue_id in self.pending_futures:
|
86
|
+
del self.pending_futures[queue_id]
|
87
|
+
return None
|
88
|
+
except Exception as e:
|
89
|
+
logger.error(f"Error retrieving from queue: {e}")
|
90
|
+
if queue_id in self.pending_futures:
|
91
|
+
del self.pending_futures[queue_id]
|
92
|
+
return None
|
93
|
+
|
94
|
+
# Result not ready yet
|
95
|
+
return None
|
96
|
+
|
97
|
+
def shutdown(self):
|
98
|
+
"""Shutdown the executor."""
|
99
|
+
self.executor.shutdown(wait=True)
|
100
|
+
|
101
|
+
|
31
102
|
class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
32
|
-
"""
|
103
|
+
"""Memory-optimized orchestrator processor for HuggingFace datasets with non-blocking operations."""
|
33
104
|
|
34
105
|
def __init__(self):
|
35
|
-
logger.debug(
|
106
|
+
logger.debug(
|
107
|
+
"Initializing HuggingFaceDatasetOrchestratorProcessor (Optimized + Non-blocking)"
|
108
|
+
)
|
36
109
|
self.dataset_name: Optional[str] = None
|
37
110
|
self.config: Optional[str] = None
|
38
111
|
self.split: Optional[str] = None
|
@@ -44,19 +117,33 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
44
117
|
self.shard_info: Dict[int, Dict[str, Any]] = {}
|
45
118
|
self.total_items: int = 0
|
46
119
|
|
47
|
-
# Work unit management
|
48
|
-
self.work_units: Dict[str, WorkUnit] = {}
|
120
|
+
# Work unit management - only store active units
|
49
121
|
self.pending_units: Deque[str] = deque()
|
50
|
-
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
122
|
+
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
51
123
|
self.lock = threading.Lock()
|
52
124
|
|
125
|
+
# Track current chunk index for on-demand creation
|
126
|
+
self.current_chunk_index = 0
|
127
|
+
|
128
|
+
# Cache data files info instead of loading builder repeatedly
|
129
|
+
self.data_files: List[str] = []
|
130
|
+
|
53
131
|
# Background thread for creating work units
|
54
132
|
self.unit_creation_thread: Optional[threading.Thread] = None
|
55
133
|
self.stop_creation = threading.Event()
|
56
134
|
|
135
|
+
# Non-blocking queue handler
|
136
|
+
self.queue_handler = NonBlockingQueueHandler()
|
137
|
+
|
138
|
+
# Response processing state
|
139
|
+
self.last_maintenance_time = datetime.now()
|
140
|
+
self.maintenance_interval = 30 # seconds
|
141
|
+
|
57
142
|
def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
|
58
143
|
"""Initialize HuggingFace dataset processor."""
|
59
144
|
logger.debug("Initializing orchestrator with config: %s", config.config)
|
145
|
+
log_memory("start of initialize")
|
146
|
+
|
60
147
|
cfg = config.config
|
61
148
|
|
62
149
|
# Dataset configuration
|
@@ -83,12 +170,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
83
170
|
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
|
84
171
|
|
85
172
|
# Initialize chunk tracking
|
86
|
-
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
87
|
-
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
88
|
-
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
173
|
+
self.checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
174
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
175
|
+
self.chunk_tracker = ChunkTracker(self.checkpoint_dir / "chunks.json")
|
89
176
|
|
90
|
-
# Discover shards
|
91
|
-
self.
|
177
|
+
# Discover shards (optimized)
|
178
|
+
self._discover_shards_optimized()
|
92
179
|
|
93
180
|
# Restore existing state
|
94
181
|
self._restore_state(storage=storage)
|
@@ -98,7 +185,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
98
185
|
target=self._create_units_background, daemon=True
|
99
186
|
)
|
100
187
|
self.unit_creation_thread.start()
|
101
|
-
|
188
|
+
|
189
|
+
log_memory("end of initialize")
|
102
190
|
|
103
191
|
def _detect_config(self, provided_config: Optional[str]) -> str:
|
104
192
|
"""Auto-detect config if not provided."""
|
@@ -110,14 +198,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
110
198
|
if not configs:
|
111
199
|
return "default"
|
112
200
|
|
113
|
-
# Prefer common config names
|
114
201
|
preferred = ["default", "en", "train", "main"]
|
115
202
|
for pref in preferred:
|
116
203
|
if pref in configs:
|
117
204
|
logger.info(f"Auto-selected config: {pref}")
|
118
205
|
return pref
|
119
206
|
|
120
|
-
# Otherwise use first available
|
121
207
|
logger.info(f"Auto-selected first available config: {configs[0]}")
|
122
208
|
return configs[0]
|
123
209
|
except Exception as e:
|
@@ -134,17 +220,14 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
134
220
|
self.dataset_name, config_name=self.config, token=self.token
|
135
221
|
)
|
136
222
|
if not splits:
|
137
|
-
logger.warning("No splits found, using 'train'")
|
138
223
|
return "train"
|
139
224
|
|
140
|
-
# Prefer training splits
|
141
225
|
preferred = ["train", "training", "test", "validation", "dev"]
|
142
226
|
for pref in preferred:
|
143
227
|
if pref in splits:
|
144
228
|
logger.info(f"Auto-selected split: {pref}")
|
145
229
|
return pref
|
146
230
|
|
147
|
-
# Otherwise use first available
|
148
231
|
logger.info(f"Auto-selected first available split: {splits[0]}")
|
149
232
|
return splits[0]
|
150
233
|
except Exception as e:
|
@@ -153,18 +236,16 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
153
236
|
|
154
237
|
def _extract_filename_from_url(self, url: str) -> str:
|
155
238
|
"""Extract filename from HF URL format."""
|
156
|
-
# Format: hf://datasets/user/dataset@hash/filename
|
157
239
|
match = re.search(r"@[a-f0-9]+/(.+)$", url)
|
158
240
|
if match:
|
159
241
|
return match.group(1)
|
160
|
-
# Fallback: just get last part
|
161
242
|
return url.split("/")[-1]
|
162
243
|
|
163
|
-
def
|
164
|
-
"""
|
165
|
-
|
244
|
+
def _get_data_files_from_builder(self) -> List[str]:
|
245
|
+
"""Get data files using dataset builder with minimal memory usage."""
|
246
|
+
# Load builder to get correct file structure
|
247
|
+
from datasets import load_dataset_builder
|
166
248
|
|
167
|
-
# Load dataset builder to get file info
|
168
249
|
builder = load_dataset_builder(self.dataset_name, self.config)
|
169
250
|
|
170
251
|
# Get data files for our split
|
@@ -176,81 +257,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
176
257
|
files = [files]
|
177
258
|
data_files = files
|
178
259
|
|
179
|
-
|
260
|
+
# Explicitly delete builder to free memory
|
261
|
+
del builder
|
262
|
+
gc.collect()
|
263
|
+
|
264
|
+
return data_files
|
265
|
+
|
266
|
+
def _discover_shards_optimized(self):
|
267
|
+
"""Discover all shards using dataset builder but release memory immediately."""
|
268
|
+
logger.info("Discovering shards...")
|
269
|
+
|
270
|
+
# Try to load cached shard info first
|
271
|
+
shard_info_cache_path = (
|
272
|
+
self.checkpoint_dir / f"{self.dataset_name}_{self.config}_{self.split}_shard_info.json"
|
273
|
+
)
|
274
|
+
|
275
|
+
if shard_info_cache_path.exists():
|
276
|
+
try:
|
277
|
+
with open(shard_info_cache_path, "r") as f:
|
278
|
+
cached_info = json.load(f)
|
279
|
+
if (
|
280
|
+
cached_info.get("dataset") == self.dataset_name
|
281
|
+
and cached_info.get("config") == self.config
|
282
|
+
and cached_info.get("split") == self.split
|
283
|
+
):
|
284
|
+
self.shard_info = {int(k): v for k, v in cached_info["shards"].items()}
|
285
|
+
self.total_items = cached_info["total_items"]
|
286
|
+
self.data_files = cached_info.get("data_files", [])
|
287
|
+
logger.info(
|
288
|
+
f"Loaded cached shard info: {len(self.shard_info)} shards, {self.total_items} total items"
|
289
|
+
)
|
290
|
+
return
|
291
|
+
except Exception as e:
|
292
|
+
logger.warning(f"Failed to load cached shard info: {e}")
|
293
|
+
|
294
|
+
# Get data files using dataset builder
|
295
|
+
self.data_files = self._get_data_files_from_builder()
|
296
|
+
|
297
|
+
if not self.data_files:
|
180
298
|
raise ValueError(f"No data files found for split '{self.split}'")
|
181
299
|
|
182
|
-
logger.info(f"Found {len(data_files)} data files")
|
300
|
+
logger.info(f"Found {len(self.data_files)} data files")
|
183
301
|
|
184
|
-
# Get
|
302
|
+
# Get metadata for each shard
|
185
303
|
cumulative_offset = 0
|
186
|
-
for i, file_url in enumerate(data_files):
|
304
|
+
for i, file_url in enumerate(self.data_files):
|
187
305
|
filename = self._extract_filename_from_url(file_url)
|
188
306
|
logger.info(f"Discovering shard {i}: {filename}")
|
189
307
|
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
"size": None,
|
199
|
-
"end_offset": None,
|
200
|
-
}
|
201
|
-
|
202
|
-
# Try to get size from builder info if available
|
203
|
-
if hasattr(builder.info, "splits") and self.split in builder.info.splits:
|
204
|
-
split_info = builder.info.splits[self.split]
|
205
|
-
if split_info.num_examples and len(data_files) == 1:
|
206
|
-
# Single shard case
|
207
|
-
self.shard_info[i]["size"] = split_info.num_examples
|
208
|
-
self.shard_info[i]["end_offset"] = (
|
209
|
-
cumulative_offset + split_info.num_examples - 1
|
210
|
-
)
|
211
|
-
cumulative_offset += split_info.num_examples
|
212
|
-
|
213
|
-
# If we couldn't get sizes, we'll need to load shards on demand
|
214
|
-
if self.shard_info[0]["size"] is None:
|
215
|
-
logger.warning("Shard sizes not available from metadata, will load on demand")
|
216
|
-
else:
|
217
|
-
self.total_items = cumulative_offset
|
218
|
-
logger.info(f"Total items across all shards: {self.total_items}")
|
219
|
-
|
220
|
-
def _get_shard_size(self, shard_id: int) -> int:
|
221
|
-
"""Get size of a shard, loading it if necessary."""
|
222
|
-
if self.shard_info[shard_id]["size"] is not None:
|
223
|
-
return self.shard_info[shard_id]["size"]
|
308
|
+
try:
|
309
|
+
# Download file to get metadata
|
310
|
+
local_path = hf_hub_download(
|
311
|
+
repo_id=self.dataset_name,
|
312
|
+
filename=filename,
|
313
|
+
repo_type="dataset",
|
314
|
+
token=self.token,
|
315
|
+
)
|
224
316
|
|
225
|
-
|
226
|
-
|
227
|
-
|
317
|
+
# Read only metadata
|
318
|
+
metadata = pq.read_metadata(local_path)
|
319
|
+
size = metadata.num_rows
|
320
|
+
|
321
|
+
self.shard_info[i] = {
|
322
|
+
"shard_id": i,
|
323
|
+
"file_url": file_url,
|
324
|
+
"filename": filename,
|
325
|
+
"start_offset": cumulative_offset,
|
326
|
+
"size": size,
|
327
|
+
"end_offset": cumulative_offset + size - 1,
|
328
|
+
}
|
228
329
|
|
229
|
-
|
230
|
-
|
231
|
-
)
|
330
|
+
cumulative_offset += size
|
331
|
+
logger.info(f"Shard {i} ({filename}): {size} rows")
|
232
332
|
|
233
|
-
|
234
|
-
|
235
|
-
|
333
|
+
except Exception as e:
|
334
|
+
logger.error(f"Failed to discover shard {i}: {e}")
|
335
|
+
# Skip this shard
|
336
|
+
continue
|
236
337
|
|
237
|
-
|
238
|
-
|
338
|
+
self.total_items = cumulative_offset
|
339
|
+
logger.info(f"Total items across all shards: {self.total_items}")
|
239
340
|
|
240
|
-
#
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
341
|
+
# Cache shard info
|
342
|
+
try:
|
343
|
+
cache_data = {
|
344
|
+
"dataset": self.dataset_name,
|
345
|
+
"config": self.config,
|
346
|
+
"split": self.split,
|
347
|
+
"shards": self.shard_info,
|
348
|
+
"total_items": self.total_items,
|
349
|
+
"data_files": self.data_files,
|
350
|
+
}
|
351
|
+
with open(shard_info_cache_path, "w") as f:
|
352
|
+
json.dump(cache_data, f)
|
353
|
+
logger.info(f"Cached shard info to {shard_info_cache_path}")
|
354
|
+
except Exception as e:
|
355
|
+
logger.warning(f"Failed to cache shard info: {e}")
|
247
356
|
|
248
|
-
#
|
249
|
-
|
250
|
-
|
251
|
-
logger.info(f"Total items: {self.total_items}")
|
357
|
+
# Force garbage collection
|
358
|
+
gc.collect()
|
359
|
+
log_memory("after discovering shards")
|
252
360
|
|
253
|
-
|
361
|
+
def _get_shard_for_index(self, global_index: int) -> Tuple[int, int]:
|
362
|
+
"""Get shard ID and local index for a global index."""
|
363
|
+
for shard_id, sinfo in self.shard_info.items():
|
364
|
+
if sinfo["start_offset"] <= global_index <= sinfo["end_offset"]:
|
365
|
+
local_index = global_index - sinfo["start_offset"]
|
366
|
+
return shard_id, local_index
|
367
|
+
raise ValueError(f"Global index {global_index} not found in any shard")
|
254
368
|
|
255
369
|
def _restore_state(self, storage: StorageManager) -> None:
|
256
370
|
"""Restore state from chunk tracker."""
|
@@ -258,73 +372,83 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
258
372
|
if not self.chunk_tracker:
|
259
373
|
return
|
260
374
|
|
261
|
-
all_processed_jobs = storage.get_all_processed_job_ids()
|
262
|
-
|
263
375
|
with self.lock:
|
376
|
+
max_chunk_index = -1
|
377
|
+
|
264
378
|
for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
379
|
+
chunk_index = chunk_state.start_index // self.chunk_size
|
380
|
+
max_chunk_index = max(max_chunk_index, chunk_index)
|
381
|
+
|
382
|
+
# Only add incomplete chunks to pending
|
383
|
+
if chunk_state.status != "completed":
|
384
|
+
self.pending_units.append(chunk_id)
|
385
|
+
elif chunk_state.status == "completed" and chunk_state.processed_ranges:
|
386
|
+
logger.warning(
|
387
|
+
f"Chunk {chunk_id} has processed_ranges stored in the checkpoint."
|
388
|
+
)
|
389
|
+
|
390
|
+
self.current_chunk_index = max_chunk_index + 1
|
391
|
+
logger.info(f"Resuming from chunk index {self.current_chunk_index}")
|
270
392
|
|
271
|
-
|
272
|
-
|
273
|
-
|
393
|
+
def _create_work_unit(self, chunk_index: int) -> Optional[WorkUnit]:
|
394
|
+
"""Create a single work unit for a chunk index."""
|
395
|
+
current_index = chunk_index * self.chunk_size
|
396
|
+
|
397
|
+
if current_index >= self.total_items:
|
398
|
+
return None
|
399
|
+
|
400
|
+
chunk_size = min(self.chunk_size, self.total_items - current_index)
|
401
|
+
|
402
|
+
# Find shard for this chunk
|
403
|
+
shard_id, _ = self._get_shard_for_index(current_index)
|
404
|
+
shard_name = Path(self.shard_info[shard_id]["filename"]).stem
|
405
|
+
|
406
|
+
job_id_obj = JobId(shard_id=shard_name, chunk_id=chunk_index, sample_id=current_index)
|
407
|
+
unit_id = job_id_obj.get_chunk_str()
|
408
|
+
|
409
|
+
# Calculate unprocessed ranges based on existing chunk state
|
410
|
+
unprocessed_ranges = [(current_index, current_index + chunk_size - 1)]
|
411
|
+
|
412
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
413
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
414
|
+
if chunk_state.processed_ranges:
|
415
|
+
# Subtract processed ranges from total range
|
416
|
+
unprocessed_ranges = self._subtract_ranges(
|
417
|
+
[(current_index, current_index + chunk_size - 1)], chunk_state.processed_ranges
|
274
418
|
)
|
275
419
|
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
if unprocessed_ranges:
|
280
|
-
# Find which shard(s) this chunk belongs to
|
281
|
-
shard_ids = []
|
282
|
-
for sid, sinfo in self.shard_info.items():
|
283
|
-
# Need size to check
|
284
|
-
if sinfo["size"] is None:
|
285
|
-
self._get_shard_size(sid)
|
286
|
-
|
287
|
-
if (
|
288
|
-
sinfo["start_offset"]
|
289
|
-
<= chunk_state.start_index + chunk_state.chunk_size - 1
|
290
|
-
and sinfo["end_offset"] >= chunk_state.start_index
|
291
|
-
):
|
292
|
-
shard_ids.append(sid)
|
293
|
-
logger.info(f"Found shard {sid} for chunk {chunk_id}: {sinfo}")
|
294
|
-
|
295
|
-
chunk_index = chunk_state.start_index // self.chunk_size
|
296
|
-
shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
|
297
|
-
unit = WorkUnit(
|
298
|
-
unit_id=chunk_id,
|
299
|
-
chunk_id=chunk_id,
|
300
|
-
source_id=shard_name,
|
301
|
-
data={
|
302
|
-
"dataset_name": self.dataset_name,
|
303
|
-
"config": self.config,
|
304
|
-
"split": self.split,
|
305
|
-
"start_index": chunk_state.start_index,
|
306
|
-
"chunk_size": chunk_state.chunk_size,
|
307
|
-
"unprocessed_ranges": unprocessed_ranges,
|
308
|
-
"shard_ids": shard_ids,
|
309
|
-
},
|
310
|
-
metadata={
|
311
|
-
"dataset": self.dataset_name,
|
312
|
-
"shard_name": shard_name,
|
313
|
-
"chunk_index": chunk_index,
|
314
|
-
},
|
315
|
-
)
|
420
|
+
# If all ranges are processed, return None (shouldn't happen if status tracking is correct)
|
421
|
+
if not unprocessed_ranges:
|
422
|
+
return None
|
316
423
|
|
317
|
-
|
318
|
-
|
424
|
+
unit = WorkUnit(
|
425
|
+
unit_id=unit_id,
|
426
|
+
chunk_id=unit_id,
|
427
|
+
source_id=shard_name,
|
428
|
+
data={
|
429
|
+
"dataset_name": self.dataset_name,
|
430
|
+
"config": self.config,
|
431
|
+
"split": self.split,
|
432
|
+
"start_index": current_index,
|
433
|
+
"chunk_size": chunk_size,
|
434
|
+
"unprocessed_ranges": unprocessed_ranges, # Use calculated ranges
|
435
|
+
"shard_ids": [shard_id],
|
436
|
+
"data_files": self.data_files,
|
437
|
+
},
|
438
|
+
metadata={
|
439
|
+
"dataset": self.dataset_name,
|
440
|
+
"shard_name": shard_name,
|
441
|
+
"chunk_index": chunk_index,
|
442
|
+
},
|
443
|
+
)
|
444
|
+
|
445
|
+
return unit
|
319
446
|
|
320
447
|
def _create_units_background(self) -> None:
|
321
448
|
"""Background thread to create work units on demand."""
|
322
449
|
logger.info("Starting work unit creation thread")
|
323
450
|
|
324
|
-
current_index = 0
|
325
|
-
|
326
451
|
while not self.stop_creation.is_set():
|
327
|
-
# Check if we need more units
|
328
452
|
with self.lock:
|
329
453
|
pending_count = len(self.pending_units)
|
330
454
|
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
@@ -337,127 +461,114 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
337
461
|
threading.Event().wait(5)
|
338
462
|
continue
|
339
463
|
|
340
|
-
# Make sure we know total items
|
341
|
-
if self.total_items == 0:
|
342
|
-
# Load all shard sizes
|
343
|
-
for sid in range(len(self.shard_info)):
|
344
|
-
self._get_shard_size(sid)
|
345
|
-
|
346
464
|
# Create units as needed
|
347
465
|
units_created = 0
|
348
466
|
|
349
|
-
while units_created < units_needed
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
):
|
360
|
-
shard_ids.append(sid)
|
361
|
-
shard_name = Path(self.shard_info[shard_ids[0]]["filename"]).stem
|
467
|
+
while units_created < units_needed:
|
468
|
+
logger.debug(f"Creating work unit for chunk {self.current_chunk_index}")
|
469
|
+
if self.current_chunk_index * self.chunk_size >= self.total_items:
|
470
|
+
threading.Event().wait(30)
|
471
|
+
break
|
472
|
+
# Get shard info for proper unit_id
|
473
|
+
current_index = self.current_chunk_index * self.chunk_size
|
474
|
+
if current_index < self.total_items:
|
475
|
+
shard_id, _ = self._get_shard_for_index(current_index)
|
476
|
+
shard_name = Path(self.shard_info[shard_id]["filename"]).stem
|
362
477
|
|
363
478
|
job_id_obj = JobId(
|
364
|
-
shard_id=shard_name,
|
479
|
+
shard_id=shard_name,
|
480
|
+
chunk_id=self.current_chunk_index,
|
481
|
+
sample_id=current_index,
|
365
482
|
)
|
366
|
-
unit_id = (
|
367
|
-
job_id_obj.get_chunk_str()
|
368
|
-
) # just the chunk part, eg pixel-images:chunk:0
|
369
|
-
if unit_id in self.work_units:
|
370
|
-
current_index += self.chunk_size
|
371
|
-
continue
|
372
|
-
|
373
|
-
# Check if chunk is already completed
|
374
|
-
if self.chunk_tracker:
|
375
|
-
chunk_state = self.chunk_tracker.chunks.get(unit_id)
|
376
|
-
if chunk_state and chunk_state.status == "completed":
|
377
|
-
current_index += self.chunk_size
|
378
|
-
continue
|
483
|
+
unit_id = job_id_obj.get_chunk_str()
|
379
484
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
"dataset_name": self.dataset_name,
|
388
|
-
"config": self.config,
|
389
|
-
"split": self.split,
|
390
|
-
"start_index": current_index,
|
391
|
-
"chunk_size": chunk_size,
|
392
|
-
"unprocessed_ranges": [(current_index, current_index + chunk_size - 1)],
|
393
|
-
"shard_ids": shard_ids,
|
394
|
-
},
|
395
|
-
metadata={
|
396
|
-
"dataset": self.dataset_name,
|
397
|
-
"shard_name": shard_name,
|
398
|
-
"chunk_index": chunk_id,
|
399
|
-
},
|
400
|
-
)
|
401
|
-
logger.debug(f"Created WorkUnit: {unit}")
|
485
|
+
with self.lock:
|
486
|
+
# Check if already tracked
|
487
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
488
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
489
|
+
if chunk_state.status == "completed":
|
490
|
+
self.current_chunk_index += 1
|
491
|
+
continue
|
402
492
|
|
403
|
-
|
493
|
+
# Add to pending
|
404
494
|
self.pending_units.append(unit_id)
|
405
495
|
|
496
|
+
# Track in chunk tracker
|
406
497
|
if self.chunk_tracker:
|
498
|
+
start_index = self.current_chunk_index * self.chunk_size
|
499
|
+
chunk_size = min(self.chunk_size, self.total_items - start_index)
|
407
500
|
self.chunk_tracker.add_chunk(
|
408
501
|
unit_id,
|
409
502
|
self.dataset_name,
|
410
|
-
"",
|
411
|
-
|
503
|
+
"",
|
504
|
+
start_index,
|
412
505
|
chunk_size,
|
413
506
|
)
|
414
507
|
|
415
508
|
units_created += 1
|
416
|
-
|
417
|
-
current_index += self.chunk_size
|
509
|
+
self.current_chunk_index += 1
|
418
510
|
|
419
511
|
if units_created > 0:
|
420
|
-
logger.debug(f"Created {units_created} work
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
) ->
|
425
|
-
"""
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
#
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
512
|
+
logger.debug(f"Created {units_created} work unit IDs")
|
513
|
+
|
514
|
+
logger.info("Thread for creating units has completed. Exiting thread.")
|
515
|
+
|
516
|
+
def process_responses_non_blocking(self, response_queue: queue.Queue) -> Optional[WorkResult]:
|
517
|
+
"""
|
518
|
+
Non-blocking method to process responses from workers.
|
519
|
+
Returns a WorkResult if one is available, None otherwise.
|
520
|
+
"""
|
521
|
+
# Check for response without blocking
|
522
|
+
response = self.queue_handler.check_response(response_queue, timeout=0.1)
|
523
|
+
|
524
|
+
if response is not None:
|
525
|
+
# Process the response
|
526
|
+
if isinstance(response, WorkResult):
|
527
|
+
logger.debug(f"Processing response for unit {response.unit_id}")
|
528
|
+
return response
|
529
|
+
else:
|
530
|
+
logger.warning(f"Unexpected response type: {type(response)}")
|
531
|
+
|
532
|
+
# Perform periodic maintenance tasks
|
533
|
+
now = datetime.now()
|
534
|
+
if (now - self.last_maintenance_time).total_seconds() > self.maintenance_interval:
|
535
|
+
self._perform_maintenance()
|
536
|
+
self.last_maintenance_time = now
|
537
|
+
|
538
|
+
return None
|
539
|
+
|
540
|
+
def _perform_maintenance(self):
|
541
|
+
"""Perform periodic maintenance tasks."""
|
542
|
+
with self.lock:
|
543
|
+
# Log current state
|
544
|
+
pending_count = len(self.pending_units)
|
545
|
+
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
546
|
+
logger.debug(f"Maintenance: {pending_count} pending, {assigned_count} assigned units")
|
446
547
|
|
447
|
-
|
448
|
-
|
548
|
+
# Check for stale assignments (workers that might have disconnected)
|
549
|
+
# This would be implemented based on your worker heartbeat mechanism
|
449
550
|
|
450
|
-
|
551
|
+
# Force checkpoint save if needed
|
552
|
+
if self.chunk_tracker:
|
553
|
+
self.chunk_tracker.save_checkpoint()
|
451
554
|
|
452
555
|
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
453
556
|
"""Get available work units for a worker."""
|
454
|
-
logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
|
455
|
-
assigned = []
|
456
557
|
|
558
|
+
logger.debug(
|
559
|
+
"get_work_units called: count=%d worker_id=%s, pending: %d",
|
560
|
+
count,
|
561
|
+
worker_id,
|
562
|
+
len(self.pending_units),
|
563
|
+
)
|
564
|
+
assigned = []
|
457
565
|
with self.lock:
|
458
566
|
while len(assigned) < count and self.pending_units:
|
459
567
|
unit_id = self.pending_units.popleft()
|
460
|
-
|
568
|
+
|
569
|
+
# Create work unit on demand
|
570
|
+
chunk_index = int(unit_id.split(":")[-1])
|
571
|
+
unit = self._create_work_unit(chunk_index)
|
461
572
|
|
462
573
|
if unit:
|
463
574
|
self.assigned_units[worker_id].add(unit_id)
|
@@ -474,22 +585,26 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
474
585
|
"""Mark a work unit as completed."""
|
475
586
|
logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
|
476
587
|
with self.lock:
|
477
|
-
|
478
|
-
self.assigned_units[worker_id].discard(unit_id)
|
588
|
+
self.assigned_units[worker_id].discard(unit_id)
|
479
589
|
|
480
|
-
|
481
|
-
|
590
|
+
if self.chunk_tracker:
|
591
|
+
self.chunk_tracker.mark_completed(unit_id)
|
592
|
+
|
593
|
+
# remove from pending deque if it's there.
|
594
|
+
try:
|
595
|
+
self.pending_units.remove(unit_id)
|
596
|
+
except:
|
597
|
+
pass
|
482
598
|
|
483
599
|
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
484
600
|
"""Mark a work unit as failed."""
|
485
|
-
logger.
|
601
|
+
logger.error("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
|
486
602
|
with self.lock:
|
487
|
-
|
488
|
-
|
489
|
-
self.pending_units.append(unit_id)
|
603
|
+
self.assigned_units[worker_id].discard(unit_id)
|
604
|
+
self.pending_units.append(unit_id)
|
490
605
|
|
491
|
-
|
492
|
-
|
606
|
+
if self.chunk_tracker:
|
607
|
+
self.chunk_tracker.mark_failed(unit_id)
|
493
608
|
|
494
609
|
def release_assignments(self, worker_id: str) -> None:
|
495
610
|
"""Release all assignments for a disconnected worker."""
|
@@ -498,8 +613,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
498
613
|
unit_ids = list(self.assigned_units.get(worker_id, []))
|
499
614
|
|
500
615
|
for unit_id in unit_ids:
|
501
|
-
|
502
|
-
|
616
|
+
logger.debug(f"Adding {unit_id} to pending queue")
|
617
|
+
self.pending_units.append(unit_id)
|
503
618
|
|
504
619
|
if worker_id in self.assigned_units:
|
505
620
|
del self.assigned_units[worker_id]
|
@@ -509,57 +624,8 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
509
624
|
|
510
625
|
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
511
626
|
"""Update work units based on what's been processed."""
|
512
|
-
logger.info(f"Updating
|
513
|
-
|
514
|
-
with self.lock:
|
515
|
-
for unit_id, unit in self.work_units.items():
|
516
|
-
# Extract chunk info from unit
|
517
|
-
logger.debug(f"Checking unit {unit_id} for updates")
|
518
|
-
logger.debug(f"Unit data: {unit.data}")
|
519
|
-
logger.debug(f"Unit metadata: {unit.metadata}")
|
520
|
-
start_index = unit.data["start_index"]
|
521
|
-
chunk_size = unit.data["chunk_size"]
|
522
|
-
shard_name = unit.metadata["shard_name"]
|
523
|
-
chunk_index = unit.metadata["chunk_index"]
|
524
|
-
|
525
|
-
# Find processed indices for this chunk
|
526
|
-
processed_indices = []
|
527
|
-
for job_id in processed_job_ids:
|
528
|
-
# Parse job_id format: "data-0000:chunk:0:idx:42"
|
529
|
-
job_id = JobId.from_str(job_id=job_id)
|
530
|
-
if job_id.shard_id == shard_name and int(job_id.chunk_id) == chunk_index:
|
531
|
-
idx = int(job_id.sample_id)
|
532
|
-
if start_index <= idx < start_index + chunk_size:
|
533
|
-
processed_indices.append(idx)
|
534
|
-
|
535
|
-
if processed_indices:
|
536
|
-
# Convert to ranges
|
537
|
-
processed_indices.sort()
|
538
|
-
processed_ranges = []
|
539
|
-
start = processed_indices[0]
|
540
|
-
end = processed_indices[0]
|
541
|
-
|
542
|
-
for idx in processed_indices[1:]:
|
543
|
-
if idx == end + 1:
|
544
|
-
end = idx
|
545
|
-
else:
|
546
|
-
processed_ranges.append((start, end))
|
547
|
-
start = idx
|
548
|
-
end = idx
|
549
|
-
|
550
|
-
processed_ranges.append((start, end))
|
551
|
-
|
552
|
-
# Calculate unprocessed ranges
|
553
|
-
total_range = [(start_index, start_index + chunk_size - 1)]
|
554
|
-
unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
|
555
|
-
|
556
|
-
# Update unit
|
557
|
-
unit.data["unprocessed_ranges"] = unprocessed_ranges
|
558
|
-
|
559
|
-
logger.debug(
|
560
|
-
f"Updated unit {unit_id}: {len(processed_indices)} processed, "
|
561
|
-
f"unprocessed ranges: {unprocessed_ranges}"
|
562
|
-
)
|
627
|
+
logger.info(f"Updating from storage with {len(processed_job_ids)} processed jobs")
|
628
|
+
# No need to update in-memory work units since we create on demand
|
563
629
|
|
564
630
|
def get_stats(self) -> Dict[str, Any]:
|
565
631
|
"""Get processor statistics."""
|
@@ -568,12 +634,12 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
568
634
|
"dataset": self.dataset_name,
|
569
635
|
"config": self.config,
|
570
636
|
"split": self.split,
|
571
|
-
"total_units": len(self.work_units),
|
572
637
|
"pending_units": len(self.pending_units),
|
573
638
|
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
574
639
|
"total_shards": len(self.shard_info),
|
575
640
|
"total_items": self.total_items,
|
576
641
|
"workers": len(self.assigned_units),
|
642
|
+
"current_chunk_index": self.current_chunk_index,
|
577
643
|
}
|
578
644
|
return stats
|
579
645
|
|
@@ -581,71 +647,111 @@ class HuggingFaceDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
581
647
|
"""Handle result processing."""
|
582
648
|
base_result = super().handle_result(result)
|
583
649
|
|
584
|
-
# Track processed items
|
585
650
|
if self.chunk_tracker:
|
586
|
-
if "item_indices"
|
587
|
-
|
588
|
-
|
651
|
+
if "item_indices" in result.metadata:
|
652
|
+
indices = result.metadata["item_indices"]
|
653
|
+
if indices:
|
654
|
+
# Convert to ranges for efficient tracking
|
655
|
+
indices.sort()
|
656
|
+
ranges = []
|
657
|
+
start = indices[0]
|
658
|
+
end = indices[0]
|
659
|
+
|
660
|
+
for i in range(1, len(indices)):
|
661
|
+
if indices[i] == end + 1:
|
662
|
+
end = indices[i]
|
663
|
+
else:
|
664
|
+
ranges.append((start, end))
|
665
|
+
start = indices[i]
|
666
|
+
end = indices[i]
|
589
667
|
|
590
|
-
|
591
|
-
indices.sort()
|
592
|
-
ranges = []
|
593
|
-
start = indices[0]
|
594
|
-
end = indices[0]
|
668
|
+
ranges.append((start, end))
|
595
669
|
|
596
|
-
|
597
|
-
|
598
|
-
end = indices[i]
|
599
|
-
else:
|
600
|
-
ranges.append((start, end))
|
601
|
-
start = indices[i]
|
602
|
-
end = indices[i]
|
670
|
+
for start_idx, end_idx in ranges:
|
671
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
603
672
|
|
604
|
-
|
673
|
+
return base_result
|
605
674
|
|
606
|
-
|
607
|
-
|
675
|
+
def _subtract_ranges(
|
676
|
+
self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
|
677
|
+
) -> List[Tuple[int, int]]:
|
678
|
+
"""Subtract processed ranges from total ranges."""
|
679
|
+
if not processed_ranges:
|
680
|
+
return total_ranges
|
608
681
|
|
609
|
-
|
682
|
+
# Create a set of all processed indices
|
683
|
+
processed_indices = set()
|
684
|
+
for start, end in processed_ranges:
|
685
|
+
processed_indices.update(range(start, end + 1))
|
686
|
+
|
687
|
+
# Find unprocessed ranges
|
688
|
+
unprocessed_ranges = []
|
689
|
+
for start, end in total_ranges:
|
690
|
+
current_start = None
|
691
|
+
for i in range(start, end + 1):
|
692
|
+
if i not in processed_indices:
|
693
|
+
if current_start is None:
|
694
|
+
current_start = i
|
695
|
+
else:
|
696
|
+
if current_start is not None:
|
697
|
+
unprocessed_ranges.append((current_start, i - 1))
|
698
|
+
current_start = None
|
699
|
+
|
700
|
+
if current_start is not None:
|
701
|
+
unprocessed_ranges.append((current_start, end))
|
702
|
+
|
703
|
+
return unprocessed_ranges
|
704
|
+
|
705
|
+
def cleanup(self):
|
706
|
+
"""Clean up resources."""
|
707
|
+
logger.info("Cleaning up orchestrator resources")
|
708
|
+
|
709
|
+
# Stop background threads
|
710
|
+
self.stop_creation.set()
|
711
|
+
if self.unit_creation_thread:
|
712
|
+
self.unit_creation_thread.join(timeout=5)
|
713
|
+
|
714
|
+
# Shutdown queue handler
|
715
|
+
self.queue_handler.shutdown()
|
716
|
+
|
717
|
+
# Save final state
|
718
|
+
if self.chunk_tracker:
|
719
|
+
self.chunk_tracker.save_checkpoint()
|
610
720
|
|
611
721
|
|
612
722
|
class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
613
|
-
"""
|
723
|
+
"""Memory-optimized worker processor for HuggingFace datasets."""
|
614
724
|
|
615
725
|
def __init__(self):
|
616
|
-
logger.debug("Initializing HuggingFaceDatasetWorkerProcessor")
|
726
|
+
logger.debug("Initializing HuggingFaceDatasetWorkerProcessor (Optimized)")
|
617
727
|
self.dataset_config: Dict[str, Any] = {}
|
618
728
|
self.token = get_token()
|
619
|
-
self.shard_cache: Dict[int, Dataset] = {} # Cache loaded shards
|
620
729
|
self.image_column: Optional[str] = None
|
621
730
|
self.url_column: Optional[str] = None
|
622
731
|
|
732
|
+
# Thread-local storage for shard info to avoid repeated builder loading
|
733
|
+
self._thread_local = threading.local()
|
734
|
+
|
623
735
|
def initialize(self, config: ProcessorConfig) -> None:
|
624
736
|
"""Initialize processor."""
|
625
737
|
logger.debug("Initializing worker with config: %s", config.config)
|
626
738
|
self.dataset_config = config.config.get("dataset", {})
|
627
739
|
|
628
|
-
# Determine if this is an image URL dataset or binary image dataset
|
629
740
|
self.image_column = self.dataset_config.get("dataset_image_column", "image")
|
630
741
|
self.url_column = self.dataset_config.get("dataset_url_column", "image_url")
|
631
742
|
self.dataset_path = self.dataset_config.get("dataset_path", None)
|
632
743
|
|
633
|
-
|
634
|
-
|
635
|
-
if
|
636
|
-
|
637
|
-
|
638
|
-
logger.info(f"Loading shard {shard_id}: {shard_filename}")
|
744
|
+
# Add mock results flag
|
745
|
+
self.mock_results = self.dataset_config.get("mock_results", False)
|
746
|
+
if self.mock_results:
|
747
|
+
logger.info("Mock results mode enabled - will generate dummy images")
|
639
748
|
|
640
|
-
|
749
|
+
def _get_shard_path(self, dataset_name: str, shard_filename: str) -> str:
|
750
|
+
"""Get local path for a shard, downloading if needed."""
|
751
|
+
return hf_hub_download(
|
641
752
|
repo_id=dataset_name, filename=shard_filename, repo_type="dataset", token=self.token
|
642
753
|
)
|
643
754
|
|
644
|
-
dataset = Dataset.from_parquet(local_path)
|
645
|
-
self.shard_cache[shard_id] = dataset
|
646
|
-
|
647
|
-
return dataset
|
648
|
-
|
649
755
|
def _extract_filename_from_url(self, url: str) -> str:
|
650
756
|
"""Extract filename from HF URL format."""
|
651
757
|
match = re.search(r"@[a-f0-9]+/(.+)$", url)
|
@@ -653,161 +759,227 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
|
|
653
759
|
return match.group(1)
|
654
760
|
return url.split("/")[-1]
|
655
761
|
|
762
|
+
def _create_dummy_image(self, index: int, metadata: Dict[str, Any]) -> Image.Image:
|
763
|
+
"""Create a dummy image"""
|
764
|
+
color = (0, 0, 0)
|
765
|
+
width, height = 128, 128
|
766
|
+
image = Image.new("RGB", (width, height), color=color)
|
767
|
+
|
768
|
+
return image
|
769
|
+
|
656
770
|
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
657
771
|
"""Process a work unit, yielding items to be captioned."""
|
658
|
-
logger.debug("Processing unit: %s", unit.unit_id)
|
772
|
+
logger.debug("Processing unit: %s (mock_results=%s)", unit.unit_id, self.mock_results)
|
773
|
+
log_memory(f"start processing unit {unit.unit_id}")
|
659
774
|
|
660
775
|
dataset_name = unit.data["dataset_name"]
|
661
|
-
config = unit.data["config"]
|
662
|
-
split = unit.data["split"]
|
663
776
|
start_index = unit.data["start_index"]
|
664
777
|
chunk_size = unit.data["chunk_size"]
|
665
778
|
unprocessed_ranges = unit.data.get(
|
666
779
|
"unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
|
667
780
|
)
|
668
781
|
shard_ids = unit.data.get("shard_ids", [])
|
782
|
+
data_files = unit.data.get("data_files", [])
|
669
783
|
|
670
784
|
logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
|
671
785
|
|
672
|
-
#
|
673
|
-
# For now, we'll need to load dataset builder to get file info
|
674
|
-
from datasets import load_dataset_builder
|
675
|
-
|
676
|
-
builder = load_dataset_builder(dataset_name, config)
|
677
|
-
|
678
|
-
data_files = []
|
679
|
-
if hasattr(builder.config, "data_files"):
|
680
|
-
if isinstance(builder.config.data_files, dict):
|
681
|
-
files = builder.config.data_files.get(split, [])
|
682
|
-
if isinstance(files, str):
|
683
|
-
files = [files]
|
684
|
-
data_files = files
|
685
|
-
|
686
|
-
# Build shard info
|
786
|
+
# Build shard info from provided data files (no dataset builder needed)
|
687
787
|
shard_info = {}
|
688
|
-
cumulative_offset = 0
|
689
|
-
|
690
|
-
for i, file_url in enumerate(data_files):
|
691
|
-
if i not in shard_ids:
|
692
|
-
# Skip loading this shard, but we need its size for offsets
|
693
|
-
# This is inefficient - in real implementation, orchestrator should pass this info
|
694
|
-
filename = self._extract_filename_from_url(file_url)
|
695
|
-
dataset = self._load_shard(dataset_name, filename, i)
|
696
|
-
size = len(dataset)
|
697
|
-
cumulative_offset += size
|
698
|
-
continue
|
699
788
|
|
700
|
-
|
701
|
-
|
789
|
+
if data_files:
|
790
|
+
# Use provided data files
|
791
|
+
for i, file_url in enumerate(data_files):
|
792
|
+
if i in shard_ids:
|
793
|
+
filename = self._extract_filename_from_url(file_url)
|
794
|
+
shard_path = self._get_shard_path(dataset_name, filename)
|
795
|
+
|
796
|
+
# Get size from metadata
|
797
|
+
metadata = pq.read_metadata(shard_path)
|
798
|
+
size = metadata.num_rows
|
799
|
+
|
800
|
+
shard_info[i] = {
|
801
|
+
"path": shard_path,
|
802
|
+
"start_offset": 0, # Will be set below
|
803
|
+
"end_offset": 0, # Will be set below
|
804
|
+
"size": size,
|
805
|
+
"metadata": metadata,
|
806
|
+
}
|
702
807
|
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
808
|
+
# Calculate offsets
|
809
|
+
cumulative_offset = 0
|
810
|
+
for i in range(max(shard_info.keys()) + 1):
|
811
|
+
if i in shard_info:
|
812
|
+
shard_info[i]["start_offset"] = cumulative_offset
|
813
|
+
shard_info[i]["end_offset"] = cumulative_offset + shard_info[i]["size"] - 1
|
814
|
+
cumulative_offset += shard_info[i]["size"]
|
815
|
+
else:
|
816
|
+
# Need to get size for offset calculation
|
817
|
+
filename = self._extract_filename_from_url(data_files[i])
|
818
|
+
shard_path = self._get_shard_path(dataset_name, filename)
|
819
|
+
metadata = pq.read_metadata(shard_path)
|
820
|
+
cumulative_offset += metadata.num_rows
|
821
|
+
else:
|
822
|
+
# This should never happen with the new orchestrator
|
823
|
+
raise ValueError("No data files provided in work unit")
|
710
824
|
|
711
825
|
# Create set of indices to process
|
712
826
|
indices_to_process = set()
|
713
827
|
for start, end in unprocessed_ranges:
|
714
828
|
indices_to_process.update(range(start, end + 1))
|
715
829
|
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
# Find which shard contains this index
|
721
|
-
shard_id = None
|
722
|
-
local_idx = None
|
723
|
-
|
724
|
-
for sid, sinfo in shard_info.items():
|
830
|
+
# Group indices by shard
|
831
|
+
indices_by_shard = defaultdict(list)
|
832
|
+
for global_idx in indices_to_process:
|
833
|
+
for shard_id, sinfo in shard_info.items():
|
725
834
|
if sinfo["start_offset"] <= global_idx <= sinfo["end_offset"]:
|
726
|
-
shard_id = sid
|
727
835
|
local_idx = global_idx - sinfo["start_offset"]
|
836
|
+
indices_by_shard[shard_id].append((global_idx, local_idx))
|
728
837
|
break
|
729
838
|
|
730
|
-
|
731
|
-
logger.warning(f"Could not find shard for global index {global_idx}")
|
732
|
-
continue
|
733
|
-
|
734
|
-
try:
|
735
|
-
# Get item from shard
|
736
|
-
item = shard_info[shard_id]["dataset"][local_idx]
|
737
|
-
|
738
|
-
# Check if this is a URL dataset or binary image dataset
|
739
|
-
image = None
|
740
|
-
image_url = None
|
741
|
-
|
742
|
-
# Try URL column first
|
743
|
-
if self.url_column and self.url_column in item:
|
744
|
-
image_url = item[self.url_column]
|
745
|
-
# Download image from URL
|
746
|
-
try:
|
747
|
-
response = requests.get(image_url, timeout=30)
|
748
|
-
response.raise_for_status()
|
749
|
-
image = Image.open(io.BytesIO(response.content))
|
750
|
-
except Exception as e:
|
751
|
-
logger.error(f"Error downloading image from {image_url}: {e}")
|
752
|
-
continue
|
753
|
-
|
754
|
-
# Try binary image column
|
755
|
-
elif self.image_column and self.image_column in item:
|
756
|
-
image_data = item[self.image_column]
|
757
|
-
if isinstance(image_data, Image.Image):
|
758
|
-
image = image_data
|
759
|
-
elif isinstance(image_data, dict) and "bytes" in image_data:
|
760
|
-
# Handle datasets Image feature
|
761
|
-
image = Image.open(io.BytesIO(image_data["bytes"]))
|
762
|
-
elif isinstance(image_data, bytes):
|
763
|
-
image = Image.open(io.BytesIO(image_data))
|
764
|
-
|
765
|
-
if image is None:
|
766
|
-
logger.warning(f"No image found for item at index {global_idx}")
|
767
|
-
continue
|
768
|
-
|
769
|
-
# Build job ID
|
770
|
-
chunk_index = unit.metadata["chunk_index"]
|
771
|
-
shard_name = unit.metadata["shard_name"]
|
772
|
-
job_id_obj = JobId(
|
773
|
-
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(global_idx)
|
774
|
-
)
|
775
|
-
job_id = job_id_obj.get_sample_str()
|
776
|
-
|
777
|
-
# Clean metadata
|
778
|
-
clean_metadata = {
|
779
|
-
k: v
|
780
|
-
for k, v in item.items()
|
781
|
-
if k not in [self.image_column, self.url_column] and not k.startswith("_")
|
782
|
-
}
|
783
|
-
|
784
|
-
clean_metadata.update(
|
785
|
-
{
|
786
|
-
"_item_index": global_idx,
|
787
|
-
"_chunk_relative_index": global_idx - start_index,
|
788
|
-
"_job_id": job_id,
|
789
|
-
"_shard_id": shard_id,
|
790
|
-
"_local_index": local_idx,
|
791
|
-
"_url": image_url,
|
792
|
-
}
|
793
|
-
)
|
794
|
-
|
795
|
-
yield {
|
796
|
-
"image": image,
|
797
|
-
"item_key": str(global_idx),
|
798
|
-
"item_index": global_idx,
|
799
|
-
"metadata": clean_metadata,
|
800
|
-
"job_id": job_id,
|
801
|
-
}
|
802
|
-
|
803
|
-
processed_indices.append(global_idx)
|
839
|
+
processed_indices = []
|
804
840
|
|
805
|
-
|
806
|
-
|
841
|
+
# Process items shard by shard
|
842
|
+
for shard_id, idx_pairs in indices_by_shard.items():
|
843
|
+
shard_path = shard_info[shard_id]["path"]
|
844
|
+
|
845
|
+
# Process in batches to avoid loading entire table
|
846
|
+
batch_size = 100
|
847
|
+
for batch_start in range(0, len(idx_pairs), batch_size):
|
848
|
+
batch_pairs = idx_pairs[batch_start : batch_start + batch_size]
|
849
|
+
local_indices = [local_idx for _, local_idx in batch_pairs]
|
850
|
+
|
851
|
+
# Read only specific rows using PyArrow
|
852
|
+
try:
|
853
|
+
# Create row group filters based on metadata
|
854
|
+
metadata = shard_info[shard_id]["metadata"]
|
855
|
+
row_groups_to_read = set()
|
856
|
+
|
857
|
+
# Find which row groups contain our indices
|
858
|
+
current_row = 0
|
859
|
+
for rg_idx in range(metadata.num_row_groups):
|
860
|
+
rg_metadata = metadata.row_group(rg_idx)
|
861
|
+
rg_num_rows = rg_metadata.num_rows
|
862
|
+
|
863
|
+
# Check if any of our indices are in this row group
|
864
|
+
for local_idx in local_indices:
|
865
|
+
if current_row <= local_idx < current_row + rg_num_rows:
|
866
|
+
row_groups_to_read.add(rg_idx)
|
867
|
+
|
868
|
+
current_row += rg_num_rows
|
869
|
+
|
870
|
+
# Read only necessary row groups
|
871
|
+
parquet_file = pq.ParquetFile(shard_path)
|
872
|
+
table = parquet_file.read_row_groups(list(row_groups_to_read))
|
873
|
+
|
874
|
+
# Process items
|
875
|
+
for global_idx, local_idx in batch_pairs:
|
876
|
+
try:
|
877
|
+
# Get item as dictionary (efficient row extraction)
|
878
|
+
row_dict = table.slice(local_idx, 1).to_pydict()
|
879
|
+
item = {k: v[0] for k, v in row_dict.items()}
|
880
|
+
|
881
|
+
# Process image
|
882
|
+
image = None
|
883
|
+
image_url = None
|
884
|
+
|
885
|
+
if self.mock_results:
|
886
|
+
# In mock mode, create a dummy image
|
887
|
+
logger.debug(f"Creating mock image for index {global_idx}")
|
888
|
+
|
889
|
+
# Still extract URL if available for metadata
|
890
|
+
if self.url_column and self.url_column in item:
|
891
|
+
image_url = item[self.url_column]
|
892
|
+
|
893
|
+
# Create dummy image with metadata context
|
894
|
+
image = self._create_dummy_image(
|
895
|
+
global_idx,
|
896
|
+
{
|
897
|
+
"_shard_id": shard_id,
|
898
|
+
"_local_index": local_idx,
|
899
|
+
},
|
900
|
+
)
|
901
|
+
else:
|
902
|
+
# Normal processing - load real images
|
903
|
+
if self.url_column and self.url_column in item:
|
904
|
+
image_url = item[self.url_column]
|
905
|
+
try:
|
906
|
+
response = requests.get(image_url, timeout=30)
|
907
|
+
response.raise_for_status()
|
908
|
+
image = Image.open(io.BytesIO(response.content))
|
909
|
+
except Exception as e:
|
910
|
+
logger.error(
|
911
|
+
f"Error downloading image from {image_url}: {e}"
|
912
|
+
)
|
913
|
+
continue
|
914
|
+
|
915
|
+
elif self.image_column and self.image_column in item:
|
916
|
+
image_data = item[self.image_column]
|
917
|
+
if isinstance(image_data, dict) and "bytes" in image_data:
|
918
|
+
image = Image.open(io.BytesIO(image_data["bytes"]))
|
919
|
+
elif isinstance(image_data, bytes):
|
920
|
+
image = Image.open(io.BytesIO(image_data))
|
921
|
+
|
922
|
+
if image is None:
|
923
|
+
logger.warning(f"No image found for item at index {global_idx}")
|
924
|
+
continue
|
925
|
+
|
926
|
+
# Build job ID
|
927
|
+
chunk_index = unit.metadata["chunk_index"]
|
928
|
+
shard_name = unit.metadata["shard_name"]
|
929
|
+
job_id_obj = JobId(
|
930
|
+
shard_id=shard_name,
|
931
|
+
chunk_id=str(chunk_index),
|
932
|
+
sample_id=str(global_idx),
|
933
|
+
)
|
934
|
+
job_id = job_id_obj.get_sample_str()
|
935
|
+
|
936
|
+
# Clean metadata
|
937
|
+
clean_metadata = {
|
938
|
+
k: v
|
939
|
+
for k, v in item.items()
|
940
|
+
if k not in [self.image_column, self.url_column]
|
941
|
+
and not k.startswith("_")
|
942
|
+
}
|
943
|
+
|
944
|
+
clean_metadata.update(
|
945
|
+
{
|
946
|
+
"_item_index": global_idx,
|
947
|
+
"_chunk_relative_index": global_idx - start_index,
|
948
|
+
"_job_id": job_id,
|
949
|
+
"_shard_id": shard_id,
|
950
|
+
"_local_index": local_idx,
|
951
|
+
"_url": image_url,
|
952
|
+
"_mock": self.mock_results, # Add flag to indicate mock data
|
953
|
+
}
|
954
|
+
)
|
955
|
+
|
956
|
+
yield {
|
957
|
+
"image": image,
|
958
|
+
"item_key": str(global_idx),
|
959
|
+
"item_index": global_idx,
|
960
|
+
"metadata": clean_metadata,
|
961
|
+
"job_id": job_id,
|
962
|
+
"_processed_indices": processed_indices,
|
963
|
+
}
|
964
|
+
|
965
|
+
processed_indices.append(global_idx)
|
966
|
+
|
967
|
+
except Exception as e:
|
968
|
+
logger.error(f"Error processing item at index {global_idx}: {e}")
|
969
|
+
|
970
|
+
# Explicitly delete table to free memory
|
971
|
+
del table
|
972
|
+
gc.collect()
|
973
|
+
|
974
|
+
except Exception as e:
|
975
|
+
logger.error(f"Error reading batch from shard {shard_id}: {e}")
|
807
976
|
|
808
977
|
# Store processed indices in context
|
809
978
|
context["_processed_indices"] = processed_indices
|
810
|
-
logger.debug(
|
979
|
+
logger.debug(
|
980
|
+
f"Processed {len(processed_indices)} indices for unit {unit.unit_id}: {processed_indices}, {context}"
|
981
|
+
)
|
982
|
+
log_memory(f"end processing unit {unit.unit_id}")
|
811
983
|
|
812
984
|
def prepare_result(
|
813
985
|
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|