pixeltable 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/column.py +8 -22
  4. pixeltable/catalog/insertable_table.py +26 -8
  5. pixeltable/catalog/table.py +179 -83
  6. pixeltable/catalog/table_version.py +13 -39
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/catalog/view.py +2 -2
  9. pixeltable/dataframe.py +20 -28
  10. pixeltable/env.py +2 -0
  11. pixeltable/exec/cache_prefetch_node.py +189 -43
  12. pixeltable/exec/data_row_batch.py +3 -3
  13. pixeltable/exec/exec_context.py +2 -2
  14. pixeltable/exec/exec_node.py +2 -2
  15. pixeltable/exec/expr_eval_node.py +8 -8
  16. pixeltable/exprs/arithmetic_expr.py +9 -4
  17. pixeltable/exprs/column_ref.py +4 -0
  18. pixeltable/exprs/comparison.py +5 -0
  19. pixeltable/exprs/json_path.py +1 -1
  20. pixeltable/func/aggregate_function.py +8 -8
  21. pixeltable/func/expr_template_function.py +6 -5
  22. pixeltable/func/udf.py +6 -11
  23. pixeltable/functions/huggingface.py +136 -25
  24. pixeltable/functions/llama_cpp.py +3 -2
  25. pixeltable/functions/mistralai.py +1 -1
  26. pixeltable/functions/openai.py +1 -1
  27. pixeltable/functions/together.py +1 -1
  28. pixeltable/functions/util.py +5 -2
  29. pixeltable/globals.py +55 -6
  30. pixeltable/plan.py +1 -1
  31. pixeltable/tool/create_test_db_dump.py +1 -1
  32. pixeltable/type_system.py +83 -35
  33. pixeltable/utils/coco.py +5 -5
  34. pixeltable/utils/formatter.py +3 -3
  35. pixeltable/utils/s3.py +6 -3
  36. {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/METADATA +119 -46
  37. {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/RECORD +40 -40
  38. {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/LICENSE +0 -0
  39. {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/WHEEL +0 -0
  40. {pixeltable-0.2.22.dist-info → pixeltable-0.2.23.dist-info}/entry_points.txt +0 -0
@@ -3,7 +3,7 @@ import sys
3
3
  import time
4
4
  import warnings
5
5
  from dataclasses import dataclass
6
- from typing import Iterable, List, Optional
6
+ from typing import Iterable, Optional
7
7
 
8
8
  from tqdm import TqdmWarning, tqdm
9
9
 
@@ -22,10 +22,10 @@ class ExprEvalNode(ExecNode):
22
22
  @dataclass
23
23
  class Cohort:
24
24
  """List of exprs that form an evaluation context and contain calls to at most one external function"""
25
- exprs_: List[exprs.Expr]
25
+ exprs_: list[exprs.Expr]
26
26
  batched_fn: Optional[CallableFunction]
27
- segment_ctxs: List['exprs.RowBuilder.EvalCtx']
28
- target_slot_idxs: List[int]
27
+ segment_ctxs: list['exprs.RowBuilder.EvalCtx']
28
+ target_slot_idxs: list[int]
29
29
  batch_size: int = 8
30
30
 
31
31
  def __init__(
@@ -38,7 +38,7 @@ class ExprEvalNode(ExecNode):
38
38
  # we're only materializing exprs that are not already in the input
39
39
  self.target_exprs = [e for e in output_exprs if e.slot_idx not in input_slot_idxs]
40
40
  self.pbar: Optional[tqdm] = None
41
- self.cohorts: List[ExprEvalNode.Cohort] = []
41
+ self.cohorts: list[ExprEvalNode.Cohort] = []
42
42
  self._create_cohorts()
43
43
 
44
44
  def __next__(self) -> DataRowBatch:
@@ -83,7 +83,7 @@ class ExprEvalNode(ExecNode):
83
83
  all_exprs = self.row_builder.get_dependencies(self.target_exprs)
84
84
  # break up all_exprs into cohorts such that each cohort contains calls to at most one external function;
85
85
  # seed the cohorts with only the ext fn calls
86
- cohorts: List[List[exprs.Expr]] = []
86
+ cohorts: list[list[exprs.Expr]] = []
87
87
  current_batched_fn: Optional[CallableFunction] = None
88
88
  for e in all_exprs:
89
89
  if not self._is_batched_fn_call(e):
@@ -100,7 +100,7 @@ class ExprEvalNode(ExecNode):
100
100
  # cohorts are evaluated in order, so we can exclude the target slots from preceding cohorts and input slots
101
101
  exclude = set(e.slot_idx for e in self.input_exprs)
102
102
  all_target_slot_idxs = set(e.slot_idx for e in self.target_exprs)
103
- target_slot_idxs: List[List[int]] = [] # the ones materialized by each cohort
103
+ target_slot_idxs: list[list[int]] = [] # the ones materialized by each cohort
104
104
  for i in range(len(cohorts)):
105
105
  cohorts[i] = self.row_builder.get_dependencies(
106
106
  cohorts[i], exclude=[self.row_builder.unique_exprs[slot_idx] for slot_idx in exclude])
@@ -171,7 +171,7 @@ class ExprEvalNode(ExecNode):
171
171
  arg_batches: list[list[exprs.Expr]] = [[] for _ in range(len(fn_call.args))]
172
172
  kwarg_batches: dict[str, list[exprs.Expr]] = {k: [] for k in fn_call.kwargs.keys()}
173
173
 
174
- valid_batch_idxs: List[int] = [] # rows with exceptions are not valid
174
+ valid_batch_idxs: list[int] = [] # rows with exceptions are not valid
175
175
  for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
176
176
  row = rows[row_idx]
177
177
  if row.has_exc(fn_call.slot_idx):
@@ -69,11 +69,15 @@ class ArithmeticExpr(Expr):
69
69
  return left * right
70
70
  if self.operator == ArithmeticOperator.DIV:
71
71
  assert self.col_type.is_float_type()
72
+ # Avoid DivisionByZero: if right is 0, make this a NULL
73
+ # TODO: Should we cast the NULLs to NaNs when they are retrieved back into Python?
74
+ nullif = sql.sql.func.nullif(right, 0)
72
75
  # We have to cast to a `float`, or else we'll get a `Decimal`
73
- return sql.sql.expression.cast(left / right, sql.Float)
76
+ return sql.sql.expression.cast(left / nullif, sql.Float)
74
77
  if self.operator == ArithmeticOperator.MOD:
75
78
  if self.col_type.is_int_type():
76
- return left % right
79
+ nullif = sql.sql.func.nullif(right, 0)
80
+ return left % nullif
77
81
  if self.col_type.is_float_type():
78
82
  # Postgres does not support modulus for floats
79
83
  return None
@@ -83,10 +87,11 @@ class ArithmeticExpr(Expr):
83
87
  # We need the behavior to be consistent, so that expressions will evaluate the same way
84
88
  # whether or not their operands can be translated to SQL. These SQL clauses should
85
89
  # mimic the behavior of Python's // operator.
90
+ nullif = sql.sql.func.nullif(right, 0)
86
91
  if self.col_type.is_int_type():
87
- return sql.sql.expression.cast(sql.func.floor(left / right), sql.Integer)
92
+ return sql.sql.expression.cast(sql.func.floor(left / nullif), sql.Integer)
88
93
  if self.col_type.is_float_type():
89
- return sql.sql.expression.cast(sql.func.floor(left / right), sql.Float)
94
+ return sql.sql.expression.cast(sql.func.floor(left / nullif), sql.Float)
90
95
  assert False
91
96
 
92
97
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -135,6 +135,10 @@ class ColumnRef(Expr):
135
135
  def __repr__(self) -> str:
136
136
  return f'ColumnRef({self.col!r})'
137
137
 
138
+ def _repr_html_(self) -> str:
139
+ tbl = catalog.Catalog.get().tbls[self.col.tbl.id]
140
+ return tbl._description_html(cols=[self.col])._repr_html_() # type: ignore[attr-defined]
141
+
138
142
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
139
143
  return None if self.perform_validation else self.col.sa_col
140
144
 
@@ -67,6 +67,11 @@ class Comparison(Expr):
67
67
  return self.components[1]
68
68
 
69
69
  def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
70
+ if str(self._op1.col_type.to_sa_type()) != str(self._op2.col_type.to_sa_type()):
71
+ # Comparing columns of different SQL types (e.g., string vs. json); this can only be done in Python
72
+ # TODO(aaron-siegel): We may be able to handle some cases in SQL by casting one side to the other's type
73
+ return None
74
+
70
75
  left = sql_elements.get(self._op1)
71
76
  if self.is_search_arg_comparison:
72
77
  # reference the index value column if there is an index and this is not a snapshot
@@ -32,7 +32,7 @@ class JsonPath(Expr):
32
32
  """
33
33
  if path_elements is None:
34
34
  path_elements = []
35
- super().__init__(ts.JsonType())
35
+ super().__init__(ts.JsonType(nullable=True)) # JsonPath expressions are always nullable
36
36
  if anchor is not None:
37
37
  self.components = [anchor]
38
38
  self.path_elements: list[Union[str, int, slice]] = path_elements
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import inspect
5
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
5
+ from typing import TYPE_CHECKING, Any, Callable, Optional
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
@@ -36,8 +36,8 @@ class AggregateFunction(Function):
36
36
  RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
37
37
 
38
38
  def __init__(
39
- self, aggregator_class: Type[Aggregator], self_path: str,
40
- init_types: List[ts.ColumnType], update_types: List[ts.ColumnType], value_type: ts.ColumnType,
39
+ self, aggregator_class: type[Aggregator], self_path: str,
40
+ init_types: list[ts.ColumnType], update_types: list[ts.ColumnType], value_type: ts.ColumnType,
41
41
  requires_order_by: bool, allows_std_agg: bool, allows_window: bool):
42
42
  self.agg_cls = aggregator_class
43
43
  self.requires_order_by = requires_order_by
@@ -128,7 +128,7 @@ class AggregateFunction(Function):
128
128
  order_by_clause=[order_by_clause] if order_by_clause is not None else [],
129
129
  group_by_clause=[group_by_clause] if group_by_clause is not None else [])
130
130
 
131
- def validate_call(self, bound_args: Dict[str, Any]) -> None:
131
+ def validate_call(self, bound_args: dict[str, Any]) -> None:
132
132
  # check that init parameters are not Exprs
133
133
  # TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
134
134
  import pixeltable.exprs as exprs
@@ -146,10 +146,10 @@ class AggregateFunction(Function):
146
146
  def uda(
147
147
  *,
148
148
  value_type: ts.ColumnType,
149
- update_types: List[ts.ColumnType],
150
- init_types: Optional[List[ts.ColumnType]] = None,
149
+ update_types: list[ts.ColumnType],
150
+ init_types: Optional[list[ts.ColumnType]] = None,
151
151
  requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
152
- ) -> Callable[[Type[Aggregator]], AggregateFunction]:
152
+ ) -> Callable[[type[Aggregator]], AggregateFunction]:
153
153
  """Decorator for user-defined aggregate functions.
154
154
 
155
155
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -171,7 +171,7 @@ def uda(
171
171
  if init_types is None:
172
172
  init_types = []
173
173
 
174
- def decorator(cls: Type[Aggregator]) -> AggregateFunction:
174
+ def decorator(cls: type[Aggregator]) -> AggregateFunction:
175
175
  # validate type parameters
176
176
  num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
177
177
  if num_init_params > 0:
@@ -1,10 +1,11 @@
1
1
  import inspect
2
- from typing import Dict, Optional, Any
2
+ from typing import Any, Optional
3
3
 
4
4
  import pixeltable
5
5
  import pixeltable.exceptions as excs
6
+
6
7
  from .function import Function
7
- from .signature import Signature, Parameter
8
+ from .signature import Signature
8
9
 
9
10
 
10
11
  class ExprTemplateFunction(Function):
@@ -22,7 +23,7 @@ class ExprTemplateFunction(Function):
22
23
  self.param_exprs_by_name = {p.name: p for p in self.param_exprs}
23
24
 
24
25
  # verify default values
25
- self.defaults: Dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
26
+ self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
26
27
  for param in signature.parameters.values():
27
28
  if param.default is inspect.Parameter.empty:
28
29
  continue
@@ -77,7 +78,7 @@ class ExprTemplateFunction(Function):
77
78
  def name(self) -> str:
78
79
  return self.self_name
79
80
 
80
- def _as_dict(self) -> Dict:
81
+ def _as_dict(self) -> dict:
81
82
  if self.self_path is not None:
82
83
  return super()._as_dict()
83
84
  return {
@@ -87,7 +88,7 @@ class ExprTemplateFunction(Function):
87
88
  }
88
89
 
89
90
  @classmethod
90
- def _from_dict(cls, d: Dict) -> Function:
91
+ def _from_dict(cls, d: dict) -> Function:
91
92
  if 'expr' not in d:
92
93
  return super()._from_dict(d)
93
94
  assert 'signature' in d and 'name' in d
pixeltable/func/udf.py CHANGED
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import List, Callable, Optional, overload, Any
3
+ from typing import Any, Callable, Optional, overload
4
4
 
5
5
  import pixeltable.exceptions as excs
6
6
  import pixeltable.type_system as ts
7
+
7
8
  from .callable_function import CallableFunction
8
9
  from .expr_template_function import ExprTemplateFunction
9
10
  from .function import Function
@@ -21,8 +22,6 @@ def udf(decorated_fn: Callable) -> Function: ...
21
22
  @overload
22
23
  def udf(
23
24
  *,
24
- return_type: Optional[ts.ColumnType] = None,
25
- param_types: Optional[List[ts.ColumnType]] = None,
26
25
  batch_size: Optional[int] = None,
27
26
  substitute_fn: Optional[Callable] = None,
28
27
  is_method: bool = False,
@@ -49,8 +48,6 @@ def udf(*args, **kwargs):
49
48
 
50
49
  # Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
51
50
  # Create a decorator for the specified schema.
52
- return_type = kwargs.pop('return_type', None)
53
- param_types = kwargs.pop('param_types', None)
54
51
  batch_size = kwargs.pop('batch_size', None)
55
52
  substitute_fn = kwargs.pop('substitute_fn', None)
56
53
  is_method = kwargs.pop('is_method', None)
@@ -64,9 +61,7 @@ def udf(*args, **kwargs):
64
61
  def decorator(decorated_fn: Callable):
65
62
  return make_function(
66
63
  decorated_fn,
67
- return_type,
68
- param_types,
69
- batch_size,
64
+ batch_size=batch_size,
70
65
  substitute_fn=substitute_fn,
71
66
  is_method=is_method,
72
67
  is_property=is_property,
@@ -79,7 +74,7 @@ def udf(*args, **kwargs):
79
74
  def make_function(
80
75
  decorated_fn: Callable,
81
76
  return_type: Optional[ts.ColumnType] = None,
82
- param_types: Optional[List[ts.ColumnType]] = None,
77
+ param_types: Optional[list[ts.ColumnType]] = None,
83
78
  batch_size: Optional[int] = None,
84
79
  substitute_fn: Optional[Callable] = None,
85
80
  is_method: bool = False,
@@ -158,10 +153,10 @@ def make_function(
158
153
  def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
159
154
 
160
155
  @overload
161
- def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
156
+ def expr_udf(*, param_types: Optional[list[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
162
157
 
163
158
  def expr_udf(*args: Any, **kwargs: Any) -> Any:
164
- def make_expr_template(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction:
159
+ def make_expr_template(py_fn: Callable, param_types: Optional[list[ts.ColumnType]]) -> ExprTemplateFunction:
165
160
  if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
166
161
  # this is a named function in a module
167
162
  function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
@@ -7,21 +7,22 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
7
7
  UDFs).
8
8
  """
9
9
 
10
- from typing import Callable, TypeVar, Optional, Any
10
+ from typing import Any, Callable, Optional, TypeVar
11
11
 
12
12
  import PIL.Image
13
13
 
14
14
  import pixeltable as pxt
15
15
  import pixeltable.env as env
16
+ import pixeltable.exceptions as excs
16
17
  from pixeltable.func import Batch
17
- from pixeltable.functions.util import resolve_torch_device, normalize_image_mode
18
+ from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
18
19
  from pixeltable.utils.code import local_public_names
19
20
 
20
21
 
21
22
  @pxt.udf(batch_size=32)
22
23
  def sentence_transformer(
23
24
  sentence: Batch[str], *, model_id: str, normalize_embeddings: bool = False
24
- ) -> Batch[pxt.Array[(None,), float]]:
25
+ ) -> Batch[pxt.Array[(None,), pxt.Float]]:
25
26
  """
26
27
  Computes sentence embeddings. `model_id` should be a pretrained Sentence Transformers model, as described
27
28
  in the [Sentence Transformers Pretrained Models](https://sbert.net/docs/sentence_transformer/pretrained_models.html)
@@ -29,7 +30,7 @@ def sentence_transformer(
29
30
 
30
31
  __Requirements:__
31
32
 
32
- - `pip install sentence-transformers`
33
+ - `pip install torch sentence-transformers`
33
34
 
34
35
  Args:
35
36
  sentence: The sentence to embed.
@@ -48,11 +49,15 @@ def sentence_transformer(
48
49
  >>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
49
50
  """
50
51
  env.Env.get().require_package('sentence_transformers')
52
+ device = resolve_torch_device('auto')
53
+ import torch
51
54
  from sentence_transformers import SentenceTransformer # type: ignore
52
55
 
53
- model = _lookup_model(model_id, SentenceTransformer)
56
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
57
+ model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
54
58
 
55
- array = model.encode(sentence, normalize_embeddings=normalize_embeddings)
59
+ # specifying the device, uses it for computation
60
+ array = model.encode(sentence, device=device, normalize_embeddings=normalize_embeddings)
56
61
  return [array[i] for i in range(array.shape[0])]
57
62
 
58
63
 
@@ -70,11 +75,15 @@ def _(model_id: str) -> pxt.ArrayType:
70
75
  @pxt.udf
71
76
  def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
72
77
  env.Env.get().require_package('sentence_transformers')
78
+ device = resolve_torch_device('auto')
79
+ import torch
73
80
  from sentence_transformers import SentenceTransformer
74
81
 
75
- model = _lookup_model(model_id, SentenceTransformer)
82
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
83
+ model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
76
84
 
77
- array = model.encode(sentences, normalize_embeddings=normalize_embeddings)
85
+ # specifying the device, uses it for computation
86
+ array = model.encode(sentences, device=device, normalize_embeddings=normalize_embeddings)
78
87
  return [array[i].tolist() for i in range(array.shape[0])]
79
88
 
80
89
 
@@ -88,7 +97,7 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
88
97
 
89
98
  __Requirements:__
90
99
 
91
- - `pip install sentence-transformers`
100
+ - `pip install torch sentence-transformers`
92
101
 
93
102
  Parameters:
94
103
  sentences1: The first sentence to be paired.
@@ -107,9 +116,13 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
107
116
  )
108
117
  """
109
118
  env.Env.get().require_package('sentence_transformers')
119
+ device = resolve_torch_device('auto')
120
+ import torch
110
121
  from sentence_transformers import CrossEncoder
111
122
 
112
- model = _lookup_model(model_id, CrossEncoder)
123
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
124
+ # and uses the device for predict computation
125
+ model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
113
126
 
114
127
  array = model.predict([[s1, s2] for s1, s2 in zip(sentences1, sentences2)], convert_to_numpy=True)
115
128
  return array.tolist()
@@ -118,23 +131,27 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
118
131
  @pxt.udf
119
132
  def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
120
133
  env.Env.get().require_package('sentence_transformers')
134
+ device = resolve_torch_device('auto')
135
+ import torch
121
136
  from sentence_transformers import CrossEncoder
122
137
 
123
- model = _lookup_model(model_id, CrossEncoder)
138
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
139
+ # and uses the device for predict computation
140
+ model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
124
141
 
125
142
  array = model.predict([[sentence1, s2] for s2 in sentences2], convert_to_numpy=True)
126
143
  return array.tolist()
127
144
 
128
145
 
129
146
  @pxt.udf(batch_size=32)
130
- def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), float]]:
147
+ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
131
148
  """
132
149
  Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
133
150
  [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
134
151
 
135
152
  __Requirements:__
136
153
 
137
- - `pip install transformers`
154
+ - `pip install torch transformers`
138
155
 
139
156
  Args:
140
157
  text: The string to embed.
@@ -165,14 +182,14 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), fl
165
182
 
166
183
 
167
184
  @pxt.udf(batch_size=32)
168
- def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), float]]:
185
+ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
169
186
  """
170
187
  Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
171
188
  [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
172
189
 
173
190
  __Requirements:__
174
191
 
175
- - `pip install transformers`
192
+ - `pip install torch transformers`
176
193
 
177
194
  Args:
178
195
  image: The image to embed.
@@ -215,14 +232,20 @@ def _(model_id: str) -> pxt.ArrayType:
215
232
 
216
233
 
217
234
  @pxt.udf(batch_size=4)
218
- def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
235
+ def detr_for_object_detection(
236
+ image: Batch[PIL.Image.Image],
237
+ *,
238
+ model_id: str,
239
+ threshold: float = 0.5,
240
+ revision: str = 'no_timm',
241
+ ) -> Batch[dict]:
219
242
  """
220
243
  Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
221
244
  [DETR Model](https://huggingface.co/docs/transformers/model_doc/detr).
222
245
 
223
246
  __Requirements:__
224
247
 
225
- - `pip install transformers`
248
+ - `pip install torch transformers`
226
249
 
227
250
  Args:
228
251
  image: The image to embed.
@@ -254,12 +277,12 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
254
277
  env.Env.get().require_package('transformers')
255
278
  device = resolve_torch_device('auto')
256
279
  import torch
257
- from transformers import DetrImageProcessor, DetrForObjectDetection
280
+ from transformers import DetrForObjectDetection, DetrImageProcessor
258
281
 
259
282
  model = _lookup_model(
260
- model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device
283
+ model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision=revision), device=device
261
284
  )
262
- processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
285
+ processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision=revision))
263
286
  normalized_images = [normalize_image_mode(img) for img in image]
264
287
 
265
288
  with torch.no_grad():
@@ -299,7 +322,7 @@ def vit_for_image_classification(
299
322
 
300
323
  __Requirements:__
301
324
 
302
- - `pip install transformers`
325
+ - `pip install torch transformers`
303
326
 
304
327
  Args:
305
328
  image: The image to classify.
@@ -330,7 +353,7 @@ def vit_for_image_classification(
330
353
  env.Env.get().require_package('transformers')
331
354
  device = resolve_torch_device('auto')
332
355
  import torch
333
- from transformers import ViTImageProcessor, ViTForImageClassification
356
+ from transformers import ViTForImageClassification, ViTImageProcessor
334
357
 
335
358
  model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
336
359
  processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
@@ -356,6 +379,86 @@ def vit_for_image_classification(
356
379
  ]
357
380
 
358
381
 
382
+ @pxt.udf
383
+ def speech2text_for_conditional_generation(
384
+ audio: pxt.Audio,
385
+ *,
386
+ model_id: str,
387
+ language: Optional[str] = None,
388
+ ) -> str:
389
+ """
390
+ Transcribes or translates speech to text using a Speech2Text model. `model_id` should be a reference to a
391
+ pretrained [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) model.
392
+
393
+ __Requirements:__
394
+
395
+ - `pip install torch torchaudio sentencepiece transformers`
396
+
397
+ Args:
398
+ audio: The audio clip to transcribe or translate.
399
+ model_id: The pretrained model to use for the transcription or translation.
400
+ language: If using a multilingual translation model, the language code to translate to. If not provided,
401
+ the model's default language will be used. If the model is not translation model, is not a
402
+ multilingual model, or does not support the specified language, an error will be raised.
403
+
404
+ Returns:
405
+ The transcribed or translated text.
406
+
407
+ Examples:
408
+ Add a computed column that applies the model `facebook/s2t-small-librispeech-asr` to an existing
409
+ Pixeltable column `audio` of the table `tbl`:
410
+
411
+ >>> tbl['transcription'] = speech2text_for_conditional_generation(
412
+ ... tbl.audio,
413
+ ... model_id='facebook/s2t-small-librispeech-asr'
414
+ ... )
415
+
416
+ Add a computed column that applies the model `facebook/s2t-medium-mustc-multilingual-st` to an existing
417
+ Pixeltable column `audio` of the table `tbl`, translating the audio to French:
418
+
419
+ >>> tbl['translation'] = speech2text_for_conditional_generation(
420
+ ... tbl.audio,
421
+ ... model_id='facebook/s2t-medium-mustc-multilingual-st',
422
+ ... language='fr'
423
+ ... )
424
+ """
425
+ env.Env.get().require_package('transformers')
426
+ env.Env.get().require_package('torchaudio')
427
+ env.Env.get().require_package('sentencepiece')
428
+ device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
429
+ import librosa
430
+ import torch
431
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
432
+
433
+ # facebook/s2t-small-librispeech-asr
434
+ # facebook/s2t-small-mustc-en-fr-st
435
+ model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
436
+ processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
437
+ assert isinstance(processor, Speech2TextProcessor)
438
+
439
+ if language is not None and language not in processor.tokenizer.lang_code_to_id:
440
+ raise excs.Error(
441
+ f"Language code '{language}' is not supported by the model '{model_id}'. "
442
+ f"Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}")
443
+
444
+ forced_bos_token_id: Optional[int] = None if language is None else processor.tokenizer.lang_code_to_id[language]
445
+
446
+ # Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
447
+ model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
448
+ waveform, sampling_rate = librosa.load(audio, sr=model_sampling_rate, mono=True)
449
+
450
+ with torch.no_grad():
451
+ inputs = processor(
452
+ waveform,
453
+ sampling_rate=sampling_rate,
454
+ return_tensors='pt'
455
+ )
456
+ generated_ids = model.generate(**inputs.to(device), forced_bos_token_id=forced_bos_token_id).to('cpu')
457
+
458
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
459
+ return transcription
460
+
461
+
359
462
  @pxt.udf
360
463
  def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
361
464
  """
@@ -385,14 +488,22 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
385
488
  T = TypeVar('T')
386
489
 
387
490
 
388
- def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
491
+ def _lookup_model(
492
+ model_id: str,
493
+ create: Callable[..., T],
494
+ device: Optional[str] = None,
495
+ pass_device_to_create: bool = False
496
+ ) -> T:
389
497
  from torch import nn
390
498
 
391
499
  key = (model_id, create, device) # For safety, include the `create` callable in the cache key
392
500
  if key not in _model_cache:
393
- model = create(model_id)
501
+ if pass_device_to_create:
502
+ model = create(model_id, device=device)
503
+ else:
504
+ model = create(model_id)
394
505
  if isinstance(model, nn.Module):
395
- if device is not None:
506
+ if not pass_device_to_create and device is not None:
396
507
  model.to(device)
397
508
  model.eval()
398
509
  _model_cache[key] = model
@@ -76,7 +76,7 @@ def _lookup_local_model(model_path: str, n_gpu_layers: int) -> 'llama_cpp.Llama'
76
76
 
77
77
  key = (model_path, None, n_gpu_layers)
78
78
  if key not in _model_cache:
79
- llm = llama_cpp.Llama(model_path, n_gpu_layers=n_gpu_layers)
79
+ llm = llama_cpp.Llama(model_path, n_gpu_layers=n_gpu_layers, verbose=False)
80
80
  _model_cache[key] = llm
81
81
  return _model_cache[key]
82
82
 
@@ -89,7 +89,8 @@ def _lookup_pretrained_model(repo_id: str, filename: Optional[str], n_gpu_layers
89
89
  llm = llama_cpp.Llama.from_pretrained(
90
90
  repo_id=repo_id,
91
91
  filename=filename,
92
- n_gpu_layers=n_gpu_layers
92
+ n_gpu_layers=n_gpu_layers,
93
+ verbose=False,
93
94
  )
94
95
  _model_cache[key] = llm
95
96
  return _model_cache[key]
@@ -141,7 +141,7 @@ _embedding_dimensions_cache: dict[str, int] = {
141
141
 
142
142
 
143
143
  @pxt.udf(batch_size=16)
144
- def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), float]]:
144
+ def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
145
145
  """
146
146
  Embeddings API.
147
147
 
@@ -304,7 +304,7 @@ _embedding_dimensions_cache: dict[str, int] = {
304
304
  @pxt.udf(batch_size=32)
305
305
  def embeddings(
306
306
  input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
307
- ) -> Batch[pxt.Array[(None,), float]]:
307
+ ) -> Batch[pxt.Array[(None,), pxt.Float]]:
308
308
  """
309
309
  Creates an embedding vector representing the input text.
310
310
 
@@ -186,7 +186,7 @@ _embedding_dimensions_cache = {
186
186
 
187
187
 
188
188
  @pxt.udf(batch_size=32)
189
- def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), float]]:
189
+ def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
190
190
  """
191
191
  Query an embedding model for a given string of text.
192
192
 
@@ -1,13 +1,16 @@
1
1
  import PIL.Image
2
2
 
3
+ from pixeltable.env import Env
3
4
 
4
- def resolve_torch_device(device: str) -> str:
5
+
6
+ def resolve_torch_device(device: str, allow_mps: bool = True) -> str:
7
+ Env.get().require_package('torch')
5
8
  import torch
6
9
 
7
10
  if device == 'auto':
8
11
  if torch.cuda.is_available():
9
12
  return 'cuda'
10
- if torch.backends.mps.is_available():
13
+ if allow_mps and torch.backends.mps.is_available():
11
14
  return 'mps'
12
15
  return 'cpu'
13
16
  return device