caption-flow 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,550 @@
1
+ """Storage exporter for converting Parquet data to various formats."""
2
+
3
+ import json
4
+ import csv
5
+ from pathlib import Path
6
+ from typing import List, Dict, Any, Optional, Union
7
+ from dataclasses import dataclass, field
8
+ import logging
9
+ import pandas as pd
10
+ import numpy as np
11
+ from ..models import StorageContents, ExportError
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class StorageExporter:
17
+ """Exports StorageContents to various formats."""
18
+
19
+ def __init__(self, contents: StorageContents):
20
+ """Initialize exporter with storage contents.
21
+
22
+ Args:
23
+ contents: StorageContents instance to export
24
+ """
25
+ self.contents = contents
26
+ self._validate_contents()
27
+
28
+ def _validate_contents(self):
29
+ """Validate that contents are suitable for export."""
30
+ if not self.contents.rows:
31
+ logger.warning("No rows to export")
32
+ if not self.contents.columns:
33
+ raise ExportError("No columns defined for export")
34
+
35
+ def _flatten_lists(self, value: Any) -> str:
36
+ """Convert list values to newline-separated strings."""
37
+ if isinstance(value, list):
38
+ # Strip newlines from each element and join
39
+ return "\n".join(str(item).replace("\n", " ") for item in value)
40
+ return str(value) if value is not None else ""
41
+
42
+ def _serialize_value(self, value: Any) -> Any:
43
+ """Convert values to JSON-serializable format."""
44
+ if pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(value, pd.Timestamp):
45
+ return value.isoformat()
46
+ elif isinstance(value, np.integer):
47
+ return int(value)
48
+ elif isinstance(value, np.floating):
49
+ return float(value)
50
+ elif isinstance(value, np.ndarray):
51
+ return value.tolist()
52
+ elif isinstance(value, dict):
53
+ return {k: self._serialize_value(v) for k, v in value.items()}
54
+ elif isinstance(value, list):
55
+ return [self._serialize_value(item) for item in value]
56
+ return value
57
+
58
+ def to_jsonl(self, output_path: Union[str, Path]) -> int:
59
+ """Export to JSONL (JSON Lines) format.
60
+
61
+ Args:
62
+ output_path: Path to output JSONL file
63
+
64
+ Returns:
65
+ Number of rows exported
66
+ """
67
+ output_path = Path(output_path)
68
+ output_path.parent.mkdir(parents=True, exist_ok=True)
69
+
70
+ rows_written = 0
71
+ with open(output_path, "w", encoding="utf-8") as f:
72
+ for row in self.contents.rows:
73
+ # Convert non-serializable values
74
+ serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
75
+ # Write each row as a JSON object on its own line
76
+ json_line = json.dumps(serializable_row, ensure_ascii=False)
77
+ f.write(json_line + "\n")
78
+ rows_written += 1
79
+
80
+ logger.info(f"Exported {rows_written} rows to JSONL: {output_path}")
81
+ return rows_written
82
+
83
+ def _get_filename_from_row(self, row: Dict[str, Any], filename_column: str) -> Optional[str]:
84
+ """Extract filename from row, falling back to URL if needed."""
85
+ # Try the specified filename column first
86
+ filename = row.get(filename_column)
87
+ if filename:
88
+ return filename
89
+
90
+ # Fall back to URL if available
91
+ url = row.get("url")
92
+ if url:
93
+ # Extract filename from URL path
94
+ from urllib.parse import urlparse
95
+
96
+ parsed = urlparse(str(url))
97
+ path_parts = parsed.path.rstrip("/").split("/")
98
+ if path_parts and path_parts[-1]:
99
+ return path_parts[-1]
100
+
101
+ return None
102
+
103
+ def to_json(self, output_dir: Union[str, Path], filename_column: str = "filename") -> int:
104
+ """Export to individual JSON files based on filename column.
105
+
106
+ Args:
107
+ output_dir: Directory to write JSON files
108
+ filename_column: Column containing the base filename
109
+
110
+ Returns:
111
+ Number of files created
112
+ """
113
+ output_dir = Path(output_dir)
114
+ output_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ # Check if we need to fall back to URL
117
+ using_url_fallback = False
118
+ if filename_column not in self.contents.columns and "url" in self.contents.columns:
119
+ logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
120
+ using_url_fallback = True
121
+ elif filename_column not in self.contents.columns:
122
+ raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
123
+
124
+ files_created = 0
125
+ skipped_count = 0
126
+
127
+ for row in self.contents.rows:
128
+ filename = self._get_filename_from_row(row, filename_column)
129
+ if not filename:
130
+ skipped_count += 1
131
+ logger.warning(f"Skipping row with no extractable filename")
132
+ continue
133
+
134
+ # Create JSON filename from original filename
135
+ base_name = Path(filename).stem
136
+ json_path = output_dir / f"{base_name}.json"
137
+
138
+ # Convert non-serializable values
139
+ serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
140
+
141
+ # Write row data as JSON
142
+ with open(json_path, "w", encoding="utf-8") as f:
143
+ json.dump(serializable_row, f, ensure_ascii=False, indent=2)
144
+
145
+ files_created += 1
146
+
147
+ if skipped_count > 0:
148
+ logger.warning(f"Skipped {skipped_count} rows with no extractable filename")
149
+
150
+ logger.info(f"Created {files_created} JSON files in: {output_dir}")
151
+ return files_created
152
+
153
+ def to_csv(self, output_path: Union[str, Path]) -> int:
154
+ """Export to CSV format, skipping complex columns.
155
+
156
+ Args:
157
+ output_path: Path to output CSV file
158
+
159
+ Returns:
160
+ Number of rows exported
161
+ """
162
+ output_path = Path(output_path)
163
+ output_path.parent.mkdir(parents=True, exist_ok=True)
164
+
165
+ # Identify complex columns to skip
166
+ complex_columns = set()
167
+ csv_safe_columns = []
168
+
169
+ # Check column types by sampling data
170
+ sample_size = min(10, len(self.contents.rows))
171
+ for row in self.contents.rows[:sample_size]:
172
+ for col, value in row.items():
173
+ if col not in complex_columns and value is not None:
174
+ # Skip dictionaries and non-output field lists
175
+ if isinstance(value, dict):
176
+ complex_columns.add(col)
177
+ logger.warning(
178
+ f"Column '{col}' contains dict type and will be skipped. "
179
+ "Consider using JSONL format for complete data export."
180
+ )
181
+ elif isinstance(value, list) and col not in self.contents.output_fields:
182
+ complex_columns.add(col)
183
+ logger.warning(
184
+ f"Column '{col}' contains list type and will be skipped. "
185
+ "Consider using JSONL format for complete data export."
186
+ )
187
+
188
+ # Build list of CSV-safe columns
189
+ csv_safe_columns = [col for col in self.contents.columns if col not in complex_columns]
190
+
191
+ if not csv_safe_columns:
192
+ raise ExportError("No columns suitable for CSV export. Use JSONL format instead.")
193
+
194
+ # Prepare rows for CSV export with safe columns only
195
+ csv_rows = []
196
+ for row in self.contents.rows:
197
+ csv_row = {}
198
+ for col in csv_safe_columns:
199
+ value = row.get(col)
200
+ # Handle list values (like captions) by joining with newlines
201
+ if isinstance(value, list):
202
+ csv_row[col] = self._flatten_lists(value)
203
+ elif pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(
204
+ value, pd.Timestamp
205
+ ):
206
+ csv_row[col] = self._serialize_value(value)
207
+ else:
208
+ csv_row[col] = value
209
+ csv_rows.append(csv_row)
210
+
211
+ # Write to CSV
212
+ with open(output_path, "w", encoding="utf-8", newline="") as f:
213
+ writer = csv.DictWriter(f, fieldnames=csv_safe_columns)
214
+ writer.writeheader()
215
+ writer.writerows(csv_rows)
216
+
217
+ # Log results
218
+ if complex_columns:
219
+ skipped_msg = f"Skipped {len(complex_columns)} complex columns: {', '.join(sorted(complex_columns))}"
220
+ logger.warning(skipped_msg)
221
+
222
+ logger.info(
223
+ f"Exported {len(csv_rows)} rows to CSV: {output_path} "
224
+ f"(with {len(csv_safe_columns)}/{len(self.contents.columns)} columns)"
225
+ )
226
+
227
+ return len(csv_rows)
228
+
229
+ def to_txt(
230
+ self,
231
+ output_dir: Union[str, Path],
232
+ filename_column: str = "filename",
233
+ export_column: str = "captions",
234
+ ) -> int:
235
+ """Export specific column to individual text files.
236
+
237
+ Args:
238
+ output_dir: Directory to write text files
239
+ filename_column: Column containing the base filename
240
+ export_column: Column to export to text files
241
+
242
+ Returns:
243
+ Number of files created
244
+ """
245
+ output_dir = Path(output_dir)
246
+ output_dir.mkdir(parents=True, exist_ok=True)
247
+
248
+ # Check if we need to fall back to URL
249
+ using_url_fallback = False
250
+ if filename_column not in self.contents.columns and "url" in self.contents.columns:
251
+ logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
252
+ using_url_fallback = True
253
+ elif filename_column not in self.contents.columns:
254
+ raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
255
+
256
+ if export_column not in self.contents.columns:
257
+ # Check if it's an output field
258
+ if export_column not in self.contents.output_fields:
259
+ raise ExportError(f"Column '{export_column}' not found in data")
260
+
261
+ files_created = 0
262
+ skipped_no_filename = 0
263
+ skipped_no_content = 0
264
+
265
+ for row in self.contents.rows:
266
+ filename = self._get_filename_from_row(row, filename_column)
267
+ if not filename:
268
+ skipped_no_filename += 1
269
+ logger.warning(f"Skipping row with no extractable filename")
270
+ continue
271
+
272
+ content = row.get(export_column)
273
+ if content is None:
274
+ skipped_no_content += 1
275
+ logger.warning(f"No {export_column} for {filename}")
276
+ continue
277
+
278
+ # Create text filename from original filename
279
+ base_name = Path(filename).stem
280
+ txt_path = output_dir / f"{base_name}.txt"
281
+
282
+ # Write content
283
+ with open(txt_path, "w", encoding="utf-8") as f:
284
+ f.write(self._flatten_lists(content))
285
+
286
+ files_created += 1
287
+
288
+ if skipped_no_filename > 0:
289
+ logger.warning(f"Skipped {skipped_no_filename} rows with no extractable filename")
290
+ if skipped_no_content > 0:
291
+ logger.warning(f"Skipped {skipped_no_content} rows with no {export_column} content")
292
+
293
+ logger.info(f"Created {files_created} text files in: {output_dir}")
294
+ return files_created
295
+
296
+ def to_huggingface_hub(
297
+ self,
298
+ dataset_name: str,
299
+ token: Optional[str] = None,
300
+ license: Optional[str] = None,
301
+ private: bool = False,
302
+ nsfw: bool = False,
303
+ tags: Optional[List[str]] = None,
304
+ language: str = "en",
305
+ task_categories: Optional[List[str]] = None,
306
+ ) -> str:
307
+ """Export to Hugging Face Hub as a dataset.
308
+
309
+ Args:
310
+ dataset_name: Name for the dataset (e.g., "username/dataset-name")
311
+ token: Hugging Face API token
312
+ license: License for the dataset (required for new repos)
313
+ private: Whether to make the dataset private
314
+ nsfw: Whether to add not-for-all-audiences tag
315
+ tags: Additional tags for the dataset
316
+ language: Language code (default: "en")
317
+ task_categories: Task categories (default: ["text-to-image", "image-to-image"])
318
+
319
+ Returns:
320
+ URL of the uploaded dataset
321
+ """
322
+ try:
323
+ from huggingface_hub import HfApi, DatasetCard, create_repo
324
+ import pyarrow as pa
325
+ import pyarrow.parquet as pq
326
+ except ImportError:
327
+ raise ExportError(
328
+ "huggingface_hub and pyarrow are required for HF export. "
329
+ "Install with: pip install huggingface_hub pyarrow"
330
+ )
331
+
332
+ # Initialize HF API
333
+ api = HfApi(token=token)
334
+
335
+ # Check if repo exists
336
+ repo_exists = False
337
+ try:
338
+ api.dataset_info(dataset_name)
339
+ repo_exists = True
340
+ logger.info(f"Dataset {dataset_name} already exists, will update it")
341
+ except:
342
+ logger.info(f"Creating new dataset: {dataset_name}")
343
+ if not license:
344
+ raise ExportError("License is required when creating a new dataset")
345
+
346
+ # Create repo if it doesn't exist
347
+ if not repo_exists:
348
+ create_repo(repo_id=dataset_name, repo_type="dataset", private=private, token=token)
349
+
350
+ # Prepare data for parquet
351
+ df = pd.DataFrame(self.contents.rows)
352
+
353
+ # Convert any remaining non-serializable types
354
+ for col in df.columns:
355
+ if df[col].dtype == "object":
356
+ df[col] = df[col].apply(
357
+ lambda x: self._serialize_value(x) if x is not None else None
358
+ )
359
+
360
+ # Determine size category
361
+ num_rows = len(df)
362
+ if num_rows < 1000:
363
+ size_category = "n<1K"
364
+ elif num_rows < 10000:
365
+ size_category = "1K<n<10K"
366
+ elif num_rows < 100000:
367
+ size_category = "10K<n<100K"
368
+ elif num_rows < 1000000:
369
+ size_category = "100K<n<1M"
370
+ elif num_rows < 10000000:
371
+ size_category = "1M<n<10M"
372
+ else:
373
+ size_category = "n>10M"
374
+
375
+ # Prepare tags
376
+ all_tags = tags or []
377
+ if nsfw:
378
+ all_tags.append("not-for-all-audiences")
379
+
380
+ # Default task categories
381
+ if task_categories is None:
382
+ task_categories = ["text-to-image", "image-to-image"]
383
+
384
+ # Create dataset card
385
+ card_content = f"""---
386
+ license: {license or 'unknown'}
387
+ language:
388
+ - {language}
389
+ size_categories:
390
+ - {size_category}
391
+ task_categories:
392
+ {self._yaml_list(task_categories)}"""
393
+
394
+ if all_tags:
395
+ card_content += f"\ntags:\n{self._yaml_list(all_tags)}"
396
+
397
+ card_content += f"""
398
+ ---
399
+
400
+ # Caption Dataset
401
+
402
+ This dataset contains {num_rows:,} captioned items exported from CaptionFlow.
403
+
404
+ ## Dataset Structure
405
+
406
+ ### Data Fields
407
+
408
+ """
409
+
410
+ # Add field descriptions
411
+ for col in df.columns:
412
+ dtype = str(df[col].dtype)
413
+ if col in self.contents.output_fields:
414
+ card_content += f"- `{col}`: List of captions/outputs\n"
415
+ else:
416
+ card_content += f"- `{col}`: {dtype}\n"
417
+
418
+ if self.contents.metadata:
419
+ card_content += "\n## Export Information\n\n"
420
+ if "export_timestamp" in self.contents.metadata:
421
+ card_content += (
422
+ f"- Export timestamp: {self.contents.metadata['export_timestamp']}\n"
423
+ )
424
+ if "field_stats" in self.contents.metadata:
425
+ card_content += "\n### Field Statistics\n\n"
426
+ for field, stats in self.contents.metadata["field_stats"].items():
427
+ card_content += f"- `{field}`: {stats['total_items']:,} items across {stats['rows_with_data']:,} rows\n"
428
+
429
+ # Create temporary parquet file
430
+ import tempfile
431
+
432
+ with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp_file:
433
+ temp_path = Path(tmp_file.name)
434
+
435
+ try:
436
+ # Write parquet file
437
+ table = pa.Table.from_pandas(df)
438
+ pq.write_table(table, temp_path, compression="snappy")
439
+
440
+ # Upload parquet file
441
+ api.upload_file(
442
+ path_or_fileobj=str(temp_path),
443
+ path_in_repo="data.parquet",
444
+ repo_id=dataset_name,
445
+ repo_type="dataset",
446
+ token=token,
447
+ )
448
+
449
+ # Create and upload dataset card
450
+ card = DatasetCard(card_content)
451
+ card.push_to_hub(dataset_name, token=token)
452
+
453
+ dataset_url = f"https://huggingface.co/datasets/{dataset_name}"
454
+ logger.info(f"Successfully uploaded dataset to: {dataset_url}")
455
+
456
+ return dataset_url
457
+
458
+ finally:
459
+ # Clean up temp file
460
+ if temp_path.exists():
461
+ temp_path.unlink()
462
+
463
+ def _yaml_list(self, items: List[str]) -> str:
464
+ """Format a list for YAML."""
465
+ return "\n".join(f"- {item}" for item in items)
466
+
467
+
468
+ # Addition to StorageManager class
469
+ async def get_storage_contents(
470
+ self,
471
+ limit: Optional[int] = None,
472
+ columns: Optional[List[str]] = None,
473
+ include_metadata: bool = True,
474
+ ) -> StorageContents:
475
+ """Retrieve storage contents for export.
476
+
477
+ Args:
478
+ limit: Maximum number of rows to retrieve
479
+ columns: Specific columns to include (None for all)
480
+ include_metadata: Whether to include metadata in the result
481
+
482
+ Returns:
483
+ StorageContents instance with the requested data
484
+ """
485
+ if not self.captions_path.exists():
486
+ return StorageContents(
487
+ rows=[],
488
+ columns=[],
489
+ output_fields=list(self.known_output_fields),
490
+ total_rows=0,
491
+ metadata={"message": "No captions file found"},
492
+ )
493
+
494
+ # Flush buffers first to ensure all data is on disk
495
+ await self.checkpoint()
496
+
497
+ # Determine columns to read
498
+ if columns:
499
+ # Validate requested columns exist
500
+ table_metadata = pq.read_metadata(self.captions_path)
501
+ available_columns = set(table_metadata.schema.names)
502
+ invalid_columns = set(columns) - available_columns
503
+ if invalid_columns:
504
+ raise ValueError(f"Columns not found: {invalid_columns}")
505
+ columns_to_read = columns
506
+ else:
507
+ # Read all columns
508
+ columns_to_read = None
509
+
510
+ # Read the table
511
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
512
+ df = table.to_pandas()
513
+
514
+ # Apply limit if specified
515
+ if limit:
516
+ df = df.head(limit)
517
+
518
+ # Convert to list of dicts
519
+ rows = df.to_dict("records")
520
+
521
+ # Parse metadata JSON strings back to dicts if present
522
+ if "metadata" in df.columns:
523
+ for row in rows:
524
+ if row.get("metadata"):
525
+ try:
526
+ row["metadata"] = json.loads(row["metadata"])
527
+ except:
528
+ pass # Keep as string if parsing fails
529
+
530
+ # Prepare metadata
531
+ metadata = {}
532
+ if include_metadata:
533
+ stats = await self.get_caption_stats()
534
+ metadata.update(
535
+ {
536
+ "export_timestamp": pd.Timestamp.now().isoformat(),
537
+ "total_available_rows": stats.get("total_rows", 0),
538
+ "rows_exported": len(rows),
539
+ "storage_path": str(self.captions_path),
540
+ "field_stats": stats.get("field_stats", {}),
541
+ }
542
+ )
543
+
544
+ return StorageContents(
545
+ rows=rows,
546
+ columns=list(df.columns),
547
+ output_fields=list(self.known_output_fields),
548
+ total_rows=len(df),
549
+ metadata=metadata,
550
+ )
@@ -15,7 +15,7 @@ from collections import defaultdict, deque
15
15
  import time
16
16
  import numpy as np
17
17
 
18
- from .models import Job, Caption, Contributor, JobStatus, JobId
18
+ from ..models import Job, Caption, Contributor, StorageContents, JobId
19
19
 
20
20
  logger = logging.getLogger(__name__)
21
21
  logger.setLevel(logging.INFO)
@@ -575,6 +575,7 @@ class StorageManager:
575
575
  logger.info("Starting storage optimization...")
576
576
 
577
577
  # Read the full table
578
+ backup_path = None
578
579
  table = pq.read_table(self.captions_path)
579
580
  df = table.to_pandas()
580
581
  original_columns = len(df.columns)
@@ -790,6 +791,89 @@ class StorageManager:
790
791
 
791
792
  return job_ids
792
793
 
794
+ async def get_storage_contents(
795
+ self,
796
+ limit: Optional[int] = None,
797
+ columns: Optional[List[str]] = None,
798
+ include_metadata: bool = True,
799
+ ) -> StorageContents:
800
+ """Retrieve storage contents for export.
801
+
802
+ Args:
803
+ limit: Maximum number of rows to retrieve
804
+ columns: Specific columns to include (None for all)
805
+ include_metadata: Whether to include metadata in the result
806
+
807
+ Returns:
808
+ StorageContents instance with the requested data
809
+ """
810
+ if not self.captions_path.exists():
811
+ return StorageContents(
812
+ rows=[],
813
+ columns=[],
814
+ output_fields=list(self.known_output_fields),
815
+ total_rows=0,
816
+ metadata={"message": "No captions file found"},
817
+ )
818
+
819
+ # Flush buffers first to ensure all data is on disk
820
+ await self.checkpoint()
821
+
822
+ # Determine columns to read
823
+ if columns:
824
+ # Validate requested columns exist
825
+ table_metadata = pq.read_metadata(self.captions_path)
826
+ available_columns = set(table_metadata.schema.names)
827
+ invalid_columns = set(columns) - available_columns
828
+ if invalid_columns:
829
+ raise ValueError(f"Columns not found: {invalid_columns}")
830
+ columns_to_read = columns
831
+ else:
832
+ # Read all columns
833
+ columns_to_read = None
834
+
835
+ # Read the table
836
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
837
+ df = table.to_pandas()
838
+
839
+ # Apply limit if specified
840
+ if limit:
841
+ df = df.head(limit)
842
+
843
+ # Convert to list of dicts
844
+ rows = df.to_dict("records")
845
+
846
+ # Parse metadata JSON strings back to dicts if present
847
+ if "metadata" in df.columns:
848
+ for row in rows:
849
+ if row.get("metadata"):
850
+ try:
851
+ row["metadata"] = json.loads(row["metadata"])
852
+ except:
853
+ pass # Keep as string if parsing fails
854
+
855
+ # Prepare metadata
856
+ metadata = {}
857
+ if include_metadata:
858
+ stats = await self.get_caption_stats()
859
+ metadata.update(
860
+ {
861
+ "export_timestamp": pd.Timestamp.now().isoformat(),
862
+ "total_available_rows": stats.get("total_rows", 0),
863
+ "rows_exported": len(rows),
864
+ "storage_path": str(self.captions_path),
865
+ "field_stats": stats.get("field_stats", {}),
866
+ }
867
+ )
868
+
869
+ return StorageContents(
870
+ rows=rows,
871
+ columns=list(df.columns),
872
+ output_fields=list(self.known_output_fields),
873
+ total_rows=len(df),
874
+ metadata=metadata,
875
+ )
876
+
793
877
  async def get_processed_jobs_for_chunk(self, chunk_id: str) -> Set[str]:
794
878
  """Get all processed job_ids for a given chunk."""
795
879
  if not self.captions_path.exists():