pixeltable 0.4.0rc2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (59) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +9 -1
  4. pixeltable/catalog/catalog.py +333 -99
  5. pixeltable/catalog/column.py +28 -26
  6. pixeltable/catalog/globals.py +12 -0
  7. pixeltable/catalog/insertable_table.py +8 -8
  8. pixeltable/catalog/schema_object.py +6 -0
  9. pixeltable/catalog/table.py +111 -116
  10. pixeltable/catalog/table_version.py +36 -50
  11. pixeltable/catalog/table_version_handle.py +4 -1
  12. pixeltable/catalog/table_version_path.py +28 -4
  13. pixeltable/catalog/view.py +10 -18
  14. pixeltable/config.py +4 -0
  15. pixeltable/dataframe.py +10 -9
  16. pixeltable/env.py +5 -11
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/exec_node.py +2 -0
  19. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  20. pixeltable/exec/sql_node.py +47 -30
  21. pixeltable/exprs/column_property_ref.py +2 -1
  22. pixeltable/exprs/column_ref.py +7 -6
  23. pixeltable/exprs/expr.py +4 -4
  24. pixeltable/func/__init__.py +1 -0
  25. pixeltable/func/mcp.py +74 -0
  26. pixeltable/func/query_template_function.py +4 -2
  27. pixeltable/func/tools.py +12 -2
  28. pixeltable/func/udf.py +2 -2
  29. pixeltable/functions/__init__.py +1 -0
  30. pixeltable/functions/anthropic.py +19 -45
  31. pixeltable/functions/deepseek.py +19 -38
  32. pixeltable/functions/fireworks.py +9 -18
  33. pixeltable/functions/gemini.py +2 -2
  34. pixeltable/functions/groq.py +108 -0
  35. pixeltable/functions/huggingface.py +8 -6
  36. pixeltable/functions/llama_cpp.py +6 -6
  37. pixeltable/functions/mistralai.py +16 -53
  38. pixeltable/functions/ollama.py +1 -1
  39. pixeltable/functions/openai.py +82 -170
  40. pixeltable/functions/replicate.py +2 -2
  41. pixeltable/functions/together.py +22 -80
  42. pixeltable/functions/util.py +6 -1
  43. pixeltable/globals.py +0 -2
  44. pixeltable/io/external_store.py +2 -2
  45. pixeltable/io/label_studio.py +4 -4
  46. pixeltable/io/table_data_conduit.py +1 -1
  47. pixeltable/metadata/__init__.py +1 -1
  48. pixeltable/metadata/converters/convert_37.py +15 -0
  49. pixeltable/metadata/notes.py +1 -0
  50. pixeltable/metadata/schema.py +5 -0
  51. pixeltable/plan.py +37 -121
  52. pixeltable/share/packager.py +2 -2
  53. pixeltable/type_system.py +30 -0
  54. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/METADATA +1 -1
  55. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/RECORD +58 -56
  56. pixeltable/utils/sample.py +0 -25
  57. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/LICENSE +0 -0
  58. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/WHEEL +0 -0
  59. {pixeltable-0.4.0rc2.dist-info → pixeltable-0.4.1.dist-info}/entry_points.txt +0 -0
@@ -64,8 +64,9 @@ class ColumnPropertyRef(Expr):
64
64
  # perform runtime checks and update state
65
65
  tv = self._col_ref.tbl_version.get()
66
66
  assert tv.is_validated
67
+ # we can assume at this point during query execution that the column exists
68
+ assert self._col_ref.col_id in tv.cols_by_id
67
69
  col = tv.cols_by_id[self._col_ref.col_id]
68
- # TODO: check for column being dropped
69
70
 
70
71
  # the errortype/-msg properties of a read-validated media column need to be extracted from the DataRow
71
72
  if (
@@ -239,7 +239,6 @@ class ColumnRef(Expr):
239
239
  return helper
240
240
 
241
241
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
242
- # return None if self.perform_validation else self.col.sa_col
243
242
  if self.perform_validation:
244
243
  return None
245
244
  # we need to reestablish that we have the correct Column instance, there could have been a metadata
@@ -248,13 +247,10 @@ class ColumnRef(Expr):
248
247
  # perform runtime checks and update state
249
248
  tv = self.tbl_version.get()
250
249
  assert tv.is_validated
250
+ # we can assume at this point during query execution that the column exists
251
+ assert self.col_id in tv.cols_by_id
251
252
  self.col = tv.cols_by_id[self.col_id]
252
253
  assert self.col.tbl is tv
253
- # TODO: check for column being dropped
254
- # print(
255
- # f'ColumnRef.sql_expr: tbl={tv.id}:{tv.effective_version} sa_tbl={id(self.col.tbl.store_tbl.sa_tbl):x} '
256
- # f'tv={id(tv):x}'
257
- # )
258
254
  return self.col.sa_col
259
255
 
260
256
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -315,6 +311,11 @@ class ColumnRef(Expr):
315
311
  'perform_validation': self.perform_validation,
316
312
  }
317
313
 
314
+ @classmethod
315
+ def get_column_id(cls, d: dict) -> catalog.QColumnId:
316
+ tbl_id, col_id = UUID(d['tbl_id']), d['col_id']
317
+ return catalog.QColumnId(tbl_id, col_id)
318
+
318
319
  @classmethod
319
320
  def get_column(cls, d: dict) -> catalog.Column:
320
321
  tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
pixeltable/exprs/expr.py CHANGED
@@ -394,17 +394,17 @@ class Expr(abc.ABC):
394
394
  return {tbl_id for e in exprs_ for tbl_id in e.tbl_ids()}
395
395
 
396
396
  @classmethod
397
- def get_refd_columns(cls, expr_dict: dict[str, Any]) -> list[catalog.Column]:
397
+ def get_refd_column_ids(cls, expr_dict: dict[str, Any]) -> set[catalog.QColumnId]:
398
398
  """Return Columns referenced by expr_dict."""
399
- result: list[catalog.Column] = []
399
+ result: set[catalog.QColumnId] = set()
400
400
  assert '_classname' in expr_dict
401
401
  from .column_ref import ColumnRef
402
402
 
403
403
  if expr_dict['_classname'] == 'ColumnRef':
404
- result.append(ColumnRef.get_column(expr_dict))
404
+ result.add(ColumnRef.get_column_id(expr_dict))
405
405
  if 'components' in expr_dict:
406
406
  for component_dict in expr_dict['components']:
407
- result.extend(cls.get_refd_columns(component_dict))
407
+ result.update(cls.get_refd_column_ids(component_dict))
408
408
  return result
409
409
 
410
410
  def as_literal(self) -> Optional[Expr]:
@@ -5,6 +5,7 @@ from .callable_function import CallableFunction
5
5
  from .expr_template_function import ExprTemplateFunction
6
6
  from .function import Function, InvalidFunction
7
7
  from .function_registry import FunctionRegistry
8
+ from .mcp import mcp_udfs
8
9
  from .query_template_function import QueryTemplateFunction, query, retrieval_udf
9
10
  from .signature import Batch, Parameter, Signature
10
11
  from .tools import Tool, ToolChoice, Tools
pixeltable/func/mcp.py ADDED
@@ -0,0 +1,74 @@
1
+ import asyncio
2
+ import inspect
3
+ from typing import TYPE_CHECKING, Any, Optional
4
+
5
+ import pixeltable as pxt
6
+ from pixeltable import exceptions as excs, type_system as ts
7
+ from pixeltable.func.signature import Parameter
8
+
9
+ if TYPE_CHECKING:
10
+ import mcp
11
+
12
+
13
+ def mcp_udfs(url: str) -> list['pxt.func.Function']:
14
+ return asyncio.run(mcp_udfs_async(url))
15
+
16
+
17
+ async def mcp_udfs_async(url: str) -> list['pxt.func.Function']:
18
+ import mcp
19
+ from mcp.client.streamable_http import streamablehttp_client
20
+
21
+ list_tools_result: Optional[mcp.types.ListToolsResult] = None
22
+ async with (
23
+ streamablehttp_client(url) as (read_stream, write_stream, _),
24
+ mcp.ClientSession(read_stream, write_stream) as session,
25
+ ):
26
+ await session.initialize()
27
+ list_tools_result = await session.list_tools()
28
+ assert list_tools_result is not None
29
+
30
+ return [mcp_tool_to_udf(url, tool) for tool in list_tools_result.tools]
31
+
32
+
33
+ def mcp_tool_to_udf(url: str, mcp_tool: 'mcp.types.Tool') -> 'pxt.func.Function':
34
+ import mcp
35
+ from mcp.client.streamable_http import streamablehttp_client
36
+
37
+ async def invoke(**kwargs: Any) -> str:
38
+ # TODO: Cache session objects rather than creating a new one each time?
39
+ async with (
40
+ streamablehttp_client(url) as (read_stream, write_stream, _),
41
+ mcp.ClientSession(read_stream, write_stream) as session,
42
+ ):
43
+ await session.initialize()
44
+ res = await session.call_tool(name=mcp_tool.name, arguments=kwargs)
45
+ # TODO Handle image/audio responses?
46
+ return res.content[0].text # type: ignore[union-attr]
47
+
48
+ if mcp_tool.description is not None:
49
+ invoke.__doc__ = mcp_tool.description
50
+
51
+ input_schema = mcp_tool.inputSchema
52
+ params = {
53
+ name: __mcp_param_to_pxt_type(mcp_tool.name, name, param) for name, param in input_schema['properties'].items()
54
+ }
55
+ required = input_schema.get('required', [])
56
+
57
+ # Ensure that any params not appearing in `required` are nullable.
58
+ # (A required param might or might not be nullable, since its type might be an 'anyOf' containing a null.)
59
+ for name in params.keys() - required:
60
+ params[name] = params[name].copy(nullable=True)
61
+
62
+ signature = pxt.func.Signature(
63
+ return_type=ts.StringType(), # Return type is always string
64
+ parameters=[Parameter(name, col_type, inspect.Parameter.KEYWORD_ONLY) for name, col_type in params.items()],
65
+ )
66
+
67
+ return pxt.func.CallableFunction(signatures=[signature], py_fns=[invoke], self_name=mcp_tool.name)
68
+
69
+
70
+ def __mcp_param_to_pxt_type(tool_name: str, name: str, param: dict[str, Any]) -> ts.ColumnType:
71
+ pxt_type = ts.ColumnType.from_json_schema(param)
72
+ if pxt_type is None:
73
+ raise excs.Error(f'Unknown type schema for MCP parameter {name!r} of tool {tool_name!r}: {param}')
74
+ return pxt_type
@@ -157,11 +157,13 @@ def retrieval_udf(
157
157
  """
158
158
  # Argument validation
159
159
  col_refs: list[exprs.ColumnRef]
160
+ # TODO: get rid of references to ColumnRef internals and replace instead with a public interface
161
+ col_names = table.columns()
160
162
  if parameters is None:
161
- col_refs = [table[col_name] for col_name in table.columns if not table[col_name].col.is_computed]
163
+ col_refs = [table[col_name] for col_name in col_names if not table[col_name].col.is_computed]
162
164
  else:
163
165
  for param in parameters:
164
- if isinstance(param, str) and param not in table.columns:
166
+ if isinstance(param, str) and param not in col_names:
165
167
  raise excs.Error(f'The specified parameter {param!r} is not a column of the table {table._path()!r}')
166
168
  col_refs = [table[param] if isinstance(param, str) else param for param in parameters]
167
169
 
pixeltable/func/tools.py CHANGED
@@ -1,8 +1,9 @@
1
+ import json
1
2
  from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
2
3
 
3
4
  import pydantic
4
5
 
5
- from pixeltable import exceptions as excs
6
+ from pixeltable import exceptions as excs, type_system as ts
6
7
 
7
8
  from .function import Function
8
9
  from .signature import Parameter
@@ -69,7 +70,9 @@ class Tool(pydantic.BaseModel):
69
70
  return _extract_float_tool_arg(kwargs, param_name=param.name)
70
71
  if param.col_type.is_bool_type():
71
72
  return _extract_bool_tool_arg(kwargs, param_name=param.name)
72
- raise AssertionError()
73
+ if param.col_type.is_json_type():
74
+ return _extract_json_tool_arg(kwargs, param_name=param.name)
75
+ raise AssertionError(param.col_type)
73
76
 
74
77
 
75
78
  class ToolChoice(pydantic.BaseModel):
@@ -137,6 +140,13 @@ def _extract_bool_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[
137
140
  return _extract_arg(bool, kwargs, param_name)
138
141
 
139
142
 
143
+ @udf
144
+ def _extract_json_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[ts.Json]:
145
+ if param_name in kwargs:
146
+ return json.loads(kwargs[param_name])
147
+ return None
148
+
149
+
140
150
  T = TypeVar('T')
141
151
 
142
152
 
pixeltable/func/udf.py CHANGED
@@ -262,7 +262,7 @@ def from_table(
262
262
  """
263
263
  from pixeltable import exprs
264
264
 
265
- ancestors = [tbl, *tbl._base_tables]
265
+ ancestors = [tbl, *tbl._get_base_tables()]
266
266
  ancestors.reverse() # We must traverse the ancestors in order from base to derived
267
267
 
268
268
  subst: dict[exprs.Expr, exprs.Expr] = {}
@@ -297,7 +297,7 @@ def from_table(
297
297
 
298
298
  if description is None:
299
299
  # Default description is the table comment
300
- description = tbl._comment
300
+ description = tbl._get_comment()
301
301
  if len(description) == 0:
302
302
  description = f"UDF for table '{tbl._name}'"
303
303
 
@@ -10,6 +10,7 @@ from . import (
10
10
  deepseek,
11
11
  fireworks,
12
12
  gemini,
13
+ groq,
13
14
  huggingface,
14
15
  image,
15
16
  json,
@@ -8,7 +8,7 @@ the [Working with Anthropic](https://pixeltable.readme.io/docs/working-with-anth
8
8
  import datetime
9
9
  import json
10
10
  import logging
11
- from typing import TYPE_CHECKING, Any, Iterable, Optional, TypeVar, Union, cast
11
+ from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
12
12
 
13
13
  import httpx
14
14
 
@@ -73,16 +73,10 @@ async def messages(
73
73
  messages: list[dict[str, str]],
74
74
  *,
75
75
  model: str,
76
- max_tokens: int = 1024,
77
- metadata: Optional[dict[str, Any]] = None,
78
- stop_sequences: Optional[list[str]] = None,
79
- system: Optional[str] = None,
80
- temperature: Optional[float] = None,
81
- tool_choice: Optional[dict] = None,
82
- tools: Optional[list[dict]] = None,
83
- top_k: Optional[int] = None,
84
- top_p: Optional[float] = None,
85
- timeout: Optional[float] = None,
76
+ max_tokens: int,
77
+ model_kwargs: Optional[dict[str, Any]] = None,
78
+ tools: Optional[list[dict[str, Any]]] = None,
79
+ tool_choice: Optional[dict[str, Any]] = None,
86
80
  ) -> dict:
87
81
  """
88
82
  Create a Message.
@@ -101,25 +95,27 @@ async def messages(
101
95
  Args:
102
96
  messages: Input messages.
103
97
  model: The model that will complete your prompt.
104
-
105
- For details on the other parameters, see: <https://docs.anthropic.com/en/api/messages>
98
+ model_kwargs: Additional keyword args for the Anthropic `messages` API.
99
+ For details on the available parameters, see: <https://docs.anthropic.com/en/api/messages>
100
+ tools: An optional list of Pixeltable tools to use for the request.
101
+ tool_choice: An optional tool choice configuration.
106
102
 
107
103
  Returns:
108
104
  A dictionary containing the response and other metadata.
109
105
 
110
106
  Examples:
111
- Add a computed column that applies the model `claude-3-haiku-20240307`
107
+ Add a computed column that applies the model `claude-3-5-sonnet-20241022`
112
108
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
113
109
 
114
110
  >>> msgs = [{'role': 'user', 'content': tbl.prompt}]
115
- ... tbl.add_computed_column(response=messages(msgs, model='claude-3-haiku-20240307'))
111
+ ... tbl.add_computed_column(response=messages(msgs, model='claude-3-5-sonnet-20241022'))
116
112
  """
117
-
118
- # it doesn't look like count_tokens() actually exists in the current version of the library
113
+ if model_kwargs is None:
114
+ model_kwargs = {}
119
115
 
120
116
  if tools is not None:
121
117
  # Reformat `tools` into Anthropic format
122
- tools = [
118
+ model_kwargs['tools'] = [
123
119
  {
124
120
  'name': tool['name'],
125
121
  'description': tool['description'],
@@ -132,17 +128,16 @@ async def messages(
132
128
  for tool in tools
133
129
  ]
134
130
 
135
- tool_choice_: Optional[dict] = None
136
131
  if tool_choice is not None:
137
132
  if tool_choice['auto']:
138
- tool_choice_ = {'type': 'auto'}
133
+ model_kwargs['tool_choice'] = {'type': 'auto'}
139
134
  elif tool_choice['required']:
140
- tool_choice_ = {'type': 'any'}
135
+ model_kwargs['tool_choice'] = {'type': 'any'}
141
136
  else:
142
137
  assert tool_choice['tool'] is not None
143
- tool_choice_ = {'type': 'tool', 'name': tool_choice['tool']}
138
+ model_kwargs['tool_choice'] = {'type': 'tool', 'name': tool_choice['tool']}
144
139
  if not tool_choice['parallel_tool_calls']:
145
- tool_choice_['disable_parallel_tool_use'] = True
140
+ model_kwargs['tool_choice']['disable_parallel_tool_use'] = True
146
141
 
147
142
  # make sure the pool info exists prior to making the request
148
143
  resource_pool_id = f'rate-limits:anthropic:{model}'
@@ -152,20 +147,8 @@ async def messages(
152
147
  # TODO: timeouts should be set system-wide and be user-configurable
153
148
  from anthropic.types import MessageParam
154
149
 
155
- # cast(Any, ...): avoid mypy errors
156
150
  result = await _anthropic_client().messages.with_raw_response.create(
157
- messages=cast(Iterable[MessageParam], messages),
158
- model=model,
159
- max_tokens=max_tokens,
160
- metadata=_opt(cast(Any, metadata)),
161
- stop_sequences=_opt(stop_sequences),
162
- system=_opt(system),
163
- temperature=_opt(cast(Any, temperature)),
164
- tools=_opt(cast(Any, tools)),
165
- tool_choice=_opt(cast(Any, tool_choice_)),
166
- top_k=_opt(top_k),
167
- top_p=_opt(top_p),
168
- timeout=_opt(timeout),
151
+ messages=cast(Iterable[MessageParam], messages), model=model, max_tokens=max_tokens, **model_kwargs
169
152
  )
170
153
 
171
154
  requests_limit_str = result.headers.get('anthropic-ratelimit-requests-limit')
@@ -224,15 +207,6 @@ def _anthropic_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
224
207
  return pxt_tool_calls
225
208
 
226
209
 
227
- _T = TypeVar('_T')
228
-
229
-
230
- def _opt(arg: _T) -> Union[_T, 'anthropic.NotGiven']:
231
- import anthropic
232
-
233
- return arg if arg is not None else anthropic.NOT_GIVEN
234
-
235
-
236
210
  __all__ = local_public_names(__name__)
237
211
 
238
212
 
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
2
+ from typing import TYPE_CHECKING, Any, Optional
3
3
 
4
4
  import httpx
5
5
 
@@ -7,8 +7,6 @@ import pixeltable as pxt
7
7
  from pixeltable import env
8
8
  from pixeltable.utils.code import local_public_names
9
9
 
10
- from .openai import _opt
11
-
12
10
  if TYPE_CHECKING:
13
11
  import openai
14
12
 
@@ -33,17 +31,9 @@ async def chat_completions(
33
31
  messages: list,
34
32
  *,
35
33
  model: str,
36
- frequency_penalty: Optional[float] = None,
37
- logprobs: Optional[bool] = None,
38
- top_logprobs: Optional[int] = None,
39
- max_tokens: Optional[int] = None,
40
- presence_penalty: Optional[float] = None,
41
- response_format: Optional[dict] = None,
42
- stop: Optional[list[str]] = None,
43
- temperature: Optional[float] = None,
44
- tools: Optional[list[dict]] = None,
45
- tool_choice: Optional[dict] = None,
46
- top_p: Optional[float] = None,
34
+ model_kwargs: Optional[dict[str, Any]] = None,
35
+ tools: Optional[list[dict[str, Any]]] = None,
36
+ tool_choice: Optional[dict[str, Any]] = None,
47
37
  ) -> dict:
48
38
  """
49
39
  Creates a model response for the given chat conversation.
@@ -60,8 +50,10 @@ async def chat_completions(
60
50
  Args:
61
51
  messages: A list of messages to use for chat completion, as described in the Deepseek API documentation.
62
52
  model: The model to use for chat completion.
63
-
64
- For details on the other parameters, see: <https://api-docs.deepseek.com/api/create-chat-completion>
53
+ model_kwargs: Additional keyword args for the Deepseek `chat/completions` API.
54
+ For details on the available parameters, see: <https://api-docs.deepseek.com/api/create-chat-completion>
55
+ tools: An optional list of Pixeltable tools to use for the request.
56
+ tool_choice: An optional tool choice configuration.
65
57
 
66
58
  Returns:
67
59
  A dictionary containing the response and other metadata.
@@ -76,39 +68,28 @@ async def chat_completions(
76
68
  ]
77
69
  tbl.add_computed_column(response=chat_completions(messages, model='deepseek-chat'))
78
70
  """
71
+ if model_kwargs is None:
72
+ model_kwargs = {}
73
+
79
74
  if tools is not None:
80
- tools = [{'type': 'function', 'function': tool} for tool in tools]
75
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
81
76
 
82
- tool_choice_: Union[str, dict, None] = None
83
77
  if tool_choice is not None:
84
78
  if tool_choice['auto']:
85
- tool_choice_ = 'auto'
79
+ model_kwargs['tool_choice'] = 'auto'
86
80
  elif tool_choice['required']:
87
- tool_choice_ = 'required'
81
+ model_kwargs['tool_choice'] = 'required'
88
82
  else:
89
83
  assert tool_choice['tool'] is not None
90
- tool_choice_ = {'type': 'function', 'function': {'name': tool_choice['tool']}}
84
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
91
85
 
92
- extra_body: Optional[dict[str, Any]] = None
93
86
  if tool_choice is not None and not tool_choice['parallel_tool_calls']:
94
- extra_body = {'parallel_tool_calls': False}
87
+ if 'extra_body' not in model_kwargs:
88
+ model_kwargs['extra_body'] = {}
89
+ model_kwargs['extra_body']['parallel_tool_calls'] = False
95
90
 
96
- # cast(Any, ...): avoid mypy errors
97
91
  result = await _deepseek_client().chat.completions.with_raw_response.create(
98
- messages=messages,
99
- model=model,
100
- frequency_penalty=_opt(frequency_penalty),
101
- logprobs=_opt(logprobs),
102
- top_logprobs=_opt(top_logprobs),
103
- max_tokens=_opt(max_tokens),
104
- presence_penalty=_opt(presence_penalty),
105
- response_format=_opt(cast(Any, response_format)),
106
- stop=_opt(stop),
107
- temperature=_opt(temperature),
108
- tools=_opt(cast(Any, tools)),
109
- tool_choice=_opt(cast(Any, tool_choice_)),
110
- top_p=_opt(top_p),
111
- extra_body=extra_body,
92
+ messages=messages, model=model, **model_kwargs
112
93
  )
113
94
 
114
95
  return json.loads(result.text)
@@ -5,7 +5,7 @@ first `pip install fireworks-ai` and configure your Fireworks AI credentials, as
5
5
  the [Working with Fireworks](https://pixeltable.readme.io/docs/working-with-fireworks) tutorial.
6
6
  """
7
7
 
8
- from typing import TYPE_CHECKING, Optional
8
+ from typing import TYPE_CHECKING, Any, Optional
9
9
 
10
10
  import pixeltable as pxt
11
11
  from pixeltable import env
@@ -29,14 +29,7 @@ def _fireworks_client() -> 'fireworks.client.Fireworks':
29
29
 
30
30
  @pxt.udf(resource_pool='request-rate:fireworks')
31
31
  async def chat_completions(
32
- messages: list[dict[str, str]],
33
- *,
34
- model: str,
35
- max_tokens: Optional[int] = None,
36
- top_k: Optional[int] = None,
37
- top_p: Optional[float] = None,
38
- temperature: Optional[float] = None,
39
- request_timeout: Optional[int] = None,
32
+ messages: list[dict[str, str]], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
40
33
  ) -> dict:
41
34
  """
42
35
  Creates a model response for the given chat conversation.
@@ -55,8 +48,8 @@ async def chat_completions(
55
48
  Args:
56
49
  messages: A list of messages comprising the conversation so far.
57
50
  model: The name of the model to use.
58
-
59
- For details on the other parameters, see: <https://docs.fireworks.ai/api-reference/post-chatcompletions>
51
+ model_kwargs: Additional keyword args for the Fireworks `chat_completions` API. For details on the available
52
+ parameters, see: <https://docs.fireworks.ai/api-reference/post-chatcompletions>
60
53
 
61
54
  Returns:
62
55
  A dictionary containing the response and other metadata.
@@ -70,20 +63,18 @@ async def chat_completions(
70
63
  ... response=chat_completions(messages, model='accounts/fireworks/models/mixtral-8x22b-instruct')
71
64
  ... )
72
65
  """
73
- kwargs = {'max_tokens': max_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature}
74
- kwargs_not_none = {k: v for k, v in kwargs.items() if v is not None}
66
+ if model_kwargs is None:
67
+ model_kwargs = {}
75
68
 
76
69
  # for debugging purposes:
77
70
  # res_sync = _fireworks_client().chat.completions.create(model=model, messages=messages, **kwargs_not_none)
78
71
  # res_sync_dict = res_sync.dict()
79
72
 
80
- if request_timeout is None:
81
- request_timeout = Config.get().get_int_value('timeout', section='fireworks') or 600
73
+ if 'request_timeout' not in model_kwargs:
74
+ model_kwargs['request_timeout'] = Config.get().get_int_value('timeout', section='fireworks') or 600
82
75
  # TODO: this timeout doesn't really work, I think it only applies to returning the stream, but not to the timing
83
76
  # of the chunks; addressing this would require a timeout for the task running this udf
84
- stream = _fireworks_client().chat.completions.acreate(
85
- model=model, messages=messages, request_timeout=request_timeout, **kwargs_not_none
86
- )
77
+ stream = _fireworks_client().chat.completions.acreate(model=model, messages=messages, **model_kwargs)
87
78
  chunks = []
88
79
  async for chunk in stream:
89
80
  chunks.append(chunk)
@@ -53,8 +53,8 @@ async def generate_content(
53
53
  config: Configuration for generation, corresponding to keyword arguments of
54
54
  `genai.types.GenerateContentConfig`. For details on the parameters, see:
55
55
  <https://googleapis.github.io/python-genai/genai.html#module-genai.types>
56
- tools: Optional list of Pixeltable tools to use. It is also possible to specify tools manually via the
57
- `config.tools` parameter, but at most one of `config.tools` or `tools` may be used.
56
+ tools: An optional list of Pixeltable tools to use. It is also possible to specify tools manually via the
57
+ `config['tools']` parameter, but at most one of `config['tools']` or `tools` may be used.
58
58
 
59
59
  Returns:
60
60
  A dictionary containing the response and other metadata.
@@ -0,0 +1,108 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various endpoints from the Groq API. In order to use them, you must
4
+ first `pip install groq` and configure your Groq credentials, as described in
5
+ the [Working with Groq](https://pixeltable.readme.io/docs/working-with-groq) tutorial.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, Optional
9
+
10
+ import pixeltable as pxt
11
+ from pixeltable import exprs
12
+ from pixeltable.env import Env, register_client
13
+ from pixeltable.utils.code import local_public_names
14
+
15
+ from .openai import _openai_response_to_pxt_tool_calls
16
+
17
+ if TYPE_CHECKING:
18
+ import groq
19
+
20
+
21
+ @register_client('groq')
22
+ def _(api_key: str) -> 'groq.AsyncGroq':
23
+ import groq
24
+
25
+ return groq.AsyncGroq(api_key=api_key)
26
+
27
+
28
+ def _groq_client() -> 'groq.AsyncGroq':
29
+ return Env.get().get_client('groq')
30
+
31
+
32
+ @pxt.udf(resource_pool='request-rate:groq')
33
+ async def chat_completions(
34
+ messages: list[dict[str, str]],
35
+ *,
36
+ model: str,
37
+ model_kwargs: Optional[dict[str, Any]] = None,
38
+ tools: Optional[list[dict[str, Any]]] = None,
39
+ tool_choice: Optional[dict[str, Any]] = None,
40
+ ) -> dict:
41
+ """
42
+ Chat Completion API.
43
+
44
+ Equivalent to the Groq `chat/completions` API endpoint.
45
+ For additional details, see: <https://console.groq.com/docs/api-reference#chat-create>
46
+
47
+ Request throttling:
48
+ Applies the rate limit set in the config (section `groq`, key `rate_limit`). If no rate
49
+ limit is configured, uses a default of 600 RPM.
50
+
51
+ __Requirements:__
52
+
53
+ - `pip install groq`
54
+
55
+ Args:
56
+ messages: A list of messages comprising the conversation so far.
57
+ model: ID of the model to use. (See overview here: <https://console.groq.com/docs/models>)
58
+ model_kwargs: Additional keyword args for the Groq `chat/completions` API.
59
+ For details on the available parameters, see: <https://console.groq.com/docs/api-reference#chat-create>
60
+
61
+ Returns:
62
+ A dictionary containing the response and other metadata.
63
+
64
+ Examples:
65
+ Add a computed column that applies the model `llama3-8b-8192`
66
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
67
+
68
+ >>> messages = [{'role': 'user', 'content': tbl.prompt}]
69
+ ... tbl.add_computed_column(response=chat_completions(messages, model='llama3-8b-8192'))
70
+ """
71
+ if model_kwargs is None:
72
+ model_kwargs = {}
73
+
74
+ Env.get().require_package('groq')
75
+
76
+ if tools is not None:
77
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
78
+
79
+ if tool_choice is not None:
80
+ if tool_choice['auto']:
81
+ model_kwargs['tool_choice'] = 'auto'
82
+ elif tool_choice['required']:
83
+ model_kwargs['tool_choice'] = 'required'
84
+ else:
85
+ assert tool_choice['tool'] is not None
86
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
87
+
88
+ if tool_choice is not None and not tool_choice['parallel_tool_calls']:
89
+ model_kwargs['parallel_tool_calls'] = False
90
+
91
+ result = await _groq_client().chat.completions.create(
92
+ messages=messages, # type: ignore[arg-type]
93
+ model=model,
94
+ **model_kwargs,
95
+ )
96
+ return result.model_dump()
97
+
98
+
99
+ def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDict:
100
+ """Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
101
+ return tools._invoke(_openai_response_to_pxt_tool_calls(response))
102
+
103
+
104
+ __all__ = local_public_names(__name__)
105
+
106
+
107
+ def __dir__() -> list[str]:
108
+ return __all__