caption-flow 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +934 -415
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +439 -67
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +265 -90
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/METADATA +5 -4
  28. caption_flow-0.4.0.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,51 +1,453 @@
1
- """Storage exporter for converting Parquet data to various formats."""
1
+ """Storage exporter for Lance datasets to various formats."""
2
2
 
3
- import json
4
3
  import csv
5
- from pathlib import Path
6
- from typing import List, Dict, Any, Optional, Union
7
- from dataclasses import dataclass, field
4
+ import json
8
5
  import logging
9
- import pandas as pd
6
+ import os
7
+ import tempfile
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Union
10
+ from urllib.parse import urlparse
11
+
12
+ import lance
10
13
  import numpy as np
11
- from ..models import StorageContents, ExportError
14
+ import pandas as pd
15
+
16
+ from ..models import ExportError, StorageContents
17
+ from .manager import StorageManager
12
18
 
13
19
  logger = logging.getLogger(__name__)
20
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
14
21
 
15
22
 
16
- class StorageExporter:
17
- """Exports StorageContents to various formats."""
23
+ class LanceStorageExporter:
24
+ """Exports Lance storage contents to various formats."""
18
25
 
19
- def __init__(self, contents: StorageContents):
20
- """Initialize exporter with storage contents.
26
+ def __init__(self, storage_manager: StorageManager):
27
+ """Initialize exporter with storage manager.
28
+
29
+ Args:
30
+ ----
31
+ storage_manager: StorageManager instance
32
+
33
+ """
34
+ self.storage_manager = storage_manager
35
+
36
+ async def export_shard(
37
+ self,
38
+ shard_name: str,
39
+ format: str,
40
+ output_path: Union[str, Path],
41
+ columns: Optional[List[str]] = None,
42
+ limit: Optional[int] = None,
43
+ **kwargs,
44
+ ) -> int:
45
+ """Export a single shard to specified format.
46
+
47
+ Args:
48
+ ----
49
+ shard_name: Name of the shard to export
50
+ format: Export format ('jsonl', 'json', 'csv', 'parquet', 'txt')
51
+ output_path: Output file or directory path
52
+ columns: Specific columns to export
53
+ limit: Maximum number of rows to export
54
+ **kwargs: Format-specific options
55
+
56
+ Returns:
57
+ -------
58
+ Number of items exported
59
+
60
+ """
61
+ logger.debug(f"Getting shard contents for {shard_name}")
62
+ await self.storage_manager.initialize()
63
+ contents = await self.storage_manager.get_shard_contents(
64
+ shard_name, limit=limit, columns=columns
65
+ )
66
+
67
+ if not contents.rows:
68
+ logger.warning(f"No data to export for shard {shard_name}")
69
+ return 0
70
+
71
+ exporter = StorageExporter(contents)
72
+
73
+ # Add shard suffix to output path
74
+ output_path = Path(output_path)
75
+ if format in ["jsonl", "csv", "parquet"]:
76
+ # Single file formats - add shard name to filename
77
+ if output_path.suffix:
78
+ output_file = (
79
+ output_path.parent / f"{output_path.stem}_{shard_name}{output_path.suffix}"
80
+ )
81
+ else:
82
+ output_file = output_path / f"{shard_name}.{format}"
83
+ else:
84
+ # Directory-based formats
85
+ output_file = output_path / shard_name
86
+
87
+ # Export based on format
88
+ if format == "jsonl":
89
+ return exporter.to_jsonl(output_file)
90
+ elif format == "json":
91
+ return exporter.to_json(output_file, kwargs.get("filename_column", "filename"))
92
+ elif format == "csv":
93
+ return exporter.to_csv(output_file)
94
+ elif format == "parquet":
95
+ return await self.export_shard_to_parquet(shard_name, output_file, columns, limit)
96
+ elif format == "txt":
97
+ return exporter.to_txt(
98
+ output_file,
99
+ kwargs.get("filename_column", "filename"),
100
+ kwargs.get("export_column", "captions"),
101
+ )
102
+ else:
103
+ raise ValueError(f"Unsupported format: {format}")
104
+
105
+ async def export_all_shards(
106
+ self,
107
+ format: str,
108
+ output_path: Union[str, Path],
109
+ columns: Optional[List[str]] = None,
110
+ limit_per_shard: Optional[int] = None,
111
+ shard_filter: Optional[List[str]] = None,
112
+ **kwargs,
113
+ ) -> Dict[str, int]:
114
+ """Export all shards (or filtered shards) to specified format.
115
+
116
+ Args:
117
+ ----
118
+ format: Export format
119
+ output_path: Base output path
120
+ columns: Columns to export
121
+ limit_per_shard: Max rows per shard
122
+ shard_filter: List of specific shards to export
123
+ **kwargs: Format-specific options
124
+
125
+ Returns:
126
+ -------
127
+ Dictionary mapping shard names to export counts
128
+
129
+ """
130
+ results = {}
131
+
132
+ # Get shards to export
133
+ await self.storage_manager.initialize()
134
+ if shard_filter:
135
+ shards = [s for s in shard_filter if s in self.storage_manager.shard_datasets]
136
+ else:
137
+ shards = list(self.storage_manager.shard_datasets.keys())
138
+
139
+ logger.info(f"Exporting {len(shards)} shards to {format} format")
140
+
141
+ for shard_name in shards:
142
+ try:
143
+ count = await self.export_shard(
144
+ shard_name,
145
+ format,
146
+ output_path,
147
+ columns=columns,
148
+ limit=limit_per_shard,
149
+ **kwargs,
150
+ )
151
+ results[shard_name] = count
152
+ logger.info(f"Exported {count} items from shard {shard_name}")
153
+ except Exception as e:
154
+ logger.error(f"Failed to export shard {shard_name}: {e}")
155
+ results[shard_name] = 0
156
+
157
+ return results
158
+
159
+ async def export_shard_to_parquet(
160
+ self,
161
+ shard_name: str,
162
+ output_path: Union[str, Path],
163
+ columns: Optional[List[str]] = None,
164
+ limit: Optional[int] = None,
165
+ ) -> int:
166
+ """Export a shard directly to Parquet format.
167
+
168
+ This is efficient as Lance is already columnar.
169
+ """
170
+ if shard_name not in self.storage_manager.shard_datasets:
171
+ raise ValueError(f"Shard {shard_name} not found")
172
+
173
+ dataset = self.storage_manager.shard_datasets[shard_name]
174
+
175
+ # Build scanner
176
+ scanner = dataset.scanner(columns=columns)
177
+ if limit:
178
+ scanner = scanner.limit(limit)
179
+
180
+ # Get table and write to parquet
181
+ table = scanner.to_table()
182
+
183
+ import pyarrow.parquet as pq
184
+
185
+ pq.write_table(table, str(output_path), compression="snappy")
186
+
187
+ return table.num_rows
188
+
189
+ async def export_to_lance(
190
+ self,
191
+ output_path: Union[str, Path],
192
+ columns: Optional[List[str]] = None,
193
+ shard_filter: Optional[List[str]] = None,
194
+ ) -> int:
195
+ """Export to a new Lance dataset, optionally filtering shards.
196
+
197
+ Args:
198
+ ----
199
+ output_path: Path for the output Lance dataset
200
+ columns: Specific columns to include
201
+ shard_filter: List of shard names to include
202
+
203
+ Returns:
204
+ -------
205
+ Total number of rows exported
206
+
207
+ """
208
+ output_path = Path(output_path)
209
+ if output_path.exists():
210
+ raise ValueError(f"Output path already exists: {output_path}")
211
+
212
+ # Get shards to export
213
+ if shard_filter:
214
+ shards = [s for s in shard_filter if s in self.storage_manager.shard_datasets]
215
+ else:
216
+ shards = list(self.storage_manager.shard_datasets.keys())
217
+
218
+ if not shards:
219
+ raise ValueError("No shards to export")
220
+
221
+ total_rows = 0
222
+ first_shard = True
223
+
224
+ for shard_name in shards:
225
+ dataset = self.storage_manager.shard_datasets[shard_name]
226
+
227
+ # Build scanner
228
+ scanner = dataset.scanner(columns=columns)
229
+ table = scanner.to_table()
230
+
231
+ if first_shard:
232
+ # Create new dataset
233
+ lance.write_dataset(table, str(output_path), mode="create")
234
+ first_shard = False
235
+ else:
236
+ # Append to existing
237
+ lance.write_dataset(table, str(output_path), mode="append")
238
+
239
+ total_rows += table.num_rows
240
+ logger.info(f"Exported {table.num_rows} rows from shard {shard_name}")
241
+
242
+ logger.info(f"Created Lance dataset at {output_path} with {total_rows} rows")
243
+ return total_rows
244
+
245
+ async def export_to_huggingface_hub(
246
+ self,
247
+ dataset_name: str,
248
+ token: Optional[str] = None,
249
+ license: str = "apache-2.0",
250
+ private: bool = False,
251
+ nsfw: bool = False,
252
+ tags: Optional[List[str]] = None,
253
+ language: str = "en",
254
+ task_categories: Optional[List[str]] = None,
255
+ shard_filter: Optional[List[str]] = None,
256
+ max_shard_size_gb: float = 2.0,
257
+ ) -> str:
258
+ """Export to Hugging Face Hub with per-shard parquet files.
21
259
 
22
260
  Args:
23
- contents: StorageContents instance to export
261
+ ----
262
+ dataset_name: Name for the dataset (e.g., "username/dataset-name")
263
+ token: Hugging Face API token
264
+ license: License for the dataset
265
+ private: Whether to make the dataset private
266
+ nsfw: Whether to add not-for-all-audiences tag
267
+ tags: Additional tags
268
+ language: Language code
269
+ task_categories: Task categories
270
+ shard_filter: Specific shards to export
271
+ max_shard_size_gb: Max size per parquet file in GB
272
+
273
+ Returns:
274
+ -------
275
+ URL of the uploaded dataset
276
+
24
277
  """
278
+ try:
279
+ import pyarrow.parquet as pq
280
+ from huggingface_hub import DatasetCard, HfApi, create_repo
281
+ except ImportError:
282
+ raise ExportError(
283
+ "huggingface_hub is required for HF export. "
284
+ "Install with: pip install huggingface_hub"
285
+ )
286
+
287
+ api = HfApi(token=token)
288
+
289
+ # Check/create repo
290
+ try:
291
+ api.dataset_info(dataset_name)
292
+ logger.info(f"Dataset {dataset_name} already exists, will update it")
293
+ except:
294
+ logger.info(f"Creating new dataset: {dataset_name}")
295
+ create_repo(repo_id=dataset_name, repo_type="dataset", private=private, token=token)
296
+
297
+ # Get shards to export
298
+ if shard_filter:
299
+ shards = [s for s in shard_filter if s in self.storage_manager.shard_datasets]
300
+ else:
301
+ shards = sorted(self.storage_manager.shard_datasets.keys())
302
+
303
+ # Export each shard as a separate parquet file
304
+ total_rows = 0
305
+
306
+ with tempfile.TemporaryDirectory() as tmpdir:
307
+ tmpdir = Path(tmpdir)
308
+ data_dir = tmpdir / "data"
309
+ data_dir.mkdir(exist_ok=True)
310
+
311
+ # Export all shards to the data directory
312
+ for shard_name in shards:
313
+ # Export shard to parquet
314
+ parquet_path = data_dir / f"{shard_name}.parquet"
315
+ rows = await self.export_shard_to_parquet(shard_name, parquet_path)
316
+
317
+ if rows > 0:
318
+ # Check file size
319
+ file_size_gb = parquet_path.stat().st_size / (1024**3)
320
+ if file_size_gb > max_shard_size_gb:
321
+ logger.warning(
322
+ f"Shard {shard_name} is {file_size_gb:.2f}GB, "
323
+ f"exceeds limit of {max_shard_size_gb}GB"
324
+ )
325
+
326
+ total_rows += rows
327
+ logger.info(f"Prepared {shard_name}: {rows} rows, {file_size_gb:.2f}GB")
328
+
329
+ # Create dataset card
330
+ stats = await self.storage_manager.get_caption_stats()
331
+
332
+ # Size category
333
+ if total_rows < 1000:
334
+ size_category = "n<1K"
335
+ elif total_rows < 10000:
336
+ size_category = "1K<n<10K"
337
+ elif total_rows < 100000:
338
+ size_category = "10K<n<100K"
339
+ elif total_rows < 1000000:
340
+ size_category = "100K<n<1M"
341
+ elif total_rows < 1000000:
342
+ size_category = "1M<n<10M"
343
+ else:
344
+ size_category = "n>10M"
345
+
346
+ # Prepare tags
347
+ default_tags = ["lance"]
348
+ all_tags = default_tags + (tags or [])
349
+ if nsfw:
350
+ all_tags.append("not-for-all-audiences")
351
+
352
+ # Default task categories
353
+ if task_categories is None:
354
+ task_categories = ["text-to-image", "image-to-image"]
355
+
356
+ # Create card content
357
+ card_content = f"""---
358
+ license: {license}
359
+ language:
360
+ - {language}
361
+ size_categories:
362
+ - {size_category}
363
+ task_categories:
364
+ {self._yaml_list(task_categories)}"""
365
+
366
+ if all_tags:
367
+ card_content += f"\ntags:\n{self._yaml_list(all_tags)}"
368
+
369
+ card_content += f"""
370
+ ---
371
+
372
+ # Caption Dataset
373
+
374
+ This dataset contains {total_rows:,} captioned items exported from CaptionFlow.
375
+
376
+ ## Dataset Structure
377
+
378
+ """
379
+
380
+ card_content += "\n\n### Data Fields\n\n"
381
+
382
+ # Add field descriptions
383
+ all_fields = set()
384
+ for field, _ in self.storage_manager.base_caption_fields:
385
+ all_fields.add(field)
386
+ for fields in self.storage_manager.shard_output_fields.values():
387
+ all_fields.update(fields)
388
+
389
+ for field in sorted(all_fields):
390
+ if field in stats.get("output_fields", []):
391
+ card_content += f"- `{field}`: List of captions/outputs\n"
392
+ else:
393
+ card_content += f"- `{field}`\n"
394
+
395
+ if stats.get("field_stats"):
396
+ card_content += "\n### Output Field Statistics\n\n"
397
+ for field, count in stats["field_stats"].items():
398
+ card_content += f"- `{field}`: {count:,} total items\n"
399
+
400
+ # Save README.md
401
+ readme_path = tmpdir / "README.md"
402
+ with open(readme_path, "w", encoding="utf-8") as f:
403
+ f.write(card_content)
404
+
405
+ # Upload the entire folder at once
406
+ logger.info(f"Uploading dataset to {dataset_name}...")
407
+ api.upload_large_folder(
408
+ repo_id=dataset_name,
409
+ folder_path=str(tmpdir),
410
+ repo_type="dataset",
411
+ )
412
+
413
+ dataset_url = f"https://huggingface.co/datasets/{dataset_name}"
414
+ logger.info(f"Successfully uploaded dataset to: {dataset_url}")
415
+
416
+ return dataset_url
417
+
418
+ def _yaml_list(self, items: List[str]) -> str:
419
+ """Format a list for YAML."""
420
+ return "\n".join(f"- {item}" for item in items)
421
+
422
+
423
+ class StorageExporter:
424
+ """Legacy exporter for StorageContents objects."""
425
+
426
+ def __init__(self, contents: StorageContents):
25
427
  self.contents = contents
26
428
  self._validate_contents()
27
429
 
28
430
  def _validate_contents(self):
29
- """Validate that contents are suitable for export."""
30
431
  if not self.contents.rows:
31
432
  logger.warning("No rows to export")
32
433
  if not self.contents.columns:
33
434
  raise ExportError("No columns defined for export")
34
435
 
35
436
  def _flatten_lists(self, value: Any) -> str:
36
- """Convert list values to newline-separated strings."""
37
437
  if isinstance(value, list):
38
- # Strip newlines from each element and join
39
438
  return "\n".join(str(item).replace("\n", " ") for item in value)
40
439
  return str(value) if value is not None else ""
41
440
 
42
441
  def _serialize_value(self, value: Any) -> Any:
43
- """Convert values to JSON-serializable format."""
442
+ import datetime as dt
443
+
44
444
  if pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(value, pd.Timestamp):
45
445
  return value.isoformat()
46
- elif isinstance(value, np.integer):
446
+ elif isinstance(value, (dt.datetime, dt.date)):
447
+ return value.isoformat()
448
+ elif isinstance(value, (np.integer, np.int64)):
47
449
  return int(value)
48
- elif isinstance(value, np.floating):
450
+ elif isinstance(value, (np.floating, np.float64)):
49
451
  return float(value)
50
452
  elif isinstance(value, np.ndarray):
51
453
  return value.tolist()
@@ -56,23 +458,13 @@ class StorageExporter:
56
458
  return value
57
459
 
58
460
  def to_jsonl(self, output_path: Union[str, Path]) -> int:
59
- """Export to JSONL (JSON Lines) format.
60
-
61
- Args:
62
- output_path: Path to output JSONL file
63
-
64
- Returns:
65
- Number of rows exported
66
- """
67
461
  output_path = Path(output_path)
68
462
  output_path.parent.mkdir(parents=True, exist_ok=True)
69
463
 
70
464
  rows_written = 0
71
465
  with open(output_path, "w", encoding="utf-8") as f:
72
466
  for row in self.contents.rows:
73
- # Convert non-serializable values
74
467
  serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
75
- # Write each row as a JSON object on its own line
76
468
  json_line = json.dumps(serializable_row, ensure_ascii=False)
77
469
  f.write(json_line + "\n")
78
470
  rows_written += 1
@@ -81,18 +473,12 @@ class StorageExporter:
81
473
  return rows_written
82
474
 
83
475
  def _get_filename_from_row(self, row: Dict[str, Any], filename_column: str) -> Optional[str]:
84
- """Extract filename from row, falling back to URL if needed."""
85
- # Try the specified filename column first
86
476
  filename = row.get(filename_column)
87
477
  if filename:
88
478
  return filename
89
479
 
90
- # Fall back to URL if available
91
480
  url = row.get("url")
92
481
  if url:
93
- # Extract filename from URL path
94
- from urllib.parse import urlparse
95
-
96
482
  parsed = urlparse(str(url))
97
483
  path_parts = parsed.path.rstrip("/").split("/")
98
484
  if path_parts and path_parts[-1]:
@@ -101,23 +487,11 @@ class StorageExporter:
101
487
  return None
102
488
 
103
489
  def to_json(self, output_dir: Union[str, Path], filename_column: str = "filename") -> int:
104
- """Export to individual JSON files based on filename column.
105
-
106
- Args:
107
- output_dir: Directory to write JSON files
108
- filename_column: Column containing the base filename
109
-
110
- Returns:
111
- Number of files created
112
- """
113
490
  output_dir = Path(output_dir)
114
491
  output_dir.mkdir(parents=True, exist_ok=True)
115
492
 
116
- # Check if we need to fall back to URL
117
- using_url_fallback = False
118
493
  if filename_column not in self.contents.columns and "url" in self.contents.columns:
119
494
  logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
120
- using_url_fallback = True
121
495
  elif filename_column not in self.contents.columns:
122
496
  raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
123
497
 
@@ -128,17 +502,13 @@ class StorageExporter:
128
502
  filename = self._get_filename_from_row(row, filename_column)
129
503
  if not filename:
130
504
  skipped_count += 1
131
- logger.warning(f"Skipping row with no extractable filename")
132
505
  continue
133
506
 
134
- # Create JSON filename from original filename
135
507
  base_name = Path(filename).stem
136
508
  json_path = output_dir / f"{base_name}.json"
137
509
 
138
- # Convert non-serializable values
139
510
  serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
140
511
 
141
- # Write row data as JSON
142
512
  with open(json_path, "w", encoding="utf-8") as f:
143
513
  json.dump(serializable_row, f, ensure_ascii=False, indent=2)
144
514
 
@@ -151,14 +521,6 @@ class StorageExporter:
151
521
  return files_created
152
522
 
153
523
  def to_csv(self, output_path: Union[str, Path]) -> int:
154
- """Export to CSV format, skipping complex columns.
155
-
156
- Args:
157
- output_path: Path to output CSV file
158
-
159
- Returns:
160
- Number of rows exported
161
- """
162
524
  output_path = Path(output_path)
163
525
  output_path.parent.mkdir(parents=True, exist_ok=True)
164
526
 
@@ -166,12 +528,10 @@ class StorageExporter:
166
528
  complex_columns = set()
167
529
  csv_safe_columns = []
168
530
 
169
- # Check column types by sampling data
170
531
  sample_size = min(10, len(self.contents.rows))
171
532
  for row in self.contents.rows[:sample_size]:
172
533
  for col, value in row.items():
173
534
  if col not in complex_columns and value is not None:
174
- # Skip dictionaries and non-output field lists
175
535
  if isinstance(value, dict):
176
536
  complex_columns.add(col)
177
537
  logger.warning(
@@ -185,19 +545,16 @@ class StorageExporter:
185
545
  "Consider using JSONL format for complete data export."
186
546
  )
187
547
 
188
- # Build list of CSV-safe columns
189
548
  csv_safe_columns = [col for col in self.contents.columns if col not in complex_columns]
190
549
 
191
550
  if not csv_safe_columns:
192
551
  raise ExportError("No columns suitable for CSV export. Use JSONL format instead.")
193
552
 
194
- # Prepare rows for CSV export with safe columns only
195
553
  csv_rows = []
196
554
  for row in self.contents.rows:
197
555
  csv_row = {}
198
556
  for col in csv_safe_columns:
199
557
  value = row.get(col)
200
- # Handle list values (like captions) by joining with newlines
201
558
  if isinstance(value, list):
202
559
  csv_row[col] = self._flatten_lists(value)
203
560
  elif pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(
@@ -208,13 +565,11 @@ class StorageExporter:
208
565
  csv_row[col] = value
209
566
  csv_rows.append(csv_row)
210
567
 
211
- # Write to CSV
212
568
  with open(output_path, "w", encoding="utf-8", newline="") as f:
213
569
  writer = csv.DictWriter(f, fieldnames=csv_safe_columns)
214
570
  writer.writeheader()
215
571
  writer.writerows(csv_rows)
216
572
 
217
- # Log results
218
573
  if complex_columns:
219
574
  skipped_msg = f"Skipped {len(complex_columns)} complex columns: {', '.join(sorted(complex_columns))}"
220
575
  logger.warning(skipped_msg)
@@ -232,29 +587,15 @@ class StorageExporter:
232
587
  filename_column: str = "filename",
233
588
  export_column: str = "captions",
234
589
  ) -> int:
235
- """Export specific column to individual text files.
236
-
237
- Args:
238
- output_dir: Directory to write text files
239
- filename_column: Column containing the base filename
240
- export_column: Column to export to text files
241
-
242
- Returns:
243
- Number of files created
244
- """
245
590
  output_dir = Path(output_dir)
246
591
  output_dir.mkdir(parents=True, exist_ok=True)
247
592
 
248
- # Check if we need to fall back to URL
249
- using_url_fallback = False
250
593
  if filename_column not in self.contents.columns and "url" in self.contents.columns:
251
594
  logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
252
- using_url_fallback = True
253
595
  elif filename_column not in self.contents.columns:
254
596
  raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
255
597
 
256
598
  if export_column not in self.contents.columns:
257
- # Check if it's an output field
258
599
  if export_column not in self.contents.output_fields:
259
600
  raise ExportError(f"Column '{export_column}' not found in data")
260
601
 
@@ -266,20 +607,16 @@ class StorageExporter:
266
607
  filename = self._get_filename_from_row(row, filename_column)
267
608
  if not filename:
268
609
  skipped_no_filename += 1
269
- logger.warning(f"Skipping row with no extractable filename")
270
610
  continue
271
611
 
272
612
  content = row.get(export_column)
273
613
  if content is None:
274
614
  skipped_no_content += 1
275
- logger.warning(f"No {export_column} for {filename}")
276
615
  continue
277
616
 
278
- # Create text filename from original filename
279
617
  base_name = Path(filename).stem
280
618
  txt_path = output_dir / f"{base_name}.txt"
281
619
 
282
- # Write content
283
620
  with open(txt_path, "w", encoding="utf-8") as f:
284
621
  f.write(self._flatten_lists(content))
285
622
 
@@ -292,259 +629,3 @@ class StorageExporter:
292
629
 
293
630
  logger.info(f"Created {files_created} text files in: {output_dir}")
294
631
  return files_created
295
-
296
- def to_huggingface_hub(
297
- self,
298
- dataset_name: str,
299
- token: Optional[str] = None,
300
- license: Optional[str] = None,
301
- private: bool = False,
302
- nsfw: bool = False,
303
- tags: Optional[List[str]] = None,
304
- language: str = "en",
305
- task_categories: Optional[List[str]] = None,
306
- ) -> str:
307
- """Export to Hugging Face Hub as a dataset.
308
-
309
- Args:
310
- dataset_name: Name for the dataset (e.g., "username/dataset-name")
311
- token: Hugging Face API token
312
- license: License for the dataset (required for new repos)
313
- private: Whether to make the dataset private
314
- nsfw: Whether to add not-for-all-audiences tag
315
- tags: Additional tags for the dataset
316
- language: Language code (default: "en")
317
- task_categories: Task categories (default: ["text-to-image", "image-to-image"])
318
-
319
- Returns:
320
- URL of the uploaded dataset
321
- """
322
- try:
323
- from huggingface_hub import HfApi, DatasetCard, create_repo
324
- import pyarrow as pa
325
- import pyarrow.parquet as pq
326
- except ImportError:
327
- raise ExportError(
328
- "huggingface_hub and pyarrow are required for HF export. "
329
- "Install with: pip install huggingface_hub pyarrow"
330
- )
331
-
332
- # Initialize HF API
333
- api = HfApi(token=token)
334
-
335
- # Check if repo exists
336
- repo_exists = False
337
- try:
338
- api.dataset_info(dataset_name)
339
- repo_exists = True
340
- logger.info(f"Dataset {dataset_name} already exists, will update it")
341
- except:
342
- logger.info(f"Creating new dataset: {dataset_name}")
343
- if not license:
344
- raise ExportError("License is required when creating a new dataset")
345
-
346
- # Create repo if it doesn't exist
347
- if not repo_exists:
348
- create_repo(repo_id=dataset_name, repo_type="dataset", private=private, token=token)
349
-
350
- # Prepare data for parquet
351
- df = pd.DataFrame(self.contents.rows)
352
-
353
- # Convert any remaining non-serializable types
354
- for col in df.columns:
355
- if df[col].dtype == "object":
356
- df[col] = df[col].apply(
357
- lambda x: self._serialize_value(x) if x is not None else None
358
- )
359
-
360
- # Determine size category
361
- num_rows = len(df)
362
- if num_rows < 1000:
363
- size_category = "n<1K"
364
- elif num_rows < 10000:
365
- size_category = "1K<n<10K"
366
- elif num_rows < 100000:
367
- size_category = "10K<n<100K"
368
- elif num_rows < 1000000:
369
- size_category = "100K<n<1M"
370
- elif num_rows < 10000000:
371
- size_category = "1M<n<10M"
372
- else:
373
- size_category = "n>10M"
374
-
375
- # Prepare tags
376
- all_tags = tags or []
377
- if nsfw:
378
- all_tags.append("not-for-all-audiences")
379
-
380
- # Default task categories
381
- if task_categories is None:
382
- task_categories = ["text-to-image", "image-to-image"]
383
-
384
- # Create dataset card
385
- card_content = f"""---
386
- license: {license or 'unknown'}
387
- language:
388
- - {language}
389
- size_categories:
390
- - {size_category}
391
- task_categories:
392
- {self._yaml_list(task_categories)}"""
393
-
394
- if all_tags:
395
- card_content += f"\ntags:\n{self._yaml_list(all_tags)}"
396
-
397
- card_content += f"""
398
- ---
399
-
400
- # Caption Dataset
401
-
402
- This dataset contains {num_rows:,} captioned items exported from CaptionFlow.
403
-
404
- ## Dataset Structure
405
-
406
- ### Data Fields
407
-
408
- """
409
-
410
- # Add field descriptions
411
- for col in df.columns:
412
- dtype = str(df[col].dtype)
413
- if col in self.contents.output_fields:
414
- card_content += f"- `{col}`: List of captions/outputs\n"
415
- else:
416
- card_content += f"- `{col}`: {dtype}\n"
417
-
418
- if self.contents.metadata:
419
- card_content += "\n## Export Information\n\n"
420
- if "export_timestamp" in self.contents.metadata:
421
- card_content += (
422
- f"- Export timestamp: {self.contents.metadata['export_timestamp']}\n"
423
- )
424
- if "field_stats" in self.contents.metadata:
425
- card_content += "\n### Field Statistics\n\n"
426
- for field, stats in self.contents.metadata["field_stats"].items():
427
- card_content += f"- `{field}`: {stats['total_items']:,} items across {stats['rows_with_data']:,} rows\n"
428
-
429
- # Create temporary parquet file
430
- import tempfile
431
-
432
- with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp_file:
433
- temp_path = Path(tmp_file.name)
434
-
435
- try:
436
- # Write parquet file
437
- table = pa.Table.from_pandas(df)
438
- pq.write_table(table, temp_path, compression="snappy")
439
-
440
- # Upload parquet file
441
- api.upload_file(
442
- path_or_fileobj=str(temp_path),
443
- path_in_repo="data.parquet",
444
- repo_id=dataset_name,
445
- repo_type="dataset",
446
- token=token,
447
- )
448
-
449
- # Create and upload dataset card
450
- card = DatasetCard(card_content)
451
- card.push_to_hub(dataset_name, token=token)
452
-
453
- dataset_url = f"https://huggingface.co/datasets/{dataset_name}"
454
- logger.info(f"Successfully uploaded dataset to: {dataset_url}")
455
-
456
- return dataset_url
457
-
458
- finally:
459
- # Clean up temp file
460
- if temp_path.exists():
461
- temp_path.unlink()
462
-
463
- def _yaml_list(self, items: List[str]) -> str:
464
- """Format a list for YAML."""
465
- return "\n".join(f"- {item}" for item in items)
466
-
467
-
468
- # Addition to StorageManager class
469
- async def get_storage_contents(
470
- self,
471
- limit: Optional[int] = None,
472
- columns: Optional[List[str]] = None,
473
- include_metadata: bool = True,
474
- ) -> StorageContents:
475
- """Retrieve storage contents for export.
476
-
477
- Args:
478
- limit: Maximum number of rows to retrieve
479
- columns: Specific columns to include (None for all)
480
- include_metadata: Whether to include metadata in the result
481
-
482
- Returns:
483
- StorageContents instance with the requested data
484
- """
485
- if not self.captions_path.exists():
486
- return StorageContents(
487
- rows=[],
488
- columns=[],
489
- output_fields=list(self.known_output_fields),
490
- total_rows=0,
491
- metadata={"message": "No captions file found"},
492
- )
493
-
494
- # Flush buffers first to ensure all data is on disk
495
- await self.checkpoint()
496
-
497
- # Determine columns to read
498
- if columns:
499
- # Validate requested columns exist
500
- table_metadata = pq.read_metadata(self.captions_path)
501
- available_columns = set(table_metadata.schema.names)
502
- invalid_columns = set(columns) - available_columns
503
- if invalid_columns:
504
- raise ValueError(f"Columns not found: {invalid_columns}")
505
- columns_to_read = columns
506
- else:
507
- # Read all columns
508
- columns_to_read = None
509
-
510
- # Read the table
511
- table = pq.read_table(self.captions_path, columns=columns_to_read)
512
- df = table.to_pandas()
513
-
514
- # Apply limit if specified
515
- if limit:
516
- df = df.head(limit)
517
-
518
- # Convert to list of dicts
519
- rows = df.to_dict("records")
520
-
521
- # Parse metadata JSON strings back to dicts if present
522
- if "metadata" in df.columns:
523
- for row in rows:
524
- if row.get("metadata"):
525
- try:
526
- row["metadata"] = json.loads(row["metadata"])
527
- except:
528
- pass # Keep as string if parsing fails
529
-
530
- # Prepare metadata
531
- metadata = {}
532
- if include_metadata:
533
- stats = await self.get_caption_stats()
534
- metadata.update(
535
- {
536
- "export_timestamp": pd.Timestamp.now().isoformat(),
537
- "total_available_rows": stats.get("total_rows", 0),
538
- "rows_exported": len(rows),
539
- "storage_path": str(self.captions_path),
540
- "field_stats": stats.get("field_stats", {}),
541
- }
542
- )
543
-
544
- return StorageContents(
545
- rows=rows,
546
- columns=list(df.columns),
547
- output_fields=list(self.known_output_fields),
548
- total_rows=len(df),
549
- metadata=metadata,
550
- )