pixeltable 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (139) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +520 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +373 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +113 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +187 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +61 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +88 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +27 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +413 -182
  88. pixeltable/tests/conftest.py +143 -87
  89. pixeltable/tests/test_audio.py +65 -0
  90. pixeltable/tests/test_catalog.py +27 -0
  91. pixeltable/tests/test_client.py +14 -14
  92. pixeltable/tests/test_component_view.py +372 -0
  93. pixeltable/tests/test_dataframe.py +433 -0
  94. pixeltable/tests/test_dirs.py +78 -62
  95. pixeltable/tests/test_document.py +117 -0
  96. pixeltable/tests/test_exprs.py +591 -135
  97. pixeltable/tests/test_function.py +297 -67
  98. pixeltable/tests/test_functions.py +283 -1
  99. pixeltable/tests/test_migration.py +43 -0
  100. pixeltable/tests/test_nos.py +54 -0
  101. pixeltable/tests/test_snapshot.py +208 -0
  102. pixeltable/tests/test_table.py +1085 -262
  103. pixeltable/tests/test_transactional_directory.py +42 -0
  104. pixeltable/tests/test_types.py +5 -11
  105. pixeltable/tests/test_video.py +149 -34
  106. pixeltable/tests/test_view.py +530 -0
  107. pixeltable/tests/utils.py +186 -45
  108. pixeltable/tool/create_test_db_dump.py +149 -0
  109. pixeltable/type_system.py +490 -126
  110. pixeltable/utils/__init__.py +17 -46
  111. pixeltable/utils/clip.py +12 -15
  112. pixeltable/utils/coco.py +136 -0
  113. pixeltable/utils/documents.py +39 -0
  114. pixeltable/utils/filecache.py +195 -0
  115. pixeltable/utils/help.py +11 -0
  116. pixeltable/utils/media_store.py +76 -0
  117. pixeltable/utils/parquet.py +126 -0
  118. pixeltable/utils/pytorch.py +172 -0
  119. pixeltable/utils/s3.py +13 -0
  120. pixeltable/utils/sql.py +17 -0
  121. pixeltable/utils/transactional_directory.py +35 -0
  122. pixeltable-0.2.0.dist-info/LICENSE +18 -0
  123. pixeltable-0.2.0.dist-info/METADATA +117 -0
  124. pixeltable-0.2.0.dist-info/RECORD +125 -0
  125. {pixeltable-0.1.1.dist-info → pixeltable-0.2.0.dist-info}/WHEEL +1 -1
  126. pixeltable/catalog.py +0 -1421
  127. pixeltable/exprs.py +0 -1745
  128. pixeltable/function.py +0 -269
  129. pixeltable/functions/clip.py +0 -10
  130. pixeltable/functions/pil/__init__.py +0 -23
  131. pixeltable/functions/tf.py +0 -21
  132. pixeltable/index.py +0 -57
  133. pixeltable/tests/test_dict.py +0 -24
  134. pixeltable/tests/test_tf.py +0 -69
  135. pixeltable/tf.py +0 -33
  136. pixeltable/utils/tf.py +0 -33
  137. pixeltable/utils/video.py +0 -32
  138. pixeltable-0.1.1.dist-info/METADATA +0 -31
  139. pixeltable-0.1.1.dist-info/RECORD +0 -36
pixeltable/func/udf.py ADDED
@@ -0,0 +1,163 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import List, Callable, Optional, overload, Any
5
+
6
+ import pixeltable as pxt
7
+ import pixeltable.exceptions as excs
8
+ import pixeltable.type_system as ts
9
+ from .batched_function import ExplicitBatchedFunction
10
+ from .callable_function import CallableFunction
11
+ from .expr_template_function import ExprTemplateFunction
12
+ from .function import Function
13
+ from .function_registry import FunctionRegistry
14
+ from .signature import Signature
15
+
16
+
17
+ # Decorator invoked without parentheses: @pxt.udf
18
+ @overload
19
+ def udf(decorated_fn: Callable) -> Function: ...
20
+
21
+
22
+ # Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
23
+ @overload
24
+ def udf(
25
+ *,
26
+ return_type: Optional[ts.ColumnType] = None,
27
+ param_types: Optional[List[ts.ColumnType]] = None,
28
+ batch_size: Optional[int] = None,
29
+ substitute_fn: Optional[Callable] = None,
30
+ _force_stored: bool = False
31
+ ) -> Callable: ...
32
+
33
+
34
+ def udf(*args, **kwargs):
35
+ """A decorator to create a Function from a function definition.
36
+
37
+ Examples:
38
+ >>> @pxt.udf
39
+ ... def my_function(x: int) -> int:
40
+ ... return x + 1
41
+
42
+ >>> @pxt.udf(param_types=[pxt.IntType()], return_type=pxt.IntType())
43
+ ... def my_function(x):
44
+ ... return x + 1
45
+ """
46
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
47
+
48
+ # Decorator invoked without parentheses: @pxt.udf
49
+ # Simply call make_function with defaults.
50
+ return make_function(decorated_fn=args[0])
51
+
52
+ else:
53
+
54
+ # Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
55
+ # Create a decorator for the specified schema.
56
+ return_type = kwargs.pop('return_type', None)
57
+ param_types = kwargs.pop('param_types', None)
58
+ batch_size = kwargs.pop('batch_size', None)
59
+ substitute_fn = kwargs.pop('py_fn', None)
60
+ force_stored = kwargs.pop('_force_stored', False)
61
+
62
+ def decorator(decorated_fn: Callable):
63
+ return make_function(
64
+ decorated_fn, return_type, param_types, batch_size, substitute_fn=substitute_fn,
65
+ force_stored=force_stored)
66
+
67
+ return decorator
68
+
69
+
70
+ def make_function(
71
+ decorated_fn: Callable,
72
+ return_type: Optional[ts.ColumnType] = None,
73
+ param_types: Optional[List[ts.ColumnType]] = None,
74
+ batch_size: Optional[int] = None,
75
+ substitute_fn: Optional[Callable] = None,
76
+ function_name: Optional[str] = None,
77
+ force_stored: bool = False
78
+ ) -> Function:
79
+ """
80
+ Constructs a `CallableFunction` or `BatchedFunction`, depending on the
81
+ supplied parameters. If `substitute_fn` is specified, then `decorated_fn`
82
+ will be used only for its signature, with execution delegated to
83
+ `substitute_fn`.
84
+ """
85
+ # Obtain function_path from decorated_fn when appropriate
86
+ if force_stored:
87
+ # force storing the function in the db
88
+ function_path = None
89
+ elif decorated_fn.__module__ != '__main__' and decorated_fn.__name__.isidentifier():
90
+ function_path = f'{decorated_fn.__module__}.{decorated_fn.__qualname__}'
91
+ else:
92
+ function_path = None
93
+
94
+ # Derive function_name, if not specified explicitly
95
+ if function_name is None:
96
+ function_name = decorated_fn.__name__
97
+
98
+ # Display name to use for error messages
99
+ errmsg_name = function_name if function_path is None else function_path
100
+
101
+ sig = Signature.create(decorated_fn, param_types, return_type)
102
+
103
+ # batched functions must have a batched return type
104
+ # TODO: remove 'Python' from the error messages when we have full inference with Annotated types
105
+ if batch_size is not None and not sig.is_batched:
106
+ raise excs.Error(f'{errmsg_name}(): batch_size is specified; Python return type must be a `Batch`')
107
+ if batch_size is not None and len(sig.batched_parameters) == 0:
108
+ raise excs.Error(f'{errmsg_name}(): batch_size is specified; at least one Python parameter must be `Batch`')
109
+ if batch_size is None and len(sig.batched_parameters) > 0:
110
+ raise excs.Error(f'{errmsg_name}(): batched parameters in udf, but no `batch_size` given')
111
+
112
+ if substitute_fn is None:
113
+ py_fn = decorated_fn
114
+ else:
115
+ if function_path is None:
116
+ raise excs.Error(f'{errmsg_name}(): @udf decorator with a `substitute_fn` can only be used in a module')
117
+ py_fn = substitute_fn
118
+
119
+ if batch_size is None:
120
+ result = CallableFunction(signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name)
121
+ else:
122
+ result = ExplicitBatchedFunction(
123
+ signature=sig, batch_size=batch_size, invoker_fn=py_fn, self_path=function_path)
124
+
125
+ # If this function is part of a module, register it
126
+ if function_path is not None:
127
+ FunctionRegistry.get().register_function(function_path, result)
128
+
129
+ return result
130
+
131
+ @overload
132
+ def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
133
+
134
+ @overload
135
+ def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable: ...
136
+
137
+ def expr_udf(*args: Any, **kwargs: Any) -> Any:
138
+ def decorator(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction:
139
+ if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
140
+ # this is a named function in a module
141
+ function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
142
+ else:
143
+ function_path = None
144
+
145
+ sig = Signature.create(py_fn, param_types=param_types, return_type=None)
146
+ # TODO: verify that the inferred return type matches that of the template
147
+ # TODO: verify that the signature doesn't contain batched parameters
148
+
149
+ # construct Parameters from the function signature
150
+ import pixeltable.exprs as exprs
151
+ var_exprs = [exprs.Variable(param.name, param.col_type) for param in sig.parameters.values()]
152
+ # call the function with the parameter expressions to construct an Expr with parameters
153
+ template = py_fn(*var_exprs)
154
+ assert isinstance(template, exprs.Expr)
155
+ py_sig = inspect.signature(py_fn)
156
+ return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
157
+
158
+ if len(args) == 1:
159
+ assert len(kwargs) == 0 and callable(args[0])
160
+ return decorator(args[0], None)
161
+ else:
162
+ assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
163
+ return lambda py_fn: decorator(py_fn, kwargs['param_types'])
@@ -1,89 +1,57 @@
1
- import os
2
- from typing import Callable, List, Optional, Union
3
- import inspect
4
- from pathlib import Path
5
1
  import tempfile
2
+ from pathlib import Path
3
+ from typing import Optional, Union
6
4
 
7
- import PIL, cv2
5
+ import PIL.Image
6
+ import av
7
+ import av.container
8
+ import av.stream
8
9
  import numpy as np
9
10
 
10
- from pixeltable.type_system import StringType, IntType, JsonType, ColumnType, FloatType, ImageType, VideoType
11
- from pixeltable.function import Function
12
- from pixeltable import catalog
11
+ import pixeltable.env as env
12
+ import pixeltable.func as func
13
+ # import all standard function modules here so they get registered with the FunctionRegistry
14
+ import pixeltable.functions.pil.image
13
15
  from pixeltable import exprs
14
- from pixeltable import env
15
- import pixeltable.exceptions as exc
16
-
17
-
18
- def udf_call(eval_fn: Callable, return_type: ColumnType, tbl: Optional[catalog.Table]) -> exprs.FunctionCall:
19
- """
20
- Interprets eval_fn's parameters to be references to columns in 'tbl' and construct ColumnRefs as args.
21
- """
22
- params = inspect.signature(eval_fn).parameters
23
- if len(params) > 0 and tbl is None:
24
- raise exc.OperationalError(f'udf_call() is missing tbl parameter')
25
- args: List[exprs.ColumnRef] = []
26
- for param_name in params:
27
- if param_name not in tbl.cols_by_name:
28
- raise exc.OperationalError(
29
- (f'udf_call(): lambda argument names need to be valid column names in table {tbl.name}: '
30
- f'column {param_name} unknown'))
31
- args.append(exprs.ColumnRef(tbl.cols_by_name[param_name]))
32
- fn = Function(return_type, [arg.col_type for arg in args], eval_fn=eval_fn)
33
- return exprs.FunctionCall(fn, args)
16
+ from pixeltable.type_system import IntType, ColumnType, FloatType, ImageType, VideoType
17
+ # automatically import all submodules so that the udfs get registered
18
+ from . import image, string, video, openai, together, fireworks, huggingface
34
19
 
20
+ # TODO: remove and replace calls with astype()
35
21
  def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
36
22
  expr.col_type = target_type
37
23
  return expr
38
24
 
39
- dict_map = Function(IntType(), [StringType(), JsonType()], eval_fn=lambda s, d: d[s])
40
-
41
-
42
- class SumAggregator:
25
+ @func.uda(
26
+ update_types=[IntType()], value_type=IntType(), name='sum', allows_window=True, requires_order_by=False)
27
+ class SumAggregator(func.Aggregator):
43
28
  def __init__(self):
44
29
  self.sum: Union[int, float] = 0
45
- @classmethod
46
- def make_aggregator(cls) -> 'SumAggregator':
47
- return cls()
48
30
  def update(self, val: Union[int, float]) -> None:
49
31
  if val is not None:
50
32
  self.sum += val
51
33
  def value(self) -> Union[int, float]:
52
34
  return self.sum
53
35
 
54
- sum = Function(
55
- IntType(), [IntType()],
56
- module_name='pixeltable.functions',
57
- init_symbol='SumAggregator.make_aggregator',
58
- update_symbol='SumAggregator.update',
59
- value_symbol='SumAggregator.value')
60
36
 
61
- class CountAggregator:
37
+ @func.uda(
38
+ update_types=[IntType()], value_type=IntType(), name='count', allows_window = True, requires_order_by = False)
39
+ class CountAggregator(func.Aggregator):
62
40
  def __init__(self):
63
41
  self.count = 0
64
- @classmethod
65
- def make_aggregator(cls) -> 'CountAggregator':
66
- return cls()
67
42
  def update(self, val: int) -> None:
68
43
  if val is not None:
69
44
  self.count += 1
70
45
  def value(self) -> int:
71
46
  return self.count
72
47
 
73
- count = Function(
74
- IntType(), [IntType()],
75
- module_name = 'pixeltable.functions',
76
- init_symbol = 'CountAggregator.make_aggregator',
77
- update_symbol = 'CountAggregator.update',
78
- value_symbol = 'CountAggregator.value')
79
48
 
80
- class MeanAggregator:
49
+ @func.uda(
50
+ update_types=[IntType()], value_type=FloatType(), name='mean', allows_window=False, requires_order_by=False)
51
+ class MeanAggregator(func.Aggregator):
81
52
  def __init__(self):
82
53
  self.sum = 0
83
54
  self.count = 0
84
- @classmethod
85
- def make_aggregator(cls) -> 'MeanAggregator':
86
- return cls()
87
55
  def update(self, val: int) -> None:
88
56
  if val is not None:
89
57
  self.sum += val
@@ -93,54 +61,35 @@ class MeanAggregator:
93
61
  return None
94
62
  return self.sum / self.count
95
63
 
96
- mean = Function(
97
- FloatType(), [IntType()],
98
- module_name = 'pixeltable.functions',
99
- init_symbol = 'MeanAggregator.make_aggregator',
100
- update_symbol = 'MeanAggregator.update',
101
- value_symbol = 'MeanAggregator.value')
102
64
 
103
- class VideoAggregator:
104
- def __init__(self):
105
- self.video_writer = None
106
- self.size = None
107
-
108
- @classmethod
109
- def make_aggregator(cls) -> 'VideoAggregator':
110
- return cls()
111
-
112
- def update(self, frame_idx: int, frame: PIL.Image.Image) -> None:
113
- if self.video_writer is None:
114
- self.size = (frame.width, frame.height)
115
- self.out_file = Path(os.getcwd()) / f'{Path(tempfile.mktemp()).name}.mp4'
116
- self.tmp_file = Path(os.getcwd()) / f'{Path(tempfile.mktemp()).name}.mp4'
117
- self.video_writer = cv2.VideoWriter(str(self.tmp_file), cv2.VideoWriter_fourcc(*'MP4V'), 25, self.size)
118
-
119
- frame_array = np.array(frame)
120
- frame_array = cv2.cvtColor(frame_array, cv2.COLOR_RGB2BGR)
121
- self.video_writer.write(frame_array)
65
+ @func.uda(
66
+ init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(), name='make_video',
67
+ requires_order_by=True, allows_window=False)
68
+ class VideoAggregator(func.Aggregator):
69
+ def __init__(self, fps: int = 25):
70
+ """follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
71
+ self.container: Optional[av.container.OutputContainer] = None
72
+ self.stream: Optional[av.stream.Stream] = None
73
+ self.fps = fps
74
+
75
+ def update(self, frame: PIL.Image.Image) -> None:
76
+ if frame is None:
77
+ return
78
+ if self.container is None:
79
+ (_, output_filename) = tempfile.mkstemp(suffix='.mp4', dir=str(env.Env.get().tmp_dir))
80
+ self.out_file = Path(output_filename)
81
+ self.container = av.open(str(self.out_file), mode='w')
82
+ self.stream = self.container.add_stream('h264', rate=self.fps)
83
+ self.stream.pix_fmt = 'yuv420p'
84
+ self.stream.width = frame.width
85
+ self.stream.height = frame.height
86
+
87
+ av_frame = av.VideoFrame.from_ndarray(np.array(frame.convert('RGB')), format='rgb24')
88
+ for packet in self.stream.encode(av_frame):
89
+ self.container.mux(packet)
122
90
 
123
91
  def value(self) -> str:
124
- self.video_writer.release()
125
- os.system(f'ffmpeg -i {self.tmp_file} -vcodec libx264 {self.out_file}')
126
- os.remove(self.tmp_file)
127
- return self.out_file
128
-
129
- make_video = Function(
130
- VideoType(), [IntType(), ImageType()], # params: frame_idx, frame
131
- order_by=[0], # update() wants frames in frame_idx order
132
- module_name = 'pixeltable.functions',
133
- init_symbol = 'VideoAggregator.make_aggregator',
134
- update_symbol = 'VideoAggregator.update',
135
- value_symbol = 'VideoAggregator.value')
136
-
137
-
138
- __all__ = [
139
- udf_call,
140
- cast,
141
- dict_map,
142
- sum,
143
- count,
144
- mean,
145
- make_video
146
- ]
92
+ for packet in self.stream.encode():
93
+ self.container.mux(packet)
94
+ self.container.close()
95
+ return str(self.out_file)
@@ -0,0 +1,216 @@
1
+ from __future__ import annotations
2
+ from typing import List, Tuple, Dict
3
+ from collections import defaultdict
4
+ import sys
5
+
6
+ import numpy as np
7
+
8
+ import pixeltable.type_system as ts
9
+ import pixeltable.func as func
10
+
11
+
12
+ # TODO: figure out a better submodule structure
13
+
14
+ # the following function has been adapted from MMEval
15
+ # (sources at https://github.com/open-mmlab/mmeval)
16
+ # Copyright (c) OpenMMLab. All rights reserved.
17
+ def calculate_bboxes_area(bboxes: np.ndarray) -> np.ndarray:
18
+ """Calculate area of bounding boxes.
19
+
20
+ Args:
21
+ bboxes (numpy.ndarray): The bboxes with shape (n, 4) or (4, ) in 'xyxy' format.
22
+ Returns:
23
+ numpy.ndarray: The area of bboxes.
24
+ """
25
+ bboxes_w = (bboxes[..., 2] - bboxes[..., 0])
26
+ bboxes_h = (bboxes[..., 3] - bboxes[..., 1])
27
+ areas = bboxes_w * bboxes_h
28
+ return areas
29
+
30
+ # the following function has been adapted from MMEval
31
+ # (sources at https://github.com/open-mmlab/mmeval)
32
+ # Copyright (c) OpenMMLab. All rights reserved.
33
+ def calculate_overlaps(bboxes1: np.ndarray, bboxes2: np.ndarray) -> np.ndarray:
34
+ """Calculate the overlap between each bbox of bboxes1 and bboxes2.
35
+
36
+ Args:
37
+ bboxes1 (numpy.ndarray): The bboxes with shape (n, 4) in 'xyxy' format.
38
+ bboxes2 (numpy.ndarray): The bboxes with shape (k, 4) in 'xyxy' format.
39
+ Returns:
40
+ numpy.ndarray: IoUs or IoFs with shape (n, k).
41
+ """
42
+ bboxes1 = bboxes1.astype(np.float32)
43
+ bboxes2 = bboxes2.astype(np.float32)
44
+ rows = bboxes1.shape[0]
45
+ cols = bboxes2.shape[0]
46
+ overlaps = np.zeros((rows, cols), dtype=np.float32)
47
+
48
+ if rows * cols == 0:
49
+ return overlaps
50
+
51
+ if bboxes1.shape[0] > bboxes2.shape[0]:
52
+ # Swap bboxes for faster calculation.
53
+ bboxes1, bboxes2 = bboxes2, bboxes1
54
+ overlaps = np.zeros((cols, rows), dtype=np.float32)
55
+ exchange = True
56
+ else:
57
+ exchange = False
58
+
59
+ # Calculate the bboxes area.
60
+ area1 = calculate_bboxes_area(bboxes1)
61
+ area2 = calculate_bboxes_area(bboxes2)
62
+ eps = np.finfo(np.float32).eps
63
+
64
+ for i in range(bboxes1.shape[0]):
65
+ x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
66
+ y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
67
+ x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
68
+ y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
69
+ overlap_w = np.maximum(x_end - x_start, 0)
70
+ overlap_h = np.maximum(y_end - y_start, 0)
71
+ overlap = overlap_w * overlap_h
72
+
73
+ union = area1[i] + area2 - overlap
74
+ union = np.maximum(union, eps)
75
+ overlaps[i, :] = overlap / union
76
+ return overlaps if not exchange else overlaps.T
77
+
78
+
79
+ # the following function has been adapted from MMEval
80
+ # (sources at https://github.com/open-mmlab/mmeval)
81
+ # Copyright (c) OpenMMLab. All rights reserved.
82
+ def calculate_image_tpfp(
83
+ pred_bboxes: np.ndarray, pred_scores: np.ndarray, gt_bboxes: np.ndarray, min_iou: float
84
+ ) -> Tuple[np.ndarray, np.ndarray]:
85
+ """Calculate the true positive and false positive on an image.
86
+
87
+ Args:
88
+ pred_bboxes (numpy.ndarray): Predicted bboxes of this image, with
89
+ shape (N, 5). The scores The predicted score of the bbox is
90
+ concatenated behind the predicted bbox.
91
+ gt_bboxes (numpy.ndarray): Ground truth bboxes of this image, with
92
+ shape (M, 4).
93
+ min_iou (float): The IoU threshold.
94
+
95
+ Returns:
96
+ tuple (tp, fp):
97
+
98
+ - tp (numpy.ndarray): Shape (N,),
99
+ the true positive flag of each predicted bbox on this image.
100
+ - fp (numpy.ndarray): Shape (N,),
101
+ the false positive flag of each predicted bbox on this image.
102
+ """
103
+ # Step 1. Concatenate `gt_bboxes` and `ignore_gt_bboxes`, then set
104
+ # the `ignore_gt_flags`.
105
+ # all_gt_bboxes = np.concatenate((gt_bboxes, ignore_gt_bboxes))
106
+ # ignore_gt_flags = np.concatenate((np.zeros(
107
+ # (gt_bboxes.shape[0], 1),
108
+ # dtype=bool), np.ones((ignore_gt_bboxes.shape[0], 1), dtype=bool)))
109
+
110
+ # Step 2. Initialize the `tp` and `fp` arrays.
111
+ num_preds = pred_bboxes.shape[0]
112
+ tp = np.zeros(num_preds, dtype=np.int8)
113
+ fp = np.zeros(num_preds, dtype=np.int8)
114
+
115
+ # Step 3. If there are no gt bboxes in this image, then all pred bboxes
116
+ # within area range are false positives.
117
+ if gt_bboxes.shape[0] == 0:
118
+ fp[...] = 1
119
+ return tp, fp
120
+
121
+ # Step 4. Calculate the IoUs between the predicted bboxes and the
122
+ # ground truth bboxes.
123
+ ious = calculate_overlaps(pred_bboxes, gt_bboxes)
124
+ # For each pred bbox, the max iou with all gts.
125
+ ious_max = ious.max(axis=1)
126
+ # For each pred bbox, which gt overlaps most with it.
127
+ ious_argmax = ious.argmax(axis=1)
128
+ # Sort all pred bbox in descending order by scores.
129
+ sorted_indices = np.argsort(-pred_scores)
130
+
131
+ # Step 5. Count the `tp` and `fp` of each iou threshold and area range.
132
+ # The flags that gt bboxes have been matched.
133
+ gt_covered_flags = np.zeros(gt_bboxes.shape[0], dtype=bool)
134
+
135
+ # Count the prediction bboxes in order of decreasing score.
136
+ for pred_bbox_idx in sorted_indices:
137
+ if ious_max[pred_bbox_idx] >= min_iou:
138
+ matched_gt_idx = ious_argmax[pred_bbox_idx]
139
+ if not gt_covered_flags[matched_gt_idx]:
140
+ tp[pred_bbox_idx] = 1
141
+ gt_covered_flags[matched_gt_idx] = True
142
+ else:
143
+ # This gt bbox has been matched and counted as fp.
144
+ fp[pred_bbox_idx] = 1
145
+ else:
146
+ fp[pred_bbox_idx] = 1
147
+
148
+ return tp, fp
149
+
150
+ @func.udf(
151
+ return_type=ts.JsonType(nullable=False),
152
+ param_types=[
153
+ ts.JsonType(nullable=False),
154
+ ts.JsonType(nullable=False),
155
+ ts.JsonType(nullable=False),
156
+ ts.JsonType(nullable=False),
157
+ ts.JsonType(nullable=False)
158
+ ])
159
+ def eval_detections(
160
+ pred_bboxes: List[List[int]], pred_classes: List[int], pred_scores: List[float],
161
+ gt_bboxes: List[List[int]], gt_classes: List[int]
162
+ ) -> Dict:
163
+ class_idxs = list(set(pred_classes + gt_classes))
164
+ result: List[Dict] = []
165
+ pred_bboxes_arr = np.asarray(pred_bboxes)
166
+ pred_classes_arr = np.asarray(pred_classes)
167
+ pred_scores_arr = np.asarray(pred_scores)
168
+ gt_bboxes_arr = np.asarray(gt_bboxes)
169
+ gt_classes_arr = np.asarray(gt_classes)
170
+ for class_idx in class_idxs:
171
+ pred_filter = pred_classes_arr == class_idx
172
+ gt_filter = gt_classes_arr == class_idx
173
+ class_pred_scores = pred_scores_arr[pred_filter]
174
+ tp, fp = calculate_image_tpfp(
175
+ pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter], [0.5])
176
+ ordered_class_pred_scores = -np.sort(-class_pred_scores)
177
+ result.append({
178
+ 'min_iou': 0.5, 'class': class_idx, 'tp': tp.tolist(), 'fp': fp.tolist(),
179
+ 'scores': ordered_class_pred_scores.tolist(), 'num_gts': gt_filter.sum().item(),
180
+ })
181
+ return result
182
+
183
+ @func.uda(
184
+ update_types=[ts.JsonType()], value_type=ts.JsonType(), name='mean_ap', allows_std_agg=True, allows_window=False)
185
+ class MeanAPAggregator:
186
+ def __init__(self):
187
+ self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
188
+
189
+ def update(self, eval_dicts: List[Dict]) -> None:
190
+ for eval_dict in eval_dicts:
191
+ class_idx = eval_dict['class']
192
+ self.class_tpfp[class_idx].append(eval_dict)
193
+
194
+ def value(self) -> Dict:
195
+ eps = np.finfo(np.float32).eps
196
+ result: Dict[int, float] = {}
197
+ for class_idx, tpfp in self.class_tpfp.items():
198
+ a1 = [x['tp'] for x in tpfp]
199
+ tp = np.concatenate([x['tp'] for x in tpfp], axis=0)
200
+ fp = np.concatenate([x['fp'] for x in tpfp], axis=0)
201
+ num_gts = np.sum([x['num_gts'] for x in tpfp])
202
+ scores = np.concatenate([np.asarray(x['scores']) for x in tpfp])
203
+ sorted_idxs = np.argsort(-scores)
204
+ tp_cumsum = tp[sorted_idxs].cumsum()
205
+ fp_cumsum = fp[sorted_idxs].cumsum()
206
+ precision = tp_cumsum / np.maximum(tp_cumsum + fp_cumsum, eps)
207
+ recall = tp_cumsum / np.maximum(num_gts, eps)
208
+
209
+ mrec = np.hstack((0, recall, 1))
210
+ mpre = np.hstack((0, precision, 0))
211
+ for i in range(mpre.shape[0] - 1, 0, -1):
212
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
213
+ ind = np.where(mrec[1:] != mrec[:-1])[0]
214
+ ap = np.sum((mrec[ind + 1] - mrec[ind]) * mpre[ind + 1])
215
+ result[class_idx] = ap.item()
216
+ return result
@@ -0,0 +1,61 @@
1
+ import logging
2
+ import os
3
+ from typing import Optional
4
+
5
+ import pixeltable as pxt
6
+ import pixeltable.exceptions as excs
7
+ from pixeltable import env
8
+
9
+
10
+ @pxt.udf
11
+ def chat_completions(
12
+ prompt: str,
13
+ model: str,
14
+ *,
15
+ max_tokens: Optional[int] = None,
16
+ repetition_penalty: Optional[float] = None,
17
+ top_k: Optional[int] = None,
18
+ top_p: Optional[float] = None,
19
+ temperature: Optional[float] = None
20
+ ) -> dict:
21
+ initialize()
22
+ kwargs = {
23
+ 'max_tokens': max_tokens,
24
+ 'repetition_penalty': repetition_penalty,
25
+ 'top_k': top_k,
26
+ 'top_p': top_p,
27
+ 'temperature': temperature
28
+ }
29
+ kwargs_not_none = dict(filter(lambda x: x[1] is not None, kwargs.items()))
30
+ print(kwargs_not_none)
31
+ return fireworks.client.Completion.create(
32
+ model=model,
33
+ prompt_or_messages=prompt,
34
+ **kwargs_not_none
35
+ ).dict()
36
+
37
+
38
+ def initialize():
39
+ global _is_fireworks_initialized
40
+ if _is_fireworks_initialized:
41
+ return
42
+
43
+ _logger.info('Initializing Fireworks client.')
44
+
45
+ config = pxt.env.Env.get().config
46
+
47
+ if 'fireworks' in config and 'api_key' in config['fireworks']:
48
+ api_key = config['fireworks']['api_key']
49
+ else:
50
+ api_key = os.environ.get('FIREWORKS_API_KEY')
51
+ if api_key is None or api_key == '':
52
+ raise excs.Error('Fireworks client not initialized (no API key configured).')
53
+
54
+ import fireworks.client
55
+
56
+ fireworks.client.api_key = api_key
57
+ _is_fireworks_initialized = True
58
+
59
+
60
+ _logger = logging.getLogger('pixeltable')
61
+ _is_fireworks_initialized = False