pixeltable 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (110) hide show
  1. pixeltable/__init__.py +20 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +23 -7
  4. pixeltable/catalog/insertable_table.py +32 -19
  5. pixeltable/catalog/table.py +210 -20
  6. pixeltable/catalog/table_version.py +272 -111
  7. pixeltable/catalog/table_version_path.py +6 -1
  8. pixeltable/dataframe.py +184 -110
  9. pixeltable/datatransfer/__init__.py +1 -0
  10. pixeltable/datatransfer/label_studio.py +526 -0
  11. pixeltable/datatransfer/remote.py +113 -0
  12. pixeltable/env.py +213 -79
  13. pixeltable/exec/__init__.py +2 -1
  14. pixeltable/exec/data_row_batch.py +6 -7
  15. pixeltable/exec/expr_eval_node.py +28 -28
  16. pixeltable/exec/sql_scan_node.py +7 -6
  17. pixeltable/exprs/__init__.py +4 -3
  18. pixeltable/exprs/column_ref.py +11 -2
  19. pixeltable/exprs/comparison.py +39 -1
  20. pixeltable/exprs/data_row.py +7 -0
  21. pixeltable/exprs/expr.py +26 -19
  22. pixeltable/exprs/function_call.py +17 -18
  23. pixeltable/exprs/globals.py +14 -2
  24. pixeltable/exprs/image_member_access.py +9 -28
  25. pixeltable/exprs/in_predicate.py +96 -0
  26. pixeltable/exprs/inline_array.py +13 -11
  27. pixeltable/exprs/inline_dict.py +15 -13
  28. pixeltable/exprs/row_builder.py +7 -1
  29. pixeltable/exprs/similarity_expr.py +67 -0
  30. pixeltable/ext/functions/whisperx.py +30 -0
  31. pixeltable/ext/functions/yolox.py +16 -0
  32. pixeltable/func/__init__.py +0 -2
  33. pixeltable/func/aggregate_function.py +5 -2
  34. pixeltable/func/callable_function.py +57 -13
  35. pixeltable/func/expr_template_function.py +14 -3
  36. pixeltable/func/function.py +35 -4
  37. pixeltable/func/signature.py +5 -15
  38. pixeltable/func/udf.py +8 -12
  39. pixeltable/functions/fireworks.py +9 -4
  40. pixeltable/functions/huggingface.py +48 -5
  41. pixeltable/functions/openai.py +49 -11
  42. pixeltable/functions/pil/image.py +61 -64
  43. pixeltable/functions/together.py +32 -6
  44. pixeltable/functions/util.py +0 -43
  45. pixeltable/functions/video.py +46 -8
  46. pixeltable/globals.py +443 -0
  47. pixeltable/index/__init__.py +1 -0
  48. pixeltable/index/base.py +9 -2
  49. pixeltable/index/btree.py +54 -0
  50. pixeltable/index/embedding_index.py +91 -15
  51. pixeltable/io/__init__.py +4 -0
  52. pixeltable/io/globals.py +59 -0
  53. pixeltable/{utils → io}/hf_datasets.py +48 -17
  54. pixeltable/io/pandas.py +148 -0
  55. pixeltable/{utils → io}/parquet.py +58 -33
  56. pixeltable/iterators/__init__.py +1 -1
  57. pixeltable/iterators/base.py +8 -4
  58. pixeltable/iterators/document.py +225 -93
  59. pixeltable/iterators/video.py +16 -9
  60. pixeltable/metadata/__init__.py +8 -4
  61. pixeltable/metadata/converters/convert_12.py +3 -0
  62. pixeltable/metadata/converters/convert_13.py +41 -0
  63. pixeltable/metadata/converters/convert_14.py +13 -0
  64. pixeltable/metadata/converters/convert_15.py +29 -0
  65. pixeltable/metadata/converters/util.py +63 -0
  66. pixeltable/metadata/schema.py +12 -6
  67. pixeltable/plan.py +11 -24
  68. pixeltable/store.py +16 -23
  69. pixeltable/tool/create_test_db_dump.py +49 -14
  70. pixeltable/type_system.py +27 -58
  71. pixeltable/utils/coco.py +94 -0
  72. pixeltable/utils/documents.py +42 -12
  73. pixeltable/utils/http_server.py +70 -0
  74. pixeltable-0.2.7.dist-info/METADATA +137 -0
  75. pixeltable-0.2.7.dist-info/RECORD +126 -0
  76. {pixeltable-0.2.5.dist-info → pixeltable-0.2.7.dist-info}/WHEEL +1 -1
  77. pixeltable/client.py +0 -600
  78. pixeltable/exprs/image_similarity_predicate.py +0 -58
  79. pixeltable/func/batched_function.py +0 -53
  80. pixeltable/func/nos_function.py +0 -202
  81. pixeltable/tests/conftest.py +0 -171
  82. pixeltable/tests/ext/test_yolox.py +0 -21
  83. pixeltable/tests/functions/test_fireworks.py +0 -43
  84. pixeltable/tests/functions/test_functions.py +0 -60
  85. pixeltable/tests/functions/test_huggingface.py +0 -158
  86. pixeltable/tests/functions/test_openai.py +0 -162
  87. pixeltable/tests/functions/test_together.py +0 -112
  88. pixeltable/tests/test_audio.py +0 -65
  89. pixeltable/tests/test_catalog.py +0 -27
  90. pixeltable/tests/test_client.py +0 -21
  91. pixeltable/tests/test_component_view.py +0 -379
  92. pixeltable/tests/test_dataframe.py +0 -440
  93. pixeltable/tests/test_dirs.py +0 -107
  94. pixeltable/tests/test_document.py +0 -120
  95. pixeltable/tests/test_exprs.py +0 -802
  96. pixeltable/tests/test_function.py +0 -332
  97. pixeltable/tests/test_index.py +0 -138
  98. pixeltable/tests/test_migration.py +0 -44
  99. pixeltable/tests/test_nos.py +0 -54
  100. pixeltable/tests/test_snapshot.py +0 -231
  101. pixeltable/tests/test_table.py +0 -1343
  102. pixeltable/tests/test_transactional_directory.py +0 -42
  103. pixeltable/tests/test_types.py +0 -52
  104. pixeltable/tests/test_video.py +0 -159
  105. pixeltable/tests/test_view.py +0 -535
  106. pixeltable/tests/utils.py +0 -442
  107. pixeltable/utils/clip.py +0 -18
  108. pixeltable-0.2.5.dist-info/METADATA +0 -128
  109. pixeltable-0.2.5.dist-info/RECORD +0 -139
  110. {pixeltable-0.2.5.dist-info → pixeltable-0.2.7.dist-info}/LICENSE +0 -0
pixeltable/globals.py ADDED
@@ -0,0 +1,443 @@
1
+ import dataclasses
2
+ import logging
3
+ from typing import Any, Optional, Union, Type
4
+
5
+ import pandas as pd
6
+ import sqlalchemy as sql
7
+ from sqlalchemy.util.preloaded import orm
8
+
9
+ import pixeltable.exceptions as excs
10
+ from pixeltable import catalog, func
11
+ from pixeltable.catalog import Catalog
12
+ from pixeltable.env import Env
13
+ from pixeltable.exprs import Predicate
14
+ from pixeltable.iterators import ComponentIterator
15
+ from pixeltable.metadata import schema
16
+
17
+ _logger = logging.getLogger('pixeltable')
18
+
19
+
20
+ def init() -> None:
21
+ """Initializes the Pixeltable environment."""
22
+ _ = Catalog.get()
23
+
24
+
25
+ def create_table(
26
+ path_str: str,
27
+ schema: dict[str, Any],
28
+ *,
29
+ primary_key: Optional[Union[str, list[str]]] = None,
30
+ num_retained_versions: int = 10,
31
+ comment: str = '',
32
+ ) -> catalog.InsertableTable:
33
+ """Create a new `InsertableTable`.
34
+
35
+ Args:
36
+ path_str: Path to the table.
37
+ schema: dictionary mapping column names to column types, value expressions, or to column specifications.
38
+ num_retained_versions: Number of versions of the table to retain.
39
+
40
+ Returns:
41
+ The newly created table.
42
+
43
+ Raises:
44
+ Error: if the path already exists or is invalid.
45
+
46
+ Examples:
47
+ Create a table with an int and a string column:
48
+
49
+ >>> table = cl.create_table('my_table', schema={'col1': IntType(), 'col2': StringType()})
50
+ """
51
+ path = catalog.Path(path_str)
52
+ Catalog.get().paths.check_is_valid(path, expected=None)
53
+ dir = Catalog.get().paths[path.parent]
54
+
55
+ if len(schema) == 0:
56
+ raise excs.Error(f'Table schema is empty: `{path_str}`')
57
+
58
+ if primary_key is None:
59
+ primary_key = []
60
+ elif isinstance(primary_key, str):
61
+ primary_key = [primary_key]
62
+ else:
63
+ if not isinstance(primary_key, list) or not all(isinstance(pk, str) for pk in primary_key):
64
+ raise excs.Error('primary_key must be a single column name or a list of column names')
65
+
66
+ tbl = catalog.InsertableTable.create(
67
+ dir._id,
68
+ path.name,
69
+ schema,
70
+ primary_key=primary_key,
71
+ num_retained_versions=num_retained_versions,
72
+ comment=comment,
73
+ )
74
+ Catalog.get().paths[path] = tbl
75
+ _logger.info(f'Created table `{path_str}`.')
76
+ return tbl
77
+
78
+
79
+ def create_view(
80
+ path_str: str,
81
+ base: catalog.Table,
82
+ *,
83
+ schema: Optional[dict[str, Any]] = None,
84
+ filter: Optional[Predicate] = None,
85
+ is_snapshot: bool = False,
86
+ iterator: Optional[tuple[type[ComponentIterator], dict[str, Any]]] = None,
87
+ num_retained_versions: int = 10,
88
+ comment: str = '',
89
+ ignore_errors: bool = False,
90
+ ) -> catalog.View:
91
+ """Create a new `View`.
92
+
93
+ Args:
94
+ path_str: Path to the view.
95
+ base: Table (ie, table or view or snapshot) to base the view on.
96
+ schema: dictionary mapping column names to column types, value expressions, or to column specifications.
97
+ filter: Predicate to filter rows of the base table.
98
+ is_snapshot: Whether the view is a snapshot.
99
+ iterator: The iterator to use for this view. If specified, then this view will be a one-to-many view of
100
+ the base table.
101
+ num_retained_versions: Number of versions of the view to retain.
102
+ ignore_errors: if True, fail silently if the path already exists or is invalid.
103
+
104
+ Returns:
105
+ The newly created view.
106
+
107
+ Raises:
108
+ Error: if the path already exists or is invalid.
109
+
110
+ Examples:
111
+ Create a view with an additional int and a string column and a filter:
112
+
113
+ >>> view = cl.create_view(
114
+ 'my_view', base, schema={'col3': IntType(), 'col4': StringType()}, filter=base.col1 > 10)
115
+
116
+ Create a table snapshot:
117
+
118
+ >>> snapshot_view = cl.create_view('my_snapshot_view', base, is_snapshot=True)
119
+
120
+ Create an immutable view with additional computed columns and a filter:
121
+
122
+ >>> snapshot_view = cl.create_view(
123
+ 'my_snapshot', base, schema={'col3': base.col2 + 1}, filter=base.col1 > 10, is_snapshot=True)
124
+ """
125
+ assert isinstance(base, catalog.Table)
126
+ path = catalog.Path(path_str)
127
+ try:
128
+ Catalog.get().paths.check_is_valid(path, expected=None)
129
+ except Exception as e:
130
+ if ignore_errors:
131
+ return
132
+ else:
133
+ raise e
134
+ dir = Catalog.get().paths[path.parent]
135
+
136
+ if schema is None:
137
+ schema = {}
138
+ if iterator is None:
139
+ iterator_class, iterator_args = None, None
140
+ else:
141
+ iterator_class, iterator_args = iterator
142
+ view = catalog.View.create(
143
+ dir._id,
144
+ path.name,
145
+ base=base,
146
+ schema=schema,
147
+ predicate=filter,
148
+ is_snapshot=is_snapshot,
149
+ iterator_cls=iterator_class,
150
+ iterator_args=iterator_args,
151
+ num_retained_versions=num_retained_versions,
152
+ comment=comment,
153
+ )
154
+ Catalog.get().paths[path] = view
155
+ _logger.info(f'Created view `{path_str}`.')
156
+ return view
157
+
158
+
159
+ def get_table(path: str) -> catalog.Table:
160
+ """Get a handle to a table (including views and snapshots).
161
+
162
+ Args:
163
+ path: Path to the table.
164
+
165
+ Returns:
166
+ A `InsertableTable` or `View` object.
167
+
168
+ Raises:
169
+ Error: If the path does not exist or does not designate a table.
170
+
171
+ Examples:
172
+ Get handle for a table in the top-level directory:
173
+
174
+ >>> table = cl.get_table('my_table')
175
+
176
+ For a table in a subdirectory:
177
+
178
+ >>> table = cl.get_table('subdir.my_table')
179
+
180
+ For a snapshot in the top-level directory:
181
+
182
+ >>> table = cl.get_table('my_snapshot')
183
+ """
184
+ p = catalog.Path(path)
185
+ Catalog.get().paths.check_is_valid(p, expected=catalog.Table)
186
+ obj = Catalog.get().paths[p]
187
+ return obj
188
+
189
+
190
+ def move(path: str, new_path: str) -> None:
191
+ """Move a schema object to a new directory and/or rename a schema object.
192
+
193
+ Args:
194
+ path: absolute path to the existing schema object.
195
+ new_path: absolute new path for the schema object.
196
+
197
+ Raises:
198
+ Error: If path does not exist or new_path already exists.
199
+
200
+ Examples:
201
+ Move a table to a different directory:
202
+
203
+ >>>> cl.move('dir1.my_table', 'dir2.my_table')
204
+
205
+ Rename a table:
206
+
207
+ >>>> cl.move('dir1.my_table', 'dir1.new_name')
208
+ """
209
+ p = catalog.Path(path)
210
+ Catalog.get().paths.check_is_valid(p, expected=catalog.SchemaObject)
211
+ new_p = catalog.Path(new_path)
212
+ Catalog.get().paths.check_is_valid(new_p, expected=None)
213
+ obj = Catalog.get().paths[p]
214
+ Catalog.get().paths.move(p, new_p)
215
+ new_dir = Catalog.get().paths[new_p.parent]
216
+ obj.move(new_p.name, new_dir._id)
217
+
218
+
219
+ def drop_table(path: str, force: bool = False, ignore_errors: bool = False) -> None:
220
+ """Drop a table.
221
+
222
+ Args:
223
+ path: Path to the table.
224
+ force: Whether to drop the table even if it has unsaved changes.
225
+ ignore_errors: Whether to ignore errors if the table does not exist.
226
+
227
+ Raises:
228
+ Error: If the path does not exist or does not designate a table and ignore_errors is False.
229
+
230
+ Examples:
231
+ >>> cl.drop_table('my_table')
232
+ """
233
+ path_obj = catalog.Path(path)
234
+ try:
235
+ Catalog.get().paths.check_is_valid(path_obj, expected=catalog.Table)
236
+ except Exception as e:
237
+ if ignore_errors:
238
+ _logger.info(f'Skipped table `{path}` (does not exist).')
239
+ return
240
+ else:
241
+ raise e
242
+ tbl = Catalog.get().paths[path_obj]
243
+ if len(Catalog.get().tbl_dependents[tbl._id]) > 0:
244
+ dependent_paths = [get_path(dep) for dep in Catalog.get().tbl_dependents[tbl._id]]
245
+ raise excs.Error(f'Table {path} has dependents: {", ".join(dependent_paths)}')
246
+ tbl._drop()
247
+ del Catalog.get().paths[path_obj]
248
+ _logger.info(f'Dropped table `{path}`.')
249
+
250
+
251
+ def list_tables(dir_path: str = '', recursive: bool = True) -> list[str]:
252
+ """List the tables in a directory.
253
+
254
+ Args:
255
+ dir_path: Path to the directory. Defaults to the root directory.
256
+ recursive: Whether to list tables in subdirectories as well.
257
+
258
+ Returns:
259
+ A list of table paths.
260
+
261
+ Raises:
262
+ Error: If the path does not exist or does not designate a directory.
263
+
264
+ Examples:
265
+ List tables in top-level directory:
266
+
267
+ >>> cl.list_tables()
268
+ ['my_table', ...]
269
+
270
+ List tables in 'dir1':
271
+
272
+ >>> cl.list_tables('dir1')
273
+ [...]
274
+ """
275
+ assert dir_path is not None
276
+ path = catalog.Path(dir_path, empty_is_valid=True)
277
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
278
+ return [str(p) for p in Catalog.get().paths.get_children(path, child_type=catalog.Table, recursive=recursive)]
279
+
280
+
281
+ def create_dir(path_str: str, ignore_errors: bool = False) -> None:
282
+ """Create a directory.
283
+
284
+ Args:
285
+ path_str: Path to the directory.
286
+ ignore_errors: if True, silently returns on error
287
+
288
+ Raises:
289
+ Error: If the path already exists or the parent is not a directory.
290
+
291
+ Examples:
292
+ >>> cl.create_dir('my_dir')
293
+
294
+ Create a subdirectory:
295
+
296
+ >>> cl.create_dir('my_dir.sub_dir')
297
+ """
298
+ try:
299
+ path = catalog.Path(path_str)
300
+ Catalog.get().paths.check_is_valid(path, expected=None)
301
+ parent = Catalog.get().paths[path.parent]
302
+ assert parent is not None
303
+ with orm.Session(Env.get().engine, future=True) as session:
304
+ dir_md = schema.DirMd(name=path.name)
305
+ dir_record = schema.Dir(parent_id=parent._id, md=dataclasses.asdict(dir_md))
306
+ session.add(dir_record)
307
+ session.flush()
308
+ assert dir_record.id is not None
309
+ Catalog.get().paths[path] = catalog.Dir(dir_record.id, parent._id, path.name)
310
+ session.commit()
311
+ _logger.info(f'Created directory `{path_str}`.')
312
+ print(f'Created directory `{path_str}`.')
313
+ except excs.Error as e:
314
+ if ignore_errors:
315
+ return
316
+ else:
317
+ raise e
318
+
319
+
320
+ def rm_dir(path_str: str) -> None:
321
+ """Remove a directory.
322
+
323
+ Args:
324
+ path_str: Path to the directory.
325
+
326
+ Raises:
327
+ Error: If the path does not exist or does not designate a directory or if the directory is not empty.
328
+
329
+ Examples:
330
+ >>> cl.rm_dir('my_dir')
331
+
332
+ Remove a subdirectory:
333
+
334
+ >>> cl.rm_dir('my_dir.sub_dir')
335
+ """
336
+ path = catalog.Path(path_str)
337
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
338
+
339
+ # make sure it's empty
340
+ if len(Catalog.get().paths.get_children(path, child_type=None, recursive=True)) > 0:
341
+ raise excs.Error(f'Directory {path_str} is not empty')
342
+ # TODO: figure out how to make force=True work in the presence of snapshots
343
+ # # delete tables
344
+ # for tbl_path in self.paths.get_children(path, child_type=MutableTable, recursive=True):
345
+ # self.drop_table(str(tbl_path), force=True)
346
+ # # rm subdirs
347
+ # for dir_path in self.paths.get_children(path, child_type=Dir, recursive=False):
348
+ # self.rm_dir(str(dir_path), force=True)
349
+
350
+ with Env.get().engine.begin() as conn:
351
+ dir = Catalog.get().paths[path]
352
+ conn.execute(sql.delete(schema.Dir.__table__).where(schema.Dir.id == dir._id))
353
+ del Catalog.get().paths[path]
354
+ _logger.info(f'Removed directory {path_str}')
355
+
356
+
357
+ def list_dirs(path_str: str = '', recursive: bool = True) -> list[str]:
358
+ """List the directories in a directory.
359
+
360
+ Args:
361
+ path_str: Path to the directory.
362
+ recursive: Whether to list subdirectories recursively.
363
+
364
+ Returns:
365
+ List of directory paths.
366
+
367
+ Raises:
368
+ Error: If the path does not exist or does not designate a directory.
369
+
370
+ Examples:
371
+ >>> cl.list_dirs('my_dir', recursive=True)
372
+ ['my_dir', 'my_dir.sub_dir1']
373
+ """
374
+ path = catalog.Path(path_str, empty_is_valid=True)
375
+ Catalog.get().paths.check_is_valid(path, expected=catalog.Dir)
376
+ return [str(p) for p in Catalog.get().paths.get_children(path, child_type=catalog.Dir, recursive=recursive)]
377
+
378
+
379
+ def list_functions() -> pd.DataFrame:
380
+ """Returns information about all registered functions.
381
+
382
+ Returns:
383
+ Pandas DataFrame with columns 'Path', 'Name', 'Parameters', 'Return Type', 'Is Agg', 'Library'
384
+ """
385
+ functions = func.FunctionRegistry.get().list_functions()
386
+ paths = ['.'.join(f.self_path.split('.')[:-1]) for f in functions]
387
+ names = [f.name for f in functions]
388
+ params = [
389
+ ', '.join([param_name + ': ' + str(param_type) for param_name, param_type in f.signature.parameters.items()])
390
+ for f in functions
391
+ ]
392
+ pd_df = pd.DataFrame(
393
+ {
394
+ 'Path': paths,
395
+ 'Function Name': names,
396
+ 'Parameters': params,
397
+ 'Return Type': [str(f.signature.get_return_type()) for f in functions],
398
+ }
399
+ )
400
+ pd_df = pd_df.style.set_properties(**{'text-align': 'left'}).set_table_styles(
401
+ [dict(selector='th', props=[('text-align', 'center')])]
402
+ ) # center-align headings
403
+ return pd_df.hide(axis='index')
404
+
405
+
406
+ def get_path(schema_obj: catalog.SchemaObject) -> str:
407
+ """Returns the path to a SchemaObject.
408
+
409
+ Args:
410
+ schema_obj: SchemaObject to get the path for.
411
+
412
+ Returns:
413
+ Path to the SchemaObject.
414
+ """
415
+ path_elements: list[str] = []
416
+ dir_id = schema_obj._dir_id
417
+ while dir_id is not None:
418
+ dir = Catalog.get().paths.get_schema_obj(dir_id)
419
+ if dir._dir_id is None:
420
+ # this is the root dir with name '', which we don't want to include in the path
421
+ break
422
+ path_elements.insert(0, dir._name)
423
+ dir_id = dir._dir_id
424
+ path_elements.append(schema_obj._name)
425
+ return '.'.join(path_elements)
426
+
427
+
428
+ def configure_logging(
429
+ *,
430
+ to_stdout: Optional[bool] = None,
431
+ level: Optional[int] = None,
432
+ add: Optional[str] = None,
433
+ remove: Optional[str] = None,
434
+ ) -> None:
435
+ """Configure logging.
436
+
437
+ Args:
438
+ to_stdout: if True, also log to stdout
439
+ level: default log level
440
+ add: comma-separated list of 'module name:log level' pairs; ex.: add='video:10'
441
+ remove: comma-separated list of module names
442
+ """
443
+ return Env.get().configure_logging(to_stdout=to_stdout, level=level, add=add, remove=remove)
@@ -1,2 +1,3 @@
1
1
  from .base import IndexBase
2
2
  from .embedding_index import EmbeddingIndex
3
+ from .btree import BtreeIndex
pixeltable/index/base.py CHANGED
@@ -13,8 +13,10 @@ class IndexBase(abc.ABC):
13
13
  Internal interface used by the catalog and runtime system to interact with indices:
14
14
  - types and expressions needed to create and populate the index value column
15
15
  - creating/dropping the index
16
- - TODO: translating queries into sqlalchemy predicates
16
+ This doesn't cover querying the index, which is dependent on the index semantics and handled by
17
+ the specific subclass.
17
18
  """
19
+
18
20
  @abc.abstractmethod
19
21
  def __init__(self, c: catalog.Column, **kwargs: Any):
20
22
  pass
@@ -25,7 +27,12 @@ class IndexBase(abc.ABC):
25
27
  pass
26
28
 
27
29
  @abc.abstractmethod
28
- def index_sa_type(self) -> sql.sqltypes.TypeEngine:
30
+ def records_value_errors(self) -> bool:
31
+ """True if index_value_expr() can raise errors"""
32
+ pass
33
+
34
+ @abc.abstractmethod
35
+ def index_sa_type(self) -> sql.types.TypeEngine:
29
36
  """Return the sqlalchemy type of the index value column"""
30
37
  pass
31
38
 
@@ -0,0 +1,54 @@
1
+ from typing import Optional
2
+
3
+ import sqlalchemy as sql
4
+
5
+ # TODO: why does this import result in a circular import, but the one im embedding_index.py doesn't?
6
+ #import pixeltable.catalog as catalog
7
+ import pixeltable.exceptions as excs
8
+ import pixeltable.func as func
9
+ from .base import IndexBase
10
+
11
+
12
+ class BtreeIndex(IndexBase):
13
+ """
14
+ Interface to B-tree indices in Postgres.
15
+ """
16
+ MAX_STRING_LEN = 256
17
+
18
+ @func.udf
19
+ def str_filter(s: Optional[str]) -> Optional[str]:
20
+ if s is None:
21
+ return None
22
+ return s[:BtreeIndex.MAX_STRING_LEN]
23
+
24
+ def __init__(self, c: 'catalog.Column'):
25
+ if not c.col_type.is_scalar_type() and not c.col_type.is_media_type():
26
+ raise excs.Error(f'Index on column {c.name}: B-tree index requires scalar or media type, got {c.col_type}')
27
+ from pixeltable.exprs import ColumnRef
28
+ self.value_expr = self.str_filter(ColumnRef(c)) if c.col_type.is_string_type() else ColumnRef(c)
29
+
30
+ def index_value_expr(self) -> 'pixeltable.exprs.Expr':
31
+ return self.value_expr
32
+
33
+ def records_value_errors(self) -> bool:
34
+ return False
35
+
36
+ def index_sa_type(self) -> sql.types.TypeEngine:
37
+ """Return the sqlalchemy type of the index value column"""
38
+ return self.value_expr.col_type.to_sa_type()
39
+
40
+ def create_index(self, index_name: str, index_value_col: 'catalog.Column', conn: sql.engine.Connection) -> None:
41
+ """Create the index on the index value column"""
42
+ idx = sql.Index(index_name, index_value_col.sa_col, postgresql_using='btree')
43
+ idx.create(bind=conn)
44
+
45
+ @classmethod
46
+ def display_name(cls) -> str:
47
+ return 'btree'
48
+
49
+ def as_dict(self) -> dict:
50
+ return {}
51
+
52
+ @classmethod
53
+ def from_dict(cls, c: 'catalog.Column', d: dict) -> 'BtreeIndex':
54
+ return cls(c)
@@ -1,8 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
3
+ from typing import Optional, Any
4
+ import enum
4
5
 
6
+ import PIL.Image
7
+ import numpy as np
5
8
  import pgvector.sqlalchemy
9
+ import PIL.Image
6
10
  import sqlalchemy as sql
7
11
 
8
12
  import pixeltable.catalog as catalog
@@ -14,15 +18,31 @@ from .base import IndexBase
14
18
 
15
19
  class EmbeddingIndex(IndexBase):
16
20
  """
17
- Internal interface used by the catalog and runtime system to interact with (embedding) indices:
18
- - types and expressions needed to create and populate the index value column
19
- - creating/dropping the index
20
- - translating 'matches' queries into sqlalchemy predicates
21
+ Interface to the pgvector access method in Postgres.
22
+ - pgvector converts the cosine metric to '1 - metric' and the inner product metric to '-metric', in order to
23
+ satisfy the Postgres requirement that an index scan requires an ORDER BY ... ASC clause
24
+ - similarity_clause() converts those metrics back to their original form; it is used in expressions outside
25
+ the Order By clause
26
+ - order_by_clause() is used exclusively in the ORDER BY clause
21
27
  """
22
28
 
29
+ class Metric(enum.Enum):
30
+ COSINE = 1
31
+ IP = 2
32
+ L2 = 3
33
+
34
+ PGVECTOR_OPS = {
35
+ Metric.COSINE: 'vector_cosine_ops',
36
+ Metric.IP: 'vector_ip_ops',
37
+ Metric.L2: 'vector_l2_ops'
38
+ }
39
+
23
40
  def __init__(
24
- self, c: catalog.Column, text_embed: Optional[func.Function] = None,
41
+ self, c: catalog.Column, metric: str, text_embed: Optional[func.Function] = None,
25
42
  img_embed: Optional[func.Function] = None):
43
+ metric_names = [m.name.lower() for m in self.Metric]
44
+ if metric.lower() not in metric_names:
45
+ raise excs.Error(f'Invalid metric {metric}, must be one of {metric_names}')
26
46
  if not c.col_type.is_string_type() and not c.col_type.is_image_type():
27
47
  raise excs.Error(f'Embedding index requires string or image column')
28
48
  if c.col_type.is_string_type() and text_embed is None:
@@ -36,6 +56,7 @@ class EmbeddingIndex(IndexBase):
36
56
  # verify signature
37
57
  self._validate_embedding_fn(img_embed, 'img_embed', ts.ColumnType.Type.IMAGE)
38
58
 
59
+ self.metric = self.Metric[metric.upper()]
39
60
  from pixeltable.exprs import ColumnRef
40
61
  self.value_expr = text_embed(ColumnRef(c)) if c.col_type.is_string_type() else img_embed(ColumnRef(c))
41
62
  assert self.value_expr.col_type.is_array_type()
@@ -49,7 +70,10 @@ class EmbeddingIndex(IndexBase):
49
70
  """Return expression that computes the value that goes into the index"""
50
71
  return self.value_expr
51
72
 
52
- def index_sa_type(self) -> sql.sqltypes.TypeEngine:
73
+ def records_value_errors(self) -> bool:
74
+ return True
75
+
76
+ def index_sa_type(self) -> sql.types.TypeEngine:
53
77
  """Return the sqlalchemy type of the index value column"""
54
78
  return self.index_col_type
55
79
 
@@ -59,10 +83,51 @@ class EmbeddingIndex(IndexBase):
59
83
  index_name, index_value_col.sa_col,
60
84
  postgresql_using='hnsw',
61
85
  postgresql_with={'m': 16, 'ef_construction': 64},
62
- postgresql_ops={index_value_col.sa_col.name: 'vector_cosine_ops'}
86
+ postgresql_ops={index_value_col.sa_col.name: self.PGVECTOR_OPS[self.metric]}
63
87
  )
64
88
  idx.create(bind=conn)
65
89
 
90
+ def similarity_clause(self, val_column: catalog.Column, item: Any) -> sql.ClauseElement:
91
+ """Create a ClauseElement to that represents '<val_column> <op> <item>'"""
92
+ assert isinstance(item, (str, PIL.Image.Image))
93
+ if isinstance(item, str):
94
+ assert self.txt_embed is not None
95
+ embedding = self.txt_embed.exec(item)
96
+ if isinstance(item, PIL.Image.Image):
97
+ assert self.img_embed is not None
98
+ embedding = self.img_embed.exec(item)
99
+
100
+ if self.metric == self.Metric.COSINE:
101
+ return val_column.sa_col.cosine_distance(embedding) * -1 + 1
102
+ elif self.metric == self.Metric.IP:
103
+ return val_column.sa_col.max_inner_product(embedding) * -1
104
+ else:
105
+ assert self.metric == self.Metric.L2
106
+ return val_column.sa_col.l2_distance(embedding)
107
+
108
+ def order_by_clause(self, val_column: catalog.Column, item: Any, is_asc: bool) -> sql.ClauseElement:
109
+ """Create a ClauseElement that is used in an ORDER BY clause"""
110
+ assert isinstance(item, (str, PIL.Image.Image))
111
+ embedding: Optional[np.ndarray] = None
112
+ if isinstance(item, str):
113
+ assert self.txt_embed is not None
114
+ embedding = self.txt_embed.exec(item)
115
+ if isinstance(item, PIL.Image.Image):
116
+ assert self.img_embed is not None
117
+ embedding = self.img_embed.exec(item)
118
+ assert embedding is not None
119
+
120
+ if self.metric == self.Metric.COSINE:
121
+ result = val_column.sa_col.cosine_distance(embedding)
122
+ result = result.desc() if is_asc else result
123
+ elif self.metric == self.Metric.IP:
124
+ result = val_column.sa_col.max_inner_product(embedding)
125
+ result = result.desc() if is_asc else result
126
+ else:
127
+ assert self.metric == self.Metric.L2
128
+ result = val_column.sa_col.l2_distance(embedding)
129
+ return result
130
+
66
131
  @classmethod
67
132
  def display_name(cls) -> str:
68
133
  return 'embedding'
@@ -72,18 +137,29 @@ class EmbeddingIndex(IndexBase):
72
137
  """Validate the signature"""
73
138
  assert isinstance(embed_fn, func.Function)
74
139
  sig = embed_fn.signature
75
- if not sig.return_type.is_array_type():
76
- raise excs.Error(f'{name} must return an array, but returns {sig.return_type}')
77
- else:
78
- shape = sig.return_type.shape
79
- if len(shape) != 1 or shape[0] == None:
80
- raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {sig.return_type}')
81
140
  if len(sig.parameters) != 1 or sig.parameters_by_pos[0].col_type.type_enum != expected_type:
82
141
  raise excs.Error(
83
142
  f'{name} must take a single {expected_type.name.lower()} parameter, but has signature {sig}')
84
143
 
144
+ # validate return type
145
+ param_name = sig.parameters_by_pos[0].name
146
+ if expected_type == ts.ColumnType.Type.STRING:
147
+ return_type = embed_fn.call_return_type({param_name: 'dummy'})
148
+ else:
149
+ assert expected_type == ts.ColumnType.Type.IMAGE
150
+ img = PIL.Image.new('RGB', (512, 512))
151
+ return_type = embed_fn.call_return_type({param_name: img})
152
+ assert return_type is not None
153
+ if not return_type.is_array_type():
154
+ raise excs.Error(f'{name} must return an array, but returns {return_type}')
155
+ else:
156
+ shape = return_type.shape
157
+ if len(shape) != 1 or shape[0] == None:
158
+ raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {return_type}')
159
+
85
160
  def as_dict(self) -> dict:
86
161
  return {
162
+ 'metric': self.metric.name.lower(),
87
163
  'txt_embed': None if self.txt_embed is None else self.txt_embed.as_dict(),
88
164
  'img_embed': None if self.img_embed is None else self.img_embed.as_dict()
89
165
  }
@@ -92,4 +168,4 @@ class EmbeddingIndex(IndexBase):
92
168
  def from_dict(cls, c: catalog.Column, d: dict) -> EmbeddingIndex:
93
169
  txt_embed = func.Function.from_dict(d['txt_embed']) if d['txt_embed'] is not None else None
94
170
  img_embed = func.Function.from_dict(d['img_embed']) if d['img_embed'] is not None else None
95
- return cls(c, text_embed=txt_embed, img_embed=img_embed)
171
+ return cls(c, metric=d['metric'], text_embed=txt_embed, img_embed=img_embed)
@@ -0,0 +1,4 @@
1
+ from .globals import create_label_studio_project
2
+ from .hf_datasets import import_huggingface_dataset
3
+ from .pandas import import_csv, import_excel, import_pandas
4
+ from .parquet import import_parquet