pixeltable 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (99) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +31 -50
  4. pixeltable/catalog/insertable_table.py +7 -6
  5. pixeltable/catalog/table.py +171 -57
  6. pixeltable/catalog/table_version.py +417 -140
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/dataframe.py +239 -121
  9. pixeltable/env.py +82 -16
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/cache_prefetch_node.py +1 -1
  12. pixeltable/exec/data_row_batch.py +6 -7
  13. pixeltable/exec/expr_eval_node.py +28 -28
  14. pixeltable/exec/in_memory_data_node.py +11 -7
  15. pixeltable/exec/sql_scan_node.py +7 -6
  16. pixeltable/exprs/__init__.py +4 -3
  17. pixeltable/exprs/column_ref.py +9 -0
  18. pixeltable/exprs/comparison.py +3 -3
  19. pixeltable/exprs/data_row.py +5 -1
  20. pixeltable/exprs/expr.py +15 -7
  21. pixeltable/exprs/function_call.py +17 -15
  22. pixeltable/exprs/image_member_access.py +9 -28
  23. pixeltable/exprs/in_predicate.py +96 -0
  24. pixeltable/exprs/inline_array.py +13 -11
  25. pixeltable/exprs/inline_dict.py +15 -13
  26. pixeltable/exprs/literal.py +16 -4
  27. pixeltable/exprs/row_builder.py +15 -41
  28. pixeltable/exprs/similarity_expr.py +65 -0
  29. pixeltable/ext/__init__.py +5 -0
  30. pixeltable/ext/functions/yolox.py +92 -0
  31. pixeltable/func/__init__.py +0 -2
  32. pixeltable/func/aggregate_function.py +18 -15
  33. pixeltable/func/callable_function.py +57 -13
  34. pixeltable/func/expr_template_function.py +20 -3
  35. pixeltable/func/function.py +35 -4
  36. pixeltable/func/globals.py +24 -14
  37. pixeltable/func/signature.py +23 -27
  38. pixeltable/func/udf.py +13 -12
  39. pixeltable/functions/__init__.py +8 -8
  40. pixeltable/functions/eval.py +7 -8
  41. pixeltable/functions/huggingface.py +64 -17
  42. pixeltable/functions/openai.py +36 -3
  43. pixeltable/functions/pil/image.py +61 -64
  44. pixeltable/functions/together.py +21 -0
  45. pixeltable/functions/util.py +11 -0
  46. pixeltable/globals.py +425 -0
  47. pixeltable/index/__init__.py +2 -0
  48. pixeltable/index/base.py +51 -0
  49. pixeltable/index/embedding_index.py +168 -0
  50. pixeltable/io/__init__.py +3 -0
  51. pixeltable/{utils → io}/hf_datasets.py +48 -17
  52. pixeltable/io/pandas.py +148 -0
  53. pixeltable/{utils → io}/parquet.py +58 -33
  54. pixeltable/iterators/__init__.py +1 -1
  55. pixeltable/iterators/base.py +4 -0
  56. pixeltable/iterators/document.py +218 -97
  57. pixeltable/iterators/video.py +8 -9
  58. pixeltable/metadata/__init__.py +7 -3
  59. pixeltable/metadata/converters/convert_12.py +3 -0
  60. pixeltable/metadata/converters/convert_13.py +41 -0
  61. pixeltable/metadata/schema.py +45 -22
  62. pixeltable/plan.py +15 -51
  63. pixeltable/store.py +38 -41
  64. pixeltable/tool/create_test_db_dump.py +39 -4
  65. pixeltable/type_system.py +47 -96
  66. pixeltable/utils/documents.py +42 -12
  67. pixeltable/utils/http_server.py +70 -0
  68. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/METADATA +14 -10
  69. pixeltable-0.2.6.dist-info/RECORD +119 -0
  70. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  71. pixeltable/client.py +0 -604
  72. pixeltable/exprs/image_similarity_predicate.py +0 -58
  73. pixeltable/func/batched_function.py +0 -53
  74. pixeltable/tests/conftest.py +0 -177
  75. pixeltable/tests/functions/test_fireworks.py +0 -42
  76. pixeltable/tests/functions/test_functions.py +0 -60
  77. pixeltable/tests/functions/test_huggingface.py +0 -158
  78. pixeltable/tests/functions/test_openai.py +0 -152
  79. pixeltable/tests/functions/test_together.py +0 -111
  80. pixeltable/tests/test_audio.py +0 -65
  81. pixeltable/tests/test_catalog.py +0 -27
  82. pixeltable/tests/test_client.py +0 -21
  83. pixeltable/tests/test_component_view.py +0 -370
  84. pixeltable/tests/test_dataframe.py +0 -439
  85. pixeltable/tests/test_dirs.py +0 -107
  86. pixeltable/tests/test_document.py +0 -120
  87. pixeltable/tests/test_exprs.py +0 -805
  88. pixeltable/tests/test_function.py +0 -324
  89. pixeltable/tests/test_migration.py +0 -43
  90. pixeltable/tests/test_nos.py +0 -54
  91. pixeltable/tests/test_snapshot.py +0 -208
  92. pixeltable/tests/test_table.py +0 -1267
  93. pixeltable/tests/test_transactional_directory.py +0 -42
  94. pixeltable/tests/test_types.py +0 -22
  95. pixeltable/tests/test_video.py +0 -159
  96. pixeltable/tests/test_view.py +0 -530
  97. pixeltable/tests/utils.py +0 -408
  98. pixeltable-0.2.4.dist-info/RECORD +0 -132
  99. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
pixeltable/env.py CHANGED
@@ -10,8 +10,8 @@ import os
10
10
  import socketserver
11
11
  import sys
12
12
  import threading
13
- import typing
14
13
  import uuid
14
+ import warnings
15
15
  from pathlib import Path
16
16
  from typing import Callable, Optional, Dict, Any, List
17
17
 
@@ -19,22 +19,28 @@ import pgserver
19
19
  import sqlalchemy as sql
20
20
  import yaml
21
21
  from sqlalchemy_utils.functions import database_exists, create_database, drop_database
22
+ from tqdm import TqdmWarning
22
23
 
23
24
  import pixeltable.exceptions as excs
24
25
  from pixeltable import metadata
26
+ from pixeltable.utils.http_server import make_server
25
27
 
26
28
 
27
29
  class Env:
28
30
  """
29
31
  Store for runtime globals.
30
32
  """
33
+
31
34
  _instance: Optional[Env] = None
32
35
  _log_fmt_str = '%(asctime)s %(levelname)s %(name)s %(filename)s:%(lineno)d: %(message)s'
33
36
 
34
37
  @classmethod
35
38
  def get(cls) -> Env:
36
39
  if cls._instance is None:
37
- cls._instance = Env()
40
+ env = Env()
41
+ env._set_up()
42
+ env._upgrade_metadata()
43
+ cls._instance = env
38
44
  return cls._instance
39
45
 
40
46
  def __init__(self):
@@ -45,7 +51,7 @@ class Env:
45
51
  self._log_dir: Optional[Path] = None # log files
46
52
  self._tmp_dir: Optional[Path] = None # any tmp files
47
53
  self._sa_engine: Optional[sql.engine.base.Engine] = None
48
- self._pgdata_dir : Optional[Path] = None
54
+ self._pgdata_dir: Optional[Path] = None
49
55
  self._db_name: Optional[str] = None
50
56
  self._db_server: Optional[pgserver.PostgresServer] = None
51
57
  self._db_url: Optional[str] = None
@@ -55,7 +61,7 @@ class Env:
55
61
  self._installed_packages: Dict[str, Optional[List[int]]] = {}
56
62
  self._nos_client: Optional[Any] = None
57
63
  self._spacy_nlp: Optional[Any] = None # spacy.Language
58
- self._httpd: Optional[socketserver.TCPServer] = None
64
+ self._httpd: Optional[http.server.ThreadingHTTPServer] = None
59
65
  self._http_address: Optional[str] = None
60
66
 
61
67
  self._registered_clients: dict[str, Any] = {}
@@ -93,13 +99,43 @@ class Env:
93
99
  assert self._http_address is not None
94
100
  return self._http_address
95
101
 
102
+ def configure_logging(
103
+ self,
104
+ *,
105
+ to_stdout: Optional[bool] = None,
106
+ level: Optional[int] = None,
107
+ add: Optional[str] = None,
108
+ remove: Optional[str] = None,
109
+ ) -> None:
110
+ """Configure logging.
111
+
112
+ Args:
113
+ to_stdout: if True, also log to stdout
114
+ level: default log level
115
+ add: comma-separated list of 'module name:log level' pairs; ex.: add='video:10'
116
+ remove: comma-separated list of module names
117
+ """
118
+ if to_stdout is not None:
119
+ self.log_to_stdout(to_stdout)
120
+ if level is not None:
121
+ self.set_log_level(level)
122
+ if add is not None:
123
+ for module, level in [t.split(':') for t in add.split(',')]:
124
+ self.set_module_log_level(module, int(level))
125
+ if remove is not None:
126
+ for module in remove.split(','):
127
+ self.set_module_log_level(module, None)
128
+ if to_stdout is None and level is None and add is None and remove is None:
129
+ self.print_log_config()
130
+
96
131
  def print_log_config(self) -> None:
97
132
  print(f'logging to {self._logfilename}')
98
133
  print(f'{"" if self._log_to_stdout else "not "}logging to stdout')
99
134
  print(f'default log level: {logging.getLevelName(self._default_log_level)}')
100
135
  print(
101
136
  f'module log levels: '
102
- f'{",".join([name + ":" + logging.getLevelName(val) for name, val in self._module_log_level.items()])}')
137
+ f'{",".join([name + ":" + logging.getLevelName(val) for name, val in self._module_log_level.items()])}'
138
+ )
103
139
 
104
140
  def log_to_stdout(self, enable: bool = True) -> None:
105
141
  self._log_to_stdout = enable
@@ -134,10 +170,14 @@ class Env:
134
170
  else:
135
171
  return False
136
172
 
137
- def set_up(self, echo: bool = False, reinit_db: bool = False) -> None:
173
+ def _set_up(self, echo: bool = False, reinit_db: bool = False) -> None:
138
174
  if self._initialized:
139
175
  return
140
176
 
177
+ # Disable spurious warnings
178
+ warnings.simplefilter('ignore', category=TqdmWarning)
179
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
180
+
141
181
  self._initialized = True
142
182
  home = Path(os.environ.get('PIXELTABLE_HOME', str(Path.home() / '.pixeltable')))
143
183
  assert self._home is None or self._home == home
@@ -188,11 +228,29 @@ class Env:
188
228
  fh = logging.FileHandler(self._log_dir / self._logfilename, mode='w')
189
229
  fh.setFormatter(logging.Formatter(self._log_fmt_str))
190
230
  self._logger.addHandler(fh)
231
+
232
+ # configure sqlalchemy logging
191
233
  sql_logger = logging.getLogger('sqlalchemy.engine')
192
234
  sql_logger.setLevel(logging.INFO)
193
235
  sql_logger.addHandler(fh)
194
236
  sql_logger.propagate = False
195
237
 
238
+ # configure pyav logging
239
+ av_logfilename = self._logfilename.replace('.log', '_av.log')
240
+ av_fh = logging.FileHandler(self._log_dir / av_logfilename, mode='w')
241
+ av_fh.setFormatter(logging.Formatter(self._log_fmt_str))
242
+ av_logger = logging.getLogger('libav')
243
+ av_logger.addHandler(av_fh)
244
+ av_logger.propagate = False
245
+
246
+ # configure web-server logging
247
+ http_logfilename = self._logfilename.replace('.log', '_http.log')
248
+ http_fh = logging.FileHandler(self._log_dir / http_logfilename, mode='w')
249
+ http_fh.setFormatter(logging.Formatter(self._log_fmt_str))
250
+ http_logger = logging.getLogger('pixeltable.http.server')
251
+ http_logger.addHandler(http_fh)
252
+ http_logger.propagate = False
253
+
196
254
  # empty tmp dir
197
255
  for path in glob.glob(f'{self._tmp_dir}/*'):
198
256
  os.remove(path)
@@ -213,6 +271,7 @@ class Env:
213
271
  create_database(self.db_url)
214
272
  self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True)
215
273
  from pixeltable.metadata import schema
274
+
216
275
  schema.Base.metadata.create_all(self._sa_engine)
217
276
  metadata.create_system_info(self._sa_engine)
218
277
  # enable pgvector
@@ -229,11 +288,12 @@ class Env:
229
288
  self._set_up_runtime()
230
289
  self.log_to_stdout(False)
231
290
 
232
- def upgrade_metadata(self) -> None:
291
+ def _upgrade_metadata(self) -> None:
233
292
  metadata.upgrade_md(self._sa_engine)
234
293
 
235
294
  def _create_nos_client(self) -> None:
236
295
  import nos
296
+
237
297
  self._logger.info('connecting to NOS')
238
298
  nos.init(logging_level=logging.DEBUG)
239
299
  self._nos_client = nos.client.InferenceClient()
@@ -242,6 +302,7 @@ class Env:
242
302
 
243
303
  # now that we have a client, we can create the module
244
304
  import importlib
305
+
245
306
  try:
246
307
  importlib.import_module('pixeltable.functions.nos')
247
308
  # it's already been created
@@ -249,6 +310,7 @@ class Env:
249
310
  except ImportError:
250
311
  pass
251
312
  from pixeltable.functions.util import create_nos_modules
313
+
252
314
  _ = create_nos_modules()
253
315
 
254
316
  def get_client(self, name: str, init: Callable, environ: Optional[str] = None) -> Any:
@@ -282,16 +344,13 @@ class Env:
282
344
  """
283
345
  The http server root is the file system root.
284
346
  eg: /home/media/foo.mp4 is located at http://127.0.0.1:{port}/home/media/foo.mp4
347
+ in windows, the server will translate paths like http://127.0.0.1:{port}/c:/media/foo.mp4
285
348
  This arrangement enables serving media hosted within _home,
286
349
  as well as external media inserted into pixeltable or produced by pixeltable.
287
350
  The port is chosen dynamically to prevent conflicts.
288
351
  """
289
352
  # Port 0 means OS picks one for us.
290
- address = ("127.0.0.1", 0)
291
- class FixedRootHandler(http.server.SimpleHTTPRequestHandler):
292
- def __init__(self, *args, **kwargs):
293
- super().__init__(*args, directory='/', **kwargs)
294
- self._httpd = socketserver.TCPServer(address, FixedRootHandler)
353
+ self._httpd = make_server('127.0.0.1', 0)
295
354
  port = self._httpd.server_address[1]
296
355
  self._http_address = f'http://127.0.0.1:{port}'
297
356
 
@@ -320,11 +379,14 @@ class Env:
320
379
  check('torchvision')
321
380
  check('transformers')
322
381
  check('sentence_transformers')
382
+ check('yolox')
323
383
  check('boto3')
384
+ check('fitz') # pymupdf
324
385
  check('pyarrow')
325
386
  check('spacy') # TODO: deal with en-core-web-sm
326
387
  if self.is_installed_package('spacy'):
327
388
  import spacy
389
+
328
390
  self._spacy_nlp = spacy.load('en_core_web_sm')
329
391
  check('tiktoken')
330
392
  check('openai')
@@ -333,6 +395,7 @@ class Env:
333
395
  check('nos')
334
396
  if self.is_installed_package('nos'):
335
397
  self._create_nos_client()
398
+ check('openpyxl')
336
399
 
337
400
  def require_package(self, package: str, min_version: Optional[List[int]] = None) -> None:
338
401
  assert package in self._installed_packages
@@ -350,9 +413,12 @@ class Env:
350
413
  if len(min_version) < len(installed_version):
351
414
  normalized_min_version = min_version + [0] * (len(installed_version) - len(min_version))
352
415
  if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
353
- raise excs.Error((
354
- f'The installed version of package {package} is {".".join([str[v] for v in installed_version])}, '
355
- f'but version >={".".join([str[v] for v in min_version])} is required'))
416
+ raise excs.Error(
417
+ (
418
+ f'The installed version of package {package} is {".".join([str[v] for v in installed_version])}, '
419
+ f'but version >={".".join([str[v] for v in min_version])} is required'
420
+ )
421
+ )
356
422
 
357
423
  def num_tmp_files(self) -> int:
358
424
  return len(glob.glob(f'{self._tmp_dir}/*'))
@@ -397,4 +463,4 @@ class Env:
397
463
  @property
398
464
  def spacy_nlp(self) -> Any:
399
465
  assert self._spacy_nlp is not None
400
- return self._spacy_nlp
466
+ return self._spacy_nlp
@@ -6,4 +6,5 @@ from .exec_node import ExecNode
6
6
  from .expr_eval_node import ExprEvalNode
7
7
  from .in_memory_data_node import InMemoryDataNode
8
8
  from .sql_scan_node import SqlScanNode
9
- from .media_validation_node import MediaValidationNode
9
+ from .media_validation_node import MediaValidationNode
10
+ from .data_row_batch import DataRowBatch
@@ -89,7 +89,7 @@ class CachePrefetchNode(ExecNode):
89
89
  # preserve the file extension, if there is one
90
90
  extension = ''
91
91
  if parsed.path != '':
92
- p = Path(urllib.parse.unquote(parsed.path))
92
+ p = Path(urllib.parse.unquote(urllib.request.url2pathname(parsed.path)))
93
93
  extension = p.suffix
94
94
  tmp_path = env.Env.get().create_tmp_path(extension=extension)
95
95
  try:
@@ -14,9 +14,8 @@ class DataRowBatch:
14
14
 
15
15
  Contains the metadata needed to initialize DataRows.
16
16
  """
17
- def __init__(self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, len: int = 0):
18
- self.tbl_id = tbl.id
19
- self.tbl_version = tbl.version
17
+ def __init__(self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, len: int = 0):
18
+ self.tbl = tbl
20
19
  self.row_builder = row_builder
21
20
  self.img_slot_idxs = [e.slot_idx for e in row_builder.unique_exprs if e.col_type.is_image_type()]
22
21
  # non-image media slots
@@ -42,9 +41,10 @@ class DataRowBatch:
42
41
 
43
42
  def set_row_ids(self, row_ids: List[int]) -> None:
44
43
  """Sets pks for rows in batch"""
44
+ assert self.tbl is not None
45
45
  assert len(row_ids) == len(self.rows)
46
46
  for row, row_id in zip(self.rows, row_ids):
47
- row.set_pk((row_id, self.tbl_version))
47
+ row.set_pk((row_id, self.tbl))
48
48
 
49
49
  def __len__(self) -> int:
50
50
  return len(self.rows)
@@ -57,6 +57,7 @@ class DataRowBatch:
57
57
  flushed_slot_idxs: Optional[List[int]] = None
58
58
  ) -> None:
59
59
  """Flushes images in the given range of rows."""
60
+ assert self.tbl is not None
60
61
  if stored_img_info is None:
61
62
  stored_img_info = []
62
63
  if flushed_slot_idxs is None:
@@ -67,12 +68,10 @@ class DataRowBatch:
67
68
  idx_range = slice(0, len(self.rows))
68
69
  for row in self.rows[idx_range]:
69
70
  for info in stored_img_info:
70
- filepath = str(MediaStore.prepare_media_path(self.tbl_id, info.col.id, self.tbl_version))
71
+ filepath = str(MediaStore.prepare_media_path(self.tbl.id, info.col.id, self.tbl.version))
71
72
  row.flush_img(info.slot_idx, filepath)
72
73
  for slot_idx in flushed_slot_idxs:
73
74
  row.flush_img(slot_idx)
74
- #_logger.debug(
75
- #f'flushed images in range {idx_range}: slot_idxs={flushed_slot_idxs} stored_img_info={stored_img_info}')
76
75
 
77
76
  def __iter__(self) -> Iterator[exprs.DataRow]:
78
77
  return DataRowBatchIterator(self)
@@ -1,20 +1,20 @@
1
- import sys
2
- import warnings
3
- from typing import List, Optional, Tuple
4
- from dataclasses import dataclass, field
5
1
  import logging
2
+ import sys
6
3
  import time
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional
7
7
 
8
8
  from tqdm import tqdm, TqdmWarning
9
9
 
10
+ import pixeltable.exprs as exprs
11
+ from pixeltable.func import CallableFunction
10
12
  from .data_row_batch import DataRowBatch
11
13
  from .exec_node import ExecNode
12
- import pixeltable.exprs as exprs
13
- import pixeltable.func as func
14
-
15
14
 
16
15
  _logger = logging.getLogger('pixeltable')
17
16
 
17
+
18
18
  class ExprEvalNode(ExecNode):
19
19
  """Materializes expressions
20
20
  """
@@ -22,7 +22,7 @@ class ExprEvalNode(ExecNode):
22
22
  class Cohort:
23
23
  """List of exprs that form an evaluation context and contain calls to at most one external function"""
24
24
  exprs: List[exprs.Expr]
25
- ext_function: Optional[func.BatchedFunction]
25
+ batched_fn: Optional[CallableFunction]
26
26
  segment_ctxs: List[exprs.RowBuilder.EvalCtx]
27
27
  target_slot_idxs: List[int]
28
28
  batch_size: int = 8
@@ -63,12 +63,12 @@ class ExprEvalNode(ExecNode):
63
63
  if self.pbar is not None:
64
64
  self.pbar.close()
65
65
 
66
- def _get_batched_fn(self, expr: exprs.Expr) -> Optional[func.BatchedFunction]:
67
- if not isinstance(expr, exprs.FunctionCall):
68
- return None
69
- return expr.fn if isinstance(expr.fn, func.BatchedFunction) else None
66
+ def _get_batched_fn(self, expr: exprs.Expr) -> Optional[CallableFunction]:
67
+ if isinstance(expr, exprs.FunctionCall) and isinstance(expr.fn, CallableFunction) and expr.fn.is_batched:
68
+ return expr.fn
69
+ return None
70
70
 
71
- def _is_ext_call(self, expr: exprs.Expr) -> bool:
71
+ def _is_batched_fn_call(self, expr: exprs.Expr) -> bool:
72
72
  return self._get_batched_fn(expr) is not None
73
73
 
74
74
  def _create_cohorts(self) -> None:
@@ -76,14 +76,14 @@ class ExprEvalNode(ExecNode):
76
76
  # break up all_exprs into cohorts such that each cohort contains calls to at most one external function;
77
77
  # seed the cohorts with only the ext fn calls
78
78
  cohorts: List[List[exprs.Expr]] = []
79
- current_ext_function: Optional[func.BatchedFunction] = None
79
+ current_batched_fn: Optional[CallableFunction] = None
80
80
  for e in all_exprs:
81
- if not self._is_ext_call(e):
81
+ if not self._is_batched_fn_call(e):
82
82
  continue
83
- if current_ext_function is None or current_ext_function != e.fn:
83
+ if current_batched_fn is None or current_batched_fn != e.fn:
84
84
  # create a new cohort
85
85
  cohorts.append([])
86
- current_ext_function = e.fn
86
+ current_batched_fn = e.fn
87
87
  cohorts[-1].append(e)
88
88
 
89
89
  # expand the cohorts to include all exprs that are in the same evaluation context as the external calls;
@@ -115,18 +115,18 @@ class ExprEvalNode(ExecNode):
115
115
  assert len(cohort) > 0
116
116
  # create the first segment here, so we can avoid checking for an empty list in the loop
117
117
  segments = [[cohort[0]]]
118
- is_ext_segment = self._is_ext_call(cohort[0])
119
- ext_fn: Optional[func.BatchedFunction] = self._get_batched_fn(cohort[0])
118
+ is_batched_segment = self._is_batched_fn_call(cohort[0])
119
+ batched_fn: Optional[CallableFunction] = self._get_batched_fn(cohort[0])
120
120
  for e in cohort[1:]:
121
- if self._is_ext_call(e):
121
+ if self._is_batched_fn_call(e):
122
122
  segments.append([e])
123
- is_ext_segment = True
124
- ext_fn = self._get_batched_fn(e)
123
+ is_batched_segment = True
124
+ batched_fn = self._get_batched_fn(e)
125
125
  else:
126
- if is_ext_segment:
126
+ if is_batched_segment:
127
127
  # start a new segment
128
128
  segments.append([])
129
- is_ext_segment = False
129
+ is_batched_segment = False
130
130
  segments[-1].append(e)
131
131
 
132
132
  # we create the EvalCtxs manually because create_eval_ctx() would repeat the dependencies of each segment
@@ -135,21 +135,21 @@ class ExprEvalNode(ExecNode):
135
135
  slot_idxs=[e.slot_idx for e in s], exprs=s, target_slot_idxs=[], target_exprs=[])
136
136
  for s in segments
137
137
  ]
138
- cohort_info = self.Cohort(cohort, ext_fn, segment_ctxs, target_slot_idxs[i])
138
+ cohort_info = self.Cohort(cohort, batched_fn, segment_ctxs, target_slot_idxs[i])
139
139
  self.cohorts.append(cohort_info)
140
140
 
141
141
  def _exec_cohort(self, cohort: Cohort, rows: DataRowBatch) -> None:
142
142
  """Compute the cohort for the entire input batch by dividing it up into sub-batches"""
143
143
  batch_start_idx = 0 # start row of the current sub-batch
144
144
  # for multi-resolution models, we re-assess the correct ext fn batch size for each input batch
145
- ext_batch_size = cohort.ext_function.get_batch_size() if cohort.ext_function is not None else None
145
+ ext_batch_size = cohort.batched_fn.get_batch_size() if cohort.batched_fn is not None else None
146
146
  if ext_batch_size is not None:
147
147
  cohort.batch_size = ext_batch_size
148
148
 
149
149
  while batch_start_idx < len(rows):
150
150
  num_batch_rows = min(cohort.batch_size, len(rows) - batch_start_idx)
151
151
  for segment_ctx in cohort.segment_ctxs:
152
- if not self._is_ext_call(segment_ctx.exprs[0]):
152
+ if not self._is_batched_fn_call(segment_ctx.exprs[0]):
153
153
  # compute batch row-wise
154
154
  for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
155
155
  self.row_builder.eval(
@@ -193,7 +193,7 @@ class ExprEvalNode(ExecNode):
193
193
  for k in kwarg_batches.keys()
194
194
  }
195
195
  start_ts = time.perf_counter()
196
- result_batch = fn_call.fn.invoke(call_args, call_kwargs)
196
+ result_batch = fn_call.fn.exec_batch(*call_args, **call_kwargs)
197
197
  self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
198
198
  self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
199
199
 
@@ -29,18 +29,21 @@ class InMemoryDataNode(ExecNode):
29
29
 
30
30
  def _open(self) -> None:
31
31
  """Create row batch and populate with self.input_rows"""
32
- column_info = {info.col.name: info for info in self.row_builder.output_slot_idxs()}
32
+ column_info = {info.col.id: info for info in self.row_builder.output_slot_idxs()}
33
+ # exclude system columns
34
+ user_column_info = {info.col.name: info for _, info in column_info.items() if info.col.name is not None}
33
35
  # stored columns that are not computed
34
- inserted_column_names = set([
35
- info.col.name for info in self.row_builder.output_slot_idxs()
36
+ inserted_col_ids = set([
37
+ info.col.id for info in self.row_builder.output_slot_idxs()
36
38
  if info.col.is_stored and not info.col.is_computed
37
39
  ])
38
40
 
39
41
  self.output_rows = DataRowBatch(self.tbl, self.row_builder, len(self.input_rows))
40
42
  for row_idx, input_row in enumerate(self.input_rows):
41
43
  # populate the output row with the values provided in the input row
44
+ input_col_ids: List[int] = []
42
45
  for col_name, val in input_row.items():
43
- col_info = column_info.get(col_name)
46
+ col_info = user_column_info.get(col_name)
44
47
  assert col_info is not None
45
48
 
46
49
  if col_info.col.col_type.is_image_type() and isinstance(val, bytes):
@@ -49,11 +52,12 @@ class InMemoryDataNode(ExecNode):
49
52
  open(path, 'wb').write(val)
50
53
  val = path
51
54
  self.output_rows[row_idx][col_info.slot_idx] = val
55
+ input_col_ids.append(col_info.col.id)
52
56
 
53
57
  # set the remaining stored non-computed columns to null
54
- null_col_names = inserted_column_names - set(input_row.keys())
55
- for col_name in null_col_names:
56
- col_info = column_info.get(col_name)
58
+ null_col_ids = inserted_col_ids - set(input_col_ids)
59
+ for col_id in null_col_ids:
60
+ col_info = column_info.get(col_id)
57
61
  assert col_info is not None
58
62
  self.output_rows[row_idx][col_info.slot_idx] = None
59
63
 
@@ -21,7 +21,6 @@ class SqlScanNode(ExecNode):
21
21
  select_list: Iterable[exprs.Expr],
22
22
  where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Predicate] = None,
23
23
  order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
24
- similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None,
25
24
  limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
26
25
  ):
27
26
  """
@@ -77,15 +76,17 @@ class SqlScanNode(ExecNode):
77
76
  # the number of tables that need to be joined to the target table
78
77
  for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
79
78
  rowid_ref.set_tbl(tbl)
80
- order_by_clause = [e.sql_expr().desc() if not asc else e.sql_expr() for e, asc in order_by_items]
79
+ order_by_clause: List[sql.ClauseElement] = []
80
+ for e, asc in order_by_items:
81
+ if isinstance(e, exprs.SimilarityExpr):
82
+ order_by_clause.append(e.as_order_by_clause(asc))
83
+ else:
84
+ order_by_clause.append(e.sql_expr().desc() if not asc else e.sql_expr())
81
85
 
82
86
  if where_clause is not None:
83
87
  sql_where_clause = where_clause.sql_expr()
84
88
  assert sql_where_clause is not None
85
89
  self.stmt = self.stmt.where(sql_where_clause)
86
- if similarity_clause is not None:
87
- self.stmt = self.stmt.order_by(
88
- similarity_clause.img_col_ref.col.sa_idx_col.l2_distance(similarity_clause.embedding()))
89
90
  if len(order_by_clause) > 0:
90
91
  self.stmt = self.stmt.order_by(*order_by_clause)
91
92
  elif target.id in row_builder.unstored_iter_args:
@@ -201,7 +202,7 @@ class SqlScanNode(ExecNode):
201
202
  self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
202
203
  if output_row[self.filter.slot_idx]:
203
204
  needs_row = True
204
- if self.limit is not None and len(output_batch) >= self.limit:
205
+ if self.limit > 0 and len(output_batch) >= self.limit:
205
206
  self.has_more_rows = False
206
207
  break
207
208
  else:
@@ -6,9 +6,10 @@ from .comparison import Comparison
6
6
  from .compound_predicate import CompoundPredicate
7
7
  from .data_row import DataRow
8
8
  from .expr import Expr
9
+ from .expr_set import ExprSet
9
10
  from .function_call import FunctionCall
10
11
  from .image_member_access import ImageMemberAccess
11
- from .image_similarity_predicate import ImageSimilarityPredicate
12
+ from .in_predicate import InPredicate
12
13
  from .inline_array import InlineArray
13
14
  from .inline_dict import InlineDict
14
15
  from .is_null import IsNull
@@ -16,9 +17,9 @@ from .json_mapper import JsonMapper
16
17
  from .json_path import RELATIVE_PATH_ROOT, JsonPath
17
18
  from .literal import Literal
18
19
  from .object_ref import ObjectRef
19
- from .variable import Variable
20
20
  from .predicate import Predicate
21
21
  from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
22
22
  from .rowid_ref import RowidRef
23
- from .expr_set import ExprSet
23
+ from .similarity_expr import SimilarityExpr
24
24
  from .type_cast import TypeCast
25
+ from .variable import Variable
@@ -63,6 +63,15 @@ class ColumnRef(Expr):
63
63
 
64
64
  return super().__getattr__(name)
65
65
 
66
+ def similarity(self, other: Any) -> Expr:
67
+ if isinstance(other, Expr):
68
+ raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
69
+ item = Expr.from_object(other)
70
+ if item is None or not(item.col_type.is_string_type() or item.col_type.is_image_type()):
71
+ raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not a {type(other)}')
72
+ from .similarity_expr import SimilarityExpr
73
+ return SimilarityExpr(self, item)
74
+
66
75
  def default_column_name(self) -> Optional[str]:
67
76
  return str(self)
68
77
 
@@ -1,14 +1,14 @@
1
1
  from __future__ import annotations
2
+
2
3
  from typing import Optional, List, Any, Dict, Tuple
3
4
 
4
5
  import sqlalchemy as sql
5
6
 
6
- from .globals import ComparisonOperator
7
+ from .data_row import DataRow
7
8
  from .expr import Expr
9
+ from .globals import ComparisonOperator
8
10
  from .predicate import Predicate
9
- from .data_row import DataRow
10
11
  from .row_builder import RowBuilder
11
- import pixeltable.catalog as catalog
12
12
 
13
13
 
14
14
  class Comparison(Predicate):
@@ -5,6 +5,8 @@ import urllib.parse
5
5
  import urllib.request
6
6
  from typing import Optional, List, Any, Tuple
7
7
 
8
+ import sqlalchemy as sql
9
+ import pgvector.sqlalchemy
8
10
  import PIL
9
11
  import numpy as np
10
12
 
@@ -110,7 +112,7 @@ class DataRow:
110
112
 
111
113
  return self.vals[index]
112
114
 
113
- def get_stored_val(self, index: object) -> Any:
115
+ def get_stored_val(self, index: object, sa_col_type: Optional[sql.types.TypeEngine] = None) -> Any:
114
116
  """Return the value that gets stored in the db"""
115
117
  assert self.excs[index] is None
116
118
  if not self.has_val[index]:
@@ -125,6 +127,8 @@ class DataRow:
125
127
  if self.vals[index] is not None and index in self.array_slot_idxs:
126
128
  assert isinstance(self.vals[index], np.ndarray)
127
129
  np_array = self.vals[index]
130
+ if sa_col_type is not None and isinstance(sa_col_type, pgvector.sqlalchemy.Vector):
131
+ return np_array
128
132
  buffer = io.BytesIO()
129
133
  np.save(buffer, np_array)
130
134
  return buffer.getvalue()
pixeltable/exprs/expr.py CHANGED
@@ -60,9 +60,9 @@ class Expr(abc.ABC):
60
60
 
61
61
  # index of the expr's value in the data row:
62
62
  # - set for all materialized exprs
63
- # - -1: not executable
63
+ # - None: not executable
64
64
  # - not set for subexprs that don't need to be materialized because the parent can be materialized via SQL
65
- self.slot_idx = -1
65
+ self.slot_idx: Optional[int] = None
66
66
  self.components: List[Expr] = [] # the subexprs that are needed to construct this expr
67
67
 
68
68
  def dependencies(self) -> List[Expr]:
@@ -110,6 +110,11 @@ class Expr(abc.ABC):
110
110
  return False
111
111
  return self._equals(other)
112
112
 
113
+ def _equals(self, other: Expr) -> bool:
114
+ # we already compared the type and components in equals(); subclasses that require additional comparisons
115
+ # override this
116
+ return True
117
+
113
118
  def _id_attrs(self) -> List[Tuple[str, Any]]:
114
119
  """Returns attribute name/value pairs that are used to construct the instance id.
115
120
 
@@ -148,7 +153,7 @@ class Expr(abc.ABC):
148
153
  cls = self.__class__
149
154
  result = cls.__new__(cls)
150
155
  result.__dict__.update(self.__dict__)
151
- result.slot_idx = -1
156
+ result.slot_idx = None
152
157
  result.components = [c.copy() for c in self.components]
153
158
  return result
154
159
 
@@ -313,10 +318,6 @@ class Expr(abc.ABC):
313
318
  return InlineArray(tuple(o))
314
319
  return None
315
320
 
316
- @abc.abstractmethod
317
- def _equals(self, other: Expr) -> bool:
318
- pass
319
-
320
321
  @abc.abstractmethod
321
322
  def sql_expr(self) -> Optional[sql.ClauseElement]:
322
323
  """
@@ -396,6 +397,13 @@ class Expr(abc.ABC):
396
397
  def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
397
398
  assert False, 'not implemented'
398
399
 
400
+ def isin(self, value_set: Any) -> 'pixeltable.exprs.InPredicate':
401
+ from .in_predicate import InPredicate
402
+ if isinstance(value_set, Expr):
403
+ return InPredicate(self, value_set_expr=value_set)
404
+ else:
405
+ return InPredicate(self, value_set_literal=value_set)
406
+
399
407
  def astype(self, new_type: ts.ColumnType) -> 'pixeltable.exprs.TypeCast':
400
408
  from pixeltable.exprs import TypeCast
401
409
  return TypeCast(self, new_type)