pixeltable 0.2.17__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (79) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/catalog.py +8 -7
  3. pixeltable/catalog/column.py +11 -8
  4. pixeltable/catalog/insertable_table.py +1 -1
  5. pixeltable/catalog/path_dict.py +8 -6
  6. pixeltable/catalog/table.py +20 -13
  7. pixeltable/catalog/table_version.py +91 -54
  8. pixeltable/catalog/table_version_path.py +7 -9
  9. pixeltable/catalog/view.py +2 -1
  10. pixeltable/dataframe.py +1 -1
  11. pixeltable/env.py +173 -83
  12. pixeltable/exec/aggregation_node.py +2 -1
  13. pixeltable/exec/component_iteration_node.py +1 -1
  14. pixeltable/exec/sql_node.py +11 -8
  15. pixeltable/exprs/__init__.py +1 -0
  16. pixeltable/exprs/arithmetic_expr.py +4 -4
  17. pixeltable/exprs/array_slice.py +2 -1
  18. pixeltable/exprs/column_property_ref.py +9 -7
  19. pixeltable/exprs/column_ref.py +2 -1
  20. pixeltable/exprs/comparison.py +10 -7
  21. pixeltable/exprs/compound_predicate.py +3 -2
  22. pixeltable/exprs/data_row.py +19 -4
  23. pixeltable/exprs/expr.py +46 -35
  24. pixeltable/exprs/expr_set.py +32 -9
  25. pixeltable/exprs/function_call.py +56 -32
  26. pixeltable/exprs/in_predicate.py +3 -2
  27. pixeltable/exprs/inline_array.py +2 -1
  28. pixeltable/exprs/inline_dict.py +2 -1
  29. pixeltable/exprs/is_null.py +3 -2
  30. pixeltable/exprs/json_mapper.py +5 -4
  31. pixeltable/exprs/json_path.py +7 -1
  32. pixeltable/exprs/literal.py +34 -7
  33. pixeltable/exprs/method_ref.py +3 -3
  34. pixeltable/exprs/object_ref.py +6 -5
  35. pixeltable/exprs/row_builder.py +25 -17
  36. pixeltable/exprs/rowid_ref.py +2 -1
  37. pixeltable/exprs/similarity_expr.py +2 -1
  38. pixeltable/exprs/sql_element_cache.py +30 -0
  39. pixeltable/exprs/type_cast.py +3 -3
  40. pixeltable/exprs/variable.py +2 -1
  41. pixeltable/ext/functions/whisperx.py +4 -4
  42. pixeltable/ext/functions/yolox.py +6 -6
  43. pixeltable/func/aggregate_function.py +1 -0
  44. pixeltable/func/function.py +28 -4
  45. pixeltable/functions/__init__.py +4 -2
  46. pixeltable/functions/anthropic.py +15 -5
  47. pixeltable/functions/fireworks.py +1 -1
  48. pixeltable/functions/globals.py +6 -1
  49. pixeltable/functions/huggingface.py +2 -2
  50. pixeltable/functions/image.py +17 -2
  51. pixeltable/functions/json.py +5 -5
  52. pixeltable/functions/mistralai.py +188 -0
  53. pixeltable/functions/openai.py +6 -10
  54. pixeltable/functions/string.py +3 -2
  55. pixeltable/functions/timestamp.py +95 -7
  56. pixeltable/functions/together.py +4 -4
  57. pixeltable/functions/video.py +2 -2
  58. pixeltable/functions/vision.py +27 -17
  59. pixeltable/functions/whisper.py +1 -1
  60. pixeltable/io/hf_datasets.py +17 -15
  61. pixeltable/io/pandas.py +0 -2
  62. pixeltable/io/parquet.py +15 -14
  63. pixeltable/iterators/document.py +16 -15
  64. pixeltable/metadata/__init__.py +1 -1
  65. pixeltable/metadata/converters/convert_19.py +46 -0
  66. pixeltable/metadata/notes.py +1 -0
  67. pixeltable/metadata/schema.py +5 -4
  68. pixeltable/plan.py +100 -78
  69. pixeltable/store.py +5 -1
  70. pixeltable/tool/create_test_db_dump.py +4 -3
  71. pixeltable/type_system.py +12 -14
  72. pixeltable/utils/documents.py +45 -42
  73. pixeltable/utils/formatter.py +2 -2
  74. {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/METADATA +79 -21
  75. pixeltable-0.2.18.dist-info/RECORD +147 -0
  76. pixeltable-0.2.17.dist-info/RECORD +0 -144
  77. {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/LICENSE +0 -0
  78. {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/WHEEL +0 -0
  79. {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/entry_points.txt +0 -0
pixeltable/env.py CHANGED
@@ -14,7 +14,8 @@ import uuid
14
14
  import warnings
15
15
  from dataclasses import dataclass
16
16
  from pathlib import Path
17
- from typing import Callable, Optional, Dict, Any, List, TYPE_CHECKING
17
+ from typing import TYPE_CHECKING, Any, Callable, Optional
18
+ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
18
19
 
19
20
  import pixeltable_pgserver
20
21
  import sqlalchemy as sql
@@ -37,6 +38,35 @@ class Env:
37
38
  _instance: Optional[Env] = None
38
39
  _log_fmt_str = '%(asctime)s %(levelname)s %(name)s %(filename)s:%(lineno)d: %(message)s'
39
40
 
41
+ _home: Optional[Path]
42
+ _media_dir: Optional[Path]
43
+ _file_cache_dir: Optional[Path] # cached media files with external URL
44
+ _dataset_cache_dir: Optional[Path] # cached datasets (eg, pytorch or COCO)
45
+ _log_dir: Optional[Path] # log files
46
+ _tmp_dir: Optional[Path] # any tmp files
47
+ _sa_engine: Optional[sql.engine.base.Engine]
48
+ _pgdata_dir: Optional[Path]
49
+ _db_name: Optional[str]
50
+ _db_server: Optional[pixeltable_pgserver.PostgresServer]
51
+ _db_url: Optional[str]
52
+ _default_time_zone: Optional[ZoneInfo]
53
+
54
+ # info about optional packages that are utilized by some parts of the code
55
+ __optional_packages: dict[str, PackageInfo]
56
+
57
+ _spacy_nlp: Optional[spacy.Language]
58
+ _httpd: Optional[http.server.HTTPServer]
59
+ _http_address: Optional[str]
60
+ _logger: logging.Logger
61
+ _default_log_level: int
62
+ _logfilename: Optional[str]
63
+ _log_to_stdout: bool
64
+ _module_log_level: dict[str, int] # module name -> log level
65
+ _config_file: Optional[Path]
66
+ _config: Optional[dict[str, Any]]
67
+ _stdout_handler: logging.StreamHandler
68
+ _initialized: bool
69
+
40
70
  @classmethod
41
71
  def get(cls) -> Env:
42
72
  if cls._instance is None:
@@ -51,24 +81,23 @@ class Env:
51
81
  cls._instance = env
52
82
 
53
83
  def __init__(self):
54
- self._home: Optional[Path] = None
55
- self._media_dir: Optional[Path] = None # computed media files
56
- self._file_cache_dir: Optional[Path] = None # cached media files with external URL
57
- self._dataset_cache_dir: Optional[Path] = None # cached datasets (eg, pytorch or COCO)
58
- self._log_dir: Optional[Path] = None # log files
59
- self._tmp_dir: Optional[Path] = None # any tmp files
60
- self._sa_engine: Optional[sql.engine.base.Engine] = None
61
- self._pgdata_dir: Optional[Path] = None
62
- self._db_name: Optional[str] = None
63
- self._db_server: Optional[pixeltable_pgserver.PostgresServer] = None
64
- self._db_url: Optional[str] = None
65
-
66
- # info about installed packages that are utilized by some parts of the code;
67
- # package name -> version; version == []: package is installed, but we haven't determined the version yet
68
- self._installed_packages: Dict[str, Optional[List[int]]] = {}
69
- self._spacy_nlp: Optional[spacy.Language] = None
70
- self._httpd: Optional[http.server.HTTPServer] = None
71
- self._http_address: Optional[str] = None
84
+ self._home = None
85
+ self._media_dir = None # computed media files
86
+ self._file_cache_dir = None # cached media files with external URL
87
+ self._dataset_cache_dir = None # cached datasets (eg, pytorch or COCO)
88
+ self._log_dir = None # log files
89
+ self._tmp_dir = None # any tmp files
90
+ self._sa_engine = None
91
+ self._pgdata_dir = None
92
+ self._db_name = None
93
+ self._db_server = None
94
+ self._db_url = None
95
+ self._default_time_zone = None
96
+
97
+ self.__optional_packages = {}
98
+ self._spacy_nlp = None
99
+ self._httpd = None
100
+ self._http_address = None
72
101
 
73
102
  # logging-related state
74
103
  self._logger = logging.getLogger('pixeltable')
@@ -76,13 +105,12 @@ class Env:
76
105
  self._logger.propagate = False
77
106
  self._logger.addFilter(self._log_filter)
78
107
  self._default_log_level = logging.INFO
79
- self._logfilename: Optional[str] = None
108
+ self._logfilename = None
80
109
  self._log_to_stdout = False
81
- self._module_log_level: Dict[str, int] = {} # module name -> log level
110
+ self._module_log_level = {} # module name -> log level
82
111
 
83
- # config
84
- self._config_file: Optional[Path] = None
85
- self._config: Optional[Dict[str, Any]] = None
112
+ self._config_file = None
113
+ self._config = None
86
114
 
87
115
  # create logging handler to also log to stdout
88
116
  self._stdout_handler = logging.StreamHandler(stream=sys.stdout)
@@ -103,6 +131,19 @@ class Env:
103
131
  assert self._http_address is not None
104
132
  return self._http_address
105
133
 
134
+ @property
135
+ def default_time_zone(self) -> Optional[ZoneInfo]:
136
+ return self._default_time_zone
137
+
138
+ @default_time_zone.setter
139
+ def default_time_zone(self, tz: Optional[ZoneInfo]) -> None:
140
+ """
141
+ This is not a publicly visible setter; it is only for testing purposes.
142
+ """
143
+ tz_name = None if tz is None else tz.key
144
+ self.engine.dispose()
145
+ self._create_engine(time_zone_name=tz_name)
146
+
106
147
  def configure_logging(
107
148
  self,
108
149
  *,
@@ -158,7 +199,8 @@ class Env:
158
199
  self._module_log_level[module] = level
159
200
 
160
201
  def is_installed_package(self, package_name: str) -> bool:
161
- return self._installed_packages[package_name] is not None
202
+ assert package_name in self.__optional_packages
203
+ return self.__optional_packages[package_name].is_installed
162
204
 
163
205
  def _log_filter(self, record: logging.LogRecord) -> bool:
164
206
  if record.name == 'pixeltable':
@@ -270,20 +312,35 @@ class Env:
270
312
  self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=None)
271
313
  self._db_url = self._db_server.get_uri(database=self._db_name, driver='psycopg')
272
314
 
273
- if reinit_db:
274
- if self._store_db_exists():
275
- self._drop_store_db()
315
+ tz_name = os.environ.get('PXT_TIME_ZONE', self._config.get('pxt_time_zone', None))
316
+ if tz_name is not None:
317
+ # Validate tzname
318
+ if not isinstance(tz_name, str):
319
+ self._logger.error(f'Invalid time zone specified in configuration.')
320
+ else:
321
+ try:
322
+ _ = ZoneInfo(tz_name)
323
+ except ZoneInfoNotFoundError:
324
+ self._logger.error(f'Invalid time zone specified in configuration: {tz_name}')
325
+
326
+ if reinit_db and self._store_db_exists():
327
+ self._drop_store_db()
328
+
329
+ create_db = not self._store_db_exists()
276
330
 
277
- if not self._store_db_exists():
278
- self._logger.info(f'creating database at {self.db_url}')
331
+ if create_db:
332
+ self._logger.info(f'creating database at: {self.db_url}')
279
333
  self._create_store_db()
280
- self._create_engine(echo=echo)
334
+ else:
335
+ self._logger.info(f'found database at: {self.db_url}')
336
+
337
+ # Create the SQLAlchemy engine. This will also set the default time zone.
338
+ self._create_engine(time_zone_name=tz_name, echo=echo)
339
+
340
+ if create_db:
281
341
  from pixeltable.metadata import schema
282
342
  schema.Base.metadata.create_all(self._sa_engine)
283
343
  metadata.create_system_info(self._sa_engine)
284
- else:
285
- self._logger.info(f'found database {self.db_url}')
286
- self._create_engine(echo=echo)
287
344
 
288
345
  print(f'Connected to Pixeltable database at: {self.db_url}')
289
346
 
@@ -291,8 +348,21 @@ class Env:
291
348
  self._set_up_runtime()
292
349
  self.log_to_stdout(False)
293
350
 
294
- def _create_engine(self, echo: bool = False) -> None:
295
- self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True, isolation_level='AUTOCOMMIT')
351
+ def _create_engine(self, time_zone_name: Optional[str], echo: bool = False) -> None:
352
+ connect_args = {} if time_zone_name is None else {'options': f'-c timezone={time_zone_name}'}
353
+ self._sa_engine = sql.create_engine(
354
+ self.db_url,
355
+ echo=echo,
356
+ future=True,
357
+ isolation_level='AUTOCOMMIT',
358
+ connect_args=connect_args,
359
+ )
360
+ self._logger.info(f'Created SQLAlchemy engine at: {self.db_url}')
361
+ with self.engine.begin() as conn:
362
+ tz_name = conn.execute(sql.text('SHOW TIME ZONE')).scalar()
363
+ assert isinstance(tz_name, str)
364
+ self._logger.info(f'Database time zone is now: {tz_name}')
365
+ self._default_time_zone = ZoneInfo(tz_name)
296
366
 
297
367
  def _store_db_exists(self) -> bool:
298
368
  assert self._db_name is not None
@@ -308,7 +378,6 @@ class Env:
308
378
  finally:
309
379
  engine.dispose()
310
380
 
311
-
312
381
  def _create_store_db(self) -> None:
313
382
  assert self._db_name is not None
314
383
  # create the db
@@ -416,61 +485,75 @@ class Env:
416
485
  def _set_up_runtime(self) -> None:
417
486
  """Check for and start runtime services"""
418
487
  self._start_web_server()
419
- self._check_installed_packages()
488
+ self.__register_packages()
489
+
490
+ def __register_packages(self) -> None:
491
+ """Declare optional packages that are utilized by some parts of the code."""
492
+ self.__register_package('anthropic')
493
+ self.__register_package('boto3')
494
+ self.__register_package('datasets')
495
+ self.__register_package('fireworks', library_name='fireworks-ai')
496
+ self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
497
+ self.__register_package('mistralai')
498
+ self.__register_package('mistune')
499
+ self.__register_package('openai')
500
+ self.__register_package('openpyxl')
501
+ self.__register_package('pyarrow')
502
+ self.__register_package('sentence_transformers', library_name='sentence-transformers')
503
+ self.__register_package('spacy') # TODO: deal with en-core-web-sm
504
+ self.__register_package('tiktoken')
505
+ self.__register_package('together')
506
+ self.__register_package('toml')
507
+ self.__register_package('torch')
508
+ self.__register_package('torchvision')
509
+ self.__register_package('transformers')
510
+ self.__register_package('whisper', library_name='openai-whisper')
511
+ self.__register_package('whisperx')
512
+ self.__register_package('yolox', library_name='git+https://github.com/Megvii-BaseDetection/YOLOX@ac58e0a')
420
513
 
421
- def _check_installed_packages(self) -> None:
422
- def check(package: str) -> None:
423
- if importlib.util.find_spec(package) is not None:
424
- self._installed_packages[package] = []
425
- else:
426
- self._installed_packages[package] = None
427
-
428
- check('toml')
429
- check('datasets')
430
- check('torch')
431
- check('torchvision')
432
- check('transformers')
433
- check('sentence_transformers')
434
- check('whisper')
435
- check('yolox')
436
- check('whisperx')
437
- check('boto3')
438
- check('fitz') # pymupdf
439
- check('pyarrow')
440
- check('spacy') # TODO: deal with en-core-web-sm
441
514
  if self.is_installed_package('spacy'):
442
515
  import spacy
443
-
444
516
  self._spacy_nlp = spacy.load('en_core_web_sm')
445
- check('tiktoken')
446
- check('openai')
447
- check('anthropic')
448
- check('together')
449
- check('fireworks')
450
- check('label_studio_sdk')
451
- check('openpyxl')
452
-
453
- def require_package(self, package: str, min_version: Optional[List[int]] = None) -> None:
454
- assert package in self._installed_packages
455
- if self._installed_packages[package] is None:
456
- raise excs.Error(f'Package {package} is not installed')
517
+
518
+ def __register_package(self, package_name: str, library_name: Optional[str] = None) -> None:
519
+ self.__optional_packages[package_name] = PackageInfo(
520
+ is_installed=importlib.util.find_spec(package_name) is not None,
521
+ library_name=library_name or package_name # defaults to package_name unless specified otherwise
522
+ )
523
+
524
+ def require_package(self, package_name: str, min_version: Optional[list[int]] = None) -> None:
525
+ """
526
+ Checks whether the specified optional package is available. If not, raises an exception
527
+ with an error message informing the user how to install it.
528
+ """
529
+ assert package_name in self.__optional_packages
530
+ package_info = self.__optional_packages[package_name]
531
+
532
+ if not package_info.is_installed:
533
+ # Check again whether the package has been installed.
534
+ # We do this so that if a user gets an "optional library not found" error message, they can
535
+ # `pip install` the library and re-run the Pixeltable operation without having to restart
536
+ # their Python session.
537
+ package_info.is_installed = importlib.util.find_spec(package_name) is not None
538
+ if not package_info.is_installed:
539
+ # Still not found.
540
+ raise excs.Error(
541
+ f'This feature requires the `{package_name}` package. To install it, run: `pip install -U {package_info.library_name}`'
542
+ )
543
+
457
544
  if min_version is None:
458
545
  return
459
546
 
460
547
  # check whether we have a version >= the required one
461
- if not self._installed_packages[package]:
462
- m = importlib.import_module(package)
463
- module_version = [int(x) for x in m.__version__.split('.')]
464
- self._installed_packages[package] = module_version
465
- installed_version = self._installed_packages[package]
466
- if len(min_version) < len(installed_version):
467
- normalized_min_version = min_version + [0] * (len(installed_version) - len(min_version))
468
- if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
548
+ if package_info.version is None:
549
+ module = importlib.import_module(package_name)
550
+ package_info.version = [int(x) for x in module.__version__.split('.')]
551
+
552
+ if min_version > package_info.version:
469
553
  raise excs.Error(
470
- (
471
- f'The installed version of package {package} is {".".join(str(v) for v in installed_version)}, '
472
- f'but version >={".".join(str(v) for v in min_version)} is required'
473
- )
554
+ f'The installed version of package `{package_name}` is {".".join(str(v) for v in package_info.version)}, '
555
+ f'but version >={".".join(str(v) for v in min_version)} is required. '
556
+ f'To fix this, run: `pip install -U {package_info.library_name}`'
474
557
  )
475
558
 
476
559
  def num_tmp_files(self) -> int:
@@ -556,3 +639,10 @@ class ApiClient:
556
639
  init_fn: Callable
557
640
  param_names: list[str]
558
641
  client_obj: Optional[Any] = None
642
+
643
+
644
+ @dataclass
645
+ class PackageInfo:
646
+ is_installed: bool
647
+ library_name: str # pypi library name (may be different from package name)
648
+ version: Optional[list[int]] = None # installed version, as a list of components (such as [3,0,2] for "3.0.2")
@@ -21,8 +21,9 @@ class AggregationNode(ExecNode):
21
21
  self.input = input
22
22
  self.group_by = group_by
23
23
  self.input_exprs = list(input_exprs)
24
- self.agg_fn_calls = agg_fn_calls
25
24
  self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
25
+ # we need to make sure to refer to the same exprs that RowBuilder.eval() will use
26
+ self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
26
27
  self.output_batch = DataRowBatch(tbl, row_builder, 0)
27
28
 
28
29
  def _reset_agg_state(self, row_num: int) -> None:
@@ -19,7 +19,7 @@ class ComponentIterationNode(ExecNode):
19
19
  super().__init__(input.row_builder, [], [], input)
20
20
  self.view = view
21
21
  iterator_args = [view.iterator_args.copy()]
22
- self.row_builder.substitute_exprs(iterator_args)
22
+ self.row_builder.set_slot_idxs(iterator_args)
23
23
  self.iterator_args = iterator_args[0]
24
24
  assert isinstance(self.iterator_args, exprs.InlineDict)
25
25
  self.iterator_args_ctx = self.row_builder.create_eval_ctx([self.iterator_args])
@@ -18,7 +18,7 @@ class SqlNode(ExecNode):
18
18
 
19
19
  def __init__(
20
20
  self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
21
- select_list: Iterable[exprs.Expr], set_pk: bool = False
21
+ select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
22
22
  ):
23
23
  """
24
24
  Initialize self.stmt with expressions derived from select_list.
@@ -35,8 +35,9 @@ class SqlNode(ExecNode):
35
35
  self.sql_exprs = exprs.ExprSet(select_list)
36
36
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
37
37
  for iter_arg in row_builder.unstored_iter_args.values():
38
- sql_subexprs = iter_arg.subexprs(filter=lambda e: e.sql_expr() is not None, traverse_matches=False)
39
- [self.sql_exprs.append(e) for e in sql_subexprs]
38
+ sql_subexprs = iter_arg.subexprs(filter=sql_elements.contains, traverse_matches=False)
39
+ for e in sql_subexprs:
40
+ self.sql_exprs.add(e)
40
41
  super().__init__(row_builder, self.sql_exprs, [], None) # we materialize self.sql_exprs
41
42
 
42
43
  # change rowid refs against a base table to rowid refs against the target table, so that we minimize
@@ -44,7 +45,7 @@ class SqlNode(ExecNode):
44
45
  for rowid_ref in [e for e in self.sql_exprs if isinstance(e, exprs.RowidRef)]:
45
46
  rowid_ref.set_tbl(tbl)
46
47
 
47
- sql_select_list = [e.sql_expr() for e in self.sql_exprs]
48
+ sql_select_list = [sql_elements.get(e) for e in self.sql_exprs]
48
49
  assert len(sql_select_list) == len(self.sql_exprs)
49
50
  assert all(e is not None for e in sql_select_list)
50
51
  self.set_pk = set_pk
@@ -204,7 +205,8 @@ class SqlScanNode(SqlNode):
204
205
  set_pk: if True, sets the primary for each DataRow
205
206
  exact_version_only: tables for which we only want to see rows created at the current version
206
207
  """
207
- super().__init__(tbl, row_builder, select_list, set_pk=set_pk)
208
+ sql_elements = exprs.SqlElementCache()
209
+ super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
208
210
  # create Select stmt
209
211
  if order_by_items is None:
210
212
  order_by_items = []
@@ -230,10 +232,10 @@ class SqlScanNode(SqlNode):
230
232
  if isinstance(e, exprs.SimilarityExpr):
231
233
  order_by_clause.append(e.as_order_by_clause(asc))
232
234
  else:
233
- order_by_clause.append(e.sql_expr().desc() if not asc else e.sql_expr())
235
+ order_by_clause.append(sql_elements.get(e).desc() if not asc else sql_elements.get(e))
234
236
 
235
237
  if where_clause is not None:
236
- sql_where_clause = where_clause.sql_expr()
238
+ sql_where_clause = sql_elements.get(where_clause)
237
239
  assert sql_where_clause is not None
238
240
  self.stmt = self.stmt.where(sql_where_clause)
239
241
  if len(order_by_clause) > 0:
@@ -272,7 +274,8 @@ class SqlLookupNode(SqlNode):
272
274
  sa_key_cols: list of key columns in the store table
273
275
  key_vals: list of key values to look up
274
276
  """
275
- super().__init__(tbl, row_builder, select_list, set_pk=True)
277
+ sql_elements = exprs.SqlElementCache()
278
+ super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
276
279
  target = tbl.tbl_version # the stored table we're scanning
277
280
  refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
278
281
  self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
@@ -20,5 +20,6 @@ from .object_ref import ObjectRef
20
20
  from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
21
21
  from .rowid_ref import RowidRef
22
22
  from .similarity_expr import SimilarityExpr
23
+ from .sql_element_cache import SqlElementCache
23
24
  from .type_cast import TypeCast
24
25
  from .variable import Variable
@@ -6,11 +6,11 @@ import sqlalchemy as sql
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
9
-
10
9
  from .data_row import DataRow
11
10
  from .expr import Expr
12
11
  from .globals import ArithmeticOperator
13
12
  from .row_builder import RowBuilder
13
+ from .sql_element_cache import SqlElementCache
14
14
 
15
15
 
16
16
  class ArithmeticExpr(Expr):
@@ -54,10 +54,10 @@ class ArithmeticExpr(Expr):
54
54
  def _op2(self) -> Expr:
55
55
  return self.components[1]
56
56
 
57
- def sql_expr(self) -> Optional[sql.ClauseElement]:
57
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
58
58
  assert self.col_type.is_int_type() or self.col_type.is_float_type() or self.col_type.is_json_type()
59
- left = self._op1.sql_expr()
60
- right = self._op2.sql_expr()
59
+ left = sql_elements.get(self._op1)
60
+ right = sql_elements.get(self._op2)
61
61
  if left is None or right is None:
62
62
  return None
63
63
  if self.operator == ArithmeticOperator.ADD:
@@ -8,6 +8,7 @@ from .data_row import DataRow
8
8
  from .expr import Expr
9
9
  from .globals import print_slice
10
10
  from .row_builder import RowBuilder
11
+ from .sql_element_cache import SqlElementCache
11
12
 
12
13
 
13
14
  class ArraySlice(Expr):
@@ -41,7 +42,7 @@ class ArraySlice(Expr):
41
42
  def _id_attrs(self) -> List[Tuple[str, Any]]:
42
43
  return super()._id_attrs() + [('index', self.index)]
43
44
 
44
- def sql_expr(self) -> Optional[sql.ClauseElement]:
45
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
45
46
  return None
46
47
 
47
48
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -1,15 +1,16 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, List, Any, Dict, Tuple
2
+
3
3
  import enum
4
+ from typing import Optional, List, Any, Dict, Tuple
4
5
 
5
6
  import sqlalchemy as sql
6
7
 
7
- from .expr import Expr
8
+ import pixeltable.type_system as ts
8
9
  from .column_ref import ColumnRef
9
- from .row_builder import RowBuilder
10
10
  from .data_row import DataRow
11
- import pixeltable.catalog as catalog
12
- import pixeltable.type_system as ts
11
+ from .expr import Expr
12
+ from .row_builder import RowBuilder
13
+ from .sql_element_cache import SqlElementCache
13
14
 
14
15
 
15
16
  class ColumnPropertyRef(Expr):
@@ -45,7 +46,7 @@ class ColumnPropertyRef(Expr):
45
46
  def __str__(self) -> str:
46
47
  return f'{self._col_ref}.{self.prop.name.lower()}'
47
48
 
48
- def sql_expr(self) -> Optional[sql.ClauseElement]:
49
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
49
50
  if not self._col_ref.col.is_stored:
50
51
  return None
51
52
  if self.prop == self.Property.ERRORTYPE:
@@ -56,7 +57,7 @@ class ColumnPropertyRef(Expr):
56
57
  return self._col_ref.col.sa_errormsg_col
57
58
  if self.prop == self.Property.FILEURL:
58
59
  # the file url is stored as the column value
59
- return self._col_ref.sql_expr()
60
+ return sql_elements.get(self._col_ref)
60
61
  return None
61
62
 
62
63
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -73,5 +74,6 @@ class ColumnPropertyRef(Expr):
73
74
  @classmethod
74
75
  def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
75
76
  assert 'prop' in d
77
+ assert isinstance(components[0], ColumnRef)
76
78
  return cls(components[0], cls.Property(d['prop']))
77
79
 
@@ -7,6 +7,7 @@ import sqlalchemy as sql
7
7
  from .expr import Expr
8
8
  from .data_row import DataRow
9
9
  from .row_builder import RowBuilder
10
+ from .sql_element_cache import SqlElementCache
10
11
  import pixeltable.iterators as iters
11
12
  import pixeltable.exceptions as excs
12
13
  import pixeltable.catalog as catalog
@@ -92,7 +93,7 @@ class ColumnRef(Expr):
92
93
  def __repr__(self) -> str:
93
94
  return f'ColumnRef({self.col!r})'
94
95
 
95
- def sql_expr(self) -> Optional[sql.ClauseElement]:
96
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
96
97
  return self.col.sa_col
97
98
 
98
99
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -1,22 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
- from datetime import datetime
4
3
  from typing import Optional, List, Any, Dict
5
4
 
6
5
  import sqlalchemy as sql
7
6
 
7
+ import pixeltable.exceptions as excs
8
+ import pixeltable.index as index
9
+ import pixeltable.type_system as ts
8
10
  from .column_ref import ColumnRef
9
11
  from .data_row import DataRow
10
12
  from .expr import Expr
11
13
  from .globals import ComparisonOperator
12
14
  from .literal import Literal
13
15
  from .row_builder import RowBuilder
14
- import pixeltable.exceptions as excs
15
- import pixeltable.index as index
16
- import pixeltable.type_system as ts
16
+ from .sql_element_cache import SqlElementCache
17
17
 
18
18
 
19
19
  class Comparison(Expr):
20
+ is_search_arg_comparison: bool
21
+ operator: ComparisonOperator
22
+
20
23
  def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
21
24
  super().__init__(ts.BoolType())
22
25
  self.operator = operator
@@ -62,8 +65,8 @@ class Comparison(Expr):
62
65
  def _op2(self) -> Expr:
63
66
  return self.components[1]
64
67
 
65
- def sql_expr(self) -> Optional[sql.ClauseElement]:
66
- left = self._op1.sql_expr()
68
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ClauseElement]:
69
+ left = sql_elements.get(self._op1)
67
70
  if self.is_search_arg_comparison:
68
71
  # reference the index value column if there is an index and this is not a snapshot
69
72
  # (indices don't apply to snapshots)
@@ -76,7 +79,7 @@ class Comparison(Expr):
76
79
  assert len(idx_info) == 1
77
80
  left = idx_info[0].val_col.sa_col
78
81
 
79
- right = self._op2.sql_expr()
82
+ right = sql_elements.get(self._op2)
80
83
  if left is None or right is None:
81
84
  return None
82
85
 
@@ -9,6 +9,7 @@ from .data_row import DataRow
9
9
  from .expr import Expr
10
10
  from .globals import LogicalOperator
11
11
  from .row_builder import RowBuilder
12
+ from .sql_element_cache import SqlElementCache
12
13
  import pixeltable.type_system as ts
13
14
 
14
15
 
@@ -66,8 +67,8 @@ class CompoundPredicate(Expr):
66
67
  non_matches = [op for op in self.components if not condition(op)]
67
68
  return (matches, self.make_conjunction(non_matches))
68
69
 
69
- def sql_expr(self) -> Optional[sql.ClauseElement]:
70
- sql_exprs = [op.sql_expr() for op in self.components]
70
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
71
+ sql_exprs = [sql_elements.get(op) for op in self.components]
71
72
  if any(e is None for e in sql_exprs):
72
73
  return None
73
74
  if self.operator == LogicalOperator.NOT: