caption-flow 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caption_flow/storage.py CHANGED
@@ -4,18 +4,21 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  from dataclasses import asdict
7
- from datetime import datetime
7
+ from datetime import datetime, timedelta
8
8
  from pathlib import Path
9
9
  from typing import List, Optional, Set, Dict, Any
10
10
  import pyarrow as pa
11
11
  import pyarrow.parquet as pq
12
12
  from pyarrow import fs
13
13
  import pandas as pd
14
- from collections import defaultdict
14
+ from collections import defaultdict, deque
15
+ import time
16
+ import numpy as np
15
17
 
16
- from .models import Job, Caption, Contributor, JobStatus
18
+ from .models import Job, Caption, Contributor, JobStatus, JobId
17
19
 
18
20
  logger = logging.getLogger(__name__)
21
+ logger.setLevel(logging.INFO)
19
22
 
20
23
 
21
24
  class StorageManager:
@@ -60,6 +63,11 @@ class StorageManager:
60
63
  self.total_flushes = 0
61
64
  self.duplicates_skipped = 0
62
65
 
66
+ # Rate tracking
67
+ self.row_additions = deque(maxlen=10000) # Store (timestamp, row_count) tuples
68
+ self.start_time = time.time()
69
+ self.last_rate_log_time = time.time()
70
+
63
71
  # Base caption schema without dynamic output fields
64
72
  self.base_caption_fields = [
65
73
  ("job_id", pa.string()),
@@ -68,6 +76,8 @@ class StorageManager:
68
76
  ("chunk_id", pa.string()),
69
77
  ("item_key", pa.string()),
70
78
  ("item_index", pa.int32()),
79
+ ("filename", pa.string()),
80
+ ("url", pa.string()),
71
81
  ("caption_count", pa.int32()),
72
82
  ("contributor_id", pa.string()),
73
83
  ("timestamp", pa.timestamp("us")),
@@ -105,6 +115,137 @@ class StorageManager:
105
115
  ]
106
116
  )
107
117
 
118
+ def _is_column_empty(self, df: pd.DataFrame, column_name: str) -> bool:
119
+ """Check if a column is entirely empty, null, or contains only zeros/empty lists."""
120
+ if column_name not in df.columns:
121
+ return True
122
+
123
+ col = df[column_name]
124
+
125
+ # Check if all values are null/NaN
126
+ if col.isna().all():
127
+ return True
128
+
129
+ # For numeric columns, check if all non-null values are 0
130
+ if pd.api.types.is_numeric_dtype(col):
131
+ non_null_values = col.dropna()
132
+ if len(non_null_values) > 0 and (non_null_values == 0).all():
133
+ return True
134
+
135
+ # For list columns, check if all are None or empty lists
136
+ if col.dtype == "object":
137
+ non_null_values = col.dropna()
138
+ if len(non_null_values) == 0:
139
+ return True
140
+ # Check if all non-null values are empty lists
141
+ all_empty_lists = True
142
+ for val in non_null_values:
143
+ if isinstance(val, list) and len(val) > 0:
144
+ all_empty_lists = False
145
+ break
146
+ elif not isinstance(val, list):
147
+ all_empty_lists = False
148
+ break
149
+ if all_empty_lists:
150
+ return True
151
+
152
+ return False
153
+
154
+ def _get_non_empty_columns(
155
+ self, df: pd.DataFrame, preserve_base_fields: bool = True
156
+ ) -> List[str]:
157
+ """Get list of columns that contain actual data.
158
+
159
+ Args:
160
+ df: DataFrame to check
161
+ preserve_base_fields: If True, always include base fields even if empty
162
+ """
163
+ base_field_names = {field[0] for field in self.base_caption_fields}
164
+ non_empty_columns = []
165
+
166
+ for col in df.columns:
167
+ # Always keep base fields if preserve_base_fields is True
168
+ if preserve_base_fields and col in base_field_names:
169
+ non_empty_columns.append(col)
170
+ elif not self._is_column_empty(df, col):
171
+ non_empty_columns.append(col)
172
+
173
+ return non_empty_columns
174
+
175
+ def _calculate_rates(self) -> Dict[str, float]:
176
+ """Calculate row addition rates over different time windows."""
177
+ current_time = time.time()
178
+ rates = {}
179
+
180
+ # Define time windows in minutes
181
+ windows = {"1min": 1, "5min": 5, "15min": 15, "60min": 60}
182
+
183
+ # Clean up old entries beyond the largest window
184
+ cutoff_time = current_time - (60 * 60) # 60 minutes
185
+ while self.row_additions and self.row_additions[0][0] < cutoff_time:
186
+ self.row_additions.popleft()
187
+
188
+ # Calculate rates for each window
189
+ for window_name, window_minutes in windows.items():
190
+ window_seconds = window_minutes * 60
191
+ window_start = current_time - window_seconds
192
+
193
+ # Sum rows added within this window
194
+ rows_in_window = sum(
195
+ count for timestamp, count in self.row_additions if timestamp >= window_start
196
+ )
197
+
198
+ # Calculate rate (rows per second)
199
+ # For windows larger than elapsed time, use elapsed time
200
+ elapsed = current_time - self.start_time
201
+ actual_window = min(window_seconds, elapsed)
202
+
203
+ if actual_window > 0:
204
+ rate = rows_in_window / actual_window
205
+ rates[window_name] = rate
206
+ else:
207
+ rates[window_name] = 0.0
208
+
209
+ # Calculate instantaneous rate (last minute)
210
+ instant_window_start = current_time - 60 # Last 60 seconds
211
+ instant_rows = sum(
212
+ count for timestamp, count in self.row_additions if timestamp >= instant_window_start
213
+ )
214
+ instant_window = min(60, current_time - self.start_time)
215
+ rates["instant"] = instant_rows / instant_window if instant_window > 0 else 0.0
216
+
217
+ # Calculate overall rate since start
218
+ total_elapsed = current_time - self.start_time
219
+ if total_elapsed > 0:
220
+ rates["overall"] = self.total_captions_written / total_elapsed
221
+ else:
222
+ rates["overall"] = 0.0
223
+
224
+ return rates
225
+
226
+ def _log_rates(self, rows_added: int):
227
+ """Log rate information if enough time has passed."""
228
+ current_time = time.time()
229
+
230
+ # Log rates every 10 seconds or if it's been more than 30 seconds
231
+ time_since_last_log = current_time - self.last_rate_log_time
232
+ if time_since_last_log < 10 and rows_added < 50:
233
+ return
234
+
235
+ rates = self._calculate_rates()
236
+
237
+ # Format the rate information
238
+ rate_str = (
239
+ f"Rate stats - Instant: {rates['instant']:.1f} rows/s | "
240
+ f"Avg (5m): {rates['5min']:.1f} | "
241
+ f"Avg (15m): {rates['15min']:.1f} | "
242
+ f"Avg (60m): {rates['60min']:.1f} | "
243
+ f"Overall: {rates['overall']:.1f} rows/s"
244
+ )
245
+
246
+ logger.info(rate_str)
247
+ self.last_rate_log_time = current_time
248
+
108
249
  def _get_existing_output_columns(self) -> Set[str]:
109
250
  """Get output field columns that actually exist in the parquet file."""
110
251
  if not self.captions_path.exists():
@@ -216,9 +357,14 @@ class StorageManager:
216
357
  if "outputs" in df.columns:
217
358
  df = df.drop(columns=["outputs"])
218
359
 
219
- # Update known fields and schema
220
- self.known_output_fields = output_fields
221
- self.caption_schema = self._build_caption_schema(output_fields)
360
+ # Remove empty columns before saving (but preserve base fields)
361
+ non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=True)
362
+ df = df[non_empty_columns]
363
+
364
+ # Update known fields and schema based on non-empty columns
365
+ base_field_names = {field[0] for field in self.base_caption_fields}
366
+ self.known_output_fields = set(non_empty_columns) - base_field_names
367
+ self.caption_schema = self._build_caption_schema(self.known_output_fields)
222
368
 
223
369
  # Write migrated table
224
370
  migrated_table = pa.Table.from_pandas(df, schema=self.caption_schema)
@@ -226,8 +372,7 @@ class StorageManager:
226
372
  logger.info("Migration complete - outputs now stored in dynamic columns")
227
373
 
228
374
  async def save_caption(self, caption: Caption):
229
- """Save a caption entry with dynamic output columns."""
230
- # Convert to dict
375
+ """Save a caption entry, grouping outputs by job_id/item_key (not separating captions)."""
231
376
  caption_dict = asdict(caption)
232
377
 
233
378
  # Extract item_index from metadata if present
@@ -242,16 +387,61 @@ class StorageManager:
242
387
  # Remove old "captions" field if it exists (will be in outputs)
243
388
  caption_dict.pop("captions", None)
244
389
 
245
- new_fields = set()
390
+ # Grouping key: (job_id, item_key)
391
+ _job_id = caption_dict.get("job_id")
392
+ job_id = JobId.from_dict(_job_id).get_sample_str()
393
+ group_key = job_id
394
+ logger.debug(
395
+ f"save_caption: group_key={group_key}, outputs={list(outputs.keys())}, caption_count={caption_dict.get('caption_count')}, item_index={caption_dict.get('item_index')}"
396
+ )
397
+
398
+ # Try to find existing buffered row for this group
399
+ found_row = False
400
+ for idx, row in enumerate(self.caption_buffer):
401
+ check_key = row.get("job_id")
402
+ logger.debug(f"Checking buffer row {idx}: check_key={check_key}, group_key={group_key}")
403
+ if check_key == group_key:
404
+ found_row = True
405
+ logger.debug(f"Found existing buffer row for group_key={group_key} at index {idx}")
406
+ # Merge outputs into existing row
407
+ for field_name, field_values in outputs.items():
408
+ if field_name not in self.known_output_fields:
409
+ self.known_output_fields.add(field_name)
410
+ logger.info(f"New output field detected: {field_name}")
411
+ if field_name in row and isinstance(row[field_name], list):
412
+ logger.debug(
413
+ f"Merging output field '{field_name}' into existing row: before={row[field_name]}, adding={field_values}"
414
+ )
415
+ row[field_name].extend(field_values)
416
+ logger.debug(f"After merge: {row[field_name]}")
417
+ else:
418
+ logger.debug(
419
+ f"Setting new output field '{field_name}' in existing row: {field_values}"
420
+ )
421
+ row[field_name] = list(field_values)
422
+ # Optionally update other fields (e.g., caption_count)
423
+ if "caption_count" in caption_dict:
424
+ old_count = row.get("caption_count", 0)
425
+ row["caption_count"] = old_count + caption_dict["caption_count"]
426
+ logger.debug(
427
+ f"Updated caption_count for group_key={group_key}: {old_count} + {caption_dict['caption_count']} = {row['caption_count']}"
428
+ )
429
+ return # Already merged, no need to add new row
430
+ else:
431
+ logger.debug(f"Caption row not found for group key: {group_key} vs {check_key}")
432
+
433
+ if not found_row:
434
+ logger.debug(
435
+ f"No existing buffer row found for group_key={group_key}, creating new row."
436
+ )
437
+
438
+ # If not found, create new row
246
439
  for field_name, field_values in outputs.items():
247
- caption_dict[field_name] = field_values
248
440
  if field_name not in self.known_output_fields:
249
- new_fields.add(field_name)
250
- self.known_output_fields.add(field_name) # Add immediately
251
-
252
- if new_fields:
253
- logger.info(f"New output fields detected: {sorted(new_fields)}")
254
- logger.info(f"Total known output fields: {sorted(self.known_output_fields)}")
441
+ self.known_output_fields.add(field_name)
442
+ logger.info(f"New output field detected: {field_name}")
443
+ caption_dict[field_name] = list(field_values)
444
+ logger.debug(f"Adding output field '{field_name}' to new row: {field_values}")
255
445
 
256
446
  # Serialize metadata to JSON if present
257
447
  if "metadata" in caption_dict:
@@ -259,68 +449,16 @@ class StorageManager:
259
449
  else:
260
450
  caption_dict["metadata"] = "{}"
261
451
 
262
- # Add to buffer
263
- self.caption_buffer.append(caption_dict)
264
-
265
- # Log buffer status
266
- logger.debug(f"Caption buffer size: {len(self.caption_buffer)}/{self.caption_buffer_size}")
267
-
268
- # Flush if buffer is large enough
269
- if len(self.caption_buffer) >= self.caption_buffer_size:
270
- await self._flush_captions()
271
-
272
- async def save_captions(self, caption_data: Dict[str, Any]):
273
- """Save captions for an image - compatible with dict input."""
274
- job_id = caption_data["job_id"]
275
-
276
- # Check if we already have captions for this job_id
277
- if job_id in self.existing_caption_job_ids:
278
- self.duplicates_skipped += 1
279
- logger.debug(f"Skipping duplicate captions for job_id: {job_id}")
280
- return
281
-
282
- # Check if it's already in the buffer
283
- for buffered in self.caption_buffer:
284
- if buffered["job_id"] == job_id:
285
- logger.debug(f"Captions for job_id {job_id} already in buffer")
286
- return
287
-
288
- # Handle outputs if present
289
- if "outputs" in caption_data:
290
- outputs = caption_data.pop("outputs")
291
- # Add each output field directly to caption_data
292
- for field_name, field_values in outputs.items():
293
- caption_data[field_name] = field_values
294
- if field_name not in self.known_output_fields:
295
- self.known_output_fields.add(field_name)
296
- logger.info(f"New output field detected: {field_name}")
297
-
298
- # Handle legacy captions field
299
- if "captions" in caption_data and "captions" not in self.known_output_fields:
300
- self.known_output_fields.add("captions")
301
-
302
- # Count all outputs
303
- caption_count = 0
304
- for field_name in self.known_output_fields:
305
- if field_name in caption_data and isinstance(caption_data[field_name], list):
306
- caption_count += len(caption_data[field_name])
307
-
308
- caption_data["caption_count"] = caption_count
309
-
310
- # Add default values for optional fields
311
- if "quality_scores" not in caption_data:
312
- caption_data["quality_scores"] = None
452
+ if isinstance(caption_dict.get("job_id"), dict):
453
+ caption_dict["job_id"] = job_id
313
454
 
314
- if "metadata" in caption_data and isinstance(caption_data["metadata"], dict):
315
- caption_data["metadata"] = json.dumps(caption_data["metadata"])
316
- elif "metadata" not in caption_data:
317
- caption_data["metadata"] = "{}"
318
-
319
- self.caption_buffer.append(caption_data)
320
- self.existing_caption_job_ids.add(job_id)
455
+ self.caption_buffer.append(caption_dict)
456
+ logger.debug(
457
+ f"Appended new caption row for group_key={group_key}. Caption buffer size: {len(self.caption_buffer)}/{self.caption_buffer_size}"
458
+ )
321
459
 
322
- # Flush if buffer is large enough
323
460
  if len(self.caption_buffer) >= self.caption_buffer_size:
461
+ logger.debug("Caption buffer full, flushing captions.")
324
462
  await self._flush_captions()
325
463
 
326
464
  async def _flush_captions(self):
@@ -337,25 +475,7 @@ class StorageManager:
337
475
  if field_name in row and isinstance(row[field_name], list):
338
476
  total_outputs += len(row[field_name])
339
477
 
340
- logger.info(f"Flushing {num_rows} rows with {total_outputs} total outputs to disk")
341
-
342
- # Check if we need to evolve the schema
343
- current_schema_fields = set(self.caption_schema.names) if self.caption_schema else set()
344
- all_fields_needed = set(
345
- self.base_caption_fields[i][0] for i in range(len(self.base_caption_fields))
346
- )
347
- all_fields_needed.update(self.known_output_fields)
348
-
349
- if all_fields_needed != current_schema_fields:
350
- # Schema evolution needed
351
- logger.info(
352
- f"Evolving schema to include new fields: {all_fields_needed - current_schema_fields}"
353
- )
354
- self.caption_schema = self._build_caption_schema(self.known_output_fields)
355
-
356
- # If file exists, we need to migrate it
357
- if self.captions_path.exists():
358
- await self._evolve_schema_on_disk()
478
+ logger.debug(f"Flushing {num_rows} rows with {total_outputs} total outputs to disk")
359
479
 
360
480
  # Prepare data with all required columns
361
481
  prepared_buffer = []
@@ -374,8 +494,9 @@ class StorageManager:
374
494
 
375
495
  prepared_buffer.append(prepared_row)
376
496
 
377
- # Create table from buffer
378
- table = pa.Table.from_pylist(prepared_buffer, schema=self.caption_schema)
497
+ # Build schema with all known fields (base + output)
498
+ schema = self._build_caption_schema(self.known_output_fields)
499
+ table = pa.Table.from_pylist(prepared_buffer, schema=schema)
379
500
 
380
501
  if self.captions_path.exists():
381
502
  # Read existing table
@@ -391,45 +512,145 @@ class StorageManager:
391
512
  if row["job_id"] not in existing_job_ids:
392
513
  new_rows.append(row)
393
514
  elif row not in duplicate_rows:
394
- duplicate_rows.append(row)
515
+ duplicate_rows.append(
516
+ {
517
+ "input": row,
518
+ "existing_job": existing.to_pandas()[
519
+ existing.to_pandas()["job_id"] == row["job_id"]
520
+ ].to_dict(orient="records"),
521
+ }
522
+ )
395
523
 
396
524
  if duplicate_rows:
397
525
  logger.info(f"Example duplicate row: {duplicate_rows[0]}")
526
+
398
527
  if new_rows:
399
528
  # Create table from new rows only
400
- new_table = pa.Table.from_pylist(new_rows, schema=self.caption_schema)
529
+ new_table = pa.Table.from_pylist(new_rows, schema=schema)
401
530
 
402
- # Combine tables
403
- combined = pa.concat_tables([existing, new_table])
531
+ # Concatenate with promote_options="default" to handle schema differences automatically
532
+ combined = pa.concat_tables([existing, new_table], promote_options="default")
404
533
 
405
- # Write with proper preservation
534
+ # Write combined table
406
535
  pq.write_table(combined, self.captions_path, compression="snappy")
407
536
 
408
- logger.info(
409
- f"Added {len(new_rows)} new rows (skipped {num_rows - len(new_rows)} duplicates)"
410
- )
537
+ self.duplicates_skipped = num_rows - len(new_rows)
411
538
  actual_new = len(new_rows)
412
539
  else:
413
- logger.info(f"All {num_rows} rows were duplicates, skipping write")
414
- actual_new = 0
540
+ logger.info(f"All {num_rows} rows were duplicates, exiting")
541
+ raise SystemError("No duplicates can be submitted")
415
542
  else:
416
- # Write new file
543
+ # Write new file with all fields
417
544
  pq.write_table(table, self.captions_path, compression="snappy")
418
545
  actual_new = num_rows
419
546
 
547
+ # Update statistics
420
548
  self.total_captions_written += actual_new
421
549
  self.total_caption_entries_written += total_outputs
422
550
  self.total_flushes += 1
423
551
  self.caption_buffer.clear()
424
552
 
553
+ # Track row additions for rate calculation
554
+ if actual_new > 0:
555
+ current_time = time.time()
556
+ self.row_additions.append((current_time, actual_new))
557
+
558
+ # Log rates
559
+ self._log_rates(actual_new)
560
+
425
561
  logger.info(
426
- f"Successfully wrote captions (rows: {self.total_captions_written}, "
427
- f"total outputs: {self.total_caption_entries_written}, "
428
- f"duplicates skipped: {self.duplicates_skipped})"
562
+ f"Successfully wrote captions (new rows: {actual_new}, "
563
+ f"total rows written: {self.total_captions_written}, "
564
+ f"total captions written: {self.total_caption_entries_written}, "
565
+ f"duplicates skipped: {self.duplicates_skipped}, "
566
+ f"output fields: {sorted(list(self.known_output_fields))})"
429
567
  )
430
568
 
569
+ async def optimize_storage(self):
570
+ """Optimize storage by dropping empty columns. Run this periodically or on-demand."""
571
+ if not self.captions_path.exists():
572
+ logger.info("No captions file to optimize")
573
+ return
574
+
575
+ logger.info("Starting storage optimization...")
576
+
577
+ # Read the full table
578
+ table = pq.read_table(self.captions_path)
579
+ df = table.to_pandas()
580
+ original_columns = len(df.columns)
581
+
582
+ # Find non-empty columns (don't preserve empty base fields)
583
+ non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=False)
584
+
585
+ # Always keep at least job_id
586
+ if "job_id" not in non_empty_columns:
587
+ non_empty_columns.append("job_id")
588
+
589
+ if len(non_empty_columns) < original_columns:
590
+ # We have columns to drop
591
+ df_optimized = df[non_empty_columns]
592
+
593
+ # Rebuild schema for non-empty columns only
594
+ base_field_names = {f[0] for f in self.base_caption_fields}
595
+ fields = []
596
+ output_fields = set()
597
+
598
+ # Process columns in a consistent order: base fields first, then output fields
599
+ for col in non_empty_columns:
600
+ if col in base_field_names:
601
+ # Find the base field definition
602
+ for fname, ftype in self.base_caption_fields:
603
+ if fname == col:
604
+ fields.append((fname, ftype))
605
+ break
606
+ else:
607
+ # Output field
608
+ output_fields.add(col)
609
+
610
+ # Add output fields in sorted order
611
+ for field_name in sorted(output_fields):
612
+ fields.append((field_name, pa.list_(pa.string())))
613
+
614
+ # Create optimized schema and table
615
+ optimized_schema = pa.schema(fields)
616
+ optimized_table = pa.Table.from_pandas(df_optimized, schema=optimized_schema)
617
+
618
+ # Backup the original file (optional)
619
+ backup_path = self.captions_path.with_suffix(".parquet.bak")
620
+ import shutil
621
+
622
+ shutil.copy2(self.captions_path, backup_path)
623
+
624
+ # Write optimized table
625
+ pq.write_table(optimized_table, self.captions_path, compression="snappy")
626
+
627
+ # Update known output fields
628
+ self.known_output_fields = output_fields
629
+
630
+ # Clean up backup (optional - keep it for safety)
631
+ # backup_path.unlink()
632
+
633
+ logger.info(
634
+ f"Storage optimization complete: {original_columns} -> {len(non_empty_columns)} columns. "
635
+ f"Removed columns: {sorted(set(df.columns) - set(non_empty_columns))}"
636
+ )
637
+ else:
638
+ logger.info(f"No optimization needed - all {original_columns} columns contain data")
639
+
640
+ # Report file size reduction
641
+ import os
642
+
643
+ if backup_path and backup_path.exists():
644
+ original_size = os.path.getsize(backup_path)
645
+ new_size = os.path.getsize(self.captions_path)
646
+ reduction_pct = (1 - new_size / original_size) * 100
647
+ logger.info(
648
+ f"File size: {original_size/1024/1024:.1f}MB -> {new_size/1024/1024:.1f}MB "
649
+ f"({reduction_pct:.1f}% reduction)"
650
+ )
651
+
431
652
  async def _evolve_schema_on_disk(self):
432
- """Evolve the schema of the existing parquet file to include new columns."""
653
+ """Evolve the schema of the existing parquet file to include new columns, removing empty ones."""
433
654
  logger.info("Evolving schema on disk to add new columns...")
434
655
 
435
656
  # Read existing data
@@ -442,63 +663,24 @@ class StorageManager:
442
663
  df[field_name] = None
443
664
  logger.info(f"Added new column: {field_name}")
444
665
 
445
- # Recreate table with new schema
446
- evolved_table = pa.Table.from_pandas(df, schema=self.caption_schema)
447
- pq.write_table(evolved_table, self.captions_path, compression="snappy")
448
- logger.info("Schema evolution complete")
666
+ # Remove empty columns (but preserve base fields)
667
+ non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=True)
668
+ df = df[non_empty_columns]
449
669
 
450
- async def get_captions(self, job_id: str) -> Optional[Dict[str, List[str]]]:
451
- """Retrieve all output fields for a specific job_id."""
452
- # Check buffer first
453
- for buffered in self.caption_buffer:
454
- if buffered["job_id"] == job_id:
455
- outputs = {}
456
- for field_name in self.known_output_fields:
457
- if field_name in buffered and buffered[field_name]:
458
- outputs[field_name] = buffered[field_name]
459
- return outputs
460
-
461
- if not self.captions_path.exists():
462
- return None
463
-
464
- table = pq.read_table(self.captions_path)
465
- df = table.to_pandas()
670
+ # Update known output fields
671
+ base_field_names = {field[0] for field in self.base_caption_fields}
672
+ self.known_output_fields = set(non_empty_columns) - base_field_names
466
673
 
467
- row = df[df["job_id"] == job_id]
468
- if row.empty:
469
- return None
674
+ # Recreate schema with only non-empty fields
675
+ self.caption_schema = self._build_caption_schema(self.known_output_fields)
470
676
 
471
- # Collect all output fields
472
- outputs = {}
473
- for field_name in self.known_output_fields:
474
- if field_name in row.columns:
475
- value = row.iloc[0][field_name]
476
- if pd.notna(value) and value is not None:
477
- outputs[field_name] = value
478
-
479
- return outputs if outputs else None
480
-
481
- async def save_job(self, job: Job):
482
- """Save or update a job - buffers until batch size reached."""
483
- # For updates, we still add to buffer (will be handled in flush)
484
- self.job_buffer.append(
485
- {
486
- "job_id": job.job_id,
487
- "dataset": job.dataset,
488
- "shard": job.shard,
489
- "item_key": job.item_key,
490
- "status": job.status.value,
491
- "assigned_to": job.assigned_to,
492
- "created_at": job.created_at,
493
- "updated_at": datetime.utcnow(),
494
- }
677
+ # Recreate table with new schema
678
+ evolved_table = pa.Table.from_pandas(df, schema=self.caption_schema)
679
+ pq.write_table(evolved_table, self.captions_path, compression="snappy")
680
+ logger.info(
681
+ f"Schema evolution complete. Active output fields: {sorted(list(self.known_output_fields))}"
495
682
  )
496
683
 
497
- self.existing_job_ids.add(job.job_id)
498
-
499
- if len(self.job_buffer) >= self.job_buffer_size:
500
- await self._flush_jobs()
501
-
502
684
  async def save_contributor(self, contributor: Contributor):
503
685
  """Save or update contributor stats - buffers until batch size reached."""
504
686
  self.contributor_buffer.append(asdict(contributor))
@@ -575,84 +757,51 @@ class StorageManager:
575
757
  await self._flush_jobs()
576
758
  await self._flush_contributors()
577
759
 
578
- logger.info(
579
- f"Checkpoint complete. Total rows: {self.total_captions_written}, "
580
- f"Total caption entries: {self.total_caption_entries_written}, "
581
- f"Duplicates skipped: {self.duplicates_skipped}"
582
- )
583
-
584
- async def job_exists(self, job_id: str) -> bool:
585
- """Check if a job already exists in storage or buffer."""
586
- if job_id in self.existing_job_ids:
587
- return True
588
-
589
- # Check buffer
590
- for buffered in self.job_buffer:
591
- if buffered["job_id"] == job_id:
592
- return True
593
-
594
- return False
595
-
596
- async def get_job(self, job_id: str) -> Optional[Job]:
597
- """Retrieve a job by ID."""
598
- # Check buffer first
599
- for buffered in self.job_buffer:
600
- if buffered["job_id"] == job_id:
601
- return Job(
602
- job_id=buffered["job_id"],
603
- dataset=buffered["dataset"],
604
- shard=buffered["shard"],
605
- item_key=buffered["item_key"],
606
- status=JobStatus(buffered["status"]),
607
- assigned_to=buffered["assigned_to"],
608
- created_at=buffered["created_at"],
609
- )
760
+ # Log final rate statistics
761
+ if self.total_captions_written > 0:
762
+ rates = self._calculate_rates()
763
+ logger.info(
764
+ f"Checkpoint complete. Total rows: {self.total_captions_written}, "
765
+ f"Total caption entries: {self.total_caption_entries_written}, "
766
+ f"Duplicates skipped: {self.duplicates_skipped} | "
767
+ f"Overall rate: {rates['overall']:.1f} rows/s"
768
+ )
769
+ else:
770
+ logger.info(
771
+ f"Checkpoint complete. Total rows: {self.total_captions_written}, "
772
+ f"Total caption entries: {self.total_caption_entries_written}, "
773
+ f"Duplicates skipped: {self.duplicates_skipped}"
774
+ )
610
775
 
611
- if not self.jobs_path.exists():
612
- return None
776
+ def get_all_processed_job_ids(self) -> Set[str]:
777
+ """Get all processed job_ids - useful for resumption."""
778
+ if not self.captions_path.exists():
779
+ logger.info("No captions file found, returning empty processed job_ids set")
780
+ return set()
613
781
 
614
- table = pq.read_table(self.jobs_path)
615
- df = table.to_pandas()
782
+ # Read only the job_id column
783
+ table = pq.read_table(self.captions_path, columns=["job_id"])
784
+ job_ids = set(table["job_id"].to_pylist())
616
785
 
617
- row = df[df["job_id"] == job_id]
618
- if row.empty:
619
- return None
786
+ # Add buffered job_ids
787
+ for row in self.caption_buffer:
788
+ if "job_id" in row:
789
+ job_ids.add(row["job_id"])
620
790
 
621
- return Job(
622
- job_id=row.iloc[0]["job_id"],
623
- dataset=row.iloc[0]["dataset"],
624
- shard=row.iloc[0]["shard"],
625
- item_key=row.iloc[0]["item_key"],
626
- status=JobStatus(row.iloc[0]["status"]),
627
- assigned_to=row.iloc[0]["assigned_to"],
628
- created_at=row.iloc[0]["created_at"],
629
- )
791
+ return job_ids
630
792
 
631
- async def get_jobs_by_worker(self, worker_id: str) -> List[Job]:
632
- """Get all jobs assigned to a worker."""
633
- if not self.jobs_path.exists():
634
- return []
793
+ async def get_processed_jobs_for_chunk(self, chunk_id: str) -> Set[str]:
794
+ """Get all processed job_ids for a given chunk."""
795
+ if not self.captions_path.exists():
796
+ return set()
635
797
 
636
- table = pq.read_table(self.jobs_path)
798
+ # Read only job_id and chunk_id columns
799
+ table = pq.read_table(self.captions_path, columns=["job_id", "chunk_id"])
637
800
  df = table.to_pandas()
638
801
 
639
- rows = df[df["assigned_to"] == worker_id]
640
-
641
- jobs = []
642
- for _, row in rows.iterrows():
643
- jobs.append(
644
- Job(
645
- job_id=row["job_id"],
646
- dataset=row["dataset"],
647
- shard=row["shard"],
648
- item_key=row["item_key"],
649
- status=JobStatus(row["status"]),
650
- assigned_to=row["assigned_to"],
651
- created_at=row["created_at"],
652
- )
653
- )
654
-
655
- return jobs
802
+ # Filter by chunk_id and return job_ids
803
+ chunk_jobs = df[df["chunk_id"] == chunk_id]["job_id"].tolist()
804
+ return set(chunk_jobs)
656
805
 
657
806
  async def get_caption_stats(self) -> Dict[str, Any]:
658
807
  """Get statistics about stored captions including field-specific stats."""
@@ -683,11 +832,12 @@ class StorageManager:
683
832
  field_lengths = []
684
833
 
685
834
  for value in df.loc[non_null_mask, field_name]:
835
+ # list or array-like
686
836
  if isinstance(value, list):
687
837
  length = len(value)
688
838
  field_total += length
689
839
  field_lengths.append(length)
690
- elif pd.notna(value):
840
+ elif value.any():
691
841
  length = 1
692
842
  field_total += length
693
843
  field_lengths.append(length)
@@ -719,46 +869,6 @@ class StorageManager:
719
869
  },
720
870
  }
721
871
 
722
- async def get_sample_captions(self, n: int = 5) -> List[Dict[str, Any]]:
723
- """Get a sample of caption entries showing all output fields."""
724
- if not self.captions_path.exists():
725
- return []
726
-
727
- table = pq.read_table(self.captions_path)
728
- df = table.to_pandas()
729
-
730
- if len(df) == 0:
731
- return []
732
-
733
- sample_df = df.sample(min(n, len(df)))
734
- samples = []
735
-
736
- for _, row in sample_df.iterrows():
737
- # Collect outputs from dynamic columns
738
- outputs = {}
739
- total_outputs = 0
740
-
741
- for field_name in self.known_output_fields:
742
- if field_name in row and pd.notna(row[field_name]):
743
- value = row[field_name]
744
- outputs[field_name] = value
745
- if isinstance(value, list):
746
- total_outputs += len(value)
747
-
748
- samples.append(
749
- {
750
- "job_id": row["job_id"],
751
- "item_key": row["item_key"],
752
- "outputs": outputs,
753
- "field_count": len(outputs),
754
- "total_outputs": total_outputs,
755
- "image_dims": f"{row.get('image_width', 'N/A')}x{row.get('image_height', 'N/A')}",
756
- "has_metadata": bool(row.get("metadata") and row["metadata"] != "{}"),
757
- }
758
- )
759
-
760
- return samples
761
-
762
872
  async def count_captions(self) -> int:
763
873
  """Count total outputs across all dynamic fields."""
764
874
  total = 0
@@ -888,142 +998,26 @@ class StorageManager:
888
998
  "fields": sorted(list(field_counts.keys())),
889
999
  }
890
1000
 
891
- async def get_captions_with_field(
892
- self, field_name: str, limit: int = 100
893
- ) -> List[Dict[str, Any]]:
894
- """Get captions that have a specific output field."""
895
- if not self.captions_path.exists():
896
- return []
897
-
898
- if field_name not in self.known_output_fields:
899
- logger.warning(f"Field '{field_name}' not found in known output fields")
900
- return []
901
-
902
- # Check if the field actually exists in the file
903
- existing_output_columns = self._get_existing_output_columns()
904
- if field_name not in existing_output_columns:
905
- logger.warning(
906
- f"Field '{field_name}' exists in known fields but not in parquet file yet"
907
- )
908
- return []
909
-
910
- # Only read necessary columns
911
- columns_to_read = ["job_id", "item_key", field_name]
912
-
913
- try:
914
- table = pq.read_table(self.captions_path, columns=columns_to_read)
915
- except Exception as e:
916
- logger.error(f"Error reading field '{field_name}': {e}")
917
- return []
918
-
919
- df = table.to_pandas()
920
-
921
- # Filter rows where field has data
922
- mask = df[field_name].notna()
923
- filtered_df = df[mask].head(limit)
924
-
925
- results = []
926
- for _, row in filtered_df.iterrows():
927
- results.append(
928
- {
929
- "job_id": row["job_id"],
930
- "item_key": row["item_key"],
931
- field_name: row[field_name],
932
- "value_count": len(row[field_name]) if isinstance(row[field_name], list) else 1,
933
- }
934
- )
935
-
936
- return results
937
-
938
- async def export_by_field(self, field_name: str, output_path: Path, format: str = "jsonl"):
939
- """Export all captions for a specific field."""
940
- if not self.captions_path.exists():
941
- logger.warning("No captions to export")
942
- return 0
943
-
944
- if field_name not in self.known_output_fields:
945
- logger.warning(f"Field '{field_name}' not found in known output fields")
946
- return 0
947
-
948
- # Check if the field actually exists in the file
949
- existing_output_columns = self._get_existing_output_columns()
950
- if field_name not in existing_output_columns:
951
- logger.warning(f"Field '{field_name}' not found in parquet file")
952
- return 0
953
-
954
- # Read only necessary columns
955
- columns_to_read = ["item_key", "dataset", field_name]
956
- table = pq.read_table(self.captions_path, columns=columns_to_read)
957
- df = table.to_pandas()
958
-
959
- exported = 0
960
- with open(output_path, "w") as f:
961
- for _, row in df.iterrows():
962
- if pd.notna(row[field_name]) and row[field_name]:
963
- if format == "jsonl":
964
- record = {
965
- "item_key": row["item_key"],
966
- "dataset": row["dataset"],
967
- field_name: row[field_name],
968
- }
969
- f.write(json.dumps(record) + "\n")
970
- exported += 1
971
-
972
- logger.info(f"Exported {exported} items with field '{field_name}' to {output_path}")
973
- return exported
974
-
975
- async def get_pending_jobs(self) -> List[Job]:
976
- """Get all pending jobs for restoration on startup."""
977
- if not self.jobs_path.exists():
978
- return []
979
-
980
- table = pq.read_table(self.jobs_path)
981
- df = table.to_pandas()
982
-
983
- # Get jobs with PENDING or PROCESSING status
984
- pending_df = df[df["status"].isin([JobStatus.PENDING.value, JobStatus.PROCESSING.value])]
985
-
986
- jobs = []
987
- for _, row in pending_df.iterrows():
988
- jobs.append(
989
- Job(
990
- job_id=row["job_id"],
991
- dataset=row["dataset"],
992
- shard=row["shard"],
993
- item_key=row["item_key"],
994
- status=JobStatus(row["status"]),
995
- assigned_to=row.get("assigned_to"),
996
- created_at=row["created_at"],
997
- )
998
- )
999
-
1000
- return jobs
1001
-
1002
- async def count_jobs(self) -> int:
1003
- """Count total jobs."""
1004
- if not self.jobs_path.exists():
1005
- return 0
1006
-
1007
- table = pq.read_table(self.jobs_path)
1008
- return len(table)
1009
-
1010
- async def count_completed_jobs(self) -> int:
1011
- """Count completed jobs."""
1012
- if not self.jobs_path.exists():
1013
- return 0
1014
-
1015
- table = pq.read_table(self.jobs_path)
1016
- df = table.to_pandas()
1017
- return len(df[df["status"] == JobStatus.COMPLETED.value])
1018
-
1019
1001
  async def close(self):
1020
1002
  """Close storage and flush buffers."""
1021
1003
  await self.checkpoint()
1022
- logger.info(
1023
- f"Storage closed. Total rows: {self.total_captions_written}, "
1024
- f"Total caption entries: {self.total_caption_entries_written}, "
1025
- f"Duplicates skipped: {self.duplicates_skipped}"
1026
- )
1004
+
1005
+ # Log final rate statistics
1006
+ if self.total_captions_written > 0:
1007
+ rates = self._calculate_rates()
1008
+ logger.info(
1009
+ f"Storage closed. Total rows: {self.total_captions_written}, "
1010
+ f"Total caption entries: {self.total_caption_entries_written}, "
1011
+ f"Duplicates skipped: {self.duplicates_skipped} | "
1012
+ f"Final rates - Overall: {rates['overall']:.1f} rows/s, "
1013
+ f"Last hour: {rates['60min']:.1f} rows/s"
1014
+ )
1015
+ else:
1016
+ logger.info(
1017
+ f"Storage closed. Total rows: {self.total_captions_written}, "
1018
+ f"Total caption entries: {self.total_caption_entries_written}, "
1019
+ f"Duplicates skipped: {self.duplicates_skipped}"
1020
+ )
1027
1021
 
1028
1022
  async def get_storage_stats(self) -> Dict[str, Any]:
1029
1023
  """Get all storage-related statistics."""
@@ -1041,6 +1035,9 @@ class StorageManager:
1041
1035
  field_stats = await self.get_caption_stats()
1042
1036
  total_rows_including_buffer = await self.count_caption_rows() + len(self.caption_buffer)
1043
1037
 
1038
+ # Calculate rates
1039
+ rates = self._calculate_rates()
1040
+
1044
1041
  return {
1045
1042
  "total_captions": disk_outputs + buffer_outputs,
1046
1043
  "total_rows": total_rows_including_buffer,
@@ -1053,4 +1050,11 @@ class StorageManager:
1053
1050
  "field_breakdown": field_stats.get("field_stats", None),
1054
1051
  "job_buffer_size": len(self.job_buffer),
1055
1052
  "contributor_buffer_size": len(self.contributor_buffer),
1053
+ "rates": {
1054
+ "instant": f"{rates['instant']:.1f} rows/s",
1055
+ "5min": f"{rates['5min']:.1f} rows/s",
1056
+ "15min": f"{rates['15min']:.1f} rows/s",
1057
+ "60min": f"{rates['60min']:.1f} rows/s",
1058
+ "overall": f"{rates['overall']:.1f} rows/s",
1059
+ },
1056
1060
  }