pixeltable 0.2.25__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (97) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/dir.py +6 -0
  5. pixeltable/catalog/globals.py +25 -0
  6. pixeltable/catalog/named_function.py +4 -0
  7. pixeltable/catalog/path_dict.py +37 -11
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +421 -231
  10. pixeltable/catalog/table_version.py +22 -8
  11. pixeltable/catalog/view.py +5 -7
  12. pixeltable/dataframe.py +439 -105
  13. pixeltable/env.py +19 -5
  14. pixeltable/exec/__init__.py +1 -1
  15. pixeltable/exec/exec_node.py +6 -7
  16. pixeltable/exec/expr_eval_node.py +1 -1
  17. pixeltable/exec/sql_node.py +92 -45
  18. pixeltable/exprs/__init__.py +1 -0
  19. pixeltable/exprs/arithmetic_expr.py +1 -1
  20. pixeltable/exprs/array_slice.py +1 -1
  21. pixeltable/exprs/column_property_ref.py +1 -1
  22. pixeltable/exprs/column_ref.py +29 -2
  23. pixeltable/exprs/comparison.py +1 -1
  24. pixeltable/exprs/compound_predicate.py +1 -1
  25. pixeltable/exprs/expr.py +12 -5
  26. pixeltable/exprs/expr_set.py +8 -0
  27. pixeltable/exprs/function_call.py +147 -39
  28. pixeltable/exprs/in_predicate.py +1 -1
  29. pixeltable/exprs/inline_expr.py +25 -5
  30. pixeltable/exprs/is_null.py +1 -1
  31. pixeltable/exprs/json_mapper.py +1 -1
  32. pixeltable/exprs/json_path.py +1 -1
  33. pixeltable/exprs/method_ref.py +1 -1
  34. pixeltable/exprs/row_builder.py +1 -1
  35. pixeltable/exprs/rowid_ref.py +1 -1
  36. pixeltable/exprs/similarity_expr.py +14 -7
  37. pixeltable/exprs/sql_element_cache.py +4 -0
  38. pixeltable/exprs/type_cast.py +2 -2
  39. pixeltable/exprs/variable.py +3 -0
  40. pixeltable/func/__init__.py +5 -4
  41. pixeltable/func/aggregate_function.py +151 -68
  42. pixeltable/func/callable_function.py +48 -16
  43. pixeltable/func/expr_template_function.py +64 -23
  44. pixeltable/func/function.py +195 -27
  45. pixeltable/func/function_registry.py +2 -1
  46. pixeltable/func/query_template_function.py +51 -9
  47. pixeltable/func/signature.py +64 -7
  48. pixeltable/func/tools.py +153 -0
  49. pixeltable/func/udf.py +57 -35
  50. pixeltable/functions/__init__.py +2 -2
  51. pixeltable/functions/anthropic.py +51 -4
  52. pixeltable/functions/gemini.py +85 -0
  53. pixeltable/functions/globals.py +54 -34
  54. pixeltable/functions/huggingface.py +10 -28
  55. pixeltable/functions/json.py +3 -8
  56. pixeltable/functions/math.py +67 -0
  57. pixeltable/functions/ollama.py +8 -8
  58. pixeltable/functions/openai.py +51 -4
  59. pixeltable/functions/timestamp.py +1 -1
  60. pixeltable/functions/video.py +3 -9
  61. pixeltable/functions/vision.py +1 -1
  62. pixeltable/globals.py +354 -80
  63. pixeltable/index/embedding_index.py +106 -34
  64. pixeltable/io/__init__.py +1 -1
  65. pixeltable/io/label_studio.py +1 -1
  66. pixeltable/io/parquet.py +39 -19
  67. pixeltable/iterators/document.py +12 -0
  68. pixeltable/metadata/__init__.py +1 -1
  69. pixeltable/metadata/converters/convert_16.py +2 -1
  70. pixeltable/metadata/converters/convert_17.py +2 -1
  71. pixeltable/metadata/converters/convert_22.py +17 -0
  72. pixeltable/metadata/converters/convert_23.py +35 -0
  73. pixeltable/metadata/converters/convert_24.py +56 -0
  74. pixeltable/metadata/converters/convert_25.py +19 -0
  75. pixeltable/metadata/converters/util.py +4 -2
  76. pixeltable/metadata/notes.py +4 -0
  77. pixeltable/metadata/schema.py +1 -0
  78. pixeltable/plan.py +128 -50
  79. pixeltable/store.py +1 -1
  80. pixeltable/type_system.py +196 -54
  81. pixeltable/utils/arrow.py +8 -3
  82. pixeltable/utils/description_helper.py +89 -0
  83. pixeltable/utils/documents.py +14 -0
  84. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/METADATA +30 -20
  85. pixeltable-0.3.0.dist-info/RECORD +155 -0
  86. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/WHEEL +1 -1
  87. pixeltable-0.3.0.dist-info/entry_points.txt +3 -0
  88. pixeltable/tool/create_test_db_dump.py +0 -311
  89. pixeltable/tool/create_test_video.py +0 -81
  90. pixeltable/tool/doc_plugins/griffe.py +0 -50
  91. pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
  92. pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
  93. pixeltable/tool/embed_udf.py +0 -9
  94. pixeltable/tool/mypy_plugin.py +0 -55
  95. pixeltable-0.2.25.dist-info/RECORD +0 -154
  96. pixeltable-0.2.25.dist-info/entry_points.txt +0 -3
  97. {pixeltable-0.2.25.dist-info → pixeltable-0.3.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,153 @@
1
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar, Union
2
+
3
+ import pydantic
4
+
5
+ import pixeltable.exceptions as excs
6
+
7
+ from .function import Function
8
+ from .signature import Parameter
9
+ from .udf import udf
10
+
11
+ if TYPE_CHECKING:
12
+ from pixeltable import exprs
13
+
14
+
15
+ # The Tool and Tools classes are containers that hold Pixeltable UDFs and related metadata, so that they can be
16
+ # realized as LLM tools. They are implemented as Pydantic models in order to provide a canonical way of converting
17
+ # to JSON, via the Pydantic `model_serializer` interface. In this way, they can be passed directly as UDF
18
+ # parameters as described in the `pixeltable.tools` and `pixeltable.tool` docstrings.
19
+ #
20
+ # (The dataclass dict serializer is insufficiently flexible for this purpose: `Tool` contains a member of type
21
+ # `Function`, which is not natively JSON-serializable; Pydantic provides a way of customizing its default
22
+ # serialization behavior, whereas dataclasses do not.)
23
+
24
+ class Tool(pydantic.BaseModel):
25
+ # Allow arbitrary types so that we can include a Pixeltable function in the schema.
26
+ # We will implement a model_serializer to ensure the Tool model can be serialized.
27
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
28
+
29
+ fn: Function
30
+ name: Optional[str] = None
31
+ description: Optional[str] = None
32
+
33
+ @property
34
+ def parameters(self) -> dict[str, Parameter]:
35
+ return self.fn.signature.parameters
36
+
37
+ @pydantic.model_serializer
38
+ def ser_model(self) -> dict[str, Any]:
39
+ return {
40
+ 'name': self.name or self.fn.name,
41
+ 'description': self.description or self.fn._docstring(),
42
+ 'parameters': {
43
+ 'type': 'object',
44
+ 'properties': {
45
+ param.name: param.col_type._to_json_schema()
46
+ for param in self.parameters.values()
47
+ }
48
+ },
49
+ 'required': [
50
+ param.name for param in self.parameters.values() if not param.col_type.nullable
51
+ ],
52
+ 'additionalProperties': False, # TODO Handle kwargs?
53
+ }
54
+
55
+ # `tool_calls` must be in standardized tool invocation format:
56
+ # {tool_name: {'args': {name1: value1, name2: value2, ...}}, ...}
57
+ def invoke(self, tool_calls: 'exprs.Expr') -> 'exprs.FunctionCall':
58
+ kwargs = {
59
+ param.name: self.__extract_tool_arg(param, tool_calls)
60
+ for param in self.parameters.values()
61
+ }
62
+ return self.fn(**kwargs)
63
+
64
+ def __extract_tool_arg(self, param: Parameter, tool_calls: 'exprs.Expr') -> 'exprs.Expr':
65
+ func_name = self.name or self.fn.name
66
+ if param.col_type.is_string_type():
67
+ return _extract_str_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
68
+ if param.col_type.is_int_type():
69
+ return _extract_int_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
70
+ if param.col_type.is_float_type():
71
+ return _extract_float_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
72
+ if param.col_type.is_bool_type():
73
+ return _extract_bool_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
74
+ assert False
75
+
76
+
77
+ class ToolChoice(pydantic.BaseModel):
78
+ auto: bool
79
+ required: bool
80
+ tool: Optional[str]
81
+ parallel_tool_calls: bool
82
+
83
+
84
+ class Tools(pydantic.BaseModel):
85
+ tools: list[Tool]
86
+
87
+ @pydantic.model_serializer
88
+ def ser_model(self) -> list[dict[str, Any]]:
89
+ return [tool.ser_model() for tool in self.tools]
90
+
91
+ # `tool_calls` must be in standardized tool invocation format:
92
+ # {tool_name: {'args': {name1: value1, name2: value2, ...}}, ...}
93
+ def _invoke(self, tool_calls: 'exprs.Expr') -> 'exprs.InlineDict':
94
+ from pixeltable import exprs
95
+
96
+ return exprs.InlineDict({
97
+ tool.name or tool.fn.name: tool.invoke(tool_calls)
98
+ for tool in self.tools
99
+ })
100
+
101
+ def choice(
102
+ self,
103
+ auto: bool = False,
104
+ required: bool = False,
105
+ tool: Union[str, Function, None] = None,
106
+ parallel_tool_calls: bool = True,
107
+ ) -> ToolChoice:
108
+ if sum([auto, required, tool is not None]) != 1:
109
+ raise excs.Error('Exactly one of `auto`, `required`, or `tool` must be specified.')
110
+ tool_name: Optional[str] = None
111
+ if tool is not None:
112
+ try:
113
+ tool_obj = next(
114
+ t for t in self.tools
115
+ if (isinstance(tool, Function) and t.fn == tool)
116
+ or (isinstance(tool, str) and (t.name or t.fn.name) == tool)
117
+ )
118
+ tool_name = tool_obj.name or tool_obj.fn.name
119
+ except StopIteration:
120
+ raise excs.Error(f'That tool is not in the specified list of tools: {tool}')
121
+ return ToolChoice(auto=auto, required=required, tool=tool_name, parallel_tool_calls=parallel_tool_calls)
122
+
123
+
124
+ @udf
125
+ def _extract_str_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[str]:
126
+ return _extract_arg(str, tool_calls, func_name, param_name)
127
+
128
+
129
+ @udf
130
+ def _extract_int_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[int]:
131
+ return _extract_arg(int, tool_calls, func_name, param_name)
132
+
133
+
134
+ @udf
135
+ def _extract_float_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[float]:
136
+ return _extract_arg(float, tool_calls, func_name, param_name)
137
+
138
+
139
+ @udf
140
+ def _extract_bool_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[bool]:
141
+ return _extract_arg(bool, tool_calls, func_name, param_name)
142
+
143
+
144
+ T = TypeVar('T')
145
+
146
+
147
+ def _extract_arg(eval_fn: Callable[[Any], T], tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[T]:
148
+ if func_name in tool_calls:
149
+ arguments = tool_calls[func_name]['args']
150
+ if param_name in arguments:
151
+ return eval_fn(arguments[param_name])
152
+ return None
153
+ return None
pixeltable/func/udf.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Callable, Optional, overload
3
+ from typing import Any, Callable, Optional, Sequence, overload
4
4
 
5
5
  import pixeltable.exceptions as excs
6
6
  import pixeltable.type_system as ts
7
7
 
8
8
  from .callable_function import CallableFunction
9
- from .expr_template_function import ExprTemplateFunction
9
+ from .expr_template_function import ExprTemplateFunction, ExprTemplate
10
10
  from .function import Function
11
11
  from .function_registry import FunctionRegistry
12
12
  from .globals import validate_symbol_path
@@ -21,13 +21,14 @@ def udf(decorated_fn: Callable) -> Function: ...
21
21
  # Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
22
22
  @overload
23
23
  def udf(
24
- *,
25
- batch_size: Optional[int] = None,
26
- substitute_fn: Optional[Callable] = None,
27
- is_method: bool = False,
28
- is_property: bool = False,
29
- _force_stored: bool = False
30
- ) -> Callable[[Callable], Function]: ...
24
+ *,
25
+ batch_size: Optional[int] = None,
26
+ substitute_fn: Optional[Callable] = None,
27
+ is_method: bool = False,
28
+ is_property: bool = False,
29
+ type_substitutions: Optional[Sequence[dict]] = None,
30
+ _force_stored: bool = False
31
+ ) -> Callable[[Callable], CallableFunction]: ...
31
32
 
32
33
 
33
34
  def udf(*args, **kwargs):
@@ -52,6 +53,7 @@ def udf(*args, **kwargs):
52
53
  substitute_fn = kwargs.pop('substitute_fn', None)
53
54
  is_method = kwargs.pop('is_method', None)
54
55
  is_property = kwargs.pop('is_property', None)
56
+ type_substitutions = kwargs.pop('type_substitutions', None)
55
57
  force_stored = kwargs.pop('_force_stored', False)
56
58
  if len(kwargs) > 0:
57
59
  raise excs.Error(f'Invalid @udf decorator kwargs: {", ".join(kwargs.keys())}')
@@ -65,6 +67,7 @@ def udf(*args, **kwargs):
65
67
  substitute_fn=substitute_fn,
66
68
  is_method=is_method,
67
69
  is_property=is_property,
70
+ type_substitutions=type_substitutions,
68
71
  force_stored=force_stored
69
72
  )
70
73
 
@@ -79,9 +82,10 @@ def make_function(
79
82
  substitute_fn: Optional[Callable] = None,
80
83
  is_method: bool = False,
81
84
  is_property: bool = False,
85
+ type_substitutions: Optional[Sequence[dict]] = None,
82
86
  function_name: Optional[str] = None,
83
87
  force_stored: bool = False
84
- ) -> Function:
88
+ ) -> CallableFunction:
85
89
  """
86
90
  Constructs a `CallableFunction` from the specified parameters.
87
91
  If `substitute_fn` is specified, then `decorated_fn`
@@ -104,25 +108,43 @@ def make_function(
104
108
  # Display name to use for error messages
105
109
  errmsg_name = function_name if function_path is None else function_path
106
110
 
107
- sig = Signature.create(decorated_fn, param_types, return_type)
108
-
109
- # batched functions must have a batched return type
110
- # TODO: remove 'Python' from the error messages when we have full inference with Annotated types
111
- if batch_size is not None and not sig.is_batched:
112
- raise excs.Error(f'{errmsg_name}(): batch_size is specified; Python return type must be a `Batch`')
113
- if batch_size is not None and len(sig.batched_parameters) == 0:
114
- raise excs.Error(f'{errmsg_name}(): batch_size is specified; at least one Python parameter must be `Batch`')
115
- if batch_size is None and len(sig.batched_parameters) > 0:
116
- raise excs.Error(f'{errmsg_name}(): batched parameters in udf, but no `batch_size` given')
117
-
118
- if is_method and is_property:
119
- raise excs.Error(f'Cannot specify both `is_method` and `is_property` (in function `{function_name}`)')
120
- if is_property and len(sig.parameters) != 1:
121
- raise excs.Error(
122
- f"`is_property=True` expects a UDF with exactly 1 parameter, but `{function_name}` has {len(sig.parameters)}"
123
- )
124
- if (is_method or is_property) and function_path is None:
125
- raise excs.Error('Stored functions cannot be declared using `is_method` or `is_property`')
111
+ signatures: list[Signature]
112
+ if type_substitutions is None:
113
+ sig = Signature.create(decorated_fn, param_types, return_type)
114
+
115
+ # batched functions must have a batched return type
116
+ # TODO: remove 'Python' from the error messages when we have full inference with Annotated types
117
+ if batch_size is not None and not sig.is_batched:
118
+ raise excs.Error(f'{errmsg_name}(): batch_size is specified; Python return type must be a `Batch`')
119
+ if batch_size is not None and len(sig.batched_parameters) == 0:
120
+ raise excs.Error(f'{errmsg_name}(): batch_size is specified; at least one Python parameter must be `Batch`')
121
+ if batch_size is None and len(sig.batched_parameters) > 0:
122
+ raise excs.Error(f'{errmsg_name}(): batched parameters in udf, but no `batch_size` given')
123
+
124
+ if is_method and is_property:
125
+ raise excs.Error(f'Cannot specify both `is_method` and `is_property` (in function `{function_name}`)')
126
+ if is_property and len(sig.parameters) != 1:
127
+ raise excs.Error(
128
+ f"`is_property=True` expects a UDF with exactly 1 parameter, but `{function_name}` has {len(sig.parameters)}"
129
+ )
130
+ if (is_method or is_property) and function_path is None:
131
+ raise excs.Error('Stored functions cannot be declared using `is_method` or `is_property`')
132
+
133
+ signatures = [sig]
134
+ else:
135
+ if function_path is None:
136
+ raise excs.Error(
137
+ f'{errmsg_name}(): type substitutions can only be used with module UDFs (not locally defined UDFs)'
138
+ )
139
+ if batch_size is not None:
140
+ raise excs.Error(f'{errmsg_name}(): type substitutions cannot be used with batched functions')
141
+ if is_method is not None or is_property is not None:
142
+ # TODO: Support this for `is_method`?
143
+ raise excs.Error(f'{errmsg_name}(): type substitutions cannot be used with `is_method` or `is_property`')
144
+ signatures = [
145
+ Signature.create(decorated_fn, param_types, return_type, type_substitutions=subst)
146
+ for subst in type_substitutions
147
+ ]
126
148
 
127
149
  if substitute_fn is None:
128
150
  py_fn = decorated_fn
@@ -132,8 +154,8 @@ def make_function(
132
154
  py_fn = substitute_fn
133
155
 
134
156
  result = CallableFunction(
135
- signature=sig,
136
- py_fn=py_fn,
157
+ signatures=signatures,
158
+ py_fns=[py_fn] * len(signatures), # All signatures share the same Python function
137
159
  self_path=function_path,
138
160
  self_name=function_name,
139
161
  batch_size=batch_size,
@@ -171,12 +193,12 @@ def expr_udf(*args: Any, **kwargs: Any) -> Any:
171
193
  import pixeltable.exprs as exprs
172
194
  var_exprs = [exprs.Variable(param.name, param.col_type) for param in sig.parameters.values()]
173
195
  # call the function with the parameter expressions to construct an Expr with parameters
174
- template = py_fn(*var_exprs)
175
- assert isinstance(template, exprs.Expr)
176
- sig.return_type = template.col_type
196
+ expr = py_fn(*var_exprs)
197
+ assert isinstance(expr, exprs.Expr)
198
+ sig.return_type = expr.col_type
177
199
  if function_path is not None:
178
200
  validate_symbol_path(function_path)
179
- return ExprTemplateFunction(template, sig, self_path=function_path, name=py_fn.__name__)
201
+ return ExprTemplateFunction([ExprTemplate(expr, sig)], self_path=function_path, name=py_fn.__name__)
180
202
 
181
203
  if len(args) == 1:
182
204
  assert len(kwargs) == 0 and callable(args[0])
@@ -1,7 +1,7 @@
1
1
  from pixeltable.utils.code import local_public_names
2
2
 
3
- from . import (anthropic, audio, fireworks, huggingface, image, json, llama_cpp, mistralai, ollama, openai, string,
4
- timestamp, together, video, vision, whisper)
3
+ from . import (anthropic, audio, fireworks, gemini, huggingface, image, json, llama_cpp, math, mistralai, ollama,
4
+ openai, string, timestamp, together, video, vision, whisper)
5
5
  from .globals import *
6
6
 
7
7
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
@@ -10,7 +10,9 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
10
10
  import tenacity
11
11
 
12
12
  import pixeltable as pxt
13
- from pixeltable import env
13
+ import pixeltable.exceptions as excs
14
+ from pixeltable import env, exprs
15
+ from pixeltable.func import Tools
14
16
  from pixeltable.utils.code import local_public_names
15
17
 
16
18
  if TYPE_CHECKING:
@@ -46,8 +48,8 @@ def messages(
46
48
  stop_sequences: Optional[list[str]] = None,
47
49
  system: Optional[str] = None,
48
50
  temperature: Optional[float] = None,
49
- tool_choice: Optional[list[dict]] = None,
50
- tools: Optional[dict] = None,
51
+ tool_choice: Optional[dict] = None,
52
+ tools: Optional[list[dict]] = None,
51
53
  top_k: Optional[int] = None,
52
54
  top_p: Optional[float] = None,
53
55
  ) -> dict:
@@ -77,6 +79,33 @@ def messages(
77
79
  >>> msgs = [{'role': 'user', 'content': tbl.prompt}]
78
80
  ... tbl['response'] = messages(msgs, model='claude-3-haiku-20240307')
79
81
  """
82
+ if tools is not None:
83
+ # Reformat `tools` into Anthropic format
84
+ tools = [
85
+ {
86
+ 'name': tool['name'],
87
+ 'description': tool['description'],
88
+ 'input_schema': {
89
+ 'type': 'object',
90
+ 'properties': tool['parameters']['properties'],
91
+ 'required': tool['required'],
92
+ },
93
+ }
94
+ for tool in tools
95
+ ]
96
+
97
+ tool_choice_: Optional[dict] = None
98
+ if tool_choice is not None:
99
+ if tool_choice['auto']:
100
+ tool_choice_ = {'type': 'auto'}
101
+ elif tool_choice['required']:
102
+ tool_choice_ = {'type': 'any'}
103
+ else:
104
+ assert tool_choice['tool'] is not None
105
+ tool_choice_ = {'type': 'tool', 'name': tool_choice['tool']}
106
+ if not tool_choice['parallel_tool_calls']:
107
+ tool_choice_['disable_parallel_tool_use'] = True
108
+
80
109
  return _retry(_anthropic_client().messages.create)(
81
110
  messages=messages,
82
111
  model=model,
@@ -85,13 +114,31 @@ def messages(
85
114
  stop_sequences=_opt(stop_sequences),
86
115
  system=_opt(system),
87
116
  temperature=_opt(temperature),
88
- tool_choice=_opt(tool_choice),
117
+ tool_choice=_opt(tool_choice_),
89
118
  tools=_opt(tools),
90
119
  top_k=_opt(top_k),
91
120
  top_p=_opt(top_p),
92
121
  ).dict()
93
122
 
94
123
 
124
+ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
125
+ """Converts an Anthropic response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
126
+ return tools._invoke(_anthropic_response_to_pxt_tool_calls(response))
127
+
128
+
129
+ @pxt.udf
130
+ def _anthropic_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
131
+ anthropic_tool_calls = [r for r in response['content'] if r['type'] == 'tool_use']
132
+ if len(anthropic_tool_calls) > 0:
133
+ return {
134
+ tool_call['name']: {
135
+ 'args': tool_call['input']
136
+ }
137
+ for tool_call in anthropic_tool_calls
138
+ }
139
+ return None
140
+
141
+
95
142
  _T = TypeVar('_T')
96
143
 
97
144
 
@@ -0,0 +1,85 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various endpoints from the Google Gemini API. In order to use them, you must
4
+ first `pip install google-generativeai` and configure your Gemini credentials, as described in
5
+ the [Working with Gemini](https://pixeltable.readme.io/docs/working-with-gemini) tutorial.
6
+ """
7
+
8
+ from typing import Optional
9
+
10
+ import pixeltable as pxt
11
+ from pixeltable import env
12
+
13
+
14
+ @env.register_client('gemini')
15
+ def _(api_key: str) -> None:
16
+ import google.generativeai as genai # type: ignore[import-untyped]
17
+ genai.configure(api_key=api_key)
18
+
19
+
20
+ def _ensure_loaded() -> None:
21
+ env.Env.get().get_client('gemini')
22
+
23
+
24
+ @pxt.udf
25
+ def generate_content(
26
+ contents: str,
27
+ *,
28
+ model_name: str,
29
+ candidate_count: Optional[int] = None,
30
+ stop_sequences: Optional[list[str]] = None,
31
+ max_output_tokens: Optional[int] = None,
32
+ temperature: Optional[float] = None,
33
+ top_p: Optional[float] = None,
34
+ top_k: Optional[int] = None,
35
+ response_mime_type: Optional[str] = None,
36
+ response_schema: Optional[dict] = None,
37
+ presence_penalty: Optional[float] = None,
38
+ frequency_penalty: Optional[float] = None,
39
+ response_logprobs: Optional[bool] = None,
40
+ logprobs: Optional[int] = None,
41
+ ) -> dict:
42
+ """
43
+ Generate content from the specified model. For additional details, see:
44
+ <https://ai.google.dev/gemini-api/docs>
45
+
46
+ __Requirements:__
47
+
48
+ - `pip install google-generativeai`
49
+
50
+ Args:
51
+ contents: The input content to generate from.
52
+ model_name: The name of the model to use.
53
+
54
+ For details on the other parameters, see: <https://ai.google.dev/gemini-api/docs>
55
+
56
+ Returns:
57
+ A dictionary containing the response and other metadata.
58
+
59
+ Examples:
60
+ Add a computed column that applies the model `gemini-1.5-flash`
61
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
62
+
63
+ >>> tbl['response'] = generate_content(tbl.prompt, model_name='gemini-1.5-flash')
64
+ """
65
+ env.Env.get().require_package('google.generativeai')
66
+ _ensure_loaded()
67
+ import google.generativeai as genai
68
+
69
+ model = genai.GenerativeModel(model_name=model_name)
70
+ gc = genai.GenerationConfig(
71
+ candidate_count=candidate_count,
72
+ stop_sequences=stop_sequences,
73
+ max_output_tokens=max_output_tokens,
74
+ temperature=temperature,
75
+ top_p=top_p,
76
+ top_k=top_k,
77
+ response_mime_type=response_mime_type,
78
+ response_schema=response_schema,
79
+ presence_penalty=presence_penalty,
80
+ frequency_penalty=frequency_penalty,
81
+ response_logprobs=response_logprobs,
82
+ logprobs=logprobs,
83
+ )
84
+ response = model.generate_content(contents, generation_config=gc)
85
+ return response.to_dict()
@@ -1,6 +1,7 @@
1
1
  import builtins
2
2
  from typing import _GenericAlias # type: ignore[attr-defined]
3
3
  from typing import Optional, Union
4
+ import typing
4
5
 
5
6
  import sqlalchemy as sql
6
7
 
@@ -16,23 +17,24 @@ def cast(expr: exprs.Expr, target_type: Union[ts.ColumnType, type, _GenericAlias
16
17
  return expr
17
18
 
18
19
 
19
- @func.uda(
20
- update_types=[ts.IntType(nullable=True)], value_type=ts.IntType(nullable=False),
21
- allows_window=True, requires_order_by=False)
22
- class sum(func.Aggregator):
20
+ T = typing.TypeVar('T')
21
+
22
+
23
+ @func.uda(allows_window=True, type_substitutions=({T: Optional[int]}, {T: Optional[float]})) # type: ignore[misc]
24
+ class sum(func.Aggregator, typing.Generic[T]):
23
25
  """Sums the selected integers or floats."""
24
26
  def __init__(self):
25
- self.sum: Optional[int] = None
27
+ self.sum: T = None
26
28
 
27
- def update(self, val: Optional[int]) -> None:
29
+ def update(self, val: T) -> None:
28
30
  if val is None:
29
31
  return
30
32
  if self.sum is None:
31
33
  self.sum = val
32
34
  else:
33
- self.sum += val
35
+ self.sum += val # type: ignore[operator]
34
36
 
35
- def value(self) -> Union[int, float]:
37
+ def value(self) -> T:
36
38
  return self.sum
37
39
 
38
40
 
@@ -43,12 +45,22 @@ def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
43
45
  return sql.sql.func.sum(val)
44
46
 
45
47
 
46
- @func.uda(update_types=[ts.IntType(nullable=True)], value_type=ts.IntType(), allows_window=True, requires_order_by=False)
47
- class count(func.Aggregator):
48
+ @func.uda(
49
+ allows_window=True,
50
+ # Allow counting non-null values of any type
51
+ # TODO: I couldn't include "Array" because we don't have a way to represent a generic array (of arbitrary dimension).
52
+ # TODO: should we have an "Any" type that can be used here?
53
+ type_substitutions=tuple(
54
+ {T: Optional[t]} # type: ignore[misc]
55
+ for t in (ts.String, ts.Int, ts.Float, ts.Bool, ts.Timestamp,
56
+ ts.Json, ts.Image, ts.Video, ts.Audio, ts.Document)
57
+ ),
58
+ )
59
+ class count(func.Aggregator, typing.Generic[T]):
48
60
  def __init__(self):
49
61
  self.count = 0
50
62
 
51
- def update(self, val: Optional[int]) -> None:
63
+ def update(self, val: T) -> None:
52
64
  if val is not None:
53
65
  self.count += 1
54
66
 
@@ -62,74 +74,82 @@ def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
62
74
 
63
75
 
64
76
  @func.uda(
65
- update_types=[ts.IntType(nullable=True)], value_type=ts.IntType(nullable=True), allows_window=True,
66
- requires_order_by=False)
67
- class min(func.Aggregator):
77
+ allows_window=True,
78
+ type_substitutions=tuple({T: Optional[t]} for t in (str, int, float, bool, ts.Timestamp)) # type: ignore[misc]
79
+ )
80
+ class min(func.Aggregator, typing.Generic[T]):
68
81
  def __init__(self):
69
- self.val: Optional[int] = None
82
+ self.val: T = None
70
83
 
71
- def update(self, val: Optional[int]) -> None:
84
+ def update(self, val: T) -> None:
72
85
  if val is None:
73
86
  return
74
87
  if self.val is None:
75
88
  self.val = val
76
89
  else:
77
- self.val = builtins.min(self.val, val)
90
+ self.val = builtins.min(self.val, val) # type: ignore[call-overload]
78
91
 
79
- def value(self) -> Optional[int]:
92
+ def value(self) -> T:
80
93
  return self.val
81
94
 
82
95
 
83
96
  @min.to_sql
84
97
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
98
+ if val.type.python_type == bool:
99
+ # TODO: min/max aggregation of booleans is not supported in Postgres (but it is in Python).
100
+ # Right now we simply force the computation to be done in Python; we might consider implementing an alternate
101
+ # way of doing it in SQL. (min/max of booleans is simply logical and/or, respectively.)
102
+ return None
85
103
  return sql.sql.func.min(val)
86
104
 
87
105
 
88
106
  @func.uda(
89
- update_types=[ts.IntType(nullable=True)], value_type=ts.IntType(nullable=True), allows_window=True,
90
- requires_order_by=False)
91
- class max(func.Aggregator):
107
+ allows_window=True,
108
+ type_substitutions=tuple({T: Optional[t]} for t in (str, int, float, bool, ts.Timestamp)) # type: ignore[misc]
109
+ )
110
+ class max(func.Aggregator, typing.Generic[T]):
92
111
  def __init__(self):
93
- self.val: Optional[int] = None
112
+ self.val: T = None
94
113
 
95
- def update(self, val: Optional[int]) -> None:
114
+ def update(self, val: T) -> None:
96
115
  if val is None:
97
116
  return
98
117
  if self.val is None:
99
118
  self.val = val
100
119
  else:
101
- self.val = builtins.max(self.val, val)
120
+ self.val = builtins.max(self.val, val) # type: ignore[call-overload]
102
121
 
103
- def value(self) -> Optional[int]:
122
+ def value(self) -> T:
104
123
  return self.val
105
124
 
106
125
 
107
126
  @max.to_sql
108
127
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
128
+ if val.type.python_type == bool:
129
+ # TODO: see comment in @min.to_sql.
130
+ return None
109
131
  return sql.sql.func.max(val)
110
132
 
111
133
 
112
- @func.uda(
113
- update_types=[ts.IntType(nullable=True)], value_type=ts.FloatType(nullable=True), allows_window=False,
114
- requires_order_by=False)
115
- class mean(func.Aggregator):
134
+ @func.uda(type_substitutions=({T: Optional[int]}, {T: Optional[float]})) # type: ignore[misc]
135
+ class mean(func.Aggregator, typing.Generic[T]):
116
136
  def __init__(self):
117
- self.sum: Optional[int] = None
137
+ self.sum: T = None
118
138
  self.count = 0
119
139
 
120
- def update(self, val: Optional[int]) -> None:
140
+ def update(self, val: T) -> None:
121
141
  if val is None:
122
142
  return
123
143
  if self.sum is None:
124
144
  self.sum = val
125
145
  else:
126
- self.sum += val
146
+ self.sum += val # type: ignore[operator]
127
147
  self.count += 1
128
148
 
129
- def value(self) -> Optional[float]:
149
+ def value(self) -> Optional[float]: # Always a float
130
150
  if self.count == 0:
131
151
  return None
132
- return self.sum / self.count
152
+ return self.sum / self.count # type: ignore[operator]
133
153
 
134
154
 
135
155
  @mean.to_sql