pixeltable 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (56) hide show
  1. pixeltable/catalog/column.py +25 -48
  2. pixeltable/catalog/insertable_table.py +7 -4
  3. pixeltable/catalog/table.py +163 -57
  4. pixeltable/catalog/table_version.py +416 -140
  5. pixeltable/catalog/table_version_path.py +2 -2
  6. pixeltable/client.py +0 -4
  7. pixeltable/dataframe.py +65 -21
  8. pixeltable/env.py +16 -1
  9. pixeltable/exec/cache_prefetch_node.py +1 -1
  10. pixeltable/exec/in_memory_data_node.py +11 -7
  11. pixeltable/exprs/comparison.py +3 -3
  12. pixeltable/exprs/data_row.py +5 -1
  13. pixeltable/exprs/literal.py +16 -4
  14. pixeltable/exprs/row_builder.py +8 -40
  15. pixeltable/ext/__init__.py +5 -0
  16. pixeltable/ext/functions/yolox.py +92 -0
  17. pixeltable/func/aggregate_function.py +15 -15
  18. pixeltable/func/expr_template_function.py +9 -1
  19. pixeltable/func/globals.py +24 -14
  20. pixeltable/func/signature.py +18 -12
  21. pixeltable/func/udf.py +7 -2
  22. pixeltable/functions/__init__.py +8 -8
  23. pixeltable/functions/eval.py +7 -8
  24. pixeltable/functions/huggingface.py +47 -19
  25. pixeltable/functions/openai.py +2 -2
  26. pixeltable/functions/util.py +11 -0
  27. pixeltable/index/__init__.py +2 -0
  28. pixeltable/index/base.py +49 -0
  29. pixeltable/index/embedding_index.py +95 -0
  30. pixeltable/metadata/schema.py +45 -22
  31. pixeltable/plan.py +15 -34
  32. pixeltable/store.py +38 -41
  33. pixeltable/tests/conftest.py +5 -11
  34. pixeltable/tests/ext/test_yolox.py +21 -0
  35. pixeltable/tests/functions/test_fireworks.py +1 -0
  36. pixeltable/tests/functions/test_huggingface.py +2 -2
  37. pixeltable/tests/functions/test_openai.py +15 -5
  38. pixeltable/tests/functions/test_together.py +1 -0
  39. pixeltable/tests/test_component_view.py +14 -5
  40. pixeltable/tests/test_dataframe.py +19 -18
  41. pixeltable/tests/test_exprs.py +99 -102
  42. pixeltable/tests/test_function.py +51 -43
  43. pixeltable/tests/test_index.py +138 -0
  44. pixeltable/tests/test_migration.py +2 -1
  45. pixeltable/tests/test_snapshot.py +24 -1
  46. pixeltable/tests/test_table.py +101 -25
  47. pixeltable/tests/test_types.py +30 -0
  48. pixeltable/tests/test_video.py +16 -16
  49. pixeltable/tests/test_view.py +5 -0
  50. pixeltable/tests/utils.py +43 -9
  51. pixeltable/tool/create_test_db_dump.py +16 -0
  52. pixeltable/type_system.py +37 -45
  53. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/METADATA +5 -4
  54. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/RECORD +56 -49
  55. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
  56. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
@@ -1,29 +1,39 @@
1
- from typing import Optional
2
- from types import ModuleType
3
1
  import importlib
4
2
  import inspect
3
+ from types import ModuleType
4
+ from typing import Optional
5
5
 
6
+ import pixeltable.exceptions as excs
6
7
 
7
- def resolve_symbol(symbol_path: str) -> object:
8
+
9
+ def resolve_symbol(symbol_path: str) -> Optional[object]:
8
10
  path_elems = symbol_path.split('.')
9
11
  module: Optional[ModuleType] = None
10
- if path_elems[0:2] == ['pixeltable', 'functions'] and len(path_elems) > 2:
11
- # if this is a pixeltable.functions submodule, it cannot be resolved via pixeltable.functions;
12
- # try to import the submodule directly
13
- submodule_path = '.'.join(path_elems[0:3])
12
+ i = len(path_elems) - 1
13
+ while i > 0 and module is None:
14
14
  try:
15
- module = importlib.import_module(submodule_path)
16
- path_elems = path_elems[3:]
15
+ module = importlib.import_module('.'.join(path_elems[:i]))
17
16
  except ModuleNotFoundError:
18
- pass
19
- if module is None:
20
- module = importlib.import_module(path_elems[0])
21
- path_elems = path_elems[1:]
17
+ i -= 1
18
+ if i == 0:
19
+ return None # Not resolvable
22
20
  obj = module
23
- for el in path_elems:
21
+ for el in path_elems[i:]:
24
22
  obj = getattr(obj, el)
25
23
  return obj
26
24
 
25
+
26
+ def validate_symbol_path(fn_path: str) -> None:
27
+ path_elems = fn_path.split('.')
28
+ fn_name = path_elems[-1]
29
+ if any(el == '<locals>' for el in path_elems):
30
+ raise excs.Error(
31
+ f'{fn_name}(): nested functions are not supported. Move the function to the module level or into a class.')
32
+ if any(not el.isidentifier() for el in path_elems):
33
+ raise excs.Error(
34
+ f'{fn_name}(): cannot resolve symbol path {fn_path}. Move the function to the module level or into a class.')
35
+
36
+
27
37
  def get_caller_module_path() -> str:
28
38
  """Return the module path of our caller's caller"""
29
39
  stack = inspect.stack()
@@ -114,20 +114,12 @@ class Signature:
114
114
  return (col_type, is_batched)
115
115
 
116
116
  @classmethod
117
- def create(
118
- cls, c: Callable,
119
- param_types: Optional[List[ts.ColumnType]] = None,
120
- return_type: Optional[Union[ts.ColumnType, Callable]] = None
121
- ) -> Signature:
122
- """Create a signature for the given Callable.
123
- Infer the parameter and return types, if none are specified.
124
- Raises an exception if the types cannot be inferred.
125
- """
117
+ def create_parameters(
118
+ cls, c: Callable, param_types: Optional[List[ts.ColumnType]] = None) -> List[Parameter]:
126
119
  sig = inspect.signature(c)
127
120
  py_parameters = list(sig.parameters.values())
128
-
129
- # check non-var parameters for name collisions and default value compatibility
130
121
  parameters: List[Parameter] = []
122
+
131
123
  for idx, param in enumerate(py_parameters):
132
124
  if param.name in cls.SPECIAL_PARAM_NAMES:
133
125
  raise excs.Error(f"'{param.name}' is a reserved parameter name")
@@ -135,6 +127,7 @@ class Signature:
135
127
  parameters.append(Parameter(param.name, None, param.kind, False))
136
128
  continue
137
129
 
130
+ # check non-var parameters for name collisions and default value compatibility
138
131
  if param_types is not None:
139
132
  if idx >= len(param_types):
140
133
  raise excs.Error(f'Missing type for parameter {param.name}')
@@ -155,7 +148,20 @@ class Signature:
155
148
 
156
149
  parameters.append(Parameter(param.name, param_type, param.kind, is_batched))
157
150
 
158
- return_is_batched = False
151
+ return parameters
152
+
153
+ @classmethod
154
+ def create(
155
+ cls, c: Callable,
156
+ param_types: Optional[List[ts.ColumnType]] = None,
157
+ return_type: Optional[Union[ts.ColumnType, Callable]] = None
158
+ ) -> Signature:
159
+ """Create a signature for the given Callable.
160
+ Infer the parameter and return types, if none are specified.
161
+ Raises an exception if the types cannot be inferred.
162
+ """
163
+ parameters = cls.create_parameters(c, param_types)
164
+ sig = inspect.signature(c)
159
165
  if return_type is None:
160
166
  return_type, return_is_batched = cls._infer_type(sig.return_annotation)
161
167
  if return_type is None:
pixeltable/func/udf.py CHANGED
@@ -11,6 +11,7 @@ from .callable_function import CallableFunction
11
11
  from .expr_template_function import ExprTemplateFunction
12
12
  from .function import Function
13
13
  from .function_registry import FunctionRegistry
14
+ from .globals import validate_symbol_path
14
15
  from .signature import Signature
15
16
 
16
17
 
@@ -124,6 +125,8 @@ def make_function(
124
125
 
125
126
  # If this function is part of a module, register it
126
127
  if function_path is not None:
128
+ # do the validation at the very end, so it's easier to write tests for other failure scenarios
129
+ validate_symbol_path(function_path)
127
130
  FunctionRegistry.get().register_function(function_path, result)
128
131
 
129
132
  return result
@@ -142,17 +145,19 @@ def expr_udf(*args: Any, **kwargs: Any) -> Any:
142
145
  else:
143
146
  function_path = None
144
147
 
145
- sig = Signature.create(py_fn, param_types=param_types, return_type=None)
146
148
  # TODO: verify that the inferred return type matches that of the template
147
149
  # TODO: verify that the signature doesn't contain batched parameters
148
150
 
149
151
  # construct Parameters from the function signature
152
+ params = Signature.create_parameters(py_fn, param_types=param_types)
150
153
  import pixeltable.exprs as exprs
151
- var_exprs = [exprs.Variable(param.name, param.col_type) for param in sig.parameters.values()]
154
+ var_exprs = [exprs.Variable(param.name, param.col_type) for param in params]
152
155
  # call the function with the parameter expressions to construct an Expr with parameters
153
156
  template = py_fn(*var_exprs)
154
157
  assert isinstance(template, exprs.Expr)
155
158
  py_sig = inspect.signature(py_fn)
159
+ if function_path is not None:
160
+ validate_symbol_path(function_path)
156
161
  return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
157
162
 
158
163
  if len(args) == 1:
@@ -23,8 +23,8 @@ def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
23
23
  return expr
24
24
 
25
25
  @func.uda(
26
- update_types=[IntType()], value_type=IntType(), name='sum', allows_window=True, requires_order_by=False)
27
- class SumAggregator(func.Aggregator):
26
+ update_types=[IntType()], value_type=IntType(), allows_window=True, requires_order_by=False)
27
+ class sum(func.Aggregator):
28
28
  def __init__(self):
29
29
  self.sum: Union[int, float] = 0
30
30
  def update(self, val: Union[int, float]) -> None:
@@ -35,8 +35,8 @@ class SumAggregator(func.Aggregator):
35
35
 
36
36
 
37
37
  @func.uda(
38
- update_types=[IntType()], value_type=IntType(), name='count', allows_window = True, requires_order_by = False)
39
- class CountAggregator(func.Aggregator):
38
+ update_types=[IntType()], value_type=IntType(), allows_window = True, requires_order_by = False)
39
+ class count(func.Aggregator):
40
40
  def __init__(self):
41
41
  self.count = 0
42
42
  def update(self, val: int) -> None:
@@ -47,8 +47,8 @@ class CountAggregator(func.Aggregator):
47
47
 
48
48
 
49
49
  @func.uda(
50
- update_types=[IntType()], value_type=FloatType(), name='mean', allows_window=False, requires_order_by=False)
51
- class MeanAggregator(func.Aggregator):
50
+ update_types=[IntType()], value_type=FloatType(), allows_window=False, requires_order_by=False)
51
+ class mean(func.Aggregator):
52
52
  def __init__(self):
53
53
  self.sum = 0
54
54
  self.count = 0
@@ -63,9 +63,9 @@ class MeanAggregator(func.Aggregator):
63
63
 
64
64
 
65
65
  @func.uda(
66
- init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(), name='make_video',
66
+ init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
67
67
  requires_order_by=True, allows_window=False)
68
- class VideoAggregator(func.Aggregator):
68
+ class make_video(func.Aggregator):
69
69
  def __init__(self, fps: int = 25):
70
70
  """follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
71
71
  self.container: Optional[av.container.OutputContainer] = None
@@ -1,4 +1,3 @@
1
- from __future__ import annotations
2
1
  from typing import List, Tuple, Dict
3
2
  from collections import defaultdict
4
3
  import sys
@@ -157,16 +156,16 @@ def calculate_image_tpfp(
157
156
  ts.JsonType(nullable=False)
158
157
  ])
159
158
  def eval_detections(
160
- pred_bboxes: List[List[int]], pred_classes: List[int], pred_scores: List[float],
161
- gt_bboxes: List[List[int]], gt_classes: List[int]
159
+ pred_bboxes: List[List[int]], pred_labels: List[int], pred_scores: List[float],
160
+ gt_bboxes: List[List[int]], gt_labels: List[int]
162
161
  ) -> Dict:
163
- class_idxs = list(set(pred_classes + gt_classes))
162
+ class_idxs = list(set(pred_labels + gt_labels))
164
163
  result: List[Dict] = []
165
164
  pred_bboxes_arr = np.asarray(pred_bboxes)
166
- pred_classes_arr = np.asarray(pred_classes)
165
+ pred_classes_arr = np.asarray(pred_labels)
167
166
  pred_scores_arr = np.asarray(pred_scores)
168
167
  gt_bboxes_arr = np.asarray(gt_bboxes)
169
- gt_classes_arr = np.asarray(gt_classes)
168
+ gt_classes_arr = np.asarray(gt_labels)
170
169
  for class_idx in class_idxs:
171
170
  pred_filter = pred_classes_arr == class_idx
172
171
  gt_filter = gt_classes_arr == class_idx
@@ -181,8 +180,8 @@ def eval_detections(
181
180
  return result
182
181
 
183
182
  @func.uda(
184
- update_types=[ts.JsonType()], value_type=ts.JsonType(), name='mean_ap', allows_std_agg=True, allows_window=False)
185
- class MeanAPAggregator:
183
+ update_types=[ts.JsonType()], value_type=ts.JsonType(), allows_std_agg=True, allows_window=False)
184
+ class mean_ap(func.Aggregator):
186
185
  def __init__(self):
187
186
  self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
188
187
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable
1
+ from typing import Callable, TypeVar, Optional
2
2
 
3
3
  import PIL.Image
4
4
  import numpy as np
@@ -7,10 +7,13 @@ import pixeltable as pxt
7
7
  import pixeltable.env as env
8
8
  import pixeltable.type_system as ts
9
9
  from pixeltable.func import Batch
10
+ from pixeltable.functions.util import resolve_torch_device
10
11
 
11
12
 
12
13
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
13
- def sentence_transformer(sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False) -> Batch[np.ndarray]:
14
+ def sentence_transformer(
15
+ sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False
16
+ ) -> Batch[np.ndarray]:
14
17
  env.Env.get().require_package('sentence_transformers')
15
18
  from sentence_transformers import SentenceTransformer
16
19
 
@@ -53,44 +56,60 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
53
56
  return array.tolist()
54
57
 
55
58
 
56
- @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
59
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
57
60
  def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
58
61
  env.Env.get().require_package('transformers')
62
+ device = resolve_torch_device('auto')
63
+ import torch
59
64
  from transformers import CLIPModel, CLIPProcessor
60
65
 
61
- model = _lookup_model(model_id, CLIPModel.from_pretrained)
66
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
67
+ assert model.config.projection_dim == 512
62
68
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
63
69
 
64
- inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
65
- embeddings = model.get_text_features(**inputs).detach().numpy()
70
+ with torch.no_grad():
71
+ inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
72
+ embeddings = model.get_text_features(**inputs.to(device)).detach().to('cpu').numpy()
73
+
66
74
  return [embeddings[i] for i in range(embeddings.shape[0])]
67
75
 
68
76
 
69
- @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
77
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
70
78
  def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
71
79
  env.Env.get().require_package('transformers')
80
+ device = resolve_torch_device('auto')
81
+ import torch
72
82
  from transformers import CLIPModel, CLIPProcessor
73
83
 
74
- model = _lookup_model(model_id, CLIPModel.from_pretrained)
84
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
85
+ assert model.config.projection_dim == 512
75
86
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
76
87
 
77
- inputs = processor(images=image, return_tensors='pt', padding=True)
78
- embeddings = model.get_image_features(**inputs).detach().numpy()
88
+ with torch.no_grad():
89
+ inputs = processor(images=image, return_tensors='pt', padding=True)
90
+ embeddings = model.get_image_features(**inputs.to(device)).detach().to('cpu').numpy()
91
+
79
92
  return [embeddings[i] for i in range(embeddings.shape[0])]
80
93
 
81
94
 
82
- @pxt.udf(batch_size=32)
95
+ @pxt.udf(batch_size=4)
83
96
  def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
84
97
  env.Env.get().require_package('transformers')
98
+ device = resolve_torch_device('auto')
99
+ import torch
85
100
  from transformers import DetrImageProcessor, DetrForObjectDetection
86
101
 
87
- model = _lookup_model(model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'))
102
+ model = _lookup_model(
103
+ model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device)
88
104
  processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
89
105
 
90
- inputs = processor(images=image, return_tensors='pt')
91
- outputs = model(**inputs)
106
+ with torch.no_grad():
107
+ inputs = processor(images=image, return_tensors='pt')
108
+ outputs = model(**inputs.to(device))
109
+ results = processor.post_process_object_detection(
110
+ outputs, threshold=threshold, target_sizes=[(img.height, img.width) for img in image]
111
+ )
92
112
 
93
- results = processor.post_process_object_detection(outputs, threshold=threshold)
94
113
  return [
95
114
  {
96
115
  'scores': [score.item() for score in result['scores']],
@@ -102,14 +121,23 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
102
121
  ]
103
122
 
104
123
 
105
- def _lookup_model(model_id: str, create: Callable) -> Any:
106
- key = (model_id, create) # For safety, include the `create` callable in the cache key
124
+ T = TypeVar('T')
125
+
126
+
127
+ def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
128
+ from torch import nn
129
+ key = (model_id, create, device) # For safety, include the `create` callable in the cache key
107
130
  if key not in _model_cache:
108
- _model_cache[key] = create(model_id)
131
+ model = create(model_id)
132
+ if device is not None:
133
+ model.to(device)
134
+ if isinstance(model, nn.Module):
135
+ model.eval()
136
+ _model_cache[key] = model
109
137
  return _model_cache[key]
110
138
 
111
139
 
112
- def _lookup_processor(model_id: str, create: Callable) -> Any:
140
+ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
113
141
  key = (model_id, create) # For safety, include the `create` callable in the cache key
114
142
  if key not in _processor_cache:
115
143
  _processor_cache[key] = create(model_id)
@@ -26,8 +26,8 @@ def openai_client() -> openai.OpenAI:
26
26
  def _retry(fn: Callable) -> Callable:
27
27
  return tenacity.retry(
28
28
  retry=tenacity.retry_if_exception_type(openai.RateLimitError),
29
- wait=tenacity.wait_random_exponential(min=1, max=60),
30
- stop=tenacity.stop_after_attempt(6)
29
+ wait=tenacity.wait_random_exponential(multiplier=3, max=180),
30
+ stop=tenacity.stop_after_attempt(20)
31
31
  )(fn)
32
32
 
33
33
 
@@ -39,3 +39,14 @@ def create_nos_modules() -> List[types.ModuleType]:
39
39
  setattr(sub_module, model_id, pt_func)
40
40
 
41
41
  return new_modules
42
+
43
+
44
+ def resolve_torch_device(device: str) -> str:
45
+ import torch
46
+ if device == 'auto':
47
+ if torch.cuda.is_available():
48
+ return 'cuda'
49
+ if torch.backends.mps.is_available():
50
+ return 'mps'
51
+ return 'cpu'
52
+ return device
@@ -0,0 +1,2 @@
1
+ from .base import IndexBase
2
+ from .embedding_index import EmbeddingIndex
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import Any
5
+
6
+ import sqlalchemy as sql
7
+
8
+ import pixeltable.catalog as catalog
9
+
10
+
11
+ class IndexBase(abc.ABC):
12
+ """
13
+ Internal interface used by the catalog and runtime system to interact with indices:
14
+ - types and expressions needed to create and populate the index value column
15
+ - creating/dropping the index
16
+ - TODO: translating queries into sqlalchemy predicates
17
+ """
18
+ @abc.abstractmethod
19
+ def __init__(self, c: catalog.Column, **kwargs: Any):
20
+ pass
21
+
22
+ @abc.abstractmethod
23
+ def index_value_expr(self) -> 'pixeltable.exprs.Expr':
24
+ """Return expression that computes the value that goes into the index"""
25
+ pass
26
+
27
+ @abc.abstractmethod
28
+ def index_sa_type(self) -> sql.sqltypes.TypeEngine:
29
+ """Return the sqlalchemy type of the index value column"""
30
+ pass
31
+
32
+ @abc.abstractmethod
33
+ def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
34
+ """Create the index on the index value column"""
35
+ pass
36
+
37
+ @classmethod
38
+ @abc.abstractmethod
39
+ def display_name(cls) -> str:
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ def as_dict(self) -> dict:
44
+ pass
45
+
46
+ @classmethod
47
+ @abc.abstractmethod
48
+ def from_dict(cls, c: catalog.Column, d: dict) -> IndexBase:
49
+ pass
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import pgvector.sqlalchemy
6
+ import sqlalchemy as sql
7
+
8
+ import pixeltable.catalog as catalog
9
+ import pixeltable.exceptions as excs
10
+ import pixeltable.func as func
11
+ import pixeltable.type_system as ts
12
+ from .base import IndexBase
13
+
14
+
15
+ class EmbeddingIndex(IndexBase):
16
+ """
17
+ Internal interface used by the catalog and runtime system to interact with (embedding) indices:
18
+ - types and expressions needed to create and populate the index value column
19
+ - creating/dropping the index
20
+ - translating 'matches' queries into sqlalchemy predicates
21
+ """
22
+
23
+ def __init__(
24
+ self, c: catalog.Column, text_embed: Optional[func.Function] = None,
25
+ img_embed: Optional[func.Function] = None):
26
+ if not c.col_type.is_string_type() and not c.col_type.is_image_type():
27
+ raise excs.Error(f'Embedding index requires string or image column')
28
+ if c.col_type.is_string_type() and text_embed is None:
29
+ raise excs.Error(f'Text embedding function is required for column {c.name} (parameter `txt_embed`)')
30
+ if c.col_type.is_image_type() and img_embed is None:
31
+ raise excs.Error(f'Image embedding function is required for column {c.name} (parameter `img_embed`)')
32
+ if text_embed is not None:
33
+ # verify signature
34
+ self._validate_embedding_fn(text_embed, 'txt_embed', ts.ColumnType.Type.STRING)
35
+ if img_embed is not None:
36
+ # verify signature
37
+ self._validate_embedding_fn(img_embed, 'img_embed', ts.ColumnType.Type.IMAGE)
38
+
39
+ from pixeltable.exprs import ColumnRef
40
+ self.value_expr = text_embed(ColumnRef(c)) if c.col_type.is_string_type() else img_embed(ColumnRef(c))
41
+ assert self.value_expr.col_type.is_array_type()
42
+ self.txt_embed = text_embed
43
+ self.img_embed = img_embed
44
+ vector_size = self.value_expr.col_type.shape[0]
45
+ assert vector_size is not None
46
+ self.index_col_type = pgvector.sqlalchemy.Vector(vector_size)
47
+
48
+ def index_value_expr(self) -> 'pixeltable.exprs.Expr':
49
+ """Return expression that computes the value that goes into the index"""
50
+ return self.value_expr
51
+
52
+ def index_sa_type(self) -> sql.sqltypes.TypeEngine:
53
+ """Return the sqlalchemy type of the index value column"""
54
+ return self.index_col_type
55
+
56
+ def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
57
+ """Create the index on the index value column"""
58
+ idx = sql.Index(
59
+ index_name, index_value_col.sa_col,
60
+ postgresql_using='hnsw',
61
+ postgresql_with={'m': 16, 'ef_construction': 64},
62
+ postgresql_ops={index_value_col.sa_col.name: 'vector_cosine_ops'}
63
+ )
64
+ idx.create(bind=conn)
65
+
66
+ @classmethod
67
+ def display_name(cls) -> str:
68
+ return 'embedding'
69
+
70
+ @classmethod
71
+ def _validate_embedding_fn(cls, embed_fn: func.Function, name: str, expected_type: ts.ColumnType.Type) -> None:
72
+ """Validate the signature"""
73
+ assert isinstance(embed_fn, func.Function)
74
+ sig = embed_fn.signature
75
+ if not sig.return_type.is_array_type():
76
+ raise excs.Error(f'{name} must return an array, but returns {sig.return_type}')
77
+ else:
78
+ shape = sig.return_type.shape
79
+ if len(shape) != 1 or shape[0] == None:
80
+ raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {sig.return_type}')
81
+ if len(sig.parameters) != 1 or sig.parameters_by_pos[0].col_type.type_enum != expected_type:
82
+ raise excs.Error(
83
+ f'{name} must take a single {expected_type.name.lower()} parameter, but has signature {sig}')
84
+
85
+ def as_dict(self) -> dict:
86
+ return {
87
+ 'txt_embed': None if self.txt_embed is None else self.txt_embed.as_dict(),
88
+ 'img_embed': None if self.img_embed is None else self.img_embed.as_dict()
89
+ }
90
+
91
+ @classmethod
92
+ def from_dict(cls, c: catalog.Column, d: dict) -> EmbeddingIndex:
93
+ txt_embed = func.Function.from_dict(d['txt_embed']) if d['txt_embed'] is not None else None
94
+ img_embed = func.Function.from_dict(d['img_embed']) if d['img_embed'] is not None else None
95
+ return cls(c, text_embed=txt_embed, img_embed=img_embed)
@@ -1,4 +1,4 @@
1
- from typing import Optional, List, Dict, get_type_hints, Type, Any, TypeVar, Tuple, Union
1
+ from typing import Optional, List, get_type_hints, Type, Any, TypeVar, Tuple, Union
2
2
  import platform
3
3
  import uuid
4
4
  import dataclasses
@@ -71,16 +71,43 @@ class Dir(Base):
71
71
 
72
72
 
73
73
  @dataclasses.dataclass
74
- class ColumnHistory:
74
+ class ColumnMd:
75
75
  """
76
- Records when a column was added/dropped, which is needed to GC unreachable storage columns
77
- (a column that was added after table snapshot n and dropped before table snapshot n+1 can be removed
78
- from the stored table).
79
- One record per column (across all schema versions).
76
+ Records the non-versioned metadata of a column.
77
+ - immutable attributes: type, primary key, etc.
78
+ - when a column was added/dropped, which is needed to GC unreachable storage columns
79
+ (a column that was added after table snapshot n and dropped before table snapshot n+1 can be removed
80
+ from the stored table).
80
81
  """
81
- col_id: int
82
+ id: int
82
83
  schema_version_add: int
83
84
  schema_version_drop: Optional[int]
85
+ col_type: dict
86
+
87
+ # if True, is part of the primary key
88
+ is_pk: bool
89
+
90
+ # if set, this is a computed column
91
+ value_expr: Optional[dict]
92
+
93
+ # if True, the column is present in the stored table
94
+ stored: Optional[bool]
95
+
96
+
97
+ @dataclasses.dataclass
98
+ class IndexMd:
99
+ """
100
+ Metadata needed to instantiate an EmbeddingIndex
101
+ """
102
+ id: int
103
+ name: str
104
+ indexed_col_id: int # column being indexed
105
+ index_val_col_id: int # column holding the values to be indexed
106
+ index_val_undo_col_id: int # column holding index values for deleted rows
107
+ schema_version_add: int
108
+ schema_version_drop: Optional[int]
109
+ class_fqn: str
110
+ init_args: dict[str, Any]
84
111
 
85
112
 
86
113
  @dataclasses.dataclass
@@ -91,13 +118,13 @@ class ViewMd:
91
118
  base_versions: List[Tuple[str, Optional[int]]]
92
119
 
93
120
  # filter predicate applied to the base table; view-only
94
- predicate: Optional[Dict[str, Any]]
121
+ predicate: Optional[dict[str, Any]]
95
122
 
96
123
  # ComponentIterator subclass; only for component views
97
124
  iterator_class_fqn: Optional[str]
98
125
 
99
126
  # args to pass to the iterator class constructor; only for component views
100
- iterator_args: Optional[Dict[str, Any]]
127
+ iterator_args: Optional[dict[str, Any]]
101
128
 
102
129
 
103
130
  @dataclasses.dataclass
@@ -109,15 +136,15 @@ class TableMd:
109
136
  # each version has a corresponding schema version (current_version >= current_schema_version)
110
137
  current_schema_version: int
111
138
 
112
- # used to assign Column.id
113
- next_col_id: int
139
+ next_col_id: int # used to assign Column.id
140
+ next_idx_id: int # used to assign IndexMd.id
114
141
 
115
142
  # - used to assign the rowid column in the storage table
116
143
  # - every row is assigned a unique and immutable rowid on insertion
117
144
  next_row_id: int
118
145
 
119
- column_history: Dict[int, ColumnHistory] # col_id -> ColumnHistory
120
-
146
+ column_md: dict[int, ColumnMd] # col_id -> ColumnMd
147
+ index_md: dict[int, IndexMd] # index_id -> IndexMd
121
148
  view_md: Optional[ViewMd]
122
149
 
123
150
 
@@ -155,24 +182,20 @@ class TableVersion(Base):
155
182
  @dataclasses.dataclass
156
183
  class SchemaColumn:
157
184
  """
158
- Records the logical (user-visible) schema of a table.
159
- Contains the full set of columns for each new schema version: one record per (column x schema version).
185
+ Records the versioned metadata of a column.
160
186
  """
161
187
  pos: int
162
188
  name: str
163
- col_type: dict
164
- is_pk: bool
165
- value_expr: Optional[dict]
166
- stored: Optional[bool]
167
- # if True, creates vector index for this column
168
- is_indexed: bool
169
189
 
170
190
 
171
191
  @dataclasses.dataclass
172
192
  class TableSchemaVersionMd:
193
+ """
194
+ Records all versioned table metadata.
195
+ """
173
196
  schema_version: int
174
197
  preceding_schema_version: Optional[int]
175
- columns: Dict[int, SchemaColumn] # col_id -> SchemaColumn
198
+ columns: dict[int, SchemaColumn] # col_id -> SchemaColumn
176
199
  num_retained_versions: int
177
200
  comment: str
178
201