pixeltable 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (63) hide show
  1. pixeltable/catalog/column.py +26 -49
  2. pixeltable/catalog/insertable_table.py +7 -4
  3. pixeltable/catalog/table.py +163 -57
  4. pixeltable/catalog/table_version.py +416 -140
  5. pixeltable/catalog/table_version_path.py +2 -2
  6. pixeltable/client.py +72 -6
  7. pixeltable/dataframe.py +65 -21
  8. pixeltable/env.py +52 -53
  9. pixeltable/exec/cache_prefetch_node.py +1 -1
  10. pixeltable/exec/in_memory_data_node.py +11 -7
  11. pixeltable/exprs/comparison.py +3 -3
  12. pixeltable/exprs/data_row.py +5 -1
  13. pixeltable/exprs/literal.py +16 -4
  14. pixeltable/exprs/row_builder.py +8 -40
  15. pixeltable/ext/__init__.py +5 -0
  16. pixeltable/ext/functions/yolox.py +92 -0
  17. pixeltable/func/aggregate_function.py +15 -15
  18. pixeltable/func/expr_template_function.py +9 -1
  19. pixeltable/func/globals.py +24 -14
  20. pixeltable/func/signature.py +18 -12
  21. pixeltable/func/udf.py +7 -2
  22. pixeltable/functions/__init__.py +9 -9
  23. pixeltable/functions/eval.py +7 -8
  24. pixeltable/functions/fireworks.py +10 -37
  25. pixeltable/functions/huggingface.py +47 -19
  26. pixeltable/functions/openai.py +192 -24
  27. pixeltable/functions/together.py +104 -9
  28. pixeltable/functions/util.py +11 -0
  29. pixeltable/index/__init__.py +2 -0
  30. pixeltable/index/base.py +49 -0
  31. pixeltable/index/embedding_index.py +95 -0
  32. pixeltable/metadata/schema.py +45 -22
  33. pixeltable/plan.py +15 -34
  34. pixeltable/store.py +38 -41
  35. pixeltable/tests/conftest.py +8 -14
  36. pixeltable/tests/ext/test_yolox.py +21 -0
  37. pixeltable/tests/functions/test_fireworks.py +43 -0
  38. pixeltable/tests/functions/test_functions.py +60 -0
  39. pixeltable/tests/{test_functions.py → functions/test_huggingface.py} +7 -143
  40. pixeltable/tests/functions/test_openai.py +162 -0
  41. pixeltable/tests/functions/test_together.py +112 -0
  42. pixeltable/tests/test_component_view.py +14 -5
  43. pixeltable/tests/test_dataframe.py +23 -22
  44. pixeltable/tests/test_exprs.py +99 -102
  45. pixeltable/tests/test_function.py +51 -43
  46. pixeltable/tests/test_index.py +138 -0
  47. pixeltable/tests/test_migration.py +2 -1
  48. pixeltable/tests/test_snapshot.py +24 -1
  49. pixeltable/tests/test_table.py +205 -26
  50. pixeltable/tests/test_types.py +30 -0
  51. pixeltable/tests/test_video.py +16 -16
  52. pixeltable/tests/test_view.py +5 -0
  53. pixeltable/tests/utils.py +171 -14
  54. pixeltable/tool/create_test_db_dump.py +16 -0
  55. pixeltable/type_system.py +77 -128
  56. pixeltable/utils/arrow.py +98 -0
  57. pixeltable/utils/hf_datasets.py +157 -0
  58. pixeltable/utils/parquet.py +68 -27
  59. pixeltable/utils/pytorch.py +16 -97
  60. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/METADATA +35 -28
  61. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/RECORD +63 -50
  62. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
  63. {pixeltable-0.2.3.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
@@ -0,0 +1,92 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Iterable, Iterator
4
+ from urllib.request import urlretrieve
5
+
6
+ import PIL.Image
7
+ import numpy as np
8
+ import torch
9
+ from yolox.data import ValTransform
10
+ from yolox.exp import get_exp, Exp
11
+ from yolox.models import YOLOX
12
+ from yolox.utils import postprocess
13
+
14
+ import pixeltable as pxt
15
+ from pixeltable import env
16
+ from pixeltable.func import Batch
17
+ from pixeltable.functions.util import resolve_torch_device
18
+
19
+ _logger = logging.getLogger('pixeltable')
20
+
21
+
22
+ @pxt.udf(batch_size=4)
23
+ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
24
+ """
25
+ Runs the specified YOLOX object detection model on an image.
26
+
27
+ YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
28
+ intended for use in production applications.
29
+
30
+ Parameters:
31
+ - `model_id` - one of: `yolox_nano, `yolox_tiny`, `yolox_s`, `yolox_m`, `yolox_l`, `yolox_x`
32
+ - `threshold` - the threshold for object detection
33
+ """
34
+ model, exp = _lookup_model(model_id, 'cpu')
35
+ image_tensors = list(_images_to_tensors(images, exp))
36
+ batch_tensor = torch.stack(image_tensors)
37
+
38
+ with torch.no_grad():
39
+ output_tensor = model(batch_tensor)
40
+
41
+ outputs = postprocess(
42
+ output_tensor, 80, threshold, exp.nmsthre, class_agnostic=False
43
+ )
44
+
45
+ results: list[dict] = []
46
+ for image in images:
47
+ ratio = min(exp.test_size[0] / image.height, exp.test_size[1] / image.width)
48
+ if outputs[0] is None:
49
+ results.append({'bboxes': [], 'scores': [], 'labels': []})
50
+ else:
51
+ results.append({
52
+ 'bboxes': [(output[:4] / ratio).tolist() for output in outputs[0]],
53
+ 'scores': [output[4].item() * output[5].item() for output in outputs[0]],
54
+ 'labels': [int(output[6]) for output in outputs[0]]
55
+ })
56
+ return results
57
+
58
+
59
+ def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
60
+ for image in images:
61
+ image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
62
+ yield torch.from_numpy(image_transform)
63
+
64
+
65
+ def _lookup_model(model_id: str, device: str) -> (YOLOX, Exp):
66
+ key = (model_id, device)
67
+ if key in _model_cache:
68
+ return _model_cache[key]
69
+
70
+ weights_url = f'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/{model_id}.pth'
71
+ weights_file = Path(f'{env.Env.get().tmp_dir}/{model_id}.pth')
72
+ if not weights_file.exists():
73
+ _logger.info(f'Downloading weights for YOLOX model {model_id}: from {weights_url} -> {weights_file}')
74
+ urlretrieve(weights_url, weights_file)
75
+
76
+ exp = get_exp(exp_name=model_id)
77
+ model = exp.get_model().to(device)
78
+
79
+ model.eval()
80
+ model.head.training = False
81
+ model.training = False
82
+
83
+ # Load in the weights from training
84
+ weights = torch.load(weights_file, map_location=torch.device(device))
85
+ model.load_state_dict(weights['model'])
86
+
87
+ _model_cache[key] = (model, exp)
88
+ return model, exp
89
+
90
+
91
+ _model_cache = {}
92
+ _val_transform = ValTransform(legacy=False)
@@ -3,13 +3,14 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- from typing import Optional, Any, Type, List, Dict
6
+ from typing import Optional, Any, Type, List, Dict, Callable
7
7
  import itertools
8
8
 
9
9
  import pixeltable.exceptions as excs
10
10
  import pixeltable.type_system as ts
11
11
  from .function import Function
12
12
  from .signature import Signature, Parameter
13
+ from .globals import validate_symbol_path
13
14
 
14
15
 
15
16
  class Aggregator(abc.ABC):
@@ -136,8 +137,7 @@ def uda(
136
137
  update_types: List[ts.ColumnType],
137
138
  init_types: Optional[List[ts.ColumnType]] = None,
138
139
  requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
139
- name: Optional[str] = None
140
- ) -> Type[Aggregator]:
140
+ ) -> Callable:
141
141
  """Decorator for user-defined aggregate functions.
142
142
 
143
143
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -155,14 +155,11 @@ def uda(
155
155
  - requires_order_by: if True, the first parameter to the function is the order-by expression
156
156
  - allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
157
157
  - allows_window: if True, the function can be used with a window
158
- - name: name of the AggregateFunction instance; if None, the class name is used
159
158
  """
160
- if name is not None and not name.isidentifier():
161
- raise excs.Error(f'Invalid name: {name}')
162
159
  if init_types is None:
163
160
  init_types = []
164
161
 
165
- def decorator(cls: Type[Aggregator]) -> Type[Aggregator]:
162
+ def decorator(cls: Type[Aggregator]) -> Type[Function]:
166
163
  # validate type parameters
167
164
  num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
168
165
  if num_init_params > 0:
@@ -178,17 +175,20 @@ def uda(
178
175
  assert value_type is not None
179
176
 
180
177
  # the AggregateFunction instance resides in the same module as cls
181
- module_path = cls.__module__
182
- nonlocal name
183
- name = name or cls.__name__
184
- instance_path = f'{module_path}.{name}'
178
+ class_path = f'{cls.__module__}.{cls.__qualname__}'
179
+ # nonlocal name
180
+ # name = name or cls.__name__
181
+ # instance_path_elements = class_path.split('.')[:-1] + [name]
182
+ # instance_path = '.'.join(instance_path_elements)
185
183
 
186
184
  # create the corresponding AggregateFunction instance
187
185
  instance = AggregateFunction(
188
- cls, instance_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
189
- module = importlib.import_module(module_path)
190
- setattr(module, name, instance)
186
+ cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
187
+ # do the path validation at the very end, in order to be able to write tests for the other failure cases
188
+ validate_symbol_path(class_path)
189
+ #module = importlib.import_module(cls.__module__)
190
+ #setattr(module, name, instance)
191
191
 
192
- return cls
192
+ return instance
193
193
 
194
194
  return decorator
@@ -50,9 +50,17 @@ class ExprTemplateFunction(Function):
50
50
  bound_args.update(
51
51
  {param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
52
52
  result = self.expr.copy()
53
+ import pixeltable.exprs as exprs
53
54
  for param_name, arg in bound_args.items():
54
55
  param_expr = self.param_exprs_by_name[param_name]
55
- result = result.substitute(param_expr, arg)
56
+ if not isinstance(arg, exprs.Expr):
57
+ # TODO: use the available param_expr.col_type
58
+ arg_expr = exprs.Expr.from_object(arg)
59
+ if arg_expr is None:
60
+ raise excs.Error(f'{self.self_name}(): cannot convert argument {arg} to a Pixeltable expression')
61
+ else:
62
+ arg_expr = arg
63
+ result = result.substitute(param_expr, arg_expr)
56
64
  import pixeltable.exprs as exprs
57
65
  assert not result.contains(exprs.Variable)
58
66
  return result
@@ -1,29 +1,39 @@
1
- from typing import Optional
2
- from types import ModuleType
3
1
  import importlib
4
2
  import inspect
3
+ from types import ModuleType
4
+ from typing import Optional
5
5
 
6
+ import pixeltable.exceptions as excs
6
7
 
7
- def resolve_symbol(symbol_path: str) -> object:
8
+
9
+ def resolve_symbol(symbol_path: str) -> Optional[object]:
8
10
  path_elems = symbol_path.split('.')
9
11
  module: Optional[ModuleType] = None
10
- if path_elems[0:2] == ['pixeltable', 'functions'] and len(path_elems) > 2:
11
- # if this is a pixeltable.functions submodule, it cannot be resolved via pixeltable.functions;
12
- # try to import the submodule directly
13
- submodule_path = '.'.join(path_elems[0:3])
12
+ i = len(path_elems) - 1
13
+ while i > 0 and module is None:
14
14
  try:
15
- module = importlib.import_module(submodule_path)
16
- path_elems = path_elems[3:]
15
+ module = importlib.import_module('.'.join(path_elems[:i]))
17
16
  except ModuleNotFoundError:
18
- pass
19
- if module is None:
20
- module = importlib.import_module(path_elems[0])
21
- path_elems = path_elems[1:]
17
+ i -= 1
18
+ if i == 0:
19
+ return None # Not resolvable
22
20
  obj = module
23
- for el in path_elems:
21
+ for el in path_elems[i:]:
24
22
  obj = getattr(obj, el)
25
23
  return obj
26
24
 
25
+
26
+ def validate_symbol_path(fn_path: str) -> None:
27
+ path_elems = fn_path.split('.')
28
+ fn_name = path_elems[-1]
29
+ if any(el == '<locals>' for el in path_elems):
30
+ raise excs.Error(
31
+ f'{fn_name}(): nested functions are not supported. Move the function to the module level or into a class.')
32
+ if any(not el.isidentifier() for el in path_elems):
33
+ raise excs.Error(
34
+ f'{fn_name}(): cannot resolve symbol path {fn_path}. Move the function to the module level or into a class.')
35
+
36
+
27
37
  def get_caller_module_path() -> str:
28
38
  """Return the module path of our caller's caller"""
29
39
  stack = inspect.stack()
@@ -114,20 +114,12 @@ class Signature:
114
114
  return (col_type, is_batched)
115
115
 
116
116
  @classmethod
117
- def create(
118
- cls, c: Callable,
119
- param_types: Optional[List[ts.ColumnType]] = None,
120
- return_type: Optional[Union[ts.ColumnType, Callable]] = None
121
- ) -> Signature:
122
- """Create a signature for the given Callable.
123
- Infer the parameter and return types, if none are specified.
124
- Raises an exception if the types cannot be inferred.
125
- """
117
+ def create_parameters(
118
+ cls, c: Callable, param_types: Optional[List[ts.ColumnType]] = None) -> List[Parameter]:
126
119
  sig = inspect.signature(c)
127
120
  py_parameters = list(sig.parameters.values())
128
-
129
- # check non-var parameters for name collisions and default value compatibility
130
121
  parameters: List[Parameter] = []
122
+
131
123
  for idx, param in enumerate(py_parameters):
132
124
  if param.name in cls.SPECIAL_PARAM_NAMES:
133
125
  raise excs.Error(f"'{param.name}' is a reserved parameter name")
@@ -135,6 +127,7 @@ class Signature:
135
127
  parameters.append(Parameter(param.name, None, param.kind, False))
136
128
  continue
137
129
 
130
+ # check non-var parameters for name collisions and default value compatibility
138
131
  if param_types is not None:
139
132
  if idx >= len(param_types):
140
133
  raise excs.Error(f'Missing type for parameter {param.name}')
@@ -155,7 +148,20 @@ class Signature:
155
148
 
156
149
  parameters.append(Parameter(param.name, param_type, param.kind, is_batched))
157
150
 
158
- return_is_batched = False
151
+ return parameters
152
+
153
+ @classmethod
154
+ def create(
155
+ cls, c: Callable,
156
+ param_types: Optional[List[ts.ColumnType]] = None,
157
+ return_type: Optional[Union[ts.ColumnType, Callable]] = None
158
+ ) -> Signature:
159
+ """Create a signature for the given Callable.
160
+ Infer the parameter and return types, if none are specified.
161
+ Raises an exception if the types cannot be inferred.
162
+ """
163
+ parameters = cls.create_parameters(c, param_types)
164
+ sig = inspect.signature(c)
159
165
  if return_type is None:
160
166
  return_type, return_is_batched = cls._infer_type(sig.return_annotation)
161
167
  if return_type is None:
pixeltable/func/udf.py CHANGED
@@ -11,6 +11,7 @@ from .callable_function import CallableFunction
11
11
  from .expr_template_function import ExprTemplateFunction
12
12
  from .function import Function
13
13
  from .function_registry import FunctionRegistry
14
+ from .globals import validate_symbol_path
14
15
  from .signature import Signature
15
16
 
16
17
 
@@ -124,6 +125,8 @@ def make_function(
124
125
 
125
126
  # If this function is part of a module, register it
126
127
  if function_path is not None:
128
+ # do the validation at the very end, so it's easier to write tests for other failure scenarios
129
+ validate_symbol_path(function_path)
127
130
  FunctionRegistry.get().register_function(function_path, result)
128
131
 
129
132
  return result
@@ -142,17 +145,19 @@ def expr_udf(*args: Any, **kwargs: Any) -> Any:
142
145
  else:
143
146
  function_path = None
144
147
 
145
- sig = Signature.create(py_fn, param_types=param_types, return_type=None)
146
148
  # TODO: verify that the inferred return type matches that of the template
147
149
  # TODO: verify that the signature doesn't contain batched parameters
148
150
 
149
151
  # construct Parameters from the function signature
152
+ params = Signature.create_parameters(py_fn, param_types=param_types)
150
153
  import pixeltable.exprs as exprs
151
- var_exprs = [exprs.Variable(param.name, param.col_type) for param in sig.parameters.values()]
154
+ var_exprs = [exprs.Variable(param.name, param.col_type) for param in params]
152
155
  # call the function with the parameter expressions to construct an Expr with parameters
153
156
  template = py_fn(*var_exprs)
154
157
  assert isinstance(template, exprs.Expr)
155
158
  py_sig = inspect.signature(py_fn)
159
+ if function_path is not None:
160
+ validate_symbol_path(function_path)
156
161
  return ExprTemplateFunction(template, py_signature=py_sig, self_path=function_path, name=py_fn.__name__)
157
162
 
158
163
  if len(args) == 1:
@@ -15,7 +15,7 @@ import pixeltable.functions.pil.image
15
15
  from pixeltable import exprs
16
16
  from pixeltable.type_system import IntType, ColumnType, FloatType, ImageType, VideoType
17
17
  # automatically import all submodules so that the udfs get registered
18
- from . import image, string, video, openai, together, fireworks, huggingface
18
+ from . import image, string, video, huggingface
19
19
 
20
20
  # TODO: remove and replace calls with astype()
21
21
  def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
@@ -23,8 +23,8 @@ def cast(expr: exprs.Expr, target_type: ColumnType) -> exprs.Expr:
23
23
  return expr
24
24
 
25
25
  @func.uda(
26
- update_types=[IntType()], value_type=IntType(), name='sum', allows_window=True, requires_order_by=False)
27
- class SumAggregator(func.Aggregator):
26
+ update_types=[IntType()], value_type=IntType(), allows_window=True, requires_order_by=False)
27
+ class sum(func.Aggregator):
28
28
  def __init__(self):
29
29
  self.sum: Union[int, float] = 0
30
30
  def update(self, val: Union[int, float]) -> None:
@@ -35,8 +35,8 @@ class SumAggregator(func.Aggregator):
35
35
 
36
36
 
37
37
  @func.uda(
38
- update_types=[IntType()], value_type=IntType(), name='count', allows_window = True, requires_order_by = False)
39
- class CountAggregator(func.Aggregator):
38
+ update_types=[IntType()], value_type=IntType(), allows_window = True, requires_order_by = False)
39
+ class count(func.Aggregator):
40
40
  def __init__(self):
41
41
  self.count = 0
42
42
  def update(self, val: int) -> None:
@@ -47,8 +47,8 @@ class CountAggregator(func.Aggregator):
47
47
 
48
48
 
49
49
  @func.uda(
50
- update_types=[IntType()], value_type=FloatType(), name='mean', allows_window=False, requires_order_by=False)
51
- class MeanAggregator(func.Aggregator):
50
+ update_types=[IntType()], value_type=FloatType(), allows_window=False, requires_order_by=False)
51
+ class mean(func.Aggregator):
52
52
  def __init__(self):
53
53
  self.sum = 0
54
54
  self.count = 0
@@ -63,9 +63,9 @@ class MeanAggregator(func.Aggregator):
63
63
 
64
64
 
65
65
  @func.uda(
66
- init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(), name='make_video',
66
+ init_types=[IntType()], update_types=[ImageType()], value_type=VideoType(),
67
67
  requires_order_by=True, allows_window=False)
68
- class VideoAggregator(func.Aggregator):
68
+ class make_video(func.Aggregator):
69
69
  def __init__(self, fps: int = 25):
70
70
  """follows https://pyav.org/docs/develop/cookbook/numpy.html#generating-video"""
71
71
  self.container: Optional[av.container.OutputContainer] = None
@@ -1,4 +1,3 @@
1
- from __future__ import annotations
2
1
  from typing import List, Tuple, Dict
3
2
  from collections import defaultdict
4
3
  import sys
@@ -157,16 +156,16 @@ def calculate_image_tpfp(
157
156
  ts.JsonType(nullable=False)
158
157
  ])
159
158
  def eval_detections(
160
- pred_bboxes: List[List[int]], pred_classes: List[int], pred_scores: List[float],
161
- gt_bboxes: List[List[int]], gt_classes: List[int]
159
+ pred_bboxes: List[List[int]], pred_labels: List[int], pred_scores: List[float],
160
+ gt_bboxes: List[List[int]], gt_labels: List[int]
162
161
  ) -> Dict:
163
- class_idxs = list(set(pred_classes + gt_classes))
162
+ class_idxs = list(set(pred_labels + gt_labels))
164
163
  result: List[Dict] = []
165
164
  pred_bboxes_arr = np.asarray(pred_bboxes)
166
- pred_classes_arr = np.asarray(pred_classes)
165
+ pred_classes_arr = np.asarray(pred_labels)
167
166
  pred_scores_arr = np.asarray(pred_scores)
168
167
  gt_bboxes_arr = np.asarray(gt_bboxes)
169
- gt_classes_arr = np.asarray(gt_classes)
168
+ gt_classes_arr = np.asarray(gt_labels)
170
169
  for class_idx in class_idxs:
171
170
  pred_filter = pred_classes_arr == class_idx
172
171
  gt_filter = gt_classes_arr == class_idx
@@ -181,8 +180,8 @@ def eval_detections(
181
180
  return result
182
181
 
183
182
  @func.uda(
184
- update_types=[ts.JsonType()], value_type=ts.JsonType(), name='mean_ap', allows_std_agg=True, allows_window=False)
185
- class MeanAPAggregator:
183
+ update_types=[ts.JsonType()], value_type=ts.JsonType(), allows_std_agg=True, allows_window=False)
184
+ class mean_ap(func.Aggregator):
186
185
  def __init__(self):
187
186
  self.class_tpfp: Dict[int, List[Dict]] = defaultdict(list)
188
187
 
@@ -1,61 +1,34 @@
1
- import logging
2
- import os
3
1
  from typing import Optional
4
2
 
3
+ import fireworks.client
4
+
5
5
  import pixeltable as pxt
6
- import pixeltable.exceptions as excs
7
6
  from pixeltable import env
8
7
 
9
8
 
9
+ def fireworks_client() -> fireworks.client.Fireworks:
10
+ return env.Env.get().get_client('fireworks', lambda api_key: fireworks.client.Fireworks(api_key=api_key))
11
+
12
+
10
13
  @pxt.udf
11
14
  def chat_completions(
12
- prompt: str,
13
- model: str,
15
+ messages: list[dict[str, str]],
14
16
  *,
17
+ model: str,
15
18
  max_tokens: Optional[int] = None,
16
- repetition_penalty: Optional[float] = None,
17
19
  top_k: Optional[int] = None,
18
20
  top_p: Optional[float] = None,
19
21
  temperature: Optional[float] = None
20
22
  ) -> dict:
21
- initialize()
22
23
  kwargs = {
23
24
  'max_tokens': max_tokens,
24
- 'repetition_penalty': repetition_penalty,
25
25
  'top_k': top_k,
26
26
  'top_p': top_p,
27
27
  'temperature': temperature
28
28
  }
29
29
  kwargs_not_none = dict(filter(lambda x: x[1] is not None, kwargs.items()))
30
- import fireworks.client
31
- return fireworks.client.Completion.create(
30
+ return fireworks_client().chat.completions.create(
32
31
  model=model,
33
- prompt_or_messages=prompt,
32
+ messages=messages,
34
33
  **kwargs_not_none
35
34
  ).dict()
36
-
37
-
38
- def initialize():
39
- global _is_fireworks_initialized
40
- if _is_fireworks_initialized:
41
- return
42
-
43
- _logger.info('Initializing Fireworks client.')
44
-
45
- config = pxt.env.Env.get().config
46
-
47
- if 'fireworks' in config and 'api_key' in config['fireworks']:
48
- api_key = config['fireworks']['api_key']
49
- else:
50
- api_key = os.environ.get('FIREWORKS_API_KEY')
51
- if api_key is None or api_key == '':
52
- raise excs.Error('Fireworks client not initialized (no API key configured).')
53
-
54
- import fireworks.client
55
-
56
- fireworks.client.api_key = api_key
57
- _is_fireworks_initialized = True
58
-
59
-
60
- _logger = logging.getLogger('pixeltable')
61
- _is_fireworks_initialized = False
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable
1
+ from typing import Callable, TypeVar, Optional
2
2
 
3
3
  import PIL.Image
4
4
  import numpy as np
@@ -7,10 +7,13 @@ import pixeltable as pxt
7
7
  import pixeltable.env as env
8
8
  import pixeltable.type_system as ts
9
9
  from pixeltable.func import Batch
10
+ from pixeltable.functions.util import resolve_torch_device
10
11
 
11
12
 
12
13
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType()))
13
- def sentence_transformer(sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False) -> Batch[np.ndarray]:
14
+ def sentence_transformer(
15
+ sentences: Batch[str], *, model_id: str, normalize_embeddings: bool = False
16
+ ) -> Batch[np.ndarray]:
14
17
  env.Env.get().require_package('sentence_transformers')
15
18
  from sentence_transformers import SentenceTransformer
16
19
 
@@ -53,44 +56,60 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
53
56
  return array.tolist()
54
57
 
55
58
 
56
- @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
59
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
57
60
  def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
58
61
  env.Env.get().require_package('transformers')
62
+ device = resolve_torch_device('auto')
63
+ import torch
59
64
  from transformers import CLIPModel, CLIPProcessor
60
65
 
61
- model = _lookup_model(model_id, CLIPModel.from_pretrained)
66
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
67
+ assert model.config.projection_dim == 512
62
68
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
63
69
 
64
- inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
65
- embeddings = model.get_text_features(**inputs).detach().numpy()
70
+ with torch.no_grad():
71
+ inputs = processor(text=text, return_tensors='pt', padding=True, truncation=True)
72
+ embeddings = model.get_text_features(**inputs.to(device)).detach().to('cpu').numpy()
73
+
66
74
  return [embeddings[i] for i in range(embeddings.shape[0])]
67
75
 
68
76
 
69
- @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
77
+ @pxt.udf(batch_size=32, return_type=ts.ArrayType((512,), dtype=ts.FloatType(), nullable=False))
70
78
  def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
71
79
  env.Env.get().require_package('transformers')
80
+ device = resolve_torch_device('auto')
81
+ import torch
72
82
  from transformers import CLIPModel, CLIPProcessor
73
83
 
74
- model = _lookup_model(model_id, CLIPModel.from_pretrained)
84
+ model = _lookup_model(model_id, CLIPModel.from_pretrained, device=device)
85
+ assert model.config.projection_dim == 512
75
86
  processor = _lookup_processor(model_id, CLIPProcessor.from_pretrained)
76
87
 
77
- inputs = processor(images=image, return_tensors='pt', padding=True)
78
- embeddings = model.get_image_features(**inputs).detach().numpy()
88
+ with torch.no_grad():
89
+ inputs = processor(images=image, return_tensors='pt', padding=True)
90
+ embeddings = model.get_image_features(**inputs.to(device)).detach().to('cpu').numpy()
91
+
79
92
  return [embeddings[i] for i in range(embeddings.shape[0])]
80
93
 
81
94
 
82
- @pxt.udf(batch_size=32)
95
+ @pxt.udf(batch_size=4)
83
96
  def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
84
97
  env.Env.get().require_package('transformers')
98
+ device = resolve_torch_device('auto')
99
+ import torch
85
100
  from transformers import DetrImageProcessor, DetrForObjectDetection
86
101
 
87
- model = _lookup_model(model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'))
102
+ model = _lookup_model(
103
+ model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device)
88
104
  processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
89
105
 
90
- inputs = processor(images=image, return_tensors='pt')
91
- outputs = model(**inputs)
106
+ with torch.no_grad():
107
+ inputs = processor(images=image, return_tensors='pt')
108
+ outputs = model(**inputs.to(device))
109
+ results = processor.post_process_object_detection(
110
+ outputs, threshold=threshold, target_sizes=[(img.height, img.width) for img in image]
111
+ )
92
112
 
93
- results = processor.post_process_object_detection(outputs, threshold=threshold)
94
113
  return [
95
114
  {
96
115
  'scores': [score.item() for score in result['scores']],
@@ -102,14 +121,23 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
102
121
  ]
103
122
 
104
123
 
105
- def _lookup_model(model_id: str, create: Callable) -> Any:
106
- key = (model_id, create) # For safety, include the `create` callable in the cache key
124
+ T = TypeVar('T')
125
+
126
+
127
+ def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
128
+ from torch import nn
129
+ key = (model_id, create, device) # For safety, include the `create` callable in the cache key
107
130
  if key not in _model_cache:
108
- _model_cache[key] = create(model_id)
131
+ model = create(model_id)
132
+ if device is not None:
133
+ model.to(device)
134
+ if isinstance(model, nn.Module):
135
+ model.eval()
136
+ _model_cache[key] = model
109
137
  return _model_cache[key]
110
138
 
111
139
 
112
- def _lookup_processor(model_id: str, create: Callable) -> Any:
140
+ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
113
141
  key = (model_id, create) # For safety, include the `create` callable in the cache key
114
142
  if key not in _processor_cache:
115
143
  _processor_cache[key] = create(model_id)