pixeltable 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (88) hide show
  1. pixeltable/__init__.py +7 -19
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +7 -7
  4. pixeltable/catalog/globals.py +3 -0
  5. pixeltable/catalog/insertable_table.py +9 -7
  6. pixeltable/catalog/table.py +220 -143
  7. pixeltable/catalog/table_version.py +36 -18
  8. pixeltable/catalog/table_version_path.py +0 -8
  9. pixeltable/catalog/view.py +3 -3
  10. pixeltable/dataframe.py +9 -24
  11. pixeltable/env.py +107 -36
  12. pixeltable/exceptions.py +7 -4
  13. pixeltable/exec/__init__.py +1 -1
  14. pixeltable/exec/aggregation_node.py +22 -15
  15. pixeltable/exec/component_iteration_node.py +62 -41
  16. pixeltable/exec/data_row_batch.py +7 -7
  17. pixeltable/exec/exec_node.py +35 -7
  18. pixeltable/exec/expr_eval_node.py +2 -1
  19. pixeltable/exec/in_memory_data_node.py +9 -9
  20. pixeltable/exec/sql_node.py +265 -136
  21. pixeltable/exprs/__init__.py +1 -0
  22. pixeltable/exprs/data_row.py +30 -19
  23. pixeltable/exprs/expr.py +15 -14
  24. pixeltable/exprs/expr_dict.py +55 -0
  25. pixeltable/exprs/expr_set.py +21 -15
  26. pixeltable/exprs/function_call.py +21 -8
  27. pixeltable/exprs/json_path.py +3 -6
  28. pixeltable/exprs/rowid_ref.py +2 -2
  29. pixeltable/exprs/sql_element_cache.py +5 -1
  30. pixeltable/ext/functions/whisperx.py +7 -2
  31. pixeltable/func/callable_function.py +2 -2
  32. pixeltable/func/function_registry.py +6 -7
  33. pixeltable/func/query_template_function.py +11 -12
  34. pixeltable/func/signature.py +17 -15
  35. pixeltable/func/udf.py +0 -4
  36. pixeltable/functions/__init__.py +1 -1
  37. pixeltable/functions/audio.py +4 -6
  38. pixeltable/functions/globals.py +86 -42
  39. pixeltable/functions/huggingface.py +12 -14
  40. pixeltable/functions/image.py +59 -45
  41. pixeltable/functions/json.py +0 -1
  42. pixeltable/functions/mistralai.py +2 -2
  43. pixeltable/functions/openai.py +22 -25
  44. pixeltable/functions/string.py +50 -50
  45. pixeltable/functions/timestamp.py +20 -20
  46. pixeltable/functions/together.py +26 -12
  47. pixeltable/functions/video.py +11 -20
  48. pixeltable/functions/whisper.py +2 -20
  49. pixeltable/globals.py +57 -56
  50. pixeltable/index/base.py +2 -2
  51. pixeltable/index/btree.py +7 -7
  52. pixeltable/index/embedding_index.py +8 -10
  53. pixeltable/io/external_store.py +11 -5
  54. pixeltable/io/globals.py +3 -1
  55. pixeltable/io/hf_datasets.py +4 -4
  56. pixeltable/io/label_studio.py +6 -6
  57. pixeltable/io/parquet.py +14 -13
  58. pixeltable/iterators/document.py +10 -8
  59. pixeltable/iterators/video.py +10 -1
  60. pixeltable/metadata/__init__.py +3 -2
  61. pixeltable/metadata/converters/convert_14.py +4 -2
  62. pixeltable/metadata/converters/convert_15.py +1 -1
  63. pixeltable/metadata/converters/convert_19.py +1 -0
  64. pixeltable/metadata/converters/convert_20.py +1 -1
  65. pixeltable/metadata/converters/util.py +9 -8
  66. pixeltable/metadata/schema.py +32 -21
  67. pixeltable/plan.py +136 -154
  68. pixeltable/store.py +51 -36
  69. pixeltable/tool/create_test_db_dump.py +7 -7
  70. pixeltable/tool/doc_plugins/griffe.py +3 -34
  71. pixeltable/tool/mypy_plugin.py +32 -0
  72. pixeltable/type_system.py +243 -60
  73. pixeltable/utils/arrow.py +10 -9
  74. pixeltable/utils/coco.py +4 -4
  75. pixeltable/utils/documents.py +1 -1
  76. pixeltable/utils/filecache.py +131 -84
  77. pixeltable/utils/formatter.py +1 -1
  78. pixeltable/utils/http_server.py +2 -5
  79. pixeltable/utils/media_store.py +6 -6
  80. pixeltable/utils/pytorch.py +10 -11
  81. pixeltable/utils/sql.py +2 -1
  82. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/METADATA +16 -7
  83. pixeltable-0.2.21.dist-info/RECORD +148 -0
  84. pixeltable/utils/help.py +0 -11
  85. pixeltable-0.2.19.dist-info/RECORD +0 -147
  86. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/LICENSE +0 -0
  87. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/WHEEL +0 -0
  88. {pixeltable-0.2.19.dist-info → pixeltable-0.2.21.dist-info}/entry_points.txt +0 -0
@@ -6,7 +6,7 @@ import inspect
6
6
  import logging
7
7
  import time
8
8
  import uuid
9
- from typing import TYPE_CHECKING, Any, Iterable, Optional
9
+ from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional
10
10
  from uuid import UUID
11
11
 
12
12
  import sqlalchemy as sql
@@ -453,7 +453,9 @@ class TableVersion:
453
453
  self.idxs_by_name[idx_name] = idx_info
454
454
 
455
455
  # add the columns and update the metadata
456
- status = self._add_columns([val_col, undo_col], conn)
456
+ # TODO support on_error='abort' for indices; it's tricky because of the way metadata changes are entangled
457
+ # with the database operations
458
+ status = self._add_columns([val_col, undo_col], conn, print_stats=False, on_error='ignore')
457
459
  # now create the index structure
458
460
  idx.create_index(self._store_idx_name(idx_id), val_col, conn)
459
461
 
@@ -478,7 +480,7 @@ class TableVersion:
478
480
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
479
481
  _logger.info(f'Dropped index {idx_md.name} on table {self.name}')
480
482
 
481
- def add_column(self, col: Column, print_stats: bool = False) -> UpdateStatus:
483
+ def add_column(self, col: Column, print_stats: bool, on_error: Literal['abort', 'ignore']) -> UpdateStatus:
482
484
  """Adds a column to the table.
483
485
  """
484
486
  assert not self.is_snapshot
@@ -498,9 +500,8 @@ class TableVersion:
498
500
  preceding_schema_version = self.schema_version
499
501
  self.schema_version = self.version
500
502
  with Env.get().engine.begin() as conn:
501
- status = self._add_columns([col], conn, print_stats=print_stats)
503
+ status = self._add_columns([col], conn, print_stats=print_stats, on_error=on_error)
502
504
  _ = self._add_default_index(col, conn)
503
- # TODO: what to do about errors?
504
505
  self._update_md(time.time(), conn, preceding_schema_version=preceding_schema_version)
505
506
  _logger.info(f'Added column {col.name} to table {self.name}, new version: {self.version}')
506
507
 
@@ -512,7 +513,13 @@ class TableVersion:
512
513
  _logger.info(f'Column {col.name}: {msg}')
513
514
  return status
514
515
 
515
- def _add_columns(self, cols: Iterable[Column], conn: sql.engine.Connection, print_stats: bool = False) -> UpdateStatus:
516
+ def _add_columns(
517
+ self,
518
+ cols: Iterable[Column],
519
+ conn: sql.engine.Connection,
520
+ print_stats: bool,
521
+ on_error: Literal['abort', 'ignore']
522
+ ) -> UpdateStatus:
516
523
  """Add and populate columns within the current transaction"""
517
524
  cols = list(cols)
518
525
  row_count = self.store_tbl.count(conn=conn)
@@ -550,10 +557,14 @@ class TableVersion:
550
557
  try:
551
558
  plan.ctx.set_conn(conn)
552
559
  plan.open()
553
- num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn)
560
+ try:
561
+ num_excs = self.store_tbl.load_column(col, plan, value_expr_slot_idx, conn, on_error)
562
+ except sql.exc.DBAPIError as exc:
563
+ # Wrap the DBAPIError in an excs.Error to unify processing in the subsequent except block
564
+ raise excs.Error(f'SQL error during execution of computed column `{col.name}`:\n{exc}') from exc
554
565
  if num_excs > 0:
555
566
  cols_with_excs.append(col)
556
- except sql.exc.DBAPIError as e:
567
+ except excs.Error as exc:
557
568
  self.cols.pop()
558
569
  for col in cols:
559
570
  # remove columns that we already added
@@ -564,7 +575,7 @@ class TableVersion:
564
575
  del self.cols_by_id[col.id]
565
576
  # we need to re-initialize the sqlalchemy schema
566
577
  self.store_tbl.create_sa_tbl()
567
- raise excs.Error(f'Error during SQL execution:\n{e}')
578
+ raise exc
568
579
  finally:
569
580
  plan.close()
570
581
 
@@ -689,21 +700,30 @@ class TableVersion:
689
700
  plan = Planner.create_insert_plan(self, rows, ignore_errors=not fail_on_exception)
690
701
  else:
691
702
  plan = Planner.create_df_insert_plan(self, df, ignore_errors=not fail_on_exception)
703
+
704
+ # this is a base table; we generate rowids during the insert
705
+ def rowids() -> Iterator[int]:
706
+ while True:
707
+ rowid = self.next_rowid
708
+ self.next_rowid += 1
709
+ yield rowid
710
+
692
711
  if conn is None:
693
712
  with Env.get().engine.begin() as conn:
694
- return self._insert(plan, conn, time.time(), print_stats)
713
+ return self._insert(plan, conn, time.time(), print_stats=print_stats, rowids=rowids())
695
714
  else:
696
- return self._insert(plan, conn, time.time(), print_stats)
715
+ return self._insert(plan, conn, time.time(), print_stats=print_stats, rowids=rowids())
697
716
 
698
717
  def _insert(
699
- self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, print_stats: bool = False,
718
+ self, exec_plan: 'exec.ExecNode', conn: sql.engine.Connection, timestamp: float, *,
719
+ rowids: Optional[Iterator[int]] = None, print_stats: bool = False,
700
720
  ) -> UpdateStatus:
701
721
  """Insert rows produced by exec_plan and propagate to views"""
702
722
  # we're creating a new version
703
723
  self.version += 1
704
724
  result = UpdateStatus()
705
- num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(exec_plan, conn, v_min=self.version)
706
- self.next_rowid = num_rows
725
+ num_rows, num_excs, cols_with_excs = self.store_tbl.insert_rows(
726
+ exec_plan, conn, v_min=self.version, rowids=rowids)
707
727
  result.num_rows = num_rows
708
728
  result.num_excs = num_excs
709
729
  result.num_computed_values += exec_plan.ctx.num_computed_exprs * num_rows
@@ -714,7 +734,7 @@ class TableVersion:
714
734
  for view in self.mutable_views:
715
735
  from pixeltable.plan import Planner
716
736
  plan, _ = Planner.create_view_load_plan(view.path, propagates_insert=True)
717
- status = view._insert(plan, conn, timestamp, print_stats)
737
+ status = view._insert(plan, conn, timestamp, print_stats=print_stats)
718
738
  result.num_rows += status.num_rows
719
739
  result.num_excs += status.num_excs
720
740
  result.num_computed_values += status.num_computed_values
@@ -751,9 +771,7 @@ class TableVersion:
751
771
  raise excs.Error(f'Filter {analysis_info.filter} not expressible in SQL')
752
772
 
753
773
  with Env.get().engine.begin() as conn:
754
- plan, updated_cols, recomputed_cols = (
755
- Planner.create_update_plan(self.path, update_spec, [], where, cascade)
756
- )
774
+ plan, updated_cols, recomputed_cols = Planner.create_update_plan(self.path, update_spec, [], where, cascade)
757
775
  from pixeltable.exprs import SqlElementCache
758
776
  result = self.propagate_update(
759
777
  plan, where.sql_expr(SqlElementCache()) if where is not None else None, recomputed_cols,
@@ -91,14 +91,6 @@ class TableVersionPath:
91
91
  col = self.tbl_version.cols_by_name[col_name]
92
92
  return ColumnRef(col)
93
93
 
94
- def __getitem__(self, index: object) -> Union[exprs.ColumnRef, pxt.DataFrame]:
95
- """Return a ColumnRef for the given column name, or a DataFrame for the given slice.
96
- """
97
- if isinstance(index, str):
98
- # basically <tbl>.<colname>
99
- return self.__getattr__(index)
100
- return pxt.DataFrame(self).__getitem__(index)
101
-
102
94
  def columns(self) -> list[Column]:
103
95
  """Return all user columns visible in this tbl version path, including columns from bases"""
104
96
  result = list(self.tbl_version.cols_by_name.values())
@@ -52,11 +52,11 @@ class View(Table):
52
52
 
53
53
  @classmethod
54
54
  def _create(
55
- cls, dir_id: UUID, name: str, base: TableVersionPath, schema: Dict[str, Any],
56
- predicate: 'pxt.exprs.Expr', is_snapshot: bool, num_retained_versions: int, comment: str,
55
+ cls, dir_id: UUID, name: str, base: TableVersionPath, additional_columns: Dict[str, Any],
56
+ predicate: Optional['pxt.exprs.Expr'], is_snapshot: bool, num_retained_versions: int, comment: str,
57
57
  iterator_cls: Optional[Type[ComponentIterator]], iterator_args: Optional[Dict]
58
58
  ) -> View:
59
- columns = cls._create_columns(schema)
59
+ columns = cls._create_columns(additional_columns)
60
60
  cls._verify_schema(columns)
61
61
 
62
62
  # verify that filter can be evaluated in the context of the base
pixeltable/dataframe.py CHANGED
@@ -8,7 +8,7 @@ import logging
8
8
  import mimetypes
9
9
  import traceback
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Set, Tuple
11
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterator, List, Optional, Sequence, Set, Tuple, Union
12
12
 
13
13
  import pandas as pd
14
14
  import pandas.io.formats.style
@@ -97,8 +97,8 @@ class DataFrameResultSet:
97
97
  return self._rows[index[0]][col_idx]
98
98
  raise excs.Error(f'Bad index: {index}')
99
99
 
100
- def __iter__(self) -> DataFrameResultSetIterator:
101
- return DataFrameResultSetIterator(self)
100
+ def __iter__(self) -> Iterator[dict[str, Any]]:
101
+ return (self._row_to_dict(i) for i in range(len(self)))
102
102
 
103
103
  def __eq__(self, other):
104
104
  if not isinstance(other, DataFrameResultSet):
@@ -106,19 +106,6 @@ class DataFrameResultSet:
106
106
  return self.to_pandas().equals(other.to_pandas())
107
107
 
108
108
 
109
- class DataFrameResultSetIterator:
110
- def __init__(self, result_set: DataFrameResultSet):
111
- self._result_set = result_set
112
- self._idx = 0
113
-
114
- def __next__(self) -> Dict[str, Any]:
115
- if self._idx >= len(self._result_set):
116
- raise StopIteration
117
- row = self._result_set._row_to_dict(self._idx)
118
- self._idx += 1
119
- return row
120
-
121
-
122
109
  # # TODO: remove this; it's only here as a reminder that we still need to call release() in the current implementation
123
110
  # class AnalysisInfo:
124
111
  # def __init__(self, tbl: catalog.TableVersion):
@@ -296,7 +283,7 @@ class DataFrame:
296
283
 
297
284
  def _create_query_plan(self) -> exec.ExecNode:
298
285
  # construct a group-by clause if we're grouping by a table
299
- group_by_clause: List[exprs.Expr] = []
286
+ group_by_clause: Optional[list[exprs.Expr]] = None
300
287
  if self.grouping_tbl is not None:
301
288
  assert self.group_by_clause is None
302
289
  num_rowid_cols = len(self.grouping_tbl.store_tbl.rowid_columns())
@@ -315,8 +302,8 @@ class DataFrame:
315
302
  where_clause=self.where_clause,
316
303
  group_by_clause=group_by_clause,
317
304
  order_by_clause=self.order_by_clause if self.order_by_clause is not None else [],
318
- limit=self.limit_val if self.limit_val is not None else 0,
319
- ) # limit_val == 0: no limit_val
305
+ limit=self.limit_val
306
+ )
320
307
 
321
308
 
322
309
  def show(self, n: int = 20) -> DataFrameResultSet:
@@ -629,17 +616,15 @@ class DataFrame:
629
616
  if self.limit_val is not None:
630
617
  raise excs.Error(f'Cannot use `{op_name}` after `limit`')
631
618
 
632
- def __getitem__(self, index: object) -> DataFrame:
619
+ def __getitem__(self, index: Union[exprs.Expr, Sequence[exprs.Expr]]) -> DataFrame:
633
620
  """
634
621
  Allowed:
635
622
  - [List[Expr]]/[Tuple[Expr]]: setting the select list
636
623
  - [Expr]: setting a single-col select list
637
624
  """
638
- if isinstance(index, tuple):
639
- index = list(index)
640
625
  if isinstance(index, exprs.Expr):
641
- index = [index]
642
- if isinstance(index, list):
626
+ return self.select(index)
627
+ if isinstance(index, Sequence):
643
628
  return self.select(*index)
644
629
  raise TypeError(f'Invalid index type: {type(index)}')
645
630
 
pixeltable/env.py CHANGED
@@ -8,6 +8,7 @@ import importlib.util
8
8
  import inspect
9
9
  import logging
10
10
  import os
11
+ import shutil
11
12
  import subprocess
12
13
  import sys
13
14
  import threading
@@ -15,12 +16,12 @@ import uuid
15
16
  import warnings
16
17
  from dataclasses import dataclass
17
18
  from pathlib import Path
18
- from typing import TYPE_CHECKING, Any, Callable, Optional
19
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
19
20
  from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
20
21
 
21
22
  import pixeltable_pgserver
22
23
  import sqlalchemy as sql
23
- import yaml
24
+ import toml
24
25
  from tqdm import TqdmWarning
25
26
 
26
27
  import pixeltable.exceptions as excs
@@ -64,7 +65,7 @@ class Env:
64
65
  _log_to_stdout: bool
65
66
  _module_log_level: dict[str, int] # module name -> log level
66
67
  _config_file: Optional[Path]
67
- _config: Optional[dict[str, Any]]
68
+ _config: Optional[Config]
68
69
  _stdout_handler: logging.StreamHandler
69
70
  _initialized: bool
70
71
 
@@ -110,6 +111,7 @@ class Env:
110
111
  self._log_to_stdout = False
111
112
  self._module_log_level = {} # module name -> log level
112
113
 
114
+ # config
113
115
  self._config_file = None
114
116
  self._config = None
115
117
 
@@ -119,7 +121,8 @@ class Env:
119
121
  self._initialized = False
120
122
 
121
123
  @property
122
- def config(self):
124
+ def config(self) -> Config:
125
+ assert self._config is not None
123
126
  return self._config
124
127
 
125
128
  @property
@@ -227,30 +230,13 @@ class Env:
227
230
  home = Path(os.environ.get('PIXELTABLE_HOME', str(Path.home() / '.pixeltable')))
228
231
  assert self._home is None or self._home == home
229
232
  self._home = home
230
- self._config_file = Path(os.environ.get('PIXELTABLE_CONFIG', str(self._home / 'config.yaml')))
233
+ self._config_file = Path(os.environ.get('PIXELTABLE_CONFIG', str(self._home / 'config.toml')))
231
234
  self._media_dir = self._home / 'media'
232
235
  self._file_cache_dir = self._home / 'file_cache'
233
236
  self._dataset_cache_dir = self._home / 'dataset_cache'
234
237
  self._log_dir = self._home / 'logs'
235
238
  self._tmp_dir = self._home / 'tmp'
236
239
 
237
- # Read in the config
238
- if os.path.isfile(self._config_file):
239
- with open(self._config_file, 'r') as stream:
240
- try:
241
- self._config = yaml.safe_load(stream)
242
- except yaml.YAMLError as exc:
243
- self._logger.error(f'Could not read config file: {self._config_file}')
244
- self._config = {}
245
- else:
246
- self._config = {}
247
-
248
- # Disable spurious warnings
249
- warnings.simplefilter('ignore', category=TqdmWarning)
250
- if 'hide_warnings' in self._config and self._config['hide_warnings']:
251
- # Disable more warnings
252
- warnings.simplefilter('ignore', category=UserWarning)
253
-
254
240
  if self._home.exists() and not self._home.is_dir():
255
241
  raise RuntimeError(f'{self._home} is not a directory')
256
242
 
@@ -274,6 +260,22 @@ class Env:
274
260
  if not self._tmp_dir.exists():
275
261
  self._tmp_dir.mkdir()
276
262
 
263
+ # Read in the config
264
+ self._config = Config.from_file(self._config_file)
265
+ self._file_cache_size_g = self._config.get_float_value('file_cache_size_g')
266
+ if self._file_cache_size_g is None:
267
+ raise excs.Error(
268
+ 'pixeltable/file_cache_size_g is missing from configuration\n'
269
+ f'(either add a `file_cache_size_g` entry to the `pixeltable` section of {self._config_file},\n'
270
+ 'or set the PIXELTABLE_FILE_CACHE_SIZE_G environment variable)'
271
+ )
272
+
273
+ # Disable spurious warnings
274
+ warnings.simplefilter('ignore', category=TqdmWarning)
275
+ if self._config.get_bool_value('hide_warnings'):
276
+ # Disable more warnings
277
+ warnings.simplefilter('ignore', category=UserWarning)
278
+
277
279
  # configure _logger to log to a file
278
280
  self._logfilename = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + '.log'
279
281
  fh = logging.FileHandler(self._log_dir / self._logfilename, mode='w')
@@ -313,7 +315,7 @@ class Env:
313
315
  self._db_server = pixeltable_pgserver.get_server(self._pgdata_dir, cleanup_mode=None)
314
316
  self._db_url = self._db_server.get_uri(database=self._db_name, driver='psycopg')
315
317
 
316
- tz_name = os.environ.get('PXT_TIME_ZONE', self._config.get('pxt_time_zone', None))
318
+ tz_name = self.config.get_string_value('time_zone')
317
319
  if tz_name is not None:
318
320
  # Validate tzname
319
321
  if not isinstance(tz_name, str):
@@ -340,7 +342,7 @@ class Env:
340
342
 
341
343
  if create_db:
342
344
  from pixeltable.metadata import schema
343
- schema.Base.metadata.create_all(self._sa_engine)
345
+ schema.base_metadata.create_all(self._sa_engine)
344
346
  metadata.create_system_info(self._sa_engine)
345
347
 
346
348
  print(f'Connected to Pixeltable database at: {self.db_url}')
@@ -440,21 +442,18 @@ class Env:
440
442
  if cl.client_obj is not None:
441
443
  return cl.client_obj # Already initialized
442
444
 
443
- # Construct a client. For each client parameter, first check if the parameter is in the environment;
444
- # if not, look in Pixeltable config from `config.yaml`.
445
+ # Construct a client, retrieving each parameter from config.
445
446
 
446
447
  init_kwargs: dict[str, str] = {}
447
448
  for param in cl.param_names:
448
- environ = f'{name.upper()}_{param.upper()}'
449
- if environ in os.environ:
450
- init_kwargs[param] = os.environ[environ]
451
- elif name.lower() in self._config and param in self._config[name.lower()]:
452
- init_kwargs[param] = self._config[name.lower()][param.lower()]
453
- if param not in init_kwargs or init_kwargs[param] == '':
449
+ arg = self._config.get_string_value(param, section=name)
450
+ if arg is not None and len(arg) > 0:
451
+ init_kwargs[param] = arg
452
+ else:
454
453
  raise excs.Error(
455
454
  f'`{name}` client not initialized: parameter `{param}` is not configured.\n'
456
- f'To fix this, specify the `{environ}` environment variable, or put `{param.lower()}` in '
457
- f'the `{name.lower()}` section of $PIXELTABLE_HOME/config.yaml.'
455
+ f'To fix this, specify the `{name.upper()}_{param.upper()}` environment variable, or put `{param.lower()}` in '
456
+ f'the `{name.lower()}` section of $PIXELTABLE_HOME/config.toml.'
458
457
  )
459
458
 
460
459
  cl.client_obj = cl.init_fn(**init_kwargs)
@@ -506,7 +505,6 @@ class Env:
506
505
  self.__register_package('spacy')
507
506
  self.__register_package('tiktoken')
508
507
  self.__register_package('together')
509
- self.__register_package('toml')
510
508
  self.__register_package('torch')
511
509
  self.__register_package('torchvision')
512
510
  self.__register_package('transformers')
@@ -643,7 +641,7 @@ def register_client(name: str) -> Callable:
643
641
  Pixeltable will attempt to load the client parameters from config. For each
644
642
  config parameter:
645
643
  - If an environment variable named MY_CLIENT_API_KEY (for example) is set, use it;
646
- - Otherwise, look for 'api_key' in the 'my_client' section of config.yaml.
644
+ - Otherwise, look for 'api_key' in the 'my_client' section of config.toml.
647
645
 
648
646
  If all config parameters are found, Pixeltable calls the initialization function;
649
647
  otherwise it throws an exception.
@@ -660,6 +658,79 @@ def register_client(name: str) -> Callable:
660
658
  return decorator
661
659
 
662
660
 
661
+ class Config:
662
+ """
663
+ The (global) Pixeltable configuration, as loaded from `config.toml`. Provides methods for retrieving
664
+ configuration values, which can be set in the config file or as environment variables.
665
+ """
666
+ __config: dict[str, Any]
667
+
668
+ T = TypeVar('T')
669
+
670
+ @classmethod
671
+ def from_file(cls, path: Path) -> Config:
672
+ """
673
+ Loads configuration from the specified TOML file. If the file does not exist, it will be
674
+ created and populated with the default configuration.
675
+ """
676
+ if os.path.isfile(path):
677
+ with open(path, 'r') as stream:
678
+ try:
679
+ config_dict = toml.load(stream)
680
+ except Exception as exc:
681
+ raise excs.Error(f'Could not read config file: {str(path)}') from exc
682
+ else:
683
+ config_dict = cls.__create_default_config(path)
684
+ with open(path, 'w') as stream:
685
+ try:
686
+ toml.dump(config_dict, stream)
687
+ except Exception as exc:
688
+ raise excs.Error(f'Could not write config file: {str(path)}') from exc
689
+ logging.getLogger('pixeltable').info(f'Created default config file at: {str(path)}')
690
+ return cls(config_dict)
691
+
692
+ @classmethod
693
+ def __create_default_config(cls, config_path: Path) -> dict[str, Any]:
694
+ free_disk_space_bytes = shutil.disk_usage(config_path.parent).free
695
+ # Default cache size is 1/5 of free disk space
696
+ file_cache_size_g = free_disk_space_bytes / 5 / (1 << 30)
697
+ return {
698
+ 'pixeltable': {
699
+ 'file_cache_size_g': round(file_cache_size_g, 1),
700
+ 'hide_warnings': False,
701
+ }
702
+ }
703
+
704
+ def __init__(self, config: dict[str, Any]) -> None:
705
+ self.__config = config
706
+
707
+ def get_value(self, key: str, expected_type: type[T], section: str = 'pixeltable') -> Optional[T]:
708
+ env_var = f'{section.upper()}_{key.upper()}'
709
+ if env_var in os.environ:
710
+ value = os.environ[env_var]
711
+ elif section in self.__config and key in self.__config[section]:
712
+ value = self.__config[section][key]
713
+ else:
714
+ return None
715
+
716
+ try:
717
+ return expected_type(value) # type: ignore[call-arg]
718
+ except ValueError:
719
+ raise excs.Error(f'Invalid value for configuration parameter {section}.{key}: {value}')
720
+
721
+ def get_string_value(self, key: str, section: str = 'pixeltable') -> Optional[str]:
722
+ return self.get_value(key, str, section)
723
+
724
+ def get_int_value(self, key: str, section: str = 'pixeltable') -> Optional[int]:
725
+ return self.get_value(key, int, section)
726
+
727
+ def get_float_value(self, key: str, section: str = 'pixeltable') -> Optional[float]:
728
+ return self.get_value(key, float, section)
729
+
730
+ def get_bool_value(self, key: str, section: str = 'pixeltable') -> Optional[bool]:
731
+ return self.get_value(key, bool, section)
732
+
733
+
663
734
  _registered_clients: dict[str, ApiClient] = {}
664
735
 
665
736
 
pixeltable/exceptions.py CHANGED
@@ -1,6 +1,9 @@
1
- from typing import List, Any
2
- from types import TracebackType
3
1
  from dataclasses import dataclass
2
+ from types import TracebackType
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ if TYPE_CHECKING:
6
+ from pixeltable import exprs
4
7
 
5
8
 
6
9
  class Error(Exception):
@@ -9,11 +12,11 @@ class Error(Exception):
9
12
 
10
13
  @dataclass
11
14
  class ExprEvalError(Exception):
12
- expr: Any # exprs.Expr, but we're not importing pixeltable.exprs to avoid circular imports
15
+ expr: 'exprs.Expr'
13
16
  expr_msg: str
14
17
  exc: Exception
15
18
  exc_tb: TracebackType
16
- input_vals: List[Any]
19
+ input_vals: list[Any]
17
20
  row_num: int
18
21
 
19
22
 
@@ -8,4 +8,4 @@ from .expr_eval_node import ExprEvalNode
8
8
  from .in_memory_data_node import InMemoryDataNode
9
9
  from .media_validation_node import MediaValidationNode
10
10
  from .row_update_node import RowUpdateNode
11
- from .sql_node import SqlLookupNode, SqlScanNode
11
+ from .sql_node import SqlLookupNode, SqlScanNode, SqlAggregationNode, SqlNode
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import sys
5
- from typing import Iterable, List, Optional, Any
5
+ from typing import Iterable, Optional, Any, Iterator
6
6
 
7
7
  import pixeltable.catalog as catalog
8
8
  import pixeltable.exceptions as excs
@@ -13,17 +13,29 @@ from .exec_node import ExecNode
13
13
  _logger = logging.getLogger('pixeltable')
14
14
 
15
15
  class AggregationNode(ExecNode):
16
+ """
17
+ In-memory aggregation for UDAs.
18
+
19
+ At the moment, this returns all results in a single DataRowBatch.
20
+ """
21
+ group_by: Optional[list[exprs.Expr]]
22
+ input_exprs: list[exprs.Expr]
23
+ agg_fn_eval_ctx: exprs.RowBuilder.EvalCtx
24
+ agg_fn_calls: list[exprs.FunctionCall]
25
+ output_batch: DataRowBatch
26
+
16
27
  def __init__(
17
- self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: List[exprs.Expr],
18
- agg_fn_calls: List[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
28
+ self, tbl: catalog.TableVersion, row_builder: exprs.RowBuilder, group_by: Optional[list[exprs.Expr]],
29
+ agg_fn_calls: list[exprs.FunctionCall], input_exprs: Iterable[exprs.Expr], input: ExecNode
19
30
  ):
20
31
  super().__init__(row_builder, group_by + agg_fn_calls, input_exprs, input)
21
32
  self.input = input
22
33
  self.group_by = group_by
23
34
  self.input_exprs = list(input_exprs)
24
- self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=input_exprs)
35
+ self.agg_fn_eval_ctx = row_builder.create_eval_ctx(agg_fn_calls, exclude=self.input_exprs)
25
36
  # we need to make sure to refer to the same exprs that RowBuilder.eval() will use
26
37
  self.agg_fn_calls = self.agg_fn_eval_ctx.target_exprs
38
+ # create output_batch here, rather than in __iter__(), so we don't need to remember tbl and row_builder
27
39
  self.output_batch = DataRowBatch(tbl, row_builder, 0)
28
40
 
29
41
  def _reset_agg_state(self, row_num: int) -> None:
@@ -45,17 +57,14 @@ class AggregationNode(ExecNode):
45
57
  input_vals = [row[d.slot_idx] for d in fn_call.dependencies()]
46
58
  raise excs.ExprEvalError(fn_call, expr_msg, e, exc_tb, input_vals, row_num)
47
59
 
48
- def __next__(self) -> DataRowBatch:
49
- if self.output_batch is None:
50
- raise StopIteration
51
-
60
+ def __iter__(self) -> Iterator[DataRowBatch]:
52
61
  prev_row: Optional[exprs.DataRow] = None
53
- current_group: Optional[List[Any]] = None # the values of the group-by exprs
62
+ current_group: Optional[list[Any]] = None # the values of the group-by exprs
54
63
  num_input_rows = 0
55
64
  for row_batch in self.input:
56
65
  num_input_rows += len(row_batch)
57
66
  for row in row_batch:
58
- group = [row[e.slot_idx] for e in self.group_by]
67
+ group = [row[e.slot_idx] for e in self.group_by] if self.group_by is not None else None
59
68
  if current_group is None:
60
69
  current_group = group
61
70
  self._reset_agg_state(0)
@@ -71,9 +80,7 @@ class AggregationNode(ExecNode):
71
80
  self.row_builder.eval(prev_row, self.agg_fn_eval_ctx, profile=self.ctx.profile)
72
81
  self.output_batch.add_row(prev_row)
73
82
 
74
- result = self.output_batch
75
- result.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
76
- self.output_batch = None
77
- _logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(result.rows)} rows')
78
- return result
83
+ self.output_batch.flush_imgs(None, self.stored_img_cols, self.flushed_img_slots)
84
+ _logger.debug(f'AggregateNode: consumed {num_input_rows} rows, returning {len(self.output_batch.rows)} rows')
85
+ yield self.output_batch
79
86