caption-flow 0.2.4__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,104 +1,113 @@
1
- """WebDataset processor implementation."""
1
+ """WebDataset processor implementation using webshart TarDataLoader."""
2
2
 
3
3
  import logging
4
4
  import threading
5
+ import gc
6
+ import os
5
7
  from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
6
8
  from collections import deque, defaultdict
7
9
  from pathlib import Path
8
10
  import json
9
- import io
10
11
  from datetime import datetime
11
12
  from PIL import Image
12
- from caption_flow.storage import StorageManager
13
+ import io
13
14
 
15
+ from caption_flow.models import JobId
16
+ from caption_flow.storage import StorageManager
14
17
  from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
15
- from ..utils import DatasetLoader, ChunkTracker
18
+ from ..utils import ChunkTracker
19
+
20
+ import webshart
21
+ import cv2
22
+ import numpy as np
16
23
 
17
24
  logger = logging.getLogger(__name__)
18
25
  logger.setLevel(logging.INFO)
19
26
 
20
27
 
21
28
  class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
22
- """Orchestrator processor for WebDataset shards."""
29
+ """Orchestrator processor for WebDataset shards using webshart with ChunkTracker."""
23
30
 
24
31
  def __init__(self):
25
- logger.debug("Initializing WebDatasetOrchestratorProcessor")
26
- self.dataset_loader: Optional[DatasetLoader] = None
32
+ logger.info("Initializing WebDatasetOrchestratorProcessor with webshart + ChunkTracker")
33
+ self.dataset: Optional[webshart.DiscoveredDataset] = None
27
34
  self.chunk_tracker: Optional[ChunkTracker] = None
28
35
  self.chunk_size: int = 1000
29
36
 
30
37
  # Work unit management
31
38
  self.work_units: Dict[str, WorkUnit] = {}
32
39
  self.pending_units: Deque[str] = deque()
33
- self.assigned_units: Dict[str, Set[str]] = defaultdict(set) # worker_id -> unit_ids
40
+ self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
34
41
  self.lock = threading.Lock()
35
42
 
36
- # Shard processing state
37
- self.all_shards: List[str] = []
38
- self.current_shard_index = 0
39
- self.current_shard_items = 0
43
+ # Shard info cache
44
+ self.shard_info_cache: Dict[int, Dict] = {}
40
45
 
41
46
  # Background thread for creating work units
42
47
  self.unit_creation_thread: Optional[threading.Thread] = None
43
48
  self.stop_creation = threading.Event()
49
+ self.min_buffer = 10
50
+ self.buffer_multiplier = 3
44
51
 
45
52
  def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
46
- """Initialize WebDataset processor."""
47
- logger.debug("Initializing orchestrator with config: %s", config.config)
48
- cfg = config.config
53
+ """Initialize with webshart dataset discovery and ChunkTracker."""
54
+ logger.info("Initializing orchestrator with config")
49
55
 
50
- # Dataset configuration
56
+ cfg = config.config
51
57
  dataset_cfg = cfg.get("dataset", {})
52
- dataset_path = dataset_cfg.get("dataset_path")
53
- dataset_type = dataset_cfg.get("dataset_type", "huggingface")
54
- dataset_split = dataset_cfg.get("dataset_split", "train")
55
- image_column = dataset_cfg.get("dataset_image_column", "image")
58
+ self.dataset_path = dataset_cfg.get("dataset_path")
59
+ metadata_path = dataset_cfg.get("metadata_path", None)
56
60
 
57
61
  # Chunk settings
58
62
  self.chunk_size = cfg.get("chunk_size", 1000)
59
63
  self.min_buffer = cfg.get("min_chunk_buffer", 10)
60
64
  self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
61
65
 
62
- logger.debug(
63
- "Chunk size: %d, min_buffer: %d, buffer_multiplier: %d",
64
- self.chunk_size,
65
- self.min_buffer,
66
- self.buffer_multiplier,
67
- )
66
+ # Cache configuration
67
+ cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
68
+ cache_dir.mkdir(parents=True, exist_ok=True)
68
69
 
69
- # Initialize dataset loader
70
- if dataset_path:
71
- checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
72
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
73
- logger.debug("Checkpoint dir: %s", checkpoint_dir)
74
-
75
- self.dataset_loader = DatasetLoader(
76
- dataset_path=dataset_path,
77
- dataset_type=dataset_type,
78
- split=dataset_split,
79
- image_column=image_column,
80
- cache_dir=checkpoint_dir,
70
+ if self.dataset_path:
71
+ # Initialize dataset with webshart
72
+ self.dataset = webshart.discover_dataset(
73
+ source=self.dataset_path,
74
+ metadata=metadata_path,
81
75
  )
82
- logger.debug("DatasetLoader initialized")
83
76
 
84
- self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
85
- logger.debug("ChunkTracker initialized at %s", checkpoint_dir / "chunks.json")
77
+ # Enable caching for efficient access
78
+ self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
79
+ self.dataset.enable_shard_cache(
80
+ location=str(cache_dir / "shard_cache"),
81
+ cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
82
+ )
83
+
84
+ logger.info(f"Dataset discovered: {self.dataset.num_shards} shards")
86
85
 
87
- # Get all shards
88
- self.all_shards = self.dataset_loader.get_shard_list()
89
- logger.debug("All shards: %s", self.all_shards)
86
+ # Initialize chunk tracker
87
+ checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
88
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
89
+ self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
90
90
 
91
91
  # Restore existing state from chunk tracker
92
- self._restore_state(storage=storage)
92
+ self._restore_state(storage)
93
93
 
94
94
  # Start background unit creation
95
95
  self.unit_creation_thread = threading.Thread(
96
96
  target=self._create_units_background, daemon=True
97
97
  )
98
98
  self.unit_creation_thread.start()
99
- logger.debug("Unit creation thread started")
100
99
  else:
101
- logger.error("No dataset_path provided in config")
100
+ logger.error("No dataset_path provided")
101
+
102
+ def _get_shard_info_cached(self, shard_idx: int) -> Optional[Dict]:
103
+ """Get shard info with caching."""
104
+ if shard_idx not in self.shard_info_cache:
105
+ try:
106
+ self.shard_info_cache[shard_idx] = self.dataset.get_shard_info(shard_idx)
107
+ except Exception as e:
108
+ logger.error(f"Error getting shard info for idx {shard_idx}: {e}")
109
+ return None
110
+ return self.shard_info_cache[shard_idx]
102
111
 
103
112
  def _restore_state(self, storage: StorageManager) -> None:
104
113
  """Restore state from chunk tracker."""
@@ -107,39 +116,43 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
107
116
  return
108
117
 
109
118
  shards_summary = self.chunk_tracker.get_shards_summary()
110
-
111
- # Get all processed job_ids from storage
112
- all_processed_jobs = storage.get_all_processed_job_ids()
119
+ logger.debug(f"Restoring state: {shards_summary}")
113
120
 
114
121
  with self.lock:
115
122
  for shard_name, shard_info in shards_summary.items():
116
- for chunk_state in shard_info["chunks"]:
117
- # Calculate actual unprocessed ranges based on what's in storage
118
- chunk_range = (
119
- chunk_state.start_index,
120
- chunk_state.start_index + chunk_state.chunk_size - 1,
121
- )
122
-
123
- # Get processed indices for this chunk
124
- processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
125
- chunk_state.chunk_id, all_processed_jobs
126
- )
123
+ chunks = shard_info.get("chunks", [])
124
+ logger.debug(f"Existing job ids: {storage.get_all_processed_job_ids()}")
125
+ for chunk_state in chunks:
126
+ # Only add incomplete chunks
127
+ if chunk_state.status != "completed":
128
+ logger.debug(f"Restoring incomplete chunk {chunk_state}")
129
+
130
+ # Get unprocessed ranges
131
+ unprocessed_ranges = chunk_state.get_unprocessed_ranges()
132
+ logger.debug(
133
+ f"Chunk {chunk_state.chunk_id} unprocessed ranges: {unprocessed_ranges}"
134
+ )
135
+ if not unprocessed_ranges:
136
+ continue
127
137
 
128
- # Calculate unprocessed ranges
129
- unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
138
+ # Convert relative ranges to absolute file indices
139
+ absolute_ranges = []
140
+ for start, end in unprocessed_ranges:
141
+ abs_start = chunk_state.start_index + start
142
+ abs_end = chunk_state.start_index + end
143
+ absolute_ranges.append((abs_start, abs_end))
130
144
 
131
- if unprocessed_ranges:
132
- # Create work unit for unprocessed items
133
- logger.debug(f"Creating WorkUnit for chunk {chunk_state}")
134
145
  unit = WorkUnit(
135
146
  unit_id=chunk_state.chunk_id,
136
147
  chunk_id=chunk_state.chunk_id,
137
148
  source_id=shard_name,
149
+ unit_size=chunk_state.chunk_size,
138
150
  data={
139
151
  "shard_url": chunk_state.shard_url,
152
+ "shard_name": shard_name,
140
153
  "start_index": chunk_state.start_index,
141
154
  "chunk_size": chunk_state.chunk_size,
142
- "unprocessed_ranges": unprocessed_ranges,
155
+ "unprocessed_ranges": absolute_ranges,
143
156
  },
144
157
  metadata={
145
158
  "shard_name": shard_name,
@@ -154,11 +167,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
154
167
  """Background thread to create work units on demand."""
155
168
  logger.info("Starting work unit creation thread")
156
169
 
157
- shard_iter = iter(self.all_shards)
158
- current_shard_url = None
159
- current_shard_name = None
160
- current_shard_items = 0
161
- current_index = 0
170
+ current_shard_idx = 0
171
+ current_file_idx = 0
162
172
 
163
173
  while not self.stop_creation.is_set():
164
174
  # Check if we need more units
@@ -169,14 +179,6 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
169
179
 
170
180
  target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
171
181
  units_needed = max(0, target_buffer - (pending_count + assigned_count))
172
- logger.debug(
173
- "pending_count=%d assigned_count=%d worker_count=%d target_buffer=%d units_needed=%d",
174
- pending_count,
175
- assigned_count,
176
- worker_count,
177
- target_buffer,
178
- units_needed,
179
- )
180
182
 
181
183
  if units_needed == 0:
182
184
  threading.Event().wait(5)
@@ -184,318 +186,192 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
184
186
 
185
187
  # Create units as needed
186
188
  units_created = 0
187
-
188
189
  while units_created < units_needed and not self.stop_creation.is_set():
189
- # Load next shard if needed
190
- if current_shard_url is None or current_index >= current_shard_items:
191
- try:
192
- current_shard_url = next(shard_iter)
193
- current_shard_name = Path(current_shard_url).stem
190
+ # Get current shard info
191
+ if current_shard_idx >= self.dataset.num_shards:
192
+ logger.info("All shards processed")
193
+ break
194
194
 
195
- logger.debug("Loading shard: %s", current_shard_url)
196
- # Count items in shard
197
- current_shard_items = sum(
198
- 1 for _ in self.dataset_loader.iterate_shard(current_shard_url)
199
- )
200
- logger.info(
201
- f"Processing shard {current_shard_name} with {current_shard_items} items"
202
- )
203
- current_index = 0
204
-
205
- except StopIteration:
206
- logger.info("All shards processed")
207
- break
208
- except Exception as e:
209
- logger.error("Error loading shard: %s", e)
210
- break
211
-
212
- # Create work unit
213
- if current_shard_url and current_index < current_shard_items:
214
- chunk_size = min(self.chunk_size, current_shard_items - current_index)
215
- unit_id = f"{current_shard_name}:chunk:{current_index // self.chunk_size}"
216
-
217
- with self.lock:
218
- # Check if this unit already exists in work_units
219
- if unit_id in self.work_units:
220
- logger.debug(
221
- f"Unit {unit_id} already exists in work_units, skipping creation"
222
- )
223
- current_index += self.chunk_size
195
+ shard_info = self._get_shard_info_cached(current_shard_idx)
196
+ if not shard_info:
197
+ current_shard_idx += 1
198
+ current_file_idx = 0
199
+ continue
200
+
201
+ shard_name = shard_info["name"]
202
+ shard_files = shard_info["num_files"]
203
+
204
+ # Check if we need to move to next shard
205
+ if current_file_idx >= shard_files:
206
+ current_shard_idx += 1
207
+ current_file_idx = 0
208
+ continue
209
+
210
+ # Create chunk for current position
211
+ chunk_size = min(self.chunk_size, shard_files - current_file_idx)
212
+ self.current_chunk_index = current_file_idx // self.chunk_size
213
+ job_id_obj = JobId(
214
+ shard_id=shard_name,
215
+ chunk_id=self.current_chunk_index,
216
+ sample_id=current_file_idx,
217
+ )
218
+ chunk_id = job_id_obj.get_chunk_str()
219
+
220
+ with self.lock:
221
+ # Skip if already exists
222
+ if chunk_id in self.work_units:
223
+ current_file_idx += self.chunk_size
224
+ continue
225
+
226
+ # Check if chunk is already completed
227
+ if self.chunk_tracker:
228
+ chunk_state = self.chunk_tracker.chunks.get(chunk_id)
229
+ if chunk_state and chunk_state.status == "completed":
230
+ current_file_idx += self.chunk_size
224
231
  continue
225
232
 
226
- # Check if chunk is already completed or has no unprocessed items
227
- if self.chunk_tracker:
228
- chunk_state = self.chunk_tracker.chunks.get(unit_id)
229
-
230
- if chunk_state:
231
- # Check if completed
232
- if chunk_state.status == "completed":
233
- logger.debug(f"Unit {unit_id} already completed, skipping")
234
- current_index += self.chunk_size
235
- continue
236
-
237
- # Check if has unprocessed items
238
- unprocessed_ranges = chunk_state.get_unprocessed_ranges()
239
- if not unprocessed_ranges:
240
- logger.debug(
241
- f"Unit {unit_id} has no unprocessed items, skipping"
242
- )
243
- current_index += self.chunk_size
244
- continue
245
-
246
- # If chunk exists but has unprocessed items, use those ranges
247
- logger.debug(
248
- f"Existing chunk {unit_id} has unprocessed ranges: {unprocessed_ranges}"
249
- )
250
-
251
- unit = WorkUnit(
252
- unit_id=unit_id,
253
- chunk_id=unit_id,
254
- source_id=current_shard_name,
255
- data={
256
- "shard_url": current_shard_url,
257
- "start_index": current_index,
258
- "chunk_size": chunk_size,
259
- "unprocessed_ranges": [
260
- (
261
- r[0] + chunk_state.start_index,
262
- r[1] + chunk_state.start_index,
263
- )
264
- for r in unprocessed_ranges
265
- ], # Convert relative to absolute
266
- },
267
- metadata={
268
- "shard_name": current_shard_name,
269
- "chunk_index": current_index // self.chunk_size,
270
- },
271
- )
272
- else:
273
- # New chunk
274
- logger.debug(
275
- "Creating new work unit: unit_id=%s shard=%s start_index=%d chunk_size=%d",
276
- unit_id,
277
- current_shard_name,
278
- current_index,
279
- chunk_size,
280
- )
281
-
282
- unit = WorkUnit(
283
- unit_id=unit_id,
284
- chunk_id=unit_id,
285
- source_id=current_shard_name,
286
- data={
287
- "shard_url": current_shard_url,
288
- "start_index": current_index,
289
- "chunk_size": chunk_size,
290
- "unprocessed_ranges": [
291
- (current_index, current_index + chunk_size - 1)
292
- ],
293
- },
294
- metadata={
295
- "shard_name": current_shard_name,
296
- "chunk_index": current_index // self.chunk_size,
297
- },
298
- )
299
- else:
300
- # No chunk tracker, create normally
301
- unit = WorkUnit(
302
- unit_id=unit_id,
303
- chunk_id=unit_id,
304
- source_id=current_shard_name,
305
- data={
306
- "shard_url": current_shard_url,
307
- "start_index": current_index,
308
- "chunk_size": chunk_size,
309
- "unprocessed_ranges": [
310
- (current_index, current_index + chunk_size - 1)
311
- ],
312
- },
313
- metadata={
314
- "shard_name": current_shard_name,
315
- "chunk_index": current_index // self.chunk_size,
316
- },
317
- )
318
-
319
- self.work_units[unit_id] = unit
320
- self.pending_units.append(unit_id)
321
- logger.debug("Added work unit %s to pending_units", unit_id)
322
-
323
- if self.chunk_tracker:
324
- added_chunk = self.chunk_tracker.add_chunk(
325
- unit_id,
326
- current_shard_name,
327
- current_shard_url,
328
- current_index,
329
- chunk_size,
330
- )
331
- if added_chunk:
332
- logger.debug("Added chunk to chunk_tracker: %s", unit_id)
333
- else:
334
- logger.debug("Chunk already exists in chunk_tracker: %s", unit_id)
233
+ # Get shard URL (path for webshart)
234
+ shard_url = shard_info.get("path", f"{shard_name}.tar")
235
+
236
+ # Create work unit
237
+ unit = WorkUnit(
238
+ unit_id=chunk_id,
239
+ chunk_id=chunk_id,
240
+ source_id=shard_name,
241
+ unit_size=chunk_size,
242
+ data={
243
+ "shard_url": shard_url,
244
+ "shard_name": shard_name,
245
+ "shard_idx": current_shard_idx,
246
+ "start_index": current_file_idx,
247
+ "chunk_size": chunk_size,
248
+ "unprocessed_ranges": [
249
+ (current_file_idx, current_file_idx + chunk_size - 1)
250
+ ],
251
+ },
252
+ metadata={
253
+ "shard_name": shard_name,
254
+ "chunk_index": current_file_idx // self.chunk_size,
255
+ },
256
+ )
257
+
258
+ self.work_units[chunk_id] = unit
259
+ self.pending_units.append(chunk_id)
260
+
261
+ # Add to chunk tracker
262
+ if self.chunk_tracker:
263
+ self.chunk_tracker.add_chunk(
264
+ chunk_id, shard_name, shard_url, current_file_idx, chunk_size
265
+ )
335
266
 
336
- units_created += 1
267
+ units_created += 1
337
268
 
338
- current_index += self.chunk_size
269
+ current_file_idx += self.chunk_size
339
270
 
340
271
  if units_created > 0:
341
272
  logger.debug(f"Created {units_created} work units")
342
273
 
274
+ logger.info("Work unit creation thread exiting")
275
+
343
276
  def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
344
277
  """Get available work units for a worker."""
345
- logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
346
278
  assigned = []
347
279
 
348
280
  with self.lock:
349
- # Get new units if needed
350
281
  while len(assigned) < count and self.pending_units:
351
282
  unit_id = self.pending_units.popleft()
352
283
  unit = self.work_units.get(unit_id)
353
284
 
354
285
  if unit:
286
+ # Update unprocessed ranges from chunk tracker before assigning
287
+ if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
288
+ chunk_state = self.chunk_tracker.chunks[unit_id]
289
+ relative_unprocessed = chunk_state.get_unprocessed_ranges()
290
+
291
+ # Convert relative to absolute indices
292
+ absolute_ranges = []
293
+ for start, end in relative_unprocessed:
294
+ abs_start = chunk_state.start_index + start
295
+ abs_end = chunk_state.start_index + end
296
+ absolute_ranges.append((abs_start, abs_end))
297
+
298
+ # Update the work unit's unprocessed ranges
299
+ unit.data["unprocessed_ranges"] = absolute_ranges
300
+
301
+ logger.debug(
302
+ f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
303
+ )
304
+
355
305
  self.assigned_units[worker_id].add(unit_id)
356
306
  assigned.append(unit)
357
- logger.debug("Assigning new unit %s to worker %s", unit_id, worker_id)
358
307
 
359
308
  if self.chunk_tracker:
360
309
  self.chunk_tracker.mark_assigned(unit_id, worker_id)
361
310
 
362
- logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
311
+ logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
363
312
  return assigned
364
313
 
365
- def _has_unprocessed_items(self, unit: WorkUnit) -> bool:
366
- """Check if a work unit has unprocessed items."""
367
- if not self.chunk_tracker:
368
- logger.debug("No chunk_tracker, assuming unit %s has unprocessed items", unit.unit_id)
369
- return True
370
-
371
- chunk_info = self.chunk_tracker.get_chunk_with_unprocessed_items(unit.unit_id)
372
- has_unprocessed = bool(chunk_info and chunk_info.get("unprocessed_ranges"))
373
- logger.debug("Unit %s has unprocessed items: %s", unit.unit_id, has_unprocessed)
374
- return has_unprocessed
375
-
376
314
  def mark_completed(self, unit_id: str, worker_id: str) -> None:
377
315
  """Mark a work unit as completed."""
378
- logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
379
316
  with self.lock:
380
317
  if unit_id in self.work_units:
381
318
  self.assigned_units[worker_id].discard(unit_id)
382
- logger.debug(
383
- "Removed unit %s from assigned_units for worker %s", unit_id, worker_id
384
- )
385
319
 
386
320
  if self.chunk_tracker:
387
321
  self.chunk_tracker.mark_completed(unit_id)
388
- logger.debug("Marked unit %s as completed in chunk_tracker", unit_id)
322
+
323
+ # Remove from memory
324
+ del self.work_units[unit_id]
389
325
 
390
326
  def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
391
327
  """Mark a work unit as failed."""
392
- logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
328
+ logger.error(f"Unit {unit_id} failed on {worker_id}: {error}")
393
329
  with self.lock:
394
330
  if unit_id in self.work_units:
395
331
  self.assigned_units[worker_id].discard(unit_id)
396
332
  self.pending_units.append(unit_id)
397
- logger.debug("Returned unit %s to pending_units", unit_id)
398
333
 
399
334
  if self.chunk_tracker:
400
335
  self.chunk_tracker.mark_failed(unit_id)
401
- logger.debug("Marked unit %s as failed in chunk_tracker", unit_id)
402
336
 
403
337
  def release_assignments(self, worker_id: str) -> None:
404
338
  """Release all assignments for a disconnected worker."""
405
- logger.debug("Releasing assignments for worker %s", worker_id)
406
339
  with self.lock:
407
340
  unit_ids = list(self.assigned_units.get(worker_id, []))
408
341
 
409
342
  for unit_id in unit_ids:
410
- if unit_id in self.work_units:
411
- unit = self.work_units[unit_id]
412
-
413
- # Update unprocessed ranges based on what's been processed
414
- if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
415
- chunk_state = self.chunk_tracker.chunks[unit_id]
343
+ if unit_id in self.work_units and self.chunk_tracker:
344
+ # Get updated unprocessed ranges from chunk tracker
345
+ chunk_state = self.chunk_tracker.chunks.get(unit_id)
346
+ if chunk_state:
416
347
  unprocessed_ranges = chunk_state.get_unprocessed_ranges()
417
-
418
- # Convert relative ranges back to absolute
348
+ # Convert relative to absolute
419
349
  absolute_ranges = []
420
350
  for start, end in unprocessed_ranges:
421
351
  abs_start = chunk_state.start_index + start
422
352
  abs_end = chunk_state.start_index + end
423
353
  absolute_ranges.append((abs_start, abs_end))
424
354
 
425
- # Update the work unit's data
426
- unit.data["unprocessed_ranges"] = absolute_ranges
427
-
428
- logger.debug(
429
- f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
430
- )
355
+ # Update work unit
356
+ self.work_units[unit_id].data["unprocessed_ranges"] = absolute_ranges
431
357
 
432
358
  self.pending_units.append(unit_id)
433
- logger.debug("Returned unit %s to pending_units", unit_id)
434
359
 
435
360
  if worker_id in self.assigned_units:
436
361
  del self.assigned_units[worker_id]
437
- logger.debug("Deleted worker %s from assigned_units", worker_id)
438
362
 
439
363
  if self.chunk_tracker:
440
364
  self.chunk_tracker.release_worker_chunks(worker_id)
441
- logger.debug("Released worker %s chunks in chunk_tracker", worker_id)
442
-
443
- def _subtract_ranges(
444
- self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
445
- ) -> List[Tuple[int, int]]:
446
- """Subtract processed ranges from total ranges."""
447
- if not processed_ranges:
448
- return total_ranges
449
-
450
- # Create a set of all processed indices
451
- processed_indices = set()
452
- for start, end in processed_ranges:
453
- processed_indices.update(range(start, end + 1))
454
-
455
- # Find unprocessed ranges
456
- unprocessed_ranges = []
457
- for start, end in total_ranges:
458
- current_start = None
459
- for i in range(start, end + 1):
460
- if i not in processed_indices:
461
- if current_start is None:
462
- current_start = i
463
- else:
464
- if current_start is not None:
465
- unprocessed_ranges.append((current_start, i - 1))
466
- current_start = None
467
-
468
- if current_start is not None:
469
- unprocessed_ranges.append((current_start, end))
470
-
471
- return unprocessed_ranges
472
365
 
473
- def get_stats(self) -> Dict[str, Any]:
474
- """Get processor statistics."""
475
- with self.lock:
476
- stats = {
477
- "total_units": len(self.work_units),
478
- "pending_units": len(self.pending_units),
479
- "assigned_units": sum(len(units) for units in self.assigned_units.values()),
480
- "total_shards": len(self.all_shards),
481
- "workers": len(self.assigned_units),
482
- }
483
- logger.debug("Stats: %s", stats)
484
- return stats
366
+ logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
485
367
 
486
368
  def handle_result(self, result: WorkResult) -> Dict[str, Any]:
487
- """Handle WebDataset-specific result processing."""
488
- # logger.debug("Handling result for unit %s", result.unit_id)
489
- base_result = super().handle_result(result)
490
-
369
+ """Handle result from worker."""
491
370
  # Track processed items if we have chunk tracker
492
- if self.chunk_tracker:
493
- if "item_indices" not in result.metadata:
494
- result.metadata["item_indices"] = [result.metadata.get("_item_index")]
371
+ if self.chunk_tracker and "item_indices" in result.metadata:
495
372
  indices = result.metadata["item_indices"]
496
- logger.debug("Result metadata item_indices: %s", indices)
497
373
 
498
- # Group consecutive indices into ranges
374
+ # Convert to ranges and mark as processed
499
375
  if indices:
500
376
  indices.sort()
501
377
  ranges = []
@@ -514,269 +390,293 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
514
390
 
515
391
  # Mark ranges as processed
516
392
  for start_idx, end_idx in ranges:
517
- logger.debug(f"Marking chunk as processed: {result.to_repr()}")
518
393
  self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
519
- logger.debug(
520
- "Marked items processed for unit %s: %d-%d",
521
- result.unit_id,
522
- start_idx,
523
- end_idx,
524
- )
525
- else:
526
- logger.error(
527
- f"No chunk tracker? {self.chunk_tracker} or no item_indices in {result.metadata}"
528
- )
529
394
 
530
- return base_result
395
+ return {
396
+ "source_id": result.source_id,
397
+ "chunk_id": result.chunk_id,
398
+ "outputs": result.outputs,
399
+ "metadata": result.metadata,
400
+ }
531
401
 
532
402
  def update_from_storage(self, processed_job_ids: Set[str]) -> None:
533
403
  """Update work units based on what's been processed."""
534
- logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
404
+ logger.info(f"Updating from {len(processed_job_ids)} processed jobs")
535
405
 
536
406
  with self.lock:
537
- for unit_id, unit in self.work_units.items():
538
- # Extract chunk info from unit
539
- start_index = unit.data["start_index"]
540
- chunk_size = unit.data["chunk_size"]
541
- shard_name = unit.metadata["shard_name"]
542
- chunk_index = unit.metadata["chunk_index"]
543
-
544
- # Find processed indices for this chunk
545
- processed_indices = []
546
- for job_id in processed_job_ids:
547
- # Parse job_id format: "data-0000:chunk:0:idx:42"
548
- parts = job_id.split(":")
549
- if (
550
- len(parts) == 5
551
- and parts[0] == shard_name
552
- and parts[1] == "chunk"
553
- and int(parts[2]) == chunk_index
554
- and parts[3] == "idx"
555
- ):
556
-
407
+ # Group by chunk
408
+ processed_by_chunk = defaultdict(set)
409
+
410
+ for job_id in processed_job_ids:
411
+ # Parse job_id to extract chunk and index
412
+ # Expected format: "shard:chunk:X:idx:Y"
413
+ parts = job_id.split(":")
414
+ if len(parts) >= 5 and parts[3] == "idx":
415
+ chunk_id = ":".join(parts[:3]) # "shard:chunk:X"
416
+ try:
557
417
  idx = int(parts[4])
558
- if start_index <= idx < start_index + chunk_size:
559
- processed_indices.append(idx)
418
+ processed_by_chunk[chunk_id].add(idx)
419
+ except ValueError:
420
+ continue
560
421
 
561
- if processed_indices:
562
- # Convert to ranges
563
- processed_indices.sort()
564
- processed_ranges = []
565
- start = processed_indices[0]
566
- end = processed_indices[0]
567
-
568
- for idx in processed_indices[1:]:
569
- if idx == end + 1:
570
- end = idx
571
- else:
572
- processed_ranges.append((start, end))
573
- start = idx
574
- end = idx
575
-
576
- processed_ranges.append((start, end))
577
-
578
- # Calculate unprocessed ranges
579
- total_range = [(start_index, start_index + chunk_size - 1)]
580
- unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
581
-
582
- # Update unit
583
- unit.data["unprocessed_ranges"] = unprocessed_ranges
584
-
585
- logger.debug(
586
- f"Updated unit {unit_id}: {len(processed_indices)} processed, "
587
- f"unprocessed ranges: {unprocessed_ranges}"
588
- )
422
+ # Update chunk tracker with processed items
423
+ if self.chunk_tracker:
424
+ for chunk_id, indices in processed_by_chunk.items():
425
+ if indices:
426
+ # Sort indices and convert to ranges
427
+ sorted_indices = sorted(indices)
428
+ if not sorted_indices:
429
+ continue
589
430
 
431
+ # Condense into contiguous ranges
432
+ ranges = []
433
+ start_range = sorted_indices[0]
434
+ end_range = sorted_indices[0]
590
435
 
591
- class WebDatasetWorkerProcessor(WorkerProcessor):
592
- """Worker processor for WebDataset shards."""
436
+ for i in range(1, len(sorted_indices)):
437
+ if sorted_indices[i] == end_range + 1:
438
+ end_range = sorted_indices[i]
439
+ else:
440
+ ranges.append((start_range, end_range))
441
+ start_range = sorted_indices[i]
442
+ end_range = sorted_indices[i]
443
+ ranges.append((start_range, end_range))
593
444
 
594
- def __init__(self):
595
- logger.debug("Initializing WebDatasetWorkerProcessor")
596
- self.dataset_loader: Optional[DatasetLoader] = None
597
- self.dataset_config: Dict[str, Any] = {}
598
- self.dataset_name: Optional[str] = None
445
+ # Mark each contiguous range as processed
446
+ logger.debug(f"Marking ranges {ranges} as processed in chunk {chunk_id}")
447
+ for start_idx, end_idx in ranges:
448
+ self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
599
449
 
600
- def initialize(self, config: ProcessorConfig) -> None:
601
- """Initialize WebDataset processor."""
602
- logger.debug("Initializing worker with config: %s", config.config)
603
- cfg = config.config["dataset"]
604
-
605
- # Store config
606
- self.dataset_config = cfg
607
-
608
- # Initialize dataset loader
609
- dataset_path = cfg.get("dataset_path")
610
- self.dataset_path = dataset_path
611
- dataset_type = cfg.get("dataset_type", "huggingface")
612
- dataset_split = cfg.get("dataset_split", "train")
613
- image_column = cfg.get("dataset_image_column", "image")
614
-
615
- if dataset_path:
616
- self.dataset_loader = DatasetLoader(
617
- dataset_path=dataset_path,
618
- dataset_type=dataset_type,
619
- split=dataset_split,
620
- image_column=image_column,
621
- )
622
- logger.debug("DatasetLoader initialized for worker")
623
- else:
624
- logger.error("No dataset_path provided in worker config")
450
+ def get_stats(self) -> Dict[str, Any]:
451
+ """Get processor statistics."""
452
+ with self.lock:
453
+ # Get chunk tracker stats if available
454
+ if self.chunk_tracker:
455
+ shards_summary = self.chunk_tracker.get_shards_summary()
456
+ total_chunks = sum(len(s.get("chunks", [])) for s in shards_summary.values())
457
+ completed_chunks = sum(
458
+ 1
459
+ for s in shards_summary.values()
460
+ for c in s.get("chunks", [])
461
+ if c.status == "completed"
462
+ )
463
+ else:
464
+ total_chunks = len(self.work_units)
465
+ completed_chunks = 0
625
466
 
626
- def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
627
- """Process a WebDataset chunk, yielding items to be captioned."""
628
- logger.debug("Processing unit: %s", unit.unit_id)
629
- if not self.dataset_loader:
630
- logger.error("Dataset loader not initialized")
631
- return
467
+ return {
468
+ "total_shards": self.dataset.num_shards if self.dataset else 0,
469
+ "total_chunks": total_chunks,
470
+ "pending_units": len(self.pending_units),
471
+ "assigned_units": sum(len(units) for units in self.assigned_units.values()),
472
+ "completed_chunks": completed_chunks,
473
+ "workers": len(self.assigned_units),
474
+ }
632
475
 
633
- shard_name = unit.metadata["shard_name"]
634
- chunk_index = unit.metadata["chunk_index"]
635
- shard_url = unit.data["shard_url"]
636
- start_index = unit.data["start_index"]
637
- chunk_size = unit.data["chunk_size"]
638
- unprocessed_ranges = unit.data.get(
639
- "unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
640
- )
476
+ def cleanup(self):
477
+ """Clean up resources."""
478
+ logger.info("Cleaning up orchestrator")
641
479
 
642
- logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
480
+ # Stop background threads
481
+ self.stop_creation.set()
482
+ if self.unit_creation_thread:
483
+ self.unit_creation_thread.join(timeout=5)
643
484
 
644
- # Create set of indices to process
645
- indices_to_process = set()
646
- for start, end in unprocessed_ranges:
647
- indices_to_process.update(range(start, end + 1))
648
- logger.debug("Indices to process: %s", indices_to_process)
485
+ # Save checkpoint
486
+ if self.chunk_tracker:
487
+ self.chunk_tracker.save_checkpoint()
649
488
 
650
- processed_indices = []
651
489
 
652
- # Iterate through shard
653
- for idx, (key, url, image_data, metadata) in enumerate(
654
- self._iterate_shard_with_metadata(shard_url)
655
- ):
656
- # Skip if not in our chunk range
657
- if idx < start_index or idx >= start_index + chunk_size:
658
- # logger.debug(f"Skipping idx={idx} not in chunk range")
659
- continue
490
+ class WebDatasetWorkerProcessor(WorkerProcessor):
491
+ """Worker processor for WebDataset shards using webshart."""
660
492
 
661
- # Skip if already processed
662
- if idx not in indices_to_process:
663
- logger.debug(f"Skipping idx={idx} already processed")
664
- continue
493
+ def __init__(self):
494
+ logger.info("Initializing WebDatasetWorkerProcessor with webshart")
495
+ self.loader: Optional[webshart.TarDataLoader] = None
496
+ self.dataset: Optional[webshart.DiscoveredDataset] = None
497
+ self.mock_results = False
665
498
 
666
- try:
667
- # Load image
668
- image = Image.open(io.BytesIO(image_data))
669
- job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
670
-
671
- # Clean metadata - remove sensitive and redundant fields
672
- clean_metadata = {
673
- k: v
674
- for k, v in metadata.items()
675
- if k not in ["url", "_shard_url", "shard_name"] # Remove these fields
676
- }
677
-
678
- # Add only necessary index information
679
- clean_metadata.update(
680
- {
681
- "_item_index": idx,
682
- "_chunk_relative_index": idx - start_index,
683
- "_job_id": job_id,
684
- }
685
- )
499
+ def initialize(self, config: ProcessorConfig) -> None:
500
+ """Initialize worker with webshart loader."""
501
+ cfg = config.config
502
+ dataset_cfg = cfg.get("dataset", {})
686
503
 
687
- # Prepare item for captioning
688
- # logger.debug("Yielding item idx=%d key=%s", idx, key)
689
- yield {
690
- "image": image,
691
- "item_key": key,
692
- "item_index": idx,
693
- "metadata": clean_metadata,
694
- "job_id": job_id,
695
- }
504
+ self.dataset_path = dataset_cfg.get("dataset_path")
505
+ metadata_path = dataset_cfg.get("metadata_path", None)
506
+ self.mock_results = dataset_cfg.get("mock_results", False)
696
507
 
697
- processed_indices.append(idx)
508
+ # Cache configuration
509
+ cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
510
+ cache_dir.mkdir(parents=True, exist_ok=True)
698
511
 
699
- except Exception as e:
700
- logger.error(f"Error processing item {key}: {e}")
512
+ if self.dataset_path and not self.mock_results:
513
+ # Discover dataset
514
+ self.dataset = webshart.discover_dataset(
515
+ source=self.dataset_path,
516
+ metadata=metadata_path,
517
+ )
701
518
 
702
- # Store processed indices in context for result preparation
703
- context["_processed_indices"] = processed_indices
704
- logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
519
+ # Enable caching
520
+ self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
521
+ self.dataset.enable_shard_cache(
522
+ location=str(cache_dir / "shard_cache"),
523
+ cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
524
+ )
705
525
 
706
- def _iterate_shard_with_metadata(
707
- self, shard_url: str
708
- ) -> Iterator[Tuple[str, str, bytes, Dict]]:
709
- """Iterate through a shard with metadata."""
710
- logger.debug("Iterating shard with metadata: %s", shard_url)
526
+ # Create loader
527
+ self.loader = webshart.TarDataLoader(
528
+ self.dataset,
529
+ buffer_size=cfg.get("buffer_size", 10),
530
+ max_file_size=cfg.get("max_file_size", 100 * 1024 * 1024),
531
+ load_file_data=True,
532
+ )
711
533
 
712
- if not self.dataset_loader:
713
- logger.error("Dataset loader not initialized")
714
- return
534
+ logger.info("webshart TarDataLoader initialized")
715
535
 
716
- # Use the DatasetLoader that returns full samples
717
- for sample in self.dataset_loader.iterate_shard(shard_url):
718
- if not isinstance(sample, dict):
719
- logger.warning("Unexpected sample format: %s", type(sample))
720
- continue
536
+ def _create_mock_image(self, idx: int) -> Image.Image:
537
+ """Create a dummy test image."""
538
+ color = ((idx * 37) % 256, (idx * 53) % 256, (idx * 71) % 256)
539
+ image = Image.new("RGB", (256, 256), color=color)
540
+ return image
721
541
 
722
- key = sample.get("__key__", "unknown")
723
- url = sample.get("__url__", "") # Don't use shard_url as default
542
+ def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
543
+ """Process a work unit by iterating specified ranges."""
544
+ logger.debug(f"Processing unit: {unit}")
724
545
 
725
- # Find image data
726
- image_data = None
727
- image_ext = None
728
- for ext in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]:
729
- if ext in sample:
730
- image_data = sample[ext]
731
- image_ext = ext
732
- break
546
+ shard_name = unit.data["shard_name"]
547
+ shard_idx = unit.data.get("shard_idx")
548
+ unprocessed_ranges = unit.data.get("unprocessed_ranges", [])
733
549
 
734
- if not image_data:
735
- logger.debug(
736
- "No image data found for item key=%s, available keys: %s",
737
- key,
738
- list(sample.keys()),
739
- )
740
- continue
550
+ # For chunk tracking
551
+ chunk_index = unit.metadata.get("chunk_index", 0)
552
+ processed_indices = []
741
553
 
742
- # Extract metadata (all non-system and non-image keys)
743
- metadata = {
744
- k: v
745
- for k, v in sample.items()
746
- if not k.startswith("__") and k not in ["jpg", "jpeg", "png", "webp", "bmp", "jxl"]
747
- }
554
+ if self.mock_results:
555
+ # Generate mock results for unprocessed ranges
556
+ for start_idx, end_idx in unprocessed_ranges:
557
+ for idx in range(start_idx, end_idx + 1):
558
+ job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
559
+
560
+ yield {
561
+ "image": self._create_mock_image(idx),
562
+ "image_data": None,
563
+ "item_key": f"mock_{idx}",
564
+ "item_index": idx,
565
+ "metadata": {
566
+ "_item_index": idx,
567
+ "_chunk_relative_index": idx - unit.data["start_index"],
568
+ "_job_id": job_id,
569
+ "_mock": True,
570
+ "_processed_indices": processed_indices,
571
+ },
572
+ "job_id": job_id,
573
+ }
748
574
 
749
- # Add image format but not URLs
750
- if image_ext:
751
- metadata["_image_format"] = image_ext
575
+ processed_indices.append(idx)
576
+ else:
577
+ # Use webshart to process unprocessed ranges
578
+ for start_idx, end_idx in unprocessed_ranges:
579
+ try:
580
+ # Jump to shard and starting position
581
+ if shard_idx is not None:
582
+ self.loader.shard(shard_idx=shard_idx, cursor_idx=start_idx)
583
+ else:
584
+ # Try to find shard by name
585
+ self.loader.shard(filename=shard_name, cursor_idx=start_idx)
586
+
587
+ # Iterate through the range
588
+ for idx in range(start_idx, end_idx + 1):
589
+ try:
590
+ entry = next(self.loader)
591
+
592
+ # Decode image
593
+ image = None
594
+ if entry.data:
595
+ try:
596
+ # Use cv2 to decode from memory
597
+ nparr = np.frombuffer(entry.data, np.uint8)
598
+ img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
599
+
600
+ if img_np is not None:
601
+ # Convert from BGR (OpenCV default) to RGB (PIL default)
602
+ img_rgb = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
603
+ image = Image.fromarray(img_rgb)
604
+ else:
605
+ logger.warning(f"cv2.imdecode failed for {entry.path}")
606
+
607
+ except ImportError:
608
+ logger.warning(
609
+ "cv2 or numpy not installed, falling back to PIL"
610
+ )
611
+ image = Image.open(io.BytesIO(entry.data))
612
+ except Exception as img_e:
613
+ logger.error(
614
+ f"Error decoding image {entry.path} with cv2: {img_e}"
615
+ )
752
616
 
753
- yield key, url, image_data, metadata
617
+ # Generate job ID compatible with chunk tracker
618
+ job_id = f"{shard_name}:chunk:{chunk_index}:idx:{idx}"
619
+
620
+ yield {
621
+ "image": image,
622
+ "image_data": entry.data,
623
+ "item_key": Path(entry.path).stem,
624
+ "item_index": idx,
625
+ "metadata": {
626
+ "_item_index": idx,
627
+ "_chunk_relative_index": idx - unit.data["start_index"],
628
+ "_job_id": job_id,
629
+ "_filename": entry.path,
630
+ "_file_size": entry.size,
631
+ "_processed_indices": processed_indices,
632
+ },
633
+ "job_id": job_id,
634
+ }
635
+
636
+ processed_indices.append(idx)
637
+
638
+ if len(processed_indices) % 10 == 0:
639
+ gc.collect()
640
+
641
+ except StopIteration:
642
+ logger.warning(f"Unexpected end of shard at index {idx}")
643
+ break
644
+ except Exception as e:
645
+ logger.error(f"Error processing index {idx}: {e}")
646
+ continue
647
+
648
+ except Exception as e:
649
+ logger.error(f"Error processing range {start_idx}-{end_idx}: {e}")
650
+ continue
651
+
652
+ # Store processed indices for result
653
+ context["_processed_indices"] = processed_indices
654
+ logger.info(f"Processed {len(processed_indices)} items from unit {unit.unit_id}")
754
655
 
755
656
  def prepare_result(
756
657
  self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
757
658
  ) -> WorkResult:
758
- """Prepare WebDataset-specific result."""
759
- logger.debug("Preparing result for unit %s", unit.unit_id)
659
+ """Prepare result with processing details."""
760
660
  result = super().prepare_result(unit, outputs, processing_time_ms)
761
661
 
762
- # Add processed indices to metadata if available
763
- if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
764
- result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
765
- logger.debug(
766
- "Added item_indices to result metadata: %s", result.metadata["item_indices"]
767
- )
662
+ # Add processed indices for chunk tracker
663
+ if hasattr(self, "_last_context") and "_processed_indices" in self._last_context:
664
+ result.metadata["item_indices"] = self._last_context["_processed_indices"]
768
665
 
769
666
  return result
770
667
 
771
668
  def get_dataset_info(self) -> Dict[str, Any]:
772
669
  """Get dataset information."""
773
- if self.dataset_loader:
774
- info = self.dataset_loader.get_dataset_info()
775
- logger.debug("Dataset info: %s", info)
776
- return info
777
- info = {
778
- "dataset_path": self.dataset_config.get("dataset_path"),
779
- "dataset_type": self.dataset_config.get("type", "huggingface"),
670
+ if self.dataset:
671
+ stats = self.dataset.get_stats()
672
+ return {
673
+ "dataset_name": self.dataset.name,
674
+ "format": self.dataset.dataset_format,
675
+ "total_shards": stats["total_shards"],
676
+ "total_files": stats.get("total_files", "Unknown"),
677
+ "mock_results": self.mock_results,
678
+ }
679
+ return {
680
+ "dataset_name": "Mock Dataset" if self.mock_results else "Unknown",
681
+ "mock_results": self.mock_results,
780
682
  }
781
- logger.debug("Dataset info (no loader): %s", info)
782
- return info