pixeltable 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (53) hide show
  1. pixeltable/__init__.py +2 -27
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +309 -59
  4. pixeltable/catalog/globals.py +5 -5
  5. pixeltable/catalog/insertable_table.py +13 -1
  6. pixeltable/catalog/path.py +13 -6
  7. pixeltable/catalog/table.py +28 -41
  8. pixeltable/catalog/table_version.py +100 -72
  9. pixeltable/catalog/view.py +35 -9
  10. pixeltable/dataframe.py +2 -2
  11. pixeltable/exceptions.py +20 -2
  12. pixeltable/exec/expr_eval/evaluators.py +0 -4
  13. pixeltable/exec/expr_eval/expr_eval_node.py +0 -1
  14. pixeltable/exec/sql_node.py +3 -3
  15. pixeltable/exprs/json_path.py +1 -5
  16. pixeltable/func/__init__.py +1 -1
  17. pixeltable/func/aggregate_function.py +1 -1
  18. pixeltable/func/callable_function.py +1 -1
  19. pixeltable/func/expr_template_function.py +2 -2
  20. pixeltable/func/function.py +3 -4
  21. pixeltable/func/query_template_function.py +87 -4
  22. pixeltable/func/tools.py +1 -1
  23. pixeltable/func/udf.py +1 -1
  24. pixeltable/functions/__init__.py +1 -0
  25. pixeltable/functions/anthropic.py +1 -1
  26. pixeltable/functions/bedrock.py +130 -0
  27. pixeltable/functions/huggingface.py +7 -6
  28. pixeltable/functions/image.py +15 -16
  29. pixeltable/functions/mistralai.py +3 -2
  30. pixeltable/functions/openai.py +9 -8
  31. pixeltable/functions/together.py +4 -3
  32. pixeltable/globals.py +7 -2
  33. pixeltable/io/datarows.py +4 -3
  34. pixeltable/io/label_studio.py +17 -17
  35. pixeltable/io/pandas.py +13 -12
  36. pixeltable/io/table_data_conduit.py +8 -2
  37. pixeltable/metadata/__init__.py +1 -1
  38. pixeltable/metadata/converters/convert_19.py +2 -2
  39. pixeltable/metadata/converters/convert_31.py +11 -0
  40. pixeltable/metadata/converters/convert_32.py +15 -0
  41. pixeltable/metadata/converters/convert_33.py +17 -0
  42. pixeltable/metadata/notes.py +3 -0
  43. pixeltable/metadata/schema.py +26 -1
  44. pixeltable/plan.py +2 -3
  45. pixeltable/share/packager.py +9 -25
  46. pixeltable/share/publish.py +20 -9
  47. pixeltable/store.py +7 -4
  48. pixeltable/utils/exception_handler.py +59 -0
  49. {pixeltable-0.3.11.dist-info → pixeltable-0.3.13.dist-info}/METADATA +1 -1
  50. {pixeltable-0.3.11.dist-info → pixeltable-0.3.13.dist-info}/RECORD +53 -48
  51. {pixeltable-0.3.11.dist-info → pixeltable-0.3.13.dist-info}/WHEEL +1 -1
  52. {pixeltable-0.3.11.dist-info → pixeltable-0.3.13.dist-info}/LICENSE +0 -0
  53. {pixeltable-0.3.11.dist-info → pixeltable-0.3.13.dist-info}/entry_points.txt +0 -0
@@ -208,10 +208,6 @@ class FnCallEvaluator(Evaluator):
208
208
  _logger.debug(f'Evaluated slot {self.fn_call.slot_idx} in {end_ts - start_ts}')
209
209
  self.dispatcher.dispatch([call_args.row], self.exec_ctx)
210
210
  except Exception as exc:
211
- import anthropic
212
-
213
- if isinstance(exc, anthropic.RateLimitError):
214
- _logger.debug(f'RateLimitError: {exc}')
215
211
  _, _, exc_tb = sys.exc_info()
216
212
  call_args.row.set_exc(self.fn_call.slot_idx, exc)
217
213
  self.dispatcher.dispatch_exc(call_args.rows, self.fn_call.slot_idx, exc_tb, self.exec_ctx)
@@ -282,7 +282,6 @@ class ExprEvalNode(ExecNode):
282
282
 
283
283
  if self.exc_event.is_set():
284
284
  # we got an exception that we need to propagate through __iter__()
285
- _logger.debug(f'Propagating exception {self.error}')
286
285
  if isinstance(self.error, excs.ExprEvalError):
287
286
  raise self.error from self.error.exc
288
287
  else:
@@ -103,7 +103,6 @@ class SqlNode(ExecNode):
103
103
  # create Select stmt
104
104
  self.sql_elements = sql_elements
105
105
  self.tbl = tbl
106
- assert all(not isinstance(e, exprs.Literal) for e in select_list) # we're never asked to materialize literals
107
106
  self.select_list = exprs.ExprSet(select_list)
108
107
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
109
108
  for iter_arg in row_builder.unstored_iter_args.values():
@@ -425,6 +424,7 @@ class SqlAggregationNode(SqlNode):
425
424
  """
426
425
 
427
426
  group_by_items: Optional[list[exprs.Expr]]
427
+ input_cte: Optional[sql.CTE]
428
428
 
429
429
  def __init__(
430
430
  self,
@@ -441,13 +441,13 @@ class SqlAggregationNode(SqlNode):
441
441
  group_by_items: list of expressions to group by
442
442
  limit: max number of rows to return: None = no limit
443
443
  """
444
- _, input_col_map = input.to_cte()
444
+ self.input_cte, input_col_map = input.to_cte()
445
445
  sql_elements = exprs.SqlElementCache(input_col_map)
446
446
  super().__init__(None, row_builder, select_list, sql_elements)
447
447
  self.group_by_items = group_by_items
448
448
 
449
449
  def _create_stmt(self) -> sql.Select:
450
- stmt = super()._create_stmt()
450
+ stmt = super()._create_stmt().select_from(self.input_cte)
451
451
  if self.group_by_items is not None:
452
452
  sql_group_by_items = [self.sql_elements.get(e) for e in self.group_by_items]
453
453
  assert all(e is not None for e in sql_group_by_items)
@@ -5,7 +5,6 @@ from typing import Any, Optional, Union
5
5
  import jmespath
6
6
  import sqlalchemy as sql
7
7
 
8
- import pixeltable as pxt
9
8
  from pixeltable import catalog, exceptions as excs, type_system as ts
10
9
 
11
10
  from .data_row import DataRow
@@ -19,10 +18,7 @@ from .sql_element_cache import SqlElementCache
19
18
 
20
19
  class JsonPath(Expr):
21
20
  def __init__(
22
- self,
23
- anchor: Optional['pxt.exprs.Expr'],
24
- path_elements: Optional[list[Union[str, int, slice]]] = None,
25
- scope_idx: int = 0,
21
+ self, anchor: Optional[Expr], path_elements: Optional[list[Union[str, int, slice]]] = None, scope_idx: int = 0
26
22
  ) -> None:
27
23
  """
28
24
  anchor can be None, in which case this is a relative JsonPath and the anchor is set later via set_anchor().
@@ -5,7 +5,7 @@ from .callable_function import CallableFunction
5
5
  from .expr_template_function import ExprTemplateFunction
6
6
  from .function import Function, InvalidFunction
7
7
  from .function_registry import FunctionRegistry
8
- from .query_template_function import QueryTemplateFunction, query
8
+ from .query_template_function import QueryTemplateFunction, query, retrieval_udf
9
9
  from .signature import Batch, Parameter, Signature
10
10
  from .tools import Tool, ToolChoice, Tools
11
11
  from .udf import expr_udf, make_function, udf
@@ -159,7 +159,7 @@ class AggregateFunction(Function):
159
159
  self.init_param_names.append(init_param_names)
160
160
  return self
161
161
 
162
- def _docstring(self) -> Optional[str]:
162
+ def comment(self) -> Optional[str]:
163
163
  return inspect.getdoc(self.agg_classes[0])
164
164
 
165
165
  def help_str(self) -> str:
@@ -60,7 +60,7 @@ class CallableFunction(Function):
60
60
  def is_async(self) -> bool:
61
61
  return inspect.iscoroutinefunction(self.py_fn)
62
62
 
63
- def _docstring(self) -> Optional[str]:
63
+ def comment(self) -> Optional[str]:
64
64
  return inspect.getdoc(self.py_fns[0])
65
65
 
66
66
  @property
@@ -95,9 +95,9 @@ class ExprTemplateFunction(Function):
95
95
  )
96
96
  return substituted_expr.col_type
97
97
 
98
- def _docstring(self) -> Optional[str]:
98
+ def comment(self) -> Optional[str]:
99
99
  if isinstance(self.templates[0].expr, exprs.FunctionCall):
100
- return self.templates[0].expr.fn._docstring()
100
+ return self.templates[0].expr.fn.comment()
101
101
  return None
102
102
 
103
103
  def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
@@ -10,8 +10,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, cast
10
10
  import sqlalchemy as sql
11
11
  from typing_extensions import Self
12
12
 
13
- import pixeltable.exceptions as excs
14
- import pixeltable.type_system as ts
13
+ from pixeltable import exceptions as excs, type_system as ts
15
14
 
16
15
  from .globals import resolve_symbol
17
16
  from .signature import Signature
@@ -106,11 +105,11 @@ class Function(ABC):
106
105
  @abstractmethod
107
106
  def is_async(self) -> bool: ...
108
107
 
109
- def _docstring(self) -> Optional[str]:
108
+ def comment(self) -> Optional[str]:
110
109
  return None
111
110
 
112
111
  def help_str(self) -> str:
113
- docstring = self._docstring()
112
+ docstring = self.comment()
114
113
  display = self.display_name + str(self.signatures[0])
115
114
  if docstring is None:
116
115
  return display
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import TYPE_CHECKING, Any, Callable, Optional, overload
4
+ from functools import reduce
5
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Union, overload
5
6
 
6
- from pixeltable import exprs, type_system as ts
7
+ from pixeltable import catalog, exceptions as excs, exprs, func, type_system as ts
7
8
 
8
9
  from .function import Function
9
10
  from .signature import Signature
@@ -17,6 +18,7 @@ class QueryTemplateFunction(Function):
17
18
 
18
19
  template_df: Optional['DataFrame']
19
20
  self_name: Optional[str]
21
+ _comment: Optional[str]
20
22
 
21
23
  @classmethod
22
24
  def create(
@@ -34,15 +36,21 @@ class QueryTemplateFunction(Function):
34
36
  assert isinstance(template_df, DataFrame)
35
37
  # we take params and return json
36
38
  sig = Signature(return_type=ts.JsonType(), parameters=params)
37
- return QueryTemplateFunction(template_df, sig, path=path, name=name)
39
+ return QueryTemplateFunction(template_df, sig, path=path, name=name, comment=inspect.getdoc(template_callable))
38
40
 
39
41
  def __init__(
40
- self, template_df: Optional['DataFrame'], sig: Signature, path: Optional[str] = None, name: Optional[str] = None
42
+ self,
43
+ template_df: Optional['DataFrame'],
44
+ sig: Signature,
45
+ path: Optional[str] = None,
46
+ name: Optional[str] = None,
47
+ comment: Optional[str] = None,
41
48
  ):
42
49
  assert sig is not None
43
50
  super().__init__([sig], self_path=path)
44
51
  self.self_name = name
45
52
  self.template_df = template_df
53
+ self._comment = comment
46
54
 
47
55
  def _update_as_overload_resolution(self, signature_idx: int) -> None:
48
56
  pass # only one signature supported for QueryTemplateFunction
@@ -74,6 +82,9 @@ class QueryTemplateFunction(Function):
74
82
  def name(self) -> str:
75
83
  return self.self_name
76
84
 
85
+ def comment(self) -> Optional[str]:
86
+ return self._comment
87
+
77
88
  def _as_dict(self) -> dict:
78
89
  return {'name': self.name, 'signature': self.signature.as_dict(), 'df': self.template_df.as_dict()}
79
90
 
@@ -112,3 +123,75 @@ def query(*args: Any, **kwargs: Any) -> Any:
112
123
  else:
113
124
  assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
114
125
  return lambda py_fn: make_query_template(py_fn, kwargs['param_types'])
126
+
127
+
128
+ def retrieval_udf(
129
+ table: catalog.Table,
130
+ name: Optional[str] = None,
131
+ description: Optional[str] = None,
132
+ parameters: Optional[Iterable[Union[str, exprs.ColumnRef]]] = None,
133
+ limit: Optional[int] = 10,
134
+ ) -> func.QueryTemplateFunction:
135
+ """
136
+ Constructs a retrieval UDF for the given table. The retrieval UDF is a UDF whose parameters are
137
+ columns of the table and whose return value is a list of rows from the table. The return value of
138
+ ```python
139
+ f(col1=x, col2=y, ...)
140
+ ```
141
+ will be a list of all rows from the table that match the specified arguments.
142
+
143
+ Args:
144
+ table: The table to use as the dataset for the retrieval tool.
145
+ name: The name of the tool. If not specified, then the name of the table will be used by default.
146
+ description: The description of the tool. If not specified, then a default description will be generated.
147
+ parameters: The columns of the table to use as parameters. If not specified, all data columns
148
+ (non-computed columns) will be used as parameters.
149
+
150
+ All of the specified parameters will be required parameters of the tool, regardless of their status
151
+ as columns.
152
+ limit: The maximum number of rows to return. If not specified, then all matching rows will be returned.
153
+
154
+ Returns:
155
+ A list of dictionaries containing data from the table, one per row that matches the input arguments.
156
+ If there are no matching rows, an empty list will be returned.
157
+ """
158
+ # Argument validation
159
+ col_refs: list[exprs.ColumnRef]
160
+ if parameters is None:
161
+ col_refs = [table[col_name] for col_name in table.columns if not table[col_name].col.is_computed]
162
+ else:
163
+ for param in parameters:
164
+ if isinstance(param, str) and param not in table.columns:
165
+ raise excs.Error(f'The specified parameter {param!r} is not a column of the table {table._path!r}')
166
+ col_refs = [table[param] if isinstance(param, str) else param for param in parameters]
167
+
168
+ if len(col_refs) == 0:
169
+ raise excs.Error('Parameter list cannot be empty.')
170
+
171
+ # Construct the dataframe
172
+ predicates = [col_ref == exprs.Variable(col_ref.col.name, col_ref.col.col_type) for col_ref in col_refs]
173
+ where_clause = reduce(lambda c1, c2: c1 & c2, predicates)
174
+ df = table.select().where(where_clause)
175
+ if limit is not None:
176
+ df = df.limit(limit)
177
+
178
+ # Construct the signature
179
+ query_params = [
180
+ func.Parameter(col_ref.col.name, col_ref.col.col_type, inspect.Parameter.POSITIONAL_OR_KEYWORD)
181
+ for col_ref in col_refs
182
+ ]
183
+ query_signature = func.Signature(return_type=ts.JsonType(), parameters=query_params)
184
+
185
+ # Construct a name and/or description if not provided
186
+ if name is None:
187
+ name = table._name
188
+ if description is None:
189
+ description = (
190
+ f'Retrieves an entry from the dataset {name!r} that matches the given parameters.\n\nParameters:\n'
191
+ )
192
+ description += '\n'.join(
193
+ [f' {col_ref.col.name}: of type `{col_ref.col.col_type._to_base_str()}`' for col_ref in col_refs]
194
+ )
195
+
196
+ fn = func.QueryTemplateFunction(df, query_signature, name=name, comment=description)
197
+ return fn
pixeltable/func/tools.py CHANGED
@@ -39,7 +39,7 @@ class Tool(pydantic.BaseModel):
39
39
  def ser_model(self) -> dict[str, Any]:
40
40
  return {
41
41
  'name': self.name or self.fn.name,
42
- 'description': self.description or self.fn._docstring(),
42
+ 'description': self.description or self.fn.comment(),
43
43
  'parameters': {
44
44
  'type': 'object',
45
45
  'properties': {param.name: param.col_type._to_json_schema() for param in self.parameters.values()},
pixeltable/func/udf.py CHANGED
@@ -262,7 +262,7 @@ def from_table(
262
262
  """
263
263
  from pixeltable import exprs
264
264
 
265
- ancestors = [tbl, *tbl._bases]
265
+ ancestors = [tbl, *tbl._base_tables]
266
266
  ancestors.reverse() # We must traverse the ancestors in order from base to derived
267
267
 
268
268
  subst: dict[exprs.Expr, exprs.Expr] = {}
@@ -5,6 +5,7 @@ from pixeltable.utils.code import local_public_names
5
5
  from . import (
6
6
  anthropic,
7
7
  audio,
8
+ bedrock,
8
9
  deepseek,
9
10
  fireworks,
10
11
  gemini,
@@ -112,7 +112,7 @@ async def messages(
112
112
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
113
113
 
114
114
  >>> msgs = [{'role': 'user', 'content': tbl.prompt}]
115
- ... tbl.add_computed_column(response= messages(msgs, model='claude-3-haiku-20240307'))
115
+ ... tbl.add_computed_column(response=messages(msgs, model='claude-3-haiku-20240307'))
116
116
  """
117
117
 
118
118
  # it doesn't look like count_tokens() actually exists in the current version of the library
@@ -0,0 +1,130 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING, Any, Optional
3
+
4
+ import pixeltable as pxt
5
+ from pixeltable import env, exprs
6
+ from pixeltable.func import Tools
7
+ from pixeltable.utils.code import local_public_names
8
+
9
+ if TYPE_CHECKING:
10
+ from botocore.client import BaseClient
11
+
12
+ _logger = logging.getLogger('pixeltable')
13
+
14
+
15
+ @env.register_client('bedrock')
16
+ def _() -> 'BaseClient':
17
+ import boto3
18
+
19
+ return boto3.client(service_name='bedrock-runtime')
20
+
21
+
22
+ # boto3 typing is weird; type information is dynamically defined, so the best we can do for the static checker is `Any`
23
+ def _bedrock_client() -> Any:
24
+ return env.Env.get().get_client('bedrock')
25
+
26
+
27
+ @pxt.udf
28
+ def converse(
29
+ messages: list[dict[str, Any]],
30
+ *,
31
+ model_id: str,
32
+ system: Optional[list[dict[str, Any]]] = None,
33
+ inference_config: Optional[dict] = None,
34
+ additional_model_request_fields: Optional[dict] = None,
35
+ tool_config: Optional[list[dict]] = None,
36
+ ) -> dict:
37
+ """
38
+ Generate a conversation response.
39
+
40
+ Equivalent to the AWS Bedrock `converse` API endpoint.
41
+ For additional details, see: <https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html>
42
+
43
+ __Requirements:__
44
+
45
+ - `pip install boto3`
46
+
47
+ Args:
48
+ messages: Input messages.
49
+ model_id: The model that will complete your prompt.
50
+ system: An optional system prompt.
51
+ inference_config: Base inference parameters to use.
52
+ additional_model_request_fields: Additional inference parameters to use.
53
+
54
+ For details on the optional parameters, see:
55
+ <https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html>
56
+
57
+ Returns:
58
+ A dictionary containing the response and other metadata.
59
+
60
+ Examples:
61
+ Add a computed column that applies the model `anthropic.claude-3-haiku-20240307-v1:0`
62
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
63
+
64
+ >>> msgs = [{'role': 'user', 'content': [{'text': tbl.prompt}]}]
65
+ ... tbl.add_computed_column(response=messages(msgs, model_id='anthropic.claude-3-haiku-20240307-v1:0'))
66
+ """
67
+
68
+ kwargs: dict[str, Any] = {'messages': messages, 'modelId': model_id}
69
+
70
+ if system is not None:
71
+ kwargs['system'] = system
72
+ if inference_config is not None:
73
+ kwargs['inferenceConfig'] = inference_config
74
+ if additional_model_request_fields is not None:
75
+ kwargs['additionalModelRequestFields'] = additional_model_request_fields
76
+
77
+ if tool_config is not None:
78
+ tool_config_ = {
79
+ 'tools': [
80
+ {
81
+ 'toolSpec': {
82
+ 'name': tool['name'],
83
+ 'description': tool['description'],
84
+ 'inputSchema': {
85
+ 'json': {
86
+ 'type': 'object',
87
+ 'properties': tool['parameters']['properties'],
88
+ 'required': tool['required'],
89
+ }
90
+ },
91
+ }
92
+ }
93
+ for tool in tool_config
94
+ ]
95
+ }
96
+ kwargs['toolConfig'] = tool_config_
97
+
98
+ return _bedrock_client().converse(**kwargs)
99
+
100
+
101
+ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
102
+ """Converts an Anthropic response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
103
+ return tools._invoke(_bedrock_response_to_pxt_tool_calls(response))
104
+
105
+
106
+ @pxt.udf
107
+ def _bedrock_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
108
+ if response.get('stopReason') != 'tool_use':
109
+ return None
110
+
111
+ pxt_tool_calls: dict[str, list[dict[str, Any]]] = {}
112
+ for message in response['output']['message']['content']:
113
+ if 'toolUse' in message:
114
+ tool_call = message['toolUse']
115
+ tool_name = tool_call['name']
116
+ if tool_name not in pxt_tool_calls:
117
+ pxt_tool_calls[tool_name] = []
118
+ pxt_tool_calls[tool_name].append({'args': tool_call['input']})
119
+
120
+ if len(pxt_tool_calls) == 0:
121
+ return None
122
+
123
+ return pxt_tool_calls
124
+
125
+
126
+ __all__ = local_public_names(__name__)
127
+
128
+
129
+ def __dir__() -> list[str]:
130
+ return __all__
@@ -13,6 +13,7 @@ import PIL.Image
13
13
 
14
14
  import pixeltable as pxt
15
15
  import pixeltable.exceptions as excs
16
+ import pixeltable.type_system as ts
16
17
  from pixeltable import env
17
18
  from pixeltable.func import Batch
18
19
  from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
@@ -61,14 +62,14 @@ def sentence_transformer(
61
62
 
62
63
 
63
64
  @sentence_transformer.conditional_return_type
64
- def _(model_id: str) -> pxt.ArrayType:
65
+ def _(model_id: str) -> ts.ArrayType:
65
66
  try:
66
67
  from sentence_transformers import SentenceTransformer
67
68
 
68
69
  model = _lookup_model(model_id, SentenceTransformer)
69
- return pxt.ArrayType((model.get_sentence_embedding_dimension(),), dtype=pxt.FloatType(), nullable=False)
70
+ return ts.ArrayType((model.get_sentence_embedding_dimension(),), dtype=ts.FloatType(), nullable=False)
70
71
  except ImportError:
71
- return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
72
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
72
73
 
73
74
 
74
75
  @pxt.udf
@@ -199,14 +200,14 @@ def _(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,
199
200
 
200
201
 
201
202
  @clip.conditional_return_type
202
- def _(model_id: str) -> pxt.ArrayType:
203
+ def _(model_id: str) -> ts.ArrayType:
203
204
  try:
204
205
  from transformers import CLIPModel
205
206
 
206
207
  model = _lookup_model(model_id, CLIPModel.from_pretrained)
207
- return pxt.ArrayType((model.config.projection_dim,), dtype=pxt.FloatType(), nullable=False)
208
+ return ts.ArrayType((model.config.projection_dim,), dtype=ts.FloatType(), nullable=False)
208
209
  except ImportError:
209
- return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
210
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
210
211
 
211
212
 
212
213
  @pxt.udf(batch_size=4)
@@ -16,6 +16,7 @@ from typing import Optional
16
16
  import PIL.Image
17
17
 
18
18
  import pixeltable as pxt
19
+ import pixeltable.type_system as ts
19
20
  from pixeltable.exprs import Expr
20
21
  from pixeltable.utils.code import local_public_names
21
22
 
@@ -88,10 +89,10 @@ def convert(self: PIL.Image.Image, mode: str) -> PIL.Image.Image:
88
89
 
89
90
 
90
91
  @convert.conditional_return_type
91
- def _(self: Expr, mode: str) -> pxt.ColumnType:
92
+ def _(self: Expr, mode: str) -> ts.ColumnType:
92
93
  input_type = self.col_type
93
- assert isinstance(input_type, pxt.ImageType)
94
- return pxt.ImageType(size=input_type.size, mode=mode, nullable=input_type.nullable)
94
+ assert isinstance(input_type, ts.ImageType)
95
+ return ts.ImageType(size=input_type.size, mode=mode, nullable=input_type.nullable)
95
96
 
96
97
 
97
98
  # Image.crop()
@@ -108,14 +109,12 @@ def crop(self: PIL.Image.Image, box: tuple[int, int, int, int]) -> PIL.Image.Ima
108
109
 
109
110
 
110
111
  @crop.conditional_return_type
111
- def _(self: Expr, box: tuple[int, int, int, int]) -> pxt.ColumnType:
112
+ def _(self: Expr, box: tuple[int, int, int, int]) -> ts.ColumnType:
112
113
  input_type = self.col_type
113
- assert isinstance(input_type, pxt.ImageType)
114
+ assert isinstance(input_type, ts.ImageType)
114
115
  if (isinstance(box, (list, tuple))) and len(box) == 4 and all(isinstance(x, int) for x in box):
115
- return pxt.ImageType(
116
- size=(box[2] - box[0], box[3] - box[1]), mode=input_type.mode, nullable=input_type.nullable
117
- )
118
- return pxt.ImageType(mode=input_type.mode, nullable=input_type.nullable) # we can't compute the size statically
116
+ return ts.ImageType(size=(box[2] - box[0], box[3] - box[1]), mode=input_type.mode, nullable=input_type.nullable)
117
+ return ts.ImageType(mode=input_type.mode, nullable=input_type.nullable) # we can't compute the size statically
119
118
 
120
119
 
121
120
  # Image.getchannel()
@@ -134,10 +133,10 @@ def getchannel(self: PIL.Image.Image, channel: int) -> PIL.Image.Image:
134
133
 
135
134
 
136
135
  @getchannel.conditional_return_type
137
- def _(self: Expr) -> pxt.ColumnType:
136
+ def _(self: Expr) -> ts.ColumnType:
138
137
  input_type = self.col_type
139
- assert isinstance(input_type, pxt.ImageType)
140
- return pxt.ImageType(size=input_type.size, mode='L', nullable=input_type.nullable)
138
+ assert isinstance(input_type, ts.ImageType)
139
+ return ts.ImageType(size=input_type.size, mode='L', nullable=input_type.nullable)
141
140
 
142
141
 
143
142
  @pxt.udf(is_method=True)
@@ -183,10 +182,10 @@ def resize(self: PIL.Image.Image, size: tuple[int, int]) -> PIL.Image.Image:
183
182
 
184
183
 
185
184
  @resize.conditional_return_type
186
- def _(self: Expr, size: tuple[int, int]) -> pxt.ColumnType:
185
+ def _(self: Expr, size: tuple[int, int]) -> ts.ColumnType:
187
186
  input_type = self.col_type
188
- assert isinstance(input_type, pxt.ImageType)
189
- return pxt.ImageType(size=size, mode=input_type.mode, nullable=input_type.nullable)
187
+ assert isinstance(input_type, ts.ImageType)
188
+ return ts.ImageType(size=size, mode=input_type.mode, nullable=input_type.nullable)
190
189
 
191
190
 
192
191
  # Image.rotate()
@@ -237,7 +236,7 @@ def transpose(self: PIL.Image.Image, method: int) -> PIL.Image.Image:
237
236
  @rotate.conditional_return_type
238
237
  @effect_spread.conditional_return_type
239
238
  @transpose.conditional_return_type
240
- def _(self: Expr) -> pxt.ColumnType:
239
+ def _(self: Expr) -> ts.ColumnType:
241
240
  return self.col_type
242
241
 
243
242
 
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar, Union
10
10
  import numpy as np
11
11
 
12
12
  import pixeltable as pxt
13
+ import pixeltable.type_system as ts
13
14
  from pixeltable.env import Env, register_client
14
15
  from pixeltable.func.signature import Batch
15
16
  from pixeltable.utils.code import local_public_names
@@ -176,9 +177,9 @@ async def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,
176
177
 
177
178
 
178
179
  @embeddings.conditional_return_type
179
- def _(model: str) -> pxt.ArrayType:
180
+ def _(model: str) -> ts.ArrayType:
180
181
  dimensions = _embedding_dimensions_cache.get(model) # `None` if unknown model
181
- return pxt.ArrayType((dimensions,), dtype=pxt.FloatType())
182
+ return ts.ArrayType((dimensions,), dtype=ts.FloatType())
182
183
 
183
184
 
184
185
  _T = TypeVar('_T')
@@ -21,6 +21,7 @@ import numpy as np
21
21
  import PIL
22
22
 
23
23
  import pixeltable as pxt
24
+ import pixeltable.type_system as ts
24
25
  from pixeltable import env, exprs
25
26
  from pixeltable.func import Batch, Tools
26
27
  from pixeltable.utils.code import local_public_names
@@ -666,13 +667,13 @@ async def embeddings(
666
667
 
667
668
 
668
669
  @embeddings.conditional_return_type
669
- def _(model: str, dimensions: Optional[int] = None) -> pxt.ArrayType:
670
+ def _(model: str, dimensions: Optional[int] = None) -> ts.ArrayType:
670
671
  if dimensions is None:
671
672
  if model not in _embedding_dimensions_cache:
672
673
  # TODO: find some other way to retrieve a sample
673
- return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
674
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
674
675
  dimensions = _embedding_dimensions_cache.get(model)
675
- return pxt.ArrayType((dimensions,), dtype=pxt.FloatType(), nullable=False)
676
+ return ts.ArrayType((dimensions,), dtype=ts.FloatType(), nullable=False)
676
677
 
677
678
 
678
679
  #####################################
@@ -738,17 +739,17 @@ async def image_generations(
738
739
 
739
740
 
740
741
  @image_generations.conditional_return_type
741
- def _(size: Optional[str] = None) -> pxt.ImageType:
742
+ def _(size: Optional[str] = None) -> ts.ImageType:
742
743
  if size is None:
743
- return pxt.ImageType(size=(1024, 1024))
744
+ return ts.ImageType(size=(1024, 1024))
744
745
  x_pos = size.find('x')
745
746
  if x_pos == -1:
746
- return pxt.ImageType()
747
+ return ts.ImageType()
747
748
  try:
748
749
  width, height = int(size[:x_pos]), int(size[x_pos + 1 :])
749
750
  except ValueError:
750
- return pxt.ImageType()
751
- return pxt.ImageType(size=(width, height))
751
+ return ts.ImageType()
752
+ return ts.ImageType(size=(width, height))
752
753
 
753
754
 
754
755
  #####################################
@@ -16,6 +16,7 @@ import tenacity
16
16
 
17
17
  import pixeltable as pxt
18
18
  import pixeltable.exceptions as excs
19
+ import pixeltable.type_system as ts
19
20
  from pixeltable import env
20
21
  from pixeltable.func import Batch
21
22
  from pixeltable.utils.code import local_public_names
@@ -225,12 +226,12 @@ async def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,
225
226
 
226
227
 
227
228
  @embeddings.conditional_return_type
228
- def _(model: str) -> pxt.ArrayType:
229
+ def _(model: str) -> ts.ArrayType:
229
230
  if model not in _embedding_dimensions_cache:
230
231
  # TODO: find some other way to retrieve a sample
231
- return pxt.ArrayType((None,), dtype=pxt.FloatType())
232
+ return ts.ArrayType((None,), dtype=ts.FloatType())
232
233
  dimensions = _embedding_dimensions_cache[model]
233
- return pxt.ArrayType((dimensions,), dtype=pxt.FloatType())
234
+ return ts.ArrayType((dimensions,), dtype=ts.FloatType())
234
235
 
235
236
 
236
237
  @pxt.udf(resource_pool='request-rate:together:images')