caption-flow 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caption_flow/storage.py CHANGED
@@ -1,4 +1,4 @@
1
- """Arrow/Parquet storage management with list column support for captions."""
1
+ """Arrow/Parquet storage management with dynamic column support for outputs."""
2
2
 
3
3
  import asyncio
4
4
  import json
@@ -11,6 +11,7 @@ import pyarrow as pa
11
11
  import pyarrow.parquet as pq
12
12
  from pyarrow import fs
13
13
  import pandas as pd
14
+ from collections import defaultdict
14
15
 
15
16
  from .models import Job, Caption, Contributor, JobStatus
16
17
 
@@ -18,7 +19,7 @@ logger = logging.getLogger(__name__)
18
19
 
19
20
 
20
21
  class StorageManager:
21
- """Manages Arrow/Parquet storage for captions and jobs with list column support."""
22
+ """Manages Arrow/Parquet storage with dynamic columns for output fields."""
22
23
 
23
24
  def __init__(
24
25
  self,
@@ -50,32 +51,37 @@ class StorageManager:
50
51
  self.existing_caption_job_ids: Set[str] = set()
51
52
  self.existing_job_ids: Set[str] = set()
52
53
 
54
+ # Track known output fields for schema evolution
55
+ self.known_output_fields: Set[str] = set()
56
+
53
57
  # Statistics
54
58
  self.total_captions_written = 0
55
59
  self.total_caption_entries_written = 0 # Total individual captions
56
60
  self.total_flushes = 0
57
61
  self.duplicates_skipped = 0
58
62
 
59
- # Schemas - Updated caption schema to support list of captions
60
- self.caption_schema = pa.schema(
61
- [
62
- ("job_id", pa.string()),
63
- ("dataset", pa.string()),
64
- ("shard", pa.string()),
65
- ("chunk_id", pa.string()),
66
- ("item_key", pa.string()),
67
- ("captions", pa.list_(pa.string())), # Changed from single caption to list
68
- ("caption_count", pa.int32()), # Number of captions for this item
69
- ("contributor_id", pa.string()),
70
- ("timestamp", pa.timestamp("us")),
71
- ("quality_scores", pa.list_(pa.float32())), # Optional quality scores per caption
72
- ("image_width", pa.int32()),
73
- ("image_height", pa.int32()),
74
- ("image_format", pa.string()),
75
- ("file_size", pa.int64()),
76
- ("processing_time_ms", pa.float32()),
77
- ]
78
- )
63
+ # Base caption schema without dynamic output fields
64
+ self.base_caption_fields = [
65
+ ("job_id", pa.string()),
66
+ ("dataset", pa.string()),
67
+ ("shard", pa.string()),
68
+ ("chunk_id", pa.string()),
69
+ ("item_key", pa.string()),
70
+ ("item_index", pa.int32()),
71
+ ("caption_count", pa.int32()),
72
+ ("contributor_id", pa.string()),
73
+ ("timestamp", pa.timestamp("us")),
74
+ ("quality_scores", pa.list_(pa.float32())),
75
+ ("image_width", pa.int32()),
76
+ ("image_height", pa.int32()),
77
+ ("image_format", pa.string()),
78
+ ("file_size", pa.int64()),
79
+ ("processing_time_ms", pa.float32()),
80
+ ("metadata", pa.string()),
81
+ ]
82
+
83
+ # Current caption schema (will be updated dynamically)
84
+ self.caption_schema = None
79
85
 
80
86
  self.job_schema = pa.schema(
81
87
  [
@@ -99,134 +105,159 @@ class StorageManager:
99
105
  ]
100
106
  )
101
107
 
108
+ def _get_existing_output_columns(self) -> Set[str]:
109
+ """Get output field columns that actually exist in the parquet file."""
110
+ if not self.captions_path.exists():
111
+ return set()
112
+
113
+ table_metadata = pq.read_metadata(self.captions_path)
114
+ existing_columns = set(table_metadata.schema.names)
115
+ base_field_names = {field[0] for field in self.base_caption_fields}
116
+
117
+ return existing_columns - base_field_names
118
+
119
+ def _build_caption_schema(self, output_fields: Set[str]) -> pa.Schema:
120
+ """Build caption schema with dynamic output fields."""
121
+ fields = self.base_caption_fields.copy()
122
+
123
+ # Add dynamic output fields (all as list of strings for now)
124
+ for field_name in sorted(output_fields): # Sort for consistent ordering
125
+ fields.append((field_name, pa.list_(pa.string())))
126
+
127
+ return pa.schema(fields)
128
+
102
129
  async def initialize(self):
103
130
  """Initialize storage files if they don't exist."""
104
- # Create empty parquet files if needed
105
131
  if not self.captions_path.exists():
106
- # Create empty table with schema using from_pydict
107
- empty_dict = {
108
- "job_id": [],
109
- "dataset": [],
110
- "shard": [],
111
- "chunk_id": [],
112
- "item_key": [],
113
- "captions": [],
114
- "caption_count": [],
115
- "contributor_id": [],
116
- "timestamp": [],
117
- "quality_scores": [],
118
- "image_width": [],
119
- "image_height": [],
120
- "image_format": [],
121
- "file_size": [],
122
- "processing_time_ms": [],
123
- }
132
+ # Create initial schema with just base fields
133
+ self.caption_schema = self._build_caption_schema(set())
134
+
135
+ # Create empty table
136
+ empty_dict = {field[0]: [] for field in self.base_caption_fields}
124
137
  empty_table = pa.Table.from_pydict(empty_dict, schema=self.caption_schema)
125
138
  pq.write_table(empty_table, self.captions_path)
126
139
  logger.info(f"Created empty caption storage at {self.captions_path}")
127
140
  else:
128
- # Load existing caption job_ids to prevent duplicates
129
- existing_captions = pq.read_table(self.captions_path, columns=["job_id"])
130
- self.existing_caption_job_ids = set(existing_captions["job_id"].to_pylist())
131
- logger.info(f"Loaded {len(self.existing_caption_job_ids)} existing caption job_ids")
141
+ # Load existing schema and detect output fields
142
+ existing_table = pq.read_table(self.captions_path)
143
+ existing_columns = set(existing_table.column_names)
144
+
145
+ # Identify output fields (columns not in base schema)
146
+ base_field_names = {field[0] for field in self.base_caption_fields}
147
+ self.known_output_fields = existing_columns - base_field_names
148
+
149
+ # Check if we need to migrate from old "outputs" JSON column
150
+ if "outputs" in existing_columns:
151
+ logger.info("Migrating from JSON outputs to dynamic columns...")
152
+ await self._migrate_outputs_to_columns(existing_table)
153
+ else:
154
+ # Build current schema from existing columns
155
+ self.caption_schema = self._build_caption_schema(self.known_output_fields)
132
156
 
133
- if not self.jobs_path.exists():
134
- # Create empty table with schema using from_pydict
135
- empty_dict = {
136
- "job_id": [],
137
- "dataset": [],
138
- "shard": [],
139
- "item_key": [],
140
- "status": [],
141
- "assigned_to": [],
142
- "created_at": [],
143
- "updated_at": [],
144
- }
145
- empty_table = pa.Table.from_pydict(empty_dict, schema=self.job_schema)
146
- pq.write_table(empty_table, self.jobs_path)
147
- logger.info(f"Created empty job storage at {self.jobs_path}")
148
- else:
149
- # Load existing job_ids
150
- existing_jobs = pq.read_table(self.jobs_path, columns=["job_id"])
151
- self.existing_job_ids = set(existing_jobs["job_id"].to_pylist())
152
- logger.info(f"Loaded {len(self.existing_job_ids)} existing job_ids")
157
+ # Load existing caption job_ids
158
+ self.existing_caption_job_ids = set(existing_table["job_id"].to_pylist())
159
+ logger.info(f"Loaded {len(self.existing_caption_job_ids)} existing caption job_ids")
160
+ logger.info(f"Known output fields: {sorted(self.known_output_fields)}")
153
161
 
162
+ # Initialize other storage files...
154
163
  if not self.contributors_path.exists():
155
- # Create empty table with schema using from_pydict
156
164
  empty_dict = {"contributor_id": [], "name": [], "total_captions": [], "trust_level": []}
157
165
  empty_table = pa.Table.from_pydict(empty_dict, schema=self.contributor_schema)
158
166
  pq.write_table(empty_table, self.contributors_path)
159
167
  logger.info(f"Created empty contributor storage at {self.contributors_path}")
160
168
  else:
161
- # Load existing contributors
162
169
  existing_contributors = pq.read_table(
163
170
  self.contributors_path, columns=["contributor_id"]
164
171
  )
165
172
  self.existing_contributor_ids = set(existing_contributors["contributor_id"].to_pylist())
166
173
  logger.info(f"Loaded {len(self.existing_contributor_ids)} existing contributor IDs")
167
174
 
168
- async def save_captions(self, caption_data: Dict[str, Any]):
169
- """Save captions for an image - single row with list of captions."""
170
- job_id = caption_data["job_id"]
175
+ async def _migrate_outputs_to_columns(self, existing_table: pa.Table):
176
+ """Migrate from JSON outputs column to dynamic columns."""
177
+ df = existing_table.to_pandas()
171
178
 
172
- # Check if we already have captions for this job_id
173
- if job_id in self.existing_caption_job_ids:
174
- self.duplicates_skipped += 1
175
- logger.debug(f"Skipping duplicate captions for job_id: {job_id}")
176
- return
179
+ # Collect all unique output field names
180
+ output_fields = set()
181
+ for outputs_json in df.get("outputs", []):
182
+ if outputs_json:
183
+ try:
184
+ outputs = json.loads(outputs_json)
185
+ output_fields.update(outputs.keys())
186
+ except:
187
+ continue
177
188
 
178
- # Check if it's already in the buffer
179
- for buffered in self.caption_buffer:
180
- if buffered["job_id"] == job_id:
181
- logger.debug(f"Captions for job_id {job_id} already in buffer")
182
- return
189
+ # Add legacy "captions" field if it exists and isn't already a base field
190
+ if "captions" in df.columns and "captions" not in {f[0] for f in self.base_caption_fields}:
191
+ output_fields.add("captions")
183
192
 
184
- # Ensure captions is a list (not a JSON string)
185
- captions = caption_data.get("captions")
186
- if isinstance(captions, str):
187
- # If it's a JSON string, decode it
188
- import json
189
-
190
- try:
191
- captions = json.loads(captions)
192
- caption_data["captions"] = captions
193
- logger.warning(f"Decoded JSON string to list for job_id {job_id}")
194
- except json.JSONDecodeError:
195
- logger.error(f"Invalid captions format for job_id {job_id}")
196
- return
193
+ logger.info(f"Found output fields to migrate: {sorted(output_fields)}")
197
194
 
198
- if not isinstance(captions, list):
199
- logger.error(f"Captions must be a list for job_id {job_id}, got {type(captions)}")
200
- return
195
+ # Create new columns for each output field
196
+ for field_name in output_fields:
197
+ if field_name not in df.columns:
198
+ df[field_name] = None
201
199
 
202
- # Add caption count
203
- caption_data["caption_count"] = len(captions)
200
+ # Migrate data from outputs JSON to columns
201
+ for idx, row in df.iterrows():
202
+ if pd.notna(row.get("outputs")):
203
+ try:
204
+ outputs = json.loads(row["outputs"])
205
+ for field_name, field_values in outputs.items():
206
+ df.at[idx, field_name] = field_values
207
+ except:
208
+ continue
204
209
 
205
- # Add default values for optional fields if not present
206
- if "quality_scores" not in caption_data:
207
- caption_data["quality_scores"] = None
210
+ # Handle legacy captions column if it's becoming a dynamic field
211
+ if "captions" in output_fields and pd.notna(row.get("captions")):
212
+ if pd.isna(df.at[idx, "captions"]):
213
+ df.at[idx, "captions"] = row["captions"]
208
214
 
209
- self.caption_buffer.append(caption_data)
210
- self.existing_caption_job_ids.add(job_id)
215
+ # Drop the old outputs column
216
+ if "outputs" in df.columns:
217
+ df = df.drop(columns=["outputs"])
211
218
 
212
- # Log buffer status
213
- logger.debug(f"Caption buffer size: {len(self.caption_buffer)}/{self.caption_buffer_size}")
214
- logger.debug(f" Added captions for {job_id}: {len(captions)} captions")
219
+ # Update known fields and schema
220
+ self.known_output_fields = output_fields
221
+ self.caption_schema = self._build_caption_schema(output_fields)
215
222
 
216
- # Flush if buffer is large enough
217
- if len(self.caption_buffer) >= self.caption_buffer_size:
218
- await self._flush_captions()
223
+ # Write migrated table
224
+ migrated_table = pa.Table.from_pandas(df, schema=self.caption_schema)
225
+ pq.write_table(migrated_table, self.captions_path)
226
+ logger.info("Migration complete - outputs now stored in dynamic columns")
219
227
 
220
228
  async def save_caption(self, caption: Caption):
221
- """Save a single caption entry."""
222
- # Convert to dict and ensure it's a list of captions
229
+ """Save a caption entry with dynamic output columns."""
230
+ # Convert to dict
223
231
  caption_dict = asdict(caption)
224
- if "captions" in caption_dict and not isinstance(caption_dict["captions"], list):
225
- caption_dict["captions"] = [caption_dict["captions"]]
226
- elif "caption" in caption_dict and isinstance(caption_dict["caption"], str):
227
- # If it's a single caption string, wrap it in a list
228
- caption_dict["captions"] = [caption_dict["caption"]]
229
- del caption_dict["caption"]
232
+
233
+ # Extract item_index from metadata if present
234
+ if "metadata" in caption_dict and isinstance(caption_dict["metadata"], dict):
235
+ item_index = caption_dict["metadata"].get("_item_index")
236
+ if item_index is not None:
237
+ caption_dict["item_index"] = item_index
238
+
239
+ # Extract outputs and handle them separately
240
+ outputs = caption_dict.pop("outputs", {})
241
+
242
+ # Remove old "captions" field if it exists (will be in outputs)
243
+ caption_dict.pop("captions", None)
244
+
245
+ new_fields = set()
246
+ for field_name, field_values in outputs.items():
247
+ caption_dict[field_name] = field_values
248
+ if field_name not in self.known_output_fields:
249
+ new_fields.add(field_name)
250
+ self.known_output_fields.add(field_name) # Add immediately
251
+
252
+ if new_fields:
253
+ logger.info(f"New output fields detected: {sorted(new_fields)}")
254
+ logger.info(f"Total known output fields: {sorted(self.known_output_fields)}")
255
+
256
+ # Serialize metadata to JSON if present
257
+ if "metadata" in caption_dict:
258
+ caption_dict["metadata"] = json.dumps(caption_dict.get("metadata", {}))
259
+ else:
260
+ caption_dict["metadata"] = "{}"
230
261
 
231
262
  # Add to buffer
232
263
  self.caption_buffer.append(caption_dict)
@@ -238,56 +269,113 @@ class StorageManager:
238
269
  if len(self.caption_buffer) >= self.caption_buffer_size:
239
270
  await self._flush_captions()
240
271
 
241
- async def save_job(self, job: Job):
242
- """Save or update a job - buffers until batch size reached."""
243
- # For updates, we still add to buffer (will be handled in flush)
244
- self.job_buffer.append(
245
- {
246
- "job_id": job.job_id,
247
- "dataset": job.dataset,
248
- "shard": job.shard,
249
- "item_key": job.item_key,
250
- "status": job.status.value,
251
- "assigned_to": job.assigned_to,
252
- "created_at": job.created_at,
253
- "updated_at": datetime.utcnow(),
254
- }
255
- )
272
+ async def save_captions(self, caption_data: Dict[str, Any]):
273
+ """Save captions for an image - compatible with dict input."""
274
+ job_id = caption_data["job_id"]
256
275
 
257
- self.existing_job_ids.add(job.job_id)
276
+ # Check if we already have captions for this job_id
277
+ if job_id in self.existing_caption_job_ids:
278
+ self.duplicates_skipped += 1
279
+ logger.debug(f"Skipping duplicate captions for job_id: {job_id}")
280
+ return
258
281
 
259
- if len(self.job_buffer) >= self.job_buffer_size:
260
- await self._flush_jobs()
282
+ # Check if it's already in the buffer
283
+ for buffered in self.caption_buffer:
284
+ if buffered["job_id"] == job_id:
285
+ logger.debug(f"Captions for job_id {job_id} already in buffer")
286
+ return
261
287
 
262
- async def save_contributor(self, contributor: Contributor):
263
- """Save or update contributor stats - buffers until batch size reached."""
264
- self.contributor_buffer.append(asdict(contributor))
288
+ # Handle outputs if present
289
+ if "outputs" in caption_data:
290
+ outputs = caption_data.pop("outputs")
291
+ # Add each output field directly to caption_data
292
+ for field_name, field_values in outputs.items():
293
+ caption_data[field_name] = field_values
294
+ if field_name not in self.known_output_fields:
295
+ self.known_output_fields.add(field_name)
296
+ logger.info(f"New output field detected: {field_name}")
297
+
298
+ # Handle legacy captions field
299
+ if "captions" in caption_data and "captions" not in self.known_output_fields:
300
+ self.known_output_fields.add("captions")
301
+
302
+ # Count all outputs
303
+ caption_count = 0
304
+ for field_name in self.known_output_fields:
305
+ if field_name in caption_data and isinstance(caption_data[field_name], list):
306
+ caption_count += len(caption_data[field_name])
307
+
308
+ caption_data["caption_count"] = caption_count
309
+
310
+ # Add default values for optional fields
311
+ if "quality_scores" not in caption_data:
312
+ caption_data["quality_scores"] = None
265
313
 
266
- if len(self.contributor_buffer) >= self.contributor_buffer_size:
267
- await self._flush_contributors()
314
+ if "metadata" in caption_data and isinstance(caption_data["metadata"], dict):
315
+ caption_data["metadata"] = json.dumps(caption_data["metadata"])
316
+ elif "metadata" not in caption_data:
317
+ caption_data["metadata"] = "{}"
318
+
319
+ self.caption_buffer.append(caption_data)
320
+ self.existing_caption_job_ids.add(job_id)
321
+
322
+ # Flush if buffer is large enough
323
+ if len(self.caption_buffer) >= self.caption_buffer_size:
324
+ await self._flush_captions()
268
325
 
269
326
  async def _flush_captions(self):
270
- """Write caption buffer to parquet with deduplication."""
327
+ """Write caption buffer to parquet with dynamic schema."""
271
328
  if not self.caption_buffer:
272
329
  return
273
330
 
274
331
  num_rows = len(self.caption_buffer)
275
- num_captions = sum(len(row["captions"]) for row in self.caption_buffer)
276
- logger.info(f"Flushing {num_rows} rows with {num_captions} total captions to disk")
277
332
 
278
- # Ensure all captions are proper lists before creating table
333
+ # Count total outputs across all fields
334
+ total_outputs = 0
279
335
  for row in self.caption_buffer:
280
- if isinstance(row["captions"], str):
281
- import json
336
+ for field_name in self.known_output_fields:
337
+ if field_name in row and isinstance(row[field_name], list):
338
+ total_outputs += len(row[field_name])
282
339
 
283
- try:
284
- row["captions"] = json.loads(row["captions"])
285
- except:
286
- logger.error(f"Failed to decode captions for {row['job_id']}")
287
- row["captions"] = [row["captions"]] # Wrap string in list as fallback
340
+ logger.info(f"Flushing {num_rows} rows with {total_outputs} total outputs to disk")
341
+
342
+ # Check if we need to evolve the schema
343
+ current_schema_fields = set(self.caption_schema.names) if self.caption_schema else set()
344
+ all_fields_needed = set(
345
+ self.base_caption_fields[i][0] for i in range(len(self.base_caption_fields))
346
+ )
347
+ all_fields_needed.update(self.known_output_fields)
288
348
 
289
- # Create table from buffer with explicit schema
290
- table = pa.Table.from_pylist(self.caption_buffer, schema=self.caption_schema)
349
+ if all_fields_needed != current_schema_fields:
350
+ # Schema evolution needed
351
+ logger.info(
352
+ f"Evolving schema to include new fields: {all_fields_needed - current_schema_fields}"
353
+ )
354
+ self.caption_schema = self._build_caption_schema(self.known_output_fields)
355
+
356
+ # If file exists, we need to migrate it
357
+ if self.captions_path.exists():
358
+ await self._evolve_schema_on_disk()
359
+
360
+ # Prepare data with all required columns
361
+ prepared_buffer = []
362
+ for row in self.caption_buffer:
363
+ prepared_row = row.copy()
364
+
365
+ # Ensure all base fields are present
366
+ for field_name, field_type in self.base_caption_fields:
367
+ if field_name not in prepared_row:
368
+ prepared_row[field_name] = None
369
+
370
+ # Ensure all output fields are present (even if None)
371
+ for field_name in self.known_output_fields:
372
+ if field_name not in prepared_row:
373
+ prepared_row[field_name] = None
374
+
375
+ prepared_buffer.append(prepared_row)
376
+
377
+ # Create table from buffer
378
+ table = pa.Table.from_pylist(prepared_buffer, schema=self.caption_schema)
291
379
 
292
380
  if self.captions_path.exists():
293
381
  # Read existing table
@@ -298,7 +386,7 @@ class StorageManager:
298
386
 
299
387
  # Filter new data to exclude duplicates
300
388
  new_rows = []
301
- for row in self.caption_buffer:
389
+ for row in prepared_buffer:
302
390
  if row["job_id"] not in existing_job_ids:
303
391
  new_rows.append(row)
304
392
 
@@ -306,10 +394,10 @@ class StorageManager:
306
394
  # Create table from new rows only
307
395
  new_table = pa.Table.from_pylist(new_rows, schema=self.caption_schema)
308
396
 
309
- # Combine tables using PyArrow concat (preserves list types better)
397
+ # Combine tables
310
398
  combined = pa.concat_tables([existing, new_table])
311
399
 
312
- # Write with proper list column preservation
400
+ # Write with proper preservation
313
401
  pq.write_table(combined, self.captions_path, compression="snappy")
314
402
 
315
403
  logger.info(
@@ -320,23 +408,99 @@ class StorageManager:
320
408
  logger.info(f"All {num_rows} rows were duplicates, skipping write")
321
409
  actual_new = 0
322
410
  else:
323
- # Write new file with proper list columns
411
+ # Write new file
324
412
  pq.write_table(table, self.captions_path, compression="snappy")
325
413
  actual_new = num_rows
326
414
 
327
415
  self.total_captions_written += actual_new
328
- self.total_caption_entries_written += sum(
329
- len(row["captions"]) for row in self.caption_buffer[:actual_new]
330
- )
416
+ self.total_caption_entries_written += total_outputs
331
417
  self.total_flushes += 1
332
418
  self.caption_buffer.clear()
333
419
 
334
420
  logger.info(
335
421
  f"Successfully wrote captions (rows: {self.total_captions_written}, "
336
- f"total captions: {self.total_caption_entries_written}, "
422
+ f"total outputs: {self.total_caption_entries_written}, "
337
423
  f"duplicates skipped: {self.duplicates_skipped})"
338
424
  )
339
425
 
426
+ async def _evolve_schema_on_disk(self):
427
+ """Evolve the schema of the existing parquet file to include new columns."""
428
+ logger.info("Evolving schema on disk to add new columns...")
429
+
430
+ # Read existing data
431
+ existing_table = pq.read_table(self.captions_path)
432
+ df = existing_table.to_pandas()
433
+
434
+ # Add missing columns with None values
435
+ for field_name in self.known_output_fields:
436
+ if field_name not in df.columns:
437
+ df[field_name] = None
438
+ logger.info(f"Added new column: {field_name}")
439
+
440
+ # Recreate table with new schema
441
+ evolved_table = pa.Table.from_pandas(df, schema=self.caption_schema)
442
+ pq.write_table(evolved_table, self.captions_path, compression="snappy")
443
+ logger.info("Schema evolution complete")
444
+
445
+ async def get_captions(self, job_id: str) -> Optional[Dict[str, List[str]]]:
446
+ """Retrieve all output fields for a specific job_id."""
447
+ # Check buffer first
448
+ for buffered in self.caption_buffer:
449
+ if buffered["job_id"] == job_id:
450
+ outputs = {}
451
+ for field_name in self.known_output_fields:
452
+ if field_name in buffered and buffered[field_name]:
453
+ outputs[field_name] = buffered[field_name]
454
+ return outputs
455
+
456
+ if not self.captions_path.exists():
457
+ return None
458
+
459
+ table = pq.read_table(self.captions_path)
460
+ df = table.to_pandas()
461
+
462
+ row = df[df["job_id"] == job_id]
463
+ if row.empty:
464
+ return None
465
+
466
+ # Collect all output fields
467
+ outputs = {}
468
+ for field_name in self.known_output_fields:
469
+ if field_name in row.columns:
470
+ value = row.iloc[0][field_name]
471
+ if pd.notna(value) and value is not None:
472
+ outputs[field_name] = value
473
+
474
+ return outputs if outputs else None
475
+
476
+ async def save_job(self, job: Job):
477
+ """Save or update a job - buffers until batch size reached."""
478
+ # For updates, we still add to buffer (will be handled in flush)
479
+ self.job_buffer.append(
480
+ {
481
+ "job_id": job.job_id,
482
+ "dataset": job.dataset,
483
+ "shard": job.shard,
484
+ "item_key": job.item_key,
485
+ "status": job.status.value,
486
+ "assigned_to": job.assigned_to,
487
+ "created_at": job.created_at,
488
+ "updated_at": datetime.utcnow(),
489
+ }
490
+ )
491
+
492
+ self.existing_job_ids.add(job.job_id)
493
+
494
+ if len(self.job_buffer) >= self.job_buffer_size:
495
+ await self._flush_jobs()
496
+
497
+ async def save_contributor(self, contributor: Contributor):
498
+ """Save or update contributor stats - buffers until batch size reached."""
499
+ self.contributor_buffer.append(asdict(contributor))
500
+
501
+ if len(self.contributor_buffer) >= self.contributor_buffer_size:
502
+ await self._flush_contributors()
503
+
340
504
  async def _flush_jobs(self):
341
505
  """Write job buffer to parquet."""
342
506
  if not self.job_buffer:
@@ -424,36 +588,6 @@ class StorageManager:
424
588
 
425
589
  return False
426
590
 
427
- async def get_captions(self, job_id: str) -> Optional[List[str]]:
428
- """Retrieve captions for a specific job_id."""
429
- # Check buffer first
430
- for buffered in self.caption_buffer:
431
- if buffered["job_id"] == job_id:
432
- return buffered["captions"]
433
-
434
- if not self.captions_path.exists():
435
- return None
436
-
437
- table = pq.read_table(self.captions_path)
438
- df = table.to_pandas()
439
-
440
- row = df[df["job_id"] == job_id]
441
- if row.empty:
442
- return None
443
-
444
- captions = row.iloc[0]["captions"]
445
-
446
- # Handle both correct list storage and incorrect JSON string storage
447
- if isinstance(captions, str):
448
- # This shouldn't happen with correct storage, but handle legacy data
449
- try:
450
- captions = json.loads(captions)
451
- logger.warning(f"Had to decode JSON string for job_id {job_id} - file needs fixing")
452
- except json.JSONDecodeError:
453
- captions = [captions] # Wrap single string as list
454
-
455
- return captions
456
-
457
591
  async def get_job(self, job_id: str) -> Optional[Job]:
458
592
  """Retrieve a job by ID."""
459
593
  # Check buffer first
@@ -516,41 +650,72 @@ class StorageManager:
516
650
  return jobs
517
651
 
518
652
  async def get_caption_stats(self) -> Dict[str, Any]:
519
- """Get statistics about stored captions."""
653
+ """Get statistics about stored captions including field-specific stats."""
520
654
  if not self.captions_path.exists():
521
- return {
522
- "total_rows": 0,
523
- "total_captions": 0,
524
- "avg_captions_per_image": 0,
525
- "min_captions": 0,
526
- "max_captions": 0,
527
- }
655
+ return {"total_rows": 0, "total_outputs": 0, "output_fields": [], "field_stats": {}}
528
656
 
529
657
  table = pq.read_table(self.captions_path)
530
658
  df = table.to_pandas()
531
659
 
532
660
  if len(df) == 0:
533
- return {
534
- "total_rows": 0,
535
- "total_captions": 0,
536
- "avg_captions_per_image": 0,
537
- "min_captions": 0,
538
- "max_captions": 0,
539
- }
540
-
541
- caption_counts = df["caption_count"].values
661
+ return {"total_rows": 0, "total_outputs": 0, "output_fields": [], "field_stats": {}}
662
+
663
+ # Get actual columns in the dataframe
664
+ existing_columns = set(df.columns)
665
+
666
+ # Calculate stats per field (only for fields that exist in the file)
667
+ field_stats = {}
668
+ total_outputs = 0
669
+
670
+ for field_name in self.known_output_fields:
671
+ if field_name in existing_columns:
672
+ # Count non-null entries
673
+ non_null_mask = df[field_name].notna()
674
+ non_null_count = non_null_mask.sum()
675
+
676
+ # Count total items in lists
677
+ field_total = 0
678
+ field_lengths = []
679
+
680
+ for value in df.loc[non_null_mask, field_name]:
681
+ if isinstance(value, list):
682
+ length = len(value)
683
+ field_total += length
684
+ field_lengths.append(length)
685
+ elif pd.notna(value):
686
+ length = 1
687
+ field_total += length
688
+ field_lengths.append(length)
689
+
690
+ if field_lengths:
691
+ field_stats[field_name] = {
692
+ "rows_with_data": non_null_count,
693
+ "total_items": field_total,
694
+ "avg_items_per_row": sum(field_lengths) / len(field_lengths),
695
+ }
696
+ if min(field_lengths) != max(field_lengths):
697
+ field_stats[field_name].update(
698
+ {
699
+ "min_items": min(field_lengths),
700
+ "max_items": max(field_lengths),
701
+ }
702
+ )
703
+ total_outputs += field_total
542
704
 
543
705
  return {
544
706
  "total_rows": len(df),
545
- "total_captions": caption_counts.sum(),
546
- "avg_captions_per_image": caption_counts.mean(),
547
- "min_captions": caption_counts.min(),
548
- "max_captions": caption_counts.max(),
549
- "std_captions": caption_counts.std(),
707
+ "total_outputs": total_outputs,
708
+ "output_fields": sorted(list(self.known_output_fields)),
709
+ "field_stats": field_stats,
710
+ "caption_count_stats": {
711
+ "mean": df["caption_count"].mean() if "caption_count" in df.columns else 0,
712
+ "min": df["caption_count"].min() if "caption_count" in df.columns else 0,
713
+ "max": df["caption_count"].max() if "caption_count" in df.columns else 0,
714
+ },
550
715
  }
551
716
 
552
717
  async def get_sample_captions(self, n: int = 5) -> List[Dict[str, Any]]:
553
- """Get a sample of caption entries for inspection."""
718
+ """Get a sample of caption entries showing all output fields."""
554
719
  if not self.captions_path.exists():
555
720
  return []
556
721
 
@@ -564,26 +729,59 @@ class StorageManager:
564
729
  samples = []
565
730
 
566
731
  for _, row in sample_df.iterrows():
732
+ # Collect outputs from dynamic columns
733
+ outputs = {}
734
+ total_outputs = 0
735
+
736
+ for field_name in self.known_output_fields:
737
+ if field_name in row and pd.notna(row[field_name]):
738
+ value = row[field_name]
739
+ outputs[field_name] = value
740
+ if isinstance(value, list):
741
+ total_outputs += len(value)
742
+
567
743
  samples.append(
568
744
  {
569
745
  "job_id": row["job_id"],
570
746
  "item_key": row["item_key"],
571
- "captions": row["captions"],
572
- "caption_count": row["caption_count"],
747
+ "outputs": outputs,
748
+ "field_count": len(outputs),
749
+ "total_outputs": total_outputs,
573
750
  "image_dims": f"{row.get('image_width', 'N/A')}x{row.get('image_height', 'N/A')}",
751
+ "has_metadata": bool(row.get("metadata") and row["metadata"] != "{}"),
574
752
  }
575
753
  )
576
754
 
577
755
  return samples
578
756
 
579
757
  async def count_captions(self) -> int:
580
- """Count total caption entries (not rows)."""
581
- if not self.captions_path.exists():
582
- return 0
758
+ """Count total outputs across all dynamic fields."""
759
+ total = 0
583
760
 
584
- table = pq.read_table(self.captions_path, columns=["caption_count"])
585
- df = table.to_pandas()
586
- return df["caption_count"].sum()
761
+ if self.captions_path.exists():
762
+ # Get actual columns in the file
763
+ table_metadata = pq.read_metadata(self.captions_path)
764
+ existing_columns = set(table_metadata.schema.names)
765
+
766
+ # Only read output fields that actually exist in the file
767
+ columns_to_read = [f for f in self.known_output_fields if f in existing_columns]
768
+
769
+ if columns_to_read:
770
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
771
+ df = table.to_pandas()
772
+
773
+ for field_name in columns_to_read:
774
+ for value in df[field_name]:
775
+ if pd.notna(value) and isinstance(value, list):
776
+ total += len(value)
777
+
778
+ # Add buffer counts
779
+ for row in self.caption_buffer:
780
+ for field_name in self.known_output_fields:
781
+ if field_name in row and isinstance(row[field_name], list):
782
+ total += len(row[field_name])
783
+
784
+ return total
587
785
 
588
786
  async def count_caption_rows(self) -> int:
589
787
  """Count total rows (unique images with captions)."""
@@ -640,6 +838,135 @@ class StorageManager:
640
838
 
641
839
  return contributors
642
840
 
841
+ async def get_output_field_stats(self) -> Dict[str, Any]:
842
+ """Get statistics about output fields in stored captions."""
843
+ if not self.captions_path.exists():
844
+ return {"total_fields": 0, "field_counts": {}}
845
+
846
+ if not self.known_output_fields:
847
+ return {"total_fields": 0, "field_counts": {}}
848
+
849
+ # Get actual columns in the file
850
+ table_metadata = pq.read_metadata(self.captions_path)
851
+ existing_columns = set(table_metadata.schema.names)
852
+
853
+ # Only read output fields that actually exist in the file
854
+ columns_to_read = [f for f in self.known_output_fields if f in existing_columns]
855
+
856
+ if not columns_to_read:
857
+ return {"total_fields": 0, "field_counts": {}}
858
+
859
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
860
+ df = table.to_pandas()
861
+
862
+ if len(df) == 0:
863
+ return {"total_fields": 0, "field_counts": {}}
864
+
865
+ # Count outputs by field
866
+ field_counts = {}
867
+ total_outputs = 0
868
+
869
+ for field_name in columns_to_read:
870
+ field_count = 0
871
+ for value in df[field_name]:
872
+ if pd.notna(value) and isinstance(value, list):
873
+ field_count += len(value)
874
+
875
+ if field_count > 0:
876
+ field_counts[field_name] = field_count
877
+ total_outputs += field_count
878
+
879
+ return {
880
+ "total_fields": len(field_counts),
881
+ "field_counts": field_counts,
882
+ "total_outputs": total_outputs,
883
+ "fields": sorted(list(field_counts.keys())),
884
+ }
885
+
886
+ async def get_captions_with_field(
887
+ self, field_name: str, limit: int = 100
888
+ ) -> List[Dict[str, Any]]:
889
+ """Get captions that have a specific output field."""
890
+ if not self.captions_path.exists():
891
+ return []
892
+
893
+ if field_name not in self.known_output_fields:
894
+ logger.warning(f"Field '{field_name}' not found in known output fields")
895
+ return []
896
+
897
+ # Check if the field actually exists in the file
898
+ existing_output_columns = self._get_existing_output_columns()
899
+ if field_name not in existing_output_columns:
900
+ logger.warning(
901
+ f"Field '{field_name}' exists in known fields but not in parquet file yet"
902
+ )
903
+ return []
904
+
905
+ # Only read necessary columns
906
+ columns_to_read = ["job_id", "item_key", field_name]
907
+
908
+ try:
909
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
910
+ except Exception as e:
911
+ logger.error(f"Error reading field '{field_name}': {e}")
912
+ return []
913
+
914
+ df = table.to_pandas()
915
+
916
+ # Filter rows where field has data
917
+ mask = df[field_name].notna()
918
+ filtered_df = df[mask].head(limit)
919
+
920
+ results = []
921
+ for _, row in filtered_df.iterrows():
922
+ results.append(
923
+ {
924
+ "job_id": row["job_id"],
925
+ "item_key": row["item_key"],
926
+ field_name: row[field_name],
927
+ "value_count": len(row[field_name]) if isinstance(row[field_name], list) else 1,
928
+ }
929
+ )
930
+
931
+ return results
932
+
933
+ async def export_by_field(self, field_name: str, output_path: Path, format: str = "jsonl"):
934
+ """Export all captions for a specific field."""
935
+ if not self.captions_path.exists():
936
+ logger.warning("No captions to export")
937
+ return 0
938
+
939
+ if field_name not in self.known_output_fields:
940
+ logger.warning(f"Field '{field_name}' not found in known output fields")
941
+ return 0
942
+
943
+ # Check if the field actually exists in the file
944
+ existing_output_columns = self._get_existing_output_columns()
945
+ if field_name not in existing_output_columns:
946
+ logger.warning(f"Field '{field_name}' not found in parquet file")
947
+ return 0
948
+
949
+ # Read only necessary columns
950
+ columns_to_read = ["item_key", "dataset", field_name]
951
+ table = pq.read_table(self.captions_path, columns=columns_to_read)
952
+ df = table.to_pandas()
953
+
954
+ exported = 0
955
+ with open(output_path, "w") as f:
956
+ for _, row in df.iterrows():
957
+ if pd.notna(row[field_name]) and row[field_name]:
958
+ if format == "jsonl":
959
+ record = {
960
+ "item_key": row["item_key"],
961
+ "dataset": row["dataset"],
962
+ field_name: row[field_name],
963
+ }
964
+ f.write(json.dumps(record) + "\n")
965
+ exported += 1
966
+
967
+ logger.info(f"Exported {exported} items with field '{field_name}' to {output_path}")
968
+ return exported
969
+
643
970
  async def get_pending_jobs(self) -> List[Job]:
644
971
  """Get all pending jobs for restoration on startup."""
645
972
  if not self.jobs_path.exists():
@@ -692,3 +1019,33 @@ class StorageManager:
692
1019
  f"Total caption entries: {self.total_caption_entries_written}, "
693
1020
  f"Duplicates skipped: {self.duplicates_skipped}"
694
1021
  )
1022
+
1023
+ async def get_storage_stats(self) -> Dict[str, Any]:
1024
+ """Get all storage-related statistics."""
1025
+ # Count outputs on disk
1026
+ disk_outputs = await self.count_captions()
1027
+
1028
+ # Count outputs in buffer
1029
+ buffer_outputs = 0
1030
+ for row in self.caption_buffer:
1031
+ for field_name in self.known_output_fields:
1032
+ if field_name in row and isinstance(row[field_name], list):
1033
+ buffer_outputs += len(row[field_name])
1034
+
1035
+ # Get field-specific stats
1036
+ field_stats = await self.get_caption_stats()
1037
+ total_rows_including_buffer = await self.count_caption_rows() + len(self.caption_buffer)
1038
+
1039
+ return {
1040
+ "total_captions": disk_outputs + buffer_outputs,
1041
+ "total_rows": total_rows_including_buffer,
1042
+ "buffer_size": len(self.caption_buffer),
1043
+ "total_written": self.total_captions_written,
1044
+ "total_entries_written": self.total_caption_entries_written,
1045
+ "duplicates_skipped": self.duplicates_skipped,
1046
+ "total_flushes": self.total_flushes,
1047
+ "output_fields": sorted(list(self.known_output_fields)),
1048
+ "field_breakdown": field_stats.get("field_stats", None),
1049
+ "job_buffer_size": len(self.job_buffer),
1050
+ "contributor_buffer_size": len(self.contributor_buffer),
1051
+ }