pixeltable 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (87) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +9 -5
  4. pixeltable/catalog/insertable_table.py +0 -2
  5. pixeltable/catalog/table.py +16 -8
  6. pixeltable/catalog/table_version.py +3 -2
  7. pixeltable/dataframe.py +184 -110
  8. pixeltable/env.py +69 -18
  9. pixeltable/exec/__init__.py +2 -1
  10. pixeltable/exec/data_row_batch.py +6 -7
  11. pixeltable/exec/expr_eval_node.py +28 -28
  12. pixeltable/exec/sql_scan_node.py +7 -6
  13. pixeltable/exprs/__init__.py +4 -3
  14. pixeltable/exprs/column_ref.py +9 -0
  15. pixeltable/exprs/expr.py +15 -7
  16. pixeltable/exprs/function_call.py +17 -15
  17. pixeltable/exprs/image_member_access.py +9 -28
  18. pixeltable/exprs/in_predicate.py +96 -0
  19. pixeltable/exprs/inline_array.py +13 -11
  20. pixeltable/exprs/inline_dict.py +15 -13
  21. pixeltable/exprs/row_builder.py +7 -1
  22. pixeltable/exprs/similarity_expr.py +65 -0
  23. pixeltable/func/__init__.py +0 -2
  24. pixeltable/func/aggregate_function.py +3 -0
  25. pixeltable/func/callable_function.py +57 -13
  26. pixeltable/func/expr_template_function.py +11 -2
  27. pixeltable/func/function.py +35 -4
  28. pixeltable/func/signature.py +5 -15
  29. pixeltable/func/udf.py +6 -10
  30. pixeltable/functions/huggingface.py +23 -4
  31. pixeltable/functions/openai.py +34 -1
  32. pixeltable/functions/pil/image.py +61 -64
  33. pixeltable/functions/together.py +21 -0
  34. pixeltable/globals.py +425 -0
  35. pixeltable/index/base.py +3 -1
  36. pixeltable/index/embedding_index.py +87 -14
  37. pixeltable/io/__init__.py +3 -0
  38. pixeltable/{utils → io}/hf_datasets.py +48 -17
  39. pixeltable/io/pandas.py +148 -0
  40. pixeltable/{utils → io}/parquet.py +58 -33
  41. pixeltable/iterators/__init__.py +1 -1
  42. pixeltable/iterators/base.py +4 -0
  43. pixeltable/iterators/document.py +218 -97
  44. pixeltable/iterators/video.py +8 -9
  45. pixeltable/metadata/__init__.py +7 -3
  46. pixeltable/metadata/converters/convert_12.py +3 -0
  47. pixeltable/metadata/converters/convert_13.py +41 -0
  48. pixeltable/plan.py +2 -19
  49. pixeltable/store.py +2 -2
  50. pixeltable/tool/create_test_db_dump.py +32 -13
  51. pixeltable/type_system.py +13 -54
  52. pixeltable/utils/documents.py +42 -12
  53. pixeltable/utils/http_server.py +70 -0
  54. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/METADATA +10 -7
  55. pixeltable-0.2.6.dist-info/RECORD +119 -0
  56. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  57. pixeltable/client.py +0 -600
  58. pixeltable/exprs/image_similarity_predicate.py +0 -58
  59. pixeltable/func/batched_function.py +0 -53
  60. pixeltable/tests/conftest.py +0 -171
  61. pixeltable/tests/ext/test_yolox.py +0 -21
  62. pixeltable/tests/functions/test_fireworks.py +0 -43
  63. pixeltable/tests/functions/test_functions.py +0 -60
  64. pixeltable/tests/functions/test_huggingface.py +0 -158
  65. pixeltable/tests/functions/test_openai.py +0 -162
  66. pixeltable/tests/functions/test_together.py +0 -112
  67. pixeltable/tests/test_audio.py +0 -65
  68. pixeltable/tests/test_catalog.py +0 -27
  69. pixeltable/tests/test_client.py +0 -21
  70. pixeltable/tests/test_component_view.py +0 -379
  71. pixeltable/tests/test_dataframe.py +0 -440
  72. pixeltable/tests/test_dirs.py +0 -107
  73. pixeltable/tests/test_document.py +0 -120
  74. pixeltable/tests/test_exprs.py +0 -802
  75. pixeltable/tests/test_function.py +0 -332
  76. pixeltable/tests/test_index.py +0 -138
  77. pixeltable/tests/test_migration.py +0 -44
  78. pixeltable/tests/test_nos.py +0 -54
  79. pixeltable/tests/test_snapshot.py +0 -231
  80. pixeltable/tests/test_table.py +0 -1343
  81. pixeltable/tests/test_transactional_directory.py +0 -42
  82. pixeltable/tests/test_types.py +0 -52
  83. pixeltable/tests/test_video.py +0 -159
  84. pixeltable/tests/test_view.py +0 -535
  85. pixeltable/tests/utils.py +0 -442
  86. pixeltable-0.2.5.dist-info/RECORD +0 -139
  87. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
pixeltable/env.py CHANGED
@@ -23,19 +23,24 @@ from tqdm import TqdmWarning
23
23
 
24
24
  import pixeltable.exceptions as excs
25
25
  from pixeltable import metadata
26
+ from pixeltable.utils.http_server import make_server
26
27
 
27
28
 
28
29
  class Env:
29
30
  """
30
31
  Store for runtime globals.
31
32
  """
33
+
32
34
  _instance: Optional[Env] = None
33
35
  _log_fmt_str = '%(asctime)s %(levelname)s %(name)s %(filename)s:%(lineno)d: %(message)s'
34
36
 
35
37
  @classmethod
36
38
  def get(cls) -> Env:
37
39
  if cls._instance is None:
38
- cls._instance = Env()
40
+ env = Env()
41
+ env._set_up()
42
+ env._upgrade_metadata()
43
+ cls._instance = env
39
44
  return cls._instance
40
45
 
41
46
  def __init__(self):
@@ -46,7 +51,7 @@ class Env:
46
51
  self._log_dir: Optional[Path] = None # log files
47
52
  self._tmp_dir: Optional[Path] = None # any tmp files
48
53
  self._sa_engine: Optional[sql.engine.base.Engine] = None
49
- self._pgdata_dir : Optional[Path] = None
54
+ self._pgdata_dir: Optional[Path] = None
50
55
  self._db_name: Optional[str] = None
51
56
  self._db_server: Optional[pgserver.PostgresServer] = None
52
57
  self._db_url: Optional[str] = None
@@ -56,7 +61,7 @@ class Env:
56
61
  self._installed_packages: Dict[str, Optional[List[int]]] = {}
57
62
  self._nos_client: Optional[Any] = None
58
63
  self._spacy_nlp: Optional[Any] = None # spacy.Language
59
- self._httpd: Optional[socketserver.TCPServer] = None
64
+ self._httpd: Optional[http.server.ThreadingHTTPServer] = None
60
65
  self._http_address: Optional[str] = None
61
66
 
62
67
  self._registered_clients: dict[str, Any] = {}
@@ -94,13 +99,43 @@ class Env:
94
99
  assert self._http_address is not None
95
100
  return self._http_address
96
101
 
102
+ def configure_logging(
103
+ self,
104
+ *,
105
+ to_stdout: Optional[bool] = None,
106
+ level: Optional[int] = None,
107
+ add: Optional[str] = None,
108
+ remove: Optional[str] = None,
109
+ ) -> None:
110
+ """Configure logging.
111
+
112
+ Args:
113
+ to_stdout: if True, also log to stdout
114
+ level: default log level
115
+ add: comma-separated list of 'module name:log level' pairs; ex.: add='video:10'
116
+ remove: comma-separated list of module names
117
+ """
118
+ if to_stdout is not None:
119
+ self.log_to_stdout(to_stdout)
120
+ if level is not None:
121
+ self.set_log_level(level)
122
+ if add is not None:
123
+ for module, level in [t.split(':') for t in add.split(',')]:
124
+ self.set_module_log_level(module, int(level))
125
+ if remove is not None:
126
+ for module in remove.split(','):
127
+ self.set_module_log_level(module, None)
128
+ if to_stdout is None and level is None and add is None and remove is None:
129
+ self.print_log_config()
130
+
97
131
  def print_log_config(self) -> None:
98
132
  print(f'logging to {self._logfilename}')
99
133
  print(f'{"" if self._log_to_stdout else "not "}logging to stdout')
100
134
  print(f'default log level: {logging.getLevelName(self._default_log_level)}')
101
135
  print(
102
136
  f'module log levels: '
103
- f'{",".join([name + ":" + logging.getLevelName(val) for name, val in self._module_log_level.items()])}')
137
+ f'{",".join([name + ":" + logging.getLevelName(val) for name, val in self._module_log_level.items()])}'
138
+ )
104
139
 
105
140
  def log_to_stdout(self, enable: bool = True) -> None:
106
141
  self._log_to_stdout = enable
@@ -135,10 +170,14 @@ class Env:
135
170
  else:
136
171
  return False
137
172
 
138
- def set_up(self, echo: bool = False, reinit_db: bool = False) -> None:
173
+ def _set_up(self, echo: bool = False, reinit_db: bool = False) -> None:
139
174
  if self._initialized:
140
175
  return
141
176
 
177
+ # Disable spurious warnings
178
+ warnings.simplefilter('ignore', category=TqdmWarning)
179
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
180
+
142
181
  self._initialized = True
143
182
  home = Path(os.environ.get('PIXELTABLE_HOME', str(Path.home() / '.pixeltable')))
144
183
  assert self._home is None or self._home == home
@@ -204,6 +243,14 @@ class Env:
204
243
  av_logger.addHandler(av_fh)
205
244
  av_logger.propagate = False
206
245
 
246
+ # configure web-server logging
247
+ http_logfilename = self._logfilename.replace('.log', '_http.log')
248
+ http_fh = logging.FileHandler(self._log_dir / http_logfilename, mode='w')
249
+ http_fh.setFormatter(logging.Formatter(self._log_fmt_str))
250
+ http_logger = logging.getLogger('pixeltable.http.server')
251
+ http_logger.addHandler(http_fh)
252
+ http_logger.propagate = False
253
+
207
254
  # empty tmp dir
208
255
  for path in glob.glob(f'{self._tmp_dir}/*'):
209
256
  os.remove(path)
@@ -224,6 +271,7 @@ class Env:
224
271
  create_database(self.db_url)
225
272
  self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True)
226
273
  from pixeltable.metadata import schema
274
+
227
275
  schema.Base.metadata.create_all(self._sa_engine)
228
276
  metadata.create_system_info(self._sa_engine)
229
277
  # enable pgvector
@@ -240,14 +288,12 @@ class Env:
240
288
  self._set_up_runtime()
241
289
  self.log_to_stdout(False)
242
290
 
243
- # Disable spurious warnings
244
- warnings.simplefilter("ignore", category=TqdmWarning)
245
-
246
- def upgrade_metadata(self) -> None:
291
+ def _upgrade_metadata(self) -> None:
247
292
  metadata.upgrade_md(self._sa_engine)
248
293
 
249
294
  def _create_nos_client(self) -> None:
250
295
  import nos
296
+
251
297
  self._logger.info('connecting to NOS')
252
298
  nos.init(logging_level=logging.DEBUG)
253
299
  self._nos_client = nos.client.InferenceClient()
@@ -256,6 +302,7 @@ class Env:
256
302
 
257
303
  # now that we have a client, we can create the module
258
304
  import importlib
305
+
259
306
  try:
260
307
  importlib.import_module('pixeltable.functions.nos')
261
308
  # it's already been created
@@ -263,6 +310,7 @@ class Env:
263
310
  except ImportError:
264
311
  pass
265
312
  from pixeltable.functions.util import create_nos_modules
313
+
266
314
  _ = create_nos_modules()
267
315
 
268
316
  def get_client(self, name: str, init: Callable, environ: Optional[str] = None) -> Any:
@@ -296,16 +344,13 @@ class Env:
296
344
  """
297
345
  The http server root is the file system root.
298
346
  eg: /home/media/foo.mp4 is located at http://127.0.0.1:{port}/home/media/foo.mp4
347
+ in windows, the server will translate paths like http://127.0.0.1:{port}/c:/media/foo.mp4
299
348
  This arrangement enables serving media hosted within _home,
300
349
  as well as external media inserted into pixeltable or produced by pixeltable.
301
350
  The port is chosen dynamically to prevent conflicts.
302
351
  """
303
352
  # Port 0 means OS picks one for us.
304
- address = ("127.0.0.1", 0)
305
- class FixedRootHandler(http.server.SimpleHTTPRequestHandler):
306
- def __init__(self, *args, **kwargs):
307
- super().__init__(*args, directory='/', **kwargs)
308
- self._httpd = socketserver.TCPServer(address, FixedRootHandler)
353
+ self._httpd = make_server('127.0.0.1', 0)
309
354
  port = self._httpd.server_address[1]
310
355
  self._http_address = f'http://127.0.0.1:{port}'
311
356
 
@@ -336,10 +381,12 @@ class Env:
336
381
  check('sentence_transformers')
337
382
  check('yolox')
338
383
  check('boto3')
384
+ check('fitz') # pymupdf
339
385
  check('pyarrow')
340
386
  check('spacy') # TODO: deal with en-core-web-sm
341
387
  if self.is_installed_package('spacy'):
342
388
  import spacy
389
+
343
390
  self._spacy_nlp = spacy.load('en_core_web_sm')
344
391
  check('tiktoken')
345
392
  check('openai')
@@ -348,6 +395,7 @@ class Env:
348
395
  check('nos')
349
396
  if self.is_installed_package('nos'):
350
397
  self._create_nos_client()
398
+ check('openpyxl')
351
399
 
352
400
  def require_package(self, package: str, min_version: Optional[List[int]] = None) -> None:
353
401
  assert package in self._installed_packages
@@ -365,9 +413,12 @@ class Env:
365
413
  if len(min_version) < len(installed_version):
366
414
  normalized_min_version = min_version + [0] * (len(installed_version) - len(min_version))
367
415
  if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
368
- raise excs.Error((
369
- f'The installed version of package {package} is {".".join([str[v] for v in installed_version])}, '
370
- f'but version >={".".join([str[v] for v in min_version])} is required'))
416
+ raise excs.Error(
417
+ (
418
+ f'The installed version of package {package} is {".".join([str[v] for v in installed_version])}, '
419
+ f'but version >={".".join([str[v] for v in min_version])} is required'
420
+ )
421
+ )
371
422
 
372
423
  def num_tmp_files(self) -> int:
373
424
  return len(glob.glob(f'{self._tmp_dir}/*'))
@@ -412,4 +463,4 @@ class Env:
412
463
  @property
413
464
  def spacy_nlp(self) -> Any:
414
465
  assert self._spacy_nlp is not None
415
- return self._spacy_nlp
466
+ return self._spacy_nlp
@@ -6,4 +6,5 @@ from .exec_node import ExecNode
6
6
  from .expr_eval_node import ExprEvalNode
7
7
  from .in_memory_data_node import InMemoryDataNode
8
8
  from .sql_scan_node import SqlScanNode
9
- from .media_validation_node import MediaValidationNode
9
+ from .media_validation_node import MediaValidationNode
10
+ from .data_row_batch import DataRowBatch
@@ -14,9 +14,8 @@ class DataRowBatch:
14
14
 
15
15
  Contains the metadata needed to initialize DataRows.
16
16
  """
17
- def __init__(self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, len: int = 0):
18
- self.tbl_id = tbl.id
19
- self.tbl_version = tbl.version
17
+ def __init__(self, tbl: Optional[catalog.TableVersion], row_builder: exprs.RowBuilder, len: int = 0):
18
+ self.tbl = tbl
20
19
  self.row_builder = row_builder
21
20
  self.img_slot_idxs = [e.slot_idx for e in row_builder.unique_exprs if e.col_type.is_image_type()]
22
21
  # non-image media slots
@@ -42,9 +41,10 @@ class DataRowBatch:
42
41
 
43
42
  def set_row_ids(self, row_ids: List[int]) -> None:
44
43
  """Sets pks for rows in batch"""
44
+ assert self.tbl is not None
45
45
  assert len(row_ids) == len(self.rows)
46
46
  for row, row_id in zip(self.rows, row_ids):
47
- row.set_pk((row_id, self.tbl_version))
47
+ row.set_pk((row_id, self.tbl))
48
48
 
49
49
  def __len__(self) -> int:
50
50
  return len(self.rows)
@@ -57,6 +57,7 @@ class DataRowBatch:
57
57
  flushed_slot_idxs: Optional[List[int]] = None
58
58
  ) -> None:
59
59
  """Flushes images in the given range of rows."""
60
+ assert self.tbl is not None
60
61
  if stored_img_info is None:
61
62
  stored_img_info = []
62
63
  if flushed_slot_idxs is None:
@@ -67,12 +68,10 @@ class DataRowBatch:
67
68
  idx_range = slice(0, len(self.rows))
68
69
  for row in self.rows[idx_range]:
69
70
  for info in stored_img_info:
70
- filepath = str(MediaStore.prepare_media_path(self.tbl_id, info.col.id, self.tbl_version))
71
+ filepath = str(MediaStore.prepare_media_path(self.tbl.id, info.col.id, self.tbl.version))
71
72
  row.flush_img(info.slot_idx, filepath)
72
73
  for slot_idx in flushed_slot_idxs:
73
74
  row.flush_img(slot_idx)
74
- #_logger.debug(
75
- #f'flushed images in range {idx_range}: slot_idxs={flushed_slot_idxs} stored_img_info={stored_img_info}')
76
75
 
77
76
  def __iter__(self) -> Iterator[exprs.DataRow]:
78
77
  return DataRowBatchIterator(self)
@@ -1,20 +1,20 @@
1
- import sys
2
- import warnings
3
- from typing import List, Optional, Tuple
4
- from dataclasses import dataclass, field
5
1
  import logging
2
+ import sys
6
3
  import time
4
+ import warnings
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional
7
7
 
8
8
  from tqdm import tqdm, TqdmWarning
9
9
 
10
+ import pixeltable.exprs as exprs
11
+ from pixeltable.func import CallableFunction
10
12
  from .data_row_batch import DataRowBatch
11
13
  from .exec_node import ExecNode
12
- import pixeltable.exprs as exprs
13
- import pixeltable.func as func
14
-
15
14
 
16
15
  _logger = logging.getLogger('pixeltable')
17
16
 
17
+
18
18
  class ExprEvalNode(ExecNode):
19
19
  """Materializes expressions
20
20
  """
@@ -22,7 +22,7 @@ class ExprEvalNode(ExecNode):
22
22
  class Cohort:
23
23
  """List of exprs that form an evaluation context and contain calls to at most one external function"""
24
24
  exprs: List[exprs.Expr]
25
- ext_function: Optional[func.BatchedFunction]
25
+ batched_fn: Optional[CallableFunction]
26
26
  segment_ctxs: List[exprs.RowBuilder.EvalCtx]
27
27
  target_slot_idxs: List[int]
28
28
  batch_size: int = 8
@@ -63,12 +63,12 @@ class ExprEvalNode(ExecNode):
63
63
  if self.pbar is not None:
64
64
  self.pbar.close()
65
65
 
66
- def _get_batched_fn(self, expr: exprs.Expr) -> Optional[func.BatchedFunction]:
67
- if not isinstance(expr, exprs.FunctionCall):
68
- return None
69
- return expr.fn if isinstance(expr.fn, func.BatchedFunction) else None
66
+ def _get_batched_fn(self, expr: exprs.Expr) -> Optional[CallableFunction]:
67
+ if isinstance(expr, exprs.FunctionCall) and isinstance(expr.fn, CallableFunction) and expr.fn.is_batched:
68
+ return expr.fn
69
+ return None
70
70
 
71
- def _is_ext_call(self, expr: exprs.Expr) -> bool:
71
+ def _is_batched_fn_call(self, expr: exprs.Expr) -> bool:
72
72
  return self._get_batched_fn(expr) is not None
73
73
 
74
74
  def _create_cohorts(self) -> None:
@@ -76,14 +76,14 @@ class ExprEvalNode(ExecNode):
76
76
  # break up all_exprs into cohorts such that each cohort contains calls to at most one external function;
77
77
  # seed the cohorts with only the ext fn calls
78
78
  cohorts: List[List[exprs.Expr]] = []
79
- current_ext_function: Optional[func.BatchedFunction] = None
79
+ current_batched_fn: Optional[CallableFunction] = None
80
80
  for e in all_exprs:
81
- if not self._is_ext_call(e):
81
+ if not self._is_batched_fn_call(e):
82
82
  continue
83
- if current_ext_function is None or current_ext_function != e.fn:
83
+ if current_batched_fn is None or current_batched_fn != e.fn:
84
84
  # create a new cohort
85
85
  cohorts.append([])
86
- current_ext_function = e.fn
86
+ current_batched_fn = e.fn
87
87
  cohorts[-1].append(e)
88
88
 
89
89
  # expand the cohorts to include all exprs that are in the same evaluation context as the external calls;
@@ -115,18 +115,18 @@ class ExprEvalNode(ExecNode):
115
115
  assert len(cohort) > 0
116
116
  # create the first segment here, so we can avoid checking for an empty list in the loop
117
117
  segments = [[cohort[0]]]
118
- is_ext_segment = self._is_ext_call(cohort[0])
119
- ext_fn: Optional[func.BatchedFunction] = self._get_batched_fn(cohort[0])
118
+ is_batched_segment = self._is_batched_fn_call(cohort[0])
119
+ batched_fn: Optional[CallableFunction] = self._get_batched_fn(cohort[0])
120
120
  for e in cohort[1:]:
121
- if self._is_ext_call(e):
121
+ if self._is_batched_fn_call(e):
122
122
  segments.append([e])
123
- is_ext_segment = True
124
- ext_fn = self._get_batched_fn(e)
123
+ is_batched_segment = True
124
+ batched_fn = self._get_batched_fn(e)
125
125
  else:
126
- if is_ext_segment:
126
+ if is_batched_segment:
127
127
  # start a new segment
128
128
  segments.append([])
129
- is_ext_segment = False
129
+ is_batched_segment = False
130
130
  segments[-1].append(e)
131
131
 
132
132
  # we create the EvalCtxs manually because create_eval_ctx() would repeat the dependencies of each segment
@@ -135,21 +135,21 @@ class ExprEvalNode(ExecNode):
135
135
  slot_idxs=[e.slot_idx for e in s], exprs=s, target_slot_idxs=[], target_exprs=[])
136
136
  for s in segments
137
137
  ]
138
- cohort_info = self.Cohort(cohort, ext_fn, segment_ctxs, target_slot_idxs[i])
138
+ cohort_info = self.Cohort(cohort, batched_fn, segment_ctxs, target_slot_idxs[i])
139
139
  self.cohorts.append(cohort_info)
140
140
 
141
141
  def _exec_cohort(self, cohort: Cohort, rows: DataRowBatch) -> None:
142
142
  """Compute the cohort for the entire input batch by dividing it up into sub-batches"""
143
143
  batch_start_idx = 0 # start row of the current sub-batch
144
144
  # for multi-resolution models, we re-assess the correct ext fn batch size for each input batch
145
- ext_batch_size = cohort.ext_function.get_batch_size() if cohort.ext_function is not None else None
145
+ ext_batch_size = cohort.batched_fn.get_batch_size() if cohort.batched_fn is not None else None
146
146
  if ext_batch_size is not None:
147
147
  cohort.batch_size = ext_batch_size
148
148
 
149
149
  while batch_start_idx < len(rows):
150
150
  num_batch_rows = min(cohort.batch_size, len(rows) - batch_start_idx)
151
151
  for segment_ctx in cohort.segment_ctxs:
152
- if not self._is_ext_call(segment_ctx.exprs[0]):
152
+ if not self._is_batched_fn_call(segment_ctx.exprs[0]):
153
153
  # compute batch row-wise
154
154
  for row_idx in range(batch_start_idx, batch_start_idx + num_batch_rows):
155
155
  self.row_builder.eval(
@@ -193,7 +193,7 @@ class ExprEvalNode(ExecNode):
193
193
  for k in kwarg_batches.keys()
194
194
  }
195
195
  start_ts = time.perf_counter()
196
- result_batch = fn_call.fn.invoke(call_args, call_kwargs)
196
+ result_batch = fn_call.fn.exec_batch(*call_args, **call_kwargs)
197
197
  self.ctx.profile.eval_time[fn_call.slot_idx] += time.perf_counter() - start_ts
198
198
  self.ctx.profile.eval_count[fn_call.slot_idx] += num_ext_batch_rows
199
199
 
@@ -21,7 +21,6 @@ class SqlScanNode(ExecNode):
21
21
  select_list: Iterable[exprs.Expr],
22
22
  where_clause: Optional[exprs.Expr] = None, filter: Optional[exprs.Predicate] = None,
23
23
  order_by_items: Optional[List[Tuple[exprs.Expr, bool]]] = None,
24
- similarity_clause: Optional[exprs.ImageSimilarityPredicate] = None,
25
24
  limit: int = 0, set_pk: bool = False, exact_version_only: Optional[List[catalog.TableVersion]] = None
26
25
  ):
27
26
  """
@@ -77,15 +76,17 @@ class SqlScanNode(ExecNode):
77
76
  # the number of tables that need to be joined to the target table
78
77
  for rowid_ref in [e for e, _ in order_by_items if isinstance(e, exprs.RowidRef)]:
79
78
  rowid_ref.set_tbl(tbl)
80
- order_by_clause = [e.sql_expr().desc() if not asc else e.sql_expr() for e, asc in order_by_items]
79
+ order_by_clause: List[sql.ClauseElement] = []
80
+ for e, asc in order_by_items:
81
+ if isinstance(e, exprs.SimilarityExpr):
82
+ order_by_clause.append(e.as_order_by_clause(asc))
83
+ else:
84
+ order_by_clause.append(e.sql_expr().desc() if not asc else e.sql_expr())
81
85
 
82
86
  if where_clause is not None:
83
87
  sql_where_clause = where_clause.sql_expr()
84
88
  assert sql_where_clause is not None
85
89
  self.stmt = self.stmt.where(sql_where_clause)
86
- if similarity_clause is not None:
87
- self.stmt = self.stmt.order_by(
88
- similarity_clause.img_col_ref.col.sa_idx_col.l2_distance(similarity_clause.embedding()))
89
90
  if len(order_by_clause) > 0:
90
91
  self.stmt = self.stmt.order_by(*order_by_clause)
91
92
  elif target.id in row_builder.unstored_iter_args:
@@ -201,7 +202,7 @@ class SqlScanNode(ExecNode):
201
202
  self.row_builder.eval(output_row, self.filter_eval_ctx, profile=self.ctx.profile)
202
203
  if output_row[self.filter.slot_idx]:
203
204
  needs_row = True
204
- if self.limit is not None and len(output_batch) >= self.limit:
205
+ if self.limit > 0 and len(output_batch) >= self.limit:
205
206
  self.has_more_rows = False
206
207
  break
207
208
  else:
@@ -6,9 +6,10 @@ from .comparison import Comparison
6
6
  from .compound_predicate import CompoundPredicate
7
7
  from .data_row import DataRow
8
8
  from .expr import Expr
9
+ from .expr_set import ExprSet
9
10
  from .function_call import FunctionCall
10
11
  from .image_member_access import ImageMemberAccess
11
- from .image_similarity_predicate import ImageSimilarityPredicate
12
+ from .in_predicate import InPredicate
12
13
  from .inline_array import InlineArray
13
14
  from .inline_dict import InlineDict
14
15
  from .is_null import IsNull
@@ -16,9 +17,9 @@ from .json_mapper import JsonMapper
16
17
  from .json_path import RELATIVE_PATH_ROOT, JsonPath
17
18
  from .literal import Literal
18
19
  from .object_ref import ObjectRef
19
- from .variable import Variable
20
20
  from .predicate import Predicate
21
21
  from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
22
22
  from .rowid_ref import RowidRef
23
- from .expr_set import ExprSet
23
+ from .similarity_expr import SimilarityExpr
24
24
  from .type_cast import TypeCast
25
+ from .variable import Variable
@@ -63,6 +63,15 @@ class ColumnRef(Expr):
63
63
 
64
64
  return super().__getattr__(name)
65
65
 
66
+ def similarity(self, other: Any) -> Expr:
67
+ if isinstance(other, Expr):
68
+ raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
69
+ item = Expr.from_object(other)
70
+ if item is None or not(item.col_type.is_string_type() or item.col_type.is_image_type()):
71
+ raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not a {type(other)}')
72
+ from .similarity_expr import SimilarityExpr
73
+ return SimilarityExpr(self, item)
74
+
66
75
  def default_column_name(self) -> Optional[str]:
67
76
  return str(self)
68
77
 
pixeltable/exprs/expr.py CHANGED
@@ -60,9 +60,9 @@ class Expr(abc.ABC):
60
60
 
61
61
  # index of the expr's value in the data row:
62
62
  # - set for all materialized exprs
63
- # - -1: not executable
63
+ # - None: not executable
64
64
  # - not set for subexprs that don't need to be materialized because the parent can be materialized via SQL
65
- self.slot_idx = -1
65
+ self.slot_idx: Optional[int] = None
66
66
  self.components: List[Expr] = [] # the subexprs that are needed to construct this expr
67
67
 
68
68
  def dependencies(self) -> List[Expr]:
@@ -110,6 +110,11 @@ class Expr(abc.ABC):
110
110
  return False
111
111
  return self._equals(other)
112
112
 
113
+ def _equals(self, other: Expr) -> bool:
114
+ # we already compared the type and components in equals(); subclasses that require additional comparisons
115
+ # override this
116
+ return True
117
+
113
118
  def _id_attrs(self) -> List[Tuple[str, Any]]:
114
119
  """Returns attribute name/value pairs that are used to construct the instance id.
115
120
 
@@ -148,7 +153,7 @@ class Expr(abc.ABC):
148
153
  cls = self.__class__
149
154
  result = cls.__new__(cls)
150
155
  result.__dict__.update(self.__dict__)
151
- result.slot_idx = -1
156
+ result.slot_idx = None
152
157
  result.components = [c.copy() for c in self.components]
153
158
  return result
154
159
 
@@ -313,10 +318,6 @@ class Expr(abc.ABC):
313
318
  return InlineArray(tuple(o))
314
319
  return None
315
320
 
316
- @abc.abstractmethod
317
- def _equals(self, other: Expr) -> bool:
318
- pass
319
-
320
321
  @abc.abstractmethod
321
322
  def sql_expr(self) -> Optional[sql.ClauseElement]:
322
323
  """
@@ -396,6 +397,13 @@ class Expr(abc.ABC):
396
397
  def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
397
398
  assert False, 'not implemented'
398
399
 
400
+ def isin(self, value_set: Any) -> 'pixeltable.exprs.InPredicate':
401
+ from .in_predicate import InPredicate
402
+ if isinstance(value_set, Expr):
403
+ return InPredicate(self, value_set_expr=value_set)
404
+ else:
405
+ return InPredicate(self, value_set_literal=value_set)
406
+
399
407
  def astype(self, new_type: ts.ColumnType) -> 'pixeltable.exprs.TypeCast':
400
408
  from pixeltable.exprs import TypeCast
401
409
  return TypeCast(self, new_type)
@@ -28,7 +28,7 @@ class FunctionCall(Expr):
28
28
  if group_by_clause is None:
29
29
  group_by_clause = []
30
30
  signature = fn.signature
31
- super().__init__(signature.get_return_type(bound_args))
31
+ super().__init__(fn.call_return_type(bound_args))
32
32
  self.fn = fn
33
33
  self.is_method_call = is_method_call
34
34
  self.check_args(signature, bound_args)
@@ -46,9 +46,9 @@ class FunctionCall(Expr):
46
46
 
47
47
  # Tuple[int, Any]:
48
48
  # - for Exprs: (index into components, None)
49
- # - otherwise: (-1, val)
50
- self.args: List[Tuple[int, Any]] = []
51
- self.kwargs: Dict[str, Tuple[int, Any]] = {}
49
+ # - otherwise: (None, val)
50
+ self.args: List[Tuple[Optional[int], Optional[Any]]] = []
51
+ self.kwargs: Dict[str, Tuple[Optional[int], Optional[Any]]] = {}
52
52
 
53
53
  # we record the types of non-variable parameters for runtime type checks
54
54
  self.arg_types: List[ts.ColumnType] = []
@@ -62,7 +62,7 @@ class FunctionCall(Expr):
62
62
  self.args.append((len(self.components), None))
63
63
  self.components.append(arg.copy())
64
64
  else:
65
- self.args.append((-1, arg))
65
+ self.args.append((None, arg))
66
66
  if param.kind != inspect.Parameter.VAR_POSITIONAL and param.kind != inspect.Parameter.VAR_KEYWORD:
67
67
  self.arg_types.append(signature.parameters[param.name].col_type)
68
68
 
@@ -74,7 +74,7 @@ class FunctionCall(Expr):
74
74
  self.kwargs[param_name] = (len(self.components), None)
75
75
  self.components.append(arg.copy())
76
76
  else:
77
- self.kwargs[param_name] = (-1, arg)
77
+ self.kwargs[param_name] = (None, arg)
78
78
  if fn.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
79
79
  self.kwarg_types[param_name] = signature.parameters[param_name].col_type
80
80
 
@@ -215,12 +215,12 @@ class FunctionCall(Expr):
215
215
 
216
216
  def _print_args(self, start_idx: int = 0, inline: bool = True) -> str:
217
217
  arg_strs = [
218
- str(arg) if idx == -1 else str(self.components[idx]) for idx, arg in self.args[start_idx:]
218
+ str(arg) if idx is None else str(self.components[idx]) for idx, arg in self.args[start_idx:]
219
219
  ]
220
220
  def print_arg(arg: Any) -> str:
221
221
  return f"'{arg}'" if isinstance(arg, str) else str(arg)
222
222
  arg_strs.extend([
223
- f'{param_name}={print_arg(arg) if idx == -1 else str(self.components[idx])}'
223
+ f'{param_name}={print_arg(arg) if idx is None else str(self.components[idx])}'
224
224
  for param_name, (idx, arg) in self.kwargs.items()
225
225
  ])
226
226
  if len(self.order_by) > 0:
@@ -287,7 +287,7 @@ class FunctionCall(Expr):
287
287
  """Return args and kwargs, constructed for data_row"""
288
288
  kwargs: Dict[str, Any] = {}
289
289
  for param_name, (component_idx, arg) in self.kwargs.items():
290
- val = arg if component_idx == -1 else data_row[self.components[component_idx].slot_idx]
290
+ val = arg if component_idx is None else data_row[self.components[component_idx].slot_idx]
291
291
  param = self.fn.signature.parameters[param_name]
292
292
  if param.kind == inspect.Parameter.VAR_KEYWORD:
293
293
  # expand **kwargs parameter
@@ -298,7 +298,7 @@ class FunctionCall(Expr):
298
298
 
299
299
  args: List[Any] = []
300
300
  for param_idx, (component_idx, arg) in enumerate(self.args):
301
- val = arg if component_idx == -1 else data_row[self.components[component_idx].slot_idx]
301
+ val = arg if component_idx is None else data_row[self.components[component_idx].slot_idx]
302
302
  param = self.fn.signature.parameters_by_pos[param_idx]
303
303
  if param.kind == inspect.Parameter.VAR_POSITIONAL:
304
304
  # expand *args parameter
@@ -333,7 +333,8 @@ class FunctionCall(Expr):
333
333
  # TODO: can we get rid of this extra copy?
334
334
  fn_expr = self.components[self.fn_expr_idx]
335
335
  data_row[self.slot_idx] = data_row[fn_expr.slot_idx]
336
- elif isinstance(self.fn, func.CallableFunction):
336
+ elif isinstance(self.fn, func.CallableFunction) and not self.fn.is_batched:
337
+ # optimization: avoid additional level of indirection we'd get from calling Function.exec()
337
338
  data_row[self.slot_idx] = self.fn.py_fn(*args, **kwargs)
338
339
  elif self.is_window_fn_call:
339
340
  if self.has_group_by():
@@ -348,9 +349,10 @@ class FunctionCall(Expr):
348
349
  self.aggregator = self.fn.agg_cls(**self.agg_init_args)
349
350
  self.aggregator.update(*args)
350
351
  data_row[self.slot_idx] = self.aggregator.value()
351
- else:
352
- assert self.is_agg_fn_call
352
+ elif self.is_agg_fn_call:
353
353
  data_row[self.slot_idx] = self.aggregator.value()
354
+ else:
355
+ data_row[self.slot_idx] = self.fn.exec(*args, **kwargs)
354
356
 
355
357
  def _as_dict(self) -> Dict:
356
358
  result = {
@@ -369,9 +371,9 @@ class FunctionCall(Expr):
369
371
  # reassemble bound args
370
372
  fn = func.Function.from_dict(d['fn'])
371
373
  param_names = list(fn.signature.parameters.keys())
372
- bound_args = {param_names[i]: arg if idx == -1 else components[idx] for i, (idx, arg) in enumerate(d['args'])}
374
+ bound_args = {param_names[i]: arg if idx is None else components[idx] for i, (idx, arg) in enumerate(d['args'])}
373
375
  bound_args.update(
374
- {param_name: val if idx == -1 else components[idx] for param_name, (idx, val) in d['kwargs'].items()})
376
+ {param_name: val if idx is None else components[idx] for param_name, (idx, val) in d['kwargs'].items()})
375
377
  group_by_exprs = components[d['group_by_start_idx']:d['group_by_stop_idx']]
376
378
  order_by_exprs = components[d['order_by_start_idx']:]
377
379
  fn_call = cls(