caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +1 -1
- caption_flow/cli.py +307 -0
- caption_flow/models.py +26 -0
- caption_flow/orchestrator.py +9 -9
- caption_flow/processors/huggingface.py +636 -464
- caption_flow/processors/webdataset.py +379 -534
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +410 -303
- caption_flow/utils/__init__.py +0 -2
- caption_flow/utils/chunk_tracker.py +196 -164
- caption_flow/utils/image_processor.py +19 -132
- caption_flow/viewer.py +594 -0
- caption_flow/workers/caption.py +164 -129
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/METADATA +45 -177
- caption_flow-0.3.1.dist-info/RECORD +33 -0
- caption_flow/utils/dataset_loader.py +0 -222
- caption_flow/utils/dataset_metadata_cache.py +0 -67
- caption_flow/utils/job_queue.py +0 -41
- caption_flow/utils/shard_processor.py +0 -119
- caption_flow/utils/shard_tracker.py +0 -83
- caption_flow-0.2.3.dist-info/RECORD +0 -35
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,119 +0,0 @@
|
|
1
|
-
"""Shard processing abstraction for different dataset types."""
|
2
|
-
|
3
|
-
import io
|
4
|
-
import logging
|
5
|
-
import time
|
6
|
-
from abc import ABC, abstractmethod
|
7
|
-
from pathlib import Path
|
8
|
-
from typing import Generator, Tuple, Optional, Dict, Any
|
9
|
-
from dataclasses import dataclass
|
10
|
-
from .image_processor import ImageProcessor
|
11
|
-
from threading import Event
|
12
|
-
import shlex
|
13
|
-
|
14
|
-
import webdataset as wds
|
15
|
-
from PIL import Image
|
16
|
-
|
17
|
-
from .dataset_loader import DatasetLoader
|
18
|
-
|
19
|
-
logger = logging.getLogger(__name__)
|
20
|
-
|
21
|
-
|
22
|
-
class ShardProcessor(ABC):
|
23
|
-
"""Abstract base for processing dataset shards."""
|
24
|
-
|
25
|
-
@abstractmethod
|
26
|
-
def iterate_chunk_with_metadata(
|
27
|
-
self,
|
28
|
-
chunk,
|
29
|
-
dataset_loader: Optional[DatasetLoader],
|
30
|
-
should_stop: Event,
|
31
|
-
connected: Event,
|
32
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
33
|
-
"""
|
34
|
-
Iterate through items in a chunk with metadata.
|
35
|
-
|
36
|
-
Yields:
|
37
|
-
Tuple of (key, url, image_data, metadata)
|
38
|
-
"""
|
39
|
-
pass
|
40
|
-
|
41
|
-
|
42
|
-
class WebDatasetShardProcessor(ShardProcessor):
|
43
|
-
"""Processor for WebDataset tar shards with range support."""
|
44
|
-
|
45
|
-
def __init__(self, hf_token: Optional[str] = None, dataset_type: str = "local"):
|
46
|
-
self.hf_token = hf_token
|
47
|
-
self.dataset_type = dataset_type
|
48
|
-
|
49
|
-
def iterate_chunk_with_metadata(
|
50
|
-
self,
|
51
|
-
chunk,
|
52
|
-
dataset_loader: Optional[DatasetLoader],
|
53
|
-
should_stop: Event,
|
54
|
-
connected: Event,
|
55
|
-
) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
|
56
|
-
"""Process WebDataset shard chunk with metadata and range support."""
|
57
|
-
# Get unprocessed ranges
|
58
|
-
unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
|
59
|
-
|
60
|
-
logger.info(
|
61
|
-
f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
|
62
|
-
)
|
63
|
-
|
64
|
-
# Create WebDataset pipeline
|
65
|
-
if self.dataset_type == "huggingface":
|
66
|
-
# Use curl with auth for HuggingFace WebDataset
|
67
|
-
url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
|
68
|
-
ds = wds.DataPipeline(
|
69
|
-
wds.SimpleShardList(url_cmd),
|
70
|
-
wds.tarfile_to_samples(),
|
71
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
72
|
-
)
|
73
|
-
else:
|
74
|
-
# Local file
|
75
|
-
ds = wds.DataPipeline(
|
76
|
-
wds.SimpleShardList(chunk.shard_url),
|
77
|
-
wds.tarfile_to_samples(),
|
78
|
-
wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
|
79
|
-
)
|
80
|
-
|
81
|
-
# Process items
|
82
|
-
absolute_idx = 0 # Absolute index in the shard
|
83
|
-
items_yielded = 0
|
84
|
-
|
85
|
-
for key, image_data in ds:
|
86
|
-
# Check if we should stop
|
87
|
-
if should_stop.is_set() or not connected.is_set():
|
88
|
-
logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
|
89
|
-
break
|
90
|
-
|
91
|
-
# Skip items before chunk start
|
92
|
-
if absolute_idx < chunk.start_index:
|
93
|
-
absolute_idx += 1
|
94
|
-
continue
|
95
|
-
|
96
|
-
# Calculate relative index within chunk
|
97
|
-
relative_idx = absolute_idx - chunk.start_index
|
98
|
-
|
99
|
-
# Stop if beyond chunk
|
100
|
-
if relative_idx >= chunk.chunk_size:
|
101
|
-
break
|
102
|
-
|
103
|
-
# Check if current index is in any unprocessed range
|
104
|
-
in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
|
105
|
-
|
106
|
-
if in_range:
|
107
|
-
# Create metadata with the relative index
|
108
|
-
metadata = {
|
109
|
-
"_chunk_relative_index": relative_idx,
|
110
|
-
}
|
111
|
-
items_yielded += 1
|
112
|
-
yield key, chunk.shard_url, image_data, metadata
|
113
|
-
|
114
|
-
absolute_idx += 1
|
115
|
-
|
116
|
-
logger.info(
|
117
|
-
f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
|
118
|
-
f"from ranges {unprocessed_ranges}"
|
119
|
-
)
|
@@ -1,83 +0,0 @@
|
|
1
|
-
"""Shard tracking using CheckpointTracker base class."""
|
2
|
-
|
3
|
-
from pathlib import Path
|
4
|
-
from typing import Dict, Any, List, Set
|
5
|
-
|
6
|
-
from .checkpoint_tracker import CheckpointTracker
|
7
|
-
|
8
|
-
|
9
|
-
class ShardTracker(CheckpointTracker):
|
10
|
-
"""Tracks shard processing progress."""
|
11
|
-
|
12
|
-
def __init__(self, checkpoint_path: Path):
|
13
|
-
"""Initialize shard tracker with checkpoint file."""
|
14
|
-
self.completed_shards: Set[str] = set()
|
15
|
-
self.partial_shards: Dict[str, Dict[str, Any]] = {}
|
16
|
-
super().__init__(checkpoint_path)
|
17
|
-
|
18
|
-
def _get_default_state(self) -> Dict[str, Any]:
|
19
|
-
"""Return default state structure for new checkpoints."""
|
20
|
-
return {"completed_shards": [], "partial_shards": {}}
|
21
|
-
|
22
|
-
def _deserialize_state(self, data: Dict[str, Any]) -> None:
|
23
|
-
"""Deserialize loaded data into instance state."""
|
24
|
-
self.completed_shards = set(data.get("completed_shards", []))
|
25
|
-
self.partial_shards = data.get("partial_shards", {})
|
26
|
-
|
27
|
-
def _serialize_state(self) -> Dict[str, Any]:
|
28
|
-
"""Serialize instance state for saving."""
|
29
|
-
return {
|
30
|
-
"completed_shards": list(self.completed_shards),
|
31
|
-
"partial_shards": self.partial_shards,
|
32
|
-
}
|
33
|
-
|
34
|
-
def mark_complete(self, shard_name: str) -> None:
|
35
|
-
"""Mark a shard as complete."""
|
36
|
-
self.completed_shards.add(shard_name)
|
37
|
-
if shard_name in self.partial_shards:
|
38
|
-
del self.partial_shards[shard_name]
|
39
|
-
self.save()
|
40
|
-
|
41
|
-
def update_partial(self, shard_name: str, processed_keys: List[str]) -> None:
|
42
|
-
"""Update partial progress for a shard."""
|
43
|
-
self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
|
44
|
-
self.save()
|
45
|
-
|
46
|
-
def get_processed_keys(self, shard_name: str) -> Set[str]:
|
47
|
-
"""Get set of processed keys for a shard."""
|
48
|
-
if shard_name in self.completed_shards:
|
49
|
-
return set() # All done
|
50
|
-
|
51
|
-
if shard_name in self.partial_shards:
|
52
|
-
return set(self.partial_shards[shard_name].get("keys", []))
|
53
|
-
|
54
|
-
return set()
|
55
|
-
|
56
|
-
def is_complete(self, shard_name: str) -> bool:
|
57
|
-
"""Check if a shard is complete."""
|
58
|
-
return shard_name in self.completed_shards
|
59
|
-
|
60
|
-
def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
|
61
|
-
"""Get list of shards that still need processing."""
|
62
|
-
remaining = []
|
63
|
-
for s in all_shards:
|
64
|
-
shard_name = Path(s).stem
|
65
|
-
|
66
|
-
if shard_name not in self.completed_shards:
|
67
|
-
remaining.append(s)
|
68
|
-
|
69
|
-
return remaining
|
70
|
-
|
71
|
-
def get_stats(self) -> Dict[str, Any]:
|
72
|
-
"""Get shard tracking statistics."""
|
73
|
-
base_stats = super().get_stats()
|
74
|
-
base_stats.update(
|
75
|
-
{
|
76
|
-
"completed_shards": len(self.completed_shards),
|
77
|
-
"partial_shards": len(self.partial_shards),
|
78
|
-
"total_partial_keys": sum(
|
79
|
-
len(data.get("keys", [])) for data in self.partial_shards.values()
|
80
|
-
),
|
81
|
-
}
|
82
|
-
)
|
83
|
-
return base_stats
|
@@ -1,35 +0,0 @@
|
|
1
|
-
caption_flow/__init__.py,sha256=NLPJ25lRN7xHqncXweINDNwbt0q8lgjZ30G21zlPdRs,303
|
2
|
-
caption_flow/cli.py,sha256=qEueeJhf3DvxSBxnOp5t32p6gAnZskvIDe6cwtPA0-Y,28892
|
3
|
-
caption_flow/models.py,sha256=bpr7yMy3vPErZCQwmgOYIix489rRGbT6lVw8wxxwTkc,4931
|
4
|
-
caption_flow/monitor.py,sha256=bAt9EJqfPgT_KdbknGdCxwBRH002pRDgyUmYIj6Dyso,7885
|
5
|
-
caption_flow/orchestrator.py,sha256=ciqWghxUxk-5s6u7W3JwD7_JLSFYV57NgOwiMkxME-I,36133
|
6
|
-
caption_flow/storage.py,sha256=Wqgtsk6yZ9Kf-izeUKHLwSvPUH3xFqIbzox20QHbc64,43370
|
7
|
-
caption_flow/processors/__init__.py,sha256=hvq-OuAJWQe6hFglKe7QmkS8473k20FmxZDSxfXpCrg,423
|
8
|
-
caption_flow/processors/base.py,sha256=JlTqCHo5HRXrXMVzgle_6pNwh4HGHsF7jLF6PeSnWr0,6783
|
9
|
-
caption_flow/processors/huggingface.py,sha256=MNz9vDMtrrTOSXe9Q_kbBrQ7XBv69X6x5xD_QP9icdg,33765
|
10
|
-
caption_flow/processors/local_filesystem.py,sha256=EYmsImbkqsIU7UZL2FijL0hotKLtPOtkzfwernQDSxA,27860
|
11
|
-
caption_flow/processors/webdataset.py,sha256=xsrYx7_5FCqez30dc4hSDYfyA9A0oKqHqwt7CRc1J0c,33812
|
12
|
-
caption_flow/utils/__init__.py,sha256=F1BChVoCsj9zn1GJRBOLHET1kLW6xrAmsbzcR7hHy6Y,202
|
13
|
-
caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
|
14
|
-
caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
|
15
|
-
caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
|
16
|
-
caption_flow/utils/checkpoint_tracker.py,sha256=-nN5gLvXyMdKOCT2SNNL2Km6UYm2Hii9wuXeezWhwx4,3339
|
17
|
-
caption_flow/utils/chunk_tracker.py,sha256=x9UwFxpj-nMeAJ6bpKw5E09QNUqu7L0pejTlk8nxgE8,19402
|
18
|
-
caption_flow/utils/dataset_loader.py,sha256=2-SgXPGQkF4CyA3zyVYfSbZMSk4YzTsVFY0izmOZPrM,8771
|
19
|
-
caption_flow/utils/dataset_metadata_cache.py,sha256=AJ8Z1GYT0DC9_LLjxNvrePKU7ecenNZun5GhaB2gvj0,2650
|
20
|
-
caption_flow/utils/image_processor.py,sha256=7Ed92iUJ-OvjzQmAGPaULoYEqoirVHHo0lxtceWGc44,5586
|
21
|
-
caption_flow/utils/job_queue.py,sha256=itdfXcrkvGjmXn4qtpgMF63k1ufRBaejDe4V6WcxzgU,1104
|
22
|
-
caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
|
23
|
-
caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
|
24
|
-
caption_flow/utils/shard_processor.py,sha256=_PCW5TfSHFfCc63Sn7bVzgjA625-aWzL4cWwZLjW0rQ,3935
|
25
|
-
caption_flow/utils/shard_tracker.py,sha256=1OqiueaC8WoxhY2nc03erZAc50mnQCZazATS6R14lbQ,3029
|
26
|
-
caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
|
27
|
-
caption_flow/workers/base.py,sha256=2AGWERC5hbmO-0V_A1MUbgRVvRNN3blqGPyDokvvzmM,7575
|
28
|
-
caption_flow/workers/caption.py,sha256=_uvpdoBzym1TKWKXtky7hBfj8YnG1EaJz-NRwaH2X1A,36722
|
29
|
-
caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
|
30
|
-
caption_flow-0.2.3.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
|
31
|
-
caption_flow-0.2.3.dist-info/METADATA,sha256=bk5Gk3eWuDH_UWXPEDKulksPc3hVHvnzm3sstLbuU-0,11914
|
32
|
-
caption_flow-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
33
|
-
caption_flow-0.2.3.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
|
34
|
-
caption_flow-0.2.3.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
|
35
|
-
caption_flow-0.2.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|