pixeltable 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +34 -6
- pixeltable/catalog/__init__.py +13 -0
- pixeltable/catalog/catalog.py +159 -0
- pixeltable/catalog/column.py +200 -0
- pixeltable/catalog/dir.py +32 -0
- pixeltable/catalog/globals.py +33 -0
- pixeltable/catalog/insertable_table.py +191 -0
- pixeltable/catalog/named_function.py +36 -0
- pixeltable/catalog/path.py +58 -0
- pixeltable/catalog/path_dict.py +139 -0
- pixeltable/catalog/schema_object.py +39 -0
- pixeltable/catalog/table.py +581 -0
- pixeltable/catalog/table_version.py +749 -0
- pixeltable/catalog/table_version_path.py +133 -0
- pixeltable/catalog/view.py +203 -0
- pixeltable/client.py +520 -30
- pixeltable/dataframe.py +540 -349
- pixeltable/env.py +373 -45
- pixeltable/exceptions.py +12 -21
- pixeltable/exec/__init__.py +9 -0
- pixeltable/exec/aggregation_node.py +78 -0
- pixeltable/exec/cache_prefetch_node.py +113 -0
- pixeltable/exec/component_iteration_node.py +79 -0
- pixeltable/exec/data_row_batch.py +95 -0
- pixeltable/exec/exec_context.py +22 -0
- pixeltable/exec/exec_node.py +61 -0
- pixeltable/exec/expr_eval_node.py +217 -0
- pixeltable/exec/in_memory_data_node.py +69 -0
- pixeltable/exec/media_validation_node.py +43 -0
- pixeltable/exec/sql_scan_node.py +225 -0
- pixeltable/exprs/__init__.py +24 -0
- pixeltable/exprs/arithmetic_expr.py +102 -0
- pixeltable/exprs/array_slice.py +71 -0
- pixeltable/exprs/column_property_ref.py +77 -0
- pixeltable/exprs/column_ref.py +105 -0
- pixeltable/exprs/comparison.py +77 -0
- pixeltable/exprs/compound_predicate.py +98 -0
- pixeltable/exprs/data_row.py +187 -0
- pixeltable/exprs/expr.py +586 -0
- pixeltable/exprs/expr_set.py +39 -0
- pixeltable/exprs/function_call.py +380 -0
- pixeltable/exprs/globals.py +69 -0
- pixeltable/exprs/image_member_access.py +115 -0
- pixeltable/exprs/image_similarity_predicate.py +58 -0
- pixeltable/exprs/inline_array.py +107 -0
- pixeltable/exprs/inline_dict.py +101 -0
- pixeltable/exprs/is_null.py +38 -0
- pixeltable/exprs/json_mapper.py +121 -0
- pixeltable/exprs/json_path.py +159 -0
- pixeltable/exprs/literal.py +54 -0
- pixeltable/exprs/object_ref.py +41 -0
- pixeltable/exprs/predicate.py +44 -0
- pixeltable/exprs/row_builder.py +355 -0
- pixeltable/exprs/rowid_ref.py +94 -0
- pixeltable/exprs/type_cast.py +53 -0
- pixeltable/exprs/variable.py +45 -0
- pixeltable/func/__init__.py +9 -0
- pixeltable/func/aggregate_function.py +194 -0
- pixeltable/func/batched_function.py +53 -0
- pixeltable/func/callable_function.py +69 -0
- pixeltable/func/expr_template_function.py +82 -0
- pixeltable/func/function.py +110 -0
- pixeltable/func/function_registry.py +227 -0
- pixeltable/func/globals.py +36 -0
- pixeltable/func/nos_function.py +202 -0
- pixeltable/func/signature.py +166 -0
- pixeltable/func/udf.py +163 -0
- pixeltable/functions/__init__.py +52 -103
- pixeltable/functions/eval.py +216 -0
- pixeltable/functions/fireworks.py +61 -0
- pixeltable/functions/huggingface.py +120 -0
- pixeltable/functions/image.py +16 -0
- pixeltable/functions/openai.py +88 -0
- pixeltable/functions/pil/image.py +148 -7
- pixeltable/functions/string.py +13 -0
- pixeltable/functions/together.py +27 -0
- pixeltable/functions/util.py +41 -0
- pixeltable/functions/video.py +62 -0
- pixeltable/iterators/__init__.py +3 -0
- pixeltable/iterators/base.py +48 -0
- pixeltable/iterators/document.py +311 -0
- pixeltable/iterators/video.py +89 -0
- pixeltable/metadata/__init__.py +54 -0
- pixeltable/metadata/converters/convert_10.py +18 -0
- pixeltable/metadata/schema.py +211 -0
- pixeltable/plan.py +656 -0
- pixeltable/store.py +413 -182
- pixeltable/tests/conftest.py +143 -87
- pixeltable/tests/test_audio.py +65 -0
- pixeltable/tests/test_catalog.py +27 -0
- pixeltable/tests/test_client.py +14 -14
- pixeltable/tests/test_component_view.py +372 -0
- pixeltable/tests/test_dataframe.py +433 -0
- pixeltable/tests/test_dirs.py +78 -62
- pixeltable/tests/test_document.py +117 -0
- pixeltable/tests/test_exprs.py +591 -135
- pixeltable/tests/test_function.py +297 -67
- pixeltable/tests/test_functions.py +283 -1
- pixeltable/tests/test_migration.py +43 -0
- pixeltable/tests/test_nos.py +54 -0
- pixeltable/tests/test_snapshot.py +208 -0
- pixeltable/tests/test_table.py +1085 -262
- pixeltable/tests/test_transactional_directory.py +42 -0
- pixeltable/tests/test_types.py +5 -11
- pixeltable/tests/test_video.py +149 -34
- pixeltable/tests/test_view.py +530 -0
- pixeltable/tests/utils.py +186 -45
- pixeltable/tool/create_test_db_dump.py +149 -0
- pixeltable/type_system.py +490 -126
- pixeltable/utils/__init__.py +17 -46
- pixeltable/utils/clip.py +12 -15
- pixeltable/utils/coco.py +136 -0
- pixeltable/utils/documents.py +39 -0
- pixeltable/utils/filecache.py +195 -0
- pixeltable/utils/help.py +11 -0
- pixeltable/utils/media_store.py +76 -0
- pixeltable/utils/parquet.py +126 -0
- pixeltable/utils/pytorch.py +172 -0
- pixeltable/utils/s3.py +13 -0
- pixeltable/utils/sql.py +17 -0
- pixeltable/utils/transactional_directory.py +35 -0
- pixeltable-0.2.0.dist-info/LICENSE +18 -0
- pixeltable-0.2.0.dist-info/METADATA +117 -0
- pixeltable-0.2.0.dist-info/RECORD +125 -0
- {pixeltable-0.1.1.dist-info → pixeltable-0.2.0.dist-info}/WHEEL +1 -1
- pixeltable/catalog.py +0 -1421
- pixeltable/exprs.py +0 -1745
- pixeltable/function.py +0 -269
- pixeltable/functions/clip.py +0 -10
- pixeltable/functions/pil/__init__.py +0 -23
- pixeltable/functions/tf.py +0 -21
- pixeltable/index.py +0 -57
- pixeltable/tests/test_dict.py +0 -24
- pixeltable/tests/test_tf.py +0 -69
- pixeltable/tf.py +0 -33
- pixeltable/utils/tf.py +0 -33
- pixeltable/utils/video.py +0 -32
- pixeltable-0.1.1.dist-info/METADATA +0 -31
- pixeltable-0.1.1.dist-info/RECORD +0 -36
pixeltable/func/udf.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from typing import List, Callable, Optional, overload, Any
|
|
5
|
+
|
|
6
|
+
import pixeltable as pxt
|
|
7
|
+
import pixeltable.exceptions as excs
|
|
8
|
+
import pixeltable.type_system as ts
|
|
9
|
+
from .batched_function import ExplicitBatchedFunction
|
|
10
|
+
from .callable_function import CallableFunction
|
|
11
|
+
from .expr_template_function import ExprTemplateFunction
|
|
12
|
+
from .function import Function
|
|
13
|
+
from .function_registry import FunctionRegistry
|
|
14
|
+
from .signature import Signature
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Decorator invoked without parentheses: @pxt.udf
|
|
18
|
+
@overload
|
|
19
|
+
def udf(decorated_fn: Callable) -> Function: ...
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
|
|
23
|
+
@overload
|
|
24
|
+
def udf(
|
|
25
|
+
*,
|
|
26
|
+
return_type: Optional[ts.ColumnType] = None,
|
|
27
|
+
param_types: Optional[List[ts.ColumnType]] = None,
|
|
28
|
+
batch_size: Optional[int] = None,
|
|
29
|
+
substitute_fn: Optional[Callable] = None,
|
|
30
|
+
_force_stored: bool = False
|
|
31
|
+
) -> Callable: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def udf(*args, **kwargs):
|
|
35
|
+
"""A decorator to create a Function from a function definition.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
>>> @pxt.udf
|
|
39
|
+
... def my_function(x: int) -> int:
|
|
40
|
+
... return x + 1
|
|
41
|
+
|
|
42
|
+
>>> @pxt.udf(param_types=[pxt.IntType()], return_type=pxt.IntType())
|
|
43
|
+
... def my_function(x):
|
|
44
|
+
... return x + 1
|
|
45
|
+
"""
|
|
46
|
+
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
|
47
|
+
|
|
48
|
+
# Decorator invoked without parentheses: @pxt.udf
|
|
49
|
+
# Simply call make_function with defaults.
|
|
50
|
+
return make_function(decorated_fn=args[0])
|
|
51
|
+
|
|
52
|
+
else:
|
|
53
|
+
|
|
54
|
+
# Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
|
|
55
|
+
# Create a decorator for the specified schema.
|
|
56
|
+
return_type = kwargs.pop('return_type', None)
|
|
57
|
+
param_types = kwargs.pop('param_types', None)
|
|
58
|
+
batch_size = kwargs.pop('batch_size', None)
|
|
59
|
+
substitute_fn = kwargs.pop('py_fn', None)
|
|
60
|
+
force_stored = kwargs.pop('_force_stored', False)
|
|
61
|
+
|
|
62
|
+
def decorator(decorated_fn: Callable):
|
|
63
|
+
return make_function(
|
|
64
|
+
decorated_fn, return_type, param_types, batch_size, substitute_fn=substitute_fn,
|
|
65
|
+
force_stored=force_stored)
|
|
66
|
+
|
|
67
|
+
return decorator
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def make_function(
|
|
71
|
+
decorated_fn: Callable,
|
|
72
|
+
return_type: Optional[ts.ColumnType] = None,
|
|
73
|
+
param_types: Optional[List[ts.ColumnType]] = None,
|
|
74
|
+
batch_size: Optional[int] = None,
|
|
75
|
+
substitute_fn: Optional[Callable] = None,
|
|
76
|
+
function_name: Optional[str] = None,
|
|
77
|
+
force_stored: bool = False
|
|
78
|
+
) -> Function:
|
|
79
|
+
"""
|
|
80
|
+
Constructs a `CallableFunction` or `BatchedFunction`, depending on the
|
|
81
|
+
supplied parameters. If `substitute_fn` is specified, then `decorated_fn`
|
|
82
|
+
will be used only for its signature, with execution delegated to
|
|
83
|
+
`substitute_fn`.
|
|
84
|
+
"""
|
|
85
|
+
# Obtain function_path from decorated_fn when appropriate
|
|
86
|
+
if force_stored:
|
|
87
|
+
# force storing the function in the db
|
|
88
|
+
function_path = None
|
|
89
|
+
elif decorated_fn.__module__ != '__main__' and decorated_fn.__name__.isidentifier():
|
|
90
|
+
function_path = f'{decorated_fn.__module__}.{decorated_fn.__qualname__}'
|
|
91
|
+
else:
|
|
92
|
+
function_path = None
|
|
93
|
+
|
|
94
|
+
# Derive function_name, if not specified explicitly
|
|
95
|
+
if function_name is None:
|
|
96
|
+
function_name = decorated_fn.__name__
|
|
97
|
+
|
|
98
|
+
# Display name to use for error messages
|
|
99
|
+
errmsg_name = function_name if function_path is None else function_path
|
|
100
|
+
|
|
101
|
+
sig = Signature.create(decorated_fn, param_types, return_type)
|
|
102
|
+
|
|
103
|
+
# batched functions must have a batched return type
|
|
104
|
+
# TODO: remove 'Python' from the error messages when we have full inference with Annotated types
|
|
105
|
+
if batch_size is not None and not sig.is_batched:
|
|
106
|
+
raise excs.Error(f'{errmsg_name}(): batch_size is specified; Python return type must be a `Batch`')
|
|
107
|
+
if batch_size is not None and len(sig.batched_parameters) == 0:
|
|
108
|
+
raise excs.Error(f'{errmsg_name}(): batch_size is specified; at least one Python parameter must be `Batch`')
|
|
109
|
+
if batch_size is None and len(sig.batched_parameters) > 0:
|
|
110
|
+
raise excs.Error(f'{errmsg_name}(): batched parameters in udf, but no `batch_size` given')
|
|
111
|
+
|
|
112
|
+
if substitute_fn is None:
|
|
113
|
+
py_fn = decorated_fn
|
|
114
|
+
else:
|
|
115
|
+
if function_path is None:
|
|
116
|
+
raise excs.Error(f'{errmsg_name}(): @udf decorator with a `substitute_fn` can only be used in a module')
|
|
117
|
+
py_fn = substitute_fn
|
|
118
|
+
|
|
119
|
+
if batch_size is None:
|
|
120
|
+
result = CallableFunction(signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name)
|
|
121
|
+
else:
|
|
122
|
+
result = ExplicitBatchedFunction(
|
|
123
|
+
signature=sig, batch_size=batch_size, invoker_fn=py_fn, self_path=function_path)
|
|
124
|
+
|
|
125
|
+
# If this function is part of a module, register it
|
|
126
|
+
if function_path is not None:
|
|
127
|
+
FunctionRegistry.get().register_function(function_path, result)
|
|
128
|
+
|
|
129
|
+
return result
|
|
130
|
+
|
|
131
|
+
@overload
|
|
132
|
+
def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
|
|
133
|
+
|
|
134
|
+
@overload
|
|
135
|
+
def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable: ...
|
|
136
|
+
|
|
137
|
+
def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
138
|
+
def decorator(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction:
|
|
139
|
+
if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
|
|
140
|
+
# this is a named function in a module
|
|
141
|
+
function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
|
|
142
|
+
else:
|
|
143
|
+
function_path = None
|
|
144
|
+
|
|
145
|
+
sig = Signature.create(py_fn, param_types=param_types, return_type=None)
|
|
146
|
+
# TODO: verify that the inferred return type matches that of the template
|
|
147
|
+
# TODO: verify that the signature doesn't contain batched parameters
|
|
148
|
+
|
|
149
|
+
# construct Parameters from the function signature
|
|
150
|
+
import pixeltable.exprs as exprs
|
|
151
|
+
var_exprs = [exprs.Variable(param.name, param.col_type) for param in sig.parameters.values()]
|
|
152
|
+
# call the function with the parameter expressions to construct an Expr with parameters
|
|
153
|
+
template = py_fn(*var_exprs)
|
|
154
|
+
assert isinstance(template, exprs.Expr)
|
|
155
|
+
py_sig = inspect.signature(py_fn)
|
|
156
|
+
return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
|
|
157
|
+
|
|
158
|
+
if len(args) == 1:
|
|
159
|
+
assert len(kwargs) == 0 and callable(args[0])
|
|
160
|
+
return decorator(args[0], None)
|
|
161
|
+
else:
|
|
162
|
+
assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
|
|
163
|
+
return lambda py_fn: decorator(py_fn, kwargs['param_types'])
|
pixeltable/functions/__init__.py
CHANGED
|
@@ -1,89 +1,57 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import Callable, List, Optional, Union
|
|
3
|
-
import inspect
|
|
4
|
-
from pathlib import Path
|
|
5
1
|
import tempfile
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional, Union
|
|
6
4
|
|
|
7
|
-
import PIL
|
|
5
|
+
import PIL.Image
|
|
6
|
+
import av
|
|
7
|
+
import av.container
|
|
8
|
+
import av.stream
|
|
8
9
|
import numpy as np
|
|
9
10
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
11
|
+
import pixeltable.env as env
|
|
12
|
+
import pixeltable.func as func
|
|
13
|
+
# import all standard function modules here so they get registered with the FunctionRegistry
|
|
14
|
+
import pixeltable.functions.pil.image
|
|
13
15
|
from pixeltable import exprs
|
|
14
|
-
from pixeltable import
|
|
15
|
-
import
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def udf_call(eval_fn: Callable, return_type: ColumnType, tbl: Optional[catalog.Table]) -> exprs.FunctionCall:
|
|
19
|
-
"""
|
|
20
|
-
Interprets eval_fn's parameters to be references to columns in 'tbl' and construct ColumnRefs as args.
|
|
21
|
-
"""
|
|
22
|
-
params = inspect.signature(eval_fn).parameters
|
|
23
|
-
if len(params) > 0 and tbl is None:
|
|
24
|
-
raise exc.OperationalError(f'udf_call() is missing tbl parameter')
|
|
25
|
-
args: List[exprs.ColumnRef] = []
|
|
26
|
-
for param_name in params:
|
|
27
|
-
if param_name not in tbl.cols_by_name:
|
|
28
|
-
raise exc.OperationalError(
|
|
29
|
-
(f'udf_call(): lambda argument names need to be valid column names in table {tbl.name}: '
|
|
30
|
-
f'column {param_name} unknown'))
|
|
31
|
-
args.append(exprs.ColumnRef(tbl.cols_by_name[param_name]))
|
|
32
|
-
fn = Function(return_type, [arg.col_type for arg in args], eval_fn=eval_fn)
|
|
33
|
-
return exprs.FunctionCall(fn, args)
|
|
16
|
+
from pixeltable.type_system import IntType, ColumnType, FloatType, ImageType, VideoType
|
|
17
|
+
# automatically import all submodules so that the udfs get registered
|
|
18
|
+
from . import image, string, video, openai, together, fireworks, huggingface
|
|
34
19
|
|
|
20
|
+
# TODO: remove and replace calls with astype()
|
|
35
21
|
def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
|
|
36
22
|
expr.col_type = target_type
|
|
37
23
|
return expr
|
|
38
24
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class SumAggregator:
|
|
25
|
+
@func.uda(
|
|
26
|
+
update_types=[IntType()], value_type=IntType(), name='sum', allows_window=True, requires_order_by=False)
|
|
27
|
+
class SumAggregator(func.Aggregator):
|
|
43
28
|
def __init__(self):
|
|
44
29
|
self.sum: Union[int, float] = 0
|
|
45
|
-
@classmethod
|
|
46
|
-
def make_aggregator(cls) -> 'SumAggregator':
|
|
47
|
-
return cls()
|
|
48
30
|
def update(self, val: Union[int, float]) -> None:
|
|
49
31
|
if val is not None:
|
|
50
32
|
self.sum += val
|
|
51
33
|
def value(self) -> Union[int, float]:
|
|
52
34
|
return self.sum
|
|
53
35
|
|
|
54
|
-
sum = Function(
|
|
55
|
-
IntType(), [IntType()],
|
|
56
|
-
module_name='pixeltable.functions',
|
|
57
|
-
init_symbol='SumAggregator.make_aggregator',
|
|
58
|
-
update_symbol='SumAggregator.update',
|
|
59
|
-
value_symbol='SumAggregator.value')
|
|
60
36
|
|
|
61
|
-
|
|
37
|
+
@func.uda(
|
|
38
|
+
update_types=[IntType()], value_type=IntType(), name='count', allows_window = True, requires_order_by = False)
|
|
39
|
+
class CountAggregator(func.Aggregator):
|
|
62
40
|
def __init__(self):
|
|
63
41
|
self.count = 0
|
|
64
|
-
@classmethod
|
|
65
|
-
def make_aggregator(cls) -> 'CountAggregator':
|
|
66
|
-
return cls()
|
|
67
42
|
def update(self, val: int) -> None:
|
|
68
43
|
if val is not None:
|
|
69
44
|
self.count += 1
|
|
70
45
|
def value(self) -> int:
|
|
71
46
|
return self.count
|
|
72
47
|
|
|
73
|
-
count = Function(
|
|
74
|
-
IntType(), [IntType()],
|
|
75
|
-
module_name = 'pixeltable.functions',
|
|
76
|
-
init_symbol = 'CountAggregator.make_aggregator',
|
|
77
|
-
update_symbol = 'CountAggregator.update',
|
|
78
|
-
value_symbol = 'CountAggregator.value')
|
|
79
48
|
|
|
80
|
-
|
|
49
|
+
@func.uda(
|
|
50
|
+
update_types=[IntType()], value_type=FloatType(), name='mean', allows_window=False, requires_order_by=False)
|
|
51
|
+
class MeanAggregator(func.Aggregator):
|
|
81
52
|
def __init__(self):
|
|
82
53
|
self.sum = 0
|
|
83
54
|
self.count = 0
|
|
84
|
-
@classmethod
|
|
85
|
-
def make_aggregator(cls) -> 'MeanAggregator':
|
|
86
|
-
return cls()
|
|
87
55
|
def update(self, val: int) -> None:
|
|
88
56
|
if val is not None:
|
|
89
57
|
self.sum += val
|
|
@@ -93,54 +61,35 @@ class MeanAggregator:
|
|
|
93
61
|
return None
|
|
94
62
|
return self.sum / self.count
|
|
95
63
|
|
|
96
|
-
mean = Function(
|
|
97
|
-
FloatType(), [IntType()],
|
|
98
|
-
module_name = 'pixeltable.functions',
|
|
99
|
-
init_symbol = 'MeanAggregator.make_aggregator',
|
|
100
|
-
update_symbol = 'MeanAggregator.update',
|
|
101
|
-
value_symbol = 'MeanAggregator.value')
|
|
102
64
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
65
|
+
@func.uda(
|
|
66
|
+
init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(), name='make_video',
|
|
67
|
+
requires_order_by=True, allows_window=False)
|
|
68
|
+
class VideoAggregator(func.Aggregator):
|
|
69
|
+
def __init__(self, fps: int = 25):
|
|
70
|
+
"""follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
|
|
71
|
+
self.container: Optional[av.container.OutputContainer] = None
|
|
72
|
+
self.stream: Optional[av.stream.Stream] = None
|
|
73
|
+
self.fps = fps
|
|
74
|
+
|
|
75
|
+
def update(self, frame: PIL.Image.Image) -> None:
|
|
76
|
+
if frame is None:
|
|
77
|
+
return
|
|
78
|
+
if self.container is None:
|
|
79
|
+
(_, output_filename) = tempfile.mkstemp(suffix='.mp4', dir=str(env.Env.get().tmp_dir))
|
|
80
|
+
self.out_file = Path(output_filename)
|
|
81
|
+
self.container = av.open(str(self.out_file), mode='w')
|
|
82
|
+
self.stream = self.container.add_stream('h264', rate=self.fps)
|
|
83
|
+
self.stream.pix_fmt = 'yuv420p'
|
|
84
|
+
self.stream.width = frame.width
|
|
85
|
+
self.stream.height = frame.height
|
|
86
|
+
|
|
87
|
+
av_frame = av.VideoFrame.from_ndarray(np.array(frame.convert('RGB')), format='rgb24')
|
|
88
|
+
for packet in self.stream.encode(av_frame):
|
|
89
|
+
self.container.mux(packet)
|
|
122
90
|
|
|
123
91
|
def value(self) -> str:
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
return self.out_file
|
|
128
|
-
|
|
129
|
-
make_video = Function(
|
|
130
|
-
VideoType(), [IntType(), ImageType()], # params: frame_idx, frame
|
|
131
|
-
order_by=[0], # update() wants frames in frame_idx order
|
|
132
|
-
module_name = 'pixeltable.functions',
|
|
133
|
-
init_symbol = 'VideoAggregator.make_aggregator',
|
|
134
|
-
update_symbol = 'VideoAggregator.update',
|
|
135
|
-
value_symbol = 'VideoAggregator.value')
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
__all__ = [
|
|
139
|
-
udf_call,
|
|
140
|
-
cast,
|
|
141
|
-
dict_map,
|
|
142
|
-
sum,
|
|
143
|
-
count,
|
|
144
|
-
mean,
|
|
145
|
-
make_video
|
|
146
|
-
]
|
|
92
|
+
for packet in self.stream.encode():
|
|
93
|
+
self.container.mux(packet)
|
|
94
|
+
self.container.close()
|
|
95
|
+
return str(self.out_file)
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Tuple, Dict
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
import pixeltable.type_system as ts
|
|
9
|
+
import pixeltable.func as func
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# TODO: figure out a better submodule structure
|
|
13
|
+
|
|
14
|
+
# the following function has been adapted from MMEval
|
|
15
|
+
# (sources at https://github.com/open-mmlab/mmeval)
|
|
16
|
+
# Copyright (c) OpenMMLab. All rights reserved.
|
|
17
|
+
def calculate_bboxes_area(bboxes: np.ndarray) -> np.ndarray:
|
|
18
|
+
"""Calculate area of bounding boxes.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
bboxes (numpy.ndarray): The bboxes with shape (n, 4) or (4, ) in 'xyxy' format.
|
|
22
|
+
Returns:
|
|
23
|
+
numpy.ndarray: The area of bboxes.
|
|
24
|
+
"""
|
|
25
|
+
bboxes_w = (bboxes[..., 2] - bboxes[..., 0])
|
|
26
|
+
bboxes_h = (bboxes[..., 3] - bboxes[..., 1])
|
|
27
|
+
areas = bboxes_w * bboxes_h
|
|
28
|
+
return areas
|
|
29
|
+
|
|
30
|
+
# the following function has been adapted from MMEval
|
|
31
|
+
# (sources at https://github.com/open-mmlab/mmeval)
|
|
32
|
+
# Copyright (c) OpenMMLab. All rights reserved.
|
|
33
|
+
def calculate_overlaps(bboxes1: np.ndarray, bboxes2: np.ndarray) -> np.ndarray:
|
|
34
|
+
"""Calculate the overlap between each bbox of bboxes1 and bboxes2.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
bboxes1 (numpy.ndarray): The bboxes with shape (n, 4) in 'xyxy' format.
|
|
38
|
+
bboxes2 (numpy.ndarray): The bboxes with shape (k, 4) in 'xyxy' format.
|
|
39
|
+
Returns:
|
|
40
|
+
numpy.ndarray: IoUs or IoFs with shape (n, k).
|
|
41
|
+
"""
|
|
42
|
+
bboxes1 = bboxes1.astype(np.float32)
|
|
43
|
+
bboxes2 = bboxes2.astype(np.float32)
|
|
44
|
+
rows = bboxes1.shape[0]
|
|
45
|
+
cols = bboxes2.shape[0]
|
|
46
|
+
overlaps = np.zeros((rows, cols), dtype=np.float32)
|
|
47
|
+
|
|
48
|
+
if rows * cols == 0:
|
|
49
|
+
return overlaps
|
|
50
|
+
|
|
51
|
+
if bboxes1.shape[0] > bboxes2.shape[0]:
|
|
52
|
+
# Swap bboxes for faster calculation.
|
|
53
|
+
bboxes1, bboxes2 = bboxes2, bboxes1
|
|
54
|
+
overlaps = np.zeros((cols, rows), dtype=np.float32)
|
|
55
|
+
exchange = True
|
|
56
|
+
else:
|
|
57
|
+
exchange = False
|
|
58
|
+
|
|
59
|
+
# Calculate the bboxes area.
|
|
60
|
+
area1 = calculate_bboxes_area(bboxes1)
|
|
61
|
+
area2 = calculate_bboxes_area(bboxes2)
|
|
62
|
+
eps = np.finfo(np.float32).eps
|
|
63
|
+
|
|
64
|
+
for i in range(bboxes1.shape[0]):
|
|
65
|
+
x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
|
|
66
|
+
y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
|
|
67
|
+
x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
|
|
68
|
+
y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
|
|
69
|
+
overlap_w = np.maximum(x_end - x_start, 0)
|
|
70
|
+
overlap_h = np.maximum(y_end - y_start, 0)
|
|
71
|
+
overlap = overlap_w * overlap_h
|
|
72
|
+
|
|
73
|
+
union = area1[i] + area2 - overlap
|
|
74
|
+
union = np.maximum(union, eps)
|
|
75
|
+
overlaps[i, :] = overlap / union
|
|
76
|
+
return overlaps if not exchange else overlaps.T
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# the following function has been adapted from MMEval
|
|
80
|
+
# (sources at https://github.com/open-mmlab/mmeval)
|
|
81
|
+
# Copyright (c) OpenMMLab. All rights reserved.
|
|
82
|
+
def calculate_image_tpfp(
|
|
83
|
+
pred_bboxes: np.ndarray, pred_scores: np.ndarray, gt_bboxes: np.ndarray, min_iou: float
|
|
84
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
85
|
+
"""Calculate the true positive and false positive on an image.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
pred_bboxes (numpy.ndarray): Predicted bboxes of this image, with
|
|
89
|
+
shape (N, 5). The scores The predicted score of the bbox is
|
|
90
|
+
concatenated behind the predicted bbox.
|
|
91
|
+
gt_bboxes (numpy.ndarray): Ground truth bboxes of this image, with
|
|
92
|
+
shape (M, 4).
|
|
93
|
+
min_iou (float): The IoU threshold.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
tuple (tp, fp):
|
|
97
|
+
|
|
98
|
+
- tp (numpy.ndarray): Shape (N,),
|
|
99
|
+
the true positive flag of each predicted bbox on this image.
|
|
100
|
+
- fp (numpy.ndarray): Shape (N,),
|
|
101
|
+
the false positive flag of each predicted bbox on this image.
|
|
102
|
+
"""
|
|
103
|
+
# Step 1. Concatenate `gt_bboxes` and `ignore_gt_bboxes`, then set
|
|
104
|
+
# the `ignore_gt_flags`.
|
|
105
|
+
# all_gt_bboxes = np.concatenate((gt_bboxes, ignore_gt_bboxes))
|
|
106
|
+
# ignore_gt_flags = np.concatenate((np.zeros(
|
|
107
|
+
# (gt_bboxes.shape[0], 1),
|
|
108
|
+
# dtype=bool), np.ones((ignore_gt_bboxes.shape[0], 1), dtype=bool)))
|
|
109
|
+
|
|
110
|
+
# Step 2. Initialize the `tp` and `fp` arrays.
|
|
111
|
+
num_preds = pred_bboxes.shape[0]
|
|
112
|
+
tp = np.zeros(num_preds, dtype=np.int8)
|
|
113
|
+
fp = np.zeros(num_preds, dtype=np.int8)
|
|
114
|
+
|
|
115
|
+
# Step 3. If there are no gt bboxes in this image, then all pred bboxes
|
|
116
|
+
# within area range are false positives.
|
|
117
|
+
if gt_bboxes.shape[0] == 0:
|
|
118
|
+
fp[...] = 1
|
|
119
|
+
return tp, fp
|
|
120
|
+
|
|
121
|
+
# Step 4. Calculate the IoUs between the predicted bboxes and the
|
|
122
|
+
# ground truth bboxes.
|
|
123
|
+
ious = calculate_overlaps(pred_bboxes, gt_bboxes)
|
|
124
|
+
# For each pred bbox, the max iou with all gts.
|
|
125
|
+
ious_max = ious.max(axis=1)
|
|
126
|
+
# For each pred bbox, which gt overlaps most with it.
|
|
127
|
+
ious_argmax = ious.argmax(axis=1)
|
|
128
|
+
# Sort all pred bbox in descending order by scores.
|
|
129
|
+
sorted_indices = np.argsort(-pred_scores)
|
|
130
|
+
|
|
131
|
+
# Step 5. Count the `tp` and `fp` of each iou threshold and area range.
|
|
132
|
+
# The flags that gt bboxes have been matched.
|
|
133
|
+
gt_covered_flags = np.zeros(gt_bboxes.shape[0], dtype=bool)
|
|
134
|
+
|
|
135
|
+
# Count the prediction bboxes in order of decreasing score.
|
|
136
|
+
for pred_bbox_idx in sorted_indices:
|
|
137
|
+
if ious_max[pred_bbox_idx] >= min_iou:
|
|
138
|
+
matched_gt_idx = ious_argmax[pred_bbox_idx]
|
|
139
|
+
if not gt_covered_flags[matched_gt_idx]:
|
|
140
|
+
tp[pred_bbox_idx] = 1
|
|
141
|
+
gt_covered_flags[matched_gt_idx] = True
|
|
142
|
+
else:
|
|
143
|
+
# This gt bbox has been matched and counted as fp.
|
|
144
|
+
fp[pred_bbox_idx] = 1
|
|
145
|
+
else:
|
|
146
|
+
fp[pred_bbox_idx] = 1
|
|
147
|
+
|
|
148
|
+
return tp, fp
|
|
149
|
+
|
|
150
|
+
@func.udf(
|
|
151
|
+
return_type=ts.JsonType(nullable=False),
|
|
152
|
+
param_types=[
|
|
153
|
+
ts.JsonType(nullable=False),
|
|
154
|
+
ts.JsonType(nullable=False),
|
|
155
|
+
ts.JsonType(nullable=False),
|
|
156
|
+
ts.JsonType(nullable=False),
|
|
157
|
+
ts.JsonType(nullable=False)
|
|
158
|
+
])
|
|
159
|
+
def eval_detections(
|
|
160
|
+
pred_bboxes: List[List[int]], pred_classes: List[int], pred_scores: List[float],
|
|
161
|
+
gt_bboxes: List[List[int]], gt_classes: List[int]
|
|
162
|
+
) -> Dict:
|
|
163
|
+
class_idxs = list(set(pred_classes + gt_classes))
|
|
164
|
+
result: List[Dict] = []
|
|
165
|
+
pred_bboxes_arr = np.asarray(pred_bboxes)
|
|
166
|
+
pred_classes_arr = np.asarray(pred_classes)
|
|
167
|
+
pred_scores_arr = np.asarray(pred_scores)
|
|
168
|
+
gt_bboxes_arr = np.asarray(gt_bboxes)
|
|
169
|
+
gt_classes_arr = np.asarray(gt_classes)
|
|
170
|
+
for class_idx in class_idxs:
|
|
171
|
+
pred_filter = pred_classes_arr == class_idx
|
|
172
|
+
gt_filter = gt_classes_arr == class_idx
|
|
173
|
+
class_pred_scores = pred_scores_arr[pred_filter]
|
|
174
|
+
tp, fp = calculate_image_tpfp(
|
|
175
|
+
pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter], [0.5])
|
|
176
|
+
ordered_class_pred_scores = -np.sort(-class_pred_scores)
|
|
177
|
+
result.append({
|
|
178
|
+
'min_iou': 0.5, 'class': class_idx, 'tp': tp.tolist(), 'fp': fp.tolist(),
|
|
179
|
+
'scores': ordered_class_pred_scores.tolist(), 'num_gts': gt_filter.sum().item(),
|
|
180
|
+
})
|
|
181
|
+
return result
|
|
182
|
+
|
|
183
|
+
@func.uda(
|
|
184
|
+
update_types=[ts.JsonType()], value_type=ts.JsonType(), name='mean_ap', allows_std_agg=True, allows_window=False)
|
|
185
|
+
class MeanAPAggregator:
|
|
186
|
+
def __init__(self):
|
|
187
|
+
self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
|
|
188
|
+
|
|
189
|
+
def update(self, eval_dicts: List[Dict]) -> None:
|
|
190
|
+
for eval_dict in eval_dicts:
|
|
191
|
+
class_idx = eval_dict['class']
|
|
192
|
+
self.class_tpfp[class_idx].append(eval_dict)
|
|
193
|
+
|
|
194
|
+
def value(self) -> Dict:
|
|
195
|
+
eps = np.finfo(np.float32).eps
|
|
196
|
+
result: Dict[int, float] = {}
|
|
197
|
+
for class_idx, tpfp in self.class_tpfp.items():
|
|
198
|
+
a1 = [x['tp'] for x in tpfp]
|
|
199
|
+
tp = np.concatenate([x['tp'] for x in tpfp], axis=0)
|
|
200
|
+
fp = np.concatenate([x['fp'] for x in tpfp], axis=0)
|
|
201
|
+
num_gts = np.sum([x['num_gts'] for x in tpfp])
|
|
202
|
+
scores = np.concatenate([np.asarray(x['scores']) for x in tpfp])
|
|
203
|
+
sorted_idxs = np.argsort(-scores)
|
|
204
|
+
tp_cumsum = tp[sorted_idxs].cumsum()
|
|
205
|
+
fp_cumsum = fp[sorted_idxs].cumsum()
|
|
206
|
+
precision = tp_cumsum / np.maximum(tp_cumsum + fp_cumsum, eps)
|
|
207
|
+
recall = tp_cumsum / np.maximum(num_gts, eps)
|
|
208
|
+
|
|
209
|
+
mrec = np.hstack((0, recall, 1))
|
|
210
|
+
mpre = np.hstack((0, precision, 0))
|
|
211
|
+
for i in range(mpre.shape[0] - 1, 0, -1):
|
|
212
|
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
|
213
|
+
ind = np.where(mrec[1:] != mrec[:-1])[0]
|
|
214
|
+
ap = np.sum((mrec[ind + 1] - mrec[ind]) * mpre[ind + 1])
|
|
215
|
+
result[class_idx] = ap.item()
|
|
216
|
+
return result
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pixeltable as pxt
|
|
6
|
+
import pixeltable.exceptions as excs
|
|
7
|
+
from pixeltable import env
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pxt.udf
|
|
11
|
+
def chat_completions(
|
|
12
|
+
prompt: str,
|
|
13
|
+
model: str,
|
|
14
|
+
*,
|
|
15
|
+
max_tokens: Optional[int] = None,
|
|
16
|
+
repetition_penalty: Optional[float] = None,
|
|
17
|
+
top_k: Optional[int] = None,
|
|
18
|
+
top_p: Optional[float] = None,
|
|
19
|
+
temperature: Optional[float] = None
|
|
20
|
+
) -> dict:
|
|
21
|
+
initialize()
|
|
22
|
+
kwargs = {
|
|
23
|
+
'max_tokens': max_tokens,
|
|
24
|
+
'repetition_penalty': repetition_penalty,
|
|
25
|
+
'top_k': top_k,
|
|
26
|
+
'top_p': top_p,
|
|
27
|
+
'temperature': temperature
|
|
28
|
+
}
|
|
29
|
+
kwargs_not_none = dict(filter(lambda x: x[1] is not None, kwargs.items()))
|
|
30
|
+
print(kwargs_not_none)
|
|
31
|
+
return fireworks.client.Completion.create(
|
|
32
|
+
model=model,
|
|
33
|
+
prompt_or_messages=prompt,
|
|
34
|
+
**kwargs_not_none
|
|
35
|
+
).dict()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def initialize():
|
|
39
|
+
global _is_fireworks_initialized
|
|
40
|
+
if _is_fireworks_initialized:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
_logger.info('Initializing Fireworks client.')
|
|
44
|
+
|
|
45
|
+
config = pxt.env.Env.get().config
|
|
46
|
+
|
|
47
|
+
if 'fireworks' in config and 'api_key' in config['fireworks']:
|
|
48
|
+
api_key = config['fireworks']['api_key']
|
|
49
|
+
else:
|
|
50
|
+
api_key = os.environ.get('FIREWORKS_API_KEY')
|
|
51
|
+
if api_key is None or api_key == '':
|
|
52
|
+
raise excs.Error('Fireworks client not initialized (no API key configured).')
|
|
53
|
+
|
|
54
|
+
import fireworks.client
|
|
55
|
+
|
|
56
|
+
fireworks.client.api_key = api_key
|
|
57
|
+
_is_fireworks_initialized = True
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
_logger = logging.getLogger('pixeltable')
|
|
61
|
+
_is_fireworks_initialized = False
|