pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (63) hide show
  1. pixeltable/catalog/column.py +26 -49
  2. pixeltable/catalog/insertable_table.py +7 -4
  3. pixeltable/catalog/table.py +163 -57
  4. pixeltable/catalog/table_version.py +416 -140
  5. pixeltable/catalog/table_version_path.py +2 -2
  6. pixeltable/client.py +72 -6
  7. pixeltable/dataframe.py +65 -21
  8. pixeltable/env.py +52 -53
  9. pixeltable/exec/cache_prefetch_node.py +1 -1
  10. pixeltable/exec/in_memory_data_node.py +11 -7
  11. pixeltable/exprs/comparison.py +3 -3
  12. pixeltable/exprs/data_row.py +5 -1
  13. pixeltable/exprs/literal.py +16 -4
  14. pixeltable/exprs/row_builder.py +8 -40
  15. pixeltable/ext/__init__.py +5 -0
  16. pixeltable/ext/functions/yolox.py +92 -0
  17. pixeltable/func/aggregate_function.py +15 -15
  18. pixeltable/func/expr_template_function.py +9 -1
  19. pixeltable/func/globals.py +24 -14
  20. pixeltable/func/signature.py +18 -12
  21. pixeltable/func/udf.py +7 -2
  22. pixeltable/functions/__init__.py +9 -9
  23. pixeltable/functions/eval.py +7 -8
  24. pixeltable/functions/fireworks.py +10 -37
  25. pixeltable/functions/huggingface.py +47 -19
  26. pixeltable/functions/openai.py +192 -24
  27. pixeltable/functions/together.py +104 -9
  28. pixeltable/functions/util.py +11 -0
  29. pixeltable/index/__init__.py +2 -0
  30. pixeltable/index/base.py +49 -0
  31. pixeltable/index/embedding_index.py +95 -0
  32. pixeltable/metadata/schema.py +45 -22
  33. pixeltable/plan.py +15 -34
  34. pixeltable/store.py +38 -41
  35. pixeltable/tests/conftest.py +8 -14
  36. pixeltable/tests/ext/test_yolox.py +21 -0
  37. pixeltable/tests/functions/test_fireworks.py +43 -0
  38. pixeltable/tests/functions/test_functions.py +60 -0
  39. pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
  40. pixeltable/tests/functions/test_openai.py +162 -0
  41. pixeltable/tests/functions/test_together.py +112 -0
  42. pixeltable/tests/test_component_view.py +14 -5
  43. pixeltable/tests/test_dataframe.py +23 -22
  44. pixeltable/tests/test_exprs.py +99 -102
  45. pixeltable/tests/test_function.py +51 -43
  46. pixeltable/tests/test_index.py +138 -0
  47. pixeltable/tests/test_migration.py +2 -1
  48. pixeltable/tests/test_snapshot.py +24 -1
  49. pixeltable/tests/test_table.py +205 -26
  50. pixeltable/tests/test_types.py +30 -0
  51. pixeltable/tests/test_video.py +16 -16
  52. pixeltable/tests/test_view.py +5 -0
  53. pixeltable/tests/utils.py +171 -14
  54. pixeltable/tool/create_test_db_dump.py +16 -0
  55. pixeltable/type_system.py +77 -128
  56. pixeltable/utils/arrow.py +98 -0
  57. pixeltable/utils/hf_datasets.py +157 -0
  58. pixeltable/utils/parquet.py +68 -27
  59. pixeltable/utils/pytorch.py +16 -97
  60. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
  61. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
  62. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
  63. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
pixeltable/plan.py CHANGED
@@ -76,7 +76,8 @@ class Analyzer:
76
76
  f'order_by()'))
77
77
  self.similarity_clause = similarity_clauses[0]
78
78
  img_col = self.similarity_clause.img_col_ref.col
79
- if not img_col.is_indexed:
79
+ indexed_col_ids = {info.col.id for info in tbl.tbl_version.idxs_by_name.values()}
80
+ if img_col.id not in indexed_col_ids:
80
81
  raise excs.Error(f'nearest() not available for unindexed column {img_col.name}')
81
82
 
82
83
  # all exprs that are evaluated in Python; not executable
@@ -220,18 +221,11 @@ class Planner:
220
221
  ) -> exec.ExecNode:
221
222
  """Creates a plan for TableVersion.insert()"""
222
223
  assert not tbl.is_view()
223
- # things we need to materialize:
224
- # 1. stored_cols: all cols we need to store, incl computed cols (and indices)
224
+ # stored_cols: all cols we need to store, incl computed cols (and indices)
225
225
  stored_cols = [c for c in tbl.cols if c.is_stored]
226
226
  assert len(stored_cols) > 0
227
- # 2. values to insert into indices
228
- indexed_cols = [c for c in tbl.cols if c.is_indexed]
229
- index_info: List[Tuple[catalog.Column, func.Function]] = []
230
- if len(indexed_cols) > 0:
231
- from pixeltable.functions.nos.image_embedding import openai_clip
232
- index_info = [(c, openai_clip) for c in tbl.cols if c.is_indexed]
233
227
 
234
- row_builder = exprs.RowBuilder([], stored_cols, index_info, [])
228
+ row_builder = exprs.RowBuilder([], stored_cols, [])
235
229
 
236
230
  # create InMemoryDataNode for 'rows'
237
231
  stored_col_info = row_builder.output_slot_idxs()
@@ -260,7 +254,7 @@ class Planner:
260
254
  @classmethod
261
255
  def create_update_plan(
262
256
  cls, tbl: catalog.TableVersionPath,
263
- update_targets: List[Tuple[catalog.Column, exprs.Expr]],
257
+ update_targets: dict[catalog.Column, exprs.Expr],
264
258
  recompute_targets: List[catalog.Column],
265
259
  where_clause: Optional[exprs.Predicate], cascade: bool
266
260
  ) -> Tuple[exec.ExecNode, List[str], List[catalog.Column]]:
@@ -279,7 +273,7 @@ class Planner:
279
273
  # retrieve all stored cols and all target exprs
280
274
  assert isinstance(tbl, catalog.TableVersionPath)
281
275
  target = tbl.tbl_version # the one we need to update
282
- updated_cols = [col for col, _ in update_targets]
276
+ updated_cols = list(update_targets.keys())
283
277
  if len(recompute_targets) > 0:
284
278
  recomputed_cols = recompute_targets.copy()
285
279
  else:
@@ -291,12 +285,12 @@ class Planner:
291
285
  col for col in target.cols if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
292
286
  ]
293
287
  select_list = [exprs.ColumnRef(col) for col in copied_cols]
294
- select_list.extend([expr for _, expr in update_targets])
288
+ select_list.extend(update_targets.values())
295
289
 
296
290
  recomputed_exprs = \
297
291
  [c.value_expr.copy().resolve_computed_cols(resolve_cols=recomputed_base_cols) for c in recomputed_base_cols]
298
292
  # recomputed cols reference the new values of the updated cols
299
- for col, e in update_targets:
293
+ for col, e in update_targets.items():
300
294
  exprs.Expr.list_substitute(recomputed_exprs, exprs.ColumnRef(col), e)
301
295
  select_list.extend(recomputed_exprs)
302
296
 
@@ -375,16 +369,10 @@ class Planner:
375
369
  # the store
376
370
  target = view.tbl_version # the one we need to populate
377
371
  stored_cols = [c for c in target.cols if c.is_stored and (c.is_computed or target.is_iterator_column(c))]
378
- # 2. index values
379
- indexed_cols = [c for c in target.cols if c.is_indexed]
380
- index_info: List[Tuple[catalog.Column, func.Function]] = []
381
- if len(indexed_cols) > 0:
382
- from pixeltable.functions.nos.image_embedding import openai_clip
383
- index_info = [(c, openai_clip) for c in target.cols if c.is_indexed]
384
- # 3. for component views: iterator args
372
+ # 2. for component views: iterator args
385
373
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
386
374
 
387
- row_builder = exprs.RowBuilder(iterator_args, stored_cols, index_info, [])
375
+ row_builder = exprs.RowBuilder(iterator_args, stored_cols, [])
388
376
 
389
377
  # execution plan:
390
378
  # 1. materialize exprs computed from the base that are needed for stored view columns
@@ -548,7 +536,7 @@ class Planner:
548
536
  analyzer = Analyzer(
549
537
  tbl, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
550
538
  order_by_clause=order_by_clause)
551
- row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [], analyzer.sql_exprs)
539
+ row_builder = exprs.RowBuilder(analyzer.all_exprs, [], analyzer.sql_exprs)
552
540
 
553
541
  analyzer.finalize(row_builder)
554
542
  # select_list: we need to materialize everything that's been collected
@@ -627,21 +615,15 @@ class Planner:
627
615
  @classmethod
628
616
  def create_add_column_plan(
629
617
  cls, tbl: catalog.TableVersionPath, col: catalog.Column
630
- ) -> Tuple[exec.ExecNode, Optional[int], Optional[int]]:
618
+ ) -> Tuple[exec.ExecNode, Optional[int]]:
631
619
  """Creates a plan for InsertableTable.add_column()
632
620
  Returns:
633
621
  plan: the plan to execute
634
- ctx: the context to use for the plan
635
622
  value_expr slot idx for the plan output (for computed cols)
636
- embedding slot idx for the plan output (for indexed image cols)
637
623
  """
638
624
  assert isinstance(tbl, catalog.TableVersionPath)
639
625
  index_info: List[Tuple[catalog.Column, func.Function]] = []
640
- if col.is_indexed:
641
- from pixeltable.functions.nos.image_embedding import openai_clip
642
- index_info = [(col, openai_clip)]
643
- row_builder = exprs.RowBuilder(
644
- output_exprs=[], columns=[col], indices=index_info, input_exprs=[])
626
+ row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[])
645
627
  analyzer = Analyzer(tbl, row_builder.default_eval_ctx.target_exprs)
646
628
  plan = cls._create_query_plan(tbl, row_builder=row_builder, analyzer=analyzer, with_pk=True)
647
629
  plan.ctx.batch_size = 16
@@ -651,6 +633,5 @@ class Planner:
651
633
  # we want to flush images
652
634
  if col.is_computed and col.is_stored and col.col_type.is_image_type():
653
635
  plan.set_stored_img_cols(row_builder.output_slot_idxs())
654
- value_expr_slot_idx: Optional[int] = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
655
- embedding_slot_idx: Optional[int] = row_builder.index_slot_idxs()[0].slot_idx if col.is_indexed else None
656
- return plan, value_expr_slot_idx, embedding_slot_idx
636
+ value_expr_slot_idx = row_builder.output_slot_idxs()[0].slot_idx if col.is_computed else None
637
+ return plan, value_expr_slot_idx
pixeltable/store.py CHANGED
@@ -38,7 +38,7 @@ class StoreBase:
38
38
  self.tbl_version = tbl_version
39
39
  self.sa_md = sql.MetaData()
40
40
  self.sa_tbl: Optional[sql.Table] = None
41
- self._create_sa_tbl()
41
+ self.create_sa_tbl()
42
42
 
43
43
  def pk_columns(self) -> List[sql.Column]:
44
44
  return self._pk_columns
@@ -62,7 +62,7 @@ class StoreBase:
62
62
  return [*rowid_cols, self.v_min_col, self.v_max_col]
63
63
 
64
64
 
65
- def _create_sa_tbl(self) -> None:
65
+ def create_sa_tbl(self) -> None:
66
66
  """Create self.sa_tbl from self.tbl_version."""
67
67
  system_cols = self._create_system_columns()
68
68
  all_cols = system_cols.copy()
@@ -76,9 +76,6 @@ class StoreBase:
76
76
  all_cols.append(col.sa_errormsg_col)
77
77
  all_cols.append(col.sa_errortype_col)
78
78
 
79
- if col.is_indexed:
80
- all_cols.append(col.sa_idx_col)
81
-
82
79
  # we create an index for:
83
80
  # - scalar columns (except for strings, because long strings can't be used for B-tree indices)
84
81
  # - non-computed video and image columns (they will contain external paths/urls that users might want to
@@ -145,8 +142,8 @@ class StoreBase:
145
142
  """Move tmp media files that we generated to a permanent location"""
146
143
  for c in media_cols:
147
144
  for table_row in table_rows:
148
- file_url = table_row[c.storage_name()]
149
- table_row[c.storage_name()] = self._move_tmp_media_file(file_url, c, v_min)
145
+ file_url = table_row[c.store_name()]
146
+ table_row[c.store_name()] = self._move_tmp_media_file(file_url, c, v_min)
150
147
 
151
148
  def _create_table_row(
152
149
  self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, media_cols: List[catalog.Column],
@@ -168,16 +165,19 @@ class StoreBase:
168
165
 
169
166
  return table_row, num_excs
170
167
 
171
- def count(self) -> None:
168
+ def count(self, conn: Optional[sql.engine.Connection] = None) -> int:
172
169
  """Return the number of rows visible in self.tbl_version"""
173
170
  stmt = sql.select(sql.func.count('*'))\
174
171
  .select_from(self.sa_tbl)\
175
172
  .where(self.v_min_col <= self.tbl_version.version)\
176
173
  .where(self.v_max_col > self.tbl_version.version)
177
- with env.Env.get().engine.begin() as conn:
174
+ if conn is None:
175
+ with env.Env.get().engine.connect() as conn:
176
+ result = conn.execute(stmt).scalar_one()
177
+ else:
178
178
  result = conn.execute(stmt).scalar_one()
179
- assert isinstance(result, int)
180
- return result
179
+ assert isinstance(result, int)
180
+ return result
181
181
 
182
182
  def create(self, conn: sql.engine.Connection) -> None:
183
183
  self.sa_md.create_all(bind=conn)
@@ -193,38 +193,35 @@ class StoreBase:
193
193
  message).
194
194
  """
195
195
  assert col.is_stored
196
- stmt = sql.text(f'ALTER TABLE {self._storage_name()} ADD COLUMN {col.storage_name()} {col.col_type.to_sql()}')
196
+ col_type_str = col.get_sa_col_type().compile(dialect=conn.dialect)
197
+ stmt = sql.text(f'ALTER TABLE {self._storage_name()} ADD COLUMN {col.store_name()} {col_type_str} NULL')
197
198
  log_stmt(_logger, stmt)
198
199
  conn.execute(stmt)
199
- added_storage_cols = [col.storage_name()]
200
+ added_storage_cols = [col.store_name()]
200
201
  if col.records_errors:
201
202
  # we also need to create the errormsg and errortype storage cols
202
203
  stmt = (f'ALTER TABLE {self._storage_name()} '
203
- f'ADD COLUMN {col.errormsg_storage_name()} {StringType().to_sql()} DEFAULT NULL')
204
+ f'ADD COLUMN {col.errormsg_store_name()} {StringType().to_sql()} DEFAULT NULL')
204
205
  conn.execute(sql.text(stmt))
205
206
  stmt = (f'ALTER TABLE {self._storage_name()} '
206
- f'ADD COLUMN {col.errortype_storage_name()} {StringType().to_sql()} DEFAULT NULL')
207
+ f'ADD COLUMN {col.errortype_store_name()} {StringType().to_sql()} DEFAULT NULL')
207
208
  conn.execute(sql.text(stmt))
208
- added_storage_cols.extend([col.errormsg_storage_name(), col.errortype_storage_name()])
209
- self._create_sa_tbl()
209
+ added_storage_cols.extend([col.errormsg_store_name(), col.errortype_store_name()])
210
+ self.create_sa_tbl()
210
211
  _logger.info(f'Added columns {added_storage_cols} to storage table {self._storage_name()}')
211
212
 
212
- def drop_column(self, col: Optional[catalog.Column] = None, conn: Optional[sql.engine.Connection] = None) -> None:
213
- """Re-create self.sa_tbl and drop column, if one is given"""
214
- if col is not None:
215
- assert conn is not None
216
- stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.storage_name()}'
213
+ def drop_column(self, col: catalog.Column, conn: sql.engine.Connection) -> None:
214
+ """Execute Alter Table Drop Column statement"""
215
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.store_name()}'
216
+ conn.execute(sql.text(stmt))
217
+ if col.records_errors:
218
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errormsg_store_name()}'
219
+ conn.execute(sql.text(stmt))
220
+ stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errortype_store_name()}'
217
221
  conn.execute(sql.text(stmt))
218
- if col.records_errors:
219
- stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errormsg_storage_name()}'
220
- conn.execute(sql.text(stmt))
221
- stmt = f'ALTER TABLE {self._storage_name()} DROP COLUMN {col.errortype_storage_name()}'
222
- conn.execute(sql.text(stmt))
223
- self._create_sa_tbl()
224
222
 
225
223
  def load_column(
226
- self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, embedding_slot_idx: int,
227
- conn: sql.engine.Connection
224
+ self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, conn: sql.engine.Connection
228
225
  ) -> int:
229
226
  """Update store column of a computed column with values produced by an execution plan
230
227
 
@@ -253,18 +250,11 @@ class StoreBase:
253
250
  col.sa_errormsg_col: error_msg
254
251
  }
255
252
  else:
256
- val = result_row.get_stored_val(value_expr_slot_idx)
253
+ val = result_row.get_stored_val(value_expr_slot_idx, col.sa_col.type)
257
254
  if col.col_type.is_media_type():
258
255
  val = self._move_tmp_media_file(val, col, result_row.pk[-1])
259
256
  values_dict = {col.sa_col: val}
260
257
 
261
- if col.is_indexed:
262
- # TODO: deal with exceptions
263
- assert not result_row.has_exc(embedding_slot_idx)
264
- # don't use get_stored_val() here, we need to pass the ndarray
265
- embedding = result_row[embedding_slot_idx]
266
- values_dict[col.sa_index_col] = embedding
267
-
268
258
  update_stmt = sql.update(self.sa_tbl).values(values_dict)
269
259
  for pk_col, pk_val in zip(self.pk_columns(), result_row.pk):
270
260
  update_stmt = update_stmt.where(pk_col == pk_val)
@@ -337,6 +327,7 @@ class StoreBase:
337
327
  self, current_version: int, base_versions: List[Optional[int]], match_on_vmin: bool,
338
328
  where_clause: Optional[sql.ClauseElement], conn: sql.engine.Connection) -> int:
339
329
  """Mark rows as deleted that are live and were created prior to current_version.
330
+ Also: populate the undo columns
340
331
  Args:
341
332
  base_versions: if non-None, join only to base rows that were created at that version,
342
333
  otherwise join to rows that are live in the base's current version (which is distinct from the
@@ -354,8 +345,14 @@ class StoreBase:
354
345
  rowid_join_clause = self._rowid_join_predicate()
355
346
  base_versions_clause = sql.true() if len(base_versions) == 0 \
356
347
  else self.base._versions_clause(base_versions, match_on_vmin)
348
+ set_clause = {self.v_max_col: current_version}
349
+ for index_info in self.tbl_version.idxs_by_name.values():
350
+ # copy value column to undo column
351
+ set_clause[index_info.undo_col.sa_col] = index_info.val_col.sa_col
352
+ # set value column to NULL
353
+ set_clause[index_info.val_col.sa_col] = None
357
354
  stmt = sql.update(self.sa_tbl) \
358
- .values({self.v_max_col: current_version}) \
355
+ .values(set_clause) \
359
356
  .where(where_clause) \
360
357
  .where(rowid_join_clause) \
361
358
  .where(base_versions_clause)
@@ -416,8 +413,8 @@ class StoreComponentView(StoreView):
416
413
  self.rowid_cols.append(self.pos_col)
417
414
  return self.rowid_cols
418
415
 
419
- def _create_sa_tbl(self) -> None:
420
- super()._create_sa_tbl()
416
+ def create_sa_tbl(self) -> None:
417
+ super().create_sa_tbl()
421
418
  # we need to fix up the 'pos' column in TableVersion
422
419
  self.tbl_version.cols_by_name['pos'].sa_col = self.pos_col
423
420
 
@@ -6,11 +6,12 @@ from typing import List
6
6
 
7
7
  import numpy as np
8
8
  import pytest
9
+ import PIL.Image
9
10
 
10
11
  import pixeltable as pxt
11
12
  import pixeltable.catalog as catalog
12
13
  from pixeltable import exprs
13
- from pixeltable import functions as ptf
14
+ import pixeltable.functions as pxtf
14
15
  from pixeltable.exprs import RELATIVE_PATH_ROOT as R
15
16
  from pixeltable.metadata import SystemInfo, create_system_info
16
17
  from pixeltable.metadata.schema import TableSchemaVersion, TableVersion, Table, Function, Dir
@@ -80,8 +81,8 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
80
81
  return create_test_tbl(test_client)
81
82
 
82
83
  # @pytest.fixture(scope='function')
83
- # def test_stored_fn(test_client: pt.Client) -> pt.Function:
84
- # @pt.udf(return_type=pt.IntType(), param_types=[pt.IntType()])
84
+ # def test_stored_fn(test_client: pxt.Client) -> pxt.Function:
85
+ # @pxt.udf(return_type=pxt.IntType(), param_types=[pxt.IntType()])
85
86
  # def test_fn(x):
86
87
  # return x + 1
87
88
  # test_client.create_function('test_fn', test_fn)
@@ -89,7 +90,7 @@ def test_tbl(test_client: pxt.Client) -> catalog.Table:
89
90
 
90
91
  @pytest.fixture(scope='function')
91
92
  def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
92
- #def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pt.Function) -> List[exprs.Expr]:
93
+ #def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pxt.Function) -> List[exprs.Expr]:
93
94
 
94
95
  t = test_tbl
95
96
  return [
@@ -120,8 +121,7 @@ def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
120
121
  t.c1.apply(json.loads),
121
122
  t.c8.errortype,
122
123
  t.c8.errormsg,
123
- ptf.sum(t.c2, group_by=t.c4, order_by=t.c3),
124
- #test_stored_fn(t.c2),
124
+ pxtf.sum(t.c2, group_by=t.c4, order_by=t.c3),
125
125
  ]
126
126
 
127
127
  @pytest.fixture(scope='function')
@@ -153,17 +153,11 @@ def img_tbl_exprs(img_tbl: catalog.Table) -> List[exprs.Expr]:
153
153
  img_t.img.localpath,
154
154
  ]
155
155
 
156
- # TODO: why does this not work with a session scope? (some user tables don't get created with create_all())
157
- #@pytest.fixture(scope='session')
158
- #def indexed_img_tbl(init_env: None) -> catalog.Table:
159
- # cl = pt.Client()
160
- # db = cl.create_db('test_indexed')
161
156
  @pytest.fixture(scope='function')
162
- def indexed_img_tbl(test_client: pxt.Client) -> catalog.Table:
163
- skip_test_if_not_installed('nos')
157
+ def small_img_tbl(test_client: pxt.Client) -> catalog.Table:
164
158
  cl = test_client
165
159
  schema = {
166
- 'img': { 'type': ImageType(nullable=False), 'indexed': True },
160
+ 'img': ImageType(nullable=False),
167
161
  'category': StringType(nullable=False),
168
162
  'split': StringType(nullable=False),
169
163
  }
@@ -0,0 +1,21 @@
1
+ import pixeltable as pxt
2
+ from pixeltable.tests.utils import skip_test_if_not_installed, get_image_files, validate_update_status
3
+
4
+
5
+ class TestYolox:
6
+
7
+ def test_yolox(self, test_client: pxt.Client):
8
+ skip_test_if_not_installed('yolox')
9
+ from pixeltable.ext.functions.yolox import yolox
10
+ cl = test_client
11
+ t = cl.create_table('yolox_test', {'image': pxt.ImageType()})
12
+ t['detect_yolox_tiny'] = yolox(t.image, model_id='yolox_tiny')
13
+ t['detect_yolox_nano'] = yolox(t.image, model_id='yolox_nano', threshold=0.2)
14
+ t['yolox_nano_bboxes'] = t.detect_yolox_nano.bboxes
15
+ images = get_image_files()[:10]
16
+ validate_update_status(t.insert({'image': image} for image in images), expected_rows=10)
17
+ rows = t.collect()
18
+ # Verify correctly formed JSON
19
+ assert all(list(result.keys()) == ['bboxes', 'labels', 'scores'] for result in rows['detect_yolox_tiny'])
20
+ # Verify that bboxes are actually present in at least some of the rows.
21
+ assert any(len(bboxes) > 0 for bboxes in rows['yolox_nano_bboxes'])
@@ -0,0 +1,43 @@
1
+ import pytest
2
+
3
+ import pixeltable as pxt
4
+ import pixeltable.exceptions as excs
5
+ from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
6
+
7
+
8
+ @pytest.mark.remote_api
9
+ class TestFireworks:
10
+
11
+ def test_fireworks(self, test_client: pxt.Client) -> None:
12
+ skip_test_if_not_installed('fireworks')
13
+ TestFireworks.skip_test_if_no_fireworks_client()
14
+ cl = test_client
15
+ t = cl.create_table('test_tbl', {'input': pxt.StringType()})
16
+ from pixeltable.functions.fireworks import chat_completions
17
+ messages = [{'role': 'user', 'content': t.input}]
18
+ t['output'] = chat_completions(
19
+ messages=messages,
20
+ model='accounts/fireworks/models/llama-v2-7b-chat'
21
+ )
22
+ t['output_2'] = chat_completions(
23
+ messages=messages,
24
+ model='accounts/fireworks/models/llama-v2-7b-chat',
25
+ max_tokens=300,
26
+ top_k=40,
27
+ top_p=0.9,
28
+ temperature=0.7
29
+ )
30
+ validate_update_status(t.insert(input="How's everything going today?"), 1)
31
+ results = t.collect()
32
+ assert len(results['output'][0]['choices'][0]['message']['content']) > 0
33
+ assert len(results['output_2'][0]['choices'][0]['message']['content']) > 0
34
+
35
+ # This ensures that the test will be skipped, rather than returning an error, when no API key is
36
+ # available (for example, when a PR runs in CI).
37
+ @staticmethod
38
+ def skip_test_if_no_fireworks_client() -> None:
39
+ try:
40
+ import pixeltable.functions.fireworks
41
+ _ = pixeltable.functions.fireworks.fireworks_client()
42
+ except excs.Error as exc:
43
+ pytest.skip(str(exc))
@@ -0,0 +1,60 @@
1
+ import pixeltable as pxt
2
+ from pixeltable import catalog
3
+ from pixeltable.functions.pil.image import blend
4
+ from pixeltable.iterators import FrameIterator
5
+ from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
6
+ from pixeltable.type_system import VideoType, StringType
7
+
8
+
9
+ class TestFunctions:
10
+ def test_pil(self, img_tbl: catalog.Table) -> None:
11
+ t = img_tbl
12
+ _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
13
+
14
+ def test_eval_detections(self, test_client: pxt.Client) -> None:
15
+ skip_test_if_not_installed('nos')
16
+ cl = test_client
17
+ video_t = cl.create_table('video_tbl', {'video': VideoType()})
18
+ # create frame view
19
+ args = {'video': video_t.video, 'fps': 1}
20
+ v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
21
+
22
+ files = get_video_files()
23
+ video_t.insert(video=files[-1])
24
+ v.add_column(frame_s=v.frame.resize([640, 480]))
25
+ from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
26
+ v.add_column(detections_a=yolox_nano(v.frame_s))
27
+ v.add_column(detections_b=yolox_small(v.frame_s))
28
+ v.add_column(gt=yolox_large(v.frame_s))
29
+ from pixeltable.functions.eval import eval_detections, mean_ap
30
+ res = v.select(
31
+ eval_detections(
32
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
33
+ )).show()
34
+ v.add_column(
35
+ eval_a=eval_detections(
36
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
37
+ v.add_column(
38
+ eval_b=eval_detections(
39
+ v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
40
+ ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
41
+ ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
42
+ common_classes = set(ap_a.keys()) & set(ap_b.keys())
43
+
44
+ ## TODO: following assertion is failing on CI,
45
+ # It is not necessarily a bug, as assert codition is not expected to be always true
46
+ # for k in common_classes:
47
+ # assert ap_a[k] <= ap_b[k]
48
+
49
+ def test_str(self, test_client: pxt.Client) -> None:
50
+ cl = test_client
51
+ t = cl.create_table('test_tbl', {'input': StringType()})
52
+ from pixeltable.functions.string import str_format
53
+ t.add_column(s1=str_format('ABC {0}', t.input))
54
+ t.add_column(s2=str_format('DEF {this}', this=t.input))
55
+ t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
56
+ status = t.insert(input='MNO')
57
+ assert status.num_rows == 1
58
+ assert status.num_excs == 0
59
+ row = t.head()[0]
60
+ assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
@@ -3,144 +3,12 @@ from typing import Dict, Any
3
3
  import pytest
4
4
 
5
5
  import pixeltable as pxt
6
- from pixeltable import catalog
7
- from pixeltable.env import Env
8
- import pixeltable.exceptions as excs
9
- from pixeltable.functions.pil.image import blend
10
- from pixeltable.iterators import FrameIterator
11
- from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed, get_sentences, get_image_files
12
- from pixeltable.type_system import VideoType, StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
6
+ from pixeltable.tests.utils import skip_test_if_not_installed, get_sentences, get_image_files, \
7
+ SAMPLE_IMAGE_URL
8
+ from pixeltable.type_system import StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
13
9
 
14
10
 
15
- class TestFunctions:
16
- def test_pil(self, img_tbl: catalog.Table) -> None:
17
- t = img_tbl
18
- _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
19
-
20
- def test_eval_detections(self, test_client: pxt.Client) -> None:
21
- skip_test_if_not_installed('nos')
22
- cl = test_client
23
- video_t = cl.create_table('video_tbl', {'video': VideoType()})
24
- # create frame view
25
- args = {'video': video_t.video, 'fps': 1}
26
- v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
27
-
28
- files = get_video_files()
29
- video_t.insert(video=files[-1])
30
- v.add_column(frame_s=v.frame.resize([640, 480]))
31
- from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
32
- v.add_column(detections_a=yolox_nano(v.frame_s))
33
- v.add_column(detections_b=yolox_small(v.frame_s))
34
- v.add_column(gt=yolox_large(v.frame_s))
35
- from pixeltable.functions.eval import eval_detections, mean_ap
36
- res = v.select(
37
- eval_detections(
38
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
39
- )).show()
40
- v.add_column(
41
- eval_a=eval_detections(
42
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
43
- v.add_column(
44
- eval_b=eval_detections(
45
- v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
46
- ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
47
- ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
48
- common_classes = set(ap_a.keys()) & set(ap_b.keys())
49
-
50
- ## TODO: following assertion is failing on CI,
51
- # It is not necessarily a bug, as assert codition is not expected to be always true
52
- # for k in common_classes:
53
- # assert ap_a[k] <= ap_b[k]
54
-
55
- def test_str(self, test_client: pxt.Client) -> None:
56
- cl = test_client
57
- t = cl.create_table('test_tbl', {'input': StringType()})
58
- from pixeltable.functions.string import str_format
59
- t.add_column(s1=str_format('ABC {0}', t.input))
60
- t.add_column(s2=str_format('DEF {this}', this=t.input))
61
- t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
62
- status = t.insert(input='MNO')
63
- assert status.num_rows == 1
64
- assert status.num_excs == 0
65
- row = t.head()[0]
66
- assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
67
-
68
- def test_openai(self, test_client: pxt.Client) -> None:
69
- skip_test_if_not_installed('openai')
70
- TestFunctions.skip_test_if_no_openai_client()
71
- cl = test_client
72
- t = cl.create_table('test_tbl', {'input': StringType()})
73
- from pixeltable.functions.openai import chat_completions, embeddings, moderations
74
- msgs = [
75
- {"role": "system", "content": "You are a helpful assistant."},
76
- {"role": "user", "content": t.input}
77
- ]
78
- t.add_column(input_msgs=msgs)
79
- t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
80
- # with inlined messages
81
- t.add_column(chat_output2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
82
- t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
83
- t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input))
84
- t.add_column(moderation=moderations(input=t.input))
85
- t.insert(input='I find you really annoying')
86
- _ = t.head()
87
-
88
- def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
89
- skip_test_if_not_installed('openai')
90
- TestFunctions.skip_test_if_no_openai_client()
91
- cl = test_client
92
- t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
93
- from pixeltable.functions.openai import chat_completions
94
- from pixeltable.functions.string import str_format
95
- msgs = [
96
- {'role': 'user',
97
- 'content': [
98
- {'type': 'text', 'text': t.prompt},
99
- {'type': 'image_url', 'image_url': {
100
- 'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
101
- }}
102
- ]}
103
- ]
104
- t.add_column(response=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300))
105
- t.add_column(response_content=t.response.choices[0].message.content)
106
- t.insert(prompt="What's in this image?", img=_sample_image_url)
107
- result = t.collect()['response_content'][0]
108
- assert len(result) > 0
109
-
110
- @staticmethod
111
- def skip_test_if_no_openai_client() -> None:
112
- try:
113
- _ = Env.get().openai_client
114
- except excs.Error as exc:
115
- pytest.skip(str(exc))
116
-
117
- def test_together(self, test_client: pxt.Client) -> None:
118
- skip_test_if_not_installed('together')
119
- if not Env.get().has_together_client:
120
- pytest.skip(f'Together client does not exist (missing API key?)')
121
- cl = test_client
122
- t = cl.create_table('test_tbl', {'input': StringType()})
123
- from pixeltable.functions.together import completions
124
- t.add_column(output=completions(prompt=t.input, model='mistralai/Mixtral-8x7B-v0.1', stop=['\n']))
125
- t.add_column(output_text=t.output.output.choices[0].text)
126
- t.insert(input='I am going to the ')
127
- result = t.select(t.output_text).collect()['output_text'][0]
128
- assert len(result) > 0
129
-
130
- def test_fireworks(self, test_client: pxt.Client) -> None:
131
- skip_test_if_not_installed('fireworks')
132
- try:
133
- from pixeltable.functions.fireworks import initialize
134
- initialize()
135
- except:
136
- pytest.skip(f'Fireworks client does not exist (missing API key?)')
137
- cl = test_client
138
- t = cl.create_table('test_tbl', {'input': StringType()})
139
- from pixeltable.functions.fireworks import chat_completions
140
- t['output'] = chat_completions(prompt=t.input, model='accounts/fireworks/models/llama-v2-7b-chat', max_tokens=256).choices[0].text
141
- t.insert(input='I am going to the ')
142
- result = t.select(t.output).collect()['output'][0]
143
- assert len(result) > 0
11
+ class TestHuggingface:
144
12
 
145
13
  def test_hf_function(self, test_client: pxt.Client) -> None:
146
14
  skip_test_if_not_installed('sentence_transformers')
@@ -255,10 +123,10 @@ class TestFunctions:
255
123
  for idx, model_id in enumerate(model_ids):
256
124
  col_name = f'embed_text{idx}'
257
125
  t[col_name] = clip_text(t.text, model_id=model_id)
258
- assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
126
+ assert t.column_types()[col_name].is_array_type()
259
127
  col_name = f'embed_img{idx}'
260
128
  t[col_name] = clip_image(t.img, model_id=model_id)
261
- assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
129
+ assert t.column_types()[col_name].is_array_type()
262
130
 
263
131
  def verify_row(row: Dict[str, Any]) -> None:
264
132
  for idx, _ in enumerate(model_ids):
@@ -281,14 +149,10 @@ class TestFunctions:
281
149
  t = cl.create_table('test_tbl', {'img': ImageType()})
282
150
  from pixeltable.functions.huggingface import detr_for_object_detection
283
151
  t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
284
- status = t.insert(img=_sample_image_url)
152
+ status = t.insert(img=SAMPLE_IMAGE_URL)
285
153
  assert status.num_rows == 1
286
154
  assert status.num_excs == 0
287
155
  result = t.select(t.detect).collect()[0]['detect']
288
156
  assert 'orange' in result['label_text']
289
157
  assert 'bowl' in result['label_text']
290
158
  assert 'broccoli' in result['label_text']
291
-
292
-
293
- _sample_image_url = \
294
- 'https://raw.githubusercontent.com/pixeltable/pixeltable/master/docs/source/data/images/000000000009.jpg'