pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/catalog/column.py +26 -49
- pixeltable/catalog/insertable_table.py +7 -4
- pixeltable/catalog/table.py +163 -57
- pixeltable/catalog/table_version.py +416 -140
- pixeltable/catalog/table_version_path.py +2 -2
- pixeltable/client.py +72 -6
- pixeltable/dataframe.py +65 -21
- pixeltable/env.py +52 -53
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/in_memory_data_node.py +11 -7
- pixeltable/exprs/comparison.py +3 -3
- pixeltable/exprs/data_row.py +5 -1
- pixeltable/exprs/literal.py +16 -4
- pixeltable/exprs/row_builder.py +8 -40
- pixeltable/ext/__init__.py +5 -0
- pixeltable/ext/functions/yolox.py +92 -0
- pixeltable/func/aggregate_function.py +15 -15
- pixeltable/func/expr_template_function.py +9 -1
- pixeltable/func/globals.py +24 -14
- pixeltable/func/signature.py +18 -12
- pixeltable/func/udf.py +7 -2
- pixeltable/functions/__init__.py +9 -9
- pixeltable/functions/eval.py +7 -8
- pixeltable/functions/fireworks.py +10 -37
- pixeltable/functions/huggingface.py +47 -19
- pixeltable/functions/openai.py +192 -24
- pixeltable/functions/together.py +104 -9
- pixeltable/functions/util.py +11 -0
- pixeltable/index/__init__.py +2 -0
- pixeltable/index/base.py +49 -0
- pixeltable/index/embedding_index.py +95 -0
- pixeltable/metadata/schema.py +45 -22
- pixeltable/plan.py +15 -34
- pixeltable/store.py +38 -41
- pixeltable/tests/conftest.py +8 -14
- pixeltable/tests/ext/test_yolox.py +21 -0
- pixeltable/tests/functions/test_fireworks.py +43 -0
- pixeltable/tests/functions/test_functions.py +60 -0
- pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
- pixeltable/tests/functions/test_openai.py +162 -0
- pixeltable/tests/functions/test_together.py +112 -0
- pixeltable/tests/test_component_view.py +14 -5
- pixeltable/tests/test_dataframe.py +23 -22
- pixeltable/tests/test_exprs.py +99 -102
- pixeltable/tests/test_function.py +51 -43
- pixeltable/tests/test_index.py +138 -0
- pixeltable/tests/test_migration.py +2 -1
- pixeltable/tests/test_snapshot.py +24 -1
- pixeltable/tests/test_table.py +205 -26
- pixeltable/tests/test_types.py +30 -0
- pixeltable/tests/test_video.py +16 -16
- pixeltable/tests/test_view.py +5 -0
- pixeltable/tests/utils.py +171 -14
- pixeltable/tool/create_test_db_dump.py +16 -0
- pixeltable/type_system.py +77 -128
- pixeltable/utils/arrow.py +98 -0
- pixeltable/utils/hf_datasets.py +157 -0
- pixeltable/utils/parquet.py +68 -27
- pixeltable/utils/pytorch.py +16 -97
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
pixeltable/plan.py
CHANGED
|
@@ -76,7 +76,8 @@ class Analyzer:
|
|
|
76
76
|
f'order_by()'))
|
|
77
77
|
self.similarity_clause = similarity_clauses[0]
|
|
78
78
|
img_col = self.similarity_clause.img_col_ref.col
|
|
79
|
-
|
|
79
|
+
indexed_col_ids = {info.col.id for info in tbl.tbl_version.idxs_by_name.values()}
|
|
80
|
+
if img_col.id not in indexed_col_ids:
|
|
80
81
|
raise excs.Error(f'nearest() not available for unindexed column {img_col.name}')
|
|
81
82
|
|
|
82
83
|
# all exprs that are evaluated in Python; not executable
|
|
@@ -220,18 +221,11 @@ class Planner:
|
|
|
220
221
|
) -> exec.ExecNode:
|
|
221
222
|
"""Creates a plan for TableVersion.insert()"""
|
|
222
223
|
assert not tbl.is_view()
|
|
223
|
-
#
|
|
224
|
-
# 1. stored_cols: all cols we need to store, incl computed cols (and indices)
|
|
224
|
+
# stored_cols: all cols we need to store, incl computed cols (and indices)
|
|
225
225
|
stored_cols = [c for c in tbl.cols if c.is_stored]
|
|
226
226
|
assert len(stored_cols) > 0
|
|
227
|
-
# 2. values to insert into indices
|
|
228
|
-
indexed_cols = [c for c in tbl.cols if c.is_indexed]
|
|
229
|
-
index_info: List[Tuple[catalog.Column, func.Function]] = []
|
|
230
|
-
if len(indexed_cols) > 0:
|
|
231
|
-
from pixeltable.functions.nos.image_embedding import openai_clip
|
|
232
|
-
index_info = [(c, openai_clip) for c in tbl.cols if c.is_indexed]
|
|
233
227
|
|
|
234
|
-
row_builder = exprs.RowBuilder([], stored_cols,
|
|
228
|
+
row_builder = exprs.RowBuilder([], stored_cols, [])
|
|
235
229
|
|
|
236
230
|
# create InMemoryDataNode for 'rows'
|
|
237
231
|
stored_col_info = row_builder.output_slot_idxs()
|
|
@@ -260,7 +254,7 @@ class Planner:
|
|
|
260
254
|
@classmethod
|
|
261
255
|
def create_update_plan(
|
|
262
256
|
cls, tbl: catalog.TableVersionPath,
|
|
263
|
-
update_targets:
|
|
257
|
+
update_targets: dict[catalog.Column, exprs.Expr],
|
|
264
258
|
recompute_targets: List[catalog.Column],
|
|
265
259
|
where_clause: Optional[exprs.Predicate], cascade: bool
|
|
266
260
|
) -> Tuple[exec.ExecNode, List[str], List[catalog.Column]]:
|
|
@@ -279,7 +273,7 @@ class Planner:
|
|
|
279
273
|
# retrieve all stored cols and all target exprs
|
|
280
274
|
assert isinstance(tbl, catalog.TableVersionPath)
|
|
281
275
|
target = tbl.tbl_version # the one we need to update
|
|
282
|
-
updated_cols =
|
|
276
|
+
updated_cols = list(update_targets.keys())
|
|
283
277
|
if len(recompute_targets) > 0:
|
|
284
278
|
recomputed_cols = recompute_targets.copy()
|
|
285
279
|
else:
|
|
@@ -291,12 +285,12 @@ class Planner:
|
|
|
291
285
|
col for col in target.cols if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
|
|
292
286
|
]
|
|
293
287
|
select_list = [exprs.ColumnRef(col) for col in copied_cols]
|
|
294
|
-
select_list.extend(
|
|
288
|
+
select_list.extend(update_targets.values())
|
|
295
289
|
|
|
296
290
|
recomputed_exprs = \
|
|
297
291
|
[c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols]
|
|
298
292
|
# recomputed cols reference the new values of the updated cols
|
|
299
|
-
for col, e in update_targets:
|
|
293
|
+
for col, e in update_targets.items():
|
|
300
294
|
exprs.Expr.list_substitute(recomputed_exprs, exprs.ColumnRef(col), e)
|
|
301
295
|
select_list.extend(recomputed_exprs)
|
|
302
296
|
|
|
@@ -375,16 +369,10 @@ class Planner:
|
|
|
375
369
|
# the store
|
|
376
370
|
target = view.tbl_version # the one we need to populate
|
|
377
371
|
stored_cols = [c for c in target.cols if c.is_stored and (c.is_computed or target.is_iterator_column(c))]
|
|
378
|
-
# 2.
|
|
379
|
-
indexed_cols = [c for c in target.cols if c.is_indexed]
|
|
380
|
-
index_info: List[Tuple[catalog.Column, func.Function]] = []
|
|
381
|
-
if len(indexed_cols) > 0:
|
|
382
|
-
from pixeltable.functions.nos.image_embedding import openai_clip
|
|
383
|
-
index_info = [(c, openai_clip) for c in target.cols if c.is_indexed]
|
|
384
|
-
# 3. for component views: iterator args
|
|
372
|
+
# 2. for component views: iterator args
|
|
385
373
|
iterator_args = [target.iterator_args] if target.iterator_args is not None else []
|
|
386
374
|
|
|
387
|
-
row_builder = exprs.RowBuilder(iterator_args, stored_cols,
|
|
375
|
+
row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
|
|
388
376
|
|
|
389
377
|
# execution plan:
|
|
390
378
|
# 1. materialize exprs computed from the base that are needed for stored view columns
|
|
@@ -548,7 +536,7 @@ class Planner:
|
|
|
548
536
|
analyzer = Analyzer(
|
|
549
537
|
tbl, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
|
|
550
538
|
order_by_clause=order_by_clause)
|
|
551
|
-
row_builder = exprs.RowBuilder(analyzer.all_exprs, [],
|
|
539
|
+
row_builder = exprs.RowBuilder(analyzer.all_exprs, [], analyzer.sql_exprs)
|
|
552
540
|
|
|
553
541
|
analyzer.finalize(row_builder)
|
|
554
542
|
# select_list: we need to materialize everything that's been collected
|
|
@@ -627,21 +615,15 @@ class Planner:
|
|
|
627
615
|
@classmethod
|
|
628
616
|
def create_add_column_plan(
|
|
629
617
|
cls, tbl: catalog.TableVersionPath, col: catalog.Column
|
|
630
|
-
) -> Tuple[exec.ExecNode, Optional[int]
|
|
618
|
+
) -> Tuple[exec.ExecNode, Optional[int]]:
|
|
631
619
|
"""Creates a plan for InsertableTable.add_column()
|
|
632
620
|
Returns:
|
|
633
621
|
plan: the plan to execute
|
|
634
|
-
ctx: the context to use for the plan
|
|
635
622
|
value_expr slot idx for the plan output (for computed cols)
|
|
636
|
-
embedding slot idx for the plan output (for indexed image cols)
|
|
637
623
|
"""
|
|
638
624
|
assert isinstance(tbl, catalog.TableVersionPath)
|
|
639
625
|
index_info: List[Tuple[catalog.Column, func.Function]] = []
|
|
640
|
-
|
|
641
|
-
from pixeltable.functions.nos.image_embedding import openai_clip
|
|
642
|
-
index_info = [(col, openai_clip)]
|
|
643
|
-
row_builder = exprs.RowBuilder(
|
|
644
|
-
output_exprs=[], columns=[col], indices=index_info, input_exprs=[])
|
|
626
|
+
row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[])
|
|
645
627
|
analyzer = Analyzer(tbl, row_builder.default_eval_ctx.target_exprs)
|
|
646
628
|
plan = cls._create_query_plan(tbl, row_builder=row_builder, analyzer=analyzer, with_pk=True)
|
|
647
629
|
plan.ctx.batch_size = 16
|
|
@@ -651,6 +633,5 @@ class Planner:
|
|
|
651
633
|
# we want to flush images
|
|
652
634
|
if col.is_computed and col.is_stored and col.col_type.is_image_type():
|
|
653
635
|
plan.set_stored_img_cols(row_builder.output_slot_idxs())
|
|
654
|
-
value_expr_slot_idx
|
|
655
|
-
|
|
656
|
-
return plan, value_expr_slot_idx, embedding_slot_idx
|
|
636
|
+
value_expr_slot_idx = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
|
|
637
|
+
return plan, value_expr_slot_idx
|
pixeltable/store.py
CHANGED
|
@@ -38,7 +38,7 @@ class StoreBase:
|
|
|
38
38
|
self.tbl_version = tbl_version
|
|
39
39
|
self.sa_md = sql.MetaData()
|
|
40
40
|
self.sa_tbl: Optional[sql.Table] = None
|
|
41
|
-
self.
|
|
41
|
+
self.create_sa_tbl()
|
|
42
42
|
|
|
43
43
|
def pk_columns(self) -> List[sql.Column]:
|
|
44
44
|
return self._pk_columns
|
|
@@ -62,7 +62,7 @@ class StoreBase:
|
|
|
62
62
|
return [*rowid_cols, self.v_min_col, self.v_max_col]
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
def
|
|
65
|
+
def create_sa_tbl(self) -> None:
|
|
66
66
|
"""Create self.sa_tbl from self.tbl_version."""
|
|
67
67
|
system_cols = self._create_system_columns()
|
|
68
68
|
all_cols = system_cols.copy()
|
|
@@ -76,9 +76,6 @@ class StoreBase:
|
|
|
76
76
|
all_cols.append(col.sa_errormsg_col)
|
|
77
77
|
all_cols.append(col.sa_errortype_col)
|
|
78
78
|
|
|
79
|
-
if col.is_indexed:
|
|
80
|
-
all_cols.append(col.sa_idx_col)
|
|
81
|
-
|
|
82
79
|
# we create an index for:
|
|
83
80
|
# - scalar columns (except for strings, because long strings can't be used for B-tree indices)
|
|
84
81
|
# - non-computed video and image columns (they will contain external paths/urls that users might want to
|
|
@@ -145,8 +142,8 @@ class StoreBase:
|
|
|
145
142
|
"""Move tmp media files that we generated to a permanent location"""
|
|
146
143
|
for c in media_cols:
|
|
147
144
|
for table_row in table_rows:
|
|
148
|
-
file_url = table_row[c.
|
|
149
|
-
table_row[c.
|
|
145
|
+
file_url = table_row[c.store_name()]
|
|
146
|
+
table_row[c.store_name()] = self._move_tmp_media_file(file_url, c, v_min)
|
|
150
147
|
|
|
151
148
|
def _create_table_row(
|
|
152
149
|
self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, media_cols: List[catalog.Column],
|
|
@@ -168,16 +165,19 @@ class StoreBase:
|
|
|
168
165
|
|
|
169
166
|
return table_row, num_excs
|
|
170
167
|
|
|
171
|
-
def count(self) ->
|
|
168
|
+
def count(self, conn: Optional[sql.engine.Connection] = None) -> int:
|
|
172
169
|
"""Return the number of rows visible in self.tbl_version"""
|
|
173
170
|
stmt = sql.select(sql.func.count('*'))\
|
|
174
171
|
.select_from(self.sa_tbl)\
|
|
175
172
|
.where(self.v_min_col <= self.tbl_version.version)\
|
|
176
173
|
.where(self.v_max_col > self.tbl_version.version)
|
|
177
|
-
|
|
174
|
+
if conn is None:
|
|
175
|
+
with env.Env.get().engine.connect() as conn:
|
|
176
|
+
result = conn.execute(stmt).scalar_one()
|
|
177
|
+
else:
|
|
178
178
|
result = conn.execute(stmt).scalar_one()
|
|
179
|
-
|
|
180
|
-
|
|
179
|
+
assert isinstance(result, int)
|
|
180
|
+
return result
|
|
181
181
|
|
|
182
182
|
def create(self, conn: sql.engine.Connection) -> None:
|
|
183
183
|
self.sa_md.create_all(bind=conn)
|
|
@@ -193,38 +193,35 @@ class StoreBase:
|
|
|
193
193
|
message).
|
|
194
194
|
"""
|
|
195
195
|
assert col.is_stored
|
|
196
|
-
|
|
196
|
+
col_type_str = col.get_sa_col_type().compile(dialect=conn.dialect)
|
|
197
|
+
stmt = sql.text(f'ALTER TABLE {self._storage_name()} ADD COLUMN {col.store_name()} {col_type_str} NULL')
|
|
197
198
|
log_stmt(_logger, stmt)
|
|
198
199
|
conn.execute(stmt)
|
|
199
|
-
added_storage_cols = [col.
|
|
200
|
+
added_storage_cols = [col.store_name()]
|
|
200
201
|
if col.records_errors:
|
|
201
202
|
# we also need to create the errormsg and errortype storage cols
|
|
202
203
|
stmt = (f'ALTER TABLE {self._storage_name()} '
|
|
203
|
-
f'ADD COLUMN {col.
|
|
204
|
+
f'ADD COLUMN {col.errormsg_store_name()} {StringType().to_sql()} DEFAULT NULL')
|
|
204
205
|
conn.execute(sql.text(stmt))
|
|
205
206
|
stmt = (f'ALTER TABLE {self._storage_name()} '
|
|
206
|
-
f'ADD COLUMN {col.
|
|
207
|
+
f'ADD COLUMN {col.errortype_store_name()} {StringType().to_sql()} DEFAULT NULL')
|
|
207
208
|
conn.execute(sql.text(stmt))
|
|
208
|
-
|
|
209
|
-
self.
|
|
209
|
+
added_storage_cols.extend([col.errormsg_store_name(), col.errortype_store_name()])
|
|
210
|
+
self.create_sa_tbl()
|
|
210
211
|
_logger.info(f'Added columns {added_storage_cols} to storage table {self._storage_name()}')
|
|
211
212
|
|
|
212
|
-
def drop_column(self, col:
|
|
213
|
-
"""
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
213
|
+
def drop_column(self, col: catalog.Column, conn: sql.engine.Connection) -> None:
|
|
214
|
+
"""Execute Alter Table Drop Column statement"""
|
|
215
|
+
stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.store_name()}'
|
|
216
|
+
conn.execute(sql.text(stmt))
|
|
217
|
+
if col.records_errors:
|
|
218
|
+
stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errormsg_store_name()}'
|
|
219
|
+
conn.execute(sql.text(stmt))
|
|
220
|
+
stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errortype_store_name()}'
|
|
217
221
|
conn.execute(sql.text(stmt))
|
|
218
|
-
if col.records_errors:
|
|
219
|
-
stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errormsg_storage_name()}'
|
|
220
|
-
conn.execute(sql.text(stmt))
|
|
221
|
-
stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errortype_storage_name()}'
|
|
222
|
-
conn.execute(sql.text(stmt))
|
|
223
|
-
self._create_sa_tbl()
|
|
224
222
|
|
|
225
223
|
def load_column(
|
|
226
|
-
self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int,
|
|
227
|
-
conn: sql.engine.Connection
|
|
224
|
+
self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, conn: sql.engine.Connection
|
|
228
225
|
) -> int:
|
|
229
226
|
"""Update store column of a computed column with values produced by an execution plan
|
|
230
227
|
|
|
@@ -253,18 +250,11 @@ class StoreBase:
|
|
|
253
250
|
col.sa_errormsg_col: error_msg
|
|
254
251
|
}
|
|
255
252
|
else:
|
|
256
|
-
val = result_row.get_stored_val(value_expr_slot_idx)
|
|
253
|
+
val = result_row.get_stored_val(value_expr_slot_idx, col.sa_col.type)
|
|
257
254
|
if col.col_type.is_media_type():
|
|
258
255
|
val = self._move_tmp_media_file(val, col, result_row.pk[-1])
|
|
259
256
|
values_dict = {col.sa_col: val}
|
|
260
257
|
|
|
261
|
-
if col.is_indexed:
|
|
262
|
-
# TODO: deal with exceptions
|
|
263
|
-
assert not result_row.has_exc(embedding_slot_idx)
|
|
264
|
-
# don't use get_stored_val() here, we need to pass the ndarray
|
|
265
|
-
embedding = result_row[embedding_slot_idx]
|
|
266
|
-
values_dict[col.sa_index_col] = embedding
|
|
267
|
-
|
|
268
258
|
update_stmt = sql.update(self.sa_tbl).values(values_dict)
|
|
269
259
|
for pk_col, pk_val in zip(self.pk_columns(), result_row.pk):
|
|
270
260
|
update_stmt = update_stmt.where(pk_col == pk_val)
|
|
@@ -337,6 +327,7 @@ class StoreBase:
|
|
|
337
327
|
self, current_version: int, base_versions: List[Optional[int]], match_on_vmin: bool,
|
|
338
328
|
where_clause: Optional[sql.ClauseElement], conn: sql.engine.Connection) -> int:
|
|
339
329
|
"""Mark rows as deleted that are live and were created prior to current_version.
|
|
330
|
+
Also: populate the undo columns
|
|
340
331
|
Args:
|
|
341
332
|
base_versions: if non-None, join only to base rows that were created at that version,
|
|
342
333
|
otherwise join to rows that are live in the base's current version (which is distinct from the
|
|
@@ -354,8 +345,14 @@ class StoreBase:
|
|
|
354
345
|
rowid_join_clause = self._rowid_join_predicate()
|
|
355
346
|
base_versions_clause = sql.true() if len(base_versions) == 0 \
|
|
356
347
|
else self.base._versions_clause(base_versions, match_on_vmin)
|
|
348
|
+
set_clause = {self.v_max_col: current_version}
|
|
349
|
+
for index_info in self.tbl_version.idxs_by_name.values():
|
|
350
|
+
# copy value column to undo column
|
|
351
|
+
set_clause[index_info.undo_col.sa_col] = index_info.val_col.sa_col
|
|
352
|
+
# set value column to NULL
|
|
353
|
+
set_clause[index_info.val_col.sa_col] = None
|
|
357
354
|
stmt = sql.update(self.sa_tbl) \
|
|
358
|
-
.values(
|
|
355
|
+
.values(set_clause) \
|
|
359
356
|
.where(where_clause) \
|
|
360
357
|
.where(rowid_join_clause) \
|
|
361
358
|
.where(base_versions_clause)
|
|
@@ -416,8 +413,8 @@ class StoreComponentView(StoreView):
|
|
|
416
413
|
self.rowid_cols.append(self.pos_col)
|
|
417
414
|
return self.rowid_cols
|
|
418
415
|
|
|
419
|
-
def
|
|
420
|
-
super().
|
|
416
|
+
def create_sa_tbl(self) -> None:
|
|
417
|
+
super().create_sa_tbl()
|
|
421
418
|
# we need to fix up the 'pos' column in TableVersion
|
|
422
419
|
self.tbl_version.cols_by_name['pos'].sa_col = self.pos_col
|
|
423
420
|
|
pixeltable/tests/conftest.py
CHANGED
|
@@ -6,11 +6,12 @@ from typing import List
|
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pytest
|
|
9
|
+
import PIL.Image
|
|
9
10
|
|
|
10
11
|
import pixeltable as pxt
|
|
11
12
|
import pixeltable.catalog as catalog
|
|
12
13
|
from pixeltable import exprs
|
|
13
|
-
|
|
14
|
+
import pixeltable.functions as pxtf
|
|
14
15
|
from pixeltable.exprs import RELATIVE_PATH_ROOT as R
|
|
15
16
|
from pixeltable.metadata import SystemInfo, create_system_info
|
|
16
17
|
from pixeltable.metadata.schema import TableSchemaVersion, TableVersion, Table, Function, Dir
|
|
@@ -80,8 +81,8 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
|
|
|
80
81
|
return create_test_tbl(test_client)
|
|
81
82
|
|
|
82
83
|
# @pytest.fixture(scope='function')
|
|
83
|
-
# def test_stored_fn(test_client:
|
|
84
|
-
# @
|
|
84
|
+
# def test_stored_fn(test_client: pxt.Client) -> pxt.Function:
|
|
85
|
+
# @pxt.udf(return_type=pxt.IntType(), param_types=[pxt.IntType()])
|
|
85
86
|
# def test_fn(x):
|
|
86
87
|
# return x + 1
|
|
87
88
|
# test_client.create_function('test_fn', test_fn)
|
|
@@ -89,7 +90,7 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
|
|
|
89
90
|
|
|
90
91
|
@pytest.fixture(scope='function')
|
|
91
92
|
def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
|
|
92
|
-
#def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn:
|
|
93
|
+
#def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pxt.Function) -> List[exprs.Expr]:
|
|
93
94
|
|
|
94
95
|
t = test_tbl
|
|
95
96
|
return [
|
|
@@ -120,8 +121,7 @@ def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
|
|
|
120
121
|
t.c1.apply(json.loads),
|
|
121
122
|
t.c8.errortype,
|
|
122
123
|
t.c8.errormsg,
|
|
123
|
-
|
|
124
|
-
#test_stored_fn(t.c2),
|
|
124
|
+
pxtf.sum(t.c2, group_by=t.c4, order_by=t.c3),
|
|
125
125
|
]
|
|
126
126
|
|
|
127
127
|
@pytest.fixture(scope='function')
|
|
@@ -153,17 +153,11 @@ def img_tbl_exprs(img_tbl: catalog.Table) -> List[exprs.Expr]:
|
|
|
153
153
|
img_t.img.localpath,
|
|
154
154
|
]
|
|
155
155
|
|
|
156
|
-
# TODO: why does this not work with a session scope? (some user tables don't get created with create_all())
|
|
157
|
-
#@pytest.fixture(scope='session')
|
|
158
|
-
#def indexed_img_tbl(init_env: None) -> catalog.Table:
|
|
159
|
-
# cl = pt.Client()
|
|
160
|
-
# db = cl.create_db('test_indexed')
|
|
161
156
|
@pytest.fixture(scope='function')
|
|
162
|
-
def
|
|
163
|
-
skip_test_if_not_installed('nos')
|
|
157
|
+
def small_img_tbl(test_client: pxt.Client) -> catalog.Table:
|
|
164
158
|
cl = test_client
|
|
165
159
|
schema = {
|
|
166
|
-
'img':
|
|
160
|
+
'img': ImageType(nullable=False),
|
|
167
161
|
'category': StringType(nullable=False),
|
|
168
162
|
'split': StringType(nullable=False),
|
|
169
163
|
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import pixeltable as pxt
|
|
2
|
+
from pixeltable.tests.utils import skip_test_if_not_installed, get_image_files, validate_update_status
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TestYolox:
|
|
6
|
+
|
|
7
|
+
def test_yolox(self, test_client: pxt.Client):
|
|
8
|
+
skip_test_if_not_installed('yolox')
|
|
9
|
+
from pixeltable.ext.functions.yolox import yolox
|
|
10
|
+
cl = test_client
|
|
11
|
+
t = cl.create_table('yolox_test', {'image': pxt.ImageType()})
|
|
12
|
+
t['detect_yolox_tiny'] = yolox(t.image, model_id='yolox_tiny')
|
|
13
|
+
t['detect_yolox_nano'] = yolox(t.image, model_id='yolox_nano', threshold=0.2)
|
|
14
|
+
t['yolox_nano_bboxes'] = t.detect_yolox_nano.bboxes
|
|
15
|
+
images = get_image_files()[:10]
|
|
16
|
+
validate_update_status(t.insert({'image': image} for image in images), expected_rows=10)
|
|
17
|
+
rows = t.collect()
|
|
18
|
+
# Verify correctly formed JSON
|
|
19
|
+
assert all(list(result.keys()) == ['bboxes', 'labels', 'scores'] for result in rows['detect_yolox_tiny'])
|
|
20
|
+
# Verify that bboxes are actually present in at least some of the rows.
|
|
21
|
+
assert any(len(bboxes) > 0 for bboxes in rows['yolox_nano_bboxes'])
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import pixeltable as pxt
|
|
4
|
+
import pixeltable.exceptions as excs
|
|
5
|
+
from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.mark.remote_api
|
|
9
|
+
class TestFireworks:
|
|
10
|
+
|
|
11
|
+
def test_fireworks(self, test_client: pxt.Client) -> None:
|
|
12
|
+
skip_test_if_not_installed('fireworks')
|
|
13
|
+
TestFireworks.skip_test_if_no_fireworks_client()
|
|
14
|
+
cl = test_client
|
|
15
|
+
t = cl.create_table('test_tbl', {'input': pxt.StringType()})
|
|
16
|
+
from pixeltable.functions.fireworks import chat_completions
|
|
17
|
+
messages = [{'role': 'user', 'content': t.input}]
|
|
18
|
+
t['output'] = chat_completions(
|
|
19
|
+
messages=messages,
|
|
20
|
+
model='accounts/fireworks/models/llama-v2-7b-chat'
|
|
21
|
+
)
|
|
22
|
+
t['output_2'] = chat_completions(
|
|
23
|
+
messages=messages,
|
|
24
|
+
model='accounts/fireworks/models/llama-v2-7b-chat',
|
|
25
|
+
max_tokens=300,
|
|
26
|
+
top_k=40,
|
|
27
|
+
top_p=0.9,
|
|
28
|
+
temperature=0.7
|
|
29
|
+
)
|
|
30
|
+
validate_update_status(t.insert(input="How's everything going today?"), 1)
|
|
31
|
+
results = t.collect()
|
|
32
|
+
assert len(results['output'][0]['choices'][0]['message']['content']) > 0
|
|
33
|
+
assert len(results['output_2'][0]['choices'][0]['message']['content']) > 0
|
|
34
|
+
|
|
35
|
+
# This ensures that the test will be skipped, rather than returning an error, when no API key is
|
|
36
|
+
# available (for example, when a PR runs in CI).
|
|
37
|
+
@staticmethod
|
|
38
|
+
def skip_test_if_no_fireworks_client() -> None:
|
|
39
|
+
try:
|
|
40
|
+
import pixeltable.functions.fireworks
|
|
41
|
+
_ = pixeltable.functions.fireworks.fireworks_client()
|
|
42
|
+
except excs.Error as exc:
|
|
43
|
+
pytest.skip(str(exc))
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import pixeltable as pxt
|
|
2
|
+
from pixeltable import catalog
|
|
3
|
+
from pixeltable.functions.pil.image import blend
|
|
4
|
+
from pixeltable.iterators import FrameIterator
|
|
5
|
+
from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
|
|
6
|
+
from pixeltable.type_system import VideoType, StringType
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TestFunctions:
|
|
10
|
+
def test_pil(self, img_tbl: catalog.Table) -> None:
|
|
11
|
+
t = img_tbl
|
|
12
|
+
_ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
|
|
13
|
+
|
|
14
|
+
def test_eval_detections(self, test_client: pxt.Client) -> None:
|
|
15
|
+
skip_test_if_not_installed('nos')
|
|
16
|
+
cl = test_client
|
|
17
|
+
video_t = cl.create_table('video_tbl', {'video': VideoType()})
|
|
18
|
+
# create frame view
|
|
19
|
+
args = {'video': video_t.video, 'fps': 1}
|
|
20
|
+
v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
|
|
21
|
+
|
|
22
|
+
files = get_video_files()
|
|
23
|
+
video_t.insert(video=files[-1])
|
|
24
|
+
v.add_column(frame_s=v.frame.resize([640, 480]))
|
|
25
|
+
from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
|
|
26
|
+
v.add_column(detections_a=yolox_nano(v.frame_s))
|
|
27
|
+
v.add_column(detections_b=yolox_small(v.frame_s))
|
|
28
|
+
v.add_column(gt=yolox_large(v.frame_s))
|
|
29
|
+
from pixeltable.functions.eval import eval_detections, mean_ap
|
|
30
|
+
res = v.select(
|
|
31
|
+
eval_detections(
|
|
32
|
+
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
|
|
33
|
+
)).show()
|
|
34
|
+
v.add_column(
|
|
35
|
+
eval_a=eval_detections(
|
|
36
|
+
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
|
|
37
|
+
v.add_column(
|
|
38
|
+
eval_b=eval_detections(
|
|
39
|
+
v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
|
|
40
|
+
ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
|
|
41
|
+
ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
|
|
42
|
+
common_classes = set(ap_a.keys()) & set(ap_b.keys())
|
|
43
|
+
|
|
44
|
+
## TODO: following assertion is failing on CI,
|
|
45
|
+
# It is not necessarily a bug, as assert codition is not expected to be always true
|
|
46
|
+
# for k in common_classes:
|
|
47
|
+
# assert ap_a[k] <= ap_b[k]
|
|
48
|
+
|
|
49
|
+
def test_str(self, test_client: pxt.Client) -> None:
|
|
50
|
+
cl = test_client
|
|
51
|
+
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
52
|
+
from pixeltable.functions.string import str_format
|
|
53
|
+
t.add_column(s1=str_format('ABC {0}', t.input))
|
|
54
|
+
t.add_column(s2=str_format('DEF {this}', this=t.input))
|
|
55
|
+
t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
|
|
56
|
+
status = t.insert(input='MNO')
|
|
57
|
+
assert status.num_rows == 1
|
|
58
|
+
assert status.num_excs == 0
|
|
59
|
+
row = t.head()[0]
|
|
60
|
+
assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
|
|
@@ -3,144 +3,12 @@ from typing import Dict, Any
|
|
|
3
3
|
import pytest
|
|
4
4
|
|
|
5
5
|
import pixeltable as pxt
|
|
6
|
-
from pixeltable import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from pixeltable.functions.pil.image import blend
|
|
10
|
-
from pixeltable.iterators import FrameIterator
|
|
11
|
-
from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed, get_sentences, get_image_files
|
|
12
|
-
from pixeltable.type_system import VideoType, StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
|
|
6
|
+
from pixeltable.tests.utils import skip_test_if_not_installed, get_sentences, get_image_files, \
|
|
7
|
+
SAMPLE_IMAGE_URL
|
|
8
|
+
from pixeltable.type_system import StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
|
|
13
9
|
|
|
14
10
|
|
|
15
|
-
class
|
|
16
|
-
def test_pil(self, img_tbl: catalog.Table) -> None:
|
|
17
|
-
t = img_tbl
|
|
18
|
-
_ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
|
|
19
|
-
|
|
20
|
-
def test_eval_detections(self, test_client: pxt.Client) -> None:
|
|
21
|
-
skip_test_if_not_installed('nos')
|
|
22
|
-
cl = test_client
|
|
23
|
-
video_t = cl.create_table('video_tbl', {'video': VideoType()})
|
|
24
|
-
# create frame view
|
|
25
|
-
args = {'video': video_t.video, 'fps': 1}
|
|
26
|
-
v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
|
|
27
|
-
|
|
28
|
-
files = get_video_files()
|
|
29
|
-
video_t.insert(video=files[-1])
|
|
30
|
-
v.add_column(frame_s=v.frame.resize([640, 480]))
|
|
31
|
-
from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
|
|
32
|
-
v.add_column(detections_a=yolox_nano(v.frame_s))
|
|
33
|
-
v.add_column(detections_b=yolox_small(v.frame_s))
|
|
34
|
-
v.add_column(gt=yolox_large(v.frame_s))
|
|
35
|
-
from pixeltable.functions.eval import eval_detections, mean_ap
|
|
36
|
-
res = v.select(
|
|
37
|
-
eval_detections(
|
|
38
|
-
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
|
|
39
|
-
)).show()
|
|
40
|
-
v.add_column(
|
|
41
|
-
eval_a=eval_detections(
|
|
42
|
-
v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
|
|
43
|
-
v.add_column(
|
|
44
|
-
eval_b=eval_detections(
|
|
45
|
-
v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
|
|
46
|
-
ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
|
|
47
|
-
ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
|
|
48
|
-
common_classes = set(ap_a.keys()) & set(ap_b.keys())
|
|
49
|
-
|
|
50
|
-
## TODO: following assertion is failing on CI,
|
|
51
|
-
# It is not necessarily a bug, as assert codition is not expected to be always true
|
|
52
|
-
# for k in common_classes:
|
|
53
|
-
# assert ap_a[k] <= ap_b[k]
|
|
54
|
-
|
|
55
|
-
def test_str(self, test_client: pxt.Client) -> None:
|
|
56
|
-
cl = test_client
|
|
57
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
58
|
-
from pixeltable.functions.string import str_format
|
|
59
|
-
t.add_column(s1=str_format('ABC {0}', t.input))
|
|
60
|
-
t.add_column(s2=str_format('DEF {this}', this=t.input))
|
|
61
|
-
t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
|
|
62
|
-
status = t.insert(input='MNO')
|
|
63
|
-
assert status.num_rows == 1
|
|
64
|
-
assert status.num_excs == 0
|
|
65
|
-
row = t.head()[0]
|
|
66
|
-
assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
|
|
67
|
-
|
|
68
|
-
def test_openai(self, test_client: pxt.Client) -> None:
|
|
69
|
-
skip_test_if_not_installed('openai')
|
|
70
|
-
TestFunctions.skip_test_if_no_openai_client()
|
|
71
|
-
cl = test_client
|
|
72
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
73
|
-
from pixeltable.functions.openai import chat_completions, embeddings, moderations
|
|
74
|
-
msgs = [
|
|
75
|
-
{"role": "system", "content": "You are a helpful assistant."},
|
|
76
|
-
{"role": "user", "content": t.input}
|
|
77
|
-
]
|
|
78
|
-
t.add_column(input_msgs=msgs)
|
|
79
|
-
t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
|
|
80
|
-
# with inlined messages
|
|
81
|
-
t.add_column(chat_output2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
|
|
82
|
-
t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
|
|
83
|
-
t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input))
|
|
84
|
-
t.add_column(moderation=moderations(input=t.input))
|
|
85
|
-
t.insert(input='I find you really annoying')
|
|
86
|
-
_ = t.head()
|
|
87
|
-
|
|
88
|
-
def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
|
|
89
|
-
skip_test_if_not_installed('openai')
|
|
90
|
-
TestFunctions.skip_test_if_no_openai_client()
|
|
91
|
-
cl = test_client
|
|
92
|
-
t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
|
|
93
|
-
from pixeltable.functions.openai import chat_completions
|
|
94
|
-
from pixeltable.functions.string import str_format
|
|
95
|
-
msgs = [
|
|
96
|
-
{'role': 'user',
|
|
97
|
-
'content': [
|
|
98
|
-
{'type': 'text', 'text': t.prompt},
|
|
99
|
-
{'type': 'image_url', 'image_url': {
|
|
100
|
-
'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
|
|
101
|
-
}}
|
|
102
|
-
]}
|
|
103
|
-
]
|
|
104
|
-
t.add_column(response=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300))
|
|
105
|
-
t.add_column(response_content=t.response.choices[0].message.content)
|
|
106
|
-
t.insert(prompt="What's in this image?", img=_sample_image_url)
|
|
107
|
-
result = t.collect()['response_content'][0]
|
|
108
|
-
assert len(result) > 0
|
|
109
|
-
|
|
110
|
-
@staticmethod
|
|
111
|
-
def skip_test_if_no_openai_client() -> None:
|
|
112
|
-
try:
|
|
113
|
-
_ = Env.get().openai_client
|
|
114
|
-
except excs.Error as exc:
|
|
115
|
-
pytest.skip(str(exc))
|
|
116
|
-
|
|
117
|
-
def test_together(self, test_client: pxt.Client) -> None:
|
|
118
|
-
skip_test_if_not_installed('together')
|
|
119
|
-
if not Env.get().has_together_client:
|
|
120
|
-
pytest.skip(f'Together client does not exist (missing API key?)')
|
|
121
|
-
cl = test_client
|
|
122
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
123
|
-
from pixeltable.functions.together import completions
|
|
124
|
-
t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
|
|
125
|
-
t.add_column(output_text=t.output.output.choices[0].text)
|
|
126
|
-
t.insert(input='I am going to the ')
|
|
127
|
-
result = t.select(t.output_text).collect()['output_text'][0]
|
|
128
|
-
assert len(result) > 0
|
|
129
|
-
|
|
130
|
-
def test_fireworks(self, test_client: pxt.Client) -> None:
|
|
131
|
-
skip_test_if_not_installed('fireworks')
|
|
132
|
-
try:
|
|
133
|
-
from pixeltable.functions.fireworks import initialize
|
|
134
|
-
initialize()
|
|
135
|
-
except:
|
|
136
|
-
pytest.skip(f'Fireworks client does not exist (missing API key?)')
|
|
137
|
-
cl = test_client
|
|
138
|
-
t = cl.create_table('test_tbl', {'input': StringType()})
|
|
139
|
-
from pixeltable.functions.fireworks import chat_completions
|
|
140
|
-
t['output'] = chat_completions(prompt=t.input, model='accounts/fireworks/models/llama-v2-7b-chat', max_tokens=256).choices[0].text
|
|
141
|
-
t.insert(input='I am going to the ')
|
|
142
|
-
result = t.select(t.output).collect()['output'][0]
|
|
143
|
-
assert len(result) > 0
|
|
11
|
+
class TestHuggingface:
|
|
144
12
|
|
|
145
13
|
def test_hf_function(self, test_client: pxt.Client) -> None:
|
|
146
14
|
skip_test_if_not_installed('sentence_transformers')
|
|
@@ -255,10 +123,10 @@ class TestFunctions:
|
|
|
255
123
|
for idx, model_id in enumerate(model_ids):
|
|
256
124
|
col_name = f'embed_text{idx}'
|
|
257
125
|
t[col_name] = clip_text(t.text, model_id=model_id)
|
|
258
|
-
assert t.column_types()[col_name]
|
|
126
|
+
assert t.column_types()[col_name].is_array_type()
|
|
259
127
|
col_name = f'embed_img{idx}'
|
|
260
128
|
t[col_name] = clip_image(t.img, model_id=model_id)
|
|
261
|
-
assert t.column_types()[col_name]
|
|
129
|
+
assert t.column_types()[col_name].is_array_type()
|
|
262
130
|
|
|
263
131
|
def verify_row(row: Dict[str, Any]) -> None:
|
|
264
132
|
for idx, _ in enumerate(model_ids):
|
|
@@ -281,14 +149,10 @@ class TestFunctions:
|
|
|
281
149
|
t = cl.create_table('test_tbl', {'img': ImageType()})
|
|
282
150
|
from pixeltable.functions.huggingface import detr_for_object_detection
|
|
283
151
|
t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
|
|
284
|
-
status = t.insert(img=
|
|
152
|
+
status = t.insert(img=SAMPLE_IMAGE_URL)
|
|
285
153
|
assert status.num_rows == 1
|
|
286
154
|
assert status.num_excs == 0
|
|
287
155
|
result = t.select(t.detect).collect()[0]['detect']
|
|
288
156
|
assert 'orange' in result['label_text']
|
|
289
157
|
assert 'bowl' in result['label_text']
|
|
290
158
|
assert 'broccoli' in result['label_text']
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
_sample_image_url = \
|
|
294
|
-
'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'
|