caption-flow 0.2.4__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,104 +1,111 @@
1
- """WebDataset processor implementation."""
1
+ """WebDataset processor implementation using webshart TarDataLoader."""
2
2
 
3
3
  import logging
4
4
  import threading
5
+ import gc
6
+ import os
5
7
  from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
6
8
  from collections import deque, defaultdict
7
9
  from pathlib import Path
8
10
  import json
9
- import io
10
11
  from datetime import datetime
11
12
  from PIL import Image
12
- from caption_flow.storage import StorageManager
13
+ import io
13
14
 
15
+ from caption_flow.storage import StorageManager
14
16
  from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
15
- from ..utils import DatasetLoader, ChunkTracker
17
+ from ..utils import ChunkTracker
18
+
19
+ import webshart
20
+ import cv2
21
+ import numpy as np
16
22
 
17
23
  logger = logging.getLogger(__name__)
18
- logger.setLevel(logging.INFO)
19
24
 
20
25
 
21
26
  class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
22
- """Orchestrator processor for WebDataset shards."""
27
+ """Orchestrator processor for WebDataset shards using webshart with ChunkTracker."""
23
28
 
24
29
  def __init__(self):
25
- logger.debug("Initializing WebDatasetOrchestratorProcessor")
26
- self.dataset_loader: Optional[DatasetLoader] = None
30
+ logger.info("Initializing WebDatasetOrchestratorProcessor with webshart + ChunkTracker")
31
+ self.dataset: Optional[webshart.DiscoveredDataset] = None
27
32
  self.chunk_tracker: Optional[ChunkTracker] = None
28
33
  self.chunk_size: int = 1000
29
34
 
30
35
  # Work unit management
31
36
  self.work_units: Dict[str, WorkUnit] = {}
32
37
  self.pending_units: Deque[str] = deque()
33
- self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
38
+ self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
34
39
  self.lock = threading.Lock()
35
40
 
36
- # Shard processing state
37
- self.all_shards: List[str] = []
38
- self.current_shard_index = 0
39
- self.current_shard_items = 0
41
+ # Shard info cache
42
+ self.shard_info_cache: Dict[int, Dict] = {}
40
43
 
41
44
  # Background thread for creating work units
42
45
  self.unit_creation_thread: Optional[threading.Thread] = None
43
46
  self.stop_creation = threading.Event()
47
+ self.min_buffer = 10
48
+ self.buffer_multiplier = 3
44
49
 
45
50
  def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
46
- """Initialize WebDataset processor."""
47
- logger.debug("Initializing orchestrator with config: %s", config.config)
48
- cfg = config.config
51
+ """Initialize with webshart dataset discovery and ChunkTracker."""
52
+ logger.info("Initializing orchestrator with config")
49
53
 
50
- # Dataset configuration
54
+ cfg = config.config
51
55
  dataset_cfg = cfg.get("dataset", {})
52
- dataset_path = dataset_cfg.get("dataset_path")
53
- dataset_type = dataset_cfg.get("dataset_type", "huggingface")
54
- dataset_split = dataset_cfg.get("dataset_split", "train")
55
- image_column = dataset_cfg.get("dataset_image_column", "image")
56
+ self.dataset_path = dataset_cfg.get("dataset_path")
57
+ metadata_path = dataset_cfg.get("metadata_path", None)
56
58
 
57
59
  # Chunk settings
58
60
  self.chunk_size = cfg.get("chunk_size", 1000)
59
61
  self.min_buffer = cfg.get("min_chunk_buffer", 10)
60
62
  self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
61
63
 
62
- logger.debug(
63
- "Chunk size: %d, min_buffer: %d, buffer_multiplier: %d",
64
- self.chunk_size,
65
- self.min_buffer,
66
- self.buffer_multiplier,
67
- )
64
+ # Cache configuration
65
+ cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
66
+ cache_dir.mkdir(parents=True, exist_ok=True)
68
67
 
69
- # Initialize dataset loader
70
- if dataset_path:
71
- checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
72
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
73
- logger.debug("Checkpoint dir: %s", checkpoint_dir)
74
-
75
- self.dataset_loader = DatasetLoader(
76
- dataset_path=dataset_path,
77
- dataset_type=dataset_type,
78
- split=dataset_split,
79
- image_column=image_column,
80
- cache_dir=checkpoint_dir,
68
+ if self.dataset_path:
69
+ # Initialize dataset with webshart
70
+ self.dataset = webshart.discover_dataset(
71
+ source=self.dataset_path,
72
+ metadata=metadata_path,
81
73
  )
82
- logger.debug("DatasetLoader initialized")
83
74
 
84
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
85
- logger.debug("ChunkTracker initialized at %s", checkpoint_dir / "chunks.json")
75
+ # Enable caching for efficient access
76
+ self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
77
+ self.dataset.enable_shard_cache(
78
+ location=str(cache_dir / "shard_cache"),
79
+ cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
80
+ )
86
81
 
87
- # Get all shards
88
- self.all_shards = self.dataset_loader.get_shard_list()
89
- logger.debug("All shards: %s", self.all_shards)
82
+ logger.info(f"Dataset discovered: {self.dataset.num_shards} shards")
83
+
84
+ # Initialize chunk tracker
85
+ checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
86
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
87
+ self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
90
88
 
91
89
  # Restore existing state from chunk tracker
92
- self._restore_state(storage=storage)
90
+ self._restore_state(storage)
93
91
 
94
92
  # Start background unit creation
95
93
  self.unit_creation_thread = threading.Thread(
96
94
  target=self._create_units_background, daemon=True
97
95
  )
98
96
  self.unit_creation_thread.start()
99
- logger.debug("Unit creation thread started")
100
97
  else:
101
- logger.error("No dataset_path provided in config")
98
+ logger.error("No dataset_path provided")
99
+
100
+ def _get_shard_info_cached(self, shard_idx: int) -> Optional[Dict]:
101
+ """Get shard info with caching."""
102
+ if shard_idx not in self.shard_info_cache:
103
+ try:
104
+ self.shard_info_cache[shard_idx] = self.dataset.get_shard_info(shard_idx)
105
+ except Exception as e:
106
+ logger.error(f"Error getting shard info for idx {shard_idx}: {e}")
107
+ return None
108
+ return self.shard_info_cache[shard_idx]
102
109
 
103
110
  def _restore_state(self, storage: StorageManager) -> None:
104
111
  """Restore state from chunk tracker."""
@@ -108,38 +115,36 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
108
115
 
109
116
  shards_summary = self.chunk_tracker.get_shards_summary()
110
117
 
111
- # Get all processed job_ids from storage
112
- all_processed_jobs = storage.get_all_processed_job_ids()
113
-
114
118
  with self.lock:
115
119
  for shard_name, shard_info in shards_summary.items():
116
- for chunk_state in shard_info["chunks"]:
117
- # Calculate actual unprocessed ranges based on what's in storage
118
- chunk_range = (
119
- chunk_state.start_index,
120
- chunk_state.start_index + chunk_state.chunk_size - 1,
121
- )
120
+ chunks = shard_info.get("chunks", [])
121
+ for chunk_state in chunks:
122
+ # Only add incomplete chunks
123
+ if chunk_state.status != "completed":
124
+ logger.debug(f"Restoring incomplete chunk {chunk_state.chunk_id}")
122
125
 
123
- # Get processed indices for this chunk
124
- processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
125
- chunk_state.chunk_id, all_processed_jobs
126
- )
126
+ # Get unprocessed ranges
127
+ unprocessed_ranges = chunk_state.get_unprocessed_ranges()
128
+ if not unprocessed_ranges:
129
+ continue
127
130
 
128
- # Calculate unprocessed ranges
129
- unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
131
+ # Convert relative ranges to absolute file indices
132
+ absolute_ranges = []
133
+ for start, end in unprocessed_ranges:
134
+ abs_start = chunk_state.start_index + start
135
+ abs_end = chunk_state.start_index + end
136
+ absolute_ranges.append((abs_start, abs_end))
130
137
 
131
- if unprocessed_ranges:
132
- # Create work unit for unprocessed items
133
- logger.debug(f"Creating WorkUnit for chunk {chunk_state}")
134
138
  unit = WorkUnit(
135
139
  unit_id=chunk_state.chunk_id,
136
140
  chunk_id=chunk_state.chunk_id,
137
141
  source_id=shard_name,
138
142
  data={
139
143
  "shard_url": chunk_state.shard_url,
144
+ "shard_name": shard_name,
140
145
  "start_index": chunk_state.start_index,
141
146
  "chunk_size": chunk_state.chunk_size,
142
- "unprocessed_ranges": unprocessed_ranges,
147
+ "unprocessed_ranges": absolute_ranges,
143
148
  },
144
149
  metadata={
145
150
  "shard_name": shard_name,
@@ -154,11 +159,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
154
159
  """Background thread to create work units on demand."""
155
160
  logger.info("Starting work unit creation thread")
156
161
 
157
- shard_iter = iter(self.all_shards)
158
- current_shard_url = None
159
- current_shard_name = None
160
- current_shard_items = 0
161
- current_index = 0
162
+ current_shard_idx = 0
163
+ current_file_idx = 0
162
164
 
163
165
  while not self.stop_creation.is_set():
164
166
  # Check if we need more units
@@ -169,14 +171,6 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
169
171
 
170
172
  target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
171
173
  units_needed = max(0, target_buffer - (pending_count + assigned_count))
172
- logger.debug(
173
- "pending_count=%d assigned_count=%d worker_count=%d target_buffer=%d units_needed=%d",
174
- pending_count,
175
- assigned_count,
176
- worker_count,
177
- target_buffer,
178
- units_needed,
179
- )
180
174
 
181
175
  if units_needed == 0:
182
176
  threading.Event().wait(5)
@@ -184,169 +178,91 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
184
178
 
185
179
  # Create units as needed
186
180
  units_created = 0
187
-
188
181
  while units_created < units_needed and not self.stop_creation.is_set():
189
- # Load next shard if needed
190
- if current_shard_url is None or current_index >= current_shard_items:
191
- try:
192
- current_shard_url = next(shard_iter)
193
- current_shard_name = Path(current_shard_url).stem
182
+ # Get current shard info
183
+ if current_shard_idx >= self.dataset.num_shards:
184
+ logger.info("All shards processed")
185
+ break
194
186
 
195
- logger.debug("Loading shard: %s", current_shard_url)
196
- # Count items in shard
197
- current_shard_items = sum(
198
- 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
199
- )
200
- logger.info(
201
- f"Processing shard {current_shard_name} with {current_shard_items} items"
202
- )
203
- current_index = 0
204
-
205
- except StopIteration:
206
- logger.info("All shards processed")
207
- break
208
- except Exception as e:
209
- logger.error("Error loading shard: %s", e)
210
- break
211
-
212
- # Create work unit
213
- if current_shard_url and current_index < current_shard_items:
214
- chunk_size = min(self.chunk_size, current_shard_items - current_index)
215
- unit_id = f"{current_shard_name}:chunk:{current_index // self.chunk_size}"
216
-
217
- with self.lock:
218
- # Check if this unit already exists in work_units
219
- if unit_id in self.work_units:
220
- logger.debug(
221
- f"Unit {unit_id} already exists in work_units, skipping creation"
222
- )
223
- current_index += self.chunk_size
187
+ shard_info = self._get_shard_info_cached(current_shard_idx)
188
+ if not shard_info:
189
+ current_shard_idx += 1
190
+ current_file_idx = 0
191
+ continue
192
+
193
+ shard_name = shard_info["name"]
194
+ shard_files = shard_info["num_files"]
195
+
196
+ # Check if we need to move to next shard
197
+ if current_file_idx >= shard_files:
198
+ current_shard_idx += 1
199
+ current_file_idx = 0
200
+ continue
201
+
202
+ # Create chunk for current position
203
+ chunk_size = min(self.chunk_size, shard_files - current_file_idx)
204
+ chunk_id = f"{shard_name}:chunk:{current_file_idx // self.chunk_size}"
205
+
206
+ with self.lock:
207
+ # Skip if already exists
208
+ if chunk_id in self.work_units:
209
+ current_file_idx += self.chunk_size
210
+ continue
211
+
212
+ # Check if chunk is already completed
213
+ if self.chunk_tracker:
214
+ chunk_state = self.chunk_tracker.chunks.get(chunk_id)
215
+ if chunk_state and chunk_state.status == "completed":
216
+ current_file_idx += self.chunk_size
224
217
  continue
225
218
 
226
- # Check if chunk is already completed or has no unprocessed items
227
- if self.chunk_tracker:
228
- chunk_state = self.chunk_tracker.chunks.get(unit_id)
229
-
230
- if chunk_state:
231
- # Check if completed
232
- if chunk_state.status == "completed":
233
- logger.debug(f"Unit {unit_id} already completed, skipping")
234
- current_index += self.chunk_size
235
- continue
236
-
237
- # Check if has unprocessed items
238
- unprocessed_ranges = chunk_state.get_unprocessed_ranges()
239
- if not unprocessed_ranges:
240
- logger.debug(
241
- f"Unit {unit_id} has no unprocessed items, skipping"
242
- )
243
- current_index += self.chunk_size
244
- continue
245
-
246
- # If chunk exists but has unprocessed items, use those ranges
247
- logger.debug(
248
- f"Existing chunk {unit_id} has unprocessed ranges: {unprocessed_ranges}"
249
- )
250
-
251
- unit = WorkUnit(
252
- unit_id=unit_id,
253
- chunk_id=unit_id,
254
- source_id=current_shard_name,
255
- data={
256
- "shard_url": current_shard_url,
257
- "start_index": current_index,
258
- "chunk_size": chunk_size,
259
- "unprocessed_ranges": [
260
- (
261
- r[0] + chunk_state.start_index,
262
- r[1] + chunk_state.start_index,
263
- )
264
- for r in unprocessed_ranges
265
- ], # Convert relative to absolute
266
- },
267
- metadata={
268
- "shard_name": current_shard_name,
269
- "chunk_index": current_index // self.chunk_size,
270
- },
271
- )
272
- else:
273
- # New chunk
274
- logger.debug(
275
- "Creating new work unit: unit_id=%s shard=%s start_index=%d chunk_size=%d",
276
- unit_id,
277
- current_shard_name,
278
- current_index,
279
- chunk_size,
280
- )
281
-
282
- unit = WorkUnit(
283
- unit_id=unit_id,
284
- chunk_id=unit_id,
285
- source_id=current_shard_name,
286
- data={
287
- "shard_url": current_shard_url,
288
- "start_index": current_index,
289
- "chunk_size": chunk_size,
290
- "unprocessed_ranges": [
291
- (current_index, current_index + chunk_size - 1)
292
- ],
293
- },
294
- metadata={
295
- "shard_name": current_shard_name,
296
- "chunk_index": current_index // self.chunk_size,
297
- },
298
- )
299
- else:
300
- # No chunk tracker, create normally
301
- unit = WorkUnit(
302
- unit_id=unit_id,
303
- chunk_id=unit_id,
304
- source_id=current_shard_name,
305
- data={
306
- "shard_url": current_shard_url,
307
- "start_index": current_index,
308
- "chunk_size": chunk_size,
309
- "unprocessed_ranges": [
310
- (current_index, current_index + chunk_size - 1)
311
- ],
312
- },
313
- metadata={
314
- "shard_name": current_shard_name,
315
- "chunk_index": current_index // self.chunk_size,
316
- },
317
- )
219
+ # Get shard URL (path for webshart)
220
+ shard_url = shard_info.get("path", f"{shard_name}.tar")
221
+
222
+ # Create work unit
223
+ unit = WorkUnit(
224
+ unit_id=chunk_id,
225
+ chunk_id=chunk_id,
226
+ source_id=shard_name,
227
+ data={
228
+ "shard_url": shard_url,
229
+ "shard_name": shard_name,
230
+ "shard_idx": current_shard_idx,
231
+ "start_index": current_file_idx,
232
+ "chunk_size": chunk_size,
233
+ "unprocessed_ranges": [
234
+ (current_file_idx, current_file_idx + chunk_size - 1)
235
+ ],
236
+ },
237
+ metadata={
238
+ "shard_name": shard_name,
239
+ "chunk_index": current_file_idx // self.chunk_size,
240
+ },
241
+ )
318
242
 
319
- self.work_units[unit_id] = unit
320
- self.pending_units.append(unit_id)
321
- logger.debug("Added work unit %s to pending_units", unit_id)
243
+ self.work_units[chunk_id] = unit
244
+ self.pending_units.append(chunk_id)
322
245
 
323
- if self.chunk_tracker:
324
- added_chunk = self.chunk_tracker.add_chunk(
325
- unit_id,
326
- current_shard_name,
327
- current_shard_url,
328
- current_index,
329
- chunk_size,
330
- )
331
- if added_chunk:
332
- logger.debug("Added chunk to chunk_tracker: %s", unit_id)
333
- else:
334
- logger.debug("Chunk already exists in chunk_tracker: %s", unit_id)
246
+ # Add to chunk tracker
247
+ if self.chunk_tracker:
248
+ self.chunk_tracker.add_chunk(
249
+ chunk_id, shard_name, shard_url, current_file_idx, chunk_size
250
+ )
335
251
 
336
- units_created += 1
252
+ units_created += 1
337
253
 
338
- current_index += self.chunk_size
254
+ current_file_idx += self.chunk_size
339
255
 
340
256
  if units_created > 0:
341
257
  logger.debug(f"Created {units_created} work units")
342
258
 
259
+ logger.info("Work unit creation thread exiting")
260
+
343
261
  def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
344
262
  """Get available work units for a worker."""
345
- logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
346
263
  assigned = []
347
264
 
348
265
  with self.lock:
349
- # Get new units if needed
350
266
  while len(assigned) < count and self.pending_units:
351
267
  unit_id = self.pending_units.popleft()
352
268
  unit = self.work_units.get(unit_id)
@@ -354,148 +270,74 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
354
270
  if unit:
355
271
  self.assigned_units[worker_id].add(unit_id)
356
272
  assigned.append(unit)
357
- logger.debug("Assigning new unit %s to worker %s", unit_id, worker_id)
358
273
 
359
274
  if self.chunk_tracker:
360
275
  self.chunk_tracker.mark_assigned(unit_id, worker_id)
361
276
 
362
- logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
277
+ logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
363
278
  return assigned
364
279
 
365
- def _has_unprocessed_items(self, unit: WorkUnit) -> bool:
366
- """Check if a work unit has unprocessed items."""
367
- if not self.chunk_tracker:
368
- logger.debug("No chunk_tracker, assuming unit %s has unprocessed items", unit.unit_id)
369
- return True
370
-
371
- chunk_info = self.chunk_tracker.get_chunk_with_unprocessed_items(unit.unit_id)
372
- has_unprocessed = bool(chunk_info and chunk_info.get("unprocessed_ranges"))
373
- logger.debug("Unit %s has unprocessed items: %s", unit.unit_id, has_unprocessed)
374
- return has_unprocessed
375
-
376
280
  def mark_completed(self, unit_id: str, worker_id: str) -> None:
377
281
  """Mark a work unit as completed."""
378
- logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
379
282
  with self.lock:
380
283
  if unit_id in self.work_units:
381
284
  self.assigned_units[worker_id].discard(unit_id)
382
- logger.debug(
383
- "Removed unit %s from assigned_units for worker %s", unit_id, worker_id
384
- )
385
285
 
386
286
  if self.chunk_tracker:
387
287
  self.chunk_tracker.mark_completed(unit_id)
388
- logger.debug("Marked unit %s as completed in chunk_tracker", unit_id)
288
+
289
+ # Remove from memory
290
+ del self.work_units[unit_id]
389
291
 
390
292
  def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
391
293
  """Mark a work unit as failed."""
392
- logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
294
+ logger.error(f"Unit {unit_id} failed on {worker_id}: {error}")
393
295
  with self.lock:
394
296
  if unit_id in self.work_units:
395
297
  self.assigned_units[worker_id].discard(unit_id)
396
298
  self.pending_units.append(unit_id)
397
- logger.debug("Returned unit %s to pending_units", unit_id)
398
299
 
399
300
  if self.chunk_tracker:
400
301
  self.chunk_tracker.mark_failed(unit_id)
401
- logger.debug("Marked unit %s as failed in chunk_tracker", unit_id)
402
302
 
403
303
  def release_assignments(self, worker_id: str) -> None:
404
304
  """Release all assignments for a disconnected worker."""
405
- logger.debug("Releasing assignments for worker %s", worker_id)
406
305
  with self.lock:
407
306
  unit_ids = list(self.assigned_units.get(worker_id, []))
408
307
 
409
308
  for unit_id in unit_ids:
410
- if unit_id in self.work_units:
411
- unit = self.work_units[unit_id]
412
-
413
- # Update unprocessed ranges based on what's been processed
414
- if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
415
- chunk_state = self.chunk_tracker.chunks[unit_id]
309
+ if unit_id in self.work_units and self.chunk_tracker:
310
+ # Get updated unprocessed ranges from chunk tracker
311
+ chunk_state = self.chunk_tracker.chunks.get(unit_id)
312
+ if chunk_state:
416
313
  unprocessed_ranges = chunk_state.get_unprocessed_ranges()
417
-
418
- # Convert relative ranges back to absolute
314
+ # Convert relative to absolute
419
315
  absolute_ranges = []
420
316
  for start, end in unprocessed_ranges:
421
317
  abs_start = chunk_state.start_index + start
422
318
  abs_end = chunk_state.start_index + end
423
319
  absolute_ranges.append((abs_start, abs_end))
424
320
 
425
- # Update the work unit's data
426
- unit.data["unprocessed_ranges"] = absolute_ranges
427
-
428
- logger.debug(
429
- f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
430
- )
321
+ # Update work unit
322
+ self.work_units[unit_id].data["unprocessed_ranges"] = absolute_ranges
431
323
 
432
324
  self.pending_units.append(unit_id)
433
- logger.debug("Returned unit %s to pending_units", unit_id)
434
325
 
435
326
  if worker_id in self.assigned_units:
436
327
  del self.assigned_units[worker_id]
437
- logger.debug("Deleted worker %s from assigned_units", worker_id)
438
328
 
439
329
  if self.chunk_tracker:
440
330
  self.chunk_tracker.release_worker_chunks(worker_id)
441
- logger.debug("Released worker %s chunks in chunk_tracker", worker_id)
442
-
443
- def _subtract_ranges(
444
- self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
445
- ) -> List[Tuple[int, int]]:
446
- """Subtract processed ranges from total ranges."""
447
- if not processed_ranges:
448
- return total_ranges
449
-
450
- # Create a set of all processed indices
451
- processed_indices = set()
452
- for start, end in processed_ranges:
453
- processed_indices.update(range(start, end + 1))
454
-
455
- # Find unprocessed ranges
456
- unprocessed_ranges = []
457
- for start, end in total_ranges:
458
- current_start = None
459
- for i in range(start, end + 1):
460
- if i not in processed_indices:
461
- if current_start is None:
462
- current_start = i
463
- else:
464
- if current_start is not None:
465
- unprocessed_ranges.append((current_start, i - 1))
466
- current_start = None
467
-
468
- if current_start is not None:
469
- unprocessed_ranges.append((current_start, end))
470
-
471
- return unprocessed_ranges
472
331
 
473
- def get_stats(self) -> Dict[str, Any]:
474
- """Get processor statistics."""
475
- with self.lock:
476
- stats = {
477
- "total_units": len(self.work_units),
478
- "pending_units": len(self.pending_units),
479
- "assigned_units": sum(len(units) for units in self.assigned_units.values()),
480
- "total_shards": len(self.all_shards),
481
- "workers": len(self.assigned_units),
482
- }
483
- logger.debug("Stats: %s", stats)
484
- return stats
332
+ logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
485
333
 
486
334
  def handle_result(self, result: WorkResult) -> Dict[str, Any]:
487
- """Handle WebDataset-specific result processing."""
488
- # logger.debug("Handling result for unit %s", result.unit_id)
489
- base_result = super().handle_result(result)
490
-
335
+ """Handle result from worker."""
491
336
  # Track processed items if we have chunk tracker
492
- if self.chunk_tracker:
493
- if "item_indices" not in result.metadata:
494
- result.metadata["item_indices"] = [result.metadata.get("_item_index")]
337
+ if self.chunk_tracker and "item_indices" in result.metadata:
495
338
  indices = result.metadata["item_indices"]
496
- logger.debug("Result metadata item_indices: %s", indices)
497
339
 
498
- # Group consecutive indices into ranges
340
+ # Convert to ranges and mark as processed
499
341
  if indices:
500
342
  indices.sort()
501
343
  ranges = []
@@ -514,269 +356,272 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
514
356
 
515
357
  # Mark ranges as processed
516
358
  for start_idx, end_idx in ranges:
517
- logger.debug(f"Marking chunk as processed: {result.to_repr()}")
518
359
  self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
519
- logger.debug(
520
- "Marked items processed for unit %s: %d-%d",
521
- result.unit_id,
522
- start_idx,
523
- end_idx,
524
- )
525
- else:
526
- logger.error(
527
- f"No chunk tracker? {self.chunk_tracker} or no item_indices in {result.metadata}"
528
- )
529
360
 
530
- return base_result
361
+ return {
362
+ "source_id": result.source_id,
363
+ "chunk_id": result.chunk_id,
364
+ "outputs": result.outputs,
365
+ "metadata": result.metadata,
366
+ }
531
367
 
532
368
  def update_from_storage(self, processed_job_ids: Set[str]) -> None:
533
369
  """Update work units based on what's been processed."""
534
- logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
370
+ logger.info(f"Updating from {len(processed_job_ids)} processed jobs")
535
371
 
536
372
  with self.lock:
537
- for unit_id, unit in self.work_units.items():
538
- # Extract chunk info from unit
539
- start_index = unit.data["start_index"]
540
- chunk_size = unit.data["chunk_size"]
541
- shard_name = unit.metadata["shard_name"]
542
- chunk_index = unit.metadata["chunk_index"]
543
-
544
- # Find processed indices for this chunk
545
- processed_indices = []
546
- for job_id in processed_job_ids:
547
- # Parse job_id format: "data-0000:chunk:0:idx:42"
548
- parts = job_id.split(":")
549
- if (
550
- len(parts) == 5
551
- and parts[0] == shard_name
552
- and parts[1] == "chunk"
553
- and int(parts[2]) == chunk_index
554
- and parts[3] == "idx"
555
- ):
556
-
373
+ # Group by chunk
374
+ processed_by_chunk = defaultdict(set)
375
+
376
+ for job_id in processed_job_ids:
377
+ # Parse job_id to extract chunk and index
378
+ # Expected format: "shard:chunk:X:idx:Y"
379
+ parts = job_id.split(":")
380
+ if len(parts) >= 5 and parts[3] == "idx":
381
+ chunk_id = ":".join(parts[:3]) # "shard:chunk:X"
382
+ try:
557
383
  idx = int(parts[4])
558
- if start_index <= idx < start_index + chunk_size:
559
- processed_indices.append(idx)
384
+ processed_by_chunk[chunk_id].add(idx)
385
+ except ValueError:
386
+ continue
560
387
 
561
- if processed_indices:
562
- # Convert to ranges
563
- processed_indices.sort()
564
- processed_ranges = []
565
- start = processed_indices[0]
566
- end = processed_indices[0]
567
-
568
- for idx in processed_indices[1:]:
569
- if idx == end + 1:
570
- end = idx
571
- else:
572
- processed_ranges.append((start, end))
573
- start = idx
574
- end = idx
575
-
576
- processed_ranges.append((start, end))
577
-
578
- # Calculate unprocessed ranges
579
- total_range = [(start_index, start_index + chunk_size - 1)]
580
- unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
581
-
582
- # Update unit
583
- unit.data["unprocessed_ranges"] = unprocessed_ranges
584
-
585
- logger.debug(
586
- f"Updated unit {unit_id}: {len(processed_indices)} processed, "
587
- f"unprocessed ranges: {unprocessed_ranges}"
588
- )
388
+ # Update chunk tracker with processed items
389
+ if self.chunk_tracker:
390
+ for chunk_id, indices in processed_by_chunk.items():
391
+ if indices:
392
+ # Sort indices and convert to ranges
393
+ sorted_indices = sorted(indices)
394
+ for idx in sorted_indices:
395
+ self.chunk_tracker.mark_items_processed(chunk_id, idx, idx)
589
396
 
397
+ def get_stats(self) -> Dict[str, Any]:
398
+ """Get processor statistics."""
399
+ with self.lock:
400
+ # Get chunk tracker stats if available
401
+ if self.chunk_tracker:
402
+ shards_summary = self.chunk_tracker.get_shards_summary()
403
+ total_chunks = sum(len(s.get("chunks", [])) for s in shards_summary.values())
404
+ completed_chunks = sum(
405
+ 1
406
+ for s in shards_summary.values()
407
+ for c in s.get("chunks", [])
408
+ if c.status == "completed"
409
+ )
410
+ else:
411
+ total_chunks = len(self.work_units)
412
+ completed_chunks = 0
590
413
 
591
- class WebDatasetWorkerProcessor(WorkerProcessor):
592
- """Worker processor for WebDataset shards."""
414
+ return {
415
+ "total_shards": self.dataset.num_shards if self.dataset else 0,
416
+ "total_chunks": total_chunks,
417
+ "pending_units": len(self.pending_units),
418
+ "assigned_units": sum(len(units) for units in self.assigned_units.values()),
419
+ "completed_chunks": completed_chunks,
420
+ "workers": len(self.assigned_units),
421
+ }
593
422
 
594
- def __init__(self):
595
- logger.debug("Initializing WebDatasetWorkerProcessor")
596
- self.dataset_loader: Optional[DatasetLoader] = None
597
- self.dataset_config: Dict[str, Any] = {}
598
- self.dataset_name: Optional[str] = None
423
+ def cleanup(self):
424
+ """Clean up resources."""
425
+ logger.info("Cleaning up orchestrator")
599
426
 
600
- def initialize(self, config: ProcessorConfig) -> None:
601
- """Initialize WebDataset processor."""
602
- logger.debug("Initializing worker with config: %s", config.config)
603
- cfg = config.config["dataset"]
604
-
605
- # Store config
606
- self.dataset_config = cfg
607
-
608
- # Initialize dataset loader
609
- dataset_path = cfg.get("dataset_path")
610
- self.dataset_path = dataset_path
611
- dataset_type = cfg.get("dataset_type", "huggingface")
612
- dataset_split = cfg.get("dataset_split", "train")
613
- image_column = cfg.get("dataset_image_column", "image")
614
-
615
- if dataset_path:
616
- self.dataset_loader = DatasetLoader(
617
- dataset_path=dataset_path,
618
- dataset_type=dataset_type,
619
- split=dataset_split,
620
- image_column=image_column,
621
- )
622
- logger.debug("DatasetLoader initialized for worker")
623
- else:
624
- logger.error("No dataset_path provided in worker config")
427
+ # Stop background threads
428
+ self.stop_creation.set()
429
+ if self.unit_creation_thread:
430
+ self.unit_creation_thread.join(timeout=5)
625
431
 
626
- def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
627
- """Process a WebDataset chunk, yielding items to be captioned."""
628
- logger.debug("Processing unit: %s", unit.unit_id)
629
- if not self.dataset_loader:
630
- logger.error("Dataset loader not initialized")
631
- return
432
+ # Save checkpoint
433
+ if self.chunk_tracker:
434
+ self.chunk_tracker.save_checkpoint()
632
435
 
633
- shard_name = unit.metadata["shard_name"]
634
- chunk_index = unit.metadata["chunk_index"]
635
- shard_url = unit.data["shard_url"]
636
- start_index = unit.data["start_index"]
637
- chunk_size = unit.data["chunk_size"]
638
- unprocessed_ranges = unit.data.get(
639
- "unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
640
- )
641
436
 
642
- logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
437
+ class WebDatasetWorkerProcessor(WorkerProcessor):
438
+ """Worker processor for WebDataset shards using webshart."""
643
439
 
644
- # Create set of indices to process
645
- indices_to_process = set()
646
- for start, end in unprocessed_ranges:
647
- indices_to_process.update(range(start, end + 1))
648
- logger.debug("Indices to process: %s", indices_to_process)
440
+ def __init__(self):
441
+ logger.info("Initializing WebDatasetWorkerProcessor with webshart")
442
+ self.loader: Optional[webshart.TarDataLoader] = None
443
+ self.dataset: Optional[webshart.DiscoveredDataset] = None
444
+ self.mock_results = False
649
445
 
650
- processed_indices = []
446
+ def initialize(self, config: ProcessorConfig) -> None:
447
+ """Initialize worker with webshart loader."""
448
+ cfg = config.config
449
+ dataset_cfg = cfg.get("dataset", {})
651
450
 
652
- # Iterate through shard
653
- for idx, (key, url, image_data, metadata) in enumerate(
654
- self._iterate_shard_with_metadata(shard_url)
655
- ):
656
- # Skip if not in our chunk range
657
- if idx < start_index or idx >= start_index + chunk_size:
658
- # logger.debug(f"Skipping idx={idx} not in chunk range")
659
- continue
451
+ self.dataset_path = dataset_cfg.get("dataset_path")
452
+ metadata_path = dataset_cfg.get("metadata_path", None)
453
+ self.mock_results = dataset_cfg.get("mock_results", False)
660
454
 
661
- # Skip if already processed
662
- if idx not in indices_to_process:
663
- logger.debug(f"Skipping idx={idx} already processed")
664
- continue
455
+ # Cache configuration
456
+ cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
457
+ cache_dir.mkdir(parents=True, exist_ok=True)
665
458
 
666
- try:
667
- # Load image
668
- image = Image.open(io.BytesIO(image_data))
669
- job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
670
-
671
- # Clean metadata - remove sensitive and redundant fields
672
- clean_metadata = {
673
- k: v
674
- for k, v in metadata.items()
675
- if k not in ["url", "_shard_url", "shard_name"] # Remove these fields
676
- }
677
-
678
- # Add only necessary index information
679
- clean_metadata.update(
680
- {
681
- "_item_index": idx,
682
- "_chunk_relative_index": idx - start_index,
683
- "_job_id": job_id,
684
- }
685
- )
459
+ if self.dataset_path and not self.mock_results:
460
+ # Discover dataset
461
+ self.dataset = webshart.discover_dataset(
462
+ source=self.dataset_path,
463
+ metadata=metadata_path,
464
+ )
465
+
466
+ # Enable caching
467
+ self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
468
+ self.dataset.enable_shard_cache(
469
+ location=str(cache_dir / "shard_cache"),
470
+ cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
471
+ )
686
472
 
687
- # Prepare item for captioning
688
- # logger.debug("Yielding item idx=%d key=%s", idx, key)
689
- yield {
690
- "image": image,
691
- "item_key": key,
692
- "item_index": idx,
693
- "metadata": clean_metadata,
694
- "job_id": job_id,
695
- }
473
+ # Create loader
474
+ self.loader = webshart.TarDataLoader(
475
+ self.dataset,
476
+ buffer_size=cfg.get("buffer_size", 10),
477
+ max_file_size=cfg.get("max_file_size", 100 * 1024 * 1024),
478
+ load_file_data=True,
479
+ )
696
480
 
697
- processed_indices.append(idx)
481
+ logger.info("webshart TarDataLoader initialized")
698
482
 
699
- except Exception as e:
700
- logger.error(f"Error processing item {key}: {e}")
483
+ def _create_mock_image(self, idx: int) -> Image.Image:
484
+ """Create a dummy test image."""
485
+ color = ((idx * 37) % 256, (idx * 53) % 256, (idx * 71) % 256)
486
+ image = Image.new("RGB", (256, 256), color=color)
487
+ return image
701
488
 
702
- # Store processed indices in context for result preparation
703
- context["_processed_indices"] = processed_indices
704
- logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
489
+ def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
490
+ """Process a work unit by iterating specified ranges."""
491
+ logger.debug(f"Processing unit: {unit.unit_id}")
705
492
 
706
- def _iterate_shard_with_metadata(
707
- self, shard_url: str
708
- ) -> Iterator[Tuple[str, str, bytes, Dict]]:
709
- """Iterate through a shard with metadata."""
710
- logger.debug("Iterating shard with metadata: %s", shard_url)
493
+ shard_name = unit.data["shard_name"]
494
+ shard_idx = unit.data.get("shard_idx")
495
+ unprocessed_ranges = unit.data.get("unprocessed_ranges", [])
711
496
 
712
- if not self.dataset_loader:
713
- logger.error("Dataset loader not initialized")
714
- return
497
+ # For chunk tracking
498
+ chunk_index = unit.metadata.get("chunk_index", 0)
499
+ processed_indices = []
715
500
 
716
- # Use the DatasetLoader that returns full samples
717
- for sample in self.dataset_loader.iterate_shard(shard_url):
718
- if not isinstance(sample, dict):
719
- logger.warning("Unexpected sample format: %s", type(sample))
720
- continue
501
+ if self.mock_results:
502
+ # Generate mock results for unprocessed ranges
503
+ for start_idx, end_idx in unprocessed_ranges:
504
+ for idx in range(start_idx, end_idx + 1):
505
+ job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
506
+
507
+ yield {
508
+ "image": self._create_mock_image(idx),
509
+ "image_data": None,
510
+ "item_key": f"mock_{idx}",
511
+ "item_index": idx,
512
+ "metadata": {
513
+ "_item_index": idx,
514
+ "_chunk_relative_index": idx - unit.data["start_index"],
515
+ "_job_id": job_id,
516
+ "_mock": True,
517
+ },
518
+ "job_id": job_id,
519
+ }
721
520
 
722
- key = sample.get("__key__", "unknown")
723
- url = sample.get("__url__", "") # Don't use shard_url as default
521
+ processed_indices.append(idx)
522
+ else:
523
+ # Use webshart to process unprocessed ranges
524
+ for start_idx, end_idx in unprocessed_ranges:
525
+ try:
526
+ # Jump to shard and starting position
527
+ if shard_idx is not None:
528
+ self.loader.shard(shard_idx=shard_idx, cursor_idx=start_idx)
529
+ else:
530
+ # Try to find shard by name
531
+ self.loader.shard(filename=shard_name, cursor_idx=start_idx)
532
+
533
+ # Iterate through the range
534
+ for idx in range(start_idx, end_idx + 1):
535
+ try:
536
+ entry = next(self.loader)
537
+
538
+ # Decode image
539
+ image = None
540
+ if entry.data:
541
+ try:
542
+ # Use cv2 to decode from memory
543
+ nparr = np.frombuffer(entry.data, np.uint8)
544
+ img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
545
+
546
+ if img_np is not None:
547
+ # Convert from BGR (OpenCV default) to RGB (PIL default)
548
+ img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
549
+ image = Image.fromarray(img_rgb)
550
+ else:
551
+ logger.warning(f"cv2.imdecode failed for {entry.path}")
552
+
553
+ except ImportError:
554
+ logger.warning(
555
+ "cv2 or numpy not installed, falling back to PIL"
556
+ )
557
+ image = Image.open(io.BytesIO(entry.data))
558
+ except Exception as img_e:
559
+ logger.error(
560
+ f"Error decoding image {entry.path} with cv2: {img_e}"
561
+ )
724
562
 
725
- # Find image data
726
- image_data = None
727
- image_ext = None
728
- for ext in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]:
729
- if ext in sample:
730
- image_data = sample[ext]
731
- image_ext = ext
732
- break
563
+ # Generate job ID compatible with chunk tracker
564
+ job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
565
+
566
+ yield {
567
+ "image": image,
568
+ "image_data": entry.data,
569
+ "item_key": Path(entry.path).stem,
570
+ "item_index": idx,
571
+ "metadata": {
572
+ "_item_index": idx,
573
+ "_chunk_relative_index": idx - unit.data["start_index"],
574
+ "_job_id": job_id,
575
+ "_filename": entry.path,
576
+ "_file_size": entry.size,
577
+ },
578
+ "job_id": job_id,
579
+ }
733
580
 
734
- if not image_data:
735
- logger.debug(
736
- "No image data found for item key=%s, available keys: %s",
737
- key,
738
- list(sample.keys()),
739
- )
740
- continue
581
+ processed_indices.append(idx)
741
582
 
742
- # Extract metadata (all non-system and non-image keys)
743
- metadata = {
744
- k: v
745
- for k, v in sample.items()
746
- if not k.startswith("__") and k not in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]
747
- }
583
+ if len(processed_indices) % 10 == 0:
584
+ gc.collect()
585
+
586
+ except StopIteration:
587
+ logger.warning(f"Unexpected end of shard at index {idx}")
588
+ break
589
+ except Exception as e:
590
+ logger.error(f"Error processing index {idx}: {e}")
591
+ continue
748
592
 
749
- # Add image format but not URLs
750
- if image_ext:
751
- metadata["_image_format"] = image_ext
593
+ except Exception as e:
594
+ logger.error(f"Error processing range {start_idx}-{end_idx}: {e}")
595
+ continue
752
596
 
753
- yield key, url, image_data, metadata
597
+ # Store processed indices for result
598
+ context["_processed_indices"] = processed_indices
599
+ logger.info(f"Processed {len(processed_indices)} items from unit {unit.unit_id}")
754
600
 
755
601
  def prepare_result(
756
602
  self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
757
603
  ) -> WorkResult:
758
- """Prepare WebDataset-specific result."""
759
- logger.debug("Preparing result for unit %s", unit.unit_id)
604
+ """Prepare result with processing details."""
760
605
  result = super().prepare_result(unit, outputs, processing_time_ms)
761
606
 
762
- # Add processed indices to metadata if available
607
+ # Add processed indices for chunk tracker
763
608
  if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
764
609
  result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
765
- logger.debug(
766
- "Added item_indices to result metadata: %s", result.metadata["item_indices"]
767
- )
768
610
 
769
611
  return result
770
612
 
771
613
  def get_dataset_info(self) -> Dict[str, Any]:
772
614
  """Get dataset information."""
773
- if self.dataset_loader:
774
- info = self.dataset_loader.get_dataset_info()
775
- logger.debug("Dataset info: %s", info)
776
- return info
777
- info = {
778
- "dataset_path": self.dataset_config.get("dataset_path"),
779
- "dataset_type": self.dataset_config.get("type", "huggingface"),
615
+ if self.dataset:
616
+ stats = self.dataset.get_stats()
617
+ return {
618
+ "dataset_name": self.dataset.name,
619
+ "format": self.dataset.dataset_format,
620
+ "total_shards": stats["total_shards"],
621
+ "total_files": stats.get("total_files", "Unknown"),
622
+ "mock_results": self.mock_results,
623
+ }
624
+ return {
625
+ "dataset_name": "Mock Dataset" if self.mock_results else "Unknown",
626
+ "mock_results": self.mock_results,
780
627
  }
781
- logger.debug("Dataset info (no loader): %s", info)
782
- return info