caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,18 +4,21 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  from dataclasses import asdict
7
- from datetime import datetime
7
+ from datetime import datetime, timedelta
8
8
  from pathlib import Path
9
9
  from typing import List, Optional, Set, Dict, Any
10
10
  import pyarrow as pa
11
11
  import pyarrow.parquet as pq
12
12
  from pyarrow import fs
13
13
  import pandas as pd
14
- from collections import defaultdict
14
+ from collections import defaultdict, deque
15
+ import time
16
+ import numpy as np
15
17
 
16
- from .models import Job, Caption, Contributor, JobStatus
18
+ from ..models import Job, Caption, Contributor, StorageContents, JobId
17
19
 
18
20
  logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.INFO)
19
22
 
20
23
 
21
24
  class StorageManager:
@@ -60,6 +63,11 @@ class StorageManager:
60
63
  self.total_flushes = 0
61
64
  self.duplicates_skipped = 0
62
65
 
66
+ # Rate tracking
67
+ self.row_additions = deque(maxlen=10000) # Store (timestamp, row_count) tuples
68
+ self.start_time = time.time()
69
+ self.last_rate_log_time = time.time()
70
+
63
71
  # Base caption schema without dynamic output fields
64
72
  self.base_caption_fields = [
65
73
  ("job_id", pa.string()),
@@ -68,6 +76,8 @@ class StorageManager:
68
76
  ("chunk_id", pa.string()),
69
77
  ("item_key", pa.string()),
70
78
  ("item_index", pa.int32()),
79
+ ("filename", pa.string()),
80
+ ("url", pa.string()),
71
81
  ("caption_count", pa.int32()),
72
82
  ("contributor_id", pa.string()),
73
83
  ("timestamp", pa.timestamp("us")),
@@ -105,6 +115,137 @@ class StorageManager:
105
115
  ]
106
116
  )
107
117
 
118
+ def _is_column_empty(self, df: pd.DataFrame, column_name: str) -> bool:
119
+ """Check if a column is entirely empty, null, or contains only zeros/empty lists."""
120
+ if column_name not in df.columns:
121
+ return True
122
+
123
+ col = df[column_name]
124
+
125
+ # Check if all values are null/NaN
126
+ if col.isna().all():
127
+ return True
128
+
129
+ # For numeric columns, check if all non-null values are 0
130
+ if pd.api.types.is_numeric_dtype(col):
131
+ non_null_values = col.dropna()
132
+ if len(non_null_values) > 0 and (non_null_values == 0).all():
133
+ return True
134
+
135
+ # For list columns, check if all are None or empty lists
136
+ if col.dtype == "object":
137
+ non_null_values = col.dropna()
138
+ if len(non_null_values) == 0:
139
+ return True
140
+ # Check if all non-null values are empty lists
141
+ all_empty_lists = True
142
+ for val in non_null_values:
143
+ if isinstance(val, list) and len(val) > 0:
144
+ all_empty_lists = False
145
+ break
146
+ elif not isinstance(val, list):
147
+ all_empty_lists = False
148
+ break
149
+ if all_empty_lists:
150
+ return True
151
+
152
+ return False
153
+
154
+ def _get_non_empty_columns(
155
+ self, df: pd.DataFrame, preserve_base_fields: bool = True
156
+ ) -> List[str]:
157
+ """Get list of columns that contain actual data.
158
+
159
+ Args:
160
+ df: DataFrame to check
161
+ preserve_base_fields: If True, always include base fields even if empty
162
+ """
163
+ base_field_names = {field[0] for field in self.base_caption_fields}
164
+ non_empty_columns = []
165
+
166
+ for col in df.columns:
167
+ # Always keep base fields if preserve_base_fields is True
168
+ if preserve_base_fields and col in base_field_names:
169
+ non_empty_columns.append(col)
170
+ elif not self._is_column_empty(df, col):
171
+ non_empty_columns.append(col)
172
+
173
+ return non_empty_columns
174
+
175
+ def _calculate_rates(self) -> Dict[str, float]:
176
+ """Calculate row addition rates over different time windows."""
177
+ current_time = time.time()
178
+ rates = {}
179
+
180
+ # Define time windows in minutes
181
+ windows = {"1min": 1, "5min": 5, "15min": 15, "60min": 60}
182
+
183
+ # Clean up old entries beyond the largest window
184
+ cutoff_time = current_time - (60 * 60) # 60 minutes
185
+ while self.row_additions and self.row_additions[0][0] < cutoff_time:
186
+ self.row_additions.popleft()
187
+
188
+ # Calculate rates for each window
189
+ for window_name, window_minutes in windows.items():
190
+ window_seconds = window_minutes * 60
191
+ window_start = current_time - window_seconds
192
+
193
+ # Sum rows added within this window
194
+ rows_in_window = sum(
195
+ count for timestamp, count in self.row_additions if timestamp >= window_start
196
+ )
197
+
198
+ # Calculate rate (rows per second)
199
+ # For windows larger than elapsed time, use elapsed time
200
+ elapsed = current_time - self.start_time
201
+ actual_window = min(window_seconds, elapsed)
202
+
203
+ if actual_window > 0:
204
+ rate = rows_in_window / actual_window
205
+ rates[window_name] = rate
206
+ else:
207
+ rates[window_name] = 0.0
208
+
209
+ # Calculate instantaneous rate (last minute)
210
+ instant_window_start = current_time - 60 # Last 60 seconds
211
+ instant_rows = sum(
212
+ count for timestamp, count in self.row_additions if timestamp >= instant_window_start
213
+ )
214
+ instant_window = min(60, current_time - self.start_time)
215
+ rates["instant"] = instant_rows / instant_window if instant_window > 0 else 0.0
216
+
217
+ # Calculate overall rate since start
218
+ total_elapsed = current_time - self.start_time
219
+ if total_elapsed > 0:
220
+ rates["overall"] = self.total_captions_written / total_elapsed
221
+ else:
222
+ rates["overall"] = 0.0
223
+
224
+ return rates
225
+
226
+ def _log_rates(self, rows_added: int):
227
+ """Log rate information if enough time has passed."""
228
+ current_time = time.time()
229
+
230
+ # Log rates every 10 seconds or if it's been more than 30 seconds
231
+ time_since_last_log = current_time - self.last_rate_log_time
232
+ if time_since_last_log < 10 and rows_added < 50:
233
+ return
234
+
235
+ rates = self._calculate_rates()
236
+
237
+ # Format the rate information
238
+ rate_str = (
239
+ f"Rate stats - Instant: {rates['instant']:.1f} rows/s | "
240
+ f"Avg (5m): {rates['5min']:.1f} | "
241
+ f"Avg (15m): {rates['15min']:.1f} | "
242
+ f"Avg (60m): {rates['60min']:.1f} | "
243
+ f"Overall: {rates['overall']:.1f} rows/s"
244
+ )
245
+
246
+ logger.info(rate_str)
247
+ self.last_rate_log_time = current_time
248
+
108
249
  def _get_existing_output_columns(self) -> Set[str]:
109
250
  """Get output field columns that actually exist in the parquet file."""
110
251
  if not self.captions_path.exists():
@@ -216,9 +357,14 @@ class StorageManager:
216
357
  if "outputs" in df.columns:
217
358
  df = df.drop(columns=["outputs"])
218
359
 
219
- # Update known fields and schema
220
- self.known_output_fields = output_fields
221
- self.caption_schema = self._build_caption_schema(output_fields)
360
+ # Remove empty columns before saving (but preserve base fields)
361
+ non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=True)
362
+ df = df[non_empty_columns]
363
+
364
+ # Update known fields and schema based on non-empty columns
365
+ base_field_names = {field[0] for field in self.base_caption_fields}
366
+ self.known_output_fields = set(non_empty_columns) - base_field_names
367
+ self.caption_schema = self._build_caption_schema(self.known_output_fields)
222
368
 
223
369
  # Write migrated table
224
370
  migrated_table = pa.Table.from_pandas(df, schema=self.caption_schema)
@@ -226,8 +372,7 @@ class StorageManager:
226
372
  logger.info("Migration complete - outputs now stored in dynamic columns")
227
373
 
228
374
  async def save_caption(self, caption: Caption):
229
- """Save a caption entry with dynamic output columns."""
230
- # Convert to dict
375
+ """Save a caption entry, grouping outputs by job_id/item_key (not separating captions)."""
231
376
  caption_dict = asdict(caption)
232
377
 
233
378
  # Extract item_index from metadata if present
@@ -242,16 +387,61 @@ class StorageManager:
242
387
  # Remove old "captions" field if it exists (will be in outputs)
243
388
  caption_dict.pop("captions", None)
244
389
 
245
- new_fields = set()
390
+ # Grouping key: (job_id, item_key)
391
+ _job_id = caption_dict.get("job_id")
392
+ job_id = JobId.from_dict(_job_id).get_sample_str()
393
+ group_key = job_id
394
+ logger.debug(
395
+ f"save_caption: group_key={group_key}, outputs={list(outputs.keys())}, caption_count={caption_dict.get('caption_count')}, item_index={caption_dict.get('item_index')}"
396
+ )
397
+
398
+ # Try to find existing buffered row for this group
399
+ found_row = False
400
+ for idx, row in enumerate(self.caption_buffer):
401
+ check_key = row.get("job_id")
402
+ logger.debug(f"Checking buffer row {idx}: check_key={check_key}, group_key={group_key}")
403
+ if check_key == group_key:
404
+ found_row = True
405
+ logger.debug(f"Found existing buffer row for group_key={group_key} at index {idx}")
406
+ # Merge outputs into existing row
407
+ for field_name, field_values in outputs.items():
408
+ if field_name not in self.known_output_fields:
409
+ self.known_output_fields.add(field_name)
410
+ logger.info(f"New output field detected: {field_name}")
411
+ if field_name in row and isinstance(row[field_name], list):
412
+ logger.debug(
413
+ f"Merging output field '{field_name}' into existing row: before={row[field_name]}, adding={field_values}"
414
+ )
415
+ row[field_name].extend(field_values)
416
+ logger.debug(f"After merge: {row[field_name]}")
417
+ else:
418
+ logger.debug(
419
+ f"Setting new output field '{field_name}' in existing row: {field_values}"
420
+ )
421
+ row[field_name] = list(field_values)
422
+ # Optionally update other fields (e.g., caption_count)
423
+ if "caption_count" in caption_dict:
424
+ old_count = row.get("caption_count", 0)
425
+ row["caption_count"] = old_count + caption_dict["caption_count"]
426
+ logger.debug(
427
+ f"Updated caption_count for group_key={group_key}: {old_count} + {caption_dict['caption_count']} = {row['caption_count']}"
428
+ )
429
+ return # Already merged, no need to add new row
430
+ else:
431
+ logger.debug(f"Caption row not found for group key: {group_key} vs {check_key}")
432
+
433
+ if not found_row:
434
+ logger.debug(
435
+ f"No existing buffer row found for group_key={group_key}, creating new row."
436
+ )
437
+
438
+ # If not found, create new row
246
439
  for field_name, field_values in outputs.items():
247
- caption_dict[field_name] = field_values
248
440
  if field_name not in self.known_output_fields:
249
- new_fields.add(field_name)
250
- self.known_output_fields.add(field_name) # Add immediately
251
-
252
- if new_fields:
253
- logger.info(f"New output fields detected: {sorted(new_fields)}")
254
- logger.info(f"Total known output fields: {sorted(self.known_output_fields)}")
441
+ self.known_output_fields.add(field_name)
442
+ logger.info(f"New output field detected: {field_name}")
443
+ caption_dict[field_name] = list(field_values)
444
+ logger.debug(f"Adding output field '{field_name}' to new row: {field_values}")
255
445
 
256
446
  # Serialize metadata to JSON if present
257
447
  if "metadata" in caption_dict:
@@ -259,68 +449,16 @@ class StorageManager:
259
449
  else:
260
450
  caption_dict["metadata"] = "{}"
261
451
 
262
- # Add to buffer
263
- self.caption_buffer.append(caption_dict)
264
-
265
- # Log buffer status
266
- logger.debug(f"Caption buffer size: {len(self.caption_buffer)}/{self.caption_buffer_size}")
267
-
268
- # Flush if buffer is large enough
269
- if len(self.caption_buffer) >= self.caption_buffer_size:
270
- await self._flush_captions()
271
-
272
- async def save_captions(self, caption_data: Dict[str, Any]):
273
- """Save captions for an image - compatible with dict input."""
274
- job_id = caption_data["job_id"]
275
-
276
- # Check if we already have captions for this job_id
277
- if job_id in self.existing_caption_job_ids:
278
- self.duplicates_skipped += 1
279
- logger.debug(f"Skipping duplicate captions for job_id: {job_id}")
280
- return
281
-
282
- # Check if it's already in the buffer
283
- for buffered in self.caption_buffer:
284
- if buffered["job_id"] == job_id:
285
- logger.debug(f"Captions for job_id {job_id} already in buffer")
286
- return
287
-
288
- # Handle outputs if present
289
- if "outputs" in caption_data:
290
- outputs = caption_data.pop("outputs")
291
- # Add each output field directly to caption_data
292
- for field_name, field_values in outputs.items():
293
- caption_data[field_name] = field_values
294
- if field_name not in self.known_output_fields:
295
- self.known_output_fields.add(field_name)
296
- logger.info(f"New output field detected: {field_name}")
297
-
298
- # Handle legacy captions field
299
- if "captions" in caption_data and "captions" not in self.known_output_fields:
300
- self.known_output_fields.add("captions")
301
-
302
- # Count all outputs
303
- caption_count = 0
304
- for field_name in self.known_output_fields:
305
- if field_name in caption_data and isinstance(caption_data[field_name], list):
306
- caption_count += len(caption_data[field_name])
307
-
308
- caption_data["caption_count"] = caption_count
452
+ if isinstance(caption_dict.get("job_id"), dict):
453
+ caption_dict["job_id"] = job_id
309
454
 
310
- # Add default values for optional fields
311
- if "quality_scores" not in caption_data:
312
- caption_data["quality_scores"] = None
313
-
314
- if "metadata" in caption_data and isinstance(caption_data["metadata"], dict):
315
- caption_data["metadata"] = json.dumps(caption_data["metadata"])
316
- elif "metadata" not in caption_data:
317
- caption_data["metadata"] = "{}"
318
-
319
- self.caption_buffer.append(caption_data)
320
- self.existing_caption_job_ids.add(job_id)
455
+ self.caption_buffer.append(caption_dict)
456
+ logger.debug(
457
+ f"Appended new caption row for group_key={group_key}. Caption buffer size: {len(self.caption_buffer)}/{self.caption_buffer_size}"
458
+ )
321
459
 
322
- # Flush if buffer is large enough
323
460
  if len(self.caption_buffer) >= self.caption_buffer_size:
461
+ logger.debug("Caption buffer full, flushing captions.")
324
462
  await self._flush_captions()
325
463
 
326
464
  async def _flush_captions(self):
@@ -337,25 +475,7 @@ class StorageManager:
337
475
  if field_name in row and isinstance(row[field_name], list):
338
476
  total_outputs += len(row[field_name])
339
477
 
340
- logger.info(f"Flushing {num_rows} rows with {total_outputs} total outputs to disk")
341
-
342
- # Check if we need to evolve the schema
343
- current_schema_fields = set(self.caption_schema.names) if self.caption_schema else set()
344
- all_fields_needed = set(
345
- self.base_caption_fields[i][0] for i in range(len(self.base_caption_fields))
346
- )
347
- all_fields_needed.update(self.known_output_fields)
348
-
349
- if all_fields_needed != current_schema_fields:
350
- # Schema evolution needed
351
- logger.info(
352
- f"Evolving schema to include new fields: {all_fields_needed - current_schema_fields}"
353
- )
354
- self.caption_schema = self._build_caption_schema(self.known_output_fields)
355
-
356
- # If file exists, we need to migrate it
357
- if self.captions_path.exists():
358
- await self._evolve_schema_on_disk()
478
+ logger.debug(f"Flushing {num_rows} rows with {total_outputs} total outputs to disk")
359
479
 
360
480
  # Prepare data with all required columns
361
481
  prepared_buffer = []
@@ -374,8 +494,9 @@ class StorageManager:
374
494
 
375
495
  prepared_buffer.append(prepared_row)
376
496
 
377
- # Create table from buffer
378
- table = pa.Table.from_pylist(prepared_buffer, schema=self.caption_schema)
497
+ # Build schema with all known fields (base + output)
498
+ schema = self._build_caption_schema(self.known_output_fields)
499
+ table = pa.Table.from_pylist(prepared_buffer, schema=schema)
379
500
 
380
501
  if self.captions_path.exists():
381
502
  # Read existing table
@@ -391,45 +512,146 @@ class StorageManager:
391
512
  if row["job_id"] not in existing_job_ids:
392
513
  new_rows.append(row)
393
514
  elif row not in duplicate_rows:
394
- duplicate_rows.append(row)
515
+ duplicate_rows.append(
516
+ {
517
+ "input": row,
518
+ "existing_job": existing.to_pandas()[
519
+ existing.to_pandas()["job_id"] == row["job_id"]
520
+ ].to_dict(orient="records"),
521
+ }
522
+ )
395
523
 
396
524
  if duplicate_rows:
397
525
  logger.info(f"Example duplicate row: {duplicate_rows[0]}")
526
+
398
527
  if new_rows:
399
528
  # Create table from new rows only
400
- new_table = pa.Table.from_pylist(new_rows, schema=self.caption_schema)
529
+ new_table = pa.Table.from_pylist(new_rows, schema=schema)
401
530
 
402
- # Combine tables
403
- combined = pa.concat_tables([existing, new_table])
531
+ # Concatenate with promote_options="default" to handle schema differences automatically
532
+ combined = pa.concat_tables([existing, new_table], promote_options="default")
404
533
 
405
- # Write with proper preservation
534
+ # Write combined table
406
535
  pq.write_table(combined, self.captions_path, compression="snappy")
407
536
 
408
- logger.info(
409
- f"Added {len(new_rows)} new rows (skipped {num_rows - len(new_rows)} duplicates)"
410
- )
537
+ self.duplicates_skipped = num_rows - len(new_rows)
411
538
  actual_new = len(new_rows)
412
539
  else:
413
- logger.info(f"All {num_rows} rows were duplicates, skipping write")
414
- actual_new = 0
540
+ logger.info(f"All {num_rows} rows were duplicates, exiting")
541
+ raise SystemError("No duplicates can be submitted")
415
542
  else:
416
- # Write new file
543
+ # Write new file with all fields
417
544
  pq.write_table(table, self.captions_path, compression="snappy")
418
545
  actual_new = num_rows
419
546
 
547
+ # Update statistics
420
548
  self.total_captions_written += actual_new
421
549
  self.total_caption_entries_written += total_outputs
422
550
  self.total_flushes += 1
423
551
  self.caption_buffer.clear()
424
552
 
553
+ # Track row additions for rate calculation
554
+ if actual_new > 0:
555
+ current_time = time.time()
556
+ self.row_additions.append((current_time, actual_new))
557
+
558
+ # Log rates
559
+ self._log_rates(actual_new)
560
+
425
561
  logger.info(
426
- f"Successfully wrote captions (rows: {self.total_captions_written}, "
427
- f"total outputs: {self.total_caption_entries_written}, "
428
- f"duplicates skipped: {self.duplicates_skipped})"
562
+ f"Successfully wrote captions (new rows: {actual_new}, "
563
+ f"total rows written: {self.total_captions_written}, "
564
+ f"total captions written: {self.total_caption_entries_written}, "
565
+ f"duplicates skipped: {self.duplicates_skipped}, "
566
+ f"output fields: {sorted(list(self.known_output_fields))})"
429
567
  )
430
568
 
569
+ async def optimize_storage(self):
570
+ """Optimize storage by dropping empty columns. Run this periodically or on-demand."""
571
+ if not self.captions_path.exists():
572
+ logger.info("No captions file to optimize")
573
+ return
574
+
575
+ logger.info("Starting storage optimization...")
576
+
577
+ # Read the full table
578
+ backup_path = None
579
+ table = pq.read_table(self.captions_path)
580
+ df = table.to_pandas()
581
+ original_columns = len(df.columns)
582
+
583
+ # Find non-empty columns (don't preserve empty base fields)
584
+ non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=False)
585
+
586
+ # Always keep at least job_id
587
+ if "job_id" not in non_empty_columns:
588
+ non_empty_columns.append("job_id")
589
+
590
+ if len(non_empty_columns) < original_columns:
591
+ # We have columns to drop
592
+ df_optimized = df[non_empty_columns]
593
+
594
+ # Rebuild schema for non-empty columns only
595
+ base_field_names = {f[0] for f in self.base_caption_fields}
596
+ fields = []
597
+ output_fields = set()
598
+
599
+ # Process columns in a consistent order: base fields first, then output fields
600
+ for col in non_empty_columns:
601
+ if col in base_field_names:
602
+ # Find the base field definition
603
+ for fname, ftype in self.base_caption_fields:
604
+ if fname == col:
605
+ fields.append((fname, ftype))
606
+ break
607
+ else:
608
+ # Output field
609
+ output_fields.add(col)
610
+
611
+ # Add output fields in sorted order
612
+ for field_name in sorted(output_fields):
613
+ fields.append((field_name, pa.list_(pa.string())))
614
+
615
+ # Create optimized schema and table
616
+ optimized_schema = pa.schema(fields)
617
+ optimized_table = pa.Table.from_pandas(df_optimized, schema=optimized_schema)
618
+
619
+ # Backup the original file (optional)
620
+ backup_path = self.captions_path.with_suffix(".parquet.bak")
621
+ import shutil
622
+
623
+ shutil.copy2(self.captions_path, backup_path)
624
+
625
+ # Write optimized table
626
+ pq.write_table(optimized_table, self.captions_path, compression="snappy")
627
+
628
+ # Update known output fields
629
+ self.known_output_fields = output_fields
630
+
631
+ # Clean up backup (optional - keep it for safety)
632
+ # backup_path.unlink()
633
+
634
+ logger.info(
635
+ f"Storage optimization complete: {original_columns} -> {len(non_empty_columns)} columns. "
636
+ f"Removed columns: {sorted(set(df.columns) - set(non_empty_columns))}"
637
+ )
638
+ else:
639
+ logger.info(f"No optimization needed - all {original_columns} columns contain data")
640
+
641
+ # Report file size reduction
642
+ import os
643
+
644
+ if backup_path and backup_path.exists():
645
+ original_size = os.path.getsize(backup_path)
646
+ new_size = os.path.getsize(self.captions_path)
647
+ reduction_pct = (1 - new_size / original_size) * 100
648
+ logger.info(
649
+ f"File size: {original_size/1024/1024:.1f}MB -> {new_size/1024/1024:.1f}MB "
650
+ f"({reduction_pct:.1f}% reduction)"
651
+ )
652
+
431
653
  async def _evolve_schema_on_disk(self):
432
- """Evolve the schema of the existing parquet file to include new columns."""
654
+ """Evolve the schema of the existing parquet file to include new columns, removing empty ones."""
433
655
  logger.info("Evolving schema on disk to add new columns...")
434
656
 
435
657
  # Read existing data
@@ -442,63 +664,24 @@ class StorageManager:
442
664
  df[field_name] = None
443
665
  logger.info(f"Added new column: {field_name}")
444
666
 
445
- # Recreate table with new schema
446
- evolved_table = pa.Table.from_pandas(df, schema=self.caption_schema)
447
- pq.write_table(evolved_table, self.captions_path, compression="snappy")
448
- logger.info("Schema evolution complete")
667
+ # Remove empty columns (but preserve base fields)
668
+ non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=True)
669
+ df = df[non_empty_columns]
449
670
 
450
- async def get_captions(self, job_id: str) -> Optional[Dict[str, List[str]]]:
451
- """Retrieve all output fields for a specific job_id."""
452
- # Check buffer first
453
- for buffered in self.caption_buffer:
454
- if buffered["job_id"] == job_id:
455
- outputs = {}
456
- for field_name in self.known_output_fields:
457
- if field_name in buffered and buffered[field_name]:
458
- outputs[field_name] = buffered[field_name]
459
- return outputs
460
-
461
- if not self.captions_path.exists():
462
- return None
463
-
464
- table = pq.read_table(self.captions_path)
465
- df = table.to_pandas()
671
+ # Update known output fields
672
+ base_field_names = {field[0] for field in self.base_caption_fields}
673
+ self.known_output_fields = set(non_empty_columns) - base_field_names
466
674
 
467
- row = df[df["job_id"] == job_id]
468
- if row.empty:
469
- return None
675
+ # Recreate schema with only non-empty fields
676
+ self.caption_schema = self._build_caption_schema(self.known_output_fields)
470
677
 
471
- # Collect all output fields
472
- outputs = {}
473
- for field_name in self.known_output_fields:
474
- if field_name in row.columns:
475
- value = row.iloc[0][field_name]
476
- if pd.notna(value) and value is not None:
477
- outputs[field_name] = value
478
-
479
- return outputs if outputs else None
480
-
481
- async def save_job(self, job: Job):
482
- """Save or update a job - buffers until batch size reached."""
483
- # For updates, we still add to buffer (will be handled in flush)
484
- self.job_buffer.append(
485
- {
486
- "job_id": job.job_id,
487
- "dataset": job.dataset,
488
- "shard": job.shard,
489
- "item_key": job.item_key,
490
- "status": job.status.value,
491
- "assigned_to": job.assigned_to,
492
- "created_at": job.created_at,
493
- "updated_at": datetime.utcnow(),
494
- }
678
+ # Recreate table with new schema
679
+ evolved_table = pa.Table.from_pandas(df, schema=self.caption_schema)
680
+ pq.write_table(evolved_table, self.captions_path, compression="snappy")
681
+ logger.info(
682
+ f"Schema evolution complete. Active output fields: {sorted(list(self.known_output_fields))}"
495
683
  )
496
684
 
497
- self.existing_job_ids.add(job.job_id)
498
-
499
- if len(self.job_buffer) >= self.job_buffer_size:
500
- await self._flush_jobs()
501
-
502
685
  async def save_contributor(self, contributor: Contributor):
503
686
  """Save or update contributor stats - buffers until batch size reached."""
504
687
  self.contributor_buffer.append(asdict(contributor))
@@ -575,84 +758,134 @@ class StorageManager:
575
758
  await self._flush_jobs()
576
759
  await self._flush_contributors()
577
760
 
578
- logger.info(
579
- f"Checkpoint complete. Total rows: {self.total_captions_written}, "
580
- f"Total caption entries: {self.total_caption_entries_written}, "
581
- f"Duplicates skipped: {self.duplicates_skipped}"
582
- )
761
+ # Log final rate statistics
762
+ if self.total_captions_written > 0:
763
+ rates = self._calculate_rates()
764
+ logger.info(
765
+ f"Checkpoint complete. Total rows: {self.total_captions_written}, "
766
+ f"Total caption entries: {self.total_caption_entries_written}, "
767
+ f"Duplicates skipped: {self.duplicates_skipped} | "
768
+ f"Overall rate: {rates['overall']:.1f} rows/s"
769
+ )
770
+ else:
771
+ logger.info(
772
+ f"Checkpoint complete. Total rows: {self.total_captions_written}, "
773
+ f"Total caption entries: {self.total_caption_entries_written}, "
774
+ f"Duplicates skipped: {self.duplicates_skipped}"
775
+ )
583
776
 
584
- async def job_exists(self, job_id: str) -> bool:
585
- """Check if a job already exists in storage or buffer."""
586
- if job_id in self.existing_job_ids:
587
- return True
777
+ def get_all_processed_job_ids(self) -> Set[str]:
778
+ """Get all processed job_ids - useful for resumption."""
779
+ if not self.captions_path.exists():
780
+ logger.info("No captions file found, returning empty processed job_ids set")
781
+ return set()
588
782
 
589
- # Check buffer
590
- for buffered in self.job_buffer:
591
- if buffered["job_id"] == job_id:
592
- return True
783
+ # Read only the job_id column
784
+ table = pq.read_table(self.captions_path, columns=["job_id"])
785
+ job_ids = set(table["job_id"].to_pylist())
593
786
 
594
- return False
787
+ # Add buffered job_ids
788
+ for row in self.caption_buffer:
789
+ if "job_id" in row:
790
+ job_ids.add(row["job_id"])
595
791
 
596
- async def get_job(self, job_id: str) -> Optional[Job]:
597
- """Retrieve a job by ID."""
598
- # Check buffer first
599
- for buffered in self.job_buffer:
600
- if buffered["job_id"] == job_id:
601
- return Job(
602
- job_id=buffered["job_id"],
603
- dataset=buffered["dataset"],
604
- shard=buffered["shard"],
605
- item_key=buffered["item_key"],
606
- status=JobStatus(buffered["status"]),
607
- assigned_to=buffered["assigned_to"],
608
- created_at=buffered["created_at"],
609
- )
792
+ return job_ids
610
793
 
611
- if not self.jobs_path.exists():
612
- return None
794
+ async def get_storage_contents(
795
+ self,
796
+ limit: Optional[int] = None,
797
+ columns: Optional[List[str]] = None,
798
+ include_metadata: bool = True,
799
+ ) -> StorageContents:
800
+ """Retrieve storage contents for export.
801
+
802
+ Args:
803
+ limit: Maximum number of rows to retrieve
804
+ columns: Specific columns to include (None for all)
805
+ include_metadata: Whether to include metadata in the result
806
+
807
+ Returns:
808
+ StorageContents instance with the requested data
809
+ """
810
+ if not self.captions_path.exists():
811
+ return StorageContents(
812
+ rows=[],
813
+ columns=[],
814
+ output_fields=list(self.known_output_fields),
815
+ total_rows=0,
816
+ metadata={"message": "No captions file found"},
817
+ )
818
+
819
+ # Flush buffers first to ensure all data is on disk
820
+ await self.checkpoint()
821
+
822
+ # Determine columns to read
823
+ if columns:
824
+ # Validate requested columns exist
825
+ table_metadata = pq.read_metadata(self.captions_path)
826
+ available_columns = set(table_metadata.schema.names)
827
+ invalid_columns = set(columns) - available_columns
828
+ if invalid_columns:
829
+ raise ValueError(f"Columns not found: {invalid_columns}")
830
+ columns_to_read = columns
831
+ else:
832
+ # Read all columns
833
+ columns_to_read = None
613
834
 
614
- table = pq.read_table(self.jobs_path)
835
+ # Read the table
836
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
615
837
  df = table.to_pandas()
616
838
 
617
- row = df[df["job_id"] == job_id]
618
- if row.empty:
619
- return None
839
+ # Apply limit if specified
840
+ if limit:
841
+ df = df.head(limit)
842
+
843
+ # Convert to list of dicts
844
+ rows = df.to_dict("records")
845
+
846
+ # Parse metadata JSON strings back to dicts if present
847
+ if "metadata" in df.columns:
848
+ for row in rows:
849
+ if row.get("metadata"):
850
+ try:
851
+ row["metadata"] = json.loads(row["metadata"])
852
+ except:
853
+ pass # Keep as string if parsing fails
854
+
855
+ # Prepare metadata
856
+ metadata = {}
857
+ if include_metadata:
858
+ stats = await self.get_caption_stats()
859
+ metadata.update(
860
+ {
861
+ "export_timestamp": pd.Timestamp.now().isoformat(),
862
+ "total_available_rows": stats.get("total_rows", 0),
863
+ "rows_exported": len(rows),
864
+ "storage_path": str(self.captions_path),
865
+ "field_stats": stats.get("field_stats", {}),
866
+ }
867
+ )
620
868
 
621
- return Job(
622
- job_id=row.iloc[0]["job_id"],
623
- dataset=row.iloc[0]["dataset"],
624
- shard=row.iloc[0]["shard"],
625
- item_key=row.iloc[0]["item_key"],
626
- status=JobStatus(row.iloc[0]["status"]),
627
- assigned_to=row.iloc[0]["assigned_to"],
628
- created_at=row.iloc[0]["created_at"],
869
+ return StorageContents(
870
+ rows=rows,
871
+ columns=list(df.columns),
872
+ output_fields=list(self.known_output_fields),
873
+ total_rows=len(df),
874
+ metadata=metadata,
629
875
  )
630
876
 
631
- async def get_jobs_by_worker(self, worker_id: str) -> List[Job]:
632
- """Get all jobs assigned to a worker."""
633
- if not self.jobs_path.exists():
634
- return []
877
+ async def get_processed_jobs_for_chunk(self, chunk_id: str) -> Set[str]:
878
+ """Get all processed job_ids for a given chunk."""
879
+ if not self.captions_path.exists():
880
+ return set()
635
881
 
636
- table = pq.read_table(self.jobs_path)
882
+ # Read only job_id and chunk_id columns
883
+ table = pq.read_table(self.captions_path, columns=["job_id", "chunk_id"])
637
884
  df = table.to_pandas()
638
885
 
639
- rows = df[df["assigned_to"] == worker_id]
640
-
641
- jobs = []
642
- for _, row in rows.iterrows():
643
- jobs.append(
644
- Job(
645
- job_id=row["job_id"],
646
- dataset=row["dataset"],
647
- shard=row["shard"],
648
- item_key=row["item_key"],
649
- status=JobStatus(row["status"]),
650
- assigned_to=row["assigned_to"],
651
- created_at=row["created_at"],
652
- )
653
- )
654
-
655
- return jobs
886
+ # Filter by chunk_id and return job_ids
887
+ chunk_jobs = df[df["chunk_id"] == chunk_id]["job_id"].tolist()
888
+ return set(chunk_jobs)
656
889
 
657
890
  async def get_caption_stats(self) -> Dict[str, Any]:
658
891
  """Get statistics about stored captions including field-specific stats."""
@@ -683,11 +916,12 @@ class StorageManager:
683
916
  field_lengths = []
684
917
 
685
918
  for value in df.loc[non_null_mask, field_name]:
919
+ # list or array-like
686
920
  if isinstance(value, list):
687
921
  length = len(value)
688
922
  field_total += length
689
923
  field_lengths.append(length)
690
- elif pd.notna(value):
924
+ elif value.any():
691
925
  length = 1
692
926
  field_total += length
693
927
  field_lengths.append(length)
@@ -719,46 +953,6 @@ class StorageManager:
719
953
  },
720
954
  }
721
955
 
722
- async def get_sample_captions(self, n: int = 5) -> List[Dict[str, Any]]:
723
- """Get a sample of caption entries showing all output fields."""
724
- if not self.captions_path.exists():
725
- return []
726
-
727
- table = pq.read_table(self.captions_path)
728
- df = table.to_pandas()
729
-
730
- if len(df) == 0:
731
- return []
732
-
733
- sample_df = df.sample(min(n, len(df)))
734
- samples = []
735
-
736
- for _, row in sample_df.iterrows():
737
- # Collect outputs from dynamic columns
738
- outputs = {}
739
- total_outputs = 0
740
-
741
- for field_name in self.known_output_fields:
742
- if field_name in row and pd.notna(row[field_name]):
743
- value = row[field_name]
744
- outputs[field_name] = value
745
- if isinstance(value, list):
746
- total_outputs += len(value)
747
-
748
- samples.append(
749
- {
750
- "job_id": row["job_id"],
751
- "item_key": row["item_key"],
752
- "outputs": outputs,
753
- "field_count": len(outputs),
754
- "total_outputs": total_outputs,
755
- "image_dims": f"{row.get('image_width', 'N/A')}x{row.get('image_height', 'N/A')}",
756
- "has_metadata": bool(row.get("metadata") and row["metadata"] != "{}"),
757
- }
758
- )
759
-
760
- return samples
761
-
762
956
  async def count_captions(self) -> int:
763
957
  """Count total outputs across all dynamic fields."""
764
958
  total = 0
@@ -888,142 +1082,26 @@ class StorageManager:
888
1082
  "fields": sorted(list(field_counts.keys())),
889
1083
  }
890
1084
 
891
- async def get_captions_with_field(
892
- self, field_name: str, limit: int = 100
893
- ) -> List[Dict[str, Any]]:
894
- """Get captions that have a specific output field."""
895
- if not self.captions_path.exists():
896
- return []
897
-
898
- if field_name not in self.known_output_fields:
899
- logger.warning(f"Field '{field_name}' not found in known output fields")
900
- return []
901
-
902
- # Check if the field actually exists in the file
903
- existing_output_columns = self._get_existing_output_columns()
904
- if field_name not in existing_output_columns:
905
- logger.warning(
906
- f"Field '{field_name}' exists in known fields but not in parquet file yet"
907
- )
908
- return []
909
-
910
- # Only read necessary columns
911
- columns_to_read = ["job_id", "item_key", field_name]
912
-
913
- try:
914
- table = pq.read_table(self.captions_path, columns=columns_to_read)
915
- except Exception as e:
916
- logger.error(f"Error reading field '{field_name}': {e}")
917
- return []
918
-
919
- df = table.to_pandas()
920
-
921
- # Filter rows where field has data
922
- mask = df[field_name].notna()
923
- filtered_df = df[mask].head(limit)
924
-
925
- results = []
926
- for _, row in filtered_df.iterrows():
927
- results.append(
928
- {
929
- "job_id": row["job_id"],
930
- "item_key": row["item_key"],
931
- field_name: row[field_name],
932
- "value_count": len(row[field_name]) if isinstance(row[field_name], list) else 1,
933
- }
934
- )
935
-
936
- return results
937
-
938
- async def export_by_field(self, field_name: str, output_path: Path, format: str = "jsonl"):
939
- """Export all captions for a specific field."""
940
- if not self.captions_path.exists():
941
- logger.warning("No captions to export")
942
- return 0
943
-
944
- if field_name not in self.known_output_fields:
945
- logger.warning(f"Field '{field_name}' not found in known output fields")
946
- return 0
947
-
948
- # Check if the field actually exists in the file
949
- existing_output_columns = self._get_existing_output_columns()
950
- if field_name not in existing_output_columns:
951
- logger.warning(f"Field '{field_name}' not found in parquet file")
952
- return 0
953
-
954
- # Read only necessary columns
955
- columns_to_read = ["item_key", "dataset", field_name]
956
- table = pq.read_table(self.captions_path, columns=columns_to_read)
957
- df = table.to_pandas()
958
-
959
- exported = 0
960
- with open(output_path, "w") as f:
961
- for _, row in df.iterrows():
962
- if pd.notna(row[field_name]) and row[field_name]:
963
- if format == "jsonl":
964
- record = {
965
- "item_key": row["item_key"],
966
- "dataset": row["dataset"],
967
- field_name: row[field_name],
968
- }
969
- f.write(json.dumps(record) + "\n")
970
- exported += 1
971
-
972
- logger.info(f"Exported {exported} items with field '{field_name}' to {output_path}")
973
- return exported
974
-
975
- async def get_pending_jobs(self) -> List[Job]:
976
- """Get all pending jobs for restoration on startup."""
977
- if not self.jobs_path.exists():
978
- return []
979
-
980
- table = pq.read_table(self.jobs_path)
981
- df = table.to_pandas()
982
-
983
- # Get jobs with PENDING or PROCESSING status
984
- pending_df = df[df["status"].isin([JobStatus.PENDING.value, JobStatus.PROCESSING.value])]
985
-
986
- jobs = []
987
- for _, row in pending_df.iterrows():
988
- jobs.append(
989
- Job(
990
- job_id=row["job_id"],
991
- dataset=row["dataset"],
992
- shard=row["shard"],
993
- item_key=row["item_key"],
994
- status=JobStatus(row["status"]),
995
- assigned_to=row.get("assigned_to"),
996
- created_at=row["created_at"],
997
- )
998
- )
999
-
1000
- return jobs
1001
-
1002
- async def count_jobs(self) -> int:
1003
- """Count total jobs."""
1004
- if not self.jobs_path.exists():
1005
- return 0
1006
-
1007
- table = pq.read_table(self.jobs_path)
1008
- return len(table)
1009
-
1010
- async def count_completed_jobs(self) -> int:
1011
- """Count completed jobs."""
1012
- if not self.jobs_path.exists():
1013
- return 0
1014
-
1015
- table = pq.read_table(self.jobs_path)
1016
- df = table.to_pandas()
1017
- return len(df[df["status"] == JobStatus.COMPLETED.value])
1018
-
1019
1085
  async def close(self):
1020
1086
  """Close storage and flush buffers."""
1021
1087
  await self.checkpoint()
1022
- logger.info(
1023
- f"Storage closed. Total rows: {self.total_captions_written}, "
1024
- f"Total caption entries: {self.total_caption_entries_written}, "
1025
- f"Duplicates skipped: {self.duplicates_skipped}"
1026
- )
1088
+
1089
+ # Log final rate statistics
1090
+ if self.total_captions_written > 0:
1091
+ rates = self._calculate_rates()
1092
+ logger.info(
1093
+ f"Storage closed. Total rows: {self.total_captions_written}, "
1094
+ f"Total caption entries: {self.total_caption_entries_written}, "
1095
+ f"Duplicates skipped: {self.duplicates_skipped} | "
1096
+ f"Final rates - Overall: {rates['overall']:.1f} rows/s, "
1097
+ f"Last hour: {rates['60min']:.1f} rows/s"
1098
+ )
1099
+ else:
1100
+ logger.info(
1101
+ f"Storage closed. Total rows: {self.total_captions_written}, "
1102
+ f"Total caption entries: {self.total_caption_entries_written}, "
1103
+ f"Duplicates skipped: {self.duplicates_skipped}"
1104
+ )
1027
1105
 
1028
1106
  async def get_storage_stats(self) -> Dict[str, Any]:
1029
1107
  """Get all storage-related statistics."""
@@ -1041,6 +1119,9 @@ class StorageManager:
1041
1119
  field_stats = await self.get_caption_stats()
1042
1120
  total_rows_including_buffer = await self.count_caption_rows() + len(self.caption_buffer)
1043
1121
 
1122
+ # Calculate rates
1123
+ rates = self._calculate_rates()
1124
+
1044
1125
  return {
1045
1126
  "total_captions": disk_outputs + buffer_outputs,
1046
1127
  "total_rows": total_rows_including_buffer,
@@ -1053,4 +1134,11 @@ class StorageManager:
1053
1134
  "field_breakdown": field_stats.get("field_stats", None),
1054
1135
  "job_buffer_size": len(self.job_buffer),
1055
1136
  "contributor_buffer_size": len(self.contributor_buffer),
1137
+ "rates": {
1138
+ "instant": f"{rates['instant']:.1f} rows/s",
1139
+ "5min": f"{rates['5min']:.1f} rows/s",
1140
+ "15min": f"{rates['15min']:.1f} rows/s",
1141
+ "60min": f"{rates['60min']:.1f} rows/s",
1142
+ "overall": f"{rates['overall']:.1f} rows/s",
1143
+ },
1056
1144
  }