pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (120) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/column.py +37 -11
  5. pixeltable/catalog/globals.py +21 -0
  6. pixeltable/catalog/insertable_table.py +6 -4
  7. pixeltable/catalog/table.py +227 -148
  8. pixeltable/catalog/table_version.py +66 -28
  9. pixeltable/catalog/table_version_path.py +0 -8
  10. pixeltable/catalog/view.py +18 -19
  11. pixeltable/dataframe.py +16 -32
  12. pixeltable/env.py +6 -1
  13. pixeltable/exec/__init__.py +1 -2
  14. pixeltable/exec/aggregation_node.py +27 -17
  15. pixeltable/exec/cache_prefetch_node.py +1 -1
  16. pixeltable/exec/data_row_batch.py +9 -26
  17. pixeltable/exec/exec_node.py +36 -7
  18. pixeltable/exec/expr_eval_node.py +19 -11
  19. pixeltable/exec/in_memory_data_node.py +14 -11
  20. pixeltable/exec/sql_node.py +266 -138
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/arithmetic_expr.py +3 -1
  23. pixeltable/exprs/array_slice.py +7 -7
  24. pixeltable/exprs/column_property_ref.py +37 -10
  25. pixeltable/exprs/column_ref.py +93 -14
  26. pixeltable/exprs/comparison.py +5 -5
  27. pixeltable/exprs/compound_predicate.py +8 -7
  28. pixeltable/exprs/data_row.py +56 -36
  29. pixeltable/exprs/expr.py +65 -63
  30. pixeltable/exprs/expr_dict.py +55 -0
  31. pixeltable/exprs/expr_set.py +26 -15
  32. pixeltable/exprs/function_call.py +53 -24
  33. pixeltable/exprs/globals.py +4 -1
  34. pixeltable/exprs/in_predicate.py +8 -7
  35. pixeltable/exprs/inline_expr.py +4 -4
  36. pixeltable/exprs/is_null.py +4 -4
  37. pixeltable/exprs/json_mapper.py +11 -12
  38. pixeltable/exprs/json_path.py +5 -10
  39. pixeltable/exprs/literal.py +5 -5
  40. pixeltable/exprs/method_ref.py +5 -4
  41. pixeltable/exprs/object_ref.py +2 -1
  42. pixeltable/exprs/row_builder.py +88 -36
  43. pixeltable/exprs/rowid_ref.py +14 -13
  44. pixeltable/exprs/similarity_expr.py +12 -7
  45. pixeltable/exprs/sql_element_cache.py +12 -6
  46. pixeltable/exprs/type_cast.py +8 -6
  47. pixeltable/exprs/variable.py +5 -4
  48. pixeltable/ext/functions/whisperx.py +7 -2
  49. pixeltable/func/aggregate_function.py +1 -1
  50. pixeltable/func/callable_function.py +2 -2
  51. pixeltable/func/function.py +11 -10
  52. pixeltable/func/function_registry.py +6 -7
  53. pixeltable/func/query_template_function.py +11 -12
  54. pixeltable/func/signature.py +17 -15
  55. pixeltable/func/udf.py +0 -4
  56. pixeltable/functions/__init__.py +2 -2
  57. pixeltable/functions/audio.py +4 -6
  58. pixeltable/functions/globals.py +84 -42
  59. pixeltable/functions/huggingface.py +31 -34
  60. pixeltable/functions/image.py +59 -45
  61. pixeltable/functions/json.py +0 -1
  62. pixeltable/functions/llama_cpp.py +106 -0
  63. pixeltable/functions/mistralai.py +2 -2
  64. pixeltable/functions/ollama.py +147 -0
  65. pixeltable/functions/openai.py +22 -25
  66. pixeltable/functions/replicate.py +72 -0
  67. pixeltable/functions/string.py +59 -50
  68. pixeltable/functions/timestamp.py +20 -20
  69. pixeltable/functions/together.py +2 -2
  70. pixeltable/functions/video.py +11 -20
  71. pixeltable/functions/whisper.py +2 -20
  72. pixeltable/globals.py +65 -74
  73. pixeltable/index/base.py +2 -2
  74. pixeltable/index/btree.py +20 -7
  75. pixeltable/index/embedding_index.py +12 -14
  76. pixeltable/io/__init__.py +1 -2
  77. pixeltable/io/external_store.py +11 -5
  78. pixeltable/io/fiftyone.py +178 -0
  79. pixeltable/io/globals.py +98 -2
  80. pixeltable/io/hf_datasets.py +1 -1
  81. pixeltable/io/label_studio.py +6 -6
  82. pixeltable/io/parquet.py +14 -13
  83. pixeltable/iterators/base.py +3 -2
  84. pixeltable/iterators/document.py +10 -8
  85. pixeltable/iterators/video.py +126 -60
  86. pixeltable/metadata/__init__.py +4 -3
  87. pixeltable/metadata/converters/convert_14.py +4 -2
  88. pixeltable/metadata/converters/convert_15.py +1 -1
  89. pixeltable/metadata/converters/convert_19.py +1 -0
  90. pixeltable/metadata/converters/convert_20.py +1 -1
  91. pixeltable/metadata/converters/convert_21.py +34 -0
  92. pixeltable/metadata/converters/util.py +54 -12
  93. pixeltable/metadata/notes.py +1 -0
  94. pixeltable/metadata/schema.py +40 -21
  95. pixeltable/plan.py +149 -165
  96. pixeltable/py.typed +0 -0
  97. pixeltable/store.py +57 -37
  98. pixeltable/tool/create_test_db_dump.py +6 -6
  99. pixeltable/tool/create_test_video.py +1 -1
  100. pixeltable/tool/doc_plugins/griffe.py +3 -34
  101. pixeltable/tool/embed_udf.py +1 -1
  102. pixeltable/tool/mypy_plugin.py +55 -0
  103. pixeltable/type_system.py +260 -61
  104. pixeltable/utils/arrow.py +10 -9
  105. pixeltable/utils/coco.py +4 -4
  106. pixeltable/utils/documents.py +16 -2
  107. pixeltable/utils/filecache.py +9 -9
  108. pixeltable/utils/formatter.py +10 -11
  109. pixeltable/utils/http_server.py +2 -5
  110. pixeltable/utils/media_store.py +6 -6
  111. pixeltable/utils/pytorch.py +10 -11
  112. pixeltable/utils/sql.py +2 -1
  113. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
  114. pixeltable-0.2.22.dist-info/RECORD +153 -0
  115. pixeltable/exec/media_validation_node.py +0 -43
  116. pixeltable/utils/help.py +0 -11
  117. pixeltable-0.2.20.dist-info/RECORD +0 -147
  118. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
  119. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
  120. {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
pixeltable/store.py CHANGED
@@ -7,18 +7,19 @@ import sys
7
7
  import urllib.parse
8
8
  import urllib.request
9
9
  import warnings
10
- from typing import Optional, Dict, Any, List, Tuple, Set, Union
10
+ from typing import Any, Iterator, Literal, Optional, Union
11
11
 
12
12
  import sqlalchemy as sql
13
- from tqdm import tqdm, TqdmWarning
13
+ from tqdm import TqdmWarning, tqdm
14
14
 
15
15
  import pixeltable.catalog as catalog
16
16
  import pixeltable.env as env
17
+ import pixeltable.exceptions as excs
17
18
  from pixeltable import exprs
18
19
  from pixeltable.exec import ExecNode
19
20
  from pixeltable.metadata import schema
20
21
  from pixeltable.utils.media_store import MediaStore
21
- from pixeltable.utils.sql import log_stmt, log_explain
22
+ from pixeltable.utils.sql import log_explain, log_stmt
22
23
 
23
24
  _logger = logging.getLogger('pixeltable')
24
25
 
@@ -31,35 +32,42 @@ class StoreBase:
31
32
  - v_min: version at which the row was created
32
33
  - v_max: version at which the row was deleted (or MAX_VERSION if it's still live)
33
34
  """
35
+ tbl_version: catalog.TableVersion
36
+ sa_md: sql.MetaData
37
+ sa_tbl: Optional[sql.Table]
38
+ _pk_cols: list[sql.Column]
39
+ v_min_col: sql.Column
40
+ v_max_col: sql.Column
41
+ base: Optional[StoreBase]
34
42
 
35
43
  __INSERT_BATCH_SIZE = 1000
36
44
 
37
45
  def __init__(self, tbl_version: catalog.TableVersion):
38
46
  self.tbl_version = tbl_version
39
47
  self.sa_md = sql.MetaData()
40
- self.sa_tbl: Optional[sql.Table] = None
48
+ self.sa_tbl = None
41
49
  # We need to declare a `base` variable here, even though it's only defined for instances of `StoreView`,
42
50
  # since it's referenced by various methods of `StoreBase`
43
51
  self.base = None if tbl_version.base is None else tbl_version.base.store_tbl
44
52
  self.create_sa_tbl()
45
53
 
46
- def pk_columns(self) -> List[sql.Column]:
47
- return self._pk_columns
54
+ def pk_columns(self) -> list[sql.Column]:
55
+ return self._pk_cols
48
56
 
49
- def rowid_columns(self) -> List[sql.Column]:
50
- return self._pk_columns[:-1]
57
+ def rowid_columns(self) -> list[sql.Column]:
58
+ return self._pk_cols[:-1]
51
59
 
52
60
  @abc.abstractmethod
53
- def _create_rowid_columns(self) -> List[sql.Column]:
61
+ def _create_rowid_columns(self) -> list[sql.Column]:
54
62
  """Create and return rowid columns"""
55
63
 
56
- def _create_system_columns(self) -> List[sql.Column]:
64
+ def _create_system_columns(self) -> list[sql.Column]:
57
65
  """Create and return system columns"""
58
66
  rowid_cols = self._create_rowid_columns()
59
67
  self.v_min_col = sql.Column('v_min', sql.BigInteger, nullable=False)
60
68
  self.v_max_col = \
61
69
  sql.Column('v_max', sql.BigInteger, nullable=False, server_default=str(schema.Table.MAX_VERSION))
62
- self._pk_columns = [*rowid_cols, self.v_min_col]
70
+ self._pk_cols = [*rowid_cols, self.v_min_col]
63
71
  return [*rowid_cols, self.v_min_col, self.v_max_col]
64
72
 
65
73
  def create_sa_tbl(self) -> None:
@@ -79,7 +87,7 @@ class StoreBase:
79
87
  # if we're called in response to a schema change, we need to remove the old table first
80
88
  self.sa_md.remove(self.sa_tbl)
81
89
 
82
- idxs: List[sql.Index] = []
90
+ idxs: list[sql.Index] = []
83
91
  # index for all system columns:
84
92
  # - base x view joins can be executed as merge joins
85
93
  # - speeds up ORDER BY rowid DESC
@@ -126,7 +134,7 @@ class StoreBase:
126
134
  return new_file_url
127
135
 
128
136
  def _move_tmp_media_files(
129
- self, table_rows: List[Dict[str, Any]], media_cols: List[catalog.Column], v_min: int
137
+ self, table_rows: list[dict[str, Any]], media_cols: list[catalog.Column], v_min: int
130
138
  ) -> None:
131
139
  """Move tmp media files that we generated to a permanent location"""
132
140
  for c in media_cols:
@@ -135,23 +143,17 @@ class StoreBase:
135
143
  table_row[c.store_name()] = self._move_tmp_media_file(file_url, c, v_min)
136
144
 
137
145
  def _create_table_row(
138
- self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, media_cols: List[catalog.Column],
139
- exc_col_ids: Set[int], v_min: int
140
- ) -> Tuple[Dict[str, Any], int]:
146
+ self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, exc_col_ids: set[int], pk: tuple[int, ...]
147
+ ) -> tuple[dict[str, Any], int]:
141
148
  """Return Tuple[complete table row, # of exceptions] for insert()
142
149
  Creates a row that includes the PK columns, with the values from input_row.pk.
143
150
  Returns:
144
151
  Tuple[complete table row, # of exceptions]
145
152
  """
146
153
  table_row, num_excs = row_builder.create_table_row(input_row, exc_col_ids)
147
-
148
- assert input_row.pk is not None and len(input_row.pk) == len(self._pk_columns)
149
- for pk_col, pk_val in zip(self._pk_columns, input_row.pk):
150
- if pk_col == self.v_min_col:
151
- table_row[pk_col.name] = v_min
152
- else:
153
- table_row[pk_col.name] = pk_val
154
-
154
+ assert len(pk) == len(self._pk_cols)
155
+ for pk_col, pk_val in zip(self._pk_cols, pk):
156
+ table_row[pk_col.name] = pk_val
155
157
  return table_row, num_excs
156
158
 
157
159
  def count(self, conn: Optional[sql.engine.Connection] = None) -> int:
@@ -212,14 +214,20 @@ class StoreBase:
212
214
  conn.execute(sql.text(stmt))
213
215
 
214
216
  def load_column(
215
- self, col: catalog.Column, exec_plan: ExecNode, value_expr_slot_idx: int, conn: sql.engine.Connection
217
+ self,
218
+ col: catalog.Column,
219
+ exec_plan: ExecNode,
220
+ value_expr_slot_idx: int,
221
+ conn: sql.engine.Connection,
222
+ on_error: Literal['abort', 'ignore']
216
223
  ) -> int:
217
224
  """Update store column of a computed column with values produced by an execution plan
218
225
 
219
226
  Returns:
220
227
  number of rows with exceptions
221
228
  Raises:
222
- sql.exc.DBAPIError if there was an error during SQL execution
229
+ sql.exc.DBAPIError if there was a SQL error during execution
230
+ excs.Error if on_error='abort' and there was an exception during row evaluation
223
231
  """
224
232
  num_excs = 0
225
233
  num_rows = 0
@@ -253,6 +261,10 @@ class StoreBase:
253
261
  if result_row.has_exc(value_expr_slot_idx):
254
262
  num_excs += 1
255
263
  value_exc = result_row.get_exc(value_expr_slot_idx)
264
+ if on_error == 'abort':
265
+ raise excs.Error(
266
+ f'Error while evaluating computed column `{col.name}`:\n{value_exc}'
267
+ ) from value_exc
256
268
  # we store a NULL value and record the exception/exc type
257
269
  error_type = type(value_exc).__name__
258
270
  error_msg = str(value_exc)
@@ -291,8 +303,8 @@ class StoreBase:
291
303
 
292
304
  def insert_rows(
293
305
  self, exec_plan: ExecNode, conn: sql.engine.Connection, v_min: Optional[int] = None,
294
- show_progress: bool = True
295
- ) -> Tuple[int, int, Set[int]]:
306
+ show_progress: bool = True, rowids: Optional[Iterator[int]] = None, abort_on_exc: bool = False
307
+ ) -> tuple[int, int, set[int]]:
296
308
  """Insert rows into the store table and update the catalog table's md
297
309
  Returns:
298
310
  number of inserted rows, number of exceptions, set of column ids that have exceptions
@@ -302,7 +314,7 @@ class StoreBase:
302
314
  # TODO: total?
303
315
  num_excs = 0
304
316
  num_rows = 0
305
- cols_with_excs: Set[int] = set()
317
+ cols_with_excs: set[int] = set()
306
318
  progress_bar: Optional[tqdm] = None # create this only after we started executing
307
319
  row_builder = exec_plan.row_builder
308
320
  media_cols = [info.col for info in row_builder.table_columns if info.col.col_type.is_media_type()]
@@ -312,13 +324,21 @@ class StoreBase:
312
324
  num_rows += len(row_batch)
313
325
  for batch_start_idx in range(0, len(row_batch), self.__INSERT_BATCH_SIZE):
314
326
  # compute batch of rows and convert them into table rows
315
- table_rows: List[Dict[str, Any]] = []
316
- for row_idx in range(batch_start_idx, min(batch_start_idx + self.__INSERT_BATCH_SIZE, len(row_batch))):
327
+ table_rows: list[dict[str, Any]] = []
328
+ batch_stop_idx = min(batch_start_idx + self.__INSERT_BATCH_SIZE, len(row_batch))
329
+ for row_idx in range(batch_start_idx, batch_stop_idx):
317
330
  row = row_batch[row_idx]
318
- table_row, num_row_exc = \
319
- self._create_table_row(row, row_builder, media_cols, cols_with_excs, v_min=v_min)
331
+ # if abort_on_exc == True, we need to check for media validation exceptions
332
+ if abort_on_exc and row.has_exc():
333
+ exc = row.get_first_exc()
334
+ raise exc
335
+
336
+ rowid = (next(rowids),) if rowids is not None else row.pk[:-1]
337
+ pk = rowid + (v_min,)
338
+ table_row, num_row_exc = self._create_table_row(row, row_builder, cols_with_excs, pk=pk)
320
339
  num_excs += num_row_exc
321
340
  table_rows.append(table_row)
341
+
322
342
  if show_progress:
323
343
  if progress_bar is None:
324
344
  warnings.simplefilter("ignore", category=TqdmWarning)
@@ -353,7 +373,7 @@ class StoreBase:
353
373
  return sql.and_(clause, self.base._versions_clause(versions[1:], match_on_vmin))
354
374
 
355
375
  def delete_rows(
356
- self, current_version: int, base_versions: List[Optional[int]], match_on_vmin: bool,
376
+ self, current_version: int, base_versions: list[Optional[int]], match_on_vmin: bool,
357
377
  where_clause: Optional[sql.ColumnElement[bool]], conn: sql.engine.Connection) -> int:
358
378
  """Mark rows as deleted that are live and were created prior to current_version.
359
379
  Also: populate the undo columns
@@ -397,7 +417,7 @@ class StoreTable(StoreBase):
397
417
  assert not tbl_version.is_view()
398
418
  super().__init__(tbl_version)
399
419
 
400
- def _create_rowid_columns(self) -> List[sql.Column]:
420
+ def _create_rowid_columns(self) -> list[sql.Column]:
401
421
  self.rowid_col = sql.Column('rowid', sql.BigInteger, nullable=False)
402
422
  return [self.rowid_col]
403
423
 
@@ -413,7 +433,7 @@ class StoreView(StoreBase):
413
433
  assert catalog_view.is_view()
414
434
  super().__init__(catalog_view)
415
435
 
416
- def _create_rowid_columns(self) -> List[sql.Column]:
436
+ def _create_rowid_columns(self) -> list[sql.Column]:
417
437
  # a view row corresponds directly to a single base row, which means it needs to duplicate its rowid columns
418
438
  self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
419
439
  return self.rowid_cols
@@ -439,7 +459,7 @@ class StoreComponentView(StoreView):
439
459
  def __init__(self, catalog_view: catalog.TableVersion):
440
460
  super().__init__(catalog_view)
441
461
 
442
- def _create_rowid_columns(self) -> List[sql.Column]:
462
+ def _create_rowid_columns(self) -> list[sql.Column]:
443
463
  # each base row is expanded into n view rows
444
464
  self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
445
465
  # name of pos column: avoid collisions with bases' pos columns
@@ -149,18 +149,18 @@ class Dumper:
149
149
  pxt.create_dir('views')
150
150
 
151
151
  # simple view
152
- v = pxt.create_view('views.view', t, filter=(t.c2 < 50))
152
+ v = pxt.create_view('views.view', t.where(t.c2 < 50))
153
153
  self.__add_expr_columns(v, 'view')
154
154
 
155
155
  # snapshot
156
- _ = pxt.create_view('views.snapshot', t, filter=(t.c2 >= 75), is_snapshot=True)
156
+ _ = pxt.create_view('views.snapshot', t.where(t.c2 >= 75), is_snapshot=True)
157
157
 
158
158
  # view of views
159
- vv = pxt.create_view('views.view_of_views', v, filter=(t.c2 >= 25))
159
+ vv = pxt.create_view('views.view_of_views', v.where(t.c2 >= 25))
160
160
  self.__add_expr_columns(vv, 'view_of_views')
161
161
 
162
162
  # empty view
163
- e = pxt.create_view('views.empty_view', t, filter=t.c2 == 4171780)
163
+ e = pxt.create_view('views.empty_view', t.where(t.c2 == 4171780))
164
164
  assert e.count() == 0
165
165
  self.__add_expr_columns(e, 'empty_view', include_expensive_functions=True)
166
166
 
@@ -278,13 +278,13 @@ class Dumper:
278
278
  # this breaks; TODO: why?
279
279
  #return t.where(t.c2 < i)
280
280
  return t.where(t.c2 < i).select(t.c1, t.c2)
281
- add_column('query_output', t.q1(t.c2))
281
+ add_column('query_output', t.queries.q1(t.c2))
282
282
 
283
283
  @t.query
284
284
  def q2(s: str):
285
285
  sim = t[f'{col_prefix}_function_call'].similarity(s)
286
286
  return t.order_by(sim, asc=False).select(t[f'{col_prefix}_function_call']).limit(5)
287
- add_column('sim_output', t.q2(t.c1))
287
+ add_column('sim_output', t.queries.q2(t.c1))
288
288
 
289
289
 
290
290
  @pxt.udf(_force_stored=True)
@@ -1,4 +1,4 @@
1
- import av
1
+ import av # type: ignore[import-untyped]
2
2
  import PIL.Image
3
3
  import PIL.ImageDraw
4
4
  import PIL.ImageFont
@@ -1,6 +1,6 @@
1
1
  import ast
2
- from typing import Optional, Union
3
2
  import warnings
3
+ from typing import Optional, Union
4
4
 
5
5
  import griffe
6
6
  import griffe.expressions
@@ -39,7 +39,7 @@ class PxtGriffeExtension(Extension):
39
39
  udf = griffe.dynamic_import(func.path)
40
40
  assert isinstance(udf, pxt.Function)
41
41
  # Convert the return type to a Pixeltable type reference
42
- func.returns = self.__column_type_to_display_str(udf.signature.get_return_type())
42
+ func.returns = str(udf.signature.get_return_type())
43
43
  # Convert the parameter types to Pixeltable type references
44
44
  for griffe_param in func.parameters:
45
45
  assert isinstance(griffe_param.annotation, griffe.expressions.Expr)
@@ -47,35 +47,4 @@ class PxtGriffeExtension(Extension):
47
47
  logger.warning(f'Parameter `{griffe_param.name}` not found in signature for UDF: {udf.display_name}')
48
48
  continue
49
49
  pxt_param = udf.signature.parameters[griffe_param.name]
50
- griffe_param.annotation = self.__column_type_to_display_str(pxt_param.col_type)
51
-
52
- def __column_type_to_display_str(self, column_type: Optional[pxt.ColumnType]) -> str:
53
- # TODO: When we enhance the Pixeltable type system, we may want to refactor some of this logic out.
54
- # I'm putting it here for now though.
55
- if column_type is None:
56
- return 'None'
57
- if column_type.is_string_type():
58
- base = 'str'
59
- elif column_type.is_int_type():
60
- base = 'int'
61
- elif column_type.is_float_type():
62
- base = 'float'
63
- elif column_type.is_bool_type():
64
- base = 'bool'
65
- elif column_type.is_timestamp_type():
66
- base = 'datetime'
67
- elif column_type.is_array_type():
68
- base = 'ArrayT'
69
- elif column_type.is_json_type():
70
- base = 'JsonT'
71
- elif column_type.is_image_type():
72
- base = 'ImageT'
73
- elif column_type.is_video_type():
74
- base = 'VideoT'
75
- elif column_type.is_audio_type():
76
- base = 'AudioT'
77
- elif column_type.is_document_type():
78
- base = 'DocumentT'
79
- else:
80
- assert False
81
- return f'Optional[{base}]' if column_type.nullable else base
50
+ griffe_param.annotation = str(pxt_param.col_type)
@@ -6,4 +6,4 @@ import pixeltable as pxt
6
6
  # TODO This can go away once we have the ability to inline expr_udf's
7
7
  @pxt.expr_udf
8
8
  def clip_text_embed(txt: str) -> np.ndarray:
9
- return pxt.functions.huggingface.clip_text(txt, model_id='openai/clip-vit-base-patch32')
9
+ return pxt.functions.huggingface.clip_text(txt, model_id='openai/clip-vit-base-patch32') # type: ignore[return-value]
@@ -0,0 +1,55 @@
1
+ from typing import Callable, Optional
2
+
3
+ from mypy import nodes
4
+ from mypy.plugin import AnalyzeTypeContext, ClassDefContext, Plugin
5
+ from mypy.plugins.common import add_method_to_class
6
+ from mypy.types import AnyType, Type, TypeOfAny
7
+
8
+ import pixeltable as pxt
9
+
10
+
11
+ class PxtPlugin(Plugin):
12
+ __UDA_FULLNAME = f'{pxt.uda.__module__}.{pxt.uda.__name__}'
13
+ __TYPE_MAP = {
14
+ pxt.Json: 'typing.Any',
15
+ pxt.Array: 'numpy.ndarray',
16
+ pxt.Image: 'PIL.Image.Image',
17
+ pxt.Video: 'builtins.str',
18
+ pxt.Audio: 'builtins.str',
19
+ pxt.Document: 'builtins.str',
20
+ }
21
+ __FULLNAME_MAP = {
22
+ f'{k.__module__}.{k.__name__}': v
23
+ for k, v in __TYPE_MAP.items()
24
+ }
25
+
26
+ def get_type_analyze_hook(self, fullname: str) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
27
+ if fullname in self.__FULLNAME_MAP:
28
+ subst_name = self.__FULLNAME_MAP[fullname]
29
+ return lambda ctx: pxt_hook(ctx, subst_name)
30
+ return None
31
+
32
+ def get_class_decorator_hook_2(self, fullname: str) -> Optional[Callable[[ClassDefContext], bool]]:
33
+ if fullname == self.__UDA_FULLNAME:
34
+ return pxt_decorator_hook
35
+ return None
36
+
37
+ def plugin(version: str) -> type:
38
+ return PxtPlugin
39
+
40
+ def pxt_hook(ctx: AnalyzeTypeContext, subst_name: str) -> Type:
41
+ if subst_name == 'typing.Any':
42
+ return AnyType(TypeOfAny.special_form)
43
+ return ctx.api.named_type(subst_name, [])
44
+
45
+ def pxt_decorator_hook(ctx: ClassDefContext) -> bool:
46
+ arg = nodes.Argument(nodes.Var('fn'), AnyType(TypeOfAny.special_form), None, nodes.ARG_POS)
47
+ add_method_to_class(
48
+ ctx.api,
49
+ ctx.cls,
50
+ "to_sql",
51
+ args=[arg],
52
+ return_type=AnyType(TypeOfAny.special_form),
53
+ is_staticmethod=True,
54
+ )
55
+ return True