pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +7 -19
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +7 -7
- pixeltable/catalog/column.py +37 -11
- pixeltable/catalog/globals.py +21 -0
- pixeltable/catalog/insertable_table.py +6 -4
- pixeltable/catalog/table.py +227 -148
- pixeltable/catalog/table_version.py +66 -28
- pixeltable/catalog/table_version_path.py +0 -8
- pixeltable/catalog/view.py +18 -19
- pixeltable/dataframe.py +16 -32
- pixeltable/env.py +6 -1
- pixeltable/exec/__init__.py +1 -2
- pixeltable/exec/aggregation_node.py +27 -17
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/data_row_batch.py +9 -26
- pixeltable/exec/exec_node.py +36 -7
- pixeltable/exec/expr_eval_node.py +19 -11
- pixeltable/exec/in_memory_data_node.py +14 -11
- pixeltable/exec/sql_node.py +266 -138
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +3 -1
- pixeltable/exprs/array_slice.py +7 -7
- pixeltable/exprs/column_property_ref.py +37 -10
- pixeltable/exprs/column_ref.py +93 -14
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +8 -7
- pixeltable/exprs/data_row.py +56 -36
- pixeltable/exprs/expr.py +65 -63
- pixeltable/exprs/expr_dict.py +55 -0
- pixeltable/exprs/expr_set.py +26 -15
- pixeltable/exprs/function_call.py +53 -24
- pixeltable/exprs/globals.py +4 -1
- pixeltable/exprs/in_predicate.py +8 -7
- pixeltable/exprs/inline_expr.py +4 -4
- pixeltable/exprs/is_null.py +4 -4
- pixeltable/exprs/json_mapper.py +11 -12
- pixeltable/exprs/json_path.py +5 -10
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +5 -4
- pixeltable/exprs/object_ref.py +2 -1
- pixeltable/exprs/row_builder.py +88 -36
- pixeltable/exprs/rowid_ref.py +14 -13
- pixeltable/exprs/similarity_expr.py +12 -7
- pixeltable/exprs/sql_element_cache.py +12 -6
- pixeltable/exprs/type_cast.py +8 -6
- pixeltable/exprs/variable.py +5 -4
- pixeltable/ext/functions/whisperx.py +7 -2
- pixeltable/func/aggregate_function.py +1 -1
- pixeltable/func/callable_function.py +2 -2
- pixeltable/func/function.py +11 -10
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/query_template_function.py +11 -12
- pixeltable/func/signature.py +17 -15
- pixeltable/func/udf.py +0 -4
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/audio.py +4 -6
- pixeltable/functions/globals.py +84 -42
- pixeltable/functions/huggingface.py +31 -34
- pixeltable/functions/image.py +59 -45
- pixeltable/functions/json.py +0 -1
- pixeltable/functions/llama_cpp.py +106 -0
- pixeltable/functions/mistralai.py +2 -2
- pixeltable/functions/ollama.py +147 -0
- pixeltable/functions/openai.py +22 -25
- pixeltable/functions/replicate.py +72 -0
- pixeltable/functions/string.py +59 -50
- pixeltable/functions/timestamp.py +20 -20
- pixeltable/functions/together.py +2 -2
- pixeltable/functions/video.py +11 -20
- pixeltable/functions/whisper.py +2 -20
- pixeltable/globals.py +65 -74
- pixeltable/index/base.py +2 -2
- pixeltable/index/btree.py +20 -7
- pixeltable/index/embedding_index.py +12 -14
- pixeltable/io/__init__.py +1 -2
- pixeltable/io/external_store.py +11 -5
- pixeltable/io/fiftyone.py +178 -0
- pixeltable/io/globals.py +98 -2
- pixeltable/io/hf_datasets.py +1 -1
- pixeltable/io/label_studio.py +6 -6
- pixeltable/io/parquet.py +14 -13
- pixeltable/iterators/base.py +3 -2
- pixeltable/iterators/document.py +10 -8
- pixeltable/iterators/video.py +126 -60
- pixeltable/metadata/__init__.py +4 -3
- pixeltable/metadata/converters/convert_14.py +4 -2
- pixeltable/metadata/converters/convert_15.py +1 -1
- pixeltable/metadata/converters/convert_19.py +1 -0
- pixeltable/metadata/converters/convert_20.py +1 -1
- pixeltable/metadata/converters/convert_21.py +34 -0
- pixeltable/metadata/converters/util.py +54 -12
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +40 -21
- pixeltable/plan.py +149 -165
- pixeltable/py.typed +0 -0
- pixeltable/store.py +57 -37
- pixeltable/tool/create_test_db_dump.py +6 -6
- pixeltable/tool/create_test_video.py +1 -1
- pixeltable/tool/doc_plugins/griffe.py +3 -34
- pixeltable/tool/embed_udf.py +1 -1
- pixeltable/tool/mypy_plugin.py +55 -0
- pixeltable/type_system.py +260 -61
- pixeltable/utils/arrow.py +10 -9
- pixeltable/utils/coco.py +4 -4
- pixeltable/utils/documents.py +16 -2
- pixeltable/utils/filecache.py +9 -9
- pixeltable/utils/formatter.py +10 -11
- pixeltable/utils/http_server.py +2 -5
- pixeltable/utils/media_store.py +6 -6
- pixeltable/utils/pytorch.py +10 -11
- pixeltable/utils/sql.py +2 -1
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
- pixeltable-0.2.22.dist-info/RECORD +153 -0
- pixeltable/exec/media_validation_node.py +0 -43
- pixeltable/utils/help.py +0 -11
- pixeltable-0.2.20.dist-info/RECORD +0 -147
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
pixeltable/store.py
CHANGED
|
@@ -7,18 +7,19 @@ import sys
|
|
|
7
7
|
import urllib.parse
|
|
8
8
|
import urllib.request
|
|
9
9
|
import warnings
|
|
10
|
-
from typing import
|
|
10
|
+
from typing import Any, Iterator, Literal, Optional, Union
|
|
11
11
|
|
|
12
12
|
import sqlalchemy as sql
|
|
13
|
-
from tqdm import
|
|
13
|
+
from tqdm import TqdmWarning, tqdm
|
|
14
14
|
|
|
15
15
|
import pixeltable.catalog as catalog
|
|
16
16
|
import pixeltable.env as env
|
|
17
|
+
import pixeltable.exceptions as excs
|
|
17
18
|
from pixeltable import exprs
|
|
18
19
|
from pixeltable.exec import ExecNode
|
|
19
20
|
from pixeltable.metadata import schema
|
|
20
21
|
from pixeltable.utils.media_store import MediaStore
|
|
21
|
-
from pixeltable.utils.sql import
|
|
22
|
+
from pixeltable.utils.sql import log_explain, log_stmt
|
|
22
23
|
|
|
23
24
|
_logger = logging.getLogger('pixeltable')
|
|
24
25
|
|
|
@@ -31,35 +32,42 @@ class StoreBase:
|
|
|
31
32
|
- v_min: version at which the row was created
|
|
32
33
|
- v_max: version at which the row was deleted (or MAX_VERSION if it's still live)
|
|
33
34
|
"""
|
|
35
|
+
tbl_version: catalog.TableVersion
|
|
36
|
+
sa_md: sql.MetaData
|
|
37
|
+
sa_tbl: Optional[sql.Table]
|
|
38
|
+
_pk_cols: list[sql.Column]
|
|
39
|
+
v_min_col: sql.Column
|
|
40
|
+
v_max_col: sql.Column
|
|
41
|
+
base: Optional[StoreBase]
|
|
34
42
|
|
|
35
43
|
__INSERT_BATCH_SIZE = 1000
|
|
36
44
|
|
|
37
45
|
def __init__(self, tbl_version: catalog.TableVersion):
|
|
38
46
|
self.tbl_version = tbl_version
|
|
39
47
|
self.sa_md = sql.MetaData()
|
|
40
|
-
self.sa_tbl
|
|
48
|
+
self.sa_tbl = None
|
|
41
49
|
# We need to declare a `base` variable here, even though it's only defined for instances of `StoreView`,
|
|
42
50
|
# since it's referenced by various methods of `StoreBase`
|
|
43
51
|
self.base = None if tbl_version.base is None else tbl_version.base.store_tbl
|
|
44
52
|
self.create_sa_tbl()
|
|
45
53
|
|
|
46
|
-
def pk_columns(self) ->
|
|
47
|
-
return self.
|
|
54
|
+
def pk_columns(self) -> list[sql.Column]:
|
|
55
|
+
return self._pk_cols
|
|
48
56
|
|
|
49
|
-
def rowid_columns(self) ->
|
|
50
|
-
return self.
|
|
57
|
+
def rowid_columns(self) -> list[sql.Column]:
|
|
58
|
+
return self._pk_cols[:-1]
|
|
51
59
|
|
|
52
60
|
@abc.abstractmethod
|
|
53
|
-
def _create_rowid_columns(self) ->
|
|
61
|
+
def _create_rowid_columns(self) -> list[sql.Column]:
|
|
54
62
|
"""Create and return rowid columns"""
|
|
55
63
|
|
|
56
|
-
def _create_system_columns(self) ->
|
|
64
|
+
def _create_system_columns(self) -> list[sql.Column]:
|
|
57
65
|
"""Create and return system columns"""
|
|
58
66
|
rowid_cols = self._create_rowid_columns()
|
|
59
67
|
self.v_min_col = sql.Column('v_min', sql.BigInteger, nullable=False)
|
|
60
68
|
self.v_max_col = \
|
|
61
69
|
sql.Column('v_max', sql.BigInteger, nullable=False, server_default=str(schema.Table.MAX_VERSION))
|
|
62
|
-
self.
|
|
70
|
+
self._pk_cols = [*rowid_cols, self.v_min_col]
|
|
63
71
|
return [*rowid_cols, self.v_min_col, self.v_max_col]
|
|
64
72
|
|
|
65
73
|
def create_sa_tbl(self) -> None:
|
|
@@ -79,7 +87,7 @@ class StoreBase:
|
|
|
79
87
|
# if we're called in response to a schema change, we need to remove the old table first
|
|
80
88
|
self.sa_md.remove(self.sa_tbl)
|
|
81
89
|
|
|
82
|
-
idxs:
|
|
90
|
+
idxs: list[sql.Index] = []
|
|
83
91
|
# index for all system columns:
|
|
84
92
|
# - base x view joins can be executed as merge joins
|
|
85
93
|
# - speeds up ORDER BY rowid DESC
|
|
@@ -126,7 +134,7 @@ class StoreBase:
|
|
|
126
134
|
return new_file_url
|
|
127
135
|
|
|
128
136
|
def _move_tmp_media_files(
|
|
129
|
-
self, table_rows:
|
|
137
|
+
self, table_rows: list[dict[str, Any]], media_cols: list[catalog.Column], v_min: int
|
|
130
138
|
) -> None:
|
|
131
139
|
"""Move tmp media files that we generated to a permanent location"""
|
|
132
140
|
for c in media_cols:
|
|
@@ -135,23 +143,17 @@ class StoreBase:
|
|
|
135
143
|
table_row[c.store_name()] = self._move_tmp_media_file(file_url, c, v_min)
|
|
136
144
|
|
|
137
145
|
def _create_table_row(
|
|
138
|
-
self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder,
|
|
139
|
-
|
|
140
|
-
) -> Tuple[Dict[str, Any], int]:
|
|
146
|
+
self, input_row: exprs.DataRow, row_builder: exprs.RowBuilder, exc_col_ids: set[int], pk: tuple[int, ...]
|
|
147
|
+
) -> tuple[dict[str, Any], int]:
|
|
141
148
|
"""Return Tuple[complete table row, # of exceptions] for insert()
|
|
142
149
|
Creates a row that includes the PK columns, with the values from input_row.pk.
|
|
143
150
|
Returns:
|
|
144
151
|
Tuple[complete table row, # of exceptions]
|
|
145
152
|
"""
|
|
146
153
|
table_row, num_excs = row_builder.create_table_row(input_row, exc_col_ids)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
if pk_col == self.v_min_col:
|
|
151
|
-
table_row[pk_col.name] = v_min
|
|
152
|
-
else:
|
|
153
|
-
table_row[pk_col.name] = pk_val
|
|
154
|
-
|
|
154
|
+
assert len(pk) == len(self._pk_cols)
|
|
155
|
+
for pk_col, pk_val in zip(self._pk_cols, pk):
|
|
156
|
+
table_row[pk_col.name] = pk_val
|
|
155
157
|
return table_row, num_excs
|
|
156
158
|
|
|
157
159
|
def count(self, conn: Optional[sql.engine.Connection] = None) -> int:
|
|
@@ -212,14 +214,20 @@ class StoreBase:
|
|
|
212
214
|
conn.execute(sql.text(stmt))
|
|
213
215
|
|
|
214
216
|
def load_column(
|
|
215
|
-
|
|
217
|
+
self,
|
|
218
|
+
col: catalog.Column,
|
|
219
|
+
exec_plan: ExecNode,
|
|
220
|
+
value_expr_slot_idx: int,
|
|
221
|
+
conn: sql.engine.Connection,
|
|
222
|
+
on_error: Literal['abort', 'ignore']
|
|
216
223
|
) -> int:
|
|
217
224
|
"""Update store column of a computed column with values produced by an execution plan
|
|
218
225
|
|
|
219
226
|
Returns:
|
|
220
227
|
number of rows with exceptions
|
|
221
228
|
Raises:
|
|
222
|
-
sql.exc.DBAPIError if there was
|
|
229
|
+
sql.exc.DBAPIError if there was a SQL error during execution
|
|
230
|
+
excs.Error if on_error='abort' and there was an exception during row evaluation
|
|
223
231
|
"""
|
|
224
232
|
num_excs = 0
|
|
225
233
|
num_rows = 0
|
|
@@ -253,6 +261,10 @@ class StoreBase:
|
|
|
253
261
|
if result_row.has_exc(value_expr_slot_idx):
|
|
254
262
|
num_excs += 1
|
|
255
263
|
value_exc = result_row.get_exc(value_expr_slot_idx)
|
|
264
|
+
if on_error == 'abort':
|
|
265
|
+
raise excs.Error(
|
|
266
|
+
f'Error while evaluating computed column `{col.name}`:\n{value_exc}'
|
|
267
|
+
) from value_exc
|
|
256
268
|
# we store a NULL value and record the exception/exc type
|
|
257
269
|
error_type = type(value_exc).__name__
|
|
258
270
|
error_msg = str(value_exc)
|
|
@@ -291,8 +303,8 @@ class StoreBase:
|
|
|
291
303
|
|
|
292
304
|
def insert_rows(
|
|
293
305
|
self, exec_plan: ExecNode, conn: sql.engine.Connection, v_min: Optional[int] = None,
|
|
294
|
-
show_progress: bool = True
|
|
295
|
-
) ->
|
|
306
|
+
show_progress: bool = True, rowids: Optional[Iterator[int]] = None, abort_on_exc: bool = False
|
|
307
|
+
) -> tuple[int, int, set[int]]:
|
|
296
308
|
"""Insert rows into the store table and update the catalog table's md
|
|
297
309
|
Returns:
|
|
298
310
|
number of inserted rows, number of exceptions, set of column ids that have exceptions
|
|
@@ -302,7 +314,7 @@ class StoreBase:
|
|
|
302
314
|
# TODO: total?
|
|
303
315
|
num_excs = 0
|
|
304
316
|
num_rows = 0
|
|
305
|
-
cols_with_excs:
|
|
317
|
+
cols_with_excs: set[int] = set()
|
|
306
318
|
progress_bar: Optional[tqdm] = None # create this only after we started executing
|
|
307
319
|
row_builder = exec_plan.row_builder
|
|
308
320
|
media_cols = [info.col for info in row_builder.table_columns if info.col.col_type.is_media_type()]
|
|
@@ -312,13 +324,21 @@ class StoreBase:
|
|
|
312
324
|
num_rows += len(row_batch)
|
|
313
325
|
for batch_start_idx in range(0, len(row_batch), self.__INSERT_BATCH_SIZE):
|
|
314
326
|
# compute batch of rows and convert them into table rows
|
|
315
|
-
table_rows:
|
|
316
|
-
|
|
327
|
+
table_rows: list[dict[str, Any]] = []
|
|
328
|
+
batch_stop_idx = min(batch_start_idx + self.__INSERT_BATCH_SIZE, len(row_batch))
|
|
329
|
+
for row_idx in range(batch_start_idx, batch_stop_idx):
|
|
317
330
|
row = row_batch[row_idx]
|
|
318
|
-
|
|
319
|
-
|
|
331
|
+
# if abort_on_exc == True, we need to check for media validation exceptions
|
|
332
|
+
if abort_on_exc and row.has_exc():
|
|
333
|
+
exc = row.get_first_exc()
|
|
334
|
+
raise exc
|
|
335
|
+
|
|
336
|
+
rowid = (next(rowids),) if rowids is not None else row.pk[:-1]
|
|
337
|
+
pk = rowid + (v_min,)
|
|
338
|
+
table_row, num_row_exc = self._create_table_row(row, row_builder, cols_with_excs, pk=pk)
|
|
320
339
|
num_excs += num_row_exc
|
|
321
340
|
table_rows.append(table_row)
|
|
341
|
+
|
|
322
342
|
if show_progress:
|
|
323
343
|
if progress_bar is None:
|
|
324
344
|
warnings.simplefilter("ignore", category=TqdmWarning)
|
|
@@ -353,7 +373,7 @@ class StoreBase:
|
|
|
353
373
|
return sql.and_(clause, self.base._versions_clause(versions[1:], match_on_vmin))
|
|
354
374
|
|
|
355
375
|
def delete_rows(
|
|
356
|
-
self, current_version: int, base_versions:
|
|
376
|
+
self, current_version: int, base_versions: list[Optional[int]], match_on_vmin: bool,
|
|
357
377
|
where_clause: Optional[sql.ColumnElement[bool]], conn: sql.engine.Connection) -> int:
|
|
358
378
|
"""Mark rows as deleted that are live and were created prior to current_version.
|
|
359
379
|
Also: populate the undo columns
|
|
@@ -397,7 +417,7 @@ class StoreTable(StoreBase):
|
|
|
397
417
|
assert not tbl_version.is_view()
|
|
398
418
|
super().__init__(tbl_version)
|
|
399
419
|
|
|
400
|
-
def _create_rowid_columns(self) ->
|
|
420
|
+
def _create_rowid_columns(self) -> list[sql.Column]:
|
|
401
421
|
self.rowid_col = sql.Column('rowid', sql.BigInteger, nullable=False)
|
|
402
422
|
return [self.rowid_col]
|
|
403
423
|
|
|
@@ -413,7 +433,7 @@ class StoreView(StoreBase):
|
|
|
413
433
|
assert catalog_view.is_view()
|
|
414
434
|
super().__init__(catalog_view)
|
|
415
435
|
|
|
416
|
-
def _create_rowid_columns(self) ->
|
|
436
|
+
def _create_rowid_columns(self) -> list[sql.Column]:
|
|
417
437
|
# a view row corresponds directly to a single base row, which means it needs to duplicate its rowid columns
|
|
418
438
|
self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
|
|
419
439
|
return self.rowid_cols
|
|
@@ -439,7 +459,7 @@ class StoreComponentView(StoreView):
|
|
|
439
459
|
def __init__(self, catalog_view: catalog.TableVersion):
|
|
440
460
|
super().__init__(catalog_view)
|
|
441
461
|
|
|
442
|
-
def _create_rowid_columns(self) ->
|
|
462
|
+
def _create_rowid_columns(self) -> list[sql.Column]:
|
|
443
463
|
# each base row is expanded into n view rows
|
|
444
464
|
self.rowid_cols = [sql.Column(c.name, c.type) for c in self.base.rowid_columns()]
|
|
445
465
|
# name of pos column: avoid collisions with bases' pos columns
|
|
@@ -149,18 +149,18 @@ class Dumper:
|
|
|
149
149
|
pxt.create_dir('views')
|
|
150
150
|
|
|
151
151
|
# simple view
|
|
152
|
-
v = pxt.create_view('views.view', t
|
|
152
|
+
v = pxt.create_view('views.view', t.where(t.c2 < 50))
|
|
153
153
|
self.__add_expr_columns(v, 'view')
|
|
154
154
|
|
|
155
155
|
# snapshot
|
|
156
|
-
_ = pxt.create_view('views.snapshot', t
|
|
156
|
+
_ = pxt.create_view('views.snapshot', t.where(t.c2 >= 75), is_snapshot=True)
|
|
157
157
|
|
|
158
158
|
# view of views
|
|
159
|
-
vv = pxt.create_view('views.view_of_views', v
|
|
159
|
+
vv = pxt.create_view('views.view_of_views', v.where(t.c2 >= 25))
|
|
160
160
|
self.__add_expr_columns(vv, 'view_of_views')
|
|
161
161
|
|
|
162
162
|
# empty view
|
|
163
|
-
e = pxt.create_view('views.empty_view', t
|
|
163
|
+
e = pxt.create_view('views.empty_view', t.where(t.c2 == 4171780))
|
|
164
164
|
assert e.count() == 0
|
|
165
165
|
self.__add_expr_columns(e, 'empty_view', include_expensive_functions=True)
|
|
166
166
|
|
|
@@ -278,13 +278,13 @@ class Dumper:
|
|
|
278
278
|
# this breaks; TODO: why?
|
|
279
279
|
#return t.where(t.c2 < i)
|
|
280
280
|
return t.where(t.c2 < i).select(t.c1, t.c2)
|
|
281
|
-
add_column('query_output', t.q1(t.c2))
|
|
281
|
+
add_column('query_output', t.queries.q1(t.c2))
|
|
282
282
|
|
|
283
283
|
@t.query
|
|
284
284
|
def q2(s: str):
|
|
285
285
|
sim = t[f'{col_prefix}_function_call'].similarity(s)
|
|
286
286
|
return t.order_by(sim, asc=False).select(t[f'{col_prefix}_function_call']).limit(5)
|
|
287
|
-
add_column('sim_output', t.q2(t.c1))
|
|
287
|
+
add_column('sim_output', t.queries.q2(t.c1))
|
|
288
288
|
|
|
289
289
|
|
|
290
290
|
@pxt.udf(_force_stored=True)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import ast
|
|
2
|
-
from typing import Optional, Union
|
|
3
2
|
import warnings
|
|
3
|
+
from typing import Optional, Union
|
|
4
4
|
|
|
5
5
|
import griffe
|
|
6
6
|
import griffe.expressions
|
|
@@ -39,7 +39,7 @@ class PxtGriffeExtension(Extension):
|
|
|
39
39
|
udf = griffe.dynamic_import(func.path)
|
|
40
40
|
assert isinstance(udf, pxt.Function)
|
|
41
41
|
# Convert the return type to a Pixeltable type reference
|
|
42
|
-
func.returns =
|
|
42
|
+
func.returns = str(udf.signature.get_return_type())
|
|
43
43
|
# Convert the parameter types to Pixeltable type references
|
|
44
44
|
for griffe_param in func.parameters:
|
|
45
45
|
assert isinstance(griffe_param.annotation, griffe.expressions.Expr)
|
|
@@ -47,35 +47,4 @@ class PxtGriffeExtension(Extension):
|
|
|
47
47
|
logger.warning(f'Parameter `{griffe_param.name}` not found in signature for UDF: {udf.display_name}')
|
|
48
48
|
continue
|
|
49
49
|
pxt_param = udf.signature.parameters[griffe_param.name]
|
|
50
|
-
griffe_param.annotation =
|
|
51
|
-
|
|
52
|
-
def __column_type_to_display_str(self, column_type: Optional[pxt.ColumnType]) -> str:
|
|
53
|
-
# TODO: When we enhance the Pixeltable type system, we may want to refactor some of this logic out.
|
|
54
|
-
# I'm putting it here for now though.
|
|
55
|
-
if column_type is None:
|
|
56
|
-
return 'None'
|
|
57
|
-
if column_type.is_string_type():
|
|
58
|
-
base = 'str'
|
|
59
|
-
elif column_type.is_int_type():
|
|
60
|
-
base = 'int'
|
|
61
|
-
elif column_type.is_float_type():
|
|
62
|
-
base = 'float'
|
|
63
|
-
elif column_type.is_bool_type():
|
|
64
|
-
base = 'bool'
|
|
65
|
-
elif column_type.is_timestamp_type():
|
|
66
|
-
base = 'datetime'
|
|
67
|
-
elif column_type.is_array_type():
|
|
68
|
-
base = 'ArrayT'
|
|
69
|
-
elif column_type.is_json_type():
|
|
70
|
-
base = 'JsonT'
|
|
71
|
-
elif column_type.is_image_type():
|
|
72
|
-
base = 'ImageT'
|
|
73
|
-
elif column_type.is_video_type():
|
|
74
|
-
base = 'VideoT'
|
|
75
|
-
elif column_type.is_audio_type():
|
|
76
|
-
base = 'AudioT'
|
|
77
|
-
elif column_type.is_document_type():
|
|
78
|
-
base = 'DocumentT'
|
|
79
|
-
else:
|
|
80
|
-
assert False
|
|
81
|
-
return f'Optional[{base}]' if column_type.nullable else base
|
|
50
|
+
griffe_param.annotation = str(pxt_param.col_type)
|
pixeltable/tool/embed_udf.py
CHANGED
|
@@ -6,4 +6,4 @@ import pixeltable as pxt
|
|
|
6
6
|
# TODO This can go away once we have the ability to inline expr_udf's
|
|
7
7
|
@pxt.expr_udf
|
|
8
8
|
def clip_text_embed(txt: str) -> np.ndarray:
|
|
9
|
-
return pxt.functions.huggingface.clip_text(txt, model_id='openai/clip-vit-base-patch32')
|
|
9
|
+
return pxt.functions.huggingface.clip_text(txt, model_id='openai/clip-vit-base-patch32') # type: ignore[return-value]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
from mypy import nodes
|
|
4
|
+
from mypy.plugin import AnalyzeTypeContext, ClassDefContext, Plugin
|
|
5
|
+
from mypy.plugins.common import add_method_to_class
|
|
6
|
+
from mypy.types import AnyType, Type, TypeOfAny
|
|
7
|
+
|
|
8
|
+
import pixeltable as pxt
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PxtPlugin(Plugin):
|
|
12
|
+
__UDA_FULLNAME = f'{pxt.uda.__module__}.{pxt.uda.__name__}'
|
|
13
|
+
__TYPE_MAP = {
|
|
14
|
+
pxt.Json: 'typing.Any',
|
|
15
|
+
pxt.Array: 'numpy.ndarray',
|
|
16
|
+
pxt.Image: 'PIL.Image.Image',
|
|
17
|
+
pxt.Video: 'builtins.str',
|
|
18
|
+
pxt.Audio: 'builtins.str',
|
|
19
|
+
pxt.Document: 'builtins.str',
|
|
20
|
+
}
|
|
21
|
+
__FULLNAME_MAP = {
|
|
22
|
+
f'{k.__module__}.{k.__name__}': v
|
|
23
|
+
for k, v in __TYPE_MAP.items()
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def get_type_analyze_hook(self, fullname: str) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
|
|
27
|
+
if fullname in self.__FULLNAME_MAP:
|
|
28
|
+
subst_name = self.__FULLNAME_MAP[fullname]
|
|
29
|
+
return lambda ctx: pxt_hook(ctx, subst_name)
|
|
30
|
+
return None
|
|
31
|
+
|
|
32
|
+
def get_class_decorator_hook_2(self, fullname: str) -> Optional[Callable[[ClassDefContext], bool]]:
|
|
33
|
+
if fullname == self.__UDA_FULLNAME:
|
|
34
|
+
return pxt_decorator_hook
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
def plugin(version: str) -> type:
|
|
38
|
+
return PxtPlugin
|
|
39
|
+
|
|
40
|
+
def pxt_hook(ctx: AnalyzeTypeContext, subst_name: str) -> Type:
|
|
41
|
+
if subst_name == 'typing.Any':
|
|
42
|
+
return AnyType(TypeOfAny.special_form)
|
|
43
|
+
return ctx.api.named_type(subst_name, [])
|
|
44
|
+
|
|
45
|
+
def pxt_decorator_hook(ctx: ClassDefContext) -> bool:
|
|
46
|
+
arg = nodes.Argument(nodes.Var('fn'), AnyType(TypeOfAny.special_form), None, nodes.ARG_POS)
|
|
47
|
+
add_method_to_class(
|
|
48
|
+
ctx.api,
|
|
49
|
+
ctx.cls,
|
|
50
|
+
"to_sql",
|
|
51
|
+
args=[arg],
|
|
52
|
+
return_type=AnyType(TypeOfAny.special_form),
|
|
53
|
+
is_staticmethod=True,
|
|
54
|
+
)
|
|
55
|
+
return True
|