pixeltable 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (87) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +9 -5
  4. pixeltable/catalog/insertable_table.py +0 -2
  5. pixeltable/catalog/table.py +16 -8
  6. pixeltable/catalog/table_version.py +3 -2
  7. pixeltable/dataframe.py +184 -110
  8. pixeltable/env.py +69 -18
  9. pixeltable/exec/__init__.py +2 -1
  10. pixeltable/exec/data_row_batch.py +6 -7
  11. pixeltable/exec/expr_eval_node.py +28 -28
  12. pixeltable/exec/sql_scan_node.py +7 -6
  13. pixeltable/exprs/__init__.py +4 -3
  14. pixeltable/exprs/column_ref.py +9 -0
  15. pixeltable/exprs/expr.py +15 -7
  16. pixeltable/exprs/function_call.py +17 -15
  17. pixeltable/exprs/image_member_access.py +9 -28
  18. pixeltable/exprs/in_predicate.py +96 -0
  19. pixeltable/exprs/inline_array.py +13 -11
  20. pixeltable/exprs/inline_dict.py +15 -13
  21. pixeltable/exprs/row_builder.py +7 -1
  22. pixeltable/exprs/similarity_expr.py +65 -0
  23. pixeltable/func/__init__.py +0 -2
  24. pixeltable/func/aggregate_function.py +3 -0
  25. pixeltable/func/callable_function.py +57 -13
  26. pixeltable/func/expr_template_function.py +11 -2
  27. pixeltable/func/function.py +35 -4
  28. pixeltable/func/signature.py +5 -15
  29. pixeltable/func/udf.py +6 -10
  30. pixeltable/functions/huggingface.py +23 -4
  31. pixeltable/functions/openai.py +34 -1
  32. pixeltable/functions/pil/image.py +61 -64
  33. pixeltable/functions/together.py +21 -0
  34. pixeltable/globals.py +425 -0
  35. pixeltable/index/base.py +3 -1
  36. pixeltable/index/embedding_index.py +87 -14
  37. pixeltable/io/__init__.py +3 -0
  38. pixeltable/{utils → io}/hf_datasets.py +48 -17
  39. pixeltable/io/pandas.py +148 -0
  40. pixeltable/{utils → io}/parquet.py +58 -33
  41. pixeltable/iterators/__init__.py +1 -1
  42. pixeltable/iterators/base.py +4 -0
  43. pixeltable/iterators/document.py +218 -97
  44. pixeltable/iterators/video.py +8 -9
  45. pixeltable/metadata/__init__.py +7 -3
  46. pixeltable/metadata/converters/convert_12.py +3 -0
  47. pixeltable/metadata/converters/convert_13.py +41 -0
  48. pixeltable/plan.py +2 -19
  49. pixeltable/store.py +2 -2
  50. pixeltable/tool/create_test_db_dump.py +32 -13
  51. pixeltable/type_system.py +13 -54
  52. pixeltable/utils/documents.py +42 -12
  53. pixeltable/utils/http_server.py +70 -0
  54. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/METADATA +10 -7
  55. pixeltable-0.2.6.dist-info/RECORD +119 -0
  56. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  57. pixeltable/client.py +0 -600
  58. pixeltable/exprs/image_similarity_predicate.py +0 -58
  59. pixeltable/func/batched_function.py +0 -53
  60. pixeltable/tests/conftest.py +0 -171
  61. pixeltable/tests/ext/test_yolox.py +0 -21
  62. pixeltable/tests/functions/test_fireworks.py +0 -43
  63. pixeltable/tests/functions/test_functions.py +0 -60
  64. pixeltable/tests/functions/test_huggingface.py +0 -158
  65. pixeltable/tests/functions/test_openai.py +0 -162
  66. pixeltable/tests/functions/test_together.py +0 -112
  67. pixeltable/tests/test_audio.py +0 -65
  68. pixeltable/tests/test_catalog.py +0 -27
  69. pixeltable/tests/test_client.py +0 -21
  70. pixeltable/tests/test_component_view.py +0 -379
  71. pixeltable/tests/test_dataframe.py +0 -440
  72. pixeltable/tests/test_dirs.py +0 -107
  73. pixeltable/tests/test_document.py +0 -120
  74. pixeltable/tests/test_exprs.py +0 -802
  75. pixeltable/tests/test_function.py +0 -332
  76. pixeltable/tests/test_index.py +0 -138
  77. pixeltable/tests/test_migration.py +0 -44
  78. pixeltable/tests/test_nos.py +0 -54
  79. pixeltable/tests/test_snapshot.py +0 -231
  80. pixeltable/tests/test_table.py +0 -1343
  81. pixeltable/tests/test_transactional_directory.py +0 -42
  82. pixeltable/tests/test_types.py +0 -52
  83. pixeltable/tests/test_video.py +0 -159
  84. pixeltable/tests/test_view.py +0 -535
  85. pixeltable/tests/utils.py +0 -442
  86. pixeltable-0.2.5.dist-info/RECORD +0 -139
  87. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
pixeltable/globals.py ADDED
@@ -0,0 +1,425 @@
1
+ import dataclasses
2
+ import logging
3
+ from typing import Any, Optional, Union, Type
4
+
5
+ import pandas as pd
6
+ import sqlalchemy as sql
7
+ from sqlalchemy.util.preloaded import orm
8
+
9
+ import pixeltable.exceptions as excs
10
+ from pixeltable import catalog, func
11
+ from pixeltable.catalog import Catalog
12
+ from pixeltable.env import Env
13
+ from pixeltable.exprs import Predicate
14
+ from pixeltable.iterators import ComponentIterator
15
+ from pixeltable.metadata import schema
16
+
17
+ _logger = logging.getLogger('pixeltable')
18
+
19
+
20
+ def init() -> None:
21
+ """Initializes the Pixeltable environment."""
22
+ _ = Catalog.get()
23
+
24
+
25
+ def create_table(
26
+ path_str: str,
27
+ schema: dict[str, Any],
28
+ *,
29
+ primary_key: Optional[Union[str, list[str]]] = None,
30
+ num_retained_versions: int = 10,
31
+ comment: str = '',
32
+ ) -> catalog.InsertableTable:
33
+ """Create a new `InsertableTable`.
34
+
35
+ Args:
36
+ path_str: Path to the table.
37
+ schema: dictionary mapping column names to column types, value expressions, or to column specifications.
38
+ num_retained_versions: Number of versions of the table to retain.
39
+
40
+ Returns:
41
+ The newly created table.
42
+
43
+ Raises:
44
+ Error: if the path already exists or is invalid.
45
+
46
+ Examples:
47
+ Create a table with an int and a string column:
48
+
49
+ >>> table = cl.create_table('my_table', schema={'col1': IntType(), 'col2': StringType()})
50
+ """
51
+ path = catalog.Path(path_str)
52
+ Catalog.get().paths.check_is_valid(path, expected=None)
53
+ dir = Catalog.get().paths[path.parent]
54
+
55
+ if len(schema) == 0:
56
+ raise excs.Error(f'Table schema is empty: `{path_str}`')
57
+
58
+ if primary_key is None:
59
+ primary_key = []
60
+ elif isinstance(primary_key, str):
61
+ primary_key = [primary_key]
62
+ else:
63
+ if not isinstance(primary_key, list) or not all(isinstance(pk, str) for pk in primary_key):
64
+ raise excs.Error('primary_key must be a single column name or a list of column names')
65
+
66
+ tbl = catalog.InsertableTable.create(
67
+ dir._id,
68
+ path.name,
69
+ schema,
70
+ primary_key=primary_key,
71
+ num_retained_versions=num_retained_versions,
72
+ comment=comment,
73
+ )
74
+ Catalog.get().paths[path] = tbl
75
+ _logger.info(f'Created table `{path_str}`.')
76
+ return tbl
77
+
78
+
79
+ def create_view(
80
+ path_str: str,
81
+ base: catalog.Table,
82
+ *,
83
+ schema: Optional[dict[str, Any]] = None,
84
+ filter: Optional[Predicate] = None,
85
+ is_snapshot: bool = False,
86
+ iterator: Optional[tuple[type[ComponentIterator], dict[str, Any]]] = None,
87
+ num_retained_versions: int = 10,
88
+ comment: str = '',
89
+ ignore_errors: bool = False,
90
+ ) -> catalog.View:
91
+ """Create a new `View`.
92
+
93
+ Args:
94
+ path_str: Path to the view.
95
+ base: Table (ie, table or view or snapshot) to base the view on.
96
+ schema: dictionary mapping column names to column types, value expressions, or to column specifications.
97
+ filter: Predicate to filter rows of the base table.
98
+ is_snapshot: Whether the view is a snapshot.
99
+ iterator_class: Class of the iterator to use for the view.
100
+ iterator_args: Arguments to pass to the iterator class.
101
+ num_retained_versions: Number of versions of the view to retain.
102
+ ignore_errors: if True, fail silently if the path already exists or is invalid.
103
+
104
+ Returns:
105
+ The newly created view.
106
+
107
+ Raises:
108
+ Error: if the path already exists or is invalid.
109
+
110
+ Examples:
111
+ Create a view with an additional int and a string column and a filter:
112
+
113
+ >>> view = cl.create_view(
114
+ 'my_view', base, schema={'col3': IntType(), 'col4': StringType()}, filter=base.col1 > 10)
115
+
116
+ Create a table snapshot:
117
+
118
+ >>> snapshot_view = cl.create_view('my_snapshot_view', base, is_snapshot=True)
119
+
120
+ Create an immutable view with additional computed columns and a filter:
121
+
122
+ >>> snapshot_view = cl.create_view(
123
+ 'my_snapshot', base, schema={'col3': base.col2 + 1}, filter=base.col1 > 10, is_snapshot=True)
124
+ """
125
+ assert isinstance(base, catalog.Table)
126
+ path = catalog.Path(path_str)
127
+ try:
128
+ Catalog.get().paths.check_is_valid(path, expected=None)
129
+ except Exception as e:
130
+ if ignore_errors:
131
+ return
132
+ else:
133
+ raise e
134
+ dir = Catalog.get().paths[path.parent]
135
+
136
+ if schema is None:
137
+ schema = {}
138
+ if iterator is None:
139
+ iterator_class, iterator_args = None, None
140
+ else:
141
+ iterator_class, iterator_args = iterator
142
+ view = catalog.View.create(
143
+ dir._id,
144
+ path.name,
145
+ base=base,
146
+ schema=schema,
147
+ predicate=filter,
148
+ is_snapshot=is_snapshot,
149
+ iterator_cls=iterator_class,
150
+ iterator_args=iterator_args,
151
+ num_retained_versions=num_retained_versions,
152
+ comment=comment,
153
+ )
154
+ Catalog.get().paths[path] = view
155
+ _logger.info(f'Created view `{path_str}`.')
156
+ return view
157
+
158
+
159
+ def get_table(path: str) -> catalog.Table:
160
+ """Get a handle to a table (including views and snapshots).
161
+
162
+ Args:
163
+ path: Path to the table.
164
+
165
+ Returns:
166
+ A `InsertableTable` or `View` object.
167
+
168
+ Raises:
169
+ Error: If the path does not exist or does not designate a table.
170
+
171
+ Examples:
172
+ Get handle for a table in the top-level directory:
173
+
174
+ >>> table = cl.get_table('my_table')
175
+
176
+ For a table in a subdirectory:
177
+
178
+ >>> table = cl.get_table('subdir.my_table')
179
+
180
+ For a snapshot in the top-level directory:
181
+
182
+ >>> table = cl.get_table('my_snapshot')
183
+ """
184
+ p = catalog.Path(path)
185
+ Catalog.get().paths.check_is_valid(p, expected=catalog.Table)
186
+ obj = Catalog.get().paths[p]
187
+ return obj
188
+
189
+
190
+ def move(path: str, new_path: str) -> None:
191
+ """Move a schema object to a new directory and/or rename a schema object.
192
+
193
+ Args:
194
+ path: absolute path to the existing schema object.
195
+ new_path: absolute new path for the schema object.
196
+
197
+ Raises:
198
+ Error: If path does not exist or new_path already exists.
199
+
200
+ Examples:
201
+ Move a table to a different directory:
202
+
203
+ >>>> cl.move('dir1.my_table', 'dir2.my_table')
204
+
205
+ Rename a table:
206
+
207
+ >>>> cl.move('dir1.my_table', 'dir1.new_name')
208
+ """
209
+ p = catalog.Path(path)
210
+ Catalog.get().paths.check_is_valid(p, expected=catalog.SchemaObject)
211
+ new_p = catalog.Path(new_path)
212
+ Catalog.get().paths.check_is_valid(new_p, expected=None)
213
+ obj = Catalog.get().paths[p]
214
+ Catalog.get().paths.move(p, new_p)
215
+ new_dir = Catalog.get().paths[new_p.parent]
216
+ obj.move(new_p.name, new_dir._id)
217
+
218
+
219
+ def drop_table(path: str, force: bool = False, ignore_errors: bool = False) -> None:
220
+ """Drop a table.
221
+
222
+ Args:
223
+ path: Path to the table.
224
+ force: Whether to drop the table even if it has unsaved changes.
225
+ ignore_errors: Whether to ignore errors if the table does not exist.
226
+
227
+ Raises:
228
+ Error: If the path does not exist or does not designate a table and ignore_errors is False.
229
+
230
+ Examples:
231
+ >>> cl.drop_table('my_table')
232
+ """
233
+ path_obj = catalog.Path(path)
234
+ try:
235
+ Catalog.get().paths.check_is_valid(path_obj, expected=catalog.Table)
236
+ except Exception as e:
237
+ if ignore_errors:
238
+ _logger.info(f'Skipped table `{path}` (does not exist).')
239
+ return
240
+ else:
241
+ raise e
242
+ tbl = Catalog.get().paths[path_obj]
243
+ if len(Catalog.get().tbl_dependents[tbl._id]) > 0:
244
+ dependent_paths = [get_path(dep) for dep in Catalog.get().tbl_dependents[tbl._id]]
245
+ raise excs.Error(f'Table {path} has dependents: {", ".join(dependent_paths)}')
246
+ tbl._drop()
247
+ del Catalog.get().paths[path_obj]
248
+ _logger.info(f'Dropped table `{path}`.')
249
+
250
+
251
+ def list_tables(dir_path: str = '', recursive: bool = True) -> list[str]:
252
+ """List the tables in a directory.
253
+
254
+ Args:
255
+ dir_path: Path to the directory. Defaults to the root directory.
256
+ recursive: Whether to list tables in subdirectories as well.
257
+
258
+ Returns:
259
+ A list of table paths.
260
+
261
+ Raises:
262
+ Error: If the path does not exist or does not designate a directory.
263
+
264
+ Examples:
265
+ List tables in top-level directory:
266
+
267
+ >>> cl.list_tables()
268
+ ['my_table', ...]
269
+
270
+ List tables in 'dir1':
271
+
272
+ >>> cl.list_tables('dir1')
273
+ [...]
274
+ """
275
+ assert dir_path is not None
276
+ path = catalog.Path(dir_path, empty_is_valid=True)
277
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
278
+ return [str(p) for p in Catalog.get().paths.get_children(path, child_type=catalog.Table, recursive=recursive)]
279
+
280
+
281
+ def create_dir(path_str: str, ignore_errors: bool = False) -> None:
282
+ """Create a directory.
283
+
284
+ Args:
285
+ path_str: Path to the directory.
286
+ ignore_errors: if True, silently returns on error
287
+
288
+ Raises:
289
+ Error: If the path already exists or the parent is not a directory.
290
+
291
+ Examples:
292
+ >>> cl.create_dir('my_dir')
293
+
294
+ Create a subdirectory:
295
+
296
+ >>> cl.create_dir('my_dir.sub_dir')
297
+ """
298
+ try:
299
+ path = catalog.Path(path_str)
300
+ Catalog.get().paths.check_is_valid(path, expected=None)
301
+ parent = Catalog.get().paths[path.parent]
302
+ assert parent is not None
303
+ with orm.Session(Env.get().engine, future=True) as session:
304
+ dir_md = schema.DirMd(name=path.name)
305
+ dir_record = schema.Dir(parent_id=parent._id, md=dataclasses.asdict(dir_md))
306
+ session.add(dir_record)
307
+ session.flush()
308
+ assert dir_record.id is not None
309
+ Catalog.get().paths[path] = catalog.Dir(dir_record.id, parent._id, path.name)
310
+ session.commit()
311
+ _logger.info(f'Created directory `{path_str}`.')
312
+ print(f'Created directory `{path_str}`.')
313
+ except excs.Error as e:
314
+ if ignore_errors:
315
+ return
316
+ else:
317
+ raise e
318
+
319
+
320
+ def rm_dir(path_str: str) -> None:
321
+ """Remove a directory.
322
+
323
+ Args:
324
+ path_str: Path to the directory.
325
+
326
+ Raises:
327
+ Error: If the path does not exist or does not designate a directory or if the directory is not empty.
328
+
329
+ Examples:
330
+ >>> cl.rm_dir('my_dir')
331
+
332
+ Remove a subdirectory:
333
+
334
+ >>> cl.rm_dir('my_dir.sub_dir')
335
+ """
336
+ path = catalog.Path(path_str)
337
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
338
+
339
+ # make sure it's empty
340
+ if len(Catalog.get().paths.get_children(path, child_type=None, recursive=True)) > 0:
341
+ raise excs.Error(f'Directory {path_str} is not empty')
342
+ # TODO: figure out how to make force=True work in the presence of snapshots
343
+ # # delete tables
344
+ # for tbl_path in self.paths.get_children(path, child_type=MutableTable, recursive=True):
345
+ # self.drop_table(str(tbl_path), force=True)
346
+ # # rm subdirs
347
+ # for dir_path in self.paths.get_children(path, child_type=Dir, recursive=False):
348
+ # self.rm_dir(str(dir_path), force=True)
349
+
350
+ with Env.get().engine.begin() as conn:
351
+ dir = Catalog.get().paths[path]
352
+ conn.execute(sql.delete(schema.Dir.__table__).where(schema.Dir.id == dir._id))
353
+ del Catalog.get().paths[path]
354
+ _logger.info(f'Removed directory {path_str}')
355
+
356
+
357
+ def list_dirs(path_str: str = '', recursive: bool = True) -> list[str]:
358
+ """List the directories in a directory.
359
+
360
+ Args:
361
+ path_str: Path to the directory.
362
+ recursive: Whether to list subdirectories recursively.
363
+
364
+ Returns:
365
+ List of directory paths.
366
+
367
+ Raises:
368
+ Error: If the path does not exist or does not designate a directory.
369
+
370
+ Examples:
371
+ >>> cl.list_dirs('my_dir', recursive=True)
372
+ ['my_dir', 'my_dir.sub_dir1']
373
+ """
374
+ path = catalog.Path(path_str, empty_is_valid=True)
375
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
376
+ return [str(p) for p in Catalog.get().paths.get_children(path, child_type=catalog.Dir, recursive=recursive)]
377
+
378
+
379
+ def list_functions() -> pd.DataFrame:
380
+ """Returns information about all registered functions.
381
+
382
+ Returns:
383
+ Pandas DataFrame with columns 'Path', 'Name', 'Parameters', 'Return Type', 'Is Agg', 'Library'
384
+ """
385
+ functions = func.FunctionRegistry.get().list_functions()
386
+ paths = ['.'.join(f.self_path.split('.')[:-1]) for f in functions]
387
+ names = [f.name for f in functions]
388
+ params = [
389
+ ', '.join([param_name + ': ' + str(param_type) for param_name, param_type in f.signature.parameters.items()])
390
+ for f in functions
391
+ ]
392
+ pd_df = pd.DataFrame(
393
+ {
394
+ 'Path': paths,
395
+ 'Function Name': names,
396
+ 'Parameters': params,
397
+ 'Return Type': [str(f.signature.get_return_type()) for f in functions],
398
+ }
399
+ )
400
+ pd_df = pd_df.style.set_properties(**{'text-align': 'left'}).set_table_styles(
401
+ [dict(selector='th', props=[('text-align', 'center')])]
402
+ ) # center-align headings
403
+ return pd_df.hide(axis='index')
404
+
405
+
406
+ def get_path(schema_obj: catalog.SchemaObject) -> str:
407
+ """Returns the path to a SchemaObject.
408
+
409
+ Args:
410
+ schema_obj: SchemaObject to get the path for.
411
+
412
+ Returns:
413
+ Path to the SchemaObject.
414
+ """
415
+ path_elements: list[str] = []
416
+ dir_id = schema_obj._dir_id
417
+ while dir_id is not None:
418
+ dir = Catalog.get().paths.get_schema_obj(dir_id)
419
+ if dir._dir_id is None:
420
+ # this is the root dir with name '', which we don't want to include in the path
421
+ break
422
+ path_elements.insert(0, dir._name)
423
+ dir_id = dir._dir_id
424
+ path_elements.append(schema_obj._name)
425
+ return '.'.join(path_elements)
pixeltable/index/base.py CHANGED
@@ -13,8 +13,10 @@ class IndexBase(abc.ABC):
13
13
  Internal interface used by the catalog and runtime system to interact with indices:
14
14
  - types and expressions needed to create and populate the index value column
15
15
  - creating/dropping the index
16
- - TODO: translating queries into sqlalchemy predicates
16
+ This doesn't cover querying the index, which is dependent on the index semantics and handled by
17
+ the specific subclass.
17
18
  """
19
+
18
20
  @abc.abstractmethod
19
21
  def __init__(self, c: catalog.Column, **kwargs: Any):
20
22
  pass
@@ -1,8 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
3
+ from typing import Optional, Any
4
+ import enum
4
5
 
6
+ import PIL.Image
7
+ import numpy as np
5
8
  import pgvector.sqlalchemy
9
+ import PIL.Image
6
10
  import sqlalchemy as sql
7
11
 
8
12
  import pixeltable.catalog as catalog
@@ -14,15 +18,31 @@ from .base import IndexBase
14
18
 
15
19
  class EmbeddingIndex(IndexBase):
16
20
  """
17
- Internal interface used by the catalog and runtime system to interact with (embedding) indices:
18
- - types and expressions needed to create and populate the index value column
19
- - creating/dropping the index
20
- - translating 'matches' queries into sqlalchemy predicates
21
+ Interface to the pgvector access method in Postgres.
22
+ - pgvector converts the cosine metric to '1 - metric' and the inner product metric to '-metric', in order to
23
+ satisfy the Postgres requirement that an index scan requires an ORDER BY ... ASC clause
24
+ - similarity_clause() converts those metrics back to their original form; it is used in expressions outside
25
+ the Order By clause
26
+ - order_by_clause() is used exclusively in the ORDER BY clause
21
27
  """
22
28
 
29
+ class Metric(enum.Enum):
30
+ COSINE = 1
31
+ IP = 2
32
+ L2 = 3
33
+
34
+ PGVECTOR_OPS = {
35
+ Metric.COSINE: 'vector_cosine_ops',
36
+ Metric.IP: 'vector_ip_ops',
37
+ Metric.L2: 'vector_l2_ops'
38
+ }
39
+
23
40
  def __init__(
24
- self, c: catalog.Column, text_embed: Optional[func.Function] = None,
41
+ self, c: catalog.Column, metric: str, text_embed: Optional[func.Function] = None,
25
42
  img_embed: Optional[func.Function] = None):
43
+ metric_names = [m.name.lower() for m in self.Metric]
44
+ if metric.lower() not in metric_names:
45
+ raise excs.Error(f'Invalid metric {metric}, must be one of {metric_names}')
26
46
  if not c.col_type.is_string_type() and not c.col_type.is_image_type():
27
47
  raise excs.Error(f'Embedding index requires string or image column')
28
48
  if c.col_type.is_string_type() and text_embed is None:
@@ -36,6 +56,7 @@ class EmbeddingIndex(IndexBase):
36
56
  # verify signature
37
57
  self._validate_embedding_fn(img_embed, 'img_embed', ts.ColumnType.Type.IMAGE)
38
58
 
59
+ self.metric = self.Metric[metric.upper()]
39
60
  from pixeltable.exprs import ColumnRef
40
61
  self.value_expr = text_embed(ColumnRef(c)) if c.col_type.is_string_type() else img_embed(ColumnRef(c))
41
62
  assert self.value_expr.col_type.is_array_type()
@@ -59,10 +80,51 @@ class EmbeddingIndex(IndexBase):
59
80
  index_name, index_value_col.sa_col,
60
81
  postgresql_using='hnsw',
61
82
  postgresql_with={'m': 16, 'ef_construction': 64},
62
- postgresql_ops={index_value_col.sa_col.name: 'vector_cosine_ops'}
83
+ postgresql_ops={index_value_col.sa_col.name: self.PGVECTOR_OPS[self.metric]}
63
84
  )
64
85
  idx.create(bind=conn)
65
86
 
87
+ def similarity_clause(self, val_column: catalog.Column, item: Any) -> sql.ClauseElement:
88
+ """Create a ClauseElement to that represents '<val_column> <op> <item>'"""
89
+ assert isinstance(item, (str, PIL.Image.Image))
90
+ if isinstance(item, str):
91
+ assert self.txt_embed is not None
92
+ embedding = self.txt_embed.exec(item)
93
+ if isinstance(item, PIL.Image.Image):
94
+ assert self.img_embed is not None
95
+ embedding = self.img_embed.exec(item)
96
+
97
+ if self.metric == self.Metric.COSINE:
98
+ return val_column.sa_col.cosine_distance(embedding) * -1 + 1
99
+ elif self.metric == self.Metric.IP:
100
+ return val_column.sa_col.max_inner_product(embedding) * -1
101
+ else:
102
+ assert self.metric == self.Metric.L2
103
+ return val_column.sa_col.l2_distance(embedding)
104
+
105
+ def order_by_clause(self, val_column: catalog.Column, item: Any, is_asc: bool) -> sql.ClauseElement:
106
+ """Create a ClauseElement that is used in an ORDER BY clause"""
107
+ assert isinstance(item, (str, PIL.Image.Image))
108
+ embedding: Optional[np.ndarray] = None
109
+ if isinstance(item, str):
110
+ assert self.txt_embed is not None
111
+ embedding = self.txt_embed.exec(item)
112
+ if isinstance(item, PIL.Image.Image):
113
+ assert self.img_embed is not None
114
+ embedding = self.img_embed.exec(item)
115
+ assert embedding is not None
116
+
117
+ if self.metric == self.Metric.COSINE:
118
+ result = val_column.sa_col.cosine_distance(embedding)
119
+ result = result.desc() if is_asc else result
120
+ elif self.metric == self.Metric.IP:
121
+ result = val_column.sa_col.max_inner_product(embedding)
122
+ result = result.desc() if is_asc else result
123
+ else:
124
+ assert self.metric == self.Metric.L2
125
+ result = val_column.sa_col.l2_distance(embedding)
126
+ return result
127
+
66
128
  @classmethod
67
129
  def display_name(cls) -> str:
68
130
  return 'embedding'
@@ -72,18 +134,29 @@ class EmbeddingIndex(IndexBase):
72
134
  """Validate the signature"""
73
135
  assert isinstance(embed_fn, func.Function)
74
136
  sig = embed_fn.signature
75
- if not sig.return_type.is_array_type():
76
- raise excs.Error(f'{name} must return an array, but returns {sig.return_type}')
77
- else:
78
- shape = sig.return_type.shape
79
- if len(shape) != 1 or shape[0] == None:
80
- raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {sig.return_type}')
81
137
  if len(sig.parameters) != 1 or sig.parameters_by_pos[0].col_type.type_enum != expected_type:
82
138
  raise excs.Error(
83
139
  f'{name} must take a single {expected_type.name.lower()} parameter, but has signature {sig}')
84
140
 
141
+ # validate return type
142
+ param_name = sig.parameters_by_pos[0].name
143
+ if expected_type == ts.ColumnType.Type.STRING:
144
+ return_type = embed_fn.call_return_type({param_name: 'dummy'})
145
+ else:
146
+ assert expected_type == ts.ColumnType.Type.IMAGE
147
+ img = PIL.Image.new('RGB', (512, 512))
148
+ return_type = embed_fn.call_return_type({param_name: img})
149
+ assert return_type is not None
150
+ if not return_type.is_array_type():
151
+ raise excs.Error(f'{name} must return an array, but returns {return_type}')
152
+ else:
153
+ shape = return_type.shape
154
+ if len(shape) != 1 or shape[0] == None:
155
+ raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {return_type}')
156
+
85
157
  def as_dict(self) -> dict:
86
158
  return {
159
+ 'metric': self.metric.name.lower(),
87
160
  'txt_embed': None if self.txt_embed is None else self.txt_embed.as_dict(),
88
161
  'img_embed': None if self.img_embed is None else self.img_embed.as_dict()
89
162
  }
@@ -92,4 +165,4 @@ class EmbeddingIndex(IndexBase):
92
165
  def from_dict(cls, c: catalog.Column, d: dict) -> EmbeddingIndex:
93
166
  txt_embed = func.Function.from_dict(d['txt_embed']) if d['txt_embed'] is not None else None
94
167
  img_embed = func.Function.from_dict(d['img_embed']) if d['img_embed'] is not None else None
95
- return cls(c, text_embed=txt_embed, img_embed=img_embed)
168
+ return cls(c, metric=d['metric'], text_embed=txt_embed, img_embed=img_embed)
@@ -0,0 +1,3 @@
1
+ from .hf_datasets import import_huggingface_dataset
2
+ from .pandas import import_csv, import_excel, import_pandas
3
+ from .parquet import import_parquet