pixeltable 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/table.py +247 -83
  3. pixeltable/catalog/view.py +5 -2
  4. pixeltable/dataframe.py +240 -92
  5. pixeltable/exec/__init__.py +1 -1
  6. pixeltable/exec/exec_node.py +6 -7
  7. pixeltable/exec/sql_node.py +91 -44
  8. pixeltable/exprs/__init__.py +1 -0
  9. pixeltable/exprs/arithmetic_expr.py +1 -1
  10. pixeltable/exprs/array_slice.py +1 -1
  11. pixeltable/exprs/column_property_ref.py +1 -1
  12. pixeltable/exprs/column_ref.py +29 -2
  13. pixeltable/exprs/comparison.py +1 -1
  14. pixeltable/exprs/compound_predicate.py +1 -1
  15. pixeltable/exprs/expr.py +11 -5
  16. pixeltable/exprs/expr_set.py +8 -0
  17. pixeltable/exprs/function_call.py +14 -11
  18. pixeltable/exprs/in_predicate.py +1 -1
  19. pixeltable/exprs/inline_expr.py +3 -3
  20. pixeltable/exprs/is_null.py +1 -1
  21. pixeltable/exprs/json_mapper.py +1 -1
  22. pixeltable/exprs/json_path.py +1 -1
  23. pixeltable/exprs/method_ref.py +1 -1
  24. pixeltable/exprs/rowid_ref.py +1 -1
  25. pixeltable/exprs/similarity_expr.py +4 -1
  26. pixeltable/exprs/sql_element_cache.py +4 -0
  27. pixeltable/exprs/type_cast.py +2 -2
  28. pixeltable/exprs/variable.py +3 -0
  29. pixeltable/func/expr_template_function.py +3 -0
  30. pixeltable/func/function.py +37 -1
  31. pixeltable/func/signature.py +1 -0
  32. pixeltable/functions/mistralai.py +0 -2
  33. pixeltable/functions/ollama.py +4 -4
  34. pixeltable/globals.py +32 -18
  35. pixeltable/index/embedding_index.py +6 -1
  36. pixeltable/io/__init__.py +1 -1
  37. pixeltable/io/parquet.py +39 -19
  38. pixeltable/iterators/__init__.py +1 -0
  39. pixeltable/iterators/image.py +100 -0
  40. pixeltable/iterators/video.py +7 -8
  41. pixeltable/metadata/__init__.py +1 -1
  42. pixeltable/metadata/converters/convert_22.py +17 -0
  43. pixeltable/metadata/notes.py +1 -0
  44. pixeltable/plan.py +129 -51
  45. pixeltable/store.py +1 -1
  46. pixeltable/tool/create_test_db_dump.py +4 -1
  47. pixeltable/type_system.py +1 -1
  48. pixeltable/utils/arrow.py +8 -3
  49. pixeltable/utils/description_helper.py +89 -0
  50. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/METADATA +28 -12
  51. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/RECORD +54 -51
  52. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/WHEEL +1 -1
  53. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/LICENSE +0 -0
  54. {pixeltable-0.2.24.dist-info → pixeltable-0.2.26.dist-info}/entry_points.txt +0 -0
pixeltable/plan.py CHANGED
@@ -1,4 +1,8 @@
1
- from typing import Any, Iterable, Optional, Sequence
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import enum
5
+ from typing import Any, Iterable, Optional, Sequence, Literal
2
6
  from uuid import UUID
3
7
 
4
8
  import sqlalchemy as sql
@@ -38,13 +42,46 @@ def _get_combined_ordering(
38
42
  return result
39
43
 
40
44
 
45
+ class JoinType(enum.Enum):
46
+ INNER = 0
47
+ LEFT = 1
48
+ # TODO: implement
49
+ # RIGHT = 2
50
+ FULL_OUTER = 3
51
+ CROSS = 4
52
+
53
+ LiteralType = Literal['inner', 'left', 'full_outer', 'cross']
54
+
55
+ @classmethod
56
+ def validated(cls, name: str, error_prefix: str) -> JoinType:
57
+ try:
58
+ return cls[name.upper()]
59
+ except KeyError:
60
+ val_strs = ', '.join(f'{s.lower()!r}' for s in cls.__members__.keys())
61
+ raise excs.Error(f'{error_prefix} must be one of: [{val_strs}]')
62
+
63
+
64
+ @dataclasses.dataclass
65
+ class JoinClause:
66
+ """Corresponds to a single 'JOIN ... ON (...)' clause in a SELECT statement; excludes the joined table."""
67
+ join_type: JoinType
68
+ join_predicate: Optional[exprs.Expr] # None for join_type == CROSS
69
+
70
+
71
+ @dataclasses.dataclass
72
+ class FromClause:
73
+ """Corresponds to the From-clause ('FROM <tbl> JOIN ... ON (...) JOIN ...') of a SELECT statement """
74
+ tbls: list[catalog.TableVersionPath]
75
+ join_clauses: list[JoinClause] = dataclasses.field(default_factory=list)
76
+
77
+
41
78
  class Analyzer:
42
79
  """
43
80
  Performs semantic analysis of a query and stores the analysis state.
44
81
  """
45
82
 
46
- tbl: catalog.TableVersionPath
47
- all_exprs: list[exprs.Expr]
83
+ from_clause: FromClause
84
+ all_exprs: list[exprs.Expr] # union of all exprs, aside from sql_where_clause
48
85
  select_list: list[exprs.Expr]
49
86
  group_by_clause: Optional[list[exprs.Expr]] # None for non-aggregate queries; [] for agg query w/o grouping
50
87
  grouping_exprs: list[exprs.Expr] # [] for non-aggregate queries or agg query w/o grouping
@@ -63,12 +100,12 @@ class Analyzer:
63
100
  agg_order_by: list[exprs.Expr]
64
101
 
65
102
  def __init__(
66
- self, tbl: catalog.TableVersionPath, select_list: Sequence[exprs.Expr],
103
+ self, from_clause: FromClause, select_list: Sequence[exprs.Expr],
67
104
  where_clause: Optional[exprs.Expr] = None, group_by_clause: Optional[list[exprs.Expr]] = None,
68
105
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None):
69
106
  if order_by_clause is None:
70
107
  order_by_clause = []
71
- self.tbl = tbl
108
+ self.from_clause = from_clause
72
109
  self.sql_elements = exprs.SqlElementCache()
73
110
 
74
111
  # remove references to unstored computed cols
@@ -88,6 +125,9 @@ class Analyzer:
88
125
 
89
126
  # all exprs that are evaluated in Python; not executable
90
127
  self.all_exprs = self.select_list.copy()
128
+ for join_clause in from_clause.join_clauses:
129
+ if join_clause.join_predicate is not None:
130
+ self.all_exprs.append(join_clause.join_predicate)
91
131
  if self.group_by_clause is not None:
92
132
  self.all_exprs.extend(self.group_by_clause)
93
133
  self.all_exprs.extend(e for e, _ in self.order_by_clause)
@@ -224,7 +264,7 @@ class Planner:
224
264
  """Creates a plan for TableVersion.insert()"""
225
265
  assert not tbl.is_view()
226
266
  # stored_cols: all cols we need to store, incl computed cols (and indices)
227
- stored_cols = [c for c in tbl.cols if c.is_stored]
267
+ stored_cols = [c for c in tbl.cols_by_id.values() if c.is_stored]
228
268
  assert len(stored_cols) > 0 # there needs to be something to store
229
269
  row_builder = exprs.RowBuilder([], stored_cols, [])
230
270
 
@@ -316,7 +356,8 @@ class Planner:
316
356
  recomputed_cols = {c for c in recomputed_cols if c.is_stored}
317
357
  recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
318
358
  copied_cols = [
319
- col for col in target.cols if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
359
+ col for col in target.cols_by_id.values()
360
+ if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
320
361
  ]
321
362
  select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
322
363
  select_list.extend(update_targets.values())
@@ -329,7 +370,7 @@ class Planner:
329
370
  select_list.extend(recomputed_exprs)
330
371
 
331
372
  # we need to retrieve the PK columns of the existing rows
332
- plan = cls.create_query_plan(tbl, select_list, where_clause=where_clause, ignore_errors=True)
373
+ plan = cls.create_query_plan(FromClause(tbls=[tbl]), select_list, where_clause=where_clause, ignore_errors=True)
333
374
  all_base_cols = copied_cols + updated_cols + list(recomputed_base_cols) # same order as select_list
334
375
  # update row builder with column information
335
376
  for i, col in enumerate(all_base_cols):
@@ -373,7 +414,8 @@ class Planner:
373
414
  recomputed_cols = {c for c in recomputed_cols if c.is_stored}
374
415
  recomputed_base_cols = {col for col in recomputed_cols if col.tbl == target}
375
416
  copied_cols = [
376
- col for col in target.cols if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
417
+ col for col in target.cols_by_id.values()
418
+ if col.is_stored and not col in updated_cols and not col in recomputed_base_cols
377
419
  ]
378
420
  select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
379
421
  select_list.extend(exprs.ColumnRef(col) for col in updated_cols)
@@ -387,13 +429,12 @@ class Planner:
387
429
  # - SqlLookupNode to retrieve the existing rows
388
430
  # - RowUpdateNode to update the retrieved rows
389
431
  # - ExprEvalNode to evaluate the remaining output exprs
390
- analyzer = Analyzer(tbl, select_list)
432
+ analyzer = Analyzer(FromClause(tbls=[tbl]), select_list)
391
433
  sql_exprs = list(exprs.Expr.list_subexprs(
392
434
  analyzer.all_exprs, filter=analyzer.sql_elements.contains, traverse_matches=False))
393
435
  row_builder = exprs.RowBuilder(analyzer.all_exprs, [], sql_exprs)
394
436
  analyzer.finalize(row_builder)
395
437
  sql_lookup_node = exec.SqlLookupNode(tbl, row_builder, sql_exprs, sa_key_cols, key_vals)
396
- delete_where_clause = sql_lookup_node.where_clause
397
438
  col_vals = [{col: row[col].val for col in updated_cols} for row in batch]
398
439
  row_update_node = exec.RowUpdateNode(tbl, key_vals, len(rowids) > 0, col_vals, row_builder, sql_lookup_node)
399
440
  plan: exec.ExecNode = row_update_node
@@ -412,7 +453,8 @@ class Planner:
412
453
  plan.set_ctx(ctx)
413
454
  recomputed_user_cols = [c for c in recomputed_cols if c.name is not None]
414
455
  return (
415
- plan, row_update_node, delete_where_clause, list(updated_cols) + recomputed_user_cols, recomputed_user_cols
456
+ plan, row_update_node, sql_lookup_node.where_clause_element, list(updated_cols) + recomputed_user_cols,
457
+ recomputed_user_cols
416
458
  )
417
459
 
418
460
  @classmethod
@@ -439,7 +481,7 @@ class Planner:
439
481
  target = view.tbl_version # the one we need to update
440
482
  # retrieve all stored cols and all target exprs
441
483
  recomputed_cols = set(recompute_targets.copy())
442
- copied_cols = [col for col in target.cols if col.is_stored and not col in recomputed_cols]
484
+ copied_cols = [col for col in target.cols_by_id.values() if col.is_stored and not col in recomputed_cols]
443
485
  select_list: list[exprs.Expr] = [exprs.ColumnRef(col) for col in copied_cols]
444
486
  # resolve recomputed exprs to stored columns in the base
445
487
  recomputed_exprs = \
@@ -448,7 +490,8 @@ class Planner:
448
490
 
449
491
  # we need to retrieve the PK columns of the existing rows
450
492
  plan = cls.create_query_plan(
451
- view, select_list, where_clause=target.predicate, ignore_errors=True, exact_version_only=view.get_bases())
493
+ FromClause(tbls=[view]), select_list, where_clause=target.predicate, ignore_errors=True,
494
+ exact_version_only=view.get_bases())
452
495
  for i, col in enumerate(copied_cols + list(recomputed_cols)): # same order as select_list
453
496
  plan.row_builder.add_table_column(col, select_list[i].slot_idx)
454
497
  # TODO: avoid duplication with view_load_plan() logic (where does this belong?)
@@ -459,7 +502,7 @@ class Planner:
459
502
 
460
503
  @classmethod
461
504
  def create_view_load_plan(
462
- cls, view: catalog.TableVersionPath, propagates_insert: bool = False
505
+ cls, view: catalog.TableVersionPath, propagates_insert: bool = False
463
506
  ) -> tuple[exec.ExecNode, int]:
464
507
  """Creates a query plan for populating a view.
465
508
 
@@ -479,7 +522,7 @@ class Planner:
479
522
  # - we can ignore stored non-computed columns because they have a default value that is supplied directly by
480
523
  # the store
481
524
  target = view.tbl_version # the one we need to populate
482
- stored_cols = [c for c in target.cols if c.is_stored]
525
+ stored_cols = [c for c in target.cols_by_id.values() if c.is_stored]
483
526
  # 2. for component views: iterator args
484
527
  iterator_args = [target.iterator_args] if target.iterator_args is not None else []
485
528
 
@@ -489,16 +532,16 @@ class Planner:
489
532
  # 1. materialize exprs computed from the base that are needed for stored view columns
490
533
  # 2. if it's an iterator view, expand the base rows into component rows
491
534
  # 3. materialize stored view columns that haven't been produced by step 1
492
- base_output_exprs = [e for e in row_builder.default_eval_ctx.exprs if e.is_bound_by(view.base)]
535
+ base_output_exprs = [e for e in row_builder.default_eval_ctx.exprs if e.is_bound_by([view.base])]
493
536
  view_output_exprs = [
494
537
  e for e in row_builder.default_eval_ctx.target_exprs
495
- if e.is_bound_by(view) and not e.is_bound_by(view.base)
538
+ if e.is_bound_by([view]) and not e.is_bound_by([view.base])
496
539
  ]
497
540
  # if we're propagating an insert, we only want to see those base rows that were created for the current version
498
- base_analyzer = Analyzer(view, base_output_exprs, where_clause=target.predicate)
541
+ base_analyzer = Analyzer(FromClause(tbls=[view.base]), base_output_exprs, where_clause=target.predicate)
499
542
  base_eval_ctx = row_builder.create_eval_ctx(base_analyzer.all_exprs)
500
543
  plan = cls._create_query_plan(
501
- view.base, row_builder=row_builder, analyzer=base_analyzer, eval_ctx=base_eval_ctx, with_pk=True,
544
+ row_builder=row_builder, analyzer=base_analyzer, eval_ctx=base_eval_ctx, with_pk=True,
502
545
  exact_version_only=view.get_bases() if propagates_insert else [])
503
546
  exec_ctx = plan.ctx
504
547
  if target.is_component_view():
@@ -513,6 +556,13 @@ class Planner:
513
556
  plan.set_ctx(exec_ctx)
514
557
  return plan, len(row_builder.default_eval_ctx.target_exprs)
515
558
 
559
+ @classmethod
560
+ def _verify_join_clauses(cls, analyzer: Analyzer) -> None:
561
+ """Verify that join clauses are expressible in SQL"""
562
+ for join_clause in analyzer.from_clause.join_clauses:
563
+ if join_clause.join_predicate is not None and analyzer.sql_elements.get(join_clause.join_predicate) is None:
564
+ raise excs.Error(f'Join predicate {join_clause.join_predicate} not expressible in SQL')
565
+
516
566
  @classmethod
517
567
  def _verify_ordering(cls, analyzer: Analyzer, verify_agg: bool) -> None:
518
568
  """Verify that the various ordering requirements don't conflict"""
@@ -551,9 +601,7 @@ class Planner:
551
601
  return s1 <= s2
552
602
 
553
603
  @classmethod
554
- def _insert_prefetch_node(
555
- cls, tbl_id: UUID, output_exprs: list[exprs.Expr], row_builder: exprs.RowBuilder, input: exec.ExecNode
556
- ) -> exec.ExecNode:
604
+ def _insert_prefetch_node(cls, tbl_id: UUID, row_builder: exprs.RowBuilder, input: exec.ExecNode) -> exec.ExecNode:
557
605
  """Returns a CachePrefetchNode into the plan if needed, otherwise returns input"""
558
606
  # we prefetch external files for all media ColumnRefs, even those that aren't part of the dependencies
559
607
  # of output_exprs: if unstored iterator columns are present, we might need to materialize ColumnRefs that
@@ -570,7 +618,7 @@ class Planner:
570
618
 
571
619
  @classmethod
572
620
  def create_query_plan(
573
- cls, tbl: catalog.TableVersionPath, select_list: Optional[list[exprs.Expr]] = None,
621
+ cls, from_clause: FromClause, select_list: Optional[list[exprs.Expr]] = None,
574
622
  where_clause: Optional[exprs.Expr] = None, group_by_clause: Optional[list[exprs.Expr]] = None,
575
623
  order_by_clause: Optional[list[tuple[exprs.Expr, bool]]] = None, limit: Optional[int] = None,
576
624
  ignore_errors: bool = False, exact_version_only: Optional[list[catalog.TableVersion]] = None
@@ -586,7 +634,7 @@ class Planner:
586
634
  if exact_version_only is None:
587
635
  exact_version_only = []
588
636
  analyzer = Analyzer(
589
- tbl, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
637
+ from_clause, select_list, where_clause=where_clause, group_by_clause=group_by_clause,
590
638
  order_by_clause=order_by_clause)
591
639
  row_builder = exprs.RowBuilder(analyzer.all_exprs, [], [])
592
640
 
@@ -595,7 +643,7 @@ class Planner:
595
643
  # with_pk: for now, we always retrieve the PK, because we need it for the file cache
596
644
  eval_ctx = row_builder.create_eval_ctx(analyzer.select_list)
597
645
  plan = cls._create_query_plan(
598
- tbl, row_builder, analyzer=analyzer, eval_ctx=eval_ctx, limit=limit, with_pk=True,
646
+ row_builder=row_builder, analyzer=analyzer, eval_ctx=eval_ctx, limit=limit, with_pk=True,
599
647
  exact_version_only=exact_version_only)
600
648
  plan.ctx.ignore_errors = ignore_errors
601
649
  select_list.clear()
@@ -604,10 +652,9 @@ class Planner:
604
652
 
605
653
  @classmethod
606
654
  def _create_query_plan(
607
- cls, tbl: catalog.TableVersionPath, row_builder: exprs.RowBuilder, analyzer: Analyzer,
608
- eval_ctx: exprs.RowBuilder.EvalCtx,
609
- limit: Optional[int] = None, with_pk: bool = False,
610
- exact_version_only: Optional[list[catalog.TableVersion]] = None
655
+ cls, row_builder: exprs.RowBuilder, analyzer: Analyzer, eval_ctx: exprs.RowBuilder.EvalCtx,
656
+ limit: Optional[int] = None, with_pk: bool = False,
657
+ exact_version_only: Optional[list[catalog.TableVersion]] = None
611
658
  ) -> exec.ExecNode:
612
659
  """
613
660
  Create plan to materialize eval_ctx.
@@ -619,7 +666,6 @@ class Planner:
619
666
  """
620
667
  if exact_version_only is None:
621
668
  exact_version_only = []
622
- assert isinstance(tbl, catalog.TableVersionPath)
623
669
  sql_elements = analyzer.sql_elements
624
670
  is_python_agg = (
625
671
  not sql_elements.contains_all(analyzer.agg_fn_calls)
@@ -627,17 +673,19 @@ class Planner:
627
673
  )
628
674
  ctx = exec.ExecContext(row_builder)
629
675
  cls._verify_ordering(analyzer, verify_agg=is_python_agg)
676
+ cls._verify_join_clauses(analyzer)
630
677
 
631
- # materialized with SQL scan:
678
+ # materialized with SQL table scans (ie, single-table SELECT statements):
632
679
  # - select list subexprs that aren't aggregates
633
- # - Where clause conjuncts that can't be run in SQL
680
+ # - join clause subexprs
681
+ # - subexprs of Where clause conjuncts that can't be run in SQL
634
682
  # - all grouping exprs, if any aggregate function call can't be run in SQL (in that case, they all have to be
635
683
  # run in Python)
636
684
  candidates = list(exprs.Expr.list_subexprs(
637
685
  analyzer.select_list,
638
686
  filter=lambda e: (
639
- sql_elements.contains(e)
640
- and not e._contains(cls=exprs.FunctionCall, filter=lambda e: bool(e.is_agg_fn_call))
687
+ sql_elements.contains(e)
688
+ and not e._contains(cls=exprs.FunctionCall, filter=lambda e: bool(e.is_agg_fn_call))
641
689
  ),
642
690
  traverse_matches=False))
643
691
  if analyzer.filter is not None:
@@ -647,15 +695,44 @@ class Planner:
647
695
  candidates.extend(exprs.Expr.list_subexprs(
648
696
  analyzer.group_by_clause, filter=lambda e: sql_elements.contains(e), traverse_matches=False))
649
697
  # not isinstance(...): we don't want to materialize Literals via a Select
650
- sql_scan_exprs = exprs.ExprSet(e for e in candidates if not isinstance(e, exprs.Literal))
698
+ sql_exprs = exprs.ExprSet(e for e in candidates if not isinstance(e, exprs.Literal))
699
+
700
+ # create table scans; each scan produces subexprs of (sql_exprs + join clauses)
701
+ join_exprs = exprs.ExprSet(
702
+ join_clause.join_predicate
703
+ for join_clause in analyzer.from_clause.join_clauses
704
+ if join_clause.join_predicate is not None)
705
+ scan_target_exprs = sql_exprs | join_exprs
706
+ tbl_scan_plans: list[exec.SqlScanNode] = []
707
+ plan: exec.ExecNode
708
+ for tbl in analyzer.from_clause.tbls:
709
+ # materialize all subexprs of scan_target_exprs that are bound by tbl
710
+ tbl_scan_exprs = exprs.ExprSet(
711
+ exprs.Expr.list_subexprs(
712
+ scan_target_exprs,
713
+ filter=lambda e: e.is_bound_by([tbl]) and not isinstance(e, exprs.Literal),
714
+ traverse_matches=False))
715
+ plan = exec.SqlScanNode(
716
+ tbl, row_builder, select_list=tbl_scan_exprs,
717
+ set_pk=with_pk, exact_version_only=exact_version_only)
718
+ tbl_scan_plans.append(plan)
719
+
720
+ if len(analyzer.from_clause.join_clauses) > 0:
721
+ plan = exec.SqlJoinNode(
722
+ row_builder, inputs=tbl_scan_plans, join_clauses=analyzer.from_clause.join_clauses,
723
+ select_list=sql_exprs)
724
+ else:
725
+ plan = tbl_scan_plans[0]
651
726
 
652
- plan = exec.SqlScanNode(
653
- tbl, row_builder, select_list=sql_scan_exprs, where_clause=analyzer.sql_where_clause,
654
- filter=analyzer.filter, set_pk=with_pk, exact_version_only=exact_version_only)
727
+ if analyzer.sql_where_clause is not None:
728
+ plan.set_where(analyzer.sql_where_clause)
729
+ if analyzer.filter is not None:
730
+ plan.set_py_filter(analyzer.filter)
655
731
  if len(analyzer.window_fn_calls) > 0:
656
732
  # we need to order the input for window functions
657
- plan.add_order_by(analyzer.get_window_fn_ob_clause())
658
- plan = cls._insert_prefetch_node(tbl.tbl_version.id, analyzer.select_list, row_builder, plan)
733
+ plan.set_order_by(analyzer.get_window_fn_ob_clause())
734
+
735
+ plan = cls._insert_prefetch_node(tbl.tbl_version.id, row_builder, plan)
659
736
 
660
737
  if analyzer.group_by_clause is not None:
661
738
  # we're doing grouping aggregation; the input of the AggregateNode are the grouping exprs plus the
@@ -663,9 +740,9 @@ class Planner:
663
740
  agg_input = exprs.ExprSet(analyzer.grouping_exprs.copy())
664
741
  for fn_call in analyzer.agg_fn_calls:
665
742
  agg_input.update(fn_call.components)
666
- if not sql_scan_exprs.issuperset(agg_input):
743
+ if not sql_exprs.issuperset(agg_input):
667
744
  # we need an ExprEvalNode
668
- plan = exec.ExprEvalNode(row_builder, agg_input, sql_scan_exprs, input=plan)
745
+ plan = exec.ExprEvalNode(row_builder, agg_input, sql_exprs, input=plan)
669
746
 
670
747
  # batch size for aggregation input: this could be the entire table, so we need to divide it into
671
748
  # smaller batches; at the same time, we need to make the batches large enough to amortize the
@@ -689,16 +766,17 @@ class Planner:
689
766
  # we need an ExprEvalNode to evaluate the remaining output exprs
690
767
  plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, agg_output, input=plan)
691
768
  else:
692
- if not exprs.ExprSet(sql_scan_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
769
+ if not exprs.ExprSet(sql_exprs).issuperset(exprs.ExprSet(eval_ctx.target_exprs)):
693
770
  # we need an ExprEvalNode to evaluate the remaining output exprs
694
- plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, sql_scan_exprs, input=plan)
771
+ plan = exec.ExprEvalNode(row_builder, eval_ctx.target_exprs, sql_exprs, input=plan)
695
772
  # we're returning everything to the user, so we might as well do it in a single batch
696
773
  ctx.batch_size = 0
697
774
 
698
- sql_node = plan.get_sql_node()
699
- assert sql_node is not None
700
775
  if len(analyzer.order_by_clause) > 0:
701
- sql_node.add_order_by(analyzer.order_by_clause)
776
+ # we have the last SqlNode we created produce the ordering
777
+ sql_node = plan.get_node(exec.SqlNode)
778
+ assert sql_node is not None
779
+ sql_node.set_order_by(analyzer.order_by_clause)
702
780
 
703
781
  if limit is not None:
704
782
  plan.set_limit(limit)
@@ -708,7 +786,7 @@ class Planner:
708
786
 
709
787
  @classmethod
710
788
  def analyze(cls, tbl: catalog.TableVersionPath, where_clause: exprs.Expr) -> Analyzer:
711
- return Analyzer(tbl, [], where_clause=where_clause)
789
+ return Analyzer(FromClause(tbls=[tbl]), [], where_clause=where_clause)
712
790
 
713
791
  @classmethod
714
792
  def create_add_column_plan(
@@ -721,9 +799,9 @@ class Planner:
721
799
  """
722
800
  assert isinstance(tbl, catalog.TableVersionPath)
723
801
  row_builder = exprs.RowBuilder(output_exprs=[], columns=[col], input_exprs=[])
724
- analyzer = Analyzer(tbl, row_builder.default_eval_ctx.target_exprs)
802
+ analyzer = Analyzer(FromClause(tbls=[tbl]), row_builder.default_eval_ctx.target_exprs)
725
803
  plan = cls._create_query_plan(
726
- tbl, row_builder=row_builder, analyzer=analyzer, eval_ctx=row_builder.default_eval_ctx, with_pk=True)
804
+ row_builder=row_builder, analyzer=analyzer, eval_ctx=row_builder.default_eval_ctx, with_pk=True)
727
805
  plan.ctx.batch_size = 16
728
806
  plan.ctx.show_pbar = True
729
807
  plan.ctx.ignore_errors = True
pixeltable/store.py CHANGED
@@ -159,7 +159,7 @@ class StoreBase:
159
159
  def count(self, conn: Optional[sql.engine.Connection] = None) -> int:
160
160
  """Return the number of rows visible in self.tbl_version"""
161
161
  stmt = (
162
- sql.select(sql.func.count('*')) # type: ignore
162
+ sql.select(sql.func.count('*'))
163
163
  .select_from(self.sa_tbl)
164
164
  .where(self.v_min_col <= self.tbl_version.version)
165
165
  .where(self.v_max_col > self.tbl_version.version)
@@ -270,7 +270,10 @@ class Dumper:
270
270
  add_column('c6_to_string', t.c6.apply(json.dumps))
271
271
  add_column('c6_back_to_json', t[f'{col_prefix}_c6_to_string'].apply(json.loads))
272
272
 
273
- t.add_embedding_index(f'{col_prefix}_function_call', string_embed=embed_udf.clip_text_embed)
273
+ t.add_embedding_index(
274
+ f'{col_prefix}_function_call',
275
+ string_embed=pxt.functions.huggingface.clip_text.using(model_id='openai/clip-vit-base-patch32')
276
+ )
274
277
 
275
278
  # query()
276
279
  @t.query
pixeltable/type_system.py CHANGED
@@ -166,7 +166,7 @@ class ColumnType:
166
166
  if t == cls.Type.DOCUMENT:
167
167
  return DocumentType()
168
168
 
169
- def __str__(self) -> str:
169
+ def __repr__(self) -> str:
170
170
  return self._to_str(as_schema=False)
171
171
 
172
172
  def _to_str(self, as_schema: bool) -> str:
pixeltable/utils/arrow.py CHANGED
@@ -3,14 +3,17 @@ from typing import Any, Iterator, Optional, Union
3
3
 
4
4
  import numpy as np
5
5
  import pyarrow as pa
6
+ import datetime
6
7
 
7
8
  import pixeltable.type_system as ts
9
+ from pixeltable.env import Env
10
+
11
+ _tz_def = Env().get().default_time_zone
8
12
 
9
13
  _logger = logging.getLogger(__name__)
10
14
 
11
15
  _pa_to_pt: dict[pa.DataType, ts.ColumnType] = {
12
16
  pa.string(): ts.StringType(nullable=True),
13
- pa.timestamp('us'): ts.TimestampType(nullable=True),
14
17
  pa.bool_(): ts.BoolType(nullable=True),
15
18
  pa.uint8(): ts.IntType(nullable=True),
16
19
  pa.int8(): ts.IntType(nullable=True),
@@ -23,7 +26,7 @@ _pa_to_pt: dict[pa.DataType, ts.ColumnType] = {
23
26
 
24
27
  _pt_to_pa: dict[type[ts.ColumnType], pa.DataType] = {
25
28
  ts.StringType: pa.string(),
26
- ts.TimestampType: pa.timestamp('us'), # postgres timestamp is microseconds
29
+ ts.TimestampType: pa.timestamp('us', tz=datetime.timezone.utc), # postgres timestamp is microseconds
27
30
  ts.BoolType: pa.bool_(),
28
31
  ts.IntType: pa.int64(),
29
32
  ts.FloatType: pa.float32(),
@@ -39,7 +42,9 @@ def to_pixeltable_type(arrow_type: pa.DataType) -> Optional[ts.ColumnType]:
39
42
  """Convert a pyarrow DataType to a pixeltable ColumnType if one is defined.
40
43
  Returns None if no conversion is currently implemented.
41
44
  """
42
- if arrow_type in _pa_to_pt:
45
+ if isinstance(arrow_type, pa.TimestampType):
46
+ return ts.TimestampType(nullable=True)
47
+ elif arrow_type in _pa_to_pt:
43
48
  return _pa_to_pt[arrow_type]
44
49
  elif isinstance(arrow_type, pa.FixedShapeTensorType):
45
50
  dtype = to_pixeltable_type(arrow_type.value_type)
@@ -0,0 +1,89 @@
1
+ import dataclasses
2
+ from typing import Optional, Union
3
+
4
+ import pandas as pd
5
+ from pandas.io.formats.style import Styler
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class _Descriptor:
10
+ body: Union[str, pd.DataFrame]
11
+ # The remaining fields only affect the behavior if `body` is a pd.DataFrame.
12
+ show_index: bool
13
+ show_header: bool
14
+ styler: Optional[Styler] = None
15
+
16
+
17
+ class DescriptionHelper:
18
+ """
19
+ Helper class for rendering long-form descriptions of Pixeltable objects.
20
+
21
+ The output is specified as a list of "descriptors", each of which can be either a string or a Pandas DataFrame,
22
+ in any combination. The descriptors will be rendered in sequence. This is useful for long-form descriptions that
23
+ include tables with differing schemas or formatting, and/or a combination of tables and text.
24
+
25
+ DescriptionHelper can convert a list of descriptors into either HTML or plaintext and do something reasonable
26
+ in each case.
27
+ """
28
+ __descriptors: list[_Descriptor]
29
+
30
+ def __init__(self) -> None:
31
+ self.__descriptors = []
32
+
33
+ def append(
34
+ self,
35
+ descriptor: Union[str, pd.DataFrame],
36
+ show_index: bool = False,
37
+ show_header: bool = True,
38
+ styler: Optional[Styler] = None,
39
+ ) -> None:
40
+ self.__descriptors.append(_Descriptor(descriptor, show_index, show_header, styler))
41
+
42
+ def to_string(self) -> str:
43
+ blocks = [self.__render_text(descriptor) for descriptor in self.__descriptors]
44
+ return '\n\n'.join(blocks)
45
+
46
+ def to_html(self) -> str:
47
+ html_blocks = [self.__apply_styles(descriptor).to_html() for descriptor in self.__descriptors]
48
+ return '\n'.join(html_blocks)
49
+
50
+ @classmethod
51
+ def __render_text(cls, descriptor: _Descriptor) -> str:
52
+ if isinstance(descriptor.body, str):
53
+ return descriptor.body
54
+ else:
55
+ # If `show_index=False`, we get cleaner output (better intercolumn spacing) by setting the index to a
56
+ # list of empty strings than by setting `index=False` in the call to `df.to_string()`. It's pretty silly
57
+ # that `index=False` has side effects in Pandas that go beyond simply not displaying the index, but it
58
+ # is what it is.
59
+ df = descriptor.body
60
+ if not descriptor.show_index:
61
+ df = df.copy()
62
+ df.index = [''] * len(df) # type: ignore[assignment]
63
+ # max_colwidth=50 is the identical default that Pandas uses for a DataFrame's __repr__() output.
64
+ return df.to_string(header=descriptor.show_header, max_colwidth=50)
65
+
66
+ @classmethod
67
+ def __apply_styles(cls, descriptor: _Descriptor) -> Styler:
68
+ if isinstance(descriptor.body, str):
69
+ return (
70
+ # Render the string as a single-cell DataFrame. This will ensure a consistent style of output in
71
+ # cases where strings appear alongside DataFrames in the same DescriptionHelper.
72
+ pd.DataFrame([descriptor.body]).style
73
+ .set_properties(None, **{'white-space': 'pre-wrap', 'text-align': 'left', 'font-weight': 'bold'})
74
+ .hide(axis='index').hide(axis='columns')
75
+ )
76
+ else:
77
+ styler = descriptor.styler
78
+ if styler is None:
79
+ styler = descriptor.body.style
80
+ styler = (
81
+ styler
82
+ .set_properties(None, **{'white-space': 'pre-wrap', 'text-align': 'left'})
83
+ .set_table_styles([dict(selector='th', props=[('text-align', 'left')])])
84
+ )
85
+ if not descriptor.show_header:
86
+ styler = styler.hide(axis='columns')
87
+ if not descriptor.show_index:
88
+ styler = styler.hide(axis='index')
89
+ return styler