pixeltable 0.2.29__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

pixeltable/__init__.py CHANGED
@@ -4,7 +4,7 @@ from .exceptions import Error
4
4
  from .exprs import RELATIVE_PATH_ROOT
5
5
  from .func import Aggregator, Function, expr_udf, uda, udf
6
6
  from .globals import (array, configure_logging, create_dir, create_snapshot, create_table, create_view, drop_dir,
7
- drop_table, get_table, init, list_dirs, list_functions, list_tables, move)
7
+ drop_table, get_table, init, list_dirs, list_functions, list_tables, move, tool, tools)
8
8
  from .type_system import (Array, ArrayType, Audio, AudioType, Bool, BoolType, ColumnType, Document, DocumentType, Float,
9
9
  FloatType, Image, ImageType, Int, IntType, Json, JsonType, Required, String, StringType,
10
10
  Timestamp, TimestampType, Video, VideoType)
pixeltable/__version__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # These version placeholders will be replaced during build.
2
- __version__ = "0.2.29"
3
- __version_tuple__ = (0, 2, 29)
2
+ __version__ = "0.2.30"
3
+ __version_tuple__ = (0, 2, 30)
@@ -1,7 +1,7 @@
1
1
  from .catalog import Catalog
2
2
  from .column import Column
3
3
  from .dir import Dir
4
- from .globals import UpdateStatus, is_valid_identifier, is_valid_path, MediaValidation, IfExistsParam
4
+ from .globals import UpdateStatus, is_valid_identifier, is_valid_path, MediaValidation, IfExistsParam, IfNotExistsParam
5
5
  from .insertable_table import InsertableTable
6
6
  from .named_function import NamedFunction
7
7
  from .path import Path
@@ -65,6 +65,18 @@ class IfExistsParam(enum.Enum):
65
65
  val_strs = ', '.join(f'{s.lower()!r}' for s in cls.__members__.keys())
66
66
  raise excs.Error(f'{param_name} must be one of: [{val_strs}]')
67
67
 
68
+ class IfNotExistsParam(enum.Enum):
69
+ ERROR = 0
70
+ IGNORE = 1
71
+
72
+ @classmethod
73
+ def validated(cls, param_val: str, param_name: str) -> IfNotExistsParam:
74
+ try:
75
+ return cls[param_val.upper()]
76
+ except KeyError:
77
+ val_strs = ', '.join(f'{s.lower()!r}' for s in cls.__members__.keys())
78
+ raise excs.Error(f'{param_name} must be one of: [{val_strs}]')
79
+
68
80
  def is_valid_identifier(name: str) -> bool:
69
81
  return name.isidentifier() and not name.startswith('_')
70
82
 
@@ -25,7 +25,7 @@ from ..exprs import ColumnRef
25
25
  from ..utils.description_helper import DescriptionHelper
26
26
  from ..utils.filecache import FileCache
27
27
  from .column import Column
28
- from .globals import _ROWID_COLUMN_NAME, MediaValidation, UpdateStatus, is_system_column_name, is_valid_identifier
28
+ from .globals import _ROWID_COLUMN_NAME, MediaValidation, UpdateStatus, is_system_column_name, is_valid_identifier, IfNotExistsParam
29
29
  from .schema_object import SchemaObject
30
30
  from .table_version import TableVersion
31
31
  from .table_version_path import TableVersionPath
@@ -712,14 +712,19 @@ class Table(SchemaObject):
712
712
  if not exists:
713
713
  raise excs.Error(f'Unknown column: {col_ref.col.qualified_name}')
714
714
 
715
- def drop_column(self, column: Union[str, ColumnRef]) -> None:
715
+ def drop_column(self, column: Union[str, ColumnRef], if_not_exists: Literal['error', 'ignore'] = 'error') -> None:
716
716
  """Drop a column from the table.
717
717
 
718
718
  Args:
719
719
  column: The name or reference of the column to drop.
720
+ if_not_exists: Directive for handling a non-existent column. Must be one of the following:
721
+
722
+ - `'error'`: raise an error if the column does not exist.
723
+ - `'ignore'`: do nothing if the column does not exist.
720
724
 
721
725
  Raises:
722
- Error: If the column does not exist or if it is referenced by a dependent computed column.
726
+ Error: If the column does not exist and `if_exists='error'`,
727
+ or if it is referenced by a dependent computed column.
723
728
 
724
729
  Examples:
725
730
  Drop the column `col` from the table `my_table` by column name:
@@ -731,14 +736,32 @@ class Table(SchemaObject):
731
736
 
732
737
  >>> tbl = pxt.get_table('my_table')
733
738
  ... tbl.drop_column(tbl.col)
739
+
740
+ Drop the column `col` from the table `my_table` if it exists, otherwise do nothing:
741
+
742
+ >>> tbl = pxt.get_table('my_table')
743
+ ... tbl.drop_col(tbl.col, if_not_exists='ignore')
734
744
  """
735
745
  self._check_is_dropped()
746
+ if self._tbl_version_path.is_snapshot():
747
+ raise excs.Error('Cannot drop column from a snapshot.')
736
748
  col: Column = None
749
+ _if_not_exists = IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
737
750
  if isinstance(column, str):
738
- self.__check_column_name_exists(column)
751
+ col = self._tbl_version_path.get_column(column, include_bases=False)
752
+ if col is None:
753
+ if _if_not_exists == IfNotExistsParam.ERROR:
754
+ raise excs.Error(f'Column {column!r} unknown')
755
+ assert _if_not_exists == IfNotExistsParam.IGNORE
756
+ return
739
757
  col = self._tbl_version.cols_by_name[column]
740
758
  else:
741
- self.__check_column_ref_exists(column)
759
+ exists = self._tbl_version_path.has_column(column.col, include_bases=False)
760
+ if not exists:
761
+ if _if_not_exists == IfNotExistsParam.ERROR:
762
+ raise excs.Error(f'Unknown column: {column.col.qualified_name}')
763
+ assert _if_not_exists == IfNotExistsParam.IGNORE
764
+ return
742
765
  col = column.col
743
766
 
744
767
  dependent_user_cols = [c for c in col.dependent_cols if c.name is not None]
@@ -866,7 +889,9 @@ class Table(SchemaObject):
866
889
  def drop_embedding_index(
867
890
  self, *,
868
891
  column: Union[str, ColumnRef, None] = None,
869
- idx_name: Optional[str] = None) -> None:
892
+ idx_name: Optional[str] = None,
893
+ if_not_exists: Literal['error', 'ignore'] = 'error'
894
+ ) -> None:
870
895
  """
871
896
  Drop an embedding index from the table. Either a column name or an index name (but not both) must be
872
897
  specified. If a column name or reference is specified, it must be a column containing exactly one
@@ -876,11 +901,20 @@ class Table(SchemaObject):
876
901
  column: The name of, or reference to, the column from which to drop the index.
877
902
  The column must have only one embedding index.
878
903
  idx_name: The name of the index to drop.
904
+ if_not_exists: Directive for handling a non-existent index. Must be one of the following:
905
+
906
+ - `'error'`: raise an error if the index does not exist.
907
+ - `'ignore'`: do nothing if the index does not exist.
908
+
909
+ Note that `if_not_exists` parameter is only applicable when an `idx_name` is specified
910
+ and it does not exist, or when `column` is specified and it has no index.
911
+ `if_not_exists` does not apply to non-exisitng column.
879
912
 
880
913
  Raises:
881
914
  Error: If `column` is specified, but the column does not exist, or it contains no embedding
882
- indices or multiple embedding indices.
883
- Error: If `idx_name` is specified, but the index does not exist or is not an embedding index.
915
+ indices and `if_not_exists='error'`, or the column has multiple embedding indices.
916
+ Error: If `idx_name` is specified, but the index is not an embedding index, or
917
+ the index does not exist and `if_not_exists='error'`.
884
918
 
885
919
  Examples:
886
920
  Drop the embedding index on the `img` column of the table `my_table` by column name:
@@ -897,6 +931,9 @@ class Table(SchemaObject):
897
931
  >>> tbl = pxt.get_table('my_table')
898
932
  ... tbl.drop_embedding_index(idx_name='idx1')
899
933
 
934
+ Drop the embedding index `idx1` of the table `my_table` by index name, if it exists, otherwise do nothing:
935
+ >>> tbl = pxt.get_table('my_table')
936
+ ... tbl.drop_embedding_index(idx_name='idx1', if_not_exists='ignore')
900
937
  """
901
938
  if (column is None) == (idx_name is None):
902
939
  raise excs.Error("Exactly one of 'column' or 'idx_name' must be provided")
@@ -910,12 +947,14 @@ class Table(SchemaObject):
910
947
  self.__check_column_ref_exists(column, include_bases=True)
911
948
  col = column.col
912
949
  assert col is not None
913
- self._drop_index(col=col, idx_name=idx_name, _idx_class=index.EmbeddingIndex)
950
+ self._drop_index(col=col, idx_name=idx_name, _idx_class=index.EmbeddingIndex, if_not_exists=if_not_exists)
914
951
 
915
952
  def drop_index(
916
953
  self, *,
917
954
  column: Union[str, ColumnRef, None] = None,
918
- idx_name: Optional[str] = None) -> None:
955
+ idx_name: Optional[str] = None,
956
+ if_not_exists: Literal['error', 'ignore'] = 'error'
957
+ ) -> None:
919
958
  """
920
959
  Drop an index from the table. Either a column name or an index name (but not both) must be
921
960
  specified. If a column name or reference is specified, it must be a column containing exactly one index;
@@ -925,6 +964,14 @@ class Table(SchemaObject):
925
964
  column: The name of, or reference to, the column from which to drop the index.
926
965
  The column must have only one embedding index.
927
966
  idx_name: The name of the index to drop.
967
+ if_not_exists: Directive for handling a non-existent index. Must be one of the following:
968
+
969
+ - `'error'`: raise an error if the index does not exist.
970
+ - `'ignore'`: do nothing if the index does not exist.
971
+
972
+ Note that `if_not_exists` parameter is only applicable when an `idx_name` is specified
973
+ and it does not exist, or when `column` is specified and it has no index.
974
+ `if_not_exists` does not apply to non-exisitng column.
928
975
 
929
976
  Raises:
930
977
  Error: If `column` is specified, but the column does not exist, or it contains no
@@ -946,6 +993,10 @@ class Table(SchemaObject):
946
993
  >>> tbl = pxt.get_table('my_table')
947
994
  ... tbl.drop_index(idx_name='idx1')
948
995
 
996
+ Drop the index `idx1` of the table `my_table` by index name, if it exists, otherwise do nothing:
997
+ >>> tbl = pxt.get_table('my_table')
998
+ ... tbl.drop_index(idx_name='idx1', if_not_exists='ignore')
999
+
949
1000
  """
950
1001
  if (column is None) == (idx_name is None):
951
1002
  raise excs.Error("Exactly one of 'column' or 'idx_name' must be provided")
@@ -959,20 +1010,25 @@ class Table(SchemaObject):
959
1010
  self.__check_column_ref_exists(column, include_bases=True)
960
1011
  col = column.col
961
1012
  assert col is not None
962
- self._drop_index(col=col, idx_name=idx_name)
1013
+ self._drop_index(col=col, idx_name=idx_name, if_not_exists=if_not_exists)
963
1014
 
964
1015
  def _drop_index(
965
1016
  self, *, col: Optional[Column] = None,
966
1017
  idx_name: Optional[str] = None,
967
- _idx_class: Optional[type[index.IndexBase]] = None
1018
+ _idx_class: Optional[type[index.IndexBase]] = None,
1019
+ if_not_exists: Literal['error', 'ignore'] = 'error'
968
1020
  ) -> None:
969
1021
  if self._tbl_version_path.is_snapshot():
970
1022
  raise excs.Error('Cannot drop an index from a snapshot')
971
1023
  assert (col is None) != (idx_name is None)
972
1024
 
973
1025
  if idx_name is not None:
1026
+ _if_not_exists = IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
974
1027
  if idx_name not in self._tbl_version.idxs_by_name:
975
- raise excs.Error(f'Index {idx_name!r} does not exist')
1028
+ if _if_not_exists == IfNotExistsParam.ERROR:
1029
+ raise excs.Error(f'Index {idx_name!r} does not exist')
1030
+ assert _if_not_exists == IfNotExistsParam.IGNORE
1031
+ return
976
1032
  idx_id = self._tbl_version.idxs_by_name[idx_name].id
977
1033
  else:
978
1034
  if col.tbl.id != self._tbl_version.id:
@@ -982,7 +1038,11 @@ class Table(SchemaObject):
982
1038
  if _idx_class is not None:
983
1039
  idx_info = [info for info in idx_info if isinstance(info.idx, _idx_class)]
984
1040
  if len(idx_info) == 0:
985
- raise excs.Error(f'Column {col.name!r} does not have an index')
1041
+ _if_not_exists = IfNotExistsParam.validated(if_not_exists, 'if_not_exists')
1042
+ if _if_not_exists == IfNotExistsParam.ERROR:
1043
+ raise excs.Error(f'Column {col.name!r} does not have an index')
1044
+ assert _if_not_exists == IfNotExistsParam.IGNORE
1045
+ return
986
1046
  if len(idx_info) > 1:
987
1047
  raise excs.Error(f"Column {col.name!r} has multiple indices; specify 'idx_name' instead")
988
1048
  idx_id = idx_info[0].id
@@ -199,11 +199,6 @@ class FunctionCall(Expr):
199
199
  pass
200
200
 
201
201
  if not isinstance(arg, Expr):
202
- # make sure that non-Expr args are json-serializable and are literals of the correct type
203
- try:
204
- _ = json.dumps(arg)
205
- except TypeError:
206
- raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg} (of type {type(arg)})')
207
202
  if arg is not None:
208
203
  try:
209
204
  param_type = param.col_type
@@ -5,4 +5,5 @@ from .function import Function
5
5
  from .function_registry import FunctionRegistry
6
6
  from .query_template_function import QueryTemplateFunction
7
7
  from .signature import Signature, Parameter, Batch
8
+ from .tools import Tool, Tools
8
9
  from .udf import udf, make_function, expr_udf
@@ -139,6 +139,9 @@ class AggregateFunction(Function):
139
139
  self.init_param_names.append(init_param_names)
140
140
  return self
141
141
 
142
+ def _docstring(self) -> Optional[str]:
143
+ return inspect.getdoc(self.agg_classes[0])
144
+
142
145
  def help_str(self) -> str:
143
146
  res = super().help_str()
144
147
  # We need to reference agg_classes[0] rather than agg_class here, because we want this to work even if the
@@ -48,6 +48,9 @@ class CallableFunction(Function):
48
48
  def is_batched(self) -> bool:
49
49
  return self.batch_size is not None
50
50
 
51
+ def _docstring(self) -> Optional[str]:
52
+ return inspect.getdoc(self.py_fns[0])
53
+
51
54
  @property
52
55
  def py_fn(self) -> Callable:
53
56
  assert not self.is_polymorphic
@@ -109,11 +112,6 @@ class CallableFunction(Function):
109
112
  self.py_fns.append(fn)
110
113
  return self
111
114
 
112
- def help_str(self) -> str:
113
- res = super().help_str()
114
- res += '\n\n' + inspect.getdoc(self.py_fns[0])
115
- return res
116
-
117
115
  def _as_dict(self) -> dict:
118
116
  if self.self_path is None:
119
117
  # this is not a module function
@@ -87,6 +87,13 @@ class ExprTemplateFunction(Function):
87
87
  assert not result._contains(exprs.Variable)
88
88
  return result
89
89
 
90
+ def _docstring(self) -> Optional[str]:
91
+ from pixeltable import exprs
92
+
93
+ if isinstance(self.templates[0].expr, exprs.FunctionCall):
94
+ return self.templates[0].expr.fn._docstring()
95
+ return None
96
+
90
97
  def exec(self, args: Sequence[Any], kwargs: dict[str, Any]) -> Any:
91
98
  from pixeltable import exec, exprs
92
99
 
@@ -92,6 +92,16 @@ class Function(abc.ABC):
92
92
  assert not self.is_polymorphic
93
93
  return len(self.signature.parameters)
94
94
 
95
+ def _docstring(self) -> Optional[str]:
96
+ return None
97
+
98
+ def help_str(self) -> str:
99
+ docstring = self._docstring()
100
+ display = self.display_name + str(self.signatures[0])
101
+ if docstring is None:
102
+ return display
103
+ return f'{display}\n\n{docstring}'
104
+
95
105
  @property
96
106
  def _resolved_fns(self) -> list[Self]:
97
107
  """
@@ -129,9 +139,6 @@ class Function(abc.ABC):
129
139
  """
130
140
  raise NotImplementedError()
131
141
 
132
- def help_str(self) -> str:
133
- return self.display_name + str(self.signatures[0])
134
-
135
142
  def __call__(self, *args: Any, **kwargs: Any) -> 'pxt.exprs.FunctionCall':
136
143
  from pixeltable import exprs
137
144
 
@@ -284,7 +291,7 @@ class Function(abc.ABC):
284
291
  """Print source code"""
285
292
  print('source not available')
286
293
 
287
- def as_dict(self) -> dict:
294
+ def as_dict(self) -> dict[str, Any]:
288
295
  """
289
296
  Return a serialized reference to the instance that can be passed to json.dumps() and converted back
290
297
  to an instance with from_dict().
@@ -0,0 +1,116 @@
1
+ from dataclasses import dataclass
2
+ import dataclasses
3
+ import json
4
+ from typing import TYPE_CHECKING, Any, Optional
5
+
6
+ import pydantic
7
+
8
+ from .function import Function
9
+ from .signature import Parameter
10
+ from .udf import udf
11
+
12
+ if TYPE_CHECKING:
13
+ from pixeltable import exprs
14
+
15
+
16
+ # The Tool and Tools classes are containers that hold Pixeltable UDFs and related metadata, so that they can be
17
+ # realized as LLM tools. They are implemented as Pydantic models in order to provide a canonical way of converting
18
+ # to JSON, via the Pydantic `model_serializer` interface. In this way, they can be passed directly as UDF
19
+ # parameters as described in the `pixeltable.tools` and `pixeltable.tool` docstrings.
20
+ #
21
+ # (The dataclass dict serializer is insufficiently flexible for this purpose: `Tool` contains a member of type
22
+ # `Function`, which is not natively JSON-serializable; Pydantic provides a way of customizing its default
23
+ # serialization behavior, whereas dataclasses do not.)
24
+
25
+ class Tool(pydantic.BaseModel):
26
+ # Allow arbitrary types so that we can include a Pixeltable function in the schema.
27
+ # We will implement a model_serializer to ensure the Tool model can be serialized.
28
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
29
+
30
+ fn: Function
31
+ name: Optional[str] = None
32
+ description: Optional[str] = None
33
+
34
+ @property
35
+ def parameters(self) -> dict[str, Parameter]:
36
+ return self.fn.signature.parameters
37
+
38
+ @pydantic.model_serializer
39
+ def ser_model(self) -> dict[str, Any]:
40
+ return {
41
+ 'name': self.name or self.fn.name,
42
+ 'description': self.description or self.fn._docstring(),
43
+ 'parameters': {
44
+ 'type': 'object',
45
+ 'properties': {
46
+ param.name: param.col_type._to_json_schema()
47
+ for param in self.parameters.values()
48
+ }
49
+ },
50
+ 'required': [
51
+ param.name for param in self.parameters.values() if not param.col_type.nullable
52
+ ],
53
+ 'additionalProperties': False, # TODO Handle kwargs?
54
+ }
55
+
56
+ # `tool_calls` must be in standardized tool invocation format:
57
+ # {tool_name: {'args': {name1: value1, name2: value2, ...}}, ...}
58
+ def invoke(self, tool_calls: 'exprs.Expr') -> 'exprs.FunctionCall':
59
+ kwargs = {
60
+ param.name: self.__extract_tool_arg(param, tool_calls)
61
+ for param in self.parameters.values()
62
+ }
63
+ return self.fn(**kwargs)
64
+
65
+ def __extract_tool_arg(self, param: Parameter, tool_calls: 'exprs.Expr') -> 'exprs.Expr':
66
+ func_name = self.name or self.fn.name
67
+ if param.col_type.is_string_type():
68
+ return _extract_str_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
69
+ if param.col_type.is_int_type():
70
+ return _extract_int_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
71
+ if param.col_type.is_float_type():
72
+ return _extract_float_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
73
+ if param.col_type.is_bool_type():
74
+ return _extract_bool_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
75
+ assert False
76
+
77
+
78
+ class Tools(pydantic.BaseModel):
79
+ tools: list[Tool]
80
+
81
+ @pydantic.model_serializer
82
+ def ser_model(self) -> list[dict[str, Any]]:
83
+ return [tool.ser_model() for tool in self.tools]
84
+
85
+ # `tool_calls` must be in standardized tool invocation format:
86
+ # {tool_name: {'args': {name1: value1, name2: value2, ...}}, ...}
87
+ def _invoke(self, tool_calls: 'exprs.Expr') -> 'exprs.InlineDict':
88
+ from pixeltable import exprs
89
+
90
+ return exprs.InlineDict({
91
+ tool.name or tool.fn.name: tool.invoke(tool_calls)
92
+ for tool in self.tools
93
+ })
94
+
95
+
96
+ @udf
97
+ def _extract_str_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[str]:
98
+ return str(_extract_arg(tool_calls, func_name, param_name))
99
+
100
+ @udf
101
+ def _extract_int_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[int]:
102
+ return int(_extract_arg(tool_calls, func_name, param_name))
103
+
104
+ @udf
105
+ def _extract_float_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[float]:
106
+ return float(_extract_arg(tool_calls, func_name, param_name))
107
+
108
+ @udf
109
+ def _extract_bool_tool_arg(tool_calls: dict, func_name: str, param_name: str) -> Optional[bool]:
110
+ return bool(_extract_arg(tool_calls, func_name, param_name))
111
+
112
+ def _extract_arg(tool_calls: dict, func_name: str, param_name: str) -> Any:
113
+ if func_name in tool_calls:
114
+ arguments = tool_calls[func_name]['args']
115
+ return arguments.get(param_name)
116
+ return None
@@ -1,7 +1,7 @@
1
1
  from pixeltable.utils.code import local_public_names
2
2
 
3
- from . import (anthropic, audio, fireworks, gemini, huggingface, image, json, llama_cpp, mistralai, ollama, openai,
4
- string, timestamp, together, video, vision, whisper)
3
+ from . import (anthropic, audio, fireworks, gemini, huggingface, image, json, llama_cpp, math, mistralai, ollama,
4
+ openai, string, timestamp, together, video, vision, whisper)
5
5
  from .globals import *
6
6
 
7
7
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
@@ -10,7 +10,8 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
10
10
  import tenacity
11
11
 
12
12
  import pixeltable as pxt
13
- from pixeltable import env
13
+ from pixeltable import env, exprs
14
+ from pixeltable.func import Tools
14
15
  from pixeltable.utils.code import local_public_names
15
16
 
16
17
  if TYPE_CHECKING:
@@ -47,7 +48,7 @@ def messages(
47
48
  system: Optional[str] = None,
48
49
  temperature: Optional[float] = None,
49
50
  tool_choice: Optional[list[dict]] = None,
50
- tools: Optional[dict] = None,
51
+ tools: Optional[list[dict]] = None,
51
52
  top_k: Optional[int] = None,
52
53
  top_p: Optional[float] = None,
53
54
  ) -> dict:
@@ -77,6 +78,21 @@ def messages(
77
78
  >>> msgs = [{'role': 'user', 'content': tbl.prompt}]
78
79
  ... tbl['response'] = messages(msgs, model='claude-3-haiku-20240307')
79
80
  """
81
+ if tools is not None:
82
+ # Reformat `tools` into Anthropic format
83
+ tools = [
84
+ {
85
+ 'name': tool['name'],
86
+ 'description': tool['description'],
87
+ 'input_schema': {
88
+ 'type': 'object',
89
+ 'properties': tool['parameters']['properties'],
90
+ 'required': tool['required'],
91
+ },
92
+ }
93
+ for tool in tools
94
+ ]
95
+
80
96
  return _retry(_anthropic_client().messages.create)(
81
97
  messages=messages,
82
98
  model=model,
@@ -92,6 +108,24 @@ def messages(
92
108
  ).dict()
93
109
 
94
110
 
111
+ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
112
+ """Converts an Anthropic response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
113
+ return tools._invoke(_anthropic_response_to_pxt_tool_calls(response))
114
+
115
+
116
+ @pxt.udf
117
+ def _anthropic_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
118
+ anthropic_tool_calls = [r for r in response['content'] if r['type'] == 'tool_use']
119
+ if len(anthropic_tool_calls) > 0:
120
+ return {
121
+ tool_call['name']: {
122
+ 'args': tool_call['input']
123
+ }
124
+ for tool_call in anthropic_tool_calls
125
+ }
126
+ return None
127
+
128
+
95
129
  _T = TypeVar('_T')
96
130
 
97
131
 
@@ -0,0 +1,67 @@
1
+ import builtins
2
+ import math
3
+ from typing import Optional
4
+
5
+ import sqlalchemy as sql
6
+
7
+ import pixeltable as pxt
8
+ from pixeltable.utils.code import local_public_names
9
+
10
+
11
+ @pxt.udf(is_method=True)
12
+ def abs(self: float) -> float:
13
+ return builtins.abs(self)
14
+
15
+
16
+ @abs.to_sql
17
+ def _(self: sql.ColumnElement) -> sql.ColumnElement:
18
+ return sql.func.abs(self)
19
+
20
+
21
+ @pxt.udf(is_method=True)
22
+ def ceil(self: float) -> float:
23
+ # This ensures the same behavior as SQL
24
+ if math.isfinite(self):
25
+ return float(math.ceil(self))
26
+ else:
27
+ return self
28
+
29
+
30
+ @ceil.to_sql
31
+ def _(self: sql.ColumnElement) -> sql.ColumnElement:
32
+ return sql.func.ceiling(self)
33
+
34
+
35
+ @pxt.udf(is_method=True)
36
+ def floor(self: float) -> float:
37
+ # This ensures the same behavior as SQL
38
+ if math.isfinite(self):
39
+ return float(math.floor(self))
40
+ else:
41
+ return self
42
+
43
+
44
+ @floor.to_sql
45
+ def _(self: sql.ColumnElement) -> sql.ColumnElement:
46
+ return sql.func.floor(self)
47
+
48
+
49
+ @pxt.udf(is_method=True)
50
+ def round(self: float, digits: Optional[int] = None) -> float:
51
+ # Set digits explicitly to 0 to guarantee a return type of float; this ensures the same behavior as SQL
52
+ return builtins.round(self, digits or 0)
53
+
54
+
55
+ @round.to_sql
56
+ def _(self: sql.ColumnElement, digits: Optional[sql.ColumnElement] = None) -> sql.ColumnElement:
57
+ if digits is None:
58
+ return sql.func.round(self)
59
+ else:
60
+ return sql.func.round(sql.cast(self, sql.Numeric), sql.cast(digits, sql.Integer))
61
+
62
+
63
+ __all__ = local_public_names(__name__)
64
+
65
+
66
+ def __dir__():
67
+ return __all__
@@ -7,6 +7,7 @@ the [Working with OpenAI](https://pixeltable.readme.io/docs/working-with-openai)
7
7
 
8
8
  import base64
9
9
  import io
10
+ import json
10
11
  import pathlib
11
12
  import uuid
12
13
  from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
@@ -16,8 +17,8 @@ import PIL.Image
16
17
  import tenacity
17
18
 
18
19
  import pixeltable as pxt
19
- from pixeltable import env
20
- from pixeltable.func import Batch
20
+ from pixeltable import env, exprs
21
+ from pixeltable.func import Batch, Tools
21
22
  from pixeltable.utils.code import local_public_names
22
23
 
23
24
  if TYPE_CHECKING:
@@ -225,6 +226,16 @@ def chat_completions(
225
226
  ]
226
227
  tbl['response'] = chat_completions(messages, model='gpt-4o-mini')
227
228
  """
229
+
230
+ if tools is not None:
231
+ tools = [
232
+ {
233
+ 'type': 'function',
234
+ 'function': tool
235
+ }
236
+ for tool in tools
237
+ ]
238
+
228
239
  result = _retry(_openai_client().chat.completions.create)(
229
240
  messages=messages,
230
241
  model=model,
@@ -453,6 +464,24 @@ def moderations(input: str, *, model: Optional[str] = None) -> dict:
453
464
  return result.dict()
454
465
 
455
466
 
467
+ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
468
+ """Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
469
+ return tools._invoke(_openai_response_to_pxt_tool_calls(response))
470
+
471
+
472
+ @pxt.udf
473
+ def _openai_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
474
+ openai_tool_calls = response['choices'][0]['message']['tool_calls']
475
+ if openai_tool_calls is not None:
476
+ return {
477
+ tool_call['function']['name']: {
478
+ 'args': json.loads(tool_call['function']['arguments'])
479
+ }
480
+ for tool_call in openai_tool_calls
481
+ }
482
+ return None
483
+
484
+
456
485
  _T = TypeVar('_T')
457
486
 
458
487