pixeltable 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (78) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +9 -1
  4. pixeltable/catalog/catalog.py +559 -134
  5. pixeltable/catalog/column.py +36 -32
  6. pixeltable/catalog/dir.py +1 -2
  7. pixeltable/catalog/globals.py +12 -0
  8. pixeltable/catalog/insertable_table.py +30 -25
  9. pixeltable/catalog/schema_object.py +9 -6
  10. pixeltable/catalog/table.py +334 -267
  11. pixeltable/catalog/table_version.py +358 -241
  12. pixeltable/catalog/table_version_handle.py +18 -2
  13. pixeltable/catalog/table_version_path.py +86 -16
  14. pixeltable/catalog/view.py +47 -23
  15. pixeltable/dataframe.py +198 -19
  16. pixeltable/env.py +6 -4
  17. pixeltable/exceptions.py +6 -0
  18. pixeltable/exec/__init__.py +1 -1
  19. pixeltable/exec/exec_node.py +2 -0
  20. pixeltable/exec/expr_eval/evaluators.py +4 -1
  21. pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
  22. pixeltable/exec/in_memory_data_node.py +1 -1
  23. pixeltable/exec/sql_node.py +188 -22
  24. pixeltable/exprs/column_property_ref.py +16 -6
  25. pixeltable/exprs/column_ref.py +33 -11
  26. pixeltable/exprs/comparison.py +1 -1
  27. pixeltable/exprs/data_row.py +5 -3
  28. pixeltable/exprs/expr.py +11 -4
  29. pixeltable/exprs/literal.py +2 -0
  30. pixeltable/exprs/row_builder.py +4 -6
  31. pixeltable/exprs/rowid_ref.py +8 -0
  32. pixeltable/exprs/similarity_expr.py +1 -0
  33. pixeltable/func/__init__.py +1 -0
  34. pixeltable/func/mcp.py +74 -0
  35. pixeltable/func/query_template_function.py +5 -3
  36. pixeltable/func/tools.py +12 -2
  37. pixeltable/func/udf.py +2 -2
  38. pixeltable/functions/__init__.py +1 -0
  39. pixeltable/functions/anthropic.py +19 -45
  40. pixeltable/functions/deepseek.py +19 -38
  41. pixeltable/functions/fireworks.py +9 -18
  42. pixeltable/functions/gemini.py +2 -3
  43. pixeltable/functions/groq.py +108 -0
  44. pixeltable/functions/llama_cpp.py +6 -6
  45. pixeltable/functions/mistralai.py +16 -53
  46. pixeltable/functions/ollama.py +1 -1
  47. pixeltable/functions/openai.py +82 -165
  48. pixeltable/functions/string.py +212 -58
  49. pixeltable/functions/together.py +22 -80
  50. pixeltable/globals.py +10 -4
  51. pixeltable/index/base.py +5 -0
  52. pixeltable/index/btree.py +5 -0
  53. pixeltable/index/embedding_index.py +5 -0
  54. pixeltable/io/external_store.py +10 -31
  55. pixeltable/io/label_studio.py +5 -5
  56. pixeltable/io/parquet.py +2 -2
  57. pixeltable/io/table_data_conduit.py +1 -32
  58. pixeltable/metadata/__init__.py +11 -2
  59. pixeltable/metadata/converters/convert_13.py +2 -2
  60. pixeltable/metadata/converters/convert_30.py +6 -11
  61. pixeltable/metadata/converters/convert_35.py +9 -0
  62. pixeltable/metadata/converters/convert_36.py +38 -0
  63. pixeltable/metadata/converters/convert_37.py +15 -0
  64. pixeltable/metadata/converters/util.py +3 -9
  65. pixeltable/metadata/notes.py +3 -0
  66. pixeltable/metadata/schema.py +13 -1
  67. pixeltable/plan.py +135 -12
  68. pixeltable/share/packager.py +138 -14
  69. pixeltable/share/publish.py +2 -2
  70. pixeltable/store.py +19 -13
  71. pixeltable/type_system.py +30 -0
  72. pixeltable/utils/dbms.py +1 -1
  73. pixeltable/utils/formatter.py +64 -42
  74. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/METADATA +2 -1
  75. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/RECORD +78 -73
  76. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/LICENSE +0 -0
  77. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/WHEEL +0 -0
  78. {pixeltable-0.3.15.dist-info → pixeltable-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -157,12 +157,14 @@ def retrieval_udf(
157
157
  """
158
158
  # Argument validation
159
159
  col_refs: list[exprs.ColumnRef]
160
+ # TODO: get rid of references to ColumnRef internals and replace instead with a public interface
161
+ col_names = table.columns()
160
162
  if parameters is None:
161
- col_refs = [table[col_name] for col_name in table.columns if not table[col_name].col.is_computed]
163
+ col_refs = [table[col_name] for col_name in col_names if not table[col_name].col.is_computed]
162
164
  else:
163
165
  for param in parameters:
164
- if isinstance(param, str) and param not in table.columns:
165
- raise excs.Error(f'The specified parameter {param!r} is not a column of the table {table._path!r}')
166
+ if isinstance(param, str) and param not in col_names:
167
+ raise excs.Error(f'The specified parameter {param!r} is not a column of the table {table._path()!r}')
166
168
  col_refs = [table[param] if isinstance(param, str) else param for param in parameters]
167
169
 
168
170
  if len(col_refs) == 0:
pixeltable/func/tools.py CHANGED
@@ -1,8 +1,9 @@
1
+ import json
1
2
  from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
2
3
 
3
4
  import pydantic
4
5
 
5
- import pixeltable.exceptions as excs
6
+ from pixeltable import exceptions as excs, type_system as ts
6
7
 
7
8
  from .function import Function
8
9
  from .signature import Parameter
@@ -69,7 +70,9 @@ class Tool(pydantic.BaseModel):
69
70
  return _extract_float_tool_arg(kwargs, param_name=param.name)
70
71
  if param.col_type.is_bool_type():
71
72
  return _extract_bool_tool_arg(kwargs, param_name=param.name)
72
- raise AssertionError()
73
+ if param.col_type.is_json_type():
74
+ return _extract_json_tool_arg(kwargs, param_name=param.name)
75
+ raise AssertionError(param.col_type)
73
76
 
74
77
 
75
78
  class ToolChoice(pydantic.BaseModel):
@@ -137,6 +140,13 @@ def _extract_bool_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[
137
140
  return _extract_arg(bool, kwargs, param_name)
138
141
 
139
142
 
143
+ @udf
144
+ def _extract_json_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[ts.Json]:
145
+ if param_name in kwargs:
146
+ return json.loads(kwargs[param_name])
147
+ return None
148
+
149
+
140
150
  T = TypeVar('T')
141
151
 
142
152
 
pixeltable/func/udf.py CHANGED
@@ -262,7 +262,7 @@ def from_table(
262
262
  """
263
263
  from pixeltable import exprs
264
264
 
265
- ancestors = [tbl, *tbl._base_tables]
265
+ ancestors = [tbl, *tbl._get_base_tables()]
266
266
  ancestors.reverse() # We must traverse the ancestors in order from base to derived
267
267
 
268
268
  subst: dict[exprs.Expr, exprs.Expr] = {}
@@ -297,7 +297,7 @@ def from_table(
297
297
 
298
298
  if description is None:
299
299
  # Default description is the table comment
300
- description = tbl._comment
300
+ description = tbl._get_comment()
301
301
  if len(description) == 0:
302
302
  description = f"UDF for table '{tbl._name}'"
303
303
 
@@ -10,6 +10,7 @@ from . import (
10
10
  deepseek,
11
11
  fireworks,
12
12
  gemini,
13
+ groq,
13
14
  huggingface,
14
15
  image,
15
16
  json,
@@ -8,7 +8,7 @@ the [Working with Anthropic](https://pixeltable.readme.io/docs/working-with-anth
8
8
  import datetime
9
9
  import json
10
10
  import logging
11
- from typing import TYPE_CHECKING, Any, Iterable, Optional, TypeVar, Union, cast
11
+ from typing import TYPE_CHECKING, Any, Iterable, Optional, cast
12
12
 
13
13
  import httpx
14
14
 
@@ -73,16 +73,10 @@ async def messages(
73
73
  messages: list[dict[str, str]],
74
74
  *,
75
75
  model: str,
76
- max_tokens: int = 1024,
77
- metadata: Optional[dict[str, Any]] = None,
78
- stop_sequences: Optional[list[str]] = None,
79
- system: Optional[str] = None,
80
- temperature: Optional[float] = None,
81
- tool_choice: Optional[dict] = None,
82
- tools: Optional[list[dict]] = None,
83
- top_k: Optional[int] = None,
84
- top_p: Optional[float] = None,
85
- timeout: Optional[float] = None,
76
+ max_tokens: int,
77
+ model_kwargs: Optional[dict[str, Any]] = None,
78
+ tools: Optional[list[dict[str, Any]]] = None,
79
+ tool_choice: Optional[dict[str, Any]] = None,
86
80
  ) -> dict:
87
81
  """
88
82
  Create a Message.
@@ -101,25 +95,27 @@ async def messages(
101
95
  Args:
102
96
  messages: Input messages.
103
97
  model: The model that will complete your prompt.
104
-
105
- For details on the other parameters, see: <https://docs.anthropic.com/en/api/messages>
98
+ model_kwargs: Additional keyword args for the Anthropic `messages` API.
99
+ For details on the available parameters, see: <https://docs.anthropic.com/en/api/messages>
100
+ tools: An optional list of Pixeltable tools to use for the request.
101
+ tool_choice: An optional tool choice configuration.
106
102
 
107
103
  Returns:
108
104
  A dictionary containing the response and other metadata.
109
105
 
110
106
  Examples:
111
- Add a computed column that applies the model `claude-3-haiku-20240307`
107
+ Add a computed column that applies the model `claude-3-5-sonnet-20241022`
112
108
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
113
109
 
114
110
  >>> msgs = [{'role': 'user', 'content': tbl.prompt}]
115
- ... tbl.add_computed_column(response=messages(msgs, model='claude-3-haiku-20240307'))
111
+ ... tbl.add_computed_column(response=messages(msgs, model='claude-3-5-sonnet-20241022'))
116
112
  """
117
-
118
- # it doesn't look like count_tokens() actually exists in the current version of the library
113
+ if model_kwargs is None:
114
+ model_kwargs = {}
119
115
 
120
116
  if tools is not None:
121
117
  # Reformat `tools` into Anthropic format
122
- tools = [
118
+ model_kwargs['tools'] = [
123
119
  {
124
120
  'name': tool['name'],
125
121
  'description': tool['description'],
@@ -132,17 +128,16 @@ async def messages(
132
128
  for tool in tools
133
129
  ]
134
130
 
135
- tool_choice_: Optional[dict] = None
136
131
  if tool_choice is not None:
137
132
  if tool_choice['auto']:
138
- tool_choice_ = {'type': 'auto'}
133
+ model_kwargs['tool_choice'] = {'type': 'auto'}
139
134
  elif tool_choice['required']:
140
- tool_choice_ = {'type': 'any'}
135
+ model_kwargs['tool_choice'] = {'type': 'any'}
141
136
  else:
142
137
  assert tool_choice['tool'] is not None
143
- tool_choice_ = {'type': 'tool', 'name': tool_choice['tool']}
138
+ model_kwargs['tool_choice'] = {'type': 'tool', 'name': tool_choice['tool']}
144
139
  if not tool_choice['parallel_tool_calls']:
145
- tool_choice_['disable_parallel_tool_use'] = True
140
+ model_kwargs['tool_choice']['disable_parallel_tool_use'] = True
146
141
 
147
142
  # make sure the pool info exists prior to making the request
148
143
  resource_pool_id = f'rate-limits:anthropic:{model}'
@@ -152,20 +147,8 @@ async def messages(
152
147
  # TODO: timeouts should be set system-wide and be user-configurable
153
148
  from anthropic.types import MessageParam
154
149
 
155
- # cast(Any, ...): avoid mypy errors
156
150
  result = await _anthropic_client().messages.with_raw_response.create(
157
- messages=cast(Iterable[MessageParam], messages),
158
- model=model,
159
- max_tokens=max_tokens,
160
- metadata=_opt(cast(Any, metadata)),
161
- stop_sequences=_opt(stop_sequences),
162
- system=_opt(system),
163
- temperature=_opt(cast(Any, temperature)),
164
- tools=_opt(cast(Any, tools)),
165
- tool_choice=_opt(cast(Any, tool_choice_)),
166
- top_k=_opt(top_k),
167
- top_p=_opt(top_p),
168
- timeout=_opt(timeout),
151
+ messages=cast(Iterable[MessageParam], messages), model=model, max_tokens=max_tokens, **model_kwargs
169
152
  )
170
153
 
171
154
  requests_limit_str = result.headers.get('anthropic-ratelimit-requests-limit')
@@ -224,15 +207,6 @@ def _anthropic_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
224
207
  return pxt_tool_calls
225
208
 
226
209
 
227
- _T = TypeVar('_T')
228
-
229
-
230
- def _opt(arg: _T) -> Union[_T, 'anthropic.NotGiven']:
231
- import anthropic
232
-
233
- return arg if arg is not None else anthropic.NOT_GIVEN
234
-
235
-
236
210
  __all__ = local_public_names(__name__)
237
211
 
238
212
 
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
2
+ from typing import TYPE_CHECKING, Any, Optional
3
3
 
4
4
  import httpx
5
5
 
@@ -7,8 +7,6 @@ import pixeltable as pxt
7
7
  from pixeltable import env
8
8
  from pixeltable.utils.code import local_public_names
9
9
 
10
- from .openai import _opt
11
-
12
10
  if TYPE_CHECKING:
13
11
  import openai
14
12
 
@@ -33,17 +31,9 @@ async def chat_completions(
33
31
  messages: list,
34
32
  *,
35
33
  model: str,
36
- frequency_penalty: Optional[float] = None,
37
- logprobs: Optional[bool] = None,
38
- top_logprobs: Optional[int] = None,
39
- max_tokens: Optional[int] = None,
40
- presence_penalty: Optional[float] = None,
41
- response_format: Optional[dict] = None,
42
- stop: Optional[list[str]] = None,
43
- temperature: Optional[float] = None,
44
- tools: Optional[list[dict]] = None,
45
- tool_choice: Optional[dict] = None,
46
- top_p: Optional[float] = None,
34
+ model_kwargs: Optional[dict[str, Any]] = None,
35
+ tools: Optional[list[dict[str, Any]]] = None,
36
+ tool_choice: Optional[dict[str, Any]] = None,
47
37
  ) -> dict:
48
38
  """
49
39
  Creates a model response for the given chat conversation.
@@ -60,8 +50,10 @@ async def chat_completions(
60
50
  Args:
61
51
  messages: A list of messages to use for chat completion, as described in the Deepseek API documentation.
62
52
  model: The model to use for chat completion.
63
-
64
- For details on the other parameters, see: <https://api-docs.deepseek.com/api/create-chat-completion>
53
+ model_kwargs: Additional keyword args for the Deepseek `chat/completions` API.
54
+ For details on the available parameters, see: <https://api-docs.deepseek.com/api/create-chat-completion>
55
+ tools: An optional list of Pixeltable tools to use for the request.
56
+ tool_choice: An optional tool choice configuration.
65
57
 
66
58
  Returns:
67
59
  A dictionary containing the response and other metadata.
@@ -76,39 +68,28 @@ async def chat_completions(
76
68
  ]
77
69
  tbl.add_computed_column(response=chat_completions(messages, model='deepseek-chat'))
78
70
  """
71
+ if model_kwargs is None:
72
+ model_kwargs = {}
73
+
79
74
  if tools is not None:
80
- tools = [{'type': 'function', 'function': tool} for tool in tools]
75
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
81
76
 
82
- tool_choice_: Union[str, dict, None] = None
83
77
  if tool_choice is not None:
84
78
  if tool_choice['auto']:
85
- tool_choice_ = 'auto'
79
+ model_kwargs['tool_choice'] = 'auto'
86
80
  elif tool_choice['required']:
87
- tool_choice_ = 'required'
81
+ model_kwargs['tool_choice'] = 'required'
88
82
  else:
89
83
  assert tool_choice['tool'] is not None
90
- tool_choice_ = {'type': 'function', 'function': {'name': tool_choice['tool']}}
84
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
91
85
 
92
- extra_body: Optional[dict[str, Any]] = None
93
86
  if tool_choice is not None and not tool_choice['parallel_tool_calls']:
94
- extra_body = {'parallel_tool_calls': False}
87
+ if 'extra_body' not in model_kwargs:
88
+ model_kwargs['extra_body'] = {}
89
+ model_kwargs['extra_body']['parallel_tool_calls'] = False
95
90
 
96
- # cast(Any, ...): avoid mypy errors
97
91
  result = await _deepseek_client().chat.completions.with_raw_response.create(
98
- messages=messages,
99
- model=model,
100
- frequency_penalty=_opt(frequency_penalty),
101
- logprobs=_opt(logprobs),
102
- top_logprobs=_opt(top_logprobs),
103
- max_tokens=_opt(max_tokens),
104
- presence_penalty=_opt(presence_penalty),
105
- response_format=_opt(cast(Any, response_format)),
106
- stop=_opt(stop),
107
- temperature=_opt(temperature),
108
- tools=_opt(cast(Any, tools)),
109
- tool_choice=_opt(cast(Any, tool_choice_)),
110
- top_p=_opt(top_p),
111
- extra_body=extra_body,
92
+ messages=messages, model=model, **model_kwargs
112
93
  )
113
94
 
114
95
  return json.loads(result.text)
@@ -5,7 +5,7 @@ first `pip install fireworks-ai` and configure your Fireworks AI credentials, as
5
5
  the [Working with Fireworks](https://pixeltable.readme.io/docs/working-with-fireworks) tutorial.
6
6
  """
7
7
 
8
- from typing import TYPE_CHECKING, Optional
8
+ from typing import TYPE_CHECKING, Any, Optional
9
9
 
10
10
  import pixeltable as pxt
11
11
  from pixeltable import env
@@ -29,14 +29,7 @@ def _fireworks_client() -> 'fireworks.client.Fireworks':
29
29
 
30
30
  @pxt.udf(resource_pool='request-rate:fireworks')
31
31
  async def chat_completions(
32
- messages: list[dict[str, str]],
33
- *,
34
- model: str,
35
- max_tokens: Optional[int] = None,
36
- top_k: Optional[int] = None,
37
- top_p: Optional[float] = None,
38
- temperature: Optional[float] = None,
39
- request_timeout: Optional[int] = None,
32
+ messages: list[dict[str, str]], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
40
33
  ) -> dict:
41
34
  """
42
35
  Creates a model response for the given chat conversation.
@@ -55,8 +48,8 @@ async def chat_completions(
55
48
  Args:
56
49
  messages: A list of messages comprising the conversation so far.
57
50
  model: The name of the model to use.
58
-
59
- For details on the other parameters, see: <https://docs.fireworks.ai/api-reference/post-chatcompletions>
51
+ model_kwargs: Additional keyword args for the Fireworks `chat_completions` API. For details on the available
52
+ parameters, see: <https://docs.fireworks.ai/api-reference/post-chatcompletions>
60
53
 
61
54
  Returns:
62
55
  A dictionary containing the response and other metadata.
@@ -70,20 +63,18 @@ async def chat_completions(
70
63
  ... response=chat_completions(messages, model='accounts/fireworks/models/mixtral-8x22b-instruct')
71
64
  ... )
72
65
  """
73
- kwargs = {'max_tokens': max_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature}
74
- kwargs_not_none = {k: v for k, v in kwargs.items() if v is not None}
66
+ if model_kwargs is None:
67
+ model_kwargs = {}
75
68
 
76
69
  # for debugging purposes:
77
70
  # res_sync = _fireworks_client().chat.completions.create(model=model, messages=messages, **kwargs_not_none)
78
71
  # res_sync_dict = res_sync.dict()
79
72
 
80
- if request_timeout is None:
81
- request_timeout = Config.get().get_int_value('timeout', section='fireworks') or 600
73
+ if 'request_timeout' not in model_kwargs:
74
+ model_kwargs['request_timeout'] = Config.get().get_int_value('timeout', section='fireworks') or 600
82
75
  # TODO: this timeout doesn't really work, I think it only applies to returning the stream, but not to the timing
83
76
  # of the chunks; addressing this would require a timeout for the task running this udf
84
- stream = _fireworks_client().chat.completions.acreate(
85
- model=model, messages=messages, request_timeout=request_timeout, **kwargs_not_none
86
- )
77
+ stream = _fireworks_client().chat.completions.acreate(model=model, messages=messages, **model_kwargs)
87
78
  chunks = []
88
79
  async for chunk in stream:
89
80
  chunks.append(chunk)
@@ -53,8 +53,8 @@ async def generate_content(
53
53
  config: Configuration for generation, corresponding to keyword arguments of
54
54
  `genai.types.GenerateContentConfig`. For details on the parameters, see:
55
55
  <https://googleapis.github.io/python-genai/genai.html#module-genai.types>
56
- tools: Optional list of Pixeltable tools to use. It is also possible to specify tools manually via the
57
- `config.tools` parameter, but at most one of `config.tools` or `tools` may be used.
56
+ tools: An optional list of Pixeltable tools to use. It is also possible to specify tools manually via the
57
+ `config['tools']` parameter, but at most one of `config['tools']` or `tools` may be used.
58
58
 
59
59
  Returns:
60
60
  A dictionary containing the response and other metadata.
@@ -103,7 +103,6 @@ def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDic
103
103
 
104
104
  @pxt.udf
105
105
  def _gemini_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
106
- print(response)
107
106
  pxt_tool_calls: dict[str, list[dict]] = {}
108
107
  for part in response['candidates'][0]['content']['parts']:
109
108
  tool_call = part.get('function_call')
@@ -0,0 +1,108 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various endpoints from the Groq API. In order to use them, you must
4
+ first `pip install groq` and configure your Groq credentials, as described in
5
+ the [Working with Groq](https://pixeltable.readme.io/docs/working-with-groq) tutorial.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, Optional
9
+
10
+ import pixeltable as pxt
11
+ from pixeltable import exprs
12
+ from pixeltable.env import Env, register_client
13
+ from pixeltable.utils.code import local_public_names
14
+
15
+ from .openai import _openai_response_to_pxt_tool_calls
16
+
17
+ if TYPE_CHECKING:
18
+ import groq
19
+
20
+
21
+ @register_client('groq')
22
+ def _(api_key: str) -> 'groq.AsyncGroq':
23
+ import groq
24
+
25
+ return groq.AsyncGroq(api_key=api_key)
26
+
27
+
28
+ def _groq_client() -> 'groq.AsyncGroq':
29
+ return Env.get().get_client('groq')
30
+
31
+
32
+ @pxt.udf(resource_pool='request-rate:groq')
33
+ async def chat_completions(
34
+ messages: list[dict[str, str]],
35
+ *,
36
+ model: str,
37
+ model_kwargs: Optional[dict[str, Any]] = None,
38
+ tools: Optional[list[dict[str, Any]]] = None,
39
+ tool_choice: Optional[dict[str, Any]] = None,
40
+ ) -> dict:
41
+ """
42
+ Chat Completion API.
43
+
44
+ Equivalent to the Groq `chat/completions` API endpoint.
45
+ For additional details, see: <https://console.groq.com/docs/api-reference#chat-create>
46
+
47
+ Request throttling:
48
+ Applies the rate limit set in the config (section `groq`, key `rate_limit`). If no rate
49
+ limit is configured, uses a default of 600 RPM.
50
+
51
+ __Requirements:__
52
+
53
+ - `pip install groq`
54
+
55
+ Args:
56
+ messages: A list of messages comprising the conversation so far.
57
+ model: ID of the model to use. (See overview here: <https://console.groq.com/docs/models>)
58
+ model_kwargs: Additional keyword args for the Groq `chat/completions` API.
59
+ For details on the available parameters, see: <https://console.groq.com/docs/api-reference#chat-create>
60
+
61
+ Returns:
62
+ A dictionary containing the response and other metadata.
63
+
64
+ Examples:
65
+ Add a computed column that applies the model `llama3-8b-8192`
66
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
67
+
68
+ >>> messages = [{'role': 'user', 'content': tbl.prompt}]
69
+ ... tbl.add_computed_column(response=chat_completions(messages, model='llama3-8b-8192'))
70
+ """
71
+ if model_kwargs is None:
72
+ model_kwargs = {}
73
+
74
+ Env.get().require_package('groq')
75
+
76
+ if tools is not None:
77
+ model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
78
+
79
+ if tool_choice is not None:
80
+ if tool_choice['auto']:
81
+ model_kwargs['tool_choice'] = 'auto'
82
+ elif tool_choice['required']:
83
+ model_kwargs['tool_choice'] = 'required'
84
+ else:
85
+ assert tool_choice['tool'] is not None
86
+ model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
87
+
88
+ if tool_choice is not None and not tool_choice['parallel_tool_calls']:
89
+ model_kwargs['parallel_tool_calls'] = False
90
+
91
+ result = await _groq_client().chat.completions.create(
92
+ messages=messages, # type: ignore[arg-type]
93
+ model=model,
94
+ **model_kwargs,
95
+ )
96
+ return result.model_dump()
97
+
98
+
99
+ def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDict:
100
+ """Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
101
+ return tools._invoke(_openai_response_to_pxt_tool_calls(response))
102
+
103
+
104
+ __all__ = local_public_names(__name__)
105
+
106
+
107
+ def __dir__() -> list[str]:
108
+ return __all__
@@ -17,7 +17,7 @@ def create_chat_completion(
17
17
  model_path: Optional[str] = None,
18
18
  repo_id: Optional[str] = None,
19
19
  repo_filename: Optional[str] = None,
20
- args: Optional[dict[str, Any]] = None,
20
+ model_kwargs: Optional[dict[str, Any]] = None,
21
21
  ) -> dict:
22
22
  """
23
23
  Generate a chat completion from a list of messages.
@@ -35,14 +35,14 @@ def create_chat_completion(
35
35
  repo_id: The Hugging Face model repo id (if using a pretrained model).
36
36
  repo_filename: A filename or glob pattern to match the model file in the repo (optional, if using a
37
37
  pretrained model).
38
- args: Additional arguments to pass to the `create_chat_completions` call, such as `max_tokens`, `temperature`,
39
- `top_p`, and `top_k`. For details, see the
38
+ model_kwargs: Additional keyword args for the llama_cpp `create_chat_completions` API, such as `max_tokens`,
39
+ `temperature`, `top_p`, and `top_k`. For details, see the
40
40
  [llama_cpp create_chat_completions documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
41
41
  """
42
42
  Env.get().require_package('llama_cpp', min_version=[0, 3, 1])
43
43
 
44
- if args is None:
45
- args = {}
44
+ if model_kwargs is None:
45
+ model_kwargs = {}
46
46
 
47
47
  if (model_path is None) == (repo_id is None):
48
48
  raise excs.Error('Exactly one of `model_path` or `repo_id` must be provided.')
@@ -56,7 +56,7 @@ def create_chat_completion(
56
56
  else:
57
57
  Env.get().require_package('huggingface_hub')
58
58
  llm = _lookup_pretrained_model(repo_id, repo_filename, n_gpu_layers)
59
- return llm.create_chat_completion(messages, **args) # type: ignore
59
+ return llm.create_chat_completion(messages, **model_kwargs) # type: ignore
60
60
 
61
61
 
62
62
  def _is_gpu_available() -> bool:
@@ -5,7 +5,7 @@ first `pip install mistralai` and configure your Mistral AI credentials, as desc
5
5
  the [Working with Mistral AI](https://pixeltable.readme.io/docs/working-with-mistralai) tutorial.
6
6
  """
7
7
 
8
- from typing import TYPE_CHECKING, Optional, TypeVar, Union
8
+ from typing import TYPE_CHECKING, Any, Optional
9
9
 
10
10
  import numpy as np
11
11
 
@@ -16,7 +16,7 @@ from pixeltable.func.signature import Batch
16
16
  from pixeltable.utils.code import local_public_names
17
17
 
18
18
  if TYPE_CHECKING:
19
- import mistralai.types.basemodel
19
+ import mistralai
20
20
 
21
21
 
22
22
  @register_client('mistral')
@@ -32,16 +32,7 @@ def _mistralai_client() -> 'mistralai.Mistral':
32
32
 
33
33
  @pxt.udf(resource_pool='request-rate:mistral')
34
34
  async def chat_completions(
35
- messages: list[dict[str, str]],
36
- *,
37
- model: str,
38
- temperature: Optional[float] = 0.7,
39
- top_p: Optional[float] = 1.0,
40
- max_tokens: Optional[int] = None,
41
- stop: Optional[list[str]] = None,
42
- random_seed: Optional[int] = None,
43
- response_format: Optional[dict] = None,
44
- safe_prompt: Optional[bool] = False,
35
+ messages: list[dict[str, str]], *, model: str, model_kwargs: Optional[dict[str, Any]] = None
45
36
  ) -> dict:
46
37
  """
47
38
  Chat Completion API.
@@ -60,8 +51,8 @@ async def chat_completions(
60
51
  Args:
61
52
  messages: The prompt(s) to generate completions for.
62
53
  model: ID of the model to use. (See overview here: <https://docs.mistral.ai/getting-started/models/>)
63
-
64
- For details on the other parameters, see: <https://docs.mistral.ai/api/#tag/chat>
54
+ model_kwargs: Additional keyword args for the Mistral `chat/completions` API.
55
+ For details on the available parameters, see: <https://docs.mistral.ai/api/#tag/chat>
65
56
 
66
57
  Returns:
67
58
  A dictionary containing the response and other metadata.
@@ -73,34 +64,20 @@ async def chat_completions(
73
64
  >>> messages = [{'role': 'user', 'content': tbl.prompt}]
74
65
  ... tbl.add_computed_column(response=completions(messages, model='mistral-latest-small'))
75
66
  """
67
+ if model_kwargs is None:
68
+ model_kwargs = {}
69
+
76
70
  Env.get().require_package('mistralai')
77
71
  result = await _mistralai_client().chat.complete_async(
78
72
  messages=messages, # type: ignore[arg-type]
79
73
  model=model,
80
- temperature=temperature,
81
- top_p=top_p,
82
- max_tokens=_opt(max_tokens),
83
- stop=stop,
84
- random_seed=_opt(random_seed),
85
- response_format=response_format, # type: ignore[arg-type]
86
- safe_prompt=safe_prompt,
74
+ **model_kwargs,
87
75
  )
88
76
  return result.dict()
89
77
 
90
78
 
91
79
  @pxt.udf(resource_pool='request-rate:mistral')
92
- async def fim_completions(
93
- prompt: str,
94
- *,
95
- model: str,
96
- temperature: Optional[float] = 0.7,
97
- top_p: Optional[float] = 1.0,
98
- max_tokens: Optional[int] = None,
99
- min_tokens: Optional[int] = None,
100
- stop: Optional[list[str]] = None,
101
- random_seed: Optional[int] = None,
102
- suffix: Optional[str] = None,
103
- ) -> dict:
80
+ async def fim_completions(prompt: str, *, model: str, model_kwargs: Optional[dict[str, Any]] = None) -> dict:
104
81
  """
105
82
  Fill-in-the-middle Completion API.
106
83
 
@@ -118,6 +95,8 @@ async def fim_completions(
118
95
  Args:
119
96
  prompt: The text/code to complete.
120
97
  model: ID of the model to use. (See overview here: <https://docs.mistral.ai/getting-started/models/>)
98
+ model_kwargs: Additional keyword args for the Mistral `fim/completions` API.
99
+ For details on the available parameters, see: <https://docs.mistral.ai/api/#tag/fim>
121
100
 
122
101
  For details on the other parameters, see: <https://docs.mistral.ai/api/#tag/fim>
123
102
 
@@ -130,18 +109,11 @@ async def fim_completions(
130
109
 
131
110
  >>> tbl.add_computed_column(response=completions(tbl.prompt, model='codestral-latest'))
132
111
  """
112
+ if model_kwargs is None:
113
+ model_kwargs = {}
114
+
133
115
  Env.get().require_package('mistralai')
134
- result = await _mistralai_client().fim.complete_async(
135
- prompt=prompt,
136
- model=model,
137
- temperature=temperature,
138
- top_p=top_p,
139
- max_tokens=_opt(max_tokens),
140
- min_tokens=_opt(min_tokens),
141
- stop=stop,
142
- random_seed=_opt(random_seed),
143
- suffix=_opt(suffix),
144
- )
116
+ result = await _mistralai_client().fim.complete_async(prompt=prompt, model=model, **model_kwargs)
145
117
  return result.dict()
146
118
 
147
119
 
@@ -182,15 +154,6 @@ def _(model: str) -> ts.ArrayType:
182
154
  return ts.ArrayType((dimensions,), dtype=ts.FloatType())
183
155
 
184
156
 
185
- _T = TypeVar('_T')
186
-
187
-
188
- def _opt(arg: Optional[_T]) -> Union[_T, 'mistralai.types.basemodel.Unset']:
189
- from mistralai.types import UNSET
190
-
191
- return arg if arg is not None else UNSET
192
-
193
-
194
157
  __all__ = local_public_names(__name__)
195
158
 
196
159