pixeltable 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (87) hide show
  1. pixeltable/__init__.py +1 -1
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +8 -7
  4. pixeltable/catalog/column.py +11 -8
  5. pixeltable/catalog/insertable_table.py +1 -1
  6. pixeltable/catalog/path_dict.py +8 -6
  7. pixeltable/catalog/table.py +20 -14
  8. pixeltable/catalog/table_version.py +92 -55
  9. pixeltable/catalog/table_version_path.py +7 -9
  10. pixeltable/catalog/view.py +3 -2
  11. pixeltable/dataframe.py +2 -2
  12. pixeltable/env.py +205 -86
  13. pixeltable/exceptions.py +5 -1
  14. pixeltable/exec/aggregation_node.py +2 -1
  15. pixeltable/exec/component_iteration_node.py +2 -2
  16. pixeltable/exec/sql_node.py +11 -8
  17. pixeltable/exprs/__init__.py +2 -2
  18. pixeltable/exprs/arithmetic_expr.py +4 -4
  19. pixeltable/exprs/array_slice.py +2 -1
  20. pixeltable/exprs/column_property_ref.py +9 -7
  21. pixeltable/exprs/column_ref.py +2 -1
  22. pixeltable/exprs/comparison.py +10 -7
  23. pixeltable/exprs/compound_predicate.py +3 -2
  24. pixeltable/exprs/data_row.py +19 -4
  25. pixeltable/exprs/expr.py +51 -41
  26. pixeltable/exprs/expr_set.py +32 -9
  27. pixeltable/exprs/function_call.py +62 -40
  28. pixeltable/exprs/in_predicate.py +3 -2
  29. pixeltable/exprs/inline_expr.py +200 -0
  30. pixeltable/exprs/is_null.py +3 -2
  31. pixeltable/exprs/json_mapper.py +5 -4
  32. pixeltable/exprs/json_path.py +7 -1
  33. pixeltable/exprs/literal.py +34 -7
  34. pixeltable/exprs/method_ref.py +3 -3
  35. pixeltable/exprs/object_ref.py +6 -5
  36. pixeltable/exprs/row_builder.py +25 -17
  37. pixeltable/exprs/rowid_ref.py +2 -1
  38. pixeltable/exprs/similarity_expr.py +2 -1
  39. pixeltable/exprs/sql_element_cache.py +30 -0
  40. pixeltable/exprs/type_cast.py +3 -3
  41. pixeltable/exprs/variable.py +2 -1
  42. pixeltable/ext/functions/whisperx.py +6 -4
  43. pixeltable/ext/functions/yolox.py +11 -9
  44. pixeltable/func/aggregate_function.py +1 -0
  45. pixeltable/func/function.py +28 -4
  46. pixeltable/functions/__init__.py +4 -2
  47. pixeltable/functions/anthropic.py +15 -5
  48. pixeltable/functions/fireworks.py +1 -1
  49. pixeltable/functions/globals.py +6 -1
  50. pixeltable/functions/huggingface.py +91 -14
  51. pixeltable/functions/image.py +20 -5
  52. pixeltable/functions/json.py +5 -5
  53. pixeltable/functions/mistralai.py +188 -0
  54. pixeltable/functions/openai.py +6 -10
  55. pixeltable/functions/string.py +3 -2
  56. pixeltable/functions/timestamp.py +95 -7
  57. pixeltable/functions/together.py +18 -11
  58. pixeltable/functions/video.py +2 -2
  59. pixeltable/functions/vision.py +69 -37
  60. pixeltable/functions/whisper.py +4 -1
  61. pixeltable/globals.py +5 -1
  62. pixeltable/io/hf_datasets.py +17 -15
  63. pixeltable/io/pandas.py +0 -2
  64. pixeltable/io/parquet.py +15 -14
  65. pixeltable/iterators/document.py +16 -15
  66. pixeltable/metadata/__init__.py +1 -1
  67. pixeltable/metadata/converters/convert_18.py +1 -1
  68. pixeltable/metadata/converters/convert_19.py +46 -0
  69. pixeltable/metadata/converters/convert_20.py +56 -0
  70. pixeltable/metadata/converters/util.py +29 -4
  71. pixeltable/metadata/notes.py +2 -0
  72. pixeltable/metadata/schema.py +5 -4
  73. pixeltable/plan.py +100 -78
  74. pixeltable/store.py +5 -1
  75. pixeltable/tool/create_test_db_dump.py +18 -6
  76. pixeltable/type_system.py +15 -15
  77. pixeltable/utils/documents.py +45 -42
  78. pixeltable/utils/formatter.py +2 -2
  79. pixeltable-0.2.19.dist-info/LICENSE +201 -0
  80. {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/METADATA +84 -24
  81. pixeltable-0.2.19.dist-info/RECORD +147 -0
  82. pixeltable/exprs/inline_array.py +0 -116
  83. pixeltable/exprs/inline_dict.py +0 -103
  84. pixeltable-0.2.17.dist-info/LICENSE +0 -18
  85. pixeltable-0.2.17.dist-info/RECORD +0 -144
  86. {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/WHEEL +0 -0
  87. {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/entry_points.txt +0 -0
pixeltable/env.py CHANGED
@@ -8,13 +8,15 @@ import importlib.util
8
8
  import inspect
9
9
  import logging
10
10
  import os
11
+ import subprocess
11
12
  import sys
12
13
  import threading
13
14
  import uuid
14
15
  import warnings
15
16
  from dataclasses import dataclass
16
17
  from pathlib import Path
17
- from typing import Callable, Optional, Dict, Any, List, TYPE_CHECKING
18
+ from typing import TYPE_CHECKING, Any, Callable, Optional
19
+ from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
18
20
 
19
21
  import pixeltable_pgserver
20
22
  import sqlalchemy as sql
@@ -37,6 +39,35 @@ class Env:
37
39
  _instance: Optional[Env] = None
38
40
  _log_fmt_str = '%(asctime)s %(levelname)s %(name)s %(filename)s:%(lineno)d: %(message)s'
39
41
 
42
+ _home: Optional[Path]
43
+ _media_dir: Optional[Path]
44
+ _file_cache_dir: Optional[Path] # cached media files with external URL
45
+ _dataset_cache_dir: Optional[Path] # cached datasets (eg, pytorch or COCO)
46
+ _log_dir: Optional[Path] # log files
47
+ _tmp_dir: Optional[Path] # any tmp files
48
+ _sa_engine: Optional[sql.engine.base.Engine]
49
+ _pgdata_dir: Optional[Path]
50
+ _db_name: Optional[str]
51
+ _db_server: Optional[pixeltable_pgserver.PostgresServer]
52
+ _db_url: Optional[str]
53
+ _default_time_zone: Optional[ZoneInfo]
54
+
55
+ # info about optional packages that are utilized by some parts of the code
56
+ __optional_packages: dict[str, PackageInfo]
57
+
58
+ _spacy_nlp: Optional[spacy.Language]
59
+ _httpd: Optional[http.server.HTTPServer]
60
+ _http_address: Optional[str]
61
+ _logger: logging.Logger
62
+ _default_log_level: int
63
+ _logfilename: Optional[str]
64
+ _log_to_stdout: bool
65
+ _module_log_level: dict[str, int] # module name -> log level
66
+ _config_file: Optional[Path]
67
+ _config: Optional[dict[str, Any]]
68
+ _stdout_handler: logging.StreamHandler
69
+ _initialized: bool
70
+
40
71
  @classmethod
41
72
  def get(cls) -> Env:
42
73
  if cls._instance is None:
@@ -51,24 +82,23 @@ class Env:
51
82
  cls._instance = env
52
83
 
53
84
  def __init__(self):
54
- self._home: Optional[Path] = None
55
- self._media_dir: Optional[Path] = None # computed media files
56
- self._file_cache_dir: Optional[Path] = None # cached media files with external URL
57
- self._dataset_cache_dir: Optional[Path] = None # cached datasets (eg, pytorch or COCO)
58
- self._log_dir: Optional[Path] = None # log files
59
- self._tmp_dir: Optional[Path] = None # any tmp files
60
- self._sa_engine: Optional[sql.engine.base.Engine] = None
61
- self._pgdata_dir: Optional[Path] = None
62
- self._db_name: Optional[str] = None
63
- self._db_server: Optional[pixeltable_pgserver.PostgresServer] = None
64
- self._db_url: Optional[str] = None
65
-
66
- # info about installed packages that are utilized by some parts of the code;
67
- # package name -> version; version == []: package is installed, but we haven't determined the version yet
68
- self._installed_packages: Dict[str, Optional[List[int]]] = {}
69
- self._spacy_nlp: Optional[spacy.Language] = None
70
- self._httpd: Optional[http.server.HTTPServer] = None
71
- self._http_address: Optional[str] = None
85
+ self._home = None
86
+ self._media_dir = None # computed media files
87
+ self._file_cache_dir = None # cached media files with external URL
88
+ self._dataset_cache_dir = None # cached datasets (eg, pytorch or COCO)
89
+ self._log_dir = None # log files
90
+ self._tmp_dir = None # any tmp files
91
+ self._sa_engine = None
92
+ self._pgdata_dir = None
93
+ self._db_name = None
94
+ self._db_server = None
95
+ self._db_url = None
96
+ self._default_time_zone = None
97
+
98
+ self.__optional_packages = {}
99
+ self._spacy_nlp = None
100
+ self._httpd = None
101
+ self._http_address = None
72
102
 
73
103
  # logging-related state
74
104
  self._logger = logging.getLogger('pixeltable')
@@ -76,13 +106,12 @@ class Env:
76
106
  self._logger.propagate = False
77
107
  self._logger.addFilter(self._log_filter)
78
108
  self._default_log_level = logging.INFO
79
- self._logfilename: Optional[str] = None
109
+ self._logfilename = None
80
110
  self._log_to_stdout = False
81
- self._module_log_level: Dict[str, int] = {} # module name -> log level
111
+ self._module_log_level = {} # module name -> log level
82
112
 
83
- # config
84
- self._config_file: Optional[Path] = None
85
- self._config: Optional[Dict[str, Any]] = None
113
+ self._config_file = None
114
+ self._config = None
86
115
 
87
116
  # create logging handler to also log to stdout
88
117
  self._stdout_handler = logging.StreamHandler(stream=sys.stdout)
@@ -103,6 +132,19 @@ class Env:
103
132
  assert self._http_address is not None
104
133
  return self._http_address
105
134
 
135
+ @property
136
+ def default_time_zone(self) -> Optional[ZoneInfo]:
137
+ return self._default_time_zone
138
+
139
+ @default_time_zone.setter
140
+ def default_time_zone(self, tz: Optional[ZoneInfo]) -> None:
141
+ """
142
+ This is not a publicly visible setter; it is only for testing purposes.
143
+ """
144
+ tz_name = None if tz is None else tz.key
145
+ self.engine.dispose()
146
+ self._create_engine(time_zone_name=tz_name)
147
+
106
148
  def configure_logging(
107
149
  self,
108
150
  *,
@@ -158,7 +200,8 @@ class Env:
158
200
  self._module_log_level[module] = level
159
201
 
160
202
  def is_installed_package(self, package_name: str) -> bool:
161
- return self._installed_packages[package_name] is not None
203
+ assert package_name in self.__optional_packages
204
+ return self.__optional_packages[package_name].is_installed
162
205
 
163
206
  def _log_filter(self, record: logging.LogRecord) -> bool:
164
207
  if record.name == 'pixeltable':
@@ -270,20 +313,35 @@ class Env:
270
313
  self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=None)
271
314
  self._db_url = self._db_server.get_uri(database=self._db_name, driver='psycopg')
272
315
 
273
- if reinit_db:
274
- if self._store_db_exists():
275
- self._drop_store_db()
316
+ tz_name = os.environ.get('PXT_TIME_ZONE', self._config.get('pxt_time_zone', None))
317
+ if tz_name is not None:
318
+ # Validate tzname
319
+ if not isinstance(tz_name, str):
320
+ self._logger.error(f'Invalid time zone specified in configuration.')
321
+ else:
322
+ try:
323
+ _ = ZoneInfo(tz_name)
324
+ except ZoneInfoNotFoundError:
325
+ self._logger.error(f'Invalid time zone specified in configuration: {tz_name}')
326
+
327
+ if reinit_db and self._store_db_exists():
328
+ self._drop_store_db()
276
329
 
277
- if not self._store_db_exists():
278
- self._logger.info(f'creating database at {self.db_url}')
330
+ create_db = not self._store_db_exists()
331
+
332
+ if create_db:
333
+ self._logger.info(f'creating database at: {self.db_url}')
279
334
  self._create_store_db()
280
- self._create_engine(echo=echo)
335
+ else:
336
+ self._logger.info(f'found database at: {self.db_url}')
337
+
338
+ # Create the SQLAlchemy engine. This will also set the default time zone.
339
+ self._create_engine(time_zone_name=tz_name, echo=echo)
340
+
341
+ if create_db:
281
342
  from pixeltable.metadata import schema
282
343
  schema.Base.metadata.create_all(self._sa_engine)
283
344
  metadata.create_system_info(self._sa_engine)
284
- else:
285
- self._logger.info(f'found database {self.db_url}')
286
- self._create_engine(echo=echo)
287
345
 
288
346
  print(f'Connected to Pixeltable database at: {self.db_url}')
289
347
 
@@ -291,8 +349,21 @@ class Env:
291
349
  self._set_up_runtime()
292
350
  self.log_to_stdout(False)
293
351
 
294
- def _create_engine(self, echo: bool = False) -> None:
295
- self._sa_engine = sql.create_engine(self.db_url, echo=echo, future=True, isolation_level='AUTOCOMMIT')
352
+ def _create_engine(self, time_zone_name: Optional[str], echo: bool = False) -> None:
353
+ connect_args = {} if time_zone_name is None else {'options': f'-c timezone={time_zone_name}'}
354
+ self._sa_engine = sql.create_engine(
355
+ self.db_url,
356
+ echo=echo,
357
+ future=True,
358
+ isolation_level='AUTOCOMMIT',
359
+ connect_args=connect_args,
360
+ )
361
+ self._logger.info(f'Created SQLAlchemy engine at: {self.db_url}')
362
+ with self.engine.begin() as conn:
363
+ tz_name = conn.execute(sql.text('SHOW TIME ZONE')).scalar()
364
+ assert isinstance(tz_name, str)
365
+ self._logger.info(f'Database time zone is now: {tz_name}')
366
+ self._default_time_zone = ZoneInfo(tz_name)
296
367
 
297
368
  def _store_db_exists(self) -> bool:
298
369
  assert self._db_name is not None
@@ -308,7 +379,6 @@ class Env:
308
379
  finally:
309
380
  engine.dispose()
310
381
 
311
-
312
382
  def _create_store_db(self) -> None:
313
383
  assert self._db_name is not None
314
384
  # create the db
@@ -416,63 +486,104 @@ class Env:
416
486
  def _set_up_runtime(self) -> None:
417
487
  """Check for and start runtime services"""
418
488
  self._start_web_server()
419
- self._check_installed_packages()
420
-
421
- def _check_installed_packages(self) -> None:
422
- def check(package: str) -> None:
423
- if importlib.util.find_spec(package) is not None:
424
- self._installed_packages[package] = []
425
- else:
426
- self._installed_packages[package] = None
427
-
428
- check('toml')
429
- check('datasets')
430
- check('torch')
431
- check('torchvision')
432
- check('transformers')
433
- check('sentence_transformers')
434
- check('whisper')
435
- check('yolox')
436
- check('whisperx')
437
- check('boto3')
438
- check('fitz') # pymupdf
439
- check('pyarrow')
440
- check('spacy') # TODO: deal with en-core-web-sm
489
+ self.__register_packages()
441
490
  if self.is_installed_package('spacy'):
442
- import spacy
443
-
444
- self._spacy_nlp = spacy.load('en_core_web_sm')
445
- check('tiktoken')
446
- check('openai')
447
- check('anthropic')
448
- check('together')
449
- check('fireworks')
450
- check('label_studio_sdk')
451
- check('openpyxl')
452
-
453
- def require_package(self, package: str, min_version: Optional[List[int]] = None) -> None:
454
- assert package in self._installed_packages
455
- if self._installed_packages[package] is None:
456
- raise excs.Error(f'Package {package} is not installed')
491
+ self.__init_spacy()
492
+
493
+ def __register_packages(self) -> None:
494
+ """Declare optional packages that are utilized by some parts of the code."""
495
+ self.__register_package('anthropic')
496
+ self.__register_package('boto3')
497
+ self.__register_package('datasets')
498
+ self.__register_package('fireworks', library_name='fireworks-ai')
499
+ self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
500
+ self.__register_package('mistralai')
501
+ self.__register_package('mistune')
502
+ self.__register_package('openai')
503
+ self.__register_package('openpyxl')
504
+ self.__register_package('pyarrow')
505
+ self.__register_package('sentence_transformers', library_name='sentence-transformers')
506
+ self.__register_package('spacy')
507
+ self.__register_package('tiktoken')
508
+ self.__register_package('together')
509
+ self.__register_package('toml')
510
+ self.__register_package('torch')
511
+ self.__register_package('torchvision')
512
+ self.__register_package('transformers')
513
+ self.__register_package('whisper', library_name='openai-whisper')
514
+ self.__register_package('whisperx')
515
+ self.__register_package('yolox', library_name='git+https://github.com/Megvii-BaseDetection/YOLOX@ac58e0a')
516
+
517
+ def __register_package(self, package_name: str, library_name: Optional[str] = None) -> None:
518
+ self.__optional_packages[package_name] = PackageInfo(
519
+ is_installed=importlib.util.find_spec(package_name) is not None,
520
+ library_name=library_name or package_name # defaults to package_name unless specified otherwise
521
+ )
522
+
523
+ def require_package(self, package_name: str, min_version: Optional[list[int]] = None) -> None:
524
+ """
525
+ Checks whether the specified optional package is available. If not, raises an exception
526
+ with an error message informing the user how to install it.
527
+ """
528
+ assert package_name in self.__optional_packages
529
+ package_info = self.__optional_packages[package_name]
530
+
531
+ if not package_info.is_installed:
532
+ # Check again whether the package has been installed.
533
+ # We do this so that if a user gets an "optional library not found" error message, they can
534
+ # `pip install` the library and re-run the Pixeltable operation without having to restart
535
+ # their Python session.
536
+ package_info.is_installed = importlib.util.find_spec(package_name) is not None
537
+ if not package_info.is_installed:
538
+ # Still not found.
539
+ raise excs.Error(
540
+ f'This feature requires the `{package_name}` package. To install it, run: `pip install -U {package_info.library_name}`'
541
+ )
542
+
457
543
  if min_version is None:
458
544
  return
459
545
 
460
546
  # check whether we have a version >= the required one
461
- if not self._installed_packages[package]:
462
- m = importlib.import_module(package)
463
- module_version = [int(x) for x in m.__version__.split('.')]
464
- self._installed_packages[package] = module_version
465
- installed_version = self._installed_packages[package]
466
- if len(min_version) < len(installed_version):
467
- normalized_min_version = min_version + [0] * (len(installed_version) - len(min_version))
468
- if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
547
+ if package_info.version is None:
548
+ module = importlib.import_module(package_name)
549
+ package_info.version = [int(x) for x in module.__version__.split('.')]
550
+
551
+ if min_version > package_info.version:
469
552
  raise excs.Error(
470
- (
471
- f'The installed version of package {package} is {".".join(str(v) for v in installed_version)}, '
472
- f'but version >={".".join(str(v) for v in min_version)} is required'
473
- )
553
+ f'The installed version of package `{package_name}` is {".".join(str(v) for v in package_info.version)}, '
554
+ f'but version >={".".join(str(v) for v in min_version)} is required. '
555
+ f'To fix this, run: `pip install -U {package_info.library_name}`'
474
556
  )
475
557
 
558
+ def __init_spacy(self) -> None:
559
+ """
560
+ spaCy relies on a pip-installed model to operate. In order to avoid requiring the model as a separate
561
+ dependency, we install it programmatically here. This should cause no problems, since the model packages
562
+ have no sub-dependencies (in fact, this is how spaCy normally manages its model resources).
563
+ """
564
+ import spacy
565
+ from spacy.cli.download import get_model_filename
566
+ spacy_model = 'en_core_web_sm'
567
+ spacy_model_version = '3.7.1'
568
+ filename = get_model_filename(spacy_model, spacy_model_version, sdist=False)
569
+ url = f'{spacy.about.__download_url__}/{filename}'
570
+ # Try to `pip install` the model. We set check=False; if the pip command fails, it's not necessarily
571
+ # a problem, because the model have been installed on a previous attempt.
572
+ self._logger.info(f'Ensuring spaCy model is installed: {filename}')
573
+ ret = subprocess.run([sys.executable, '-m', 'pip', 'install', '-qU', url], check=False)
574
+ if ret.returncode != 0:
575
+ self._logger.warn(f'pip install failed for spaCy model: {filename}')
576
+ try:
577
+ self._logger.info(f'Loading spaCy model: {spacy_model}')
578
+ self._spacy_nlp = spacy.load(spacy_model)
579
+ except Exception as exc:
580
+ self._logger.warn(f'Failed to load spaCy model: {spacy_model}', exc_info=exc)
581
+ warnings.warn(
582
+ f"Failed to load spaCy model '{spacy_model}'. spaCy features will not be available.",
583
+ excs.PixeltableWarning
584
+ )
585
+ self.__optional_packages['spacy'].is_installed = False
586
+
476
587
  def num_tmp_files(self) -> int:
477
588
  return len(glob.glob(f'{self._tmp_dir}/*'))
478
589
 
@@ -511,6 +622,7 @@ class Env:
511
622
 
512
623
  @property
513
624
  def spacy_nlp(self) -> spacy.Language:
625
+ Env.get().require_package('spacy')
514
626
  assert self._spacy_nlp is not None
515
627
  return self._spacy_nlp
516
628
 
@@ -556,3 +668,10 @@ class ApiClient:
556
668
  init_fn: Callable
557
669
  param_names: list[str]
558
670
  client_obj: Optional[Any] = None
671
+
672
+
673
+ @dataclass
674
+ class PackageInfo:
675
+ is_installed: bool
676
+ library_name: str # pypi library name (may be different from package name)
677
+ version: Optional[list[int]] = None # installed version, as a list of components (such as [3,0,2] for "3.0.2")
pixeltable/exceptions.py CHANGED
@@ -14,4 +14,8 @@ class ExprEvalError(Exception):
14
14
  exc: Exception
15
15
  exc_tb: TracebackType
16
16
  input_vals: List[Any]
17
- row_num: int
17
+ row_num: int
18
+
19
+
20
+ class PixeltableWarning(Warning):
21
+ pass
@@ -21,8 +21,9 @@ class AggregationNode(ExecNode):
21
21
  self.input = input
22
22
  self.group_by = group_by
23
23
  self.input_exprs = list(input_exprs)
24
- self.agg_fn_calls = agg_fn_calls
25
24
  self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
25
+ # we need to make sure to refer to the same exprs that RowBuilder.eval() will use
26
+ self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
26
27
  self.output_batch = DataRowBatch(tbl, row_builder, 0)
27
28
 
28
29
  def _reset_agg_state(self, row_num: int) -> None:
@@ -19,12 +19,12 @@ class ComponentIterationNode(ExecNode):
19
19
  super().__init__(input.row_builder, [], [], input)
20
20
  self.view = view
21
21
  iterator_args = [view.iterator_args.copy()]
22
- self.row_builder.substitute_exprs(iterator_args)
22
+ self.row_builder.set_slot_idxs(iterator_args)
23
23
  self.iterator_args = iterator_args[0]
24
24
  assert isinstance(self.iterator_args, exprs.InlineDict)
25
25
  self.iterator_args_ctx = self.row_builder.create_eval_ctx([self.iterator_args])
26
26
  self.iterator_output_schema, self.unstored_column_names = \
27
- self.view.iterator_cls.output_schema(**self.iterator_args.to_dict())
27
+ self.view.iterator_cls.output_schema(**self.iterator_args.to_kwargs())
28
28
  self.iterator_output_fields = list(self.iterator_output_schema.keys())
29
29
  self.iterator_output_cols = \
30
30
  {field_name: self.view.cols_by_name[field_name] for field_name in self.iterator_output_fields}
@@ -18,7 +18,7 @@ class SqlNode(ExecNode):
18
18
 
19
19
  def __init__(
20
20
  self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
21
- select_list: Iterable[exprs.Expr], set_pk: bool = False
21
+ select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
22
22
  ):
23
23
  """
24
24
  Initialize self.stmt with expressions derived from select_list.
@@ -35,8 +35,9 @@ class SqlNode(ExecNode):
35
35
  self.sql_exprs = exprs.ExprSet(select_list)
36
36
  # unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
37
37
  for iter_arg in row_builder.unstored_iter_args.values():
38
- sql_subexprs = iter_arg.subexprs(filter=lambda e: e.sql_expr() is not None, traverse_matches=False)
39
- [self.sql_exprs.append(e) for e in sql_subexprs]
38
+ sql_subexprs = iter_arg.subexprs(filter=sql_elements.contains, traverse_matches=False)
39
+ for e in sql_subexprs:
40
+ self.sql_exprs.add(e)
40
41
  super().__init__(row_builder, self.sql_exprs, [], None) # we materialize self.sql_exprs
41
42
 
42
43
  # change rowid refs against a base table to rowid refs against the target table, so that we minimize
@@ -44,7 +45,7 @@ class SqlNode(ExecNode):
44
45
  for rowid_ref in [e for e in self.sql_exprs if isinstance(e, exprs.RowidRef)]:
45
46
  rowid_ref.set_tbl(tbl)
46
47
 
47
- sql_select_list = [e.sql_expr() for e in self.sql_exprs]
48
+ sql_select_list = [sql_elements.get(e) for e in self.sql_exprs]
48
49
  assert len(sql_select_list) == len(self.sql_exprs)
49
50
  assert all(e is not None for e in sql_select_list)
50
51
  self.set_pk = set_pk
@@ -204,7 +205,8 @@ class SqlScanNode(SqlNode):
204
205
  set_pk: if True, sets the primary for each DataRow
205
206
  exact_version_only: tables for which we only want to see rows created at the current version
206
207
  """
207
- super().__init__(tbl, row_builder, select_list, set_pk=set_pk)
208
+ sql_elements = exprs.SqlElementCache()
209
+ super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
208
210
  # create Select stmt
209
211
  if order_by_items is None:
210
212
  order_by_items = []
@@ -230,10 +232,10 @@ class SqlScanNode(SqlNode):
230
232
  if isinstance(e, exprs.SimilarityExpr):
231
233
  order_by_clause.append(e.as_order_by_clause(asc))
232
234
  else:
233
- order_by_clause.append(e.sql_expr().desc() if not asc else e.sql_expr())
235
+ order_by_clause.append(sql_elements.get(e).desc() if not asc else sql_elements.get(e))
234
236
 
235
237
  if where_clause is not None:
236
- sql_where_clause = where_clause.sql_expr()
238
+ sql_where_clause = sql_elements.get(where_clause)
237
239
  assert sql_where_clause is not None
238
240
  self.stmt = self.stmt.where(sql_where_clause)
239
241
  if len(order_by_clause) > 0:
@@ -272,7 +274,8 @@ class SqlLookupNode(SqlNode):
272
274
  sa_key_cols: list of key columns in the store table
273
275
  key_vals: list of key values to look up
274
276
  """
275
- super().__init__(tbl, row_builder, select_list, set_pk=True)
277
+ sql_elements = exprs.SqlElementCache()
278
+ super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
276
279
  target = tbl.tbl_version # the stored table we're scanning
277
280
  refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
278
281
  self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
@@ -9,8 +9,7 @@ from .expr import Expr
9
9
  from .expr_set import ExprSet
10
10
  from .function_call import FunctionCall
11
11
  from .in_predicate import InPredicate
12
- from .inline_array import InlineArray
13
- from .inline_dict import InlineDict
12
+ from .inline_expr import InlineArray, InlineDict, InlineList
14
13
  from .is_null import IsNull
15
14
  from .json_mapper import JsonMapper
16
15
  from .json_path import RELATIVE_PATH_ROOT, JsonPath
@@ -20,5 +19,6 @@ from .object_ref import ObjectRef
20
19
  from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
21
20
  from .rowid_ref import RowidRef
22
21
  from .similarity_expr import SimilarityExpr
22
+ from .sql_element_cache import SqlElementCache
23
23
  from .type_cast import TypeCast
24
24
  from .variable import Variable
@@ -6,11 +6,11 @@ import sqlalchemy as sql
6
6
 
7
7
  import pixeltable.exceptions as excs
8
8
  import pixeltable.type_system as ts
9
-
10
9
  from .data_row import DataRow
11
10
  from .expr import Expr
12
11
  from .globals import ArithmeticOperator
13
12
  from .row_builder import RowBuilder
13
+ from .sql_element_cache import SqlElementCache
14
14
 
15
15
 
16
16
  class ArithmeticExpr(Expr):
@@ -54,10 +54,10 @@ class ArithmeticExpr(Expr):
54
54
  def _op2(self) -> Expr:
55
55
  return self.components[1]
56
56
 
57
- def sql_expr(self) -> Optional[sql.ClauseElement]:
57
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
58
58
  assert self.col_type.is_int_type() or self.col_type.is_float_type() or self.col_type.is_json_type()
59
- left = self._op1.sql_expr()
60
- right = self._op2.sql_expr()
59
+ left = sql_elements.get(self._op1)
60
+ right = sql_elements.get(self._op2)
61
61
  if left is None or right is None:
62
62
  return None
63
63
  if self.operator == ArithmeticOperator.ADD:
@@ -8,6 +8,7 @@ from .data_row import DataRow
8
8
  from .expr import Expr
9
9
  from .globals import print_slice
10
10
  from .row_builder import RowBuilder
11
+ from .sql_element_cache import SqlElementCache
11
12
 
12
13
 
13
14
  class ArraySlice(Expr):
@@ -41,7 +42,7 @@ class ArraySlice(Expr):
41
42
  def _id_attrs(self) -> List[Tuple[str, Any]]:
42
43
  return super()._id_attrs() + [('index', self.index)]
43
44
 
44
- def sql_expr(self) -> Optional[sql.ClauseElement]:
45
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
45
46
  return None
46
47
 
47
48
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -1,15 +1,16 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, List, Any, Dict, Tuple
2
+
3
3
  import enum
4
+ from typing import Optional, List, Any, Dict, Tuple
4
5
 
5
6
  import sqlalchemy as sql
6
7
 
7
- from .expr import Expr
8
+ import pixeltable.type_system as ts
8
9
  from .column_ref import ColumnRef
9
- from .row_builder import RowBuilder
10
10
  from .data_row import DataRow
11
- import pixeltable.catalog as catalog
12
- import pixeltable.type_system as ts
11
+ from .expr import Expr
12
+ from .row_builder import RowBuilder
13
+ from .sql_element_cache import SqlElementCache
13
14
 
14
15
 
15
16
  class ColumnPropertyRef(Expr):
@@ -45,7 +46,7 @@ class ColumnPropertyRef(Expr):
45
46
  def __str__(self) -> str:
46
47
  return f'{self._col_ref}.{self.prop.name.lower()}'
47
48
 
48
- def sql_expr(self) -> Optional[sql.ClauseElement]:
49
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
49
50
  if not self._col_ref.col.is_stored:
50
51
  return None
51
52
  if self.prop == self.Property.ERRORTYPE:
@@ -56,7 +57,7 @@ class ColumnPropertyRef(Expr):
56
57
  return self._col_ref.col.sa_errormsg_col
57
58
  if self.prop == self.Property.FILEURL:
58
59
  # the file url is stored as the column value
59
- return self._col_ref.sql_expr()
60
+ return sql_elements.get(self._col_ref)
60
61
  return None
61
62
 
62
63
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -73,5 +74,6 @@ class ColumnPropertyRef(Expr):
73
74
  @classmethod
74
75
  def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
75
76
  assert 'prop' in d
77
+ assert isinstance(components[0], ColumnRef)
76
78
  return cls(components[0], cls.Property(d['prop']))
77
79
 
@@ -7,6 +7,7 @@ import sqlalchemy as sql
7
7
  from .expr import Expr
8
8
  from .data_row import DataRow
9
9
  from .row_builder import RowBuilder
10
+ from .sql_element_cache import SqlElementCache
10
11
  import pixeltable.iterators as iters
11
12
  import pixeltable.exceptions as excs
12
13
  import pixeltable.catalog as catalog
@@ -92,7 +93,7 @@ class ColumnRef(Expr):
92
93
  def __repr__(self) -> str:
93
94
  return f'ColumnRef({self.col!r})'
94
95
 
95
- def sql_expr(self) -> Optional[sql.ClauseElement]:
96
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
96
97
  return self.col.sa_col
97
98
 
98
99
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -1,22 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
- from datetime import datetime
4
3
  from typing import Optional, List, Any, Dict
5
4
 
6
5
  import sqlalchemy as sql
7
6
 
7
+ import pixeltable.exceptions as excs
8
+ import pixeltable.index as index
9
+ import pixeltable.type_system as ts
8
10
  from .column_ref import ColumnRef
9
11
  from .data_row import DataRow
10
12
  from .expr import Expr
11
13
  from .globals import ComparisonOperator
12
14
  from .literal import Literal
13
15
  from .row_builder import RowBuilder
14
- import pixeltable.exceptions as excs
15
- import pixeltable.index as index
16
- import pixeltable.type_system as ts
16
+ from .sql_element_cache import SqlElementCache
17
17
 
18
18
 
19
19
  class Comparison(Expr):
20
+ is_search_arg_comparison: bool
21
+ operator: ComparisonOperator
22
+
20
23
  def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
21
24
  super().__init__(ts.BoolType())
22
25
  self.operator = operator
@@ -62,8 +65,8 @@ class Comparison(Expr):
62
65
  def _op2(self) -> Expr:
63
66
  return self.components[1]
64
67
 
65
- def sql_expr(self) -> Optional[sql.ClauseElement]:
66
- left = self._op1.sql_expr()
68
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ClauseElement]:
69
+ left = sql_elements.get(self._op1)
67
70
  if self.is_search_arg_comparison:
68
71
  # reference the index value column if there is an index and this is not a snapshot
69
72
  # (indices don't apply to snapshots)
@@ -76,7 +79,7 @@ class Comparison(Expr):
76
79
  assert len(idx_info) == 1
77
80
  left = idx_info[0].val_col.sa_col
78
81
 
79
- right = self._op2.sql_expr()
82
+ right = sql_elements.get(self._op2)
80
83
  if left is None or right is None:
81
84
  return None
82
85