caption-flow 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +937 -416
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +5 -3
  5. caption_flow/orchestrator.py +186 -116
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +440 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +66 -25
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +41 -19
  18. caption_flow/utils/chunk_tracker.py +200 -65
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +12 -6
  25. caption_flow/workers/caption.py +272 -91
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.3.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.3.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,27 @@
1
1
  """Local filesystem datasets processor implementation."""
2
2
 
3
+ import asyncio
4
+ import io
3
5
  import logging
4
- import threading
6
+ import mimetypes
5
7
  import os
6
- from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
7
- from collections import deque, defaultdict
8
+ import threading
9
+ from collections import defaultdict, deque
8
10
  from pathlib import Path
9
- import json
10
- import io
11
- import mimetypes
12
- from datetime import datetime
13
- from PIL import Image
11
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Tuple
12
+
14
13
  import aiofiles
15
- from fastapi import FastAPI, HTTPException, Response
16
- from fastapi.responses import StreamingResponse
17
- import uvicorn
18
- import asyncio
19
14
  import requests
15
+ import uvicorn
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.responses import StreamingResponse
18
+ from PIL import Image
20
19
 
21
20
  from caption_flow.storage import StorageManager
22
- from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
23
- from ..utils import ChunkTracker
21
+
24
22
  from ..models import JobId
23
+ from ..utils import ChunkTracker
24
+ from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
25
25
 
26
26
  logger = logging.getLogger(__name__)
27
27
  logger.setLevel(logging.DEBUG)
@@ -217,23 +217,19 @@ class LocalFilesystemOrchestratorProcessor(OrchestratorProcessor):
217
217
  if not self.chunk_tracker:
218
218
  return
219
219
 
220
- all_processed_jobs = storage.get_all_processed_job_ids()
220
+ storage.get_all_processed_job_ids()
221
221
 
222
222
  with self.lock:
223
223
  for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
224
- # Calculate actual unprocessed ranges
225
- chunk_range = (
226
- chunk_state.start_index,
227
- chunk_state.start_index + chunk_state.chunk_size - 1,
228
- )
229
-
230
- # Get processed indices for this chunk
231
- processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
232
- chunk_id, all_processed_jobs
233
- )
224
+ # Get unprocessed ranges (relative coordinates from ChunkTracker)
225
+ relative_unprocessed_ranges = chunk_state.get_unprocessed_ranges()
234
226
 
235
- # Calculate unprocessed ranges
236
- unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
227
+ # Convert relative ranges to absolute ranges
228
+ unprocessed_ranges = []
229
+ for start, end in relative_unprocessed_ranges:
230
+ abs_start = chunk_state.start_index + start
231
+ abs_end = chunk_state.start_index + end
232
+ unprocessed_ranges.append((abs_start, abs_end))
237
233
 
238
234
  if unprocessed_ranges:
239
235
  # Create work unit for unprocessed items
@@ -588,7 +584,7 @@ class LocalFilesystemWorkerProcessor(WorkerProcessor):
588
584
  processed_indices = []
589
585
 
590
586
  # Get orchestrator info if we need HTTP
591
- orchestrator = context.get("orchestrator")
587
+ context.get("orchestrator")
592
588
 
593
589
  for idx in sorted(indices_to_process):
594
590
  try:
@@ -1,28 +1,27 @@
1
1
  """WebDataset processor implementation using webshart TarDataLoader."""
2
2
 
3
- import logging
4
- import threading
5
3
  import gc
4
+ import io
5
+ import logging
6
6
  import os
7
- from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
8
- from collections import deque, defaultdict
7
+ import threading
8
+ from collections import defaultdict, deque
9
9
  from pathlib import Path
10
- import json
11
- from datetime import datetime
10
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Set
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import webshart
12
15
  from PIL import Image
13
- import io
14
16
 
15
17
  from caption_flow.models import JobId
16
18
  from caption_flow.storage import StorageManager
17
- from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
18
- from ..utils import ChunkTracker
19
19
 
20
- import webshart
21
- import cv2
22
- import numpy as np
20
+ from ..utils import ChunkTracker
21
+ from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
23
22
 
24
23
  logger = logging.getLogger(__name__)
25
- logger.setLevel(logging.INFO)
24
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
26
25
 
27
26
 
28
27
  class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
@@ -217,7 +216,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
217
216
  while units_created < units_needed and not self.stop_creation.is_set():
218
217
  # Get current shard info
219
218
  if current_shard_idx >= self.dataset.num_shards:
220
- logger.info("All shards processed")
219
+ threading.Event().wait(5)
221
220
  break
222
221
 
223
222
  shard_info = self._get_shard_info_cached(current_shard_idx)
@@ -240,8 +239,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
240
239
  self.current_chunk_index = current_file_idx // self.chunk_size
241
240
  job_id_obj = JobId(
242
241
  shard_id=shard_name,
243
- chunk_id=self.current_chunk_index,
244
- sample_id=current_file_idx,
242
+ chunk_id=str(self.current_chunk_index),
243
+ sample_id=str(current_file_idx),
245
244
  )
246
245
  chunk_id = job_id_obj.get_chunk_str()
247
246
 
@@ -306,8 +305,15 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
306
305
  assigned = []
307
306
 
308
307
  with self.lock:
309
- while len(assigned) < count and self.pending_units:
308
+ units_checked = 0
309
+ max_units_to_check = len(self.pending_units)
310
+
311
+ while len(assigned) < count and units_checked < max_units_to_check:
312
+ if not self.pending_units:
313
+ break
314
+
310
315
  unit_id = self.pending_units.popleft()
316
+ units_checked += 1
311
317
  unit = self.work_units.get(unit_id)
312
318
 
313
319
  if unit:
@@ -316,6 +322,16 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
316
322
  chunk_state = self.chunk_tracker.chunks[unit_id]
317
323
  relative_unprocessed = chunk_state.get_unprocessed_ranges()
318
324
 
325
+ # If no unprocessed ranges, mark as completed and skip
326
+ if not relative_unprocessed:
327
+ logger.info(
328
+ f"Chunk {unit_id} has no unprocessed ranges, marking as completed"
329
+ )
330
+ self.chunk_tracker.mark_completed(unit_id)
331
+ # Remove from work units
332
+ del self.work_units[unit_id]
333
+ continue
334
+
319
335
  # Convert relative to absolute indices
320
336
  absolute_ranges = []
321
337
  for start, end in relative_unprocessed:
@@ -335,6 +351,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
335
351
 
336
352
  if self.chunk_tracker:
337
353
  self.chunk_tracker.mark_assigned(unit_id, worker_id)
354
+ else:
355
+ # Put it back if we couldn't get the unit
356
+ self.pending_units.append(unit_id)
338
357
 
339
358
  logger.debug(f"Assigned {len(assigned)} units to worker {worker_id}")
340
359
  return assigned
@@ -394,8 +413,20 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
394
413
  logger.info(f"Released {len(unit_ids)} assignments from {worker_id}")
395
414
 
396
415
  def handle_result(self, result: WorkResult) -> Dict[str, Any]:
397
- """Handle result from worker."""
398
- # Track processed items if we have chunk tracker
416
+ """Handle result from worker and update chunk tracker."""
417
+ # Extract the actual item index from the metadata
418
+ item_index = result.metadata.get("_item_index", None)
419
+
420
+ # If we have an item index, mark it as processed in the chunk tracker
421
+ if self.chunk_tracker and item_index is not None and result.chunk_id:
422
+ try:
423
+ # Mark single item as processed
424
+ self.chunk_tracker.mark_items_processed(result.chunk_id, item_index, item_index)
425
+ # logger.debug(f"Marked item {item_index} as processed in chunk {result.chunk_id}")
426
+ except Exception as e:
427
+ logger.error(f"Error marking item {item_index} as processed: {e}")
428
+
429
+ # Also handle batch results if present (backward compatibility)
399
430
  if self.chunk_tracker and "item_indices" in result.metadata:
400
431
  indices = result.metadata["item_indices"]
401
432
 
@@ -419,6 +450,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
419
450
  # Mark ranges as processed
420
451
  for start_idx, end_idx in ranges:
421
452
  self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
453
+ logger.debug(
454
+ f"Marked range {start_idx}-{end_idx} as processed in chunk {result.chunk_id}"
455
+ )
422
456
 
423
457
  return {
424
458
  "source_id": result.source_id,
@@ -499,8 +533,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
499
533
  for start_idx, end_idx in ranges:
500
534
  self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
501
535
 
502
- # Save checkpoint after updating
503
- self.chunk_tracker.save()
536
+ # Flush checkpoint after major update
537
+ self.chunk_tracker.flush()
504
538
 
505
539
  def get_stats(self) -> Dict[str, Any]:
506
540
  """Get processor statistics."""
@@ -537,9 +571,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
537
571
  if self.unit_creation_thread:
538
572
  self.unit_creation_thread.join(timeout=5)
539
573
 
540
- # Save checkpoint
574
+ # Flush final checkpoint on cleanup
541
575
  if self.chunk_tracker:
542
- self.chunk_tracker.save_checkpoint()
576
+ self.chunk_tracker.flush()
543
577
 
544
578
 
545
579
  class WebDatasetWorkerProcessor(WorkerProcessor):
@@ -559,6 +593,9 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
559
593
  self.dataset_path = dataset_cfg.get("dataset_path")
560
594
  metadata_path = dataset_cfg.get("metadata_path", None)
561
595
  self.mock_results = dataset_cfg.get("mock_results", False)
596
+ split_worker_cache = dataset_cfg.get(
597
+ "split_worker_cache", True
598
+ ) # multiple workers get their own cache by default
562
599
 
563
600
  # Cache configuration
564
601
  cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
@@ -574,7 +611,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
574
611
  # Enable caching
575
612
  self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
576
613
  self.dataset.enable_shard_cache(
577
- location=str(cache_dir / "shard_cache"),
614
+ location=(
615
+ str(cache_dir / "shard_cache" / str(self.gpu_id))
616
+ if split_worker_cache
617
+ else str(cache_dir / "shard_cache")
618
+ ),
578
619
  cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
579
620
  )
580
621
 
@@ -646,7 +687,7 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
646
687
  # Iterate through the range
647
688
  for idx in range(start_idx, end_idx + 1):
648
689
  try:
649
- entry = next(self.loader)
690
+ entry = webshart.next_with_cache_wait(self.loader)
650
691
 
651
692
  # Decode image
652
693
  image = None