pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (63) hide show
  1. pixeltable/catalog/column.py +26 -49
  2. pixeltable/catalog/insertable_table.py +7 -4
  3. pixeltable/catalog/table.py +163 -57
  4. pixeltable/catalog/table_version.py +416 -140
  5. pixeltable/catalog/table_version_path.py +2 -2
  6. pixeltable/client.py +72 -6
  7. pixeltable/dataframe.py +65 -21
  8. pixeltable/env.py +52 -53
  9. pixeltable/exec/cache_prefetch_node.py +1 -1
  10. pixeltable/exec/in_memory_data_node.py +11 -7
  11. pixeltable/exprs/comparison.py +3 -3
  12. pixeltable/exprs/data_row.py +5 -1
  13. pixeltable/exprs/literal.py +16 -4
  14. pixeltable/exprs/row_builder.py +8 -40
  15. pixeltable/ext/__init__.py +5 -0
  16. pixeltable/ext/functions/yolox.py +92 -0
  17. pixeltable/func/aggregate_function.py +15 -15
  18. pixeltable/func/expr_template_function.py +9 -1
  19. pixeltable/func/globals.py +24 -14
  20. pixeltable/func/signature.py +18 -12
  21. pixeltable/func/udf.py +7 -2
  22. pixeltable/functions/__init__.py +9 -9
  23. pixeltable/functions/eval.py +7 -8
  24. pixeltable/functions/fireworks.py +10 -37
  25. pixeltable/functions/huggingface.py +47 -19
  26. pixeltable/functions/openai.py +192 -24
  27. pixeltable/functions/together.py +104 -9
  28. pixeltable/functions/util.py +11 -0
  29. pixeltable/index/__init__.py +2 -0
  30. pixeltable/index/base.py +49 -0
  31. pixeltable/index/embedding_index.py +95 -0
  32. pixeltable/metadata/schema.py +45 -22
  33. pixeltable/plan.py +15 -34
  34. pixeltable/store.py +38 -41
  35. pixeltable/tests/conftest.py +8 -14
  36. pixeltable/tests/ext/test_yolox.py +21 -0
  37. pixeltable/tests/functions/test_fireworks.py +43 -0
  38. pixeltable/tests/functions/test_functions.py +60 -0
  39. pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
  40. pixeltable/tests/functions/test_openai.py +162 -0
  41. pixeltable/tests/functions/test_together.py +112 -0
  42. pixeltable/tests/test_component_view.py +14 -5
  43. pixeltable/tests/test_dataframe.py +23 -22
  44. pixeltable/tests/test_exprs.py +99 -102
  45. pixeltable/tests/test_function.py +51 -43
  46. pixeltable/tests/test_index.py +138 -0
  47. pixeltable/tests/test_migration.py +2 -1
  48. pixeltable/tests/test_snapshot.py +24 -1
  49. pixeltable/tests/test_table.py +205 -26
  50. pixeltable/tests/test_types.py +30 -0
  51. pixeltable/tests/test_video.py +16 -16
  52. pixeltable/tests/test_view.py +5 -0
  53. pixeltable/tests/utils.py +171 -14
  54. pixeltable/tool/create_test_db_dump.py +16 -0
  55. pixeltable/type_system.py +77 -128
  56. pixeltable/utils/arrow.py +98 -0
  57. pixeltable/utils/hf_datasets.py +157 -0
  58. pixeltable/utils/parquet.py +68 -27
  59. pixeltable/utils/pytorch.py +16 -97
  60. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
  61. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
  62. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
  63. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
@@ -1,9 +1,14 @@
1
1
  import base64
2
2
  import io
3
- from typing import Optional
3
+ import pathlib
4
+ import uuid
5
+ from typing import Optional, TypeVar, Union, Callable
4
6
 
5
7
  import PIL.Image
6
8
  import numpy as np
9
+ import openai
10
+ import tenacity
11
+ from openai._types import NOT_GIVEN, NotGiven
7
12
 
8
13
  import pixeltable as pxt
9
14
  import pixeltable.type_system as ts
@@ -11,43 +16,148 @@ from pixeltable import env
11
16
  from pixeltable.func import Batch
12
17
 
13
18
 
19
+ def openai_client() -> openai.OpenAI:
20
+ return env.Env.get().get_client('openai', lambda api_key: openai.OpenAI(api_key=api_key))
21
+
22
+
23
+ # Exponential backoff decorator using tenacity.
24
+ # TODO(aaron-siegel): Right now this hardwires random exponential backoff with defaults suggested
25
+ # by OpenAI. Should we investigate making this more customizable in the future?
26
+ def _retry(fn: Callable) -> Callable:
27
+ return tenacity.retry(
28
+ retry=tenacity.retry_if_exception_type(openai.RateLimitError),
29
+ wait=tenacity.wait_random_exponential(multiplier=3, max=180),
30
+ stop=tenacity.stop_after_attempt(20)
31
+ )(fn)
32
+
33
+
34
+ #####################################
35
+ # Audio Endpoints
36
+
37
+ @pxt.udf(return_type=ts.AudioType())
38
+ @_retry
39
+ def speech(
40
+ input: str,
41
+ *,
42
+ model: str,
43
+ voice: str,
44
+ response_format: Optional[str] = None,
45
+ speed: Optional[float] = None
46
+ ) -> str:
47
+ content = openai_client().audio.speech.create(
48
+ input=input,
49
+ model=model,
50
+ voice=voice,
51
+ response_format=_opt(response_format),
52
+ speed=_opt(speed)
53
+ )
54
+ ext = response_format or 'mp3'
55
+ output_filename = str(env.Env.get().tmp_dir / f"{uuid.uuid4()}.{ext}")
56
+ content.stream_to_file(output_filename, chunk_size=1 << 20)
57
+ return output_filename
58
+
59
+
60
+ @pxt.udf(
61
+ param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True),
62
+ ts.StringType(nullable=True), ts.FloatType(nullable=True)]
63
+ )
64
+ @_retry
65
+ def transcriptions(
66
+ audio: str,
67
+ *,
68
+ model: str,
69
+ language: Optional[str] = None,
70
+ prompt: Optional[str] = None,
71
+ temperature: Optional[float] = None
72
+ ) -> dict:
73
+ file = pathlib.Path(audio)
74
+ transcription = openai_client().audio.transcriptions.create(
75
+ file=file,
76
+ model=model,
77
+ language=_opt(language),
78
+ prompt=_opt(prompt),
79
+ temperature=_opt(temperature)
80
+ )
81
+ return transcription.dict()
82
+
83
+
84
+ @pxt.udf(
85
+ param_types=[ts.AudioType(), ts.StringType(), ts.StringType(nullable=True), ts.FloatType(nullable=True)]
86
+ )
87
+ @_retry
88
+ def translations(
89
+ audio: str,
90
+ *,
91
+ model: str,
92
+ prompt: Optional[str] = None,
93
+ temperature: Optional[float] = None
94
+ ) -> dict:
95
+ file = pathlib.Path(audio)
96
+ translation = openai_client().audio.translations.create(
97
+ file=file,
98
+ model=model,
99
+ prompt=_opt(prompt),
100
+ temperature=_opt(temperature)
101
+ )
102
+ return translation.dict()
103
+
104
+
105
+ #####################################
106
+ # Chat Endpoints
107
+
14
108
  @pxt.udf
109
+ @_retry
15
110
  def chat_completions(
16
111
  messages: list,
112
+ *,
17
113
  model: str,
18
114
  frequency_penalty: Optional[float] = None,
19
- logit_bias: Optional[dict] = None,
115
+ logit_bias: Optional[dict[str, int]] = None,
116
+ logprobs: Optional[bool] = None,
117
+ top_logprobs: Optional[int] = None,
20
118
  max_tokens: Optional[int] = None,
21
119
  n: Optional[int] = None,
22
120
  presence_penalty: Optional[float] = None,
23
121
  response_format: Optional[dict] = None,
24
122
  seed: Optional[int] = None,
123
+ stop: Optional[list[str]] = None,
124
+ temperature: Optional[float] = None,
25
125
  top_p: Optional[float] = None,
26
- temperature: Optional[float] = None
126
+ tools: Optional[list[dict]] = None,
127
+ tool_choice: Optional[dict] = None,
128
+ user: Optional[str] = None
27
129
  ) -> dict:
28
- from openai._types import NOT_GIVEN
29
- result = env.Env.get().openai_client.chat.completions.create(
130
+ result = openai_client().chat.completions.create(
30
131
  messages=messages,
31
132
  model=model,
32
- frequency_penalty=frequency_penalty if frequency_penalty is not None else NOT_GIVEN,
33
- logit_bias=logit_bias if logit_bias is not None else NOT_GIVEN,
34
- max_tokens=max_tokens if max_tokens is not None else NOT_GIVEN,
35
- n=n if n is not None else NOT_GIVEN,
36
- presence_penalty=presence_penalty if presence_penalty is not None else NOT_GIVEN,
37
- response_format=response_format if response_format is not None else NOT_GIVEN,
38
- seed=seed if seed is not None else NOT_GIVEN,
39
- top_p=top_p if top_p is not None else NOT_GIVEN,
40
- temperature=temperature if temperature is not None else NOT_GIVEN
133
+ frequency_penalty=_opt(frequency_penalty),
134
+ logit_bias=_opt(logit_bias),
135
+ logprobs=_opt(logprobs),
136
+ top_logprobs=_opt(top_logprobs),
137
+ max_tokens=_opt(max_tokens),
138
+ n=_opt(n),
139
+ presence_penalty=_opt(presence_penalty),
140
+ response_format=_opt(response_format),
141
+ seed=_opt(seed),
142
+ stop=_opt(stop),
143
+ temperature=_opt(temperature),
144
+ top_p=_opt(top_p),
145
+ tools=_opt(tools),
146
+ tool_choice=_opt(tool_choice),
147
+ user=_opt(user)
41
148
  )
42
149
  return result.dict()
43
150
 
44
151
 
45
152
  @pxt.udf
153
+ @_retry
46
154
  def vision(
47
155
  prompt: str,
48
156
  image: PIL.Image.Image,
157
+ *,
49
158
  model: str = 'gpt-4-vision-preview'
50
159
  ) -> str:
160
+ # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
51
161
  bytes_arr = io.BytesIO()
52
162
  image.save(bytes_arr, format='png')
53
163
  b64_bytes = base64.b64encode(bytes_arr.getvalue())
@@ -61,28 +171,86 @@ def vision(
61
171
  }}
62
172
  ]}
63
173
  ]
64
- result = env.Env.get().openai_client.chat.completions.create(
174
+ result = openai_client().chat.completions.create(
65
175
  messages=messages,
66
176
  model=model
67
177
  )
68
178
  return result.choices[0].message.content
69
179
 
70
180
 
71
- @pxt.udf
72
- def moderations(input: str, model: Optional[str] = None) -> dict:
73
- result = env.Env().get().openai_client.moderations.create(input=input, model=model)
74
- return result.dict()
75
-
181
+ #####################################
182
+ # Embeddings Endpoints
76
183
 
77
184
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
78
- def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
79
- result = env.Env().get().openai_client.embeddings.create(
185
+ @_retry
186
+ def embeddings(
187
+ input: Batch[str],
188
+ *,
189
+ model: str,
190
+ user: Optional[str] = None
191
+ ) -> Batch[np.ndarray]:
192
+ result = openai_client().embeddings.create(
80
193
  input=input,
81
194
  model=model,
195
+ user=_opt(user),
82
196
  encoding_format='float'
83
197
  )
84
- embeddings = [
198
+ return [
85
199
  np.array(data.embedding, dtype=np.float64)
86
200
  for data in result.data
87
201
  ]
88
- return embeddings
202
+
203
+
204
+ #####################################
205
+ # Images Endpoints
206
+
207
+ @pxt.udf
208
+ @_retry
209
+ def image_generations(
210
+ prompt: str,
211
+ *,
212
+ model: Optional[str] = None,
213
+ quality: Optional[str] = None,
214
+ size: Optional[str] = None,
215
+ style: Optional[str] = None,
216
+ user: Optional[str] = None
217
+ ) -> PIL.Image.Image:
218
+ # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
219
+ result = openai_client().images.generate(
220
+ prompt=prompt,
221
+ model=_opt(model),
222
+ quality=_opt(quality),
223
+ size=_opt(size),
224
+ style=_opt(style),
225
+ user=_opt(user),
226
+ response_format="b64_json"
227
+ )
228
+ b64_str = result.data[0].b64_json
229
+ b64_bytes = base64.b64decode(b64_str)
230
+ img = PIL.Image.open(io.BytesIO(b64_bytes))
231
+ img.load()
232
+ return img
233
+
234
+
235
+ #####################################
236
+ # Moderations Endpoints
237
+
238
+ @pxt.udf
239
+ @_retry
240
+ def moderations(
241
+ input: str,
242
+ *,
243
+ model: Optional[str] = None
244
+ ) -> dict:
245
+ result = openai_client().moderations.create(
246
+ input=input,
247
+ model=_opt(model)
248
+ )
249
+ return result.dict()
250
+
251
+
252
+ _T = TypeVar('_T')
253
+
254
+
255
+ def _opt(arg: _T) -> Union[_T, NotGiven]:
256
+ return arg if arg is not None else NOT_GIVEN
@@ -1,27 +1,122 @@
1
+ import base64
2
+ import io
1
3
  from typing import Optional
2
4
 
5
+ import PIL.Image
6
+ import numpy as np
7
+ import together
8
+
3
9
  import pixeltable as pxt
10
+ from pixeltable import env
11
+ from pixeltable.func import Batch
12
+
13
+
14
+ def together_client() -> together.Together:
15
+ return env.Env.get().get_client('together', lambda api_key: together.Together(api_key=api_key))
4
16
 
5
17
 
6
18
  @pxt.udf
7
19
  def completions(
8
20
  prompt: str,
21
+ *,
9
22
  model: str,
10
23
  max_tokens: Optional[int] = None,
11
- repetition_penalty: Optional[float] = None,
12
24
  stop: Optional[list] = None,
13
- top_k: Optional[int] = None,
25
+ temperature: Optional[float] = None,
14
26
  top_p: Optional[float] = None,
15
- temperature: Optional[float] = None
27
+ top_k: Optional[int] = None,
28
+ repetition_penalty: Optional[float] = None,
29
+ logprobs: Optional[int] = None,
30
+ echo: Optional[bool] = None,
31
+ n: Optional[int] = None,
32
+ safety_model: Optional[str] = None
16
33
  ) -> dict:
17
- import together
18
- return together.Complete.create(
19
- prompt,
20
- model,
34
+ return together_client().completions.create(
35
+ prompt=prompt,
36
+ model=model,
21
37
  max_tokens=max_tokens,
22
- repetition_penalty=repetition_penalty,
23
38
  stop=stop,
39
+ temperature=temperature,
40
+ top_p=top_p,
24
41
  top_k=top_k,
42
+ repetition_penalty=repetition_penalty,
43
+ logprobs=logprobs,
44
+ echo=echo,
45
+ n=n,
46
+ safety_model=safety_model
47
+ ).dict()
48
+
49
+
50
+ @pxt.udf
51
+ def chat_completions(
52
+ messages: list[dict[str, str]],
53
+ *,
54
+ model: str,
55
+ max_tokens: Optional[int] = None,
56
+ stop: Optional[list[str]] = None,
57
+ temperature: Optional[float] = None,
58
+ top_p: Optional[float] = None,
59
+ top_k: Optional[int] = None,
60
+ repetition_penalty: Optional[float] = None,
61
+ logprobs: Optional[int] = None,
62
+ echo: Optional[bool] = None,
63
+ n: Optional[int] = None,
64
+ safety_model: Optional[str] = None,
65
+ response_format: Optional[dict] = None,
66
+ tools: Optional[dict] = None,
67
+ tool_choice: Optional[dict] = None
68
+ ) -> dict:
69
+ return together_client().chat.completions.create(
70
+ messages=messages,
71
+ model=model,
72
+ max_tokens=max_tokens,
73
+ stop=stop,
74
+ temperature=temperature,
25
75
  top_p=top_p,
26
- temperature=temperature
76
+ top_k=top_k,
77
+ repetition_penalty=repetition_penalty,
78
+ logprobs=logprobs,
79
+ echo=echo,
80
+ n=n,
81
+ safety_model=safety_model,
82
+ response_format=response_format,
83
+ tools=tools,
84
+ tool_choice=tool_choice
85
+ ).dict()
86
+
87
+
88
+ @pxt.udf(batch_size=32, return_type=pxt.ArrayType((None,), dtype=pxt.FloatType()))
89
+ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
90
+ result = together_client().embeddings.create(input=input, model=model)
91
+ return [
92
+ np.array(data.embedding, dtype=np.float64)
93
+ for data in result.data
94
+ ]
95
+
96
+
97
+ @pxt.udf
98
+ def image_generations(
99
+ prompt: str,
100
+ *,
101
+ model: str,
102
+ steps: Optional[int] = None,
103
+ seed: Optional[int] = None,
104
+ height: Optional[int] = None,
105
+ width: Optional[int] = None,
106
+ negative_prompt: Optional[str] = None,
107
+ ) -> PIL.Image.Image:
108
+ # TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
109
+ result = together_client().images.generate(
110
+ prompt=prompt,
111
+ model=model,
112
+ steps=steps,
113
+ seed=seed,
114
+ height=height,
115
+ width=width,
116
+ negative_prompt=negative_prompt
27
117
  )
118
+ b64_str = result.data[0].b64_json
119
+ b64_bytes = base64.b64decode(b64_str)
120
+ img = PIL.Image.open(io.BytesIO(b64_bytes))
121
+ img.load()
122
+ return img
@@ -39,3 +39,14 @@ def create_nos_modules() -> List[types.ModuleType]:
39
39
  setattr(sub_module, model_id, pt_func)
40
40
 
41
41
  return new_modules
42
+
43
+
44
+ def resolve_torch_device(device: str) -> str:
45
+ import torch
46
+ if device == 'auto':
47
+ if torch.cuda.is_available():
48
+ return 'cuda'
49
+ if torch.backends.mps.is_available():
50
+ return 'mps'
51
+ return 'cpu'
52
+ return device
@@ -0,0 +1,2 @@
1
+ from .base import IndexBase
2
+ from .embedding_index import EmbeddingIndex
@@ -0,0 +1,49 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ from typing import Any
5
+
6
+ import sqlalchemy as sql
7
+
8
+ import pixeltable.catalog as catalog
9
+
10
+
11
+ class IndexBase(abc.ABC):
12
+ """
13
+ Internal interface used by the catalog and runtime system to interact with indices:
14
+ - types and expressions needed to create and populate the index value column
15
+ - creating/dropping the index
16
+ - TODO: translating queries into sqlalchemy predicates
17
+ """
18
+ @abc.abstractmethod
19
+ def __init__(self, c: catalog.Column, **kwargs: Any):
20
+ pass
21
+
22
+ @abc.abstractmethod
23
+ def index_value_expr(self) -> 'pixeltable.exprs.Expr':
24
+ """Return expression that computes the value that goes into the index"""
25
+ pass
26
+
27
+ @abc.abstractmethod
28
+ def index_sa_type(self) -> sql.sqltypes.TypeEngine:
29
+ """Return the sqlalchemy type of the index value column"""
30
+ pass
31
+
32
+ @abc.abstractmethod
33
+ def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
34
+ """Create the index on the index value column"""
35
+ pass
36
+
37
+ @classmethod
38
+ @abc.abstractmethod
39
+ def display_name(cls) -> str:
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ def as_dict(self) -> dict:
44
+ pass
45
+
46
+ @classmethod
47
+ @abc.abstractmethod
48
+ def from_dict(cls, c: catalog.Column, d: dict) -> IndexBase:
49
+ pass
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import pgvector.sqlalchemy
6
+ import sqlalchemy as sql
7
+
8
+ import pixeltable.catalog as catalog
9
+ import pixeltable.exceptions as excs
10
+ import pixeltable.func as func
11
+ import pixeltable.type_system as ts
12
+ from .base import IndexBase
13
+
14
+
15
+ class EmbeddingIndex(IndexBase):
16
+ """
17
+ Internal interface used by the catalog and runtime system to interact with (embedding) indices:
18
+ - types and expressions needed to create and populate the index value column
19
+ - creating/dropping the index
20
+ - translating 'matches' queries into sqlalchemy predicates
21
+ """
22
+
23
+ def __init__(
24
+ self, c: catalog.Column, text_embed: Optional[func.Function] = None,
25
+ img_embed: Optional[func.Function] = None):
26
+ if not c.col_type.is_string_type() and not c.col_type.is_image_type():
27
+ raise excs.Error(f'Embedding index requires string or image column')
28
+ if c.col_type.is_string_type() and text_embed is None:
29
+ raise excs.Error(f'Text embedding function is required for column {c.name} (parameter `txt_embed`)')
30
+ if c.col_type.is_image_type() and img_embed is None:
31
+ raise excs.Error(f'Image embedding function is required for column {c.name} (parameter `img_embed`)')
32
+ if text_embed is not None:
33
+ # verify signature
34
+ self._validate_embedding_fn(text_embed, 'txt_embed', ts.ColumnType.Type.STRING)
35
+ if img_embed is not None:
36
+ # verify signature
37
+ self._validate_embedding_fn(img_embed, 'img_embed', ts.ColumnType.Type.IMAGE)
38
+
39
+ from pixeltable.exprs import ColumnRef
40
+ self.value_expr = text_embed(ColumnRef(c)) if c.col_type.is_string_type() else img_embed(ColumnRef(c))
41
+ assert self.value_expr.col_type.is_array_type()
42
+ self.txt_embed = text_embed
43
+ self.img_embed = img_embed
44
+ vector_size = self.value_expr.col_type.shape[0]
45
+ assert vector_size is not None
46
+ self.index_col_type = pgvector.sqlalchemy.Vector(vector_size)
47
+
48
+ def index_value_expr(self) -> 'pixeltable.exprs.Expr':
49
+ """Return expression that computes the value that goes into the index"""
50
+ return self.value_expr
51
+
52
+ def index_sa_type(self) -> sql.sqltypes.TypeEngine:
53
+ """Return the sqlalchemy type of the index value column"""
54
+ return self.index_col_type
55
+
56
+ def create_index(self, index_name: str, index_value_col: catalog.Column, conn: sql.engine.Connection) -> None:
57
+ """Create the index on the index value column"""
58
+ idx = sql.Index(
59
+ index_name, index_value_col.sa_col,
60
+ postgresql_using='hnsw',
61
+ postgresql_with={'m': 16, 'ef_construction': 64},
62
+ postgresql_ops={index_value_col.sa_col.name: 'vector_cosine_ops'}
63
+ )
64
+ idx.create(bind=conn)
65
+
66
+ @classmethod
67
+ def display_name(cls) -> str:
68
+ return 'embedding'
69
+
70
+ @classmethod
71
+ def _validate_embedding_fn(cls, embed_fn: func.Function, name: str, expected_type: ts.ColumnType.Type) -> None:
72
+ """Validate the signature"""
73
+ assert isinstance(embed_fn, func.Function)
74
+ sig = embed_fn.signature
75
+ if not sig.return_type.is_array_type():
76
+ raise excs.Error(f'{name} must return an array, but returns {sig.return_type}')
77
+ else:
78
+ shape = sig.return_type.shape
79
+ if len(shape) != 1 or shape[0] == None:
80
+ raise excs.Error(f'{name} must return a 1D array of a specific length, but returns {sig.return_type}')
81
+ if len(sig.parameters) != 1 or sig.parameters_by_pos[0].col_type.type_enum != expected_type:
82
+ raise excs.Error(
83
+ f'{name} must take a single {expected_type.name.lower()} parameter, but has signature {sig}')
84
+
85
+ def as_dict(self) -> dict:
86
+ return {
87
+ 'txt_embed': None if self.txt_embed is None else self.txt_embed.as_dict(),
88
+ 'img_embed': None if self.img_embed is None else self.img_embed.as_dict()
89
+ }
90
+
91
+ @classmethod
92
+ def from_dict(cls, c: catalog.Column, d: dict) -> EmbeddingIndex:
93
+ txt_embed = func.Function.from_dict(d['txt_embed']) if d['txt_embed'] is not None else None
94
+ img_embed = func.Function.from_dict(d['img_embed']) if d['img_embed'] is not None else None
95
+ return cls(c, text_embed=txt_embed, img_embed=img_embed)
@@ -1,4 +1,4 @@
1
- from typing import Optional, List, Dict, get_type_hints, Type, Any, TypeVar, Tuple, Union
1
+ from typing import Optional, List, get_type_hints, Type, Any, TypeVar, Tuple, Union
2
2
  import platform
3
3
  import uuid
4
4
  import dataclasses
@@ -71,16 +71,43 @@ class Dir(Base):
71
71
 
72
72
 
73
73
  @dataclasses.dataclass
74
- class ColumnHistory:
74
+ class ColumnMd:
75
75
  """
76
- Records when a column was added/dropped, which is needed to GC unreachable storage columns
77
- (a column that was added after table snapshot n and dropped before table snapshot n+1 can be removed
78
- from the stored table).
79
- One record per column (across all schema versions).
76
+ Records the non-versioned metadata of a column.
77
+ - immutable attributes: type, primary key, etc.
78
+ - when a column was added/dropped, which is needed to GC unreachable storage columns
79
+ (a column that was added after table snapshot n and dropped before table snapshot n+1 can be removed
80
+ from the stored table).
80
81
  """
81
- col_id: int
82
+ id: int
82
83
  schema_version_add: int
83
84
  schema_version_drop: Optional[int]
85
+ col_type: dict
86
+
87
+ # if True, is part of the primary key
88
+ is_pk: bool
89
+
90
+ # if set, this is a computed column
91
+ value_expr: Optional[dict]
92
+
93
+ # if True, the column is present in the stored table
94
+ stored: Optional[bool]
95
+
96
+
97
+ @dataclasses.dataclass
98
+ class IndexMd:
99
+ """
100
+ Metadata needed to instantiate an EmbeddingIndex
101
+ """
102
+ id: int
103
+ name: str
104
+ indexed_col_id: int # column being indexed
105
+ index_val_col_id: int # column holding the values to be indexed
106
+ index_val_undo_col_id: int # column holding index values for deleted rows
107
+ schema_version_add: int
108
+ schema_version_drop: Optional[int]
109
+ class_fqn: str
110
+ init_args: dict[str, Any]
84
111
 
85
112
 
86
113
  @dataclasses.dataclass
@@ -91,13 +118,13 @@ class ViewMd:
91
118
  base_versions: List[Tuple[str, Optional[int]]]
92
119
 
93
120
  # filter predicate applied to the base table; view-only
94
- predicate: Optional[Dict[str, Any]]
121
+ predicate: Optional[dict[str, Any]]
95
122
 
96
123
  # ComponentIterator subclass; only for component views
97
124
  iterator_class_fqn: Optional[str]
98
125
 
99
126
  # args to pass to the iterator class constructor; only for component views
100
- iterator_args: Optional[Dict[str, Any]]
127
+ iterator_args: Optional[dict[str, Any]]
101
128
 
102
129
 
103
130
  @dataclasses.dataclass
@@ -109,15 +136,15 @@ class TableMd:
109
136
  # each version has a corresponding schema version (current_version >= current_schema_version)
110
137
  current_schema_version: int
111
138
 
112
- # used to assign Column.id
113
- next_col_id: int
139
+ next_col_id: int # used to assign Column.id
140
+ next_idx_id: int # used to assign IndexMd.id
114
141
 
115
142
  # - used to assign the rowid column in the storage table
116
143
  # - every row is assigned a unique and immutable rowid on insertion
117
144
  next_row_id: int
118
145
 
119
- column_history: Dict[int, ColumnHistory] # col_id -> ColumnHistory
120
-
146
+ column_md: dict[int, ColumnMd] # col_id -> ColumnMd
147
+ index_md: dict[int, IndexMd] # index_id -> IndexMd
121
148
  view_md: Optional[ViewMd]
122
149
 
123
150
 
@@ -155,24 +182,20 @@ class TableVersion(Base):
155
182
  @dataclasses.dataclass
156
183
  class SchemaColumn:
157
184
  """
158
- Records the logical (user-visible) schema of a table.
159
- Contains the full set of columns for each new schema version: one record per (column x schema version).
185
+ Records the versioned metadata of a column.
160
186
  """
161
187
  pos: int
162
188
  name: str
163
- col_type: dict
164
- is_pk: bool
165
- value_expr: Optional[dict]
166
- stored: Optional[bool]
167
- # if True, creates vector index for this column
168
- is_indexed: bool
169
189
 
170
190
 
171
191
  @dataclasses.dataclass
172
192
  class TableSchemaVersionMd:
193
+ """
194
+ Records all versioned table metadata.
195
+ """
173
196
  schema_version: int
174
197
  preceding_schema_version: Optional[int]
175
- columns: Dict[int, SchemaColumn] # col_id -> SchemaColumn
198
+ columns: dict[int, SchemaColumn] # col_id -> SchemaColumn
176
199
  num_retained_versions: int
177
200
  comment: str
178
201