pixeltable 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (119) hide show
  1. pixeltable/__init__.py +53 -0
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/__init__.py +13 -0
  4. pixeltable/catalog/catalog.py +159 -0
  5. pixeltable/catalog/column.py +181 -0
  6. pixeltable/catalog/dir.py +32 -0
  7. pixeltable/catalog/globals.py +33 -0
  8. pixeltable/catalog/insertable_table.py +192 -0
  9. pixeltable/catalog/named_function.py +36 -0
  10. pixeltable/catalog/path.py +58 -0
  11. pixeltable/catalog/path_dict.py +139 -0
  12. pixeltable/catalog/schema_object.py +39 -0
  13. pixeltable/catalog/table.py +695 -0
  14. pixeltable/catalog/table_version.py +1026 -0
  15. pixeltable/catalog/table_version_path.py +133 -0
  16. pixeltable/catalog/view.py +203 -0
  17. pixeltable/dataframe.py +749 -0
  18. pixeltable/env.py +466 -0
  19. pixeltable/exceptions.py +17 -0
  20. pixeltable/exec/__init__.py +10 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +94 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +73 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +226 -0
  31. pixeltable/exprs/__init__.py +25 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +114 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +199 -0
  39. pixeltable/exprs/expr.py +594 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +382 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +96 -0
  44. pixeltable/exprs/in_predicate.py +96 -0
  45. pixeltable/exprs/inline_array.py +109 -0
  46. pixeltable/exprs/inline_dict.py +103 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +66 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +329 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/similarity_expr.py +65 -0
  56. pixeltable/exprs/type_cast.py +53 -0
  57. pixeltable/exprs/variable.py +45 -0
  58. pixeltable/ext/__init__.py +5 -0
  59. pixeltable/ext/functions/yolox.py +92 -0
  60. pixeltable/func/__init__.py +7 -0
  61. pixeltable/func/aggregate_function.py +197 -0
  62. pixeltable/func/callable_function.py +113 -0
  63. pixeltable/func/expr_template_function.py +99 -0
  64. pixeltable/func/function.py +141 -0
  65. pixeltable/func/function_registry.py +227 -0
  66. pixeltable/func/globals.py +46 -0
  67. pixeltable/func/nos_function.py +202 -0
  68. pixeltable/func/signature.py +162 -0
  69. pixeltable/func/udf.py +164 -0
  70. pixeltable/functions/__init__.py +95 -0
  71. pixeltable/functions/eval.py +215 -0
  72. pixeltable/functions/fireworks.py +34 -0
  73. pixeltable/functions/huggingface.py +167 -0
  74. pixeltable/functions/image.py +16 -0
  75. pixeltable/functions/openai.py +289 -0
  76. pixeltable/functions/pil/image.py +147 -0
  77. pixeltable/functions/string.py +13 -0
  78. pixeltable/functions/together.py +143 -0
  79. pixeltable/functions/util.py +52 -0
  80. pixeltable/functions/video.py +62 -0
  81. pixeltable/globals.py +425 -0
  82. pixeltable/index/__init__.py +2 -0
  83. pixeltable/index/base.py +51 -0
  84. pixeltable/index/embedding_index.py +168 -0
  85. pixeltable/io/__init__.py +3 -0
  86. pixeltable/io/hf_datasets.py +188 -0
  87. pixeltable/io/pandas.py +148 -0
  88. pixeltable/io/parquet.py +192 -0
  89. pixeltable/iterators/__init__.py +3 -0
  90. pixeltable/iterators/base.py +52 -0
  91. pixeltable/iterators/document.py +432 -0
  92. pixeltable/iterators/video.py +88 -0
  93. pixeltable/metadata/__init__.py +58 -0
  94. pixeltable/metadata/converters/convert_10.py +18 -0
  95. pixeltable/metadata/converters/convert_12.py +3 -0
  96. pixeltable/metadata/converters/convert_13.py +41 -0
  97. pixeltable/metadata/schema.py +234 -0
  98. pixeltable/plan.py +620 -0
  99. pixeltable/store.py +424 -0
  100. pixeltable/tool/create_test_db_dump.py +184 -0
  101. pixeltable/tool/create_test_video.py +81 -0
  102. pixeltable/type_system.py +846 -0
  103. pixeltable/utils/__init__.py +17 -0
  104. pixeltable/utils/arrow.py +98 -0
  105. pixeltable/utils/clip.py +18 -0
  106. pixeltable/utils/coco.py +136 -0
  107. pixeltable/utils/documents.py +69 -0
  108. pixeltable/utils/filecache.py +195 -0
  109. pixeltable/utils/help.py +11 -0
  110. pixeltable/utils/http_server.py +70 -0
  111. pixeltable/utils/media_store.py +76 -0
  112. pixeltable/utils/pytorch.py +91 -0
  113. pixeltable/utils/s3.py +13 -0
  114. pixeltable/utils/sql.py +17 -0
  115. pixeltable/utils/transactional_directory.py +35 -0
  116. pixeltable-0.0.0.dist-info/LICENSE +18 -0
  117. pixeltable-0.0.0.dist-info/METADATA +131 -0
  118. pixeltable-0.0.0.dist-info/RECORD +119 -0
  119. pixeltable-0.0.0.dist-info/WHEEL +4 -0
pixeltable/globals.py ADDED
@@ -0,0 +1,425 @@
1
+ import dataclasses
2
+ import logging
3
+ from typing import Any, Optional, Union, Type
4
+
5
+ import pandas as pd
6
+ import sqlalchemy as sql
7
+ from sqlalchemy.util.preloaded import orm
8
+
9
+ import pixeltable.exceptions as excs
10
+ from pixeltable import catalog, func
11
+ from pixeltable.catalog import Catalog
12
+ from pixeltable.env import Env
13
+ from pixeltable.exprs import Predicate
14
+ from pixeltable.iterators import ComponentIterator
15
+ from pixeltable.metadata import schema
16
+
17
+ _logger = logging.getLogger('pixeltable')
18
+
19
+
20
+ def init() -> None:
21
+ """Initializes the Pixeltable environment."""
22
+ _ = Catalog.get()
23
+
24
+
25
+ def create_table(
26
+ path_str: str,
27
+ schema: dict[str, Any],
28
+ *,
29
+ primary_key: Optional[Union[str, list[str]]] = None,
30
+ num_retained_versions: int = 10,
31
+ comment: str = '',
32
+ ) -> catalog.InsertableTable:
33
+ """Create a new `InsertableTable`.
34
+
35
+ Args:
36
+ path_str: Path to the table.
37
+ schema: dictionary mapping column names to column types, value expressions, or to column specifications.
38
+ num_retained_versions: Number of versions of the table to retain.
39
+
40
+ Returns:
41
+ The newly created table.
42
+
43
+ Raises:
44
+ Error: if the path already exists or is invalid.
45
+
46
+ Examples:
47
+ Create a table with an int and a string column:
48
+
49
+ >>> table = cl.create_table('my_table', schema={'col1': IntType(), 'col2': StringType()})
50
+ """
51
+ path = catalog.Path(path_str)
52
+ Catalog.get().paths.check_is_valid(path, expected=None)
53
+ dir = Catalog.get().paths[path.parent]
54
+
55
+ if len(schema) == 0:
56
+ raise excs.Error(f'Table schema is empty: `{path_str}`')
57
+
58
+ if primary_key is None:
59
+ primary_key = []
60
+ elif isinstance(primary_key, str):
61
+ primary_key = [primary_key]
62
+ else:
63
+ if not isinstance(primary_key, list) or not all(isinstance(pk, str) for pk in primary_key):
64
+ raise excs.Error('primary_key must be a single column name or a list of column names')
65
+
66
+ tbl = catalog.InsertableTable.create(
67
+ dir._id,
68
+ path.name,
69
+ schema,
70
+ primary_key=primary_key,
71
+ num_retained_versions=num_retained_versions,
72
+ comment=comment,
73
+ )
74
+ Catalog.get().paths[path] = tbl
75
+ _logger.info(f'Created table `{path_str}`.')
76
+ return tbl
77
+
78
+
79
+ def create_view(
80
+ path_str: str,
81
+ base: catalog.Table,
82
+ *,
83
+ schema: Optional[dict[str, Any]] = None,
84
+ filter: Optional[Predicate] = None,
85
+ is_snapshot: bool = False,
86
+ iterator: Optional[tuple[type[ComponentIterator], dict[str, Any]]] = None,
87
+ num_retained_versions: int = 10,
88
+ comment: str = '',
89
+ ignore_errors: bool = False,
90
+ ) -> catalog.View:
91
+ """Create a new `View`.
92
+
93
+ Args:
94
+ path_str: Path to the view.
95
+ base: Table (ie, table or view or snapshot) to base the view on.
96
+ schema: dictionary mapping column names to column types, value expressions, or to column specifications.
97
+ filter: Predicate to filter rows of the base table.
98
+ is_snapshot: Whether the view is a snapshot.
99
+ iterator_class: Class of the iterator to use for the view.
100
+ iterator_args: Arguments to pass to the iterator class.
101
+ num_retained_versions: Number of versions of the view to retain.
102
+ ignore_errors: if True, fail silently if the path already exists or is invalid.
103
+
104
+ Returns:
105
+ The newly created view.
106
+
107
+ Raises:
108
+ Error: if the path already exists or is invalid.
109
+
110
+ Examples:
111
+ Create a view with an additional int and a string column and a filter:
112
+
113
+ >>> view = cl.create_view(
114
+ 'my_view', base, schema={'col3': IntType(), 'col4': StringType()}, filter=base.col1 > 10)
115
+
116
+ Create a table snapshot:
117
+
118
+ >>> snapshot_view = cl.create_view('my_snapshot_view', base, is_snapshot=True)
119
+
120
+ Create an immutable view with additional computed columns and a filter:
121
+
122
+ >>> snapshot_view = cl.create_view(
123
+ 'my_snapshot', base, schema={'col3': base.col2 + 1}, filter=base.col1 > 10, is_snapshot=True)
124
+ """
125
+ assert isinstance(base, catalog.Table)
126
+ path = catalog.Path(path_str)
127
+ try:
128
+ Catalog.get().paths.check_is_valid(path, expected=None)
129
+ except Exception as e:
130
+ if ignore_errors:
131
+ return
132
+ else:
133
+ raise e
134
+ dir = Catalog.get().paths[path.parent]
135
+
136
+ if schema is None:
137
+ schema = {}
138
+ if iterator is None:
139
+ iterator_class, iterator_args = None, None
140
+ else:
141
+ iterator_class, iterator_args = iterator
142
+ view = catalog.View.create(
143
+ dir._id,
144
+ path.name,
145
+ base=base,
146
+ schema=schema,
147
+ predicate=filter,
148
+ is_snapshot=is_snapshot,
149
+ iterator_cls=iterator_class,
150
+ iterator_args=iterator_args,
151
+ num_retained_versions=num_retained_versions,
152
+ comment=comment,
153
+ )
154
+ Catalog.get().paths[path] = view
155
+ _logger.info(f'Created view `{path_str}`.')
156
+ return view
157
+
158
+
159
+ def get_table(path: str) -> catalog.Table:
160
+ """Get a handle to a table (including views and snapshots).
161
+
162
+ Args:
163
+ path: Path to the table.
164
+
165
+ Returns:
166
+ A `InsertableTable` or `View` object.
167
+
168
+ Raises:
169
+ Error: If the path does not exist or does not designate a table.
170
+
171
+ Examples:
172
+ Get handle for a table in the top-level directory:
173
+
174
+ >>> table = cl.get_table('my_table')
175
+
176
+ For a table in a subdirectory:
177
+
178
+ >>> table = cl.get_table('subdir.my_table')
179
+
180
+ For a snapshot in the top-level directory:
181
+
182
+ >>> table = cl.get_table('my_snapshot')
183
+ """
184
+ p = catalog.Path(path)
185
+ Catalog.get().paths.check_is_valid(p, expected=catalog.Table)
186
+ obj = Catalog.get().paths[p]
187
+ return obj
188
+
189
+
190
+ def move(path: str, new_path: str) -> None:
191
+ """Move a schema object to a new directory and/or rename a schema object.
192
+
193
+ Args:
194
+ path: absolute path to the existing schema object.
195
+ new_path: absolute new path for the schema object.
196
+
197
+ Raises:
198
+ Error: If path does not exist or new_path already exists.
199
+
200
+ Examples:
201
+ Move a table to a different directory:
202
+
203
+ >>>> cl.move('dir1.my_table', 'dir2.my_table')
204
+
205
+ Rename a table:
206
+
207
+ >>>> cl.move('dir1.my_table', 'dir1.new_name')
208
+ """
209
+ p = catalog.Path(path)
210
+ Catalog.get().paths.check_is_valid(p, expected=catalog.SchemaObject)
211
+ new_p = catalog.Path(new_path)
212
+ Catalog.get().paths.check_is_valid(new_p, expected=None)
213
+ obj = Catalog.get().paths[p]
214
+ Catalog.get().paths.move(p, new_p)
215
+ new_dir = Catalog.get().paths[new_p.parent]
216
+ obj.move(new_p.name, new_dir._id)
217
+
218
+
219
+ def drop_table(path: str, force: bool = False, ignore_errors: bool = False) -> None:
220
+ """Drop a table.
221
+
222
+ Args:
223
+ path: Path to the table.
224
+ force: Whether to drop the table even if it has unsaved changes.
225
+ ignore_errors: Whether to ignore errors if the table does not exist.
226
+
227
+ Raises:
228
+ Error: If the path does not exist or does not designate a table and ignore_errors is False.
229
+
230
+ Examples:
231
+ >>> cl.drop_table('my_table')
232
+ """
233
+ path_obj = catalog.Path(path)
234
+ try:
235
+ Catalog.get().paths.check_is_valid(path_obj, expected=catalog.Table)
236
+ except Exception as e:
237
+ if ignore_errors:
238
+ _logger.info(f'Skipped table `{path}` (does not exist).')
239
+ return
240
+ else:
241
+ raise e
242
+ tbl = Catalog.get().paths[path_obj]
243
+ if len(Catalog.get().tbl_dependents[tbl._id]) > 0:
244
+ dependent_paths = [get_path(dep) for dep in Catalog.get().tbl_dependents[tbl._id]]
245
+ raise excs.Error(f'Table {path} has dependents: {", ".join(dependent_paths)}')
246
+ tbl._drop()
247
+ del Catalog.get().paths[path_obj]
248
+ _logger.info(f'Dropped table `{path}`.')
249
+
250
+
251
+ def list_tables(dir_path: str = '', recursive: bool = True) -> list[str]:
252
+ """List the tables in a directory.
253
+
254
+ Args:
255
+ dir_path: Path to the directory. Defaults to the root directory.
256
+ recursive: Whether to list tables in subdirectories as well.
257
+
258
+ Returns:
259
+ A list of table paths.
260
+
261
+ Raises:
262
+ Error: If the path does not exist or does not designate a directory.
263
+
264
+ Examples:
265
+ List tables in top-level directory:
266
+
267
+ >>> cl.list_tables()
268
+ ['my_table', ...]
269
+
270
+ List tables in 'dir1':
271
+
272
+ >>> cl.list_tables('dir1')
273
+ [...]
274
+ """
275
+ assert dir_path is not None
276
+ path = catalog.Path(dir_path, empty_is_valid=True)
277
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
278
+ return [str(p) for p in Catalog.get().paths.get_children(path, child_type=catalog.Table, recursive=recursive)]
279
+
280
+
281
+ def create_dir(path_str: str, ignore_errors: bool = False) -> None:
282
+ """Create a directory.
283
+
284
+ Args:
285
+ path_str: Path to the directory.
286
+ ignore_errors: if True, silently returns on error
287
+
288
+ Raises:
289
+ Error: If the path already exists or the parent is not a directory.
290
+
291
+ Examples:
292
+ >>> cl.create_dir('my_dir')
293
+
294
+ Create a subdirectory:
295
+
296
+ >>> cl.create_dir('my_dir.sub_dir')
297
+ """
298
+ try:
299
+ path = catalog.Path(path_str)
300
+ Catalog.get().paths.check_is_valid(path, expected=None)
301
+ parent = Catalog.get().paths[path.parent]
302
+ assert parent is not None
303
+ with orm.Session(Env.get().engine, future=True) as session:
304
+ dir_md = schema.DirMd(name=path.name)
305
+ dir_record = schema.Dir(parent_id=parent._id, md=dataclasses.asdict(dir_md))
306
+ session.add(dir_record)
307
+ session.flush()
308
+ assert dir_record.id is not None
309
+ Catalog.get().paths[path] = catalog.Dir(dir_record.id, parent._id, path.name)
310
+ session.commit()
311
+ _logger.info(f'Created directory `{path_str}`.')
312
+ print(f'Created directory `{path_str}`.')
313
+ except excs.Error as e:
314
+ if ignore_errors:
315
+ return
316
+ else:
317
+ raise e
318
+
319
+
320
+ def rm_dir(path_str: str) -> None:
321
+ """Remove a directory.
322
+
323
+ Args:
324
+ path_str: Path to the directory.
325
+
326
+ Raises:
327
+ Error: If the path does not exist or does not designate a directory or if the directory is not empty.
328
+
329
+ Examples:
330
+ >>> cl.rm_dir('my_dir')
331
+
332
+ Remove a subdirectory:
333
+
334
+ >>> cl.rm_dir('my_dir.sub_dir')
335
+ """
336
+ path = catalog.Path(path_str)
337
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
338
+
339
+ # make sure it's empty
340
+ if len(Catalog.get().paths.get_children(path, child_type=None, recursive=True)) > 0:
341
+ raise excs.Error(f'Directory {path_str} is not empty')
342
+ # TODO: figure out how to make force=True work in the presence of snapshots
343
+ # # delete tables
344
+ # for tbl_path in self.paths.get_children(path, child_type=MutableTable, recursive=True):
345
+ # self.drop_table(str(tbl_path), force=True)
346
+ # # rm subdirs
347
+ # for dir_path in self.paths.get_children(path, child_type=Dir, recursive=False):
348
+ # self.rm_dir(str(dir_path), force=True)
349
+
350
+ with Env.get().engine.begin() as conn:
351
+ dir = Catalog.get().paths[path]
352
+ conn.execute(sql.delete(schema.Dir.__table__).where(schema.Dir.id == dir._id))
353
+ del Catalog.get().paths[path]
354
+ _logger.info(f'Removed directory {path_str}')
355
+
356
+
357
+ def list_dirs(path_str: str = '', recursive: bool = True) -> list[str]:
358
+ """List the directories in a directory.
359
+
360
+ Args:
361
+ path_str: Path to the directory.
362
+ recursive: Whether to list subdirectories recursively.
363
+
364
+ Returns:
365
+ List of directory paths.
366
+
367
+ Raises:
368
+ Error: If the path does not exist or does not designate a directory.
369
+
370
+ Examples:
371
+ >>> cl.list_dirs('my_dir', recursive=True)
372
+ ['my_dir', 'my_dir.sub_dir1']
373
+ """
374
+ path = catalog.Path(path_str, empty_is_valid=True)
375
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
376
+ return [str(p) for p in Catalog.get().paths.get_children(path, child_type=catalog.Dir, recursive=recursive)]
377
+
378
+
379
+ def list_functions() -> pd.DataFrame:
380
+ """Returns information about all registered functions.
381
+
382
+ Returns:
383
+ Pandas DataFrame with columns 'Path', 'Name', 'Parameters', 'Return Type', 'Is Agg', 'Library'
384
+ """
385
+ functions = func.FunctionRegistry.get().list_functions()
386
+ paths = ['.'.join(f.self_path.split('.')[:-1]) for f in functions]
387
+ names = [f.name for f in functions]
388
+ params = [
389
+ ', '.join([param_name + ': ' + str(param_type) for param_name, param_type in f.signature.parameters.items()])
390
+ for f in functions
391
+ ]
392
+ pd_df = pd.DataFrame(
393
+ {
394
+ 'Path': paths,
395
+ 'Function Name': names,
396
+ 'Parameters': params,
397
+ 'Return Type': [str(f.signature.get_return_type()) for f in functions],
398
+ }
399
+ )
400
+ pd_df = pd_df.style.set_properties(**{'text-align': 'left'}).set_table_styles(
401
+ [dict(selector='th', props=[('text-align', 'center')])]
402
+ ) # center-align headings
403
+ return pd_df.hide(axis='index')
404
+
405
+
406
+ def get_path(schema_obj: catalog.SchemaObject) -> str:
407
+ """Returns the path to a SchemaObject.
408
+
409
+ Args:
410
+ schema_obj: SchemaObject to get the path for.
411
+
412
+ Returns:
413
+ Path to the SchemaObject.
414
+ """
415
+ path_elements: list[str] = []
416
+ dir_id = schema_obj._dir_id
417
+ while dir_id is not None:
418
+ dir = Catalog.get().paths.get_schema_obj(dir_id)
419
+ if dir._dir_id is None:
420
+ # this is the root dir with name '', which we don't want to include in the path
421
+ break
422
+ path_elements.insert(0, dir._name)
423
+ dir_id = dir._dir_id
424
+ path_elements.append(schema_obj._name)
425
+ return '.'.join(path_elements)
@@ -0,0 +1,2 @@
1
+ from .base import IndexBase
2
+ from .embedding_index import EmbeddingIndex
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import Any
5
+
6
+ import sqlalchemy as sql
7
+
8
+ import pixeltable.catalog as catalog
9
+
10
+
11
+ class IndexBase(abc.ABC):
12
+ """
13
+ Internal interface used by the catalog and runtime system to interact with indices:
14
+ - types and expressions needed to create and populate the index value column
15
+ - creating/dropping the index
16
+ This doesn't cover querying the index, which is dependent on the index semantics and handled by
17
+ the specific subclass.
18
+ """
19
+
20
+ @abc.abstractmethod
21
+ def __init__(self, c: catalog.Column, **kwargs: Any):
22
+ pass
23
+
24
+ @abc.abstractmethod
25
+ def index_value_expr(self) -> 'pixeltable.exprs.Expr':
26
+ """Return expression that computes the value that goes into the index"""
27
+ pass
28
+
29
+ @abc.abstractmethod
30
+ def index_sa_type(self) -> sql.sqltypes.TypeEngine:
31
+ """Return the sqlalchemy type of the index value column"""
32
+ pass
33
+
34
+ @abc.abstractmethod
35
+ def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
36
+ """Create the index on the index value column"""
37
+ pass
38
+
39
+ @classmethod
40
+ @abc.abstractmethod
41
+ def display_name(cls) -> str:
42
+ pass
43
+
44
+ @abc.abstractmethod
45
+ def as_dict(self) -> dict:
46
+ pass
47
+
48
+ @classmethod
49
+ @abc.abstractmethod
50
+ def from_dict(cls, c: catalog.Column, d: dict) -> IndexBase:
51
+ pass
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Any
4
+ import enum
5
+
6
+ import PIL.Image
7
+ import numpy as np
8
+ import pgvector.sqlalchemy
9
+ import PIL.Image
10
+ import sqlalchemy as sql
11
+
12
+ import pixeltable.catalog as catalog
13
+ import pixeltable.exceptions as excs
14
+ import pixeltable.func as func
15
+ import pixeltable.type_system as ts
16
+ from .base import IndexBase
17
+
18
+
19
+ class EmbeddingIndex(IndexBase):
20
+ """
21
+ Interface to the pgvector access method in Postgres.
22
+ - pgvector converts the cosine metric to '1 - metric' and the inner product metric to '-metric', in order to
23
+ satisfy the Postgres requirement that an index scan requires an ORDER BY ... ASC clause
24
+ - similarity_clause() converts those metrics back to their original form; it is used in expressions outside
25
+ the Order By clause
26
+ - order_by_clause() is used exclusively in the ORDER BY clause
27
+ """
28
+
29
+ class Metric(enum.Enum):
30
+ COSINE = 1
31
+ IP = 2
32
+ L2 = 3
33
+
34
+ PGVECTOR_OPS = {
35
+ Metric.COSINE: 'vector_cosine_ops',
36
+ Metric.IP: 'vector_ip_ops',
37
+ Metric.L2: 'vector_l2_ops'
38
+ }
39
+
40
+ def __init__(
41
+ self, c: catalog.Column, metric: str, text_embed: Optional[func.Function] = None,
42
+ img_embed: Optional[func.Function] = None):
43
+ metric_names = [m.name.lower() for m in self.Metric]
44
+ if metric.lower() not in metric_names:
45
+ raise excs.Error(f'Invalid metric {metric}, must be one of {metric_names}')
46
+ if not c.col_type.is_string_type() and not c.col_type.is_image_type():
47
+ raise excs.Error(f'Embedding index requires string or image column')
48
+ if c.col_type.is_string_type() and text_embed is None:
49
+ raise excs.Error(f'Text embedding function is required for column {c.name} (parameter `txt_embed`)')
50
+ if c.col_type.is_image_type() and img_embed is None:
51
+ raise excs.Error(f'Image embedding function is required for column {c.name} (parameter `img_embed`)')
52
+ if text_embed is not None:
53
+ # verify signature
54
+ self._validate_embedding_fn(text_embed, 'txt_embed', ts.ColumnType.Type.STRING)
55
+ if img_embed is not None:
56
+ # verify signature
57
+ self._validate_embedding_fn(img_embed, 'img_embed', ts.ColumnType.Type.IMAGE)
58
+
59
+ self.metric = self.Metric[metric.upper()]
60
+ from pixeltable.exprs import ColumnRef
61
+ self.value_expr = text_embed(ColumnRef(c)) if c.col_type.is_string_type() else img_embed(ColumnRef(c))
62
+ assert self.value_expr.col_type.is_array_type()
63
+ self.txt_embed = text_embed
64
+ self.img_embed = img_embed
65
+ vector_size = self.value_expr.col_type.shape[0]
66
+ assert vector_size is not None
67
+ self.index_col_type = pgvector.sqlalchemy.Vector(vector_size)
68
+
69
+ def index_value_expr(self) -> 'pixeltable.exprs.Expr':
70
+ """Return expression that computes the value that goes into the index"""
71
+ return self.value_expr
72
+
73
+ def index_sa_type(self) -> sql.sqltypes.TypeEngine:
74
+ """Return the sqlalchemy type of the index value column"""
75
+ return self.index_col_type
76
+
77
+ def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
78
+ """Create the index on the index value column"""
79
+ idx = sql.Index(
80
+ index_name, index_value_col.sa_col,
81
+ postgresql_using='hnsw',
82
+ postgresql_with={'m': 16, 'ef_construction': 64},
83
+ postgresql_ops={index_value_col.sa_col.name: self.PGVECTOR_OPS[self.metric]}
84
+ )
85
+ idx.create(bind=conn)
86
+
87
+ def similarity_clause(self, val_column: catalog.Column, item: Any) -> sql.ClauseElement:
88
+ """Create a ClauseElement to that represents '<val_column> <op> <item>'"""
89
+ assert isinstance(item, (str, PIL.Image.Image))
90
+ if isinstance(item, str):
91
+ assert self.txt_embed is not None
92
+ embedding = self.txt_embed.exec(item)
93
+ if isinstance(item, PIL.Image.Image):
94
+ assert self.img_embed is not None
95
+ embedding = self.img_embed.exec(item)
96
+
97
+ if self.metric == self.Metric.COSINE:
98
+ return val_column.sa_col.cosine_distance(embedding) * -1 + 1
99
+ elif self.metric == self.Metric.IP:
100
+ return val_column.sa_col.max_inner_product(embedding) * -1
101
+ else:
102
+ assert self.metric == self.Metric.L2
103
+ return val_column.sa_col.l2_distance(embedding)
104
+
105
+ def order_by_clause(self, val_column: catalog.Column, item: Any, is_asc: bool) -> sql.ClauseElement:
106
+ """Create a ClauseElement that is used in an ORDER BY clause"""
107
+ assert isinstance(item, (str, PIL.Image.Image))
108
+ embedding: Optional[np.ndarray] = None
109
+ if isinstance(item, str):
110
+ assert self.txt_embed is not None
111
+ embedding = self.txt_embed.exec(item)
112
+ if isinstance(item, PIL.Image.Image):
113
+ assert self.img_embed is not None
114
+ embedding = self.img_embed.exec(item)
115
+ assert embedding is not None
116
+
117
+ if self.metric == self.Metric.COSINE:
118
+ result = val_column.sa_col.cosine_distance(embedding)
119
+ result = result.desc() if is_asc else result
120
+ elif self.metric == self.Metric.IP:
121
+ result = val_column.sa_col.max_inner_product(embedding)
122
+ result = result.desc() if is_asc else result
123
+ else:
124
+ assert self.metric == self.Metric.L2
125
+ result = val_column.sa_col.l2_distance(embedding)
126
+ return result
127
+
128
+ @classmethod
129
+ def display_name(cls) -> str:
130
+ return 'embedding'
131
+
132
+ @classmethod
133
+ def _validate_embedding_fn(cls, embed_fn: func.Function, name: str, expected_type: ts.ColumnType.Type) -> None:
134
+ """Validate the signature"""
135
+ assert isinstance(embed_fn, func.Function)
136
+ sig = embed_fn.signature
137
+ if len(sig.parameters) != 1 or sig.parameters_by_pos[0].col_type.type_enum != expected_type:
138
+ raise excs.Error(
139
+ f'{name} must take a single {expected_type.name.lower()} parameter, but has signature {sig}')
140
+
141
+ # validate return type
142
+ param_name = sig.parameters_by_pos[0].name
143
+ if expected_type == ts.ColumnType.Type.STRING:
144
+ return_type = embed_fn.call_return_type({param_name: 'dummy'})
145
+ else:
146
+ assert expected_type == ts.ColumnType.Type.IMAGE
147
+ img = PIL.Image.new('RGB', (512, 512))
148
+ return_type = embed_fn.call_return_type({param_name: img})
149
+ assert return_type is not None
150
+ if not return_type.is_array_type():
151
+ raise excs.Error(f'{name} must return an array, but returns {return_type}')
152
+ else:
153
+ shape = return_type.shape
154
+ if len(shape) != 1 or shape[0] == None:
155
+ raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {return_type}')
156
+
157
+ def as_dict(self) -> dict:
158
+ return {
159
+ 'metric': self.metric.name.lower(),
160
+ 'txt_embed': None if self.txt_embed is None else self.txt_embed.as_dict(),
161
+ 'img_embed': None if self.img_embed is None else self.img_embed.as_dict()
162
+ }
163
+
164
+ @classmethod
165
+ def from_dict(cls, c: catalog.Column, d: dict) -> EmbeddingIndex:
166
+ txt_embed = func.Function.from_dict(d['txt_embed']) if d['txt_embed'] is not None else None
167
+ img_embed = func.Function.from_dict(d['img_embed']) if d['img_embed'] is not None else None
168
+ return cls(c, metric=d['metric'], text_embed=txt_embed, img_embed=img_embed)
@@ -0,0 +1,3 @@
1
+ from .hf_datasets import import_huggingface_dataset
2
+ from .pandas import import_csv, import_excel, import_pandas
3
+ from .parquet import import_parquet