pixeltable 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (56) hide show
  1. pixeltable/__init__.py +3 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/column.py +14 -2
  4. pixeltable/catalog/insertable_table.py +32 -17
  5. pixeltable/catalog/table.py +194 -12
  6. pixeltable/catalog/table_version.py +270 -110
  7. pixeltable/catalog/table_version_path.py +6 -1
  8. pixeltable/datatransfer/__init__.py +1 -0
  9. pixeltable/datatransfer/label_studio.py +526 -0
  10. pixeltable/datatransfer/remote.py +113 -0
  11. pixeltable/env.py +156 -73
  12. pixeltable/exprs/column_ref.py +2 -2
  13. pixeltable/exprs/comparison.py +39 -1
  14. pixeltable/exprs/data_row.py +7 -0
  15. pixeltable/exprs/expr.py +11 -12
  16. pixeltable/exprs/function_call.py +0 -3
  17. pixeltable/exprs/globals.py +14 -2
  18. pixeltable/exprs/similarity_expr.py +5 -3
  19. pixeltable/ext/functions/whisperx.py +30 -0
  20. pixeltable/ext/functions/yolox.py +16 -0
  21. pixeltable/func/aggregate_function.py +2 -2
  22. pixeltable/func/expr_template_function.py +3 -1
  23. pixeltable/func/udf.py +2 -2
  24. pixeltable/functions/fireworks.py +9 -4
  25. pixeltable/functions/huggingface.py +25 -1
  26. pixeltable/functions/openai.py +15 -10
  27. pixeltable/functions/together.py +11 -6
  28. pixeltable/functions/util.py +0 -43
  29. pixeltable/functions/video.py +46 -8
  30. pixeltable/globals.py +20 -2
  31. pixeltable/index/__init__.py +1 -0
  32. pixeltable/index/base.py +6 -1
  33. pixeltable/index/btree.py +54 -0
  34. pixeltable/index/embedding_index.py +4 -1
  35. pixeltable/io/__init__.py +1 -0
  36. pixeltable/io/globals.py +59 -0
  37. pixeltable/iterators/base.py +4 -4
  38. pixeltable/iterators/document.py +26 -15
  39. pixeltable/iterators/video.py +9 -1
  40. pixeltable/metadata/__init__.py +2 -2
  41. pixeltable/metadata/converters/convert_14.py +13 -0
  42. pixeltable/metadata/converters/convert_15.py +29 -0
  43. pixeltable/metadata/converters/util.py +63 -0
  44. pixeltable/metadata/schema.py +12 -6
  45. pixeltable/plan.py +9 -5
  46. pixeltable/store.py +14 -21
  47. pixeltable/tool/create_test_db_dump.py +16 -0
  48. pixeltable/type_system.py +14 -4
  49. pixeltable/utils/coco.py +94 -0
  50. pixeltable-0.2.7.dist-info/METADATA +137 -0
  51. {pixeltable-0.2.6.dist-info → pixeltable-0.2.7.dist-info}/RECORD +53 -46
  52. pixeltable/func/nos_function.py +0 -202
  53. pixeltable/utils/clip.py +0 -18
  54. pixeltable-0.2.6.dist-info/METADATA +0 -131
  55. {pixeltable-0.2.6.dist-info → pixeltable-0.2.7.dist-info}/LICENSE +0 -0
  56. {pixeltable-0.2.6.dist-info → pixeltable-0.2.7.dist-info}/WHEEL +0 -0
pixeltable/env.py CHANGED
@@ -5,20 +5,20 @@ import glob
5
5
  import http.server
6
6
  import importlib
7
7
  import importlib.util
8
+ import inspect
8
9
  import logging
9
10
  import os
10
- import socketserver
11
11
  import sys
12
12
  import threading
13
13
  import uuid
14
14
  import warnings
15
+ from dataclasses import dataclass
15
16
  from pathlib import Path
16
17
  from typing import Callable, Optional, Dict, Any, List
17
18
 
18
19
  import pgserver
19
20
  import sqlalchemy as sql
20
21
  import yaml
21
- from sqlalchemy_utils.functions import database_exists, create_database, drop_database
22
22
  from tqdm import TqdmWarning
23
23
 
24
24
  import pixeltable.exceptions as excs
@@ -37,12 +37,16 @@ class Env:
37
37
  @classmethod
38
38
  def get(cls) -> Env:
39
39
  if cls._instance is None:
40
- env = Env()
41
- env._set_up()
42
- env._upgrade_metadata()
43
- cls._instance = env
40
+ cls._init_env()
44
41
  return cls._instance
45
42
 
43
+ @classmethod
44
+ def _init_env(cls, reinit_db: bool = False) -> None:
45
+ env = Env()
46
+ env._set_up(reinit_db=reinit_db)
47
+ env._upgrade_metadata()
48
+ cls._instance = env
49
+
46
50
  def __init__(self):
47
51
  self._home: Optional[Path] = None
48
52
  self._media_dir: Optional[Path] = None # computed media files
@@ -59,12 +63,11 @@ class Env:
59
63
  # info about installed packages that are utilized by some parts of the code;
60
64
  # package name -> version; version == []: package is installed, but we haven't determined the version yet
61
65
  self._installed_packages: Dict[str, Optional[List[int]]] = {}
62
- self._nos_client: Optional[Any] = None
63
66
  self._spacy_nlp: Optional[Any] = None # spacy.Language
64
- self._httpd: Optional[http.server.ThreadingHTTPServer] = None
67
+ self._httpd: Optional[http.server.HTTPServer] = None
65
68
  self._http_address: Optional[str] = None
66
69
 
67
- self._registered_clients: dict[str, Any] = {}
70
+ self._registered_clients: dict[str, ApiClient] = {}
68
71
 
69
72
  # logging-related state
70
73
  self._logger = logging.getLogger('pixeltable')
@@ -120,8 +123,8 @@ class Env:
120
123
  if level is not None:
121
124
  self.set_log_level(level)
122
125
  if add is not None:
123
- for module, level in [t.split(':') for t in add.split(',')]:
124
- self.set_module_log_level(module, int(level))
126
+ for module, level_str in [t.split(':') for t in add.split(',')]:
127
+ self.set_module_log_level(module, int(level_str))
125
128
  if remove is not None:
126
129
  for module in remove.split(','):
127
130
  self.set_module_log_level(module, None)
@@ -263,24 +266,19 @@ class Env:
263
266
  self._db_url = self._db_server.get_uri(database=self._db_name)
264
267
 
265
268
  if reinit_db:
266
- if database_exists(self.db_url):
267
- drop_database(self.db_url)
269
+ if self._store_db_exists():
270
+ self._drop_store_db()
268
271
 
269
- if not database_exists(self.db_url):
272
+ if not self._store_db_exists():
270
273
  self._logger.info(f'creating database at {self.db_url}')
271
- create_database(self.db_url)
272
- self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True)
274
+ self._create_store_db()
275
+ self._create_engine(echo=echo)
273
276
  from pixeltable.metadata import schema
274
-
275
277
  schema.Base.metadata.create_all(self._sa_engine)
276
278
  metadata.create_system_info(self._sa_engine)
277
- # enable pgvector
278
- with self._sa_engine.begin() as conn:
279
- conn.execute(sql.text('CREATE EXTENSION vector'))
280
279
  else:
281
280
  self._logger.info(f'found database {self.db_url}')
282
- if self._sa_engine is None:
283
- self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True)
281
+ self._create_engine(echo=echo)
284
282
 
285
283
  print(f'Connected to Pixeltable database at: {self.db_url}')
286
284
 
@@ -288,57 +286,110 @@ class Env:
288
286
  self._set_up_runtime()
289
287
  self.log_to_stdout(False)
290
288
 
291
- def _upgrade_metadata(self) -> None:
292
- metadata.upgrade_md(self._sa_engine)
293
-
294
- def _create_nos_client(self) -> None:
295
- import nos
289
+ def _create_engine(self, echo: bool = False) -> None:
290
+ self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True, isolation_level='AUTOCOMMIT')
296
291
 
297
- self._logger.info('connecting to NOS')
298
- nos.init(logging_level=logging.DEBUG)
299
- self._nos_client = nos.client.InferenceClient()
300
- self._logger.info('waiting for NOS')
301
- self._nos_client.WaitForServer()
302
-
303
- # now that we have a client, we can create the module
304
- import importlib
292
+ def _store_db_exists(self) -> bool:
293
+ assert self._db_name is not None
294
+ # don't try to connect to self.db_name, it may not exist
295
+ db_url = self._db_server.get_uri(database='postgres')
296
+ engine = sql.create_engine(db_url, future=True)
297
+ try:
298
+ with engine.begin() as conn:
299
+ stmt = f"SELECT COUNT(*) FROM pg_database WHERE datname = '{self._db_name}'"
300
+ result = conn.scalar(sql.text(stmt))
301
+ assert result <= 1
302
+ return result == 1
303
+ finally:
304
+ engine.dispose()
305
+
306
+
307
+ def _create_store_db(self) -> None:
308
+ assert self._db_name is not None
309
+ # create the db
310
+ pg_db_url = self._db_server.get_uri(database='postgres')
311
+ engine = sql.create_engine(pg_db_url, future=True, isolation_level='AUTOCOMMIT')
312
+ preparer = engine.dialect.identifier_preparer
313
+ try:
314
+ with engine.begin() as conn:
315
+ # use C collation to get standard C/Python-style sorting
316
+ stmt = (
317
+ f"CREATE DATABASE {preparer.quote(self._db_name)} "
318
+ "ENCODING 'utf-8' LC_COLLATE 'C' LC_CTYPE 'C' TEMPLATE template0"
319
+ )
320
+ conn.execute(sql.text(stmt))
321
+ finally:
322
+ engine.dispose()
305
323
 
324
+ # enable pgvector
325
+ store_db_url = self._db_server.get_uri(database=self._db_name)
326
+ engine = sql.create_engine(store_db_url, future=True, isolation_level='AUTOCOMMIT')
306
327
  try:
307
- importlib.import_module('pixeltable.functions.nos')
308
- # it's already been created
309
- return
310
- except ImportError:
311
- pass
312
- from pixeltable.functions.util import create_nos_modules
328
+ with engine.begin() as conn:
329
+ conn.execute(sql.text('CREATE EXTENSION vector'))
330
+ finally:
331
+ engine.dispose()
332
+
333
+ def _drop_store_db(self) -> None:
334
+ assert self._db_name is not None
335
+ db_url = self._db_server.get_uri(database='postgres')
336
+ engine = sql.create_engine(db_url, future=True, isolation_level='AUTOCOMMIT')
337
+ preparer = engine.dialect.identifier_preparer
338
+ try:
339
+ with engine.begin() as conn:
340
+ # terminate active connections
341
+ stmt = (f"""
342
+ SELECT pg_terminate_backend(pg_stat_activity.pid)
343
+ FROM pg_stat_activity
344
+ WHERE pg_stat_activity.datname = '{self._db_name}'
345
+ AND pid <> pg_backend_pid()
346
+ """)
347
+ conn.execute(sql.text(stmt))
348
+ # drop db
349
+ stmt = f'DROP DATABASE {preparer.quote(self._db_name)}'
350
+ conn.execute(sql.text(stmt))
351
+ finally:
352
+ engine.dispose()
313
353
 
314
- _ = create_nos_modules()
354
+ def _upgrade_metadata(self) -> None:
355
+ metadata.upgrade_md(self._sa_engine)
315
356
 
316
- def get_client(self, name: str, init: Callable, environ: Optional[str] = None) -> Any:
317
- """
318
- Gets the client with the specified name, using `init` to construct one if necessary.
357
+ def _register_client(self, name: str, init_fn: Callable) -> None:
358
+ sig = inspect.signature(init_fn)
359
+ param_names = list(sig.parameters.keys())
360
+ self._registered_clients[name] = ApiClient(init_fn=init_fn, param_names=param_names)
319
361
 
320
- - name: The name of the client
321
- - init: A `Callable` with signature `fn(api_key: str) -> Any` that constructs a client object
322
- - environ: The name of the environment variable to use for the API key, if no API key is found in config
323
- (defaults to f'{name.upper()}_API_KEY')
362
+ def get_client(self, name: str) -> Any:
324
363
  """
325
- if name in self._registered_clients:
326
- return self._registered_clients[name]
327
-
328
- if environ is None:
329
- environ = f'{name.upper()}_API_KEY'
364
+ Gets the client with the specified name, initializing it if necessary.
330
365
 
331
- if name in self._config and 'api_key' in self._config[name]:
332
- api_key = self._config[name]['api_key']
333
- else:
334
- api_key = os.environ.get(environ)
335
- if api_key is None or api_key == '':
336
- raise excs.Error(f'`{name}` client not initialized (no API key configured).')
366
+ Args:
367
+ - name: The name of the client
368
+ """
369
+ cl = self._registered_clients[name]
370
+ if cl.client_obj is not None:
371
+ return cl.client_obj # Already initialized
372
+
373
+ # Construct a client. For each client parameter, first check if the parameter is in the environment;
374
+ # if not, look in Pixeltable config from `config.yaml`.
375
+
376
+ init_kwargs: dict[str, str] = {}
377
+ for param in cl.param_names:
378
+ environ = f'{name.upper()}_{param.upper()}'
379
+ if environ in os.environ:
380
+ init_kwargs[param] = os.environ[environ]
381
+ elif name.lower() in self._config and param in self._config[name.lower()]:
382
+ init_kwargs[param] = self._config[name.lower()][param.lower()]
383
+ if param not in init_kwargs or init_kwargs[param] == '':
384
+ raise excs.Error(
385
+ f'`{name}` client not initialized: parameter `{param}` is not configured.\n'
386
+ f'To fix this, specify the `{environ}` environment variable, or put `{param.lower()}` in '
387
+ f'the `{name.lower()}` section of $PIXELTABLE_HOME/config.yaml.'
388
+ )
337
389
 
338
- client = init(api_key)
339
- self._registered_clients[name] = client
390
+ cl.client_obj = cl.init_fn(**init_kwargs)
340
391
  self._logger.info(f'Initialized `{name}` client.')
341
- return client
392
+ return cl.client_obj
342
393
 
343
394
  def _start_web_server(self) -> None:
344
395
  """
@@ -380,6 +431,7 @@ class Env:
380
431
  check('transformers')
381
432
  check('sentence_transformers')
382
433
  check('yolox')
434
+ check('whisperx')
383
435
  check('boto3')
384
436
  check('fitz') # pymupdf
385
437
  check('pyarrow')
@@ -392,9 +444,7 @@ class Env:
392
444
  check('openai')
393
445
  check('together')
394
446
  check('fireworks')
395
- check('nos')
396
- if self.is_installed_package('nos'):
397
- self._create_nos_client()
447
+ check('label_studio_sdk')
398
448
  check('openpyxl')
399
449
 
400
450
  def require_package(self, package: str, min_version: Optional[List[int]] = None) -> None:
@@ -405,7 +455,7 @@ class Env:
405
455
  return
406
456
 
407
457
  # check whether we have a version >= the required one
408
- if self._installed_packages[package] == []:
458
+ if not self._installed_packages[package]:
409
459
  m = importlib.import_module(package)
410
460
  module_version = [int(x) for x in m.__version__.split('.')]
411
461
  self._installed_packages[package] = module_version
@@ -415,8 +465,8 @@ class Env:
415
465
  if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
416
466
  raise excs.Error(
417
467
  (
418
- f'The installed version of package {package} is {".".join([str[v] for v in installed_version])}, '
419
- f'but version >={".".join([str[v] for v in min_version])} is required'
468
+ f'The installed version of package {package} is {".".join(str(v) for v in installed_version)}, '
469
+ f'but version >={".".join(str(v) for v in min_version)} is required'
420
470
  )
421
471
  )
422
472
 
@@ -456,11 +506,44 @@ class Env:
456
506
  assert self._sa_engine is not None
457
507
  return self._sa_engine
458
508
 
459
- @property
460
- def nos_client(self) -> Any:
461
- return self._nos_client
462
-
463
509
  @property
464
510
  def spacy_nlp(self) -> Any:
465
511
  assert self._spacy_nlp is not None
466
512
  return self._spacy_nlp
513
+
514
+
515
+ def register_client(name: str) -> Callable:
516
+ """Decorator that registers a third-party API client for use by Pixeltable.
517
+
518
+ The decorated function is an initialization wrapper for the client, and can have
519
+ any number of string parameters, with a signature such as:
520
+
521
+ ```
522
+ def my_client(api_key: str, url: str) -> my_client_sdk.Client:
523
+ return my_client_sdk.Client(api_key=api_key, url=url)
524
+ ```
525
+
526
+ The initialization wrapper will not be called immediately; initialization will
527
+ be deferred until the first time the client is used. At initialization time,
528
+ Pixeltable will attempt to load the client parameters from config. For each
529
+ config parameter:
530
+ - If an environment variable named MY_CLIENT_API_KEY (for example) is set, use it;
531
+ - Otherwise, look for 'api_key' in the 'my_client' section of config.yaml.
532
+
533
+ If all config parameters are found, Pixeltable calls the initialization function;
534
+ otherwise it throws an exception.
535
+
536
+ Args:
537
+ - name (str): The name of the API client (e.g., 'openai' or 'label-studio').
538
+ """
539
+ def decorator(fn: Callable) -> None:
540
+ Env.get()._register_client(name, fn)
541
+
542
+ return decorator
543
+
544
+
545
+ @dataclass
546
+ class ApiClient:
547
+ init_fn: Callable
548
+ param_names: list[str]
549
+ client_obj: Optional[Any] = None
@@ -108,7 +108,7 @@ class ColumnRef(Expr):
108
108
  def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
109
109
  tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
110
110
  tbl_version = catalog.Catalog.get().tbl_versions[(tbl_id, version)]
111
- assert col_id in tbl_version.cols_by_id
112
- col = tbl_version.cols_by_id[col_id]
111
+ # don't use tbl_version.cols_by_id here, this might be a snapshot reference to a column that was then dropped
112
+ col = next(col for col in tbl_version.cols if col.id == col_id)
113
113
  return cls(col)
114
114
 
@@ -4,18 +4,44 @@ from typing import Optional, List, Any, Dict, Tuple
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
7
+ from .column_ref import ColumnRef
7
8
  from .data_row import DataRow
8
9
  from .expr import Expr
9
10
  from .globals import ComparisonOperator
11
+ from .literal import Literal
10
12
  from .predicate import Predicate
11
13
  from .row_builder import RowBuilder
14
+ import pixeltable.exceptions as excs
15
+ import pixeltable.index as index
12
16
 
13
17
 
14
18
  class Comparison(Predicate):
15
19
  def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
16
20
  super().__init__()
17
21
  self.operator = operator
18
- self.components = [op1, op2]
22
+
23
+ # if this is a comparison of a column to a literal (ie, could be used as a search argument in an index lookup),
24
+ # normalize it to <column> <operator> <literal>.
25
+ if isinstance(op1, ColumnRef) and isinstance(op2, Literal):
26
+ self.is_search_arg_comparison = True
27
+ self.components = [op1, op2]
28
+ elif isinstance(op1, Literal) and isinstance(op2, ColumnRef):
29
+ self.is_search_arg_comparison = True
30
+ self.components = [op2, op1]
31
+ self.operator = self.operator.reverse()
32
+ else:
33
+ self.is_search_arg_comparison = False
34
+ self.components = [op1, op2]
35
+
36
+ import pixeltable.index as index
37
+ if self.is_search_arg_comparison and self._op2.col_type.is_string_type() \
38
+ and len(self._op2.val) >= index.BtreeIndex.MAX_STRING_LEN:
39
+ # we can't use an index for this after all
40
+ raise excs.Error(
41
+ f'String literal too long for comparison against indexed column {self._op1.col.name!r} '
42
+ f'(max length is {index.BtreeIndex.MAX_STRING_LEN - 1})'
43
+ )
44
+
19
45
  self.id = self._create_id()
20
46
 
21
47
  def __str__(self) -> str:
@@ -37,6 +63,18 @@ class Comparison(Predicate):
37
63
 
38
64
  def sql_expr(self) -> Optional[sql.ClauseElement]:
39
65
  left = self._op1.sql_expr()
66
+ if self.is_search_arg_comparison:
67
+ # reference the index value column if there is an index and this is not a snapshot
68
+ # (indices don't apply to snapshots)
69
+ tbl = self._op1.col.tbl
70
+ idx_info = [
71
+ info for info in self._op1.col.get_idx_info().values() if isinstance(info.idx, index.BtreeIndex)
72
+ ]
73
+ if len(idx_info) > 0 and not tbl.is_snapshot:
74
+ # there shouldn't be multiple B-tree indices on a column
75
+ assert len(idx_info) == 1
76
+ left = idx_info[0].val_col.sa_col
77
+
40
78
  right = self._op2.sql_expr()
41
79
  if left is None or right is None:
42
80
  return None
@@ -197,3 +197,10 @@ class DataRow:
197
197
  pass
198
198
  self.vals[index] = None
199
199
 
200
+ @property
201
+ def rowid(self) -> Tuple[int]:
202
+ return self.pk[:-1]
203
+
204
+ @property
205
+ def v_min(self) -> int:
206
+ return self.pk[-1]
pixeltable/exprs/expr.py CHANGED
@@ -169,16 +169,22 @@ class Expr(abc.ABC):
169
169
  memo[id(self)] = result
170
170
  return result
171
171
 
172
- def substitute(self, old: Expr, new: Expr) -> Expr:
172
+ def substitute(self, spec: dict[Expr, Expr]) -> Expr:
173
173
  """
174
174
  Replace 'old' with 'new' recursively.
175
175
  """
176
- if self.equals(old):
177
- return new.copy()
176
+ for old, new in spec.items():
177
+ if self.equals(old):
178
+ return new.copy()
178
179
  for i in range(len(self.components)):
179
- self.components[i] = self.components[i].substitute(old, new)
180
+ self.components[i] = self.components[i].substitute(spec)
180
181
  return self
181
182
 
183
+ @classmethod
184
+ def list_substitute(cls, expr_list: List[Expr], spec: dict[Expr, Expr]) -> None:
185
+ for i in range(len(expr_list)):
186
+ expr_list[i] = expr_list[i].substitute(spec)
187
+
182
188
  def resolve_computed_cols(self, resolve_cols: Optional[Set[catalog.Column]] = None) -> Expr:
183
189
  """
184
190
  Recursively replace ColRefs to unstored computed columns with their value exprs.
@@ -196,9 +202,7 @@ class Expr(abc.ABC):
196
202
  ])
197
203
  if len(target_col_refs) == 0:
198
204
  return result
199
- for ref in target_col_refs:
200
- assert ref.col.value_expr is not None
201
- result = result.substitute(ref, ref.col.value_expr)
205
+ result = result.substitute({ref: ref.col.value_expr for ref in target_col_refs})
202
206
 
203
207
  def is_bound_by(self, tbl: catalog.TableVersionPath) -> bool:
204
208
  """Returns True if this expr can be evaluated in the context of tbl."""
@@ -225,11 +229,6 @@ class Expr(abc.ABC):
225
229
  self.components[i] = self.components[i]._retarget(tbl_versions)
226
230
  return self
227
231
 
228
- @classmethod
229
- def list_substitute(cls, expr_list: List[Expr], old: Expr, new: Expr) -> None:
230
- for i in range(len(expr_list)):
231
- expr_list[i] = expr_list[i].substitute(old, new)
232
-
233
232
  @abc.abstractmethod
234
233
  def __str__(self) -> str:
235
234
  pass
@@ -174,9 +174,6 @@ class FunctionCall(Expr):
174
174
  f'Parameter {param_name}: argument type {arg.col_type} does not match parameter type '
175
175
  f'{param_type}')
176
176
 
177
- def is_nos_call(self) -> bool:
178
- return isinstance(self.fn, func.NOSFunction)
179
-
180
177
  def _equals(self, other: FunctionCall) -> bool:
181
178
  if self.fn != other.fn:
182
179
  return False
@@ -1,7 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import datetime
2
- from typing import Union
3
4
  import enum
4
-
5
+ from typing import Union
5
6
 
6
7
  # Python types corresponding to our literal types
7
8
  LiteralPythonTypes = Union[str, int, float, bool, datetime.datetime, datetime.date]
@@ -33,6 +34,17 @@ class ComparisonOperator(enum.Enum):
33
34
  if self == self.GE:
34
35
  return '>='
35
36
 
37
+ def reverse(self) -> ComparisonOperator:
38
+ if self == self.LT:
39
+ return self.GT
40
+ if self == self.LE:
41
+ return self.GE
42
+ if self == self.GT:
43
+ return self.LT
44
+ if self == self.GE:
45
+ return self.LE
46
+ return self
47
+
36
48
 
37
49
  class LogicalOperator(enum.Enum):
38
50
  AND = 0
@@ -23,13 +23,15 @@ class SimilarityExpr(Expr):
23
23
 
24
24
  # determine index to use
25
25
  idx_info = col_ref.col.get_idx_info()
26
- if len(idx_info) == 0:
26
+ import pixeltable.index as index
27
+ embedding_idx_info = [info for info in idx_info.values() if isinstance(info.idx, index.EmbeddingIndex)]
28
+ if len(embedding_idx_info) == 0:
27
29
  raise excs.Error(f'No index found for column {col_ref.col}')
28
- if len(idx_info) > 1:
30
+ if len(embedding_idx_info) > 1:
29
31
  raise excs.Error(
30
32
  f'Column {col_ref.col.name} has multiple indices; use the index name to disambiguate, '
31
33
  f'e.g., `{col_ref.col.name}.<index-name>.similarity(...)`')
32
- self.idx_info = next(iter(idx_info.values()))
34
+ self.idx_info = embedding_idx_info[0]
33
35
  idx = self.idx_info.idx
34
36
 
35
37
  if item.col_type.is_string_type() and idx.txt_embed is None:
@@ -0,0 +1,30 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import whisperx
5
+ from whisperx.asr import FasterWhisperPipeline
6
+
7
+ import pixeltable as pxt
8
+
9
+
10
+ @pxt.udf(param_types=[pxt.AudioType(), pxt.StringType(), pxt.StringType(), pxt.StringType(), pxt.IntType()])
11
+ def transcribe(
12
+ audio: str, *, model: str, compute_type: Optional[str] = None, language: Optional[str] = None, chunk_size: int = 30
13
+ ) -> dict:
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
16
+ model = _lookup_model(model, device, compute_type)
17
+ audio_array = whisperx.load_audio(audio)
18
+ result = model.transcribe(audio_array, batch_size=16, language=language, chunk_size=chunk_size)
19
+ return result
20
+
21
+
22
+ def _lookup_model(model_id: str, device: str, compute_type: str) -> FasterWhisperPipeline:
23
+ key = (model_id, device, compute_type)
24
+ if key not in _model_cache:
25
+ model = whisperx.load_model(model_id, device, compute_type=compute_type)
26
+ _model_cache[key] = model
27
+ return _model_cache[key]
28
+
29
+
30
+ _model_cache = {}
@@ -56,6 +56,22 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
56
56
  return results
57
57
 
58
58
 
59
+ @pxt.udf
60
+ def yolo_to_coco(detections: dict) -> list:
61
+ bboxes, labels = detections['bboxes'], detections['labels']
62
+ num_annotations = len(detections['bboxes'])
63
+ assert num_annotations == len(detections['labels'])
64
+ result = []
65
+ for i in range(num_annotations):
66
+ bbox = bboxes[i]
67
+ ann = {
68
+ 'bbox': [round(bbox[0]), round(bbox[1]), round(bbox[2] - bbox[0]), round(bbox[3] - bbox[1])],
69
+ 'category': labels[i],
70
+ }
71
+ result.append(ann)
72
+ return result
73
+
74
+
59
75
  def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: Exp) -> Iterator[torch.Tensor]:
60
76
  for image in images:
61
77
  image_transform, _ = _val_transform(np.array(image), None, exp.test_size)
@@ -140,7 +140,7 @@ def uda(
140
140
  update_types: List[ts.ColumnType],
141
141
  init_types: Optional[List[ts.ColumnType]] = None,
142
142
  requires_order_by: bool = False, allows_std_agg: bool = True, allows_window: bool = False,
143
- ) -> Callable:
143
+ ) -> Callable[[Type[Aggregator]], AggregateFunction]:
144
144
  """Decorator for user-defined aggregate functions.
145
145
 
146
146
  The decorated class must inherit from Aggregator and implement the following methods:
@@ -162,7 +162,7 @@ def uda(
162
162
  if init_types is None:
163
163
  init_types = []
164
164
 
165
- def decorator(cls: Type[Aggregator]) -> Type[Function]:
165
+ def decorator(cls: Type[Aggregator]) -> AggregateFunction:
166
166
  # validate type parameters
167
167
  num_init_params = len(inspect.signature(cls.__init__).parameters) - 1
168
168
  if num_init_params > 0:
@@ -50,6 +50,7 @@ class ExprTemplateFunction(Function):
50
50
  {param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args})
51
51
  result = self.expr.copy()
52
52
  import pixeltable.exprs as exprs
53
+ arg_exprs: dict[exprs.Expr, exprs.Expr] = {}
53
54
  for param_name, arg in bound_args.items():
54
55
  param_expr = self.param_exprs_by_name[param_name]
55
56
  if not isinstance(arg, exprs.Expr):
@@ -59,7 +60,8 @@ class ExprTemplateFunction(Function):
59
60
  raise excs.Error(f'{self.self_name}(): cannot convert argument {arg} to a Pixeltable expression')
60
61
  else:
61
62
  arg_expr = arg
62
- result = result.substitute(param_expr, arg_expr)
63
+ arg_exprs[param_expr] = arg_expr
64
+ result = result.substitute(arg_exprs)
63
65
  import pixeltable.exprs as exprs
64
66
  assert not result.contains(exprs.Variable)
65
67
  return result
pixeltable/func/udf.py CHANGED
@@ -28,7 +28,7 @@ def udf(
28
28
  batch_size: Optional[int] = None,
29
29
  substitute_fn: Optional[Callable] = None,
30
30
  _force_stored: bool = False
31
- ) -> Callable: ...
31
+ ) -> Callable[[Callable], Function]: ...
32
32
 
33
33
 
34
34
  def udf(*args, **kwargs):
@@ -131,7 +131,7 @@ def make_function(
131
131
  def expr_udf(py_fn: Callable) -> ExprTemplateFunction: ...
132
132
 
133
133
  @overload
134
- def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable: ...
134
+ def expr_udf(*, param_types: Optional[List[ts.ColumnType]] = None) -> Callable[[Callable], ExprTemplateFunction]: ...
135
135
 
136
136
  def expr_udf(*args: Any, **kwargs: Any) -> Any:
137
137
  def decorator(py_fn: Callable, param_types: Optional[List[ts.ColumnType]]) -> ExprTemplateFunction: