pixeltable 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (119) hide show
  1. pixeltable/__init__.py +53 -0
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/__init__.py +13 -0
  4. pixeltable/catalog/catalog.py +159 -0
  5. pixeltable/catalog/column.py +181 -0
  6. pixeltable/catalog/dir.py +32 -0
  7. pixeltable/catalog/globals.py +33 -0
  8. pixeltable/catalog/insertable_table.py +192 -0
  9. pixeltable/catalog/named_function.py +36 -0
  10. pixeltable/catalog/path.py +58 -0
  11. pixeltable/catalog/path_dict.py +139 -0
  12. pixeltable/catalog/schema_object.py +39 -0
  13. pixeltable/catalog/table.py +695 -0
  14. pixeltable/catalog/table_version.py +1026 -0
  15. pixeltable/catalog/table_version_path.py +133 -0
  16. pixeltable/catalog/view.py +203 -0
  17. pixeltable/dataframe.py +749 -0
  18. pixeltable/env.py +466 -0
  19. pixeltable/exceptions.py +17 -0
  20. pixeltable/exec/__init__.py +10 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +94 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +73 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +226 -0
  31. pixeltable/exprs/__init__.py +25 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +114 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +199 -0
  39. pixeltable/exprs/expr.py +594 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +382 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +96 -0
  44. pixeltable/exprs/in_predicate.py +96 -0
  45. pixeltable/exprs/inline_array.py +109 -0
  46. pixeltable/exprs/inline_dict.py +103 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +66 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +329 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/similarity_expr.py +65 -0
  56. pixeltable/exprs/type_cast.py +53 -0
  57. pixeltable/exprs/variable.py +45 -0
  58. pixeltable/ext/__init__.py +5 -0
  59. pixeltable/ext/functions/yolox.py +92 -0
  60. pixeltable/func/__init__.py +7 -0
  61. pixeltable/func/aggregate_function.py +197 -0
  62. pixeltable/func/callable_function.py +113 -0
  63. pixeltable/func/expr_template_function.py +99 -0
  64. pixeltable/func/function.py +141 -0
  65. pixeltable/func/function_registry.py +227 -0
  66. pixeltable/func/globals.py +46 -0
  67. pixeltable/func/nos_function.py +202 -0
  68. pixeltable/func/signature.py +162 -0
  69. pixeltable/func/udf.py +164 -0
  70. pixeltable/functions/__init__.py +95 -0
  71. pixeltable/functions/eval.py +215 -0
  72. pixeltable/functions/fireworks.py +34 -0
  73. pixeltable/functions/huggingface.py +167 -0
  74. pixeltable/functions/image.py +16 -0
  75. pixeltable/functions/openai.py +289 -0
  76. pixeltable/functions/pil/image.py +147 -0
  77. pixeltable/functions/string.py +13 -0
  78. pixeltable/functions/together.py +143 -0
  79. pixeltable/functions/util.py +52 -0
  80. pixeltable/functions/video.py +62 -0
  81. pixeltable/globals.py +425 -0
  82. pixeltable/index/__init__.py +2 -0
  83. pixeltable/index/base.py +51 -0
  84. pixeltable/index/embedding_index.py +168 -0
  85. pixeltable/io/__init__.py +3 -0
  86. pixeltable/io/hf_datasets.py +188 -0
  87. pixeltable/io/pandas.py +148 -0
  88. pixeltable/io/parquet.py +192 -0
  89. pixeltable/iterators/__init__.py +3 -0
  90. pixeltable/iterators/base.py +52 -0
  91. pixeltable/iterators/document.py +432 -0
  92. pixeltable/iterators/video.py +88 -0
  93. pixeltable/metadata/__init__.py +58 -0
  94. pixeltable/metadata/converters/convert_10.py +18 -0
  95. pixeltable/metadata/converters/convert_12.py +3 -0
  96. pixeltable/metadata/converters/convert_13.py +41 -0
  97. pixeltable/metadata/schema.py +234 -0
  98. pixeltable/plan.py +620 -0
  99. pixeltable/store.py +424 -0
  100. pixeltable/tool/create_test_db_dump.py +184 -0
  101. pixeltable/tool/create_test_video.py +81 -0
  102. pixeltable/type_system.py +846 -0
  103. pixeltable/utils/__init__.py +17 -0
  104. pixeltable/utils/arrow.py +98 -0
  105. pixeltable/utils/clip.py +18 -0
  106. pixeltable/utils/coco.py +136 -0
  107. pixeltable/utils/documents.py +69 -0
  108. pixeltable/utils/filecache.py +195 -0
  109. pixeltable/utils/help.py +11 -0
  110. pixeltable/utils/http_server.py +70 -0
  111. pixeltable/utils/media_store.py +76 -0
  112. pixeltable/utils/pytorch.py +91 -0
  113. pixeltable/utils/s3.py +13 -0
  114. pixeltable/utils/sql.py +17 -0
  115. pixeltable/utils/transactional_directory.py +35 -0
  116. pixeltable-0.0.0.dist-info/LICENSE +18 -0
  117. pixeltable-0.0.0.dist-info/METADATA +131 -0
  118. pixeltable-0.0.0.dist-info/RECORD +119 -0
  119. pixeltable-0.0.0.dist-info/WHEEL +4 -0
pixeltable/store.py ADDED
@@ -0,0 +1,424 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import logging
5
+ import os
6
+ import sys
7
+ import urllib.parse
8
+ import urllib.request
9
+ import warnings
10
+ from typing import Optional, Dict, Any, List, Tuple, Set
11
+
12
+ import sqlalchemy as sql
13
+ from tqdm import tqdm, TqdmWarning
14
+
15
+ import pixeltable.catalog as catalog
16
+ import pixeltable.env as env
17
+ from pixeltable import exprs
18
+ import pixeltable.exceptions as excs
19
+ from pixeltable.exec import ExecNode
20
+ from pixeltable.metadata import schema
21
+ from pixeltable.type_system import StringType
22
+ from pixeltable.utils.media_store import MediaStore
23
+ from pixeltable.utils.sql import log_stmt, log_explain
24
+
25
+ _logger = logging.getLogger('pixeltable')
26
+
27
+
28
+ class StoreBase:
29
+ """Base class for stored tables
30
+
31
+ Each row has the following system columns:
32
+ - rowid columns: one or more columns that identify a user-visible row across all versions
33
+ - v_min: version at which the row was created
34
+ - v_max: version at which the row was deleted (or MAX_VERSION if it's still live)
35
+ """
36
+
37
+ def __init__(self, tbl_version: catalog.TableVersion):
38
+ self.tbl_version = tbl_version
39
+ self.sa_md = sql.MetaData()
40
+ self.sa_tbl: Optional[sql.Table] = None
41
+ self.create_sa_tbl()
42
+
43
+ def pk_columns(self) -> List[sql.Column]:
44
+ return self._pk_columns
45
+
46
+ def rowid_columns(self) -> List[sql.Column]:
47
+ return self._pk_columns[:-1]
48
+
49
+ @abc.abstractmethod
50
+ def _create_rowid_columns(self) -> List[sql.Column]:
51
+ """Create and return rowid columns"""
52
+ pass
53
+
54
+ @abc.abstractmethod
55
+ def _create_system_columns(self) -> List[sql.Column]:
56
+ """Create and return system columns"""
57
+ rowid_cols = self._create_rowid_columns()
58
+ self.v_min_col = sql.Column('v_min', sql.BigInteger, nullable=False)
59
+ self.v_max_col = \
60
+ sql.Column('v_max', sql.BigInteger, nullable=False, server_default=str(schema.Table.MAX_VERSION))
61
+ self._pk_columns = [*rowid_cols, self.v_min_col]
62
+ return [*rowid_cols, self.v_min_col, self.v_max_col]
63
+
64
+
65
+ def create_sa_tbl(self) -> None:
66
+ """Create self.sa_tbl from self.tbl_version."""
67
+ system_cols = self._create_system_columns()
68
+ all_cols = system_cols.copy()
69
+ idxs: List[sql.Index] = []
70
+ for col in [c for c in self.tbl_version.cols if c.is_stored]:
71
+ # re-create sql.Column for each column, regardless of whether it already has sa_col set: it was bound
72
+ # to the last sql.Table version we created and cannot be reused
73
+ col.create_sa_cols()
74
+ all_cols.append(col.sa_col)
75
+ if col.records_errors:
76
+ all_cols.append(col.sa_errormsg_col)
77
+ all_cols.append(col.sa_errortype_col)
78
+
79
+ # we create an index for:
80
+ # - scalar columns (except for strings, because long strings can't be used for B-tree indices)
81
+ # - non-computed video and image columns (they will contain external paths/urls that users might want to
82
+ # filter on)
83
+ if (col.col_type.is_scalar_type() and not col.col_type.is_string_type()) \
84
+ or (col.col_type.is_media_type() and not col.is_computed):
85
+ # index names need to be unique within the Postgres instance
86
+ idx_name = f'idx_{col.id}_{self.tbl_version.id.hex}'
87
+ idxs.append(sql.Index(idx_name, col.sa_col))
88
+
89
+ if self.sa_tbl is not None:
90
+ # if we're called in response to a schema change, we need to remove the old table first
91
+ self.sa_md.remove(self.sa_tbl)
92
+
93
+ # index for all system columns:
94
+ # - base x view joins can be executed as merge joins
95
+ # - speeds up ORDER BY rowid DESC
96
+ # - allows filtering for a particular table version in index scan
97
+ idx_name = f'sys_cols_idx_{self.tbl_version.id.hex}'
98
+ idxs.append(sql.Index(idx_name, *system_cols))
99
+ # v_min/v_max indices: speeds up base table scans needed to propagate a base table insert or delete
100
+ idx_name = f'vmin_idx_{self.tbl_version.id.hex}'
101
+ idxs.append(sql.Index(idx_name, self.v_min_col, postgresql_using='brin'))
102
+ idx_name = f'vmax_idx_{self.tbl_version.id.hex}'
103
+ idxs.append(sql.Index(idx_name, self.v_max_col, postgresql_using='brin'))
104
+
105
+ self.sa_tbl = sql.Table(self._storage_name(), self.sa_md, *all_cols, *idxs)
106
+
107
+ @abc.abstractmethod
108
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
109
+ """Return predicate for rowid joins to all bases"""
110
+ pass
111
+
112
+ @abc.abstractmethod
113
+ def _storage_name(self) -> str:
114
+ """Return the name of the data store table"""
115
+ pass
116
+
117
+ def _move_tmp_media_file(self, file_url: Optional[str], col: catalog.Column, v_min: int) -> str:
118
+ """Move tmp media file with given url to Env.media_dir and return new url, or given url if not a tmp_dir file"""
119
+ pxt_tmp_dir = str(env.Env.get().tmp_dir)
120
+ if file_url is None:
121
+ return None
122
+ parsed = urllib.parse.urlparse(file_url)
123
+ # We should never be passed a local file path here. The "len > 1" ensures that Windows
124
+ # file paths aren't mistaken for URLs with a single-character scheme.
125
+ assert len(parsed.scheme) > 1
126
+ if parsed.scheme != 'file':
127
+ # remote url
128
+ return file_url
129
+ file_path = urllib.parse.unquote(urllib.request.url2pathname(parsed.path))
130
+ if not file_path.startswith(pxt_tmp_dir):
131
+ # not a tmp file
132
+ return file_url
133
+ _, ext = os.path.splitext(file_path)
134
+ new_path = str(MediaStore.prepare_media_path(self.tbl_version.id, col.id, v_min, ext=ext))
135
+ os.rename(file_path, new_path)
136
+ new_file_url = urllib.parse.urljoin('file:', urllib.request.pathname2url(new_path))
137
+ return new_file_url
138
+
139
+ def _move_tmp_media_files(
140
+ self, table_rows: List[Dict[str, Any]], media_cols: List[catalog.Column], v_min: int
141
+ ) -> None:
142
+ """Move tmp media files that we generated to a permanent location"""
143
+ for c in media_cols:
144
+ for table_row in table_rows:
145
+ file_url = table_row[c.store_name()]
146
+ table_row[c.store_name()] = self._move_tmp_media_file(file_url, c, v_min)
147
+
148
+ def _create_table_row(
149
+ self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, media_cols: List[catalog.Column],
150
+ exc_col_ids: Set[int], v_min: int
151
+ ) -> Tuple[Dict[str, Any], int]:
152
+ """Return Tuple[complete table row, # of exceptions] for insert()
153
+ Creates a row that includes the PK columns, with the values from input_row.pk.
154
+ Returns:
155
+ Tuple[complete table row, # of exceptions]
156
+ """
157
+ table_row, num_excs = row_builder.create_table_row(input_row, exc_col_ids)
158
+
159
+ assert input_row.pk is not None and len(input_row.pk) == len(self._pk_columns)
160
+ for pk_col, pk_val in zip(self._pk_columns, input_row.pk):
161
+ if pk_col == self.v_min_col:
162
+ table_row[pk_col.name] = v_min
163
+ else:
164
+ table_row[pk_col.name] = pk_val
165
+
166
+ return table_row, num_excs
167
+
168
+ def count(self, conn: Optional[sql.engine.Connection] = None) -> int:
169
+ """Return the number of rows visible in self.tbl_version"""
170
+ stmt = sql.select(sql.func.count('*'))\
171
+ .select_from(self.sa_tbl)\
172
+ .where(self.v_min_col <= self.tbl_version.version)\
173
+ .where(self.v_max_col > self.tbl_version.version)
174
+ if conn is None:
175
+ with env.Env.get().engine.connect() as conn:
176
+ result = conn.execute(stmt).scalar_one()
177
+ else:
178
+ result = conn.execute(stmt).scalar_one()
179
+ assert isinstance(result, int)
180
+ return result
181
+
182
+ def create(self, conn: sql.engine.Connection) -> None:
183
+ self.sa_md.create_all(bind=conn)
184
+
185
+ def drop(self, conn: sql.engine.Connection) -> None:
186
+ """Drop store table"""
187
+ self.sa_md.drop_all(bind=conn)
188
+
189
+ def add_column(self, col: catalog.Column, conn: sql.engine.Connection) -> None:
190
+ """Add column(s) to the store-resident table based on a catalog column
191
+
192
+ Note that a computed catalog column will require two extra columns (for the computed value and for the error
193
+ message).
194
+ """
195
+ assert col.is_stored
196
+ col_type_str = col.get_sa_col_type().compile(dialect=conn.dialect)
197
+ stmt = sql.text(f'ALTER TABLE {self._storage_name()} ADD COLUMN {col.store_name()} {col_type_str} NULL')
198
+ log_stmt(_logger, stmt)
199
+ conn.execute(stmt)
200
+ added_storage_cols = [col.store_name()]
201
+ if col.records_errors:
202
+ # we also need to create the errormsg and errortype storage cols
203
+ stmt = (f'ALTER TABLE {self._storage_name()} '
204
+ f'ADD COLUMN {col.errormsg_store_name()} VARCHAR DEFAULT NULL')
205
+ conn.execute(sql.text(stmt))
206
+ stmt = (f'ALTER TABLE {self._storage_name()} '
207
+ f'ADD COLUMN {col.errortype_store_name()} VARCHAR DEFAULT NULL')
208
+ conn.execute(sql.text(stmt))
209
+ added_storage_cols.extend([col.errormsg_store_name(), col.errortype_store_name()])
210
+ self.create_sa_tbl()
211
+ _logger.info(f'Added columns {added_storage_cols} to storage table {self._storage_name()}')
212
+
213
+ def drop_column(self, col: catalog.Column, conn: sql.engine.Connection) -> None:
214
+ """Execute Alter Table Drop Column statement"""
215
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.store_name()}'
216
+ conn.execute(sql.text(stmt))
217
+ if col.records_errors:
218
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errormsg_store_name()}'
219
+ conn.execute(sql.text(stmt))
220
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errortype_store_name()}'
221
+ conn.execute(sql.text(stmt))
222
+
223
+ def load_column(
224
+ self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, conn: sql.engine.Connection
225
+ ) -> int:
226
+ """Update store column of a computed column with values produced by an execution plan
227
+
228
+ Returns:
229
+ number of rows with exceptions
230
+ Raises:
231
+ sql.exc.DBAPIError if there was an error during SQL execution
232
+ """
233
+ num_excs = 0
234
+ num_rows = 0
235
+ for row_batch in exec_plan:
236
+ num_rows += len(row_batch)
237
+ for result_row in row_batch:
238
+ values_dict: Dict[sql.Column, Any] = {}
239
+
240
+ if col.is_computed:
241
+ if result_row.has_exc(value_expr_slot_idx):
242
+ num_excs += 1
243
+ value_exc = result_row.get_exc(value_expr_slot_idx)
244
+ # we store a NULL value and record the exception/exc type
245
+ error_type = type(value_exc).__name__
246
+ error_msg = str(value_exc)
247
+ values_dict = {
248
+ col.sa_col: None,
249
+ col.sa_errortype_col: error_type,
250
+ col.sa_errormsg_col: error_msg
251
+ }
252
+ else:
253
+ val = result_row.get_stored_val(value_expr_slot_idx, col.sa_col.type)
254
+ if col.col_type.is_media_type():
255
+ val = self._move_tmp_media_file(val, col, result_row.pk[-1])
256
+ values_dict = {col.sa_col: val}
257
+
258
+ update_stmt = sql.update(self.sa_tbl).values(values_dict)
259
+ for pk_col, pk_val in zip(self.pk_columns(), result_row.pk):
260
+ update_stmt = update_stmt.where(pk_col == pk_val)
261
+ log_stmt(_logger, update_stmt)
262
+ conn.execute(update_stmt)
263
+
264
+ return num_excs
265
+
266
+ def insert_rows(
267
+ self, exec_plan: ExecNode, conn: sql.engine.Connection, v_min: Optional[int] = None
268
+ ) -> Tuple[int, int, Set[int]]:
269
+ """Insert rows into the store table and update the catalog table's md
270
+ Returns:
271
+ number of inserted rows, number of exceptions, set of column ids that have exceptions
272
+ """
273
+ assert v_min is not None
274
+ exec_plan.ctx.conn = conn
275
+ batch_size = 16 # TODO: is this a good batch size?
276
+ # TODO: total?
277
+ num_excs = 0
278
+ num_rows = 0
279
+ cols_with_excs: Set[int] = set()
280
+ progress_bar: Optional[tqdm] = None # create this only after we started executing
281
+ row_builder = exec_plan.row_builder
282
+ media_cols = [info.col for info in row_builder.table_columns if info.col.col_type.is_media_type()]
283
+ try:
284
+ exec_plan.open()
285
+ for row_batch in exec_plan:
286
+ num_rows += len(row_batch)
287
+ for batch_start_idx in range(0, len(row_batch), batch_size):
288
+ # compute batch of rows and convert them into table rows
289
+ table_rows: List[Dict[str, Any]] = []
290
+ for row_idx in range(batch_start_idx, min(batch_start_idx + batch_size, len(row_batch))):
291
+ row = row_batch[row_idx]
292
+ table_row, num_row_exc = \
293
+ self._create_table_row(row, row_builder, media_cols, cols_with_excs, v_min=v_min)
294
+ num_excs += num_row_exc
295
+ table_rows.append(table_row)
296
+ if progress_bar is None:
297
+ warnings.simplefilter("ignore", category=TqdmWarning)
298
+ progress_bar = tqdm(
299
+ desc=f'Inserting rows into `{self.tbl_version.name}`',
300
+ unit=' rows',
301
+ ncols=100,
302
+ file=sys.stdout
303
+ )
304
+ progress_bar.update(1)
305
+ self._move_tmp_media_files(table_rows, media_cols, v_min)
306
+ conn.execute(sql.insert(self.sa_tbl), table_rows)
307
+ if progress_bar is not None:
308
+ progress_bar.close()
309
+ return num_rows, num_excs, cols_with_excs
310
+ finally:
311
+ exec_plan.close()
312
+
313
+ def _versions_clause(self, versions: List[Optional[int]], match_on_vmin: bool) -> sql.ClauseElement:
314
+ """Return filter for base versions"""
315
+ v = versions[0]
316
+ if v is None:
317
+ # we're looking at live rows
318
+ clause = sql.and_(self.v_min_col <= self.tbl_version.version, self.v_max_col == schema.Table.MAX_VERSION)
319
+ else:
320
+ # we're looking at a specific version
321
+ clause = self.v_min_col == v if match_on_vmin else self.v_max_col == v
322
+ if len(versions) == 1:
323
+ return clause
324
+ return sql.and_(clause, self.base._versions_clause(versions[1:], match_on_vmin))
325
+
326
+ def delete_rows(
327
+ self, current_version: int, base_versions: List[Optional[int]], match_on_vmin: bool,
328
+ where_clause: Optional[sql.ClauseElement], conn: sql.engine.Connection) -> int:
329
+ """Mark rows as deleted that are live and were created prior to current_version.
330
+ Also: populate the undo columns
331
+ Args:
332
+ base_versions: if non-None, join only to base rows that were created at that version,
333
+ otherwise join to rows that are live in the base's current version (which is distinct from the
334
+ current_version parameter)
335
+ match_on_vmin: if True, match exact versions on v_min; if False, match on v_max
336
+ where_clause: if not None, also apply where_clause
337
+ Returns:
338
+ number of deleted rows
339
+ """
340
+ where_clause = sql.true() if where_clause is None else where_clause
341
+ where_clause = sql.and_(
342
+ self.v_min_col < current_version,
343
+ self.v_max_col == schema.Table.MAX_VERSION,
344
+ where_clause)
345
+ rowid_join_clause = self._rowid_join_predicate()
346
+ base_versions_clause = sql.true() if len(base_versions) == 0 \
347
+ else self.base._versions_clause(base_versions, match_on_vmin)
348
+ set_clause = {self.v_max_col: current_version}
349
+ for index_info in self.tbl_version.idxs_by_name.values():
350
+ # copy value column to undo column
351
+ set_clause[index_info.undo_col.sa_col] = index_info.val_col.sa_col
352
+ # set value column to NULL
353
+ set_clause[index_info.val_col.sa_col] = None
354
+ stmt = sql.update(self.sa_tbl) \
355
+ .values(set_clause) \
356
+ .where(where_clause) \
357
+ .where(rowid_join_clause) \
358
+ .where(base_versions_clause)
359
+ log_explain(_logger, stmt, conn)
360
+ status = conn.execute(stmt)
361
+ return status.rowcount
362
+
363
+
364
+ class StoreTable(StoreBase):
365
+ def __init__(self, tbl_version: catalog.TableVersion):
366
+ assert not tbl_version.is_view()
367
+ super().__init__(tbl_version)
368
+
369
+ def _create_rowid_columns(self) -> List[sql.Column]:
370
+ self.rowid_col = sql.Column('rowid', sql.BigInteger, nullable=False)
371
+ return [self.rowid_col]
372
+
373
+ def _storage_name(self) -> str:
374
+ return f'tbl_{self.tbl_version.id.hex}'
375
+
376
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
377
+ return sql.true()
378
+
379
+
380
+ class StoreView(StoreBase):
381
+ def __init__(self, catalog_view: catalog.TableVersion):
382
+ assert catalog_view.is_view()
383
+ self.base = catalog_view.base.store_tbl
384
+ super().__init__(catalog_view)
385
+
386
+ def _create_rowid_columns(self) -> List[sql.Column]:
387
+ # a view row corresponds directly to a single base row, which means it needs to duplicate its rowid columns
388
+ self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
389
+ return self.rowid_cols
390
+
391
+ def _storage_name(self) -> str:
392
+ return f'view_{self.tbl_version.id.hex}'
393
+
394
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
395
+ return sql.and_(
396
+ self.base._rowid_join_predicate(),
397
+ *[c1 == c2 for c1, c2 in zip(self.rowid_columns(), self.base.rowid_columns())])
398
+
399
+ class StoreComponentView(StoreView):
400
+ """A view that stores components of its base, as produced by a ComponentIterator
401
+
402
+ PK: now also includes pos, the position returned by the ComponentIterator for the base row identified by base_rowid
403
+ """
404
+ def __init__(self, catalog_view: catalog.TableVersion):
405
+ super().__init__(catalog_view)
406
+
407
+ def _create_rowid_columns(self) -> List[sql.Column]:
408
+ # each base row is expanded into n view rows
409
+ self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
410
+ # name of pos column: avoid collisions with bases' pos columns
411
+ self.pos_col = sql.Column(f'pos_{len(self.rowid_cols) - 1}', sql.BigInteger, nullable=False)
412
+ self.pos_col_idx = len(self.rowid_cols)
413
+ self.rowid_cols.append(self.pos_col)
414
+ return self.rowid_cols
415
+
416
+ def create_sa_tbl(self) -> None:
417
+ super().create_sa_tbl()
418
+ # we need to fix up the 'pos' column in TableVersion
419
+ self.tbl_version.cols_by_name['pos'].sa_col = self.pos_col
420
+
421
+ def _rowid_join_predicate(self) -> sql.ClauseElement:
422
+ return sql.and_(
423
+ self.base._rowid_join_predicate(),
424
+ *[c1 == c2 for c1, c2 in zip(self.rowid_columns()[:-1], self.base.rowid_columns())])
@@ -0,0 +1,184 @@
1
+ import datetime
2
+ import json
3
+ import logging
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ import pgserver
9
+ import toml
10
+
11
+ import pixeltable as pxt
12
+ import pixeltable.metadata as metadata
13
+ from pixeltable.env import Env
14
+ from pixeltable.func import Batch
15
+ from pixeltable.type_system import \
16
+ StringType, IntType, FloatType, BoolType, TimestampType, JsonType
17
+
18
+ _logger = logging.getLogger('pixeltable')
19
+
20
+
21
+ class Dumper:
22
+
23
+ def __init__(self, output_dir='target', db_name='pxtdump') -> None:
24
+ self.output_dir = pathlib.Path(output_dir)
25
+ shared_home = pathlib.Path(os.environ.get('PIXELTABLE_HOME', '~/.pixeltable')).expanduser()
26
+ mock_home_dir = self.output_dir / '.pixeltable'
27
+ mock_home_dir.mkdir(parents=True, exist_ok=True)
28
+ os.environ['PIXELTABLE_HOME'] = str(mock_home_dir)
29
+ os.environ['PIXELTABLE_CONFIG'] = str(shared_home / 'config.yaml')
30
+ os.environ['PIXELTABLE_DB'] = db_name
31
+ os.environ['PIXELTABLE_PGDATA'] = str(shared_home / 'pgdata')
32
+
33
+ Env.get().configure_logging(level=logging.DEBUG, to_stdout=True)
34
+
35
+ def dump_db(self) -> None:
36
+ md_version = metadata.VERSION
37
+ dump_file = self.output_dir / f'pixeltable-v{md_version:03d}-test.dump.gz'
38
+ _logger.info(f'Creating database dump at: {dump_file}')
39
+ pg_package_dir = os.path.dirname(pgserver.__file__)
40
+ pg_dump_binary = f'{pg_package_dir}/pginstall/bin/pg_dump'
41
+ _logger.info(f'Using pg_dump binary at: {pg_dump_binary}')
42
+ with open(dump_file, 'wb') as dump:
43
+ pg_dump_process = subprocess.Popen(
44
+ [pg_dump_binary, Env.get().db_url, '-U', 'postgres', '-Fc'],
45
+ stdout=subprocess.PIPE
46
+ )
47
+ subprocess.run(
48
+ ["gzip", "-9"],
49
+ stdin=pg_dump_process.stdout,
50
+ stdout=dump,
51
+ check=True
52
+ )
53
+ info_file = self.output_dir / f'pixeltable-v{md_version:03d}-test-info.toml'
54
+ git_sha = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
55
+ user = os.environ.get('USER', os.environ.get('USERNAME'))
56
+ info_dict = {'pixeltable-dump': {
57
+ 'metadata-version': md_version,
58
+ 'git-sha': git_sha,
59
+ 'datetime': datetime.datetime.utcnow(),
60
+ 'user': user
61
+ }}
62
+ with open(info_file, 'w') as info:
63
+ toml.dump(info_dict, info)
64
+
65
+ # TODO: Add additional features to the test DB dump (ideally it should exercise
66
+ # every major pixeltable DB feature)
67
+ def create_tables(self) -> None:
68
+ schema = {
69
+ 'c1': StringType(nullable=False),
70
+ 'c1n': StringType(nullable=True),
71
+ 'c2': IntType(nullable=False),
72
+ 'c3': FloatType(nullable=False),
73
+ 'c4': BoolType(nullable=False),
74
+ 'c5': TimestampType(nullable=False),
75
+ 'c6': JsonType(nullable=False),
76
+ 'c7': JsonType(nullable=False),
77
+ }
78
+ t = pxt.create_table('sample_table', schema, primary_key='c2')
79
+
80
+ # Add columns for InlineArray and InlineDict
81
+ t.add_column(c8=[[1, 2, 3], [4, 5, 6]])
82
+ t.add_column(c9=[['a', 'b', 'c'], ['d', 'e', 'f']])
83
+ t.add_column(c10=[t.c1, [t.c1n, t.c2]])
84
+ t.add_column(c11={'int': 22, 'dict': {'key': 'val'}, 'expr': t.c1})
85
+
86
+ # InPredicate
87
+ t.add_column(isin_1=t.c1.isin(['test string 1', 'test string 2', 'test string 3']))
88
+ t.add_column(isin_2=t.c2.isin([1, 2, 3, 4, 5]))
89
+ t.add_column(isin_3=t.c2.isin(t.c6.f5))
90
+
91
+ # Add columns for .astype converters to ensure they're persisted properly
92
+ t.add_column(c2_as_float=t.c2.astype(FloatType()))
93
+
94
+ # Add columns for .apply
95
+ t.add_column(c2_to_string=t.c2.apply(str))
96
+ t.add_column(c6_to_string=t.c6.apply(json.dumps))
97
+ t.add_column(c6_back_to_json=t.c6_to_string.apply(json.loads))
98
+
99
+ num_rows = 100
100
+ d1 = {
101
+ 'f1': 'test string 1',
102
+ 'f2': 1,
103
+ 'f3': 1.0,
104
+ 'f4': True,
105
+ 'f5': [1.0, 2.0, 3.0, 4.0],
106
+ 'f6': {
107
+ 'f7': 'test string 2',
108
+ 'f8': [1.0, 2.0, 3.0, 4.0],
109
+ },
110
+ }
111
+ d2 = [d1, d1]
112
+
113
+ c1_data = [f'test string {i}' for i in range(num_rows)]
114
+ c2_data = [i for i in range(num_rows)]
115
+ c3_data = [float(i) for i in range(num_rows)]
116
+ c4_data = [bool(i % 2) for i in range(num_rows)]
117
+ c5_data = [datetime.datetime.now()] * num_rows
118
+ c6_data = []
119
+ for i in range(num_rows):
120
+ d = {
121
+ 'f1': f'test string {i}',
122
+ 'f2': i,
123
+ 'f3': float(i),
124
+ 'f4': bool(i % 2),
125
+ 'f5': [1.0, 2.0, 3.0, 4.0],
126
+ 'f6': {
127
+ 'f7': 'test string 2',
128
+ 'f8': [1.0, 2.0, 3.0, 4.0],
129
+ },
130
+ }
131
+ c6_data.append(d)
132
+
133
+ c7_data = [d2] * num_rows
134
+ rows = [
135
+ {
136
+ 'c1': c1_data[i],
137
+ 'c1n': c1_data[i] if i % 10 != 0 else None,
138
+ 'c2': c2_data[i],
139
+ 'c3': c3_data[i],
140
+ 'c4': c4_data[i],
141
+ 'c5': c5_data[i],
142
+ 'c6': c6_data[i],
143
+ 'c7': c7_data[i],
144
+ }
145
+ for i in range(num_rows)
146
+ ]
147
+ t.insert(rows)
148
+ pxt.create_dir('views')
149
+ v = pxt.create_view('views.sample_view', t, filter=(t.c2 < 50))
150
+ _ = pxt.create_view('views.sample_snapshot', t, filter=(t.c2 >= 75), is_snapshot=True)
151
+ e = pxt.create_view('views.empty_view', t, filter=t.c2 == 4171780)
152
+ assert e.count() == 0
153
+ # Computed column using a library function
154
+ v['str_format'] = pxt.functions.string.str_format('{0} {key}', t.c1, key=t.c1)
155
+ # Computed column using a bespoke stored udf
156
+ v['test_udf'] = test_udf_stored(t.c2)
157
+ # Computed column using a batched function
158
+ # (apply this to the empty view, since it's a "heavyweight" function)
159
+ e['batched'] = pxt.functions.huggingface.clip_text(t.c1, model_id='openai/clip-vit-base-patch32')
160
+ # computed column using a stored batched function
161
+ v['test_udf_batched'] = test_udf_stored_batched(t.c1, upper=False)
162
+ # astype
163
+ v['astype'] = t.c1.astype(pxt.FloatType())
164
+
165
+
166
+ @pxt.udf(_force_stored=True)
167
+ def test_udf_stored(n: int) -> int:
168
+ return n + 1
169
+
170
+
171
+ @pxt.udf(batch_size=4, _force_stored=True)
172
+ def test_udf_stored_batched(strings: Batch[str], *, upper: bool = True) -> Batch[str]:
173
+ return [string.upper() if upper else string.lower() for string in strings]
174
+
175
+
176
+ def main() -> None:
177
+ _logger.info("Creating pixeltable test artifact.")
178
+ dumper = Dumper()
179
+ dumper.create_tables()
180
+ dumper.dump_db()
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
@@ -0,0 +1,81 @@
1
+ import av
2
+ import PIL.Image
3
+ import PIL.ImageDraw
4
+ import PIL.ImageFont
5
+
6
+ from pathlib import Path
7
+ from typing import Optional
8
+ import tempfile
9
+ import math
10
+
11
+ def create_test_video(
12
+ frame_count: int,
13
+ frame_rate: float = 1.0,
14
+ frame_width: int = 224,
15
+ aspect_ratio: str = '16:9',
16
+ frame_height: Optional[int] = None,
17
+ output_path: Optional[Path] = None,
18
+ font_file: str = '/Library/Fonts/Arial Unicode.ttf',
19
+ ) -> Path:
20
+ """
21
+ Creates an .mp4 video file such as the ones in /tests/data/test_videos
22
+ The video contains a frame number in each frame (for visual sanity check).
23
+
24
+ Args:
25
+ frame_count: Number of frames to create
26
+ frame_rate: Frame rate of the video
27
+ frame_width (int): Width in pixels of the video frame. Note: cost of decoding increases dramatically
28
+ with frame width * frame height.
29
+ aspect_ratio: Aspect ratio (width/height) of the video frames string of form 'width:height'
30
+ frame_height: Height of the video frame, if given, aspect_ratio is ignored
31
+ output_path: Path to save the video file
32
+ font_file: Path to the font file used for text.
33
+ """
34
+
35
+ if output_path is None:
36
+ output_path = Path(tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name)
37
+
38
+ parts = [int(p) for p in aspect_ratio.split(':')]
39
+ assert len(parts) == 2
40
+ aspect_ratio = parts[0] / parts[1]
41
+
42
+ if frame_height is None:
43
+ frame_height = math.ceil(frame_width / aspect_ratio)
44
+
45
+ frame_size = (frame_width, frame_height)
46
+
47
+ font_size = min(frame_height, frame_width) // 4
48
+ font = PIL.ImageFont.truetype(font=font_file, size=font_size)
49
+ font_fill = 0xFFFFFF # white
50
+ frame_color = 0xFFFFFF - font_fill # black
51
+ # Create a video container
52
+ container = av.open(str(output_path), mode='w')
53
+
54
+ # Add a video stream
55
+ stream = container.add_stream('h264', rate=frame_rate)
56
+ stream.width, stream.height = frame_size
57
+ stream.pix_fmt = 'yuv420p'
58
+
59
+ for frame_number in range(frame_count):
60
+ # Create an image with a number in it
61
+ image = PIL.Image.new('RGB', frame_size, color=frame_color)
62
+ draw = PIL.ImageDraw.Draw(image)
63
+ # Optionally, add a font here if you have one
64
+ text = str(frame_number)
65
+ _, _, text_width, text_height = draw.textbbox((0, 0), text, font=font)
66
+ text_position = ((frame_size[0] - text_width) // 2, (frame_size[1] - text_height) // 2)
67
+ draw.text(text_position, text, font=font, fill=font_fill)
68
+
69
+ # Convert the PIL image to an AVFrame
70
+ frame = av.VideoFrame.from_image(image)
71
+
72
+ # Encode and write the frame
73
+ for packet in stream.encode(frame):
74
+ container.mux(packet)
75
+
76
+ # Flush and close the stream
77
+ for packet in stream.encode():
78
+ container.mux(packet)
79
+
80
+ container.close()
81
+ return output_path