caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +921 -427
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +463 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +303 -92
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
  28. caption_flow-0.4.1.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,29 +1,29 @@
1
- """Arrow/Parquet storage management with dynamic column support for outputs."""
1
+ """Storage management with Lance backend using a single dataset."""
2
2
 
3
- import asyncio
4
3
  import gc
5
4
  import json
6
5
  import logging
6
+ import os
7
+ import time
8
+ from collections import defaultdict, deque
7
9
  from dataclasses import asdict
8
- from datetime import datetime, timedelta
10
+ from datetime import datetime
9
11
  from pathlib import Path
10
- from typing import List, Optional, Set, Dict, Any
11
- import pyarrow as pa
12
- import pyarrow.parquet as pq
13
- from pyarrow import fs
12
+ from typing import Any, Dict, List, Optional, Set
13
+
14
+ import duckdb
15
+ import lance
14
16
  import pandas as pd
15
- from collections import defaultdict, deque
16
- import time
17
- import numpy as np
17
+ import pyarrow as pa
18
18
 
19
- from ..models import Job, Caption, Contributor, StorageContents, JobId
19
+ from ..models import Caption, Contributor, JobId, StorageContents
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
- logger.setLevel(logging.DEBUG)
22
+ logger.setLevel(os.environ.get("CAPTIONFLOW_LOG_LEVEL", "INFO").upper())
23
23
 
24
24
 
25
25
  class StorageManager:
26
- """Manages Arrow/Parquet storage with dynamic columns for output fields."""
26
+ """Manages Lance storage with a single dataset and dynamic columns."""
27
27
 
28
28
  def __init__(
29
29
  self,
@@ -31,16 +31,17 @@ class StorageManager:
31
31
  caption_buffer_size: int = 100,
32
32
  contributor_buffer_size: int = 50,
33
33
  ):
34
+ self.duckdb_shard_connections = {}
34
35
  self.data_dir = Path(data_dir)
35
36
  self.data_dir.mkdir(parents=True, exist_ok=True)
36
37
 
37
38
  # File paths
38
- self.captions_path = self.data_dir / "captions.parquet"
39
- self.jobs_path = self.data_dir / "jobs.parquet"
40
- self.contributors_path = self.data_dir / "contributors.parquet"
41
- self.stats_path = self.data_dir / "storage_stats.json" # Persist stats here
39
+ self.captions_path = self.data_dir / "captions.lance"
40
+ self.jobs_path = self.data_dir / "jobs.lance"
41
+ self.contributors_path = self.data_dir / "contributors.lance"
42
+ self.stats_path = self.data_dir / "storage_stats.json"
42
43
 
43
- # In-memory buffers for batching writes
44
+ # In-memory buffers
44
45
  self.caption_buffer = []
45
46
  self.job_buffer = []
46
47
  self.contributor_buffer = []
@@ -49,32 +50,37 @@ class StorageManager:
49
50
  self.caption_buffer_size = caption_buffer_size
50
51
  self.contributor_buffer_size = contributor_buffer_size
51
52
 
52
- # Track existing job_ids to prevent duplicates
53
- self.existing_contributor_ids: Set[str] = set()
53
+ # Track existing IDs
54
54
  self.existing_caption_job_ids: Set[str] = set()
55
+ self.existing_contributor_ids: Set[str] = set()
55
56
  self.existing_job_ids: Set[str] = set()
56
57
 
57
- # Track known output fields for schema evolution
58
+ # Track known output fields
58
59
  self.known_output_fields: Set[str] = set()
59
60
 
60
- # In-memory statistics (loaded once at startup, then tracked incrementally)
61
+ # Lance datasets
62
+ self.captions_dataset: Optional[lance.Dataset] = None
63
+ self.jobs_dataset: Optional[lance.Dataset] = None
64
+ self.contributors_dataset: Optional[lance.Dataset] = None
65
+
66
+ # Statistics
61
67
  self.stats = {
62
- "disk_rows": 0, # Rows in parquet file
63
- "disk_outputs": 0, # Total outputs in parquet file
64
- "field_counts": {}, # Count of outputs per field on disk
65
- "total_captions_written": 0, # Total rows written during this session
66
- "total_caption_entries_written": 0, # Total individual captions written
68
+ "disk_rows": 0,
69
+ "disk_outputs": 0,
70
+ "field_counts": {},
71
+ "total_captions_written": 0,
72
+ "total_caption_entries_written": 0,
67
73
  "total_flushes": 0,
68
74
  "duplicates_skipped": 0,
69
- "session_field_counts": {}, # Outputs per field written this session
75
+ "session_field_counts": {},
70
76
  }
71
77
 
72
78
  # Rate tracking
73
- self.row_additions = deque(maxlen=100) # Store (timestamp, row_count) tuples
79
+ self.row_additions = deque(maxlen=100)
74
80
  self.start_time = time.time()
75
81
  self.last_rate_log_time = time.time()
76
82
 
77
- # Base caption schema without dynamic output fields
83
+ # Base caption schema
78
84
  self.base_caption_fields = [
79
85
  ("job_id", pa.string()),
80
86
  ("dataset", pa.string()),
@@ -121,6 +127,130 @@ class StorageManager:
121
127
  ]
122
128
  )
123
129
 
130
+ def init_duckdb_connection(
131
+ self, output_shard: Optional[str] = None
132
+ ) -> duckdb.DuckDBPyConnection:
133
+ """Initialize or retrieve a DuckDB connection for a given output shard.
134
+ Currently, we just use a single output shard, but this allows for future implementation of multiple.
135
+
136
+ Args:
137
+ ----
138
+ output_shard (Optional[str]): The output shard identifier. If None, uses default shard.
139
+
140
+ Returns:
141
+ -------
142
+ duckdb.DuckDBPyConnection: The DuckDB connection for the specified shard.
143
+
144
+ """
145
+ shard_key = output_shard or "default"
146
+ if shard_key in self.duckdb_shard_connections:
147
+ return self.duckdb_shard_connections[shard_key]
148
+
149
+ conn = duckdb.connect(database=":memory:")
150
+
151
+ # For the default shard, register the captions Lance dataset if it exists
152
+ if shard_key == "default":
153
+ # Force refresh the dataset to handle cases where it was recreated due to schema evolution
154
+ if self.captions_path.exists():
155
+ try:
156
+ # Always reload from disk to ensure we have the latest version
157
+ logger.debug(f"Reloading Lance dataset from {self.captions_path}")
158
+ self.captions_dataset = lance.dataset(str(self.captions_path))
159
+ logger.debug("Successfully loaded Lance dataset, converting to Arrow table")
160
+ # Convert Lance dataset to Arrow table for DuckDB compatibility
161
+ arrow_table = self.captions_dataset.to_table()
162
+ logger.debug(
163
+ f"Successfully converted Lance dataset to Arrow table with {arrow_table.num_rows} rows"
164
+ )
165
+ # Register the Arrow table in DuckDB so it can be queried
166
+ conn.register("captions", arrow_table)
167
+ logger.debug(
168
+ f"Registered Lance dataset {self.captions_path} as 'captions' table in DuckDB"
169
+ )
170
+
171
+ # Verify the table was registered
172
+ tables = conn.execute("SHOW TABLES").fetchall()
173
+ logger.debug(f"Available tables in DuckDB: {tables}")
174
+ except Exception as e:
175
+ logger.warning(f"Failed to register Lance dataset in DuckDB: {e}")
176
+ # Fall back to direct file path queries
177
+
178
+ self.duckdb_shard_connections[shard_key] = conn
179
+
180
+ return conn
181
+
182
+ def _init_lance_dataset(self) -> Optional[lance.LanceDataset]:
183
+ """Initialize or retrieve the captions Lance dataset."""
184
+ if self.captions_dataset:
185
+ logger.debug("Captions dataset already initialized")
186
+ return self.captions_dataset
187
+
188
+ if not self.captions_path.exists():
189
+ logger.debug("Captions dataset does not exist, creating new one")
190
+ # Create initial schema with just base fields
191
+ self.caption_schema = self._build_caption_schema(set())
192
+
193
+ # Create empty dataset on disk with proper schema
194
+ empty_dict = {}
195
+ for field_name, field_type in self.base_caption_fields:
196
+ if field_type == pa.string():
197
+ empty_dict[field_name] = []
198
+ elif field_type == pa.int32():
199
+ empty_dict[field_name] = []
200
+ elif field_type == pa.int64():
201
+ empty_dict[field_name] = []
202
+ elif field_type == pa.float32():
203
+ empty_dict[field_name] = []
204
+ elif field_type == pa.timestamp("us"):
205
+ empty_dict[field_name] = []
206
+ elif field_type == pa.list_(pa.float32()):
207
+ empty_dict[field_name] = []
208
+ else:
209
+ empty_dict[field_name] = []
210
+
211
+ empty_table = pa.Table.from_pydict(empty_dict, schema=self.caption_schema)
212
+ self.captions_dataset = lance.write_dataset(
213
+ empty_table, str(self.captions_path), mode="create"
214
+ )
215
+ logger.info(f"Created empty captions storage at {self.captions_path}")
216
+
217
+ return self.captions_dataset
218
+
219
+ try:
220
+ logger.debug(f"Loading Lance dataset from {self.captions_path}")
221
+ self.captions_dataset = lance.dataset(str(self.captions_path))
222
+ return self.captions_dataset
223
+ except Exception as e:
224
+ logger.error(f"Failed to load Lance dataset from {self.captions_path}: {e}")
225
+ return None
226
+
227
+ def _update_duckdb_connections_after_schema_change(self):
228
+ """Update DuckDB connections after dataset schema has changed."""
229
+ logger.debug(
230
+ f"Updating {len(self.duckdb_shard_connections)} DuckDB connections after schema change"
231
+ )
232
+ for shard_key, conn in self.duckdb_shard_connections.items():
233
+ if shard_key == "default" and self.captions_dataset:
234
+ try:
235
+ # Re-register the updated dataset
236
+ arrow_table = self.captions_dataset.to_table()
237
+ conn.register("captions", arrow_table)
238
+ logger.debug(
239
+ f"Updated DuckDB registration for {self.captions_path} after schema change"
240
+ )
241
+ except Exception as e:
242
+ logger.warning(f"Failed to update DuckDB connection after schema change: {e}")
243
+
244
+ def _build_caption_schema(self, output_fields: Set[str]) -> pa.Schema:
245
+ """Build caption schema with dynamic output fields."""
246
+ fields = self.base_caption_fields.copy()
247
+
248
+ # Add dynamic output fields
249
+ for field_name in sorted(output_fields):
250
+ fields.append((field_name, pa.list_(pa.string())))
251
+
252
+ return pa.schema(fields)
253
+
124
254
  def _save_stats(self):
125
255
  """Persist current stats to disk."""
126
256
  try:
@@ -135,60 +265,107 @@ class StorageManager:
135
265
  try:
136
266
  with open(self.stats_path, "r") as f:
137
267
  loaded_stats = json.load(f)
138
- # Merge loaded stats with defaults
139
268
  self.stats.update(loaded_stats)
140
269
  logger.info(f"Loaded stats from {self.stats_path}")
141
270
  except Exception as e:
142
271
  logger.error(f"Failed to load stats: {e}")
143
272
 
144
- def _calculate_initial_stats(self):
145
- """Calculate stats from parquet file - only called once at initialization."""
146
- if not self.captions_path.exists():
273
+ async def initialize(self):
274
+ """Initialize storage and load existing data."""
275
+ self._load_stats()
276
+
277
+ # Initialize caption storage
278
+ if self._init_lance_dataset():
279
+ logger.debug(f"Initialized captions dataset from {self.captions_path}")
280
+ # Load existing job IDs
281
+ job_ids = (
282
+ self.captions_dataset.to_table(columns=["job_id"]).column("job_id").to_pylist()
283
+ )
284
+ self.existing_caption_job_ids = set(job_ids)
285
+
286
+ # Detect output fields
287
+ schema = self.captions_dataset.schema
288
+ base_field_names = {field[0] for field in self.base_caption_fields}
289
+ self.known_output_fields = set(schema.names) - base_field_names
290
+
291
+ # Update caption schema
292
+ self.caption_schema = self._build_caption_schema(self.known_output_fields)
293
+
294
+ logger.info(
295
+ f"Loaded Lance dataset: {len(job_ids)} rows, output fields: {sorted(self.known_output_fields)}"
296
+ )
297
+
298
+ # Calculate stats if not loaded
299
+ if self.stats["disk_rows"] == 0:
300
+ await self._calculate_initial_stats()
301
+ else:
302
+ logger.warning("No existing captions dataset found, starting fresh.")
303
+
304
+ # Initialize contributors storage
305
+ if not self.contributors_path.exists():
306
+ # Create empty contributors dataset
307
+ empty_dict = {"contributor_id": [], "name": [], "total_captions": [], "trust_level": []}
308
+ empty_table = pa.Table.from_pydict(empty_dict, schema=self.contributor_schema)
309
+ self.contributors_dataset = lance.write_dataset(
310
+ empty_table, str(self.contributors_path), mode="create"
311
+ )
312
+ logger.info(f"Created empty contributor storage at {self.contributors_path}")
313
+ else:
314
+ self.contributors_dataset = lance.dataset(str(self.contributors_path))
315
+ contributor_ids = (
316
+ self.contributors_dataset.to_table(columns=["contributor_id"])
317
+ .column("contributor_id")
318
+ .to_pylist()
319
+ )
320
+ self.existing_contributor_ids = set(contributor_ids)
321
+ logger.info(f"Loaded contributors dataset: {len(contributor_ids)} contributors")
322
+
323
+ # Initialize jobs storage
324
+ if not self.jobs_path.exists():
325
+ # Create empty jobs dataset
326
+ empty_dict = {
327
+ "job_id": [],
328
+ "dataset": [],
329
+ "shard": [],
330
+ "item_key": [],
331
+ "status": [],
332
+ "assigned_to": [],
333
+ "created_at": [],
334
+ "updated_at": [],
335
+ }
336
+ empty_table = pa.Table.from_pydict(empty_dict, schema=self.job_schema)
337
+ self.jobs_dataset = lance.write_dataset(empty_table, str(self.jobs_path), mode="create")
338
+ logger.info(f"Created empty jobs storage at {self.jobs_path}")
339
+ else:
340
+ self.jobs_dataset = lance.dataset(str(self.jobs_path))
341
+ logger.info(f"Loaded jobs dataset: {self.jobs_dataset.count_rows()} rows")
342
+
343
+ async def _calculate_initial_stats(self):
344
+ """Calculate initial statistics from Lance dataset."""
345
+ if not self.captions_dataset:
147
346
  return
148
347
 
149
- logger.info("Calculating initial statistics from parquet file...")
348
+ logger.info("Calculating initial statistics...")
150
349
 
151
350
  try:
152
- # Get metadata to determine row count
153
- table_metadata = pq.read_metadata(self.captions_path)
154
- self.stats["disk_rows"] = table_metadata.num_rows
155
-
156
- # Simply use the known_output_fields that were already detected during initialization
157
- # This avoids issues with PyArrow schema parsing
158
- if self.known_output_fields and table_metadata.num_rows > 0:
159
- # Read the entire table since column-specific reading is causing issues
160
- table = pq.read_table(self.captions_path)
351
+ self.stats["disk_rows"] = self.captions_dataset.count_rows()
352
+
353
+ if self.known_output_fields and self.stats["disk_rows"] > 0:
354
+ # Sample data to calculate stats efficiently
355
+ table = self.captions_dataset.to_table()
161
356
  df = table.to_pandas()
162
357
 
163
358
  total_outputs = 0
164
359
  field_counts = {}
165
360
 
166
- # Count outputs in each known output field
167
361
  for field_name in self.known_output_fields:
168
362
  if field_name in df.columns:
169
363
  field_count = 0
170
364
  column_data = df[field_name]
171
365
 
172
- # Iterate through values more carefully to avoid numpy array issues
173
- for i in range(len(column_data)):
174
- value = column_data.iloc[i]
175
- # Handle None/NaN values first
176
- if value is None:
177
- continue
178
- # For lists/arrays, check if they're actually None or have content
179
- try:
180
- # If it's a list-like object, count its length
181
- if hasattr(value, "__len__") and not isinstance(value, str):
182
- # Check if it's a pandas null value by trying to check the first element
183
- if len(value) > 0:
184
- field_count += len(value)
185
- # Empty lists are valid but contribute 0
186
- except TypeError:
187
- # Value is scalar NA/NaN, skip it
188
- continue
189
- except:
190
- # Any other error, skip this value
191
- continue
366
+ for value in column_data:
367
+ if value is not None and isinstance(value, list) and len(value) > 0:
368
+ field_count += len(value)
192
369
 
193
370
  if field_count > 0:
194
371
  field_counts[field_name] = field_count
@@ -197,7 +374,6 @@ class StorageManager:
197
374
  self.stats["disk_outputs"] = total_outputs
198
375
  self.stats["field_counts"] = field_counts
199
376
 
200
- # Clean up
201
377
  del df, table
202
378
  gc.collect()
203
379
  else:
@@ -205,684 +381,323 @@ class StorageManager:
205
381
  self.stats["field_counts"] = {}
206
382
 
207
383
  logger.info(
208
- f"Initial stats: {self.stats['disk_rows']} rows, {self.stats['disk_outputs']} outputs, fields: {list(self.stats['field_counts'].keys())}"
384
+ f"Initial stats: {self.stats['disk_rows']} rows, "
385
+ f"{self.stats['disk_outputs']} outputs, "
386
+ f"fields: {list(self.stats['field_counts'].keys())}"
209
387
  )
210
388
 
211
389
  except Exception as e:
212
390
  logger.error(f"Failed to calculate initial stats: {e}", exc_info=True)
213
- # Set default values
214
- self.stats["disk_rows"] = 0
215
- self.stats["disk_outputs"] = 0
216
- self.stats["field_counts"] = {}
217
391
 
218
- # Save the calculated stats
219
392
  self._save_stats()
220
393
 
221
- def _update_stats_for_new_captions(self, captions_added: List[dict], rows_added: int):
222
- """Update stats incrementally as new captions are added."""
223
- # Update row counts
224
- self.stats["disk_rows"] += rows_added
225
- self.stats["total_captions_written"] += rows_added
226
-
227
- # Count outputs in the new captions
228
- outputs_added = 0
229
- for caption in captions_added:
230
- for field_name in self.known_output_fields:
231
- if field_name in caption and isinstance(caption[field_name], list):
232
- count = len(caption[field_name])
233
- outputs_added += count
234
-
235
- # Update field-specific counts
236
- if field_name not in self.stats["field_counts"]:
237
- self.stats["field_counts"][field_name] = 0
238
- self.stats["field_counts"][field_name] += count
239
-
240
- if field_name not in self.stats["session_field_counts"]:
241
- self.stats["session_field_counts"][field_name] = 0
242
- self.stats["session_field_counts"][field_name] += count
243
-
244
- self.stats["disk_outputs"] += outputs_added
245
- self.stats["total_caption_entries_written"] += outputs_added
246
-
247
- def _is_column_empty(self, df: pd.DataFrame, column_name: str) -> bool:
248
- """Check if a column is entirely empty, null, or contains only zeros/empty lists."""
249
- if column_name not in df.columns:
250
- return True
251
-
252
- col = df[column_name]
253
-
254
- # Check if all values are null/NaN
255
- if col.isna().all():
256
- return True
257
-
258
- # For numeric columns, check if all non-null values are 0
259
- if pd.api.types.is_numeric_dtype(col):
260
- non_null_values = col.dropna()
261
- if len(non_null_values) > 0 and (non_null_values == 0).all():
262
- return True
263
-
264
- # For list columns, check if all are None or empty lists
265
- if col.dtype == "object":
266
- non_null_values = col.dropna()
267
- if len(non_null_values) == 0:
268
- return True
269
- # Check if all non-null values are empty lists
270
- all_empty_lists = True
271
- for val in non_null_values:
272
- if isinstance(val, list) and len(val) > 0:
273
- all_empty_lists = False
274
- break
275
- elif not isinstance(val, list):
276
- all_empty_lists = False
277
- break
278
- if all_empty_lists:
279
- return True
280
-
281
- return False
282
-
283
- def _get_non_empty_columns(
284
- self, df: pd.DataFrame, preserve_base_fields: bool = True
285
- ) -> List[str]:
286
- """Get list of columns that contain actual data.
287
-
288
- Args:
289
- df: DataFrame to check
290
- preserve_base_fields: If True, always include base fields even if empty
291
- """
292
- base_field_names = {field[0] for field in self.base_caption_fields}
293
- non_empty_columns = []
294
-
295
- for col in df.columns:
296
- # Always keep base fields if preserve_base_fields is True
297
- if preserve_base_fields and col in base_field_names:
298
- non_empty_columns.append(col)
299
- elif not self._is_column_empty(df, col):
300
- non_empty_columns.append(col)
301
-
302
- return non_empty_columns
303
-
304
- def _calculate_rates(self) -> Dict[str, float]:
305
- """Calculate row addition rates over different time windows."""
306
- current_time = time.time()
307
- rates = {}
308
-
309
- # Define time windows in minutes
310
- windows = {"1min": 1, "5min": 5, "15min": 15, "60min": 60}
311
-
312
- # Clean up old entries beyond the largest window
313
- cutoff_time = current_time - (60 * 60) # 60 minutes
314
- while self.row_additions and self.row_additions[0][0] < cutoff_time:
315
- self.row_additions.popleft()
316
-
317
- # Calculate rates for each window
318
- for window_name, window_minutes in windows.items():
319
- window_seconds = window_minutes * 60
320
- window_start = current_time - window_seconds
321
-
322
- # Sum rows added within this window
323
- rows_in_window = sum(
324
- count for timestamp, count in self.row_additions if timestamp >= window_start
325
- )
326
-
327
- # Calculate rate (rows per second)
328
- # For windows larger than elapsed time, use elapsed time
329
- elapsed = current_time - self.start_time
330
- actual_window = min(window_seconds, elapsed)
331
-
332
- if actual_window > 0:
333
- rate = rows_in_window / actual_window
334
- rates[window_name] = rate
335
- else:
336
- rates[window_name] = 0.0
337
-
338
- # Calculate instantaneous rate (last minute)
339
- instant_window_start = current_time - 60 # Last 60 seconds
340
- instant_rows = sum(
341
- count for timestamp, count in self.row_additions if timestamp >= instant_window_start
342
- )
343
- instant_window = min(60, current_time - self.start_time)
344
- rates["instant"] = instant_rows / instant_window if instant_window > 0 else 0.0
345
-
346
- # Calculate overall rate since start
347
- total_elapsed = current_time - self.start_time
348
- if total_elapsed > 0:
349
- rates["overall"] = self.stats["total_captions_written"] / total_elapsed
350
- else:
351
- rates["overall"] = 0.0
352
-
353
- return rates
354
-
355
- def _log_rates(self, rows_added: int):
356
- """Log rate information if enough time has passed."""
357
- current_time = time.time()
358
-
359
- # Log rates every 10 seconds or if it's been more than 30 seconds
360
- time_since_last_log = current_time - self.last_rate_log_time
361
- if time_since_last_log < 10 and rows_added < 50:
362
- return
363
-
364
- rates = self._calculate_rates()
365
-
366
- # Format the rate information
367
- rate_str = (
368
- f"Rate stats - Instant: {rates['instant']:.1f} rows/s | "
369
- f"Avg (5m): {rates['5min']:.1f} | "
370
- f"Avg (15m): {rates['15min']:.1f} | "
371
- f"Avg (60m): {rates['60min']:.1f} | "
372
- f"Overall: {rates['overall']:.1f} rows/s"
373
- )
374
-
375
- logger.info(rate_str)
376
- self.last_rate_log_time = current_time
377
-
378
- def _get_existing_output_columns(self) -> Set[str]:
379
- """Get output field columns that actually exist in the parquet file."""
380
- if not self.captions_path.exists():
381
- return set()
382
-
383
- table_metadata = pq.read_metadata(self.captions_path)
384
- existing_columns = set(table_metadata.schema.names)
385
- base_field_names = {field[0] for field in self.base_caption_fields}
386
-
387
- return existing_columns - base_field_names
388
-
389
- def _build_caption_schema(self, output_fields: Set[str]) -> pa.Schema:
390
- """Build caption schema with dynamic output fields."""
391
- fields = self.base_caption_fields.copy()
392
-
393
- # Add dynamic output fields (all as list of strings for now)
394
- for field_name in sorted(output_fields): # Sort for consistent ordering
395
- fields.append((field_name, pa.list_(pa.string())))
396
-
397
- return pa.schema(fields)
398
-
399
- async def initialize(self):
400
- """Initialize storage files if they don't exist."""
401
- # Load persisted stats if available
402
- self._load_stats()
403
-
404
- if not self.captions_path.exists():
405
- # Create initial schema with just base fields
406
- self.caption_schema = self._build_caption_schema(set())
407
-
408
- # Create empty table
409
- empty_dict = {field[0]: [] for field in self.base_caption_fields}
410
- empty_table = pa.Table.from_pydict(empty_dict, schema=self.caption_schema)
411
- pq.write_table(empty_table, self.captions_path)
412
- logger.info(f"Created empty caption storage at {self.captions_path}")
413
-
414
- # Initialize stats
415
- self.stats["disk_rows"] = 0
416
- self.stats["disk_outputs"] = 0
417
- self.stats["field_counts"] = {}
418
- else:
419
- # Load existing schema and detect output fields
420
- existing_table = pq.read_table(self.captions_path)
421
- existing_columns = set(existing_table.column_names)
422
-
423
- # Identify output fields (columns not in base schema)
424
- base_field_names = {field[0] for field in self.base_caption_fields}
425
- self.known_output_fields = existing_columns - base_field_names
426
-
427
- # Check if we need to migrate from old "outputs" JSON column
428
- if "outputs" in existing_columns:
429
- logger.info("Migrating from JSON outputs to dynamic columns...")
430
- await self._migrate_outputs_to_columns(existing_table)
431
- else:
432
- # Build current schema from existing columns
433
- self.caption_schema = self._build_caption_schema(self.known_output_fields)
434
-
435
- # Load existing caption job_ids
436
- self.existing_caption_job_ids = set(existing_table["job_id"].to_pylist())
437
- logger.info(f"Loaded {len(self.existing_caption_job_ids)} existing caption job_ids")
438
- logger.info(f"Known output fields: {sorted(self.known_output_fields)}")
439
-
440
- # Calculate initial stats if not already loaded from file
441
- if self.stats["disk_rows"] == 0:
442
- self._calculate_initial_stats()
443
-
444
- # Initialize other storage files...
445
- if not self.contributors_path.exists():
446
- empty_dict = {"contributor_id": [], "name": [], "total_captions": [], "trust_level": []}
447
- empty_table = pa.Table.from_pydict(empty_dict, schema=self.contributor_schema)
448
- pq.write_table(empty_table, self.contributors_path)
449
- logger.info(f"Created empty contributor storage at {self.contributors_path}")
450
- else:
451
- existing_contributors = pq.read_table(
452
- self.contributors_path, columns=["contributor_id"]
453
- )
454
- self.existing_contributor_ids = set(existing_contributors["contributor_id"].to_pylist())
455
- logger.info(f"Loaded {len(self.existing_contributor_ids)} existing contributor IDs")
456
-
457
- async def _migrate_outputs_to_columns(self, existing_table: pa.Table):
458
- """Migrate from JSON outputs column to dynamic columns."""
459
- df = existing_table.to_pandas()
460
-
461
- # Collect all unique output field names
462
- output_fields = set()
463
- for outputs_json in df.get("outputs", []):
464
- if outputs_json:
465
- try:
466
- outputs = json.loads(outputs_json)
467
- output_fields.update(outputs.keys())
468
- except:
469
- continue
470
-
471
- # Add legacy "captions" field if it exists and isn't already a base field
472
- if "captions" in df.columns and "captions" not in {f[0] for f in self.base_caption_fields}:
473
- output_fields.add("captions")
474
-
475
- logger.info(f"Found output fields to migrate: {sorted(output_fields)}")
476
-
477
- # Create new columns for each output field
478
- for field_name in output_fields:
479
- if field_name not in df.columns:
480
- df[field_name] = None
481
-
482
- # Migrate data from outputs JSON to columns
483
- for idx, row in df.iterrows():
484
- if pd.notna(row.get("outputs")):
485
- try:
486
- outputs = json.loads(row["outputs"])
487
- for field_name, field_values in outputs.items():
488
- df.at[idx, field_name] = field_values
489
- except:
490
- continue
491
-
492
- # Handle legacy captions column if it's becoming a dynamic field
493
- if "captions" in output_fields and pd.notna(row.get("captions")):
494
- if pd.isna(df.at[idx, "captions"]):
495
- df.at[idx, "captions"] = row["captions"]
496
-
497
- # Drop the old outputs column
498
- if "outputs" in df.columns:
499
- df = df.drop(columns=["outputs"])
500
-
501
- # Remove empty columns before saving (but preserve base fields)
502
- non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=True)
503
- df = df[non_empty_columns]
504
-
505
- # Update known fields and schema based on non-empty columns
506
- base_field_names = {field[0] for field in self.base_caption_fields}
507
- self.known_output_fields = set(non_empty_columns) - base_field_names
508
- self.caption_schema = self._build_caption_schema(self.known_output_fields)
509
-
510
- # Write migrated table
511
- migrated_table = pa.Table.from_pandas(df, schema=self.caption_schema)
512
- pq.write_table(migrated_table, self.captions_path)
513
- logger.info("Migration complete - outputs now stored in dynamic columns")
514
-
515
- # Recalculate stats after migration
516
- self._calculate_initial_stats()
517
-
518
394
  async def save_caption(self, caption: Caption):
519
- """Save a caption entry, grouping outputs by job_id/item_key (not separating captions)."""
395
+ """Save a caption entry."""
520
396
  caption_dict = asdict(caption)
521
397
 
522
- # Extract item_index from metadata if present
398
+ # Extract item_index from metadata
523
399
  if "metadata" in caption_dict and isinstance(caption_dict["metadata"], dict):
524
400
  item_index = caption_dict["metadata"].get("_item_index")
525
401
  if item_index is not None:
526
402
  caption_dict["item_index"] = item_index
527
403
 
528
- # Extract outputs and handle them separately
404
+ # Extract outputs
529
405
  outputs = caption_dict.pop("outputs", {})
530
-
531
- # Remove old "captions" field if it exists (will be in outputs)
532
406
  caption_dict.pop("captions", None)
533
407
 
534
- # Grouping key: (job_id, item_key)
408
+ # Get job_id for deduplication - convert to string early
535
409
  _job_id = caption_dict.get("job_id")
536
- job_id = JobId.from_dict(_job_id).get_sample_str()
537
- group_key = job_id
410
+ job_id = JobId.from_dict(_job_id).get_sample_str() if isinstance(_job_id, dict) else _job_id
411
+ caption_dict["job_id"] = job_id # Update dict with string version
538
412
 
539
- # Check for duplicate - if this job_id already exists on disk, skip it
540
- if group_key in self.existing_caption_job_ids:
413
+ # Check for duplicate
414
+ if job_id in self.existing_caption_job_ids:
541
415
  self.stats["duplicates_skipped"] += 1
542
- logger.debug(f"Skipping duplicate job_id: {group_key}")
416
+ logger.debug(f"Skipping duplicate job_id: {job_id}")
543
417
  return
544
418
 
545
- # Try to find existing buffered row for this group
546
- found_row = False
547
- for idx, row in enumerate(self.caption_buffer):
548
- check_key = row.get("job_id")
549
- if check_key == group_key:
550
- found_row = True
551
- # Merge outputs into existing row
419
+ # Try to find existing buffered row
420
+ for _idx, row in enumerate(self.caption_buffer):
421
+ if row.get("job_id") == job_id:
422
+ # Merge outputs
552
423
  for field_name, field_values in outputs.items():
553
424
  if field_name not in self.known_output_fields:
554
425
  self.known_output_fields.add(field_name)
426
+ logger.info(f"New output field detected: {field_name}")
555
427
  if field_name in row and isinstance(row[field_name], list):
556
428
  row[field_name].extend(field_values)
557
429
  else:
558
430
  row[field_name] = list(field_values)
559
- # Optionally update other fields (e.g., caption_count)
560
431
  if "caption_count" in caption_dict:
561
432
  old_count = row.get("caption_count", 0)
562
433
  row["caption_count"] = old_count + caption_dict["caption_count"]
563
- return # Already merged, no need to add new row
434
+ return
564
435
 
565
- # If not found, create new row
436
+ # Create new row
566
437
  for field_name, field_values in outputs.items():
567
438
  if field_name not in self.known_output_fields:
568
439
  self.known_output_fields.add(field_name)
569
440
  logger.info(f"New output field detected: {field_name}")
570
441
  caption_dict[field_name] = list(field_values)
571
442
 
572
- # Serialize metadata to JSON if present
443
+ # Serialize metadata
573
444
  if "metadata" in caption_dict:
574
445
  caption_dict["metadata"] = json.dumps(caption_dict.get("metadata", {}))
575
446
  else:
576
447
  caption_dict["metadata"] = "{}"
577
448
 
578
- if isinstance(caption_dict.get("job_id"), dict):
579
- caption_dict["job_id"] = job_id
580
-
581
449
  self.caption_buffer.append(caption_dict)
582
450
 
583
451
  if len(self.caption_buffer) >= self.caption_buffer_size:
584
- logger.debug("Caption buffer full, flushing captions.")
452
+ logger.debug("Caption buffer full, flushing.")
585
453
  await self._flush_captions()
586
454
 
587
455
  async def _flush_captions(self):
588
- """Write caption buffer to parquet with dynamic schema."""
456
+ """Flush caption buffer to Lance dataset."""
589
457
  if not self.caption_buffer:
590
458
  return
591
459
 
592
460
  try:
593
461
  num_rows = len(self.caption_buffer)
594
-
595
- # Count total outputs across all fields
596
- total_outputs = 0
597
- for row in self.caption_buffer:
598
- for field_name in self.known_output_fields:
599
- if field_name in row and isinstance(row[field_name], list):
600
- total_outputs += len(row[field_name])
601
-
602
- logger.debug(
603
- f"Flushing {num_rows} rows with {total_outputs} total outputs to disk. preparing data..."
604
- )
605
-
606
- # Keep a copy of captions for stats update
607
462
  captions_to_write = list(self.caption_buffer)
608
463
 
609
- # Prepare data with all required columns
464
+ # Prepare data
610
465
  prepared_buffer = []
611
466
  new_job_ids = []
467
+
612
468
  for row in self.caption_buffer:
613
469
  prepared_row = row.copy()
614
-
615
- # Track job_ids for deduplication
616
470
  job_id = prepared_row.get("job_id")
617
471
  if job_id:
618
472
  new_job_ids.append(job_id)
619
473
 
620
474
  # Ensure all base fields are present
621
- for field_name, field_type in self.base_caption_fields:
475
+ for field_name, _field_type in self.base_caption_fields:
622
476
  if field_name not in prepared_row:
623
477
  prepared_row[field_name] = None
624
478
 
625
- # Ensure all output fields are present (even if None)
479
+ # Ensure all output fields are present
626
480
  for field_name in self.known_output_fields:
627
481
  if field_name not in prepared_row:
628
482
  prepared_row[field_name] = None
629
483
 
630
484
  prepared_buffer.append(prepared_row)
631
485
 
632
- # Build schema with all known fields (base + output)
633
- logger.debug("building schema...")
486
+ # Build schema and create table
634
487
  schema = self._build_caption_schema(self.known_output_fields)
635
- logger.debug("schema built, creating table...")
636
488
  table = pa.Table.from_pylist(prepared_buffer, schema=schema)
637
489
 
638
- if self.captions_path.exists():
639
- # Read existing table - this is necessary for parquet format
640
- logger.debug("Reading existing captions file...")
641
- existing = pq.read_table(self.captions_path)
490
+ # Write to Lance
491
+ if self.captions_dataset is None and self.captions_path.exists():
492
+ try:
493
+ self.captions_dataset = lance.dataset(str(self.captions_path))
494
+ except Exception:
495
+ # Dataset might be corrupted or incomplete
496
+ self.captions_dataset = None
497
+
498
+ if self.captions_dataset is not None:
499
+ # Check if schema has changed (new output fields added)
500
+ existing_schema_fields = set(self.captions_dataset.schema.names)
501
+ new_schema_fields = set(schema.names)
502
+
503
+ if new_schema_fields != existing_schema_fields:
504
+ # Schema has changed, need to merge existing data with new schema
505
+ logger.info(
506
+ f"Schema evolution detected. New fields: {new_schema_fields - existing_schema_fields}"
507
+ )
642
508
 
643
- logger.debug("writing new rows...")
644
- # Concatenate with promote_options="default" to handle schema differences automatically
645
- logger.debug("concat tables...")
646
- combined = pa.concat_tables([existing, table], promote_options="default")
509
+ # Read existing data
510
+ existing_table = self.captions_dataset.to_table()
511
+ existing_df = existing_table.to_pandas()
647
512
 
648
- # Write combined table
649
- pq.write_table(combined, self.captions_path, compression="snappy")
513
+ # Add missing columns to existing data
514
+ for field_name in new_schema_fields - existing_schema_fields:
515
+ existing_df[field_name] = None
650
516
 
651
- actual_new = len(table)
652
- else:
653
- # Write new file with all fields
654
- pq.write_table(table, self.captions_path, compression="snappy")
655
- actual_new = num_rows
656
- logger.debug("write complete.")
517
+ # Convert back to table with new schema
518
+ existing_table_updated = pa.Table.from_pandas(existing_df, schema=schema)
657
519
 
658
- # Clean up
659
- del prepared_buffer, table
520
+ # Concatenate existing and new data
521
+ combined_table = pa.concat_tables([existing_table_updated, table])
660
522
 
661
- # Update the in-memory job_id set for efficient deduplication
523
+ # Recreate dataset with combined data
524
+ self.captions_dataset = lance.write_dataset(
525
+ combined_table, str(self.captions_path), mode="overwrite"
526
+ )
527
+
528
+ # Update DuckDB connections after schema evolution
529
+ logger.debug("Updating DuckDB connections after schema evolution")
530
+ self._update_duckdb_connections_after_schema_change()
531
+ else:
532
+ # Schema hasn't changed, normal append
533
+ self.captions_dataset = lance.write_dataset(
534
+ table, str(self.captions_path), mode="append"
535
+ )
536
+ else:
537
+ # Create new dataset
538
+ self.captions_dataset = lance.write_dataset(
539
+ table, str(self.captions_path), mode="create"
540
+ )
541
+
542
+ # Update tracking
662
543
  self.existing_caption_job_ids.update(new_job_ids)
663
544
 
664
- # Update statistics incrementally
665
- self._update_stats_for_new_captions(captions_to_write, actual_new)
545
+ # Update stats
546
+ self._update_stats_for_new_captions(captions_to_write, num_rows)
666
547
  self.stats["total_flushes"] += 1
667
548
 
668
- # Clear buffer
669
- self.caption_buffer.clear()
549
+ # Track row additions
550
+ current_time = time.time()
551
+ self.row_additions.append((current_time, num_rows))
552
+ self._log_rates(num_rows)
670
553
 
671
- # Track row additions for rate calculation
672
- if actual_new > 0:
673
- current_time = time.time()
674
- self.row_additions.append((current_time, actual_new))
675
-
676
- # Log rates
677
- self._log_rates(actual_new)
554
+ # Clear buffer only on success
555
+ self.caption_buffer.clear()
678
556
 
679
- logger.info(
680
- f"Successfully wrote captions (new rows: {actual_new}, "
681
- f"total rows written: {self.stats['total_captions_written']}, "
682
- f"total captions written: {self.stats['total_caption_entries_written']}, "
683
- f"duplicates skipped: {self.stats['duplicates_skipped']}, "
684
- f"output fields: {sorted(list(self.known_output_fields))})"
685
- )
557
+ logger.info(f"Flushed {num_rows} rows to Lance dataset")
686
558
 
687
559
  # Save stats periodically
688
560
  if self.stats["total_flushes"] % 10 == 0:
689
561
  self._save_stats()
690
562
 
563
+ except Exception as e:
564
+ logger.error(f"Failed to flush captions: {e}")
565
+ # Don't clear buffer on failure - preserve data
566
+ raise
691
567
  finally:
692
- self.caption_buffer.clear()
693
- # Force garbage collection
694
568
  gc.collect()
695
569
 
696
- async def optimize_storage(self):
697
- """Optimize storage by dropping empty columns. Run this periodically or on-demand."""
698
- if not self.captions_path.exists():
699
- logger.info("No captions file to optimize")
700
- return
701
-
702
- logger.info("Starting storage optimization...")
703
-
704
- # Read the full table
705
- backup_path = None
706
- table = pq.read_table(self.captions_path)
707
- df = table.to_pandas()
708
- original_columns = len(df.columns)
709
-
710
- # Find non-empty columns (don't preserve empty base fields)
711
- non_empty_columns = self._get_non_empty_columns(df, preserve_base_fields=False)
712
-
713
- # Always keep at least job_id
714
- if "job_id" not in non_empty_columns:
715
- non_empty_columns.append("job_id")
716
-
717
- if len(non_empty_columns) < original_columns:
718
- # We have columns to drop
719
- df_optimized = df[non_empty_columns]
720
-
721
- # Rebuild schema for non-empty columns only
722
- base_field_names = {f[0] for f in self.base_caption_fields}
723
- fields = []
724
- output_fields = set()
725
-
726
- # Process columns in a consistent order: base fields first, then output fields
727
- for col in non_empty_columns:
728
- if col in base_field_names:
729
- # Find the base field definition
730
- for fname, ftype in self.base_caption_fields:
731
- if fname == col:
732
- fields.append((fname, ftype))
733
- break
734
- else:
735
- # Output field
736
- output_fields.add(col)
570
+ def _update_stats_for_new_captions(self, captions_added: List[dict], rows_added: int):
571
+ """Update stats incrementally."""
572
+ self.stats["disk_rows"] += rows_added
573
+ self.stats["total_captions_written"] += rows_added
737
574
 
738
- # Add output fields in sorted order
739
- for field_name in sorted(output_fields):
740
- fields.append((field_name, pa.list_(pa.string())))
575
+ outputs_added = 0
576
+ for caption in captions_added:
577
+ for field_name in self.known_output_fields:
578
+ if field_name in caption and isinstance(caption[field_name], list):
579
+ count = len(caption[field_name])
580
+ outputs_added += count
741
581
 
742
- # Create optimized schema and table
743
- optimized_schema = pa.schema(fields)
744
- optimized_table = pa.Table.from_pandas(df_optimized, schema=optimized_schema)
582
+ if field_name not in self.stats["field_counts"]:
583
+ self.stats["field_counts"][field_name] = 0
584
+ self.stats["field_counts"][field_name] += count
745
585
 
746
- # Backup the original file (optional)
747
- backup_path = self.captions_path.with_suffix(".parquet.bak")
748
- import shutil
586
+ if field_name not in self.stats["session_field_counts"]:
587
+ self.stats["session_field_counts"][field_name] = 0
588
+ self.stats["session_field_counts"][field_name] += count
749
589
 
750
- shutil.copy2(self.captions_path, backup_path)
590
+ self.stats["disk_outputs"] += outputs_added
591
+ self.stats["total_caption_entries_written"] += outputs_added
751
592
 
752
- # Write optimized table
753
- pq.write_table(optimized_table, self.captions_path, compression="snappy")
593
+ def _calculate_rates(self) -> Dict[str, float]:
594
+ """Calculate row addition rates."""
595
+ current_time = time.time()
596
+ rates = {}
754
597
 
755
- # Update known output fields and recalculate stats
756
- self.known_output_fields = output_fields
598
+ windows = {"1min": 1, "5min": 5, "15min": 15, "60min": 60}
599
+ cutoff_time = current_time - (60 * 60)
757
600
 
758
- # Remove dropped fields from stats
759
- dropped_fields = set(self.stats["field_counts"].keys()) - output_fields
760
- for field in dropped_fields:
761
- del self.stats["field_counts"][field]
762
- if field in self.stats["session_field_counts"]:
763
- del self.stats["session_field_counts"][field]
601
+ while self.row_additions and self.row_additions[0][0] < cutoff_time:
602
+ self.row_additions.popleft()
764
603
 
765
- logger.info(
766
- f"Storage optimization complete: {original_columns} -> {len(non_empty_columns)} columns. "
767
- f"Removed columns: {sorted(set(df.columns) - set(non_empty_columns))}"
604
+ for window_name, window_minutes in windows.items():
605
+ window_seconds = window_minutes * 60
606
+ window_start = current_time - window_seconds
607
+ rows_in_window = sum(
608
+ count for timestamp, count in self.row_additions if timestamp >= window_start
768
609
  )
610
+ elapsed = current_time - self.start_time
611
+ actual_window = min(window_seconds, elapsed)
612
+ if actual_window > 0:
613
+ rates[window_name] = rows_in_window / actual_window
614
+ else:
615
+ rates[window_name] = 0.0
616
+
617
+ # Add instant and overall rates
618
+ rates["instant"] = rates.get("1min", 0.0)
619
+ total_elapsed = current_time - self.start_time
620
+ if total_elapsed > 0:
621
+ rates["overall"] = self.stats["total_captions_written"] / total_elapsed
769
622
  else:
770
- logger.info(f"No optimization needed - all {original_columns} columns contain data")
623
+ rates["overall"] = 0.0
771
624
 
772
- # Report file size reduction
773
- import os
625
+ return rates
774
626
 
775
- if backup_path and backup_path.exists():
776
- original_size = os.path.getsize(backup_path)
777
- new_size = os.path.getsize(self.captions_path)
778
- reduction_pct = (1 - new_size / original_size) * 100
779
- logger.info(
780
- f"File size: {original_size/1024/1024:.1f}MB -> {new_size/1024/1024:.1f}MB "
781
- f"({reduction_pct:.1f}% reduction)"
782
- )
627
+ def _log_rates(self, rows_added: int):
628
+ """Log rate information."""
629
+ current_time = time.time()
630
+ time_since_last_log = current_time - self.last_rate_log_time
783
631
 
784
- # Save updated stats
785
- self._save_stats()
632
+ if time_since_last_log < 10 and rows_added < 50:
633
+ return
634
+
635
+ rates = self._calculate_rates()
636
+ rate_str = (
637
+ f"Rate stats - Instant: {rates.get('1min', 0):.1f} rows/s | "
638
+ f"Avg (5m): {rates.get('5min', 0):.1f} | "
639
+ f"Avg (15m): {rates.get('15min', 0):.1f} | "
640
+ f"Overall: {rates['overall']:.1f} rows/s"
641
+ )
642
+ logger.info(rate_str)
643
+ self.last_rate_log_time = current_time
786
644
 
787
645
  async def save_contributor(self, contributor: Contributor):
788
- """Save or update contributor stats - buffers until batch size reached."""
646
+ """Save or update contributor stats."""
789
647
  self.contributor_buffer.append(asdict(contributor))
790
648
 
791
649
  if len(self.contributor_buffer) >= self.contributor_buffer_size:
792
650
  await self._flush_contributors()
793
651
 
794
- async def _flush_jobs(self):
795
- """Write job buffer to parquet."""
796
- if not self.job_buffer:
652
+ async def _flush_contributors(self):
653
+ """Flush contributor buffer to Lance."""
654
+ if not self.contributor_buffer:
797
655
  return
798
656
 
799
- table = pa.Table.from_pylist(self.job_buffer, schema=self.job_schema)
800
-
801
- # For jobs, we need to handle updates (upsert logic)
802
- if self.jobs_path.exists():
803
- existing = pq.read_table(self.jobs_path).to_pandas()
804
- new_df = table.to_pandas()
805
-
806
- # Update existing records or add new ones
807
- for _, row in new_df.iterrows():
808
- mask = existing["job_id"] == row["job_id"]
809
- if mask.any():
810
- # Update existing
811
- for col in row.index:
812
- existing.loc[existing[mask].index, col] = row[col]
813
- else:
814
- # Add new
815
- existing = pd.concat([existing, pd.DataFrame([row])], ignore_index=True)
657
+ table = pa.Table.from_pylist(self.contributor_buffer, schema=self.contributor_schema)
816
658
 
817
- updated_table = pa.Table.from_pandas(existing, schema=self.job_schema)
818
- pq.write_table(updated_table, self.jobs_path)
819
- else:
820
- pq.write_table(table, self.jobs_path)
659
+ mode = "append" if self.contributors_path.exists() else "create"
660
+ self.contributors_dataset = lance.write_dataset(
661
+ table, str(self.contributors_path), mode=mode
662
+ )
663
+ if mode == "create":
664
+ logger.info(f"Created contributor storage at {self.contributors_path}")
821
665
 
822
- self.job_buffer.clear()
823
- logger.debug(f"Flushed {len(self.job_buffer)} jobs")
666
+ self.contributor_buffer.clear()
824
667
 
825
- async def _flush_contributors(self):
826
- """Write contributor buffer to parquet."""
827
- if not self.contributor_buffer:
668
+ async def _flush_jobs(self):
669
+ """Flush job buffer to Lance."""
670
+ if not self.job_buffer:
828
671
  return
829
672
 
830
- table = pa.Table.from_pylist(self.contributor_buffer, schema=self.contributor_schema)
831
-
832
- # Handle updates for contributors
833
- if self.contributors_path.exists():
834
- existing = pq.read_table(self.contributors_path).to_pandas()
835
- new_df = table.to_pandas()
673
+ table = pa.Table.from_pylist(self.job_buffer, schema=self.job_schema)
836
674
 
837
- for _, row in new_df.iterrows():
838
- mask = existing["contributor_id"] == row["contributor_id"]
839
- if mask.any():
840
- for col in row.index:
841
- existing.loc[mask, col] = row[col]
842
- else:
843
- existing = pd.concat([existing, pd.DataFrame([row])], ignore_index=True)
675
+ table = pa.Table.from_pylist(self.job_buffer, schema=self.job_schema)
844
676
 
845
- updated_table = pa.Table.from_pandas(existing, schema=self.contributor_schema)
846
- pq.write_table(updated_table, self.contributors_path)
847
- else:
848
- pq.write_table(table, self.contributors_path)
677
+ mode = "append" if self.jobs_path.exists() else "create"
678
+ self.jobs_dataset = lance.write_dataset(table, str(self.jobs_path), mode=mode)
679
+ if mode == "create":
680
+ logger.info(f"Created jobs storage at {self.jobs_path}")
849
681
 
850
- self.contributor_buffer.clear()
682
+ self.job_buffer.clear()
851
683
 
852
684
  async def checkpoint(self):
853
- """Force flush all buffers to disk - called periodically by orchestrator."""
854
- logger.info(
855
- f"Checkpoint: Flushing buffers (captions: {len(self.caption_buffer)}, "
856
- f"jobs: {len(self.job_buffer)}, contributors: {len(self.contributor_buffer)})"
857
- )
685
+ """Flush all buffers to disk."""
686
+ logger.info("Checkpoint: Flushing buffers")
858
687
 
859
688
  await self._flush_captions()
860
- await self._flush_jobs()
861
689
  await self._flush_contributors()
690
+ await self._flush_jobs()
862
691
 
863
- # Save stats on checkpoint
864
692
  self._save_stats()
865
693
 
866
- # Log final rate statistics
867
- if self.stats["total_captions_written"] > 0:
868
- rates = self._calculate_rates()
869
- logger.info(
870
- f"Checkpoint complete. Total rows: {self.stats['total_captions_written']}, "
871
- f"Total caption entries: {self.stats['total_caption_entries_written']}, "
872
- f"Duplicates skipped: {self.stats['duplicates_skipped']} | "
873
- f"Overall rate: {rates['overall']:.1f} rows/s"
874
- )
875
- else:
876
- logger.info(
877
- f"Checkpoint complete. Total rows: {self.stats['total_captions_written']}, "
878
- f"Total caption entries: {self.stats['total_caption_entries_written']}, "
879
- f"Duplicates skipped: {self.stats['duplicates_skipped']}"
880
- )
694
+ logger.info(
695
+ f"Checkpoint complete. Total rows: {self.stats['disk_rows']}, "
696
+ f"Total outputs: {self.stats['disk_outputs']}"
697
+ )
881
698
 
882
699
  def get_all_processed_job_ids(self) -> Set[str]:
883
- """Get all processed job_ids - useful for resumption."""
884
- # Return the in-memory set which is kept up-to-date
885
- # Also add any job_ids currently in the buffer
700
+ """Get all processed job_ids."""
886
701
  all_job_ids = self.existing_caption_job_ids.copy()
887
702
 
888
703
  for row in self.caption_buffer:
@@ -897,97 +712,89 @@ class StorageManager:
897
712
  columns: Optional[List[str]] = None,
898
713
  include_metadata: bool = True,
899
714
  ) -> StorageContents:
900
- """Retrieve storage contents for export.
901
-
902
- Args:
903
- limit: Maximum number of rows to retrieve
904
- columns: Specific columns to include (None for all)
905
- include_metadata: Whether to include metadata in the result
715
+ """Get storage contents for export using DuckDB."""
716
+ # Flush buffers first
717
+ await self.checkpoint()
906
718
 
907
- Returns:
908
- StorageContents instance with the requested data
909
- """
910
- if not self.captions_path.exists():
719
+ if not self.captions_path.exists() or not self.captions_dataset:
911
720
  return StorageContents(
912
721
  rows=[],
913
722
  columns=[],
914
723
  output_fields=list(self.known_output_fields),
915
724
  total_rows=0,
916
- metadata={"message": "No captions file found"},
725
+ metadata={"message": "No data available"},
917
726
  )
918
727
 
919
- # Flush buffers first to ensure all data is on disk
920
- await self.checkpoint()
921
-
922
- # Determine columns to read
923
- if columns:
924
- # Validate requested columns exist
925
- table_metadata = pq.read_metadata(self.captions_path)
926
- available_columns = set(table_metadata.schema.names)
927
- invalid_columns = set(columns) - available_columns
928
- if invalid_columns:
929
- raise ValueError(f"Columns not found: {invalid_columns}")
930
- columns_to_read = columns
931
- else:
932
- # Read all columns
933
- columns_to_read = None
934
-
935
- # Read the table
936
- table = pq.read_table(self.captions_path, columns=columns_to_read)
937
- df = table.to_pandas()
938
-
939
- # Apply limit if specified
940
- if limit:
941
- df = df.head(limit)
942
-
943
- # Convert to list of dicts
944
- rows = df.to_dict("records")
945
-
946
- # Parse metadata JSON strings back to dicts if present
947
- if "metadata" in df.columns:
948
- for row in rows:
949
- if row.get("metadata"):
950
- try:
951
- row["metadata"] = json.loads(row["metadata"])
952
- except:
953
- pass # Keep as string if parsing fails
954
-
955
- # Prepare metadata
956
- metadata = {}
957
- if include_metadata:
958
- metadata.update(
959
- {
960
- "export_timestamp": pd.Timestamp.now().isoformat(),
961
- "total_available_rows": self.stats.get("disk_rows", 0),
728
+ try:
729
+ logger.debug("Getting DuckDB connection")
730
+ con = self.init_duckdb_connection()
731
+ logger.debug("Got DuckDB connection, building query")
732
+
733
+ # Build query
734
+ column_str = "*"
735
+ if columns:
736
+ # Quote column names to handle special characters
737
+ column_str = ", ".join([f'"{c}"' for c in columns])
738
+
739
+ query = f"SELECT {column_str} FROM captions"
740
+ if limit:
741
+ query += f" LIMIT {limit}"
742
+
743
+ logger.debug(f"Executing DuckDB query: {query}")
744
+ # Execute query and fetch data
745
+ table = con.execute(query).fetch_arrow_table()
746
+ logger.debug(f"Query executed successfully, got {table.num_rows} rows")
747
+ rows = table.to_pylist()
748
+ actual_columns = table.schema.names
749
+
750
+ # Parse metadata
751
+ if "metadata" in actual_columns:
752
+ for row in rows:
753
+ if row.get("metadata"):
754
+ try:
755
+ row["metadata"] = json.loads(row["metadata"])
756
+ except (json.JSONDecodeError, TypeError):
757
+ pass # Keep as string if not valid JSON
758
+
759
+ metadata = {}
760
+ if include_metadata:
761
+ metadata = {
762
+ "export_timestamp": datetime.now().isoformat(),
763
+ "total_available_rows": self.stats["disk_rows"],
962
764
  "rows_exported": len(rows),
963
765
  "storage_path": str(self.captions_path),
964
- "field_stats": self.stats.get("field_counts", {}),
766
+ "field_stats": self.stats["field_counts"],
965
767
  }
966
- )
967
768
 
968
- return StorageContents(
969
- rows=rows,
970
- columns=list(df.columns),
971
- output_fields=list(self.known_output_fields),
972
- total_rows=len(df),
973
- metadata=metadata,
974
- )
769
+ return StorageContents(
770
+ rows=rows,
771
+ columns=actual_columns,
772
+ output_fields=list(self.known_output_fields),
773
+ total_rows=len(rows),
774
+ metadata=metadata,
775
+ )
776
+ except Exception as e:
777
+ logger.error(f"Failed to get storage contents with DuckDB: {e}", exc_info=True)
778
+ return StorageContents(
779
+ rows=[],
780
+ columns=[],
781
+ output_fields=list(self.known_output_fields),
782
+ total_rows=0,
783
+ metadata={"error": str(e)},
784
+ )
975
785
 
976
786
  async def get_processed_jobs_for_chunk(self, chunk_id: str) -> Set[str]:
977
787
  """Get all processed job_ids for a given chunk."""
978
- if not self.captions_path.exists():
788
+ if not self.captions_dataset:
979
789
  return set()
980
790
 
981
- # Read only job_id and chunk_id columns
982
- table = pq.read_table(self.captions_path, columns=["job_id", "chunk_id"])
983
- df = table.to_pandas()
984
-
985
- # Filter by chunk_id and return job_ids
986
- chunk_jobs = df[df["chunk_id"] == chunk_id]["job_id"].tolist()
987
- return set(chunk_jobs)
791
+ table = self.captions_dataset.to_table(
792
+ columns=["job_id", "chunk_id"], filter=f"chunk_id = '{chunk_id}'"
793
+ )
794
+ return set(table.column("job_id").to_pylist())
988
795
 
989
796
  async def get_caption_stats(self) -> Dict[str, Any]:
990
- """Get statistics about stored captions from cached values."""
797
+ """Get statistics about stored captions."""
991
798
  total_rows = self.stats["disk_rows"] + len(self.caption_buffer)
992
799
 
993
800
  # Count outputs in buffer
@@ -1021,24 +828,20 @@ class StorageManager:
1021
828
  "total_outputs": total_outputs,
1022
829
  "output_fields": sorted(list(self.known_output_fields)),
1023
830
  "field_stats": field_stats,
831
+ # Compatibility fields for CLI
832
+ "shard_count": 1,
833
+ "shards": ["default"],
1024
834
  }
1025
835
 
1026
836
  async def count_captions(self) -> int:
1027
- """Count total outputs across all dynamic fields from cached values."""
1028
- # Use cached disk count
1029
- total = self.stats["disk_outputs"]
1030
-
1031
- # Add buffer counts
1032
- for row in self.caption_buffer:
1033
- for field_name in self.known_output_fields:
1034
- if field_name in row and isinstance(row[field_name], list):
1035
- total += len(row[field_name])
1036
-
1037
- return total
837
+ """Count total outputs across all fields."""
838
+ stats = await self.get_caption_stats()
839
+ return stats["total_outputs"]
1038
840
 
1039
841
  async def count_caption_rows(self) -> int:
1040
- """Count total rows from cached values."""
1041
- return self.stats["disk_rows"] + len(self.caption_buffer)
842
+ """Count total rows."""
843
+ stats = await self.get_caption_stats()
844
+ return stats["total_rows"]
1042
845
 
1043
846
  async def get_contributor(self, contributor_id: str) -> Optional[Contributor]:
1044
847
  """Retrieve a contributor by ID."""
@@ -1047,32 +850,35 @@ class StorageManager:
1047
850
  if buffered["contributor_id"] == contributor_id:
1048
851
  return Contributor(**buffered)
1049
852
 
1050
- if not self.contributors_path.exists():
853
+ if not self.contributors_dataset:
1051
854
  return None
1052
855
 
1053
- table = pq.read_table(self.contributors_path)
1054
- df = table.to_pandas()
856
+ try:
857
+ table = self.contributors_dataset.to_table(
858
+ filter=f"contributor_id = '{contributor_id}'"
859
+ )
860
+ if table.num_rows == 0:
861
+ return None
1055
862
 
1056
- row = df[df["contributor_id"] == contributor_id]
1057
- if row.empty:
863
+ df = table.to_pandas()
864
+ row = df.iloc[0]
865
+ return Contributor(
866
+ contributor_id=row["contributor_id"],
867
+ name=row["name"],
868
+ total_captions=int(row["total_captions"]),
869
+ trust_level=int(row["trust_level"]),
870
+ )
871
+ except Exception as e:
872
+ logger.error(f"Failed to get contributor {contributor_id}: {e}")
1058
873
  return None
1059
874
 
1060
- return Contributor(
1061
- contributor_id=row.iloc[0]["contributor_id"],
1062
- name=row.iloc[0]["name"],
1063
- total_captions=int(row.iloc[0]["total_captions"]),
1064
- trust_level=int(row.iloc[0]["trust_level"]),
1065
- )
1066
-
1067
875
  async def get_top_contributors(self, limit: int = 10) -> List[Contributor]:
1068
876
  """Get top contributors by caption count."""
1069
877
  contributors = []
1070
878
 
1071
- if self.contributors_path.exists():
1072
- table = pq.read_table(self.contributors_path)
879
+ if self.contributors_dataset:
880
+ table = self.contributors_dataset.to_table()
1073
881
  df = table.to_pandas()
1074
-
1075
- # Sort by total_captions descending
1076
882
  df = df.sort_values("total_captions", ascending=False).head(limit)
1077
883
 
1078
884
  for _, row in df.iterrows():
@@ -1088,18 +894,9 @@ class StorageManager:
1088
894
  return contributors
1089
895
 
1090
896
  async def get_output_field_stats(self) -> Dict[str, Any]:
1091
- """Get statistics about output fields from cached values."""
1092
- # Combine disk and buffer stats
1093
- field_counts = self.stats["field_counts"].copy()
1094
-
1095
- # Add buffer counts
1096
- for row in self.caption_buffer:
1097
- for field_name in self.known_output_fields:
1098
- if field_name in row and isinstance(row[field_name], list):
1099
- if field_name not in field_counts:
1100
- field_counts[field_name] = 0
1101
- field_counts[field_name] += len(row[field_name])
1102
-
897
+ """Get statistics about output fields."""
898
+ stats = await self.get_caption_stats()
899
+ field_counts = {field: info["total_items"] for field, info in stats["field_stats"].items()}
1103
900
  total_outputs = sum(field_counts.values())
1104
901
 
1105
902
  return {
@@ -1113,55 +910,138 @@ class StorageManager:
1113
910
  """Close storage and flush buffers."""
1114
911
  await self.checkpoint()
1115
912
 
1116
- # Log final rate statistics
1117
- if self.stats["total_captions_written"] > 0:
1118
- rates = self._calculate_rates()
1119
- logger.info(
1120
- f"Storage closed. Total rows: {self.stats['total_captions_written']}, "
1121
- f"Total caption entries: {self.stats['total_caption_entries_written']}, "
1122
- f"Duplicates skipped: {self.stats['duplicates_skipped']} | "
1123
- f"Final rates - Overall: {rates['overall']:.1f} rows/s, "
1124
- f"Last hour: {rates['60min']:.1f} rows/s"
1125
- )
1126
- else:
1127
- logger.info(
1128
- f"Storage closed. Total rows: {self.stats['total_captions_written']}, "
1129
- f"Total caption entries: {self.stats['total_caption_entries_written']}, "
1130
- f"Duplicates skipped: {self.stats['duplicates_skipped']}"
1131
- )
913
+ rates = self._calculate_rates()
914
+ logger.info(
915
+ f"Storage closed. Total rows written: {self.stats['total_captions_written']}, "
916
+ f"Total outputs: {self.stats['total_caption_entries_written']}, "
917
+ f"Overall rate: {rates['overall']:.1f} rows/s"
918
+ )
1132
919
 
1133
920
  async def get_storage_stats(self) -> Dict[str, Any]:
1134
- """Get all storage-related statistics from cached values."""
1135
- # Count outputs in buffer
1136
- buffer_outputs = 0
1137
- for row in self.caption_buffer:
1138
- for field_name in self.known_output_fields:
1139
- if field_name in row and isinstance(row[field_name], list):
1140
- buffer_outputs += len(row[field_name])
1141
-
1142
- # Get field-specific stats
1143
- field_stats = await self.get_caption_stats()
1144
- total_rows_including_buffer = self.stats["disk_rows"] + len(self.caption_buffer)
1145
-
1146
- # Calculate rates
921
+ """Get all storage-related statistics."""
922
+ caption_stats = await self.get_caption_stats()
1147
923
  rates = self._calculate_rates()
1148
924
 
925
+ # Format field_breakdown to match expected format (dict of dicts with total_items)
926
+ field_breakdown = {}
927
+ for field, stats in caption_stats.get("field_stats", {}).items():
928
+ if isinstance(stats, dict):
929
+ # Already in correct format
930
+ field_breakdown[field] = stats
931
+ else:
932
+ # Convert simple int to expected format
933
+ field_breakdown[field] = {"total_items": stats}
934
+
1149
935
  return {
1150
- "total_captions": self.stats["disk_outputs"] + buffer_outputs,
1151
- "total_rows": total_rows_including_buffer,
936
+ "total_captions": caption_stats["total_outputs"],
937
+ "total_rows": caption_stats["total_rows"],
1152
938
  "buffer_size": len(self.caption_buffer),
1153
939
  "total_written": self.stats["total_captions_written"],
1154
940
  "total_entries_written": self.stats["total_caption_entries_written"],
1155
941
  "duplicates_skipped": self.stats["duplicates_skipped"],
1156
942
  "total_flushes": self.stats["total_flushes"],
1157
943
  "output_fields": sorted(list(self.known_output_fields)),
1158
- "field_breakdown": field_stats.get("field_stats", {}),
944
+ "field_breakdown": field_breakdown,
1159
945
  "contributor_buffer_size": len(self.contributor_buffer),
1160
946
  "rates": {
1161
- "instant": f"{rates['instant']:.1f} rows/s",
1162
- "5min": f"{rates['5min']:.1f} rows/s",
1163
- "15min": f"{rates['15min']:.1f} rows/s",
1164
- "60min": f"{rates['60min']:.1f} rows/s",
1165
- "overall": f"{rates['overall']:.1f} rows/s",
947
+ "instant": f"{rates.get('instant', 0.0):.1f} rows/s",
948
+ "5min": f"{rates.get('5min', 0.0):.1f} rows/s",
949
+ "15min": f"{rates.get('15min', 0.0):.1f} rows/s",
950
+ "60min": f"{rates.get('60min', 0.0):.1f} rows/s",
951
+ "overall": f"{rates.get('overall', 0.0):.1f} rows/s",
1166
952
  },
1167
953
  }
954
+
955
+ async def optimize_storage(self):
956
+ """Optimize storage by compacting Lance dataset."""
957
+ if self.captions_dataset:
958
+ logger.info("Optimizing Lance dataset...")
959
+ self.captions_dataset.optimize.compact_files()
960
+ self.captions_dataset.cleanup_old_versions()
961
+ logger.info("Storage optimization complete")
962
+
963
+ def _is_column_empty(self, df: pd.DataFrame, column_name: str) -> bool:
964
+ """Check if a column is entirely empty."""
965
+ if column_name not in df.columns:
966
+ return True
967
+
968
+ col = df[column_name]
969
+ if col.isna().all():
970
+ return True
971
+
972
+ if pd.api.types.is_numeric_dtype(col):
973
+ non_null_values = col.dropna()
974
+ if len(non_null_values) > 0 and (non_null_values == 0).all():
975
+ return True
976
+
977
+ if col.dtype == "object":
978
+ non_null_values = col.dropna()
979
+ if len(non_null_values) == 0:
980
+ return True
981
+ all_empty_lists = True
982
+ for val in non_null_values:
983
+ if isinstance(val, list) and len(val) > 0:
984
+ all_empty_lists = False
985
+ break
986
+ elif not isinstance(val, list):
987
+ all_empty_lists = False
988
+ break
989
+ if all_empty_lists:
990
+ return True
991
+
992
+ return False
993
+
994
+ def _get_non_empty_columns(
995
+ self, df: pd.DataFrame, preserve_base_fields: bool = True
996
+ ) -> List[str]:
997
+ """Get list of columns that contain actual data."""
998
+ base_field_names = {field[0] for field in self.base_caption_fields}
999
+ non_empty_columns = []
1000
+
1001
+ for col in df.columns:
1002
+ if preserve_base_fields and col in base_field_names:
1003
+ non_empty_columns.append(col)
1004
+ elif not self._is_column_empty(df, col):
1005
+ non_empty_columns.append(col)
1006
+
1007
+ return non_empty_columns
1008
+
1009
+ def _get_existing_output_columns(self) -> Set[str]:
1010
+ """Get output field columns that exist - for API compatibility."""
1011
+ return self.known_output_fields.copy()
1012
+
1013
+ # Compatibility methods for LanceStorageExporter
1014
+ @property
1015
+ def shard_datasets(self) -> Dict[str, Any]:
1016
+ """Compatibility property for exporter - returns single default shard."""
1017
+ if self.captions_dataset:
1018
+ return {"default": self.captions_dataset}
1019
+ return {}
1020
+
1021
+ @property
1022
+ def shard_output_fields(self) -> Dict[str, Set[str]]:
1023
+ """Compatibility property for exporter - returns output fields for default shard."""
1024
+ return {"default": self.known_output_fields.copy()}
1025
+
1026
+ async def get_shard_contents(
1027
+ self,
1028
+ shard_name: str,
1029
+ limit: Optional[int] = None,
1030
+ columns: Optional[List[str]] = None,
1031
+ include_metadata: bool = True,
1032
+ ) -> StorageContents:
1033
+ """Compatibility method for exporter - delegates to get_storage_contents for default shard."""
1034
+ if shard_name != "default":
1035
+ return StorageContents(
1036
+ rows=[],
1037
+ columns=[],
1038
+ output_fields=list(self.known_output_fields),
1039
+ total_rows=0,
1040
+ metadata={
1041
+ "error": f"Shard '{shard_name}' not found. Only 'default' shard is supported."
1042
+ },
1043
+ )
1044
+
1045
+ return await self.get_storage_contents(
1046
+ limit=limit, columns=columns, include_metadata=include_metadata
1047
+ )