caption-flow 0.2.4__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  """Arrow/Parquet storage management with dynamic column support for outputs."""
2
2
 
3
3
  import asyncio
4
+ import gc
4
5
  import json
5
6
  import logging
6
7
  from dataclasses import asdict
@@ -18,7 +19,7 @@ import numpy as np
18
19
  from ..models import Job, Caption, Contributor, StorageContents, JobId
19
20
 
20
21
  logger = logging.getLogger(__name__)
21
- logger.setLevel(logging.INFO)
22
+ logger.setLevel(logging.DEBUG)
22
23
 
23
24
 
24
25
  class StorageManager:
@@ -28,8 +29,7 @@ class StorageManager:
28
29
  self,
29
30
  data_dir: Path,
30
31
  caption_buffer_size: int = 100,
31
- job_buffer_size: int = 100,
32
- contributor_buffer_size: int = 10,
32
+ contributor_buffer_size: int = 50,
33
33
  ):
34
34
  self.data_dir = Path(data_dir)
35
35
  self.data_dir.mkdir(parents=True, exist_ok=True)
@@ -38,6 +38,7 @@ class StorageManager:
38
38
  self.captions_path = self.data_dir / "captions.parquet"
39
39
  self.jobs_path = self.data_dir / "jobs.parquet"
40
40
  self.contributors_path = self.data_dir / "contributors.parquet"
41
+ self.stats_path = self.data_dir / "storage_stats.json" # Persist stats here
41
42
 
42
43
  # In-memory buffers for batching writes
43
44
  self.caption_buffer = []
@@ -46,7 +47,6 @@ class StorageManager:
46
47
 
47
48
  # Buffer size configuration
48
49
  self.caption_buffer_size = caption_buffer_size
49
- self.job_buffer_size = job_buffer_size
50
50
  self.contributor_buffer_size = contributor_buffer_size
51
51
 
52
52
  # Track existing job_ids to prevent duplicates
@@ -57,14 +57,20 @@ class StorageManager:
57
57
  # Track known output fields for schema evolution
58
58
  self.known_output_fields: Set[str] = set()
59
59
 
60
- # Statistics
61
- self.total_captions_written = 0
62
- self.total_caption_entries_written = 0 # Total individual captions
63
- self.total_flushes = 0
64
- self.duplicates_skipped = 0
60
+ # In-memory statistics (loaded once at startup, then tracked incrementally)
61
+ self.stats = {
62
+ "disk_rows": 0, # Rows in parquet file
63
+ "disk_outputs": 0, # Total outputs in parquet file
64
+ "field_counts": {}, # Count of outputs per field on disk
65
+ "total_captions_written": 0, # Total rows written during this session
66
+ "total_caption_entries_written": 0, # Total individual captions written
67
+ "total_flushes": 0,
68
+ "duplicates_skipped": 0,
69
+ "session_field_counts": {}, # Outputs per field written this session
70
+ }
65
71
 
66
72
  # Rate tracking
67
- self.row_additions = deque(maxlen=10000) # Store (timestamp, row_count) tuples
73
+ self.row_additions = deque(maxlen=100) # Store (timestamp, row_count) tuples
68
74
  self.start_time = time.time()
69
75
  self.last_rate_log_time = time.time()
70
76
 
@@ -115,6 +121,129 @@ class StorageManager:
115
121
  ]
116
122
  )
117
123
 
124
+ def _save_stats(self):
125
+ """Persist current stats to disk."""
126
+ try:
127
+ with open(self.stats_path, "w") as f:
128
+ json.dump(self.stats, f, indent=2)
129
+ except Exception as e:
130
+ logger.error(f"Failed to save stats: {e}")
131
+
132
+ def _load_stats(self):
133
+ """Load stats from disk if available."""
134
+ if self.stats_path.exists():
135
+ try:
136
+ with open(self.stats_path, "r") as f:
137
+ loaded_stats = json.load(f)
138
+ # Merge loaded stats with defaults
139
+ self.stats.update(loaded_stats)
140
+ logger.info(f"Loaded stats from {self.stats_path}")
141
+ except Exception as e:
142
+ logger.error(f"Failed to load stats: {e}")
143
+
144
+ def _calculate_initial_stats(self):
145
+ """Calculate stats from parquet file - only called once at initialization."""
146
+ if not self.captions_path.exists():
147
+ return
148
+
149
+ logger.info("Calculating initial statistics from parquet file...")
150
+
151
+ try:
152
+ # Get metadata to determine row count
153
+ table_metadata = pq.read_metadata(self.captions_path)
154
+ self.stats["disk_rows"] = table_metadata.num_rows
155
+
156
+ # Simply use the known_output_fields that were already detected during initialization
157
+ # This avoids issues with PyArrow schema parsing
158
+ if self.known_output_fields and table_metadata.num_rows > 0:
159
+ # Read the entire table since column-specific reading is causing issues
160
+ table = pq.read_table(self.captions_path)
161
+ df = table.to_pandas()
162
+
163
+ total_outputs = 0
164
+ field_counts = {}
165
+
166
+ # Count outputs in each known output field
167
+ for field_name in self.known_output_fields:
168
+ if field_name in df.columns:
169
+ field_count = 0
170
+ column_data = df[field_name]
171
+
172
+ # Iterate through values more carefully to avoid numpy array issues
173
+ for i in range(len(column_data)):
174
+ value = column_data.iloc[i]
175
+ # Handle None/NaN values first
176
+ if value is None:
177
+ continue
178
+ # For lists/arrays, check if they're actually None or have content
179
+ try:
180
+ # If it's a list-like object, count its length
181
+ if hasattr(value, "__len__") and not isinstance(value, str):
182
+ # Check if it's a pandas null value by trying to check the first element
183
+ if len(value) > 0:
184
+ field_count += len(value)
185
+ # Empty lists are valid but contribute 0
186
+ except TypeError:
187
+ # Value is scalar NA/NaN, skip it
188
+ continue
189
+ except:
190
+ # Any other error, skip this value
191
+ continue
192
+
193
+ if field_count > 0:
194
+ field_counts[field_name] = field_count
195
+ total_outputs += field_count
196
+
197
+ self.stats["disk_outputs"] = total_outputs
198
+ self.stats["field_counts"] = field_counts
199
+
200
+ # Clean up
201
+ del df, table
202
+ gc.collect()
203
+ else:
204
+ self.stats["disk_outputs"] = 0
205
+ self.stats["field_counts"] = {}
206
+
207
+ logger.info(
208
+ f"Initial stats: {self.stats['disk_rows']} rows, {self.stats['disk_outputs']} outputs, fields: {list(self.stats['field_counts'].keys())}"
209
+ )
210
+
211
+ except Exception as e:
212
+ logger.error(f"Failed to calculate initial stats: {e}", exc_info=True)
213
+ # Set default values
214
+ self.stats["disk_rows"] = 0
215
+ self.stats["disk_outputs"] = 0
216
+ self.stats["field_counts"] = {}
217
+
218
+ # Save the calculated stats
219
+ self._save_stats()
220
+
221
+ def _update_stats_for_new_captions(self, captions_added: List[dict], rows_added: int):
222
+ """Update stats incrementally as new captions are added."""
223
+ # Update row counts
224
+ self.stats["disk_rows"] += rows_added
225
+ self.stats["total_captions_written"] += rows_added
226
+
227
+ # Count outputs in the new captions
228
+ outputs_added = 0
229
+ for caption in captions_added:
230
+ for field_name in self.known_output_fields:
231
+ if field_name in caption and isinstance(caption[field_name], list):
232
+ count = len(caption[field_name])
233
+ outputs_added += count
234
+
235
+ # Update field-specific counts
236
+ if field_name not in self.stats["field_counts"]:
237
+ self.stats["field_counts"][field_name] = 0
238
+ self.stats["field_counts"][field_name] += count
239
+
240
+ if field_name not in self.stats["session_field_counts"]:
241
+ self.stats["session_field_counts"][field_name] = 0
242
+ self.stats["session_field_counts"][field_name] += count
243
+
244
+ self.stats["disk_outputs"] += outputs_added
245
+ self.stats["total_caption_entries_written"] += outputs_added
246
+
118
247
  def _is_column_empty(self, df: pd.DataFrame, column_name: str) -> bool:
119
248
  """Check if a column is entirely empty, null, or contains only zeros/empty lists."""
120
249
  if column_name not in df.columns:
@@ -217,7 +346,7 @@ class StorageManager:
217
346
  # Calculate overall rate since start
218
347
  total_elapsed = current_time - self.start_time
219
348
  if total_elapsed > 0:
220
- rates["overall"] = self.total_captions_written / total_elapsed
349
+ rates["overall"] = self.stats["total_captions_written"] / total_elapsed
221
350
  else:
222
351
  rates["overall"] = 0.0
223
352
 
@@ -269,6 +398,9 @@ class StorageManager:
269
398
 
270
399
  async def initialize(self):
271
400
  """Initialize storage files if they don't exist."""
401
+ # Load persisted stats if available
402
+ self._load_stats()
403
+
272
404
  if not self.captions_path.exists():
273
405
  # Create initial schema with just base fields
274
406
  self.caption_schema = self._build_caption_schema(set())
@@ -278,6 +410,11 @@ class StorageManager:
278
410
  empty_table = pa.Table.from_pydict(empty_dict, schema=self.caption_schema)
279
411
  pq.write_table(empty_table, self.captions_path)
280
412
  logger.info(f"Created empty caption storage at {self.captions_path}")
413
+
414
+ # Initialize stats
415
+ self.stats["disk_rows"] = 0
416
+ self.stats["disk_outputs"] = 0
417
+ self.stats["field_counts"] = {}
281
418
  else:
282
419
  # Load existing schema and detect output fields
283
420
  existing_table = pq.read_table(self.captions_path)
@@ -300,6 +437,10 @@ class StorageManager:
300
437
  logger.info(f"Loaded {len(self.existing_caption_job_ids)} existing caption job_ids")
301
438
  logger.info(f"Known output fields: {sorted(self.known_output_fields)}")
302
439
 
440
+ # Calculate initial stats if not already loaded from file
441
+ if self.stats["disk_rows"] == 0:
442
+ self._calculate_initial_stats()
443
+
303
444
  # Initialize other storage files...
304
445
  if not self.contributors_path.exists():
305
446
  empty_dict = {"contributor_id": [], "name": [], "total_captions": [], "trust_level": []}
@@ -371,6 +512,9 @@ class StorageManager:
371
512
  pq.write_table(migrated_table, self.captions_path)
372
513
  logger.info("Migration complete - outputs now stored in dynamic columns")
373
514
 
515
+ # Recalculate stats after migration
516
+ self._calculate_initial_stats()
517
+
374
518
  async def save_caption(self, caption: Caption):
375
519
  """Save a caption entry, grouping outputs by job_id/item_key (not separating captions)."""
376
520
  caption_dict = asdict(caption)
@@ -391,49 +535,32 @@ class StorageManager:
391
535
  _job_id = caption_dict.get("job_id")
392
536
  job_id = JobId.from_dict(_job_id).get_sample_str()
393
537
  group_key = job_id
394
- logger.debug(
395
- f"save_caption: group_key={group_key}, outputs={list(outputs.keys())}, caption_count={caption_dict.get('caption_count')}, item_index={caption_dict.get('item_index')}"
396
- )
538
+
539
+ # Check for duplicate - if this job_id already exists on disk, skip it
540
+ if group_key in self.existing_caption_job_ids:
541
+ self.stats["duplicates_skipped"] += 1
542
+ logger.debug(f"Skipping duplicate job_id: {group_key}")
543
+ return
397
544
 
398
545
  # Try to find existing buffered row for this group
399
546
  found_row = False
400
547
  for idx, row in enumerate(self.caption_buffer):
401
548
  check_key = row.get("job_id")
402
- logger.debug(f"Checking buffer row {idx}: check_key={check_key}, group_key={group_key}")
403
549
  if check_key == group_key:
404
550
  found_row = True
405
- logger.debug(f"Found existing buffer row for group_key={group_key} at index {idx}")
406
551
  # Merge outputs into existing row
407
552
  for field_name, field_values in outputs.items():
408
553
  if field_name not in self.known_output_fields:
409
554
  self.known_output_fields.add(field_name)
410
- logger.info(f"New output field detected: {field_name}")
411
555
  if field_name in row and isinstance(row[field_name], list):
412
- logger.debug(
413
- f"Merging output field '{field_name}' into existing row: before={row[field_name]}, adding={field_values}"
414
- )
415
556
  row[field_name].extend(field_values)
416
- logger.debug(f"After merge: {row[field_name]}")
417
557
  else:
418
- logger.debug(
419
- f"Setting new output field '{field_name}' in existing row: {field_values}"
420
- )
421
558
  row[field_name] = list(field_values)
422
559
  # Optionally update other fields (e.g., caption_count)
423
560
  if "caption_count" in caption_dict:
424
561
  old_count = row.get("caption_count", 0)
425
562
  row["caption_count"] = old_count + caption_dict["caption_count"]
426
- logger.debug(
427
- f"Updated caption_count for group_key={group_key}: {old_count} + {caption_dict['caption_count']} = {row['caption_count']}"
428
- )
429
563
  return # Already merged, no need to add new row
430
- else:
431
- logger.debug(f"Caption row not found for group key: {group_key} vs {check_key}")
432
-
433
- if not found_row:
434
- logger.debug(
435
- f"No existing buffer row found for group_key={group_key}, creating new row."
436
- )
437
564
 
438
565
  # If not found, create new row
439
566
  for field_name, field_values in outputs.items():
@@ -441,7 +568,6 @@ class StorageManager:
441
568
  self.known_output_fields.add(field_name)
442
569
  logger.info(f"New output field detected: {field_name}")
443
570
  caption_dict[field_name] = list(field_values)
444
- logger.debug(f"Adding output field '{field_name}' to new row: {field_values}")
445
571
 
446
572
  # Serialize metadata to JSON if present
447
573
  if "metadata" in caption_dict:
@@ -453,9 +579,6 @@ class StorageManager:
453
579
  caption_dict["job_id"] = job_id
454
580
 
455
581
  self.caption_buffer.append(caption_dict)
456
- logger.debug(
457
- f"Appended new caption row for group_key={group_key}. Caption buffer size: {len(self.caption_buffer)}/{self.caption_buffer_size}"
458
- )
459
582
 
460
583
  if len(self.caption_buffer) >= self.caption_buffer_size:
461
584
  logger.debug("Caption buffer full, flushing captions.")
@@ -466,105 +589,109 @@ class StorageManager:
466
589
  if not self.caption_buffer:
467
590
  return
468
591
 
469
- num_rows = len(self.caption_buffer)
592
+ try:
593
+ num_rows = len(self.caption_buffer)
470
594
 
471
- # Count total outputs across all fields
472
- total_outputs = 0
473
- for row in self.caption_buffer:
474
- for field_name in self.known_output_fields:
475
- if field_name in row and isinstance(row[field_name], list):
476
- total_outputs += len(row[field_name])
595
+ # Count total outputs across all fields
596
+ total_outputs = 0
597
+ for row in self.caption_buffer:
598
+ for field_name in self.known_output_fields:
599
+ if field_name in row and isinstance(row[field_name], list):
600
+ total_outputs += len(row[field_name])
477
601
 
478
- logger.debug(f"Flushing {num_rows} rows with {total_outputs} total outputs to disk")
602
+ logger.debug(
603
+ f"Flushing {num_rows} rows with {total_outputs} total outputs to disk. preparing data..."
604
+ )
479
605
 
480
- # Prepare data with all required columns
481
- prepared_buffer = []
482
- for row in self.caption_buffer:
483
- prepared_row = row.copy()
606
+ # Keep a copy of captions for stats update
607
+ captions_to_write = list(self.caption_buffer)
484
608
 
485
- # Ensure all base fields are present
486
- for field_name, field_type in self.base_caption_fields:
487
- if field_name not in prepared_row:
488
- prepared_row[field_name] = None
609
+ # Prepare data with all required columns
610
+ prepared_buffer = []
611
+ new_job_ids = []
612
+ for row in self.caption_buffer:
613
+ prepared_row = row.copy()
489
614
 
490
- # Ensure all output fields are present (even if None)
491
- for field_name in self.known_output_fields:
492
- if field_name not in prepared_row:
493
- prepared_row[field_name] = None
494
-
495
- prepared_buffer.append(prepared_row)
496
-
497
- # Build schema with all known fields (base + output)
498
- schema = self._build_caption_schema(self.known_output_fields)
499
- table = pa.Table.from_pylist(prepared_buffer, schema=schema)
500
-
501
- if self.captions_path.exists():
502
- # Read existing table
503
- existing = pq.read_table(self.captions_path)
504
-
505
- # Get existing job_ids for deduplication
506
- existing_job_ids = set(existing.column("job_id").to_pylist())
507
-
508
- # Filter new data to exclude duplicates
509
- new_rows = []
510
- duplicate_rows = []
511
- for row in prepared_buffer:
512
- if row["job_id"] not in existing_job_ids:
513
- new_rows.append(row)
514
- elif row not in duplicate_rows:
515
- duplicate_rows.append(
516
- {
517
- "input": row,
518
- "existing_job": existing.to_pandas()[
519
- existing.to_pandas()["job_id"] == row["job_id"]
520
- ].to_dict(orient="records"),
521
- }
522
- )
615
+ # Track job_ids for deduplication
616
+ job_id = prepared_row.get("job_id")
617
+ if job_id:
618
+ new_job_ids.append(job_id)
619
+
620
+ # Ensure all base fields are present
621
+ for field_name, field_type in self.base_caption_fields:
622
+ if field_name not in prepared_row:
623
+ prepared_row[field_name] = None
624
+
625
+ # Ensure all output fields are present (even if None)
626
+ for field_name in self.known_output_fields:
627
+ if field_name not in prepared_row:
628
+ prepared_row[field_name] = None
629
+
630
+ prepared_buffer.append(prepared_row)
523
631
 
524
- if duplicate_rows:
525
- logger.info(f"Example duplicate row: {duplicate_rows[0]}")
632
+ # Build schema with all known fields (base + output)
633
+ logger.debug("building schema...")
634
+ schema = self._build_caption_schema(self.known_output_fields)
635
+ logger.debug("schema built, creating table...")
636
+ table = pa.Table.from_pylist(prepared_buffer, schema=schema)
526
637
 
527
- if new_rows:
528
- # Create table from new rows only
529
- new_table = pa.Table.from_pylist(new_rows, schema=schema)
638
+ if self.captions_path.exists():
639
+ # Read existing table - this is necessary for parquet format
640
+ logger.debug("Reading existing captions file...")
641
+ existing = pq.read_table(self.captions_path)
530
642
 
643
+ logger.debug("writing new rows...")
531
644
  # Concatenate with promote_options="default" to handle schema differences automatically
532
- combined = pa.concat_tables([existing, new_table], promote_options="default")
645
+ logger.debug("concat tables...")
646
+ combined = pa.concat_tables([existing, table], promote_options="default")
533
647
 
534
648
  # Write combined table
535
649
  pq.write_table(combined, self.captions_path, compression="snappy")
536
650
 
537
- self.duplicates_skipped = num_rows - len(new_rows)
538
- actual_new = len(new_rows)
651
+ actual_new = len(table)
539
652
  else:
540
- logger.info(f"All {num_rows} rows were duplicates, exiting")
541
- raise SystemError("No duplicates can be submitted")
542
- else:
543
- # Write new file with all fields
544
- pq.write_table(table, self.captions_path, compression="snappy")
545
- actual_new = num_rows
653
+ # Write new file with all fields
654
+ pq.write_table(table, self.captions_path, compression="snappy")
655
+ actual_new = num_rows
656
+ logger.debug("write complete.")
546
657
 
547
- # Update statistics
548
- self.total_captions_written += actual_new
549
- self.total_caption_entries_written += total_outputs
550
- self.total_flushes += 1
551
- self.caption_buffer.clear()
658
+ # Clean up
659
+ del prepared_buffer, table
552
660
 
553
- # Track row additions for rate calculation
554
- if actual_new > 0:
555
- current_time = time.time()
556
- self.row_additions.append((current_time, actual_new))
661
+ # Update the in-memory job_id set for efficient deduplication
662
+ self.existing_caption_job_ids.update(new_job_ids)
557
663
 
558
- # Log rates
559
- self._log_rates(actual_new)
664
+ # Update statistics incrementally
665
+ self._update_stats_for_new_captions(captions_to_write, actual_new)
666
+ self.stats["total_flushes"] += 1
560
667
 
561
- logger.info(
562
- f"Successfully wrote captions (new rows: {actual_new}, "
563
- f"total rows written: {self.total_captions_written}, "
564
- f"total captions written: {self.total_caption_entries_written}, "
565
- f"duplicates skipped: {self.duplicates_skipped}, "
566
- f"output fields: {sorted(list(self.known_output_fields))})"
567
- )
668
+ # Clear buffer
669
+ self.caption_buffer.clear()
670
+
671
+ # Track row additions for rate calculation
672
+ if actual_new > 0:
673
+ current_time = time.time()
674
+ self.row_additions.append((current_time, actual_new))
675
+
676
+ # Log rates
677
+ self._log_rates(actual_new)
678
+
679
+ logger.info(
680
+ f"Successfully wrote captions (new rows: {actual_new}, "
681
+ f"total rows written: {self.stats['total_captions_written']}, "
682
+ f"total captions written: {self.stats['total_caption_entries_written']}, "
683
+ f"duplicates skipped: {self.stats['duplicates_skipped']}, "
684
+ f"output fields: {sorted(list(self.known_output_fields))})"
685
+ )
686
+
687
+ # Save stats periodically
688
+ if self.stats["total_flushes"] % 10 == 0:
689
+ self._save_stats()
690
+
691
+ finally:
692
+ self.caption_buffer.clear()
693
+ # Force garbage collection
694
+ gc.collect()
568
695
 
569
696
  async def optimize_storage(self):
570
697
  """Optimize storage by dropping empty columns. Run this periodically or on-demand."""
@@ -625,11 +752,15 @@ class StorageManager:
625
752
  # Write optimized table
626
753
  pq.write_table(optimized_table, self.captions_path, compression="snappy")
627
754
 
628
- # Update known output fields
755
+ # Update known output fields and recalculate stats
629
756
  self.known_output_fields = output_fields
630
757
 
631
- # Clean up backup (optional - keep it for safety)
632
- # backup_path.unlink()
758
+ # Remove dropped fields from stats
759
+ dropped_fields = set(self.stats["field_counts"].keys()) - output_fields
760
+ for field in dropped_fields:
761
+ del self.stats["field_counts"][field]
762
+ if field in self.stats["session_field_counts"]:
763
+ del self.stats["session_field_counts"][field]
633
764
 
634
765
  logger.info(
635
766
  f"Storage optimization complete: {original_columns} -> {len(non_empty_columns)} columns. "
@@ -650,37 +781,8 @@ class StorageManager:
650
781
  f"({reduction_pct:.1f}% reduction)"
651
782
  )
652
783
 
653
- async def _evolve_schema_on_disk(self):
654
- """Evolve the schema of the existing parquet file to include new columns, removing empty ones."""
655
- logger.info("Evolving schema on disk to add new columns...")
656
-
657
- # Read existing data
658
- existing_table = pq.read_table(self.captions_path)
659
- df = existing_table.to_pandas()
660
-
661
- # Add missing columns with None values
662
- for field_name in self.known_output_fields:
663
- if field_name not in df.columns:
664
- df[field_name] = None
665
- logger.info(f"Added new column: {field_name}")
666
-
667
- # Remove empty columns (but preserve base fields)
668
- non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=True)
669
- df = df[non_empty_columns]
670
-
671
- # Update known output fields
672
- base_field_names = {field[0] for field in self.base_caption_fields}
673
- self.known_output_fields = set(non_empty_columns) - base_field_names
674
-
675
- # Recreate schema with only non-empty fields
676
- self.caption_schema = self._build_caption_schema(self.known_output_fields)
677
-
678
- # Recreate table with new schema
679
- evolved_table = pa.Table.from_pandas(df, schema=self.caption_schema)
680
- pq.write_table(evolved_table, self.captions_path, compression="snappy")
681
- logger.info(
682
- f"Schema evolution complete. Active output fields: {sorted(list(self.known_output_fields))}"
683
- )
784
+ # Save updated stats
785
+ self._save_stats()
684
786
 
685
787
  async def save_contributor(self, contributor: Contributor):
686
788
  """Save or update contributor stats - buffers until batch size reached."""
@@ -758,38 +860,36 @@ class StorageManager:
758
860
  await self._flush_jobs()
759
861
  await self._flush_contributors()
760
862
 
863
+ # Save stats on checkpoint
864
+ self._save_stats()
865
+
761
866
  # Log final rate statistics
762
- if self.total_captions_written > 0:
867
+ if self.stats["total_captions_written"] > 0:
763
868
  rates = self._calculate_rates()
764
869
  logger.info(
765
- f"Checkpoint complete. Total rows: {self.total_captions_written}, "
766
- f"Total caption entries: {self.total_caption_entries_written}, "
767
- f"Duplicates skipped: {self.duplicates_skipped} | "
870
+ f"Checkpoint complete. Total rows: {self.stats['total_captions_written']}, "
871
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
872
+ f"Duplicates skipped: {self.stats['duplicates_skipped']} | "
768
873
  f"Overall rate: {rates['overall']:.1f} rows/s"
769
874
  )
770
875
  else:
771
876
  logger.info(
772
- f"Checkpoint complete. Total rows: {self.total_captions_written}, "
773
- f"Total caption entries: {self.total_caption_entries_written}, "
774
- f"Duplicates skipped: {self.duplicates_skipped}"
877
+ f"Checkpoint complete. Total rows: {self.stats['total_captions_written']}, "
878
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
879
+ f"Duplicates skipped: {self.stats['duplicates_skipped']}"
775
880
  )
776
881
 
777
882
  def get_all_processed_job_ids(self) -> Set[str]:
778
883
  """Get all processed job_ids - useful for resumption."""
779
- if not self.captions_path.exists():
780
- logger.info("No captions file found, returning empty processed job_ids set")
781
- return set()
782
-
783
- # Read only the job_id column
784
- table = pq.read_table(self.captions_path, columns=["job_id"])
785
- job_ids = set(table["job_id"].to_pylist())
884
+ # Return the in-memory set which is kept up-to-date
885
+ # Also add any job_ids currently in the buffer
886
+ all_job_ids = self.existing_caption_job_ids.copy()
786
887
 
787
- # Add buffered job_ids
788
888
  for row in self.caption_buffer:
789
889
  if "job_id" in row:
790
- job_ids.add(row["job_id"])
890
+ all_job_ids.add(row["job_id"])
791
891
 
792
- return job_ids
892
+ return all_job_ids
793
893
 
794
894
  async def get_storage_contents(
795
895
  self,
@@ -855,14 +955,13 @@ class StorageManager:
855
955
  # Prepare metadata
856
956
  metadata = {}
857
957
  if include_metadata:
858
- stats = await self.get_caption_stats()
859
958
  metadata.update(
860
959
  {
861
960
  "export_timestamp": pd.Timestamp.now().isoformat(),
862
- "total_available_rows": stats.get("total_rows", 0),
961
+ "total_available_rows": self.stats.get("disk_rows", 0),
863
962
  "rows_exported": len(rows),
864
963
  "storage_path": str(self.captions_path),
865
- "field_stats": stats.get("field_stats", {}),
964
+ "field_stats": self.stats.get("field_counts", {}),
866
965
  }
867
966
  )
868
967
 
@@ -888,91 +987,46 @@ class StorageManager:
888
987
  return set(chunk_jobs)
889
988
 
890
989
  async def get_caption_stats(self) -> Dict[str, Any]:
891
- """Get statistics about stored captions including field-specific stats."""
892
- if not self.captions_path.exists():
893
- return {"total_rows": 0, "total_outputs": 0, "output_fields": [], "field_stats": {}}
894
-
895
- table = pq.read_table(self.captions_path)
896
- df = table.to_pandas()
897
-
898
- if len(df) == 0:
899
- return {"total_rows": 0, "total_outputs": 0, "output_fields": [], "field_stats": {}}
990
+ """Get statistics about stored captions from cached values."""
991
+ total_rows = self.stats["disk_rows"] + len(self.caption_buffer)
900
992
 
901
- # Get actual columns in the dataframe
902
- existing_columns = set(df.columns)
993
+ # Count outputs in buffer
994
+ buffer_outputs = 0
995
+ buffer_field_counts = defaultdict(int)
996
+ for row in self.caption_buffer:
997
+ for field_name in self.known_output_fields:
998
+ if field_name in row and isinstance(row[field_name], list):
999
+ count = len(row[field_name])
1000
+ buffer_outputs += count
1001
+ buffer_field_counts[field_name] += count
903
1002
 
904
- # Calculate stats per field (only for fields that exist in the file)
1003
+ # Merge buffer counts with disk counts
905
1004
  field_stats = {}
906
- total_outputs = 0
907
-
908
1005
  for field_name in self.known_output_fields:
909
- if field_name in existing_columns:
910
- # Count non-null entries
911
- non_null_mask = df[field_name].notna()
912
- non_null_count = non_null_mask.sum()
913
-
914
- # Count total items in lists
915
- field_total = 0
916
- field_lengths = []
917
-
918
- for value in df.loc[non_null_mask, field_name]:
919
- # list or array-like
920
- if isinstance(value, list):
921
- length = len(value)
922
- field_total += length
923
- field_lengths.append(length)
924
- elif value.any():
925
- length = 1
926
- field_total += length
927
- field_lengths.append(length)
928
-
929
- if field_lengths:
930
- field_stats[field_name] = {
931
- "rows_with_data": non_null_count,
932
- "total_items": field_total,
933
- "avg_items_per_row": sum(field_lengths) / len(field_lengths),
934
- }
935
- if min(field_lengths) != max(field_lengths):
936
- field_stats[field_name].update(
937
- {
938
- "min_items": min(field_lengths),
939
- "max_items": max(field_lengths),
940
- }
941
- )
942
- total_outputs += field_total
1006
+ disk_count = self.stats["field_counts"].get(field_name, 0)
1007
+ buffer_count = buffer_field_counts.get(field_name, 0)
1008
+ total_count = disk_count + buffer_count
1009
+
1010
+ if total_count > 0:
1011
+ field_stats[field_name] = {
1012
+ "total_items": total_count,
1013
+ "disk_items": disk_count,
1014
+ "buffer_items": buffer_count,
1015
+ }
1016
+
1017
+ total_outputs = self.stats["disk_outputs"] + buffer_outputs
943
1018
 
944
1019
  return {
945
- "total_rows": len(df),
1020
+ "total_rows": total_rows,
946
1021
  "total_outputs": total_outputs,
947
1022
  "output_fields": sorted(list(self.known_output_fields)),
948
1023
  "field_stats": field_stats,
949
- "caption_count_stats": {
950
- "mean": df["caption_count"].mean() if "caption_count" in df.columns else 0,
951
- "min": df["caption_count"].min() if "caption_count" in df.columns else 0,
952
- "max": df["caption_count"].max() if "caption_count" in df.columns else 0,
953
- },
954
1024
  }
955
1025
 
956
1026
  async def count_captions(self) -> int:
957
- """Count total outputs across all dynamic fields."""
958
- total = 0
959
-
960
- if self.captions_path.exists():
961
- # Get actual columns in the file
962
- table_metadata = pq.read_metadata(self.captions_path)
963
- existing_columns = set(table_metadata.schema.names)
964
-
965
- # Only read output fields that actually exist in the file
966
- columns_to_read = [f for f in self.known_output_fields if f in existing_columns]
967
-
968
- if columns_to_read:
969
- table = pq.read_table(self.captions_path, columns=columns_to_read)
970
- df = table.to_pandas()
971
-
972
- for field_name in columns_to_read:
973
- for value in df[field_name]:
974
- if pd.notna(value) and isinstance(value, list):
975
- total += len(value)
1027
+ """Count total outputs across all dynamic fields from cached values."""
1028
+ # Use cached disk count
1029
+ total = self.stats["disk_outputs"]
976
1030
 
977
1031
  # Add buffer counts
978
1032
  for row in self.caption_buffer:
@@ -983,12 +1037,8 @@ class StorageManager:
983
1037
  return total
984
1038
 
985
1039
  async def count_caption_rows(self) -> int:
986
- """Count total rows (unique images with captions)."""
987
- if not self.captions_path.exists():
988
- return 0
989
-
990
- table = pq.read_table(self.captions_path)
991
- return len(table)
1040
+ """Count total rows from cached values."""
1041
+ return self.stats["disk_rows"] + len(self.caption_buffer)
992
1042
 
993
1043
  async def get_contributor(self, contributor_id: str) -> Optional[Contributor]:
994
1044
  """Retrieve a contributor by ID."""
@@ -1038,42 +1088,19 @@ class StorageManager:
1038
1088
  return contributors
1039
1089
 
1040
1090
  async def get_output_field_stats(self) -> Dict[str, Any]:
1041
- """Get statistics about output fields in stored captions."""
1042
- if not self.captions_path.exists():
1043
- return {"total_fields": 0, "field_counts": {}}
1044
-
1045
- if not self.known_output_fields:
1046
- return {"total_fields": 0, "field_counts": {}}
1047
-
1048
- # Get actual columns in the file
1049
- table_metadata = pq.read_metadata(self.captions_path)
1050
- existing_columns = set(table_metadata.schema.names)
1051
-
1052
- # Only read output fields that actually exist in the file
1053
- columns_to_read = [f for f in self.known_output_fields if f in existing_columns]
1054
-
1055
- if not columns_to_read:
1056
- return {"total_fields": 0, "field_counts": {}}
1057
-
1058
- table = pq.read_table(self.captions_path, columns=columns_to_read)
1059
- df = table.to_pandas()
1091
+ """Get statistics about output fields from cached values."""
1092
+ # Combine disk and buffer stats
1093
+ field_counts = self.stats["field_counts"].copy()
1060
1094
 
1061
- if len(df) == 0:
1062
- return {"total_fields": 0, "field_counts": {}}
1063
-
1064
- # Count outputs by field
1065
- field_counts = {}
1066
- total_outputs = 0
1067
-
1068
- for field_name in columns_to_read:
1069
- field_count = 0
1070
- for value in df[field_name]:
1071
- if pd.notna(value) and isinstance(value, list):
1072
- field_count += len(value)
1095
+ # Add buffer counts
1096
+ for row in self.caption_buffer:
1097
+ for field_name in self.known_output_fields:
1098
+ if field_name in row and isinstance(row[field_name], list):
1099
+ if field_name not in field_counts:
1100
+ field_counts[field_name] = 0
1101
+ field_counts[field_name] += len(row[field_name])
1073
1102
 
1074
- if field_count > 0:
1075
- field_counts[field_name] = field_count
1076
- total_outputs += field_count
1103
+ total_outputs = sum(field_counts.values())
1077
1104
 
1078
1105
  return {
1079
1106
  "total_fields": len(field_counts),
@@ -1087,27 +1114,24 @@ class StorageManager:
1087
1114
  await self.checkpoint()
1088
1115
 
1089
1116
  # Log final rate statistics
1090
- if self.total_captions_written > 0:
1117
+ if self.stats["total_captions_written"] > 0:
1091
1118
  rates = self._calculate_rates()
1092
1119
  logger.info(
1093
- f"Storage closed. Total rows: {self.total_captions_written}, "
1094
- f"Total caption entries: {self.total_caption_entries_written}, "
1095
- f"Duplicates skipped: {self.duplicates_skipped} | "
1120
+ f"Storage closed. Total rows: {self.stats['total_captions_written']}, "
1121
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
1122
+ f"Duplicates skipped: {self.stats['duplicates_skipped']} | "
1096
1123
  f"Final rates - Overall: {rates['overall']:.1f} rows/s, "
1097
1124
  f"Last hour: {rates['60min']:.1f} rows/s"
1098
1125
  )
1099
1126
  else:
1100
1127
  logger.info(
1101
- f"Storage closed. Total rows: {self.total_captions_written}, "
1102
- f"Total caption entries: {self.total_caption_entries_written}, "
1103
- f"Duplicates skipped: {self.duplicates_skipped}"
1128
+ f"Storage closed. Total rows: {self.stats['total_captions_written']}, "
1129
+ f"Total caption entries: {self.stats['total_caption_entries_written']}, "
1130
+ f"Duplicates skipped: {self.stats['duplicates_skipped']}"
1104
1131
  )
1105
1132
 
1106
1133
  async def get_storage_stats(self) -> Dict[str, Any]:
1107
- """Get all storage-related statistics."""
1108
- # Count outputs on disk
1109
- disk_outputs = await self.count_captions()
1110
-
1134
+ """Get all storage-related statistics from cached values."""
1111
1135
  # Count outputs in buffer
1112
1136
  buffer_outputs = 0
1113
1137
  for row in self.caption_buffer:
@@ -1117,22 +1141,21 @@ class StorageManager:
1117
1141
 
1118
1142
  # Get field-specific stats
1119
1143
  field_stats = await self.get_caption_stats()
1120
- total_rows_including_buffer = await self.count_caption_rows() + len(self.caption_buffer)
1144
+ total_rows_including_buffer = self.stats["disk_rows"] + len(self.caption_buffer)
1121
1145
 
1122
1146
  # Calculate rates
1123
1147
  rates = self._calculate_rates()
1124
1148
 
1125
1149
  return {
1126
- "total_captions": disk_outputs + buffer_outputs,
1150
+ "total_captions": self.stats["disk_outputs"] + buffer_outputs,
1127
1151
  "total_rows": total_rows_including_buffer,
1128
1152
  "buffer_size": len(self.caption_buffer),
1129
- "total_written": self.total_captions_written,
1130
- "total_entries_written": self.total_caption_entries_written,
1131
- "duplicates_skipped": self.duplicates_skipped,
1132
- "total_flushes": self.total_flushes,
1153
+ "total_written": self.stats["total_captions_written"],
1154
+ "total_entries_written": self.stats["total_caption_entries_written"],
1155
+ "duplicates_skipped": self.stats["duplicates_skipped"],
1156
+ "total_flushes": self.stats["total_flushes"],
1133
1157
  "output_fields": sorted(list(self.known_output_fields)),
1134
- "field_breakdown": field_stats.get("field_stats", None),
1135
- "job_buffer_size": len(self.job_buffer),
1158
+ "field_breakdown": field_stats.get("field_stats", {}),
1136
1159
  "contributor_buffer_size": len(self.contributor_buffer),
1137
1160
  "rates": {
1138
1161
  "instant": f"{rates['instant']:.1f} rows/s",