pixeltable 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. pixeltable/__init__.py +2 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -1
  4. pixeltable/catalog/column.py +41 -29
  5. pixeltable/catalog/globals.py +18 -0
  6. pixeltable/catalog/insertable_table.py +30 -10
  7. pixeltable/catalog/table.py +198 -86
  8. pixeltable/catalog/table_version.py +47 -53
  9. pixeltable/catalog/table_version_path.py +2 -2
  10. pixeltable/catalog/view.py +17 -18
  11. pixeltable/dataframe.py +27 -36
  12. pixeltable/env.py +7 -0
  13. pixeltable/exec/__init__.py +0 -1
  14. pixeltable/exec/aggregation_node.py +6 -3
  15. pixeltable/exec/cache_prefetch_node.py +189 -43
  16. pixeltable/exec/data_row_batch.py +5 -22
  17. pixeltable/exec/exec_context.py +2 -2
  18. pixeltable/exec/exec_node.py +3 -2
  19. pixeltable/exec/expr_eval_node.py +23 -16
  20. pixeltable/exec/in_memory_data_node.py +6 -3
  21. pixeltable/exec/sql_node.py +24 -25
  22. pixeltable/exprs/arithmetic_expr.py +12 -5
  23. pixeltable/exprs/array_slice.py +7 -7
  24. pixeltable/exprs/column_property_ref.py +37 -10
  25. pixeltable/exprs/column_ref.py +97 -14
  26. pixeltable/exprs/comparison.py +10 -5
  27. pixeltable/exprs/compound_predicate.py +8 -7
  28. pixeltable/exprs/data_row.py +27 -18
  29. pixeltable/exprs/expr.py +53 -52
  30. pixeltable/exprs/expr_set.py +5 -0
  31. pixeltable/exprs/function_call.py +32 -16
  32. pixeltable/exprs/globals.py +4 -1
  33. pixeltable/exprs/in_predicate.py +8 -7
  34. pixeltable/exprs/inline_expr.py +4 -4
  35. pixeltable/exprs/is_null.py +4 -4
  36. pixeltable/exprs/json_mapper.py +11 -12
  37. pixeltable/exprs/json_path.py +6 -11
  38. pixeltable/exprs/literal.py +5 -5
  39. pixeltable/exprs/method_ref.py +5 -4
  40. pixeltable/exprs/object_ref.py +2 -1
  41. pixeltable/exprs/row_builder.py +88 -36
  42. pixeltable/exprs/rowid_ref.py +12 -11
  43. pixeltable/exprs/similarity_expr.py +12 -7
  44. pixeltable/exprs/sql_element_cache.py +7 -5
  45. pixeltable/exprs/type_cast.py +8 -6
  46. pixeltable/exprs/variable.py +5 -4
  47. pixeltable/func/aggregate_function.py +9 -9
  48. pixeltable/func/expr_template_function.py +6 -5
  49. pixeltable/func/function.py +11 -10
  50. pixeltable/func/udf.py +6 -11
  51. pixeltable/functions/__init__.py +2 -2
  52. pixeltable/functions/globals.py +5 -7
  53. pixeltable/functions/huggingface.py +155 -45
  54. pixeltable/functions/llama_cpp.py +107 -0
  55. pixeltable/functions/mistralai.py +1 -1
  56. pixeltable/functions/ollama.py +147 -0
  57. pixeltable/functions/openai.py +1 -1
  58. pixeltable/functions/replicate.py +72 -0
  59. pixeltable/functions/string.py +9 -0
  60. pixeltable/functions/together.py +1 -1
  61. pixeltable/functions/util.py +5 -2
  62. pixeltable/globals.py +67 -26
  63. pixeltable/index/btree.py +16 -3
  64. pixeltable/index/embedding_index.py +4 -4
  65. pixeltable/io/__init__.py +1 -2
  66. pixeltable/io/fiftyone.py +178 -0
  67. pixeltable/io/globals.py +96 -2
  68. pixeltable/iterators/base.py +3 -2
  69. pixeltable/iterators/document.py +1 -1
  70. pixeltable/iterators/video.py +120 -63
  71. pixeltable/metadata/__init__.py +1 -1
  72. pixeltable/metadata/converters/convert_21.py +34 -0
  73. pixeltable/metadata/converters/util.py +45 -4
  74. pixeltable/metadata/notes.py +1 -0
  75. pixeltable/metadata/schema.py +8 -0
  76. pixeltable/plan.py +17 -15
  77. pixeltable/py.typed +0 -0
  78. pixeltable/store.py +7 -2
  79. pixeltable/tool/create_test_db_dump.py +1 -1
  80. pixeltable/tool/create_test_video.py +1 -1
  81. pixeltable/tool/embed_udf.py +1 -1
  82. pixeltable/tool/mypy_plugin.py +28 -5
  83. pixeltable/type_system.py +100 -36
  84. pixeltable/utils/coco.py +5 -5
  85. pixeltable/utils/documents.py +15 -1
  86. pixeltable/utils/formatter.py +12 -13
  87. pixeltable/utils/s3.py +6 -3
  88. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/METADATA +158 -49
  89. pixeltable-0.2.23.dist-info/RECORD +153 -0
  90. pixeltable/exec/media_validation_node.py +0 -43
  91. pixeltable-0.2.21.dist-info/RECORD +0 -148
  92. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/LICENSE +0 -0
  93. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/WHEEL +0 -0
  94. {pixeltable-0.2.21.dist-info → pixeltable-0.2.23.dist-info}/entry_points.txt +0 -0
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import inspect
5
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
5
+ from typing import TYPE_CHECKING, Any, Callable, Optional
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
@@ -36,8 +36,8 @@ class AggregateFunction(Function):
36
36
  RESERVED_PARAMS = {ORDER_BY_PARAM, GROUP_BY_PARAM}
37
37
 
38
38
  def __init__(
39
- self, aggregator_class: Type[Aggregator], self_path: str,
40
- init_types: List[ts.ColumnType], update_types: List[ts.ColumnType], value_type: ts.ColumnType,
39
+ self, aggregator_class: type[Aggregator], self_path: str,
40
+ init_types: list[ts.ColumnType], update_types: list[ts.ColumnType], value_type: ts.ColumnType,
41
41
  requires_order_by: bool, allows_std_agg: bool, allows_window: bool):
42
42
  self.agg_cls = aggregator_class
43
43
  self.requires_order_by = requires_order_by
@@ -86,7 +86,7 @@ class AggregateFunction(Function):
86
86
  res += '\n\n' + inspect.getdoc(self.agg_cls.update)
87
87
  return res
88
88
 
89
- def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.Expr':
89
+ def __call__(self, *args: object, **kwargs: object) -> 'pixeltable.exprs.FunctionCall':
90
90
  from pixeltable import exprs
91
91
 
92
92
  # perform semantic analysis of special parameters 'order_by' and 'group_by'
@@ -128,7 +128,7 @@ class AggregateFunction(Function):
128
128
  order_by_clause=[order_by_clause] if order_by_clause is not None else [],
129
129
  group_by_clause=[group_by_clause] if group_by_clause is not None else [])
130
130
 
131
- def validate_call(self, bound_args: Dict[str, Any]) -> None:
131
+ def validate_call(self, bound_args: dict[str, Any]) -> None:
132
132
  # check that init parameters are not Exprs
133
133
  # TODO: do this in the planner (check that init parameters are either constants or only refer to grouping exprs)
134
134
  import pixeltable.exprs as exprs
@@ -146,10 +146,10 @@ class AggregateFunction(Function):
146
146
  def uda(
147
147
  *,
148
148
  value_type: ts.ColumnType,
149
- update_types: List[ts.ColumnType],
150
- init_types: Optional[List[ts.ColumnType]] = None,
149
+ update_types: list[ts.ColumnType],
150
+ init_types: Optional[list[ts.ColumnType]] = None,
151
151
  requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
152
- ) -> Callable[[Type[Aggregator]], AggregateFunction]:
152
+ ) -> Callable[[type[Aggregator]], AggregateFunction]:
153
153
  """Decorator for user-defined aggregate functions.
154
154
 
155
155
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -171,7 +171,7 @@ def uda(
171
171
  if init_types is None:
172
172
  init_types = []
173
173
 
174
- def decorator(cls: Type[Aggregator]) -> AggregateFunction:
174
+ def decorator(cls: type[Aggregator]) -> AggregateFunction:
175
175
  # validate type parameters
176
176
  num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
177
177
  if num_init_params > 0:
@@ -1,10 +1,11 @@
1
1
  import inspect
2
- from typing import Dict, Optional, Any
2
+ from typing import Any, Optional
3
3
 
4
4
  import pixeltable
5
5
  import pixeltable.exceptions as excs
6
+
6
7
  from .function import Function
7
- from .signature import Signature, Parameter
8
+ from .signature import Signature
8
9
 
9
10
 
10
11
  class ExprTemplateFunction(Function):
@@ -22,7 +23,7 @@ class ExprTemplateFunction(Function):
22
23
  self.param_exprs_by_name = {p.name: p for p in self.param_exprs}
23
24
 
24
25
  # verify default values
25
- self.defaults: Dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
26
+ self.defaults: dict[str, exprs.Literal] = {} # key: param name, value: default value converted to a Literal
26
27
  for param in signature.parameters.values():
27
28
  if param.default is inspect.Parameter.empty:
28
29
  continue
@@ -77,7 +78,7 @@ class ExprTemplateFunction(Function):
77
78
  def name(self) -> str:
78
79
  return self.self_name
79
80
 
80
- def _as_dict(self) -> Dict:
81
+ def _as_dict(self) -> dict:
81
82
  if self.self_path is not None:
82
83
  return super()._as_dict()
83
84
  return {
@@ -87,7 +88,7 @@ class ExprTemplateFunction(Function):
87
88
  }
88
89
 
89
90
  @classmethod
90
- def _from_dict(cls, d: Dict) -> Function:
91
+ def _from_dict(cls, d: dict) -> Function:
91
92
  if 'expr' not in d:
92
93
  return super()._from_dict(d)
93
94
  assert 'signature' in d and 'name' in d
@@ -3,12 +3,13 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- from typing import Any, Callable, Dict, Optional, Tuple
6
+ from typing import Any, Callable, Optional
7
7
 
8
8
  import sqlalchemy as sql
9
9
 
10
- import pixeltable
10
+ import pixeltable as pxt
11
11
  import pixeltable.type_system as ts
12
+
12
13
  from .globals import resolve_symbol
13
14
  from .signature import Signature
14
15
 
@@ -66,13 +67,13 @@ class Function(abc.ABC):
66
67
  def help_str(self) -> str:
67
68
  return self.display_name + str(self.signature)
68
69
 
69
- def __call__(self, *args: Any, **kwargs: Any) -> 'pixeltable.exprs.Expr':
70
+ def __call__(self, *args: Any, **kwargs: Any) -> 'pxt.exprs.FunctionCall':
70
71
  from pixeltable import exprs
71
72
  bound_args = self.signature.py_signature.bind(*args, **kwargs)
72
73
  self.validate_call(bound_args.arguments)
73
74
  return exprs.FunctionCall(self, bound_args.arguments)
74
75
 
75
- def validate_call(self, bound_args: Dict[str, Any]) -> None:
76
+ def validate_call(self, bound_args: dict[str, Any]) -> None:
76
77
  """Override this to do custom validation of the arguments"""
77
78
  pass
78
79
 
@@ -121,7 +122,7 @@ class Function(abc.ABC):
121
122
  """Print source code"""
122
123
  print('source not available')
123
124
 
124
- def as_dict(self) -> Dict:
125
+ def as_dict(self) -> dict:
125
126
  """
126
127
  Return a serialized reference to the instance that can be passed to json.dumps() and converted back
127
128
  to an instance with from_dict().
@@ -130,13 +131,13 @@ class Function(abc.ABC):
130
131
  classpath = f'{self.__class__.__module__}.{self.__class__.__qualname__}'
131
132
  return {'_classpath': classpath, **self._as_dict()}
132
133
 
133
- def _as_dict(self) -> Dict:
134
+ def _as_dict(self) -> dict:
134
135
  """Default serialization: store the path to self (which includes the module path)"""
135
136
  assert self.self_path is not None
136
137
  return {'path': self.self_path}
137
138
 
138
139
  @classmethod
139
- def from_dict(cls, d: Dict) -> Function:
140
+ def from_dict(cls, d: dict) -> Function:
140
141
  """
141
142
  Turn dict that was produced by calling as_dict() into an instance of the correct Function subclass.
142
143
  """
@@ -147,14 +148,14 @@ class Function(abc.ABC):
147
148
  return func_class._from_dict(d)
148
149
 
149
150
  @classmethod
150
- def _from_dict(cls, d: Dict) -> Function:
151
+ def _from_dict(cls, d: dict) -> Function:
151
152
  """Default deserialization: load the symbol indicated by the stored symbol_path"""
152
153
  assert 'path' in d and d['path'] is not None
153
154
  instance = resolve_symbol(d['path'])
154
155
  assert isinstance(instance, Function)
155
156
  return instance
156
157
 
157
- def to_store(self) -> Tuple[Dict, bytes]:
158
+ def to_store(self) -> tuple[dict, bytes]:
158
159
  """
159
160
  Serialize the function to a format that can be stored in the Pixeltable store
160
161
  Returns:
@@ -165,7 +166,7 @@ class Function(abc.ABC):
165
166
  raise NotImplementedError()
166
167
 
167
168
  @classmethod
168
- def from_store(cls, name: Optional[str], md: Dict, binary_obj: bytes) -> Function:
169
+ def from_store(cls, name: Optional[str], md: dict, binary_obj: bytes) -> Function:
169
170
  """
170
171
  Create a Function instance from the serialized representation returned by to_store()
171
172
  """
pixeltable/func/udf.py CHANGED
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import List, Callable, Optional, overload, Any
3
+ from typing import Any, Callable, Optional, overload
4
4
 
5
5
  import pixeltable.exceptions as excs
6
6
  import pixeltable.type_system as ts
7
+
7
8
  from .callable_function import CallableFunction
8
9
  from .expr_template_function import ExprTemplateFunction
9
10
  from .function import Function
@@ -21,8 +22,6 @@ def udf(decorated_fn: Callable) -> Function: ...
21
22
  @overload
22
23
  def udf(
23
24
  *,
24
- return_type: Optional[ts.ColumnType] = None,
25
- param_types: Optional[List[ts.ColumnType]] = None,
26
25
  batch_size: Optional[int] = None,
27
26
  substitute_fn: Optional[Callable] = None,
28
27
  is_method: bool = False,
@@ -49,8 +48,6 @@ def udf(*args, **kwargs):
49
48
 
50
49
  # Decorator schema invoked with parentheses: @pxt.udf(**kwargs)
51
50
  # Create a decorator for the specified schema.
52
- return_type = kwargs.pop('return_type', None)
53
- param_types = kwargs.pop('param_types', None)
54
51
  batch_size = kwargs.pop('batch_size', None)
55
52
  substitute_fn = kwargs.pop('substitute_fn', None)
56
53
  is_method = kwargs.pop('is_method', None)
@@ -64,9 +61,7 @@ def udf(*args, **kwargs):
64
61
  def decorator(decorated_fn: Callable):
65
62
  return make_function(
66
63
  decorated_fn,
67
- return_type,
68
- param_types,
69
- batch_size,
64
+ batch_size=batch_size,
70
65
  substitute_fn=substitute_fn,
71
66
  is_method=is_method,
72
67
  is_property=is_property,
@@ -79,7 +74,7 @@ def udf(*args, **kwargs):
79
74
  def make_function(
80
75
  decorated_fn: Callable,
81
76
  return_type: Optional[ts.ColumnType] = None,
82
- param_types: Optional[List[ts.ColumnType]] = None,
77
+ param_types: Optional[list[ts.ColumnType]] = None,
83
78
  batch_size: Optional[int] = None,
84
79
  substitute_fn: Optional[Callable] = None,
85
80
  is_method: bool = False,
@@ -158,10 +153,10 @@ def make_function(
158
153
  def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
159
154
 
160
155
  @overload
161
- def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
156
+ def expr_udf(*, param_types: Optional[list[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
162
157
 
163
158
  def expr_udf(*args: Any, **kwargs: Any) -> Any:
164
- def make_expr_template(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction:
159
+ def make_expr_template(py_fn: Callable, param_types: Optional[list[ts.ColumnType]]) -> ExprTemplateFunction:
165
160
  if py_fn.__module__ != '__main__' and py_fn.__name__.isidentifier():
166
161
  # this is a named function in a module
167
162
  function_path = f'{py_fn.__module__}.{py_fn.__qualname__}'
@@ -1,7 +1,7 @@
1
1
  from pixeltable.utils.code import local_public_names
2
2
 
3
- from . import (anthropic, audio, fireworks, huggingface, image, json, mistralai, openai, string, timestamp, together,
4
- video, vision, whisper)
3
+ from . import (anthropic, audio, fireworks, huggingface, image, json, llama_cpp, mistralai, ollama, openai, string,
4
+ timestamp, together, video, vision, whisper)
5
5
  from .globals import *
6
6
 
7
7
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
@@ -36,9 +36,7 @@ class sum(func.Aggregator):
36
36
  return self.sum
37
37
 
38
38
 
39
- # disable type checking: mypy doesn't seem to understand that 'sum' is an instance of Function
40
- # TODO: find a way to have this type-checked
41
- @sum.to_sql # type: ignore
39
+ @sum.to_sql
42
40
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
43
41
  # This can produce a Decimal. We are deliberately avoiding an explicit cast to a Bigint here, because that can
44
42
  # cause overflows in Postgres. We're instead doing the conversion to the target type in SqlNode.__iter__().
@@ -58,7 +56,7 @@ class count(func.Aggregator):
58
56
  return self.count
59
57
 
60
58
 
61
- @count.to_sql # type: ignore
59
+ @count.to_sql
62
60
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
63
61
  return sql.sql.func.count(val)
64
62
 
@@ -82,7 +80,7 @@ class min(func.Aggregator):
82
80
  return self.val
83
81
 
84
82
 
85
- @min.to_sql # type: ignore
83
+ @min.to_sql
86
84
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
87
85
  return sql.sql.func.min(val)
88
86
 
@@ -106,7 +104,7 @@ class max(func.Aggregator):
106
104
  return self.val
107
105
 
108
106
 
109
- @max.to_sql # type: ignore
107
+ @max.to_sql
110
108
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
111
109
  return sql.sql.func.max(val)
112
110
 
@@ -134,7 +132,7 @@ class mean(func.Aggregator):
134
132
  return self.sum / self.count
135
133
 
136
134
 
137
- @mean.to_sql # type: ignore
135
+ @mean.to_sql
138
136
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
139
137
  return sql.sql.func.avg(val)
140
138
 
@@ -7,21 +7,22 @@ first `pip install transformers` (or in some cases, `sentence-transformers`, as
7
7
  UDFs).
8
8
  """
9
9
 
10
- from typing import Callable, TypeVar, Optional, Any
10
+ from typing import Any, Callable, Optional, TypeVar
11
11
 
12
12
  import PIL.Image
13
13
 
14
14
  import pixeltable as pxt
15
15
  import pixeltable.env as env
16
+ import pixeltable.exceptions as excs
16
17
  from pixeltable.func import Batch
17
- from pixeltable.functions.util import resolve_torch_device, normalize_image_mode
18
+ from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
18
19
  from pixeltable.utils.code import local_public_names
19
20
 
20
21
 
21
22
  @pxt.udf(batch_size=32)
22
23
  def sentence_transformer(
23
24
  sentence: Batch[str], *, model_id: str, normalize_embeddings: bool = False
24
- ) -> Batch[pxt.Array[(None,), float]]:
25
+ ) -> Batch[pxt.Array[(None,), pxt.Float]]:
25
26
  """
26
27
  Computes sentence embeddings. `model_id` should be a pretrained Sentence Transformers model, as described
27
28
  in the [Sentence Transformers Pretrained Models](https://sbert.net/docs/sentence_transformer/pretrained_models.html)
@@ -29,7 +30,7 @@ def sentence_transformer(
29
30
 
30
31
  __Requirements:__
31
32
 
32
- - `pip install sentence-transformers`
33
+ - `pip install torch sentence-transformers`
33
34
 
34
35
  Args:
35
36
  sentence: The sentence to embed.
@@ -48,11 +49,15 @@ def sentence_transformer(
48
49
  >>> tbl['result'] = sentence_transformer(tbl.sentence, model_id='all-mpnet-base-v2')
49
50
  """
50
51
  env.Env.get().require_package('sentence_transformers')
52
+ device = resolve_torch_device('auto')
53
+ import torch
51
54
  from sentence_transformers import SentenceTransformer # type: ignore
52
55
 
53
- model = _lookup_model(model_id, SentenceTransformer)
56
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
57
+ model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
54
58
 
55
- array = model.encode(sentence, normalize_embeddings=normalize_embeddings)
59
+ # specifying the device, uses it for computation
60
+ array = model.encode(sentence, device=device, normalize_embeddings=normalize_embeddings)
56
61
  return [array[i] for i in range(array.shape[0])]
57
62
 
58
63
 
@@ -70,11 +75,15 @@ def _(model_id: str) -> pxt.ArrayType:
70
75
  @pxt.udf
71
76
  def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
72
77
  env.Env.get().require_package('sentence_transformers')
78
+ device = resolve_torch_device('auto')
79
+ import torch
73
80
  from sentence_transformers import SentenceTransformer
74
81
 
75
- model = _lookup_model(model_id, SentenceTransformer)
82
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
83
+ model = _lookup_model(model_id, SentenceTransformer, device=device, pass_device_to_create=True)
76
84
 
77
- array = model.encode(sentences, normalize_embeddings=normalize_embeddings)
85
+ # specifying the device, uses it for computation
86
+ array = model.encode(sentences, device=device, normalize_embeddings=normalize_embeddings)
78
87
  return [array[i].tolist() for i in range(array.shape[0])]
79
88
 
80
89
 
@@ -88,7 +97,7 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
88
97
 
89
98
  __Requirements:__
90
99
 
91
- - `pip install sentence-transformers`
100
+ - `pip install torch sentence-transformers`
92
101
 
93
102
  Parameters:
94
103
  sentences1: The first sentence to be paired.
@@ -107,9 +116,13 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
107
116
  )
108
117
  """
109
118
  env.Env.get().require_package('sentence_transformers')
119
+ device = resolve_torch_device('auto')
120
+ import torch
110
121
  from sentence_transformers import CrossEncoder
111
122
 
112
- model = _lookup_model(model_id, CrossEncoder)
123
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
124
+ # and uses the device for predict computation
125
+ model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
113
126
 
114
127
  array = model.predict([[s1, s2] for s1, s2 in zip(sentences1, sentences2)], convert_to_numpy=True)
115
128
  return array.tolist()
@@ -118,23 +131,27 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
118
131
  @pxt.udf
119
132
  def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
120
133
  env.Env.get().require_package('sentence_transformers')
134
+ device = resolve_torch_device('auto')
135
+ import torch
121
136
  from sentence_transformers import CrossEncoder
122
137
 
123
- model = _lookup_model(model_id, CrossEncoder)
138
+ # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
139
+ # and uses the device for predict computation
140
+ model = _lookup_model(model_id, CrossEncoder, device=device, pass_device_to_create=True)
124
141
 
125
142
  array = model.predict([[sentence1, s2] for s2 in sentences2], convert_to_numpy=True)
126
143
  return array.tolist()
127
144
 
128
145
 
129
146
  @pxt.udf(batch_size=32)
130
- def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), float]]:
147
+ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
131
148
  """
132
149
  Computes a CLIP embedding for the specified text. `model_id` should be a reference to a pretrained
133
150
  [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
134
151
 
135
152
  __Requirements:__
136
153
 
137
- - `pip install transformers`
154
+ - `pip install torch transformers`
138
155
 
139
156
  Args:
140
157
  text: The string to embed.
@@ -165,14 +182,14 @@ def clip_text(text: Batch[str], *, model_id: str) -> Batch[pxt.Array[(None,), fl
165
182
 
166
183
 
167
184
  @pxt.udf(batch_size=32)
168
- def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), float]]:
185
+ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
169
186
  """
170
187
  Computes a CLIP embedding for the specified image. `model_id` should be a reference to a pretrained
171
188
  [CLIP Model](https://huggingface.co/docs/transformers/model_doc/clip).
172
189
 
173
190
  __Requirements:__
174
191
 
175
- - `pip install transformers`
192
+ - `pip install torch transformers`
176
193
 
177
194
  Args:
178
195
  image: The image to embed.
@@ -215,14 +232,20 @@ def _(model_id: str) -> pxt.ArrayType:
215
232
 
216
233
 
217
234
  @pxt.udf(batch_size=4)
218
- def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
235
+ def detr_for_object_detection(
236
+ image: Batch[PIL.Image.Image],
237
+ *,
238
+ model_id: str,
239
+ threshold: float = 0.5,
240
+ revision: str = 'no_timm',
241
+ ) -> Batch[dict]:
219
242
  """
220
243
  Computes DETR object detections for the specified image. `model_id` should be a reference to a pretrained
221
244
  [DETR Model](https://huggingface.co/docs/transformers/model_doc/detr).
222
245
 
223
246
  __Requirements:__
224
247
 
225
- - `pip install transformers`
248
+ - `pip install torch transformers`
226
249
 
227
250
  Args:
228
251
  image: The image to embed.
@@ -254,12 +277,12 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
254
277
  env.Env.get().require_package('transformers')
255
278
  device = resolve_torch_device('auto')
256
279
  import torch
257
- from transformers import DetrImageProcessor, DetrForObjectDetection
280
+ from transformers import DetrForObjectDetection, DetrImageProcessor
258
281
 
259
282
  model = _lookup_model(
260
- model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision='no_timm'), device=device
283
+ model_id, lambda x: DetrForObjectDetection.from_pretrained(x, revision=revision), device=device
261
284
  )
262
- processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision='no_timm'))
285
+ processor = _lookup_processor(model_id, lambda x: DetrImageProcessor.from_pretrained(x, revision=revision))
263
286
  normalized_images = [normalize_image_mode(img) for img in image]
264
287
 
265
288
  with torch.no_grad():
@@ -286,7 +309,7 @@ def vit_for_image_classification(
286
309
  *,
287
310
  model_id: str,
288
311
  top_k: int = 5
289
- ) -> Batch[list[dict[str, Any]]]:
312
+ ) -> Batch[dict[str, Any]]:
290
313
  """
291
314
  Computes image classifications for the specified image using a Vision Transformer (ViT) model.
292
315
  `model_id` should be a reference to a pretrained [ViT Model](https://huggingface.co/docs/transformers/en/model_doc/vit).
@@ -299,7 +322,7 @@ def vit_for_image_classification(
299
322
 
300
323
  __Requirements:__
301
324
 
302
- - `pip install transformers`
325
+ - `pip install torch transformers`
303
326
 
304
327
  Args:
305
328
  image: The image to classify.
@@ -307,30 +330,30 @@ def vit_for_image_classification(
307
330
  top_k: The number of classes to return.
308
331
 
309
332
  Returns:
310
- A list of the `top_k` highest-scoring classes for each image. Each element in the list is a dictionary
311
- in the following format:
333
+ A dictionary containing the output of the image classification model, in the following format:
312
334
 
313
- ```python
314
- {
315
- 'p': 0.230, # class probability
316
- 'class': 935, # class ID
317
- 'label': 'mashed potato', # class label
318
- }
319
- ```
335
+ ```python
336
+ {
337
+ 'scores': [0.325, 0.198, 0.105], # list of probabilities of the top-k most likely classes
338
+ 'labels': [340, 353, 386], # list of class IDs for the top-k most likely classes
339
+ 'label_text': ['zebra', 'gazelle', 'African elephant, Loxodonta africana'],
340
+ # corresponding text names of the top-k most likely classes
341
+ ```
320
342
 
321
343
  Examples:
322
344
  Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
323
- Pixeltable column `image` of the table `tbl`:
345
+ Pixeltable column `image` of the table `tbl`, returning the 10 most likely classes for each image:
324
346
 
325
347
  >>> tbl['image_class'] = vit_for_image_classification(
326
348
  ... tbl.image,
327
- ... model_id='google/vit-base-patch16-224'
349
+ ... model_id='google/vit-base-patch16-224',
350
+ ... top_k=10
328
351
  ... )
329
352
  """
330
353
  env.Env.get().require_package('transformers')
331
354
  device = resolve_torch_device('auto')
332
355
  import torch
333
- from transformers import ViTImageProcessor, ViTForImageClassification
356
+ from transformers import ViTForImageClassification, ViTImageProcessor
334
357
 
335
358
  model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
336
359
  processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
@@ -344,19 +367,98 @@ def vit_for_image_classification(
344
367
  probs = torch.softmax(logits, dim=-1)
345
368
  top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
346
369
 
370
+ # There is no official post_process method for ViT models; for consistency, we structure the output
371
+ # the same way as the output of the DETR model given by `post_process_object_detection`.
347
372
  return [
348
- [
349
- {
350
- 'p': top_k_probs[n, k].item(),
351
- 'class': top_k_indices[n, k].item(),
352
- 'label': model.config.id2label[top_k_indices[n, k].item()],
353
- }
354
- for k in range(top_k_probs.shape[1])
355
- ]
373
+ {
374
+ 'scores': [top_k_probs[n, k].item() for k in range(top_k_probs.shape[1])],
375
+ 'labels': [top_k_indices[n, k].item() for k in range(top_k_probs.shape[1])],
376
+ 'label_text': [model.config.id2label[top_k_indices[n, k].item()] for k in range(top_k_probs.shape[1])],
377
+ }
356
378
  for n in range(top_k_probs.shape[0])
357
379
  ]
358
380
 
359
381
 
382
+ @pxt.udf
383
+ def speech2text_for_conditional_generation(
384
+ audio: pxt.Audio,
385
+ *,
386
+ model_id: str,
387
+ language: Optional[str] = None,
388
+ ) -> str:
389
+ """
390
+ Transcribes or translates speech to text using a Speech2Text model. `model_id` should be a reference to a
391
+ pretrained [Speech2Text](https://huggingface.co/docs/transformers/en/model_doc/speech_to_text) model.
392
+
393
+ __Requirements:__
394
+
395
+ - `pip install torch torchaudio sentencepiece transformers`
396
+
397
+ Args:
398
+ audio: The audio clip to transcribe or translate.
399
+ model_id: The pretrained model to use for the transcription or translation.
400
+ language: If using a multilingual translation model, the language code to translate to. If not provided,
401
+ the model's default language will be used. If the model is not translation model, is not a
402
+ multilingual model, or does not support the specified language, an error will be raised.
403
+
404
+ Returns:
405
+ The transcribed or translated text.
406
+
407
+ Examples:
408
+ Add a computed column that applies the model `facebook/s2t-small-librispeech-asr` to an existing
409
+ Pixeltable column `audio` of the table `tbl`:
410
+
411
+ >>> tbl['transcription'] = speech2text_for_conditional_generation(
412
+ ... tbl.audio,
413
+ ... model_id='facebook/s2t-small-librispeech-asr'
414
+ ... )
415
+
416
+ Add a computed column that applies the model `facebook/s2t-medium-mustc-multilingual-st` to an existing
417
+ Pixeltable column `audio` of the table `tbl`, translating the audio to French:
418
+
419
+ >>> tbl['translation'] = speech2text_for_conditional_generation(
420
+ ... tbl.audio,
421
+ ... model_id='facebook/s2t-medium-mustc-multilingual-st',
422
+ ... language='fr'
423
+ ... )
424
+ """
425
+ env.Env.get().require_package('transformers')
426
+ env.Env.get().require_package('torchaudio')
427
+ env.Env.get().require_package('sentencepiece')
428
+ device = resolve_torch_device('auto', allow_mps=False) # Doesn't seem to work on 'mps'; use 'cpu' instead
429
+ import librosa
430
+ import torch
431
+ from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
432
+
433
+ # facebook/s2t-small-librispeech-asr
434
+ # facebook/s2t-small-mustc-en-fr-st
435
+ model = _lookup_model(model_id, Speech2TextForConditionalGeneration.from_pretrained, device=device)
436
+ processor = _lookup_processor(model_id, Speech2TextProcessor.from_pretrained)
437
+ assert isinstance(processor, Speech2TextProcessor)
438
+
439
+ if language is not None and language not in processor.tokenizer.lang_code_to_id:
440
+ raise excs.Error(
441
+ f"Language code '{language}' is not supported by the model '{model_id}'. "
442
+ f"Supported languages are: {list(processor.tokenizer.lang_code_to_id.keys())}")
443
+
444
+ forced_bos_token_id: Optional[int] = None if language is None else processor.tokenizer.lang_code_to_id[language]
445
+
446
+ # Get the model's sampling rate. Default to 16 kHz (the standard) if not in config
447
+ model_sampling_rate = getattr(model.config, 'sampling_rate', 16_000)
448
+ waveform, sampling_rate = librosa.load(audio, sr=model_sampling_rate, mono=True)
449
+
450
+ with torch.no_grad():
451
+ inputs = processor(
452
+ waveform,
453
+ sampling_rate=sampling_rate,
454
+ return_tensors='pt'
455
+ )
456
+ generated_ids = model.generate(**inputs.to(device), forced_bos_token_id=forced_bos_token_id).to('cpu')
457
+
458
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
459
+ return transcription
460
+
461
+
360
462
  @pxt.udf
361
463
  def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
362
464
  """
@@ -386,14 +488,22 @@ def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str,
386
488
  T = TypeVar('T')
387
489
 
388
490
 
389
- def _lookup_model(model_id: str, create: Callable[[str], T], device: Optional[str] = None) -> T:
491
+ def _lookup_model(
492
+ model_id: str,
493
+ create: Callable[..., T],
494
+ device: Optional[str] = None,
495
+ pass_device_to_create: bool = False
496
+ ) -> T:
390
497
  from torch import nn
391
498
 
392
499
  key = (model_id, create, device) # For safety, include the `create` callable in the cache key
393
500
  if key not in _model_cache:
394
- model = create(model_id)
501
+ if pass_device_to_create:
502
+ model = create(model_id, device=device)
503
+ else:
504
+ model = create(model_id)
395
505
  if isinstance(model, nn.Module):
396
- if device is not None:
506
+ if not pass_device_to_create and device is not None:
397
507
  model.to(device)
398
508
  model.eval()
399
509
  _model_cache[key] = model