pixeltable 0.4.0rc3__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -1
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +9 -1
- pixeltable/catalog/catalog.py +333 -99
- pixeltable/catalog/column.py +28 -26
- pixeltable/catalog/globals.py +12 -0
- pixeltable/catalog/insertable_table.py +8 -8
- pixeltable/catalog/schema_object.py +6 -0
- pixeltable/catalog/table.py +111 -116
- pixeltable/catalog/table_version.py +36 -50
- pixeltable/catalog/table_version_handle.py +4 -1
- pixeltable/catalog/table_version_path.py +28 -4
- pixeltable/catalog/view.py +10 -18
- pixeltable/config.py +4 -0
- pixeltable/dataframe.py +10 -9
- pixeltable/env.py +5 -11
- pixeltable/exceptions.py +6 -0
- pixeltable/exec/exec_node.py +2 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +4 -4
- pixeltable/exec/sql_node.py +47 -30
- pixeltable/exprs/column_property_ref.py +2 -1
- pixeltable/exprs/column_ref.py +7 -6
- pixeltable/exprs/expr.py +4 -4
- pixeltable/func/__init__.py +1 -0
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +4 -2
- pixeltable/func/tools.py +12 -2
- pixeltable/func/udf.py +2 -2
- pixeltable/functions/__init__.py +1 -0
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +8 -6
- pixeltable/functions/mistralai.py +2 -13
- pixeltable/functions/openai.py +1 -6
- pixeltable/functions/replicate.py +2 -2
- pixeltable/functions/util.py +6 -1
- pixeltable/globals.py +0 -2
- pixeltable/io/external_store.py +2 -2
- pixeltable/io/label_studio.py +4 -4
- pixeltable/io/table_data_conduit.py +1 -1
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +5 -0
- pixeltable/plan.py +37 -121
- pixeltable/share/packager.py +2 -2
- pixeltable/type_system.py +30 -0
- {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/METADATA +1 -1
- {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/RECORD +51 -49
- pixeltable/utils/sample.py +0 -25
- {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/LICENSE +0 -0
- {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/WHEEL +0 -0
- {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.1.dist-info}/entry_points.txt +0 -0
|
@@ -64,8 +64,9 @@ class ColumnPropertyRef(Expr):
|
|
|
64
64
|
# perform runtime checks and update state
|
|
65
65
|
tv = self._col_ref.tbl_version.get()
|
|
66
66
|
assert tv.is_validated
|
|
67
|
+
# we can assume at this point during query execution that the column exists
|
|
68
|
+
assert self._col_ref.col_id in tv.cols_by_id
|
|
67
69
|
col = tv.cols_by_id[self._col_ref.col_id]
|
|
68
|
-
# TODO: check for column being dropped
|
|
69
70
|
|
|
70
71
|
# the errortype/-msg properties of a read-validated media column need to be extracted from the DataRow
|
|
71
72
|
if (
|
pixeltable/exprs/column_ref.py
CHANGED
|
@@ -239,7 +239,6 @@ class ColumnRef(Expr):
|
|
|
239
239
|
return helper
|
|
240
240
|
|
|
241
241
|
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
242
|
-
# return None if self.perform_validation else self.col.sa_col
|
|
243
242
|
if self.perform_validation:
|
|
244
243
|
return None
|
|
245
244
|
# we need to reestablish that we have the correct Column instance, there could have been a metadata
|
|
@@ -248,13 +247,10 @@ class ColumnRef(Expr):
|
|
|
248
247
|
# perform runtime checks and update state
|
|
249
248
|
tv = self.tbl_version.get()
|
|
250
249
|
assert tv.is_validated
|
|
250
|
+
# we can assume at this point during query execution that the column exists
|
|
251
|
+
assert self.col_id in tv.cols_by_id
|
|
251
252
|
self.col = tv.cols_by_id[self.col_id]
|
|
252
253
|
assert self.col.tbl is tv
|
|
253
|
-
# TODO: check for column being dropped
|
|
254
|
-
# print(
|
|
255
|
-
# f'ColumnRef.sql_expr: tbl={tv.id}:{tv.effective_version} sa_tbl={id(self.col.tbl.store_tbl.sa_tbl):x} '
|
|
256
|
-
# f'tv={id(tv):x}'
|
|
257
|
-
# )
|
|
258
254
|
return self.col.sa_col
|
|
259
255
|
|
|
260
256
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
@@ -315,6 +311,11 @@ class ColumnRef(Expr):
|
|
|
315
311
|
'perform_validation': self.perform_validation,
|
|
316
312
|
}
|
|
317
313
|
|
|
314
|
+
@classmethod
|
|
315
|
+
def get_column_id(cls, d: dict) -> catalog.QColumnId:
|
|
316
|
+
tbl_id, col_id = UUID(d['tbl_id']), d['col_id']
|
|
317
|
+
return catalog.QColumnId(tbl_id, col_id)
|
|
318
|
+
|
|
318
319
|
@classmethod
|
|
319
320
|
def get_column(cls, d: dict) -> catalog.Column:
|
|
320
321
|
tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
|
pixeltable/exprs/expr.py
CHANGED
|
@@ -394,17 +394,17 @@ class Expr(abc.ABC):
|
|
|
394
394
|
return {tbl_id for e in exprs_ for tbl_id in e.tbl_ids()}
|
|
395
395
|
|
|
396
396
|
@classmethod
|
|
397
|
-
def
|
|
397
|
+
def get_refd_column_ids(cls, expr_dict: dict[str, Any]) -> set[catalog.QColumnId]:
|
|
398
398
|
"""Return Columns referenced by expr_dict."""
|
|
399
|
-
result:
|
|
399
|
+
result: set[catalog.QColumnId] = set()
|
|
400
400
|
assert '_classname' in expr_dict
|
|
401
401
|
from .column_ref import ColumnRef
|
|
402
402
|
|
|
403
403
|
if expr_dict['_classname'] == 'ColumnRef':
|
|
404
|
-
result.
|
|
404
|
+
result.add(ColumnRef.get_column_id(expr_dict))
|
|
405
405
|
if 'components' in expr_dict:
|
|
406
406
|
for component_dict in expr_dict['components']:
|
|
407
|
-
result.
|
|
407
|
+
result.update(cls.get_refd_column_ids(component_dict))
|
|
408
408
|
return result
|
|
409
409
|
|
|
410
410
|
def as_literal(self) -> Optional[Expr]:
|
pixeltable/func/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from .callable_function import CallableFunction
|
|
|
5
5
|
from .expr_template_function import ExprTemplateFunction
|
|
6
6
|
from .function import Function, InvalidFunction
|
|
7
7
|
from .function_registry import FunctionRegistry
|
|
8
|
+
from .mcp import mcp_udfs
|
|
8
9
|
from .query_template_function import QueryTemplateFunction, query, retrieval_udf
|
|
9
10
|
from .signature import Batch, Parameter, Signature
|
|
10
11
|
from .tools import Tool, ToolChoice, Tools
|
pixeltable/func/mcp.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
4
|
+
|
|
5
|
+
import pixeltable as pxt
|
|
6
|
+
from pixeltable import exceptions as excs, type_system as ts
|
|
7
|
+
from pixeltable.func.signature import Parameter
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import mcp
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def mcp_udfs(url: str) -> list['pxt.func.Function']:
|
|
14
|
+
return asyncio.run(mcp_udfs_async(url))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def mcp_udfs_async(url: str) -> list['pxt.func.Function']:
|
|
18
|
+
import mcp
|
|
19
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
20
|
+
|
|
21
|
+
list_tools_result: Optional[mcp.types.ListToolsResult] = None
|
|
22
|
+
async with (
|
|
23
|
+
streamablehttp_client(url) as (read_stream, write_stream, _),
|
|
24
|
+
mcp.ClientSession(read_stream, write_stream) as session,
|
|
25
|
+
):
|
|
26
|
+
await session.initialize()
|
|
27
|
+
list_tools_result = await session.list_tools()
|
|
28
|
+
assert list_tools_result is not None
|
|
29
|
+
|
|
30
|
+
return [mcp_tool_to_udf(url, tool) for tool in list_tools_result.tools]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def mcp_tool_to_udf(url: str, mcp_tool: 'mcp.types.Tool') -> 'pxt.func.Function':
|
|
34
|
+
import mcp
|
|
35
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
36
|
+
|
|
37
|
+
async def invoke(**kwargs: Any) -> str:
|
|
38
|
+
# TODO: Cache session objects rather than creating a new one each time?
|
|
39
|
+
async with (
|
|
40
|
+
streamablehttp_client(url) as (read_stream, write_stream, _),
|
|
41
|
+
mcp.ClientSession(read_stream, write_stream) as session,
|
|
42
|
+
):
|
|
43
|
+
await session.initialize()
|
|
44
|
+
res = await session.call_tool(name=mcp_tool.name, arguments=kwargs)
|
|
45
|
+
# TODO Handle image/audio responses?
|
|
46
|
+
return res.content[0].text # type: ignore[union-attr]
|
|
47
|
+
|
|
48
|
+
if mcp_tool.description is not None:
|
|
49
|
+
invoke.__doc__ = mcp_tool.description
|
|
50
|
+
|
|
51
|
+
input_schema = mcp_tool.inputSchema
|
|
52
|
+
params = {
|
|
53
|
+
name: __mcp_param_to_pxt_type(mcp_tool.name, name, param) for name, param in input_schema['properties'].items()
|
|
54
|
+
}
|
|
55
|
+
required = input_schema.get('required', [])
|
|
56
|
+
|
|
57
|
+
# Ensure that any params not appearing in `required` are nullable.
|
|
58
|
+
# (A required param might or might not be nullable, since its type might be an 'anyOf' containing a null.)
|
|
59
|
+
for name in params.keys() - required:
|
|
60
|
+
params[name] = params[name].copy(nullable=True)
|
|
61
|
+
|
|
62
|
+
signature = pxt.func.Signature(
|
|
63
|
+
return_type=ts.StringType(), # Return type is always string
|
|
64
|
+
parameters=[Parameter(name, col_type, inspect.Parameter.KEYWORD_ONLY) for name, col_type in params.items()],
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return pxt.func.CallableFunction(signatures=[signature], py_fns=[invoke], self_name=mcp_tool.name)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def __mcp_param_to_pxt_type(tool_name: str, name: str, param: dict[str, Any]) -> ts.ColumnType:
|
|
71
|
+
pxt_type = ts.ColumnType.from_json_schema(param)
|
|
72
|
+
if pxt_type is None:
|
|
73
|
+
raise excs.Error(f'Unknown type schema for MCP parameter {name!r} of tool {tool_name!r}: {param}')
|
|
74
|
+
return pxt_type
|
|
@@ -157,11 +157,13 @@ def retrieval_udf(
|
|
|
157
157
|
"""
|
|
158
158
|
# Argument validation
|
|
159
159
|
col_refs: list[exprs.ColumnRef]
|
|
160
|
+
# TODO: get rid of references to ColumnRef internals and replace instead with a public interface
|
|
161
|
+
col_names = table.columns()
|
|
160
162
|
if parameters is None:
|
|
161
|
-
col_refs = [table[col_name] for col_name in
|
|
163
|
+
col_refs = [table[col_name] for col_name in col_names if not table[col_name].col.is_computed]
|
|
162
164
|
else:
|
|
163
165
|
for param in parameters:
|
|
164
|
-
if isinstance(param, str) and param not in
|
|
166
|
+
if isinstance(param, str) and param not in col_names:
|
|
165
167
|
raise excs.Error(f'The specified parameter {param!r} is not a column of the table {table._path()!r}')
|
|
166
168
|
col_refs = [table[param] if isinstance(param, str) else param for param in parameters]
|
|
167
169
|
|
pixeltable/func/tools.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
|
|
2
3
|
|
|
3
4
|
import pydantic
|
|
4
5
|
|
|
5
|
-
from pixeltable import exceptions as excs
|
|
6
|
+
from pixeltable import exceptions as excs, type_system as ts
|
|
6
7
|
|
|
7
8
|
from .function import Function
|
|
8
9
|
from .signature import Parameter
|
|
@@ -69,7 +70,9 @@ class Tool(pydantic.BaseModel):
|
|
|
69
70
|
return _extract_float_tool_arg(kwargs, param_name=param.name)
|
|
70
71
|
if param.col_type.is_bool_type():
|
|
71
72
|
return _extract_bool_tool_arg(kwargs, param_name=param.name)
|
|
72
|
-
|
|
73
|
+
if param.col_type.is_json_type():
|
|
74
|
+
return _extract_json_tool_arg(kwargs, param_name=param.name)
|
|
75
|
+
raise AssertionError(param.col_type)
|
|
73
76
|
|
|
74
77
|
|
|
75
78
|
class ToolChoice(pydantic.BaseModel):
|
|
@@ -137,6 +140,13 @@ def _extract_bool_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[
|
|
|
137
140
|
return _extract_arg(bool, kwargs, param_name)
|
|
138
141
|
|
|
139
142
|
|
|
143
|
+
@udf
|
|
144
|
+
def _extract_json_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[ts.Json]:
|
|
145
|
+
if param_name in kwargs:
|
|
146
|
+
return json.loads(kwargs[param_name])
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
|
|
140
150
|
T = TypeVar('T')
|
|
141
151
|
|
|
142
152
|
|
pixeltable/func/udf.py
CHANGED
|
@@ -262,7 +262,7 @@ def from_table(
|
|
|
262
262
|
"""
|
|
263
263
|
from pixeltable import exprs
|
|
264
264
|
|
|
265
|
-
ancestors = [tbl, *tbl.
|
|
265
|
+
ancestors = [tbl, *tbl._get_base_tables()]
|
|
266
266
|
ancestors.reverse() # We must traverse the ancestors in order from base to derived
|
|
267
267
|
|
|
268
268
|
subst: dict[exprs.Expr, exprs.Expr] = {}
|
|
@@ -297,7 +297,7 @@ def from_table(
|
|
|
297
297
|
|
|
298
298
|
if description is None:
|
|
299
299
|
# Default description is the table comment
|
|
300
|
-
description = tbl.
|
|
300
|
+
description = tbl._get_comment()
|
|
301
301
|
if len(description) == 0:
|
|
302
302
|
description = f"UDF for table '{tbl._name}'"
|
|
303
303
|
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
|
|
3
|
+
that wrap various endpoints from the Groq API. In order to use them, you must
|
|
4
|
+
first `pip install groq` and configure your Groq credentials, as described in
|
|
5
|
+
the [Working with Groq](https://pixeltable.readme.io/docs/working-with-groq) tutorial.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
9
|
+
|
|
10
|
+
import pixeltable as pxt
|
|
11
|
+
from pixeltable import exprs
|
|
12
|
+
from pixeltable.env import Env, register_client
|
|
13
|
+
from pixeltable.utils.code import local_public_names
|
|
14
|
+
|
|
15
|
+
from .openai import _openai_response_to_pxt_tool_calls
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import groq
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_client('groq')
|
|
22
|
+
def _(api_key: str) -> 'groq.AsyncGroq':
|
|
23
|
+
import groq
|
|
24
|
+
|
|
25
|
+
return groq.AsyncGroq(api_key=api_key)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _groq_client() -> 'groq.AsyncGroq':
|
|
29
|
+
return Env.get().get_client('groq')
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@pxt.udf(resource_pool='request-rate:groq')
|
|
33
|
+
async def chat_completions(
|
|
34
|
+
messages: list[dict[str, str]],
|
|
35
|
+
*,
|
|
36
|
+
model: str,
|
|
37
|
+
model_kwargs: Optional[dict[str, Any]] = None,
|
|
38
|
+
tools: Optional[list[dict[str, Any]]] = None,
|
|
39
|
+
tool_choice: Optional[dict[str, Any]] = None,
|
|
40
|
+
) -> dict:
|
|
41
|
+
"""
|
|
42
|
+
Chat Completion API.
|
|
43
|
+
|
|
44
|
+
Equivalent to the Groq `chat/completions` API endpoint.
|
|
45
|
+
For additional details, see: <https://console.groq.com/docs/api-reference#chat-create>
|
|
46
|
+
|
|
47
|
+
Request throttling:
|
|
48
|
+
Applies the rate limit set in the config (section `groq`, key `rate_limit`). If no rate
|
|
49
|
+
limit is configured, uses a default of 600 RPM.
|
|
50
|
+
|
|
51
|
+
__Requirements:__
|
|
52
|
+
|
|
53
|
+
- `pip install groq`
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
messages: A list of messages comprising the conversation so far.
|
|
57
|
+
model: ID of the model to use. (See overview here: <https://console.groq.com/docs/models>)
|
|
58
|
+
model_kwargs: Additional keyword args for the Groq `chat/completions` API.
|
|
59
|
+
For details on the available parameters, see: <https://console.groq.com/docs/api-reference#chat-create>
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
A dictionary containing the response and other metadata.
|
|
63
|
+
|
|
64
|
+
Examples:
|
|
65
|
+
Add a computed column that applies the model `llama3-8b-8192`
|
|
66
|
+
to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
|
|
67
|
+
|
|
68
|
+
>>> messages = [{'role': 'user', 'content': tbl.prompt}]
|
|
69
|
+
... tbl.add_computed_column(response=chat_completions(messages, model='llama3-8b-8192'))
|
|
70
|
+
"""
|
|
71
|
+
if model_kwargs is None:
|
|
72
|
+
model_kwargs = {}
|
|
73
|
+
|
|
74
|
+
Env.get().require_package('groq')
|
|
75
|
+
|
|
76
|
+
if tools is not None:
|
|
77
|
+
model_kwargs['tools'] = [{'type': 'function', 'function': tool} for tool in tools]
|
|
78
|
+
|
|
79
|
+
if tool_choice is not None:
|
|
80
|
+
if tool_choice['auto']:
|
|
81
|
+
model_kwargs['tool_choice'] = 'auto'
|
|
82
|
+
elif tool_choice['required']:
|
|
83
|
+
model_kwargs['tool_choice'] = 'required'
|
|
84
|
+
else:
|
|
85
|
+
assert tool_choice['tool'] is not None
|
|
86
|
+
model_kwargs['tool_choice'] = {'type': 'function', 'function': {'name': tool_choice['tool']}}
|
|
87
|
+
|
|
88
|
+
if tool_choice is not None and not tool_choice['parallel_tool_calls']:
|
|
89
|
+
model_kwargs['parallel_tool_calls'] = False
|
|
90
|
+
|
|
91
|
+
result = await _groq_client().chat.completions.create(
|
|
92
|
+
messages=messages, # type: ignore[arg-type]
|
|
93
|
+
model=model,
|
|
94
|
+
**model_kwargs,
|
|
95
|
+
)
|
|
96
|
+
return result.model_dump()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def invoke_tools(tools: pxt.func.Tools, response: exprs.Expr) -> exprs.InlineDict:
|
|
100
|
+
"""Converts an OpenAI response dict to Pixeltable tool invocation format and calls `tools._invoke()`."""
|
|
101
|
+
return tools._invoke(_openai_response_to_pxt_tool_calls(response))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
__all__ = local_public_names(__name__)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def __dir__() -> list[str]:
|
|
108
|
+
return __all__
|
|
@@ -51,7 +51,7 @@ def sentence_transformer(
|
|
|
51
51
|
"""
|
|
52
52
|
env.Env.get().require_package('sentence_transformers')
|
|
53
53
|
device = resolve_torch_device('auto')
|
|
54
|
-
from sentence_transformers import SentenceTransformer
|
|
54
|
+
from sentence_transformers import SentenceTransformer
|
|
55
55
|
|
|
56
56
|
# specifying the device, moves the model to device (gpu:cuda/mps, cpu)
|
|
57
57
|
model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
|
|
@@ -170,7 +170,7 @@ def clip(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Flo
|
|
|
170
170
|
env.Env.get().require_package('transformers')
|
|
171
171
|
device = resolve_torch_device('auto')
|
|
172
172
|
import torch
|
|
173
|
-
from transformers import CLIPModel, CLIPProcessor
|
|
173
|
+
from transformers import CLIPModel, CLIPProcessor
|
|
174
174
|
|
|
175
175
|
model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
|
|
176
176
|
processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
|
|
@@ -395,19 +395,21 @@ def speech2text_for_conditional_generation(audio: pxt.Audio, *, model_id: str, l
|
|
|
395
395
|
device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
|
|
396
396
|
import torch
|
|
397
397
|
import torchaudio # type: ignore[import-untyped]
|
|
398
|
-
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
|
398
|
+
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor, Speech2TextTokenizer
|
|
399
399
|
|
|
400
400
|
model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
|
|
401
401
|
processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
|
|
402
|
+
tokenizer = processor.tokenizer
|
|
402
403
|
assert isinstance(processor, Speech2TextProcessor)
|
|
404
|
+
assert isinstance(tokenizer, Speech2TextTokenizer)
|
|
403
405
|
|
|
404
|
-
if language is not None and language not in
|
|
406
|
+
if language is not None and language not in tokenizer.lang_code_to_id:
|
|
405
407
|
raise excs.Error(
|
|
406
408
|
f"Language code '{language}' is not supported by the model '{model_id}'. "
|
|
407
|
-
f'Supported languages are: {list(
|
|
409
|
+
f'Supported languages are: {list(tokenizer.lang_code_to_id.keys())}'
|
|
408
410
|
)
|
|
409
411
|
|
|
410
|
-
forced_bos_token_id: Optional[int] = None if language is None else
|
|
412
|
+
forced_bos_token_id: Optional[int] = None if language is None else tokenizer.lang_code_to_id[language]
|
|
411
413
|
|
|
412
414
|
# Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
|
|
413
415
|
model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
|
|
@@ -5,7 +5,7 @@ first `pip install mistralai` and configure your Mistral AI credentials, as desc
|
|
|
5
5
|
the [Working with Mistral AI](https://pixeltable.readme.io/docs/working-with-mistralai) tutorial.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -16,7 +16,7 @@ from pixeltable.func.signature import Batch
|
|
|
16
16
|
from pixeltable.utils.code import local_public_names
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
|
-
import mistralai
|
|
19
|
+
import mistralai
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@register_client('mistral')
|
|
@@ -54,8 +54,6 @@ async def chat_completions(
|
|
|
54
54
|
model_kwargs: Additional keyword args for the Mistral `chat/completions` API.
|
|
55
55
|
For details on the available parameters, see: <https://docs.mistral.ai/api/#tag/chat>
|
|
56
56
|
|
|
57
|
-
For details on the other parameters, see: <https://docs.mistral.ai/api/#tag/chat>
|
|
58
|
-
|
|
59
57
|
Returns:
|
|
60
58
|
A dictionary containing the response and other metadata.
|
|
61
59
|
|
|
@@ -156,15 +154,6 @@ def _(model: str) -> ts.ArrayType:
|
|
|
156
154
|
return ts.ArrayType((dimensions,), dtype=ts.FloatType())
|
|
157
155
|
|
|
158
156
|
|
|
159
|
-
_T = TypeVar('_T')
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def _opt(arg: Optional[_T]) -> Union[_T, 'mistralai.types.basemodel.Unset']:
|
|
163
|
-
from mistralai.types import UNSET
|
|
164
|
-
|
|
165
|
-
return arg if arg is not None else UNSET
|
|
166
|
-
|
|
167
|
-
|
|
168
157
|
__all__ = local_public_names(__name__)
|
|
169
158
|
|
|
170
159
|
|
pixeltable/functions/openai.py
CHANGED
|
@@ -205,12 +205,7 @@ async def speech(input: str, *, model: str, voice: str, model_kwargs: Optional[d
|
|
|
205
205
|
if model_kwargs is None:
|
|
206
206
|
model_kwargs = {}
|
|
207
207
|
|
|
208
|
-
content = await _openai_client().audio.speech.create(
|
|
209
|
-
input=input,
|
|
210
|
-
model=model,
|
|
211
|
-
voice=voice, # type: ignore
|
|
212
|
-
**model_kwargs,
|
|
213
|
-
)
|
|
208
|
+
content = await _openai_client().audio.speech.create(input=input, model=model, voice=voice, **model_kwargs)
|
|
214
209
|
ext = model_kwargs.get('response_format', 'mp3')
|
|
215
210
|
output_filename = str(env.Env.get().tmp_dir / f'{uuid.uuid4()}.{ext}')
|
|
216
211
|
content.write_to_file(output_filename)
|
|
@@ -12,7 +12,7 @@ from pixeltable.env import Env, register_client
|
|
|
12
12
|
from pixeltable.utils.code import local_public_names
|
|
13
13
|
|
|
14
14
|
if TYPE_CHECKING:
|
|
15
|
-
import replicate
|
|
15
|
+
import replicate
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
@register_client('replicate')
|
|
@@ -27,7 +27,7 @@ def _replicate_client() -> 'replicate.Client':
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
@pxt.udf(resource_pool='request-rate:replicate')
|
|
30
|
-
async def run(input: dict[str, Any], *, ref: str) ->
|
|
30
|
+
async def run(input: dict[str, Any], *, ref: str) -> pxt.Json:
|
|
31
31
|
"""
|
|
32
32
|
Run a model on Replicate.
|
|
33
33
|
|
pixeltable/functions/util.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import PIL.Image
|
|
2
2
|
|
|
3
|
+
from pixeltable.config import Config
|
|
3
4
|
from pixeltable.env import Env
|
|
4
5
|
|
|
5
6
|
|
|
@@ -7,10 +8,14 @@ def resolve_torch_device(device: str, allow_mps: bool = True) -> str:
|
|
|
7
8
|
Env.get().require_package('torch')
|
|
8
9
|
import torch
|
|
9
10
|
|
|
11
|
+
mps_enabled = Config.get().get_bool_value('enable_mps')
|
|
12
|
+
if mps_enabled is None:
|
|
13
|
+
mps_enabled = True # Default to True if not set in config
|
|
14
|
+
|
|
10
15
|
if device == 'auto':
|
|
11
16
|
if torch.cuda.is_available():
|
|
12
17
|
return 'cuda'
|
|
13
|
-
if allow_mps and torch.backends.mps.is_available():
|
|
18
|
+
if mps_enabled and allow_mps and torch.backends.mps.is_available():
|
|
14
19
|
return 'mps'
|
|
15
20
|
return 'cpu'
|
|
16
21
|
return device
|
pixeltable/globals.py
CHANGED
|
@@ -428,8 +428,6 @@ def get_table(path: str) -> catalog.Table:
|
|
|
428
428
|
"""
|
|
429
429
|
path_obj = catalog.Path(path)
|
|
430
430
|
tbl = Catalog.get().get_table(path_obj)
|
|
431
|
-
tv = tbl._tbl_version.get()
|
|
432
|
-
_logger.debug(f'get_table(): tbl={tv.id}:{tv.effective_version} sa_tbl={id(tv.store_tbl.sa_tbl):x} tv={id(tv):x}')
|
|
433
431
|
return tbl
|
|
434
432
|
|
|
435
433
|
|
pixeltable/io/external_store.py
CHANGED
|
@@ -202,7 +202,7 @@ class Project(ExternalStore, abc.ABC):
|
|
|
202
202
|
resolved_col_mapping: dict[Column, str] = {}
|
|
203
203
|
|
|
204
204
|
# Validate names
|
|
205
|
-
t_cols = set(table.
|
|
205
|
+
t_cols = set(table._get_schema().keys())
|
|
206
206
|
for t_col, ext_col in col_mapping.items():
|
|
207
207
|
if t_col not in t_cols:
|
|
208
208
|
if is_user_specified_col_mapping:
|
|
@@ -225,7 +225,7 @@ class Project(ExternalStore, abc.ABC):
|
|
|
225
225
|
assert isinstance(col_ref, exprs.ColumnRef)
|
|
226
226
|
resolved_col_mapping[col_ref.col] = ext_col
|
|
227
227
|
# Validate column specs
|
|
228
|
-
t_col_types = table.
|
|
228
|
+
t_col_types = table._get_schema()
|
|
229
229
|
for t_col, ext_col in col_mapping.items():
|
|
230
230
|
t_col_type = t_col_types[t_col]
|
|
231
231
|
if ext_col in export_cols:
|
pixeltable/io/label_studio.py
CHANGED
|
@@ -412,8 +412,8 @@ class LabelStudioProject(Project):
|
|
|
412
412
|
# TODO(aaron-siegel): Simplify this once propagation is properly implemented in batch_update
|
|
413
413
|
ancestor = t
|
|
414
414
|
while local_annotations_col not in ancestor._tbl_version.get().cols:
|
|
415
|
-
assert ancestor.
|
|
416
|
-
ancestor = ancestor.
|
|
415
|
+
assert ancestor._get_base_table is not None
|
|
416
|
+
ancestor = ancestor._get_base_table()
|
|
417
417
|
update_status = ancestor.batch_update(updates)
|
|
418
418
|
env.Env.get().console_logger.info(f'Updated annotation(s) from {len(updates)} task(s) in {self}.')
|
|
419
419
|
return SyncStatus(pxt_rows_updated=update_status.num_rows, num_excs=update_status.num_excs)
|
|
@@ -560,7 +560,7 @@ class LabelStudioProject(Project):
|
|
|
560
560
|
|
|
561
561
|
if name is None:
|
|
562
562
|
# Create a default name that's unique to the table
|
|
563
|
-
all_stores = t.external_stores
|
|
563
|
+
all_stores = t.external_stores()
|
|
564
564
|
n = 0
|
|
565
565
|
while f'ls_project_{n}' in all_stores:
|
|
566
566
|
n += 1
|
|
@@ -576,7 +576,7 @@ class LabelStudioProject(Project):
|
|
|
576
576
|
local_annotations_column = ANNOTATIONS_COLUMN
|
|
577
577
|
else:
|
|
578
578
|
local_annotations_column = next(k for k, v in col_mapping.items() if v == ANNOTATIONS_COLUMN)
|
|
579
|
-
if local_annotations_column not in t.
|
|
579
|
+
if local_annotations_column not in t._get_schema():
|
|
580
580
|
t.add_columns({local_annotations_column: ts.Json})
|
|
581
581
|
|
|
582
582
|
resolved_col_mapping = cls.validate_columns(
|
|
@@ -101,7 +101,7 @@ class TableDataConduit:
|
|
|
101
101
|
def add_table_info(self, table: pxt.Table) -> None:
|
|
102
102
|
"""Add information about the table into which we are inserting data"""
|
|
103
103
|
assert isinstance(table, pxt.Table)
|
|
104
|
-
self.pxt_schema = table.
|
|
104
|
+
self.pxt_schema = table._get_schema()
|
|
105
105
|
self.pxt_pk = table._tbl_version.get().primary_key
|
|
106
106
|
for col in table._tbl_version_path.columns():
|
|
107
107
|
if col.is_required_for_insert:
|
pixeltable/metadata/__init__.py
CHANGED
|
@@ -18,7 +18,7 @@ _console_logger = ConsoleLogger(logging.getLogger('pixeltable'))
|
|
|
18
18
|
_logger = logging.getLogger('pixeltable')
|
|
19
19
|
|
|
20
20
|
# current version of the metadata; this is incremented whenever the metadata schema changes
|
|
21
|
-
VERSION =
|
|
21
|
+
VERSION = 38
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def create_system_info(engine: sql.engine.Engine) -> None:
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from uuid import UUID
|
|
2
|
+
|
|
3
|
+
import sqlalchemy as sql
|
|
4
|
+
|
|
5
|
+
from pixeltable.metadata import register_converter
|
|
6
|
+
from pixeltable.metadata.converters.util import convert_table_md
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@register_converter(version=37)
|
|
10
|
+
def _(engine: sql.engine.Engine) -> None:
|
|
11
|
+
convert_table_md(engine, table_md_updater=__update_table_md)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def __update_table_md(table_md: dict, _: UUID) -> None:
|
|
15
|
+
table_md['view_sn'] = 0
|
pixeltable/metadata/notes.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# rather than as a comment, so that the existence of a description can be enforced by
|
|
3
3
|
# the unit tests when new versions are added.
|
|
4
4
|
VERSION_NOTES = {
|
|
5
|
+
38: 'Added TableMd.view_sn',
|
|
5
6
|
37: 'Add support for the sample() method on DataFrames',
|
|
6
7
|
36: 'Added Table.lock_dummy',
|
|
7
8
|
35: 'Track reference_tbl in ColumnRef',
|
pixeltable/metadata/schema.py
CHANGED
|
@@ -177,6 +177,11 @@ class TableMd:
|
|
|
177
177
|
# - every row is assigned a unique and immutable rowid on insertion
|
|
178
178
|
next_row_id: int
|
|
179
179
|
|
|
180
|
+
# sequence number to track changes in the set of mutable views of this table (ie, this table = the view base)
|
|
181
|
+
# - incremented for each add/drop of a mutable view
|
|
182
|
+
# - only maintained for mutable tables
|
|
183
|
+
view_sn: int
|
|
184
|
+
|
|
180
185
|
# Metadata format for external stores:
|
|
181
186
|
# {'class': 'pixeltable.io.label_studio.LabelStudioProject', 'md': {'project_id': 3}}
|
|
182
187
|
external_stores: list[dict[str, Any]]
|