pixeltable 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (36) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/table.py +0 -1
  4. pixeltable/catalog/table_version.py +1 -1
  5. pixeltable/catalog/view.py +1 -1
  6. pixeltable/dataframe.py +1 -1
  7. pixeltable/env.py +34 -5
  8. pixeltable/exceptions.py +5 -1
  9. pixeltable/exec/component_iteration_node.py +1 -1
  10. pixeltable/exprs/__init__.py +1 -2
  11. pixeltable/exprs/expr.py +5 -6
  12. pixeltable/exprs/function_call.py +8 -10
  13. pixeltable/exprs/inline_expr.py +200 -0
  14. pixeltable/ext/functions/whisperx.py +2 -0
  15. pixeltable/ext/functions/yolox.py +5 -3
  16. pixeltable/functions/huggingface.py +89 -12
  17. pixeltable/functions/image.py +3 -3
  18. pixeltable/functions/together.py +15 -8
  19. pixeltable/functions/vision.py +43 -21
  20. pixeltable/functions/whisper.py +3 -0
  21. pixeltable/globals.py +5 -1
  22. pixeltable/metadata/__init__.py +1 -1
  23. pixeltable/metadata/converters/convert_18.py +1 -1
  24. pixeltable/metadata/converters/convert_20.py +56 -0
  25. pixeltable/metadata/converters/util.py +29 -4
  26. pixeltable/metadata/notes.py +1 -0
  27. pixeltable/tool/create_test_db_dump.py +14 -3
  28. pixeltable/type_system.py +3 -1
  29. pixeltable-0.2.19.dist-info/LICENSE +201 -0
  30. {pixeltable-0.2.18.dist-info → pixeltable-0.2.19.dist-info}/METADATA +6 -4
  31. {pixeltable-0.2.18.dist-info → pixeltable-0.2.19.dist-info}/RECORD +33 -33
  32. pixeltable/exprs/inline_array.py +0 -117
  33. pixeltable/exprs/inline_dict.py +0 -104
  34. pixeltable-0.2.18.dist-info/LICENSE +0 -18
  35. {pixeltable-0.2.18.dist-info → pixeltable-0.2.19.dist-info}/WHEEL +0 -0
  36. {pixeltable-0.2.18.dist-info → pixeltable-0.2.19.dist-info}/entry_points.txt +0 -0
pixeltable/__init__.py CHANGED
@@ -4,7 +4,7 @@ from .exceptions import Error
4
4
  from .exprs import RELATIVE_PATH_ROOT
5
5
  from .func import Function, udf, Aggregator, uda, expr_udf
6
6
  from .globals import init, create_table, create_view, get_table, move, drop_table, list_tables, create_dir, drop_dir, \
7
- list_dirs, list_functions, configure_logging
7
+ list_dirs, list_functions, configure_logging, array
8
8
  from .type_system import (
9
9
  ColumnType,
10
10
  StringType,
pixeltable/__version__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # These version placeholders will be replaced during build.
2
- __version__ = "0.2.18"
3
- __version_tuple__ = (0, 2, 18)
2
+ __version__ = "0.2.19"
3
+ __version_tuple__ = (0, 2, 19)
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import builtins
5
- import itertools
6
5
  import json
7
6
  import logging
8
7
  from pathlib import Path
@@ -147,7 +147,7 @@ class TableVersion:
147
147
  module = importlib.import_module(module_name)
148
148
  self.iterator_cls = getattr(module, class_name)
149
149
  self.iterator_args = exprs.InlineDict.from_dict(tbl_md.view_md.iterator_args)
150
- output_schema, _ = self.iterator_cls.output_schema(**self.iterator_args.to_dict())
150
+ output_schema, _ = self.iterator_cls.output_schema(**self.iterator_args.to_kwargs())
151
151
  self.num_iterator_cols = len(output_schema)
152
152
  assert tbl_md.view_md.iterator_args is not None
153
153
 
@@ -94,7 +94,7 @@ class View(Table):
94
94
  ]
95
95
  sig = func.Signature(InvalidType(), params)
96
96
  from pixeltable.exprs import FunctionCall
97
- FunctionCall.normalize_args(sig, bound_args)
97
+ FunctionCall.normalize_args(iterator_cls.__name__, sig, bound_args)
98
98
  except TypeError as e:
99
99
  raise Error(f'Cannot instantiate iterator with given arguments: {e}')
100
100
 
pixeltable/dataframe.py CHANGED
@@ -501,7 +501,7 @@ class DataFrame:
501
501
  elif isinstance(raw_expr, dict):
502
502
  select_list.append((exprs.InlineDict(raw_expr), name))
503
503
  elif isinstance(raw_expr, list):
504
- select_list.append((exprs.InlineArray(raw_expr), name))
504
+ select_list.append((exprs.InlineList(raw_expr), name))
505
505
  else:
506
506
  select_list.append((exprs.Literal(raw_expr), name))
507
507
  expr = select_list[-1][0]
pixeltable/env.py CHANGED
@@ -8,6 +8,7 @@ import importlib.util
8
8
  import inspect
9
9
  import logging
10
10
  import os
11
+ import subprocess
11
12
  import sys
12
13
  import threading
13
14
  import uuid
@@ -486,6 +487,8 @@ class Env:
486
487
  """Check for and start runtime services"""
487
488
  self._start_web_server()
488
489
  self.__register_packages()
490
+ if self.is_installed_package('spacy'):
491
+ self.__init_spacy()
489
492
 
490
493
  def __register_packages(self) -> None:
491
494
  """Declare optional packages that are utilized by some parts of the code."""
@@ -500,7 +503,7 @@ class Env:
500
503
  self.__register_package('openpyxl')
501
504
  self.__register_package('pyarrow')
502
505
  self.__register_package('sentence_transformers', library_name='sentence-transformers')
503
- self.__register_package('spacy') # TODO: deal with en-core-web-sm
506
+ self.__register_package('spacy')
504
507
  self.__register_package('tiktoken')
505
508
  self.__register_package('together')
506
509
  self.__register_package('toml')
@@ -511,10 +514,6 @@ class Env:
511
514
  self.__register_package('whisperx')
512
515
  self.__register_package('yolox', library_name='git+https://github.com/Megvii-BaseDetection/YOLOX@ac58e0a')
513
516
 
514
- if self.is_installed_package('spacy'):
515
- import spacy
516
- self._spacy_nlp = spacy.load('en_core_web_sm')
517
-
518
517
  def __register_package(self, package_name: str, library_name: Optional[str] = None) -> None:
519
518
  self.__optional_packages[package_name] = PackageInfo(
520
519
  is_installed=importlib.util.find_spec(package_name) is not None,
@@ -556,6 +555,35 @@ class Env:
556
555
  f'To fix this, run: `pip install -U {package_info.library_name}`'
557
556
  )
558
557
 
558
+ def __init_spacy(self) -> None:
559
+ """
560
+ spaCy relies on a pip-installed model to operate. In order to avoid requiring the model as a separate
561
+ dependency, we install it programmatically here. This should cause no problems, since the model packages
562
+ have no sub-dependencies (in fact, this is how spaCy normally manages its model resources).
563
+ """
564
+ import spacy
565
+ from spacy.cli.download import get_model_filename
566
+ spacy_model = 'en_core_web_sm'
567
+ spacy_model_version = '3.7.1'
568
+ filename = get_model_filename(spacy_model, spacy_model_version, sdist=False)
569
+ url = f'{spacy.about.__download_url__}/{filename}'
570
+ # Try to `pip install` the model. We set check=False; if the pip command fails, it's not necessarily
571
+ # a problem, because the model have been installed on a previous attempt.
572
+ self._logger.info(f'Ensuring spaCy model is installed: {filename}')
573
+ ret = subprocess.run([sys.executable, '-m', 'pip', 'install', '-qU', url], check=False)
574
+ if ret.returncode != 0:
575
+ self._logger.warn(f'pip install failed for spaCy model: {filename}')
576
+ try:
577
+ self._logger.info(f'Loading spaCy model: {spacy_model}')
578
+ self._spacy_nlp = spacy.load(spacy_model)
579
+ except Exception as exc:
580
+ self._logger.warn(f'Failed to load spaCy model: {spacy_model}', exc_info=exc)
581
+ warnings.warn(
582
+ f"Failed to load spaCy model '{spacy_model}'. spaCy features will not be available.",
583
+ excs.PixeltableWarning
584
+ )
585
+ self.__optional_packages['spacy'].is_installed = False
586
+
559
587
  def num_tmp_files(self) -> int:
560
588
  return len(glob.glob(f'{self._tmp_dir}/*'))
561
589
 
@@ -594,6 +622,7 @@ class Env:
594
622
 
595
623
  @property
596
624
  def spacy_nlp(self) -> spacy.Language:
625
+ Env.get().require_package('spacy')
597
626
  assert self._spacy_nlp is not None
598
627
  return self._spacy_nlp
599
628
 
pixeltable/exceptions.py CHANGED
@@ -14,4 +14,8 @@ class ExprEvalError(Exception):
14
14
  exc: Exception
15
15
  exc_tb: TracebackType
16
16
  input_vals: List[Any]
17
- row_num: int
17
+ row_num: int
18
+
19
+
20
+ class PixeltableWarning(Warning):
21
+ pass
@@ -24,7 +24,7 @@ class ComponentIterationNode(ExecNode):
24
24
  assert isinstance(self.iterator_args, exprs.InlineDict)
25
25
  self.iterator_args_ctx = self.row_builder.create_eval_ctx([self.iterator_args])
26
26
  self.iterator_output_schema, self.unstored_column_names = \
27
- self.view.iterator_cls.output_schema(**self.iterator_args.to_dict())
27
+ self.view.iterator_cls.output_schema(**self.iterator_args.to_kwargs())
28
28
  self.iterator_output_fields = list(self.iterator_output_schema.keys())
29
29
  self.iterator_output_cols = \
30
30
  {field_name: self.view.cols_by_name[field_name] for field_name in self.iterator_output_fields}
@@ -9,8 +9,7 @@ from .expr import Expr
9
9
  from .expr_set import ExprSet
10
10
  from .function_call import FunctionCall
11
11
  from .in_predicate import InPredicate
12
- from .inline_array import InlineArray
13
- from .inline_dict import InlineDict
12
+ from .inline_expr import InlineArray, InlineDict, InlineList
14
13
  from .is_null import IsNull
15
14
  from .json_mapper import JsonMapper
16
15
  from .json_path import RELATIVE_PATH_ROOT, JsonPath
pixeltable/exprs/expr.py CHANGED
@@ -356,15 +356,14 @@ class Expr(abc.ABC):
356
356
  """
357
357
  if isinstance(o, Expr):
358
358
  return o
359
- # Try to create a literal. We need to check for InlineArray/InlineDict
360
- # first, to prevent arrays from inappropriately being interpreted as JsonType
359
+ # Try to create a literal. We need to check for InlineList/InlineDict
360
+ # first, to prevent them from inappropriately being interpreted as JsonType
361
361
  # literals.
362
- # TODO: general cleanup of InlineArray/InlineDict
363
362
  if isinstance(o, list):
364
- from .inline_array import InlineArray
365
- return InlineArray(tuple(o))
363
+ from .inline_expr import InlineList
364
+ return InlineList(o)
366
365
  if isinstance(o, dict):
367
- from .inline_dict import InlineDict
366
+ from .inline_expr import InlineDict
368
367
  return InlineDict(o)
369
368
  obj_type = ts.ColumnType.infer_literal_type(o)
370
369
  if obj_type is not None:
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import json
5
5
  import sys
6
- from typing import Optional, Any
6
+ from typing import Any, Optional
7
7
 
8
8
  import sqlalchemy as sql
9
9
 
@@ -11,10 +11,10 @@ import pixeltable.catalog as catalog
11
11
  import pixeltable.exceptions as excs
12
12
  import pixeltable.func as func
13
13
  import pixeltable.type_system as ts
14
+
14
15
  from .data_row import DataRow
15
16
  from .expr import Expr
16
- from .inline_array import InlineArray
17
- from .inline_dict import InlineDict
17
+ from .inline_expr import InlineDict, InlineList
18
18
  from .row_builder import RowBuilder
19
19
  from .rowid_ref import RowidRef
20
20
  from .sql_element_cache import SqlElementCache
@@ -53,7 +53,7 @@ class FunctionCall(Expr):
53
53
  super().__init__(fn.call_return_type(bound_args))
54
54
  self.fn = fn
55
55
  self.is_method_call = is_method_call
56
- self.normalize_args(signature, bound_args)
56
+ self.normalize_args(fn.name, signature, bound_args)
57
57
 
58
58
  self.agg_init_args = {}
59
59
  if self.is_agg_fn_call:
@@ -143,7 +143,7 @@ class FunctionCall(Expr):
143
143
  return super().default_column_name()
144
144
 
145
145
  @classmethod
146
- def normalize_args(cls, signature: func.Signature, bound_args: dict[str, Any]) -> None:
146
+ def normalize_args(cls, fn_name: str, signature: func.Signature, bound_args: dict[str, Any]) -> None:
147
147
  """Converts all args to Exprs and checks that they are compatible with signature.
148
148
 
149
149
  Updates bound_args in place, where necessary.
@@ -163,9 +163,7 @@ class FunctionCall(Expr):
163
163
 
164
164
  if isinstance(arg, list) or isinstance(arg, tuple):
165
165
  try:
166
- # If the column type is JsonType, force the literal to be JSON
167
- is_json = is_var_param or (param.col_type is not None and param.col_type.is_json_type())
168
- arg = InlineArray(arg, force_json=is_json)
166
+ arg = InlineList(arg)
169
167
  bound_args[param_name] = arg
170
168
  continue
171
169
  except excs.Error:
@@ -177,7 +175,7 @@ class FunctionCall(Expr):
177
175
  try:
178
176
  _ = json.dumps(arg)
179
177
  except TypeError:
180
- raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg}')
178
+ raise excs.Error(f'Argument for parameter {param_name!r} is not json-serializable: {arg} (of type {type(arg)})')
181
179
  if arg is not None:
182
180
  try:
183
181
  param_type = param.col_type
@@ -215,7 +213,7 @@ class FunctionCall(Expr):
215
213
  or (arg.col_type.is_json_type() and param.col_type.is_scalar_type())
216
214
  ):
217
215
  raise excs.Error(
218
- f'Parameter {param_name}: argument type {arg.col_type} does not match parameter type '
216
+ f'Parameter {param_name} (in function {fn_name}): argument type {arg.col_type} does not match parameter type '
219
217
  f'{param.col_type}')
220
218
 
221
219
  def _equals(self, other: FunctionCall) -> bool:
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import Any, Iterable, Optional
5
+
6
+ import numpy as np
7
+ import sqlalchemy as sql
8
+
9
+ import pixeltable.exceptions as excs
10
+ import pixeltable.type_system as ts
11
+
12
+ from .data_row import DataRow
13
+ from .expr import Expr
14
+ from .literal import Literal
15
+ from .row_builder import RowBuilder
16
+ from .sql_element_cache import SqlElementCache
17
+
18
+
19
+ class InlineArray(Expr):
20
+ """
21
+ Array 'literal' which can use Exprs as values.
22
+ """
23
+
24
+ def __init__(self, elements: Iterable):
25
+ exprs = []
26
+ for el in elements:
27
+ if isinstance(el, Expr):
28
+ exprs.append(el)
29
+ elif isinstance(el, list) or isinstance(el, tuple):
30
+ exprs.append(InlineArray(el))
31
+ else:
32
+ exprs.append(Literal(el))
33
+
34
+ inferred_element_type: Optional[ts.ColumnType] = ts.InvalidType()
35
+ for i, expr in enumerate(exprs):
36
+ supertype = inferred_element_type.supertype(expr.col_type)
37
+ if supertype is None:
38
+ raise excs.Error(
39
+ f'Could not infer element type of array: element of type `{expr.col_type}` at index {i} '
40
+ f'is not compatible with type `{inferred_element_type}` of preceding elements'
41
+ )
42
+ inferred_element_type = supertype
43
+
44
+ if inferred_element_type.is_scalar_type():
45
+ col_type = ts.ArrayType((len(exprs),), inferred_element_type)
46
+ elif inferred_element_type.is_array_type():
47
+ assert isinstance(inferred_element_type, ts.ArrayType)
48
+ col_type = ts.ArrayType(
49
+ (len(exprs), *inferred_element_type.shape),
50
+ ts.ColumnType.make_type(inferred_element_type.dtype)
51
+ )
52
+ else:
53
+ raise excs.Error(f'Element type is not a valid dtype for an array: {inferred_element_type}')
54
+
55
+ super().__init__(col_type)
56
+ self.components.extend(exprs)
57
+ self.id = self._create_id()
58
+
59
+ def __str__(self) -> str:
60
+ elem_strs = [str(expr) for expr in self.components]
61
+ return f'[{", ".join(elem_strs)}]'
62
+
63
+ def _equals(self, _: InlineArray) -> bool:
64
+ return True # Always true if components match
65
+
66
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
67
+ return None
68
+
69
+ def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
70
+ data_row[self.slot_idx] = np.array([data_row[el.slot_idx] for el in self.components])
71
+
72
+ def _as_dict(self) -> dict:
73
+ return super()._as_dict()
74
+
75
+ @classmethod
76
+ def _from_dict(cls, _: dict, components: list[Expr]) -> Expr:
77
+ try:
78
+ return cls(components)
79
+ except excs.Error:
80
+ # For legacy compatibility reasons, we need to try constructing as an `InlineList`.
81
+ # This is because in schema versions <= 19, `InlineArray` was serialized incorrectly, and
82
+ # there is no way to determine the correct expression type until the subexpressions are
83
+ # loaded and their types are known.
84
+ return InlineList(components)
85
+
86
+
87
+ class InlineList(Expr):
88
+ """
89
+ List 'literal' which can use Exprs as values.
90
+ """
91
+
92
+ def __init__(self, elements: Iterable):
93
+ exprs = []
94
+ for el in elements:
95
+ if isinstance(el, Expr):
96
+ exprs.append(el)
97
+ elif isinstance(el, list) or isinstance(el, tuple):
98
+ exprs.append(InlineList(el))
99
+ elif isinstance(el, dict):
100
+ exprs.append(InlineDict(el))
101
+ else:
102
+ exprs.append(Literal(el))
103
+
104
+ super().__init__(ts.JsonType())
105
+ self.components.extend(exprs)
106
+ self.id = self._create_id()
107
+
108
+ def __str__(self) -> str:
109
+ elem_strs = [str(expr) for expr in self.components]
110
+ return f'[{", ".join(elem_strs)}]'
111
+
112
+ def _equals(self, _: InlineList) -> bool:
113
+ return True # Always true if components match
114
+
115
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
116
+ return None
117
+
118
+ def eval(self, data_row: DataRow, _: RowBuilder) -> None:
119
+ data_row[self.slot_idx] = [data_row[el.slot_idx] for el in self.components]
120
+
121
+ def _as_dict(self) -> dict:
122
+ return super()._as_dict()
123
+
124
+ @classmethod
125
+ def _from_dict(cls, _: dict, components: list[Expr]) -> Expr:
126
+ return cls(components)
127
+
128
+
129
+ class InlineDict(Expr):
130
+ """
131
+ Dictionary 'literal' which can use Exprs as values.
132
+ """
133
+
134
+ keys: list[str]
135
+
136
+ def __init__(self, d: dict[str, Any]):
137
+ self.keys = []
138
+ exprs: list[Expr] = []
139
+ for key, val in d.items():
140
+ if not isinstance(key, str):
141
+ raise excs.Error(f'Dictionary requires string keys; {key} has type {type(key)}')
142
+ self.keys.append(key)
143
+ if isinstance(val, Expr):
144
+ exprs.append(val)
145
+ elif isinstance(val, dict):
146
+ exprs.append(InlineDict(val))
147
+ elif isinstance(val, list) or isinstance(val, tuple):
148
+ exprs.append(InlineList(val))
149
+ else:
150
+ exprs.append(Literal(val))
151
+
152
+ super().__init__(ts.JsonType())
153
+ self.components.extend(exprs)
154
+ self.id = self._create_id()
155
+
156
+ def __str__(self) -> str:
157
+ item_strs = list(f"'{key}': {str(expr)}" for key, expr in zip(self.keys, self.components))
158
+ return '{' + ', '.join(item_strs) + '}'
159
+
160
+ def _equals(self, other: InlineDict) -> bool:
161
+ # The dict values are just the components, which have already been checked
162
+ return self.keys == other.keys
163
+
164
+ def _id_attrs(self) -> list[tuple[str, Any]]:
165
+ return super()._id_attrs() + [('keys', self.keys)]
166
+
167
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
168
+ return None
169
+
170
+ def eval(self, data_row: DataRow, _: RowBuilder) -> None:
171
+ assert len(self.keys) == len(self.components)
172
+ data_row[self.slot_idx] = {
173
+ key: data_row[expr.slot_idx]
174
+ for key, expr in zip(self.keys, self.components)
175
+ }
176
+
177
+ def to_kwargs(self) -> dict[str, Any]:
178
+ """Deconstructs this expression into a dictionary by recursively unwrapping all Literals,
179
+ InlineDicts, and InlineLists."""
180
+ return InlineDict._to_kwarg_element(self)
181
+
182
+ @classmethod
183
+ def _to_kwarg_element(cls, expr: Expr) -> Any:
184
+ if isinstance(expr, Literal):
185
+ return expr.val
186
+ if isinstance(expr, InlineDict):
187
+ return {key: cls._to_kwarg_element(val) for key, val in zip(expr.keys, expr.components)}
188
+ if isinstance(expr, InlineList):
189
+ return [cls._to_kwarg_element(el) for el in expr.components]
190
+ return expr
191
+
192
+ def _as_dict(self) -> dict[str, Any]:
193
+ return {'keys': self.keys, **super()._as_dict()}
194
+
195
+ @classmethod
196
+ def _from_dict(cls, d: dict, components: list[Expr]) -> Expr:
197
+ assert 'keys' in d
198
+ assert len(d['keys']) == len(components)
199
+ arg = dict(zip(d['keys'], components))
200
+ return InlineDict(arg)
@@ -19,6 +19,8 @@ def transcribe(
19
19
  equivalent to the WhisperX `transcribe` function, as described in the
20
20
  [WhisperX library documentation](https://github.com/m-bain/whisperX).
21
21
 
22
+ WhisperX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
23
+
22
24
  __Requirements:__
23
25
 
24
26
  - `pip install whisperx`
@@ -26,8 +26,7 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
26
26
  Computes YOLOX object detections for the specified image. `model_id` should reference one of the models
27
27
  defined in the [YOLOX documentation](https://github.com/Megvii-BaseDetection/YOLOX).
28
28
 
29
- YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
30
- intended for use in production applications.
29
+ YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
31
30
 
32
31
  __Requirements__:
33
32
 
@@ -79,6 +78,8 @@ def yolo_to_coco(detections: dict) -> list:
79
78
  """
80
79
  Converts the output of a YOLOX object detection model to COCO format.
81
80
 
81
+ YOLOX is part of the `pixeltable.ext` package: long-term support in Pixeltable is not guaranteed.
82
+
82
83
  Args:
83
84
  detections: The output of a YOLOX object detection model, as returned by `yolox`.
84
85
 
@@ -89,7 +90,8 @@ def yolo_to_coco(detections: dict) -> list:
89
90
  Add a computed column that converts the output `tbl.detections` to COCO format, where `tbl.image`
90
91
  is the image for which detections were computed:
91
92
 
92
- >>> tbl['detections_coco'] = yolo_to_coco(tbl.detections)
93
+ >>> tbl['detections'] = yolox(tbl.image, model_id='yolox_m', threshold=0.8)
94
+ ... tbl['detections_coco'] = yolo_to_coco(tbl.detections)
93
95
  """
94
96
  bboxes, labels = detections['bboxes'], detections['labels']
95
97
  num_annotations = len(detections['bboxes'])
@@ -185,7 +185,7 @@ def clip_image(image: Batch[PIL.Image.Image], *, model_id: str) -> Batch[np.ndar
185
185
 
186
186
  Examples:
187
187
  Add a computed column that applies the model `openai/clip-vit-base-patch32` to an existing
188
- Pixeltable column `tbl.image` of the table `tbl`:
188
+ Pixeltable column `image` of the table `tbl`:
189
189
 
190
190
  >>> tbl['result'] = clip_image(tbl.image, model_id='openai/clip-vit-base-patch32')
191
191
  """
@@ -228,24 +228,24 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
228
228
 
229
229
  Args:
230
230
  image: The image to embed.
231
- model_id: The pretrained model to use for the embedding.
231
+ model_id: The pretrained model to use for object detection.
232
232
 
233
233
  Returns:
234
234
  A dictionary containing the output of the object detection model, in the following format:
235
235
 
236
- ```python
237
- {
238
- 'scores': [0.99, 0.999], # list of confidence scores for each detected object
239
- 'labels': [25, 25], # list of COCO class labels for each detected object
240
- 'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
241
- 'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
242
- # list of bounding boxes for each detected object, as [x1, y1, x2, y2]
243
- }
244
- ```
236
+ ```python
237
+ {
238
+ 'scores': [0.99, 0.999], # list of confidence scores for each detected object
239
+ 'labels': [25, 25], # list of COCO class labels for each detected object
240
+ 'label_text': ['giraffe', 'giraffe'], # corresponding text names of class labels
241
+ 'boxes': [[51.942, 356.174, 181.481, 413.975], [383.225, 58.66, 605.64, 361.346]]
242
+ # list of bounding boxes for each detected object, as [x1, y1, x2, y2]
243
+ }
244
+ ```
245
245
 
246
246
  Examples:
247
247
  Add a computed column that applies the model `facebook/detr-resnet-50` to an existing
248
- Pixeltable column `tbl.image` of the table `tbl`:
248
+ Pixeltable column `image` of the table `tbl`:
249
249
 
250
250
  >>> tbl['detections'] = detr_for_object_detection(
251
251
  ... tbl.image,
@@ -282,6 +282,83 @@ def detr_for_object_detection(image: Batch[PIL.Image.Image], *, model_id: str, t
282
282
  ]
283
283
 
284
284
 
285
+ @pxt.udf(batch_size=4)
286
+ def vit_for_image_classification(
287
+ image: Batch[PIL.Image.Image],
288
+ *,
289
+ model_id: str,
290
+ top_k: int = 5
291
+ ) -> Batch[list[dict[str, Any]]]:
292
+ """
293
+ Computes image classifications for the specified image using a Vision Transformer (ViT) model.
294
+ `model_id` should be a reference to a pretrained [ViT Model](https://huggingface.co/docs/transformers/en/model_doc/vit).
295
+
296
+ __Note:__ Be sure the model is a ViT model that is trained for image classification (that is, a model designed for
297
+ use with the
298
+ [ViTForImageClassification](https://huggingface.co/docs/transformers/en/model_doc/vit#transformers.ViTForImageClassification)
299
+ class), such as `google/vit-base-patch16-224`. General feature-extraction models such as
300
+ `google/vit-base-patch16-224-in21k` will not produce the desired results.
301
+
302
+ __Requirements:__
303
+
304
+ - `pip install transformers`
305
+
306
+ Args:
307
+ image: The image to classify.
308
+ model_id: The pretrained model to use for the classification.
309
+ top_k: The number of classes to return.
310
+
311
+ Returns:
312
+ A list of the `top_k` highest-scoring classes for each image. Each element in the list is a dictionary
313
+ in the following format:
314
+
315
+ ```python
316
+ {
317
+ 'p': 0.230, # class probability
318
+ 'class': 935, # class ID
319
+ 'label': 'mashed potato', # class label
320
+ }
321
+ ```
322
+
323
+ Examples:
324
+ Add a computed column that applies the model `google/vit-base-patch16-224` to an existing
325
+ Pixeltable column `image` of the table `tbl`:
326
+
327
+ >>> tbl['image_class'] = vit_for_image_classification(
328
+ ... tbl.image,
329
+ ... model_id='google/vit-base-patch16-224'
330
+ ... )
331
+ """
332
+ env.Env.get().require_package('transformers')
333
+ device = resolve_torch_device('auto')
334
+ import torch
335
+ from transformers import ViTImageProcessor, ViTForImageClassification
336
+
337
+ model: ViTForImageClassification = _lookup_model(model_id, ViTForImageClassification.from_pretrained, device=device)
338
+ processor = _lookup_processor(model_id, ViTImageProcessor.from_pretrained)
339
+ normalized_images = [normalize_image_mode(img) for img in image]
340
+
341
+ with torch.no_grad():
342
+ inputs = processor(images=normalized_images, return_tensors='pt')
343
+ outputs = model(**inputs.to(device))
344
+ logits = outputs.logits
345
+
346
+ probs = torch.softmax(logits, dim=-1)
347
+ top_k_probs, top_k_indices = torch.topk(probs, top_k, dim=-1)
348
+
349
+ return [
350
+ [
351
+ {
352
+ 'p': top_k_probs[n, k].item(),
353
+ 'class': top_k_indices[n, k].item(),
354
+ 'label': model.config.id2label[top_k_indices[n, k].item()],
355
+ }
356
+ for k in range(top_k_probs.shape[1])
357
+ ]
358
+ for n in range(top_k_probs.shape[0])
359
+ ]
360
+
361
+
285
362
  @pxt.udf
286
363
  def detr_to_coco(image: PIL.Image.Image, detr_info: dict[str, Any]) -> dict[str, Any]:
287
364
  """
@@ -92,7 +92,7 @@ def _(self: Expr, mode: str) -> ts.ColumnType:
92
92
 
93
93
 
94
94
  # Image.crop()
95
- @func.udf(substitute_fn=PIL.Image.Image.crop, param_types=[ts.ImageType(), ts.ArrayType((4,), dtype=ts.IntType())], is_method=True)
95
+ @func.udf(substitute_fn=PIL.Image.Image.crop, is_method=True)
96
96
  def crop(self: PIL.Image.Image, box: tuple[int, int, int, int]) -> PIL.Image.Image:
97
97
  """
98
98
  Return a rectangular region from the image. The box is a 4-tuple defining the left, upper, right, and lower pixel
@@ -151,7 +151,7 @@ def _(self: Expr) -> ts.ColumnType:
151
151
 
152
152
 
153
153
  # Image.resize()
154
- @func.udf(param_types=[ts.ImageType(), ts.ArrayType((2,), dtype=ts.IntType())], is_method=True)
154
+ @func.udf(is_method=True)
155
155
  def resize(self: PIL.Image.Image, size: tuple[int, int]) -> PIL.Image.Image:
156
156
  """
157
157
  Return a resized copy of the image. The size parameter is a tuple containing the width and height of the new image.
@@ -366,7 +366,7 @@ def quantize(
366
366
 
367
367
 
368
368
  @func.udf(substitute_fn=PIL.Image.Image.reduce, is_method=True)
369
- def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int]] = None) -> PIL.Image.Image:
369
+ def reduce(self: PIL.Image.Image, factor: int, box: Optional[tuple[int, int, int, int]] = None) -> PIL.Image.Image:
370
370
  """
371
371
  Reduce the image by the given factor.
372
372