caption-flow 0.2.4__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/huggingface.py +636 -464
- caption_flow/processors/webdataset.py +379 -534
- caption_flow/storage/manager.py +328 -305
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +196 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/workers/caption.py +164 -129
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/METADATA +2 -1
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/RECORD +15 -20
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,104 +1,111 @@
|
|
1
|
-
"""WebDataset processor implementation."""
|
1
|
+
"""WebDataset processor implementation using webshart TarDataLoader."""
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import threading
|
5
|
+
import gc
|
6
|
+
import os
|
5
7
|
from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
|
6
8
|
from collections import deque, defaultdict
|
7
9
|
from pathlib import Path
|
8
10
|
import json
|
9
|
-
import io
|
10
11
|
from datetime import datetime
|
11
12
|
from PIL import Image
|
12
|
-
|
13
|
+
import io
|
13
14
|
|
15
|
+
from caption_flow.storage import StorageManager
|
14
16
|
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
15
|
-
from ..utils import
|
17
|
+
from ..utils import ChunkTracker
|
18
|
+
|
19
|
+
import webshart
|
20
|
+
import cv2
|
21
|
+
import numpy as np
|
16
22
|
|
17
23
|
logger = logging.getLogger(__name__)
|
18
|
-
logger.setLevel(logging.INFO)
|
19
24
|
|
20
25
|
|
21
26
|
class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
22
|
-
"""Orchestrator processor for WebDataset shards."""
|
27
|
+
"""Orchestrator processor for WebDataset shards using webshart with ChunkTracker."""
|
23
28
|
|
24
29
|
def __init__(self):
|
25
|
-
logger.
|
26
|
-
self.
|
30
|
+
logger.info("Initializing WebDatasetOrchestratorProcessor with webshart + ChunkTracker")
|
31
|
+
self.dataset: Optional[webshart.DiscoveredDataset] = None
|
27
32
|
self.chunk_tracker: Optional[ChunkTracker] = None
|
28
33
|
self.chunk_size: int = 1000
|
29
34
|
|
30
35
|
# Work unit management
|
31
36
|
self.work_units: Dict[str, WorkUnit] = {}
|
32
37
|
self.pending_units: Deque[str] = deque()
|
33
|
-
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
38
|
+
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
34
39
|
self.lock = threading.Lock()
|
35
40
|
|
36
|
-
# Shard
|
37
|
-
self.
|
38
|
-
self.current_shard_index = 0
|
39
|
-
self.current_shard_items = 0
|
41
|
+
# Shard info cache
|
42
|
+
self.shard_info_cache: Dict[int, Dict] = {}
|
40
43
|
|
41
44
|
# Background thread for creating work units
|
42
45
|
self.unit_creation_thread: Optional[threading.Thread] = None
|
43
46
|
self.stop_creation = threading.Event()
|
47
|
+
self.min_buffer = 10
|
48
|
+
self.buffer_multiplier = 3
|
44
49
|
|
45
50
|
def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
|
46
|
-
"""Initialize
|
47
|
-
logger.
|
48
|
-
cfg = config.config
|
51
|
+
"""Initialize with webshart dataset discovery and ChunkTracker."""
|
52
|
+
logger.info("Initializing orchestrator with config")
|
49
53
|
|
50
|
-
|
54
|
+
cfg = config.config
|
51
55
|
dataset_cfg = cfg.get("dataset", {})
|
52
|
-
dataset_path = dataset_cfg.get("dataset_path")
|
53
|
-
|
54
|
-
dataset_split = dataset_cfg.get("dataset_split", "train")
|
55
|
-
image_column = dataset_cfg.get("dataset_image_column", "image")
|
56
|
+
self.dataset_path = dataset_cfg.get("dataset_path")
|
57
|
+
metadata_path = dataset_cfg.get("metadata_path", None)
|
56
58
|
|
57
59
|
# Chunk settings
|
58
60
|
self.chunk_size = cfg.get("chunk_size", 1000)
|
59
61
|
self.min_buffer = cfg.get("min_chunk_buffer", 10)
|
60
62
|
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
|
61
63
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
self.min_buffer,
|
66
|
-
self.buffer_multiplier,
|
67
|
-
)
|
64
|
+
# Cache configuration
|
65
|
+
cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
|
66
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
68
67
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
self.dataset_loader = DatasetLoader(
|
76
|
-
dataset_path=dataset_path,
|
77
|
-
dataset_type=dataset_type,
|
78
|
-
split=dataset_split,
|
79
|
-
image_column=image_column,
|
80
|
-
cache_dir=checkpoint_dir,
|
68
|
+
if self.dataset_path:
|
69
|
+
# Initialize dataset with webshart
|
70
|
+
self.dataset = webshart.discover_dataset(
|
71
|
+
source=self.dataset_path,
|
72
|
+
metadata=metadata_path,
|
81
73
|
)
|
82
|
-
logger.debug("DatasetLoader initialized")
|
83
74
|
|
84
|
-
|
85
|
-
|
75
|
+
# Enable caching for efficient access
|
76
|
+
self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
|
77
|
+
self.dataset.enable_shard_cache(
|
78
|
+
location=str(cache_dir / "shard_cache"),
|
79
|
+
cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
|
80
|
+
)
|
86
81
|
|
87
|
-
|
88
|
-
|
89
|
-
|
82
|
+
logger.info(f"Dataset discovered: {self.dataset.num_shards} shards")
|
83
|
+
|
84
|
+
# Initialize chunk tracker
|
85
|
+
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
86
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
87
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
90
88
|
|
91
89
|
# Restore existing state from chunk tracker
|
92
|
-
self._restore_state(storage
|
90
|
+
self._restore_state(storage)
|
93
91
|
|
94
92
|
# Start background unit creation
|
95
93
|
self.unit_creation_thread = threading.Thread(
|
96
94
|
target=self._create_units_background, daemon=True
|
97
95
|
)
|
98
96
|
self.unit_creation_thread.start()
|
99
|
-
logger.debug("Unit creation thread started")
|
100
97
|
else:
|
101
|
-
logger.error("No dataset_path provided
|
98
|
+
logger.error("No dataset_path provided")
|
99
|
+
|
100
|
+
def _get_shard_info_cached(self, shard_idx: int) -> Optional[Dict]:
|
101
|
+
"""Get shard info with caching."""
|
102
|
+
if shard_idx not in self.shard_info_cache:
|
103
|
+
try:
|
104
|
+
self.shard_info_cache[shard_idx] = self.dataset.get_shard_info(shard_idx)
|
105
|
+
except Exception as e:
|
106
|
+
logger.error(f"Error getting shard info for idx {shard_idx}: {e}")
|
107
|
+
return None
|
108
|
+
return self.shard_info_cache[shard_idx]
|
102
109
|
|
103
110
|
def _restore_state(self, storage: StorageManager) -> None:
|
104
111
|
"""Restore state from chunk tracker."""
|
@@ -108,38 +115,36 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
108
115
|
|
109
116
|
shards_summary = self.chunk_tracker.get_shards_summary()
|
110
117
|
|
111
|
-
# Get all processed job_ids from storage
|
112
|
-
all_processed_jobs = storage.get_all_processed_job_ids()
|
113
|
-
|
114
118
|
with self.lock:
|
115
119
|
for shard_name, shard_info in shards_summary.items():
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
)
|
120
|
+
chunks = shard_info.get("chunks", [])
|
121
|
+
for chunk_state in chunks:
|
122
|
+
# Only add incomplete chunks
|
123
|
+
if chunk_state.status != "completed":
|
124
|
+
logger.debug(f"Restoring incomplete chunk {chunk_state.chunk_id}")
|
122
125
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
126
|
+
# Get unprocessed ranges
|
127
|
+
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
128
|
+
if not unprocessed_ranges:
|
129
|
+
continue
|
127
130
|
|
128
|
-
|
129
|
-
|
131
|
+
# Convert relative ranges to absolute file indices
|
132
|
+
absolute_ranges = []
|
133
|
+
for start, end in unprocessed_ranges:
|
134
|
+
abs_start = chunk_state.start_index + start
|
135
|
+
abs_end = chunk_state.start_index + end
|
136
|
+
absolute_ranges.append((abs_start, abs_end))
|
130
137
|
|
131
|
-
if unprocessed_ranges:
|
132
|
-
# Create work unit for unprocessed items
|
133
|
-
logger.debug(f"Creating WorkUnit for chunk {chunk_state}")
|
134
138
|
unit = WorkUnit(
|
135
139
|
unit_id=chunk_state.chunk_id,
|
136
140
|
chunk_id=chunk_state.chunk_id,
|
137
141
|
source_id=shard_name,
|
138
142
|
data={
|
139
143
|
"shard_url": chunk_state.shard_url,
|
144
|
+
"shard_name": shard_name,
|
140
145
|
"start_index": chunk_state.start_index,
|
141
146
|
"chunk_size": chunk_state.chunk_size,
|
142
|
-
"unprocessed_ranges":
|
147
|
+
"unprocessed_ranges": absolute_ranges,
|
143
148
|
},
|
144
149
|
metadata={
|
145
150
|
"shard_name": shard_name,
|
@@ -154,11 +159,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
154
159
|
"""Background thread to create work units on demand."""
|
155
160
|
logger.info("Starting work unit creation thread")
|
156
161
|
|
157
|
-
|
158
|
-
|
159
|
-
current_shard_name = None
|
160
|
-
current_shard_items = 0
|
161
|
-
current_index = 0
|
162
|
+
current_shard_idx = 0
|
163
|
+
current_file_idx = 0
|
162
164
|
|
163
165
|
while not self.stop_creation.is_set():
|
164
166
|
# Check if we need more units
|
@@ -169,14 +171,6 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
169
171
|
|
170
172
|
target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
|
171
173
|
units_needed = max(0, target_buffer - (pending_count + assigned_count))
|
172
|
-
logger.debug(
|
173
|
-
"pending_count=%d assigned_count=%d worker_count=%d target_buffer=%d units_needed=%d",
|
174
|
-
pending_count,
|
175
|
-
assigned_count,
|
176
|
-
worker_count,
|
177
|
-
target_buffer,
|
178
|
-
units_needed,
|
179
|
-
)
|
180
174
|
|
181
175
|
if units_needed == 0:
|
182
176
|
threading.Event().wait(5)
|
@@ -184,169 +178,91 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
184
178
|
|
185
179
|
# Create units as needed
|
186
180
|
units_created = 0
|
187
|
-
|
188
181
|
while units_created < units_needed and not self.stop_creation.is_set():
|
189
|
-
#
|
190
|
-
if
|
191
|
-
|
192
|
-
|
193
|
-
current_shard_name = Path(current_shard_url).stem
|
182
|
+
# Get current shard info
|
183
|
+
if current_shard_idx >= self.dataset.num_shards:
|
184
|
+
logger.info("All shards processed")
|
185
|
+
break
|
194
186
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
187
|
+
shard_info = self._get_shard_info_cached(current_shard_idx)
|
188
|
+
if not shard_info:
|
189
|
+
current_shard_idx += 1
|
190
|
+
current_file_idx = 0
|
191
|
+
continue
|
192
|
+
|
193
|
+
shard_name = shard_info["name"]
|
194
|
+
shard_files = shard_info["num_files"]
|
195
|
+
|
196
|
+
# Check if we need to move to next shard
|
197
|
+
if current_file_idx >= shard_files:
|
198
|
+
current_shard_idx += 1
|
199
|
+
current_file_idx = 0
|
200
|
+
continue
|
201
|
+
|
202
|
+
# Create chunk for current position
|
203
|
+
chunk_size = min(self.chunk_size, shard_files - current_file_idx)
|
204
|
+
chunk_id = f"{shard_name}:chunk:{current_file_idx // self.chunk_size}"
|
205
|
+
|
206
|
+
with self.lock:
|
207
|
+
# Skip if already exists
|
208
|
+
if chunk_id in self.work_units:
|
209
|
+
current_file_idx += self.chunk_size
|
210
|
+
continue
|
211
|
+
|
212
|
+
# Check if chunk is already completed
|
213
|
+
if self.chunk_tracker:
|
214
|
+
chunk_state = self.chunk_tracker.chunks.get(chunk_id)
|
215
|
+
if chunk_state and chunk_state.status == "completed":
|
216
|
+
current_file_idx += self.chunk_size
|
224
217
|
continue
|
225
218
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
)
|
250
|
-
|
251
|
-
unit = WorkUnit(
|
252
|
-
unit_id=unit_id,
|
253
|
-
chunk_id=unit_id,
|
254
|
-
source_id=current_shard_name,
|
255
|
-
data={
|
256
|
-
"shard_url": current_shard_url,
|
257
|
-
"start_index": current_index,
|
258
|
-
"chunk_size": chunk_size,
|
259
|
-
"unprocessed_ranges": [
|
260
|
-
(
|
261
|
-
r[0] + chunk_state.start_index,
|
262
|
-
r[1] + chunk_state.start_index,
|
263
|
-
)
|
264
|
-
for r in unprocessed_ranges
|
265
|
-
], # Convert relative to absolute
|
266
|
-
},
|
267
|
-
metadata={
|
268
|
-
"shard_name": current_shard_name,
|
269
|
-
"chunk_index": current_index // self.chunk_size,
|
270
|
-
},
|
271
|
-
)
|
272
|
-
else:
|
273
|
-
# New chunk
|
274
|
-
logger.debug(
|
275
|
-
"Creating new work unit: unit_id=%s shard=%s start_index=%d chunk_size=%d",
|
276
|
-
unit_id,
|
277
|
-
current_shard_name,
|
278
|
-
current_index,
|
279
|
-
chunk_size,
|
280
|
-
)
|
281
|
-
|
282
|
-
unit = WorkUnit(
|
283
|
-
unit_id=unit_id,
|
284
|
-
chunk_id=unit_id,
|
285
|
-
source_id=current_shard_name,
|
286
|
-
data={
|
287
|
-
"shard_url": current_shard_url,
|
288
|
-
"start_index": current_index,
|
289
|
-
"chunk_size": chunk_size,
|
290
|
-
"unprocessed_ranges": [
|
291
|
-
(current_index, current_index + chunk_size - 1)
|
292
|
-
],
|
293
|
-
},
|
294
|
-
metadata={
|
295
|
-
"shard_name": current_shard_name,
|
296
|
-
"chunk_index": current_index // self.chunk_size,
|
297
|
-
},
|
298
|
-
)
|
299
|
-
else:
|
300
|
-
# No chunk tracker, create normally
|
301
|
-
unit = WorkUnit(
|
302
|
-
unit_id=unit_id,
|
303
|
-
chunk_id=unit_id,
|
304
|
-
source_id=current_shard_name,
|
305
|
-
data={
|
306
|
-
"shard_url": current_shard_url,
|
307
|
-
"start_index": current_index,
|
308
|
-
"chunk_size": chunk_size,
|
309
|
-
"unprocessed_ranges": [
|
310
|
-
(current_index, current_index + chunk_size - 1)
|
311
|
-
],
|
312
|
-
},
|
313
|
-
metadata={
|
314
|
-
"shard_name": current_shard_name,
|
315
|
-
"chunk_index": current_index // self.chunk_size,
|
316
|
-
},
|
317
|
-
)
|
219
|
+
# Get shard URL (path for webshart)
|
220
|
+
shard_url = shard_info.get("path", f"{shard_name}.tar")
|
221
|
+
|
222
|
+
# Create work unit
|
223
|
+
unit = WorkUnit(
|
224
|
+
unit_id=chunk_id,
|
225
|
+
chunk_id=chunk_id,
|
226
|
+
source_id=shard_name,
|
227
|
+
data={
|
228
|
+
"shard_url": shard_url,
|
229
|
+
"shard_name": shard_name,
|
230
|
+
"shard_idx": current_shard_idx,
|
231
|
+
"start_index": current_file_idx,
|
232
|
+
"chunk_size": chunk_size,
|
233
|
+
"unprocessed_ranges": [
|
234
|
+
(current_file_idx, current_file_idx + chunk_size - 1)
|
235
|
+
],
|
236
|
+
},
|
237
|
+
metadata={
|
238
|
+
"shard_name": shard_name,
|
239
|
+
"chunk_index": current_file_idx // self.chunk_size,
|
240
|
+
},
|
241
|
+
)
|
318
242
|
|
319
|
-
|
320
|
-
|
321
|
-
logger.debug("Added work unit %s to pending_units", unit_id)
|
243
|
+
self.work_units[chunk_id] = unit
|
244
|
+
self.pending_units.append(chunk_id)
|
322
245
|
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
current_index,
|
329
|
-
chunk_size,
|
330
|
-
)
|
331
|
-
if added_chunk:
|
332
|
-
logger.debug("Added chunk to chunk_tracker: %s", unit_id)
|
333
|
-
else:
|
334
|
-
logger.debug("Chunk already exists in chunk_tracker: %s", unit_id)
|
246
|
+
# Add to chunk tracker
|
247
|
+
if self.chunk_tracker:
|
248
|
+
self.chunk_tracker.add_chunk(
|
249
|
+
chunk_id, shard_name, shard_url, current_file_idx, chunk_size
|
250
|
+
)
|
335
251
|
|
336
|
-
|
252
|
+
units_created += 1
|
337
253
|
|
338
|
-
|
254
|
+
current_file_idx += self.chunk_size
|
339
255
|
|
340
256
|
if units_created > 0:
|
341
257
|
logger.debug(f"Created {units_created} work units")
|
342
258
|
|
259
|
+
logger.info("Work unit creation thread exiting")
|
260
|
+
|
343
261
|
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
344
262
|
"""Get available work units for a worker."""
|
345
|
-
logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
|
346
263
|
assigned = []
|
347
264
|
|
348
265
|
with self.lock:
|
349
|
-
# Get new units if needed
|
350
266
|
while len(assigned) < count and self.pending_units:
|
351
267
|
unit_id = self.pending_units.popleft()
|
352
268
|
unit = self.work_units.get(unit_id)
|
@@ -354,148 +270,74 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
354
270
|
if unit:
|
355
271
|
self.assigned_units[worker_id].add(unit_id)
|
356
272
|
assigned.append(unit)
|
357
|
-
logger.debug("Assigning new unit %s to worker %s", unit_id, worker_id)
|
358
273
|
|
359
274
|
if self.chunk_tracker:
|
360
275
|
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
361
276
|
|
362
|
-
logger.debug("
|
277
|
+
logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
|
363
278
|
return assigned
|
364
279
|
|
365
|
-
def _has_unprocessed_items(self, unit: WorkUnit) -> bool:
|
366
|
-
"""Check if a work unit has unprocessed items."""
|
367
|
-
if not self.chunk_tracker:
|
368
|
-
logger.debug("No chunk_tracker, assuming unit %s has unprocessed items", unit.unit_id)
|
369
|
-
return True
|
370
|
-
|
371
|
-
chunk_info = self.chunk_tracker.get_chunk_with_unprocessed_items(unit.unit_id)
|
372
|
-
has_unprocessed = bool(chunk_info and chunk_info.get("unprocessed_ranges"))
|
373
|
-
logger.debug("Unit %s has unprocessed items: %s", unit.unit_id, has_unprocessed)
|
374
|
-
return has_unprocessed
|
375
|
-
|
376
280
|
def mark_completed(self, unit_id: str, worker_id: str) -> None:
|
377
281
|
"""Mark a work unit as completed."""
|
378
|
-
logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
|
379
282
|
with self.lock:
|
380
283
|
if unit_id in self.work_units:
|
381
284
|
self.assigned_units[worker_id].discard(unit_id)
|
382
|
-
logger.debug(
|
383
|
-
"Removed unit %s from assigned_units for worker %s", unit_id, worker_id
|
384
|
-
)
|
385
285
|
|
386
286
|
if self.chunk_tracker:
|
387
287
|
self.chunk_tracker.mark_completed(unit_id)
|
388
|
-
|
288
|
+
|
289
|
+
# Remove from memory
|
290
|
+
del self.work_units[unit_id]
|
389
291
|
|
390
292
|
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
391
293
|
"""Mark a work unit as failed."""
|
392
|
-
logger.
|
294
|
+
logger.error(f"Unit {unit_id} failed on {worker_id}: {error}")
|
393
295
|
with self.lock:
|
394
296
|
if unit_id in self.work_units:
|
395
297
|
self.assigned_units[worker_id].discard(unit_id)
|
396
298
|
self.pending_units.append(unit_id)
|
397
|
-
logger.debug("Returned unit %s to pending_units", unit_id)
|
398
299
|
|
399
300
|
if self.chunk_tracker:
|
400
301
|
self.chunk_tracker.mark_failed(unit_id)
|
401
|
-
logger.debug("Marked unit %s as failed in chunk_tracker", unit_id)
|
402
302
|
|
403
303
|
def release_assignments(self, worker_id: str) -> None:
|
404
304
|
"""Release all assignments for a disconnected worker."""
|
405
|
-
logger.debug("Releasing assignments for worker %s", worker_id)
|
406
305
|
with self.lock:
|
407
306
|
unit_ids = list(self.assigned_units.get(worker_id, []))
|
408
307
|
|
409
308
|
for unit_id in unit_ids:
|
410
|
-
if unit_id in self.work_units:
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
415
|
-
chunk_state = self.chunk_tracker.chunks[unit_id]
|
309
|
+
if unit_id in self.work_units and self.chunk_tracker:
|
310
|
+
# Get updated unprocessed ranges from chunk tracker
|
311
|
+
chunk_state = self.chunk_tracker.chunks.get(unit_id)
|
312
|
+
if chunk_state:
|
416
313
|
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
417
|
-
|
418
|
-
# Convert relative ranges back to absolute
|
314
|
+
# Convert relative to absolute
|
419
315
|
absolute_ranges = []
|
420
316
|
for start, end in unprocessed_ranges:
|
421
317
|
abs_start = chunk_state.start_index + start
|
422
318
|
abs_end = chunk_state.start_index + end
|
423
319
|
absolute_ranges.append((abs_start, abs_end))
|
424
320
|
|
425
|
-
# Update
|
426
|
-
|
427
|
-
|
428
|
-
logger.debug(
|
429
|
-
f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
|
430
|
-
)
|
321
|
+
# Update work unit
|
322
|
+
self.work_units[unit_id].data["unprocessed_ranges"] = absolute_ranges
|
431
323
|
|
432
324
|
self.pending_units.append(unit_id)
|
433
|
-
logger.debug("Returned unit %s to pending_units", unit_id)
|
434
325
|
|
435
326
|
if worker_id in self.assigned_units:
|
436
327
|
del self.assigned_units[worker_id]
|
437
|
-
logger.debug("Deleted worker %s from assigned_units", worker_id)
|
438
328
|
|
439
329
|
if self.chunk_tracker:
|
440
330
|
self.chunk_tracker.release_worker_chunks(worker_id)
|
441
|
-
logger.debug("Released worker %s chunks in chunk_tracker", worker_id)
|
442
|
-
|
443
|
-
def _subtract_ranges(
|
444
|
-
self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
|
445
|
-
) -> List[Tuple[int, int]]:
|
446
|
-
"""Subtract processed ranges from total ranges."""
|
447
|
-
if not processed_ranges:
|
448
|
-
return total_ranges
|
449
|
-
|
450
|
-
# Create a set of all processed indices
|
451
|
-
processed_indices = set()
|
452
|
-
for start, end in processed_ranges:
|
453
|
-
processed_indices.update(range(start, end + 1))
|
454
|
-
|
455
|
-
# Find unprocessed ranges
|
456
|
-
unprocessed_ranges = []
|
457
|
-
for start, end in total_ranges:
|
458
|
-
current_start = None
|
459
|
-
for i in range(start, end + 1):
|
460
|
-
if i not in processed_indices:
|
461
|
-
if current_start is None:
|
462
|
-
current_start = i
|
463
|
-
else:
|
464
|
-
if current_start is not None:
|
465
|
-
unprocessed_ranges.append((current_start, i - 1))
|
466
|
-
current_start = None
|
467
|
-
|
468
|
-
if current_start is not None:
|
469
|
-
unprocessed_ranges.append((current_start, end))
|
470
|
-
|
471
|
-
return unprocessed_ranges
|
472
331
|
|
473
|
-
|
474
|
-
"""Get processor statistics."""
|
475
|
-
with self.lock:
|
476
|
-
stats = {
|
477
|
-
"total_units": len(self.work_units),
|
478
|
-
"pending_units": len(self.pending_units),
|
479
|
-
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
480
|
-
"total_shards": len(self.all_shards),
|
481
|
-
"workers": len(self.assigned_units),
|
482
|
-
}
|
483
|
-
logger.debug("Stats: %s", stats)
|
484
|
-
return stats
|
332
|
+
logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
|
485
333
|
|
486
334
|
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
487
|
-
"""Handle
|
488
|
-
# logger.debug("Handling result for unit %s", result.unit_id)
|
489
|
-
base_result = super().handle_result(result)
|
490
|
-
|
335
|
+
"""Handle result from worker."""
|
491
336
|
# Track processed items if we have chunk tracker
|
492
|
-
if self.chunk_tracker:
|
493
|
-
if "item_indices" not in result.metadata:
|
494
|
-
result.metadata["item_indices"] = [result.metadata.get("_item_index")]
|
337
|
+
if self.chunk_tracker and "item_indices" in result.metadata:
|
495
338
|
indices = result.metadata["item_indices"]
|
496
|
-
logger.debug("Result metadata item_indices: %s", indices)
|
497
339
|
|
498
|
-
#
|
340
|
+
# Convert to ranges and mark as processed
|
499
341
|
if indices:
|
500
342
|
indices.sort()
|
501
343
|
ranges = []
|
@@ -514,269 +356,272 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
514
356
|
|
515
357
|
# Mark ranges as processed
|
516
358
|
for start_idx, end_idx in ranges:
|
517
|
-
logger.debug(f"Marking chunk as processed: {result.to_repr()}")
|
518
359
|
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
519
|
-
logger.debug(
|
520
|
-
"Marked items processed for unit %s: %d-%d",
|
521
|
-
result.unit_id,
|
522
|
-
start_idx,
|
523
|
-
end_idx,
|
524
|
-
)
|
525
|
-
else:
|
526
|
-
logger.error(
|
527
|
-
f"No chunk tracker? {self.chunk_tracker} or no item_indices in {result.metadata}"
|
528
|
-
)
|
529
360
|
|
530
|
-
return
|
361
|
+
return {
|
362
|
+
"source_id": result.source_id,
|
363
|
+
"chunk_id": result.chunk_id,
|
364
|
+
"outputs": result.outputs,
|
365
|
+
"metadata": result.metadata,
|
366
|
+
}
|
531
367
|
|
532
368
|
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
533
369
|
"""Update work units based on what's been processed."""
|
534
|
-
logger.info(f"Updating
|
370
|
+
logger.info(f"Updating from {len(processed_job_ids)} processed jobs")
|
535
371
|
|
536
372
|
with self.lock:
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
# Parse job_id format: "data-0000:chunk:0:idx:42"
|
548
|
-
parts = job_id.split(":")
|
549
|
-
if (
|
550
|
-
len(parts) == 5
|
551
|
-
and parts[0] == shard_name
|
552
|
-
and parts[1] == "chunk"
|
553
|
-
and int(parts[2]) == chunk_index
|
554
|
-
and parts[3] == "idx"
|
555
|
-
):
|
556
|
-
|
373
|
+
# Group by chunk
|
374
|
+
processed_by_chunk = defaultdict(set)
|
375
|
+
|
376
|
+
for job_id in processed_job_ids:
|
377
|
+
# Parse job_id to extract chunk and index
|
378
|
+
# Expected format: "shard:chunk:X:idx:Y"
|
379
|
+
parts = job_id.split(":")
|
380
|
+
if len(parts) >= 5 and parts[3] == "idx":
|
381
|
+
chunk_id = ":".join(parts[:3]) # "shard:chunk:X"
|
382
|
+
try:
|
557
383
|
idx = int(parts[4])
|
558
|
-
|
559
|
-
|
384
|
+
processed_by_chunk[chunk_id].add(idx)
|
385
|
+
except ValueError:
|
386
|
+
continue
|
560
387
|
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
if idx == end + 1:
|
570
|
-
end = idx
|
571
|
-
else:
|
572
|
-
processed_ranges.append((start, end))
|
573
|
-
start = idx
|
574
|
-
end = idx
|
575
|
-
|
576
|
-
processed_ranges.append((start, end))
|
577
|
-
|
578
|
-
# Calculate unprocessed ranges
|
579
|
-
total_range = [(start_index, start_index + chunk_size - 1)]
|
580
|
-
unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
|
581
|
-
|
582
|
-
# Update unit
|
583
|
-
unit.data["unprocessed_ranges"] = unprocessed_ranges
|
584
|
-
|
585
|
-
logger.debug(
|
586
|
-
f"Updated unit {unit_id}: {len(processed_indices)} processed, "
|
587
|
-
f"unprocessed ranges: {unprocessed_ranges}"
|
588
|
-
)
|
388
|
+
# Update chunk tracker with processed items
|
389
|
+
if self.chunk_tracker:
|
390
|
+
for chunk_id, indices in processed_by_chunk.items():
|
391
|
+
if indices:
|
392
|
+
# Sort indices and convert to ranges
|
393
|
+
sorted_indices = sorted(indices)
|
394
|
+
for idx in sorted_indices:
|
395
|
+
self.chunk_tracker.mark_items_processed(chunk_id, idx, idx)
|
589
396
|
|
397
|
+
def get_stats(self) -> Dict[str, Any]:
|
398
|
+
"""Get processor statistics."""
|
399
|
+
with self.lock:
|
400
|
+
# Get chunk tracker stats if available
|
401
|
+
if self.chunk_tracker:
|
402
|
+
shards_summary = self.chunk_tracker.get_shards_summary()
|
403
|
+
total_chunks = sum(len(s.get("chunks", [])) for s in shards_summary.values())
|
404
|
+
completed_chunks = sum(
|
405
|
+
1
|
406
|
+
for s in shards_summary.values()
|
407
|
+
for c in s.get("chunks", [])
|
408
|
+
if c.status == "completed"
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
total_chunks = len(self.work_units)
|
412
|
+
completed_chunks = 0
|
590
413
|
|
591
|
-
|
592
|
-
|
414
|
+
return {
|
415
|
+
"total_shards": self.dataset.num_shards if self.dataset else 0,
|
416
|
+
"total_chunks": total_chunks,
|
417
|
+
"pending_units": len(self.pending_units),
|
418
|
+
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
419
|
+
"completed_chunks": completed_chunks,
|
420
|
+
"workers": len(self.assigned_units),
|
421
|
+
}
|
593
422
|
|
594
|
-
def
|
595
|
-
|
596
|
-
|
597
|
-
self.dataset_config: Dict[str, Any] = {}
|
598
|
-
self.dataset_name: Optional[str] = None
|
423
|
+
def cleanup(self):
|
424
|
+
"""Clean up resources."""
|
425
|
+
logger.info("Cleaning up orchestrator")
|
599
426
|
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
# Store config
|
606
|
-
self.dataset_config = cfg
|
607
|
-
|
608
|
-
# Initialize dataset loader
|
609
|
-
dataset_path = cfg.get("dataset_path")
|
610
|
-
self.dataset_path = dataset_path
|
611
|
-
dataset_type = cfg.get("dataset_type", "huggingface")
|
612
|
-
dataset_split = cfg.get("dataset_split", "train")
|
613
|
-
image_column = cfg.get("dataset_image_column", "image")
|
614
|
-
|
615
|
-
if dataset_path:
|
616
|
-
self.dataset_loader = DatasetLoader(
|
617
|
-
dataset_path=dataset_path,
|
618
|
-
dataset_type=dataset_type,
|
619
|
-
split=dataset_split,
|
620
|
-
image_column=image_column,
|
621
|
-
)
|
622
|
-
logger.debug("DatasetLoader initialized for worker")
|
623
|
-
else:
|
624
|
-
logger.error("No dataset_path provided in worker config")
|
427
|
+
# Stop background threads
|
428
|
+
self.stop_creation.set()
|
429
|
+
if self.unit_creation_thread:
|
430
|
+
self.unit_creation_thread.join(timeout=5)
|
625
431
|
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
if not self.dataset_loader:
|
630
|
-
logger.error("Dataset loader not initialized")
|
631
|
-
return
|
432
|
+
# Save checkpoint
|
433
|
+
if self.chunk_tracker:
|
434
|
+
self.chunk_tracker.save_checkpoint()
|
632
435
|
|
633
|
-
shard_name = unit.metadata["shard_name"]
|
634
|
-
chunk_index = unit.metadata["chunk_index"]
|
635
|
-
shard_url = unit.data["shard_url"]
|
636
|
-
start_index = unit.data["start_index"]
|
637
|
-
chunk_size = unit.data["chunk_size"]
|
638
|
-
unprocessed_ranges = unit.data.get(
|
639
|
-
"unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
|
640
|
-
)
|
641
436
|
|
642
|
-
|
437
|
+
class WebDatasetWorkerProcessor(WorkerProcessor):
|
438
|
+
"""Worker processor for WebDataset shards using webshart."""
|
643
439
|
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
440
|
+
def __init__(self):
|
441
|
+
logger.info("Initializing WebDatasetWorkerProcessor with webshart")
|
442
|
+
self.loader: Optional[webshart.TarDataLoader] = None
|
443
|
+
self.dataset: Optional[webshart.DiscoveredDataset] = None
|
444
|
+
self.mock_results = False
|
649
445
|
|
650
|
-
|
446
|
+
def initialize(self, config: ProcessorConfig) -> None:
|
447
|
+
"""Initialize worker with webshart loader."""
|
448
|
+
cfg = config.config
|
449
|
+
dataset_cfg = cfg.get("dataset", {})
|
651
450
|
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
):
|
656
|
-
# Skip if not in our chunk range
|
657
|
-
if idx < start_index or idx >= start_index + chunk_size:
|
658
|
-
# logger.debug(f"Skipping idx={idx} not in chunk range")
|
659
|
-
continue
|
451
|
+
self.dataset_path = dataset_cfg.get("dataset_path")
|
452
|
+
metadata_path = dataset_cfg.get("metadata_path", None)
|
453
|
+
self.mock_results = dataset_cfg.get("mock_results", False)
|
660
454
|
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
continue
|
455
|
+
# Cache configuration
|
456
|
+
cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
|
457
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
665
458
|
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
clean_metadata.update(
|
680
|
-
{
|
681
|
-
"_item_index": idx,
|
682
|
-
"_chunk_relative_index": idx - start_index,
|
683
|
-
"_job_id": job_id,
|
684
|
-
}
|
685
|
-
)
|
459
|
+
if self.dataset_path and not self.mock_results:
|
460
|
+
# Discover dataset
|
461
|
+
self.dataset = webshart.discover_dataset(
|
462
|
+
source=self.dataset_path,
|
463
|
+
metadata=metadata_path,
|
464
|
+
)
|
465
|
+
|
466
|
+
# Enable caching
|
467
|
+
self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
|
468
|
+
self.dataset.enable_shard_cache(
|
469
|
+
location=str(cache_dir / "shard_cache"),
|
470
|
+
cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
|
471
|
+
)
|
686
472
|
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
"job_id": job_id,
|
695
|
-
}
|
473
|
+
# Create loader
|
474
|
+
self.loader = webshart.TarDataLoader(
|
475
|
+
self.dataset,
|
476
|
+
buffer_size=cfg.get("buffer_size", 10),
|
477
|
+
max_file_size=cfg.get("max_file_size", 100 * 1024 * 1024),
|
478
|
+
load_file_data=True,
|
479
|
+
)
|
696
480
|
|
697
|
-
|
481
|
+
logger.info("webshart TarDataLoader initialized")
|
698
482
|
|
699
|
-
|
700
|
-
|
483
|
+
def _create_mock_image(self, idx: int) -> Image.Image:
|
484
|
+
"""Create a dummy test image."""
|
485
|
+
color = ((idx * 37) % 256, (idx * 53) % 256, (idx * 71) % 256)
|
486
|
+
image = Image.new("RGB", (256, 256), color=color)
|
487
|
+
return image
|
701
488
|
|
702
|
-
|
703
|
-
|
704
|
-
logger.debug("
|
489
|
+
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
490
|
+
"""Process a work unit by iterating specified ranges."""
|
491
|
+
logger.debug(f"Processing unit: {unit.unit_id}")
|
705
492
|
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
"""Iterate through a shard with metadata."""
|
710
|
-
logger.debug("Iterating shard with metadata: %s", shard_url)
|
493
|
+
shard_name = unit.data["shard_name"]
|
494
|
+
shard_idx = unit.data.get("shard_idx")
|
495
|
+
unprocessed_ranges = unit.data.get("unprocessed_ranges", [])
|
711
496
|
|
712
|
-
|
713
|
-
|
714
|
-
|
497
|
+
# For chunk tracking
|
498
|
+
chunk_index = unit.metadata.get("chunk_index", 0)
|
499
|
+
processed_indices = []
|
715
500
|
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
501
|
+
if self.mock_results:
|
502
|
+
# Generate mock results for unprocessed ranges
|
503
|
+
for start_idx, end_idx in unprocessed_ranges:
|
504
|
+
for idx in range(start_idx, end_idx + 1):
|
505
|
+
job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
|
506
|
+
|
507
|
+
yield {
|
508
|
+
"image": self._create_mock_image(idx),
|
509
|
+
"image_data": None,
|
510
|
+
"item_key": f"mock_{idx}",
|
511
|
+
"item_index": idx,
|
512
|
+
"metadata": {
|
513
|
+
"_item_index": idx,
|
514
|
+
"_chunk_relative_index": idx - unit.data["start_index"],
|
515
|
+
"_job_id": job_id,
|
516
|
+
"_mock": True,
|
517
|
+
},
|
518
|
+
"job_id": job_id,
|
519
|
+
}
|
721
520
|
|
722
|
-
|
723
|
-
|
521
|
+
processed_indices.append(idx)
|
522
|
+
else:
|
523
|
+
# Use webshart to process unprocessed ranges
|
524
|
+
for start_idx, end_idx in unprocessed_ranges:
|
525
|
+
try:
|
526
|
+
# Jump to shard and starting position
|
527
|
+
if shard_idx is not None:
|
528
|
+
self.loader.shard(shard_idx=shard_idx, cursor_idx=start_idx)
|
529
|
+
else:
|
530
|
+
# Try to find shard by name
|
531
|
+
self.loader.shard(filename=shard_name, cursor_idx=start_idx)
|
532
|
+
|
533
|
+
# Iterate through the range
|
534
|
+
for idx in range(start_idx, end_idx + 1):
|
535
|
+
try:
|
536
|
+
entry = next(self.loader)
|
537
|
+
|
538
|
+
# Decode image
|
539
|
+
image = None
|
540
|
+
if entry.data:
|
541
|
+
try:
|
542
|
+
# Use cv2 to decode from memory
|
543
|
+
nparr = np.frombuffer(entry.data, np.uint8)
|
544
|
+
img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
545
|
+
|
546
|
+
if img_np is not None:
|
547
|
+
# Convert from BGR (OpenCV default) to RGB (PIL default)
|
548
|
+
img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
549
|
+
image = Image.fromarray(img_rgb)
|
550
|
+
else:
|
551
|
+
logger.warning(f"cv2.imdecode failed for {entry.path}")
|
552
|
+
|
553
|
+
except ImportError:
|
554
|
+
logger.warning(
|
555
|
+
"cv2 or numpy not installed, falling back to PIL"
|
556
|
+
)
|
557
|
+
image = Image.open(io.BytesIO(entry.data))
|
558
|
+
except Exception as img_e:
|
559
|
+
logger.error(
|
560
|
+
f"Error decoding image {entry.path} with cv2: {img_e}"
|
561
|
+
)
|
724
562
|
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
563
|
+
# Generate job ID compatible with chunk tracker
|
564
|
+
job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
|
565
|
+
|
566
|
+
yield {
|
567
|
+
"image": image,
|
568
|
+
"image_data": entry.data,
|
569
|
+
"item_key": Path(entry.path).stem,
|
570
|
+
"item_index": idx,
|
571
|
+
"metadata": {
|
572
|
+
"_item_index": idx,
|
573
|
+
"_chunk_relative_index": idx - unit.data["start_index"],
|
574
|
+
"_job_id": job_id,
|
575
|
+
"_filename": entry.path,
|
576
|
+
"_file_size": entry.size,
|
577
|
+
},
|
578
|
+
"job_id": job_id,
|
579
|
+
}
|
733
580
|
|
734
|
-
|
735
|
-
logger.debug(
|
736
|
-
"No image data found for item key=%s, available keys: %s",
|
737
|
-
key,
|
738
|
-
list(sample.keys()),
|
739
|
-
)
|
740
|
-
continue
|
581
|
+
processed_indices.append(idx)
|
741
582
|
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
583
|
+
if len(processed_indices) % 10 == 0:
|
584
|
+
gc.collect()
|
585
|
+
|
586
|
+
except StopIteration:
|
587
|
+
logger.warning(f"Unexpected end of shard at index {idx}")
|
588
|
+
break
|
589
|
+
except Exception as e:
|
590
|
+
logger.error(f"Error processing index {idx}: {e}")
|
591
|
+
continue
|
748
592
|
|
749
|
-
|
750
|
-
|
751
|
-
|
593
|
+
except Exception as e:
|
594
|
+
logger.error(f"Error processing range {start_idx}-{end_idx}: {e}")
|
595
|
+
continue
|
752
596
|
|
753
|
-
|
597
|
+
# Store processed indices for result
|
598
|
+
context["_processed_indices"] = processed_indices
|
599
|
+
logger.info(f"Processed {len(processed_indices)} items from unit {unit.unit_id}")
|
754
600
|
|
755
601
|
def prepare_result(
|
756
602
|
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|
757
603
|
) -> WorkResult:
|
758
|
-
"""Prepare
|
759
|
-
logger.debug("Preparing result for unit %s", unit.unit_id)
|
604
|
+
"""Prepare result with processing details."""
|
760
605
|
result = super().prepare_result(unit, outputs, processing_time_ms)
|
761
606
|
|
762
|
-
# Add processed indices
|
607
|
+
# Add processed indices for chunk tracker
|
763
608
|
if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
|
764
609
|
result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
|
765
|
-
logger.debug(
|
766
|
-
"Added item_indices to result metadata: %s", result.metadata["item_indices"]
|
767
|
-
)
|
768
610
|
|
769
611
|
return result
|
770
612
|
|
771
613
|
def get_dataset_info(self) -> Dict[str, Any]:
|
772
614
|
"""Get dataset information."""
|
773
|
-
if self.
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
615
|
+
if self.dataset:
|
616
|
+
stats = self.dataset.get_stats()
|
617
|
+
return {
|
618
|
+
"dataset_name": self.dataset.name,
|
619
|
+
"format": self.dataset.dataset_format,
|
620
|
+
"total_shards": stats["total_shards"],
|
621
|
+
"total_files": stats.get("total_files", "Unknown"),
|
622
|
+
"mock_results": self.mock_results,
|
623
|
+
}
|
624
|
+
return {
|
625
|
+
"dataset_name": "Mock Dataset" if self.mock_results else "Unknown",
|
626
|
+
"mock_results": self.mock_results,
|
780
627
|
}
|
781
|
-
logger.debug("Dataset info (no loader): %s", info)
|
782
|
-
return info
|