pixeltable 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +3 -1
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/column.py +8 -2
- pixeltable/catalog/insertable_table.py +32 -17
- pixeltable/catalog/table.py +167 -12
- pixeltable/catalog/table_version.py +185 -106
- pixeltable/datatransfer/__init__.py +1 -0
- pixeltable/datatransfer/label_studio.py +452 -0
- pixeltable/datatransfer/remote.py +85 -0
- pixeltable/env.py +148 -69
- pixeltable/exprs/column_ref.py +2 -2
- pixeltable/exprs/comparison.py +39 -1
- pixeltable/exprs/data_row.py +7 -0
- pixeltable/exprs/expr.py +11 -12
- pixeltable/exprs/function_call.py +0 -3
- pixeltable/exprs/globals.py +14 -2
- pixeltable/exprs/similarity_expr.py +5 -3
- pixeltable/ext/functions/whisperx.py +30 -0
- pixeltable/ext/functions/yolox.py +16 -0
- pixeltable/func/aggregate_function.py +2 -2
- pixeltable/func/expr_template_function.py +3 -1
- pixeltable/func/udf.py +2 -2
- pixeltable/functions/fireworks.py +9 -4
- pixeltable/functions/huggingface.py +25 -1
- pixeltable/functions/openai.py +15 -10
- pixeltable/functions/together.py +11 -6
- pixeltable/functions/util.py +0 -43
- pixeltable/functions/video.py +46 -8
- pixeltable/globals.py +20 -2
- pixeltable/index/__init__.py +1 -0
- pixeltable/index/base.py +6 -1
- pixeltable/index/btree.py +54 -0
- pixeltable/index/embedding_index.py +4 -1
- pixeltable/io/__init__.py +1 -0
- pixeltable/io/globals.py +58 -0
- pixeltable/iterators/base.py +4 -4
- pixeltable/iterators/document.py +26 -15
- pixeltable/iterators/video.py +9 -1
- pixeltable/metadata/__init__.py +2 -2
- pixeltable/metadata/converters/convert_14.py +13 -0
- pixeltable/metadata/schema.py +9 -6
- pixeltable/plan.py +9 -5
- pixeltable/store.py +14 -21
- pixeltable/tool/create_test_db_dump.py +14 -0
- pixeltable/type_system.py +14 -4
- pixeltable/utils/coco.py +94 -0
- pixeltable-0.2.8.dist-info/METADATA +137 -0
- {pixeltable-0.2.6.dist-info → pixeltable-0.2.8.dist-info}/RECORD +50 -45
- pixeltable/func/nos_function.py +0 -202
- pixeltable/utils/clip.py +0 -18
- pixeltable-0.2.6.dist-info/METADATA +0 -131
- {pixeltable-0.2.6.dist-info → pixeltable-0.2.8.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.6.dist-info → pixeltable-0.2.8.dist-info}/WHEEL +0 -0
pixeltable/env.py
CHANGED
|
@@ -5,20 +5,20 @@ import glob
|
|
|
5
5
|
import http.server
|
|
6
6
|
import importlib
|
|
7
7
|
import importlib.util
|
|
8
|
+
import inspect
|
|
8
9
|
import logging
|
|
9
10
|
import os
|
|
10
|
-
import socketserver
|
|
11
11
|
import sys
|
|
12
12
|
import threading
|
|
13
13
|
import uuid
|
|
14
14
|
import warnings
|
|
15
|
+
from dataclasses import dataclass
|
|
15
16
|
from pathlib import Path
|
|
16
17
|
from typing import Callable, Optional, Dict, Any, List
|
|
17
18
|
|
|
18
19
|
import pgserver
|
|
19
20
|
import sqlalchemy as sql
|
|
20
21
|
import yaml
|
|
21
|
-
from sqlalchemy_utils.functions import database_exists, create_database, drop_database
|
|
22
22
|
from tqdm import TqdmWarning
|
|
23
23
|
|
|
24
24
|
import pixeltable.exceptions as excs
|
|
@@ -59,12 +59,11 @@ class Env:
|
|
|
59
59
|
# info about installed packages that are utilized by some parts of the code;
|
|
60
60
|
# package name -> version; version == []: package is installed, but we haven't determined the version yet
|
|
61
61
|
self._installed_packages: Dict[str, Optional[List[int]]] = {}
|
|
62
|
-
self._nos_client: Optional[Any] = None
|
|
63
62
|
self._spacy_nlp: Optional[Any] = None # spacy.Language
|
|
64
|
-
self._httpd: Optional[http.server.
|
|
63
|
+
self._httpd: Optional[http.server.HTTPServer] = None
|
|
65
64
|
self._http_address: Optional[str] = None
|
|
66
65
|
|
|
67
|
-
self._registered_clients: dict[str,
|
|
66
|
+
self._registered_clients: dict[str, ApiClient] = {}
|
|
68
67
|
|
|
69
68
|
# logging-related state
|
|
70
69
|
self._logger = logging.getLogger('pixeltable')
|
|
@@ -120,8 +119,8 @@ class Env:
|
|
|
120
119
|
if level is not None:
|
|
121
120
|
self.set_log_level(level)
|
|
122
121
|
if add is not None:
|
|
123
|
-
for module,
|
|
124
|
-
self.set_module_log_level(module, int(
|
|
122
|
+
for module, level_str in [t.split(':') for t in add.split(',')]:
|
|
123
|
+
self.set_module_log_level(module, int(level_str))
|
|
125
124
|
if remove is not None:
|
|
126
125
|
for module in remove.split(','):
|
|
127
126
|
self.set_module_log_level(module, None)
|
|
@@ -263,24 +262,19 @@ class Env:
|
|
|
263
262
|
self._db_url = self._db_server.get_uri(database=self._db_name)
|
|
264
263
|
|
|
265
264
|
if reinit_db:
|
|
266
|
-
if
|
|
267
|
-
|
|
265
|
+
if self._store_db_exists():
|
|
266
|
+
self._drop_store_db()
|
|
268
267
|
|
|
269
|
-
if not
|
|
268
|
+
if not self._store_db_exists():
|
|
270
269
|
self._logger.info(f'creating database at {self.db_url}')
|
|
271
|
-
|
|
272
|
-
self.
|
|
270
|
+
self._create_store_db()
|
|
271
|
+
self._create_engine(echo=echo)
|
|
273
272
|
from pixeltable.metadata import schema
|
|
274
|
-
|
|
275
273
|
schema.Base.metadata.create_all(self._sa_engine)
|
|
276
274
|
metadata.create_system_info(self._sa_engine)
|
|
277
|
-
# enable pgvector
|
|
278
|
-
with self._sa_engine.begin() as conn:
|
|
279
|
-
conn.execute(sql.text('CREATE EXTENSION vector'))
|
|
280
275
|
else:
|
|
281
276
|
self._logger.info(f'found database {self.db_url}')
|
|
282
|
-
|
|
283
|
-
self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True)
|
|
277
|
+
self._create_engine(echo=echo)
|
|
284
278
|
|
|
285
279
|
print(f'Connected to Pixeltable database at: {self.db_url}')
|
|
286
280
|
|
|
@@ -288,57 +282,110 @@ class Env:
|
|
|
288
282
|
self._set_up_runtime()
|
|
289
283
|
self.log_to_stdout(False)
|
|
290
284
|
|
|
291
|
-
def
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
def _create_nos_client(self) -> None:
|
|
295
|
-
import nos
|
|
285
|
+
def _create_engine(self, echo: bool = False) -> None:
|
|
286
|
+
self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True, isolation_level='AUTOCOMMIT')
|
|
296
287
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
self.
|
|
300
|
-
self.
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
288
|
+
def _store_db_exists(self) -> bool:
|
|
289
|
+
assert self._db_name is not None
|
|
290
|
+
# don't try to connect to self.db_name, it may not exist
|
|
291
|
+
db_url = self._db_server.get_uri(database='postgres')
|
|
292
|
+
engine = sql.create_engine(db_url, future=True)
|
|
293
|
+
try:
|
|
294
|
+
with engine.begin() as conn:
|
|
295
|
+
stmt = f"SELECT COUNT(*) FROM pg_database WHERE datname = '{self._db_name}'"
|
|
296
|
+
result = conn.scalar(sql.text(stmt))
|
|
297
|
+
assert result <= 1
|
|
298
|
+
return result == 1
|
|
299
|
+
finally:
|
|
300
|
+
engine.dispose()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _create_store_db(self) -> None:
|
|
304
|
+
assert self._db_name is not None
|
|
305
|
+
# create the db
|
|
306
|
+
pg_db_url = self._db_server.get_uri(database='postgres')
|
|
307
|
+
engine = sql.create_engine(pg_db_url, future=True, isolation_level='AUTOCOMMIT')
|
|
308
|
+
preparer = engine.dialect.identifier_preparer
|
|
309
|
+
try:
|
|
310
|
+
with engine.begin() as conn:
|
|
311
|
+
# use C collation to get standard C/Python-style sorting
|
|
312
|
+
stmt = (
|
|
313
|
+
f"CREATE DATABASE {preparer.quote(self._db_name)} "
|
|
314
|
+
"ENCODING 'utf-8' LC_COLLATE 'C' LC_CTYPE 'C' TEMPLATE template0"
|
|
315
|
+
)
|
|
316
|
+
conn.execute(sql.text(stmt))
|
|
317
|
+
finally:
|
|
318
|
+
engine.dispose()
|
|
305
319
|
|
|
320
|
+
# enable pgvector
|
|
321
|
+
store_db_url = self._db_server.get_uri(database=self._db_name)
|
|
322
|
+
engine = sql.create_engine(store_db_url, future=True, isolation_level='AUTOCOMMIT')
|
|
306
323
|
try:
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
324
|
+
with engine.begin() as conn:
|
|
325
|
+
conn.execute(sql.text('CREATE EXTENSION vector'))
|
|
326
|
+
finally:
|
|
327
|
+
engine.dispose()
|
|
328
|
+
|
|
329
|
+
def _drop_store_db(self) -> None:
|
|
330
|
+
assert self._db_name is not None
|
|
331
|
+
db_url = self._db_server.get_uri(database='postgres')
|
|
332
|
+
engine = sql.create_engine(db_url, future=True, isolation_level='AUTOCOMMIT')
|
|
333
|
+
preparer = engine.dialect.identifier_preparer
|
|
334
|
+
try:
|
|
335
|
+
with engine.begin() as conn:
|
|
336
|
+
# terminate active connections
|
|
337
|
+
stmt = (f"""
|
|
338
|
+
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
|
339
|
+
FROM pg_stat_activity
|
|
340
|
+
WHERE pg_stat_activity.datname = '{self._db_name}'
|
|
341
|
+
AND pid <> pg_backend_pid()
|
|
342
|
+
""")
|
|
343
|
+
conn.execute(sql.text(stmt))
|
|
344
|
+
# drop db
|
|
345
|
+
stmt = f'DROP DATABASE {preparer.quote(self._db_name)}'
|
|
346
|
+
conn.execute(sql.text(stmt))
|
|
347
|
+
finally:
|
|
348
|
+
engine.dispose()
|
|
313
349
|
|
|
314
|
-
|
|
350
|
+
def _upgrade_metadata(self) -> None:
|
|
351
|
+
metadata.upgrade_md(self._sa_engine)
|
|
315
352
|
|
|
316
|
-
def
|
|
317
|
-
|
|
318
|
-
|
|
353
|
+
def _register_client(self, name: str, init_fn: Callable) -> None:
|
|
354
|
+
sig = inspect.signature(init_fn)
|
|
355
|
+
param_names = list(sig.parameters.keys())
|
|
356
|
+
self._registered_clients[name] = ApiClient(init_fn=init_fn, param_names=param_names)
|
|
319
357
|
|
|
320
|
-
|
|
321
|
-
- init: A `Callable` with signature `fn(api_key: str) -> Any` that constructs a client object
|
|
322
|
-
- environ: The name of the environment variable to use for the API key, if no API key is found in config
|
|
323
|
-
(defaults to f'{name.upper()}_API_KEY')
|
|
358
|
+
def get_client(self, name: str) -> Any:
|
|
324
359
|
"""
|
|
325
|
-
|
|
326
|
-
return self._registered_clients[name]
|
|
360
|
+
Gets the client with the specified name, initializing it if necessary.
|
|
327
361
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
if
|
|
336
|
-
|
|
362
|
+
Args:
|
|
363
|
+
- name: The name of the client
|
|
364
|
+
"""
|
|
365
|
+
cl = self._registered_clients[name]
|
|
366
|
+
if cl.client_obj is not None:
|
|
367
|
+
return cl.client_obj # Already initialized
|
|
368
|
+
|
|
369
|
+
# Construct a client. For each client parameter, first check if the parameter is in the environment;
|
|
370
|
+
# if not, look in Pixeltable config from `config.yaml`.
|
|
371
|
+
|
|
372
|
+
init_kwargs: dict[str, str] = {}
|
|
373
|
+
for param in cl.param_names:
|
|
374
|
+
environ = f'{name.upper()}_{param.upper()}'
|
|
375
|
+
if environ in os.environ:
|
|
376
|
+
init_kwargs[param] = os.environ[environ]
|
|
377
|
+
elif name.lower() in self._config and param in self._config[name.lower()]:
|
|
378
|
+
init_kwargs[param] = self._config[name.lower()][param.lower()]
|
|
379
|
+
if param not in init_kwargs or init_kwargs[param] == '':
|
|
380
|
+
raise excs.Error(
|
|
381
|
+
f'`{name}` client not initialized: parameter `{param}` is not configured.\n'
|
|
382
|
+
f'To fix this, specify the `{environ}` environment variable, or put `{param.lower()}` in '
|
|
383
|
+
f'the `{name.lower()}` section of $PIXELTABLE_HOME/config.yaml.'
|
|
384
|
+
)
|
|
337
385
|
|
|
338
|
-
|
|
339
|
-
self._registered_clients[name] = client
|
|
386
|
+
cl.client_obj = cl.init_fn(**init_kwargs)
|
|
340
387
|
self._logger.info(f'Initialized `{name}` client.')
|
|
341
|
-
return
|
|
388
|
+
return cl.client_obj
|
|
342
389
|
|
|
343
390
|
def _start_web_server(self) -> None:
|
|
344
391
|
"""
|
|
@@ -380,6 +427,7 @@ class Env:
|
|
|
380
427
|
check('transformers')
|
|
381
428
|
check('sentence_transformers')
|
|
382
429
|
check('yolox')
|
|
430
|
+
check('whisperx')
|
|
383
431
|
check('boto3')
|
|
384
432
|
check('fitz') # pymupdf
|
|
385
433
|
check('pyarrow')
|
|
@@ -392,9 +440,7 @@ class Env:
|
|
|
392
440
|
check('openai')
|
|
393
441
|
check('together')
|
|
394
442
|
check('fireworks')
|
|
395
|
-
check('
|
|
396
|
-
if self.is_installed_package('nos'):
|
|
397
|
-
self._create_nos_client()
|
|
443
|
+
check('label_studio_sdk')
|
|
398
444
|
check('openpyxl')
|
|
399
445
|
|
|
400
446
|
def require_package(self, package: str, min_version: Optional[List[int]] = None) -> None:
|
|
@@ -405,7 +451,7 @@ class Env:
|
|
|
405
451
|
return
|
|
406
452
|
|
|
407
453
|
# check whether we have a version >= the required one
|
|
408
|
-
if self._installed_packages[package]
|
|
454
|
+
if not self._installed_packages[package]:
|
|
409
455
|
m = importlib.import_module(package)
|
|
410
456
|
module_version = [int(x) for x in m.__version__.split('.')]
|
|
411
457
|
self._installed_packages[package] = module_version
|
|
@@ -415,8 +461,8 @@ class Env:
|
|
|
415
461
|
if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
|
|
416
462
|
raise excs.Error(
|
|
417
463
|
(
|
|
418
|
-
f'The installed version of package {package} is {".".join(
|
|
419
|
-
f'but version >={".".join(
|
|
464
|
+
f'The installed version of package {package} is {".".join(str(v) for v in installed_version)}, '
|
|
465
|
+
f'but version >={".".join(str(v) for v in min_version)} is required'
|
|
420
466
|
)
|
|
421
467
|
)
|
|
422
468
|
|
|
@@ -456,11 +502,44 @@ class Env:
|
|
|
456
502
|
assert self._sa_engine is not None
|
|
457
503
|
return self._sa_engine
|
|
458
504
|
|
|
459
|
-
@property
|
|
460
|
-
def nos_client(self) -> Any:
|
|
461
|
-
return self._nos_client
|
|
462
|
-
|
|
463
505
|
@property
|
|
464
506
|
def spacy_nlp(self) -> Any:
|
|
465
507
|
assert self._spacy_nlp is not None
|
|
466
508
|
return self._spacy_nlp
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def register_client(name: str) -> Callable:
|
|
512
|
+
"""Decorator that registers a third-party API client for use by Pixeltable.
|
|
513
|
+
|
|
514
|
+
The decorated function is an initialization wrapper for the client, and can have
|
|
515
|
+
any number of string parameters, with a signature such as:
|
|
516
|
+
|
|
517
|
+
```
|
|
518
|
+
def my_client(api_key: str, url: str) -> my_client_sdk.Client:
|
|
519
|
+
return my_client_sdk.Client(api_key=api_key, url=url)
|
|
520
|
+
```
|
|
521
|
+
|
|
522
|
+
The initialization wrapper will not be called immediately; initialization will
|
|
523
|
+
be deferred until the first time the client is used. At initialization time,
|
|
524
|
+
Pixeltable will attempt to load the client parameters from config. For each
|
|
525
|
+
config parameter:
|
|
526
|
+
- If an environment variable named MY_CLIENT_API_KEY (for example) is set, use it;
|
|
527
|
+
- Otherwise, look for 'api_key' in the 'my_client' section of config.yaml.
|
|
528
|
+
|
|
529
|
+
If all config parameters are found, Pixeltable calls the initialization function;
|
|
530
|
+
otherwise it throws an exception.
|
|
531
|
+
|
|
532
|
+
Args:
|
|
533
|
+
- name (str): The name of the API client (e.g., 'openai' or 'label-studio').
|
|
534
|
+
"""
|
|
535
|
+
def decorator(fn: Callable) -> None:
|
|
536
|
+
Env.get()._register_client(name, fn)
|
|
537
|
+
|
|
538
|
+
return decorator
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
@dataclass
|
|
542
|
+
class ApiClient:
|
|
543
|
+
init_fn: Callable
|
|
544
|
+
param_names: list[str]
|
|
545
|
+
client_obj: Optional[Any] = None
|
pixeltable/exprs/column_ref.py
CHANGED
|
@@ -108,7 +108,7 @@ class ColumnRef(Expr):
|
|
|
108
108
|
def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
|
|
109
109
|
tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
|
|
110
110
|
tbl_version = catalog.Catalog.get().tbl_versions[(tbl_id, version)]
|
|
111
|
-
|
|
112
|
-
col = tbl_version.
|
|
111
|
+
# don't use tbl_version.cols_by_id here, this might be a snapshot reference to a column that was then dropped
|
|
112
|
+
col = next(col for col in tbl_version.cols if col.id == col_id)
|
|
113
113
|
return cls(col)
|
|
114
114
|
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -4,18 +4,44 @@ from typing import Optional, List, Any, Dict, Tuple
|
|
|
4
4
|
|
|
5
5
|
import sqlalchemy as sql
|
|
6
6
|
|
|
7
|
+
from .column_ref import ColumnRef
|
|
7
8
|
from .data_row import DataRow
|
|
8
9
|
from .expr import Expr
|
|
9
10
|
from .globals import ComparisonOperator
|
|
11
|
+
from .literal import Literal
|
|
10
12
|
from .predicate import Predicate
|
|
11
13
|
from .row_builder import RowBuilder
|
|
14
|
+
import pixeltable.exceptions as excs
|
|
15
|
+
import pixeltable.index as index
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class Comparison(Predicate):
|
|
15
19
|
def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
|
|
16
20
|
super().__init__()
|
|
17
21
|
self.operator = operator
|
|
18
|
-
|
|
22
|
+
|
|
23
|
+
# if this is a comparison of a column to a literal (ie, could be used as a search argument in an index lookup),
|
|
24
|
+
# normalize it to <column> <operator> <literal>.
|
|
25
|
+
if isinstance(op1, ColumnRef) and isinstance(op2, Literal):
|
|
26
|
+
self.is_search_arg_comparison = True
|
|
27
|
+
self.components = [op1, op2]
|
|
28
|
+
elif isinstance(op1, Literal) and isinstance(op2, ColumnRef):
|
|
29
|
+
self.is_search_arg_comparison = True
|
|
30
|
+
self.components = [op2, op1]
|
|
31
|
+
self.operator = self.operator.reverse()
|
|
32
|
+
else:
|
|
33
|
+
self.is_search_arg_comparison = False
|
|
34
|
+
self.components = [op1, op2]
|
|
35
|
+
|
|
36
|
+
import pixeltable.index as index
|
|
37
|
+
if self.is_search_arg_comparison and self._op2.col_type.is_string_type() \
|
|
38
|
+
and len(self._op2.val) >= index.BtreeIndex.MAX_STRING_LEN:
|
|
39
|
+
# we can't use an index for this after all
|
|
40
|
+
raise excs.Error(
|
|
41
|
+
f'String literal too long for comparison against indexed column {self._op1.col.name!r} '
|
|
42
|
+
f'(max length is {index.BtreeIndex.MAX_STRING_LEN - 1})'
|
|
43
|
+
)
|
|
44
|
+
|
|
19
45
|
self.id = self._create_id()
|
|
20
46
|
|
|
21
47
|
def __str__(self) -> str:
|
|
@@ -37,6 +63,18 @@ class Comparison(Predicate):
|
|
|
37
63
|
|
|
38
64
|
def sql_expr(self) -> Optional[sql.ClauseElement]:
|
|
39
65
|
left = self._op1.sql_expr()
|
|
66
|
+
if self.is_search_arg_comparison:
|
|
67
|
+
# reference the index value column if there is an index and this is not a snapshot
|
|
68
|
+
# (indices don't apply to snapshots)
|
|
69
|
+
tbl = self._op1.col.tbl
|
|
70
|
+
idx_info = [
|
|
71
|
+
info for info in self._op1.col.get_idx_info().values() if isinstance(info.idx, index.BtreeIndex)
|
|
72
|
+
]
|
|
73
|
+
if len(idx_info) > 0 and not tbl.is_snapshot:
|
|
74
|
+
# there shouldn't be multiple B-tree indices on a column
|
|
75
|
+
assert len(idx_info) == 1
|
|
76
|
+
left = idx_info[0].val_col.sa_col
|
|
77
|
+
|
|
40
78
|
right = self._op2.sql_expr()
|
|
41
79
|
if left is None or right is None:
|
|
42
80
|
return None
|
pixeltable/exprs/data_row.py
CHANGED
pixeltable/exprs/expr.py
CHANGED
|
@@ -169,16 +169,22 @@ class Expr(abc.ABC):
|
|
|
169
169
|
memo[id(self)] = result
|
|
170
170
|
return result
|
|
171
171
|
|
|
172
|
-
def substitute(self,
|
|
172
|
+
def substitute(self, spec: dict[Expr, Expr]) -> Expr:
|
|
173
173
|
"""
|
|
174
174
|
Replace 'old' with 'new' recursively.
|
|
175
175
|
"""
|
|
176
|
-
|
|
177
|
-
|
|
176
|
+
for old, new in spec.items():
|
|
177
|
+
if self.equals(old):
|
|
178
|
+
return new.copy()
|
|
178
179
|
for i in range(len(self.components)):
|
|
179
|
-
self.components[i] = self.components[i].substitute(
|
|
180
|
+
self.components[i] = self.components[i].substitute(spec)
|
|
180
181
|
return self
|
|
181
182
|
|
|
183
|
+
@classmethod
|
|
184
|
+
def list_substitute(cls, expr_list: List[Expr], spec: dict[Expr, Expr]) -> None:
|
|
185
|
+
for i in range(len(expr_list)):
|
|
186
|
+
expr_list[i] = expr_list[i].substitute(spec)
|
|
187
|
+
|
|
182
188
|
def resolve_computed_cols(self, resolve_cols: Optional[Set[catalog.Column]] = None) -> Expr:
|
|
183
189
|
"""
|
|
184
190
|
Recursively replace ColRefs to unstored computed columns with their value exprs.
|
|
@@ -196,9 +202,7 @@ class Expr(abc.ABC):
|
|
|
196
202
|
])
|
|
197
203
|
if len(target_col_refs) == 0:
|
|
198
204
|
return result
|
|
199
|
-
for ref in target_col_refs
|
|
200
|
-
assert ref.col.value_expr is not None
|
|
201
|
-
result = result.substitute(ref, ref.col.value_expr)
|
|
205
|
+
result = result.substitute({ref: ref.col.value_expr for ref in target_col_refs})
|
|
202
206
|
|
|
203
207
|
def is_bound_by(self, tbl: catalog.TableVersionPath) -> bool:
|
|
204
208
|
"""Returns True if this expr can be evaluated in the context of tbl."""
|
|
@@ -225,11 +229,6 @@ class Expr(abc.ABC):
|
|
|
225
229
|
self.components[i] = self.components[i]._retarget(tbl_versions)
|
|
226
230
|
return self
|
|
227
231
|
|
|
228
|
-
@classmethod
|
|
229
|
-
def list_substitute(cls, expr_list: List[Expr], old: Expr, new: Expr) -> None:
|
|
230
|
-
for i in range(len(expr_list)):
|
|
231
|
-
expr_list[i] = expr_list[i].substitute(old, new)
|
|
232
|
-
|
|
233
232
|
@abc.abstractmethod
|
|
234
233
|
def __str__(self) -> str:
|
|
235
234
|
pass
|
|
@@ -174,9 +174,6 @@ class FunctionCall(Expr):
|
|
|
174
174
|
f'Parameter {param_name}: argument type {arg.col_type} does not match parameter type '
|
|
175
175
|
f'{param_type}')
|
|
176
176
|
|
|
177
|
-
def is_nos_call(self) -> bool:
|
|
178
|
-
return isinstance(self.fn, func.NOSFunction)
|
|
179
|
-
|
|
180
177
|
def _equals(self, other: FunctionCall) -> bool:
|
|
181
178
|
if self.fn != other.fn:
|
|
182
179
|
return False
|
pixeltable/exprs/globals.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import datetime
|
|
2
|
-
from typing import Union
|
|
3
4
|
import enum
|
|
4
|
-
|
|
5
|
+
from typing import Union
|
|
5
6
|
|
|
6
7
|
# Python types corresponding to our literal types
|
|
7
8
|
LiteralPythonTypes = Union[str, int, float, bool, datetime.datetime, datetime.date]
|
|
@@ -33,6 +34,17 @@ class ComparisonOperator(enum.Enum):
|
|
|
33
34
|
if self == self.GE:
|
|
34
35
|
return '>='
|
|
35
36
|
|
|
37
|
+
def reverse(self) -> ComparisonOperator:
|
|
38
|
+
if self == self.LT:
|
|
39
|
+
return self.GT
|
|
40
|
+
if self == self.LE:
|
|
41
|
+
return self.GE
|
|
42
|
+
if self == self.GT:
|
|
43
|
+
return self.LT
|
|
44
|
+
if self == self.GE:
|
|
45
|
+
return self.LE
|
|
46
|
+
return self
|
|
47
|
+
|
|
36
48
|
|
|
37
49
|
class LogicalOperator(enum.Enum):
|
|
38
50
|
AND = 0
|
|
@@ -23,13 +23,15 @@ class SimilarityExpr(Expr):
|
|
|
23
23
|
|
|
24
24
|
# determine index to use
|
|
25
25
|
idx_info = col_ref.col.get_idx_info()
|
|
26
|
-
|
|
26
|
+
import pixeltable.index as index
|
|
27
|
+
embedding_idx_info = [info for info in idx_info.values() if isinstance(info.idx, index.EmbeddingIndex)]
|
|
28
|
+
if len(embedding_idx_info) == 0:
|
|
27
29
|
raise excs.Error(f'No index found for column {col_ref.col}')
|
|
28
|
-
if len(
|
|
30
|
+
if len(embedding_idx_info) > 1:
|
|
29
31
|
raise excs.Error(
|
|
30
32
|
f'Column {col_ref.col.name} has multiple indices; use the index name to disambiguate, '
|
|
31
33
|
f'e.g., `{col_ref.col.name}.<index-name>.similarity(...)`')
|
|
32
|
-
self.idx_info =
|
|
34
|
+
self.idx_info = embedding_idx_info[0]
|
|
33
35
|
idx = self.idx_info.idx
|
|
34
36
|
|
|
35
37
|
if item.col_type.is_string_type() and idx.txt_embed is None:
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import whisperx
|
|
5
|
+
from whisperx.asr import FasterWhisperPipeline
|
|
6
|
+
|
|
7
|
+
import pixeltable as pxt
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@pxt.udf(param_types=[pxt.AudioType(), pxt.StringType(), pxt.StringType(), pxt.StringType(), pxt.IntType()])
|
|
11
|
+
def transcribe(
|
|
12
|
+
audio: str, *, model: str, compute_type: Optional[str] = None, language: Optional[str] = None, chunk_size: int = 30
|
|
13
|
+
) -> dict:
|
|
14
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
15
|
+
compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
|
|
16
|
+
model = _lookup_model(model, device, compute_type)
|
|
17
|
+
audio_array = whisperx.load_audio(audio)
|
|
18
|
+
result = model.transcribe(audio_array, batch_size=16, language=language, chunk_size=chunk_size)
|
|
19
|
+
return result
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _lookup_model(model_id: str, device: str, compute_type: str) -> FasterWhisperPipeline:
|
|
23
|
+
key = (model_id, device, compute_type)
|
|
24
|
+
if key not in _model_cache:
|
|
25
|
+
model = whisperx.load_model(model_id, device, compute_type=compute_type)
|
|
26
|
+
_model_cache[key] = model
|
|
27
|
+
return _model_cache[key]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_model_cache = {}
|
|
@@ -56,6 +56,22 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
|
|
|
56
56
|
return results
|
|
57
57
|
|
|
58
58
|
|
|
59
|
+
@pxt.udf
|
|
60
|
+
def yolo_to_coco(detections: dict) -> list:
|
|
61
|
+
bboxes, labels = detections['bboxes'], detections['labels']
|
|
62
|
+
num_annotations = len(detections['bboxes'])
|
|
63
|
+
assert num_annotations == len(detections['labels'])
|
|
64
|
+
result = []
|
|
65
|
+
for i in range(num_annotations):
|
|
66
|
+
bbox = bboxes[i]
|
|
67
|
+
ann = {
|
|
68
|
+
'bbox': [round(bbox[0]), round(bbox[1]), round(bbox[2] - bbox[0]), round(bbox[3] - bbox[1])],
|
|
69
|
+
'category': labels[i],
|
|
70
|
+
}
|
|
71
|
+
result.append(ann)
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
|
|
59
75
|
def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
|
|
60
76
|
for image in images:
|
|
61
77
|
image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
|
|
@@ -140,7 +140,7 @@ def uda(
|
|
|
140
140
|
update_types: List[ts.ColumnType],
|
|
141
141
|
init_types: Optional[List[ts.ColumnType]] = None,
|
|
142
142
|
requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
|
|
143
|
-
) -> Callable:
|
|
143
|
+
) -> Callable[[Type[Aggregator]], AggregateFunction]:
|
|
144
144
|
"""Decorator for user-defined aggregate functions.
|
|
145
145
|
|
|
146
146
|
The decorated class must inherit from Aggregator and implement the following methods:
|
|
@@ -162,7 +162,7 @@ def uda(
|
|
|
162
162
|
if init_types is None:
|
|
163
163
|
init_types = []
|
|
164
164
|
|
|
165
|
-
def decorator(cls: Type[Aggregator]) ->
|
|
165
|
+
def decorator(cls: Type[Aggregator]) -> AggregateFunction:
|
|
166
166
|
# validate type parameters
|
|
167
167
|
num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
|
|
168
168
|
if num_init_params > 0:
|
|
@@ -50,6 +50,7 @@ class ExprTemplateFunction(Function):
|
|
|
50
50
|
{param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
|
|
51
51
|
result = self.expr.copy()
|
|
52
52
|
import pixeltable.exprs as exprs
|
|
53
|
+
arg_exprs: dict[exprs.Expr, exprs.Expr] = {}
|
|
53
54
|
for param_name, arg in bound_args.items():
|
|
54
55
|
param_expr = self.param_exprs_by_name[param_name]
|
|
55
56
|
if not isinstance(arg, exprs.Expr):
|
|
@@ -59,7 +60,8 @@ class ExprTemplateFunction(Function):
|
|
|
59
60
|
raise excs.Error(f'{self.self_name}(): cannot convert argument {arg} to a Pixeltable expression')
|
|
60
61
|
else:
|
|
61
62
|
arg_expr = arg
|
|
62
|
-
|
|
63
|
+
arg_exprs[param_expr] = arg_expr
|
|
64
|
+
result = result.substitute(arg_exprs)
|
|
63
65
|
import pixeltable.exprs as exprs
|
|
64
66
|
assert not result.contains(exprs.Variable)
|
|
65
67
|
return result
|
pixeltable/func/udf.py
CHANGED
|
@@ -28,7 +28,7 @@ def udf(
|
|
|
28
28
|
batch_size: Optional[int] = None,
|
|
29
29
|
substitute_fn: Optional[Callable] = None,
|
|
30
30
|
_force_stored: bool = False
|
|
31
|
-
) -> Callable: ...
|
|
31
|
+
) -> Callable[[Callable], Function]: ...
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def udf(*args, **kwargs):
|
|
@@ -131,7 +131,7 @@ def make_function(
|
|
|
131
131
|
def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
|
|
132
132
|
|
|
133
133
|
@overload
|
|
134
|
-
def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable: ...
|
|
134
|
+
def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
|
|
135
135
|
|
|
136
136
|
def expr_udf(*args: Any, **kwargs: Any) -> Any:
|
|
137
137
|
def decorator(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction:
|
|
@@ -6,8 +6,13 @@ import pixeltable as pxt
|
|
|
6
6
|
from pixeltable import env
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
|
|
10
|
-
|
|
9
|
+
@env.register_client('fireworks')
|
|
10
|
+
def _(api_key: str) -> fireworks.client.Fireworks:
|
|
11
|
+
return fireworks.client.Fireworks(api_key=api_key)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _fireworks_client() -> fireworks.client.Fireworks:
|
|
15
|
+
return env.Env.get().get_client('fireworks')
|
|
11
16
|
|
|
12
17
|
|
|
13
18
|
@pxt.udf
|
|
@@ -26,8 +31,8 @@ def chat_completions(
|
|
|
26
31
|
'top_p': top_p,
|
|
27
32
|
'temperature': temperature
|
|
28
33
|
}
|
|
29
|
-
kwargs_not_none =
|
|
30
|
-
return
|
|
34
|
+
kwargs_not_none = {k: v for k, v in kwargs.items() if v is not None}
|
|
35
|
+
return _fireworks_client().chat.completions.create(
|
|
31
36
|
model=model,
|
|
32
37
|
messages=messages,
|
|
33
38
|
**kwargs_not_none
|