caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/__init__.py +3 -3
- caption_flow/cli.py +921 -427
- caption_flow/models.py +45 -3
- caption_flow/monitor.py +2 -3
- caption_flow/orchestrator.py +153 -104
- caption_flow/processors/__init__.py +3 -3
- caption_flow/processors/base.py +8 -7
- caption_flow/processors/huggingface.py +463 -68
- caption_flow/processors/local_filesystem.py +24 -28
- caption_flow/processors/webdataset.py +28 -22
- caption_flow/storage/exporter.py +420 -339
- caption_flow/storage/manager.py +636 -756
- caption_flow/utils/__init__.py +1 -1
- caption_flow/utils/auth.py +1 -1
- caption_flow/utils/caption_utils.py +1 -1
- caption_flow/utils/certificates.py +15 -8
- caption_flow/utils/checkpoint_tracker.py +30 -28
- caption_flow/utils/chunk_tracker.py +153 -56
- caption_flow/utils/image_processor.py +9 -9
- caption_flow/utils/json_utils.py +37 -20
- caption_flow/utils/prompt_template.py +24 -16
- caption_flow/utils/vllm_config.py +5 -4
- caption_flow/viewer.py +4 -12
- caption_flow/workers/base.py +5 -4
- caption_flow/workers/caption.py +303 -92
- caption_flow/workers/data.py +6 -8
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
- caption_flow-0.4.1.dist-info/RECORD +33 -0
- caption_flow-0.3.4.dist-info/RECORD +0 -33
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
caption_flow/storage/exporter.py
CHANGED
@@ -1,51 +1,453 @@
|
|
1
|
-
"""Storage exporter for
|
1
|
+
"""Storage exporter for Lance datasets to various formats."""
|
2
2
|
|
3
|
-
import json
|
4
3
|
import csv
|
5
|
-
|
6
|
-
from typing import List, Dict, Any, Optional, Union
|
7
|
-
from dataclasses import dataclass, field
|
4
|
+
import json
|
8
5
|
import logging
|
9
|
-
import
|
6
|
+
import os
|
7
|
+
import tempfile
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Dict, List, Optional, Union
|
10
|
+
from urllib.parse import urlparse
|
11
|
+
|
12
|
+
import lance
|
10
13
|
import numpy as np
|
11
|
-
|
14
|
+
import pandas as pd
|
15
|
+
|
16
|
+
from ..models import ExportError, StorageContents
|
17
|
+
from .manager import StorageManager
|
12
18
|
|
13
19
|
logger = logging.getLogger(__name__)
|
20
|
+
logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
|
14
21
|
|
15
22
|
|
16
|
-
class
|
17
|
-
"""Exports
|
23
|
+
class LanceStorageExporter:
|
24
|
+
"""Exports Lance storage contents to various formats."""
|
18
25
|
|
19
|
-
def __init__(self,
|
20
|
-
"""Initialize exporter with storage
|
26
|
+
def __init__(self, storage_manager: StorageManager):
|
27
|
+
"""Initialize exporter with storage manager.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
----
|
31
|
+
storage_manager: StorageManager instance
|
32
|
+
|
33
|
+
"""
|
34
|
+
self.storage_manager = storage_manager
|
35
|
+
|
36
|
+
async def export_shard(
|
37
|
+
self,
|
38
|
+
shard_name: str,
|
39
|
+
format: str,
|
40
|
+
output_path: Union[str, Path],
|
41
|
+
columns: Optional[List[str]] = None,
|
42
|
+
limit: Optional[int] = None,
|
43
|
+
**kwargs,
|
44
|
+
) -> int:
|
45
|
+
"""Export a single shard to specified format.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
----
|
49
|
+
shard_name: Name of the shard to export
|
50
|
+
format: Export format ('jsonl', 'json', 'csv', 'parquet', 'txt')
|
51
|
+
output_path: Output file or directory path
|
52
|
+
columns: Specific columns to export
|
53
|
+
limit: Maximum number of rows to export
|
54
|
+
**kwargs: Format-specific options
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
-------
|
58
|
+
Number of items exported
|
59
|
+
|
60
|
+
"""
|
61
|
+
logger.debug(f"Getting shard contents for {shard_name}")
|
62
|
+
await self.storage_manager.initialize()
|
63
|
+
contents = await self.storage_manager.get_shard_contents(
|
64
|
+
shard_name, limit=limit, columns=columns
|
65
|
+
)
|
66
|
+
|
67
|
+
if not contents.rows:
|
68
|
+
logger.warning(f"No data to export for shard {shard_name}")
|
69
|
+
return 0
|
70
|
+
|
71
|
+
exporter = StorageExporter(contents)
|
72
|
+
|
73
|
+
# Add shard suffix to output path
|
74
|
+
output_path = Path(output_path)
|
75
|
+
if format in ["jsonl", "csv", "parquet"]:
|
76
|
+
# Single file formats - add shard name to filename
|
77
|
+
if output_path.suffix:
|
78
|
+
output_file = (
|
79
|
+
output_path.parent / f"{output_path.stem}_{shard_name}{output_path.suffix}"
|
80
|
+
)
|
81
|
+
else:
|
82
|
+
output_file = output_path / f"{shard_name}.{format}"
|
83
|
+
else:
|
84
|
+
# Directory-based formats
|
85
|
+
output_file = output_path / shard_name
|
86
|
+
|
87
|
+
# Export based on format
|
88
|
+
if format == "jsonl":
|
89
|
+
return exporter.to_jsonl(output_file)
|
90
|
+
elif format == "json":
|
91
|
+
return exporter.to_json(output_file, kwargs.get("filename_column", "filename"))
|
92
|
+
elif format == "csv":
|
93
|
+
return exporter.to_csv(output_file)
|
94
|
+
elif format == "parquet":
|
95
|
+
return await self.export_shard_to_parquet(shard_name, output_file, columns, limit)
|
96
|
+
elif format == "txt":
|
97
|
+
return exporter.to_txt(
|
98
|
+
output_file,
|
99
|
+
kwargs.get("filename_column", "filename"),
|
100
|
+
kwargs.get("export_column", "captions"),
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
raise ValueError(f"Unsupported format: {format}")
|
104
|
+
|
105
|
+
async def export_all_shards(
|
106
|
+
self,
|
107
|
+
format: str,
|
108
|
+
output_path: Union[str, Path],
|
109
|
+
columns: Optional[List[str]] = None,
|
110
|
+
limit_per_shard: Optional[int] = None,
|
111
|
+
shard_filter: Optional[List[str]] = None,
|
112
|
+
**kwargs,
|
113
|
+
) -> Dict[str, int]:
|
114
|
+
"""Export all shards (or filtered shards) to specified format.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
----
|
118
|
+
format: Export format
|
119
|
+
output_path: Base output path
|
120
|
+
columns: Columns to export
|
121
|
+
limit_per_shard: Max rows per shard
|
122
|
+
shard_filter: List of specific shards to export
|
123
|
+
**kwargs: Format-specific options
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
-------
|
127
|
+
Dictionary mapping shard names to export counts
|
128
|
+
|
129
|
+
"""
|
130
|
+
results = {}
|
131
|
+
|
132
|
+
# Get shards to export
|
133
|
+
await self.storage_manager.initialize()
|
134
|
+
if shard_filter:
|
135
|
+
shards = [s for s in shard_filter if s in self.storage_manager.shard_datasets]
|
136
|
+
else:
|
137
|
+
shards = list(self.storage_manager.shard_datasets.keys())
|
138
|
+
|
139
|
+
logger.info(f"Exporting {len(shards)} shards to {format} format")
|
140
|
+
|
141
|
+
for shard_name in shards:
|
142
|
+
try:
|
143
|
+
count = await self.export_shard(
|
144
|
+
shard_name,
|
145
|
+
format,
|
146
|
+
output_path,
|
147
|
+
columns=columns,
|
148
|
+
limit=limit_per_shard,
|
149
|
+
**kwargs,
|
150
|
+
)
|
151
|
+
results[shard_name] = count
|
152
|
+
logger.info(f"Exported {count} items from shard {shard_name}")
|
153
|
+
except Exception as e:
|
154
|
+
logger.error(f"Failed to export shard {shard_name}: {e}")
|
155
|
+
results[shard_name] = 0
|
156
|
+
|
157
|
+
return results
|
158
|
+
|
159
|
+
async def export_shard_to_parquet(
|
160
|
+
self,
|
161
|
+
shard_name: str,
|
162
|
+
output_path: Union[str, Path],
|
163
|
+
columns: Optional[List[str]] = None,
|
164
|
+
limit: Optional[int] = None,
|
165
|
+
) -> int:
|
166
|
+
"""Export a shard directly to Parquet format.
|
167
|
+
|
168
|
+
This is efficient as Lance is already columnar.
|
169
|
+
"""
|
170
|
+
if shard_name not in self.storage_manager.shard_datasets:
|
171
|
+
raise ValueError(f"Shard {shard_name} not found")
|
172
|
+
|
173
|
+
dataset = self.storage_manager.shard_datasets[shard_name]
|
174
|
+
|
175
|
+
# Build scanner
|
176
|
+
scanner = dataset.scanner(columns=columns)
|
177
|
+
if limit:
|
178
|
+
scanner = scanner.limit(limit)
|
179
|
+
|
180
|
+
# Get table and write to parquet
|
181
|
+
table = scanner.to_table()
|
182
|
+
|
183
|
+
import pyarrow.parquet as pq
|
184
|
+
|
185
|
+
pq.write_table(table, str(output_path), compression="snappy")
|
186
|
+
|
187
|
+
return table.num_rows
|
188
|
+
|
189
|
+
async def export_to_lance(
|
190
|
+
self,
|
191
|
+
output_path: Union[str, Path],
|
192
|
+
columns: Optional[List[str]] = None,
|
193
|
+
shard_filter: Optional[List[str]] = None,
|
194
|
+
) -> int:
|
195
|
+
"""Export to a new Lance dataset, optionally filtering shards.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
----
|
199
|
+
output_path: Path for the output Lance dataset
|
200
|
+
columns: Specific columns to include
|
201
|
+
shard_filter: List of shard names to include
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
-------
|
205
|
+
Total number of rows exported
|
206
|
+
|
207
|
+
"""
|
208
|
+
output_path = Path(output_path)
|
209
|
+
if output_path.exists():
|
210
|
+
raise ValueError(f"Output path already exists: {output_path}")
|
211
|
+
|
212
|
+
# Get shards to export
|
213
|
+
if shard_filter:
|
214
|
+
shards = [s for s in shard_filter if s in self.storage_manager.shard_datasets]
|
215
|
+
else:
|
216
|
+
shards = list(self.storage_manager.shard_datasets.keys())
|
217
|
+
|
218
|
+
if not shards:
|
219
|
+
raise ValueError("No shards to export")
|
220
|
+
|
221
|
+
total_rows = 0
|
222
|
+
first_shard = True
|
223
|
+
|
224
|
+
for shard_name in shards:
|
225
|
+
dataset = self.storage_manager.shard_datasets[shard_name]
|
226
|
+
|
227
|
+
# Build scanner
|
228
|
+
scanner = dataset.scanner(columns=columns)
|
229
|
+
table = scanner.to_table()
|
230
|
+
|
231
|
+
if first_shard:
|
232
|
+
# Create new dataset
|
233
|
+
lance.write_dataset(table, str(output_path), mode="create")
|
234
|
+
first_shard = False
|
235
|
+
else:
|
236
|
+
# Append to existing
|
237
|
+
lance.write_dataset(table, str(output_path), mode="append")
|
238
|
+
|
239
|
+
total_rows += table.num_rows
|
240
|
+
logger.info(f"Exported {table.num_rows} rows from shard {shard_name}")
|
241
|
+
|
242
|
+
logger.info(f"Created Lance dataset at {output_path} with {total_rows} rows")
|
243
|
+
return total_rows
|
244
|
+
|
245
|
+
async def export_to_huggingface_hub(
|
246
|
+
self,
|
247
|
+
dataset_name: str,
|
248
|
+
token: Optional[str] = None,
|
249
|
+
license: str = "apache-2.0",
|
250
|
+
private: bool = False,
|
251
|
+
nsfw: bool = False,
|
252
|
+
tags: Optional[List[str]] = None,
|
253
|
+
language: str = "en",
|
254
|
+
task_categories: Optional[List[str]] = None,
|
255
|
+
shard_filter: Optional[List[str]] = None,
|
256
|
+
max_shard_size_gb: float = 2.0,
|
257
|
+
) -> str:
|
258
|
+
"""Export to Hugging Face Hub with per-shard parquet files.
|
21
259
|
|
22
260
|
Args:
|
23
|
-
|
261
|
+
----
|
262
|
+
dataset_name: Name for the dataset (e.g., "username/dataset-name")
|
263
|
+
token: Hugging Face API token
|
264
|
+
license: License for the dataset
|
265
|
+
private: Whether to make the dataset private
|
266
|
+
nsfw: Whether to add not-for-all-audiences tag
|
267
|
+
tags: Additional tags
|
268
|
+
language: Language code
|
269
|
+
task_categories: Task categories
|
270
|
+
shard_filter: Specific shards to export
|
271
|
+
max_shard_size_gb: Max size per parquet file in GB
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
-------
|
275
|
+
URL of the uploaded dataset
|
276
|
+
|
24
277
|
"""
|
278
|
+
try:
|
279
|
+
import pyarrow.parquet as pq
|
280
|
+
from huggingface_hub import DatasetCard, HfApi, create_repo
|
281
|
+
except ImportError:
|
282
|
+
raise ExportError(
|
283
|
+
"huggingface_hub is required for HF export. "
|
284
|
+
"Install with: pip install huggingface_hub"
|
285
|
+
)
|
286
|
+
|
287
|
+
api = HfApi(token=token)
|
288
|
+
|
289
|
+
# Check/create repo
|
290
|
+
try:
|
291
|
+
api.dataset_info(dataset_name)
|
292
|
+
logger.info(f"Dataset {dataset_name} already exists, will update it")
|
293
|
+
except:
|
294
|
+
logger.info(f"Creating new dataset: {dataset_name}")
|
295
|
+
create_repo(repo_id=dataset_name, repo_type="dataset", private=private, token=token)
|
296
|
+
|
297
|
+
# Get shards to export
|
298
|
+
if shard_filter:
|
299
|
+
shards = [s for s in shard_filter if s in self.storage_manager.shard_datasets]
|
300
|
+
else:
|
301
|
+
shards = sorted(self.storage_manager.shard_datasets.keys())
|
302
|
+
|
303
|
+
# Export each shard as a separate parquet file
|
304
|
+
total_rows = 0
|
305
|
+
|
306
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
307
|
+
tmpdir = Path(tmpdir)
|
308
|
+
data_dir = tmpdir / "data"
|
309
|
+
data_dir.mkdir(exist_ok=True)
|
310
|
+
|
311
|
+
# Export all shards to the data directory
|
312
|
+
for shard_name in shards:
|
313
|
+
# Export shard to parquet
|
314
|
+
parquet_path = data_dir / f"{shard_name}.parquet"
|
315
|
+
rows = await self.export_shard_to_parquet(shard_name, parquet_path)
|
316
|
+
|
317
|
+
if rows > 0:
|
318
|
+
# Check file size
|
319
|
+
file_size_gb = parquet_path.stat().st_size / (1024**3)
|
320
|
+
if file_size_gb > max_shard_size_gb:
|
321
|
+
logger.warning(
|
322
|
+
f"Shard {shard_name} is {file_size_gb:.2f}GB, "
|
323
|
+
f"exceeds limit of {max_shard_size_gb}GB"
|
324
|
+
)
|
325
|
+
|
326
|
+
total_rows += rows
|
327
|
+
logger.info(f"Prepared {shard_name}: {rows} rows, {file_size_gb:.2f}GB")
|
328
|
+
|
329
|
+
# Create dataset card
|
330
|
+
stats = await self.storage_manager.get_caption_stats()
|
331
|
+
|
332
|
+
# Size category
|
333
|
+
if total_rows < 1000:
|
334
|
+
size_category = "n<1K"
|
335
|
+
elif total_rows < 10000:
|
336
|
+
size_category = "1K<n<10K"
|
337
|
+
elif total_rows < 100000:
|
338
|
+
size_category = "10K<n<100K"
|
339
|
+
elif total_rows < 1000000:
|
340
|
+
size_category = "100K<n<1M"
|
341
|
+
elif total_rows < 1000000:
|
342
|
+
size_category = "1M<n<10M"
|
343
|
+
else:
|
344
|
+
size_category = "n>10M"
|
345
|
+
|
346
|
+
# Prepare tags
|
347
|
+
default_tags = ["lance"]
|
348
|
+
all_tags = default_tags + (tags or [])
|
349
|
+
if nsfw:
|
350
|
+
all_tags.append("not-for-all-audiences")
|
351
|
+
|
352
|
+
# Default task categories
|
353
|
+
if task_categories is None:
|
354
|
+
task_categories = ["text-to-image", "image-to-image"]
|
355
|
+
|
356
|
+
# Create card content
|
357
|
+
card_content = f"""---
|
358
|
+
license: {license}
|
359
|
+
language:
|
360
|
+
- {language}
|
361
|
+
size_categories:
|
362
|
+
- {size_category}
|
363
|
+
task_categories:
|
364
|
+
{self._yaml_list(task_categories)}"""
|
365
|
+
|
366
|
+
if all_tags:
|
367
|
+
card_content += f"\ntags:\n{self._yaml_list(all_tags)}"
|
368
|
+
|
369
|
+
card_content += f"""
|
370
|
+
---
|
371
|
+
|
372
|
+
# Caption Dataset
|
373
|
+
|
374
|
+
This dataset contains {total_rows:,} captioned items exported from CaptionFlow.
|
375
|
+
|
376
|
+
## Dataset Structure
|
377
|
+
|
378
|
+
"""
|
379
|
+
|
380
|
+
card_content += "\n\n### Data Fields\n\n"
|
381
|
+
|
382
|
+
# Add field descriptions
|
383
|
+
all_fields = set()
|
384
|
+
for field, _ in self.storage_manager.base_caption_fields:
|
385
|
+
all_fields.add(field)
|
386
|
+
for fields in self.storage_manager.shard_output_fields.values():
|
387
|
+
all_fields.update(fields)
|
388
|
+
|
389
|
+
for field in sorted(all_fields):
|
390
|
+
if field in stats.get("output_fields", []):
|
391
|
+
card_content += f"- `{field}`: List of captions/outputs\n"
|
392
|
+
else:
|
393
|
+
card_content += f"- `{field}`\n"
|
394
|
+
|
395
|
+
if stats.get("field_stats"):
|
396
|
+
card_content += "\n### Output Field Statistics\n\n"
|
397
|
+
for field, count in stats["field_stats"].items():
|
398
|
+
card_content += f"- `{field}`: {count:,} total items\n"
|
399
|
+
|
400
|
+
# Save README.md
|
401
|
+
readme_path = tmpdir / "README.md"
|
402
|
+
with open(readme_path, "w", encoding="utf-8") as f:
|
403
|
+
f.write(card_content)
|
404
|
+
|
405
|
+
# Upload the entire folder at once
|
406
|
+
logger.info(f"Uploading dataset to {dataset_name}...")
|
407
|
+
api.upload_large_folder(
|
408
|
+
repo_id=dataset_name,
|
409
|
+
folder_path=str(tmpdir),
|
410
|
+
repo_type="dataset",
|
411
|
+
)
|
412
|
+
|
413
|
+
dataset_url = f"https://huggingface.co/datasets/{dataset_name}"
|
414
|
+
logger.info(f"Successfully uploaded dataset to: {dataset_url}")
|
415
|
+
|
416
|
+
return dataset_url
|
417
|
+
|
418
|
+
def _yaml_list(self, items: List[str]) -> str:
|
419
|
+
"""Format a list for YAML."""
|
420
|
+
return "\n".join(f"- {item}" for item in items)
|
421
|
+
|
422
|
+
|
423
|
+
class StorageExporter:
|
424
|
+
"""Legacy exporter for StorageContents objects."""
|
425
|
+
|
426
|
+
def __init__(self, contents: StorageContents):
|
25
427
|
self.contents = contents
|
26
428
|
self._validate_contents()
|
27
429
|
|
28
430
|
def _validate_contents(self):
|
29
|
-
"""Validate that contents are suitable for export."""
|
30
431
|
if not self.contents.rows:
|
31
432
|
logger.warning("No rows to export")
|
32
433
|
if not self.contents.columns:
|
33
434
|
raise ExportError("No columns defined for export")
|
34
435
|
|
35
436
|
def _flatten_lists(self, value: Any) -> str:
|
36
|
-
"""Convert list values to newline-separated strings."""
|
37
437
|
if isinstance(value, list):
|
38
|
-
# Strip newlines from each element and join
|
39
438
|
return "\n".join(str(item).replace("\n", " ") for item in value)
|
40
439
|
return str(value) if value is not None else ""
|
41
440
|
|
42
441
|
def _serialize_value(self, value: Any) -> Any:
|
43
|
-
|
442
|
+
import datetime as dt
|
443
|
+
|
44
444
|
if pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(value, pd.Timestamp):
|
45
445
|
return value.isoformat()
|
46
|
-
elif isinstance(value,
|
446
|
+
elif isinstance(value, (dt.datetime, dt.date)):
|
447
|
+
return value.isoformat()
|
448
|
+
elif isinstance(value, (np.integer, np.int64)):
|
47
449
|
return int(value)
|
48
|
-
elif isinstance(value, np.floating):
|
450
|
+
elif isinstance(value, (np.floating, np.float64)):
|
49
451
|
return float(value)
|
50
452
|
elif isinstance(value, np.ndarray):
|
51
453
|
return value.tolist()
|
@@ -56,23 +458,13 @@ class StorageExporter:
|
|
56
458
|
return value
|
57
459
|
|
58
460
|
def to_jsonl(self, output_path: Union[str, Path]) -> int:
|
59
|
-
"""Export to JSONL (JSON Lines) format.
|
60
|
-
|
61
|
-
Args:
|
62
|
-
output_path: Path to output JSONL file
|
63
|
-
|
64
|
-
Returns:
|
65
|
-
Number of rows exported
|
66
|
-
"""
|
67
461
|
output_path = Path(output_path)
|
68
462
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
69
463
|
|
70
464
|
rows_written = 0
|
71
465
|
with open(output_path, "w", encoding="utf-8") as f:
|
72
466
|
for row in self.contents.rows:
|
73
|
-
# Convert non-serializable values
|
74
467
|
serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
|
75
|
-
# Write each row as a JSON object on its own line
|
76
468
|
json_line = json.dumps(serializable_row, ensure_ascii=False)
|
77
469
|
f.write(json_line + "\n")
|
78
470
|
rows_written += 1
|
@@ -81,18 +473,12 @@ class StorageExporter:
|
|
81
473
|
return rows_written
|
82
474
|
|
83
475
|
def _get_filename_from_row(self, row: Dict[str, Any], filename_column: str) -> Optional[str]:
|
84
|
-
"""Extract filename from row, falling back to URL if needed."""
|
85
|
-
# Try the specified filename column first
|
86
476
|
filename = row.get(filename_column)
|
87
477
|
if filename:
|
88
478
|
return filename
|
89
479
|
|
90
|
-
# Fall back to URL if available
|
91
480
|
url = row.get("url")
|
92
481
|
if url:
|
93
|
-
# Extract filename from URL path
|
94
|
-
from urllib.parse import urlparse
|
95
|
-
|
96
482
|
parsed = urlparse(str(url))
|
97
483
|
path_parts = parsed.path.rstrip("/").split("/")
|
98
484
|
if path_parts and path_parts[-1]:
|
@@ -101,23 +487,11 @@ class StorageExporter:
|
|
101
487
|
return None
|
102
488
|
|
103
489
|
def to_json(self, output_dir: Union[str, Path], filename_column: str = "filename") -> int:
|
104
|
-
"""Export to individual JSON files based on filename column.
|
105
|
-
|
106
|
-
Args:
|
107
|
-
output_dir: Directory to write JSON files
|
108
|
-
filename_column: Column containing the base filename
|
109
|
-
|
110
|
-
Returns:
|
111
|
-
Number of files created
|
112
|
-
"""
|
113
490
|
output_dir = Path(output_dir)
|
114
491
|
output_dir.mkdir(parents=True, exist_ok=True)
|
115
492
|
|
116
|
-
# Check if we need to fall back to URL
|
117
|
-
using_url_fallback = False
|
118
493
|
if filename_column not in self.contents.columns and "url" in self.contents.columns:
|
119
494
|
logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
|
120
|
-
using_url_fallback = True
|
121
495
|
elif filename_column not in self.contents.columns:
|
122
496
|
raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
|
123
497
|
|
@@ -128,17 +502,13 @@ class StorageExporter:
|
|
128
502
|
filename = self._get_filename_from_row(row, filename_column)
|
129
503
|
if not filename:
|
130
504
|
skipped_count += 1
|
131
|
-
logger.warning(f"Skipping row with no extractable filename")
|
132
505
|
continue
|
133
506
|
|
134
|
-
# Create JSON filename from original filename
|
135
507
|
base_name = Path(filename).stem
|
136
508
|
json_path = output_dir / f"{base_name}.json"
|
137
509
|
|
138
|
-
# Convert non-serializable values
|
139
510
|
serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
|
140
511
|
|
141
|
-
# Write row data as JSON
|
142
512
|
with open(json_path, "w", encoding="utf-8") as f:
|
143
513
|
json.dump(serializable_row, f, ensure_ascii=False, indent=2)
|
144
514
|
|
@@ -151,14 +521,6 @@ class StorageExporter:
|
|
151
521
|
return files_created
|
152
522
|
|
153
523
|
def to_csv(self, output_path: Union[str, Path]) -> int:
|
154
|
-
"""Export to CSV format, skipping complex columns.
|
155
|
-
|
156
|
-
Args:
|
157
|
-
output_path: Path to output CSV file
|
158
|
-
|
159
|
-
Returns:
|
160
|
-
Number of rows exported
|
161
|
-
"""
|
162
524
|
output_path = Path(output_path)
|
163
525
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
164
526
|
|
@@ -166,12 +528,10 @@ class StorageExporter:
|
|
166
528
|
complex_columns = set()
|
167
529
|
csv_safe_columns = []
|
168
530
|
|
169
|
-
# Check column types by sampling data
|
170
531
|
sample_size = min(10, len(self.contents.rows))
|
171
532
|
for row in self.contents.rows[:sample_size]:
|
172
533
|
for col, value in row.items():
|
173
534
|
if col not in complex_columns and value is not None:
|
174
|
-
# Skip dictionaries and non-output field lists
|
175
535
|
if isinstance(value, dict):
|
176
536
|
complex_columns.add(col)
|
177
537
|
logger.warning(
|
@@ -185,19 +545,16 @@ class StorageExporter:
|
|
185
545
|
"Consider using JSONL format for complete data export."
|
186
546
|
)
|
187
547
|
|
188
|
-
# Build list of CSV-safe columns
|
189
548
|
csv_safe_columns = [col for col in self.contents.columns if col not in complex_columns]
|
190
549
|
|
191
550
|
if not csv_safe_columns:
|
192
551
|
raise ExportError("No columns suitable for CSV export. Use JSONL format instead.")
|
193
552
|
|
194
|
-
# Prepare rows for CSV export with safe columns only
|
195
553
|
csv_rows = []
|
196
554
|
for row in self.contents.rows:
|
197
555
|
csv_row = {}
|
198
556
|
for col in csv_safe_columns:
|
199
557
|
value = row.get(col)
|
200
|
-
# Handle list values (like captions) by joining with newlines
|
201
558
|
if isinstance(value, list):
|
202
559
|
csv_row[col] = self._flatten_lists(value)
|
203
560
|
elif pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(
|
@@ -208,13 +565,11 @@ class StorageExporter:
|
|
208
565
|
csv_row[col] = value
|
209
566
|
csv_rows.append(csv_row)
|
210
567
|
|
211
|
-
# Write to CSV
|
212
568
|
with open(output_path, "w", encoding="utf-8", newline="") as f:
|
213
569
|
writer = csv.DictWriter(f, fieldnames=csv_safe_columns)
|
214
570
|
writer.writeheader()
|
215
571
|
writer.writerows(csv_rows)
|
216
572
|
|
217
|
-
# Log results
|
218
573
|
if complex_columns:
|
219
574
|
skipped_msg = f"Skipped {len(complex_columns)} complex columns: {', '.join(sorted(complex_columns))}"
|
220
575
|
logger.warning(skipped_msg)
|
@@ -232,29 +587,15 @@ class StorageExporter:
|
|
232
587
|
filename_column: str = "filename",
|
233
588
|
export_column: str = "captions",
|
234
589
|
) -> int:
|
235
|
-
"""Export specific column to individual text files.
|
236
|
-
|
237
|
-
Args:
|
238
|
-
output_dir: Directory to write text files
|
239
|
-
filename_column: Column containing the base filename
|
240
|
-
export_column: Column to export to text files
|
241
|
-
|
242
|
-
Returns:
|
243
|
-
Number of files created
|
244
|
-
"""
|
245
590
|
output_dir = Path(output_dir)
|
246
591
|
output_dir.mkdir(parents=True, exist_ok=True)
|
247
592
|
|
248
|
-
# Check if we need to fall back to URL
|
249
|
-
using_url_fallback = False
|
250
593
|
if filename_column not in self.contents.columns and "url" in self.contents.columns:
|
251
594
|
logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
|
252
|
-
using_url_fallback = True
|
253
595
|
elif filename_column not in self.contents.columns:
|
254
596
|
raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
|
255
597
|
|
256
598
|
if export_column not in self.contents.columns:
|
257
|
-
# Check if it's an output field
|
258
599
|
if export_column not in self.contents.output_fields:
|
259
600
|
raise ExportError(f"Column '{export_column}' not found in data")
|
260
601
|
|
@@ -266,20 +607,16 @@ class StorageExporter:
|
|
266
607
|
filename = self._get_filename_from_row(row, filename_column)
|
267
608
|
if not filename:
|
268
609
|
skipped_no_filename += 1
|
269
|
-
logger.warning(f"Skipping row with no extractable filename")
|
270
610
|
continue
|
271
611
|
|
272
612
|
content = row.get(export_column)
|
273
613
|
if content is None:
|
274
614
|
skipped_no_content += 1
|
275
|
-
logger.warning(f"No {export_column} for {filename}")
|
276
615
|
continue
|
277
616
|
|
278
|
-
# Create text filename from original filename
|
279
617
|
base_name = Path(filename).stem
|
280
618
|
txt_path = output_dir / f"{base_name}.txt"
|
281
619
|
|
282
|
-
# Write content
|
283
620
|
with open(txt_path, "w", encoding="utf-8") as f:
|
284
621
|
f.write(self._flatten_lists(content))
|
285
622
|
|
@@ -292,259 +629,3 @@ class StorageExporter:
|
|
292
629
|
|
293
630
|
logger.info(f"Created {files_created} text files in: {output_dir}")
|
294
631
|
return files_created
|
295
|
-
|
296
|
-
def to_huggingface_hub(
|
297
|
-
self,
|
298
|
-
dataset_name: str,
|
299
|
-
token: Optional[str] = None,
|
300
|
-
license: Optional[str] = None,
|
301
|
-
private: bool = False,
|
302
|
-
nsfw: bool = False,
|
303
|
-
tags: Optional[List[str]] = None,
|
304
|
-
language: str = "en",
|
305
|
-
task_categories: Optional[List[str]] = None,
|
306
|
-
) -> str:
|
307
|
-
"""Export to Hugging Face Hub as a dataset.
|
308
|
-
|
309
|
-
Args:
|
310
|
-
dataset_name: Name for the dataset (e.g., "username/dataset-name")
|
311
|
-
token: Hugging Face API token
|
312
|
-
license: License for the dataset (required for new repos)
|
313
|
-
private: Whether to make the dataset private
|
314
|
-
nsfw: Whether to add not-for-all-audiences tag
|
315
|
-
tags: Additional tags for the dataset
|
316
|
-
language: Language code (default: "en")
|
317
|
-
task_categories: Task categories (default: ["text-to-image", "image-to-image"])
|
318
|
-
|
319
|
-
Returns:
|
320
|
-
URL of the uploaded dataset
|
321
|
-
"""
|
322
|
-
try:
|
323
|
-
from huggingface_hub import HfApi, DatasetCard, create_repo
|
324
|
-
import pyarrow as pa
|
325
|
-
import pyarrow.parquet as pq
|
326
|
-
except ImportError:
|
327
|
-
raise ExportError(
|
328
|
-
"huggingface_hub and pyarrow are required for HF export. "
|
329
|
-
"Install with: pip install huggingface_hub pyarrow"
|
330
|
-
)
|
331
|
-
|
332
|
-
# Initialize HF API
|
333
|
-
api = HfApi(token=token)
|
334
|
-
|
335
|
-
# Check if repo exists
|
336
|
-
repo_exists = False
|
337
|
-
try:
|
338
|
-
api.dataset_info(dataset_name)
|
339
|
-
repo_exists = True
|
340
|
-
logger.info(f"Dataset {dataset_name} already exists, will update it")
|
341
|
-
except:
|
342
|
-
logger.info(f"Creating new dataset: {dataset_name}")
|
343
|
-
if not license:
|
344
|
-
raise ExportError("License is required when creating a new dataset")
|
345
|
-
|
346
|
-
# Create repo if it doesn't exist
|
347
|
-
if not repo_exists:
|
348
|
-
create_repo(repo_id=dataset_name, repo_type="dataset", private=private, token=token)
|
349
|
-
|
350
|
-
# Prepare data for parquet
|
351
|
-
df = pd.DataFrame(self.contents.rows)
|
352
|
-
|
353
|
-
# Convert any remaining non-serializable types
|
354
|
-
for col in df.columns:
|
355
|
-
if df[col].dtype == "object":
|
356
|
-
df[col] = df[col].apply(
|
357
|
-
lambda x: self._serialize_value(x) if x is not None else None
|
358
|
-
)
|
359
|
-
|
360
|
-
# Determine size category
|
361
|
-
num_rows = len(df)
|
362
|
-
if num_rows < 1000:
|
363
|
-
size_category = "n<1K"
|
364
|
-
elif num_rows < 10000:
|
365
|
-
size_category = "1K<n<10K"
|
366
|
-
elif num_rows < 100000:
|
367
|
-
size_category = "10K<n<100K"
|
368
|
-
elif num_rows < 1000000:
|
369
|
-
size_category = "100K<n<1M"
|
370
|
-
elif num_rows < 10000000:
|
371
|
-
size_category = "1M<n<10M"
|
372
|
-
else:
|
373
|
-
size_category = "n>10M"
|
374
|
-
|
375
|
-
# Prepare tags
|
376
|
-
all_tags = tags or []
|
377
|
-
if nsfw:
|
378
|
-
all_tags.append("not-for-all-audiences")
|
379
|
-
|
380
|
-
# Default task categories
|
381
|
-
if task_categories is None:
|
382
|
-
task_categories = ["text-to-image", "image-to-image"]
|
383
|
-
|
384
|
-
# Create dataset card
|
385
|
-
card_content = f"""---
|
386
|
-
license: {license or 'unknown'}
|
387
|
-
language:
|
388
|
-
- {language}
|
389
|
-
size_categories:
|
390
|
-
- {size_category}
|
391
|
-
task_categories:
|
392
|
-
{self._yaml_list(task_categories)}"""
|
393
|
-
|
394
|
-
if all_tags:
|
395
|
-
card_content += f"\ntags:\n{self._yaml_list(all_tags)}"
|
396
|
-
|
397
|
-
card_content += f"""
|
398
|
-
---
|
399
|
-
|
400
|
-
# Caption Dataset
|
401
|
-
|
402
|
-
This dataset contains {num_rows:,} captioned items exported from CaptionFlow.
|
403
|
-
|
404
|
-
## Dataset Structure
|
405
|
-
|
406
|
-
### Data Fields
|
407
|
-
|
408
|
-
"""
|
409
|
-
|
410
|
-
# Add field descriptions
|
411
|
-
for col in df.columns:
|
412
|
-
dtype = str(df[col].dtype)
|
413
|
-
if col in self.contents.output_fields:
|
414
|
-
card_content += f"- `{col}`: List of captions/outputs\n"
|
415
|
-
else:
|
416
|
-
card_content += f"- `{col}`: {dtype}\n"
|
417
|
-
|
418
|
-
if self.contents.metadata:
|
419
|
-
card_content += "\n## Export Information\n\n"
|
420
|
-
if "export_timestamp" in self.contents.metadata:
|
421
|
-
card_content += (
|
422
|
-
f"- Export timestamp: {self.contents.metadata['export_timestamp']}\n"
|
423
|
-
)
|
424
|
-
if "field_stats" in self.contents.metadata:
|
425
|
-
card_content += "\n### Field Statistics\n\n"
|
426
|
-
for field, stats in self.contents.metadata["field_stats"].items():
|
427
|
-
card_content += f"- `{field}`: {stats['total_items']:,} items across {stats['rows_with_data']:,} rows\n"
|
428
|
-
|
429
|
-
# Create temporary parquet file
|
430
|
-
import tempfile
|
431
|
-
|
432
|
-
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp_file:
|
433
|
-
temp_path = Path(tmp_file.name)
|
434
|
-
|
435
|
-
try:
|
436
|
-
# Write parquet file
|
437
|
-
table = pa.Table.from_pandas(df)
|
438
|
-
pq.write_table(table, temp_path, compression="snappy")
|
439
|
-
|
440
|
-
# Upload parquet file
|
441
|
-
api.upload_file(
|
442
|
-
path_or_fileobj=str(temp_path),
|
443
|
-
path_in_repo="data.parquet",
|
444
|
-
repo_id=dataset_name,
|
445
|
-
repo_type="dataset",
|
446
|
-
token=token,
|
447
|
-
)
|
448
|
-
|
449
|
-
# Create and upload dataset card
|
450
|
-
card = DatasetCard(card_content)
|
451
|
-
card.push_to_hub(dataset_name, token=token)
|
452
|
-
|
453
|
-
dataset_url = f"https://huggingface.co/datasets/{dataset_name}"
|
454
|
-
logger.info(f"Successfully uploaded dataset to: {dataset_url}")
|
455
|
-
|
456
|
-
return dataset_url
|
457
|
-
|
458
|
-
finally:
|
459
|
-
# Clean up temp file
|
460
|
-
if temp_path.exists():
|
461
|
-
temp_path.unlink()
|
462
|
-
|
463
|
-
def _yaml_list(self, items: List[str]) -> str:
|
464
|
-
"""Format a list for YAML."""
|
465
|
-
return "\n".join(f"- {item}" for item in items)
|
466
|
-
|
467
|
-
|
468
|
-
# Addition to StorageManager class
|
469
|
-
async def get_storage_contents(
|
470
|
-
self,
|
471
|
-
limit: Optional[int] = None,
|
472
|
-
columns: Optional[List[str]] = None,
|
473
|
-
include_metadata: bool = True,
|
474
|
-
) -> StorageContents:
|
475
|
-
"""Retrieve storage contents for export.
|
476
|
-
|
477
|
-
Args:
|
478
|
-
limit: Maximum number of rows to retrieve
|
479
|
-
columns: Specific columns to include (None for all)
|
480
|
-
include_metadata: Whether to include metadata in the result
|
481
|
-
|
482
|
-
Returns:
|
483
|
-
StorageContents instance with the requested data
|
484
|
-
"""
|
485
|
-
if not self.captions_path.exists():
|
486
|
-
return StorageContents(
|
487
|
-
rows=[],
|
488
|
-
columns=[],
|
489
|
-
output_fields=list(self.known_output_fields),
|
490
|
-
total_rows=0,
|
491
|
-
metadata={"message": "No captions file found"},
|
492
|
-
)
|
493
|
-
|
494
|
-
# Flush buffers first to ensure all data is on disk
|
495
|
-
await self.checkpoint()
|
496
|
-
|
497
|
-
# Determine columns to read
|
498
|
-
if columns:
|
499
|
-
# Validate requested columns exist
|
500
|
-
table_metadata = pq.read_metadata(self.captions_path)
|
501
|
-
available_columns = set(table_metadata.schema.names)
|
502
|
-
invalid_columns = set(columns) - available_columns
|
503
|
-
if invalid_columns:
|
504
|
-
raise ValueError(f"Columns not found: {invalid_columns}")
|
505
|
-
columns_to_read = columns
|
506
|
-
else:
|
507
|
-
# Read all columns
|
508
|
-
columns_to_read = None
|
509
|
-
|
510
|
-
# Read the table
|
511
|
-
table = pq.read_table(self.captions_path, columns=columns_to_read)
|
512
|
-
df = table.to_pandas()
|
513
|
-
|
514
|
-
# Apply limit if specified
|
515
|
-
if limit:
|
516
|
-
df = df.head(limit)
|
517
|
-
|
518
|
-
# Convert to list of dicts
|
519
|
-
rows = df.to_dict("records")
|
520
|
-
|
521
|
-
# Parse metadata JSON strings back to dicts if present
|
522
|
-
if "metadata" in df.columns:
|
523
|
-
for row in rows:
|
524
|
-
if row.get("metadata"):
|
525
|
-
try:
|
526
|
-
row["metadata"] = json.loads(row["metadata"])
|
527
|
-
except:
|
528
|
-
pass # Keep as string if parsing fails
|
529
|
-
|
530
|
-
# Prepare metadata
|
531
|
-
metadata = {}
|
532
|
-
if include_metadata:
|
533
|
-
stats = await self.get_caption_stats()
|
534
|
-
metadata.update(
|
535
|
-
{
|
536
|
-
"export_timestamp": pd.Timestamp.now().isoformat(),
|
537
|
-
"total_available_rows": stats.get("total_rows", 0),
|
538
|
-
"rows_exported": len(rows),
|
539
|
-
"storage_path": str(self.captions_path),
|
540
|
-
"field_stats": stats.get("field_stats", {}),
|
541
|
-
}
|
542
|
-
)
|
543
|
-
|
544
|
-
return StorageContents(
|
545
|
-
rows=rows,
|
546
|
-
columns=list(df.columns),
|
547
|
-
output_fields=list(self.known_output_fields),
|
548
|
-
total_rows=len(df),
|
549
|
-
metadata=metadata,
|
550
|
-
)
|