caption-flow 0.2.4__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/base.py +3 -0
- caption_flow/processors/huggingface.py +637 -464
- caption_flow/processors/local_filesystem.py +2 -0
- caption_flow/processors/webdataset.py +438 -538
- caption_flow/storage/manager.py +328 -305
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +197 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/workers/caption.py +191 -138
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/METADATA +2 -1
- caption_flow-0.3.2.dist-info/RECORD +33 -0
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- caption_flow-0.2.4.dist-info/RECORD +0 -38
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.4.dist-info → caption_flow-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,104 +1,113 @@
|
|
1
|
-
"""WebDataset processor implementation."""
|
1
|
+
"""WebDataset processor implementation using webshart TarDataLoader."""
|
2
2
|
|
3
3
|
import logging
|
4
4
|
import threading
|
5
|
+
import gc
|
6
|
+
import os
|
5
7
|
from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
|
6
8
|
from collections import deque, defaultdict
|
7
9
|
from pathlib import Path
|
8
10
|
import json
|
9
|
-
import io
|
10
11
|
from datetime import datetime
|
11
12
|
from PIL import Image
|
12
|
-
|
13
|
+
import io
|
13
14
|
|
15
|
+
from caption_flow.models import JobId
|
16
|
+
from caption_flow.storage import StorageManager
|
14
17
|
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
15
|
-
from ..utils import
|
18
|
+
from ..utils import ChunkTracker
|
19
|
+
|
20
|
+
import webshart
|
21
|
+
import cv2
|
22
|
+
import numpy as np
|
16
23
|
|
17
24
|
logger = logging.getLogger(__name__)
|
18
25
|
logger.setLevel(logging.INFO)
|
19
26
|
|
20
27
|
|
21
28
|
class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
22
|
-
"""Orchestrator processor for WebDataset shards."""
|
29
|
+
"""Orchestrator processor for WebDataset shards using webshart with ChunkTracker."""
|
23
30
|
|
24
31
|
def __init__(self):
|
25
|
-
logger.
|
26
|
-
self.
|
32
|
+
logger.info("Initializing WebDatasetOrchestratorProcessor with webshart + ChunkTracker")
|
33
|
+
self.dataset: Optional[webshart.DiscoveredDataset] = None
|
27
34
|
self.chunk_tracker: Optional[ChunkTracker] = None
|
28
35
|
self.chunk_size: int = 1000
|
29
36
|
|
30
37
|
# Work unit management
|
31
38
|
self.work_units: Dict[str, WorkUnit] = {}
|
32
39
|
self.pending_units: Deque[str] = deque()
|
33
|
-
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
40
|
+
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
34
41
|
self.lock = threading.Lock()
|
35
42
|
|
36
|
-
# Shard
|
37
|
-
self.
|
38
|
-
self.current_shard_index = 0
|
39
|
-
self.current_shard_items = 0
|
43
|
+
# Shard info cache
|
44
|
+
self.shard_info_cache: Dict[int, Dict] = {}
|
40
45
|
|
41
46
|
# Background thread for creating work units
|
42
47
|
self.unit_creation_thread: Optional[threading.Thread] = None
|
43
48
|
self.stop_creation = threading.Event()
|
49
|
+
self.min_buffer = 10
|
50
|
+
self.buffer_multiplier = 3
|
44
51
|
|
45
52
|
def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
|
46
|
-
"""Initialize
|
47
|
-
logger.
|
48
|
-
cfg = config.config
|
53
|
+
"""Initialize with webshart dataset discovery and ChunkTracker."""
|
54
|
+
logger.info("Initializing orchestrator with config")
|
49
55
|
|
50
|
-
|
56
|
+
cfg = config.config
|
51
57
|
dataset_cfg = cfg.get("dataset", {})
|
52
|
-
dataset_path = dataset_cfg.get("dataset_path")
|
53
|
-
|
54
|
-
dataset_split = dataset_cfg.get("dataset_split", "train")
|
55
|
-
image_column = dataset_cfg.get("dataset_image_column", "image")
|
58
|
+
self.dataset_path = dataset_cfg.get("dataset_path")
|
59
|
+
metadata_path = dataset_cfg.get("metadata_path", None)
|
56
60
|
|
57
61
|
# Chunk settings
|
58
62
|
self.chunk_size = cfg.get("chunk_size", 1000)
|
59
63
|
self.min_buffer = cfg.get("min_chunk_buffer", 10)
|
60
64
|
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
|
61
65
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
self.min_buffer,
|
66
|
-
self.buffer_multiplier,
|
67
|
-
)
|
66
|
+
# Cache configuration
|
67
|
+
cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
|
68
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
68
69
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
self.dataset_loader = DatasetLoader(
|
76
|
-
dataset_path=dataset_path,
|
77
|
-
dataset_type=dataset_type,
|
78
|
-
split=dataset_split,
|
79
|
-
image_column=image_column,
|
80
|
-
cache_dir=checkpoint_dir,
|
70
|
+
if self.dataset_path:
|
71
|
+
# Initialize dataset with webshart
|
72
|
+
self.dataset = webshart.discover_dataset(
|
73
|
+
source=self.dataset_path,
|
74
|
+
metadata=metadata_path,
|
81
75
|
)
|
82
|
-
logger.debug("DatasetLoader initialized")
|
83
76
|
|
84
|
-
|
85
|
-
|
77
|
+
# Enable caching for efficient access
|
78
|
+
self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
|
79
|
+
self.dataset.enable_shard_cache(
|
80
|
+
location=str(cache_dir / "shard_cache"),
|
81
|
+
cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
|
82
|
+
)
|
83
|
+
|
84
|
+
logger.info(f"Dataset discovered: {self.dataset.num_shards} shards")
|
86
85
|
|
87
|
-
#
|
88
|
-
|
89
|
-
|
86
|
+
# Initialize chunk tracker
|
87
|
+
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
88
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
89
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
90
90
|
|
91
91
|
# Restore existing state from chunk tracker
|
92
|
-
self._restore_state(storage
|
92
|
+
self._restore_state(storage)
|
93
93
|
|
94
94
|
# Start background unit creation
|
95
95
|
self.unit_creation_thread = threading.Thread(
|
96
96
|
target=self._create_units_background, daemon=True
|
97
97
|
)
|
98
98
|
self.unit_creation_thread.start()
|
99
|
-
logger.debug("Unit creation thread started")
|
100
99
|
else:
|
101
|
-
logger.error("No dataset_path provided
|
100
|
+
logger.error("No dataset_path provided")
|
101
|
+
|
102
|
+
def _get_shard_info_cached(self, shard_idx: int) -> Optional[Dict]:
|
103
|
+
"""Get shard info with caching."""
|
104
|
+
if shard_idx not in self.shard_info_cache:
|
105
|
+
try:
|
106
|
+
self.shard_info_cache[shard_idx] = self.dataset.get_shard_info(shard_idx)
|
107
|
+
except Exception as e:
|
108
|
+
logger.error(f"Error getting shard info for idx {shard_idx}: {e}")
|
109
|
+
return None
|
110
|
+
return self.shard_info_cache[shard_idx]
|
102
111
|
|
103
112
|
def _restore_state(self, storage: StorageManager) -> None:
|
104
113
|
"""Restore state from chunk tracker."""
|
@@ -107,39 +116,43 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
107
116
|
return
|
108
117
|
|
109
118
|
shards_summary = self.chunk_tracker.get_shards_summary()
|
110
|
-
|
111
|
-
# Get all processed job_ids from storage
|
112
|
-
all_processed_jobs = storage.get_all_processed_job_ids()
|
119
|
+
logger.debug(f"Restoring state: {shards_summary}")
|
113
120
|
|
114
121
|
with self.lock:
|
115
122
|
for shard_name, shard_info in shards_summary.items():
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
123
|
+
chunks = shard_info.get("chunks", [])
|
124
|
+
logger.debug(f"Existing job ids: {storage.get_all_processed_job_ids()}")
|
125
|
+
for chunk_state in chunks:
|
126
|
+
# Only add incomplete chunks
|
127
|
+
if chunk_state.status != "completed":
|
128
|
+
logger.debug(f"Restoring incomplete chunk {chunk_state}")
|
129
|
+
|
130
|
+
# Get unprocessed ranges
|
131
|
+
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
132
|
+
logger.debug(
|
133
|
+
f"Chunk {chunk_state.chunk_id} unprocessed ranges: {unprocessed_ranges}"
|
134
|
+
)
|
135
|
+
if not unprocessed_ranges:
|
136
|
+
continue
|
127
137
|
|
128
|
-
|
129
|
-
|
138
|
+
# Convert relative ranges to absolute file indices
|
139
|
+
absolute_ranges = []
|
140
|
+
for start, end in unprocessed_ranges:
|
141
|
+
abs_start = chunk_state.start_index + start
|
142
|
+
abs_end = chunk_state.start_index + end
|
143
|
+
absolute_ranges.append((abs_start, abs_end))
|
130
144
|
|
131
|
-
if unprocessed_ranges:
|
132
|
-
# Create work unit for unprocessed items
|
133
|
-
logger.debug(f"Creating WorkUnit for chunk {chunk_state}")
|
134
145
|
unit = WorkUnit(
|
135
146
|
unit_id=chunk_state.chunk_id,
|
136
147
|
chunk_id=chunk_state.chunk_id,
|
137
148
|
source_id=shard_name,
|
149
|
+
unit_size=chunk_state.chunk_size,
|
138
150
|
data={
|
139
151
|
"shard_url": chunk_state.shard_url,
|
152
|
+
"shard_name": shard_name,
|
140
153
|
"start_index": chunk_state.start_index,
|
141
154
|
"chunk_size": chunk_state.chunk_size,
|
142
|
-
"unprocessed_ranges":
|
155
|
+
"unprocessed_ranges": absolute_ranges,
|
143
156
|
},
|
144
157
|
metadata={
|
145
158
|
"shard_name": shard_name,
|
@@ -154,11 +167,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
154
167
|
"""Background thread to create work units on demand."""
|
155
168
|
logger.info("Starting work unit creation thread")
|
156
169
|
|
157
|
-
|
158
|
-
|
159
|
-
current_shard_name = None
|
160
|
-
current_shard_items = 0
|
161
|
-
current_index = 0
|
170
|
+
current_shard_idx = 0
|
171
|
+
current_file_idx = 0
|
162
172
|
|
163
173
|
while not self.stop_creation.is_set():
|
164
174
|
# Check if we need more units
|
@@ -169,14 +179,6 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
169
179
|
|
170
180
|
target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
|
171
181
|
units_needed = max(0, target_buffer - (pending_count + assigned_count))
|
172
|
-
logger.debug(
|
173
|
-
"pending_count=%d assigned_count=%d worker_count=%d target_buffer=%d units_needed=%d",
|
174
|
-
pending_count,
|
175
|
-
assigned_count,
|
176
|
-
worker_count,
|
177
|
-
target_buffer,
|
178
|
-
units_needed,
|
179
|
-
)
|
180
182
|
|
181
183
|
if units_needed == 0:
|
182
184
|
threading.Event().wait(5)
|
@@ -184,318 +186,192 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
184
186
|
|
185
187
|
# Create units as needed
|
186
188
|
units_created = 0
|
187
|
-
|
188
189
|
while units_created < units_needed and not self.stop_creation.is_set():
|
189
|
-
#
|
190
|
-
if
|
191
|
-
|
192
|
-
|
193
|
-
current_shard_name = Path(current_shard_url).stem
|
190
|
+
# Get current shard info
|
191
|
+
if current_shard_idx >= self.dataset.num_shards:
|
192
|
+
logger.info("All shards processed")
|
193
|
+
break
|
194
194
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
195
|
+
shard_info = self._get_shard_info_cached(current_shard_idx)
|
196
|
+
if not shard_info:
|
197
|
+
current_shard_idx += 1
|
198
|
+
current_file_idx = 0
|
199
|
+
continue
|
200
|
+
|
201
|
+
shard_name = shard_info["name"]
|
202
|
+
shard_files = shard_info["num_files"]
|
203
|
+
|
204
|
+
# Check if we need to move to next shard
|
205
|
+
if current_file_idx >= shard_files:
|
206
|
+
current_shard_idx += 1
|
207
|
+
current_file_idx = 0
|
208
|
+
continue
|
209
|
+
|
210
|
+
# Create chunk for current position
|
211
|
+
chunk_size = min(self.chunk_size, shard_files - current_file_idx)
|
212
|
+
self.current_chunk_index = current_file_idx // self.chunk_size
|
213
|
+
job_id_obj = JobId(
|
214
|
+
shard_id=shard_name,
|
215
|
+
chunk_id=self.current_chunk_index,
|
216
|
+
sample_id=current_file_idx,
|
217
|
+
)
|
218
|
+
chunk_id = job_id_obj.get_chunk_str()
|
219
|
+
|
220
|
+
with self.lock:
|
221
|
+
# Skip if already exists
|
222
|
+
if chunk_id in self.work_units:
|
223
|
+
current_file_idx += self.chunk_size
|
224
|
+
continue
|
225
|
+
|
226
|
+
# Check if chunk is already completed
|
227
|
+
if self.chunk_tracker:
|
228
|
+
chunk_state = self.chunk_tracker.chunks.get(chunk_id)
|
229
|
+
if chunk_state and chunk_state.status == "completed":
|
230
|
+
current_file_idx += self.chunk_size
|
224
231
|
continue
|
225
232
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
"unprocessed_ranges": [
|
260
|
-
(
|
261
|
-
r[0] + chunk_state.start_index,
|
262
|
-
r[1] + chunk_state.start_index,
|
263
|
-
)
|
264
|
-
for r in unprocessed_ranges
|
265
|
-
], # Convert relative to absolute
|
266
|
-
},
|
267
|
-
metadata={
|
268
|
-
"shard_name": current_shard_name,
|
269
|
-
"chunk_index": current_index // self.chunk_size,
|
270
|
-
},
|
271
|
-
)
|
272
|
-
else:
|
273
|
-
# New chunk
|
274
|
-
logger.debug(
|
275
|
-
"Creating new work unit: unit_id=%s shard=%s start_index=%d chunk_size=%d",
|
276
|
-
unit_id,
|
277
|
-
current_shard_name,
|
278
|
-
current_index,
|
279
|
-
chunk_size,
|
280
|
-
)
|
281
|
-
|
282
|
-
unit = WorkUnit(
|
283
|
-
unit_id=unit_id,
|
284
|
-
chunk_id=unit_id,
|
285
|
-
source_id=current_shard_name,
|
286
|
-
data={
|
287
|
-
"shard_url": current_shard_url,
|
288
|
-
"start_index": current_index,
|
289
|
-
"chunk_size": chunk_size,
|
290
|
-
"unprocessed_ranges": [
|
291
|
-
(current_index, current_index + chunk_size - 1)
|
292
|
-
],
|
293
|
-
},
|
294
|
-
metadata={
|
295
|
-
"shard_name": current_shard_name,
|
296
|
-
"chunk_index": current_index // self.chunk_size,
|
297
|
-
},
|
298
|
-
)
|
299
|
-
else:
|
300
|
-
# No chunk tracker, create normally
|
301
|
-
unit = WorkUnit(
|
302
|
-
unit_id=unit_id,
|
303
|
-
chunk_id=unit_id,
|
304
|
-
source_id=current_shard_name,
|
305
|
-
data={
|
306
|
-
"shard_url": current_shard_url,
|
307
|
-
"start_index": current_index,
|
308
|
-
"chunk_size": chunk_size,
|
309
|
-
"unprocessed_ranges": [
|
310
|
-
(current_index, current_index + chunk_size - 1)
|
311
|
-
],
|
312
|
-
},
|
313
|
-
metadata={
|
314
|
-
"shard_name": current_shard_name,
|
315
|
-
"chunk_index": current_index // self.chunk_size,
|
316
|
-
},
|
317
|
-
)
|
318
|
-
|
319
|
-
self.work_units[unit_id] = unit
|
320
|
-
self.pending_units.append(unit_id)
|
321
|
-
logger.debug("Added work unit %s to pending_units", unit_id)
|
322
|
-
|
323
|
-
if self.chunk_tracker:
|
324
|
-
added_chunk = self.chunk_tracker.add_chunk(
|
325
|
-
unit_id,
|
326
|
-
current_shard_name,
|
327
|
-
current_shard_url,
|
328
|
-
current_index,
|
329
|
-
chunk_size,
|
330
|
-
)
|
331
|
-
if added_chunk:
|
332
|
-
logger.debug("Added chunk to chunk_tracker: %s", unit_id)
|
333
|
-
else:
|
334
|
-
logger.debug("Chunk already exists in chunk_tracker: %s", unit_id)
|
233
|
+
# Get shard URL (path for webshart)
|
234
|
+
shard_url = shard_info.get("path", f"{shard_name}.tar")
|
235
|
+
|
236
|
+
# Create work unit
|
237
|
+
unit = WorkUnit(
|
238
|
+
unit_id=chunk_id,
|
239
|
+
chunk_id=chunk_id,
|
240
|
+
source_id=shard_name,
|
241
|
+
unit_size=chunk_size,
|
242
|
+
data={
|
243
|
+
"shard_url": shard_url,
|
244
|
+
"shard_name": shard_name,
|
245
|
+
"shard_idx": current_shard_idx,
|
246
|
+
"start_index": current_file_idx,
|
247
|
+
"chunk_size": chunk_size,
|
248
|
+
"unprocessed_ranges": [
|
249
|
+
(current_file_idx, current_file_idx + chunk_size - 1)
|
250
|
+
],
|
251
|
+
},
|
252
|
+
metadata={
|
253
|
+
"shard_name": shard_name,
|
254
|
+
"chunk_index": current_file_idx // self.chunk_size,
|
255
|
+
},
|
256
|
+
)
|
257
|
+
|
258
|
+
self.work_units[chunk_id] = unit
|
259
|
+
self.pending_units.append(chunk_id)
|
260
|
+
|
261
|
+
# Add to chunk tracker
|
262
|
+
if self.chunk_tracker:
|
263
|
+
self.chunk_tracker.add_chunk(
|
264
|
+
chunk_id, shard_name, shard_url, current_file_idx, chunk_size
|
265
|
+
)
|
335
266
|
|
336
|
-
|
267
|
+
units_created += 1
|
337
268
|
|
338
|
-
|
269
|
+
current_file_idx += self.chunk_size
|
339
270
|
|
340
271
|
if units_created > 0:
|
341
272
|
logger.debug(f"Created {units_created} work units")
|
342
273
|
|
274
|
+
logger.info("Work unit creation thread exiting")
|
275
|
+
|
343
276
|
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
344
277
|
"""Get available work units for a worker."""
|
345
|
-
logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
|
346
278
|
assigned = []
|
347
279
|
|
348
280
|
with self.lock:
|
349
|
-
# Get new units if needed
|
350
281
|
while len(assigned) < count and self.pending_units:
|
351
282
|
unit_id = self.pending_units.popleft()
|
352
283
|
unit = self.work_units.get(unit_id)
|
353
284
|
|
354
285
|
if unit:
|
286
|
+
# Update unprocessed ranges from chunk tracker before assigning
|
287
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
288
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
289
|
+
relative_unprocessed = chunk_state.get_unprocessed_ranges()
|
290
|
+
|
291
|
+
# Convert relative to absolute indices
|
292
|
+
absolute_ranges = []
|
293
|
+
for start, end in relative_unprocessed:
|
294
|
+
abs_start = chunk_state.start_index + start
|
295
|
+
abs_end = chunk_state.start_index + end
|
296
|
+
absolute_ranges.append((abs_start, abs_end))
|
297
|
+
|
298
|
+
# Update the work unit's unprocessed ranges
|
299
|
+
unit.data["unprocessed_ranges"] = absolute_ranges
|
300
|
+
|
301
|
+
logger.debug(
|
302
|
+
f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
|
303
|
+
)
|
304
|
+
|
355
305
|
self.assigned_units[worker_id].add(unit_id)
|
356
306
|
assigned.append(unit)
|
357
|
-
logger.debug("Assigning new unit %s to worker %s", unit_id, worker_id)
|
358
307
|
|
359
308
|
if self.chunk_tracker:
|
360
309
|
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
361
310
|
|
362
|
-
logger.debug("
|
311
|
+
logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
|
363
312
|
return assigned
|
364
313
|
|
365
|
-
def _has_unprocessed_items(self, unit: WorkUnit) -> bool:
|
366
|
-
"""Check if a work unit has unprocessed items."""
|
367
|
-
if not self.chunk_tracker:
|
368
|
-
logger.debug("No chunk_tracker, assuming unit %s has unprocessed items", unit.unit_id)
|
369
|
-
return True
|
370
|
-
|
371
|
-
chunk_info = self.chunk_tracker.get_chunk_with_unprocessed_items(unit.unit_id)
|
372
|
-
has_unprocessed = bool(chunk_info and chunk_info.get("unprocessed_ranges"))
|
373
|
-
logger.debug("Unit %s has unprocessed items: %s", unit.unit_id, has_unprocessed)
|
374
|
-
return has_unprocessed
|
375
|
-
|
376
314
|
def mark_completed(self, unit_id: str, worker_id: str) -> None:
|
377
315
|
"""Mark a work unit as completed."""
|
378
|
-
logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
|
379
316
|
with self.lock:
|
380
317
|
if unit_id in self.work_units:
|
381
318
|
self.assigned_units[worker_id].discard(unit_id)
|
382
|
-
logger.debug(
|
383
|
-
"Removed unit %s from assigned_units for worker %s", unit_id, worker_id
|
384
|
-
)
|
385
319
|
|
386
320
|
if self.chunk_tracker:
|
387
321
|
self.chunk_tracker.mark_completed(unit_id)
|
388
|
-
|
322
|
+
|
323
|
+
# Remove from memory
|
324
|
+
del self.work_units[unit_id]
|
389
325
|
|
390
326
|
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
391
327
|
"""Mark a work unit as failed."""
|
392
|
-
logger.
|
328
|
+
logger.error(f"Unit {unit_id} failed on {worker_id}: {error}")
|
393
329
|
with self.lock:
|
394
330
|
if unit_id in self.work_units:
|
395
331
|
self.assigned_units[worker_id].discard(unit_id)
|
396
332
|
self.pending_units.append(unit_id)
|
397
|
-
logger.debug("Returned unit %s to pending_units", unit_id)
|
398
333
|
|
399
334
|
if self.chunk_tracker:
|
400
335
|
self.chunk_tracker.mark_failed(unit_id)
|
401
|
-
logger.debug("Marked unit %s as failed in chunk_tracker", unit_id)
|
402
336
|
|
403
337
|
def release_assignments(self, worker_id: str) -> None:
|
404
338
|
"""Release all assignments for a disconnected worker."""
|
405
|
-
logger.debug("Releasing assignments for worker %s", worker_id)
|
406
339
|
with self.lock:
|
407
340
|
unit_ids = list(self.assigned_units.get(worker_id, []))
|
408
341
|
|
409
342
|
for unit_id in unit_ids:
|
410
|
-
if unit_id in self.work_units:
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
415
|
-
chunk_state = self.chunk_tracker.chunks[unit_id]
|
343
|
+
if unit_id in self.work_units and self.chunk_tracker:
|
344
|
+
# Get updated unprocessed ranges from chunk tracker
|
345
|
+
chunk_state = self.chunk_tracker.chunks.get(unit_id)
|
346
|
+
if chunk_state:
|
416
347
|
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
417
|
-
|
418
|
-
# Convert relative ranges back to absolute
|
348
|
+
# Convert relative to absolute
|
419
349
|
absolute_ranges = []
|
420
350
|
for start, end in unprocessed_ranges:
|
421
351
|
abs_start = chunk_state.start_index + start
|
422
352
|
abs_end = chunk_state.start_index + end
|
423
353
|
absolute_ranges.append((abs_start, abs_end))
|
424
354
|
|
425
|
-
# Update
|
426
|
-
|
427
|
-
|
428
|
-
logger.debug(
|
429
|
-
f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
|
430
|
-
)
|
355
|
+
# Update work unit
|
356
|
+
self.work_units[unit_id].data["unprocessed_ranges"] = absolute_ranges
|
431
357
|
|
432
358
|
self.pending_units.append(unit_id)
|
433
|
-
logger.debug("Returned unit %s to pending_units", unit_id)
|
434
359
|
|
435
360
|
if worker_id in self.assigned_units:
|
436
361
|
del self.assigned_units[worker_id]
|
437
|
-
logger.debug("Deleted worker %s from assigned_units", worker_id)
|
438
362
|
|
439
363
|
if self.chunk_tracker:
|
440
364
|
self.chunk_tracker.release_worker_chunks(worker_id)
|
441
|
-
logger.debug("Released worker %s chunks in chunk_tracker", worker_id)
|
442
|
-
|
443
|
-
def _subtract_ranges(
|
444
|
-
self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
|
445
|
-
) -> List[Tuple[int, int]]:
|
446
|
-
"""Subtract processed ranges from total ranges."""
|
447
|
-
if not processed_ranges:
|
448
|
-
return total_ranges
|
449
|
-
|
450
|
-
# Create a set of all processed indices
|
451
|
-
processed_indices = set()
|
452
|
-
for start, end in processed_ranges:
|
453
|
-
processed_indices.update(range(start, end + 1))
|
454
|
-
|
455
|
-
# Find unprocessed ranges
|
456
|
-
unprocessed_ranges = []
|
457
|
-
for start, end in total_ranges:
|
458
|
-
current_start = None
|
459
|
-
for i in range(start, end + 1):
|
460
|
-
if i not in processed_indices:
|
461
|
-
if current_start is None:
|
462
|
-
current_start = i
|
463
|
-
else:
|
464
|
-
if current_start is not None:
|
465
|
-
unprocessed_ranges.append((current_start, i - 1))
|
466
|
-
current_start = None
|
467
|
-
|
468
|
-
if current_start is not None:
|
469
|
-
unprocessed_ranges.append((current_start, end))
|
470
|
-
|
471
|
-
return unprocessed_ranges
|
472
365
|
|
473
|
-
|
474
|
-
"""Get processor statistics."""
|
475
|
-
with self.lock:
|
476
|
-
stats = {
|
477
|
-
"total_units": len(self.work_units),
|
478
|
-
"pending_units": len(self.pending_units),
|
479
|
-
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
480
|
-
"total_shards": len(self.all_shards),
|
481
|
-
"workers": len(self.assigned_units),
|
482
|
-
}
|
483
|
-
logger.debug("Stats: %s", stats)
|
484
|
-
return stats
|
366
|
+
logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
|
485
367
|
|
486
368
|
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
487
|
-
"""Handle
|
488
|
-
# logger.debug("Handling result for unit %s", result.unit_id)
|
489
|
-
base_result = super().handle_result(result)
|
490
|
-
|
369
|
+
"""Handle result from worker."""
|
491
370
|
# Track processed items if we have chunk tracker
|
492
|
-
if self.chunk_tracker:
|
493
|
-
if "item_indices" not in result.metadata:
|
494
|
-
result.metadata["item_indices"] = [result.metadata.get("_item_index")]
|
371
|
+
if self.chunk_tracker and "item_indices" in result.metadata:
|
495
372
|
indices = result.metadata["item_indices"]
|
496
|
-
logger.debug("Result metadata item_indices: %s", indices)
|
497
373
|
|
498
|
-
#
|
374
|
+
# Convert to ranges and mark as processed
|
499
375
|
if indices:
|
500
376
|
indices.sort()
|
501
377
|
ranges = []
|
@@ -514,269 +390,293 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
514
390
|
|
515
391
|
# Mark ranges as processed
|
516
392
|
for start_idx, end_idx in ranges:
|
517
|
-
logger.debug(f"Marking chunk as processed: {result.to_repr()}")
|
518
393
|
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
519
|
-
logger.debug(
|
520
|
-
"Marked items processed for unit %s: %d-%d",
|
521
|
-
result.unit_id,
|
522
|
-
start_idx,
|
523
|
-
end_idx,
|
524
|
-
)
|
525
|
-
else:
|
526
|
-
logger.error(
|
527
|
-
f"No chunk tracker? {self.chunk_tracker} or no item_indices in {result.metadata}"
|
528
|
-
)
|
529
394
|
|
530
|
-
return
|
395
|
+
return {
|
396
|
+
"source_id": result.source_id,
|
397
|
+
"chunk_id": result.chunk_id,
|
398
|
+
"outputs": result.outputs,
|
399
|
+
"metadata": result.metadata,
|
400
|
+
}
|
531
401
|
|
532
402
|
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
533
403
|
"""Update work units based on what's been processed."""
|
534
|
-
logger.info(f"Updating
|
404
|
+
logger.info(f"Updating from {len(processed_job_ids)} processed jobs")
|
535
405
|
|
536
406
|
with self.lock:
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
# Parse job_id format: "data-0000:chunk:0:idx:42"
|
548
|
-
parts = job_id.split(":")
|
549
|
-
if (
|
550
|
-
len(parts) == 5
|
551
|
-
and parts[0] == shard_name
|
552
|
-
and parts[1] == "chunk"
|
553
|
-
and int(parts[2]) == chunk_index
|
554
|
-
and parts[3] == "idx"
|
555
|
-
):
|
556
|
-
|
407
|
+
# Group by chunk
|
408
|
+
processed_by_chunk = defaultdict(set)
|
409
|
+
|
410
|
+
for job_id in processed_job_ids:
|
411
|
+
# Parse job_id to extract chunk and index
|
412
|
+
# Expected format: "shard:chunk:X:idx:Y"
|
413
|
+
parts = job_id.split(":")
|
414
|
+
if len(parts) >= 5 and parts[3] == "idx":
|
415
|
+
chunk_id = ":".join(parts[:3]) # "shard:chunk:X"
|
416
|
+
try:
|
557
417
|
idx = int(parts[4])
|
558
|
-
|
559
|
-
|
418
|
+
processed_by_chunk[chunk_id].add(idx)
|
419
|
+
except ValueError:
|
420
|
+
continue
|
560
421
|
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
if idx == end + 1:
|
570
|
-
end = idx
|
571
|
-
else:
|
572
|
-
processed_ranges.append((start, end))
|
573
|
-
start = idx
|
574
|
-
end = idx
|
575
|
-
|
576
|
-
processed_ranges.append((start, end))
|
577
|
-
|
578
|
-
# Calculate unprocessed ranges
|
579
|
-
total_range = [(start_index, start_index + chunk_size - 1)]
|
580
|
-
unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
|
581
|
-
|
582
|
-
# Update unit
|
583
|
-
unit.data["unprocessed_ranges"] = unprocessed_ranges
|
584
|
-
|
585
|
-
logger.debug(
|
586
|
-
f"Updated unit {unit_id}: {len(processed_indices)} processed, "
|
587
|
-
f"unprocessed ranges: {unprocessed_ranges}"
|
588
|
-
)
|
422
|
+
# Update chunk tracker with processed items
|
423
|
+
if self.chunk_tracker:
|
424
|
+
for chunk_id, indices in processed_by_chunk.items():
|
425
|
+
if indices:
|
426
|
+
# Sort indices and convert to ranges
|
427
|
+
sorted_indices = sorted(indices)
|
428
|
+
if not sorted_indices:
|
429
|
+
continue
|
589
430
|
|
431
|
+
# Condense into contiguous ranges
|
432
|
+
ranges = []
|
433
|
+
start_range = sorted_indices[0]
|
434
|
+
end_range = sorted_indices[0]
|
590
435
|
|
591
|
-
|
592
|
-
|
436
|
+
for i in range(1, len(sorted_indices)):
|
437
|
+
if sorted_indices[i] == end_range + 1:
|
438
|
+
end_range = sorted_indices[i]
|
439
|
+
else:
|
440
|
+
ranges.append((start_range, end_range))
|
441
|
+
start_range = sorted_indices[i]
|
442
|
+
end_range = sorted_indices[i]
|
443
|
+
ranges.append((start_range, end_range))
|
593
444
|
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
self.dataset_name: Optional[str] = None
|
445
|
+
# Mark each contiguous range as processed
|
446
|
+
logger.debug(f"Marking ranges {ranges} as processed in chunk {chunk_id}")
|
447
|
+
for start_idx, end_idx in ranges:
|
448
|
+
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
599
449
|
|
600
|
-
def
|
601
|
-
"""
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
self.dataset_loader = DatasetLoader(
|
617
|
-
dataset_path=dataset_path,
|
618
|
-
dataset_type=dataset_type,
|
619
|
-
split=dataset_split,
|
620
|
-
image_column=image_column,
|
621
|
-
)
|
622
|
-
logger.debug("DatasetLoader initialized for worker")
|
623
|
-
else:
|
624
|
-
logger.error("No dataset_path provided in worker config")
|
450
|
+
def get_stats(self) -> Dict[str, Any]:
|
451
|
+
"""Get processor statistics."""
|
452
|
+
with self.lock:
|
453
|
+
# Get chunk tracker stats if available
|
454
|
+
if self.chunk_tracker:
|
455
|
+
shards_summary = self.chunk_tracker.get_shards_summary()
|
456
|
+
total_chunks = sum(len(s.get("chunks", [])) for s in shards_summary.values())
|
457
|
+
completed_chunks = sum(
|
458
|
+
1
|
459
|
+
for s in shards_summary.values()
|
460
|
+
for c in s.get("chunks", [])
|
461
|
+
if c.status == "completed"
|
462
|
+
)
|
463
|
+
else:
|
464
|
+
total_chunks = len(self.work_units)
|
465
|
+
completed_chunks = 0
|
625
466
|
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
467
|
+
return {
|
468
|
+
"total_shards": self.dataset.num_shards if self.dataset else 0,
|
469
|
+
"total_chunks": total_chunks,
|
470
|
+
"pending_units": len(self.pending_units),
|
471
|
+
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
472
|
+
"completed_chunks": completed_chunks,
|
473
|
+
"workers": len(self.assigned_units),
|
474
|
+
}
|
632
475
|
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
start_index = unit.data["start_index"]
|
637
|
-
chunk_size = unit.data["chunk_size"]
|
638
|
-
unprocessed_ranges = unit.data.get(
|
639
|
-
"unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
|
640
|
-
)
|
476
|
+
def cleanup(self):
|
477
|
+
"""Clean up resources."""
|
478
|
+
logger.info("Cleaning up orchestrator")
|
641
479
|
|
642
|
-
|
480
|
+
# Stop background threads
|
481
|
+
self.stop_creation.set()
|
482
|
+
if self.unit_creation_thread:
|
483
|
+
self.unit_creation_thread.join(timeout=5)
|
643
484
|
|
644
|
-
#
|
645
|
-
|
646
|
-
|
647
|
-
indices_to_process.update(range(start, end + 1))
|
648
|
-
logger.debug("Indices to process: %s", indices_to_process)
|
485
|
+
# Save checkpoint
|
486
|
+
if self.chunk_tracker:
|
487
|
+
self.chunk_tracker.save_checkpoint()
|
649
488
|
|
650
|
-
processed_indices = []
|
651
489
|
|
652
|
-
|
653
|
-
|
654
|
-
self._iterate_shard_with_metadata(shard_url)
|
655
|
-
):
|
656
|
-
# Skip if not in our chunk range
|
657
|
-
if idx < start_index or idx >= start_index + chunk_size:
|
658
|
-
# logger.debug(f"Skipping idx={idx} not in chunk range")
|
659
|
-
continue
|
490
|
+
class WebDatasetWorkerProcessor(WorkerProcessor):
|
491
|
+
"""Worker processor for WebDataset shards using webshart."""
|
660
492
|
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
493
|
+
def __init__(self):
|
494
|
+
logger.info("Initializing WebDatasetWorkerProcessor with webshart")
|
495
|
+
self.loader: Optional[webshart.TarDataLoader] = None
|
496
|
+
self.dataset: Optional[webshart.DiscoveredDataset] = None
|
497
|
+
self.mock_results = False
|
665
498
|
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
# Clean metadata - remove sensitive and redundant fields
|
672
|
-
clean_metadata = {
|
673
|
-
k: v
|
674
|
-
for k, v in metadata.items()
|
675
|
-
if k not in ["url", "_shard_url", "shard_name"] # Remove these fields
|
676
|
-
}
|
677
|
-
|
678
|
-
# Add only necessary index information
|
679
|
-
clean_metadata.update(
|
680
|
-
{
|
681
|
-
"_item_index": idx,
|
682
|
-
"_chunk_relative_index": idx - start_index,
|
683
|
-
"_job_id": job_id,
|
684
|
-
}
|
685
|
-
)
|
499
|
+
def initialize(self, config: ProcessorConfig) -> None:
|
500
|
+
"""Initialize worker with webshart loader."""
|
501
|
+
cfg = config.config
|
502
|
+
dataset_cfg = cfg.get("dataset", {})
|
686
503
|
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
"image": image,
|
691
|
-
"item_key": key,
|
692
|
-
"item_index": idx,
|
693
|
-
"metadata": clean_metadata,
|
694
|
-
"job_id": job_id,
|
695
|
-
}
|
504
|
+
self.dataset_path = dataset_cfg.get("dataset_path")
|
505
|
+
metadata_path = dataset_cfg.get("metadata_path", None)
|
506
|
+
self.mock_results = dataset_cfg.get("mock_results", False)
|
696
507
|
|
697
|
-
|
508
|
+
# Cache configuration
|
509
|
+
cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
|
510
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
698
511
|
|
699
|
-
|
700
|
-
|
512
|
+
if self.dataset_path and not self.mock_results:
|
513
|
+
# Discover dataset
|
514
|
+
self.dataset = webshart.discover_dataset(
|
515
|
+
source=self.dataset_path,
|
516
|
+
metadata=metadata_path,
|
517
|
+
)
|
701
518
|
|
702
|
-
|
703
|
-
|
704
|
-
|
519
|
+
# Enable caching
|
520
|
+
self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
|
521
|
+
self.dataset.enable_shard_cache(
|
522
|
+
location=str(cache_dir / "shard_cache"),
|
523
|
+
cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
|
524
|
+
)
|
705
525
|
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
526
|
+
# Create loader
|
527
|
+
self.loader = webshart.TarDataLoader(
|
528
|
+
self.dataset,
|
529
|
+
buffer_size=cfg.get("buffer_size", 10),
|
530
|
+
max_file_size=cfg.get("max_file_size", 100 * 1024 * 1024),
|
531
|
+
load_file_data=True,
|
532
|
+
)
|
711
533
|
|
712
|
-
|
713
|
-
logger.error("Dataset loader not initialized")
|
714
|
-
return
|
534
|
+
logger.info("webshart TarDataLoader initialized")
|
715
535
|
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
536
|
+
def _create_mock_image(self, idx: int) -> Image.Image:
|
537
|
+
"""Create a dummy test image."""
|
538
|
+
color = ((idx * 37) % 256, (idx * 53) % 256, (idx * 71) % 256)
|
539
|
+
image = Image.new("RGB", (256, 256), color=color)
|
540
|
+
return image
|
721
541
|
|
722
|
-
|
723
|
-
|
542
|
+
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
543
|
+
"""Process a work unit by iterating specified ranges."""
|
544
|
+
logger.debug(f"Processing unit: {unit}")
|
724
545
|
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
for ext in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]:
|
729
|
-
if ext in sample:
|
730
|
-
image_data = sample[ext]
|
731
|
-
image_ext = ext
|
732
|
-
break
|
546
|
+
shard_name = unit.data["shard_name"]
|
547
|
+
shard_idx = unit.data.get("shard_idx")
|
548
|
+
unprocessed_ranges = unit.data.get("unprocessed_ranges", [])
|
733
549
|
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
key,
|
738
|
-
list(sample.keys()),
|
739
|
-
)
|
740
|
-
continue
|
550
|
+
# For chunk tracking
|
551
|
+
chunk_index = unit.metadata.get("chunk_index", 0)
|
552
|
+
processed_indices = []
|
741
553
|
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
for
|
746
|
-
|
747
|
-
|
554
|
+
if self.mock_results:
|
555
|
+
# Generate mock results for unprocessed ranges
|
556
|
+
for start_idx, end_idx in unprocessed_ranges:
|
557
|
+
for idx in range(start_idx, end_idx + 1):
|
558
|
+
job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
|
559
|
+
|
560
|
+
yield {
|
561
|
+
"image": self._create_mock_image(idx),
|
562
|
+
"image_data": None,
|
563
|
+
"item_key": f"mock_{idx}",
|
564
|
+
"item_index": idx,
|
565
|
+
"metadata": {
|
566
|
+
"_item_index": idx,
|
567
|
+
"_chunk_relative_index": idx - unit.data["start_index"],
|
568
|
+
"_job_id": job_id,
|
569
|
+
"_mock": True,
|
570
|
+
"_processed_indices": processed_indices,
|
571
|
+
},
|
572
|
+
"job_id": job_id,
|
573
|
+
}
|
748
574
|
|
749
|
-
|
750
|
-
|
751
|
-
|
575
|
+
processed_indices.append(idx)
|
576
|
+
else:
|
577
|
+
# Use webshart to process unprocessed ranges
|
578
|
+
for start_idx, end_idx in unprocessed_ranges:
|
579
|
+
try:
|
580
|
+
# Jump to shard and starting position
|
581
|
+
if shard_idx is not None:
|
582
|
+
self.loader.shard(shard_idx=shard_idx, cursor_idx=start_idx)
|
583
|
+
else:
|
584
|
+
# Try to find shard by name
|
585
|
+
self.loader.shard(filename=shard_name, cursor_idx=start_idx)
|
586
|
+
|
587
|
+
# Iterate through the range
|
588
|
+
for idx in range(start_idx, end_idx + 1):
|
589
|
+
try:
|
590
|
+
entry = next(self.loader)
|
591
|
+
|
592
|
+
# Decode image
|
593
|
+
image = None
|
594
|
+
if entry.data:
|
595
|
+
try:
|
596
|
+
# Use cv2 to decode from memory
|
597
|
+
nparr = np.frombuffer(entry.data, np.uint8)
|
598
|
+
img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
599
|
+
|
600
|
+
if img_np is not None:
|
601
|
+
# Convert from BGR (OpenCV default) to RGB (PIL default)
|
602
|
+
img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
|
603
|
+
image = Image.fromarray(img_rgb)
|
604
|
+
else:
|
605
|
+
logger.warning(f"cv2.imdecode failed for {entry.path}")
|
606
|
+
|
607
|
+
except ImportError:
|
608
|
+
logger.warning(
|
609
|
+
"cv2 or numpy not installed, falling back to PIL"
|
610
|
+
)
|
611
|
+
image = Image.open(io.BytesIO(entry.data))
|
612
|
+
except Exception as img_e:
|
613
|
+
logger.error(
|
614
|
+
f"Error decoding image {entry.path} with cv2: {img_e}"
|
615
|
+
)
|
752
616
|
|
753
|
-
|
617
|
+
# Generate job ID compatible with chunk tracker
|
618
|
+
job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
|
619
|
+
|
620
|
+
yield {
|
621
|
+
"image": image,
|
622
|
+
"image_data": entry.data,
|
623
|
+
"item_key": Path(entry.path).stem,
|
624
|
+
"item_index": idx,
|
625
|
+
"metadata": {
|
626
|
+
"_item_index": idx,
|
627
|
+
"_chunk_relative_index": idx - unit.data["start_index"],
|
628
|
+
"_job_id": job_id,
|
629
|
+
"_filename": entry.path,
|
630
|
+
"_file_size": entry.size,
|
631
|
+
"_processed_indices": processed_indices,
|
632
|
+
},
|
633
|
+
"job_id": job_id,
|
634
|
+
}
|
635
|
+
|
636
|
+
processed_indices.append(idx)
|
637
|
+
|
638
|
+
if len(processed_indices) % 10 == 0:
|
639
|
+
gc.collect()
|
640
|
+
|
641
|
+
except StopIteration:
|
642
|
+
logger.warning(f"Unexpected end of shard at index {idx}")
|
643
|
+
break
|
644
|
+
except Exception as e:
|
645
|
+
logger.error(f"Error processing index {idx}: {e}")
|
646
|
+
continue
|
647
|
+
|
648
|
+
except Exception as e:
|
649
|
+
logger.error(f"Error processing range {start_idx}-{end_idx}: {e}")
|
650
|
+
continue
|
651
|
+
|
652
|
+
# Store processed indices for result
|
653
|
+
context["_processed_indices"] = processed_indices
|
654
|
+
logger.info(f"Processed {len(processed_indices)} items from unit {unit.unit_id}")
|
754
655
|
|
755
656
|
def prepare_result(
|
756
657
|
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|
757
658
|
) -> WorkResult:
|
758
|
-
"""Prepare
|
759
|
-
logger.debug("Preparing result for unit %s", unit.unit_id)
|
659
|
+
"""Prepare result with processing details."""
|
760
660
|
result = super().prepare_result(unit, outputs, processing_time_ms)
|
761
661
|
|
762
|
-
# Add processed indices
|
763
|
-
if
|
764
|
-
result.metadata["item_indices"] =
|
765
|
-
logger.debug(
|
766
|
-
"Added item_indices to result metadata: %s", result.metadata["item_indices"]
|
767
|
-
)
|
662
|
+
# Add processed indices for chunk tracker
|
663
|
+
if hasattr(self, "_last_context") and "_processed_indices" in self._last_context:
|
664
|
+
result.metadata["item_indices"] = self._last_context["_processed_indices"]
|
768
665
|
|
769
666
|
return result
|
770
667
|
|
771
668
|
def get_dataset_info(self) -> Dict[str, Any]:
|
772
669
|
"""Get dataset information."""
|
773
|
-
if self.
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
670
|
+
if self.dataset:
|
671
|
+
stats = self.dataset.get_stats()
|
672
|
+
return {
|
673
|
+
"dataset_name": self.dataset.name,
|
674
|
+
"format": self.dataset.dataset_format,
|
675
|
+
"total_shards": stats["total_shards"],
|
676
|
+
"total_files": stats.get("total_files", "Unknown"),
|
677
|
+
"mock_results": self.mock_results,
|
678
|
+
}
|
679
|
+
return {
|
680
|
+
"dataset_name": "Mock Dataset" if self.mock_results else "Unknown",
|
681
|
+
"mock_results": self.mock_results,
|
780
682
|
}
|
781
|
-
logger.debug("Dataset info (no loader): %s", info)
|
782
|
-
return info
|