pixeltable 0.4.16__py3-none-any.whl → 0.4.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

@@ -1,5 +1,4 @@
1
1
  import base64
2
- import datetime
3
2
  import io
4
3
  import json
5
4
  import logging
@@ -13,6 +12,7 @@ from uuid import UUID
13
12
 
14
13
  import more_itertools
15
14
  import numpy as np
15
+ import pgvector.sqlalchemy as sql_vector # type: ignore[import-untyped]
16
16
  import PIL.Image
17
17
  import pyarrow as pa
18
18
  import pyarrow.parquet as pq
@@ -21,6 +21,7 @@ import sqlalchemy as sql
21
21
  import pixeltable as pxt
22
22
  from pixeltable import catalog, exceptions as excs, metadata, type_system as ts
23
23
  from pixeltable.env import Env
24
+ from pixeltable.exprs.data_row import CellMd
24
25
  from pixeltable.metadata import schema
25
26
  from pixeltable.utils import sha256sum
26
27
  from pixeltable.utils.formatter import Formatter
@@ -109,9 +110,12 @@ class TablePackager:
109
110
  assert any(tv.id == base.id for base in self.table._tbl_version_path.get_tbl_versions())
110
111
  sql_types = {col.name: col.type for col in tv.store_tbl.sa_tbl.columns}
111
112
  media_cols: set[str] = set()
113
+ cellmd_cols: set[str] = set()
112
114
  for col in tv.cols:
113
115
  if col.is_stored and col.col_type.is_media_type():
114
116
  media_cols.add(col.store_name())
117
+ if col.stores_cellmd:
118
+ cellmd_cols.add(col.cellmd_store_name())
115
119
 
116
120
  parquet_schema = self.__to_parquet_schema(tv.store_tbl.sa_tbl)
117
121
  # TODO: Partition larger tables into multiple parquet files. (The parquet file naming scheme anticipates
@@ -126,10 +130,10 @@ class TablePackager:
126
130
  # excessive memory usage. The pyarrow tables are then amalgamated into the (single) Parquet table on disk.
127
131
  # We use snappy compression for the Parquet tables; the entire bundle will be bzip2-compressed later, so
128
132
  # faster compression should provide good performance while still reducing temporary storage utilization.
129
- parquet_writer = pq.ParquetWriter(parquet_file, parquet_schema, compression='SNAPPY')
133
+ parquet_writer = pq.ParquetWriter(parquet_file, parquet_schema, compression='snappy')
130
134
  filter_tv = self.table._tbl_version_path.tbl_version.get()
131
135
  row_iter = tv.store_tbl.dump_rows(tv.version, filter_tv.store_tbl, filter_tv.version)
132
- for pa_table in self.__to_pa_tables(row_iter, sql_types, media_cols, parquet_schema):
136
+ for pa_table in self.__to_pa_tables(row_iter, sql_types, media_cols, cellmd_cols, parquet_schema):
133
137
  parquet_writer.write_table(pa_table)
134
138
  parquet_writer.close()
135
139
 
@@ -138,7 +142,7 @@ class TablePackager:
138
142
  @classmethod
139
143
  def __to_parquet_schema(cls, store_tbl: sql.Table) -> pa.Schema:
140
144
  entries = [(col_name, cls.__to_parquet_type(col.type)) for col_name, col in store_tbl.columns.items()]
141
- return pa.schema(entries) # type: ignore[arg-type]
145
+ return pa.schema(entries)
142
146
 
143
147
  @classmethod
144
148
  def __to_parquet_type(cls, col_type: sql.types.TypeEngine[Any]) -> pa.DataType:
@@ -151,13 +155,17 @@ class TablePackager:
151
155
  if isinstance(col_type, sql.Float):
152
156
  return pa.float32()
153
157
  if isinstance(col_type, sql.TIMESTAMP):
154
- return pa.timestamp('us', tz=datetime.timezone.utc)
158
+ return pa.timestamp('us', tz='UTC')
155
159
  if isinstance(col_type, sql.Date):
156
160
  return pa.date32()
157
161
  if isinstance(col_type, sql.JSON):
158
162
  return pa.string() # JSON will be exported as strings
159
163
  if isinstance(col_type, sql.LargeBinary):
160
164
  return pa.binary()
165
+ if isinstance(col_type, sql_vector.Vector):
166
+ # Parquet/pyarrow do not handle null values properly for fixed_shape_tensor(), so we have to use list_()
167
+ # here instead.
168
+ return pa.list_(pa.float32())
161
169
  raise AssertionError(f'Unrecognized SQL type: {col_type} (type {type(col_type)})')
162
170
 
163
171
  def __to_pa_tables(
@@ -165,6 +173,7 @@ class TablePackager:
165
173
  row_iter: Iterator[dict[str, Any]],
166
174
  sql_types: dict[str, sql.types.TypeEngine[Any]],
167
175
  media_cols: set[str],
176
+ cellmd_cols: set[str],
168
177
  arrow_schema: pa.Schema,
169
178
  batch_size: int = 1_000,
170
179
  ) -> Iterator[pa.Table]:
@@ -176,14 +185,21 @@ class TablePackager:
176
185
  for rows in more_itertools.batched(row_iter, batch_size):
177
186
  cols = {}
178
187
  for name, sql_type in sql_types.items():
179
- is_media_col = name in media_cols
180
- values = [self.__to_pa_value(row.get(name), sql_type, is_media_col) for row in rows]
188
+ values = [
189
+ self.__to_pa_value(row.get(name), sql_type, name in media_cols, name in cellmd_cols) for row in rows
190
+ ]
181
191
  cols[name] = values
182
192
  yield pa.Table.from_pydict(cols, schema=arrow_schema)
183
193
 
184
- def __to_pa_value(self, val: Any, sql_type: sql.types.TypeEngine[Any], is_media_col: bool) -> Any:
194
+ def __to_pa_value(
195
+ self, val: Any, sql_type: sql.types.TypeEngine[Any], is_media_col: bool, is_cellmd_col: bool
196
+ ) -> Any:
185
197
  if val is None:
186
198
  return None
199
+ if is_cellmd_col:
200
+ assert isinstance(val, dict)
201
+ # Export JSON as strings
202
+ return json.dumps(self.__process_cellmd(val))
187
203
  if isinstance(sql_type, sql.JSON):
188
204
  # Export JSON as strings
189
205
  return json.dumps(val)
@@ -194,6 +210,10 @@ class TablePackager:
194
210
  return val
195
211
 
196
212
  def __process_media_url(self, url: str) -> str:
213
+ """
214
+ Process a media URL for export. If it's a local file URL (file://), then replace it with a pxtmedia:// URI,
215
+ copying the file into the tarball if necessary. If it's any other type of URL, return it unchanged.
216
+ """
197
217
  parsed_url = urllib.parse.urlparse(url)
198
218
  if parsed_url.scheme == 'file':
199
219
  # It's the URL of a local file. Replace it with a pxtmedia:// URI.
@@ -214,6 +234,21 @@ class TablePackager:
214
234
  # For any type of URL other than a local file, just return the URL as-is.
215
235
  return url
216
236
 
237
+ def __process_cellmd(self, cellmd: dict[str, Any]) -> dict[str, Any]:
238
+ """
239
+ Process a cellmd dictionary for export. This involves replacing any local file references
240
+ with pxtmedia:// URIs, as described above.
241
+ """
242
+ cellmd_ = CellMd.from_dict(cellmd)
243
+ if cellmd_.file_urls is None:
244
+ return cellmd # No changes
245
+
246
+ updated_urls: list[str] = []
247
+ for url in cellmd_.file_urls:
248
+ updated_urls.append(self.__process_media_url(url))
249
+ cellmd_.file_urls = updated_urls
250
+ return cellmd_.as_dict()
251
+
217
252
  def __build_tarball(self) -> Path:
218
253
  bundle_path = self.tmp_dir / 'bundle.tar.bz2'
219
254
  with tarfile.open(bundle_path, 'w:bz2') as tf:
@@ -409,6 +444,9 @@ class TableRestorer:
409
444
  # 2. "rectify" the v_max values in both the temporary table and the existing table (more on this below);
410
445
  # 3. Delete any row instances from the temporary table that are already present in the existing table;
411
446
  # 4. Copy the remaining rows from the temporary table into the existing table.
447
+ # 5. Rectify any index columns.
448
+
449
+ # STEP 1: Import the parquet data into a temporary table.
412
450
 
413
451
  # Create a temporary table for the initial data load, containing columns for all columns present in the
414
452
  # parquet table. The parquet columns have identical names to those in the store table, so we can use the
@@ -416,7 +454,7 @@ class TableRestorer:
416
454
  # e.g., pa.string() may hold either VARCHAR or serialized JSONB).
417
455
  temp_cols: dict[str, sql.Column] = {}
418
456
  for field in parquet_table.schema:
419
- assert field.name in store_sa_tbl.columns
457
+ assert field.name in store_sa_tbl.columns, f'{field.name} not in {list(store_sa_tbl.columns)}'
420
458
  col_type = store_sa_tbl.columns[field.name].type
421
459
  temp_cols[field.name] = sql.Column(field.name, col_type)
422
460
  temp_sa_tbl_name = f'temp_{uuid.uuid4().hex}'
@@ -432,6 +470,8 @@ class TableRestorer:
432
470
  rows = self.__from_pa_pydict(tv, pydict)
433
471
  conn.execute(sql.insert(temp_sa_tbl), rows)
434
472
 
473
+ # STEP 2: Rectify v_max values.
474
+
435
475
  # Each row version is identified uniquely by its pk, a tuple (row_id, pos_0, pos_1, ..., pos_k, v_min).
436
476
  # Conversely, v_max is not part of the primary key, but is simply a bookkeeping device.
437
477
  # In an original table, v_max is always equal to the v_min of the succeeding row instance with the same
@@ -540,6 +580,8 @@ class TableRestorer:
540
580
  result = conn.execute(q)
541
581
  _logger.debug(f'Rectified {result.rowcount} row(s) in {store_sa_tbl_name!r}.')
542
582
 
583
+ # STEP 3: Delete any row instances from the temporary table that are already present in the existing table.
584
+
543
585
  # Now we need to update rows in the existing table that are also present in the temporary table. This is to
544
586
  # account for the scenario where the temporary table has columns that are not present in the existing table.
545
587
  # (We can't simply replace the rows with their versions in the temporary table, because the converse scenario
@@ -570,7 +612,9 @@ class TableRestorer:
570
612
  result = conn.execute(q)
571
613
  _logger.debug(f'Deleted {result.rowcount} row(s) from {temp_sa_tbl_name!r}.')
572
614
 
573
- # Finally, copy the remaining data (consisting entirely of new row instances) from the temporary table into
615
+ # STEP 4: Copy the remaining rows from the temporary table into the existing table.
616
+
617
+ # Now copy the remaining data (consisting entirely of new row instances) from the temporary table into
574
618
  # the actual table.
575
619
  q = store_sa_tbl.insert().from_select(
576
620
  [store_sa_tbl.c[col_name] for col_name in temp_cols], sql.select(*temp_cols.values())
@@ -579,39 +623,113 @@ class TableRestorer:
579
623
  result = conn.execute(q)
580
624
  _logger.debug(f'Inserted {result.rowcount} row(s) from {temp_sa_tbl_name!r} into {store_sa_tbl_name!r}.')
581
625
 
626
+ # STEP 5: Rectify any index columns.
627
+
628
+ # Finally, rectify any index columns in the table. This involves shuffling data between the index's val and
629
+ # undo columns to ensure they appropriately reflect the most recent replicated version of the table.
630
+
631
+ # Get the most recent replicated version of the table. This might be the version we're currently importing,
632
+ # but it might be a different version of the table that was previously imported.
633
+ head_version_md = catalog.Catalog.get()._collect_tbl_history(tv.id, n=1)[0]
634
+ head_version = head_version_md.version_md.version
635
+ _logger.debug(f'Head version for index rectification is {head_version}.')
636
+
637
+ # Get the index info from the table metadata. Here we use the tbl_md that we just collected from the DB.
638
+ # This is to ensure we pick up ALL indices, including dropped indices and indices that are present in
639
+ # a previously replicated version of the table, but not in the one currently being imported.
640
+ index_md = head_version_md.tbl_md.index_md
641
+
642
+ # Now update the table. We can do this for all indices together with just two SQL queries. For each index,
643
+ # at most one of the val or undo columns will be non-NULL in any given row.
644
+ # For rows where v_min <= head_version < v_max, we set, for all indices:
645
+ # val_col = whichever of (val_col, undo_col) is non-NULL (or NULL if both are, e.g., for a dropped index)
646
+ # undo_col = NULL
647
+ # For rows where head_version < v_min or v_max <= head_version, vice versa.
648
+ val_sql_clauses: dict[str, sql.ColumnElement] = {}
649
+ undo_sql_clauses: dict[str, sql.ColumnElement] = {}
650
+ for index in index_md.values():
651
+ if index.class_fqn.endswith('.EmbeddingIndex'):
652
+ val_col_name = f'col_{index.index_val_col_id}'
653
+ undo_col_name = f'col_{index.index_val_undo_col_id}'
654
+ # Check that the val column for the index is actually present in the store table. We need to do this
655
+ # to properly handle the case where the replica represents a table version that was *not* the most
656
+ # recent version at the time it was published. In that case, it is possible for tbl_md to contain
657
+ # metadata for indices not known to any version that has been replicated. (However, the converse
658
+ # *does* hold: all replicated indices must have metadata in tbl_md; and that's what's important.)
659
+ if val_col_name in store_sa_tbl.c:
660
+ assert undo_col_name in store_sa_tbl.c
661
+ coalesce = sql.func.coalesce(store_sa_tbl.c[val_col_name], store_sa_tbl.c[undo_col_name])
662
+ val_sql_clauses[val_col_name] = coalesce
663
+ val_sql_clauses[undo_col_name] = sql.null()
664
+ undo_sql_clauses[undo_col_name] = coalesce
665
+ undo_sql_clauses[val_col_name] = sql.null()
666
+
667
+ if len(val_sql_clauses) > 0:
668
+ q2 = (
669
+ store_sa_tbl.update()
670
+ .values(**val_sql_clauses)
671
+ .where(sql.and_(tv.store_tbl.v_min_col <= head_version, tv.store_tbl.v_max_col > head_version))
672
+ )
673
+ _logger.debug(q2.compile())
674
+ _ = conn.execute(q2)
675
+ q2 = (
676
+ store_sa_tbl.update()
677
+ .values(**undo_sql_clauses)
678
+ .where(sql.or_(tv.store_tbl.v_min_col > head_version, tv.store_tbl.v_max_col <= head_version))
679
+ )
680
+ _logger.debug(q2.compile())
681
+ _ = conn.execute(q2)
682
+ _logger.debug(f'Rectified index columns in {store_sa_tbl_name!r}.')
683
+ else:
684
+ _logger.debug(f'No index columns to rectify in {store_sa_tbl_name!r}.')
685
+
582
686
  def __from_pa_pydict(self, tv: catalog.TableVersion, pydict: dict[str, Any]) -> list[dict[str, Any]]:
583
687
  # Data conversions from pyarrow to Pixeltable
584
688
  sql_types: dict[str, sql.types.TypeEngine[Any]] = {}
585
689
  for col_name in pydict:
586
690
  assert col_name in tv.store_tbl.sa_tbl.columns
587
691
  sql_types[col_name] = tv.store_tbl.sa_tbl.columns[col_name].type
588
- media_cols: dict[str, catalog.Column] = {}
589
- for col in tv.cols:
590
- if col.is_stored and col.col_type.is_media_type():
591
- assert tv.id == col.tbl.id
592
- assert tv.version == col.tbl.version
593
- media_cols[col.store_name()] = col
692
+ stored_cols: dict[str, catalog.Column] = {col.store_name(): col for col in tv.cols if col.is_stored}
693
+ stored_cols |= {col.cellmd_store_name(): col for col in tv.cols if col.stores_cellmd}
594
694
 
595
695
  row_count = len(next(iter(pydict.values())))
596
- rows: list[dict[str, Any]] = []
597
- for i in range(row_count):
598
- row = {
599
- col_name: self.__from_pa_value(col_vals[i], sql_types[col_name], media_cols.get(col_name))
600
- for col_name, col_vals in pydict.items()
601
- }
602
- rows.append(row)
696
+ rows: list[dict[str, Any]] = [{} for _ in range(row_count)]
697
+ for col_name, col_vals in pydict.items():
698
+ assert len(col_vals) == row_count
699
+ col = stored_cols.get(col_name) # Will be None for system columns
700
+ is_media_col = col is not None and col.is_stored and col.col_type.is_media_type()
701
+ is_cellmd_col = col is not None and col.stores_cellmd and col_name == col.cellmd_store_name()
702
+ assert col is None or is_cellmd_col or col_name == col.store_name()
703
+
704
+ for i, val in enumerate(col_vals):
705
+ rows[i][col_name] = self.__from_pa_value(val, sql_types[col_name], col, is_media_col, is_cellmd_col)
603
706
 
604
707
  return rows
605
708
 
606
709
  def __from_pa_value(
607
- self, val: Any, sql_type: sql.types.TypeEngine[Any], media_col: Optional[catalog.Column]
710
+ self,
711
+ val: Any,
712
+ sql_type: sql.types.TypeEngine[Any],
713
+ col: Optional[catalog.Column],
714
+ is_media_col: bool,
715
+ is_cellmd_col: bool,
608
716
  ) -> Any:
609
717
  if val is None:
610
718
  return None
719
+ if isinstance(sql_type, sql_vector.Vector):
720
+ if isinstance(val, list):
721
+ val = np.array(val, dtype=np.float32)
722
+ assert isinstance(val, np.ndarray) and val.dtype == np.float32 and val.ndim == 1
723
+ return val
724
+ if is_cellmd_col:
725
+ assert col is not None
726
+ assert isinstance(val, str)
727
+ return self.__restore_cellmd(col, json.loads(val))
611
728
  if isinstance(sql_type, sql.JSON):
612
729
  return json.loads(val)
613
- if media_col is not None:
614
- return self.__relocate_media_file(media_col, val)
730
+ if is_media_col:
731
+ assert col is not None
732
+ return self.__relocate_media_file(col, val)
615
733
  return val
616
734
 
617
735
  def __relocate_media_file(self, media_col: catalog.Column, url: str) -> str:
@@ -629,3 +747,14 @@ class TableRestorer:
629
747
  return self.media_files[url]
630
748
  # For any type of URL other than a local file, just return the URL as-is.
631
749
  return url
750
+
751
+ def __restore_cellmd(self, col: catalog.Column, cellmd: dict[str, Any]) -> dict[str, Any]:
752
+ cellmd_ = CellMd.from_dict(cellmd)
753
+ if cellmd_.file_urls is None:
754
+ return cellmd # No changes
755
+
756
+ updated_urls: list[str] = []
757
+ for url in cellmd_.file_urls:
758
+ updated_urls.append(self.__relocate_media_file(col, url))
759
+ cellmd_.file_urls = updated_urls
760
+ return cellmd_.as_dict()
pixeltable/store.py CHANGED
@@ -111,6 +111,8 @@ class StoreBase:
111
111
  idx_name = f'vmax_idx_{tbl_version.id.hex}'
112
112
  idxs.append(sql.Index(idx_name, self.v_max_col, postgresql_using=Env.get().dbms.version_index_type))
113
113
 
114
+ # TODO: Include indices to ensure a completely accurate SA table definition?
115
+
114
116
  self.sa_tbl = sql.Table(self._storage_name(), self.sa_md, *all_cols, *idxs)
115
117
  # _logger.debug(f'created sa tbl for {tbl_version.id!s} (sa_tbl={id(self.sa_tbl):x}, tv={id(tbl_version):x})')
116
118
 
@@ -195,15 +197,34 @@ class StoreBase:
195
197
  log_stmt(_logger, stmt)
196
198
  Env.get().conn.execute(stmt)
197
199
 
198
- def ensure_columns_exist(self, cols: Iterable[catalog.Column]) -> None:
200
+ def ensure_updated_schema(self) -> None:
201
+ from pixeltable.utils.dbms import PostgresqlDbms
202
+
203
+ # This should only be called during replica creation where the underlying DBMS is Postgres.
204
+ assert isinstance(Env.get().dbms, PostgresqlDbms)
205
+
199
206
  conn = Env.get().conn
207
+ tv = self.tbl_version.get()
208
+
209
+ # Ensure columns exist
200
210
  sql_text = f'SELECT column_name FROM information_schema.columns WHERE table_name = {self._storage_name()!r}'
201
211
  result = conn.execute(sql.text(sql_text))
202
212
  existing_cols = {row[0] for row in result}
203
- for col in cols:
204
- if col.store_name() not in existing_cols:
213
+ for col in tv.cols:
214
+ if col.is_stored and col.store_name() not in existing_cols:
215
+ _logger.debug(f'Adding missing column {col.store_name()!r} to store table {self._storage_name()!r}')
205
216
  self.add_column(col)
206
217
 
218
+ # Ensure indices exist
219
+ sql_text = f'SELECT indexname FROM pg_indexes WHERE tablename = {self._storage_name()!r}'
220
+ result = conn.execute(sql.text(sql_text))
221
+ existing_idxs = {row[0] for row in result}
222
+ for idx_name, idx_info in tv.all_idxs.items():
223
+ store_name = tv._store_idx_name(idx_info.id)
224
+ if store_name not in existing_idxs:
225
+ _logger.debug(f'Creating missing index {idx_name!r} on store table {self._storage_name()!r}')
226
+ idx_info.idx.create_index(store_name, idx_info.val_col)
227
+
207
228
  def load_column(self, col: catalog.Column, exec_plan: ExecNode, abort_on_exc: bool) -> int:
208
229
  """Update store column of a computed column with values produced by an execution plan
209
230
 
@@ -434,8 +455,7 @@ class StoreBase:
434
455
  *[c1 == c2 for c1, c2 in zip(self.rowid_columns(), filter_view.rowid_columns())],
435
456
  )
436
457
  stmt = (
437
- sql.select('*') # TODO: Use a more specific list of columns?
438
- .select_from(self.sa_tbl)
458
+ sql.select(self.sa_tbl)
439
459
  .where(self.v_min_col <= version)
440
460
  .where(self.v_max_col > version)
441
461
  .where(sql.exists().where(filter_predicate))
pixeltable/utils/arrow.py CHANGED
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
16
16
  PA_TO_PXT_TYPES: dict[pa.DataType, ts.ColumnType] = {
17
17
  pa.string(): ts.StringType(nullable=True),
18
18
  pa.large_string(): ts.StringType(nullable=True),
19
- pa.timestamp('us', tz=datetime.timezone.utc): ts.TimestampType(nullable=True),
19
+ pa.timestamp('us', tz='UTC'): ts.TimestampType(nullable=True),
20
20
  pa.bool_(): ts.BoolType(nullable=True),
21
21
  pa.int8(): ts.IntType(nullable=True),
22
22
  pa.int16(): ts.IntType(nullable=True),
@@ -35,7 +35,7 @@ PA_TO_PXT_TYPES: dict[pa.DataType, ts.ColumnType] = {
35
35
 
36
36
  PXT_TO_PA_TYPES: dict[type[ts.ColumnType], pa.DataType] = {
37
37
  ts.StringType: pa.string(),
38
- ts.TimestampType: pa.timestamp('us', tz=datetime.timezone.utc), # postgres timestamp is microseconds
38
+ ts.TimestampType: pa.timestamp('us', tz='UTC'), # postgres timestamp is microseconds
39
39
  ts.DateType: pa.date32(), # This could be date64
40
40
  ts.BoolType: pa.bool_(),
41
41
  ts.IntType: pa.int64(),
@@ -61,7 +61,7 @@ def to_pixeltable_type(arrow_type: pa.DataType, nullable: bool) -> Optional[ts.C
61
61
  dtype = to_pixeltable_type(arrow_type.value_type, nullable)
62
62
  if dtype is None:
63
63
  return None
64
- return ts.ArrayType(shape=arrow_type.shape, dtype=dtype, nullable=nullable)
64
+ return ts.ArrayType(shape=tuple(arrow_type.shape), dtype=dtype, nullable=nullable)
65
65
  else:
66
66
  return None
67
67
 
@@ -92,7 +92,7 @@ def to_pxt_schema(
92
92
 
93
93
 
94
94
  def to_arrow_schema(pixeltable_schema: dict[str, Any]) -> pa.Schema:
95
- return pa.schema((name, to_arrow_type(typ)) for name, typ in pixeltable_schema.items()) # type: ignore[misc]
95
+ return pa.schema((name, to_arrow_type(typ)) for name, typ in pixeltable_schema.items())
96
96
 
97
97
 
98
98
  def _to_record_batch(column_vals: dict[str, list[Any]], schema: pa.Schema) -> pa.RecordBatch:
@@ -106,7 +106,7 @@ def _to_record_batch(column_vals: dict[str, list[Any]], schema: pa.Schema) -> pa
106
106
  else:
107
107
  pa_array = cast(pa.Array, pa.array(column_vals[field.name]))
108
108
  pa_arrays.append(pa_array)
109
- return pa.RecordBatch.from_arrays(pa_arrays, schema=schema) # type: ignore
109
+ return pa.RecordBatch.from_arrays(pa_arrays, schema=schema)
110
110
 
111
111
 
112
112
  def to_record_batches(df: 'pxt.DataFrame', batch_size_bytes: int) -> Iterator[pa.RecordBatch]:
@@ -192,7 +192,7 @@ def to_pydict(batch: pa.Table | pa.RecordBatch) -> dict[str, list | np.ndarray]:
192
192
  col = batch.column(k)
193
193
  if isinstance(col.type, pa.FixedShapeTensorType):
194
194
  # treat array columns as numpy arrays to easily preserve numpy type
195
- out[name] = col.to_numpy(zero_copy_only=False) # type: ignore[call-arg]
195
+ out[name] = col.to_numpy(zero_copy_only=False)
196
196
  else:
197
197
  # for the rest, use pydict to preserve python types
198
198
  out[name] = col.to_pylist()
pixeltable/utils/av.py CHANGED
@@ -3,6 +3,8 @@ from typing import Any
3
3
  import av
4
4
  import av.stream
5
5
 
6
+ from pixeltable.env import Env
7
+
6
8
 
7
9
  def get_metadata(path: str) -> dict:
8
10
  with av.open(path) as container:
@@ -89,23 +91,114 @@ def has_audio_stream(path: str) -> bool:
89
91
  return any(stream['type'] == 'audio' for stream in md['streams'])
90
92
 
91
93
 
92
- def ffmpeg_clip_cmd(input_path: str, output_path: str, start_time: float, duration: float | None = None) -> list[str]:
93
- # the order of arguments is critical: -ss <start> -t <duration> -i <input>
94
- cmd = ['ffmpeg', '-ss', str(start_time)]
94
+ def ffmpeg_clip_cmd(
95
+ input_path: str,
96
+ output_path: str,
97
+ start_time: float,
98
+ duration: float | None = None,
99
+ fast: bool = True,
100
+ video_encoder: str | None = None,
101
+ video_encoder_args: dict[str, Any] | None = None,
102
+ ) -> list[str]:
103
+ cmd = ['ffmpeg']
104
+ if fast:
105
+ # fast: -ss before -i
106
+ cmd.extend(
107
+ [
108
+ '-ss',
109
+ str(start_time),
110
+ '-i',
111
+ input_path,
112
+ '-map',
113
+ '0', # Copy all streams from input
114
+ '-c',
115
+ 'copy', # Stream copy (no re-encoding)
116
+ ]
117
+ )
118
+ else:
119
+ if video_encoder is None:
120
+ video_encoder = Env.get().default_video_encoder
121
+
122
+ # accurate: -ss after -i
123
+ cmd.extend(
124
+ [
125
+ '-i',
126
+ input_path,
127
+ '-ss',
128
+ str(start_time),
129
+ '-map',
130
+ '0', # Copy all streams from input
131
+ '-c:a',
132
+ 'copy', # audio copy
133
+ '-c:s',
134
+ 'copy', # subtitle copy
135
+ '-c:v',
136
+ video_encoder, # re-encode video
137
+ ]
138
+ )
139
+ if video_encoder_args is not None:
140
+ for k, v in video_encoder_args.items():
141
+ cmd.extend([f'-{k}', str(v)])
142
+
95
143
  if duration is not None:
96
144
  cmd.extend(['-t', str(duration)])
145
+ cmd.extend(['-loglevel', 'error', output_path])
146
+ return cmd
147
+
148
+
149
+ def ffmpeg_segment_cmd(
150
+ input_path: str,
151
+ output_pattern: str,
152
+ segment_duration: float | None = None,
153
+ segment_times: list[float] | None = None,
154
+ video_encoder: str | None = None,
155
+ video_encoder_args: dict[str, Any] | None = None,
156
+ ) -> list[str]:
157
+ """Commandline for frame-accurate segmentation"""
158
+ assert (segment_duration is None) != (segment_times is None)
159
+ if video_encoder is None:
160
+ video_encoder = Env.get().default_video_encoder
161
+
162
+ cmd = [
163
+ 'ffmpeg',
164
+ '-i',
165
+ input_path,
166
+ '-map',
167
+ '0', # Copy all streams from input
168
+ '-c:a',
169
+ 'copy', # don't re-encode audio
170
+ '-c:v',
171
+ video_encoder, # re-encode video
172
+ ]
173
+ if video_encoder_args is not None:
174
+ for k, v in video_encoder_args.items():
175
+ cmd.extend([f'-{k}', str(v)])
176
+ cmd.extend(['-f', 'segment'])
177
+
178
+ # -force_key_frames needs to precede -f segment
179
+ if segment_duration is not None:
180
+ cmd.extend(
181
+ [
182
+ '-force_key_frames',
183
+ f'expr:gte(t,n_forced*{segment_duration})', # Force keyframe at each segment boundary
184
+ '-f',
185
+ 'segment',
186
+ '-segment_time',
187
+ str(segment_duration),
188
+ ]
189
+ )
190
+ else:
191
+ assert segment_times is not None
192
+ times_str = ','.join([str(t) for t in segment_times])
193
+ cmd.extend(['-force_key_frames', times_str, '-f', 'segment', '-segment_times', times_str])
194
+
97
195
  cmd.extend(
98
196
  [
99
- '-i', # Input file
100
- input_path,
101
- '-y', # Overwrite output file
197
+ '-reset_timestamps',
198
+ '1', # Reset timestamps for each segment
102
199
  '-loglevel',
103
200
  'error', # Only show errors
104
- '-c',
105
- 'copy', # Stream copy (no re-encoding)
106
- '-map',
107
- '0', # Copy all streams from input
108
- output_path,
201
+ output_pattern,
109
202
  ]
110
203
  )
111
204
  return cmd
@@ -22,6 +22,7 @@ class StorageTarget(enum.Enum):
22
22
  LOCAL_STORE = 'os' # Local file system
23
23
  S3_STORE = 's3' # Amazon S3
24
24
  R2_STORE = 'r2' # Cloudflare R2
25
+ B2_STORE = 'b2' # Backblaze B2
25
26
  GCS_STORE = 'gs' # Google Cloud Storage
26
27
  AZURE_STORE = 'az' # Azure Blob Storage
27
28
  HTTP_STORE = 'http' # HTTP/HTTPS
@@ -63,6 +64,7 @@ class StorageObjectAddress(NamedTuple):
63
64
  StorageTarget.LOCAL_STORE,
64
65
  StorageTarget.S3_STORE,
65
66
  StorageTarget.R2_STORE,
67
+ StorageTarget.B2_STORE,
66
68
  StorageTarget.GCS_STORE,
67
69
  StorageTarget.AZURE_STORE,
68
70
  StorageTarget.HTTP_STORE,
@@ -218,15 +220,23 @@ class ObjectPath:
218
220
  # Standard HTTP(S) URL format
219
221
  # https://account.blob.core.windows.net/container/<optional path>/<optional object>
220
222
  # https://account.r2.cloudflarestorage.com/container/<optional path>/<optional object>
223
+ # https://s3.us-west-004.backblazeb2.com/container/<optional path>/<optional object>
221
224
  # and possibly others
222
225
  key = parsed.path
223
226
  if 'cloudflare' in parsed.netloc:
224
227
  storage_target = StorageTarget.R2_STORE
228
+ elif 'backblazeb2' in parsed.netloc:
229
+ storage_target = StorageTarget.B2_STORE
225
230
  elif 'windows' in parsed.netloc:
226
231
  storage_target = StorageTarget.AZURE_STORE
227
232
  else:
228
233
  storage_target = StorageTarget.HTTP_STORE
229
- if storage_target in [StorageTarget.S3_STORE, StorageTarget.AZURE_STORE, StorageTarget.R2_STORE]:
234
+ if storage_target in (
235
+ StorageTarget.S3_STORE,
236
+ StorageTarget.AZURE_STORE,
237
+ StorageTarget.R2_STORE,
238
+ StorageTarget.B2_STORE,
239
+ ):
230
240
  account_name = parsed.netloc.split('.', 1)[0]
231
241
  account_extension = parsed.netloc.split('.', 1)[1]
232
242
  path_parts = key.lstrip('/').split('/', 1)
@@ -370,6 +380,11 @@ class ObjectOps:
370
380
  env.Env.get().require_package('boto3')
371
381
  from pixeltable.utils.s3_store import S3Store
372
382
 
383
+ return S3Store(soa)
384
+ if soa.storage_target == StorageTarget.B2_STORE:
385
+ env.Env.get().require_package('boto3')
386
+ from pixeltable.utils.s3_store import S3Store
387
+
373
388
  return S3Store(soa)
374
389
  if soa.storage_target == StorageTarget.GCS_STORE and soa.scheme == 'gs':
375
390
  env.Env.get().require_package('google.cloud.storage')