pixeltable 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (94) hide show
  1. pixeltable/__init__.py +5 -3
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/__init__.py +1 -0
  4. pixeltable/catalog/catalog.py +335 -128
  5. pixeltable/catalog/column.py +21 -5
  6. pixeltable/catalog/dir.py +19 -6
  7. pixeltable/catalog/insertable_table.py +34 -37
  8. pixeltable/catalog/named_function.py +0 -4
  9. pixeltable/catalog/schema_object.py +28 -42
  10. pixeltable/catalog/table.py +195 -158
  11. pixeltable/catalog/table_version.py +187 -232
  12. pixeltable/catalog/table_version_handle.py +50 -0
  13. pixeltable/catalog/table_version_path.py +49 -33
  14. pixeltable/catalog/view.py +56 -96
  15. pixeltable/config.py +103 -0
  16. pixeltable/dataframe.py +90 -90
  17. pixeltable/env.py +98 -168
  18. pixeltable/exec/aggregation_node.py +5 -4
  19. pixeltable/exec/cache_prefetch_node.py +1 -1
  20. pixeltable/exec/component_iteration_node.py +13 -9
  21. pixeltable/exec/data_row_batch.py +3 -3
  22. pixeltable/exec/exec_context.py +0 -4
  23. pixeltable/exec/exec_node.py +3 -2
  24. pixeltable/exec/expr_eval/schedulers.py +2 -1
  25. pixeltable/exec/in_memory_data_node.py +9 -4
  26. pixeltable/exec/row_update_node.py +1 -2
  27. pixeltable/exec/sql_node.py +20 -16
  28. pixeltable/exprs/column_ref.py +9 -9
  29. pixeltable/exprs/comparison.py +1 -1
  30. pixeltable/exprs/data_row.py +4 -4
  31. pixeltable/exprs/expr.py +20 -5
  32. pixeltable/exprs/function_call.py +98 -58
  33. pixeltable/exprs/json_mapper.py +25 -8
  34. pixeltable/exprs/json_path.py +6 -5
  35. pixeltable/exprs/object_ref.py +16 -5
  36. pixeltable/exprs/row_builder.py +15 -15
  37. pixeltable/exprs/rowid_ref.py +21 -7
  38. pixeltable/func/__init__.py +1 -1
  39. pixeltable/func/function.py +38 -6
  40. pixeltable/func/query_template_function.py +3 -6
  41. pixeltable/func/tools.py +26 -26
  42. pixeltable/func/udf.py +1 -1
  43. pixeltable/functions/__init__.py +2 -0
  44. pixeltable/functions/anthropic.py +9 -3
  45. pixeltable/functions/fireworks.py +7 -4
  46. pixeltable/functions/globals.py +4 -5
  47. pixeltable/functions/huggingface.py +1 -5
  48. pixeltable/functions/image.py +17 -7
  49. pixeltable/functions/llama_cpp.py +1 -1
  50. pixeltable/functions/mistralai.py +1 -1
  51. pixeltable/functions/ollama.py +4 -4
  52. pixeltable/functions/openai.py +26 -23
  53. pixeltable/functions/string.py +23 -30
  54. pixeltable/functions/timestamp.py +11 -6
  55. pixeltable/functions/together.py +14 -12
  56. pixeltable/functions/util.py +1 -1
  57. pixeltable/functions/video.py +5 -4
  58. pixeltable/functions/vision.py +6 -9
  59. pixeltable/functions/whisper.py +3 -3
  60. pixeltable/globals.py +246 -260
  61. pixeltable/index/__init__.py +2 -0
  62. pixeltable/index/base.py +1 -1
  63. pixeltable/index/btree.py +3 -1
  64. pixeltable/index/embedding_index.py +11 -5
  65. pixeltable/io/external_store.py +11 -12
  66. pixeltable/io/label_studio.py +4 -3
  67. pixeltable/io/parquet.py +57 -56
  68. pixeltable/iterators/__init__.py +4 -2
  69. pixeltable/iterators/audio.py +11 -11
  70. pixeltable/iterators/document.py +10 -10
  71. pixeltable/iterators/string.py +1 -2
  72. pixeltable/iterators/video.py +14 -15
  73. pixeltable/metadata/__init__.py +9 -5
  74. pixeltable/metadata/converters/convert_10.py +0 -1
  75. pixeltable/metadata/converters/convert_15.py +0 -2
  76. pixeltable/metadata/converters/convert_23.py +0 -2
  77. pixeltable/metadata/converters/convert_24.py +3 -3
  78. pixeltable/metadata/converters/convert_25.py +1 -1
  79. pixeltable/metadata/converters/convert_27.py +0 -2
  80. pixeltable/metadata/converters/convert_28.py +0 -2
  81. pixeltable/metadata/converters/convert_29.py +7 -8
  82. pixeltable/metadata/converters/util.py +7 -7
  83. pixeltable/metadata/schema.py +27 -19
  84. pixeltable/plan.py +68 -40
  85. pixeltable/share/packager.py +12 -9
  86. pixeltable/store.py +37 -38
  87. pixeltable/type_system.py +41 -28
  88. pixeltable/utils/filecache.py +2 -1
  89. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/METADATA +1 -1
  90. pixeltable-0.3.7.dist-info/RECORD +174 -0
  91. pixeltable-0.3.5.dist-info/RECORD +0 -172
  92. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/LICENSE +0 -0
  93. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/WHEEL +0 -0
  94. {pixeltable-0.3.5.dist-info → pixeltable-0.3.7.dist-info}/entry_points.txt +0 -0
@@ -43,11 +43,11 @@ class JsonPath(Expr):
43
43
  self.id = self._create_id()
44
44
 
45
45
  def __repr__(self) -> str:
46
- # else "R": the anchor is RELATIVE_PATH_ROOT
47
- return (
48
- f'{str(self._anchor) if self._anchor is not None else "R"}'
49
- f'{"." if isinstance(self.path_elements[0], str) else ""}{self._json_path()}'
50
- )
46
+ # else 'R': the anchor is RELATIVE_PATH_ROOT
47
+ anchor_str = str(self._anchor) if self._anchor is not None else 'R'
48
+ if len(self.path_elements) == 0:
49
+ return anchor_str
50
+ return f'{anchor_str}{"." if isinstance(self.path_elements[0], str) else ""}{self._json_path()}'
51
51
 
52
52
  def _as_dict(self) -> dict:
53
53
  path_elements = [[el.start, el.stop, el.step] if isinstance(el, slice) else el for el in self.path_elements]
@@ -158,6 +158,7 @@ class JsonPath(Expr):
158
158
  return ''.join(result)
159
159
 
160
160
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
161
+ assert self._anchor is not None, self
161
162
  val = data_row[self._anchor.slot_idx]
162
163
  if self.compiled_path is not None:
163
164
  val = self.compiled_path.search(val)
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Optional
3
+ from typing import Any, Optional
4
4
 
5
5
  import sqlalchemy as sql
6
6
 
@@ -26,14 +26,22 @@ class ObjectRef(Expr):
26
26
  self.owner = owner
27
27
  self.id = self._create_id()
28
28
 
29
+ def _id_attrs(self) -> list[tuple[str, Any]]:
30
+ # We have no components, so we can't rely on the default behavior here (otherwise, all ObjectRef
31
+ # instances will be conflated into a single slot).
32
+ return [('addr', id(self))]
33
+
34
+ def substitute(self, subs: dict[Expr, Expr]) -> Expr:
35
+ # Just return self; we need to avoid creating a new id after doing the substitution, because otherwise
36
+ # we'll wind up in a situation where the scope_anchor of the enclosing JsonMapper is different from the
37
+ # nested ObjectRefs inside its target_expr (and therefore occupies a different slot_idx).
38
+ return self
39
+
29
40
  def scope(self) -> ExprScope:
30
41
  return self._scope
31
42
 
32
- def __str__(self) -> str:
33
- assert False
34
-
35
43
  def _equals(self, other: ObjectRef) -> bool:
36
- return self.owner is other.owner
44
+ return self.id == other.id
37
45
 
38
46
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
39
47
  return None
@@ -41,3 +49,6 @@ class ObjectRef(Expr):
41
49
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
42
50
  # this will be called, but the value has already been materialized elsewhere
43
51
  pass
52
+
53
+ def __repr__(self) -> str:
54
+ return f'ObjectRef({self.owner}, {self.id}, {self.owner.id})'
@@ -7,17 +7,15 @@ from typing import Any, Iterable, Optional, Sequence
7
7
  from uuid import UUID
8
8
 
9
9
  import numpy as np
10
- import sqlalchemy as sql
11
10
 
12
11
  import pixeltable.catalog as catalog
13
12
  import pixeltable.exceptions as excs
14
- import pixeltable.func as func
15
13
  import pixeltable.utils as utils
16
14
  from pixeltable.env import Env
17
15
  from pixeltable.utils.media_store import MediaStore
18
16
 
19
17
  from .data_row import DataRow
20
- from .expr import Expr
18
+ from .expr import Expr, ExprScope
21
19
  from .expr_set import ExprSet
22
20
 
23
21
 
@@ -174,11 +172,13 @@ class RowBuilder:
174
172
 
175
173
  def refs_unstored_iter_col(col_ref: ColumnRef) -> bool:
176
174
  tbl = col_ref.col.tbl
177
- return tbl.is_component_view() and tbl.is_iterator_column(col_ref.col) and not col_ref.col.is_stored
175
+ return (
176
+ tbl.get().is_component_view and tbl.get().is_iterator_column(col_ref.col) and not col_ref.col.is_stored
177
+ )
178
178
 
179
179
  unstored_iter_col_refs = [col_ref for col_ref in col_refs if refs_unstored_iter_col(col_ref)]
180
180
  component_views = [col_ref.col.tbl for col_ref in unstored_iter_col_refs]
181
- unstored_iter_args = {view.id: view.iterator_args.copy() for view in component_views}
181
+ unstored_iter_args = {view.id: view.get().iterator_args.copy() for view in component_views}
182
182
  self.unstored_iter_args = {
183
183
  id: self._record_unique_expr(arg, recursive=True) for id, arg in unstored_iter_args.items()
184
184
  }
@@ -236,13 +236,6 @@ class RowBuilder:
236
236
  """Return ColumnSlotIdx for output columns"""
237
237
  return self.table_columns
238
238
 
239
- def set_conn(self, conn: sql.engine.Connection) -> None:
240
- from .function_call import FunctionCall
241
-
242
- for expr in self.unique_exprs:
243
- if isinstance(expr, FunctionCall) and isinstance(expr.fn, func.QueryTemplateFunction):
244
- expr.fn.set_conn(conn)
245
-
246
239
  @property
247
240
  def num_materialized(self) -> int:
248
241
  return self.next_slot_idx
@@ -299,6 +292,7 @@ class RowBuilder:
299
292
  # this is input and therefore doesn't depend on other exprs
300
293
  continue
301
294
  for d in expr.dependencies():
295
+ assert d.slot_idx is not None, f'{expr}, {d}'
302
296
  if d.slot_idx in excluded_slot_idxs:
303
297
  continue
304
298
  dependencies[expr.slot_idx].add(d.slot_idx)
@@ -376,7 +370,12 @@ class RowBuilder:
376
370
  data_row.set_exc(slot_idx, exc)
377
371
 
378
372
  def eval(
379
- self, data_row: DataRow, ctx: EvalCtx, profile: Optional[ExecProfile] = None, ignore_errors: bool = False
373
+ self,
374
+ data_row: DataRow,
375
+ ctx: EvalCtx,
376
+ profile: Optional[ExecProfile] = None,
377
+ ignore_errors: bool = False,
378
+ force_eval: Optional[ExprScope] = None,
380
379
  ) -> None:
381
380
  """
382
381
  Populates the slots in data_row given in ctx.
@@ -384,10 +383,11 @@ class RowBuilder:
384
383
  and omits any of that expr's dependents's eval().
385
384
  profile: if present, populated with execution time of each expr.eval() call; indexed by expr.slot_idx
386
385
  ignore_errors: if False, raises ExprEvalError if any expr.eval() raises an exception
386
+ force_eval: forces exprs in the specified scope to be reevaluated, even if they already have a value
387
387
  """
388
388
  for expr in ctx.exprs:
389
389
  assert expr.slot_idx >= 0
390
- if data_row.has_val[expr.slot_idx] or data_row.has_exc(expr.slot_idx):
390
+ if expr.scope() != force_eval and (data_row.has_val[expr.slot_idx] or data_row.has_exc(expr.slot_idx)):
391
391
  continue
392
392
  try:
393
393
  start_time = time.perf_counter()
@@ -425,7 +425,7 @@ class RowBuilder:
425
425
  else:
426
426
  if col.col_type.is_image_type() and data_row.file_urls[slot_idx] is None:
427
427
  # we have yet to store this image
428
- filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.version))
428
+ filepath = str(MediaStore.prepare_media_path(col.tbl.id, col.id, col.tbl.get().version))
429
429
  data_row.flush_img(slot_idx, filepath)
430
430
  val = data_row.get_stored_val(slot_idx, col.sa_col.type)
431
431
  table_row[col.store_name()] = val
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Optional
3
+ from typing import TYPE_CHECKING, Any, Optional, cast
4
4
  from uuid import UUID
5
5
 
6
6
  import sqlalchemy as sql
@@ -13,6 +13,9 @@ from .expr import Expr
13
13
  from .row_builder import RowBuilder
14
14
  from .sql_element_cache import SqlElementCache
15
15
 
16
+ if TYPE_CHECKING:
17
+ from pixeltable import store
18
+
16
19
 
17
20
  class RowidRef(Expr):
18
21
  """A reference to a part of a table rowid
@@ -23,9 +26,15 @@ class RowidRef(Expr):
23
26
  (with and without a TableVersion).
24
27
  """
25
28
 
29
+ tbl: Optional[catalog.TableVersionHandle]
30
+ normalized_base: Optional[catalog.TableVersionHandle]
31
+ tbl_id: UUID
32
+ normalized_base_id: UUID
33
+ rowid_component_idx: int
34
+
26
35
  def __init__(
27
36
  self,
28
- tbl: catalog.TableVersion,
37
+ tbl: catalog.TableVersionHandle,
29
38
  idx: int,
30
39
  tbl_id: Optional[UUID] = None,
31
40
  normalized_base_id: Optional[UUID] = None,
@@ -37,8 +46,8 @@ class RowidRef(Expr):
37
46
  # (which has the same values as all its descendent views)
38
47
  normalized_base = tbl
39
48
  # don't try to reference tbl.store_tbl here
40
- while normalized_base.base is not None and normalized_base.base.num_rowid_columns() > idx:
41
- normalized_base = normalized_base.base
49
+ while normalized_base.get().base is not None and normalized_base.get().base.get().num_rowid_columns() > idx:
50
+ normalized_base = normalized_base.get().base
42
51
  self.normalized_base = normalized_base
43
52
  else:
44
53
  self.normalized_base = None
@@ -66,8 +75,13 @@ class RowidRef(Expr):
66
75
 
67
76
  def __repr__(self) -> str:
68
77
  # check if this is the pos column of a component view
69
- tbl = self.tbl if self.tbl is not None else catalog.Catalog.get().tbl_versions[(self.tbl_id, None)]
70
- if tbl.is_component_view() and self.rowid_component_idx == tbl.store_tbl.pos_col_idx: # type: ignore[attr-defined]
78
+ from pixeltable import store
79
+
80
+ tbl = self.tbl.get() if self.tbl is not None else catalog.Catalog.get().get_tbl_version(self.tbl_id, None)
81
+ if (
82
+ tbl.is_component_view
83
+ and self.rowid_component_idx == cast(store.StoreComponentView, tbl.store_tbl).pos_col_idx
84
+ ):
71
85
  return catalog.globals._POS_COLUMN_NAME
72
86
  return ''
73
87
 
@@ -85,7 +99,7 @@ class RowidRef(Expr):
85
99
  self.tbl_id = self.tbl.id
86
100
 
87
101
  def sql_expr(self, _: SqlElementCache) -> Optional[sql.ColumnElement]:
88
- tbl = self.tbl if self.tbl is not None else catalog.Catalog.get().tbl_versions[(self.tbl_id, None)]
102
+ tbl = self.tbl.get() if self.tbl is not None else catalog.Catalog.get().get_tbl_version(self.tbl_id, None)
89
103
  rowid_cols = tbl.store_tbl.rowid_columns()
90
104
  return rowid_cols[self.rowid_component_idx]
91
105
 
@@ -1,7 +1,7 @@
1
1
  from .aggregate_function import AggregateFunction, Aggregator, uda
2
2
  from .callable_function import CallableFunction
3
3
  from .expr_template_function import ExprTemplateFunction
4
- from .function import Function
4
+ from .function import Function, InvalidFunction
5
5
  from .function_registry import FunctionRegistry
6
6
  from .query_template_function import QueryTemplateFunction, query
7
7
  from .signature import Batch, Parameter, Signature
@@ -62,7 +62,6 @@ class Function(ABC):
62
62
  # Check that stored functions cannot be declared using `is_method` or `is_property`:
63
63
  assert not ((is_method or is_property) and self_path is None)
64
64
  assert isinstance(signatures, list)
65
- assert len(signatures) > 0
66
65
  self.signatures = signatures
67
66
  self.self_path = self_path # fully-qualified path to self
68
67
  self.is_method = is_method
@@ -72,6 +71,10 @@ class Function(ABC):
72
71
  self._to_sql = self.__default_to_sql
73
72
  self._resource_pool = self.__default_resource_pool
74
73
 
74
+ @property
75
+ def is_valid(self) -> bool:
76
+ return len(self.signatures) > 0
77
+
75
78
  @property
76
79
  def name(self) -> str:
77
80
  assert self.self_path is not None
@@ -468,11 +471,18 @@ class Function(ABC):
468
471
  @classmethod
469
472
  def _from_dict(cls, d: dict) -> Function:
470
473
  """Default deserialization: load the symbol indicated by the stored symbol_path"""
471
- assert 'path' in d and d['path'] is not None
472
- assert 'signature' in d and d['signature'] is not None
473
- instance = resolve_symbol(d['path'])
474
- assert isinstance(instance, Function)
475
- return instance
474
+ path = d.get('path')
475
+ assert path is not None
476
+ try:
477
+ instance = resolve_symbol(path)
478
+ if isinstance(instance, Function):
479
+ return instance
480
+ else:
481
+ return InvalidFunction(
482
+ path, d, f'the symbol {path!r} is no longer a UDF. (Was the `@pxt.udf` decorator removed?)'
483
+ )
484
+ except (AttributeError, ImportError):
485
+ return InvalidFunction(path, d, f'the symbol {path!r} no longer exists. (Was the UDF moved or renamed?)')
476
486
 
477
487
  def to_store(self) -> tuple[dict, bytes]:
478
488
  """
@@ -490,3 +500,25 @@ class Function(ABC):
490
500
  Create a Function instance from the serialized representation returned by to_store()
491
501
  """
492
502
  raise NotImplementedError()
503
+
504
+
505
+ class InvalidFunction(Function):
506
+ fn_dict: dict[str, Any]
507
+ errormsg: str
508
+
509
+ def __init__(self, self_path: str, fn_dict: dict[str, Any], errormsg: str):
510
+ super().__init__([], self_path)
511
+ self.fn_dict = fn_dict
512
+ self.errormsg = errormsg
513
+
514
+ def _as_dict(self) -> dict:
515
+ """
516
+ Here we write out (verbatim) the original metadata that failed to load (and that resulted in the
517
+ InvalidFunction). Note that the InvalidFunction itself is never serlialized, so there is no corresponding
518
+ from_dict() method.
519
+ """
520
+ return self.fn_dict
521
+
522
+ @property
523
+ def is_async(self) -> bool:
524
+ return False
@@ -21,7 +21,7 @@ class QueryTemplateFunction(Function):
21
21
 
22
22
  template_df: Optional['DataFrame']
23
23
  self_name: Optional[str]
24
- conn: Optional[sql.engine.Connection]
24
+ # conn: Optional[sql.engine.Connection]
25
25
  defaults: dict[str, exprs.Literal]
26
26
 
27
27
  @classmethod
@@ -53,7 +53,7 @@ class QueryTemplateFunction(Function):
53
53
  # if we're running as part of an ongoing update operation, we need to use the same connection, otherwise
54
54
  # we end up with a deadlock
55
55
  # TODO: figure out a more general way to make execution state available
56
- self.conn = None
56
+ # self.conn = None
57
57
 
58
58
  # convert defaults to Literals
59
59
  self.defaults = {} # key: param name, value: default value converted to a Literal
@@ -67,9 +67,6 @@ class QueryTemplateFunction(Function):
67
67
  def _update_as_overload_resolution(self, signature_idx: int) -> None:
68
68
  pass # only one signature supported for QueryTemplateFunction
69
69
 
70
- def set_conn(self, conn: Optional[sql.engine.Connection]) -> None:
71
- self.conn = conn
72
-
73
70
  @property
74
71
  def is_async(self) -> bool:
75
72
  return True
@@ -82,7 +79,7 @@ class QueryTemplateFunction(Function):
82
79
  {param_name: default for param_name, default in self.defaults.items() if param_name not in bound_args}
83
80
  )
84
81
  bound_df = self.template_df.bind(bound_args)
85
- result = await bound_df._acollect(self.conn)
82
+ result = await bound_df._acollect()
86
83
  return list(result)
87
84
 
88
85
  @property
pixeltable/func/tools.py CHANGED
@@ -48,22 +48,27 @@ class Tool(pydantic.BaseModel):
48
48
  'additionalProperties': False, # TODO Handle kwargs?
49
49
  }
50
50
 
51
- # `tool_calls` must be in standardized tool invocation format:
52
- # {tool_name: {'args': {name1: value1, name2: value2, ...}}, ...}
53
- def invoke(self, tool_calls: 'exprs.Expr') -> 'exprs.FunctionCall':
54
- kwargs = {param.name: self.__extract_tool_arg(param, tool_calls) for param in self.parameters.values()}
55
- return self.fn(**kwargs)
51
+ # The output of `tool_calls` must be a dict in standardized tool invocation format:
52
+ # {tool_name: [{'args': {name1: value1, name2: value2, ...}}, ...], ...}
53
+ def invoke(self, tool_calls: 'exprs.Expr') -> 'exprs.Expr':
54
+ from pixeltable import exprs
56
55
 
57
- def __extract_tool_arg(self, param: Parameter, tool_calls: 'exprs.Expr') -> 'exprs.Expr':
58
56
  func_name = self.name or self.fn.name
57
+ return exprs.JsonMapper(tool_calls[func_name]['*'], self.__invoke_kwargs(exprs.RELATIVE_PATH_ROOT.args))
58
+
59
+ def __invoke_kwargs(self, kwargs: 'exprs.Expr') -> 'exprs.FunctionCall':
60
+ kwargs = {param.name: self.__extract_tool_arg(param, kwargs) for param in self.parameters.values()}
61
+ return self.fn(**kwargs)
62
+
63
+ def __extract_tool_arg(self, param: Parameter, kwargs: 'exprs.Expr') -> 'exprs.FunctionCall':
59
64
  if param.col_type.is_string_type():
60
- return _extract_str_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
65
+ return _extract_str_tool_arg(kwargs, param_name=param.name)
61
66
  if param.col_type.is_int_type():
62
- return _extract_int_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
67
+ return _extract_int_tool_arg(kwargs, param_name=param.name)
63
68
  if param.col_type.is_float_type():
64
- return _extract_float_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
69
+ return _extract_float_tool_arg(kwargs, param_name=param.name)
65
70
  if param.col_type.is_bool_type():
66
- return _extract_bool_tool_arg(tool_calls, func_name=func_name, param_name=param.name)
71
+ return _extract_bool_tool_arg(kwargs, param_name=param.name)
67
72
  assert False
68
73
 
69
74
 
@@ -113,34 +118,29 @@ class Tools(pydantic.BaseModel):
113
118
 
114
119
 
115
120
  @udf
116
- def _extract_str_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[str]:
117
- return _extract_arg(str, tool_calls, func_name, param_name)
121
+ def _extract_str_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[str]:
122
+ return _extract_arg(str, kwargs, param_name)
118
123
 
119
124
 
120
125
  @udf
121
- def _extract_int_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[int]:
122
- return _extract_arg(int, tool_calls, func_name, param_name)
126
+ def _extract_int_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[int]:
127
+ return _extract_arg(int, kwargs, param_name)
123
128
 
124
129
 
125
130
  @udf
126
- def _extract_float_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[float]:
127
- return _extract_arg(float, tool_calls, func_name, param_name)
131
+ def _extract_float_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[float]:
132
+ return _extract_arg(float, kwargs, param_name)
128
133
 
129
134
 
130
135
  @udf
131
- def _extract_bool_tool_arg(tool_calls: dict[str, Any], func_name: str, param_name: str) -> Optional[bool]:
132
- return _extract_arg(bool, tool_calls, func_name, param_name)
136
+ def _extract_bool_tool_arg(kwargs: dict[str, Any], param_name: str) -> Optional[bool]:
137
+ return _extract_arg(bool, kwargs, param_name)
133
138
 
134
139
 
135
140
  T = TypeVar('T')
136
141
 
137
142
 
138
- def _extract_arg(
139
- eval_fn: Callable[[Any], T], tool_calls: dict[str, Any], func_name: str, param_name: str
140
- ) -> Optional[T]:
141
- if func_name in tool_calls:
142
- arguments = tool_calls[func_name]['args']
143
- if param_name in arguments:
144
- return eval_fn(arguments[param_name])
145
- return None
143
+ def _extract_arg(eval_fn: Callable[[Any], T], kwargs: dict[str, Any], param_name: str) -> Optional[T]:
144
+ if param_name in kwargs:
145
+ return eval_fn(kwargs[param_name])
146
146
  return None
pixeltable/func/udf.py CHANGED
@@ -268,7 +268,7 @@ def from_table(
268
268
  params: list[Parameter] = []
269
269
 
270
270
  for t in ancestors:
271
- for name, col in t._tbl_version.cols_by_name.items():
271
+ for name, col in t._tbl_version.get().cols_by_name.items():
272
272
  assert name not in result_dict, f'Column name is not unique: {name}'
273
273
  if col.is_computed:
274
274
  # Computed column. Apply any existing substitutions and add the new expression to the subst dict.
@@ -1,3 +1,5 @@
1
+ # ruff: noqa: F401
2
+
1
3
  from pixeltable.utils.code import local_public_names
2
4
 
3
5
  from . import (
@@ -213,9 +213,15 @@ def invoke_tools(tools: Tools, response: exprs.Expr) -> exprs.InlineDict:
213
213
  @pxt.udf
214
214
  def _anthropic_response_to_pxt_tool_calls(response: dict) -> Optional[dict]:
215
215
  anthropic_tool_calls = [r for r in response['content'] if r['type'] == 'tool_use']
216
- if len(anthropic_tool_calls) > 0:
217
- return {tool_call['name']: {'args': tool_call['input']} for tool_call in anthropic_tool_calls}
218
- return None
216
+ if len(anthropic_tool_calls) == 0:
217
+ return None
218
+ pxt_tool_calls: dict[str, list[dict[str, Any]]] = {}
219
+ for tool_call in anthropic_tool_calls:
220
+ tool_name = tool_call['name']
221
+ if tool_name not in pxt_tool_calls:
222
+ pxt_tool_calls[tool_name] = []
223
+ pxt_tool_calls[tool_name].append({'args': tool_call['input']})
224
+ return pxt_tool_calls
219
225
 
220
226
 
221
227
  _T = TypeVar('_T')
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Optional
9
9
 
10
10
  import pixeltable as pxt
11
11
  from pixeltable import env
12
+ from pixeltable.config import Config
12
13
  from pixeltable.utils.code import local_public_names
13
14
 
14
15
  if TYPE_CHECKING:
@@ -41,7 +42,7 @@ async def chat_completions(
41
42
  Creates a model response for the given chat conversation.
42
43
 
43
44
  Equivalent to the Fireworks AI `chat/completions` API endpoint.
44
- For additional details, see: [https://docs.fireworks.ai/api-reference/post-chatcompletions](https://docs.fireworks.ai/api-reference/post-chatcompletions)
45
+ For additional details, see: <https://docs.fireworks.ai/api-reference/post-chatcompletions>
45
46
 
46
47
  Request throttling:
47
48
  Applies the rate limit set in the config (section `fireworks`, key `rate_limit`). If no rate
@@ -55,7 +56,7 @@ async def chat_completions(
55
56
  messages: A list of messages comprising the conversation so far.
56
57
  model: The name of the model to use.
57
58
 
58
- For details on the other parameters, see: [https://docs.fireworks.ai/api-reference/post-chatcompletions](https://docs.fireworks.ai/api-reference/post-chatcompletions)
59
+ For details on the other parameters, see: <https://docs.fireworks.ai/api-reference/post-chatcompletions>
59
60
 
60
61
  Returns:
61
62
  A dictionary containing the response and other metadata.
@@ -65,7 +66,9 @@ async def chat_completions(
65
66
  to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
66
67
 
67
68
  >>> messages = [{'role': 'user', 'content': tbl.prompt}]
68
- ... tbl.add_computed_column(response=chat_completions(messages, model='accounts/fireworks/models/mixtral-8x22b-instruct'))
69
+ ... tbl.add_computed_column(
70
+ ... response=chat_completions(messages, model='accounts/fireworks/models/mixtral-8x22b-instruct')
71
+ ... )
69
72
  """
70
73
  kwargs = {'max_tokens': max_tokens, 'top_k': top_k, 'top_p': top_p, 'temperature': temperature}
71
74
  kwargs_not_none = {k: v for k, v in kwargs.items() if v is not None}
@@ -75,7 +78,7 @@ async def chat_completions(
75
78
  # res_sync_dict = res_sync.dict()
76
79
 
77
80
  if request_timeout is None:
78
- request_timeout = env.Env.get().config.get_int_value('timeout', section='fireworks') or 600
81
+ request_timeout = Config.get().get_int_value('timeout', section='fireworks') or 600
79
82
  # TODO: this timeout doesn't really work, I think it only applies to returning the stream, but not to the timing
80
83
  # of the chunks; addressing this would require a timeout for the task running this udf
81
84
  stream = _fireworks_client().chat.completions.acreate(
@@ -6,9 +6,8 @@ from typing import Optional, Union
6
6
 
7
7
  import sqlalchemy as sql
8
8
 
9
- import pixeltable.func as func
10
9
  import pixeltable.type_system as ts
11
- from pixeltable import exprs
10
+ from pixeltable import exprs, func
12
11
  from pixeltable.utils.code import local_public_names
13
12
 
14
13
 
@@ -50,7 +49,6 @@ def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
50
49
  @func.uda(
51
50
  allows_window=True,
52
51
  # Allow counting non-null values of any type
53
- # TODO: I couldn't include "Array" because we don't have a way to represent a generic array (of arbitrary dimension).
54
52
  # TODO: should we have an "Any" type that can be used here?
55
53
  type_substitutions=tuple(
56
54
  {T: Optional[t]} # type: ignore[misc]
@@ -60,6 +58,7 @@ def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
60
58
  ts.Float,
61
59
  ts.Bool,
62
60
  ts.Timestamp,
61
+ ts.Array,
63
62
  ts.Json,
64
63
  ts.Image,
65
64
  ts.Video,
@@ -107,7 +106,7 @@ class min(func.Aggregator, typing.Generic[T]):
107
106
 
108
107
  @min.to_sql
109
108
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
110
- if val.type.python_type == bool:
109
+ if val.type.python_type is bool:
111
110
  # TODO: min/max aggregation of booleans is not supported in Postgres (but it is in Python).
112
111
  # Right now we simply force the computation to be done in Python; we might consider implementing an alternate
113
112
  # way of doing it in SQL. (min/max of booleans is simply logical and/or, respectively.)
@@ -137,7 +136,7 @@ class max(func.Aggregator, typing.Generic[T]):
137
136
 
138
137
  @max.to_sql
139
138
  def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
140
- if val.type.python_type == bool:
139
+ if val.type.python_type is bool:
141
140
  # TODO: see comment in @min.to_sql.
142
141
  return None
143
142
  return sql.sql.func.max(val)
@@ -12,8 +12,8 @@ from typing import Any, Callable, Optional, TypeVar
12
12
  import PIL.Image
13
13
 
14
14
  import pixeltable as pxt
15
- import pixeltable.env as env
16
15
  import pixeltable.exceptions as excs
16
+ from pixeltable import env
17
17
  from pixeltable.func import Batch
18
18
  from pixeltable.functions.util import normalize_image_mode, resolve_torch_device
19
19
  from pixeltable.utils.code import local_public_names
@@ -50,7 +50,6 @@ def sentence_transformer(
50
50
  """
51
51
  env.Env.get().require_package('sentence_transformers')
52
52
  device = resolve_torch_device('auto')
53
- import torch
54
53
  from sentence_transformers import SentenceTransformer # type: ignore
55
54
 
56
55
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
@@ -76,7 +75,6 @@ def _(model_id: str) -> pxt.ArrayType:
76
75
  def sentence_transformer_list(sentences: list, *, model_id: str, normalize_embeddings: bool = False) -> list:
77
76
  env.Env.get().require_package('sentence_transformers')
78
77
  device = resolve_torch_device('auto')
79
- import torch
80
78
  from sentence_transformers import SentenceTransformer
81
79
 
82
80
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
@@ -117,7 +115,6 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
117
115
  """
118
116
  env.Env.get().require_package('sentence_transformers')
119
117
  device = resolve_torch_device('auto')
120
- import torch
121
118
  from sentence_transformers import CrossEncoder
122
119
 
123
120
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)
@@ -132,7 +129,6 @@ def cross_encoder(sentences1: Batch[str], sentences2: Batch[str], *, model_id: s
132
129
  def cross_encoder_list(sentence1: str, sentences2: list, *, model_id: str) -> list:
133
130
  env.Env.get().require_package('sentence_transformers')
134
131
  device = resolve_torch_device('auto')
135
- import torch
136
132
  from sentence_transformers import CrossEncoder
137
133
 
138
134
  # specifying the device, moves the model to device (gpu:cuda/mps, cpu)