caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +934 -415
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +439 -67
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +265 -90
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
- caption_flow-0.4.0.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,27 +1,27 @@
|
|
1
1
|
"""Local filesystem datasets processor implementation."""
|
2
2
|
|
3
|
+
import asyncio
|
4
|
+
import io
|
3
5
|
import logging
|
4
|
-
import
|
6
|
+
import mimetypes
|
5
7
|
import os
|
6
|
-
|
7
|
-
from collections import
|
8
|
+
import threading
|
9
|
+
from collections import defaultdict, deque
|
8
10
|
from pathlib import Path
|
9
|
-
import
|
10
|
-
|
11
|
-
import mimetypes
|
12
|
-
from datetime import datetime
|
13
|
-
from PIL import Image
|
11
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Set, Tuple
|
12
|
+
|
14
13
|
import aiofiles
|
15
|
-
from fastapi import FastAPI, HTTPException, Response
|
16
|
-
from fastapi.responses import StreamingResponse
|
17
|
-
import uvicorn
|
18
|
-
import asyncio
|
19
14
|
import requests
|
15
|
+
import uvicorn
|
16
|
+
from fastapi import FastAPI, HTTPException
|
17
|
+
from fastapi.responses import StreamingResponse
|
18
|
+
from PIL import Image
|
20
19
|
|
21
20
|
from caption_flow.storage import StorageManager
|
22
|
-
|
23
|
-
from ..utils import ChunkTracker
|
21
|
+
|
24
22
|
from ..models import JobId
|
23
|
+
from ..utils import ChunkTracker
|
24
|
+
from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
|
25
25
|
|
26
26
|
logger = logging.getLogger(__name__)
|
27
27
|
logger.setLevel(logging.DEBUG)
|
@@ -217,23 +217,19 @@ class LocalFilesystemOrchestratorProcessor(OrchestratorProcessor):
|
|
217
217
|
if not self.chunk_tracker:
|
218
218
|
return
|
219
219
|
|
220
|
-
|
220
|
+
storage.get_all_processed_job_ids()
|
221
221
|
|
222
222
|
with self.lock:
|
223
223
|
for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
|
224
|
-
#
|
225
|
-
|
226
|
-
chunk_state.start_index,
|
227
|
-
chunk_state.start_index + chunk_state.chunk_size - 1,
|
228
|
-
)
|
229
|
-
|
230
|
-
# Get processed indices for this chunk
|
231
|
-
processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
|
232
|
-
chunk_id, all_processed_jobs
|
233
|
-
)
|
224
|
+
# Get unprocessed ranges (relative coordinates from ChunkTracker)
|
225
|
+
relative_unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
234
226
|
|
235
|
-
#
|
236
|
-
unprocessed_ranges =
|
227
|
+
# Convert relative ranges to absolute ranges
|
228
|
+
unprocessed_ranges = []
|
229
|
+
for start, end in relative_unprocessed_ranges:
|
230
|
+
abs_start = chunk_state.start_index + start
|
231
|
+
abs_end = chunk_state.start_index + end
|
232
|
+
unprocessed_ranges.append((abs_start, abs_end))
|
237
233
|
|
238
234
|
if unprocessed_ranges:
|
239
235
|
# Create work unit for unprocessed items
|
@@ -588,7 +584,7 @@ class LocalFilesystemWorkerProcessor(WorkerProcessor):
|
|
588
584
|
processed_indices = []
|
589
585
|
|
590
586
|
# Get orchestrator info if we need HTTP
|
591
|
-
|
587
|
+
context.get("orchestrator")
|
592
588
|
|
593
589
|
for idx in sorted(indices_to_process):
|
594
590
|
try:
|
@@ -1,28 +1,27 @@
|
|
1
1
|
"""WebDataset processor implementation using webshart TarDataLoader."""
|
2
2
|
|
3
|
-
import logging
|
4
|
-
import threading
|
5
3
|
import gc
|
4
|
+
import io
|
5
|
+
import logging
|
6
6
|
import os
|
7
|
-
|
8
|
-
from collections import
|
7
|
+
import threading
|
8
|
+
from collections import defaultdict, deque
|
9
9
|
from pathlib import Path
|
10
|
-
import
|
11
|
-
|
10
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Set
|
11
|
+
|
12
|
+
import cv2
|
13
|
+
import numpy as np
|
14
|
+
import webshart
|
12
15
|
from PIL import Image
|
13
|
-
import io
|
14
16
|
|
15
17
|
from caption_flow.models import JobId
|
16
18
|
from caption_flow.storage import StorageManager
|
17
|
-
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
18
|
-
from ..utils import ChunkTracker
|
19
19
|
|
20
|
-
import
|
21
|
-
import
|
22
|
-
import numpy as np
|
20
|
+
from ..utils import ChunkTracker
|
21
|
+
from .base import OrchestratorProcessor, ProcessorConfig, WorkerProcessor, WorkResult, WorkUnit
|
23
22
|
|
24
23
|
logger = logging.getLogger(__name__)
|
25
|
-
logger.setLevel(
|
24
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
26
25
|
|
27
26
|
|
28
27
|
class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
@@ -217,7 +216,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
217
216
|
while units_created < units_needed and not self.stop_creation.is_set():
|
218
217
|
# Get current shard info
|
219
218
|
if current_shard_idx >= self.dataset.num_shards:
|
220
|
-
|
219
|
+
threading.Event().wait(5)
|
221
220
|
break
|
222
221
|
|
223
222
|
shard_info = self._get_shard_info_cached(current_shard_idx)
|
@@ -240,8 +239,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
240
239
|
self.current_chunk_index = current_file_idx // self.chunk_size
|
241
240
|
job_id_obj = JobId(
|
242
241
|
shard_id=shard_name,
|
243
|
-
chunk_id=self.current_chunk_index,
|
244
|
-
sample_id=current_file_idx,
|
242
|
+
chunk_id=str(self.current_chunk_index),
|
243
|
+
sample_id=str(current_file_idx),
|
245
244
|
)
|
246
245
|
chunk_id = job_id_obj.get_chunk_str()
|
247
246
|
|
@@ -534,8 +533,8 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
534
533
|
for start_idx, end_idx in ranges:
|
535
534
|
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
536
535
|
|
537
|
-
#
|
538
|
-
self.chunk_tracker.
|
536
|
+
# Flush checkpoint after major update
|
537
|
+
self.chunk_tracker.flush()
|
539
538
|
|
540
539
|
def get_stats(self) -> Dict[str, Any]:
|
541
540
|
"""Get processor statistics."""
|
@@ -572,9 +571,9 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
572
571
|
if self.unit_creation_thread:
|
573
572
|
self.unit_creation_thread.join(timeout=5)
|
574
573
|
|
575
|
-
#
|
574
|
+
# Flush final checkpoint on cleanup
|
576
575
|
if self.chunk_tracker:
|
577
|
-
self.chunk_tracker.
|
576
|
+
self.chunk_tracker.flush()
|
578
577
|
|
579
578
|
|
580
579
|
class WebDatasetWorkerProcessor(WorkerProcessor):
|
@@ -594,6 +593,9 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
594
593
|
self.dataset_path = dataset_cfg.get("dataset_path")
|
595
594
|
metadata_path = dataset_cfg.get("metadata_path", None)
|
596
595
|
self.mock_results = dataset_cfg.get("mock_results", False)
|
596
|
+
split_worker_cache = dataset_cfg.get(
|
597
|
+
"split_worker_cache", True
|
598
|
+
) # multiple workers get their own cache by default
|
597
599
|
|
598
600
|
# Cache configuration
|
599
601
|
cache_dir = Path(cfg.get("cache_dir", "./webshart_cache"))
|
@@ -609,7 +611,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
609
611
|
# Enable caching
|
610
612
|
self.dataset.enable_metadata_cache(location=str(cache_dir / "metadata_cache"))
|
611
613
|
self.dataset.enable_shard_cache(
|
612
|
-
location=
|
614
|
+
location=(
|
615
|
+
str(cache_dir / "shard_cache" / str(self.gpu_id))
|
616
|
+
if split_worker_cache
|
617
|
+
else str(cache_dir / "shard_cache")
|
618
|
+
),
|
613
619
|
cache_limit_gb=cfg.get("shard_cache_gb", 10.0),
|
614
620
|
)
|
615
621
|
|
@@ -681,7 +687,7 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
681
687
|
# Iterate through the range
|
682
688
|
for idx in range(start_idx, end_idx + 1):
|
683
689
|
try:
|
684
|
-
entry =
|
690
|
+
entry = webshart.next_with_cache_wait(self.loader)
|
685
691
|
|
686
692
|
# Decode image
|
687
693
|
image = None
|