pixeltable 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (44) hide show
  1. pixeltable/__init__.py +1 -2
  2. pixeltable/__version__.py +2 -2
  3. pixeltable/catalog/column.py +5 -0
  4. pixeltable/catalog/globals.py +16 -0
  5. pixeltable/catalog/insertable_table.py +82 -41
  6. pixeltable/catalog/table.py +78 -55
  7. pixeltable/catalog/table_version.py +18 -3
  8. pixeltable/catalog/view.py +9 -2
  9. pixeltable/env.py +1 -1
  10. pixeltable/exec/exec_node.py +1 -1
  11. pixeltable/exprs/__init__.py +2 -1
  12. pixeltable/exprs/arithmetic_expr.py +2 -0
  13. pixeltable/exprs/column_ref.py +36 -0
  14. pixeltable/exprs/expr.py +39 -9
  15. pixeltable/exprs/globals.py +12 -0
  16. pixeltable/exprs/json_mapper.py +1 -1
  17. pixeltable/exprs/json_path.py +0 -6
  18. pixeltable/exprs/similarity_expr.py +5 -20
  19. pixeltable/exprs/string_op.py +107 -0
  20. pixeltable/ext/functions/yolox.py +21 -64
  21. pixeltable/func/tools.py +2 -2
  22. pixeltable/functions/__init__.py +1 -1
  23. pixeltable/functions/globals.py +16 -5
  24. pixeltable/globals.py +85 -33
  25. pixeltable/io/__init__.py +3 -2
  26. pixeltable/io/datarows.py +138 -0
  27. pixeltable/io/external_store.py +8 -5
  28. pixeltable/io/globals.py +7 -160
  29. pixeltable/io/hf_datasets.py +21 -98
  30. pixeltable/io/pandas.py +29 -43
  31. pixeltable/io/parquet.py +17 -42
  32. pixeltable/io/table_data_conduit.py +569 -0
  33. pixeltable/io/utils.py +6 -21
  34. pixeltable/metadata/__init__.py +1 -1
  35. pixeltable/metadata/converters/convert_30.py +50 -0
  36. pixeltable/metadata/converters/util.py +26 -1
  37. pixeltable/metadata/notes.py +1 -0
  38. pixeltable/metadata/schema.py +3 -0
  39. pixeltable/utils/arrow.py +32 -7
  40. {pixeltable-0.3.9.dist-info → pixeltable-0.3.10.dist-info}/METADATA +1 -1
  41. {pixeltable-0.3.9.dist-info → pixeltable-0.3.10.dist-info}/RECORD +44 -40
  42. {pixeltable-0.3.9.dist-info → pixeltable-0.3.10.dist-info}/WHEEL +1 -1
  43. {pixeltable-0.3.9.dist-info → pixeltable-0.3.10.dist-info}/LICENSE +0 -0
  44. {pixeltable-0.3.9.dist-info → pixeltable-0.3.10.dist-info}/entry_points.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  from typing import Any, Optional, Sequence
4
5
  from uuid import UUID
5
6
 
@@ -125,11 +126,46 @@ class ColumnRef(Expr):
125
126
 
126
127
  return super().__getattr__(name)
127
128
 
129
+ @classmethod
130
+ def find_embedding_index(
131
+ cls, col: catalog.Column, idx_name: Optional[str], method_name: str
132
+ ) -> dict[str, catalog.TableVersion.IndexInfo]:
133
+ """Return IndexInfo for a column, with an optional given name"""
134
+ # determine index to use
135
+ idx_info_dict = col.get_idx_info()
136
+ from pixeltable import index
137
+
138
+ embedding_idx_info = {
139
+ info: value for info, value in idx_info_dict.items() if isinstance(value.idx, index.EmbeddingIndex)
140
+ }
141
+ if len(embedding_idx_info) == 0:
142
+ raise excs.Error(f'No indices found for {method_name!r} on column {col.name!r}')
143
+ if idx_name is not None and idx_name not in embedding_idx_info:
144
+ raise excs.Error(f'Index {idx_name!r} not found for {method_name!r} on column {col.name!r}')
145
+ if len(embedding_idx_info) > 1:
146
+ if idx_name is None:
147
+ raise excs.Error(
148
+ f'Column {col.name!r} has multiple indices; use the index name to disambiguate: '
149
+ f'`{method_name}(..., idx=<index_name>)`'
150
+ )
151
+ idx_info = {idx_name: embedding_idx_info[idx_name]}
152
+ else:
153
+ idx_info = embedding_idx_info
154
+ return idx_info
155
+
128
156
  def similarity(self, item: Any, *, idx: Optional[str] = None) -> Expr:
129
157
  from .similarity_expr import SimilarityExpr
130
158
 
131
159
  return SimilarityExpr(self, item, idx_name=idx)
132
160
 
161
+ def embedding(self, *, idx: Optional[str] = None) -> ColumnRef:
162
+ idx_info = ColumnRef.find_embedding_index(self.col, idx, 'embedding')
163
+ assert len(idx_info) == 1
164
+ col = copy.copy(next(iter(idx_info.values())).val_col)
165
+ col.name = f'{self.col.name}_embedding_{idx if idx is not None else ""}'
166
+ col.create_sa_cols()
167
+ return ColumnRef(col)
168
+
133
169
  def default_column_name(self) -> Optional[str]:
134
170
  return str(self)
135
171
 
pixeltable/exprs/expr.py CHANGED
@@ -17,7 +17,7 @@ from typing_extensions import Self, _AnnotatedAlias
17
17
  from pixeltable import catalog, exceptions as excs, func, type_system as ts
18
18
 
19
19
  from .data_row import DataRow
20
- from .globals import ArithmeticOperator, ComparisonOperator, LiteralPythonTypes, LogicalOperator
20
+ from .globals import ArithmeticOperator, ComparisonOperator, LiteralPythonTypes, LogicalOperator, StringOperator
21
21
 
22
22
  if TYPE_CHECKING:
23
23
  from pixeltable import exprs
@@ -605,10 +605,6 @@ class Expr(abc.ABC):
605
605
  # Return the `MethodRef` object itself; it requires arguments to become a `FunctionCall`
606
606
  return method_ref
607
607
 
608
- def __rshift__(self, other: object) -> 'exprs.Expr':
609
- # Implemented here for type-checking purposes
610
- raise excs.Error('The `>>` operator can only be applied to Json expressions')
611
-
612
608
  def __bool__(self) -> bool:
613
609
  raise TypeError(
614
610
  f'Pixeltable expressions cannot be used in conjunction with Python boolean operators (and/or/not)\n{self!r}'
@@ -658,13 +654,17 @@ class Expr(abc.ABC):
658
654
  def __neg__(self) -> 'exprs.ArithmeticExpr':
659
655
  return self._make_arithmetic_expr(ArithmeticOperator.MUL, -1)
660
656
 
661
- def __add__(self, other: object) -> 'exprs.ArithmeticExpr':
657
+ def __add__(self, other: object) -> Union[exprs.ArithmeticExpr, exprs.StringOp]:
658
+ if isinstance(self, str) or (isinstance(self, Expr) and self.col_type.is_string_type()):
659
+ return self._make_string_expr(StringOperator.CONCAT, other)
662
660
  return self._make_arithmetic_expr(ArithmeticOperator.ADD, other)
663
661
 
664
662
  def __sub__(self, other: object) -> 'exprs.ArithmeticExpr':
665
663
  return self._make_arithmetic_expr(ArithmeticOperator.SUB, other)
666
664
 
667
- def __mul__(self, other: object) -> 'exprs.ArithmeticExpr':
665
+ def __mul__(self, other: object) -> Union['exprs.ArithmeticExpr', 'exprs.StringOp']:
666
+ if isinstance(self, str) or (isinstance(self, Expr) and self.col_type.is_string_type()):
667
+ return self._make_string_expr(StringOperator.REPEAT, other)
668
668
  return self._make_arithmetic_expr(ArithmeticOperator.MUL, other)
669
669
 
670
670
  def __truediv__(self, other: object) -> 'exprs.ArithmeticExpr':
@@ -676,13 +676,17 @@ class Expr(abc.ABC):
676
676
  def __floordiv__(self, other: object) -> 'exprs.ArithmeticExpr':
677
677
  return self._make_arithmetic_expr(ArithmeticOperator.FLOORDIV, other)
678
678
 
679
- def __radd__(self, other: object) -> 'exprs.ArithmeticExpr':
679
+ def __radd__(self, other: object) -> Union['exprs.ArithmeticExpr', 'exprs.StringOp']:
680
+ if isinstance(other, str) or (isinstance(other, Expr) and other.col_type.is_string_type()):
681
+ return self._rmake_string_expr(StringOperator.CONCAT, other)
680
682
  return self._rmake_arithmetic_expr(ArithmeticOperator.ADD, other)
681
683
 
682
684
  def __rsub__(self, other: object) -> 'exprs.ArithmeticExpr':
683
685
  return self._rmake_arithmetic_expr(ArithmeticOperator.SUB, other)
684
686
 
685
- def __rmul__(self, other: object) -> 'exprs.ArithmeticExpr':
687
+ def __rmul__(self, other: object) -> Union['exprs.ArithmeticExpr', 'exprs.StringOp']:
688
+ if isinstance(other, str) or (isinstance(other, Expr) and other.col_type.is_string_type()):
689
+ return self._rmake_string_expr(StringOperator.REPEAT, other)
686
690
  return self._rmake_arithmetic_expr(ArithmeticOperator.MUL, other)
687
691
 
688
692
  def __rtruediv__(self, other: object) -> 'exprs.ArithmeticExpr':
@@ -694,6 +698,32 @@ class Expr(abc.ABC):
694
698
  def __rfloordiv__(self, other: object) -> 'exprs.ArithmeticExpr':
695
699
  return self._rmake_arithmetic_expr(ArithmeticOperator.FLOORDIV, other)
696
700
 
701
+ def _make_string_expr(self, op: StringOperator, other: object) -> 'exprs.StringOp':
702
+ """
703
+ Make left-handed version of string expression.
704
+ """
705
+ from .literal import Literal
706
+ from .string_op import StringOp
707
+
708
+ if isinstance(other, Expr):
709
+ return StringOp(op, self, other)
710
+ if isinstance(other, typing.get_args(LiteralPythonTypes)):
711
+ return StringOp(op, self, Literal(other))
712
+ raise TypeError(f'Other must be Expr or literal: {type(other)}')
713
+
714
+ def _rmake_string_expr(self, op: StringOperator, other: object) -> 'exprs.StringOp':
715
+ """
716
+ Right-handed version of _make_string_expr. other must be a literal; if it were an Expr,
717
+ the operation would have already been evaluated in its left-handed form.
718
+ """
719
+ from .literal import Literal
720
+ from .string_op import StringOp
721
+
722
+ assert not isinstance(other, Expr) # Else the left-handed form would have evaluated first
723
+ if isinstance(other, typing.get_args(LiteralPythonTypes)):
724
+ return StringOp(op, Literal(other), self)
725
+ raise TypeError(f'Other must be Expr or literal: {type(other)}')
726
+
697
727
  def _make_arithmetic_expr(self, op: ArithmeticOperator, other: object) -> 'exprs.ArithmeticExpr':
698
728
  """
699
729
  other: Union[Expr, LiteralPythonTypes]
@@ -87,3 +87,15 @@ class ArithmeticOperator(enum.Enum):
87
87
  if self == self.FLOORDIV:
88
88
  return '//'
89
89
  raise AssertionError()
90
+
91
+
92
+ class StringOperator(enum.Enum):
93
+ CONCAT = 0
94
+ REPEAT = 1
95
+
96
+ def __str__(self) -> str:
97
+ if self == self.CONCAT:
98
+ return '+'
99
+ if self == self.REPEAT:
100
+ return '*'
101
+ raise AssertionError()
@@ -86,7 +86,7 @@ class JsonMapper(Expr):
86
86
  return self._src_expr.equals(other._src_expr) and self._target_expr.equals(other._target_expr)
87
87
 
88
88
  def __repr__(self) -> str:
89
- return f'{self._src_expr} >> {self._target_expr}'
89
+ return f'map({self._src_expr}, lambda R: {self._target_expr})'
90
90
 
91
91
  @property
92
92
  def _src_expr(self) -> Expr:
@@ -110,12 +110,6 @@ class JsonPath(Expr):
110
110
  return JsonPath(self._anchor, [*self.path_elements, index])
111
111
  raise excs.Error(f'Invalid json list index: {index}')
112
112
 
113
- def __rshift__(self, other: object) -> 'JsonMapper':
114
- rhs_expr = Expr.from_object(other)
115
- if rhs_expr is None:
116
- raise excs.Error(f'>> requires an expression on the right-hand side, found {type(other)}')
117
- return JsonMapper(self, rhs_expr)
118
-
119
113
  def default_column_name(self) -> Optional[str]:
120
114
  anchor_name = self._anchor.default_column_name() if self._anchor is not None else ''
121
115
  ret_name = f'{anchor_name}.{self._json_path()}'
@@ -23,26 +23,12 @@ class SimilarityExpr(Expr):
23
23
 
24
24
  self.components = [col_ref, item_expr]
25
25
 
26
- # determine index to use
27
- idx_info = col_ref.col.get_idx_info()
28
26
  from pixeltable import index
29
27
 
30
- embedding_idx_info = {
31
- info.name: info for info in idx_info.values() if isinstance(info.idx, index.EmbeddingIndex)
32
- }
33
- if len(embedding_idx_info) == 0:
34
- raise excs.Error(f'No index found for column {col_ref.col!r}')
35
- if idx_name is not None and idx_name not in embedding_idx_info:
36
- raise excs.Error(f'Index {idx_name!r} not found for column {col_ref.col.name!r}')
37
- if len(embedding_idx_info) > 1:
38
- if idx_name is None:
39
- raise excs.Error(
40
- f'Column {col_ref.col.name!r} has multiple indices; use the index name to disambiguate: '
41
- f'`{col_ref.col.name}.similarity(..., idx=<name>)`'
42
- )
43
- self.idx_info = embedding_idx_info[idx_name]
44
- else:
45
- self.idx_info = next(iter(embedding_idx_info.values()))
28
+ # determine index to use
29
+ idx_dict = ColumnRef.find_embedding_index(col_ref.col, idx_name, 'similarity')
30
+ assert len(idx_dict) == 1
31
+ self.idx_info = next(iter(idx_dict.values()))
46
32
  idx = self.idx_info.idx
47
33
  assert isinstance(idx, index.EmbeddingIndex)
48
34
 
@@ -86,8 +72,7 @@ class SimilarityExpr(Expr):
86
72
  return self.idx_info.idx.order_by_clause(self.idx_info.val_col, item, is_asc)
87
73
 
88
74
  def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
89
- # this should never get called
90
- raise AssertionError()
75
+ raise excs.Error('similarity(): cannot be used in a computed column')
91
76
 
92
77
  def _as_dict(self) -> dict:
93
78
  return {'idx_name': self.idx_info.name, **super()._as_dict()}
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Optional, Union
4
+
5
+ import sqlalchemy as sql
6
+
7
+ import pixeltable.exceptions as excs
8
+ import pixeltable.type_system as ts
9
+
10
+ from .data_row import DataRow
11
+ from .expr import Expr
12
+ from .globals import StringOperator
13
+ from .row_builder import RowBuilder
14
+ from .sql_element_cache import SqlElementCache
15
+
16
+
17
+ class StringOp(Expr):
18
+ """
19
+ Allows operations on strings
20
+ """
21
+
22
+ operator: StringOperator
23
+
24
+ def __init__(self, operator: StringOperator, op1: Expr, op2: Expr):
25
+ super().__init__(ts.StringType(nullable=op1.col_type.nullable))
26
+ self.operator = operator
27
+ self.components = [op1, op2]
28
+ assert op1.col_type.is_string_type()
29
+ if operator in {StringOperator.CONCAT, StringOperator.REPEAT}:
30
+ if operator == StringOperator.CONCAT and not op2.col_type.is_string_type():
31
+ raise excs.Error(
32
+ f'{self}: {operator} on strings requires string type, but {op2} has type {op2.col_type}'
33
+ )
34
+ if operator == StringOperator.REPEAT and not op2.col_type.is_int_type():
35
+ raise excs.Error(f'{self}: {operator} on strings requires int type, but {op2} has type {op2.col_type}')
36
+ else:
37
+ raise excs.Error(
38
+ f'{self}: invalid operation {operator} on strings; '
39
+ f'only operators {StringOperator.CONCAT} and {StringOperator.REPEAT} are supported'
40
+ )
41
+ self.id = self._create_id()
42
+
43
+ @property
44
+ def _op1(self) -> Expr:
45
+ return self.components[0]
46
+
47
+ @property
48
+ def _op2(self) -> Expr:
49
+ return self.components[1]
50
+
51
+ def __repr__(self) -> str:
52
+ # add parentheses around operands that are StringOpExpr to express precedence
53
+ op1_str = f'({self._op1})' if isinstance(self._op1, StringOp) else str(self._op1)
54
+ op2_str = f'({self._op2})' if isinstance(self._op2, StringOp) else str(self._op2)
55
+ return f'{op1_str} {self.operator} {op2_str}'
56
+
57
+ def _equals(self, other: StringOp) -> bool:
58
+ return self.operator == other.operator
59
+
60
+ def _id_attrs(self) -> list[tuple[str, Any]]:
61
+ return [*super()._id_attrs(), ('operator', self.operator.value)]
62
+
63
+ def sql_expr(self, sql_elements: SqlElementCache) -> Optional[sql.ColumnElement]:
64
+ left = sql_elements.get(self._op1)
65
+ right = sql_elements.get(self._op2)
66
+ if left is None or right is None:
67
+ return None
68
+ if self.operator == StringOperator.CONCAT:
69
+ return left.concat(right)
70
+ if self.operator == StringOperator.REPEAT:
71
+ return sql.func.repeat(sql.cast(left, sql.String), sql.cast(right, sql.Integer))
72
+ return None
73
+
74
+ def eval(self, data_row: DataRow, row_builder: RowBuilder) -> None:
75
+ op1_val = data_row[self._op1.slot_idx]
76
+ op2_val = data_row[self._op2.slot_idx]
77
+ data_row[self.slot_idx] = self.eval_nullable(op1_val, op2_val)
78
+
79
+ def eval_nullable(self, op1_val: Union[str, None], op2_val: Union[int, str, None]) -> Union[str, None]:
80
+ """
81
+ Return the result of evaluating the expression on two nullable int/float operands,
82
+ None is interpreted as SQL NULL
83
+ """
84
+ if op1_val is None or op2_val is None:
85
+ return None
86
+ return self.eval_non_null(op1_val, op2_val)
87
+
88
+ def eval_non_null(self, op1_val: str, op2_val: Union[int, str]) -> str:
89
+ """
90
+ Return the result of evaluating the expression on two int/float operands
91
+ """
92
+ assert self.operator in {StringOperator.CONCAT, StringOperator.REPEAT}
93
+ if self.operator == StringOperator.CONCAT:
94
+ assert isinstance(op2_val, str)
95
+ return op1_val + op2_val
96
+ else:
97
+ assert isinstance(op2_val, int)
98
+ return op1_val * op2_val
99
+
100
+ def _as_dict(self) -> dict:
101
+ return {'operator': self.operator.value, **super()._as_dict()}
102
+
103
+ @classmethod
104
+ def _from_dict(cls, d: dict, components: list[Expr]) -> StringOp:
105
+ assert 'operator' in d
106
+ assert len(components) == 2
107
+ return cls(StringOperator(d['operator']), components[0], components[1])
@@ -1,21 +1,15 @@
1
1
  import logging
2
- from pathlib import Path
3
- from typing import TYPE_CHECKING, Iterable, Iterator
4
- from urllib.request import urlretrieve
2
+ from typing import TYPE_CHECKING
5
3
 
6
- import numpy as np
7
4
  import PIL.Image
8
5
 
9
6
  import pixeltable as pxt
10
- from pixeltable import env
11
7
  from pixeltable.func import Batch
12
8
  from pixeltable.functions.util import normalize_image_mode
13
9
  from pixeltable.utils.code import local_public_names
14
10
 
15
11
  if TYPE_CHECKING:
16
- import torch
17
- from yolox.exp import Exp # type: ignore[import-untyped]
18
- from yolox.models import YOLOX # type: ignore[import-untyped]
12
+ from yolox.models import Yolox, YoloxProcessor # type: ignore[import-untyped]
19
13
 
20
14
  _logger = logging.getLogger('pixeltable')
21
15
 
@@ -30,7 +24,7 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
30
24
 
31
25
  __Requirements__:
32
26
 
33
- - `pip install git+https://github.com/Megvii-BaseDetection/YOLOX`
27
+ - `pip install pixeltable-yolox`
34
28
 
35
29
  Args:
36
30
  model_id: one of: `yolox_nano`, `yolox_tiny`, `yolox_s`, `yolox_m`, `yolox_l`, `yolox_x`
@@ -46,31 +40,14 @@ def yolox(images: Batch[PIL.Image.Image], *, model_id: str, threshold: float = 0
46
40
  >>> tbl.add_computed_column(detections=yolox(tbl.image, model_id='yolox_m', threshold=0.8))
47
41
  """
48
42
  import torch
49
- from yolox.utils import postprocess # type: ignore[import-untyped]
50
-
51
- model, exp = _lookup_model(model_id, 'cpu')
52
- image_tensors = list(_images_to_tensors(images, exp))
53
- batch_tensor = torch.stack(image_tensors)
54
43
 
44
+ model = _lookup_model(model_id, 'cpu')
45
+ processor = _lookup_processor(model_id)
46
+ normalized_images = [normalize_image_mode(image) for image in images]
55
47
  with torch.no_grad():
56
- output_tensor = model(batch_tensor)
57
-
58
- outputs = postprocess(output_tensor, 80, threshold, exp.nmsthre, class_agnostic=False)
59
-
60
- results: list[dict] = []
61
- for image in images:
62
- ratio = min(exp.test_size[0] / image.height, exp.test_size[1] / image.width)
63
- if outputs[0] is None:
64
- results.append({'bboxes': [], 'scores': [], 'labels': []})
65
- else:
66
- results.append(
67
- {
68
- 'bboxes': [(output[:4] / ratio).tolist() for output in outputs[0]],
69
- 'scores': [output[4].item() * output[5].item() for output in outputs[0]],
70
- 'labels': [int(output[6]) for output in outputs[0]],
71
- }
72
- )
73
- return results
48
+ tensor = processor(normalized_images)
49
+ output = model(tensor)
50
+ return processor.postprocess(normalized_images, output, threshold=threshold)
74
51
 
75
52
 
76
53
  @pxt.udf
@@ -107,47 +84,27 @@ def yolo_to_coco(detections: dict) -> list:
107
84
  return result
108
85
 
109
86
 
110
- def _images_to_tensors(images: Iterable[PIL.Image.Image], exp: 'Exp') -> Iterator['torch.Tensor']:
111
- import torch
112
- from yolox.data import ValTransform # type: ignore[import-untyped]
113
-
114
- val_transform = ValTransform(legacy=False)
115
- for image in images:
116
- normalized_image = normalize_image_mode(image)
117
- image_transform, _ = val_transform(np.array(normalized_image), None, exp.test_size)
118
- yield torch.from_numpy(image_transform)
119
-
120
-
121
- def _lookup_model(model_id: str, device: str) -> tuple['YOLOX', 'Exp']:
122
- import torch
123
- from yolox.exp import get_exp
87
+ def _lookup_model(model_id: str, device: str) -> 'Yolox':
88
+ from yolox.models import Yolox
124
89
 
125
90
  key = (model_id, device)
126
- if key in _model_cache:
127
- return _model_cache[key]
91
+ if key not in _model_cache:
92
+ _model_cache[key] = Yolox.from_pretrained(model_id, device=device)
128
93
 
129
- weights_url = f'https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/{model_id}.pth'
130
- weights_file = Path(f'{env.Env.get().tmp_dir}/{model_id}.pth')
131
- if not weights_file.exists():
132
- _logger.info(f'Downloading weights for YOLOX model {model_id}: from {weights_url} -> {weights_file}')
133
- urlretrieve(weights_url, weights_file)
94
+ return _model_cache[key]
134
95
 
135
- exp = get_exp(exp_name=model_id)
136
- model = exp.get_model().to(device)
137
96
 
138
- model.eval()
139
- model.head.training = False
140
- model.training = False
97
+ def _lookup_processor(model_id: str) -> 'YoloxProcessor':
98
+ from yolox.models import YoloxProcessor
141
99
 
142
- # Load in the weights from training
143
- weights = torch.load(weights_file, map_location=torch.device(device))
144
- model.load_state_dict(weights['model'])
100
+ if model_id not in _processor_cache:
101
+ _processor_cache[model_id] = YoloxProcessor(model_id)
145
102
 
146
- _model_cache[key] = (model, exp)
147
- return model, exp
103
+ return _processor_cache[model_id]
148
104
 
149
105
 
150
- _model_cache: dict[tuple[str, str], tuple['YOLOX', 'Exp']] = {}
106
+ _model_cache: dict[tuple[str, str], 'Yolox'] = {}
107
+ _processor_cache: dict[str, 'YoloxProcessor'] = {}
151
108
 
152
109
 
153
110
  __all__ = local_public_names(__name__)
pixeltable/func/tools.py CHANGED
@@ -51,10 +51,10 @@ class Tool(pydantic.BaseModel):
51
51
  # The output of `tool_calls` must be a dict in standardized tool invocation format:
52
52
  # {tool_name: [{'args': {name1: value1, name2: value2, ...}}, ...], ...}
53
53
  def invoke(self, tool_calls: 'exprs.Expr') -> 'exprs.Expr':
54
- from pixeltable import exprs
54
+ import pixeltable.functions as pxtf
55
55
 
56
56
  func_name = self.name or self.fn.name
57
- return exprs.JsonMapper(tool_calls[func_name]['*'], self.__invoke_kwargs(exprs.RELATIVE_PATH_ROOT.args))
57
+ return pxtf.map(tool_calls[func_name]['*'], lambda x: self.__invoke_kwargs(x.args))
58
58
 
59
59
  def __invoke_kwargs(self, kwargs: 'exprs.Expr') -> 'exprs.FunctionCall':
60
60
  kwargs = {param.name: self.__extract_tool_arg(param, kwargs) for param in self.parameters.values()}
@@ -24,7 +24,7 @@ from . import (
24
24
  vision,
25
25
  whisper,
26
26
  )
27
- from .globals import count, max, mean, min, sum
27
+ from .globals import count, map, max, mean, min, sum
28
28
 
29
29
  __all__ = local_public_names(__name__, exclude=['globals']) + local_public_names(globals.__name__)
30
30
 
@@ -1,15 +1,14 @@
1
1
  import builtins
2
2
  import typing
3
-
4
- from typing import _GenericAlias # type: ignore[attr-defined] # isort: skip
5
- from typing import Optional, Union
3
+ from typing import Any, Callable, Optional, Union
6
4
 
7
5
  import sqlalchemy as sql
8
6
 
9
- import pixeltable.type_system as ts
10
- from pixeltable import exprs, func
7
+ from pixeltable import exceptions as excs, exprs, func, type_system as ts
11
8
  from pixeltable.utils.code import local_public_names
12
9
 
10
+ from typing import _GenericAlias # type: ignore[attr-defined] # isort: skip
11
+
13
12
 
14
13
  # TODO: remove and replace calls with astype()
15
14
  def cast(expr: exprs.Expr, target_type: Union[ts.ColumnType, type, _GenericAlias]) -> exprs.Expr:
@@ -168,6 +167,18 @@ def _(val: sql.ColumnElement) -> Optional[sql.ColumnElement]:
168
167
  return sql.sql.func.avg(val)
169
168
 
170
169
 
170
+ def map(expr: exprs.Expr, fn: Callable[[exprs.Expr], Any]) -> exprs.Expr:
171
+ target_expr: exprs.Expr
172
+ try:
173
+ target_expr = exprs.Expr.from_object(fn(exprs.json_path.RELATIVE_PATH_ROOT))
174
+ except Exception as e:
175
+ raise excs.Error(
176
+ 'Failed to evaluate map function. '
177
+ '(The `fn` argument to `map()` must produce a valid Pixeltable expression.)'
178
+ ) from e
179
+ return exprs.JsonMapper(expr, target_expr)
180
+
181
+
171
182
  __all__ = local_public_names(__name__)
172
183
 
173
184