pixeltable 0.2.17__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/catalog.py +8 -7
- pixeltable/catalog/column.py +11 -8
- pixeltable/catalog/insertable_table.py +1 -1
- pixeltable/catalog/path_dict.py +8 -6
- pixeltable/catalog/table.py +20 -13
- pixeltable/catalog/table_version.py +91 -54
- pixeltable/catalog/table_version_path.py +7 -9
- pixeltable/catalog/view.py +2 -1
- pixeltable/dataframe.py +1 -1
- pixeltable/env.py +173 -83
- pixeltable/exec/aggregation_node.py +2 -1
- pixeltable/exec/component_iteration_node.py +1 -1
- pixeltable/exec/sql_node.py +11 -8
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +4 -4
- pixeltable/exprs/array_slice.py +2 -1
- pixeltable/exprs/column_property_ref.py +9 -7
- pixeltable/exprs/column_ref.py +2 -1
- pixeltable/exprs/comparison.py +10 -7
- pixeltable/exprs/compound_predicate.py +3 -2
- pixeltable/exprs/data_row.py +19 -4
- pixeltable/exprs/expr.py +46 -35
- pixeltable/exprs/expr_set.py +32 -9
- pixeltable/exprs/function_call.py +56 -32
- pixeltable/exprs/in_predicate.py +3 -2
- pixeltable/exprs/inline_array.py +2 -1
- pixeltable/exprs/inline_dict.py +2 -1
- pixeltable/exprs/is_null.py +3 -2
- pixeltable/exprs/json_mapper.py +5 -4
- pixeltable/exprs/json_path.py +7 -1
- pixeltable/exprs/literal.py +34 -7
- pixeltable/exprs/method_ref.py +3 -3
- pixeltable/exprs/object_ref.py +6 -5
- pixeltable/exprs/row_builder.py +25 -17
- pixeltable/exprs/rowid_ref.py +2 -1
- pixeltable/exprs/similarity_expr.py +2 -1
- pixeltable/exprs/sql_element_cache.py +30 -0
- pixeltable/exprs/type_cast.py +3 -3
- pixeltable/exprs/variable.py +2 -1
- pixeltable/ext/functions/whisperx.py +4 -4
- pixeltable/ext/functions/yolox.py +6 -6
- pixeltable/func/aggregate_function.py +1 -0
- pixeltable/func/function.py +28 -4
- pixeltable/functions/__init__.py +4 -2
- pixeltable/functions/anthropic.py +15 -5
- pixeltable/functions/fireworks.py +1 -1
- pixeltable/functions/globals.py +6 -1
- pixeltable/functions/huggingface.py +2 -2
- pixeltable/functions/image.py +17 -2
- pixeltable/functions/json.py +5 -5
- pixeltable/functions/mistralai.py +188 -0
- pixeltable/functions/openai.py +6 -10
- pixeltable/functions/string.py +3 -2
- pixeltable/functions/timestamp.py +95 -7
- pixeltable/functions/together.py +4 -4
- pixeltable/functions/video.py +2 -2
- pixeltable/functions/vision.py +27 -17
- pixeltable/functions/whisper.py +1 -1
- pixeltable/io/hf_datasets.py +17 -15
- pixeltable/io/pandas.py +0 -2
- pixeltable/io/parquet.py +15 -14
- pixeltable/iterators/document.py +16 -15
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_19.py +46 -0
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +5 -4
- pixeltable/plan.py +100 -78
- pixeltable/store.py +5 -1
- pixeltable/tool/create_test_db_dump.py +4 -3
- pixeltable/type_system.py +12 -14
- pixeltable/utils/documents.py +45 -42
- pixeltable/utils/formatter.py +2 -2
- {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/METADATA +79 -21
- pixeltable-0.2.18.dist-info/RECORD +147 -0
- pixeltable-0.2.17.dist-info/RECORD +0 -144
- {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.17.dist-info → pixeltable-0.2.18.dist-info}/entry_points.txt +0 -0
pixeltable/env.py
CHANGED
|
@@ -14,7 +14,8 @@ import uuid
|
|
|
14
14
|
import warnings
|
|
15
15
|
from dataclasses import dataclass
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
18
|
+
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
|
18
19
|
|
|
19
20
|
import pixeltable_pgserver
|
|
20
21
|
import sqlalchemy as sql
|
|
@@ -37,6 +38,35 @@ class Env:
|
|
|
37
38
|
_instance: Optional[Env] = None
|
|
38
39
|
_log_fmt_str = '%(asctime)s %(levelname)s %(name)s %(filename)s:%(lineno)d: %(message)s'
|
|
39
40
|
|
|
41
|
+
_home: Optional[Path]
|
|
42
|
+
_media_dir: Optional[Path]
|
|
43
|
+
_file_cache_dir: Optional[Path] # cached media files with external URL
|
|
44
|
+
_dataset_cache_dir: Optional[Path] # cached datasets (eg, pytorch or COCO)
|
|
45
|
+
_log_dir: Optional[Path] # log files
|
|
46
|
+
_tmp_dir: Optional[Path] # any tmp files
|
|
47
|
+
_sa_engine: Optional[sql.engine.base.Engine]
|
|
48
|
+
_pgdata_dir: Optional[Path]
|
|
49
|
+
_db_name: Optional[str]
|
|
50
|
+
_db_server: Optional[pixeltable_pgserver.PostgresServer]
|
|
51
|
+
_db_url: Optional[str]
|
|
52
|
+
_default_time_zone: Optional[ZoneInfo]
|
|
53
|
+
|
|
54
|
+
# info about optional packages that are utilized by some parts of the code
|
|
55
|
+
__optional_packages: dict[str, PackageInfo]
|
|
56
|
+
|
|
57
|
+
_spacy_nlp: Optional[spacy.Language]
|
|
58
|
+
_httpd: Optional[http.server.HTTPServer]
|
|
59
|
+
_http_address: Optional[str]
|
|
60
|
+
_logger: logging.Logger
|
|
61
|
+
_default_log_level: int
|
|
62
|
+
_logfilename: Optional[str]
|
|
63
|
+
_log_to_stdout: bool
|
|
64
|
+
_module_log_level: dict[str, int] # module name -> log level
|
|
65
|
+
_config_file: Optional[Path]
|
|
66
|
+
_config: Optional[dict[str, Any]]
|
|
67
|
+
_stdout_handler: logging.StreamHandler
|
|
68
|
+
_initialized: bool
|
|
69
|
+
|
|
40
70
|
@classmethod
|
|
41
71
|
def get(cls) -> Env:
|
|
42
72
|
if cls._instance is None:
|
|
@@ -51,24 +81,23 @@ class Env:
|
|
|
51
81
|
cls._instance = env
|
|
52
82
|
|
|
53
83
|
def __init__(self):
|
|
54
|
-
self._home
|
|
55
|
-
self._media_dir
|
|
56
|
-
self._file_cache_dir
|
|
57
|
-
self._dataset_cache_dir
|
|
58
|
-
self._log_dir
|
|
59
|
-
self._tmp_dir
|
|
60
|
-
self._sa_engine
|
|
61
|
-
self._pgdata_dir
|
|
62
|
-
self._db_name
|
|
63
|
-
self._db_server
|
|
64
|
-
self._db_url
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
self.
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
71
|
-
self._http_address: Optional[str] = None
|
|
84
|
+
self._home = None
|
|
85
|
+
self._media_dir = None # computed media files
|
|
86
|
+
self._file_cache_dir = None # cached media files with external URL
|
|
87
|
+
self._dataset_cache_dir = None # cached datasets (eg, pytorch or COCO)
|
|
88
|
+
self._log_dir = None # log files
|
|
89
|
+
self._tmp_dir = None # any tmp files
|
|
90
|
+
self._sa_engine = None
|
|
91
|
+
self._pgdata_dir = None
|
|
92
|
+
self._db_name = None
|
|
93
|
+
self._db_server = None
|
|
94
|
+
self._db_url = None
|
|
95
|
+
self._default_time_zone = None
|
|
96
|
+
|
|
97
|
+
self.__optional_packages = {}
|
|
98
|
+
self._spacy_nlp = None
|
|
99
|
+
self._httpd = None
|
|
100
|
+
self._http_address = None
|
|
72
101
|
|
|
73
102
|
# logging-related state
|
|
74
103
|
self._logger = logging.getLogger('pixeltable')
|
|
@@ -76,13 +105,12 @@ class Env:
|
|
|
76
105
|
self._logger.propagate = False
|
|
77
106
|
self._logger.addFilter(self._log_filter)
|
|
78
107
|
self._default_log_level = logging.INFO
|
|
79
|
-
self._logfilename
|
|
108
|
+
self._logfilename = None
|
|
80
109
|
self._log_to_stdout = False
|
|
81
|
-
self._module_log_level
|
|
110
|
+
self._module_log_level = {} # module name -> log level
|
|
82
111
|
|
|
83
|
-
|
|
84
|
-
self.
|
|
85
|
-
self._config: Optional[Dict[str, Any]] = None
|
|
112
|
+
self._config_file = None
|
|
113
|
+
self._config = None
|
|
86
114
|
|
|
87
115
|
# create logging handler to also log to stdout
|
|
88
116
|
self._stdout_handler = logging.StreamHandler(stream=sys.stdout)
|
|
@@ -103,6 +131,19 @@ class Env:
|
|
|
103
131
|
assert self._http_address is not None
|
|
104
132
|
return self._http_address
|
|
105
133
|
|
|
134
|
+
@property
|
|
135
|
+
def default_time_zone(self) -> Optional[ZoneInfo]:
|
|
136
|
+
return self._default_time_zone
|
|
137
|
+
|
|
138
|
+
@default_time_zone.setter
|
|
139
|
+
def default_time_zone(self, tz: Optional[ZoneInfo]) -> None:
|
|
140
|
+
"""
|
|
141
|
+
This is not a publicly visible setter; it is only for testing purposes.
|
|
142
|
+
"""
|
|
143
|
+
tz_name = None if tz is None else tz.key
|
|
144
|
+
self.engine.dispose()
|
|
145
|
+
self._create_engine(time_zone_name=tz_name)
|
|
146
|
+
|
|
106
147
|
def configure_logging(
|
|
107
148
|
self,
|
|
108
149
|
*,
|
|
@@ -158,7 +199,8 @@ class Env:
|
|
|
158
199
|
self._module_log_level[module] = level
|
|
159
200
|
|
|
160
201
|
def is_installed_package(self, package_name: str) -> bool:
|
|
161
|
-
|
|
202
|
+
assert package_name in self.__optional_packages
|
|
203
|
+
return self.__optional_packages[package_name].is_installed
|
|
162
204
|
|
|
163
205
|
def _log_filter(self, record: logging.LogRecord) -> bool:
|
|
164
206
|
if record.name == 'pixeltable':
|
|
@@ -270,20 +312,35 @@ class Env:
|
|
|
270
312
|
self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=None)
|
|
271
313
|
self._db_url = self._db_server.get_uri(database=self._db_name, driver='psycopg')
|
|
272
314
|
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
315
|
+
tz_name = os.environ.get('PXT_TIME_ZONE', self._config.get('pxt_time_zone', None))
|
|
316
|
+
if tz_name is not None:
|
|
317
|
+
# Validate tzname
|
|
318
|
+
if not isinstance(tz_name, str):
|
|
319
|
+
self._logger.error(f'Invalid time zone specified in configuration.')
|
|
320
|
+
else:
|
|
321
|
+
try:
|
|
322
|
+
_ = ZoneInfo(tz_name)
|
|
323
|
+
except ZoneInfoNotFoundError:
|
|
324
|
+
self._logger.error(f'Invalid time zone specified in configuration: {tz_name}')
|
|
325
|
+
|
|
326
|
+
if reinit_db and self._store_db_exists():
|
|
327
|
+
self._drop_store_db()
|
|
328
|
+
|
|
329
|
+
create_db = not self._store_db_exists()
|
|
276
330
|
|
|
277
|
-
if
|
|
278
|
-
self._logger.info(f'creating database at {self.db_url}')
|
|
331
|
+
if create_db:
|
|
332
|
+
self._logger.info(f'creating database at: {self.db_url}')
|
|
279
333
|
self._create_store_db()
|
|
280
|
-
|
|
334
|
+
else:
|
|
335
|
+
self._logger.info(f'found database at: {self.db_url}')
|
|
336
|
+
|
|
337
|
+
# Create the SQLAlchemy engine. This will also set the default time zone.
|
|
338
|
+
self._create_engine(time_zone_name=tz_name, echo=echo)
|
|
339
|
+
|
|
340
|
+
if create_db:
|
|
281
341
|
from pixeltable.metadata import schema
|
|
282
342
|
schema.Base.metadata.create_all(self._sa_engine)
|
|
283
343
|
metadata.create_system_info(self._sa_engine)
|
|
284
|
-
else:
|
|
285
|
-
self._logger.info(f'found database {self.db_url}')
|
|
286
|
-
self._create_engine(echo=echo)
|
|
287
344
|
|
|
288
345
|
print(f'Connected to Pixeltable database at: {self.db_url}')
|
|
289
346
|
|
|
@@ -291,8 +348,21 @@ class Env:
|
|
|
291
348
|
self._set_up_runtime()
|
|
292
349
|
self.log_to_stdout(False)
|
|
293
350
|
|
|
294
|
-
def _create_engine(self, echo: bool = False) -> None:
|
|
295
|
-
|
|
351
|
+
def _create_engine(self, time_zone_name: Optional[str], echo: bool = False) -> None:
|
|
352
|
+
connect_args = {} if time_zone_name is None else {'options': f'-c timezone={time_zone_name}'}
|
|
353
|
+
self._sa_engine = sql.create_engine(
|
|
354
|
+
self.db_url,
|
|
355
|
+
echo=echo,
|
|
356
|
+
future=True,
|
|
357
|
+
isolation_level='AUTOCOMMIT',
|
|
358
|
+
connect_args=connect_args,
|
|
359
|
+
)
|
|
360
|
+
self._logger.info(f'Created SQLAlchemy engine at: {self.db_url}')
|
|
361
|
+
with self.engine.begin() as conn:
|
|
362
|
+
tz_name = conn.execute(sql.text('SHOW TIME ZONE')).scalar()
|
|
363
|
+
assert isinstance(tz_name, str)
|
|
364
|
+
self._logger.info(f'Database time zone is now: {tz_name}')
|
|
365
|
+
self._default_time_zone = ZoneInfo(tz_name)
|
|
296
366
|
|
|
297
367
|
def _store_db_exists(self) -> bool:
|
|
298
368
|
assert self._db_name is not None
|
|
@@ -308,7 +378,6 @@ class Env:
|
|
|
308
378
|
finally:
|
|
309
379
|
engine.dispose()
|
|
310
380
|
|
|
311
|
-
|
|
312
381
|
def _create_store_db(self) -> None:
|
|
313
382
|
assert self._db_name is not None
|
|
314
383
|
# create the db
|
|
@@ -416,61 +485,75 @@ class Env:
|
|
|
416
485
|
def _set_up_runtime(self) -> None:
|
|
417
486
|
"""Check for and start runtime services"""
|
|
418
487
|
self._start_web_server()
|
|
419
|
-
self.
|
|
488
|
+
self.__register_packages()
|
|
489
|
+
|
|
490
|
+
def __register_packages(self) -> None:
|
|
491
|
+
"""Declare optional packages that are utilized by some parts of the code."""
|
|
492
|
+
self.__register_package('anthropic')
|
|
493
|
+
self.__register_package('boto3')
|
|
494
|
+
self.__register_package('datasets')
|
|
495
|
+
self.__register_package('fireworks', library_name='fireworks-ai')
|
|
496
|
+
self.__register_package('label_studio_sdk', library_name='label-studio-sdk')
|
|
497
|
+
self.__register_package('mistralai')
|
|
498
|
+
self.__register_package('mistune')
|
|
499
|
+
self.__register_package('openai')
|
|
500
|
+
self.__register_package('openpyxl')
|
|
501
|
+
self.__register_package('pyarrow')
|
|
502
|
+
self.__register_package('sentence_transformers', library_name='sentence-transformers')
|
|
503
|
+
self.__register_package('spacy') # TODO: deal with en-core-web-sm
|
|
504
|
+
self.__register_package('tiktoken')
|
|
505
|
+
self.__register_package('together')
|
|
506
|
+
self.__register_package('toml')
|
|
507
|
+
self.__register_package('torch')
|
|
508
|
+
self.__register_package('torchvision')
|
|
509
|
+
self.__register_package('transformers')
|
|
510
|
+
self.__register_package('whisper', library_name='openai-whisper')
|
|
511
|
+
self.__register_package('whisperx')
|
|
512
|
+
self.__register_package('yolox', library_name='git+https://github.com/Megvii-BaseDetection/YOLOX@ac58e0a')
|
|
420
513
|
|
|
421
|
-
def _check_installed_packages(self) -> None:
|
|
422
|
-
def check(package: str) -> None:
|
|
423
|
-
if importlib.util.find_spec(package) is not None:
|
|
424
|
-
self._installed_packages[package] = []
|
|
425
|
-
else:
|
|
426
|
-
self._installed_packages[package] = None
|
|
427
|
-
|
|
428
|
-
check('toml')
|
|
429
|
-
check('datasets')
|
|
430
|
-
check('torch')
|
|
431
|
-
check('torchvision')
|
|
432
|
-
check('transformers')
|
|
433
|
-
check('sentence_transformers')
|
|
434
|
-
check('whisper')
|
|
435
|
-
check('yolox')
|
|
436
|
-
check('whisperx')
|
|
437
|
-
check('boto3')
|
|
438
|
-
check('fitz') # pymupdf
|
|
439
|
-
check('pyarrow')
|
|
440
|
-
check('spacy') # TODO: deal with en-core-web-sm
|
|
441
514
|
if self.is_installed_package('spacy'):
|
|
442
515
|
import spacy
|
|
443
|
-
|
|
444
516
|
self._spacy_nlp = spacy.load('en_core_web_sm')
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
517
|
+
|
|
518
|
+
def __register_package(self, package_name: str, library_name: Optional[str] = None) -> None:
|
|
519
|
+
self.__optional_packages[package_name] = PackageInfo(
|
|
520
|
+
is_installed=importlib.util.find_spec(package_name) is not None,
|
|
521
|
+
library_name=library_name or package_name # defaults to package_name unless specified otherwise
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
def require_package(self, package_name: str, min_version: Optional[list[int]] = None) -> None:
|
|
525
|
+
"""
|
|
526
|
+
Checks whether the specified optional package is available. If not, raises an exception
|
|
527
|
+
with an error message informing the user how to install it.
|
|
528
|
+
"""
|
|
529
|
+
assert package_name in self.__optional_packages
|
|
530
|
+
package_info = self.__optional_packages[package_name]
|
|
531
|
+
|
|
532
|
+
if not package_info.is_installed:
|
|
533
|
+
# Check again whether the package has been installed.
|
|
534
|
+
# We do this so that if a user gets an "optional library not found" error message, they can
|
|
535
|
+
# `pip install` the library and re-run the Pixeltable operation without having to restart
|
|
536
|
+
# their Python session.
|
|
537
|
+
package_info.is_installed = importlib.util.find_spec(package_name) is not None
|
|
538
|
+
if not package_info.is_installed:
|
|
539
|
+
# Still not found.
|
|
540
|
+
raise excs.Error(
|
|
541
|
+
f'This feature requires the `{package_name}` package. To install it, run: `pip install -U {package_info.library_name}`'
|
|
542
|
+
)
|
|
543
|
+
|
|
457
544
|
if min_version is None:
|
|
458
545
|
return
|
|
459
546
|
|
|
460
547
|
# check whether we have a version >= the required one
|
|
461
|
-
if
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
if len(min_version) < len(installed_version):
|
|
467
|
-
normalized_min_version = min_version + [0] * (len(installed_version) - len(min_version))
|
|
468
|
-
if any([a < b for a, b in zip(installed_version, normalized_min_version)]):
|
|
548
|
+
if package_info.version is None:
|
|
549
|
+
module = importlib.import_module(package_name)
|
|
550
|
+
package_info.version = [int(x) for x in module.__version__.split('.')]
|
|
551
|
+
|
|
552
|
+
if min_version > package_info.version:
|
|
469
553
|
raise excs.Error(
|
|
470
|
-
(
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
)
|
|
554
|
+
f'The installed version of package `{package_name}` is {".".join(str(v) for v in package_info.version)}, '
|
|
555
|
+
f'but version >={".".join(str(v) for v in min_version)} is required. '
|
|
556
|
+
f'To fix this, run: `pip install -U {package_info.library_name}`'
|
|
474
557
|
)
|
|
475
558
|
|
|
476
559
|
def num_tmp_files(self) -> int:
|
|
@@ -556,3 +639,10 @@ class ApiClient:
|
|
|
556
639
|
init_fn: Callable
|
|
557
640
|
param_names: list[str]
|
|
558
641
|
client_obj: Optional[Any] = None
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
@dataclass
|
|
645
|
+
class PackageInfo:
|
|
646
|
+
is_installed: bool
|
|
647
|
+
library_name: str # pypi library name (may be different from package name)
|
|
648
|
+
version: Optional[list[int]] = None # installed version, as a list of components (such as [3,0,2] for "3.0.2")
|
|
@@ -21,8 +21,9 @@ class AggregationNode(ExecNode):
|
|
|
21
21
|
self.input = input
|
|
22
22
|
self.group_by = group_by
|
|
23
23
|
self.input_exprs = list(input_exprs)
|
|
24
|
-
self.agg_fn_calls = agg_fn_calls
|
|
25
24
|
self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
|
|
25
|
+
# we need to make sure to refer to the same exprs that RowBuilder.eval() will use
|
|
26
|
+
self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
|
|
26
27
|
self.output_batch = DataRowBatch(tbl, row_builder, 0)
|
|
27
28
|
|
|
28
29
|
def _reset_agg_state(self, row_num: int) -> None:
|
|
@@ -19,7 +19,7 @@ class ComponentIterationNode(ExecNode):
|
|
|
19
19
|
super().__init__(input.row_builder, [], [], input)
|
|
20
20
|
self.view = view
|
|
21
21
|
iterator_args = [view.iterator_args.copy()]
|
|
22
|
-
self.row_builder.
|
|
22
|
+
self.row_builder.set_slot_idxs(iterator_args)
|
|
23
23
|
self.iterator_args = iterator_args[0]
|
|
24
24
|
assert isinstance(self.iterator_args, exprs.InlineDict)
|
|
25
25
|
self.iterator_args_ctx = self.row_builder.create_eval_ctx([self.iterator_args])
|
pixeltable/exec/sql_node.py
CHANGED
|
@@ -18,7 +18,7 @@ class SqlNode(ExecNode):
|
|
|
18
18
|
|
|
19
19
|
def __init__(
|
|
20
20
|
self, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder,
|
|
21
|
-
select_list: Iterable[exprs.Expr], set_pk: bool = False
|
|
21
|
+
select_list: Iterable[exprs.Expr], sql_elements: exprs.SqlElementCache, set_pk: bool = False
|
|
22
22
|
):
|
|
23
23
|
"""
|
|
24
24
|
Initialize self.stmt with expressions derived from select_list.
|
|
@@ -35,8 +35,9 @@ class SqlNode(ExecNode):
|
|
|
35
35
|
self.sql_exprs = exprs.ExprSet(select_list)
|
|
36
36
|
# unstored iter columns: we also need to retrieve whatever is needed to materialize the iter args
|
|
37
37
|
for iter_arg in row_builder.unstored_iter_args.values():
|
|
38
|
-
sql_subexprs = iter_arg.subexprs(filter=
|
|
39
|
-
|
|
38
|
+
sql_subexprs = iter_arg.subexprs(filter=sql_elements.contains, traverse_matches=False)
|
|
39
|
+
for e in sql_subexprs:
|
|
40
|
+
self.sql_exprs.add(e)
|
|
40
41
|
super().__init__(row_builder, self.sql_exprs, [], None) # we materialize self.sql_exprs
|
|
41
42
|
|
|
42
43
|
# change rowid refs against a base table to rowid refs against the target table, so that we minimize
|
|
@@ -44,7 +45,7 @@ class SqlNode(ExecNode):
|
|
|
44
45
|
for rowid_ref in [e for e in self.sql_exprs if isinstance(e, exprs.RowidRef)]:
|
|
45
46
|
rowid_ref.set_tbl(tbl)
|
|
46
47
|
|
|
47
|
-
sql_select_list = [
|
|
48
|
+
sql_select_list = [sql_elements.get(e) for e in self.sql_exprs]
|
|
48
49
|
assert len(sql_select_list) == len(self.sql_exprs)
|
|
49
50
|
assert all(e is not None for e in sql_select_list)
|
|
50
51
|
self.set_pk = set_pk
|
|
@@ -204,7 +205,8 @@ class SqlScanNode(SqlNode):
|
|
|
204
205
|
set_pk: if True, sets the primary for each DataRow
|
|
205
206
|
exact_version_only: tables for which we only want to see rows created at the current version
|
|
206
207
|
"""
|
|
207
|
-
|
|
208
|
+
sql_elements = exprs.SqlElementCache()
|
|
209
|
+
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=set_pk)
|
|
208
210
|
# create Select stmt
|
|
209
211
|
if order_by_items is None:
|
|
210
212
|
order_by_items = []
|
|
@@ -230,10 +232,10 @@ class SqlScanNode(SqlNode):
|
|
|
230
232
|
if isinstance(e, exprs.SimilarityExpr):
|
|
231
233
|
order_by_clause.append(e.as_order_by_clause(asc))
|
|
232
234
|
else:
|
|
233
|
-
order_by_clause.append(
|
|
235
|
+
order_by_clause.append(sql_elements.get(e).desc() if not asc else sql_elements.get(e))
|
|
234
236
|
|
|
235
237
|
if where_clause is not None:
|
|
236
|
-
sql_where_clause =
|
|
238
|
+
sql_where_clause = sql_elements.get(where_clause)
|
|
237
239
|
assert sql_where_clause is not None
|
|
238
240
|
self.stmt = self.stmt.where(sql_where_clause)
|
|
239
241
|
if len(order_by_clause) > 0:
|
|
@@ -272,7 +274,8 @@ class SqlLookupNode(SqlNode):
|
|
|
272
274
|
sa_key_cols: list of key columns in the store table
|
|
273
275
|
key_vals: list of key values to look up
|
|
274
276
|
"""
|
|
275
|
-
|
|
277
|
+
sql_elements = exprs.SqlElementCache()
|
|
278
|
+
super().__init__(tbl, row_builder, select_list, sql_elements, set_pk=True)
|
|
276
279
|
target = tbl.tbl_version # the stored table we're scanning
|
|
277
280
|
refd_tbl_ids = exprs.Expr.list_tbl_ids(self.sql_exprs)
|
|
278
281
|
self.stmt = self.create_from_clause(tbl, self.stmt, refd_tbl_ids)
|
pixeltable/exprs/__init__.py
CHANGED
|
@@ -20,5 +20,6 @@ from .object_ref import ObjectRef
|
|
|
20
20
|
from .row_builder import RowBuilder, ColumnSlotIdx, ExecProfile
|
|
21
21
|
from .rowid_ref import RowidRef
|
|
22
22
|
from .similarity_expr import SimilarityExpr
|
|
23
|
+
from .sql_element_cache import SqlElementCache
|
|
23
24
|
from .type_cast import TypeCast
|
|
24
25
|
from .variable import Variable
|
|
@@ -6,11 +6,11 @@ import sqlalchemy as sql
|
|
|
6
6
|
|
|
7
7
|
import pixeltable.exceptions as excs
|
|
8
8
|
import pixeltable.type_system as ts
|
|
9
|
-
|
|
10
9
|
from .data_row import DataRow
|
|
11
10
|
from .expr import Expr
|
|
12
11
|
from .globals import ArithmeticOperator
|
|
13
12
|
from .row_builder import RowBuilder
|
|
13
|
+
from .sql_element_cache import SqlElementCache
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class ArithmeticExpr(Expr):
|
|
@@ -54,10 +54,10 @@ class ArithmeticExpr(Expr):
|
|
|
54
54
|
def _op2(self) -> Expr:
|
|
55
55
|
return self.components[1]
|
|
56
56
|
|
|
57
|
-
def sql_expr(self) -> Optional[sql.
|
|
57
|
+
def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
58
58
|
assert self.col_type.is_int_type() or self.col_type.is_float_type() or self.col_type.is_json_type()
|
|
59
|
-
left = self._op1
|
|
60
|
-
right = self._op2
|
|
59
|
+
left = sql_elements.get(self._op1)
|
|
60
|
+
right = sql_elements.get(self._op2)
|
|
61
61
|
if left is None or right is None:
|
|
62
62
|
return None
|
|
63
63
|
if self.operator == ArithmeticOperator.ADD:
|
pixeltable/exprs/array_slice.py
CHANGED
|
@@ -8,6 +8,7 @@ from .data_row import DataRow
|
|
|
8
8
|
from .expr import Expr
|
|
9
9
|
from .globals import print_slice
|
|
10
10
|
from .row_builder import RowBuilder
|
|
11
|
+
from .sql_element_cache import SqlElementCache
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class ArraySlice(Expr):
|
|
@@ -41,7 +42,7 @@ class ArraySlice(Expr):
|
|
|
41
42
|
def _id_attrs(self) -> List[Tuple[str, Any]]:
|
|
42
43
|
return super()._id_attrs() + [('index', self.index)]
|
|
43
44
|
|
|
44
|
-
def sql_expr(self) -> Optional[sql.
|
|
45
|
+
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
45
46
|
return None
|
|
46
47
|
|
|
47
48
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
@@ -1,15 +1,16 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import enum
|
|
4
|
+
from typing import Optional, List, Any, Dict, Tuple
|
|
4
5
|
|
|
5
6
|
import sqlalchemy as sql
|
|
6
7
|
|
|
7
|
-
|
|
8
|
+
import pixeltable.type_system as ts
|
|
8
9
|
from .column_ref import ColumnRef
|
|
9
|
-
from .row_builder import RowBuilder
|
|
10
10
|
from .data_row import DataRow
|
|
11
|
-
|
|
12
|
-
|
|
11
|
+
from .expr import Expr
|
|
12
|
+
from .row_builder import RowBuilder
|
|
13
|
+
from .sql_element_cache import SqlElementCache
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
class ColumnPropertyRef(Expr):
|
|
@@ -45,7 +46,7 @@ class ColumnPropertyRef(Expr):
|
|
|
45
46
|
def __str__(self) -> str:
|
|
46
47
|
return f'{self._col_ref}.{self.prop.name.lower()}'
|
|
47
48
|
|
|
48
|
-
def sql_expr(self) -> Optional[sql.
|
|
49
|
+
def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
49
50
|
if not self._col_ref.col.is_stored:
|
|
50
51
|
return None
|
|
51
52
|
if self.prop == self.Property.ERRORTYPE:
|
|
@@ -56,7 +57,7 @@ class ColumnPropertyRef(Expr):
|
|
|
56
57
|
return self._col_ref.col.sa_errormsg_col
|
|
57
58
|
if self.prop == self.Property.FILEURL:
|
|
58
59
|
# the file url is stored as the column value
|
|
59
|
-
return self._col_ref
|
|
60
|
+
return sql_elements.get(self._col_ref)
|
|
60
61
|
return None
|
|
61
62
|
|
|
62
63
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
|
@@ -73,5 +74,6 @@ class ColumnPropertyRef(Expr):
|
|
|
73
74
|
@classmethod
|
|
74
75
|
def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
|
|
75
76
|
assert 'prop' in d
|
|
77
|
+
assert isinstance(components[0], ColumnRef)
|
|
76
78
|
return cls(components[0], cls.Property(d['prop']))
|
|
77
79
|
|
pixeltable/exprs/column_ref.py
CHANGED
|
@@ -7,6 +7,7 @@ import sqlalchemy as sql
|
|
|
7
7
|
from .expr import Expr
|
|
8
8
|
from .data_row import DataRow
|
|
9
9
|
from .row_builder import RowBuilder
|
|
10
|
+
from .sql_element_cache import SqlElementCache
|
|
10
11
|
import pixeltable.iterators as iters
|
|
11
12
|
import pixeltable.exceptions as excs
|
|
12
13
|
import pixeltable.catalog as catalog
|
|
@@ -92,7 +93,7 @@ class ColumnRef(Expr):
|
|
|
92
93
|
def __repr__(self) -> str:
|
|
93
94
|
return f'ColumnRef({self.col!r})'
|
|
94
95
|
|
|
95
|
-
def sql_expr(self) -> Optional[sql.
|
|
96
|
+
def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
96
97
|
return self.col.sa_col
|
|
97
98
|
|
|
98
99
|
def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
|
pixeltable/exprs/comparison.py
CHANGED
|
@@ -1,22 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from datetime import datetime
|
|
4
3
|
from typing import Optional, List, Any, Dict
|
|
5
4
|
|
|
6
5
|
import sqlalchemy as sql
|
|
7
6
|
|
|
7
|
+
import pixeltable.exceptions as excs
|
|
8
|
+
import pixeltable.index as index
|
|
9
|
+
import pixeltable.type_system as ts
|
|
8
10
|
from .column_ref import ColumnRef
|
|
9
11
|
from .data_row import DataRow
|
|
10
12
|
from .expr import Expr
|
|
11
13
|
from .globals import ComparisonOperator
|
|
12
14
|
from .literal import Literal
|
|
13
15
|
from .row_builder import RowBuilder
|
|
14
|
-
|
|
15
|
-
import pixeltable.index as index
|
|
16
|
-
import pixeltable.type_system as ts
|
|
16
|
+
from .sql_element_cache import SqlElementCache
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class Comparison(Expr):
|
|
20
|
+
is_search_arg_comparison: bool
|
|
21
|
+
operator: ComparisonOperator
|
|
22
|
+
|
|
20
23
|
def __init__(self, operator: ComparisonOperator, op1: Expr, op2: Expr):
|
|
21
24
|
super().__init__(ts.BoolType())
|
|
22
25
|
self.operator = operator
|
|
@@ -62,8 +65,8 @@ class Comparison(Expr):
|
|
|
62
65
|
def _op2(self) -> Expr:
|
|
63
66
|
return self.components[1]
|
|
64
67
|
|
|
65
|
-
def sql_expr(self) -> Optional[sql.ClauseElement]:
|
|
66
|
-
left = self._op1
|
|
68
|
+
def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ClauseElement]:
|
|
69
|
+
left = sql_elements.get(self._op1)
|
|
67
70
|
if self.is_search_arg_comparison:
|
|
68
71
|
# reference the index value column if there is an index and this is not a snapshot
|
|
69
72
|
# (indices don't apply to snapshots)
|
|
@@ -76,7 +79,7 @@ class Comparison(Expr):
|
|
|
76
79
|
assert len(idx_info) == 1
|
|
77
80
|
left = idx_info[0].val_col.sa_col
|
|
78
81
|
|
|
79
|
-
right = self._op2
|
|
82
|
+
right = sql_elements.get(self._op2)
|
|
80
83
|
if left is None or right is None:
|
|
81
84
|
return None
|
|
82
85
|
|
|
@@ -9,6 +9,7 @@ from .data_row import DataRow
|
|
|
9
9
|
from .expr import Expr
|
|
10
10
|
from .globals import LogicalOperator
|
|
11
11
|
from .row_builder import RowBuilder
|
|
12
|
+
from .sql_element_cache import SqlElementCache
|
|
12
13
|
import pixeltable.type_system as ts
|
|
13
14
|
|
|
14
15
|
|
|
@@ -66,8 +67,8 @@ class CompoundPredicate(Expr):
|
|
|
66
67
|
non_matches = [op for op in self.components if not condition(op)]
|
|
67
68
|
return (matches, self.make_conjunction(non_matches))
|
|
68
69
|
|
|
69
|
-
def sql_expr(self) -> Optional[sql.
|
|
70
|
-
sql_exprs = [
|
|
70
|
+
def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
|
|
71
|
+
sql_exprs = [sql_elements.get(op) for op in self.components]
|
|
71
72
|
if any(e is None for e in sql_exprs):
|
|
72
73
|
return None
|
|
73
74
|
if self.operator == LogicalOperator.NOT:
|