pixeltable 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (56) hide show
  1. pixeltable/catalog/column.py +25 -48
  2. pixeltable/catalog/insertable_table.py +7 -4
  3. pixeltable/catalog/table.py +163 -57
  4. pixeltable/catalog/table_version.py +416 -140
  5. pixeltable/catalog/table_version_path.py +2 -2
  6. pixeltable/client.py +0 -4
  7. pixeltable/dataframe.py +65 -21
  8. pixeltable/env.py +16 -1
  9. pixeltable/exec/cache_prefetch_node.py +1 -1
  10. pixeltable/exec/in_memory_data_node.py +11 -7
  11. pixeltable/exprs/comparison.py +3 -3
  12. pixeltable/exprs/data_row.py +5 -1
  13. pixeltable/exprs/literal.py +16 -4
  14. pixeltable/exprs/row_builder.py +8 -40
  15. pixeltable/ext/__init__.py +5 -0
  16. pixeltable/ext/functions/yolox.py +92 -0
  17. pixeltable/func/aggregate_function.py +15 -15
  18. pixeltable/func/expr_template_function.py +9 -1
  19. pixeltable/func/globals.py +24 -14
  20. pixeltable/func/signature.py +18 -12
  21. pixeltable/func/udf.py +7 -2
  22. pixeltable/functions/__init__.py +8 -8
  23. pixeltable/functions/eval.py +7 -8
  24. pixeltable/functions/huggingface.py +47 -19
  25. pixeltable/functions/openai.py +2 -2
  26. pixeltable/functions/util.py +11 -0
  27. pixeltable/index/__init__.py +2 -0
  28. pixeltable/index/base.py +49 -0
  29. pixeltable/index/embedding_index.py +95 -0
  30. pixeltable/metadata/schema.py +45 -22
  31. pixeltable/plan.py +15 -34
  32. pixeltable/store.py +38 -41
  33. pixeltable/tests/conftest.py +5 -11
  34. pixeltable/tests/ext/test_yolox.py +21 -0
  35. pixeltable/tests/functions/test_fireworks.py +1 -0
  36. pixeltable/tests/functions/test_huggingface.py +2 -2
  37. pixeltable/tests/functions/test_openai.py +15 -5
  38. pixeltable/tests/functions/test_together.py +1 -0
  39. pixeltable/tests/test_component_view.py +14 -5
  40. pixeltable/tests/test_dataframe.py +19 -18
  41. pixeltable/tests/test_exprs.py +99 -102
  42. pixeltable/tests/test_function.py +51 -43
  43. pixeltable/tests/test_index.py +138 -0
  44. pixeltable/tests/test_migration.py +2 -1
  45. pixeltable/tests/test_snapshot.py +24 -1
  46. pixeltable/tests/test_table.py +101 -25
  47. pixeltable/tests/test_types.py +30 -0
  48. pixeltable/tests/test_video.py +16 -16
  49. pixeltable/tests/test_view.py +5 -0
  50. pixeltable/tests/utils.py +43 -9
  51. pixeltable/tool/create_test_db_dump.py +16 -0
  52. pixeltable/type_system.py +37 -45
  53. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/METADATA +5 -4
  54. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/RECORD +56 -49
  55. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/LICENSE +0 -0
  56. {pixeltable-0.2.4.dist-info → pixeltable-0.2.5.dist-info}/WHEEL +0 -0
@@ -101,8 +101,8 @@ class TableVersionPath:
101
101
  return DataFrame(self).__getitem__(index)
102
102
 
103
103
  def columns(self) -> List[Column]:
104
- """Return all columns visible in this tbl version path, including columns from bases"""
105
- result = self.tbl_version.cols.copy()
104
+ """Return all user columns visible in this tbl version path, including columns from bases"""
105
+ result = list(self.tbl_version.cols_by_name.values())
106
106
  if self.base is not None:
107
107
  base_cols = self.base.columns()
108
108
  # we only include base columns that don't conflict with one of our column names
pixeltable/client.py CHANGED
@@ -132,10 +132,6 @@ class Client:
132
132
  Create a table with an int and a string column:
133
133
 
134
134
  >>> table = cl.create_table('my_table', schema={'col1': IntType(), 'col2': StringType()})
135
-
136
- Create a table with a single indexed image column:
137
-
138
- >>> table = cl.create_table('my_table', schema={'col1': {'type': ImageType(), 'indexed': True}})
139
135
  """
140
136
  path = catalog.Path(path_str)
141
137
  self.catalog.paths.check_is_valid(path, expected=None)
pixeltable/dataframe.py CHANGED
@@ -11,6 +11,8 @@ import traceback
11
11
  from pathlib import Path
12
12
  from typing import List, Optional, Any, Dict, Generator, Tuple, Set
13
13
 
14
+ import PIL.Image
15
+ import cv2
14
16
  import pandas as pd
15
17
  import pandas.io.formats.style
16
18
  import sqlalchemy as sql
@@ -31,15 +33,6 @@ __all__ = [
31
33
 
32
34
  _logger = logging.getLogger('pixeltable')
33
35
 
34
- def _format_img(img: object) -> str:
35
- """
36
- Create <img> tag for Image object.
37
- """
38
- assert isinstance(img, Image.Image), f'Wrong type: {type(img)}'
39
- with io.BytesIO() as buffer:
40
- img.save(buffer, 'jpeg')
41
- img_base64 = base64.b64encode(buffer.getvalue()).decode()
42
- return f'<div style="width:200px;"><img src="data:image/jpeg;base64,{img_base64}" width="200" /></div>'
43
36
 
44
37
  def _create_source_tag(file_path: str) -> str:
45
38
  abs_path = Path(file_path)
@@ -50,21 +43,17 @@ def _create_source_tag(file_path: str) -> str:
50
43
  mime_attr = f'type="{mime}"' if mime is not None else ''
51
44
  return f'<source src="{src_url}" {mime_attr} />'
52
45
 
53
- def _format_video(file_path: str) -> str:
54
- return f'<video controls>{_create_source_tag(file_path)}</video>'
55
-
56
- def _format_audio(file_path: str) -> str:
57
- return f'<audio controls>{_create_source_tag(file_path)}</audio>'
58
46
 
59
47
  class DataFrameResultSet:
48
+
60
49
  def __init__(self, rows: List[List[Any]], col_names: List[str], col_types: List[ColumnType]):
61
50
  self._rows = rows
62
51
  self._col_names = col_names
63
52
  self._col_types = col_types
64
53
  self._formatters = {
65
- ts.ImageType: _format_img,
66
- ts.VideoType: _format_video,
67
- ts.AudioType: _format_audio,
54
+ ts.ImageType: self._format_img,
55
+ ts.VideoType: self._format_video,
56
+ ts.AudioType: self._format_audio,
68
57
  }
69
58
 
70
59
  def __len__(self) -> int:
@@ -85,9 +74,7 @@ class DataFrameResultSet:
85
74
  for col_name, col_type in zip(self._col_names, self._col_types)
86
75
  if col_type.__class__ in self._formatters
87
76
  }
88
-
89
- # TODO: why does mypy complain about formatters having an incorrect type?
90
- return self.to_pandas().to_html(formatters=formatters, escape=False, index=False) # type: ignore[arg-type]
77
+ return self.to_pandas().to_html(formatters=formatters, escape=False, index=False)
91
78
 
92
79
  def __str__(self) -> str:
93
80
  return self.to_pandas().to_string()
@@ -102,6 +89,64 @@ class DataFrameResultSet:
102
89
  def _row_to_dict(self, row_idx: int) -> Dict[str, Any]:
103
90
  return {self._col_names[i]: self._rows[row_idx][i] for i in range(len(self._col_names))}
104
91
 
92
+ # Formatters
93
+
94
+ def _format_img(self, img: Image.Image) -> str:
95
+ """
96
+ Create <img> tag for Image object.
97
+ """
98
+ assert isinstance(img, Image.Image), f'Wrong type: {type(img)}'
99
+ # Try to make it look decent in a variety of display scenarios
100
+ if len(self._rows) > 1:
101
+ width = 240 # Multiple rows: display small images
102
+ elif len(self._col_names) > 1:
103
+ width = 480 # Multiple columns: display medium images
104
+ else:
105
+ width = 640 # A single image: larger display
106
+ with io.BytesIO() as buffer:
107
+ img.save(buffer, 'jpeg')
108
+ img_base64 = base64.b64encode(buffer.getvalue()).decode()
109
+ return f'''
110
+ <div style="width:{width}px;">
111
+ <img src="data:image/jpeg;base64,{img_base64}" width="{width}" />
112
+ </div>
113
+ '''
114
+
115
+ def _format_video(self, file_path: str) -> str:
116
+ thumb_tag = ""
117
+ # Attempt to extract the first frame of the video to use as a thumbnail,
118
+ # so that the notebook can be exported as HTML and viewed in contexts where
119
+ # the video itself is not accessible.
120
+ # TODO(aaron-siegel): If the video is backed by a concrete external URL,
121
+ # should we link to that instead?
122
+ video_reader = cv2.VideoCapture(str(file_path))
123
+ if video_reader.isOpened():
124
+ status, img_array = video_reader.read()
125
+ if status:
126
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
127
+ thumb = PIL.Image.fromarray(img_array)
128
+ with io.BytesIO() as buffer:
129
+ thumb.save(buffer, 'jpeg')
130
+ thumb_base64 = base64.b64encode(buffer.getvalue()).decode()
131
+ thumb_tag = f'poster="data:image/jpeg;base64,{thumb_base64}"'
132
+ video_reader.release()
133
+ if len(self._rows) > 1:
134
+ width = 320
135
+ elif len(self._col_names) > 1:
136
+ width = 480
137
+ else:
138
+ width = 800
139
+ return f'''
140
+ <div style="width:{width}px;">
141
+ <video controls width="{width}" {thumb_tag}>
142
+ {_create_source_tag(file_path)}
143
+ </video>
144
+ </div>
145
+ '''
146
+
147
+ def _format_audio(self, file_path: str) -> str:
148
+ return f'<audio controls>{_create_source_tag(file_path)}</audio>'
149
+
105
150
  def __getitem__(self, index: Any) -> Any:
106
151
  if isinstance(index, str):
107
152
  if index not in self._col_names:
@@ -173,7 +218,6 @@ class AnalysisInfo:
173
218
  self.filter.release()
174
219
 
175
220
 
176
-
177
221
  class DataFrame:
178
222
  def __init__(
179
223
  self, tbl: catalog.TableVersionPath,
pixeltable/env.py CHANGED
@@ -10,8 +10,8 @@ import os
10
10
  import socketserver
11
11
  import sys
12
12
  import threading
13
- import typing
14
13
  import uuid
14
+ import warnings
15
15
  from pathlib import Path
16
16
  from typing import Callable, Optional, Dict, Any, List
17
17
 
@@ -19,6 +19,7 @@ import pgserver
19
19
  import sqlalchemy as sql
20
20
  import yaml
21
21
  from sqlalchemy_utils.functions import database_exists, create_database, drop_database
22
+ from tqdm import TqdmWarning
22
23
 
23
24
  import pixeltable.exceptions as excs
24
25
  from pixeltable import metadata
@@ -188,11 +189,21 @@ class Env:
188
189
  fh = logging.FileHandler(self._log_dir / self._logfilename, mode='w')
189
190
  fh.setFormatter(logging.Formatter(self._log_fmt_str))
190
191
  self._logger.addHandler(fh)
192
+
193
+ # configure sqlalchemy logging
191
194
  sql_logger = logging.getLogger('sqlalchemy.engine')
192
195
  sql_logger.setLevel(logging.INFO)
193
196
  sql_logger.addHandler(fh)
194
197
  sql_logger.propagate = False
195
198
 
199
+ # configure pyav logging
200
+ av_logfilename = self._logfilename.replace('.log', '_av.log')
201
+ av_fh = logging.FileHandler(self._log_dir / av_logfilename, mode='w')
202
+ av_fh.setFormatter(logging.Formatter(self._log_fmt_str))
203
+ av_logger = logging.getLogger('libav')
204
+ av_logger.addHandler(av_fh)
205
+ av_logger.propagate = False
206
+
196
207
  # empty tmp dir
197
208
  for path in glob.glob(f'{self._tmp_dir}/*'):
198
209
  os.remove(path)
@@ -229,6 +240,9 @@ class Env:
229
240
  self._set_up_runtime()
230
241
  self.log_to_stdout(False)
231
242
 
243
+ # Disable spurious warnings
244
+ warnings.simplefilter("ignore", category=TqdmWarning)
245
+
232
246
  def upgrade_metadata(self) -> None:
233
247
  metadata.upgrade_md(self._sa_engine)
234
248
 
@@ -320,6 +334,7 @@ class Env:
320
334
  check('torchvision')
321
335
  check('transformers')
322
336
  check('sentence_transformers')
337
+ check('yolox')
323
338
  check('boto3')
324
339
  check('pyarrow')
325
340
  check('spacy') # TODO: deal with en-core-web-sm
@@ -89,7 +89,7 @@ class CachePrefetchNode(ExecNode):
89
89
  # preserve the file extension, if there is one
90
90
  extension = ''
91
91
  if parsed.path != '':
92
- p = Path(urllib.parse.unquote(parsed.path))
92
+ p = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed.path)))
93
93
  extension = p.suffix
94
94
  tmp_path = env.Env.get().create_tmp_path(extension=extension)
95
95
  try:
@@ -29,18 +29,21 @@ class InMemoryDataNode(ExecNode):
29
29
 
30
30
  def _open(self) -> None:
31
31
  """Create row batch and populate with self.input_rows"""
32
- column_info = {info.col.name: info for info in self.row_builder.output_slot_idxs()}
32
+ column_info = {info.col.id: info for info in self.row_builder.output_slot_idxs()}
33
+ # exclude system columns
34
+ user_column_info = {info.col.name: info for _, info in column_info.items() if info.col.name is not None}
33
35
  # stored columns that are not computed
34
- inserted_column_names = set([
35
- info.col.name for info in self.row_builder.output_slot_idxs()
36
+ inserted_col_ids = set([
37
+ info.col.id for info in self.row_builder.output_slot_idxs()
36
38
  if info.col.is_stored and not info.col.is_computed
37
39
  ])
38
40
 
39
41
  self.output_rows = DataRowBatch(self.tbl, self.row_builder, len(self.input_rows))
40
42
  for row_idx, input_row in enumerate(self.input_rows):
41
43
  # populate the output row with the values provided in the input row
44
+ input_col_ids: List[int] = []
42
45
  for col_name, val in input_row.items():
43
- col_info = column_info.get(col_name)
46
+ col_info = user_column_info.get(col_name)
44
47
  assert col_info is not None
45
48
 
46
49
  if col_info.col.col_type.is_image_type() and isinstance(val, bytes):
@@ -49,11 +52,12 @@ class InMemoryDataNode(ExecNode):
49
52
  open(path, 'wb').write(val)
50
53
  val = path
51
54
  self.output_rows[row_idx][col_info.slot_idx] = val
55
+ input_col_ids.append(col_info.col.id)
52
56
 
53
57
  # set the remaining stored non-computed columns to null
54
- null_col_names = inserted_column_names - set(input_row.keys())
55
- for col_name in null_col_names:
56
- col_info = column_info.get(col_name)
58
+ null_col_ids = inserted_col_ids - set(input_col_ids)
59
+ for col_id in null_col_ids:
60
+ col_info = column_info.get(col_id)
57
61
  assert col_info is not None
58
62
  self.output_rows[row_idx][col_info.slot_idx] = None
59
63
 
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations
2
+
2
3
  from typing import Optional, List, Any, Dict, Tuple
3
4
 
4
5
  import sqlalchemy as sql
5
6
 
6
- from .globals import ComparisonOperator
7
+ from .data_row import DataRow
7
8
  from .expr import Expr
9
+ from .globals import ComparisonOperator
8
10
  from .predicate import Predicate
9
- from .data_row import DataRow
10
11
  from .row_builder import RowBuilder
11
- import pixeltable.catalog as catalog
12
12
 
13
13
 
14
14
  class Comparison(Predicate):
@@ -5,6 +5,8 @@ import urllib.parse
5
5
  import urllib.request
6
6
  from typing import Optional, List, Any, Tuple
7
7
 
8
+ import sqlalchemy as sql
9
+ import pgvector.sqlalchemy
8
10
  import PIL
9
11
  import numpy as np
10
12
 
@@ -110,7 +112,7 @@ class DataRow:
110
112
 
111
113
  return self.vals[index]
112
114
 
113
- def get_stored_val(self, index: object) -> Any:
115
+ def get_stored_val(self, index: object, sa_col_type: Optional[sql.types.TypeEngine] = None) -> Any:
114
116
  """Return the value that gets stored in the db"""
115
117
  assert self.excs[index] is None
116
118
  if not self.has_val[index]:
@@ -125,6 +127,8 @@ class DataRow:
125
127
  if self.vals[index] is not None and index in self.array_slot_idxs:
126
128
  assert isinstance(self.vals[index], np.ndarray)
127
129
  np_array = self.vals[index]
130
+ if sa_col_type is not None and isinstance(sa_col_type, pgvector.sqlalchemy.Vector):
131
+ return np_array
128
132
  buffer = io.BytesIO()
129
133
  np.save(buffer, np_array)
130
134
  return buffer.getvalue()
@@ -1,13 +1,16 @@
1
1
  from __future__ import annotations
2
+
3
+ import datetime
2
4
  from typing import Optional, List, Any, Dict, Tuple
3
5
 
4
6
  import sqlalchemy as sql
5
7
 
6
- from .expr import Expr
8
+ import pixeltable.exceptions as excs
9
+ import pixeltable.type_system as ts
7
10
  from .data_row import DataRow
11
+ from .expr import Expr
8
12
  from .row_builder import RowBuilder
9
- import pixeltable.catalog as catalog
10
- import pixeltable.type_system as ts
13
+
11
14
 
12
15
  class Literal(Expr):
13
16
  def __init__(self, val: Any, col_type: Optional[ts.ColumnType] = None):
@@ -46,9 +49,18 @@ class Literal(Expr):
46
49
  data_row[self.slot_idx] = self.val
47
50
 
48
51
  def _as_dict(self) -> Dict:
49
- return {'val': self.val, **super()._as_dict()}
52
+ # For some types, we need to explictly record their type, because JSON does not know
53
+ # how to interpret them unambiguously
54
+ if self.col_type.is_timestamp_type():
55
+ return {'val': self.val.isoformat(), 'val_t': self.col_type._type.name, **super()._as_dict()}
56
+ else:
57
+ return {'val': self.val, **super()._as_dict()}
50
58
 
51
59
  @classmethod
52
60
  def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
53
61
  assert 'val' in d
62
+ if 'val_t' in d:
63
+ val_t = d['val_t']
64
+ assert val_t == ts.ColumnType.Type.TIMESTAMP.name
65
+ return cls(datetime.datetime.fromisoformat(d['val']))
54
66
  return cls(d['val'])
@@ -54,14 +54,12 @@ class RowBuilder:
54
54
  target_exprs: List[Expr] # exprs corresponding to target_slot_idxs
55
55
 
56
56
  def __init__(
57
- self, output_exprs: List[Expr], columns: List[catalog.Column],
58
- indices: List[Tuple[catalog.Column, func.Function]], input_exprs: List[Expr]
57
+ self, output_exprs: List[Expr], columns: List[catalog.Column], input_exprs: List[Expr]
59
58
  ):
60
59
  """
61
60
  Args:
62
61
  output_exprs: list of Exprs to be evaluated
63
62
  columns: list of columns to be materialized
64
- indices: list of embeddings to be materialized (Tuple[indexed column, embedding function])
65
63
  """
66
64
  self.unique_exprs = ExprSet() # dependencies precede their dependents
67
65
  self.next_slot_idx = 0
@@ -73,7 +71,6 @@ class RowBuilder:
73
71
  # output exprs: all exprs the caller wants to materialize
74
72
  # - explicitly requested output_exprs
75
73
  # - values for computed columns
76
- # - embedding values for indices
77
74
  resolve_cols = set(columns)
78
75
  self.output_exprs = [
79
76
  self._record_unique_expr(e.copy().resolve_computed_cols(resolve_cols=resolve_cols), recursive=True)
@@ -97,21 +94,6 @@ class RowBuilder:
97
94
  ref = self._record_unique_expr(ref, recursive=False)
98
95
  self.add_table_column(col, ref.slot_idx)
99
96
 
100
- # record indices; indexed by slot_idx
101
- self.index_columns: List[catalog.Column] = []
102
- for col, embedding_fn in indices:
103
- # we assume that the parameter of the embedding function is a ref to an image column
104
- assert col.col_type.is_image_type()
105
- # construct expr to compute embedding; explicitly resize images to the required size
106
- target_img_type = next(iter(embedding_fn.signature.parameters.values())).col_type
107
- expr = embedding_fn(ColumnRef(col).resize(target_img_type.size))
108
- expr = self._record_unique_expr(expr, recursive=True)
109
- self.output_exprs.append(expr)
110
- if len(self.index_columns) <= expr.slot_idx:
111
- # pad to slot_idx
112
- self.index_columns.extend([None] * (expr.slot_idx - len(self.index_columns) + 1))
113
- self.index_columns[expr.slot_idx] = col
114
-
115
97
  # default eval ctx: all output exprs
116
98
  self.default_eval_ctx = self.create_eval_ctx(self.output_exprs, exclude=unique_input_exprs)
117
99
 
@@ -170,13 +152,6 @@ class RowBuilder:
170
152
  """Return ColumnSlotIdx for output columns"""
171
153
  return self.table_columns
172
154
 
173
- def index_slot_idxs(self) -> List[ColumnSlotIdx]:
174
- """Return ColumnSlotIdx for index columns"""
175
- return [
176
- ColumnSlotIdx(self.output_columns[i], i) for i in range(len(self.index_columns))
177
- if self.output_columns[i] is not None
178
- ]
179
-
180
155
  @property
181
156
  def num_materialized(self) -> int:
182
157
  return self.next_slot_idx
@@ -334,22 +309,15 @@ class RowBuilder:
334
309
  exc = data_row.get_exc(slot_idx)
335
310
  num_excs += 1
336
311
  exc_col_ids.add(col.id)
337
- table_row[col.storage_name()] = None
338
- table_row[col.errortype_storage_name()] = type(exc).__name__
339
- table_row[col.errormsg_storage_name()] = str(exc)
312
+ table_row[col.store_name()] = None
313
+ table_row[col.errortype_store_name()] = type(exc).__name__
314
+ table_row[col.errormsg_store_name()] = str(exc)
340
315
  else:
341
- val = data_row.get_stored_val(slot_idx)
342
- table_row[col.storage_name()] = val
316
+ val = data_row.get_stored_val(slot_idx, col.sa_col.type)
317
+ table_row[col.store_name()] = val
343
318
  # we unfortunately need to set these, even if there are no errors
344
- table_row[col.errortype_storage_name()] = None
345
- table_row[col.errormsg_storage_name()] = None
346
-
347
- for slot_idx, col in enumerate(self.index_columns):
348
- if col is None:
349
- continue
350
- # don't use get_stored_val() here, we need to pass in the ndarray
351
- val = data_row[slot_idx]
352
- table_row[col.index_storage_name()] = val
319
+ table_row[col.errortype_store_name()] = None
320
+ table_row[col.errormsg_store_name()] = None
353
321
 
354
322
  return table_row, num_excs
355
323
 
@@ -0,0 +1,5 @@
1
+ """
2
+ Extended integrations for Pixeltable. This package contains experimental or demonstration features that
3
+ are not intended for production use. Long-term support cannot be guaranteed, usually because the features
4
+ have dependencies whose future support is unclear.
5
+ """
@@ -0,0 +1,92 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Iterable, Iterator
4
+ from urllib.request import urlretrieve
5
+
6
+ import PIL.Image
7
+ import numpy as np
8
+ import torch
9
+ from yolox.data import ValTransform
10
+ from yolox.exp import get_exp, Exp
11
+ from yolox.models import YOLOX
12
+ from yolox.utils import postprocess
13
+
14
+ import pixeltable as pxt
15
+ from pixeltable import env
16
+ from pixeltable.func import Batch
17
+ from pixeltable.functions.util import resolve_torch_device
18
+
19
+ _logger = logging.getLogger('pixeltable')
20
+
21
+
22
+ @pxt.udf(batch_size=4)
23
+ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0.5) -> Batch[dict]:
24
+ """
25
+ Runs the specified YOLOX object detection model on an image.
26
+
27
+ YOLOX support is part of the `pixeltable.ext` package: long-term support is not guaranteed, and it is not
28
+ intended for use in production applications.
29
+
30
+ Parameters:
31
+ - `model_id` - one of: `yolox_nano, `yolox_tiny`, `yolox_s`, `yolox_m`, `yolox_l`, `yolox_x`
32
+ - `threshold` - the threshold for object detection
33
+ """
34
+ model, exp = _lookup_model(model_id, 'cpu')
35
+ image_tensors = list(_images_to_tensors(images, exp))
36
+ batch_tensor = torch.stack(image_tensors)
37
+
38
+ with torch.no_grad():
39
+ output_tensor = model(batch_tensor)
40
+
41
+ outputs = postprocess(
42
+ output_tensor, 80, threshold, exp.nmsthre, class_agnostic=False
43
+ )
44
+
45
+ results: list[dict] = []
46
+ for image in images:
47
+ ratio = min(exp.test_size[0] / image.height, exp.test_size[1] / image.width)
48
+ if outputs[0] is None:
49
+ results.append({'bboxes': [], 'scores': [], 'labels': []})
50
+ else:
51
+ results.append({
52
+ 'bboxes': [(output[:4] / ratio).tolist() for output in outputs[0]],
53
+ 'scores': [output[4].item() * output[5].item() for output in outputs[0]],
54
+ 'labels': [int(output[6]) for output in outputs[0]]
55
+ })
56
+ return results
57
+
58
+
59
+ def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
60
+ for image in images:
61
+ image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
62
+ yield torch.from_numpy(image_transform)
63
+
64
+
65
+ def _lookup_model(model_id: str, device: str) -> (YOLOX, Exp):
66
+ key = (model_id, device)
67
+ if key in _model_cache:
68
+ return _model_cache[key]
69
+
70
+ weights_url = f'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/{model_id}.pth'
71
+ weights_file = Path(f'{env.Env.get().tmp_dir}/{model_id}.pth')
72
+ if not weights_file.exists():
73
+ _logger.info(f'Downloading weights for YOLOX model {model_id}: from {weights_url} -> {weights_file}')
74
+ urlretrieve(weights_url, weights_file)
75
+
76
+ exp = get_exp(exp_name=model_id)
77
+ model = exp.get_model().to(device)
78
+
79
+ model.eval()
80
+ model.head.training = False
81
+ model.training = False
82
+
83
+ # Load in the weights from training
84
+ weights = torch.load(weights_file, map_location=torch.device(device))
85
+ model.load_state_dict(weights['model'])
86
+
87
+ _model_cache[key] = (model, exp)
88
+ return model, exp
89
+
90
+
91
+ _model_cache = {}
92
+ _val_transform = ValTransform(legacy=False)
@@ -3,13 +3,14 @@ from __future__ import annotations
3
3
  import abc
4
4
  import importlib
5
5
  import inspect
6
- from typing import Optional, Any, Type, List, Dict
6
+ from typing import Optional, Any, Type, List, Dict, Callable
7
7
  import itertools
8
8
 
9
9
  import pixeltable.exceptions as excs
10
10
  import pixeltable.type_system as ts
11
11
  from .function import Function
12
12
  from .signature import Signature, Parameter
13
+ from .globals import validate_symbol_path
13
14
 
14
15
 
15
16
  class Aggregator(abc.ABC):
@@ -136,8 +137,7 @@ def uda(
136
137
  update_types: List[ts.ColumnType],
137
138
  init_types: Optional[List[ts.ColumnType]] = None,
138
139
  requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
139
- name: Optional[str] = None
140
- ) -> Type[Aggregator]:
140
+ ) -> Callable:
141
141
  """Decorator for user-defined aggregate functions.
142
142
 
143
143
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -155,14 +155,11 @@ def uda(
155
155
  - requires_order_by: if True, the first parameter to the function is the order-by expression
156
156
  - allows_std_agg: if True, the function can be used as a standard aggregate function w/o a window
157
157
  - allows_window: if True, the function can be used with a window
158
- - name: name of the AggregateFunction instance; if None, the class name is used
159
158
  """
160
- if name is not None and not name.isidentifier():
161
- raise excs.Error(f'Invalid name: {name}')
162
159
  if init_types is None:
163
160
  init_types = []
164
161
 
165
- def decorator(cls: Type[Aggregator]) -> Type[Aggregator]:
162
+ def decorator(cls: Type[Aggregator]) -> Type[Function]:
166
163
  # validate type parameters
167
164
  num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
168
165
  if num_init_params > 0:
@@ -178,17 +175,20 @@ def uda(
178
175
  assert value_type is not None
179
176
 
180
177
  # the AggregateFunction instance resides in the same module as cls
181
- module_path = cls.__module__
182
- nonlocal name
183
- name = name or cls.__name__
184
- instance_path = f'{module_path}.{name}'
178
+ class_path = f'{cls.__module__}.{cls.__qualname__}'
179
+ # nonlocal name
180
+ # name = name or cls.__name__
181
+ # instance_path_elements = class_path.split('.')[:-1] + [name]
182
+ # instance_path = '.'.join(instance_path_elements)
185
183
 
186
184
  # create the corresponding AggregateFunction instance
187
185
  instance = AggregateFunction(
188
- cls, instance_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
189
- module = importlib.import_module(module_path)
190
- setattr(module, name, instance)
186
+ cls, class_path, init_types, update_types, value_type, requires_order_by, allows_std_agg, allows_window)
187
+ # do the path validation at the very end, in order to be able to write tests for the other failure cases
188
+ validate_symbol_path(class_path)
189
+ #module = importlib.import_module(cls.__module__)
190
+ #setattr(module, name, instance)
191
191
 
192
- return cls
192
+ return instance
193
193
 
194
194
  return decorator
@@ -50,9 +50,17 @@ class ExprTemplateFunction(Function):
50
50
  bound_args.update(
51
51
  {param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
52
52
  result = self.expr.copy()
53
+ import pixeltable.exprs as exprs
53
54
  for param_name, arg in bound_args.items():
54
55
  param_expr = self.param_exprs_by_name[param_name]
55
- result = result.substitute(param_expr, arg)
56
+ if not isinstance(arg, exprs.Expr):
57
+ # TODO: use the available param_expr.col_type
58
+ arg_expr = exprs.Expr.from_object(arg)
59
+ if arg_expr is None:
60
+ raise excs.Error(f'{self.self_name}(): cannot convert argument {arg} to a Pixeltable expression')
61
+ else:
62
+ arg_expr = arg
63
+ result = result.substitute(param_expr, arg_expr)
56
64
  import pixeltable.exprs as exprs
57
65
  assert not result.contains(exprs.Variable)
58
66
  return result