caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caption_flow/cli.py CHANGED
@@ -161,6 +161,7 @@ def main(ctx, verbose: bool):
161
161
  @click.option("--key", help="SSL key path")
162
162
  @click.option("--no-ssl", is_flag=True, help="Disable SSL (development only)")
163
163
  @click.option("--vllm", is_flag=True, help="Use vLLM orchestrator for WebDataset/HF datasets")
164
+ @click.option("--verbose", is_flag=True, help="Enable verbose logging")
164
165
  @click.pass_context
165
166
  def orchestrator(ctx, config: Optional[str], **kwargs):
166
167
  """Start the orchestrator server."""
@@ -366,6 +367,63 @@ def monitor(
366
367
  sys.exit(1)
367
368
 
368
369
 
370
+ # Add this command after the export command in cli.py
371
+
372
+
373
+ @main.command()
374
+ @click.option("--data-dir", default="./caption_data", help="Storage directory")
375
+ @click.option("--refresh-rate", default=10, type=int, help="Display refresh rate (Hz)")
376
+ @click.option("--no-images", is_flag=True, help="Disable image preview")
377
+ @click.pass_context
378
+ def view(ctx, data_dir: str, refresh_rate: int, no_images: bool):
379
+ """Browse captioned dataset with interactive TUI viewer."""
380
+ from .viewer import DatasetViewer
381
+
382
+ data_path = Path(data_dir)
383
+
384
+ if not data_path.exists():
385
+ console.print(f"[red]Storage directory not found: {data_dir}[/red]")
386
+ sys.exit(1)
387
+
388
+ if not (data_path / "captions.parquet").exists():
389
+ console.print(f"[red]No captions file found in {data_dir}[/red]")
390
+ console.print("[yellow]Have you exported any captions yet?[/yellow]")
391
+ sys.exit(1)
392
+
393
+ # Check for term-image if images are enabled
394
+ if not no_images:
395
+ try:
396
+ import term_image
397
+ except ImportError:
398
+ console.print("[yellow]Warning: term-image not installed[/yellow]")
399
+ console.print("Install with: pip install term-image")
400
+ console.print("Running without image preview...")
401
+ no_images = True
402
+
403
+ try:
404
+ viewer = DatasetViewer(data_path)
405
+ if no_images:
406
+ viewer.disable_images = True
407
+ viewer.refresh_rate = refresh_rate
408
+
409
+ console.print(f"[cyan]Starting dataset viewer...[/cyan]")
410
+ console.print(f"[dim]Data directory: {data_path}[/dim]")
411
+
412
+ asyncio.run(viewer.run())
413
+
414
+ except FileNotFoundError as e:
415
+ console.print(f"[red]Error: {e}[/red]")
416
+ sys.exit(1)
417
+ except KeyboardInterrupt:
418
+ console.print("\n[yellow]Viewer closed[/yellow]")
419
+ except Exception as e:
420
+ console.print(f"[red]Error: {e}[/red]")
421
+ import traceback
422
+
423
+ traceback.print_exc()
424
+ sys.exit(1)
425
+
426
+
369
427
  @main.command()
370
428
  @click.option("--config", type=click.Path(exists=True), help="Configuration file")
371
429
  @click.option("--server", help="Orchestrator WebSocket URL")
@@ -635,6 +693,256 @@ def scan_chunks(data_dir: str, checkpoint_dir: str, fix: bool, verbose: bool):
635
693
  tracker.save_checkpoint()
636
694
 
637
695
 
696
+ @main.command()
697
+ @click.option("--data-dir", default="./caption_data", help="Storage directory")
698
+ @click.option(
699
+ "--format",
700
+ type=click.Choice(
701
+ ["jsonl", "json", "csv", "txt", "huggingface_hub", "all"], case_sensitive=False
702
+ ),
703
+ default="jsonl",
704
+ help="Export format (default: jsonl)",
705
+ )
706
+ @click.option("--output", "-o", help="Output path (file for jsonl/csv, directory for json/txt)")
707
+ @click.option("--limit", type=int, help="Limit number of rows to export")
708
+ @click.option("--columns", help="Comma-separated list of columns to export (default: all)")
709
+ @click.option("--export-column", default="captions", help="Column to export for txt format")
710
+ @click.option("--filename-column", default="filename", help="Column containing filenames")
711
+ @click.option("--include-empty", is_flag=True, help="Include rows with empty export column")
712
+ @click.option("--stats-only", is_flag=True, help="Show statistics without exporting")
713
+ @click.option(
714
+ "--optimize", is_flag=True, help="Optimize storage before export (remove empty columns)"
715
+ )
716
+ @click.option("--verbose", is_flag=True, help="Show detailed export progress")
717
+ @click.option("--hf-dataset", help="Dataset name on HF Hub (e.g., username/dataset-name)")
718
+ @click.option("--license", help="License for the dataset (required for new HF datasets)")
719
+ @click.option("--private", is_flag=True, help="Make HF dataset private")
720
+ @click.option("--nsfw", is_flag=True, help="Add not-for-all-audiences tag")
721
+ @click.option("--tags", help="Comma-separated tags for HF dataset")
722
+ def export(
723
+ data_dir: str,
724
+ format: str,
725
+ output: Optional[str],
726
+ limit: Optional[int],
727
+ columns: Optional[str],
728
+ export_column: str,
729
+ filename_column: str,
730
+ include_empty: bool,
731
+ stats_only: bool,
732
+ optimize: bool,
733
+ verbose: bool,
734
+ hf_dataset: Optional[str],
735
+ license: Optional[str],
736
+ private: bool,
737
+ nsfw: bool,
738
+ tags: Optional[str],
739
+ ):
740
+ """Export caption data to various formats."""
741
+ from .storage import StorageManager
742
+ from .storage.exporter import StorageExporter, ExportError
743
+
744
+ # Initialize storage manager
745
+ storage_path = Path(data_dir)
746
+ if not storage_path.exists():
747
+ console.print(f"[red]Storage directory not found: {data_dir}[/red]")
748
+ sys.exit(1)
749
+
750
+ storage = StorageManager(storage_path)
751
+
752
+ async def run_export():
753
+ await storage.initialize()
754
+
755
+ # Show statistics first
756
+ stats = await storage.get_caption_stats()
757
+ console.print("\n[bold cyan]Storage Statistics:[/bold cyan]")
758
+ console.print(f"[green]Total rows:[/green] {stats['total_rows']:,}")
759
+ console.print(f"[green]Total outputs:[/green] {stats['total_outputs']:,}")
760
+ console.print(f"[green]Output fields:[/green] {', '.join(stats['output_fields'])}")
761
+
762
+ if stats.get("field_stats"):
763
+ console.print("\n[cyan]Field breakdown:[/cyan]")
764
+ for field, field_stat in stats["field_stats"].items():
765
+ console.print(
766
+ f" • {field}: {field_stat['total_items']:,} items "
767
+ f"in {field_stat['rows_with_data']:,} rows"
768
+ )
769
+
770
+ if stats_only:
771
+ return
772
+
773
+ # Optimize storage if requested
774
+ if optimize:
775
+ console.print("\n[yellow]Optimizing storage (removing empty columns)...[/yellow]")
776
+ await storage.optimize_storage()
777
+
778
+ # Prepare columns list
779
+ column_list = None
780
+ if columns:
781
+ column_list = [col.strip() for col in columns.split(",")]
782
+ console.print(f"\n[cyan]Exporting columns:[/cyan] {', '.join(column_list)}")
783
+
784
+ # Get storage contents
785
+ console.print("\n[yellow]Loading data...[/yellow]")
786
+ try:
787
+ contents = await storage.get_storage_contents(
788
+ limit=limit, columns=column_list, include_metadata=True
789
+ )
790
+ except ValueError as e:
791
+ console.print(f"[red]Error: {e}[/red]")
792
+ sys.exit(1)
793
+
794
+ if not contents.rows:
795
+ console.print("[yellow]No data to export![/yellow]")
796
+ return
797
+
798
+ # Filter out empty rows if not including empty
799
+ if not include_empty and format in ["txt", "json"]:
800
+ original_count = len(contents.rows)
801
+ contents.rows = [
802
+ row
803
+ for row in contents.rows
804
+ if row.get(export_column)
805
+ and (not isinstance(row[export_column], list) or len(row[export_column]) > 0)
806
+ ]
807
+ filtered_count = original_count - len(contents.rows)
808
+ if filtered_count > 0:
809
+ console.print(f"[dim]Filtered {filtered_count} empty rows[/dim]")
810
+
811
+ # Create exporter
812
+ exporter = StorageExporter(contents)
813
+
814
+ # Determine output paths
815
+ if format == "all":
816
+ # Export to all formats
817
+ base_name = output or "caption_export"
818
+ base_path = Path(base_name)
819
+
820
+ formats_exported = []
821
+
822
+ # JSONL
823
+ jsonl_path = base_path.with_suffix(".jsonl")
824
+ console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {jsonl_path}")
825
+ rows = exporter.to_jsonl(jsonl_path)
826
+ formats_exported.append(f"JSONL: {rows:,} rows")
827
+
828
+ # CSV
829
+ csv_path = base_path.with_suffix(".csv")
830
+ console.print(f"[cyan]Exporting to CSV:[/cyan] {csv_path}")
831
+ try:
832
+ rows = exporter.to_csv(csv_path)
833
+ formats_exported.append(f"CSV: {rows:,} rows")
834
+ except ExportError as e:
835
+ console.print(f"[yellow]Skipping CSV: {e}[/yellow]")
836
+
837
+ # JSON files
838
+ json_dir = base_path.parent / f"{base_path.stem}_json"
839
+ console.print(f"[cyan]Exporting to JSON files:[/cyan] {json_dir}/")
840
+ try:
841
+ files = exporter.to_json(json_dir, filename_column)
842
+ formats_exported.append(f"JSON: {files:,} files")
843
+ except ExportError as e:
844
+ console.print(f"[yellow]Skipping JSON files: {e}[/yellow]")
845
+
846
+ # Text files
847
+ txt_dir = base_path.parent / f"{base_path.stem}_txt"
848
+ console.print(f"[cyan]Exporting to text files:[/cyan] {txt_dir}/")
849
+ try:
850
+ files = exporter.to_txt(txt_dir, filename_column, export_column)
851
+ formats_exported.append(f"Text: {files:,} files")
852
+ except ExportError as e:
853
+ console.print(f"[yellow]Skipping text files: {e}[/yellow]")
854
+
855
+ console.print(f"\n[green]✓ Export complete![/green]")
856
+ for fmt in formats_exported:
857
+ console.print(f" • {fmt}")
858
+
859
+ else:
860
+ # Single format export
861
+ try:
862
+ if format == "jsonl":
863
+ output_path = output or "captions.jsonl"
864
+ console.print(f"\n[cyan]Exporting to JSONL:[/cyan] {output_path}")
865
+ rows = exporter.to_jsonl(output_path)
866
+ console.print(f"[green]✓ Exported {rows:,} rows[/green]")
867
+
868
+ elif format == "csv":
869
+ output_path = output or "captions.csv"
870
+ console.print(f"\n[cyan]Exporting to CSV:[/cyan] {output_path}")
871
+ rows = exporter.to_csv(output_path)
872
+ console.print(f"[green]✓ Exported {rows:,} rows[/green]")
873
+
874
+ elif format == "json":
875
+ output_dir = output or "./json_output"
876
+ console.print(f"\n[cyan]Exporting to JSON files:[/cyan] {output_dir}/")
877
+ files = exporter.to_json(output_dir, filename_column)
878
+ console.print(f"[green]✓ Created {files:,} JSON files[/green]")
879
+
880
+ elif format == "txt":
881
+ output_dir = output or "./txt_output"
882
+ console.print(f"\n[cyan]Exporting to text files:[/cyan] {output_dir}/")
883
+ console.print(f"[dim]Export column: {export_column}[/dim]")
884
+ files = exporter.to_txt(output_dir, filename_column, export_column)
885
+ console.print(f"[green]✓ Created {files:,} text files[/green]")
886
+
887
+ elif format == "huggingface_hub":
888
+ # Validate required parameters
889
+ if not hf_dataset:
890
+ console.print(
891
+ "[red]Error: --hf-dataset required for huggingface_hub format[/red]"
892
+ )
893
+ console.print(
894
+ "[dim]Example: --hf-dataset username/my-caption-dataset[/dim]"
895
+ )
896
+ sys.exit(1)
897
+
898
+ # Parse tags
899
+ tag_list = None
900
+ if tags:
901
+ tag_list = [tag.strip() for tag in tags.split(",")]
902
+
903
+ console.print(f"\n[cyan]Uploading to Hugging Face Hub:[/cyan] {hf_dataset}")
904
+ if private:
905
+ console.print("[dim]Privacy: Private dataset[/dim]")
906
+ if nsfw:
907
+ console.print("[dim]Content: Not for all audiences[/dim]")
908
+ if tag_list:
909
+ console.print(f"[dim]Tags: {', '.join(tag_list)}[/dim]")
910
+
911
+ url = exporter.to_huggingface_hub(
912
+ dataset_name=hf_dataset,
913
+ license=license,
914
+ private=private,
915
+ nsfw=nsfw,
916
+ tags=tag_list,
917
+ )
918
+ console.print(f"[green]✓ Dataset uploaded to: {url}[/green]")
919
+
920
+ except ExportError as e:
921
+ console.print(f"[red]Export error: {e}[/red]")
922
+ sys.exit(1)
923
+
924
+ # Show export metadata
925
+ if verbose and contents.metadata:
926
+ console.print("\n[dim]Export metadata:[/dim]")
927
+ console.print(f" Timestamp: {contents.metadata.get('export_timestamp')}")
928
+ console.print(f" Total available: {contents.metadata.get('total_available_rows'):,}")
929
+ console.print(f" Rows exported: {contents.metadata.get('rows_exported'):,}")
930
+
931
+ # Run the async export
932
+ try:
933
+ asyncio.run(run_export())
934
+ except KeyboardInterrupt:
935
+ console.print("\n[yellow]Export cancelled[/yellow]")
936
+ sys.exit(1)
937
+ except Exception as e:
938
+ console.print(f"[red]Unexpected error: {e}[/red]")
939
+ if verbose:
940
+ import traceback
941
+
942
+ traceback.print_exc()
943
+ sys.exit(1)
944
+
945
+
638
946
  @main.command()
639
947
  @click.option("--domain", help="Domain for Let's Encrypt certificate")
640
948
  @click.option("--email", help="Email for Let's Encrypt registration")
caption_flow/models.py CHANGED
@@ -1,9 +1,11 @@
1
1
  """Data models for CaptionFlow."""
2
2
 
3
+ import PIL
3
4
  from dataclasses import dataclass, field
4
5
  from datetime import datetime
5
6
  from enum import Enum
6
- from typing import Any, Dict, List, Optional
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+ from PIL import Image
7
9
 
8
10
 
9
11
  class JobStatus(Enum):
@@ -38,6 +40,38 @@ class Job:
38
40
  self.created_at = datetime.utcnow()
39
41
 
40
42
 
43
+ @dataclass
44
+ class JobId:
45
+ shard_id: str
46
+ chunk_id: str
47
+ sample_id: str
48
+
49
+ def get_shard_str(self):
50
+ return f"{self.shard_id}"
51
+
52
+ def get_chunk_str(self):
53
+ return f"{self.shard_id}:chunk:{self.chunk_id}"
54
+
55
+ def get_sample_str(self):
56
+ return f"{self.shard_id}:chunk:{self.chunk_id}:idx:{self.sample_id}"
57
+
58
+ @staticmethod
59
+ def from_dict(job: dict) -> "JobId":
60
+ return JobId(shard_id=job["shard_id"], chunk_id=job["chunk_id"], sample_id=job["sample_id"])
61
+
62
+ @staticmethod
63
+ def from_values(shard_id: str, chunk_id: str, sample_id: str) -> "JobId":
64
+ return JobId(shard_id=shard_id, chunk_id=chunk_id, sample_id=sample_id)
65
+
66
+ @staticmethod
67
+ def from_str(job_id: str):
68
+ # from data-0000:chunk:0:idx:0
69
+ parts = job_id.split(":")
70
+ if len(parts) != 5:
71
+ raise ValueError(f"Invalid job_id format: {job_id}")
72
+ return JobId(shard_id=parts[0], chunk_id=parts[2], sample_id=parts[4])
73
+
74
+
41
75
  @dataclass
42
76
  class Caption:
43
77
  """Generated caption with attribution and image metadata."""
@@ -61,6 +95,8 @@ class Caption:
61
95
  image_height: Optional[int] = None
62
96
  image_format: Optional[str] = None
63
97
  file_size: Optional[int] = None
98
+ filename: Optional[str] = None
99
+ url: Optional[str] = None
64
100
 
65
101
  # Processing metadata
66
102
  caption_index: Optional[int] = None # Which caption this is (0, 1, 2...)
@@ -82,3 +118,100 @@ class Contributor:
82
118
  name: str
83
119
  total_captions: int = 0
84
120
  trust_level: int = 1
121
+
122
+
123
+ @dataclass
124
+ class ProcessingStage:
125
+ """Configuration for a single processing stage."""
126
+
127
+ name: str
128
+ model: str
129
+ prompts: List[str]
130
+ output_field: str
131
+ requires: List[str] = field(default_factory=list)
132
+ sampling: Optional[Dict[str, Any]] = None
133
+
134
+ # Model-specific overrides
135
+ tensor_parallel_size: Optional[int] = None
136
+ max_model_len: Optional[int] = None
137
+ dtype: Optional[str] = None
138
+ gpu_memory_utilization: Optional[float] = None
139
+
140
+
141
+ @dataclass
142
+ class StageResult:
143
+ """Results from a single stage."""
144
+
145
+ stage_name: str
146
+ output_field: str
147
+ outputs: List[str] # Multiple outputs from multiple prompts
148
+ error: Optional[str] = None
149
+
150
+ def is_success(self) -> bool:
151
+ return self.error is None and bool(self.outputs)
152
+
153
+
154
+ @dataclass
155
+ class ShardChunk:
156
+ """Shard chunk assignment with unprocessed ranges."""
157
+
158
+ chunk_id: str
159
+ shard_url: str
160
+ shard_name: str
161
+ start_index: int
162
+ chunk_size: int
163
+ unprocessed_ranges: List[Tuple[int, int]] = field(default_factory=list)
164
+
165
+
166
+ @dataclass
167
+ class ProcessingItem:
168
+ """Item being processed."""
169
+
170
+ chunk_id: str
171
+ item_key: str
172
+ image: Image.Image
173
+ image_data: bytes
174
+ metadata: Dict[str, Any] = field(default_factory=dict)
175
+ stage_results: Dict[str, StageResult] = field(default_factory=dict) # Accumulated results
176
+
177
+
178
+ @dataclass
179
+ class ProcessedResult:
180
+ """Result with multi-stage outputs."""
181
+
182
+ chunk_id: str
183
+ shard_name: str
184
+ item_key: str
185
+ outputs: Dict[str, List[str]] # field_name -> list of outputs
186
+ image_width: int
187
+ image_height: int
188
+ image_format: str
189
+ file_size: int
190
+ processing_time_ms: float
191
+ metadata: Dict[str, Any] = field(default_factory=dict)
192
+
193
+
194
+ @dataclass
195
+ class StorageContents:
196
+ """Container for storage data to be exported."""
197
+
198
+ rows: List[Dict[str, Any]]
199
+ columns: List[str]
200
+ output_fields: List[str]
201
+ total_rows: int
202
+ metadata: Dict[str, Any] = field(default_factory=dict)
203
+
204
+ def __post_init__(self):
205
+ """Validate data consistency."""
206
+ if self.rows and self.columns:
207
+ # Ensure all rows have the expected columns
208
+ for row in self.rows:
209
+ missing_cols = set(self.columns) - set(row.keys())
210
+ if missing_cols:
211
+ logger.warning(f"Row missing columns: {missing_cols}")
212
+
213
+
214
+ class ExportError(Exception):
215
+ """Base exception for export-related errors."""
216
+
217
+ pass
caption_flow/monitor.py CHANGED
@@ -83,7 +83,7 @@ class Monitor:
83
83
  await self._handle_update(data)
84
84
 
85
85
  except Exception as e:
86
- logger.error(f"Connection error: {e}")
86
+ logger.error(f"Connection error: {e}", exc_info=True)
87
87
  await asyncio.sleep(5)
88
88
 
89
89
  async def _handle_update(self, data: Dict):