caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  """Arrow/Parquet storage management with dynamic column support for outputs."""
2
2
 
3
3
  import asyncio
4
+ import gc
4
5
  import json
5
6
  import logging
6
7
  from dataclasses import asdict
@@ -15,10 +16,10 @@ from collections import defaultdict, deque
15
16
  import time
16
17
  import numpy as np
17
18
 
18
- from .models import Job, Caption, Contributor, JobStatus, JobId
19
+ from ..models import Job, Caption, Contributor, StorageContents, JobId
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
- logger.setLevel(logging.INFO)
22
+ logger.setLevel(logging.DEBUG)
22
23
 
23
24
 
24
25
  class StorageManager:
@@ -28,8 +29,7 @@ class StorageManager:
28
29
  self,
29
30
  data_dir: Path,
30
31
  caption_buffer_size: int = 100,
31
- job_buffer_size: int = 100,
32
- contributor_buffer_size: int = 10,
32
+ contributor_buffer_size: int = 50,
33
33
  ):
34
34
  self.data_dir = Path(data_dir)
35
35
  self.data_dir.mkdir(parents=True, exist_ok=True)
@@ -38,6 +38,7 @@ class StorageManager:
38
38
  self.captions_path = self.data_dir / "captions.parquet"
39
39
  self.jobs_path = self.data_dir / "jobs.parquet"
40
40
  self.contributors_path = self.data_dir / "contributors.parquet"
41
+ self.stats_path = self.data_dir / "storage_stats.json" # Persist stats here
41
42
 
42
43
  # In-memory buffers for batching writes
43
44
  self.caption_buffer = []
@@ -46,7 +47,6 @@ class StorageManager:
46
47
 
47
48
  # Buffer size configuration
48
49
  self.caption_buffer_size = caption_buffer_size
49
- self.job_buffer_size = job_buffer_size
50
50
  self.contributor_buffer_size = contributor_buffer_size
51
51
 
52
52
  # Track existing job_ids to prevent duplicates
@@ -57,14 +57,20 @@ class StorageManager:
57
57
  # Track known output fields for schema evolution
58
58
  self.known_output_fields: Set[str] = set()
59
59
 
60
- # Statistics
61
- self.total_captions_written = 0
62
- self.total_caption_entries_written = 0 # Total individual captions
63
- self.total_flushes = 0
64
- self.duplicates_skipped = 0
60
+ # In-memory statistics (loaded once at startup, then tracked incrementally)
61
+ self.stats = {
62
+ "disk_rows": 0, # Rows in parquet file
63
+ "disk_outputs": 0, # Total outputs in parquet file
64
+ "field_counts": {}, # Count of outputs per field on disk
65
+ "total_captions_written": 0, # Total rows written during this session
66
+ "total_caption_entries_written": 0, # Total individual captions written
67
+ "total_flushes": 0,
68
+ "duplicates_skipped": 0,
69
+ "session_field_counts": {}, # Outputs per field written this session
70
+ }
65
71
 
66
72
  # Rate tracking
67
- self.row_additions = deque(maxlen=10000) # Store (timestamp, row_count) tuples
73
+ self.row_additions = deque(maxlen=100) # Store (timestamp, row_count) tuples
68
74
  self.start_time = time.time()
69
75
  self.last_rate_log_time = time.time()
70
76
 
@@ -115,6 +121,129 @@ class StorageManager:
115
121
  ]
116
122
  )
117
123
 
124
+ def _save_stats(self):
125
+ """Persist current stats to disk."""
126
+ try:
127
+ with open(self.stats_path, "w") as f:
128
+ json.dump(self.stats, f, indent=2)
129
+ except Exception as e:
130
+ logger.error(f"Failed to save stats: {e}")
131
+
132
+ def _load_stats(self):
133
+ """Load stats from disk if available."""
134
+ if self.stats_path.exists():
135
+ try:
136
+ with open(self.stats_path, "r") as f:
137
+ loaded_stats = json.load(f)
138
+ # Merge loaded stats with defaults
139
+ self.stats.update(loaded_stats)
140
+ logger.info(f"Loaded stats from {self.stats_path}")
141
+ except Exception as e:
142
+ logger.error(f"Failed to load stats: {e}")
143
+
144
+ def _calculate_initial_stats(self):
145
+ """Calculate stats from parquet file - only called once at initialization."""
146
+ if not self.captions_path.exists():
147
+ return
148
+
149
+ logger.info("Calculating initial statistics from parquet file...")
150
+
151
+ try:
152
+ # Get metadata to determine row count
153
+ table_metadata = pq.read_metadata(self.captions_path)
154
+ self.stats["disk_rows"] = table_metadata.num_rows
155
+
156
+ # Simply use the known_output_fields that were already detected during initialization
157
+ # This avoids issues with PyArrow schema parsing
158
+ if self.known_output_fields and table_metadata.num_rows > 0:
159
+ # Read the entire table since column-specific reading is causing issues
160
+ table = pq.read_table(self.captions_path)
161
+ df = table.to_pandas()
162
+
163
+ total_outputs = 0
164
+ field_counts = {}
165
+
166
+ # Count outputs in each known output field
167
+ for field_name in self.known_output_fields:
168
+ if field_name in df.columns:
169
+ field_count = 0
170
+ column_data = df[field_name]
171
+
172
+ # Iterate through values more carefully to avoid numpy array issues
173
+ for i in range(len(column_data)):
174
+ value = column_data.iloc[i]
175
+ # Handle None/NaN values first
176
+ if value is None:
177
+ continue
178
+ # For lists/arrays, check if they're actually None or have content
179
+ try:
180
+ # If it's a list-like object, count its length
181
+ if hasattr(value, "__len__") and not isinstance(value, str):
182
+ # Check if it's a pandas null value by trying to check the first element
183
+ if len(value) > 0:
184
+ field_count += len(value)
185
+ # Empty lists are valid but contribute 0
186
+ except TypeError:
187
+ # Value is scalar NA/NaN, skip it
188
+ continue
189
+ except:
190
+ # Any other error, skip this value
191
+ continue
192
+
193
+ if field_count > 0:
194
+ field_counts[field_name] = field_count
195
+ total_outputs += field_count
196
+
197
+ self.stats["disk_outputs"] = total_outputs
198
+ self.stats["field_counts"] = field_counts
199
+
200
+ # Clean up
201
+ del df, table
202
+ gc.collect()
203
+ else:
204
+ self.stats["disk_outputs"] = 0
205
+ self.stats["field_counts"] = {}
206
+
207
+ logger.info(
208
+ f"Initial stats: {self.stats['disk_rows']} rows, {self.stats['disk_outputs']} outputs, fields: {list(self.stats['field_counts'].keys())}"
209
+ )
210
+
211
+ except Exception as e:
212
+ logger.error(f"Failed to calculate initial stats: {e}", exc_info=True)
213
+ # Set default values
214
+ self.stats["disk_rows"] = 0
215
+ self.stats["disk_outputs"] = 0
216
+ self.stats["field_counts"] = {}
217
+
218
+ # Save the calculated stats
219
+ self._save_stats()
220
+
221
+ def _update_stats_for_new_captions(self, captions_added: List[dict], rows_added: int):
222
+ """Update stats incrementally as new captions are added."""
223
+ # Update row counts
224
+ self.stats["disk_rows"] += rows_added
225
+ self.stats["total_captions_written"] += rows_added
226
+
227
+ # Count outputs in the new captions
228
+ outputs_added = 0
229
+ for caption in captions_added:
230
+ for field_name in self.known_output_fields:
231
+ if field_name in caption and isinstance(caption[field_name], list):
232
+ count = len(caption[field_name])
233
+ outputs_added += count
234
+
235
+ # Update field-specific counts
236
+ if field_name not in self.stats["field_counts"]:
237
+ self.stats["field_counts"][field_name] = 0
238
+ self.stats["field_counts"][field_name] += count
239
+
240
+ if field_name not in self.stats["session_field_counts"]:
241
+ self.stats["session_field_counts"][field_name] = 0
242
+ self.stats["session_field_counts"][field_name] += count
243
+
244
+ self.stats["disk_outputs"] += outputs_added
245
+ self.stats["total_caption_entries_written"] += outputs_added
246
+
118
247
  def _is_column_empty(self, df: pd.DataFrame, column_name: str) -> bool:
119
248
  """Check if a column is entirely empty, null, or contains only zeros/empty lists."""
120
249
  if column_name not in df.columns:
@@ -217,7 +346,7 @@ class StorageManager:
217
346
  # Calculate overall rate since start
218
347
  total_elapsed = current_time - self.start_time
219
348
  if total_elapsed > 0:
220
- rates["overall"] = self.total_captions_written / total_elapsed
349
+ rates["overall"] = self.stats["total_captions_written"] / total_elapsed
221
350
  else:
222
351
  rates["overall"] = 0.0
223
352
 
@@ -269,6 +398,9 @@ class StorageManager:
269
398
 
270
399
  async def initialize(self):
271
400
  """Initialize storage files if they don't exist."""
401
+ # Load persisted stats if available
402
+ self._load_stats()
403
+
272
404
  if not self.captions_path.exists():
273
405
  # Create initial schema with just base fields
274
406
  self.caption_schema = self._build_caption_schema(set())
@@ -278,6 +410,11 @@ class StorageManager:
278
410
  empty_table = pa.Table.from_pydict(empty_dict, schema=self.caption_schema)
279
411
  pq.write_table(empty_table, self.captions_path)
280
412
  logger.info(f"Created empty caption storage at {self.captions_path}")
413
+
414
+ # Initialize stats
415
+ self.stats["disk_rows"] = 0
416
+ self.stats["disk_outputs"] = 0
417
+ self.stats["field_counts"] = {}
281
418
  else:
282
419
  # Load existing schema and detect output fields
283
420
  existing_table = pq.read_table(self.captions_path)
@@ -300,6 +437,10 @@ class StorageManager:
300
437
  logger.info(f"Loaded {len(self.existing_caption_job_ids)} existing caption job_ids")
301
438
  logger.info(f"Known output fields: {sorted(self.known_output_fields)}")
302
439
 
440
+ # Calculate initial stats if not already loaded from file
441
+ if self.stats["disk_rows"] == 0:
442
+ self._calculate_initial_stats()
443
+
303
444
  # Initialize other storage files...
304
445
  if not self.contributors_path.exists():
305
446
  empty_dict = {"contributor_id": [], "name": [], "total_captions": [], "trust_level": []}
@@ -371,6 +512,9 @@ class StorageManager:
371
512
  pq.write_table(migrated_table, self.captions_path)
372
513
  logger.info("Migration complete - outputs now stored in dynamic columns")
373
514
 
515
+ # Recalculate stats after migration
516
+ self._calculate_initial_stats()
517
+
374
518
  async def save_caption(self, caption: Caption):
375
519
  """Save a caption entry, grouping outputs by job_id/item_key (not separating captions)."""
376
520
  caption_dict = asdict(caption)
@@ -391,49 +535,32 @@ class StorageManager:
391
535
  _job_id = caption_dict.get("job_id")
392
536
  job_id = JobId.from_dict(_job_id).get_sample_str()
393
537
  group_key = job_id
394
- logger.debug(
395
- f"save_caption: group_key={group_key}, outputs={list(outputs.keys())}, caption_count={caption_dict.get('caption_count')}, item_index={caption_dict.get('item_index')}"
396
- )
538
+
539
+ # Check for duplicate - if this job_id already exists on disk, skip it
540
+ if group_key in self.existing_caption_job_ids:
541
+ self.stats["duplicates_skipped"] += 1
542
+ logger.debug(f"Skipping duplicate job_id: {group_key}")
543
+ return
397
544
 
398
545
  # Try to find existing buffered row for this group
399
546
  found_row = False
400
547
  for idx, row in enumerate(self.caption_buffer):
401
548
  check_key = row.get("job_id")
402
- logger.debug(f"Checking buffer row {idx}: check_key={check_key}, group_key={group_key}")
403
549
  if check_key == group_key:
404
550
  found_row = True
405
- logger.debug(f"Found existing buffer row for group_key={group_key} at index {idx}")
406
551
  # Merge outputs into existing row
407
552
  for field_name, field_values in outputs.items():
408
553
  if field_name not in self.known_output_fields:
409
554
  self.known_output_fields.add(field_name)
410
- logger.info(f"New output field detected: {field_name}")
411
555
  if field_name in row and isinstance(row[field_name], list):
412
- logger.debug(
413
- f"Merging output field '{field_name}' into existing row: before={row[field_name]}, adding={field_values}"
414
- )
415
556
  row[field_name].extend(field_values)
416
- logger.debug(f"After merge: {row[field_name]}")
417
557
  else:
418
- logger.debug(
419
- f"Setting new output field '{field_name}' in existing row: {field_values}"
420
- )
421
558
  row[field_name] = list(field_values)
422
559
  # Optionally update other fields (e.g., caption_count)
423
560
  if "caption_count" in caption_dict:
424
561
  old_count = row.get("caption_count", 0)
425
562
  row["caption_count"] = old_count + caption_dict["caption_count"]
426
- logger.debug(
427
- f"Updated caption_count for group_key={group_key}: {old_count} + {caption_dict['caption_count']} = {row['caption_count']}"
428
- )
429
563
  return # Already merged, no need to add new row
430
- else:
431
- logger.debug(f"Caption row not found for group key: {group_key} vs {check_key}")
432
-
433
- if not found_row:
434
- logger.debug(
435
- f"No existing buffer row found for group_key={group_key}, creating new row."
436
- )
437
564
 
438
565
  # If not found, create new row
439
566
  for field_name, field_values in outputs.items():
@@ -441,7 +568,6 @@ class StorageManager:
441
568
  self.known_output_fields.add(field_name)
442
569
  logger.info(f"New output field detected: {field_name}")
443
570
  caption_dict[field_name] = list(field_values)
444
- logger.debug(f"Adding output field '{field_name}' to new row: {field_values}")
445
571
 
446
572
  # Serialize metadata to JSON if present
447
573
  if "metadata" in caption_dict:
@@ -453,9 +579,6 @@ class StorageManager:
453
579
  caption_dict["job_id"] = job_id
454
580
 
455
581
  self.caption_buffer.append(caption_dict)
456
- logger.debug(
457
- f"Appended new caption row for group_key={group_key}. Caption buffer size: {len(self.caption_buffer)}/{self.caption_buffer_size}"
458
- )
459
582
 
460
583
  if len(self.caption_buffer) >= self.caption_buffer_size:
461
584
  logger.debug("Caption buffer full, flushing captions.")
@@ -466,105 +589,109 @@ class StorageManager:
466
589
  if not self.caption_buffer:
467
590
  return
468
591
 
469
- num_rows = len(self.caption_buffer)
592
+ try:
593
+ num_rows = len(self.caption_buffer)
470
594
 
471
- # Count total outputs across all fields
472
- total_outputs = 0
473
- for row in self.caption_buffer:
474
- for field_name in self.known_output_fields:
475
- if field_name in row and isinstance(row[field_name], list):
476
- total_outputs += len(row[field_name])
595
+ # Count total outputs across all fields
596
+ total_outputs = 0
597
+ for row in self.caption_buffer:
598
+ for field_name in self.known_output_fields:
599
+ if field_name in row and isinstance(row[field_name], list):
600
+ total_outputs += len(row[field_name])
477
601
 
478
- logger.debug(f"Flushing {num_rows} rows with {total_outputs} total outputs to disk")
602
+ logger.debug(
603
+ f"Flushing {num_rows} rows with {total_outputs} total outputs to disk. preparing data..."
604
+ )
479
605
 
480
- # Prepare data with all required columns
481
- prepared_buffer = []
482
- for row in self.caption_buffer:
483
- prepared_row = row.copy()
606
+ # Keep a copy of captions for stats update
607
+ captions_to_write = list(self.caption_buffer)
484
608
 
485
- # Ensure all base fields are present
486
- for field_name, field_type in self.base_caption_fields:
487
- if field_name not in prepared_row:
488
- prepared_row[field_name] = None
609
+ # Prepare data with all required columns
610
+ prepared_buffer = []
611
+ new_job_ids = []
612
+ for row in self.caption_buffer:
613
+ prepared_row = row.copy()
489
614
 
490
- # Ensure all output fields are present (even if None)
491
- for field_name in self.known_output_fields:
492
- if field_name not in prepared_row:
493
- prepared_row[field_name] = None
494
-
495
- prepared_buffer.append(prepared_row)
496
-
497
- # Build schema with all known fields (base + output)
498
- schema = self._build_caption_schema(self.known_output_fields)
499
- table = pa.Table.from_pylist(prepared_buffer, schema=schema)
500
-
501
- if self.captions_path.exists():
502
- # Read existing table
503
- existing = pq.read_table(self.captions_path)
504
-
505
- # Get existing job_ids for deduplication
506
- existing_job_ids = set(existing.column("job_id").to_pylist())
507
-
508
- # Filter new data to exclude duplicates
509
- new_rows = []
510
- duplicate_rows = []
511
- for row in prepared_buffer:
512
- if row["job_id"] not in existing_job_ids:
513
- new_rows.append(row)
514
- elif row not in duplicate_rows:
515
- duplicate_rows.append(
516
- {
517
- "input": row,
518
- "existing_job": existing.to_pandas()[
519
- existing.to_pandas()["job_id"] == row["job_id"]
520
- ].to_dict(orient="records"),
521
- }
522
- )
615
+ # Track job_ids for deduplication
616
+ job_id = prepared_row.get("job_id")
617
+ if job_id:
618
+ new_job_ids.append(job_id)
619
+
620
+ # Ensure all base fields are present
621
+ for field_name, field_type in self.base_caption_fields:
622
+ if field_name not in prepared_row:
623
+ prepared_row[field_name] = None
624
+
625
+ # Ensure all output fields are present (even if None)
626
+ for field_name in self.known_output_fields:
627
+ if field_name not in prepared_row:
628
+ prepared_row[field_name] = None
523
629
 
524
- if duplicate_rows:
525
- logger.info(f"Example duplicate row: {duplicate_rows[0]}")
630
+ prepared_buffer.append(prepared_row)
526
631
 
527
- if new_rows:
528
- # Create table from new rows only
529
- new_table = pa.Table.from_pylist(new_rows, schema=schema)
632
+ # Build schema with all known fields (base + output)
633
+ logger.debug("building schema...")
634
+ schema = self._build_caption_schema(self.known_output_fields)
635
+ logger.debug("schema built, creating table...")
636
+ table = pa.Table.from_pylist(prepared_buffer, schema=schema)
530
637
 
638
+ if self.captions_path.exists():
639
+ # Read existing table - this is necessary for parquet format
640
+ logger.debug("Reading existing captions file...")
641
+ existing = pq.read_table(self.captions_path)
642
+
643
+ logger.debug("writing new rows...")
531
644
  # Concatenate with promote_options="default" to handle schema differences automatically
532
- combined = pa.concat_tables([existing, new_table], promote_options="default")
645
+ logger.debug("concat tables...")
646
+ combined = pa.concat_tables([existing, table], promote_options="default")
533
647
 
534
648
  # Write combined table
535
649
  pq.write_table(combined, self.captions_path, compression="snappy")
536
650
 
537
- self.duplicates_skipped = num_rows - len(new_rows)
538
- actual_new = len(new_rows)
651
+ actual_new = len(table)
539
652
  else:
540
- logger.info(f"All {num_rows} rows were duplicates, exiting")
541
- raise SystemError("No duplicates can be submitted")
542
- else:
543
- # Write new file with all fields
544
- pq.write_table(table, self.captions_path, compression="snappy")
545
- actual_new = num_rows
653
+ # Write new file with all fields
654
+ pq.write_table(table, self.captions_path, compression="snappy")
655
+ actual_new = num_rows
656
+ logger.debug("write complete.")
546
657
 
547
- # Update statistics
548
- self.total_captions_written += actual_new
549
- self.total_caption_entries_written += total_outputs
550
- self.total_flushes += 1
551
- self.caption_buffer.clear()
658
+ # Clean up
659
+ del prepared_buffer, table
552
660
 
553
- # Track row additions for rate calculation
554
- if actual_new > 0:
555
- current_time = time.time()
556
- self.row_additions.append((current_time, actual_new))
661
+ # Update the in-memory job_id set for efficient deduplication
662
+ self.existing_caption_job_ids.update(new_job_ids)
557
663
 
558
- # Log rates
559
- self._log_rates(actual_new)
664
+ # Update statistics incrementally
665
+ self._update_stats_for_new_captions(captions_to_write, actual_new)
666
+ self.stats["total_flushes"] += 1
560
667
 
561
- logger.info(
562
- f"Successfully wrote captions (new rows: {actual_new}, "
563
- f"total rows written: {self.total_captions_written}, "
564
- f"total captions written: {self.total_caption_entries_written}, "
565
- f"duplicates skipped: {self.duplicates_skipped}, "
566
- f"output fields: {sorted(list(self.known_output_fields))})"
567
- )
668
+ # Clear buffer
669
+ self.caption_buffer.clear()
670
+
671
+ # Track row additions for rate calculation
672
+ if actual_new > 0:
673
+ current_time = time.time()
674
+ self.row_additions.append((current_time, actual_new))
675
+
676
+ # Log rates
677
+ self._log_rates(actual_new)
678
+
679
+ logger.info(
680
+ f"Successfully wrote captions (new rows: {actual_new}, "
681
+ f"total rows written: {self.stats['total_captions_written']}, "
682
+ f"total captions written: {self.stats['total_caption_entries_written']}, "
683
+ f"duplicates skipped: {self.stats['duplicates_skipped']}, "
684
+ f"output fields: {sorted(list(self.known_output_fields))})"
685
+ )
686
+
687
+ # Save stats periodically
688
+ if self.stats["total_flushes"] % 10 == 0:
689
+ self._save_stats()
690
+
691
+ finally:
692
+ self.caption_buffer.clear()
693
+ # Force garbage collection
694
+ gc.collect()
568
695
 
569
696
  async def optimize_storage(self):
570
697
  """Optimize storage by dropping empty columns. Run this periodically or on-demand."""
@@ -575,6 +702,7 @@ class StorageManager:
575
702
  logger.info("Starting storage optimization...")
576
703
 
577
704
  # Read the full table
705
+ backup_path = None
578
706
  table = pq.read_table(self.captions_path)
579
707
  df = table.to_pandas()
580
708
  original_columns = len(df.columns)
@@ -624,11 +752,15 @@ class StorageManager:
624
752
  # Write optimized table
625
753
  pq.write_table(optimized_table, self.captions_path, compression="snappy")
626
754
 
627
- # Update known output fields
755
+ # Update known output fields and recalculate stats
628
756
  self.known_output_fields = output_fields
629
757
 
630
- # Clean up backup (optional - keep it for safety)
631
- # backup_path.unlink()
758
+ # Remove dropped fields from stats
759
+ dropped_fields = set(self.stats["field_counts"].keys()) - output_fields
760
+ for field in dropped_fields:
761
+ del self.stats["field_counts"][field]
762
+ if field in self.stats["session_field_counts"]:
763
+ del self.stats["session_field_counts"][field]
632
764
 
633
765
  logger.info(
634
766
  f"Storage optimization complete: {original_columns} -> {len(non_empty_columns)} columns. "
@@ -649,37 +781,8 @@ class StorageManager:
649
781
  f"({reduction_pct:.1f}% reduction)"
650
782
  )
651
783
 
652
- async def _evolve_schema_on_disk(self):
653
- """Evolve the schema of the existing parquet file to include new columns, removing empty ones."""
654
- logger.info("Evolving schema on disk to add new columns...")
655
-
656
- # Read existing data
657
- existing_table = pq.read_table(self.captions_path)
658
- df = existing_table.to_pandas()
659
-
660
- # Add missing columns with None values
661
- for field_name in self.known_output_fields:
662
- if field_name not in df.columns:
663
- df[field_name] = None
664
- logger.info(f"Added new column: {field_name}")
665
-
666
- # Remove empty columns (but preserve base fields)
667
- non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=True)
668
- df = df[non_empty_columns]
669
-
670
- # Update known output fields
671
- base_field_names = {field[0] for field in self.base_caption_fields}
672
- self.known_output_fields = set(non_empty_columns) - base_field_names
673
-
674
- # Recreate schema with only non-empty fields
675
- self.caption_schema = self._build_caption_schema(self.known_output_fields)
676
-
677
- # Recreate table with new schema
678
- evolved_table = pa.Table.from_pandas(df, schema=self.caption_schema)
679
- pq.write_table(evolved_table, self.captions_path, compression="snappy")
680
- logger.info(
681
- f"Schema evolution complete. Active output fields: {sorted(list(self.known_output_fields))}"
682
- )
784
+ # Save updated stats
785
+ self._save_stats()
683
786
 
684
787
  async def save_contributor(self, contributor: Contributor):
685
788
  """Save or update contributor stats - buffers until batch size reached."""
@@ -757,38 +860,118 @@ class StorageManager:
757
860
  await self._flush_jobs()
758
861
  await self._flush_contributors()
759
862
 
863
+ # Save stats on checkpoint
864
+ self._save_stats()
865
+
760
866
  # Log final rate statistics
761
- if self.total_captions_written > 0:
867
+ if self.stats["total_captions_written"] > 0:
762
868
  rates = self._calculate_rates()
763
869
  logger.info(
764
- f"Checkpoint complete. Total rows: {self.total_captions_written}, "
765
- f"Total caption entries: {self.total_caption_entries_written}, "
766
- f"Duplicates skipped: {self.duplicates_skipped} | "
870
+ f"Checkpoint complete. Total rows: {self.stats['total_captions_written']}, "
871
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
872
+ f"Duplicates skipped: {self.stats['duplicates_skipped']} | "
767
873
  f"Overall rate: {rates['overall']:.1f} rows/s"
768
874
  )
769
875
  else:
770
876
  logger.info(
771
- f"Checkpoint complete. Total rows: {self.total_captions_written}, "
772
- f"Total caption entries: {self.total_caption_entries_written}, "
773
- f"Duplicates skipped: {self.duplicates_skipped}"
877
+ f"Checkpoint complete. Total rows: {self.stats['total_captions_written']}, "
878
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
879
+ f"Duplicates skipped: {self.stats['duplicates_skipped']}"
774
880
  )
775
881
 
776
882
  def get_all_processed_job_ids(self) -> Set[str]:
777
883
  """Get all processed job_ids - useful for resumption."""
778
- if not self.captions_path.exists():
779
- logger.info("No captions file found, returning empty processed job_ids set")
780
- return set()
781
-
782
- # Read only the job_id column
783
- table = pq.read_table(self.captions_path, columns=["job_id"])
784
- job_ids = set(table["job_id"].to_pylist())
884
+ # Return the in-memory set which is kept up-to-date
885
+ # Also add any job_ids currently in the buffer
886
+ all_job_ids = self.existing_caption_job_ids.copy()
785
887
 
786
- # Add buffered job_ids
787
888
  for row in self.caption_buffer:
788
889
  if "job_id" in row:
789
- job_ids.add(row["job_id"])
890
+ all_job_ids.add(row["job_id"])
891
+
892
+ return all_job_ids
893
+
894
+ async def get_storage_contents(
895
+ self,
896
+ limit: Optional[int] = None,
897
+ columns: Optional[List[str]] = None,
898
+ include_metadata: bool = True,
899
+ ) -> StorageContents:
900
+ """Retrieve storage contents for export.
901
+
902
+ Args:
903
+ limit: Maximum number of rows to retrieve
904
+ columns: Specific columns to include (None for all)
905
+ include_metadata: Whether to include metadata in the result
906
+
907
+ Returns:
908
+ StorageContents instance with the requested data
909
+ """
910
+ if not self.captions_path.exists():
911
+ return StorageContents(
912
+ rows=[],
913
+ columns=[],
914
+ output_fields=list(self.known_output_fields),
915
+ total_rows=0,
916
+ metadata={"message": "No captions file found"},
917
+ )
918
+
919
+ # Flush buffers first to ensure all data is on disk
920
+ await self.checkpoint()
790
921
 
791
- return job_ids
922
+ # Determine columns to read
923
+ if columns:
924
+ # Validate requested columns exist
925
+ table_metadata = pq.read_metadata(self.captions_path)
926
+ available_columns = set(table_metadata.schema.names)
927
+ invalid_columns = set(columns) - available_columns
928
+ if invalid_columns:
929
+ raise ValueError(f"Columns not found: {invalid_columns}")
930
+ columns_to_read = columns
931
+ else:
932
+ # Read all columns
933
+ columns_to_read = None
934
+
935
+ # Read the table
936
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
937
+ df = table.to_pandas()
938
+
939
+ # Apply limit if specified
940
+ if limit:
941
+ df = df.head(limit)
942
+
943
+ # Convert to list of dicts
944
+ rows = df.to_dict("records")
945
+
946
+ # Parse metadata JSON strings back to dicts if present
947
+ if "metadata" in df.columns:
948
+ for row in rows:
949
+ if row.get("metadata"):
950
+ try:
951
+ row["metadata"] = json.loads(row["metadata"])
952
+ except:
953
+ pass # Keep as string if parsing fails
954
+
955
+ # Prepare metadata
956
+ metadata = {}
957
+ if include_metadata:
958
+ metadata.update(
959
+ {
960
+ "export_timestamp": pd.Timestamp.now().isoformat(),
961
+ "total_available_rows": self.stats.get("disk_rows", 0),
962
+ "rows_exported": len(rows),
963
+ "storage_path": str(self.captions_path),
964
+ "field_stats": self.stats.get("field_counts", {}),
965
+ }
966
+ )
967
+
968
+ return StorageContents(
969
+ rows=rows,
970
+ columns=list(df.columns),
971
+ output_fields=list(self.known_output_fields),
972
+ total_rows=len(df),
973
+ metadata=metadata,
974
+ )
792
975
 
793
976
  async def get_processed_jobs_for_chunk(self, chunk_id: str) -> Set[str]:
794
977
  """Get all processed job_ids for a given chunk."""
@@ -804,91 +987,46 @@ class StorageManager:
804
987
  return set(chunk_jobs)
805
988
 
806
989
  async def get_caption_stats(self) -> Dict[str, Any]:
807
- """Get statistics about stored captions including field-specific stats."""
808
- if not self.captions_path.exists():
809
- return {"total_rows": 0, "total_outputs": 0, "output_fields": [], "field_stats": {}}
810
-
811
- table = pq.read_table(self.captions_path)
812
- df = table.to_pandas()
813
-
814
- if len(df) == 0:
815
- return {"total_rows": 0, "total_outputs": 0, "output_fields": [], "field_stats": {}}
990
+ """Get statistics about stored captions from cached values."""
991
+ total_rows = self.stats["disk_rows"] + len(self.caption_buffer)
816
992
 
817
- # Get actual columns in the dataframe
818
- existing_columns = set(df.columns)
993
+ # Count outputs in buffer
994
+ buffer_outputs = 0
995
+ buffer_field_counts = defaultdict(int)
996
+ for row in self.caption_buffer:
997
+ for field_name in self.known_output_fields:
998
+ if field_name in row and isinstance(row[field_name], list):
999
+ count = len(row[field_name])
1000
+ buffer_outputs += count
1001
+ buffer_field_counts[field_name] += count
819
1002
 
820
- # Calculate stats per field (only for fields that exist in the file)
1003
+ # Merge buffer counts with disk counts
821
1004
  field_stats = {}
822
- total_outputs = 0
823
-
824
1005
  for field_name in self.known_output_fields:
825
- if field_name in existing_columns:
826
- # Count non-null entries
827
- non_null_mask = df[field_name].notna()
828
- non_null_count = non_null_mask.sum()
829
-
830
- # Count total items in lists
831
- field_total = 0
832
- field_lengths = []
833
-
834
- for value in df.loc[non_null_mask, field_name]:
835
- # list or array-like
836
- if isinstance(value, list):
837
- length = len(value)
838
- field_total += length
839
- field_lengths.append(length)
840
- elif value.any():
841
- length = 1
842
- field_total += length
843
- field_lengths.append(length)
844
-
845
- if field_lengths:
846
- field_stats[field_name] = {
847
- "rows_with_data": non_null_count,
848
- "total_items": field_total,
849
- "avg_items_per_row": sum(field_lengths) / len(field_lengths),
850
- }
851
- if min(field_lengths) != max(field_lengths):
852
- field_stats[field_name].update(
853
- {
854
- "min_items": min(field_lengths),
855
- "max_items": max(field_lengths),
856
- }
857
- )
858
- total_outputs += field_total
1006
+ disk_count = self.stats["field_counts"].get(field_name, 0)
1007
+ buffer_count = buffer_field_counts.get(field_name, 0)
1008
+ total_count = disk_count + buffer_count
1009
+
1010
+ if total_count > 0:
1011
+ field_stats[field_name] = {
1012
+ "total_items": total_count,
1013
+ "disk_items": disk_count,
1014
+ "buffer_items": buffer_count,
1015
+ }
1016
+
1017
+ total_outputs = self.stats["disk_outputs"] + buffer_outputs
859
1018
 
860
1019
  return {
861
- "total_rows": len(df),
1020
+ "total_rows": total_rows,
862
1021
  "total_outputs": total_outputs,
863
1022
  "output_fields": sorted(list(self.known_output_fields)),
864
1023
  "field_stats": field_stats,
865
- "caption_count_stats": {
866
- "mean": df["caption_count"].mean() if "caption_count" in df.columns else 0,
867
- "min": df["caption_count"].min() if "caption_count" in df.columns else 0,
868
- "max": df["caption_count"].max() if "caption_count" in df.columns else 0,
869
- },
870
1024
  }
871
1025
 
872
1026
  async def count_captions(self) -> int:
873
- """Count total outputs across all dynamic fields."""
874
- total = 0
875
-
876
- if self.captions_path.exists():
877
- # Get actual columns in the file
878
- table_metadata = pq.read_metadata(self.captions_path)
879
- existing_columns = set(table_metadata.schema.names)
880
-
881
- # Only read output fields that actually exist in the file
882
- columns_to_read = [f for f in self.known_output_fields if f in existing_columns]
883
-
884
- if columns_to_read:
885
- table = pq.read_table(self.captions_path, columns=columns_to_read)
886
- df = table.to_pandas()
887
-
888
- for field_name in columns_to_read:
889
- for value in df[field_name]:
890
- if pd.notna(value) and isinstance(value, list):
891
- total += len(value)
1027
+ """Count total outputs across all dynamic fields from cached values."""
1028
+ # Use cached disk count
1029
+ total = self.stats["disk_outputs"]
892
1030
 
893
1031
  # Add buffer counts
894
1032
  for row in self.caption_buffer:
@@ -899,12 +1037,8 @@ class StorageManager:
899
1037
  return total
900
1038
 
901
1039
  async def count_caption_rows(self) -> int:
902
- """Count total rows (unique images with captions)."""
903
- if not self.captions_path.exists():
904
- return 0
905
-
906
- table = pq.read_table(self.captions_path)
907
- return len(table)
1040
+ """Count total rows from cached values."""
1041
+ return self.stats["disk_rows"] + len(self.caption_buffer)
908
1042
 
909
1043
  async def get_contributor(self, contributor_id: str) -> Optional[Contributor]:
910
1044
  """Retrieve a contributor by ID."""
@@ -954,42 +1088,19 @@ class StorageManager:
954
1088
  return contributors
955
1089
 
956
1090
  async def get_output_field_stats(self) -> Dict[str, Any]:
957
- """Get statistics about output fields in stored captions."""
958
- if not self.captions_path.exists():
959
- return {"total_fields": 0, "field_counts": {}}
1091
+ """Get statistics about output fields from cached values."""
1092
+ # Combine disk and buffer stats
1093
+ field_counts = self.stats["field_counts"].copy()
960
1094
 
961
- if not self.known_output_fields:
962
- return {"total_fields": 0, "field_counts": {}}
963
-
964
- # Get actual columns in the file
965
- table_metadata = pq.read_metadata(self.captions_path)
966
- existing_columns = set(table_metadata.schema.names)
967
-
968
- # Only read output fields that actually exist in the file
969
- columns_to_read = [f for f in self.known_output_fields if f in existing_columns]
970
-
971
- if not columns_to_read:
972
- return {"total_fields": 0, "field_counts": {}}
973
-
974
- table = pq.read_table(self.captions_path, columns=columns_to_read)
975
- df = table.to_pandas()
976
-
977
- if len(df) == 0:
978
- return {"total_fields": 0, "field_counts": {}}
979
-
980
- # Count outputs by field
981
- field_counts = {}
982
- total_outputs = 0
983
-
984
- for field_name in columns_to_read:
985
- field_count = 0
986
- for value in df[field_name]:
987
- if pd.notna(value) and isinstance(value, list):
988
- field_count += len(value)
1095
+ # Add buffer counts
1096
+ for row in self.caption_buffer:
1097
+ for field_name in self.known_output_fields:
1098
+ if field_name in row and isinstance(row[field_name], list):
1099
+ if field_name not in field_counts:
1100
+ field_counts[field_name] = 0
1101
+ field_counts[field_name] += len(row[field_name])
989
1102
 
990
- if field_count > 0:
991
- field_counts[field_name] = field_count
992
- total_outputs += field_count
1103
+ total_outputs = sum(field_counts.values())
993
1104
 
994
1105
  return {
995
1106
  "total_fields": len(field_counts),
@@ -1003,27 +1114,24 @@ class StorageManager:
1003
1114
  await self.checkpoint()
1004
1115
 
1005
1116
  # Log final rate statistics
1006
- if self.total_captions_written > 0:
1117
+ if self.stats["total_captions_written"] > 0:
1007
1118
  rates = self._calculate_rates()
1008
1119
  logger.info(
1009
- f"Storage closed. Total rows: {self.total_captions_written}, "
1010
- f"Total caption entries: {self.total_caption_entries_written}, "
1011
- f"Duplicates skipped: {self.duplicates_skipped} | "
1120
+ f"Storage closed. Total rows: {self.stats['total_captions_written']}, "
1121
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
1122
+ f"Duplicates skipped: {self.stats['duplicates_skipped']} | "
1012
1123
  f"Final rates - Overall: {rates['overall']:.1f} rows/s, "
1013
1124
  f"Last hour: {rates['60min']:.1f} rows/s"
1014
1125
  )
1015
1126
  else:
1016
1127
  logger.info(
1017
- f"Storage closed. Total rows: {self.total_captions_written}, "
1018
- f"Total caption entries: {self.total_caption_entries_written}, "
1019
- f"Duplicates skipped: {self.duplicates_skipped}"
1128
+ f"Storage closed. Total rows: {self.stats['total_captions_written']}, "
1129
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
1130
+ f"Duplicates skipped: {self.stats['duplicates_skipped']}"
1020
1131
  )
1021
1132
 
1022
1133
  async def get_storage_stats(self) -> Dict[str, Any]:
1023
- """Get all storage-related statistics."""
1024
- # Count outputs on disk
1025
- disk_outputs = await self.count_captions()
1026
-
1134
+ """Get all storage-related statistics from cached values."""
1027
1135
  # Count outputs in buffer
1028
1136
  buffer_outputs = 0
1029
1137
  for row in self.caption_buffer:
@@ -1033,22 +1141,21 @@ class StorageManager:
1033
1141
 
1034
1142
  # Get field-specific stats
1035
1143
  field_stats = await self.get_caption_stats()
1036
- total_rows_including_buffer = await self.count_caption_rows() + len(self.caption_buffer)
1144
+ total_rows_including_buffer = self.stats["disk_rows"] + len(self.caption_buffer)
1037
1145
 
1038
1146
  # Calculate rates
1039
1147
  rates = self._calculate_rates()
1040
1148
 
1041
1149
  return {
1042
- "total_captions": disk_outputs + buffer_outputs,
1150
+ "total_captions": self.stats["disk_outputs"] + buffer_outputs,
1043
1151
  "total_rows": total_rows_including_buffer,
1044
1152
  "buffer_size": len(self.caption_buffer),
1045
- "total_written": self.total_captions_written,
1046
- "total_entries_written": self.total_caption_entries_written,
1047
- "duplicates_skipped": self.duplicates_skipped,
1048
- "total_flushes": self.total_flushes,
1153
+ "total_written": self.stats["total_captions_written"],
1154
+ "total_entries_written": self.stats["total_caption_entries_written"],
1155
+ "duplicates_skipped": self.stats["duplicates_skipped"],
1156
+ "total_flushes": self.stats["total_flushes"],
1049
1157
  "output_fields": sorted(list(self.known_output_fields)),
1050
- "field_breakdown": field_stats.get("field_stats", None),
1051
- "job_buffer_size": len(self.job_buffer),
1158
+ "field_breakdown": field_stats.get("field_stats", {}),
1052
1159
  "contributor_buffer_size": len(self.contributor_buffer),
1053
1160
  "rates": {
1054
1161
  "instant": f"{rates['instant']:.1f} rows/s",