pixeltable 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (79) hide show
  1. pixeltable/__version__.py +2 -2
  2. pixeltable/catalog/catalog.py +8 -7
  3. pixeltable/catalog/column.py +11 -8
  4. pixeltable/catalog/insertable_table.py +1 -1
  5. pixeltable/catalog/path_dict.py +8 -6
  6. pixeltable/catalog/table.py +20 -13
  7. pixeltable/catalog/table_version.py +91 -54
  8. pixeltable/catalog/table_version_path.py +7 -9
  9. pixeltable/catalog/view.py +2 -1
  10. pixeltable/dataframe.py +1 -1
  11. pixeltable/env.py +173 -82
  12. pixeltable/exec/aggregation_node.py +2 -1
  13. pixeltable/exec/component_iteration_node.py +1 -1
  14. pixeltable/exec/sql_node.py +11 -8
  15. pixeltable/exprs/__init__.py +1 -0
  16. pixeltable/exprs/arithmetic_expr.py +4 -4
  17. pixeltable/exprs/array_slice.py +2 -1
  18. pixeltable/exprs/column_property_ref.py +9 -7
  19. pixeltable/exprs/column_ref.py +2 -1
  20. pixeltable/exprs/comparison.py +10 -7
  21. pixeltable/exprs/compound_predicate.py +3 -2
  22. pixeltable/exprs/data_row.py +19 -4
  23. pixeltable/exprs/expr.py +46 -35
  24. pixeltable/exprs/expr_set.py +32 -9
  25. pixeltable/exprs/function_call.py +56 -32
  26. pixeltable/exprs/in_predicate.py +3 -2
  27. pixeltable/exprs/inline_array.py +2 -1
  28. pixeltable/exprs/inline_dict.py +2 -1
  29. pixeltable/exprs/is_null.py +3 -2
  30. pixeltable/exprs/json_mapper.py +5 -4
  31. pixeltable/exprs/json_path.py +7 -1
  32. pixeltable/exprs/literal.py +34 -7
  33. pixeltable/exprs/method_ref.py +3 -3
  34. pixeltable/exprs/object_ref.py +6 -5
  35. pixeltable/exprs/row_builder.py +25 -17
  36. pixeltable/exprs/rowid_ref.py +2 -1
  37. pixeltable/exprs/similarity_expr.py +2 -1
  38. pixeltable/exprs/sql_element_cache.py +30 -0
  39. pixeltable/exprs/type_cast.py +3 -3
  40. pixeltable/exprs/variable.py +2 -1
  41. pixeltable/ext/functions/whisperx.py +4 -4
  42. pixeltable/ext/functions/yolox.py +6 -6
  43. pixeltable/func/aggregate_function.py +1 -0
  44. pixeltable/func/function.py +28 -4
  45. pixeltable/functions/__init__.py +4 -2
  46. pixeltable/functions/anthropic.py +107 -0
  47. pixeltable/functions/fireworks.py +2 -2
  48. pixeltable/functions/globals.py +6 -1
  49. pixeltable/functions/huggingface.py +2 -2
  50. pixeltable/functions/image.py +17 -2
  51. pixeltable/functions/json.py +5 -5
  52. pixeltable/functions/mistralai.py +188 -0
  53. pixeltable/functions/openai.py +6 -10
  54. pixeltable/functions/string.py +3 -2
  55. pixeltable/functions/timestamp.py +95 -7
  56. pixeltable/functions/together.py +5 -5
  57. pixeltable/functions/video.py +2 -2
  58. pixeltable/functions/vision.py +27 -17
  59. pixeltable/functions/whisper.py +1 -1
  60. pixeltable/io/hf_datasets.py +17 -15
  61. pixeltable/io/pandas.py +0 -2
  62. pixeltable/io/parquet.py +15 -14
  63. pixeltable/iterators/document.py +16 -15
  64. pixeltable/metadata/__init__.py +1 -1
  65. pixeltable/metadata/converters/convert_19.py +46 -0
  66. pixeltable/metadata/notes.py +1 -0
  67. pixeltable/metadata/schema.py +5 -4
  68. pixeltable/plan.py +100 -78
  69. pixeltable/store.py +5 -1
  70. pixeltable/tool/create_test_db_dump.py +4 -3
  71. pixeltable/type_system.py +12 -14
  72. pixeltable/utils/documents.py +45 -42
  73. pixeltable/utils/formatter.py +2 -2
  74. {pixeltable-0.2.16.dist-info → pixeltable-0.2.18.dist-info}/METADATA +79 -21
  75. pixeltable-0.2.18.dist-info/RECORD +147 -0
  76. pixeltable-0.2.16.dist-info/RECORD +0 -143
  77. {pixeltable-0.2.16.dist-info → pixeltable-0.2.18.dist-info}/LICENSE +0 -0
  78. {pixeltable-0.2.16.dist-info → pixeltable-0.2.18.dist-info}/WHEEL +0 -0
  79. {pixeltable-0.2.16.dist-info → pixeltable-0.2.18.dist-info}/entry_points.txt +0 -0
@@ -9,6 +9,12 @@ import pixeltable
9
9
  import pixeltable.catalog as catalog
10
10
  import pixeltable.exceptions as excs
11
11
  import pixeltable.type_system as ts
12
+ from .data_row import DataRow
13
+ from .expr import Expr
14
+ from .globals import print_slice
15
+ from .json_mapper import JsonMapper
16
+ from .row_builder import RowBuilder
17
+ from .sql_element_cache import SqlElementCache
12
18
 
13
19
  from .data_row import DataRow
14
20
  from .expr import Expr
@@ -140,7 +146,7 @@ class JsonPath(Expr):
140
146
  def _id_attrs(self) -> list[tuple[str, Any]]:
141
147
  return super()._id_attrs() + [('path_elements', self.path_elements)]
142
148
 
143
- def sql_expr(self) -> Optional[sql.ClauseElement]:
149
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ClauseElement]:
144
150
  """
145
151
  Postgres appears to have a bug: jsonb_path_query('{a: [{b: 0}, {b: 1}]}', '$.a.b') returns
146
152
  *two* rows (each containing col val 0), not a single row with [0, 0].
@@ -1,15 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import datetime
4
- from typing import Optional, List, Any, Dict, Tuple
4
+ from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
6
  import sqlalchemy as sql
7
7
 
8
- import pixeltable.exceptions as excs
9
8
  import pixeltable.type_system as ts
9
+ from pixeltable.env import Env
10
+
10
11
  from .data_row import DataRow
11
12
  from .expr import Expr
12
13
  from .row_builder import RowBuilder
14
+ from .sql_element_cache import SqlElementCache
13
15
 
14
16
 
15
17
  class Literal(Expr):
@@ -22,6 +24,15 @@ class Literal(Expr):
22
24
  if col_type is None:
23
25
  raise TypeError(f'Not a valid literal: {val}')
24
26
  super().__init__(col_type)
27
+ if isinstance(val, datetime.datetime):
28
+ # Normalize the datetime to UTC: all timestamps are stored as UTC (both in the database and in literals)
29
+ if val.tzinfo is None:
30
+ # We have a naive datetime. Modify it to use the configured default time zone
31
+ default_tz = Env.get().default_time_zone
32
+ if default_tz is not None:
33
+ val = val.replace(tzinfo=default_tz)
34
+ # Now convert to UTC
35
+ val = val.astimezone(datetime.timezone.utc)
25
36
  self.val = val
26
37
  self.id = self._create_id()
27
38
 
@@ -29,17 +40,24 @@ class Literal(Expr):
29
40
  return 'Literal'
30
41
 
31
42
  def __str__(self) -> str:
32
- if self.col_type.is_string_type() or self.col_type.is_timestamp_type():
43
+ if self.col_type.is_string_type():
33
44
  return f"'{self.val}'"
45
+ if self.col_type.is_timestamp_type():
46
+ assert isinstance(self.val, datetime.datetime)
47
+ default_tz = Env.get().default_time_zone
48
+ return f"'{self.val.astimezone(default_tz).isoformat()}'"
34
49
  return str(self.val)
35
50
 
51
+ def __repr__(self) -> str:
52
+ return f'Literal({self.val!r})'
53
+
36
54
  def _equals(self, other: Literal) -> bool:
37
55
  return self.val == other.val
38
56
 
39
57
  def _id_attrs(self) -> List[Tuple[str, Any]]:
40
58
  return super()._id_attrs() + [('val', self.val)]
41
59
 
42
- def sql_expr(self) -> Optional[sql.ClauseElement]:
60
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ClauseElement]:
43
61
  # we need to return something here so that we can generate a Where clause for predicates
44
62
  # that involve literals (like Where c > 0)
45
63
  return sql.sql.expression.literal(self.val)
@@ -52,7 +70,12 @@ class Literal(Expr):
52
70
  # For some types, we need to explictly record their type, because JSON does not know
53
71
  # how to interpret them unambiguously
54
72
  if self.col_type.is_timestamp_type():
55
- return {'val': self.val.isoformat(), 'val_t': self.col_type._type.name, **super()._as_dict()}
73
+ assert isinstance(self.val, datetime.datetime)
74
+ assert self.val.tzinfo == datetime.timezone.utc # Must be UTC in a literal
75
+ # Convert to ISO format in UTC (in keeping with the principle: all timestamps are
76
+ # stored as UTC in the database)
77
+ encoded_val = self.val.isoformat()
78
+ return {'val': encoded_val, 'val_t': self.col_type._type.name, **super()._as_dict()}
56
79
  else:
57
80
  return {'val': self.val, **super()._as_dict()}
58
81
 
@@ -61,6 +84,10 @@ class Literal(Expr):
61
84
  assert 'val' in d
62
85
  if 'val_t' in d:
63
86
  val_t = d['val_t']
87
+ # Currently the only special-cased literal type is TIMESTAMP
64
88
  assert val_t == ts.ColumnType.Type.TIMESTAMP.name
65
- return cls(datetime.datetime.fromisoformat(d['val']))
66
- return cls(d['val'])
89
+ dt = datetime.datetime.fromisoformat(d['val'])
90
+ assert dt.tzinfo == datetime.timezone.utc # Must be UTC in the database
91
+ return cls(dt)
92
+ else:
93
+ return cls(d['val'])
@@ -2,12 +2,12 @@ from typing import Any, Optional
2
2
 
3
3
  import sqlalchemy as sql
4
4
 
5
- import pixeltable.exceptions as excs
6
5
  import pixeltable.type_system as ts
7
6
  from pixeltable.exprs import Expr, FunctionCall
8
- from pixeltable.func import FunctionRegistry, CallableFunction
7
+ from pixeltable.func import FunctionRegistry
9
8
  from .data_row import DataRow
10
9
  from .row_builder import RowBuilder
10
+ from .sql_element_cache import SqlElementCache
11
11
 
12
12
 
13
13
  class MethodRef(Expr):
@@ -53,7 +53,7 @@ class MethodRef(Expr):
53
53
  def _id_attrs(self) -> list[tuple[str, Any]]:
54
54
  return super()._id_attrs() + [('method_name', self.method_name)]
55
55
 
56
- def sql_expr(self) -> Optional[sql.ClauseElement]:
56
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ClauseElement]:
57
57
  return None
58
58
 
59
59
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -1,14 +1,15 @@
1
1
  from __future__ import annotations
2
- from typing import Optional, List, Any, Dict, Tuple
3
- import copy
2
+
3
+ from typing import Optional
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
7
+ import pixeltable.type_system as ts
8
+ from .data_row import DataRow
7
9
  from .expr import Expr, ExprScope
8
10
  from .json_mapper import JsonMapper
9
- from .data_row import DataRow
10
11
  from .row_builder import RowBuilder
11
- import pixeltable.type_system as ts
12
+ from .sql_element_cache import SqlElementCache
12
13
 
13
14
 
14
15
  class ObjectRef(Expr):
@@ -32,7 +33,7 @@ class ObjectRef(Expr):
32
33
  def _equals(self, other: ObjectRef) -> bool:
33
34
  return self.owner is other.owner
34
35
 
35
- def sql_expr(self) -> Optional[sql.ClauseElement]:
36
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ClauseElement]:
36
37
  return None
37
38
 
38
39
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import sys
4
4
  import time
5
5
  from dataclasses import dataclass
6
- from typing import Optional, List, Any, Dict, Sequence, Tuple, Set
6
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple
7
7
 
8
8
  import sqlalchemy as sql
9
9
 
@@ -11,6 +11,7 @@ import pixeltable.catalog as catalog
11
11
  import pixeltable.exceptions as excs
12
12
  import pixeltable.func as func
13
13
  import pixeltable.utils as utils
14
+
14
15
  from .data_row import DataRow
15
16
  from .expr import Expr
16
17
  from .expr_set import ExprSet
@@ -22,7 +23,7 @@ class ExecProfile:
22
23
  self.eval_count = [0] * row_builder.num_materialized
23
24
  self.row_builder = row_builder
24
25
 
25
- def print(self, num_rows: int) -> str:
26
+ def print(self, num_rows: int) -> None:
26
27
  for i in range(self.row_builder.num_materialized):
27
28
  if self.eval_count[i] == 0:
28
29
  continue
@@ -57,7 +58,7 @@ class RowBuilder:
57
58
  target_exprs: List[Expr] # exprs corresponding to target_slot_idxs
58
59
 
59
60
  def __init__(
60
- self, output_exprs: Sequence[Expr], columns: Sequence[catalog.Column], input_exprs: Sequence[Expr]
61
+ self, output_exprs: Sequence[Expr], columns: Sequence[catalog.Column], input_exprs: Iterable[Expr]
61
62
  ):
62
63
  """
63
64
  Args:
@@ -96,7 +97,7 @@ class RowBuilder:
96
97
  expr = ColumnRef(col)
97
98
  expr = self._record_unique_expr(expr, recursive=False)
98
99
  self.add_table_column(col, expr.slot_idx)
99
- self.output_exprs.append(expr)
100
+ self.output_exprs.add(expr)
100
101
 
101
102
  # default eval ctx: all output exprs
102
103
  self.default_eval_ctx = self.create_eval_ctx(list(self.output_exprs), exclude=unique_input_exprs)
@@ -193,7 +194,7 @@ class RowBuilder:
193
194
  expr.components[i] = self._record_unique_expr(c, True)
194
195
  assert expr.slot_idx is None
195
196
  expr.slot_idx = self._next_slot_idx()
196
- self.unique_exprs.append(expr)
197
+ self.unique_exprs.add(expr)
197
198
  return expr
198
199
 
199
200
  def _record_output_expr_id(self, e: Expr, output_expr_id: int) -> None:
@@ -227,18 +228,25 @@ class RowBuilder:
227
228
  # merge dependencies and convert to list
228
229
  return sorted(set().union(*[dependencies[i] for i in target_slot_idxs]))
229
230
 
230
- def substitute_exprs(self, expr_list: list, remove_duplicates: bool = True) -> None:
231
- """Substitutes exprs with their executable counterparts from unique_exprs and optionally removes duplicates"""
232
- i = 0
233
- unique_ids: set[int] = set() # slot idxs within expr_list
234
- while i < len(expr_list):
235
- unique_expr = self.unique_exprs[expr_list[i]]
236
- if unique_expr.slot_idx in unique_ids and remove_duplicates:
237
- del expr_list[i]
238
- else:
239
- expr_list[i] = unique_expr
240
- unique_ids.add(unique_expr.slot_idx)
241
- i += 1
231
+ def set_slot_idxs(self, expr_list: Sequence[Expr], remove_duplicates: bool = True) -> None:
232
+ """
233
+ Recursively sets slot_idx in expr_list and its components
234
+
235
+ remove_duplicates == True: removes duplicates in-place
236
+ """
237
+ for e in expr_list:
238
+ self.__set_slot_idxs_aux(e)
239
+ if remove_duplicates:
240
+ deduped = list(ExprSet(expr_list))
241
+ expr_list[:] = deduped
242
+
243
+ def __set_slot_idxs_aux(self, e: Expr) -> None:
244
+ """Recursively sets slot_idx in e and its components"""
245
+ if e not in self.unique_exprs:
246
+ return
247
+ e.slot_idx = self.unique_exprs[e].slot_idx
248
+ for c in e.components:
249
+ self.__set_slot_idxs_aux(c)
242
250
 
243
251
  def get_dependencies(self, targets: List[Expr], exclude: Optional[List[Expr]] = None) -> List[Expr]:
244
252
  """
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
  from typing import Optional, List, Any, Dict, Tuple
3
+ from .sql_element_cache import SqlElementCache
3
4
  from uuid import UUID
4
5
 
5
6
  import sqlalchemy as sql
@@ -72,7 +73,7 @@ class RowidRef(Expr):
72
73
  self.tbl = tbl.tbl_version
73
74
  self.tbl_id = self.tbl.id
74
75
 
75
- def sql_expr(self) -> Optional[sql.ClauseElement]:
76
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ClauseElement]:
76
77
  tbl = self.tbl if self.tbl is not None else catalog.Catalog.get().tbl_versions[(self.tbl_id, None)]
77
78
  rowid_cols = tbl.store_tbl.rowid_columns()
78
79
  return rowid_cols[self.rowid_component_idx]
@@ -1,4 +1,5 @@
1
1
  from typing import Optional, List, Any
2
+ from .sql_element_cache import SqlElementCache
2
3
 
3
4
  import sqlalchemy as sql
4
5
  import PIL.Image
@@ -56,7 +57,7 @@ class SimilarityExpr(Expr):
56
57
  def __str__(self) -> str:
57
58
  return f'{self.components[0]}.similarity({self.components[1]})'
58
59
 
59
- def sql_expr(self) -> Optional[sql.ClauseElement]:
60
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ClauseElement]:
60
61
  if not isinstance(self.components[1], Literal):
61
62
  raise excs.Error(f'similarity(): requires a string or a PIL.Image.Image object, not an expression')
62
63
  item = self.components[1].val
@@ -0,0 +1,30 @@
1
+ from typing import Iterable, Union, Optional
2
+
3
+ import sqlalchemy as sql
4
+
5
+ from .expr import Expr
6
+
7
+
8
+ class SqlElementCache:
9
+ """Cache of sql.ColumnElements for exprs"""
10
+
11
+ cache: dict[int, Optional[sql.ColumnElement]] # key: Expr.id
12
+
13
+ def __init__(self):
14
+ self.cache = {}
15
+
16
+ def get(self, e: Expr) -> Optional[sql.ColumnElement]:
17
+ """Returns the sql.ColumnElement for the given Expr, or None if Expr.to_sql() returns None."""
18
+ try:
19
+ return self.cache[e.id]
20
+ except KeyError:
21
+ pass
22
+ el = e.sql_expr(self)
23
+ self.cache[e.id] = el
24
+ return el
25
+
26
+ def contains(self, items: Union[Expr, Iterable[Expr]]) -> bool:
27
+ """Returns True if every item has a (non-None) sql.ColumnElement."""
28
+ if isinstance(items, Expr):
29
+ return self.get(items) is not None
30
+ return all(self.get(e) is not None for e in items)
@@ -1,4 +1,3 @@
1
- import json
2
1
  from typing import Optional, Dict, List, Tuple, Any
3
2
 
4
3
  import sqlalchemy as sql
@@ -6,6 +5,7 @@ import sqlalchemy as sql
6
5
  import pixeltable.type_system as ts
7
6
  from .expr import DataRow, Expr
8
7
  from .row_builder import RowBuilder
8
+ from .sql_element_cache import SqlElementCache
9
9
 
10
10
 
11
11
  class TypeCast(Expr):
@@ -29,9 +29,9 @@ class TypeCast(Expr):
29
29
  def _id_attrs(self) -> List[Tuple[str, Any]]:
30
30
  return super()._id_attrs() + [('new_type', self.col_type)]
31
31
 
32
- def sql_expr(self) -> Optional[sql.ClauseElement]:
32
+ def sql_expr(self, _: SqlElementCache) -> Optional[sql.ClauseElement]:
33
33
  """
34
- `sql_expr` is unimplemented for now, in order to sidestep potentially thorny
34
+ sql_expr() is unimplemented for now, in order to sidestep potentially thorny
35
35
  questions about consistency of doing type conversions in both Python and Postgres.
36
36
  """
37
37
  return None
@@ -6,6 +6,7 @@ import pixeltable.type_system as ts
6
6
  from .data_row import DataRow
7
7
  from .expr import Expr
8
8
  from .row_builder import RowBuilder
9
+ from .sql_element_cache import SqlElementCache
9
10
 
10
11
 
11
12
  class Variable(Expr):
@@ -31,7 +32,7 @@ class Variable(Expr):
31
32
  def __str__(self) -> str:
32
33
  return self.name
33
34
 
34
- def sql_expr(self) -> NoReturn:
35
+ def sql_expr(self, _: SqlElementCache) -> NoReturn:
35
36
  raise NotImplementedError()
36
37
 
37
38
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> NoReturn:
@@ -1,9 +1,9 @@
1
- from typing import Optional, TYPE_CHECKING
1
+ from typing import TYPE_CHECKING, Optional
2
2
 
3
3
  from pixeltable.utils.code import local_public_names
4
4
 
5
5
  if TYPE_CHECKING:
6
- from whisperx.asr import FasterWhisperPipeline
6
+ from whisperx.asr import FasterWhisperPipeline # type: ignore[import-untyped]
7
7
 
8
8
  import pixeltable as pxt
9
9
 
@@ -40,7 +40,7 @@ def transcribe(
40
40
  >>> tbl['result'] = transcribe(tbl.audio, model='tiny.en')
41
41
  """
42
42
  import torch
43
- import whisperx
43
+ import whisperx # type: ignore[import-untyped]
44
44
 
45
45
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
46
  compute_type = compute_type or ('float16' if device == 'cuda' else 'int8')
@@ -60,7 +60,7 @@ def _lookup_model(model_id: str, device: str, compute_type: str) -> 'FasterWhisp
60
60
  return _model_cache[key]
61
61
 
62
62
 
63
- _model_cache = {}
63
+ _model_cache: dict[tuple[str, str, str], 'FasterWhisperPipeline'] = {}
64
64
 
65
65
 
66
66
  __all__ = local_public_names(__name__)
@@ -1,10 +1,10 @@
1
1
  import logging
2
2
  from pathlib import Path
3
- from typing import Iterable, Iterator, TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, Iterable, Iterator
4
4
  from urllib.request import urlretrieve
5
5
 
6
- import PIL.Image
7
6
  import numpy as np
7
+ import PIL.Image
8
8
 
9
9
  import pixeltable as pxt
10
10
  from pixeltable import env
@@ -14,8 +14,8 @@ from pixeltable.utils.code import local_public_names
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  import torch
17
- from yolox.exp import Exp
18
- from yolox.models import YOLOX
17
+ from yolox.exp import Exp # type: ignore[import-untyped]
18
+ from yolox.models import YOLOX # type: ignore[import-untyped]
19
19
 
20
20
  _logger = logging.getLogger('pixeltable')
21
21
 
@@ -47,7 +47,7 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
47
47
  >>> tbl['detections'] = yolox(tbl.image, model_id='yolox_m', threshold=0.8)
48
48
  """
49
49
  import torch
50
- from yolox.utils import postprocess
50
+ from yolox.utils import postprocess # type: ignore[import-untyped]
51
51
 
52
52
  model, exp = _lookup_model(model_id, 'cpu')
53
53
  image_tensors = list(_images_to_tensors(images, exp))
@@ -107,7 +107,7 @@ def yolo_to_coco(detections: dict) -> list:
107
107
 
108
108
  def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: 'Exp') -> Iterator['torch.Tensor']:
109
109
  import torch
110
- from yolox.data import ValTransform
110
+ from yolox.data import ValTransform # type: ignore[import-untyped]
111
111
 
112
112
  _val_transform = ValTransform(legacy=False)
113
113
  for image in images:
@@ -18,6 +18,7 @@ if TYPE_CHECKING:
18
18
  class Aggregator(abc.ABC):
19
19
  def update(self, *args: Any, **kwargs: Any) -> None:
20
20
  pass
21
+
21
22
  def value(self) -> Any:
22
23
  pass
23
24
 
@@ -5,10 +5,10 @@ import importlib
5
5
  import inspect
6
6
  from typing import Any, Callable, Dict, Optional, Tuple
7
7
 
8
+ import sqlalchemy as sql
9
+
8
10
  import pixeltable
9
- import pixeltable.exceptions as excs
10
11
  import pixeltable.type_system as ts
11
-
12
12
  from .globals import resolve_symbol
13
13
  from .signature import Signature
14
14
 
@@ -21,14 +21,29 @@ class Function(abc.ABC):
21
21
  via the member self_path.
22
22
  """
23
23
 
24
- def __init__(self, signature: Signature, self_path: Optional[str] = None, is_method: bool = False, is_property: bool = False):
24
+ signature: Signature
25
+ self_path: Optional[str]
26
+ is_method: bool
27
+ is_property: bool
28
+ _conditional_return_type: Optional[Callable[..., ts.ColumnType]]
29
+
30
+ # Translates a call to this function with the given arguments to its SQLAlchemy equivalent.
31
+ # Overriden for specific Function instances via the to_sql() decorator. The override must accept the same
32
+ # parameter names as the original function. Each parameter is going to be of type sql.ColumnElement.
33
+ _to_sql: Callable[..., Optional[sql.ColumnElement]]
34
+
35
+
36
+ def __init__(
37
+ self, signature: Signature, self_path: Optional[str] = None, is_method: bool = False, is_property: bool = False
38
+ ):
25
39
  # Check that stored functions cannot be declared using `is_method` or `is_property`:
26
40
  assert not ((is_method or is_property) and self_path is None)
27
41
  self.signature = signature
28
42
  self.self_path = self_path # fully-qualified path to self
29
43
  self.is_method = is_method
30
44
  self.is_property = is_property
31
- self._conditional_return_type: Optional[Callable[..., ts.ColumnType]] = None
45
+ self._conditional_return_type = None
46
+ self._to_sql = self.__default_to_sql
32
47
 
33
48
  @property
34
49
  def name(self) -> str:
@@ -88,6 +103,15 @@ class Function(abc.ABC):
88
103
  """Execute the function with the given arguments and return the result."""
89
104
  pass
90
105
 
106
+ def to_sql(self, fn: Callable[..., Optional[sql.ColumnElement]]) -> Callable[..., Optional[sql.ColumnElement]]:
107
+ """Instance decorator for specifying the SQL translation of this function"""
108
+ self._to_sql = fn
109
+ return fn
110
+
111
+ def __default_to_sql(self, *args: Any, **kwargs: Any) -> Optional[sql.ColumnElement]:
112
+ """The default implementation of SQL translation, which provides no translation"""
113
+ return None
114
+
91
115
  def __eq__(self, other: object) -> bool:
92
116
  if not isinstance(other, self.__class__):
93
117
  return False
@@ -1,7 +1,9 @@
1
- from . import audio, fireworks, huggingface, image, json, openai, string, timestamp, together, video, vision
2
- from .globals import *
3
1
  from pixeltable.utils.code import local_public_names
4
2
 
3
+ from . import (anthropic, audio, fireworks, huggingface, image, json, mistralai, openai, string, timestamp, together,
4
+ video, vision)
5
+ from .globals import *
6
+
5
7
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
6
8
 
7
9
 
@@ -0,0 +1,107 @@
1
+ """
2
+ Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
3
+ that wrap various endpoints from the Anthropic API. In order to use them, you must
4
+ first `pip install anthropic` and configure your Anthropic credentials, as described in
5
+ the [Working with Anthropic](https://pixeltable.readme.io/docs/working-with-anthropic) tutorial.
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
9
+
10
+ import tenacity
11
+
12
+ import pixeltable as pxt
13
+ from pixeltable import env
14
+ from pixeltable.utils.code import local_public_names
15
+
16
+ if TYPE_CHECKING:
17
+ import anthropic
18
+
19
+
20
+ @env.register_client('anthropic')
21
+ def _(api_key: str) -> 'anthropic.Anthropic':
22
+ import anthropic
23
+ return anthropic.Anthropic(api_key=api_key)
24
+
25
+
26
+ def _anthropic_client() -> 'anthropic.Anthropic':
27
+ return env.Env.get().get_client('anthropic')
28
+
29
+
30
+ def _retry(fn: Callable) -> Callable:
31
+ import anthropic
32
+ return tenacity.retry(
33
+ retry=tenacity.retry_if_exception_type(anthropic.RateLimitError),
34
+ wait=tenacity.wait_random_exponential(multiplier=1, max=60),
35
+ stop=tenacity.stop_after_attempt(20),
36
+ )(fn)
37
+
38
+
39
+ @pxt.udf
40
+ def messages(
41
+ messages: list[dict[str, str]],
42
+ *,
43
+ model: str,
44
+ max_tokens: int = 1024,
45
+ metadata: Optional[dict[str, Any]] = None,
46
+ stop_sequences: Optional[list[str]] = None,
47
+ system: Optional[str] = None,
48
+ temperature: Optional[float] = None,
49
+ tool_choice: Optional[list[dict]] = None,
50
+ tools: Optional[dict] = None,
51
+ top_k: Optional[int] = None,
52
+ top_p: Optional[float] = None,
53
+ ) -> dict:
54
+ """
55
+ Create a Message.
56
+
57
+ Equivalent to the Anthropic `messages` API endpoint.
58
+ For additional details, see: <https://docs.anthropic.com/en/api/messages>
59
+
60
+ __Requirements:__
61
+
62
+ - `pip install anthropic`
63
+
64
+ Args:
65
+ messages: Input messages.
66
+ model: The model that will complete your prompt.
67
+
68
+ For details on the other parameters, see: <https://docs.anthropic.com/en/api/messages>
69
+
70
+ Returns:
71
+ A dictionary containing the response and other metadata.
72
+
73
+ Examples:
74
+ Add a computed column that applies the model `claude-3-haiku-20240307`
75
+ to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
76
+
77
+ >>> msgs = [{'role': 'user', 'content': tbl.prompt}]
78
+ ... tbl['response'] = messages(msgs, model='claude-3-haiku-20240307')
79
+ """
80
+ return _retry(_anthropic_client().messages.create)(
81
+ messages=messages,
82
+ model=model,
83
+ max_tokens=max_tokens,
84
+ metadata=_opt(metadata),
85
+ stop_sequences=_opt(stop_sequences),
86
+ system=_opt(system),
87
+ temperature=_opt(temperature),
88
+ tool_choice=_opt(tool_choice),
89
+ tools=_opt(tools),
90
+ top_k=_opt(top_k),
91
+ top_p=_opt(top_p),
92
+ ).dict()
93
+
94
+
95
+ _T = TypeVar('_T')
96
+
97
+
98
+ def _opt(arg: _T) -> Union[_T, 'anthropic.NotGiven']:
99
+ import anthropic
100
+ return arg if arg is not None else anthropic.NOT_GIVEN
101
+
102
+
103
+ __all__ = local_public_names(__name__)
104
+
105
+
106
+ def __dir__():
107
+ return __all__
@@ -12,7 +12,7 @@ from pixeltable import env
12
12
  from pixeltable.utils.code import local_public_names
13
13
 
14
14
  if TYPE_CHECKING:
15
- import fireworks.client
15
+ import fireworks.client # type: ignore[import-untyped]
16
16
 
17
17
 
18
18
  @env.register_client('fireworks')
@@ -60,7 +60,7 @@ def chat_completions(
60
60
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
61
61
 
62
62
  >>> messages = [{'role': 'user', 'content': tbl.prompt}]
63
- ... tbl['response'] = chat_completions(tbl.prompt, model='accounts/fireworks/models/mixtral-8x22b-instruct')
63
+ ... tbl['response'] = chat_completions(messages, model='accounts/fireworks/models/mixtral-8x22b-instruct')
64
64
  """
65
65
  kwargs = {'max_tokens': max_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature}
66
66
  kwargs_not_none = {k: v for k, v in kwargs.items() if v is not None}
@@ -1,4 +1,4 @@
1
- from typing import Optional, Union
1
+ from typing import Optional, Union, Any
2
2
 
3
3
  import pixeltable.func as func
4
4
  import pixeltable.type_system as ts
@@ -25,6 +25,11 @@ class sum(func.Aggregator):
25
25
  def value(self) -> Union[int, float]:
26
26
  return self.sum
27
27
 
28
+ # @sum.to_sql
29
+ # def _(val: 'sqlalchemy.ColumnElements') -> Optional['sqlalchemy.ColumnElements']:
30
+ # import sqlalchemy as sql
31
+ # return sql.sql.functions.sum(val)
32
+
28
33
 
29
34
  @func.uda(update_types=[ts.IntType()], value_type=ts.IntType(), allows_window=True, requires_order_by=False)
30
35
  class count(func.Aggregator):
@@ -332,8 +332,8 @@ def _lookup_processor(model_id: str, create: Callable[[str], T]) -> T:
332
332
  return _processor_cache[key]
333
333
 
334
334
 
335
- _model_cache = {}
336
- _processor_cache = {}
335
+ _model_cache: dict[tuple[str, Callable, Optional[str]], Any] = {}
336
+ _processor_cache: dict[tuple[str, Callable], Any] = {}
337
337
 
338
338
 
339
339
  __all__ = local_public_names(__name__)