caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,119 +0,0 @@
1
- """Shard processing abstraction for different dataset types."""
2
-
3
- import io
4
- import logging
5
- import time
6
- from abc import ABC, abstractmethod
7
- from pathlib import Path
8
- from typing import Generator, Tuple, Optional, Dict, Any
9
- from dataclasses import dataclass
10
- from .image_processor import ImageProcessor
11
- from threading import Event
12
- import shlex
13
-
14
- import webdataset as wds
15
- from PIL import Image
16
-
17
- from .dataset_loader import DatasetLoader
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class ShardProcessor(ABC):
23
- """Abstract base for processing dataset shards."""
24
-
25
- @abstractmethod
26
- def iterate_chunk_with_metadata(
27
- self,
28
- chunk,
29
- dataset_loader: Optional[DatasetLoader],
30
- should_stop: Event,
31
- connected: Event,
32
- ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
33
- """
34
- Iterate through items in a chunk with metadata.
35
-
36
- Yields:
37
- Tuple of (key, url, image_data, metadata)
38
- """
39
- pass
40
-
41
-
42
- class WebDatasetShardProcessor(ShardProcessor):
43
- """Processor for WebDataset tar shards with range support."""
44
-
45
- def __init__(self, hf_token: Optional[str] = None, dataset_type: str = "local"):
46
- self.hf_token = hf_token
47
- self.dataset_type = dataset_type
48
-
49
- def iterate_chunk_with_metadata(
50
- self,
51
- chunk,
52
- dataset_loader: Optional[DatasetLoader],
53
- should_stop: Event,
54
- connected: Event,
55
- ) -> Generator[Tuple[str, str, bytes, Dict[str, Any]], None, None]:
56
- """Process WebDataset shard chunk with metadata and range support."""
57
- # Get unprocessed ranges
58
- unprocessed_ranges = getattr(chunk, "unprocessed_ranges", [(0, chunk.chunk_size - 1)])
59
-
60
- logger.info(
61
- f"Processing WebDataset chunk {chunk.chunk_id} with ranges: {unprocessed_ranges}"
62
- )
63
-
64
- # Create WebDataset pipeline
65
- if self.dataset_type == "huggingface":
66
- # Use curl with auth for HuggingFace WebDataset
67
- url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.hf_token)}' {shlex.quote(chunk.shard_url)} || true"
68
- ds = wds.DataPipeline(
69
- wds.SimpleShardList(url_cmd),
70
- wds.tarfile_to_samples(),
71
- wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
72
- )
73
- else:
74
- # Local file
75
- ds = wds.DataPipeline(
76
- wds.SimpleShardList(chunk.shard_url),
77
- wds.tarfile_to_samples(),
78
- wds.to_tuple("__key__", "jpg;png;jpeg;webp;jxl"),
79
- )
80
-
81
- # Process items
82
- absolute_idx = 0 # Absolute index in the shard
83
- items_yielded = 0
84
-
85
- for key, image_data in ds:
86
- # Check if we should stop
87
- if should_stop.is_set() or not connected.is_set():
88
- logger.info(f"Stopping WebDataset chunk processing early due to disconnect")
89
- break
90
-
91
- # Skip items before chunk start
92
- if absolute_idx < chunk.start_index:
93
- absolute_idx += 1
94
- continue
95
-
96
- # Calculate relative index within chunk
97
- relative_idx = absolute_idx - chunk.start_index
98
-
99
- # Stop if beyond chunk
100
- if relative_idx >= chunk.chunk_size:
101
- break
102
-
103
- # Check if current index is in any unprocessed range
104
- in_range = any(start <= relative_idx <= end for start, end in unprocessed_ranges)
105
-
106
- if in_range:
107
- # Create metadata with the relative index
108
- metadata = {
109
- "_chunk_relative_index": relative_idx,
110
- }
111
- items_yielded += 1
112
- yield key, chunk.shard_url, image_data, metadata
113
-
114
- absolute_idx += 1
115
-
116
- logger.info(
117
- f"WebDataset chunk {chunk.chunk_id}: yielded {items_yielded} items "
118
- f"from ranges {unprocessed_ranges}"
119
- )
@@ -1,83 +0,0 @@
1
- """Shard tracking using CheckpointTracker base class."""
2
-
3
- from pathlib import Path
4
- from typing import Dict, Any, List, Set
5
-
6
- from .checkpoint_tracker import CheckpointTracker
7
-
8
-
9
- class ShardTracker(CheckpointTracker):
10
- """Tracks shard processing progress."""
11
-
12
- def __init__(self, checkpoint_path: Path):
13
- """Initialize shard tracker with checkpoint file."""
14
- self.completed_shards: Set[str] = set()
15
- self.partial_shards: Dict[str, Dict[str, Any]] = {}
16
- super().__init__(checkpoint_path)
17
-
18
- def _get_default_state(self) -> Dict[str, Any]:
19
- """Return default state structure for new checkpoints."""
20
- return {"completed_shards": [], "partial_shards": {}}
21
-
22
- def _deserialize_state(self, data: Dict[str, Any]) -> None:
23
- """Deserialize loaded data into instance state."""
24
- self.completed_shards = set(data.get("completed_shards", []))
25
- self.partial_shards = data.get("partial_shards", {})
26
-
27
- def _serialize_state(self) -> Dict[str, Any]:
28
- """Serialize instance state for saving."""
29
- return {
30
- "completed_shards": list(self.completed_shards),
31
- "partial_shards": self.partial_shards,
32
- }
33
-
34
- def mark_complete(self, shard_name: str) -> None:
35
- """Mark a shard as complete."""
36
- self.completed_shards.add(shard_name)
37
- if shard_name in self.partial_shards:
38
- del self.partial_shards[shard_name]
39
- self.save()
40
-
41
- def update_partial(self, shard_name: str, processed_keys: List[str]) -> None:
42
- """Update partial progress for a shard."""
43
- self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
44
- self.save()
45
-
46
- def get_processed_keys(self, shard_name: str) -> Set[str]:
47
- """Get set of processed keys for a shard."""
48
- if shard_name in self.completed_shards:
49
- return set() # All done
50
-
51
- if shard_name in self.partial_shards:
52
- return set(self.partial_shards[shard_name].get("keys", []))
53
-
54
- return set()
55
-
56
- def is_complete(self, shard_name: str) -> bool:
57
- """Check if a shard is complete."""
58
- return shard_name in self.completed_shards
59
-
60
- def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
61
- """Get list of shards that still need processing."""
62
- remaining = []
63
- for s in all_shards:
64
- shard_name = Path(s).stem
65
-
66
- if shard_name not in self.completed_shards:
67
- remaining.append(s)
68
-
69
- return remaining
70
-
71
- def get_stats(self) -> Dict[str, Any]:
72
- """Get shard tracking statistics."""
73
- base_stats = super().get_stats()
74
- base_stats.update(
75
- {
76
- "completed_shards": len(self.completed_shards),
77
- "partial_shards": len(self.partial_shards),
78
- "total_partial_keys": sum(
79
- len(data.get("keys", [])) for data in self.partial_shards.values()
80
- ),
81
- }
82
- )
83
- return base_stats
@@ -1,35 +0,0 @@
1
- caption_flow/__init__.py,sha256=NLPJ25lRN7xHqncXweINDNwbt0q8lgjZ30G21zlPdRs,303
2
- caption_flow/cli.py,sha256=qEueeJhf3DvxSBxnOp5t32p6gAnZskvIDe6cwtPA0-Y,28892
3
- caption_flow/models.py,sha256=bpr7yMy3vPErZCQwmgOYIix489rRGbT6lVw8wxxwTkc,4931
4
- caption_flow/monitor.py,sha256=bAt9EJqfPgT_KdbknGdCxwBRH002pRDgyUmYIj6Dyso,7885
5
- caption_flow/orchestrator.py,sha256=ciqWghxUxk-5s6u7W3JwD7_JLSFYV57NgOwiMkxME-I,36133
6
- caption_flow/storage.py,sha256=Wqgtsk6yZ9Kf-izeUKHLwSvPUH3xFqIbzox20QHbc64,43370
7
- caption_flow/processors/__init__.py,sha256=hvq-OuAJWQe6hFglKe7QmkS8473k20FmxZDSxfXpCrg,423
8
- caption_flow/processors/base.py,sha256=JlTqCHo5HRXrXMVzgle_6pNwh4HGHsF7jLF6PeSnWr0,6783
9
- caption_flow/processors/huggingface.py,sha256=MNz9vDMtrrTOSXe9Q_kbBrQ7XBv69X6x5xD_QP9icdg,33765
10
- caption_flow/processors/local_filesystem.py,sha256=EYmsImbkqsIU7UZL2FijL0hotKLtPOtkzfwernQDSxA,27860
11
- caption_flow/processors/webdataset.py,sha256=xsrYx7_5FCqez30dc4hSDYfyA9A0oKqHqwt7CRc1J0c,33812
12
- caption_flow/utils/__init__.py,sha256=F1BChVoCsj9zn1GJRBOLHET1kLW6xrAmsbzcR7hHy6Y,202
13
- caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
14
- caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
15
- caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
16
- caption_flow/utils/checkpoint_tracker.py,sha256=-nN5gLvXyMdKOCT2SNNL2Km6UYm2Hii9wuXeezWhwx4,3339
17
- caption_flow/utils/chunk_tracker.py,sha256=x9UwFxpj-nMeAJ6bpKw5E09QNUqu7L0pejTlk8nxgE8,19402
18
- caption_flow/utils/dataset_loader.py,sha256=2-SgXPGQkF4CyA3zyVYfSbZMSk4YzTsVFY0izmOZPrM,8771
19
- caption_flow/utils/dataset_metadata_cache.py,sha256=AJ8Z1GYT0DC9_LLjxNvrePKU7ecenNZun5GhaB2gvj0,2650
20
- caption_flow/utils/image_processor.py,sha256=7Ed92iUJ-OvjzQmAGPaULoYEqoirVHHo0lxtceWGc44,5586
21
- caption_flow/utils/job_queue.py,sha256=itdfXcrkvGjmXn4qtpgMF63k1ufRBaejDe4V6WcxzgU,1104
22
- caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
23
- caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
24
- caption_flow/utils/shard_processor.py,sha256=_PCW5TfSHFfCc63Sn7bVzgjA625-aWzL4cWwZLjW0rQ,3935
25
- caption_flow/utils/shard_tracker.py,sha256=1OqiueaC8WoxhY2nc03erZAc50mnQCZazATS6R14lbQ,3029
26
- caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
27
- caption_flow/workers/base.py,sha256=2AGWERC5hbmO-0V_A1MUbgRVvRNN3blqGPyDokvvzmM,7575
28
- caption_flow/workers/caption.py,sha256=_uvpdoBzym1TKWKXtky7hBfj8YnG1EaJz-NRwaH2X1A,36722
29
- caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
30
- caption_flow-0.2.3.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
31
- caption_flow-0.2.3.dist-info/METADATA,sha256=bk5Gk3eWuDH_UWXPEDKulksPc3hVHvnzm3sstLbuU-0,11914
32
- caption_flow-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
33
- caption_flow-0.2.3.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
34
- caption_flow-0.2.3.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
35
- caption_flow-0.2.3.dist-info/RECORD,,