pixeltable 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (119) hide show
  1. pixeltable/__init__.py +53 -0
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/__init__.py +13 -0
  4. pixeltable/catalog/catalog.py +159 -0
  5. pixeltable/catalog/column.py +181 -0
  6. pixeltable/catalog/dir.py +32 -0
  7. pixeltable/catalog/globals.py +33 -0
  8. pixeltable/catalog/insertable_table.py +192 -0
  9. pixeltable/catalog/named_function.py +36 -0
  10. pixeltable/catalog/path.py +58 -0
  11. pixeltable/catalog/path_dict.py +139 -0
  12. pixeltable/catalog/schema_object.py +39 -0
  13. pixeltable/catalog/table.py +695 -0
  14. pixeltable/catalog/table_version.py +1026 -0
  15. pixeltable/catalog/table_version_path.py +133 -0
  16. pixeltable/catalog/view.py +203 -0
  17. pixeltable/dataframe.py +749 -0
  18. pixeltable/env.py +466 -0
  19. pixeltable/exceptions.py +17 -0
  20. pixeltable/exec/__init__.py +10 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +94 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +73 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +226 -0
  31. pixeltable/exprs/__init__.py +25 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +114 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +199 -0
  39. pixeltable/exprs/expr.py +594 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +382 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +96 -0
  44. pixeltable/exprs/in_predicate.py +96 -0
  45. pixeltable/exprs/inline_array.py +109 -0
  46. pixeltable/exprs/inline_dict.py +103 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +66 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +329 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/similarity_expr.py +65 -0
  56. pixeltable/exprs/type_cast.py +53 -0
  57. pixeltable/exprs/variable.py +45 -0
  58. pixeltable/ext/__init__.py +5 -0
  59. pixeltable/ext/functions/yolox.py +92 -0
  60. pixeltable/func/__init__.py +7 -0
  61. pixeltable/func/aggregate_function.py +197 -0
  62. pixeltable/func/callable_function.py +113 -0
  63. pixeltable/func/expr_template_function.py +99 -0
  64. pixeltable/func/function.py +141 -0
  65. pixeltable/func/function_registry.py +227 -0
  66. pixeltable/func/globals.py +46 -0
  67. pixeltable/func/nos_function.py +202 -0
  68. pixeltable/func/signature.py +162 -0
  69. pixeltable/func/udf.py +164 -0
  70. pixeltable/functions/__init__.py +95 -0
  71. pixeltable/functions/eval.py +215 -0
  72. pixeltable/functions/fireworks.py +34 -0
  73. pixeltable/functions/huggingface.py +167 -0
  74. pixeltable/functions/image.py +16 -0
  75. pixeltable/functions/openai.py +289 -0
  76. pixeltable/functions/pil/image.py +147 -0
  77. pixeltable/functions/string.py +13 -0
  78. pixeltable/functions/together.py +143 -0
  79. pixeltable/functions/util.py +52 -0
  80. pixeltable/functions/video.py +62 -0
  81. pixeltable/globals.py +425 -0
  82. pixeltable/index/__init__.py +2 -0
  83. pixeltable/index/base.py +51 -0
  84. pixeltable/index/embedding_index.py +168 -0
  85. pixeltable/io/__init__.py +3 -0
  86. pixeltable/io/hf_datasets.py +188 -0
  87. pixeltable/io/pandas.py +148 -0
  88. pixeltable/io/parquet.py +192 -0
  89. pixeltable/iterators/__init__.py +3 -0
  90. pixeltable/iterators/base.py +52 -0
  91. pixeltable/iterators/document.py +432 -0
  92. pixeltable/iterators/video.py +88 -0
  93. pixeltable/metadata/__init__.py +58 -0
  94. pixeltable/metadata/converters/convert_10.py +18 -0
  95. pixeltable/metadata/converters/convert_12.py +3 -0
  96. pixeltable/metadata/converters/convert_13.py +41 -0
  97. pixeltable/metadata/schema.py +234 -0
  98. pixeltable/plan.py +620 -0
  99. pixeltable/store.py +424 -0
  100. pixeltable/tool/create_test_db_dump.py +184 -0
  101. pixeltable/tool/create_test_video.py +81 -0
  102. pixeltable/type_system.py +846 -0
  103. pixeltable/utils/__init__.py +17 -0
  104. pixeltable/utils/arrow.py +98 -0
  105. pixeltable/utils/clip.py +18 -0
  106. pixeltable/utils/coco.py +136 -0
  107. pixeltable/utils/documents.py +69 -0
  108. pixeltable/utils/filecache.py +195 -0
  109. pixeltable/utils/help.py +11 -0
  110. pixeltable/utils/http_server.py +70 -0
  111. pixeltable/utils/media_store.py +76 -0
  112. pixeltable/utils/pytorch.py +91 -0
  113. pixeltable/utils/s3.py +13 -0
  114. pixeltable/utils/sql.py +17 -0
  115. pixeltable/utils/transactional_directory.py +35 -0
  116. pixeltable-0.0.0.dist-info/LICENSE +18 -0
  117. pixeltable-0.0.0.dist-info/METADATA +131 -0
  118. pixeltable-0.0.0.dist-info/RECORD +119 -0
  119. pixeltable-0.0.0.dist-info/WHEEL +4 -0
pixeltable/func/udf.py ADDED
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import List, Callable, Optional, overload, Any
5
+
6
+ import pixeltable as pxt
7
+ import pixeltable.exceptions as excs
8
+ import pixeltable.type_system as ts
9
+ from .callable_function import CallableFunction
10
+ from .expr_template_function import ExprTemplateFunction
11
+ from .function import Function
12
+ from .function_registry import FunctionRegistry
13
+ from .globals import validate_symbol_path
14
+ from .signature import Signature
15
+
16
+
17
+ # Decorator invoked without parentheses: @pxt.udf
18
+ @overload
19
+ def udf(decorated_fn: Callable) -> Function: ...
20
+
21
+
22
+ # Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
23
+ @overload
24
+ def udf(
25
+ *,
26
+ return_type: Optional[ts.ColumnType] = None,
27
+ param_types: Optional[List[ts.ColumnType]] = None,
28
+ batch_size: Optional[int] = None,
29
+ substitute_fn: Optional[Callable] = None,
30
+ _force_stored: bool = False
31
+ ) -> Callable: ...
32
+
33
+
34
+ def udf(*args, **kwargs):
35
+ """A decorator to create a Function from a function definition.
36
+
37
+ Examples:
38
+ >>> @pxt.udf
39
+ ... def my_function(x: int) -> int:
40
+ ... return x + 1
41
+
42
+ >>> @pxt.udf(param_types=[pxt.IntType()], return_type=pxt.IntType())
43
+ ... def my_function(x):
44
+ ... return x + 1
45
+ """
46
+ if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
47
+
48
+ # Decorator invoked without parentheses: @pxt.udf
49
+ # Simply call make_function with defaults.
50
+ return make_function(decorated_fn=args[0])
51
+
52
+ else:
53
+
54
+ # Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
55
+ # Create a decorator for the specified schema.
56
+ return_type = kwargs.pop('return_type', None)
57
+ param_types = kwargs.pop('param_types', None)
58
+ batch_size = kwargs.pop('batch_size', None)
59
+ substitute_fn = kwargs.pop('py_fn', None)
60
+ force_stored = kwargs.pop('_force_stored', False)
61
+
62
+ def decorator(decorated_fn: Callable):
63
+ return make_function(
64
+ decorated_fn, return_type, param_types, batch_size,
65
+ substitute_fn=substitute_fn, force_stored=force_stored)
66
+
67
+ return decorator
68
+
69
+
70
+ def make_function(
71
+ decorated_fn: Callable,
72
+ return_type: Optional[ts.ColumnType] = None,
73
+ param_types: Optional[List[ts.ColumnType]] = None,
74
+ batch_size: Optional[int] = None,
75
+ substitute_fn: Optional[Callable] = None,
76
+ function_name: Optional[str] = None,
77
+ force_stored: bool = False
78
+ ) -> Function:
79
+ """
80
+ Constructs a `CallableFunction` from the specified parameters.
81
+ If `substitute_fn` is specified, then `decorated_fn`
82
+ will be used only for its signature, with execution delegated to
83
+ `substitute_fn`.
84
+ """
85
+ # Obtain function_path from decorated_fn when appropriate
86
+ if force_stored:
87
+ # force storing the function in the db
88
+ function_path = None
89
+ elif decorated_fn.__module__ != '__main__' and decorated_fn.__name__.isidentifier():
90
+ function_path = f'{decorated_fn.__module__}.{decorated_fn.__qualname__}'
91
+ else:
92
+ function_path = None
93
+
94
+ # Derive function_name, if not specified explicitly
95
+ if function_name is None:
96
+ function_name = decorated_fn.__name__
97
+
98
+ # Display name to use for error messages
99
+ errmsg_name = function_name if function_path is None else function_path
100
+
101
+ sig = Signature.create(decorated_fn, param_types, return_type)
102
+
103
+ # batched functions must have a batched return type
104
+ # TODO: remove 'Python' from the error messages when we have full inference with Annotated types
105
+ if batch_size is not None and not sig.is_batched:
106
+ raise excs.Error(f'{errmsg_name}(): batch_size is specified; Python return type must be a `Batch`')
107
+ if batch_size is not None and len(sig.batched_parameters) == 0:
108
+ raise excs.Error(f'{errmsg_name}(): batch_size is specified; at least one Python parameter must be `Batch`')
109
+ if batch_size is None and len(sig.batched_parameters) > 0:
110
+ raise excs.Error(f'{errmsg_name}(): batched parameters in udf, but no `batch_size` given')
111
+
112
+ if substitute_fn is None:
113
+ py_fn = decorated_fn
114
+ else:
115
+ if function_path is None:
116
+ raise excs.Error(f'{errmsg_name}(): @udf decorator with a `substitute_fn` can only be used in a module')
117
+ py_fn = substitute_fn
118
+
119
+ result = CallableFunction(
120
+ signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name, batch_size=batch_size)
121
+
122
+ # If this function is part of a module, register it
123
+ if function_path is not None:
124
+ # do the validation at the very end, so it's easier to write tests for other failure scenarios
125
+ validate_symbol_path(function_path)
126
+ FunctionRegistry.get().register_function(function_path, result)
127
+
128
+ return result
129
+
130
+ @overload
131
+ def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
132
+
133
+ @overload
134
+ def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable: ...
135
+
136
+ def expr_udf(*args: Any, **kwargs: Any) -> Any:
137
+ def decorator(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction:
138
+ if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
139
+ # this is a named function in a module
140
+ function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
141
+ else:
142
+ function_path = None
143
+
144
+ # TODO: verify that the inferred return type matches that of the template
145
+ # TODO: verify that the signature doesn't contain batched parameters
146
+
147
+ # construct Parameters from the function signature
148
+ params = Signature.create_parameters(py_fn, param_types=param_types)
149
+ import pixeltable.exprs as exprs
150
+ var_exprs = [exprs.Variable(param.name, param.col_type) for param in params]
151
+ # call the function with the parameter expressions to construct an Expr with parameters
152
+ template = py_fn(*var_exprs)
153
+ assert isinstance(template, exprs.Expr)
154
+ py_sig = inspect.signature(py_fn)
155
+ if function_path is not None:
156
+ validate_symbol_path(function_path)
157
+ return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
158
+
159
+ if len(args) == 1:
160
+ assert len(kwargs) == 0 and callable(args[0])
161
+ return decorator(args[0], None)
162
+ else:
163
+ assert len(args) == 0 and len(kwargs) == 1 and 'param_types' in kwargs
164
+ return lambda py_fn: decorator(py_fn, kwargs['param_types'])
@@ -0,0 +1,95 @@
1
+ import tempfile
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import PIL.Image
6
+ import av
7
+ import av.container
8
+ import av.stream
9
+ import numpy as np
10
+
11
+ import pixeltable.env as env
12
+ import pixeltable.func as func
13
+ # import all standard function modules here so they get registered with the FunctionRegistry
14
+ import pixeltable.functions.pil.image
15
+ from pixeltable import exprs
16
+ from pixeltable.type_system import IntType, ColumnType, FloatType, ImageType, VideoType
17
+ # automatically import all submodules so that the udfs get registered
18
+ from . import image, string, video, huggingface
19
+
20
+ # TODO: remove and replace calls with astype()
21
+ def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
22
+ expr.col_type = target_type
23
+ return expr
24
+
25
+ @func.uda(
26
+ update_types=[IntType()], value_type=IntType(), allows_window=True, requires_order_by=False)
27
+ class sum(func.Aggregator):
28
+ def __init__(self):
29
+ self.sum: Union[int, float] = 0
30
+ def update(self, val: Union[int, float]) -> None:
31
+ if val is not None:
32
+ self.sum += val
33
+ def value(self) -> Union[int, float]:
34
+ return self.sum
35
+
36
+
37
+ @func.uda(
38
+ update_types=[IntType()], value_type=IntType(), allows_window = True, requires_order_by = False)
39
+ class count(func.Aggregator):
40
+ def __init__(self):
41
+ self.count = 0
42
+ def update(self, val: int) -> None:
43
+ if val is not None:
44
+ self.count += 1
45
+ def value(self) -> int:
46
+ return self.count
47
+
48
+
49
+ @func.uda(
50
+ update_types=[IntType()], value_type=FloatType(), allows_window=False, requires_order_by=False)
51
+ class mean(func.Aggregator):
52
+ def __init__(self):
53
+ self.sum = 0
54
+ self.count = 0
55
+ def update(self, val: int) -> None:
56
+ if val is not None:
57
+ self.sum += val
58
+ self.count += 1
59
+ def value(self) -> float:
60
+ if self.count == 0:
61
+ return None
62
+ return self.sum / self.count
63
+
64
+
65
+ @func.uda(
66
+ init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
67
+ requires_order_by=True, allows_window=False)
68
+ class make_video(func.Aggregator):
69
+ def __init__(self, fps: int = 25):
70
+ """follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
71
+ self.container: Optional[av.container.OutputContainer] = None
72
+ self.stream: Optional[av.stream.Stream] = None
73
+ self.fps = fps
74
+
75
+ def update(self, frame: PIL.Image.Image) -> None:
76
+ if frame is None:
77
+ return
78
+ if self.container is None:
79
+ (_, output_filename) = tempfile.mkstemp(suffix='.mp4', dir=str(env.Env.get().tmp_dir))
80
+ self.out_file = Path(output_filename)
81
+ self.container = av.open(str(self.out_file), mode='w')
82
+ self.stream = self.container.add_stream('h264', rate=self.fps)
83
+ self.stream.pix_fmt = 'yuv420p'
84
+ self.stream.width = frame.width
85
+ self.stream.height = frame.height
86
+
87
+ av_frame = av.VideoFrame.from_ndarray(np.array(frame.convert('RGB')), format='rgb24')
88
+ for packet in self.stream.encode(av_frame):
89
+ self.container.mux(packet)
90
+
91
+ def value(self) -> str:
92
+ for packet in self.stream.encode():
93
+ self.container.mux(packet)
94
+ self.container.close()
95
+ return str(self.out_file)
@@ -0,0 +1,215 @@
1
+ from typing import List, Tuple, Dict
2
+ from collections import defaultdict
3
+ import sys
4
+
5
+ import numpy as np
6
+
7
+ import pixeltable.type_system as ts
8
+ import pixeltable.func as func
9
+
10
+
11
+ # TODO: figure out a better submodule structure
12
+
13
+ # the following function has been adapted from MMEval
14
+ # (sources at https://github.com/open-mmlab/mmeval)
15
+ # Copyright (c) OpenMMLab. All rights reserved.
16
+ def calculate_bboxes_area(bboxes: np.ndarray) -> np.ndarray:
17
+ """Calculate area of bounding boxes.
18
+
19
+ Args:
20
+ bboxes (numpy.ndarray): The bboxes with shape (n, 4) or (4, ) in 'xyxy' format.
21
+ Returns:
22
+ numpy.ndarray: The area of bboxes.
23
+ """
24
+ bboxes_w = (bboxes[..., 2] - bboxes[..., 0])
25
+ bboxes_h = (bboxes[..., 3] - bboxes[..., 1])
26
+ areas = bboxes_w * bboxes_h
27
+ return areas
28
+
29
+ # the following function has been adapted from MMEval
30
+ # (sources at https://github.com/open-mmlab/mmeval)
31
+ # Copyright (c) OpenMMLab. All rights reserved.
32
+ def calculate_overlaps(bboxes1: np.ndarray, bboxes2: np.ndarray) -> np.ndarray:
33
+ """Calculate the overlap between each bbox of bboxes1 and bboxes2.
34
+
35
+ Args:
36
+ bboxes1 (numpy.ndarray): The bboxes with shape (n, 4) in 'xyxy' format.
37
+ bboxes2 (numpy.ndarray): The bboxes with shape (k, 4) in 'xyxy' format.
38
+ Returns:
39
+ numpy.ndarray: IoUs or IoFs with shape (n, k).
40
+ """
41
+ bboxes1 = bboxes1.astype(np.float32)
42
+ bboxes2 = bboxes2.astype(np.float32)
43
+ rows = bboxes1.shape[0]
44
+ cols = bboxes2.shape[0]
45
+ overlaps = np.zeros((rows, cols), dtype=np.float32)
46
+
47
+ if rows * cols == 0:
48
+ return overlaps
49
+
50
+ if bboxes1.shape[0] > bboxes2.shape[0]:
51
+ # Swap bboxes for faster calculation.
52
+ bboxes1, bboxes2 = bboxes2, bboxes1
53
+ overlaps = np.zeros((cols, rows), dtype=np.float32)
54
+ exchange = True
55
+ else:
56
+ exchange = False
57
+
58
+ # Calculate the bboxes area.
59
+ area1 = calculate_bboxes_area(bboxes1)
60
+ area2 = calculate_bboxes_area(bboxes2)
61
+ eps = np.finfo(np.float32).eps
62
+
63
+ for i in range(bboxes1.shape[0]):
64
+ x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
65
+ y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
66
+ x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
67
+ y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
68
+ overlap_w = np.maximum(x_end - x_start, 0)
69
+ overlap_h = np.maximum(y_end - y_start, 0)
70
+ overlap = overlap_w * overlap_h
71
+
72
+ union = area1[i] + area2 - overlap
73
+ union = np.maximum(union, eps)
74
+ overlaps[i, :] = overlap / union
75
+ return overlaps if not exchange else overlaps.T
76
+
77
+
78
+ # the following function has been adapted from MMEval
79
+ # (sources at https://github.com/open-mmlab/mmeval)
80
+ # Copyright (c) OpenMMLab. All rights reserved.
81
+ def calculate_image_tpfp(
82
+ pred_bboxes: np.ndarray, pred_scores: np.ndarray, gt_bboxes: np.ndarray, min_iou: float
83
+ ) -> Tuple[np.ndarray, np.ndarray]:
84
+ """Calculate the true positive and false positive on an image.
85
+
86
+ Args:
87
+ pred_bboxes (numpy.ndarray): Predicted bboxes of this image, with
88
+ shape (N, 5). The scores The predicted score of the bbox is
89
+ concatenated behind the predicted bbox.
90
+ gt_bboxes (numpy.ndarray): Ground truth bboxes of this image, with
91
+ shape (M, 4).
92
+ min_iou (float): The IoU threshold.
93
+
94
+ Returns:
95
+ tuple (tp, fp):
96
+
97
+ - tp (numpy.ndarray): Shape (N,),
98
+ the true positive flag of each predicted bbox on this image.
99
+ - fp (numpy.ndarray): Shape (N,),
100
+ the false positive flag of each predicted bbox on this image.
101
+ """
102
+ # Step 1. Concatenate `gt_bboxes` and `ignore_gt_bboxes`, then set
103
+ # the `ignore_gt_flags`.
104
+ # all_gt_bboxes = np.concatenate((gt_bboxes, ignore_gt_bboxes))
105
+ # ignore_gt_flags = np.concatenate((np.zeros(
106
+ # (gt_bboxes.shape[0], 1),
107
+ # dtype=bool), np.ones((ignore_gt_bboxes.shape[0], 1), dtype=bool)))
108
+
109
+ # Step 2. Initialize the `tp` and `fp` arrays.
110
+ num_preds = pred_bboxes.shape[0]
111
+ tp = np.zeros(num_preds, dtype=np.int8)
112
+ fp = np.zeros(num_preds, dtype=np.int8)
113
+
114
+ # Step 3. If there are no gt bboxes in this image, then all pred bboxes
115
+ # within area range are false positives.
116
+ if gt_bboxes.shape[0] == 0:
117
+ fp[...] = 1
118
+ return tp, fp
119
+
120
+ # Step 4. Calculate the IoUs between the predicted bboxes and the
121
+ # ground truth bboxes.
122
+ ious = calculate_overlaps(pred_bboxes, gt_bboxes)
123
+ # For each pred bbox, the max iou with all gts.
124
+ ious_max = ious.max(axis=1)
125
+ # For each pred bbox, which gt overlaps most with it.
126
+ ious_argmax = ious.argmax(axis=1)
127
+ # Sort all pred bbox in descending order by scores.
128
+ sorted_indices = np.argsort(-pred_scores)
129
+
130
+ # Step 5. Count the `tp` and `fp` of each iou threshold and area range.
131
+ # The flags that gt bboxes have been matched.
132
+ gt_covered_flags = np.zeros(gt_bboxes.shape[0], dtype=bool)
133
+
134
+ # Count the prediction bboxes in order of decreasing score.
135
+ for pred_bbox_idx in sorted_indices:
136
+ if ious_max[pred_bbox_idx] >= min_iou:
137
+ matched_gt_idx = ious_argmax[pred_bbox_idx]
138
+ if not gt_covered_flags[matched_gt_idx]:
139
+ tp[pred_bbox_idx] = 1
140
+ gt_covered_flags[matched_gt_idx] = True
141
+ else:
142
+ # This gt bbox has been matched and counted as fp.
143
+ fp[pred_bbox_idx] = 1
144
+ else:
145
+ fp[pred_bbox_idx] = 1
146
+
147
+ return tp, fp
148
+
149
+ @func.udf(
150
+ return_type=ts.JsonType(nullable=False),
151
+ param_types=[
152
+ ts.JsonType(nullable=False),
153
+ ts.JsonType(nullable=False),
154
+ ts.JsonType(nullable=False),
155
+ ts.JsonType(nullable=False),
156
+ ts.JsonType(nullable=False)
157
+ ])
158
+ def eval_detections(
159
+ pred_bboxes: List[List[int]], pred_labels: List[int], pred_scores: List[float],
160
+ gt_bboxes: List[List[int]], gt_labels: List[int]
161
+ ) -> Dict:
162
+ class_idxs = list(set(pred_labels + gt_labels))
163
+ result: List[Dict] = []
164
+ pred_bboxes_arr = np.asarray(pred_bboxes)
165
+ pred_classes_arr = np.asarray(pred_labels)
166
+ pred_scores_arr = np.asarray(pred_scores)
167
+ gt_bboxes_arr = np.asarray(gt_bboxes)
168
+ gt_classes_arr = np.asarray(gt_labels)
169
+ for class_idx in class_idxs:
170
+ pred_filter = pred_classes_arr == class_idx
171
+ gt_filter = gt_classes_arr == class_idx
172
+ class_pred_scores = pred_scores_arr[pred_filter]
173
+ tp, fp = calculate_image_tpfp(
174
+ pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter], [0.5])
175
+ ordered_class_pred_scores = -np.sort(-class_pred_scores)
176
+ result.append({
177
+ 'min_iou': 0.5, 'class': class_idx, 'tp': tp.tolist(), 'fp': fp.tolist(),
178
+ 'scores': ordered_class_pred_scores.tolist(), 'num_gts': gt_filter.sum().item(),
179
+ })
180
+ return result
181
+
182
+ @func.uda(
183
+ update_types=[ts.JsonType()], value_type=ts.JsonType(), allows_std_agg=True, allows_window=False)
184
+ class mean_ap(func.Aggregator):
185
+ def __init__(self):
186
+ self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
187
+
188
+ def update(self, eval_dicts: List[Dict]) -> None:
189
+ for eval_dict in eval_dicts:
190
+ class_idx = eval_dict['class']
191
+ self.class_tpfp[class_idx].append(eval_dict)
192
+
193
+ def value(self) -> Dict:
194
+ eps = np.finfo(np.float32).eps
195
+ result: Dict[int, float] = {}
196
+ for class_idx, tpfp in self.class_tpfp.items():
197
+ a1 = [x['tp'] for x in tpfp]
198
+ tp = np.concatenate([x['tp'] for x in tpfp], axis=0)
199
+ fp = np.concatenate([x['fp'] for x in tpfp], axis=0)
200
+ num_gts = np.sum([x['num_gts'] for x in tpfp])
201
+ scores = np.concatenate([np.asarray(x['scores']) for x in tpfp])
202
+ sorted_idxs = np.argsort(-scores)
203
+ tp_cumsum = tp[sorted_idxs].cumsum()
204
+ fp_cumsum = fp[sorted_idxs].cumsum()
205
+ precision = tp_cumsum / np.maximum(tp_cumsum + fp_cumsum, eps)
206
+ recall = tp_cumsum / np.maximum(num_gts, eps)
207
+
208
+ mrec = np.hstack((0, recall, 1))
209
+ mpre = np.hstack((0, precision, 0))
210
+ for i in range(mpre.shape[0] - 1, 0, -1):
211
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
212
+ ind = np.where(mrec[1:] != mrec[:-1])[0]
213
+ ap = np.sum((mrec[ind + 1] - mrec[ind]) * mpre[ind + 1])
214
+ result[class_idx] = ap.item()
215
+ return result
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ import fireworks.client
4
+
5
+ import pixeltable as pxt
6
+ from pixeltable import env
7
+
8
+
9
+ def fireworks_client() -> fireworks.client.Fireworks:
10
+ return env.Env.get().get_client('fireworks', lambda api_key: fireworks.client.Fireworks(api_key=api_key))
11
+
12
+
13
+ @pxt.udf
14
+ def chat_completions(
15
+ messages: list[dict[str, str]],
16
+ *,
17
+ model: str,
18
+ max_tokens: Optional[int] = None,
19
+ top_k: Optional[int] = None,
20
+ top_p: Optional[float] = None,
21
+ temperature: Optional[float] = None
22
+ ) -> dict:
23
+ kwargs = {
24
+ 'max_tokens': max_tokens,
25
+ 'top_k': top_k,
26
+ 'top_p': top_p,
27
+ 'temperature': temperature
28
+ }
29
+ kwargs_not_none = dict(filter(lambda x: x[1] is not None, kwargs.items()))
30
+ return fireworks_client().chat.completions.create(
31
+ model=model,
32
+ messages=messages,
33
+ **kwargs_not_none
34
+ ).dict()
@@ -0,0 +1,167 @@
1
+ from typing import Callable, TypeVar, Optional
2
+
3
+ import PIL.Image
4
+ import numpy as np
5
+
6
+ import pixeltable as pxt
7
+ import pixeltable.env as env
8
+ import pixeltable.type_system as ts
9
+ from pixeltable.func import Batch
10
+ from pixeltable.functions.util import resolve_torch_device
11
+
12
+
13
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
14
+ def sentence_transformer(
15
+ sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False
16
+ ) -> Batch[np.ndarray]:
17
+ env.Env.get().require_package('sentence_transformers')
18
+ from sentence_transformers import SentenceTransformer
19
+
20
+ model = _lookup_model(model_id, SentenceTransformer)
21
+
22
+ array = model.encode(sentences, normalize_embeddings=normalize_embeddings)
23
+ return [array[i] for i in range(array.shape[0])]
24
+
25
+
26
+ @sentence_transformer.conditional_return_type
27
+ def _(model_id: str) -> ts.ArrayType:
28
+ try:
29
+ from sentence_transformers import SentenceTransformer
30
+ model = _lookup_model(model_id, SentenceTransformer)
31
+ return ts.ArrayType((model.get_sentence_embedding_dimension(),), dtype=ts.FloatType(), nullable=False)
32
+ except ImportError:
33
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
34
+
35
+
36
+ @pxt.udf
37
+ def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
38
+ env.Env.get().require_package('sentence_transformers')
39
+ from sentence_transformers import SentenceTransformer
40
+
41
+ model = _lookup_model(model_id, SentenceTransformer)
42
+
43
+ array = model.encode(sentences, normalize_embeddings=normalize_embeddings)
44
+ return [array[i].tolist() for i in range(array.shape[0])]
45
+
46
+
47
+ @pxt.udf(batch_size=32)
48
+ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: str) -> Batch[float]:
49
+ env.Env.get().require_package('sentence_transformers')
50
+ from sentence_transformers import CrossEncoder
51
+
52
+ model = _lookup_model(model_id, CrossEncoder)
53
+
54
+ array = model.predict([[s1, s2] for s1, s2 in zip(sentences1, sentences2)], convert_to_numpy=True)
55
+ return array.tolist()
56
+
57
+
58
+ @pxt.udf
59
+ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
60
+ env.Env.get().require_package('sentence_transformers')
61
+ from sentence_transformers import CrossEncoder
62
+
63
+ model = _lookup_model(model_id, CrossEncoder)
64
+
65
+ array = model.predict([[sentence1, s2] for s2 in sentences2], convert_to_numpy=True)
66
+ return array.tolist()
67
+
68
+
69
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
70
+ def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
71
+ env.Env.get().require_package('transformers')
72
+ device = resolve_torch_device('auto')
73
+ import torch
74
+ from transformers import CLIPModel, CLIPProcessor
75
+
76
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
77
+ processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
78
+
79
+ with torch.no_grad():
80
+ inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
81
+ embeddings = model.get_text_features(**inputs.to(device)).detach().to('cpu').numpy()
82
+
83
+ return [embeddings[i] for i in range(embeddings.shape[0])]
84
+
85
+
86
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
87
+ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
88
+ env.Env.get().require_package('transformers')
89
+ device = resolve_torch_device('auto')
90
+ import torch
91
+ from transformers import CLIPModel, CLIPProcessor
92
+
93
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
94
+ processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
95
+
96
+ with torch.no_grad():
97
+ inputs = processor(images=image, return_tensors='pt', padding=True)
98
+ embeddings = model.get_image_features(**inputs.to(device)).detach().to('cpu').numpy()
99
+
100
+ return [embeddings[i] for i in range(embeddings.shape[0])]
101
+
102
+
103
+ @clip_text.conditional_return_type
104
+ @clip_image.conditional_return_type
105
+ def _(model_id: str) -> ts.ArrayType:
106
+ try:
107
+ from transformers import CLIPModel
108
+ model = _lookup_model(model_id, CLIPModel.from_pretrained)
109
+ return ts.ArrayType((model.config.projection_dim,), dtype=ts.FloatType(), nullable=False)
110
+ except ImportError:
111
+ return ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False)
112
+
113
+
114
+ @pxt.udf(batch_size=4)
115
+ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
116
+ env.Env.get().require_package('transformers')
117
+ device = resolve_torch_device('auto')
118
+ import torch
119
+ from transformers import DetrImageProcessor, DetrForObjectDetection
120
+
121
+ model = _lookup_model(
122
+ model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device)
123
+ processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
124
+
125
+ with torch.no_grad():
126
+ inputs = processor(images=image, return_tensors='pt')
127
+ outputs = model(**inputs.to(device))
128
+ results = processor.post_process_object_detection(
129
+ outputs, threshold=threshold, target_sizes=[(img.height, img.width) for img in image]
130
+ )
131
+
132
+ return [
133
+ {
134
+ 'scores': [score.item() for score in result['scores']],
135
+ 'labels': [label.item() for label in result['labels']],
136
+ 'label_text': [model.config.id2label[label.item()] for label in result['labels']],
137
+ 'boxes': [box.tolist() for box in result['boxes']]
138
+ }
139
+ for result in results
140
+ ]
141
+
142
+
143
+ T = TypeVar('T')
144
+
145
+
146
+ def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
147
+ from torch import nn
148
+ key = (model_id, create, device) # For safety, include the `create` callable in the cache key
149
+ if key not in _model_cache:
150
+ model = create(model_id)
151
+ if device is not None:
152
+ model.to(device)
153
+ if isinstance(model, nn.Module):
154
+ model.eval()
155
+ _model_cache[key] = model
156
+ return _model_cache[key]
157
+
158
+
159
+ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
160
+ key = (model_id, create) # For safety, include the `create` callable in the cache key
161
+ if key not in _processor_cache:
162
+ _processor_cache[key] = create(model_id)
163
+ return _processor_cache[key]
164
+
165
+
166
+ _model_cache = {}
167
+ _processor_cache = {}
@@ -0,0 +1,16 @@
1
+ import base64
2
+
3
+ import PIL.Image
4
+
5
+ from pixeltable.type_system import ImageType, StringType
6
+ import pixeltable.func as func
7
+
8
+
9
+ @func.udf
10
+ def b64_encode(img: PIL.Image.Image, image_format: str = 'png') -> str:
11
+ # Encode this image as a b64-encoded png.
12
+ import io
13
+ bytes_arr = io.BytesIO()
14
+ img.save(bytes_arr, format=image_format)
15
+ b64_bytes = base64.b64encode(bytes_arr.getvalue())
16
+ return b64_bytes.decode('utf-8')