pixeltable 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (76) hide show
  1. pixeltable/__init__.py +15 -33
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/catalog.py +1 -1
  4. pixeltable/catalog/column.py +28 -16
  5. pixeltable/catalog/dir.py +2 -2
  6. pixeltable/catalog/insertable_table.py +5 -55
  7. pixeltable/catalog/named_function.py +2 -2
  8. pixeltable/catalog/schema_object.py +2 -7
  9. pixeltable/catalog/table.py +298 -204
  10. pixeltable/catalog/table_version.py +104 -139
  11. pixeltable/catalog/table_version_path.py +22 -4
  12. pixeltable/catalog/view.py +20 -10
  13. pixeltable/dataframe.py +128 -25
  14. pixeltable/env.py +21 -14
  15. pixeltable/exec/exec_context.py +5 -0
  16. pixeltable/exec/exec_node.py +1 -0
  17. pixeltable/exec/in_memory_data_node.py +29 -24
  18. pixeltable/exec/sql_scan_node.py +1 -1
  19. pixeltable/exprs/column_ref.py +13 -8
  20. pixeltable/exprs/data_row.py +4 -0
  21. pixeltable/exprs/expr.py +16 -1
  22. pixeltable/exprs/function_call.py +4 -4
  23. pixeltable/exprs/row_builder.py +29 -20
  24. pixeltable/exprs/similarity_expr.py +4 -3
  25. pixeltable/ext/functions/yolox.py +2 -1
  26. pixeltable/func/__init__.py +1 -0
  27. pixeltable/func/aggregate_function.py +14 -12
  28. pixeltable/func/callable_function.py +8 -6
  29. pixeltable/func/expr_template_function.py +13 -19
  30. pixeltable/func/function.py +3 -6
  31. pixeltable/func/query_template_function.py +84 -0
  32. pixeltable/func/signature.py +68 -23
  33. pixeltable/func/udf.py +13 -10
  34. pixeltable/functions/__init__.py +6 -91
  35. pixeltable/functions/eval.py +26 -14
  36. pixeltable/functions/fireworks.py +25 -23
  37. pixeltable/functions/globals.py +62 -0
  38. pixeltable/functions/huggingface.py +20 -16
  39. pixeltable/functions/image.py +170 -1
  40. pixeltable/functions/openai.py +95 -128
  41. pixeltable/functions/string.py +10 -2
  42. pixeltable/functions/together.py +95 -84
  43. pixeltable/functions/util.py +16 -0
  44. pixeltable/functions/video.py +94 -16
  45. pixeltable/functions/whisper.py +78 -0
  46. pixeltable/globals.py +1 -1
  47. pixeltable/io/__init__.py +10 -0
  48. pixeltable/io/external_store.py +370 -0
  49. pixeltable/io/globals.py +50 -22
  50. pixeltable/{datatransfer → io}/label_studio.py +279 -166
  51. pixeltable/io/parquet.py +1 -1
  52. pixeltable/iterators/__init__.py +9 -0
  53. pixeltable/iterators/string.py +40 -0
  54. pixeltable/metadata/__init__.py +6 -8
  55. pixeltable/metadata/converters/convert_10.py +2 -4
  56. pixeltable/metadata/converters/convert_12.py +7 -2
  57. pixeltable/metadata/converters/convert_13.py +6 -8
  58. pixeltable/metadata/converters/convert_14.py +2 -4
  59. pixeltable/metadata/converters/convert_15.py +40 -25
  60. pixeltable/metadata/converters/convert_16.py +18 -0
  61. pixeltable/metadata/converters/util.py +11 -8
  62. pixeltable/metadata/schema.py +3 -6
  63. pixeltable/plan.py +8 -7
  64. pixeltable/store.py +1 -1
  65. pixeltable/tool/create_test_db_dump.py +145 -54
  66. pixeltable/tool/embed_udf.py +9 -0
  67. pixeltable/type_system.py +1 -2
  68. pixeltable/utils/code.py +34 -0
  69. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/METADATA +2 -2
  70. pixeltable-0.2.9.dist-info/RECORD +131 -0
  71. pixeltable/datatransfer/__init__.py +0 -1
  72. pixeltable/datatransfer/remote.py +0 -113
  73. pixeltable/functions/pil/image.py +0 -147
  74. pixeltable-0.2.7.dist-info/RECORD +0 -126
  75. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/LICENSE +0 -0
  76. {pixeltable-0.2.7.dist-info → pixeltable-0.2.9.dist-info}/WHEEL +0 -0
pixeltable/dataframe.py CHANGED
@@ -9,7 +9,7 @@ import logging
9
9
  import mimetypes
10
10
  import traceback
11
11
  from pathlib import Path
12
- from typing import List, Optional, Any, Dict, Generator, Tuple, Set
12
+ from typing import List, Optional, Any, Dict, Iterator, Tuple, Set
13
13
 
14
14
  import PIL.Image
15
15
  import cv2
@@ -22,6 +22,7 @@ import pixeltable.catalog as catalog
22
22
  import pixeltable.exceptions as excs
23
23
  import pixeltable.exprs as exprs
24
24
  import pixeltable.type_system as ts
25
+ import pixeltable.func as func
25
26
  from pixeltable.catalog import is_valid_identifier
26
27
  from pixeltable.env import Env
27
28
  from pixeltable.plan import Planner
@@ -344,7 +345,37 @@ class DataFrame:
344
345
  assert set(out_names) == seen_out_names
345
346
  return out_exprs, out_names
346
347
 
347
- def _exec(self) -> Generator[exprs.DataRow, None, None]:
348
+ def _vars(self) -> dict[str, exprs.Variable]:
349
+ """
350
+ Return a dict mapping variable name to Variable for all Variables contained in any component of the DataFrame
351
+ """
352
+ all_exprs: list[exprs.Expr] = []
353
+ all_exprs.extend(self._select_list_exprs)
354
+ if self.where_clause is not None:
355
+ all_exprs.append(self.where_clause)
356
+ if self.group_by_clause is not None:
357
+ all_exprs.extend(self.group_by_clause)
358
+ if self.order_by_clause is not None:
359
+ all_exprs.extend([expr for expr, _ in self.order_by_clause])
360
+ vars = exprs.Expr.list_subexprs(all_exprs, expr_class=exprs.Variable)
361
+ unique_vars: dict[str, exprs.Variable] = {}
362
+ for var in vars:
363
+ if var.name not in unique_vars:
364
+ unique_vars[var.name] = var
365
+ else:
366
+ if unique_vars[var.name].col_type != var.col_type:
367
+ raise excs.Error(f'Multiple definitions of parameter {var.name}')
368
+ return unique_vars
369
+
370
+ def parameters(self) -> dict[str, ColumnType]:
371
+ """Return a dict mapping parameter name to parameter type.
372
+
373
+ Parameters are Variables contained in any component of the DataFrame.
374
+ """
375
+ vars = self._vars()
376
+ return {name: var.col_type for name, var in vars.items()}
377
+
378
+ def _exec(self, conn: Optional[sql.engine.Connection] = None) -> Iterator[exprs.DataRow]:
348
379
  """Run the query and return rows as a generator.
349
380
  This function must not modify the state of the DataFrame, otherwise it breaks dataset caching.
350
381
  """
@@ -361,6 +392,7 @@ class DataFrame:
361
392
 
362
393
  for item in self._select_list_exprs:
363
394
  item.bind_rel_paths(None)
395
+
364
396
  plan = Planner.create_query_plan(
365
397
  self.tbl,
366
398
  self._select_list_exprs,
@@ -370,8 +402,8 @@ class DataFrame:
370
402
  limit=self.limit_val if self.limit_val is not None else 0,
371
403
  ) # limit_val == 0: no limit_val
372
404
 
373
- with Env.get().engine.begin() as conn:
374
- plan.ctx.conn = conn
405
+ def exec_plan(conn: sql.engine.Connection) -> Iterator[exprs.DataRow]:
406
+ plan.ctx.set_conn(conn)
375
407
  plan.open()
376
408
  try:
377
409
  for row_batch in plan:
@@ -379,7 +411,12 @@ class DataFrame:
379
411
  yield data_row
380
412
  finally:
381
413
  plan.close()
382
- return
414
+
415
+ if conn is None:
416
+ with Env.get().engine.begin() as conn:
417
+ yield from exec_plan(conn)
418
+ else:
419
+ yield from exec_plan(conn)
383
420
 
384
421
  def show(self, n: int = 20) -> DataFrameResultSet:
385
422
  assert n is not None
@@ -407,10 +444,54 @@ class DataFrame:
407
444
  def get_column_types(self) -> List[ColumnType]:
408
445
  return [expr.col_type for expr in self._select_list_exprs]
409
446
 
447
+ def bind(self, args: dict[str, Any]) -> DataFrame:
448
+ """Bind arguments to parameters and return a new DataFrame."""
449
+ # substitute Variables with the corresponding values according to 'args', converted to Literals
450
+ select_list_exprs = copy.deepcopy(self._select_list_exprs)
451
+ where_clause = copy.deepcopy(self.where_clause)
452
+ group_by_clause = copy.deepcopy(self.group_by_clause)
453
+ order_by_exprs = [copy.deepcopy(order_by_expr) for order_by_expr, _ in self.order_by_clause] \
454
+ if self.order_by_clause is not None else None
455
+
456
+ var_exprs: dict[exprs.Expr, exprs.Expr] = {}
457
+ vars = self._vars()
458
+ for arg_name, arg_val in args.items():
459
+ if arg_name not in vars:
460
+ # ignore unused variables
461
+ continue
462
+ var_expr = vars[arg_name]
463
+ arg_expr = exprs.Expr.from_object(arg_val)
464
+ if arg_expr is None:
465
+ raise excs.Error(f'Cannot convert argument {arg_val} to a Pixeltable expression')
466
+ var_exprs[var_expr] = arg_expr
467
+
468
+ exprs.Expr.list_substitute(select_list_exprs, var_exprs)
469
+ if where_clause is not None:
470
+ where_clause.substitute(var_exprs)
471
+ if group_by_clause is not None:
472
+ exprs.Expr.list_substitute(group_by_clause, var_exprs)
473
+ if order_by_exprs is not None:
474
+ exprs.Expr.list_substitute(order_by_exprs, var_exprs)
475
+
476
+ select_list = list(zip(select_list_exprs, self._column_names))
477
+ order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None
478
+ if order_by_exprs is not None:
479
+ order_by_clause = [
480
+ (expr, asc) for expr, asc in zip(order_by_exprs, [asc for _, asc in self.order_by_clause])
481
+ ]
482
+
483
+ return DataFrame(
484
+ self.tbl, select_list=select_list, where_clause=where_clause,
485
+ group_by_clause=group_by_clause, grouping_tbl=self.grouping_tbl,
486
+ order_by_clause=order_by_clause, limit=self.limit_val)
487
+
410
488
  def collect(self) -> DataFrameResultSet:
489
+ return self._collect()
490
+
491
+ def _collect(self, conn: Optional[sql.engine.Connection] = None) -> DataFrameResultSet:
411
492
  try:
412
493
  result_rows = []
413
- for data_row in self._exec():
494
+ for data_row in self._exec(conn):
414
495
  result_row = [data_row[e.slot_idx] for e in self._select_list_exprs]
415
496
  result_rows.append(result_row)
416
497
  except excs.ExprEvalError as e:
@@ -579,10 +660,10 @@ class DataFrame:
579
660
  if len(grouping_items) > 1:
580
661
  raise excs.Error(f'group_by(): only one table can be specified')
581
662
  # we need to make sure that the grouping table is a base of self.tbl
582
- base = self.tbl.find_tbl_version(item.tbl_version_path.tbl_id())
663
+ base = self.tbl.find_tbl_version(item._tbl_version_path.tbl_id())
583
664
  if base is None or base.id == self.tbl.tbl_id():
584
665
  raise excs.Error(f'group_by(): {item.name} is not a base table of {self.tbl.tbl_name()}')
585
- grouping_tbl = item.tbl_version_path.tbl_version
666
+ grouping_tbl = item._tbl_version_path.tbl_version
586
667
  break
587
668
  if not isinstance(item, exprs.Expr):
588
669
  raise excs.Error(f'Invalid expression in group_by(): {item}')
@@ -615,6 +696,7 @@ class DataFrame:
615
696
  )
616
697
 
617
698
  def limit(self, n: int) -> DataFrame:
699
+ # TODO: allow n to be a Variable that can be substituted in bind()
618
700
  assert n is not None and isinstance(n, int)
619
701
  return DataFrame(
620
702
  self.tbl,
@@ -643,7 +725,7 @@ class DataFrame:
643
725
  return self.select(*index)
644
726
  raise TypeError(f'Invalid index type: {type(index)}')
645
727
 
646
- def _as_dict(self) -> Dict[str, Any]:
728
+ def as_dict(self) -> Dict[str, Any]:
647
729
  """
648
730
  Returns:
649
731
  Dictionary representing this dataframe.
@@ -651,22 +733,46 @@ class DataFrame:
651
733
  tbl_versions = self.tbl.get_tbl_versions()
652
734
  d = {
653
735
  '_classname': 'DataFrame',
654
- 'tbl_ids': [str(t.id) for t in tbl_versions],
655
- 'tbl_versions': [t.version for t in tbl_versions],
656
- 'select_list': [(e.as_dict(), name) for (e, name) in self.select_list]
657
- if self.select_list is not None
658
- else None,
736
+ 'tbl': self.tbl.as_dict(),
737
+ 'select_list':
738
+ [(e.as_dict(), name) for (e, name) in self.select_list] if self.select_list is not None else None,
659
739
  'where_clause': self.where_clause.as_dict() if self.where_clause is not None else None,
660
- 'group_by_clause': [e.as_dict() for e in self.group_by_clause]
661
- if self.group_by_clause is not None
662
- else None,
663
- 'order_by_clause': [(e.as_dict(), asc) for (e, asc) in self.order_by_clause]
664
- if self.order_by_clause is not None
665
- else None,
740
+ 'group_by_clause':
741
+ [e.as_dict() for e in self.group_by_clause] if self.group_by_clause is not None else None,
742
+ 'grouping_tbl': self.grouping_tbl.as_dict() if self.grouping_tbl is not None else None,
743
+ 'order_by_clause':
744
+ [(e.as_dict(), asc) for (e,asc) in self.order_by_clause] if self.order_by_clause is not None else None,
666
745
  'limit_val': self.limit_val,
667
746
  }
668
747
  return d
669
748
 
749
+ @classmethod
750
+ def from_dict(cls, d: Dict[str, Any]) -> 'DataFrame':
751
+ tbl = catalog.TableVersionPath.from_dict(d['tbl'])
752
+ select_list = [(exprs.Expr.from_dict(e), name) for e, name in d['select_list']] \
753
+ if d['select_list'] is not None else None
754
+ where_clause = exprs.Predicate.from_dict(d['where_clause']) \
755
+ if d['where_clause'] is not None else None
756
+ group_by_clause = [exprs.Expr.from_dict(e) for e in d['group_by_clause']] \
757
+ if d['group_by_clause'] is not None else None
758
+ grouping_tbl = catalog.TableVersion.from_dict(d['grouping_tbl']) \
759
+ if d['grouping_tbl'] is not None else None
760
+ order_by_clause = [(exprs.Expr.from_dict(e), asc) for e, asc in d['order_by_clause']] \
761
+ if d['order_by_clause'] is not None else None
762
+ limit_val = d['limit_val']
763
+ return DataFrame(
764
+ tbl, select_list=select_list, where_clause=where_clause, group_by_clause=group_by_clause,
765
+ grouping_tbl=grouping_tbl, order_by_clause=order_by_clause, limit=limit_val)
766
+
767
+ def _hash_result_set(self) -> str:
768
+ """Return a hash that changes when the result set changes."""
769
+ d = self.as_dict()
770
+ # add list of referenced table versions (the actual versions, not the effective ones) in order to force cache
771
+ # invalidation when any of the referenced tables changes
772
+ d['tbl_versions'] = [tbl_version.version for tbl_version in self.tbl.get_tbl_versions()]
773
+ summary_string = json.dumps(d)
774
+ return hashlib.sha256(summary_string.encode()).hexdigest()
775
+
670
776
  def to_coco_dataset(self) -> Path:
671
777
  """Convert the dataframe to a COCO dataset.
672
778
  This dataframe must return a single json-typed output column in the following format:
@@ -686,9 +792,7 @@ class DataFrame:
686
792
  """
687
793
  from pixeltable.utils.coco import write_coco_dataset
688
794
 
689
- summary_string = json.dumps(self._as_dict())
690
- cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
691
-
795
+ cache_key = self._hash_result_set()
692
796
  dest_path = Env.get().dataset_cache_dir / f'coco_{cache_key}'
693
797
  if dest_path.exists():
694
798
  assert dest_path.is_dir()
@@ -737,8 +841,7 @@ class DataFrame:
737
841
  from pixeltable.io.parquet import save_parquet # pylint: disable=import-outside-toplevel
738
842
  from pixeltable.utils.pytorch import PixeltablePytorchDataset # pylint: disable=import-outside-toplevel
739
843
 
740
- summary_string = json.dumps(self._as_dict())
741
- cache_key = hashlib.sha256(summary_string.encode()).hexdigest()
844
+ cache_key = self._hash_result_set()
742
845
 
743
846
  dest_path = (Env.get().dataset_cache_dir / f'df_{cache_key}').with_suffix('.parquet') # pylint: disable = protected-access
744
847
  if dest_path.exists(): # fast path: use cache
pixeltable/env.py CHANGED
@@ -14,7 +14,7 @@ import uuid
14
14
  import warnings
15
15
  from dataclasses import dataclass
16
16
  from pathlib import Path
17
- from typing import Callable, Optional, Dict, Any, List
17
+ from typing import Callable, Optional, Dict, Any, List, TYPE_CHECKING
18
18
 
19
19
  import pgserver
20
20
  import sqlalchemy as sql
@@ -25,6 +25,9 @@ import pixeltable.exceptions as excs
25
25
  from pixeltable import metadata
26
26
  from pixeltable.utils.http_server import make_server
27
27
 
28
+ if TYPE_CHECKING:
29
+ import spacy
30
+
28
31
 
29
32
  class Env:
30
33
  """
@@ -63,12 +66,10 @@ class Env:
63
66
  # info about installed packages that are utilized by some parts of the code;
64
67
  # package name -> version; version == []: package is installed, but we haven't determined the version yet
65
68
  self._installed_packages: Dict[str, Optional[List[int]]] = {}
66
- self._spacy_nlp: Optional[Any] = None # spacy.Language
69
+ self._spacy_nlp: Optional[spacy.Language] = None
67
70
  self._httpd: Optional[http.server.HTTPServer] = None
68
71
  self._http_address: Optional[str] = None
69
72
 
70
- self._registered_clients: dict[str, ApiClient] = {}
71
-
72
73
  # logging-related state
73
74
  self._logger = logging.getLogger('pixeltable')
74
75
  self._logger.setLevel(logging.DEBUG) # allow everything to pass, we filter in _log_filter()
@@ -177,8 +178,6 @@ class Env:
177
178
  if self._initialized:
178
179
  return
179
180
 
180
- # Disable spurious warnings
181
- warnings.simplefilter('ignore', category=TqdmWarning)
182
181
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
183
182
 
184
183
  self._initialized = True
@@ -203,6 +202,12 @@ class Env:
203
202
  else:
204
203
  self._config = {}
205
204
 
205
+ # Disable spurious warnings
206
+ warnings.simplefilter('ignore', category=TqdmWarning)
207
+ if 'hide_warnings' in self._config and self._config['hide_warnings']:
208
+ # Disable more warnings
209
+ warnings.simplefilter('ignore', category=UserWarning)
210
+
206
211
  if self._home.exists() and not self._home.is_dir():
207
212
  raise RuntimeError(f'{self._home} is not a directory')
208
213
 
@@ -354,11 +359,6 @@ class Env:
354
359
  def _upgrade_metadata(self) -> None:
355
360
  metadata.upgrade_md(self._sa_engine)
356
361
 
357
- def _register_client(self, name: str, init_fn: Callable) -> None:
358
- sig = inspect.signature(init_fn)
359
- param_names = list(sig.parameters.keys())
360
- self._registered_clients[name] = ApiClient(init_fn=init_fn, param_names=param_names)
361
-
362
362
  def get_client(self, name: str) -> Any:
363
363
  """
364
364
  Gets the client with the specified name, initializing it if necessary.
@@ -366,7 +366,7 @@ class Env:
366
366
  Args:
367
367
  - name: The name of the client
368
368
  """
369
- cl = self._registered_clients[name]
369
+ cl = _registered_clients[name]
370
370
  if cl.client_obj is not None:
371
371
  return cl.client_obj # Already initialized
372
372
 
@@ -430,6 +430,7 @@ class Env:
430
430
  check('torchvision')
431
431
  check('transformers')
432
432
  check('sentence_transformers')
433
+ check('whisper')
433
434
  check('yolox')
434
435
  check('whisperx')
435
436
  check('boto3')
@@ -507,7 +508,7 @@ class Env:
507
508
  return self._sa_engine
508
509
 
509
510
  @property
510
- def spacy_nlp(self) -> Any:
511
+ def spacy_nlp(self) -> spacy.Language:
511
512
  assert self._spacy_nlp is not None
512
513
  return self._spacy_nlp
513
514
 
@@ -537,11 +538,17 @@ def register_client(name: str) -> Callable:
537
538
  - name (str): The name of the API client (e.g., 'openai' or 'label-studio').
538
539
  """
539
540
  def decorator(fn: Callable) -> None:
540
- Env.get()._register_client(name, fn)
541
+ global _registered_clients
542
+ sig = inspect.signature(fn)
543
+ param_names = list(sig.parameters.keys())
544
+ _registered_clients[name] = ApiClient(init_fn=fn, param_names=param_names)
541
545
 
542
546
  return decorator
543
547
 
544
548
 
549
+ _registered_clients: dict[str, ApiClient] = {}
550
+
551
+
545
552
  @dataclass
546
553
  class ApiClient:
547
554
  init_fn: Callable
@@ -13,6 +13,7 @@ class ExecContext:
13
13
  ):
14
14
  self.show_pbar = show_pbar
15
15
  self.batch_size = batch_size
16
+ self.row_builder = row_builder
16
17
  self.profile = exprs.ExecProfile(row_builder)
17
18
  # num_rows is used to compute the total number of computed cells used for the progress bar
18
19
  self.num_rows: Optional[int] = None
@@ -20,3 +21,7 @@ class ExecContext:
20
21
  self.pk_clause = pk_clause
21
22
  self.num_computed_exprs = num_computed_exprs
22
23
  self.ignore_errors = ignore_errors
24
+
25
+ def set_conn(self, conn: sql.engine.Connection) -> None:
26
+ self.conn = conn
27
+ self.row_builder.set_conn(conn)
@@ -11,6 +11,7 @@ class ExecNode(abc.ABC):
11
11
  def __init__(
12
12
  self, row_builder: exprs.RowBuilder, output_exprs: Iterable[exprs.Expr],
13
13
  input_exprs: Iterable[exprs.Expr], input: Optional[ExecNode] = None):
14
+ self.output_exprs = output_exprs
14
15
  self.row_builder = row_builder
15
16
  self.input = input
16
17
  # we flush all image slots that aren't part of our output but are needed to create our output
@@ -1,25 +1,29 @@
1
- from typing import List, Dict, Any, Optional
2
- import urllib
3
1
  import logging
4
- import os
2
+ from typing import List, Dict, Any, Optional
5
3
 
6
- from .data_row_batch import DataRowBatch
7
- from .exec_node import ExecNode
8
4
  import pixeltable.catalog as catalog
9
5
  import pixeltable.exprs as exprs
10
- import pixeltable.env as env
11
6
  from pixeltable.utils.media_store import MediaStore
12
-
7
+ from .data_row_batch import DataRowBatch
8
+ from .exec_node import ExecNode
13
9
 
14
10
  _logger = logging.getLogger('pixeltable')
15
11
 
16
12
  class InMemoryDataNode(ExecNode):
17
- """Outputs in-memory data as a row batch of a particular table"""
13
+ """
14
+ Outputs in-memory data as a DataRowBatch of a particular table.
15
+
16
+ Populates slots of all non-computed columns (ie, output ColumnRefs)
17
+ - with the values provided in the input rows
18
+ - if an input row doesn't provide a value, sets the slot to the column default
19
+ """
18
20
  def __init__(
19
21
  self, tbl: catalog.TableVersionPath, rows: List[Dict[str, Any]],
20
22
  row_builder: exprs.RowBuilder, start_row_id: int,
21
23
  ):
22
- super().__init__(row_builder, [], [], None)
24
+ # we materialize all output slots
25
+ output_exprs = [e for e in row_builder.get_output_exprs() if isinstance(e, exprs.ColumnRef)]
26
+ super().__init__(row_builder, output_exprs, [], None)
23
27
  assert tbl.is_insertable()
24
28
  self.tbl = tbl
25
29
  self.input_rows = rows
@@ -29,21 +33,22 @@ class InMemoryDataNode(ExecNode):
29
33
 
30
34
  def _open(self) -> None:
31
35
  """Create row batch and populate with self.input_rows"""
32
- column_info = {info.col.id: info for info in self.row_builder.output_slot_idxs()}
33
- # exclude system columns
34
- user_column_info = {info.col.name: info for _, info in column_info.items() if info.col.name is not None}
35
- # stored columns that are not computed
36
- inserted_col_ids = set([
37
- info.col.id for info in self.row_builder.output_slot_idxs()
38
- if info.col.is_stored and not info.col.is_computed
39
- ])
36
+ user_cols_by_name = {
37
+ col_ref.col.name: exprs.ColumnSlotIdx(col_ref.col, col_ref.slot_idx)
38
+ for col_ref in self.output_exprs if col_ref.col.name is not None
39
+ }
40
+ output_cols_by_idx = {
41
+ col_ref.slot_idx: exprs.ColumnSlotIdx(col_ref.col, col_ref.slot_idx)
42
+ for col_ref in self.output_exprs
43
+ }
44
+ output_slot_idxs = {e.slot_idx for e in self.output_exprs}
40
45
 
41
46
  self.output_rows = DataRowBatch(self.tbl, self.row_builder, len(self.input_rows))
42
47
  for row_idx, input_row in enumerate(self.input_rows):
43
48
  # populate the output row with the values provided in the input row
44
- input_col_ids: List[int] = []
49
+ input_slot_idxs: set[int] = set()
45
50
  for col_name, val in input_row.items():
46
- col_info = user_column_info.get(col_name)
51
+ col_info = user_cols_by_name.get(col_name)
47
52
  assert col_info is not None
48
53
 
49
54
  if col_info.col.col_type.is_image_type() and isinstance(val, bytes):
@@ -52,12 +57,12 @@ class InMemoryDataNode(ExecNode):
52
57
  open(path, 'wb').write(val)
53
58
  val = path
54
59
  self.output_rows[row_idx][col_info.slot_idx] = val
55
- input_col_ids.append(col_info.col.id)
60
+ input_slot_idxs.add(col_info.slot_idx)
56
61
 
57
- # set the remaining stored non-computed columns to null
58
- null_col_ids = inserted_col_ids - set(input_col_ids)
59
- for col_id in null_col_ids:
60
- col_info = column_info.get(col_id)
62
+ # set the remaining output slots to their default values (presently None)
63
+ missing_slot_idxs = output_slot_idxs - input_slot_idxs
64
+ for slot_idx in missing_slot_idxs:
65
+ col_info = output_cols_by_idx.get(slot_idx)
61
66
  assert col_info is not None
62
67
  self.output_rows[row_idx][col_info.slot_idx] = None
63
68
 
@@ -37,7 +37,6 @@ class SqlScanNode(ExecNode):
37
37
  order_by_items = []
38
38
  if exact_version_only is None:
39
39
  exact_version_only = []
40
- super().__init__(row_builder, [], [], None)
41
40
  self.tbl = tbl
42
41
  target = tbl.tbl_version # the stored table we're scanning
43
42
  self.sql_exprs = exprs.ExprSet(select_list)
@@ -45,6 +44,7 @@ class SqlScanNode(ExecNode):
45
44
  for iter_arg in row_builder.unstored_iter_args.values():
46
45
  sql_subexprs = iter_arg.subexprs(filter=lambda e: e.sql_expr() is not None, traverse_matches=False)
47
46
  [self.sql_exprs.append(e) for e in sql_subexprs]
47
+ super().__init__(row_builder, self.sql_exprs, [], None) # we materialize self.sql_exprs
48
48
  self.filter = filter
49
49
  self.filter_eval_ctx = \
50
50
  row_builder.create_eval_ctx([filter], exclude=select_list) if filter is not None else None
@@ -1,5 +1,5 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, List, Any, Dict, Tuple
2
+ from typing import Optional, Any, Tuple
3
3
  from uuid import UUID
4
4
 
5
5
  import sqlalchemy as sql
@@ -38,7 +38,7 @@ class ColumnRef(Expr):
38
38
  self.iter_arg_ctx = iter_arg_ctx
39
39
  assert len(self.iter_arg_ctx.target_slot_idxs) == 1 # a single inline dict
40
40
 
41
- def _id_attrs(self) -> List[Tuple[str, Any]]:
41
+ def _id_attrs(self) -> list[Tuple[str, Any]]:
42
42
  return super()._id_attrs() + [('tbl_id', self.col.tbl.id), ('col_id', self.col.id)]
43
43
 
44
44
  def __getattr__(self, name: str) -> Expr:
@@ -64,8 +64,8 @@ class ColumnRef(Expr):
64
64
  return super().__getattr__(name)
65
65
 
66
66
  def similarity(self, other: Any) -> Expr:
67
- if isinstance(other, Expr):
68
- raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
67
+ # if isinstance(other, Expr):
68
+ # raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
69
69
  item = Expr.from_object(other)
70
70
  if item is None or not(item.col_type.is_string_type() or item.col_type.is_image_type()):
71
71
  raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not a {type(other)}')
@@ -86,7 +86,8 @@ class ColumnRef(Expr):
86
86
 
87
87
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
88
88
  if not self.is_unstored_iter_col:
89
- assert data_row.has_val[self.slot_idx]
89
+ # supply default
90
+ data_row[self.slot_idx] = None
90
91
  return
91
92
 
92
93
  # if this is a new base row, we need to instantiate a new iterator
@@ -99,16 +100,20 @@ class ColumnRef(Expr):
99
100
  res = next(self.iterator)
100
101
  data_row[self.slot_idx] = res[self.col.name]
101
102
 
102
- def _as_dict(self) -> Dict:
103
+ def _as_dict(self) -> dict:
103
104
  tbl = self.col.tbl
104
105
  version = tbl.version if tbl.is_snapshot else None
105
106
  return {'tbl_id': str(tbl.id), 'tbl_version': version, 'col_id': self.col.id}
106
107
 
107
108
  @classmethod
108
- def _from_dict(cls, d: Dict, components: List[Expr]) -> Expr:
109
+ def get_column(cls, d: dict) -> catalog.Column:
109
110
  tbl_id, version, col_id = UUID(d['tbl_id']), d['tbl_version'], d['col_id']
110
111
  tbl_version = catalog.Catalog.get().tbl_versions[(tbl_id, version)]
111
112
  # don't use tbl_version.cols_by_id here, this might be a snapshot reference to a column that was then dropped
112
113
  col = next(col for col in tbl_version.cols if col.id == col_id)
113
- return cls(col)
114
+ return col
114
115
 
116
+ @classmethod
117
+ def _from_dict(cls, d: dict, _: list[Expr]) -> Expr:
118
+ col = cls.get_column(d)
119
+ return cls(col)
@@ -133,6 +133,10 @@ class DataRow:
133
133
  np.save(buffer, np_array)
134
134
  return buffer.getvalue()
135
135
 
136
+ # for JSON columns, we need to store None as an explicit NULL, otherwise it stores a json 'null'
137
+ if self.vals[index] is None and sa_col_type is not None and isinstance(sa_col_type, sql.JSON):
138
+ return sql.sql.null()
139
+
136
140
  return self.vals[index]
137
141
 
138
142
  def __setitem__(self, idx: object, val: Any) -> None:
pixeltable/exprs/expr.py CHANGED
@@ -158,7 +158,9 @@ class Expr(abc.ABC):
158
158
  return result
159
159
 
160
160
  @classmethod
161
- def copy_list(cls, expr_list: List[Expr]) -> List[Expr]:
161
+ def copy_list(cls, expr_list: Optional[List[Expr]]) -> Optional[List[Expr]]:
162
+ if expr_list is None:
163
+ return None
162
164
  return [e.copy() for e in expr_list]
163
165
 
164
166
  def __deepcopy__(self, memo=None) -> Expr:
@@ -297,6 +299,19 @@ class Expr(abc.ABC):
297
299
  ids.update(e.tbl_ids())
298
300
  return ids
299
301
 
302
+ @classmethod
303
+ def get_refd_columns(cls, expr_dict: dict[str, Any]) -> list[catalog.Column]:
304
+ """Return Columns referenced by expr_dict."""
305
+ result: list[catalog.Column] = []
306
+ assert '_classname' in expr_dict
307
+ from .column_ref import ColumnRef
308
+ if expr_dict['_classname'] == 'ColumnRef':
309
+ result.append(ColumnRef.get_column(expr_dict))
310
+ if 'components' in expr_dict:
311
+ for component_dict in expr_dict['components']:
312
+ result.extend(cls.get_refd_columns(component_dict))
313
+ return result
314
+
300
315
  @classmethod
301
316
  def from_object(cls, o: object) -> Optional[Expr]:
302
317
  """
@@ -54,7 +54,7 @@ class FunctionCall(Expr):
54
54
  self.arg_types: List[ts.ColumnType] = []
55
55
  self.kwarg_types: Dict[str, ts.ColumnType] = {}
56
56
  # the prefix of parameters that are bound can be passed by position
57
- for param in fn.py_signature.parameters.values():
57
+ for param in fn.signature.py_signature.parameters.values():
58
58
  if param.name not in bound_args or param.kind == inspect.Parameter.KEYWORD_ONLY:
59
59
  break
60
60
  arg = bound_args[param.name]
@@ -67,7 +67,7 @@ class FunctionCall(Expr):
67
67
  self.arg_types.append(signature.parameters[param.name].col_type)
68
68
 
69
69
  # the remaining args are passed as keywords
70
- kw_param_names = set(bound_args.keys()) - set(list(fn.py_signature.parameters.keys())[:len(self.args)])
70
+ kw_param_names = set(bound_args.keys()) - set(list(fn.signature.py_signature.parameters.keys())[:len(self.args)])
71
71
  for param_name in kw_param_names:
72
72
  arg = bound_args[param_name]
73
73
  if isinstance(arg, Expr):
@@ -75,7 +75,7 @@ class FunctionCall(Expr):
75
75
  self.components.append(arg.copy())
76
76
  else:
77
77
  self.kwargs[param_name] = (None, arg)
78
- if fn.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
78
+ if fn.signature.py_signature.parameters[param_name].kind != inspect.Parameter.VAR_KEYWORD:
79
79
  self.kwarg_types[param_name] = signature.parameters[param_name].col_type
80
80
 
81
81
  # window function state:
@@ -117,7 +117,7 @@ class FunctionCall(Expr):
117
117
  self.id = self._create_id()
118
118
 
119
119
  def _create_rowid_refs(self, tbl: catalog.Table) -> List[Expr]:
120
- target = tbl.tbl_version_path.tbl_version
120
+ target = tbl._tbl_version_path.tbl_version
121
121
  return [RowidRef(target, i) for i in range(target.num_rowid_columns())]
122
122
 
123
123
  @classmethod