pixeltable 0.4.0rc3__py3-none-any.whl → 0.4.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +23 -5
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +5 -3
- pixeltable/catalog/catalog.py +1318 -404
- pixeltable/catalog/column.py +186 -115
- pixeltable/catalog/dir.py +1 -2
- pixeltable/catalog/globals.py +11 -43
- pixeltable/catalog/insertable_table.py +167 -79
- pixeltable/catalog/path.py +61 -23
- pixeltable/catalog/schema_object.py +9 -10
- pixeltable/catalog/table.py +626 -308
- pixeltable/catalog/table_metadata.py +101 -0
- pixeltable/catalog/table_version.py +713 -569
- pixeltable/catalog/table_version_handle.py +37 -6
- pixeltable/catalog/table_version_path.py +42 -29
- pixeltable/catalog/tbl_ops.py +50 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +108 -94
- pixeltable/config.py +128 -22
- pixeltable/dataframe.py +188 -100
- pixeltable/env.py +407 -136
- pixeltable/exceptions.py +6 -0
- pixeltable/exec/__init__.py +3 -0
- pixeltable/exec/aggregation_node.py +7 -8
- pixeltable/exec/cache_prefetch_node.py +83 -110
- pixeltable/exec/cell_materialization_node.py +231 -0
- pixeltable/exec/cell_reconstruction_node.py +135 -0
- pixeltable/exec/component_iteration_node.py +4 -3
- pixeltable/exec/data_row_batch.py +8 -65
- pixeltable/exec/exec_context.py +16 -4
- pixeltable/exec/exec_node.py +13 -36
- pixeltable/exec/expr_eval/evaluators.py +7 -6
- pixeltable/exec/expr_eval/expr_eval_node.py +27 -12
- pixeltable/exec/expr_eval/globals.py +8 -5
- pixeltable/exec/expr_eval/row_buffer.py +1 -2
- pixeltable/exec/expr_eval/schedulers.py +190 -30
- pixeltable/exec/globals.py +32 -0
- pixeltable/exec/in_memory_data_node.py +18 -18
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +16 -9
- pixeltable/exec/sql_node.py +206 -101
- pixeltable/exprs/__init__.py +1 -1
- pixeltable/exprs/arithmetic_expr.py +27 -22
- pixeltable/exprs/array_slice.py +3 -3
- pixeltable/exprs/column_property_ref.py +34 -30
- pixeltable/exprs/column_ref.py +92 -96
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +5 -4
- pixeltable/exprs/data_row.py +152 -55
- pixeltable/exprs/expr.py +62 -43
- pixeltable/exprs/expr_dict.py +3 -3
- pixeltable/exprs/expr_set.py +17 -10
- pixeltable/exprs/function_call.py +75 -37
- pixeltable/exprs/globals.py +1 -2
- pixeltable/exprs/in_predicate.py +4 -4
- pixeltable/exprs/inline_expr.py +10 -27
- pixeltable/exprs/is_null.py +1 -3
- pixeltable/exprs/json_mapper.py +8 -8
- pixeltable/exprs/json_path.py +56 -22
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +2 -2
- pixeltable/exprs/object_ref.py +2 -2
- pixeltable/exprs/row_builder.py +127 -53
- pixeltable/exprs/rowid_ref.py +8 -12
- pixeltable/exprs/similarity_expr.py +50 -25
- pixeltable/exprs/sql_element_cache.py +4 -4
- pixeltable/exprs/string_op.py +5 -5
- pixeltable/exprs/type_cast.py +3 -5
- pixeltable/func/__init__.py +1 -0
- pixeltable/func/aggregate_function.py +8 -8
- pixeltable/func/callable_function.py +9 -9
- pixeltable/func/expr_template_function.py +10 -10
- pixeltable/func/function.py +18 -20
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/globals.py +2 -3
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +20 -18
- pixeltable/func/signature.py +43 -16
- pixeltable/func/tools.py +23 -13
- pixeltable/func/udf.py +18 -20
- pixeltable/functions/__init__.py +6 -0
- pixeltable/functions/anthropic.py +93 -33
- pixeltable/functions/audio.py +114 -10
- pixeltable/functions/bedrock.py +13 -6
- pixeltable/functions/date.py +1 -1
- pixeltable/functions/deepseek.py +20 -9
- pixeltable/functions/fireworks.py +2 -2
- pixeltable/functions/gemini.py +28 -11
- pixeltable/functions/globals.py +13 -13
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1046 -23
- pixeltable/functions/image.py +9 -18
- pixeltable/functions/llama_cpp.py +23 -8
- pixeltable/functions/math.py +3 -4
- pixeltable/functions/mistralai.py +4 -15
- pixeltable/functions/ollama.py +16 -9
- pixeltable/functions/openai.py +104 -82
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +2 -2
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +21 -28
- pixeltable/functions/timestamp.py +13 -14
- pixeltable/functions/together.py +4 -6
- pixeltable/functions/twelvelabs.py +92 -0
- pixeltable/functions/util.py +6 -1
- pixeltable/functions/video.py +1388 -106
- pixeltable/functions/vision.py +7 -7
- pixeltable/functions/whisper.py +15 -7
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/{ext/functions → functions}/yolox.py +2 -4
- pixeltable/globals.py +332 -105
- pixeltable/index/base.py +13 -22
- pixeltable/index/btree.py +23 -22
- pixeltable/index/embedding_index.py +32 -44
- pixeltable/io/__init__.py +4 -2
- pixeltable/io/datarows.py +7 -6
- pixeltable/io/external_store.py +49 -77
- pixeltable/io/fiftyone.py +11 -11
- pixeltable/io/globals.py +29 -28
- pixeltable/io/hf_datasets.py +17 -9
- pixeltable/io/label_studio.py +70 -66
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +12 -11
- pixeltable/io/parquet.py +13 -93
- pixeltable/io/table_data_conduit.py +71 -47
- pixeltable/io/utils.py +3 -3
- pixeltable/iterators/__init__.py +2 -1
- pixeltable/iterators/audio.py +21 -11
- pixeltable/iterators/document.py +116 -55
- pixeltable/iterators/image.py +5 -2
- pixeltable/iterators/video.py +293 -13
- pixeltable/metadata/__init__.py +4 -2
- pixeltable/metadata/converters/convert_18.py +2 -2
- pixeltable/metadata/converters/convert_19.py +2 -2
- pixeltable/metadata/converters/convert_20.py +2 -2
- pixeltable/metadata/converters/convert_21.py +2 -2
- pixeltable/metadata/converters/convert_22.py +2 -2
- pixeltable/metadata/converters/convert_24.py +2 -2
- pixeltable/metadata/converters/convert_25.py +2 -2
- pixeltable/metadata/converters/convert_26.py +2 -2
- pixeltable/metadata/converters/convert_29.py +4 -4
- pixeltable/metadata/converters/convert_34.py +2 -2
- pixeltable/metadata/converters/convert_36.py +2 -2
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/util.py +13 -12
- pixeltable/metadata/notes.py +4 -0
- pixeltable/metadata/schema.py +79 -42
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +274 -223
- pixeltable/share/__init__.py +1 -1
- pixeltable/share/packager.py +259 -129
- pixeltable/share/protocol/__init__.py +34 -0
- pixeltable/share/protocol/common.py +170 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +109 -0
- pixeltable/share/publish.py +213 -57
- pixeltable/store.py +238 -175
- pixeltable/type_system.py +104 -63
- pixeltable/utils/__init__.py +2 -3
- pixeltable/utils/arrow.py +108 -13
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +305 -0
- pixeltable/utils/code.py +3 -3
- pixeltable/utils/console_output.py +4 -1
- pixeltable/utils/coroutine.py +6 -23
- pixeltable/utils/dbms.py +31 -5
- pixeltable/utils/description_helper.py +4 -5
- pixeltable/utils/documents.py +5 -6
- pixeltable/utils/exception_handler.py +7 -30
- pixeltable/utils/filecache.py +6 -6
- pixeltable/utils/formatter.py +4 -6
- pixeltable/utils/gcs_store.py +283 -0
- pixeltable/utils/http_server.py +2 -3
- pixeltable/utils/iceberg.py +1 -2
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +88 -0
- pixeltable/utils/local_store.py +316 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +528 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +5 -6
- pixeltable/utils/s3_store.py +392 -0
- pixeltable-0.4.20.dist-info/METADATA +587 -0
- pixeltable-0.4.20.dist-info/RECORD +218 -0
- {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.20.dist-info}/WHEEL +1 -1
- pixeltable-0.4.20.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/ext/__init__.py +0 -17
- pixeltable/ext/functions/__init__.py +0 -11
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/utils/media_store.py +0 -77
- pixeltable/utils/s3.py +0 -17
- pixeltable/utils/sample.py +0 -25
- pixeltable-0.4.0rc3.dist-info/METADATA +0 -435
- pixeltable-0.4.0rc3.dist-info/RECORD +0 -189
- pixeltable-0.4.0rc3.dist-info/entry_points.txt +0 -3
- {pixeltable-0.4.0rc3.dist-info → pixeltable-0.4.20.dist-info/licenses}/LICENSE +0 -0
pixeltable/func/signature.py
CHANGED
|
@@ -4,7 +4,7 @@ import dataclasses
|
|
|
4
4
|
import inspect
|
|
5
5
|
import logging
|
|
6
6
|
import typing
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
|
8
8
|
|
|
9
9
|
import pixeltable.exceptions as excs
|
|
10
10
|
import pixeltable.type_system as ts
|
|
@@ -18,11 +18,11 @@ _logger = logging.getLogger('pixeltable')
|
|
|
18
18
|
@dataclasses.dataclass
|
|
19
19
|
class Parameter:
|
|
20
20
|
name: str
|
|
21
|
-
col_type:
|
|
21
|
+
col_type: ts.ColumnType | None # None for variable parameters
|
|
22
22
|
kind: inspect._ParameterKind
|
|
23
23
|
# for some reason, this needs to precede is_batched in the dataclass definition,
|
|
24
24
|
# otherwise Python complains that an argument with a default is followed by an argument without a default
|
|
25
|
-
default:
|
|
25
|
+
default: 'exprs.Literal' | None = None # default value for the parameter
|
|
26
26
|
is_batched: bool = False # True if the parameter is a batched parameter (eg, Batch[dict])
|
|
27
27
|
|
|
28
28
|
def __post_init__(self) -> None:
|
|
@@ -84,8 +84,28 @@ class Signature:
|
|
|
84
84
|
"""
|
|
85
85
|
|
|
86
86
|
SPECIAL_PARAM_NAMES: ClassVar[list[str]] = ['group_by', 'order_by']
|
|
87
|
-
|
|
88
|
-
|
|
87
|
+
SYSTEM_PARAM_NAMES: ClassVar[list[str]] = ['_runtime_ctx']
|
|
88
|
+
|
|
89
|
+
return_type: ts.ColumnType
|
|
90
|
+
is_batched: bool
|
|
91
|
+
parameters: dict[str, Parameter] # name -> Parameter
|
|
92
|
+
parameters_by_pos: list[Parameter] # ordered by position in the signature
|
|
93
|
+
constant_parameters: list[Parameter] # parameters that are not batched
|
|
94
|
+
batched_parameters: list[Parameter] # parameters that are batched
|
|
95
|
+
required_parameters: list[Parameter] # parameters that do not have a default value
|
|
96
|
+
|
|
97
|
+
# the names of recognized system parameters in the signature; these are excluded from self.parameters
|
|
98
|
+
system_parameters: list[str]
|
|
99
|
+
|
|
100
|
+
py_signature: inspect.Signature
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
return_type: ts.ColumnType,
|
|
105
|
+
parameters: list[Parameter],
|
|
106
|
+
is_batched: bool = False,
|
|
107
|
+
system_parameters: list[str] | None = None,
|
|
108
|
+
):
|
|
89
109
|
assert isinstance(return_type, ts.ColumnType)
|
|
90
110
|
self.return_type = return_type
|
|
91
111
|
self.is_batched = is_batched
|
|
@@ -95,6 +115,7 @@ class Signature:
|
|
|
95
115
|
self.constant_parameters = [p for p in parameters if not p.is_batched]
|
|
96
116
|
self.batched_parameters = [p for p in parameters if p.is_batched]
|
|
97
117
|
self.required_parameters = [p for p in parameters if not p.has_default()]
|
|
118
|
+
self.system_parameters = system_parameters if system_parameters is not None else []
|
|
98
119
|
self.py_signature = inspect.Signature([p.to_py_param() for p in self.parameters_by_pos])
|
|
99
120
|
|
|
100
121
|
def get_return_type(self) -> ts.ColumnType:
|
|
@@ -151,7 +172,7 @@ class Signature:
|
|
|
151
172
|
|
|
152
173
|
return True
|
|
153
174
|
|
|
154
|
-
def validate_args(self, bound_args: dict[str,
|
|
175
|
+
def validate_args(self, bound_args: dict[str, 'exprs.Expr' | None], context: str = '') -> None:
|
|
155
176
|
if context:
|
|
156
177
|
context = f' ({context})'
|
|
157
178
|
|
|
@@ -210,11 +231,11 @@ class Signature:
|
|
|
210
231
|
return f'({", ".join(param_strs)}) -> {self.get_return_type()}'
|
|
211
232
|
|
|
212
233
|
@classmethod
|
|
213
|
-
def _infer_type(cls, annotation:
|
|
234
|
+
def _infer_type(cls, annotation: type | None) -> tuple[ts.ColumnType | None, bool | None]:
|
|
214
235
|
"""Returns: (column type, is_batched) or (None, ...) if the type cannot be inferred"""
|
|
215
236
|
if annotation is None:
|
|
216
237
|
return (None, None)
|
|
217
|
-
py_type:
|
|
238
|
+
py_type: type | None = None
|
|
218
239
|
is_batched = False
|
|
219
240
|
if typing.get_origin(annotation) == typing.Annotated:
|
|
220
241
|
type_args = typing.get_args(annotation)
|
|
@@ -231,12 +252,13 @@ class Signature:
|
|
|
231
252
|
@classmethod
|
|
232
253
|
def create_parameters(
|
|
233
254
|
cls,
|
|
234
|
-
py_fn:
|
|
235
|
-
py_params:
|
|
236
|
-
param_types:
|
|
237
|
-
type_substitutions:
|
|
255
|
+
py_fn: Callable | None = None,
|
|
256
|
+
py_params: list[inspect.Parameter] | None = None,
|
|
257
|
+
param_types: list[ts.ColumnType] | None = None,
|
|
258
|
+
type_substitutions: dict | None = None,
|
|
238
259
|
is_cls_method: bool = False,
|
|
239
260
|
) -> list[Parameter]:
|
|
261
|
+
"""Ignores parameters starting with '_'."""
|
|
240
262
|
from pixeltable import exprs
|
|
241
263
|
|
|
242
264
|
assert (py_fn is None) != (py_params is None)
|
|
@@ -251,6 +273,10 @@ class Signature:
|
|
|
251
273
|
for idx, param in enumerate(py_params):
|
|
252
274
|
if is_cls_method and idx == 0:
|
|
253
275
|
continue # skip 'self' or 'cls' parameter
|
|
276
|
+
if param.name in cls.SYSTEM_PARAM_NAMES:
|
|
277
|
+
continue # skip system parameters
|
|
278
|
+
if param.name.startswith('_'):
|
|
279
|
+
raise excs.Error(f"{param.name!r}: parameters starting with '_' are reserved")
|
|
254
280
|
if param.name in cls.SPECIAL_PARAM_NAMES:
|
|
255
281
|
raise excs.Error(f'{param.name!r} is a reserved parameter name')
|
|
256
282
|
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
@@ -284,9 +310,9 @@ class Signature:
|
|
|
284
310
|
def create(
|
|
285
311
|
cls,
|
|
286
312
|
py_fn: Callable,
|
|
287
|
-
param_types:
|
|
288
|
-
return_type:
|
|
289
|
-
type_substitutions:
|
|
313
|
+
param_types: list[ts.ColumnType] | None = None,
|
|
314
|
+
return_type: ts.ColumnType | None = None,
|
|
315
|
+
type_substitutions: dict | None = None,
|
|
290
316
|
is_cls_method: bool = False,
|
|
291
317
|
) -> Signature:
|
|
292
318
|
"""Create a signature for the given Callable.
|
|
@@ -308,5 +334,6 @@ class Signature:
|
|
|
308
334
|
raise excs.Error('Cannot infer pixeltable return type')
|
|
309
335
|
else:
|
|
310
336
|
_, return_is_batched = cls._infer_type(sig.return_annotation)
|
|
337
|
+
system_params = [param_name for param_name in sig.parameters if param_name in cls.SYSTEM_PARAM_NAMES]
|
|
311
338
|
|
|
312
|
-
return Signature(return_type, parameters, return_is_batched)
|
|
339
|
+
return Signature(return_type, parameters, return_is_batched, system_parameters=system_params)
|
pixeltable/func/tools.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
|
|
1
|
+
import json
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
|
2
3
|
|
|
3
4
|
import pydantic
|
|
4
5
|
|
|
5
|
-
from pixeltable import exceptions as excs
|
|
6
|
+
from pixeltable import exceptions as excs, type_system as ts
|
|
6
7
|
|
|
7
8
|
from .function import Function
|
|
8
9
|
from .signature import Parameter
|
|
@@ -28,8 +29,8 @@ class Tool(pydantic.BaseModel):
|
|
|
28
29
|
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
|
29
30
|
|
|
30
31
|
fn: Function
|
|
31
|
-
name:
|
|
32
|
-
description:
|
|
32
|
+
name: str | None = None
|
|
33
|
+
description: str | None = None
|
|
33
34
|
|
|
34
35
|
@property
|
|
35
36
|
def parameters(self) -> dict[str, Parameter]:
|
|
@@ -69,13 +70,15 @@ class Tool(pydantic.BaseModel):
|
|
|
69
70
|
return _extract_float_tool_arg(kwargs, param_name=param.name)
|
|
70
71
|
if param.col_type.is_bool_type():
|
|
71
72
|
return _extract_bool_tool_arg(kwargs, param_name=param.name)
|
|
72
|
-
|
|
73
|
+
if param.col_type.is_json_type():
|
|
74
|
+
return _extract_json_tool_arg(kwargs, param_name=param.name)
|
|
75
|
+
raise AssertionError(param.col_type)
|
|
73
76
|
|
|
74
77
|
|
|
75
78
|
class ToolChoice(pydantic.BaseModel):
|
|
76
79
|
auto: bool
|
|
77
80
|
required: bool
|
|
78
|
-
tool:
|
|
81
|
+
tool: str | None
|
|
79
82
|
parallel_tool_calls: bool
|
|
80
83
|
|
|
81
84
|
|
|
@@ -97,12 +100,12 @@ class Tools(pydantic.BaseModel):
|
|
|
97
100
|
self,
|
|
98
101
|
auto: bool = False,
|
|
99
102
|
required: bool = False,
|
|
100
|
-
tool:
|
|
103
|
+
tool: str | Function | None = None,
|
|
101
104
|
parallel_tool_calls: bool = True,
|
|
102
105
|
) -> ToolChoice:
|
|
103
106
|
if sum([auto, required, tool is not None]) != 1:
|
|
104
107
|
raise excs.Error('Exactly one of `auto`, `required`, or `tool` must be specified.')
|
|
105
|
-
tool_name:
|
|
108
|
+
tool_name: str | None = None
|
|
106
109
|
if tool is not None:
|
|
107
110
|
try:
|
|
108
111
|
tool_obj = next(
|
|
@@ -118,29 +121,36 @@ class Tools(pydantic.BaseModel):
|
|
|
118
121
|
|
|
119
122
|
|
|
120
123
|
@udf
|
|
121
|
-
def _extract_str_tool_arg(kwargs: dict[str, Any], param_name: str) ->
|
|
124
|
+
def _extract_str_tool_arg(kwargs: dict[str, Any], param_name: str) -> str | None:
|
|
122
125
|
return _extract_arg(str, kwargs, param_name)
|
|
123
126
|
|
|
124
127
|
|
|
125
128
|
@udf
|
|
126
|
-
def _extract_int_tool_arg(kwargs: dict[str, Any], param_name: str) ->
|
|
129
|
+
def _extract_int_tool_arg(kwargs: dict[str, Any], param_name: str) -> int | None:
|
|
127
130
|
return _extract_arg(int, kwargs, param_name)
|
|
128
131
|
|
|
129
132
|
|
|
130
133
|
@udf
|
|
131
|
-
def _extract_float_tool_arg(kwargs: dict[str, Any], param_name: str) ->
|
|
134
|
+
def _extract_float_tool_arg(kwargs: dict[str, Any], param_name: str) -> float | None:
|
|
132
135
|
return _extract_arg(float, kwargs, param_name)
|
|
133
136
|
|
|
134
137
|
|
|
135
138
|
@udf
|
|
136
|
-
def _extract_bool_tool_arg(kwargs: dict[str, Any], param_name: str) ->
|
|
139
|
+
def _extract_bool_tool_arg(kwargs: dict[str, Any], param_name: str) -> bool | None:
|
|
137
140
|
return _extract_arg(bool, kwargs, param_name)
|
|
138
141
|
|
|
139
142
|
|
|
143
|
+
@udf
|
|
144
|
+
def _extract_json_tool_arg(kwargs: dict[str, Any], param_name: str) -> ts.Json | None:
|
|
145
|
+
if param_name in kwargs:
|
|
146
|
+
return json.loads(kwargs[param_name])
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
|
|
140
150
|
T = TypeVar('T')
|
|
141
151
|
|
|
142
152
|
|
|
143
|
-
def _extract_arg(eval_fn: Callable[[Any], T], kwargs: dict[str, Any], param_name: str) ->
|
|
153
|
+
def _extract_arg(eval_fn: Callable[[Any], T], kwargs: dict[str, Any], param_name: str) -> T | None:
|
|
144
154
|
if param_name in kwargs:
|
|
145
155
|
return eval_fn(kwargs[param_name])
|
|
146
156
|
return None
|
pixeltable/func/udf.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Sequence, overload
|
|
5
5
|
|
|
6
6
|
import pixeltable.exceptions as excs
|
|
7
7
|
import pixeltable.type_system as ts
|
|
@@ -26,12 +26,12 @@ def udf(decorated_fn: Callable) -> CallableFunction: ...
|
|
|
26
26
|
@overload
|
|
27
27
|
def udf(
|
|
28
28
|
*,
|
|
29
|
-
batch_size:
|
|
30
|
-
substitute_fn:
|
|
29
|
+
batch_size: int | None = None,
|
|
30
|
+
substitute_fn: Callable | None = None,
|
|
31
31
|
is_method: bool = False,
|
|
32
32
|
is_property: bool = False,
|
|
33
|
-
resource_pool:
|
|
34
|
-
type_substitutions:
|
|
33
|
+
resource_pool: str | None = None,
|
|
34
|
+
type_substitutions: Sequence[dict] | None = None,
|
|
35
35
|
_force_stored: bool = False,
|
|
36
36
|
) -> Callable[[Callable], CallableFunction]: ...
|
|
37
37
|
|
|
@@ -39,7 +39,7 @@ def udf(
|
|
|
39
39
|
# pxt.udf() called explicitly on a Table:
|
|
40
40
|
@overload
|
|
41
41
|
def udf(
|
|
42
|
-
table: catalog.Table, /, *, return_value: Any = None, description:
|
|
42
|
+
table: catalog.Table, /, *, return_value: Any = None, description: str | None = None
|
|
43
43
|
) -> ExprTemplateFunction: ...
|
|
44
44
|
|
|
45
45
|
|
|
@@ -96,15 +96,15 @@ def udf(*args, **kwargs): # type: ignore[no-untyped-def]
|
|
|
96
96
|
|
|
97
97
|
def make_function(
|
|
98
98
|
decorated_fn: Callable,
|
|
99
|
-
return_type:
|
|
100
|
-
param_types:
|
|
101
|
-
batch_size:
|
|
102
|
-
substitute_fn:
|
|
99
|
+
return_type: ts.ColumnType | None = None,
|
|
100
|
+
param_types: list[ts.ColumnType] | None = None,
|
|
101
|
+
batch_size: int | None = None,
|
|
102
|
+
substitute_fn: Callable | None = None,
|
|
103
103
|
is_method: bool = False,
|
|
104
104
|
is_property: bool = False,
|
|
105
|
-
resource_pool:
|
|
106
|
-
type_substitutions:
|
|
107
|
-
function_name:
|
|
105
|
+
resource_pool: str | None = None,
|
|
106
|
+
type_substitutions: Sequence[dict] | None = None,
|
|
107
|
+
function_name: str | None = None,
|
|
108
108
|
force_stored: bool = False,
|
|
109
109
|
) -> CallableFunction:
|
|
110
110
|
"""
|
|
@@ -201,11 +201,11 @@ def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
|
|
|
201
201
|
|
|
202
202
|
|
|
203
203
|
@overload
|
|
204
|
-
def expr_udf(*, param_types:
|
|
204
|
+
def expr_udf(*, param_types: list[ts.ColumnType] | None = None) -> Callable[[Callable], ExprTemplateFunction]: ...
|
|
205
205
|
|
|
206
206
|
|
|
207
207
|
def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
208
|
-
def make_expr_template(py_fn: Callable, param_types:
|
|
208
|
+
def make_expr_template(py_fn: Callable, param_types: list[ts.ColumnType] | None) -> ExprTemplateFunction:
|
|
209
209
|
from pixeltable import exprs
|
|
210
210
|
|
|
211
211
|
if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
|
|
@@ -237,9 +237,7 @@ def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
|
237
237
|
return lambda py_fn: make_expr_template(py_fn, kwargs['param_types'])
|
|
238
238
|
|
|
239
239
|
|
|
240
|
-
def from_table(
|
|
241
|
-
tbl: catalog.Table, return_value: Optional['exprs.Expr'], description: Optional[str]
|
|
242
|
-
) -> ExprTemplateFunction:
|
|
240
|
+
def from_table(tbl: catalog.Table, return_value: 'exprs.Expr' | None, description: str | None) -> ExprTemplateFunction:
|
|
243
241
|
"""
|
|
244
242
|
Constructs an `ExprTemplateFunction` from a `Table`.
|
|
245
243
|
|
|
@@ -262,7 +260,7 @@ def from_table(
|
|
|
262
260
|
"""
|
|
263
261
|
from pixeltable import exprs
|
|
264
262
|
|
|
265
|
-
ancestors = [tbl, *tbl.
|
|
263
|
+
ancestors = [tbl, *tbl._get_base_tables()]
|
|
266
264
|
ancestors.reverse() # We must traverse the ancestors in order from base to derived
|
|
267
265
|
|
|
268
266
|
subst: dict[exprs.Expr, exprs.Expr] = {}
|
|
@@ -297,7 +295,7 @@ def from_table(
|
|
|
297
295
|
|
|
298
296
|
if description is None:
|
|
299
297
|
# Default description is the table comment
|
|
300
|
-
description = tbl.
|
|
298
|
+
description = tbl._get_comment()
|
|
301
299
|
if len(description) == 0:
|
|
302
300
|
description = f"UDF for table '{tbl._name}'"
|
|
303
301
|
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -10,6 +10,7 @@ from . import (
|
|
|
10
10
|
deepseek,
|
|
11
11
|
fireworks,
|
|
12
12
|
gemini,
|
|
13
|
+
groq,
|
|
13
14
|
huggingface,
|
|
14
15
|
image,
|
|
15
16
|
json,
|
|
@@ -18,13 +19,18 @@ from . import (
|
|
|
18
19
|
mistralai,
|
|
19
20
|
ollama,
|
|
20
21
|
openai,
|
|
22
|
+
openrouter,
|
|
21
23
|
replicate,
|
|
24
|
+
reve,
|
|
22
25
|
string,
|
|
23
26
|
timestamp,
|
|
24
27
|
together,
|
|
28
|
+
twelvelabs,
|
|
25
29
|
video,
|
|
26
30
|
vision,
|
|
27
31
|
whisper,
|
|
32
|
+
whisperx,
|
|
33
|
+
yolox,
|
|
28
34
|
)
|
|
29
35
|
from .globals import count, map, max, mean, min, sum
|
|
30
36
|
|
|
@@ -8,7 +8,7 @@ the [Working with Anthropic](https://pixeltable.readme.io/docs/working-with-anth
|
|
|
8
8
|
import datetime
|
|
9
9
|
import json
|
|
10
10
|
import logging
|
|
11
|
-
from typing import TYPE_CHECKING, Any, Iterable,
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Iterable, cast
|
|
12
12
|
|
|
13
13
|
import httpx
|
|
14
14
|
|
|
@@ -38,6 +38,64 @@ def _anthropic_client() -> 'anthropic.AsyncAnthropic':
|
|
|
38
38
|
return env.Env.get().get_client('anthropic')
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
def _get_header_info(
|
|
42
|
+
headers: httpx.Headers,
|
|
43
|
+
) -> tuple[
|
|
44
|
+
tuple[int, int, datetime.datetime] | None,
|
|
45
|
+
tuple[int, int, datetime.datetime] | None,
|
|
46
|
+
tuple[int, int, datetime.datetime] | None,
|
|
47
|
+
]:
|
|
48
|
+
"""Extract rate limit info from Anthropic API response headers."""
|
|
49
|
+
requests_limit_str = headers.get('anthropic-ratelimit-requests-limit')
|
|
50
|
+
requests_limit = int(requests_limit_str) if requests_limit_str is not None else None
|
|
51
|
+
requests_remaining_str = headers.get('anthropic-ratelimit-requests-remaining')
|
|
52
|
+
requests_remaining = int(requests_remaining_str) if requests_remaining_str is not None else None
|
|
53
|
+
requests_reset_str = headers.get('anthropic-ratelimit-requests-reset')
|
|
54
|
+
requests_reset = (
|
|
55
|
+
datetime.datetime.fromisoformat(requests_reset_str.replace('Z', '+00:00')) if requests_reset_str else None
|
|
56
|
+
)
|
|
57
|
+
requests_info = (
|
|
58
|
+
(requests_limit, requests_remaining, requests_reset) if requests_reset and requests_remaining else None
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
input_tokens_limit_str = headers.get('anthropic-ratelimit-input-tokens-limit')
|
|
62
|
+
input_tokens_limit = int(input_tokens_limit_str) if input_tokens_limit_str is not None else None
|
|
63
|
+
input_tokens_remaining_str = headers.get('anthropic-ratelimit-input-tokens-remaining')
|
|
64
|
+
input_tokens_remaining = int(input_tokens_remaining_str) if input_tokens_remaining_str is not None else None
|
|
65
|
+
input_tokens_reset_str = headers.get('anthropic-ratelimit-input-tokens-reset')
|
|
66
|
+
input_tokens_reset = (
|
|
67
|
+
datetime.datetime.fromisoformat(input_tokens_reset_str.replace('Z', '+00:00'))
|
|
68
|
+
if input_tokens_reset_str
|
|
69
|
+
else None
|
|
70
|
+
)
|
|
71
|
+
input_tokens_info = (
|
|
72
|
+
(input_tokens_limit, input_tokens_remaining, input_tokens_reset)
|
|
73
|
+
if input_tokens_reset and input_tokens_remaining
|
|
74
|
+
else None
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
output_tokens_limit_str = headers.get('anthropic-ratelimit-output-tokens-limit')
|
|
78
|
+
output_tokens_limit = int(output_tokens_limit_str) if output_tokens_limit_str is not None else None
|
|
79
|
+
output_tokens_remaining_str = headers.get('anthropic-ratelimit-output-tokens-remaining')
|
|
80
|
+
output_tokens_remaining = int(output_tokens_remaining_str) if output_tokens_remaining_str is not None else None
|
|
81
|
+
output_tokens_reset_str = headers.get('anthropic-ratelimit-output-tokens-reset')
|
|
82
|
+
output_tokens_reset = (
|
|
83
|
+
datetime.datetime.fromisoformat(output_tokens_reset_str.replace('Z', '+00:00'))
|
|
84
|
+
if output_tokens_reset_str
|
|
85
|
+
else None
|
|
86
|
+
)
|
|
87
|
+
output_tokens_info = (
|
|
88
|
+
(output_tokens_limit, output_tokens_remaining, output_tokens_reset)
|
|
89
|
+
if output_tokens_reset and output_tokens_remaining
|
|
90
|
+
else None
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if requests_info is None or input_tokens_info is None or output_tokens_info is None:
|
|
94
|
+
_logger.debug(f'get_header_info(): incomplete rate limit info: {headers}')
|
|
95
|
+
|
|
96
|
+
return requests_info, input_tokens_info, output_tokens_info
|
|
97
|
+
|
|
98
|
+
|
|
41
99
|
class AnthropicRateLimitsInfo(env.RateLimitsInfo):
|
|
42
100
|
def __init__(self) -> None:
|
|
43
101
|
super().__init__(self._get_request_resources)
|
|
@@ -51,7 +109,28 @@ class AnthropicRateLimitsInfo(env.RateLimitsInfo):
|
|
|
51
109
|
input_len += len(message['content'])
|
|
52
110
|
return {'requests': 1, 'input_tokens': int(input_len / 4), 'output_tokens': max_tokens}
|
|
53
111
|
|
|
54
|
-
def
|
|
112
|
+
def record_exc(self, exc: Exception) -> None:
|
|
113
|
+
import anthropic
|
|
114
|
+
|
|
115
|
+
if (
|
|
116
|
+
not isinstance(exc, anthropic.APIError)
|
|
117
|
+
or not hasattr(exc, 'response')
|
|
118
|
+
or not hasattr(exc.response, 'headers')
|
|
119
|
+
):
|
|
120
|
+
return
|
|
121
|
+
requests_info, input_tokens_info, output_tokens_info = _get_header_info(exc.response.headers)
|
|
122
|
+
_logger.debug(
|
|
123
|
+
f'record_exc(): requests_info={requests_info} input_tokens_info={input_tokens_info} '
|
|
124
|
+
f'output_tokens_info={output_tokens_info}'
|
|
125
|
+
)
|
|
126
|
+
self.record(requests=requests_info, input_tokens=input_tokens_info, output_tokens=output_tokens_info)
|
|
127
|
+
self.has_exc = True
|
|
128
|
+
|
|
129
|
+
retry_after_str = exc.response.headers.get('retry-after')
|
|
130
|
+
if retry_after_str is not None:
|
|
131
|
+
_logger.debug(f'retry-after: {retry_after_str}')
|
|
132
|
+
|
|
133
|
+
def get_retry_delay(self, exc: Exception) -> float | None:
|
|
55
134
|
import anthropic
|
|
56
135
|
|
|
57
136
|
# deal with timeouts separately, they don't come with headers
|
|
@@ -64,8 +143,7 @@ class AnthropicRateLimitsInfo(env.RateLimitsInfo):
|
|
|
64
143
|
should_retry_str = exc.response.headers.get('x-should-retry', '')
|
|
65
144
|
if should_retry_str.lower() != 'true':
|
|
66
145
|
return None
|
|
67
|
-
|
|
68
|
-
return int(retry_after_str)
|
|
146
|
+
return super().get_retry_delay(exc)
|
|
69
147
|
|
|
70
148
|
|
|
71
149
|
@pxt.udf
|
|
@@ -74,9 +152,10 @@ async def messages(
|
|
|
74
152
|
*,
|
|
75
153
|
model: str,
|
|
76
154
|
max_tokens: int,
|
|
77
|
-
model_kwargs:
|
|
78
|
-
tools:
|
|
79
|
-
tool_choice:
|
|
155
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
156
|
+
tools: list[dict[str, Any]] | None = None,
|
|
157
|
+
tool_choice: dict[str, Any] | None = None,
|
|
158
|
+
_runtime_ctx: env.RuntimeCtx | None = None,
|
|
80
159
|
) -> dict:
|
|
81
160
|
"""
|
|
82
161
|
Create a Message.
|
|
@@ -151,32 +230,13 @@ async def messages(
|
|
|
151
230
|
messages=cast(Iterable[MessageParam], messages), model=model, max_tokens=max_tokens, **model_kwargs
|
|
152
231
|
)
|
|
153
232
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
requests_reset = datetime.datetime.fromisoformat(requests_reset_str.replace('Z', '+00:00'))
|
|
160
|
-
input_tokens_limit_str = result.headers.get('anthropic-ratelimit-input-tokens-limit')
|
|
161
|
-
input_tokens_limit = int(input_tokens_limit_str) if input_tokens_limit_str is not None else None
|
|
162
|
-
input_tokens_remaining_str = result.headers.get('anthropic-ratelimit-input-tokens-remaining')
|
|
163
|
-
input_tokens_remaining = int(input_tokens_remaining_str) if input_tokens_remaining_str is not None else None
|
|
164
|
-
input_tokens_reset_str = result.headers.get('anthropic-ratelimit-input-tokens-reset')
|
|
165
|
-
input_tokens_reset = datetime.datetime.fromisoformat(input_tokens_reset_str.replace('Z', '+00:00'))
|
|
166
|
-
output_tokens_limit_str = result.headers.get('anthropic-ratelimit-output-tokens-limit')
|
|
167
|
-
output_tokens_limit = int(output_tokens_limit_str) if output_tokens_limit_str is not None else None
|
|
168
|
-
output_tokens_remaining_str = result.headers.get('anthropic-ratelimit-output-tokens-remaining')
|
|
169
|
-
output_tokens_remaining = int(output_tokens_remaining_str) if output_tokens_remaining_str is not None else None
|
|
170
|
-
output_tokens_reset_str = result.headers.get('anthropic-ratelimit-output-tokens-reset')
|
|
171
|
-
output_tokens_reset = datetime.datetime.fromisoformat(output_tokens_reset_str.replace('Z', '+00:00'))
|
|
172
|
-
retry_after_str = result.headers.get('retry-after')
|
|
173
|
-
if retry_after_str is not None:
|
|
174
|
-
_logger.debug(f'retry-after: {retry_after_str}')
|
|
175
|
-
|
|
233
|
+
requests_info, input_tokens_info, output_tokens_info = _get_header_info(result.headers)
|
|
234
|
+
# retry_after_str = result.headers.get('retry-after')
|
|
235
|
+
# if retry_after_str is not None:
|
|
236
|
+
# _logger.debug(f'retry-after: {retry_after_str}')
|
|
237
|
+
is_retry = _runtime_ctx is not None and _runtime_ctx.is_retry
|
|
176
238
|
rate_limits_info.record(
|
|
177
|
-
requests=
|
|
178
|
-
input_tokens=(input_tokens_limit, input_tokens_remaining, input_tokens_reset),
|
|
179
|
-
output_tokens=(output_tokens_limit, output_tokens_remaining, output_tokens_reset),
|
|
239
|
+
requests=requests_info, input_tokens=input_tokens_info, output_tokens=output_tokens_info, reset_exc=is_retry
|
|
180
240
|
)
|
|
181
241
|
|
|
182
242
|
result_dict = json.loads(result.text)
|
|
@@ -194,7 +254,7 @@ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
|
|
|
194
254
|
|
|
195
255
|
|
|
196
256
|
@pxt.udf
|
|
197
|
-
def _anthropic_response_to_pxt_tool_calls(response: dict) ->
|
|
257
|
+
def _anthropic_response_to_pxt_tool_calls(response: dict) -> dict | None:
|
|
198
258
|
anthropic_tool_calls = [r for r in response['content'] if r['type'] == 'tool_use']
|
|
199
259
|
if len(anthropic_tool_calls) == 0:
|
|
200
260
|
return None
|
pixeltable/functions/audio.py
CHANGED
|
@@ -1,26 +1,130 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs) for `AudioType`.
|
|
3
|
-
|
|
4
|
-
Example:
|
|
5
|
-
```python
|
|
6
|
-
import pixeltable as pxt
|
|
7
|
-
import pixeltable.functions as pxtf
|
|
8
|
-
|
|
9
|
-
t = pxt.get_table(...)
|
|
10
|
-
t.select(pxtf.audio.get_metadata()).collect()
|
|
11
|
-
```
|
|
12
3
|
"""
|
|
13
4
|
|
|
5
|
+
import av
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
14
8
|
import pixeltable as pxt
|
|
9
|
+
import pixeltable.utils.av as av_utils
|
|
15
10
|
from pixeltable.utils.code import local_public_names
|
|
11
|
+
from pixeltable.utils.local_store import TempStore
|
|
16
12
|
|
|
17
13
|
|
|
18
14
|
@pxt.udf(is_method=True)
|
|
19
15
|
def get_metadata(audio: pxt.Audio) -> dict:
|
|
20
16
|
"""
|
|
21
17
|
Gets various metadata associated with an audio file and returns it as a dictionary.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
audio: The audio to get metadata for.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
A `dict` such as the following:
|
|
24
|
+
|
|
25
|
+
```json
|
|
26
|
+
{
|
|
27
|
+
'size': 2568827,
|
|
28
|
+
'streams': [
|
|
29
|
+
{
|
|
30
|
+
'type': 'audio',
|
|
31
|
+
'frames': 0,
|
|
32
|
+
'duration': 2646000,
|
|
33
|
+
'metadata': {},
|
|
34
|
+
'time_base': 2.2675736961451248e-05,
|
|
35
|
+
'codec_context': {
|
|
36
|
+
'name': 'flac',
|
|
37
|
+
'profile': None,
|
|
38
|
+
'channels': 1,
|
|
39
|
+
'codec_tag': '\\x00\\x00\\x00\\x00',
|
|
40
|
+
},
|
|
41
|
+
'duration_seconds': 60.0,
|
|
42
|
+
}
|
|
43
|
+
],
|
|
44
|
+
'bit_rate': 342510,
|
|
45
|
+
'metadata': {'encoder': 'Lavf61.1.100'},
|
|
46
|
+
'bit_exact': False,
|
|
47
|
+
}
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
Examples:
|
|
51
|
+
Extract metadata for files in the `audio_col` column of the table `tbl`:
|
|
52
|
+
|
|
53
|
+
>>> tbl.select(tbl.audio_col.get_metadata()).collect()
|
|
54
|
+
"""
|
|
55
|
+
return av_utils.get_metadata(audio)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pxt.udf()
|
|
59
|
+
def encode_audio(
|
|
60
|
+
audio_data: pxt.Array[pxt.Float], *, input_sample_rate: int, format: str, output_sample_rate: int | None = None
|
|
61
|
+
) -> pxt.Audio:
|
|
62
|
+
"""
|
|
63
|
+
Encodes an audio clip represented as an array into a specified audio format.
|
|
64
|
+
|
|
65
|
+
Parameters:
|
|
66
|
+
audio_data: An array of sampled amplitudes. The accepted array shapes are `(N,)` or `(1, N)` for mono audio
|
|
67
|
+
or `(2, N)` for stereo.
|
|
68
|
+
input_sample_rate: The sample rate of the input audio data.
|
|
69
|
+
format: The desired output audio format. The supported formats are 'wav', 'mp3', 'flac', and 'mp4'.
|
|
70
|
+
output_sample_rate: The desired sample rate for the output audio. Defaults to the input sample rate if
|
|
71
|
+
unspecified.
|
|
72
|
+
|
|
73
|
+
Examples:
|
|
74
|
+
Add a computed column with encoded FLAC audio files to a table with audio data (as arrays of floats) and sample
|
|
75
|
+
rates:
|
|
76
|
+
|
|
77
|
+
```
|
|
78
|
+
t.add_computed_column(
|
|
79
|
+
audio_file=encode_audio(
|
|
80
|
+
t.audio_data, input_sample_rate=t.sample_rate, format='flac'
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
```
|
|
22
84
|
"""
|
|
23
|
-
|
|
85
|
+
if format not in av_utils.AUDIO_FORMATS:
|
|
86
|
+
raise pxt.Error(f'Only the following formats are supported: {av_utils.AUDIO_FORMATS.keys()}')
|
|
87
|
+
if output_sample_rate is None:
|
|
88
|
+
output_sample_rate = input_sample_rate
|
|
89
|
+
|
|
90
|
+
codec, ext = av_utils.AUDIO_FORMATS[format]
|
|
91
|
+
output_path = str(TempStore.create_path(extension=f'.{ext}'))
|
|
92
|
+
|
|
93
|
+
match audio_data.shape:
|
|
94
|
+
case (_,):
|
|
95
|
+
# Mono audio as 1D array, reshape for pyav
|
|
96
|
+
layout = 'mono'
|
|
97
|
+
audio_data_transformed = audio_data[None, :]
|
|
98
|
+
case (1, _):
|
|
99
|
+
# Mono audio as 2D array, simply reshape and transpose the input for pyav
|
|
100
|
+
layout = 'mono'
|
|
101
|
+
audio_data_transformed = audio_data.reshape(-1, 1).transpose()
|
|
102
|
+
case (2, _):
|
|
103
|
+
# Stereo audio. Input layout: [[L0, L1, L2, ...],[R0, R1, R2, ...]],
|
|
104
|
+
# pyav expects: [L0, R0, L1, R1, L2, R2, ...]
|
|
105
|
+
layout = 'stereo'
|
|
106
|
+
audio_data_transformed = np.empty(audio_data.shape[1] * 2, dtype=audio_data.dtype)
|
|
107
|
+
audio_data_transformed[0::2] = audio_data[0]
|
|
108
|
+
audio_data_transformed[1::2] = audio_data[1]
|
|
109
|
+
audio_data_transformed = audio_data_transformed.reshape(1, -1)
|
|
110
|
+
case _:
|
|
111
|
+
raise pxt.Error(
|
|
112
|
+
f'Supported input array shapes are (N,), (1, N) for mono and (2, N) for stereo, got {audio_data.shape}'
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
with av.open(output_path, mode='w') as output_container:
|
|
116
|
+
stream = output_container.add_stream(codec, rate=output_sample_rate)
|
|
117
|
+
assert isinstance(stream, av.AudioStream)
|
|
118
|
+
|
|
119
|
+
frame = av.AudioFrame.from_ndarray(audio_data_transformed, format='flt', layout=layout)
|
|
120
|
+
frame.sample_rate = input_sample_rate
|
|
121
|
+
|
|
122
|
+
for packet in stream.encode(frame):
|
|
123
|
+
output_container.mux(packet)
|
|
124
|
+
for packet in stream.encode():
|
|
125
|
+
output_container.mux(packet)
|
|
126
|
+
|
|
127
|
+
return output_path
|
|
24
128
|
|
|
25
129
|
|
|
26
130
|
__all__ = local_public_names(__name__)
|