caption-flow 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +307 -0
- caption_flow/models.py +26 -0
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +85 -1
- caption_flow/viewer.py +594 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.2.4.dist-info}/METADATA +44 -177
- {caption_flow-0.2.3.dist-info → caption_flow-0.2.4.dist-info}/RECORD +12 -9
- {caption_flow-0.2.3.dist-info → caption_flow-0.2.4.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.2.4.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.3.dist-info → caption_flow-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,550 @@
|
|
1
|
+
"""Storage exporter for converting Parquet data to various formats."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
import csv
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List, Dict, Any, Optional, Union
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
import logging
|
9
|
+
import pandas as pd
|
10
|
+
import numpy as np
|
11
|
+
from ..models import StorageContents, ExportError
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class StorageExporter:
|
17
|
+
"""Exports StorageContents to various formats."""
|
18
|
+
|
19
|
+
def __init__(self, contents: StorageContents):
|
20
|
+
"""Initialize exporter with storage contents.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
contents: StorageContents instance to export
|
24
|
+
"""
|
25
|
+
self.contents = contents
|
26
|
+
self._validate_contents()
|
27
|
+
|
28
|
+
def _validate_contents(self):
|
29
|
+
"""Validate that contents are suitable for export."""
|
30
|
+
if not self.contents.rows:
|
31
|
+
logger.warning("No rows to export")
|
32
|
+
if not self.contents.columns:
|
33
|
+
raise ExportError("No columns defined for export")
|
34
|
+
|
35
|
+
def _flatten_lists(self, value: Any) -> str:
|
36
|
+
"""Convert list values to newline-separated strings."""
|
37
|
+
if isinstance(value, list):
|
38
|
+
# Strip newlines from each element and join
|
39
|
+
return "\n".join(str(item).replace("\n", " ") for item in value)
|
40
|
+
return str(value) if value is not None else ""
|
41
|
+
|
42
|
+
def _serialize_value(self, value: Any) -> Any:
|
43
|
+
"""Convert values to JSON-serializable format."""
|
44
|
+
if pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(value, pd.Timestamp):
|
45
|
+
return value.isoformat()
|
46
|
+
elif isinstance(value, np.integer):
|
47
|
+
return int(value)
|
48
|
+
elif isinstance(value, np.floating):
|
49
|
+
return float(value)
|
50
|
+
elif isinstance(value, np.ndarray):
|
51
|
+
return value.tolist()
|
52
|
+
elif isinstance(value, dict):
|
53
|
+
return {k: self._serialize_value(v) for k, v in value.items()}
|
54
|
+
elif isinstance(value, list):
|
55
|
+
return [self._serialize_value(item) for item in value]
|
56
|
+
return value
|
57
|
+
|
58
|
+
def to_jsonl(self, output_path: Union[str, Path]) -> int:
|
59
|
+
"""Export to JSONL (JSON Lines) format.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
output_path: Path to output JSONL file
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Number of rows exported
|
66
|
+
"""
|
67
|
+
output_path = Path(output_path)
|
68
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
69
|
+
|
70
|
+
rows_written = 0
|
71
|
+
with open(output_path, "w", encoding="utf-8") as f:
|
72
|
+
for row in self.contents.rows:
|
73
|
+
# Convert non-serializable values
|
74
|
+
serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
|
75
|
+
# Write each row as a JSON object on its own line
|
76
|
+
json_line = json.dumps(serializable_row, ensure_ascii=False)
|
77
|
+
f.write(json_line + "\n")
|
78
|
+
rows_written += 1
|
79
|
+
|
80
|
+
logger.info(f"Exported {rows_written} rows to JSONL: {output_path}")
|
81
|
+
return rows_written
|
82
|
+
|
83
|
+
def _get_filename_from_row(self, row: Dict[str, Any], filename_column: str) -> Optional[str]:
|
84
|
+
"""Extract filename from row, falling back to URL if needed."""
|
85
|
+
# Try the specified filename column first
|
86
|
+
filename = row.get(filename_column)
|
87
|
+
if filename:
|
88
|
+
return filename
|
89
|
+
|
90
|
+
# Fall back to URL if available
|
91
|
+
url = row.get("url")
|
92
|
+
if url:
|
93
|
+
# Extract filename from URL path
|
94
|
+
from urllib.parse import urlparse
|
95
|
+
|
96
|
+
parsed = urlparse(str(url))
|
97
|
+
path_parts = parsed.path.rstrip("/").split("/")
|
98
|
+
if path_parts and path_parts[-1]:
|
99
|
+
return path_parts[-1]
|
100
|
+
|
101
|
+
return None
|
102
|
+
|
103
|
+
def to_json(self, output_dir: Union[str, Path], filename_column: str = "filename") -> int:
|
104
|
+
"""Export to individual JSON files based on filename column.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
output_dir: Directory to write JSON files
|
108
|
+
filename_column: Column containing the base filename
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
Number of files created
|
112
|
+
"""
|
113
|
+
output_dir = Path(output_dir)
|
114
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
115
|
+
|
116
|
+
# Check if we need to fall back to URL
|
117
|
+
using_url_fallback = False
|
118
|
+
if filename_column not in self.contents.columns and "url" in self.contents.columns:
|
119
|
+
logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
|
120
|
+
using_url_fallback = True
|
121
|
+
elif filename_column not in self.contents.columns:
|
122
|
+
raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
|
123
|
+
|
124
|
+
files_created = 0
|
125
|
+
skipped_count = 0
|
126
|
+
|
127
|
+
for row in self.contents.rows:
|
128
|
+
filename = self._get_filename_from_row(row, filename_column)
|
129
|
+
if not filename:
|
130
|
+
skipped_count += 1
|
131
|
+
logger.warning(f"Skipping row with no extractable filename")
|
132
|
+
continue
|
133
|
+
|
134
|
+
# Create JSON filename from original filename
|
135
|
+
base_name = Path(filename).stem
|
136
|
+
json_path = output_dir / f"{base_name}.json"
|
137
|
+
|
138
|
+
# Convert non-serializable values
|
139
|
+
serializable_row = {k: self._serialize_value(v) for k, v in row.items()}
|
140
|
+
|
141
|
+
# Write row data as JSON
|
142
|
+
with open(json_path, "w", encoding="utf-8") as f:
|
143
|
+
json.dump(serializable_row, f, ensure_ascii=False, indent=2)
|
144
|
+
|
145
|
+
files_created += 1
|
146
|
+
|
147
|
+
if skipped_count > 0:
|
148
|
+
logger.warning(f"Skipped {skipped_count} rows with no extractable filename")
|
149
|
+
|
150
|
+
logger.info(f"Created {files_created} JSON files in: {output_dir}")
|
151
|
+
return files_created
|
152
|
+
|
153
|
+
def to_csv(self, output_path: Union[str, Path]) -> int:
|
154
|
+
"""Export to CSV format, skipping complex columns.
|
155
|
+
|
156
|
+
Args:
|
157
|
+
output_path: Path to output CSV file
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Number of rows exported
|
161
|
+
"""
|
162
|
+
output_path = Path(output_path)
|
163
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
164
|
+
|
165
|
+
# Identify complex columns to skip
|
166
|
+
complex_columns = set()
|
167
|
+
csv_safe_columns = []
|
168
|
+
|
169
|
+
# Check column types by sampling data
|
170
|
+
sample_size = min(10, len(self.contents.rows))
|
171
|
+
for row in self.contents.rows[:sample_size]:
|
172
|
+
for col, value in row.items():
|
173
|
+
if col not in complex_columns and value is not None:
|
174
|
+
# Skip dictionaries and non-output field lists
|
175
|
+
if isinstance(value, dict):
|
176
|
+
complex_columns.add(col)
|
177
|
+
logger.warning(
|
178
|
+
f"Column '{col}' contains dict type and will be skipped. "
|
179
|
+
"Consider using JSONL format for complete data export."
|
180
|
+
)
|
181
|
+
elif isinstance(value, list) and col not in self.contents.output_fields:
|
182
|
+
complex_columns.add(col)
|
183
|
+
logger.warning(
|
184
|
+
f"Column '{col}' contains list type and will be skipped. "
|
185
|
+
"Consider using JSONL format for complete data export."
|
186
|
+
)
|
187
|
+
|
188
|
+
# Build list of CSV-safe columns
|
189
|
+
csv_safe_columns = [col for col in self.contents.columns if col not in complex_columns]
|
190
|
+
|
191
|
+
if not csv_safe_columns:
|
192
|
+
raise ExportError("No columns suitable for CSV export. Use JSONL format instead.")
|
193
|
+
|
194
|
+
# Prepare rows for CSV export with safe columns only
|
195
|
+
csv_rows = []
|
196
|
+
for row in self.contents.rows:
|
197
|
+
csv_row = {}
|
198
|
+
for col in csv_safe_columns:
|
199
|
+
value = row.get(col)
|
200
|
+
# Handle list values (like captions) by joining with newlines
|
201
|
+
if isinstance(value, list):
|
202
|
+
csv_row[col] = self._flatten_lists(value)
|
203
|
+
elif pd.api.types.is_datetime64_any_dtype(type(value)) or isinstance(
|
204
|
+
value, pd.Timestamp
|
205
|
+
):
|
206
|
+
csv_row[col] = self._serialize_value(value)
|
207
|
+
else:
|
208
|
+
csv_row[col] = value
|
209
|
+
csv_rows.append(csv_row)
|
210
|
+
|
211
|
+
# Write to CSV
|
212
|
+
with open(output_path, "w", encoding="utf-8", newline="") as f:
|
213
|
+
writer = csv.DictWriter(f, fieldnames=csv_safe_columns)
|
214
|
+
writer.writeheader()
|
215
|
+
writer.writerows(csv_rows)
|
216
|
+
|
217
|
+
# Log results
|
218
|
+
if complex_columns:
|
219
|
+
skipped_msg = f"Skipped {len(complex_columns)} complex columns: {', '.join(sorted(complex_columns))}"
|
220
|
+
logger.warning(skipped_msg)
|
221
|
+
|
222
|
+
logger.info(
|
223
|
+
f"Exported {len(csv_rows)} rows to CSV: {output_path} "
|
224
|
+
f"(with {len(csv_safe_columns)}/{len(self.contents.columns)} columns)"
|
225
|
+
)
|
226
|
+
|
227
|
+
return len(csv_rows)
|
228
|
+
|
229
|
+
def to_txt(
|
230
|
+
self,
|
231
|
+
output_dir: Union[str, Path],
|
232
|
+
filename_column: str = "filename",
|
233
|
+
export_column: str = "captions",
|
234
|
+
) -> int:
|
235
|
+
"""Export specific column to individual text files.
|
236
|
+
|
237
|
+
Args:
|
238
|
+
output_dir: Directory to write text files
|
239
|
+
filename_column: Column containing the base filename
|
240
|
+
export_column: Column to export to text files
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
Number of files created
|
244
|
+
"""
|
245
|
+
output_dir = Path(output_dir)
|
246
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
247
|
+
|
248
|
+
# Check if we need to fall back to URL
|
249
|
+
using_url_fallback = False
|
250
|
+
if filename_column not in self.contents.columns and "url" in self.contents.columns:
|
251
|
+
logger.warning(f"Column '{filename_column}' not found, falling back to 'url' column")
|
252
|
+
using_url_fallback = True
|
253
|
+
elif filename_column not in self.contents.columns:
|
254
|
+
raise ExportError(f"Column '{filename_column}' not found and no 'url' column available")
|
255
|
+
|
256
|
+
if export_column not in self.contents.columns:
|
257
|
+
# Check if it's an output field
|
258
|
+
if export_column not in self.contents.output_fields:
|
259
|
+
raise ExportError(f"Column '{export_column}' not found in data")
|
260
|
+
|
261
|
+
files_created = 0
|
262
|
+
skipped_no_filename = 0
|
263
|
+
skipped_no_content = 0
|
264
|
+
|
265
|
+
for row in self.contents.rows:
|
266
|
+
filename = self._get_filename_from_row(row, filename_column)
|
267
|
+
if not filename:
|
268
|
+
skipped_no_filename += 1
|
269
|
+
logger.warning(f"Skipping row with no extractable filename")
|
270
|
+
continue
|
271
|
+
|
272
|
+
content = row.get(export_column)
|
273
|
+
if content is None:
|
274
|
+
skipped_no_content += 1
|
275
|
+
logger.warning(f"No {export_column} for {filename}")
|
276
|
+
continue
|
277
|
+
|
278
|
+
# Create text filename from original filename
|
279
|
+
base_name = Path(filename).stem
|
280
|
+
txt_path = output_dir / f"{base_name}.txt"
|
281
|
+
|
282
|
+
# Write content
|
283
|
+
with open(txt_path, "w", encoding="utf-8") as f:
|
284
|
+
f.write(self._flatten_lists(content))
|
285
|
+
|
286
|
+
files_created += 1
|
287
|
+
|
288
|
+
if skipped_no_filename > 0:
|
289
|
+
logger.warning(f"Skipped {skipped_no_filename} rows with no extractable filename")
|
290
|
+
if skipped_no_content > 0:
|
291
|
+
logger.warning(f"Skipped {skipped_no_content} rows with no {export_column} content")
|
292
|
+
|
293
|
+
logger.info(f"Created {files_created} text files in: {output_dir}")
|
294
|
+
return files_created
|
295
|
+
|
296
|
+
def to_huggingface_hub(
|
297
|
+
self,
|
298
|
+
dataset_name: str,
|
299
|
+
token: Optional[str] = None,
|
300
|
+
license: Optional[str] = None,
|
301
|
+
private: bool = False,
|
302
|
+
nsfw: bool = False,
|
303
|
+
tags: Optional[List[str]] = None,
|
304
|
+
language: str = "en",
|
305
|
+
task_categories: Optional[List[str]] = None,
|
306
|
+
) -> str:
|
307
|
+
"""Export to Hugging Face Hub as a dataset.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
dataset_name: Name for the dataset (e.g., "username/dataset-name")
|
311
|
+
token: Hugging Face API token
|
312
|
+
license: License for the dataset (required for new repos)
|
313
|
+
private: Whether to make the dataset private
|
314
|
+
nsfw: Whether to add not-for-all-audiences tag
|
315
|
+
tags: Additional tags for the dataset
|
316
|
+
language: Language code (default: "en")
|
317
|
+
task_categories: Task categories (default: ["text-to-image", "image-to-image"])
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
URL of the uploaded dataset
|
321
|
+
"""
|
322
|
+
try:
|
323
|
+
from huggingface_hub import HfApi, DatasetCard, create_repo
|
324
|
+
import pyarrow as pa
|
325
|
+
import pyarrow.parquet as pq
|
326
|
+
except ImportError:
|
327
|
+
raise ExportError(
|
328
|
+
"huggingface_hub and pyarrow are required for HF export. "
|
329
|
+
"Install with: pip install huggingface_hub pyarrow"
|
330
|
+
)
|
331
|
+
|
332
|
+
# Initialize HF API
|
333
|
+
api = HfApi(token=token)
|
334
|
+
|
335
|
+
# Check if repo exists
|
336
|
+
repo_exists = False
|
337
|
+
try:
|
338
|
+
api.dataset_info(dataset_name)
|
339
|
+
repo_exists = True
|
340
|
+
logger.info(f"Dataset {dataset_name} already exists, will update it")
|
341
|
+
except:
|
342
|
+
logger.info(f"Creating new dataset: {dataset_name}")
|
343
|
+
if not license:
|
344
|
+
raise ExportError("License is required when creating a new dataset")
|
345
|
+
|
346
|
+
# Create repo if it doesn't exist
|
347
|
+
if not repo_exists:
|
348
|
+
create_repo(repo_id=dataset_name, repo_type="dataset", private=private, token=token)
|
349
|
+
|
350
|
+
# Prepare data for parquet
|
351
|
+
df = pd.DataFrame(self.contents.rows)
|
352
|
+
|
353
|
+
# Convert any remaining non-serializable types
|
354
|
+
for col in df.columns:
|
355
|
+
if df[col].dtype == "object":
|
356
|
+
df[col] = df[col].apply(
|
357
|
+
lambda x: self._serialize_value(x) if x is not None else None
|
358
|
+
)
|
359
|
+
|
360
|
+
# Determine size category
|
361
|
+
num_rows = len(df)
|
362
|
+
if num_rows < 1000:
|
363
|
+
size_category = "n<1K"
|
364
|
+
elif num_rows < 10000:
|
365
|
+
size_category = "1K<n<10K"
|
366
|
+
elif num_rows < 100000:
|
367
|
+
size_category = "10K<n<100K"
|
368
|
+
elif num_rows < 1000000:
|
369
|
+
size_category = "100K<n<1M"
|
370
|
+
elif num_rows < 10000000:
|
371
|
+
size_category = "1M<n<10M"
|
372
|
+
else:
|
373
|
+
size_category = "n>10M"
|
374
|
+
|
375
|
+
# Prepare tags
|
376
|
+
all_tags = tags or []
|
377
|
+
if nsfw:
|
378
|
+
all_tags.append("not-for-all-audiences")
|
379
|
+
|
380
|
+
# Default task categories
|
381
|
+
if task_categories is None:
|
382
|
+
task_categories = ["text-to-image", "image-to-image"]
|
383
|
+
|
384
|
+
# Create dataset card
|
385
|
+
card_content = f"""---
|
386
|
+
license: {license or 'unknown'}
|
387
|
+
language:
|
388
|
+
- {language}
|
389
|
+
size_categories:
|
390
|
+
- {size_category}
|
391
|
+
task_categories:
|
392
|
+
{self._yaml_list(task_categories)}"""
|
393
|
+
|
394
|
+
if all_tags:
|
395
|
+
card_content += f"\ntags:\n{self._yaml_list(all_tags)}"
|
396
|
+
|
397
|
+
card_content += f"""
|
398
|
+
---
|
399
|
+
|
400
|
+
# Caption Dataset
|
401
|
+
|
402
|
+
This dataset contains {num_rows:,} captioned items exported from CaptionFlow.
|
403
|
+
|
404
|
+
## Dataset Structure
|
405
|
+
|
406
|
+
### Data Fields
|
407
|
+
|
408
|
+
"""
|
409
|
+
|
410
|
+
# Add field descriptions
|
411
|
+
for col in df.columns:
|
412
|
+
dtype = str(df[col].dtype)
|
413
|
+
if col in self.contents.output_fields:
|
414
|
+
card_content += f"- `{col}`: List of captions/outputs\n"
|
415
|
+
else:
|
416
|
+
card_content += f"- `{col}`: {dtype}\n"
|
417
|
+
|
418
|
+
if self.contents.metadata:
|
419
|
+
card_content += "\n## Export Information\n\n"
|
420
|
+
if "export_timestamp" in self.contents.metadata:
|
421
|
+
card_content += (
|
422
|
+
f"- Export timestamp: {self.contents.metadata['export_timestamp']}\n"
|
423
|
+
)
|
424
|
+
if "field_stats" in self.contents.metadata:
|
425
|
+
card_content += "\n### Field Statistics\n\n"
|
426
|
+
for field, stats in self.contents.metadata["field_stats"].items():
|
427
|
+
card_content += f"- `{field}`: {stats['total_items']:,} items across {stats['rows_with_data']:,} rows\n"
|
428
|
+
|
429
|
+
# Create temporary parquet file
|
430
|
+
import tempfile
|
431
|
+
|
432
|
+
with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp_file:
|
433
|
+
temp_path = Path(tmp_file.name)
|
434
|
+
|
435
|
+
try:
|
436
|
+
# Write parquet file
|
437
|
+
table = pa.Table.from_pandas(df)
|
438
|
+
pq.write_table(table, temp_path, compression="snappy")
|
439
|
+
|
440
|
+
# Upload parquet file
|
441
|
+
api.upload_file(
|
442
|
+
path_or_fileobj=str(temp_path),
|
443
|
+
path_in_repo="data.parquet",
|
444
|
+
repo_id=dataset_name,
|
445
|
+
repo_type="dataset",
|
446
|
+
token=token,
|
447
|
+
)
|
448
|
+
|
449
|
+
# Create and upload dataset card
|
450
|
+
card = DatasetCard(card_content)
|
451
|
+
card.push_to_hub(dataset_name, token=token)
|
452
|
+
|
453
|
+
dataset_url = f"https://huggingface.co/datasets/{dataset_name}"
|
454
|
+
logger.info(f"Successfully uploaded dataset to: {dataset_url}")
|
455
|
+
|
456
|
+
return dataset_url
|
457
|
+
|
458
|
+
finally:
|
459
|
+
# Clean up temp file
|
460
|
+
if temp_path.exists():
|
461
|
+
temp_path.unlink()
|
462
|
+
|
463
|
+
def _yaml_list(self, items: List[str]) -> str:
|
464
|
+
"""Format a list for YAML."""
|
465
|
+
return "\n".join(f"- {item}" for item in items)
|
466
|
+
|
467
|
+
|
468
|
+
# Addition to StorageManager class
|
469
|
+
async def get_storage_contents(
|
470
|
+
self,
|
471
|
+
limit: Optional[int] = None,
|
472
|
+
columns: Optional[List[str]] = None,
|
473
|
+
include_metadata: bool = True,
|
474
|
+
) -> StorageContents:
|
475
|
+
"""Retrieve storage contents for export.
|
476
|
+
|
477
|
+
Args:
|
478
|
+
limit: Maximum number of rows to retrieve
|
479
|
+
columns: Specific columns to include (None for all)
|
480
|
+
include_metadata: Whether to include metadata in the result
|
481
|
+
|
482
|
+
Returns:
|
483
|
+
StorageContents instance with the requested data
|
484
|
+
"""
|
485
|
+
if not self.captions_path.exists():
|
486
|
+
return StorageContents(
|
487
|
+
rows=[],
|
488
|
+
columns=[],
|
489
|
+
output_fields=list(self.known_output_fields),
|
490
|
+
total_rows=0,
|
491
|
+
metadata={"message": "No captions file found"},
|
492
|
+
)
|
493
|
+
|
494
|
+
# Flush buffers first to ensure all data is on disk
|
495
|
+
await self.checkpoint()
|
496
|
+
|
497
|
+
# Determine columns to read
|
498
|
+
if columns:
|
499
|
+
# Validate requested columns exist
|
500
|
+
table_metadata = pq.read_metadata(self.captions_path)
|
501
|
+
available_columns = set(table_metadata.schema.names)
|
502
|
+
invalid_columns = set(columns) - available_columns
|
503
|
+
if invalid_columns:
|
504
|
+
raise ValueError(f"Columns not found: {invalid_columns}")
|
505
|
+
columns_to_read = columns
|
506
|
+
else:
|
507
|
+
# Read all columns
|
508
|
+
columns_to_read = None
|
509
|
+
|
510
|
+
# Read the table
|
511
|
+
table = pq.read_table(self.captions_path, columns=columns_to_read)
|
512
|
+
df = table.to_pandas()
|
513
|
+
|
514
|
+
# Apply limit if specified
|
515
|
+
if limit:
|
516
|
+
df = df.head(limit)
|
517
|
+
|
518
|
+
# Convert to list of dicts
|
519
|
+
rows = df.to_dict("records")
|
520
|
+
|
521
|
+
# Parse metadata JSON strings back to dicts if present
|
522
|
+
if "metadata" in df.columns:
|
523
|
+
for row in rows:
|
524
|
+
if row.get("metadata"):
|
525
|
+
try:
|
526
|
+
row["metadata"] = json.loads(row["metadata"])
|
527
|
+
except:
|
528
|
+
pass # Keep as string if parsing fails
|
529
|
+
|
530
|
+
# Prepare metadata
|
531
|
+
metadata = {}
|
532
|
+
if include_metadata:
|
533
|
+
stats = await self.get_caption_stats()
|
534
|
+
metadata.update(
|
535
|
+
{
|
536
|
+
"export_timestamp": pd.Timestamp.now().isoformat(),
|
537
|
+
"total_available_rows": stats.get("total_rows", 0),
|
538
|
+
"rows_exported": len(rows),
|
539
|
+
"storage_path": str(self.captions_path),
|
540
|
+
"field_stats": stats.get("field_stats", {}),
|
541
|
+
}
|
542
|
+
)
|
543
|
+
|
544
|
+
return StorageContents(
|
545
|
+
rows=rows,
|
546
|
+
columns=list(df.columns),
|
547
|
+
output_fields=list(self.known_output_fields),
|
548
|
+
total_rows=len(df),
|
549
|
+
metadata=metadata,
|
550
|
+
)
|
@@ -15,7 +15,7 @@ from collections import defaultdict, deque
|
|
15
15
|
import time
|
16
16
|
import numpy as np
|
17
17
|
|
18
|
-
from
|
18
|
+
from ..models import Job, Caption, Contributor, StorageContents, JobId
|
19
19
|
|
20
20
|
logger = logging.getLogger(__name__)
|
21
21
|
logger.setLevel(logging.INFO)
|
@@ -575,6 +575,7 @@ class StorageManager:
|
|
575
575
|
logger.info("Starting storage optimization...")
|
576
576
|
|
577
577
|
# Read the full table
|
578
|
+
backup_path = None
|
578
579
|
table = pq.read_table(self.captions_path)
|
579
580
|
df = table.to_pandas()
|
580
581
|
original_columns = len(df.columns)
|
@@ -790,6 +791,89 @@ class StorageManager:
|
|
790
791
|
|
791
792
|
return job_ids
|
792
793
|
|
794
|
+
async def get_storage_contents(
|
795
|
+
self,
|
796
|
+
limit: Optional[int] = None,
|
797
|
+
columns: Optional[List[str]] = None,
|
798
|
+
include_metadata: bool = True,
|
799
|
+
) -> StorageContents:
|
800
|
+
"""Retrieve storage contents for export.
|
801
|
+
|
802
|
+
Args:
|
803
|
+
limit: Maximum number of rows to retrieve
|
804
|
+
columns: Specific columns to include (None for all)
|
805
|
+
include_metadata: Whether to include metadata in the result
|
806
|
+
|
807
|
+
Returns:
|
808
|
+
StorageContents instance with the requested data
|
809
|
+
"""
|
810
|
+
if not self.captions_path.exists():
|
811
|
+
return StorageContents(
|
812
|
+
rows=[],
|
813
|
+
columns=[],
|
814
|
+
output_fields=list(self.known_output_fields),
|
815
|
+
total_rows=0,
|
816
|
+
metadata={"message": "No captions file found"},
|
817
|
+
)
|
818
|
+
|
819
|
+
# Flush buffers first to ensure all data is on disk
|
820
|
+
await self.checkpoint()
|
821
|
+
|
822
|
+
# Determine columns to read
|
823
|
+
if columns:
|
824
|
+
# Validate requested columns exist
|
825
|
+
table_metadata = pq.read_metadata(self.captions_path)
|
826
|
+
available_columns = set(table_metadata.schema.names)
|
827
|
+
invalid_columns = set(columns) - available_columns
|
828
|
+
if invalid_columns:
|
829
|
+
raise ValueError(f"Columns not found: {invalid_columns}")
|
830
|
+
columns_to_read = columns
|
831
|
+
else:
|
832
|
+
# Read all columns
|
833
|
+
columns_to_read = None
|
834
|
+
|
835
|
+
# Read the table
|
836
|
+
table = pq.read_table(self.captions_path, columns=columns_to_read)
|
837
|
+
df = table.to_pandas()
|
838
|
+
|
839
|
+
# Apply limit if specified
|
840
|
+
if limit:
|
841
|
+
df = df.head(limit)
|
842
|
+
|
843
|
+
# Convert to list of dicts
|
844
|
+
rows = df.to_dict("records")
|
845
|
+
|
846
|
+
# Parse metadata JSON strings back to dicts if present
|
847
|
+
if "metadata" in df.columns:
|
848
|
+
for row in rows:
|
849
|
+
if row.get("metadata"):
|
850
|
+
try:
|
851
|
+
row["metadata"] = json.loads(row["metadata"])
|
852
|
+
except:
|
853
|
+
pass # Keep as string if parsing fails
|
854
|
+
|
855
|
+
# Prepare metadata
|
856
|
+
metadata = {}
|
857
|
+
if include_metadata:
|
858
|
+
stats = await self.get_caption_stats()
|
859
|
+
metadata.update(
|
860
|
+
{
|
861
|
+
"export_timestamp": pd.Timestamp.now().isoformat(),
|
862
|
+
"total_available_rows": stats.get("total_rows", 0),
|
863
|
+
"rows_exported": len(rows),
|
864
|
+
"storage_path": str(self.captions_path),
|
865
|
+
"field_stats": stats.get("field_stats", {}),
|
866
|
+
}
|
867
|
+
)
|
868
|
+
|
869
|
+
return StorageContents(
|
870
|
+
rows=rows,
|
871
|
+
columns=list(df.columns),
|
872
|
+
output_fields=list(self.known_output_fields),
|
873
|
+
total_rows=len(df),
|
874
|
+
metadata=metadata,
|
875
|
+
)
|
876
|
+
|
793
877
|
async def get_processed_jobs_for_chunk(self, chunk_id: str) -> Set[str]:
|
794
878
|
"""Get all processed job_ids for a given chunk."""
|
795
879
|
if not self.captions_path.exists():
|