pixeltable 0.2.13__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (58) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/column.py +8 -3
  4. pixeltable/catalog/globals.py +8 -0
  5. pixeltable/catalog/table.py +25 -9
  6. pixeltable/catalog/table_version.py +30 -55
  7. pixeltable/catalog/view.py +1 -1
  8. pixeltable/env.py +4 -4
  9. pixeltable/exec/__init__.py +2 -1
  10. pixeltable/exec/row_update_node.py +61 -0
  11. pixeltable/exec/{sql_scan_node.py → sql_node.py} +120 -56
  12. pixeltable/exprs/__init__.py +1 -1
  13. pixeltable/exprs/arithmetic_expr.py +41 -16
  14. pixeltable/exprs/expr.py +72 -22
  15. pixeltable/exprs/function_call.py +64 -29
  16. pixeltable/exprs/globals.py +5 -1
  17. pixeltable/exprs/inline_array.py +18 -11
  18. pixeltable/exprs/method_ref.py +63 -0
  19. pixeltable/ext/__init__.py +9 -0
  20. pixeltable/ext/functions/__init__.py +8 -0
  21. pixeltable/ext/functions/whisperx.py +45 -5
  22. pixeltable/ext/functions/yolox.py +60 -14
  23. pixeltable/func/callable_function.py +12 -4
  24. pixeltable/func/expr_template_function.py +1 -1
  25. pixeltable/func/function.py +12 -2
  26. pixeltable/func/function_registry.py +24 -9
  27. pixeltable/func/udf.py +32 -4
  28. pixeltable/functions/__init__.py +1 -1
  29. pixeltable/functions/fireworks.py +33 -0
  30. pixeltable/functions/huggingface.py +96 -6
  31. pixeltable/functions/image.py +226 -41
  32. pixeltable/functions/json.py +46 -0
  33. pixeltable/functions/openai.py +214 -0
  34. pixeltable/functions/string.py +195 -218
  35. pixeltable/functions/timestamp.py +210 -0
  36. pixeltable/functions/together.py +106 -0
  37. pixeltable/functions/video.py +2 -2
  38. pixeltable/functions/{eval.py → vision.py} +170 -27
  39. pixeltable/functions/whisper.py +32 -0
  40. pixeltable/io/__init__.py +1 -1
  41. pixeltable/io/external_store.py +2 -2
  42. pixeltable/io/globals.py +133 -1
  43. pixeltable/io/pandas.py +82 -31
  44. pixeltable/iterators/video.py +55 -23
  45. pixeltable/metadata/__init__.py +1 -1
  46. pixeltable/metadata/converters/convert_18.py +39 -0
  47. pixeltable/metadata/notes.py +10 -0
  48. pixeltable/plan.py +76 -1
  49. pixeltable/store.py +65 -28
  50. pixeltable/tool/create_test_db_dump.py +8 -9
  51. pixeltable/tool/doc_plugins/griffe.py +4 -0
  52. pixeltable/type_system.py +84 -63
  53. {pixeltable-0.2.13.dist-info → pixeltable-0.2.15.dist-info}/METADATA +2 -2
  54. {pixeltable-0.2.13.dist-info → pixeltable-0.2.15.dist-info}/RECORD +57 -51
  55. pixeltable/exprs/image_member_access.py +0 -96
  56. {pixeltable-0.2.13.dist-info → pixeltable-0.2.15.dist-info}/LICENSE +0 -0
  57. {pixeltable-0.2.13.dist-info → pixeltable-0.2.15.dist-info}/WHEEL +0 -0
  58. {pixeltable-0.2.13.dist-info → pixeltable-0.2.15.dist-info}/entry_points.txt +0 -0
@@ -1,8 +1,9 @@
1
- from typing import Optional
1
+ from typing import Optional, TYPE_CHECKING
2
2
 
3
- import torch
4
- import whisperx
5
- from whisperx.asr import FasterWhisperPipeline
3
+ from pixeltable.utils.code import local_public_names
4
+
5
+ if TYPE_CHECKING:
6
+ from whisperx.asr import FasterWhisperPipeline
6
7
 
7
8
  import pixeltable as pxt
8
9
 
@@ -11,6 +12,36 @@ import pixeltable as pxt
11
12
  def transcribe(
12
13
  audio: str, *, model: str, compute_type: Optional[str] = None, language: Optional[str] = None, chunk_size: int = 30
13
14
  ) -> dict:
15
+ """
16
+ Transcribe an audio file using WhisperX.
17
+
18
+ This UDF runs a transcription model _locally_ using the WhisperX library,
19
+ equivalent to the WhisperX `transcribe` function, as described in the
20
+ [WhisperX library documentation](https://github.com/m-bain/whisperX).
21
+
22
+ __Requirements:__
23
+
24
+ - `pip install whisperx`
25
+
26
+ Args:
27
+ audio: The audio file to transcribe.
28
+ model: The name of the model to use for transcription.
29
+
30
+ See the [WhisperX library documentation](https://github.com/m-bain/whisperX) for details
31
+ on the remaining parameters.
32
+
33
+ Returns:
34
+ A dictionary containing the transcription and various other metadata.
35
+
36
+ Examples:
37
+ Add a computed column that applies the model `tiny.en` to an existing Pixeltable column `tbl.audio`
38
+ of the table `tbl`:
39
+
40
+ >>> tbl['result'] = transcribe(tbl.audio, model='tiny.en')
41
+ """
42
+ import torch
43
+ import whisperx
44
+
14
45
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
46
  compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
16
47
  model = _lookup_model(model, device, compute_type)
@@ -19,7 +50,9 @@ def transcribe(
19
50
  return result
20
51
 
21
52
 
22
- def _lookup_model(model_id: str, device: str, compute_type: str) -> FasterWhisperPipeline:
53
+ def _lookup_model(model_id: str, device: str, compute_type: str) -> 'FasterWhisperPipeline':
54
+ import whisperx
55
+
23
56
  key = (model_id, device, compute_type)
24
57
  if key not in _model_cache:
25
58
  model = whisperx.load_model(model_id, device, compute_type=compute_type)
@@ -28,3 +61,10 @@ def _lookup_model(model_id: str, device: str, compute_type: str) -> FasterWhispe
28
61
 
29
62
 
30
63
  _model_cache = {}
64
+
65
+
66
+ __all__ = local_public_names(__name__)
67
+
68
+
69
+ def __dir__():
70
+ return __all__
@@ -1,20 +1,21 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import Iterable, Iterator
3
+ from typing import Iterable, Iterator, TYPE_CHECKING
4
4
  from urllib.request import urlretrieve
5
5
 
6
6
  import PIL.Image
7
7
  import numpy as np
8
- import torch
9
- from yolox.data import ValTransform
10
- from yolox.exp import get_exp, Exp
11
- from yolox.models import YOLOX
12
- from yolox.utils import postprocess
13
8
 
14
9
  import pixeltable as pxt
15
10
  from pixeltable import env
16
11
  from pixeltable.func import Batch
17
12
  from pixeltable.functions.util import normalize_image_mode
13
+ from pixeltable.utils.code import local_public_names
14
+
15
+ if TYPE_CHECKING:
16
+ import torch
17
+ from yolox.exp import Exp
18
+ from yolox.models import YOLOX
18
19
 
19
20
  _logger = logging.getLogger('pixeltable')
20
21
 
@@ -22,15 +23,32 @@ _logger = logging.getLogger('pixeltable')
22
23
  @pxt.udf(batch_size=4)
23
24
  def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
24
25
  """
25
- Runs the specified YOLOX object detection model on an image.
26
+ Computes YOLOX object detections for the specified image. `model_id` should reference one of the models
27
+ defined in the [YOLOX documentation](https://github.com/Megvii-BaseDetection/YOLOX).
26
28
 
27
29
  YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
28
30
  intended for use in production applications.
29
31
 
30
- Parameters:
31
- - `model_id` - one of: `yolox_nano, `yolox_tiny`, `yolox_s`, `yolox_m`, `yolox_l`, `yolox_x`
32
- - `threshold` - the threshold for object detection
32
+ __Requirements__:
33
+
34
+ - `pip install git+https://github.com/Megvii-BaseDetection/YOLOX`
35
+
36
+ Args:
37
+ model_id: one of: `yolox_nano`, `yolox_tiny`, `yolox_s`, `yolox_m`, `yolox_l`, `yolox_x`
38
+ threshold: the threshold for object detection
39
+
40
+ Returns:
41
+ A dictionary containing the output of the object detection model.
42
+
43
+ Examples:
44
+ Add a computed column that applies the model `yolox_m` to an existing
45
+ Pixeltable column `tbl.image` of the table `tbl`:
46
+
47
+ >>> tbl['detections'] = yolox(tbl.image, model_id='yolox_m', threshold=0.8)
33
48
  """
49
+ import torch
50
+ from yolox.utils import postprocess
51
+
34
52
  model, exp = _lookup_model(model_id, 'cpu')
35
53
  image_tensors = list(_images_to_tensors(images, exp))
36
54
  batch_tensor = torch.stack(image_tensors)
@@ -58,6 +76,21 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
58
76
 
59
77
  @pxt.udf
60
78
  def yolo_to_coco(detections: dict) -> list:
79
+ """
80
+ Converts the output of a YOLOX object detection model to COCO format.
81
+
82
+ Args:
83
+ detections: The output of a YOLOX object detection model, as returned by `yolox`.
84
+
85
+ Returns:
86
+ A dictionary containing the data from `detections`, converted to COCO format.
87
+
88
+ Examples:
89
+ Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
90
+ is the image for which detections were computed:
91
+
92
+ >>> tbl['detections_coco'] = yolo_to_coco(tbl.detections)
93
+ """
61
94
  bboxes, labels = detections['bboxes'], detections['labels']
62
95
  num_annotations = len(detections['bboxes'])
63
96
  assert num_annotations == len(detections['labels'])
@@ -72,14 +105,21 @@ def yolo_to_coco(detections: dict) -> list:
72
105
  return result
73
106
 
74
107
 
75
- def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
108
+ def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: 'Exp') -> Iterator['torch.Tensor']:
109
+ import torch
110
+ from yolox.data import ValTransform
111
+
112
+ _val_transform = ValTransform(legacy=False)
76
113
  for image in images:
77
114
  image = normalize_image_mode(image)
78
115
  image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
79
116
  yield torch.from_numpy(image_transform)
80
117
 
81
118
 
82
- def _lookup_model(model_id: str, device: str) -> (YOLOX, Exp):
119
+ def _lookup_model(model_id: str, device: str) -> tuple['YOLOX', 'Exp']:
120
+ import torch
121
+ from yolox.exp import get_exp
122
+
83
123
  key = (model_id, device)
84
124
  if key in _model_cache:
85
125
  return _model_cache[key]
@@ -105,5 +145,11 @@ def _lookup_model(model_id: str, device: str) -> (YOLOX, Exp):
105
145
  return model, exp
106
146
 
107
147
 
108
- _model_cache = {}
109
- _val_transform = ValTransform(legacy=False)
148
+ _model_cache: dict[tuple[str, str], tuple['YOLOX', 'Exp']] = {}
149
+
150
+
151
+ __all__ = local_public_names(__name__)
152
+
153
+
154
+ def __dir__():
155
+ return __all__
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import Optional, Callable, Tuple, Any
4
+ from typing import Any, Callable, Optional
5
5
  from uuid import UUID
6
6
 
7
7
  import cloudpickle
@@ -19,14 +19,21 @@ class CallableFunction(Function):
19
19
  """
20
20
 
21
21
  def __init__(
22
- self, signature: Signature, py_fn: Callable, self_path: Optional[str] = None,
23
- self_name: Optional[str] = None, batch_size: Optional[int] = None):
22
+ self,
23
+ signature: Signature,
24
+ py_fn: Callable,
25
+ self_path: Optional[str] = None,
26
+ self_name: Optional[str] = None,
27
+ batch_size: Optional[int] = None,
28
+ is_method: bool = False,
29
+ is_property: bool = False
30
+ ):
24
31
  assert py_fn is not None
25
32
  self.py_fn = py_fn
26
33
  self.self_name = self_name
27
34
  self.batch_size = batch_size
28
35
  self.__doc__ = py_fn.__doc__
29
- super().__init__(signature, self_path=self_path)
36
+ super().__init__(signature, self_path=self_path, is_method=is_method, is_property=is_property)
30
37
 
31
38
  @property
32
39
  def is_batched(self) -> bool:
@@ -78,6 +85,7 @@ class CallableFunction(Function):
78
85
  def _as_dict(self) -> dict:
79
86
  if self.self_path is None:
80
87
  # this is not a module function
88
+ assert not self.is_method and not self.is_property
81
89
  from .function_registry import FunctionRegistry
82
90
  id = FunctionRegistry.get().create_stored_function(self)
83
91
  return {'id': id.hex}
@@ -56,7 +56,7 @@ class ExprTemplateFunction(Function):
56
56
  arg_exprs[param_expr] = arg_expr
57
57
  result = result.substitute(arg_exprs)
58
58
  import pixeltable.exprs as exprs
59
- assert not result.contains(exprs.Variable)
59
+ assert not result._contains(exprs.Variable)
60
60
  return result
61
61
 
62
62
  def exec(self, *args: Any, **kwargs: Any) -> Any:
@@ -3,10 +3,12 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- from typing import Optional, Dict, Any, Tuple, Callable
6
+ from typing import Any, Callable, Dict, Optional, Tuple
7
7
 
8
8
  import pixeltable
9
+ import pixeltable.exceptions as excs
9
10
  import pixeltable.type_system as ts
11
+
10
12
  from .globals import resolve_symbol
11
13
  from .signature import Signature
12
14
 
@@ -19,9 +21,13 @@ class Function(abc.ABC):
19
21
  via the member self_path.
20
22
  """
21
23
 
22
- def __init__(self, signature: Signature, self_path: Optional[str] = None):
24
+ def __init__(self, signature: Signature, self_path: Optional[str] = None, is_method: bool = False, is_property: bool = False):
25
+ # Check that stored functions cannot be declared using `is_method` or `is_property`:
26
+ assert not ((is_method or is_property) and self_path is None)
23
27
  self.signature = signature
24
28
  self.self_path = self_path # fully-qualified path to self
29
+ self.is_method = is_method
30
+ self.is_property = is_property
25
31
  self._conditional_return_type: Optional[Callable[..., ts.ColumnType]] = None
26
32
 
27
33
  @property
@@ -38,6 +44,10 @@ class Function(abc.ABC):
38
44
  return self.self_path[len(ptf_prefix):]
39
45
  return self.self_path
40
46
 
47
+ @property
48
+ def arity(self) -> int:
49
+ return len(self.signature.parameters)
50
+
41
51
  def help_str(self) -> str:
42
52
  return self.display_name + str(self.signature)
43
53
 
@@ -4,11 +4,9 @@ import dataclasses
4
4
  import importlib
5
5
  import logging
6
6
  import sys
7
- import types
8
- from typing import Optional, Dict, List, Tuple
7
+ from typing import Optional, Dict, List
9
8
  from uuid import UUID
10
9
 
11
- import cloudpickle
12
10
  import sqlalchemy as sql
13
11
 
14
12
  import pixeltable.env as env
@@ -36,6 +34,7 @@ class FunctionRegistry:
36
34
  def __init__(self):
37
35
  self.stored_fns_by_id: Dict[UUID, Function] = {}
38
36
  self.module_fns: Dict[str, Function] = {} # fqn -> Function
37
+ self.type_methods: dict[ts.ColumnType.Type, dict[str, Function]] = {}
39
38
 
40
39
  def clear_cache(self) -> None:
41
40
  """
@@ -69,6 +68,13 @@ class FunctionRegistry:
69
68
  if fqn in self.module_fns:
70
69
  raise excs.Error(f'A UDF with that name already exists: {fqn}')
71
70
  self.module_fns[fqn] = fn
71
+ if fn.is_method or fn.is_property:
72
+ base_type = fn.signature.parameters_by_pos[0].col_type.type_enum
73
+ if base_type not in self.type_methods:
74
+ self.type_methods[base_type] = {}
75
+ if fn.name in self.type_methods[base_type]:
76
+ raise excs.Error(f'Duplicate method name for type {base_type}: {fn.name}')
77
+ self.type_methods[base_type][fn.name] = fn
72
78
 
73
79
  def list_functions(self) -> List[Function]:
74
80
  # retrieve Function.Metadata data for all existing stored functions from store directly
@@ -129,12 +135,21 @@ class FunctionRegistry:
129
135
  # assert fqn in self.module_fns, f'{fqn} not found'
130
136
  # return self.module_fns[fqn]
131
137
 
132
- def get_type_methods(self, name: str, base_type: ts.ColumnType.Type) -> List[Function]:
133
- return [
134
- fn for fn in self.module_fns.values()
135
- if fn.self_path is not None and fn.self_path.endswith('.' + name) \
136
- and fn.signature.parameters_by_pos[0].col_type.type_enum == base_type
137
- ]
138
+ def get_type_methods(self, base_type: ts.ColumnType.Type) -> list[Function]:
139
+ """
140
+ Get a list of all methods (and properties) registered for a given base type.
141
+ """
142
+ if base_type in self.type_methods:
143
+ return list(self.type_methods[base_type].values())
144
+ return []
145
+
146
+ def lookup_type_method(self, base_type: ts.ColumnType.Type, name: str) -> Optional[Function]:
147
+ """
148
+ Look up a method (or property) by name for a given base type. If no such method is registered, return None.
149
+ """
150
+ if base_type in self.type_methods and name in self.type_methods[base_type]:
151
+ return self.type_methods[base_type][name]
152
+ return None
138
153
 
139
154
  #def create_function(self, md: schema.FunctionMd, binary_obj: bytes, dir_id: Optional[UUID] = None) -> UUID:
140
155
  def create_stored_function(self, pxt_fn: Function, dir_id: Optional[UUID] = None) -> UUID:
pixeltable/func/udf.py CHANGED
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from typing import List, Callable, Optional, overload, Any
4
4
 
5
- import pixeltable as pxt
6
5
  import pixeltable.exceptions as excs
7
6
  import pixeltable.type_system as ts
8
7
  from .callable_function import CallableFunction
@@ -26,6 +25,8 @@ def udf(
26
25
  param_types: Optional[List[ts.ColumnType]] = None,
27
26
  batch_size: Optional[int] = None,
28
27
  substitute_fn: Optional[Callable] = None,
28
+ is_method: bool = False,
29
+ is_property: bool = False,
29
30
  _force_stored: bool = False
30
31
  ) -> Callable[[Callable], Function]: ...
31
32
 
@@ -56,6 +57,8 @@ def udf(*args, **kwargs):
56
57
  param_types = kwargs.pop('param_types', None)
57
58
  batch_size = kwargs.pop('batch_size', None)
58
59
  substitute_fn = kwargs.pop('substitute_fn', None)
60
+ is_method = kwargs.pop('is_method', None)
61
+ is_property = kwargs.pop('is_property', None)
59
62
  force_stored = kwargs.pop('_force_stored', False)
60
63
  if len(kwargs) > 0:
61
64
  raise excs.Error(f'Invalid @udf decorator kwargs: {", ".join(kwargs.keys())}')
@@ -64,8 +67,15 @@ def udf(*args, **kwargs):
64
67
 
65
68
  def decorator(decorated_fn: Callable):
66
69
  return make_function(
67
- decorated_fn, return_type, param_types, batch_size,
68
- substitute_fn=substitute_fn, force_stored=force_stored)
70
+ decorated_fn,
71
+ return_type,
72
+ param_types,
73
+ batch_size,
74
+ substitute_fn=substitute_fn,
75
+ is_method=is_method,
76
+ is_property=is_property,
77
+ force_stored=force_stored
78
+ )
69
79
 
70
80
  return decorator
71
81
 
@@ -76,6 +86,8 @@ def make_function(
76
86
  param_types: Optional[List[ts.ColumnType]] = None,
77
87
  batch_size: Optional[int] = None,
78
88
  substitute_fn: Optional[Callable] = None,
89
+ is_method: bool = False,
90
+ is_property: bool = False,
79
91
  function_name: Optional[str] = None,
80
92
  force_stored: bool = False
81
93
  ) -> Function:
@@ -112,6 +124,15 @@ def make_function(
112
124
  if batch_size is None and len(sig.batched_parameters) > 0:
113
125
  raise excs.Error(f'{errmsg_name}(): batched parameters in udf, but no `batch_size` given')
114
126
 
127
+ if is_method and is_property:
128
+ raise excs.Error(f'Cannot specify both `is_method` and `is_property` (in function `{function_name}`)')
129
+ if is_property and len(sig.parameters) != 1:
130
+ raise excs.Error(
131
+ f"`is_property=True` expects a UDF with exactly 1 parameter, but `{function_name}` has {len(sig.parameters)}"
132
+ )
133
+ if (is_method or is_property) and function_path is None:
134
+ raise excs.Error('Stored functions cannot be declared using `is_method` or `is_property`')
135
+
115
136
  if substitute_fn is None:
116
137
  py_fn = decorated_fn
117
138
  else:
@@ -120,7 +141,14 @@ def make_function(
120
141
  py_fn = substitute_fn
121
142
 
122
143
  result = CallableFunction(
123
- signature=sig, py_fn=py_fn, self_path=function_path, self_name=function_name, batch_size=batch_size)
144
+ signature=sig,
145
+ py_fn=py_fn,
146
+ self_path=function_path,
147
+ self_name=function_name,
148
+ batch_size=batch_size,
149
+ is_method=is_method,
150
+ is_property=is_property
151
+ )
124
152
 
125
153
  # If this function is part of a module, register it
126
154
  if function_path is not None:
@@ -1,4 +1,4 @@
1
- from . import fireworks, huggingface, image, openai, string, together, video
1
+ from . import fireworks, huggingface, image, openai, string, together, video, timestamp, json, vision
2
2
  from .globals import *
3
3
  from pixeltable.utils.code import local_public_names
4
4
 
@@ -1,3 +1,10 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various endpoints from the Fireworks AI API. In order to use them, you must
4
+ first `pip install fireworks-ai` and configure your Fireworks AI credentials, as described in
5
+ the [Working with Fireworks](https://pixeltable.readme.io/docs/working-with-fireworks) tutorial.
6
+ """
7
+
1
8
  from typing import Optional, TYPE_CHECKING
2
9
 
3
10
  import pixeltable as pxt
@@ -29,6 +36,32 @@ def chat_completions(
29
36
  top_p: Optional[float] = None,
30
37
  temperature: Optional[float] = None,
31
38
  ) -> dict:
39
+ """
40
+ Creates a model response for the given chat conversation.
41
+
42
+ Equivalent to the Fireworks AI `chat/completions` API endpoint.
43
+ For additional details, see: [https://docs.fireworks.ai/api-reference/post-chatcompletions](https://docs.fireworks.ai/api-reference/post-chatcompletions)
44
+
45
+ __Requirements:__
46
+
47
+ - `pip install fireworks-ai`
48
+
49
+ Args:
50
+ messages: A list of messages comprising the conversation so far.
51
+ model: The name of the model to use.
52
+
53
+ For details on the other parameters, see: [https://docs.fireworks.ai/api-reference/post-chatcompletions](https://docs.fireworks.ai/api-reference/post-chatcompletions)
54
+
55
+ Returns:
56
+ A dictionary containing the response and other metadata.
57
+
58
+ Examples:
59
+ Add a computed column that applies the model `accounts/fireworks/models/mixtral-8x22b-instruct`
60
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
61
+
62
+ >>> messages = [{'role': 'user', 'content': tbl.prompt}]
63
+ ... tbl['response'] = chat_completions(tbl.prompt, model='accounts/fireworks/models/mixtral-8x22b-instruct')
64
+ """
32
65
  kwargs = {'max_tokens': max_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature}
33
66
  kwargs_not_none = {k: v for k, v in kwargs.items() if v is not None}
34
67
  return _fireworks_client().chat.completions.create(model=model, messages=messages, **kwargs_not_none).dict()
@@ -25,7 +25,7 @@ def sentence_transformer(
25
25
  sentence: Batch[str], *, model_id: str, normalize_embeddings: bool = False
26
26
  ) -> Batch[np.ndarray]:
27
27
  """
28
- Runs the specified pretrained sentence-transformers model. `model_id` should be a pretrained model, as described
28
+ Computes sentence embeddings. `model_id` should be a pretrained Sentence Transformers model, as described
29
29
  in the [Sentence Transformers Pretrained Models](https://sbert.net/docs/sentence_transformer/pretrained_models.html)
30
30
  documentation.
31
31
 
@@ -83,8 +83,8 @@ def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embed
83
83
  @pxt.udf(batch_size=32)
84
84
  def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: str) -> Batch[float]:
85
85
  """
86
- Runs the specified cross-encoder model to compute similarity scores for pairs of sentences.
87
- `model_id` should be a pretrained model, as described in the
86
+ Performs predicts on the given sentence pair.
87
+ `model_id` should be a pretrained Cross-Encoder model, as described in the
88
88
  [Cross-Encoder Pretrained Models](https://www.sbert.net/docs/cross_encoder/pretrained_models.html)
89
89
  documentation.
90
90
 
@@ -130,7 +130,27 @@ def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> li
130
130
 
131
131
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
132
132
  def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
133
- """Runs the specified CLIP model on text."""
133
+ """
134
+ Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
135
+ [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
136
+
137
+ __Requirements:__
138
+
139
+ - `pip install transformers`
140
+
141
+ Args:
142
+ text: The string to embed.
143
+ model_id: The pretrained model to use for the embedding.
144
+
145
+ Returns:
146
+ An array containing the output of the embedding model.
147
+
148
+ Examples:
149
+ Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
150
+ Pixeltable column `tbl.text` of the table `tbl`:
151
+
152
+ >>> tbl['result'] = clip_text(tbl.text, model_id='openai/clip-vit-base-patch32')
153
+ """
134
154
  env.Env.get().require_package('transformers')
135
155
  device = resolve_torch_device('auto')
136
156
  import torch
@@ -148,7 +168,27 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[np.ndarray]:
148
168
 
149
169
  @pxt.udf(batch_size=32, return_type=ts.ArrayType((None,), dtype=ts.FloatType(), nullable=False))
150
170
  def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndarray]:
151
- """Runs the specified CLIP model on images."""
171
+ """
172
+ Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
173
+ [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
174
+
175
+ __Requirements:__
176
+
177
+ - `pip install transformers`
178
+
179
+ Args:
180
+ image: The image to embed.
181
+ model_id: The pretrained model to use for the embedding.
182
+
183
+ Returns:
184
+ An array containing the output of the embedding model.
185
+
186
+ Examples:
187
+ Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
188
+ Pixeltable column `tbl.image` of the table `tbl`:
189
+
190
+ >>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
191
+ """
152
192
  env.Env.get().require_package('transformers')
153
193
  device = resolve_torch_device('auto')
154
194
  import torch
@@ -178,7 +218,41 @@ def _(model_id: str) -> ts.ArrayType:
178
218
 
179
219
  @pxt.udf(batch_size=4)
180
220
  def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
181
- """Runs the specified DETR model."""
221
+ """
222
+ Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
223
+ [DETR Model](https://huggingface.co/docs/transformers/model_doc/detr).
224
+
225
+ __Requirements:__
226
+
227
+ - `pip install transformers`
228
+
229
+ Args:
230
+ image: The image to embed.
231
+ model_id: The pretrained model to use for the embedding.
232
+
233
+ Returns:
234
+ A dictionary containing the output of the object detection model, in the following format:
235
+
236
+ ```python
237
+ {
238
+ 'scores': [0.99, 0.999], # list of confidence scores for each detected object
239
+ 'labels': [25, 25], # list of COCO class labels for each detected object
240
+ 'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
241
+ 'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
242
+ # list of bounding boxes for each detected object, as [x1, y1, x2, y2]
243
+ }
244
+ ```
245
+
246
+ Examples:
247
+ Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
248
+ Pixeltable column `tbl.image` of the table `tbl`:
249
+
250
+ >>> tbl['detections'] = detr_for_object_detection(
251
+ ... tbl.image,
252
+ ... model_id='facebook/detr-resnet-50',
253
+ ... threshold=0.8
254
+ ... )
255
+ """
182
256
  env.Env.get().require_package('transformers')
183
257
  device = resolve_torch_device('auto')
184
258
  import torch
@@ -210,6 +284,22 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
210
284
 
211
285
  @pxt.udf
212
286
  def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
287
+ """
288
+ Converts the output of a DETR object detection model to COCO format.
289
+
290
+ Args:
291
+ image: The image for which detections were computed.
292
+ detr_info: The output of a DETR object detection model, as returned by `detr_for_object_detection`.
293
+
294
+ Returns:
295
+ A dictionary containing the data from `detr_info`, converted to COCO format.
296
+
297
+ Examples:
298
+ Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
299
+ is the image for which detections were computed:
300
+
301
+ >>> tbl['detections_coco'] = detr_to_coco(tbl.image, tbl.detections)
302
+ """
213
303
  bboxes, labels = detr_info['boxes'], detr_info['labels']
214
304
  annotations = [
215
305
  {'bbox': [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]], 'category': label}