caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +921 -427
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +463 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +303 -92
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
  28. caption_flow-0.4.1.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,27 @@
1
1
  """Local filesystem datasets processor implementation."""
2
2
 
3
+ import asyncio
4
+ import io
3
5
  import logging
4
- import threading
6
+ import mimetypes
5
7
  import os
6
- from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
7
- from collections import deque, defaultdict
8
+ import threading
9
+ from collections import defaultdict, deque
8
10
  from pathlib import Path
9
- import json
10
- import io
11
- import mimetypes
12
- from datetime import datetime
13
- from PIL import Image
11
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Tuple
12
+
14
13
  import aiofiles
15
- from fastapi import FastAPI, HTTPException, Response
16
- from fastapi.responses import StreamingResponse
17
- import uvicorn
18
- import asyncio
19
14
  import requests
15
+ import uvicorn
16
+ from fastapi import FastAPI, HTTPException
17
+ from fastapi.responses import StreamingResponse
18
+ from PIL import Image
20
19
 
21
20
  from caption_flow.storage import StorageManager
22
- from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
23
- from ..utils import ChunkTracker
21
+
24
22
  from ..models import JobId
23
+ from ..utils import ChunkTracker
24
+ from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
25
25
 
26
26
  logger = logging.getLogger(__name__)
27
27
  logger.setLevel(logging.DEBUG)
@@ -217,23 +217,19 @@ class LocalFilesystemOrchestratorProcessor(OrchestratorProcessor):
217
217
  if not self.chunk_tracker:
218
218
  return
219
219
 
220
- all_processed_jobs = storage.get_all_processed_job_ids()
220
+ storage.get_all_processed_job_ids()
221
221
 
222
222
  with self.lock:
223
223
  for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
224
- # Calculate actual unprocessed ranges
225
- chunk_range = (
226
- chunk_state.start_index,
227
- chunk_state.start_index + chunk_state.chunk_size - 1,
228
- )
229
-
230
- # Get processed indices for this chunk
231
- processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
232
- chunk_id, all_processed_jobs
233
- )
224
+ # Get unprocessed ranges (relative coordinates from ChunkTracker)
225
+ relative_unprocessed_ranges = chunk_state.get_unprocessed_ranges()
234
226
 
235
- # Calculate unprocessed ranges
236
- unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
227
+ # Convert relative ranges to absolute ranges
228
+ unprocessed_ranges = []
229
+ for start, end in relative_unprocessed_ranges:
230
+ abs_start = chunk_state.start_index + start
231
+ abs_end = chunk_state.start_index + end
232
+ unprocessed_ranges.append((abs_start, abs_end))
237
233
 
238
234
  if unprocessed_ranges:
239
235
  # Create work unit for unprocessed items
@@ -588,7 +584,7 @@ class LocalFilesystemWorkerProcessor(WorkerProcessor):
588
584
  processed_indices = []
589
585
 
590
586
  # Get orchestrator info if we need HTTP
591
- orchestrator = context.get("orchestrator")
587
+ context.get("orchestrator")
592
588
 
593
589
  for idx in sorted(indices_to_process):
594
590
  try:
@@ -1,28 +1,27 @@
1
1
  """WebDataset processor implementation using webshart TarDataLoader."""
2
2
 
3
- import logging
4
- import threading
5
3
  import gc
4
+ import io
5
+ import logging
6
6
  import os
7
- from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
8
- from collections import deque, defaultdict
7
+ import threading
8
+ from collections import defaultdict, deque
9
9
  from pathlib import Path
10
- import json
11
- from datetime import datetime
10
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Set
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import webshart
12
15
  from PIL import Image
13
- import io
14
16
 
15
17
  from caption_flow.models import JobId
16
18
  from caption_flow.storage import StorageManager
17
- from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
18
- from ..utils import ChunkTracker
19
19
 
20
- import webshart
21
- import cv2
22
- import numpy as np
20
+ from ..utils import ChunkTracker
21
+ from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
23
22
 
24
23
  logger = logging.getLogger(__name__)
25
- logger.setLevel(logging.INFO)
24
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
26
25
 
27
26
 
28
27
  class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
@@ -217,7 +216,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
217
216
  while units_created < units_needed and not self.stop_creation.is_set():
218
217
  # Get current shard info
219
218
  if current_shard_idx >= self.dataset.num_shards:
220
- logger.info("All shards processed")
219
+ threading.Event().wait(5)
221
220
  break
222
221
 
223
222
  shard_info = self._get_shard_info_cached(current_shard_idx)
@@ -240,8 +239,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
240
239
  self.current_chunk_index = current_file_idx // self.chunk_size
241
240
  job_id_obj = JobId(
242
241
  shard_id=shard_name,
243
- chunk_id=self.current_chunk_index,
244
- sample_id=current_file_idx,
242
+ chunk_id=str(self.current_chunk_index),
243
+ sample_id=str(current_file_idx),
245
244
  )
246
245
  chunk_id = job_id_obj.get_chunk_str()
247
246
 
@@ -534,8 +533,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
534
533
  for start_idx, end_idx in ranges:
535
534
  self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
536
535
 
537
- # Save checkpoint after updating
538
- self.chunk_tracker.save()
536
+ # Flush checkpoint after major update
537
+ self.chunk_tracker.flush()
539
538
 
540
539
  def get_stats(self) -> Dict[str, Any]:
541
540
  """Get processor statistics."""
@@ -572,9 +571,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
572
571
  if self.unit_creation_thread:
573
572
  self.unit_creation_thread.join(timeout=5)
574
573
 
575
- # Save checkpoint
574
+ # Flush final checkpoint on cleanup
576
575
  if self.chunk_tracker:
577
- self.chunk_tracker.save()
576
+ self.chunk_tracker.flush()
578
577
 
579
578
 
580
579
  class WebDatasetWorkerProcessor(WorkerProcessor):
@@ -594,6 +593,9 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
594
593
  self.dataset_path = dataset_cfg.get("dataset_path")
595
594
  metadata_path = dataset_cfg.get("metadata_path", None)
596
595
  self.mock_results = dataset_cfg.get("mock_results", False)
596
+ split_worker_cache = dataset_cfg.get(
597
+ "split_worker_cache", True
598
+ ) # multiple workers get their own cache by default
597
599
 
598
600
  # Cache configuration
599
601
  cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
@@ -609,7 +611,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
609
611
  # Enable caching
610
612
  self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
611
613
  self.dataset.enable_shard_cache(
612
- location=str(cache_dir / "shard_cache"),
614
+ location=(
615
+ str(cache_dir / "shard_cache" / str(self.gpu_id))
616
+ if split_worker_cache
617
+ else str(cache_dir / "shard_cache")
618
+ ),
613
619
  cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
614
620
  )
615
621
 
@@ -681,7 +687,7 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
681
687
  # Iterate through the range
682
688
  for idx in range(start_idx, end_idx + 1):
683
689
  try:
684
- entry = next(self.loader)
690
+ entry = webshart.next_with_cache_wait(self.loader)
685
691
 
686
692
  # Decode image
687
693
  image = None