pixeltable 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (87) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +9 -5
  4. pixeltable/catalog/insertable_table.py +0 -2
  5. pixeltable/catalog/table.py +16 -8
  6. pixeltable/catalog/table_version.py +3 -2
  7. pixeltable/dataframe.py +184 -110
  8. pixeltable/env.py +69 -18
  9. pixeltable/exec/__init__.py +2 -1
  10. pixeltable/exec/data_row_batch.py +6 -7
  11. pixeltable/exec/expr_eval_node.py +28 -28
  12. pixeltable/exec/sql_scan_node.py +7 -6
  13. pixeltable/exprs/__init__.py +4 -3
  14. pixeltable/exprs/column_ref.py +9 -0
  15. pixeltable/exprs/expr.py +15 -7
  16. pixeltable/exprs/function_call.py +17 -15
  17. pixeltable/exprs/image_member_access.py +9 -28
  18. pixeltable/exprs/in_predicate.py +96 -0
  19. pixeltable/exprs/inline_array.py +13 -11
  20. pixeltable/exprs/inline_dict.py +15 -13
  21. pixeltable/exprs/row_builder.py +7 -1
  22. pixeltable/exprs/similarity_expr.py +65 -0
  23. pixeltable/func/__init__.py +0 -2
  24. pixeltable/func/aggregate_function.py +3 -0
  25. pixeltable/func/callable_function.py +57 -13
  26. pixeltable/func/expr_template_function.py +11 -2
  27. pixeltable/func/function.py +35 -4
  28. pixeltable/func/signature.py +5 -15
  29. pixeltable/func/udf.py +6 -10
  30. pixeltable/functions/huggingface.py +23 -4
  31. pixeltable/functions/openai.py +34 -1
  32. pixeltable/functions/pil/image.py +61 -64
  33. pixeltable/functions/together.py +21 -0
  34. pixeltable/globals.py +425 -0
  35. pixeltable/index/base.py +3 -1
  36. pixeltable/index/embedding_index.py +87 -14
  37. pixeltable/io/__init__.py +3 -0
  38. pixeltable/{utils → io}/hf_datasets.py +48 -17
  39. pixeltable/io/pandas.py +148 -0
  40. pixeltable/{utils → io}/parquet.py +58 -33
  41. pixeltable/iterators/__init__.py +1 -1
  42. pixeltable/iterators/base.py +4 -0
  43. pixeltable/iterators/document.py +218 -97
  44. pixeltable/iterators/video.py +8 -9
  45. pixeltable/metadata/__init__.py +7 -3
  46. pixeltable/metadata/converters/convert_12.py +3 -0
  47. pixeltable/metadata/converters/convert_13.py +41 -0
  48. pixeltable/plan.py +2 -19
  49. pixeltable/store.py +2 -2
  50. pixeltable/tool/create_test_db_dump.py +32 -13
  51. pixeltable/type_system.py +13 -54
  52. pixeltable/utils/documents.py +42 -12
  53. pixeltable/utils/http_server.py +70 -0
  54. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/METADATA +10 -7
  55. pixeltable-0.2.6.dist-info/RECORD +119 -0
  56. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  57. pixeltable/client.py +0 -600
  58. pixeltable/exprs/image_similarity_predicate.py +0 -58
  59. pixeltable/func/batched_function.py +0 -53
  60. pixeltable/tests/conftest.py +0 -171
  61. pixeltable/tests/ext/test_yolox.py +0 -21
  62. pixeltable/tests/functions/test_fireworks.py +0 -43
  63. pixeltable/tests/functions/test_functions.py +0 -60
  64. pixeltable/tests/functions/test_huggingface.py +0 -158
  65. pixeltable/tests/functions/test_openai.py +0 -162
  66. pixeltable/tests/functions/test_together.py +0 -112
  67. pixeltable/tests/test_audio.py +0 -65
  68. pixeltable/tests/test_catalog.py +0 -27
  69. pixeltable/tests/test_client.py +0 -21
  70. pixeltable/tests/test_component_view.py +0 -379
  71. pixeltable/tests/test_dataframe.py +0 -440
  72. pixeltable/tests/test_dirs.py +0 -107
  73. pixeltable/tests/test_document.py +0 -120
  74. pixeltable/tests/test_exprs.py +0 -802
  75. pixeltable/tests/test_function.py +0 -332
  76. pixeltable/tests/test_index.py +0 -138
  77. pixeltable/tests/test_migration.py +0 -44
  78. pixeltable/tests/test_nos.py +0 -54
  79. pixeltable/tests/test_snapshot.py +0 -231
  80. pixeltable/tests/test_table.py +0 -1343
  81. pixeltable/tests/test_transactional_directory.py +0 -42
  82. pixeltable/tests/test_types.py +0 -52
  83. pixeltable/tests/test_video.py +0 -159
  84. pixeltable/tests/test_view.py +0 -535
  85. pixeltable/tests/utils.py +0 -442
  86. pixeltable-0.2.5.dist-info/RECORD +0 -139
  87. {pixeltable-0.2.5.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
@@ -1,53 +0,0 @@
1
- import inspect
2
- from typing import List, Dict, Any, Optional, Callable
3
- import abc
4
-
5
- from .function import Function
6
- from .signature import Signature
7
-
8
-
9
- class BatchedFunction(Function):
10
- """Base class for functions that can run on batches"""
11
-
12
- @abc.abstractmethod
13
- def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
14
- """Return the batch size for the given arguments, or None if the batch size is unknown.
15
- args/kwargs might be empty
16
- """
17
- raise NotImplementedError
18
-
19
- @abc.abstractmethod
20
- def invoke(self, arg_batches: List[List[Any]], kwarg_batches: Dict[str, List[Any]]) -> List[Any]:
21
- """Invoke the function for the given batch and return a batch of results"""
22
- raise NotImplementedError
23
-
24
-
25
- class ExplicitBatchedFunction(BatchedFunction):
26
- """
27
- A `BatchedFunction` that is defined by a signature and an explicit python
28
- `Callable`.
29
- """
30
- def __init__(self, signature: Signature, batch_size: Optional[int], invoker_fn: Callable, self_path: str):
31
- super().__init__(signature=signature, py_signature=inspect.signature(invoker_fn), self_path=self_path)
32
- self.batch_size = batch_size
33
- self.invoker_fn = invoker_fn
34
-
35
- def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
36
- return self.batch_size
37
-
38
- def invoke(self, arg_batches: List[List[Any]], kwarg_batches: Dict[str, List[Any]]) -> List[Any]:
39
- """Invoke the function for the given batch and return a batch of results"""
40
- constant_param_names = [p.name for p in self.signature.constant_parameters]
41
- kwargs = {k: v[0] for k, v in kwarg_batches.items() if k in constant_param_names}
42
- kwarg_batches = {k: v for k, v in kwarg_batches.items() if k not in constant_param_names}
43
- return self.invoker_fn(*arg_batches, **kwargs, **kwarg_batches)
44
-
45
- def validate_call(self, bound_args: Dict[str, Any]) -> None:
46
- """Verify constant parameters"""
47
- import pixeltable.exprs as exprs
48
- for param in self.signature.constant_parameters:
49
- if param.name in bound_args and isinstance(bound_args[param.name], exprs.Expr):
50
- raise ValueError(
51
- f'{self.display_name}(): '
52
- f'parameter {param.name} must be a constant value, not a Pixeltable expression'
53
- )
@@ -1,171 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- from typing import List
6
-
7
- import numpy as np
8
- import pytest
9
- import PIL.Image
10
-
11
- import pixeltable as pxt
12
- import pixeltable.catalog as catalog
13
- from pixeltable import exprs
14
- import pixeltable.functions as pxtf
15
- from pixeltable.exprs import RELATIVE_PATH_ROOT as R
16
- from pixeltable.metadata import SystemInfo, create_system_info
17
- from pixeltable.metadata.schema import TableSchemaVersion, TableVersion, Table, Function, Dir
18
- from pixeltable.tests.utils import read_data_file, create_test_tbl, create_all_datatypes_tbl, skip_test_if_not_installed
19
- from pixeltable.type_system import StringType, ImageType, FloatType
20
-
21
-
22
- @pytest.fixture(scope='session')
23
- def init_env(tmp_path_factory) -> None:
24
- from pixeltable.env import Env
25
- # set the relevant env vars for Client() to connect to the test db
26
-
27
- shared_home = pathlib.Path(os.environ.get('PIXELTABLE_HOME', str(pathlib.Path.home() / '.pixeltable')))
28
- home_dir = str(tmp_path_factory.mktemp('base') / '.pixeltable')
29
- os.environ['PIXELTABLE_HOME'] = home_dir
30
- os.environ['PIXELTABLE_CONFIG'] = str(shared_home / 'config.yaml')
31
- test_db = 'test'
32
- os.environ['PIXELTABLE_DB'] = test_db
33
- os.environ['PIXELTABLE_PGDATA'] = str(shared_home / 'pgdata')
34
-
35
- # ensure this home dir exits
36
- shared_home.mkdir(parents=True, exist_ok=True)
37
- # this also runs create_all()
38
- Env.get().set_up(echo=True)
39
- yield
40
- # leave db in place for debugging purposes
41
-
42
- @pytest.fixture(scope='function')
43
- def test_client(init_env) -> pxt.Client:
44
- # Clean the DB *before* instantiating a client object. This is because some tests
45
- # (such as test_migration.py) may leave the DB in a broken state, from which the
46
- # client is uninstantiable.
47
- clean_db()
48
- cl = pxt.Client(reload=True)
49
- cl.logging(level=logging.DEBUG, to_stdout=True)
50
- yield cl
51
-
52
-
53
- def clean_db(restore_tables: bool = True) -> None:
54
- from pixeltable.env import Env
55
- # The logic from Client.reset_catalog() has been moved here, so that it
56
- # does not rely on instantiating a Client object. As before, UUID-named data tables will
57
- # not be cleaned. If in the future it is desirable to clean out data tables as well,
58
- # the commented lines may be used to drop ALL tables from the test db.
59
- # sql_md = declarative_base().metadata
60
- # sql_md.reflect(Env.get().engine)
61
- # sql_md.drop_all(bind=Env.get().engine)
62
- engine = Env.get().engine
63
- SystemInfo.__table__.drop(engine, checkfirst=True)
64
- TableSchemaVersion.__table__.drop(engine, checkfirst=True)
65
- TableVersion.__table__.drop(engine, checkfirst=True)
66
- Table.__table__.drop(engine, checkfirst=True)
67
- Function.__table__.drop(engine, checkfirst=True)
68
- Dir.__table__.drop(engine, checkfirst=True)
69
- if restore_tables:
70
- Dir.__table__.create(engine)
71
- Function.__table__.create(engine)
72
- Table.__table__.create(engine)
73
- TableVersion.__table__.create(engine)
74
- TableSchemaVersion.__table__.create(engine)
75
- SystemInfo.__table__.create(engine)
76
- create_system_info(engine)
77
-
78
-
79
- @pytest.fixture(scope='function')
80
- def test_tbl(test_client: pxt.Client) -> catalog.Table:
81
- return create_test_tbl(test_client)
82
-
83
- # @pytest.fixture(scope='function')
84
- # def test_stored_fn(test_client: pxt.Client) -> pxt.Function:
85
- # @pxt.udf(return_type=pxt.IntType(), param_types=[pxt.IntType()])
86
- # def test_fn(x):
87
- # return x + 1
88
- # test_client.create_function('test_fn', test_fn)
89
- # return test_fn
90
-
91
- @pytest.fixture(scope='function')
92
- def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
93
- #def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pxt.Function) -> List[exprs.Expr]:
94
-
95
- t = test_tbl
96
- return [
97
- t.c1,
98
- t.c7['*'].f1,
99
- exprs.Literal('test'),
100
- exprs.InlineDict({
101
- 'a': t.c1, 'b': t.c6.f1, 'c': 17,
102
- 'd': exprs.InlineDict({'e': t.c2}),
103
- 'f': exprs.InlineArray((t.c3, t.c3))
104
- }),
105
- exprs.InlineArray([[t.c2, t.c2], [t.c2, t.c2]]),
106
- t.c2 > 5,
107
- t.c2 == None,
108
- ~(t.c2 > 5),
109
- (t.c2 > 5) & (t.c1 == 'test'),
110
- (t.c2 > 5) | (t.c1 == 'test'),
111
- t.c7['*'].f5 >> [R[3], R[2], R[1], R[0]],
112
- t.c8[0, 1:],
113
- t.c2.astype(FloatType()),
114
- (t.c2 + 1).astype(FloatType()),
115
- t.c2.apply(str),
116
- (t.c2 + 1).apply(str),
117
- t.c3.apply(str),
118
- t.c4.apply(str),
119
- t.c5.apply(str),
120
- t.c6.apply(str),
121
- t.c1.apply(json.loads),
122
- t.c8.errortype,
123
- t.c8.errormsg,
124
- pxtf.sum(t.c2, group_by=t.c4, order_by=t.c3),
125
- ]
126
-
127
- @pytest.fixture(scope='function')
128
- def all_datatypes_tbl(test_client: pxt.Client) -> catalog.Table:
129
- return create_all_datatypes_tbl(test_client)
130
-
131
- @pytest.fixture(scope='function')
132
- def img_tbl(test_client: pxt.Client) -> catalog.Table:
133
- schema = {
134
- 'img': ImageType(nullable=False),
135
- 'category': StringType(nullable=False),
136
- 'split': StringType(nullable=False),
137
- }
138
- # this table is not indexed in order to avoid the cost of computing embeddings
139
- tbl = test_client.create_table('test_img_tbl', schema)
140
- rows = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
141
- tbl.insert(rows)
142
- return tbl
143
-
144
- @pytest.fixture(scope='function')
145
- def img_tbl_exprs(img_tbl: catalog.Table) -> List[exprs.Expr]:
146
- img_t = img_tbl
147
- return [
148
- img_t.img.width,
149
- img_t.img.rotate(90),
150
- # we're using a list here, not a tuple; the latter turns into a list during the back/forth conversion
151
- img_t.img.rotate(90).resize([224, 224]),
152
- img_t.img.fileurl,
153
- img_t.img.localpath,
154
- ]
155
-
156
- @pytest.fixture(scope='function')
157
- def small_img_tbl(test_client: pxt.Client) -> catalog.Table:
158
- cl = test_client
159
- schema = {
160
- 'img': ImageType(nullable=False),
161
- 'category': StringType(nullable=False),
162
- 'split': StringType(nullable=False),
163
- }
164
- tbl = cl.create_table('test_indexed_img_tbl', schema)
165
- rows = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
166
- # select output_rows randomly in the hope of getting a good sample of the available categories
167
- rng = np.random.default_rng(17)
168
- idxs = rng.choice(np.arange(len(rows)), size=40, replace=False)
169
- rows = [rows[i] for i in idxs]
170
- tbl.insert(rows)
171
- return tbl
@@ -1,21 +0,0 @@
1
- import pixeltable as pxt
2
- from pixeltable.tests.utils import skip_test_if_not_installed, get_image_files, validate_update_status
3
-
4
-
5
- class TestYolox:
6
-
7
- def test_yolox(self, test_client: pxt.Client):
8
- skip_test_if_not_installed('yolox')
9
- from pixeltable.ext.functions.yolox import yolox
10
- cl = test_client
11
- t = cl.create_table('yolox_test', {'image': pxt.ImageType()})
12
- t['detect_yolox_tiny'] = yolox(t.image, model_id='yolox_tiny')
13
- t['detect_yolox_nano'] = yolox(t.image, model_id='yolox_nano', threshold=0.2)
14
- t['yolox_nano_bboxes'] = t.detect_yolox_nano.bboxes
15
- images = get_image_files()[:10]
16
- validate_update_status(t.insert({'image': image} for image in images), expected_rows=10)
17
- rows = t.collect()
18
- # Verify correctly formed JSON
19
- assert all(list(result.keys()) == ['bboxes', 'labels', 'scores'] for result in rows['detect_yolox_tiny'])
20
- # Verify that bboxes are actually present in at least some of the rows.
21
- assert any(len(bboxes) > 0 for bboxes in rows['yolox_nano_bboxes'])
@@ -1,43 +0,0 @@
1
- import pytest
2
-
3
- import pixeltable as pxt
4
- import pixeltable.exceptions as excs
5
- from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
6
-
7
-
8
- @pytest.mark.remote_api
9
- class TestFireworks:
10
-
11
- def test_fireworks(self, test_client: pxt.Client) -> None:
12
- skip_test_if_not_installed('fireworks')
13
- TestFireworks.skip_test_if_no_fireworks_client()
14
- cl = test_client
15
- t = cl.create_table('test_tbl', {'input': pxt.StringType()})
16
- from pixeltable.functions.fireworks import chat_completions
17
- messages = [{'role': 'user', 'content': t.input}]
18
- t['output'] = chat_completions(
19
- messages=messages,
20
- model='accounts/fireworks/models/llama-v2-7b-chat'
21
- )
22
- t['output_2'] = chat_completions(
23
- messages=messages,
24
- model='accounts/fireworks/models/llama-v2-7b-chat',
25
- max_tokens=300,
26
- top_k=40,
27
- top_p=0.9,
28
- temperature=0.7
29
- )
30
- validate_update_status(t.insert(input="How's everything going today?"), 1)
31
- results = t.collect()
32
- assert len(results['output'][0]['choices'][0]['message']['content']) > 0
33
- assert len(results['output_2'][0]['choices'][0]['message']['content']) > 0
34
-
35
- # This ensures that the test will be skipped, rather than returning an error, when no API key is
36
- # available (for example, when a PR runs in CI).
37
- @staticmethod
38
- def skip_test_if_no_fireworks_client() -> None:
39
- try:
40
- import pixeltable.functions.fireworks
41
- _ = pixeltable.functions.fireworks.fireworks_client()
42
- except excs.Error as exc:
43
- pytest.skip(str(exc))
@@ -1,60 +0,0 @@
1
- import pixeltable as pxt
2
- from pixeltable import catalog
3
- from pixeltable.functions.pil.image import blend
4
- from pixeltable.iterators import FrameIterator
5
- from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
6
- from pixeltable.type_system import VideoType, StringType
7
-
8
-
9
- class TestFunctions:
10
- def test_pil(self, img_tbl: catalog.Table) -> None:
11
- t = img_tbl
12
- _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
13
-
14
- def test_eval_detections(self, test_client: pxt.Client) -> None:
15
- skip_test_if_not_installed('nos')
16
- cl = test_client
17
- video_t = cl.create_table('video_tbl', {'video': VideoType()})
18
- # create frame view
19
- args = {'video': video_t.video, 'fps': 1}
20
- v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
21
-
22
- files = get_video_files()
23
- video_t.insert(video=files[-1])
24
- v.add_column(frame_s=v.frame.resize([640, 480]))
25
- from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
26
- v.add_column(detections_a=yolox_nano(v.frame_s))
27
- v.add_column(detections_b=yolox_small(v.frame_s))
28
- v.add_column(gt=yolox_large(v.frame_s))
29
- from pixeltable.functions.eval import eval_detections, mean_ap
30
- res = v.select(
31
- eval_detections(
32
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
33
- )).show()
34
- v.add_column(
35
- eval_a=eval_detections(
36
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
37
- v.add_column(
38
- eval_b=eval_detections(
39
- v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
40
- ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
41
- ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
42
- common_classes = set(ap_a.keys()) & set(ap_b.keys())
43
-
44
- ## TODO: following assertion is failing on CI,
45
- # It is not necessarily a bug, as assert codition is not expected to be always true
46
- # for k in common_classes:
47
- # assert ap_a[k] <= ap_b[k]
48
-
49
- def test_str(self, test_client: pxt.Client) -> None:
50
- cl = test_client
51
- t = cl.create_table('test_tbl', {'input': StringType()})
52
- from pixeltable.functions.string import str_format
53
- t.add_column(s1=str_format('ABC {0}', t.input))
54
- t.add_column(s2=str_format('DEF {this}', this=t.input))
55
- t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
56
- status = t.insert(input='MNO')
57
- assert status.num_rows == 1
58
- assert status.num_excs == 0
59
- row = t.head()[0]
60
- assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
@@ -1,158 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- import pytest
4
-
5
- import pixeltable as pxt
6
- from pixeltable.tests.utils import skip_test_if_not_installed, get_sentences, get_image_files, \
7
- SAMPLE_IMAGE_URL
8
- from pixeltable.type_system import StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
9
-
10
-
11
- class TestHuggingface:
12
-
13
- def test_hf_function(self, test_client: pxt.Client) -> None:
14
- skip_test_if_not_installed('sentence_transformers')
15
- cl = test_client
16
- t = cl.create_table('test_tbl', {'input': StringType(), 'bool_col': BoolType()})
17
- from pixeltable.functions.huggingface import sentence_transformer
18
- model_id = 'intfloat/e5-large-v2'
19
- t.add_column(e5=sentence_transformer(t.input, model_id=model_id))
20
- sents = get_sentences()
21
- status = t.insert({'input': s, 'bool_col': True} for s in sents)
22
- assert status.num_rows == len(sents)
23
- assert status.num_excs == 0
24
-
25
- # verify handling of constant params
26
- with pytest.raises(ValueError) as exc_info:
27
- t.add_column(e5_2=sentence_transformer(t.input, model_id=t.input))
28
- assert ': parameter model_id must be a constant value' in str(exc_info.value)
29
- with pytest.raises(ValueError) as exc_info:
30
- t.add_column(e5_2=sentence_transformer(t.input, model_id=model_id, normalize_embeddings=t.bool_col))
31
- assert ': parameter normalize_embeddings must be a constant value' in str(exc_info.value)
32
-
33
- # make sure this doesn't cause an exception
34
- # TODO: is there some way to capture the output?
35
- t.describe()
36
-
37
- def test_sentence_transformer(self, test_client: pxt.Client) -> None:
38
- skip_test_if_not_installed('sentence_transformers')
39
- cl = test_client
40
- t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
41
- sents = get_sentences(10)
42
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
43
- assert status.num_rows == len(sents)
44
- assert status.num_excs == 0
45
-
46
- # run multiple models one at a time in order to exercise batching
47
- from pixeltable.functions.huggingface import sentence_transformer, sentence_transformer_list
48
- model_ids = ['sentence-transformers/all-mpnet-base-v2', 'BAAI/bge-reranker-base']
49
- num_dims = [768, 768]
50
- for idx, model_id in enumerate(model_ids):
51
- col_name = f'embed{idx}'
52
- t[col_name] = sentence_transformer(t.input, model_id=model_id, normalize_embeddings=True)
53
- assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
54
- list_col_name = f'embed_list{idx}'
55
- t[list_col_name] = sentence_transformer_list(t.input_list, model_id=model_id, normalize_embeddings=True)
56
- assert t.column_types()[list_col_name] == JsonType()
57
-
58
- def verify_row(row: Dict[str, Any]) -> None:
59
- for idx, (_, d) in enumerate(zip(model_ids, num_dims)):
60
- assert row[f'embed{idx}'].shape == (d,)
61
- assert len(row[f'embed_list{idx}']) == len(sents)
62
- assert all(len(v) == d for v in row[f'embed_list{idx}'])
63
-
64
- verify_row(t.tail(1)[0])
65
-
66
- # execution still works after reload
67
- cl = pxt.Client(reload=True)
68
- t = cl.get_table('test_tbl')
69
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
70
- assert status.num_rows == len(sents)
71
- assert status.num_excs == 0
72
- verify_row(t.tail(1)[0])
73
-
74
- def test_cross_encoder(self, test_client: pxt.Client) -> None:
75
- skip_test_if_not_installed('sentence_transformers')
76
- cl = test_client
77
- t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
78
- sents = get_sentences(10)
79
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
80
- assert status.num_rows == len(sents)
81
- assert status.num_excs == 0
82
-
83
- # run multiple models one at a time in order to exercise batching
84
- from pixeltable.functions.huggingface import cross_encoder, cross_encoder_list
85
- model_ids = ['cross-encoder/ms-marco-MiniLM-L-6-v2', 'cross-encoder/ms-marco-TinyBERT-L-2-v2']
86
- for idx, model_id in enumerate(model_ids):
87
- col_name = f'embed{idx}'
88
- t[col_name] = cross_encoder(t.input, t.input, model_id=model_id)
89
- assert t.column_types()[col_name] == FloatType()
90
- list_col_name = f'embed_list{idx}'
91
- t[list_col_name] = cross_encoder_list(t.input, t.input_list, model_id=model_id)
92
- assert t.column_types()[list_col_name] == JsonType()
93
-
94
- def verify_row(row: Dict[str, Any]) -> None:
95
- for i in range(len(model_ids)):
96
- assert len(row[f'embed_list{idx}']) == len(sents)
97
- assert all(isinstance(v, float) for v in row[f'embed_list{idx}'])
98
-
99
- verify_row(t.tail(1)[0])
100
-
101
- # execution still works after reload
102
- cl = pxt.Client(reload=True)
103
- t = cl.get_table('test_tbl')
104
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
105
- assert status.num_rows == len(sents)
106
- assert status.num_excs == 0
107
- verify_row(t.tail(1)[0])
108
-
109
- def test_clip(self, test_client: pxt.Client) -> None:
110
- skip_test_if_not_installed('transformers')
111
- cl = test_client
112
- t = cl.create_table('test_tbl', {'text': StringType(), 'img': ImageType()})
113
- num_rows = 10
114
- sents = get_sentences(num_rows)
115
- imgs = get_image_files()[:num_rows]
116
- status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
117
- assert status.num_rows == len(sents)
118
- assert status.num_excs == 0
119
-
120
- # run multiple models one at a time in order to exercise batching
121
- from pixeltable.functions.huggingface import clip_text, clip_image
122
- model_ids = ['openai/clip-vit-base-patch32', 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K']
123
- for idx, model_id in enumerate(model_ids):
124
- col_name = f'embed_text{idx}'
125
- t[col_name] = clip_text(t.text, model_id=model_id)
126
- assert t.column_types()[col_name].is_array_type()
127
- col_name = f'embed_img{idx}'
128
- t[col_name] = clip_image(t.img, model_id=model_id)
129
- assert t.column_types()[col_name].is_array_type()
130
-
131
- def verify_row(row: Dict[str, Any]) -> None:
132
- for idx, _ in enumerate(model_ids):
133
- assert row[f'embed_text{idx}'].shape == (512,)
134
- assert row[f'embed_img{idx}'].shape == (512,)
135
-
136
- verify_row(t.tail(1)[0])
137
-
138
- # execution still works after reload
139
- cl = pxt.Client(reload=True)
140
- t = cl.get_table('test_tbl')
141
- status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
142
- assert status.num_rows == len(sents)
143
- assert status.num_excs == 0
144
- verify_row(t.tail(1)[0])
145
-
146
- def test_detr_for_object_detection(self, test_client: pxt.Client) -> None:
147
- skip_test_if_not_installed('transformers')
148
- cl = test_client
149
- t = cl.create_table('test_tbl', {'img': ImageType()})
150
- from pixeltable.functions.huggingface import detr_for_object_detection
151
- t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
152
- status = t.insert(img=SAMPLE_IMAGE_URL)
153
- assert status.num_rows == 1
154
- assert status.num_excs == 0
155
- result = t.select(t.detect).collect()[0]['detect']
156
- assert 'orange' in result['label_text']
157
- assert 'bowl' in result['label_text']
158
- assert 'broccoli' in result['label_text']
@@ -1,162 +0,0 @@
1
- import pytest
2
-
3
- import pixeltable as pxt
4
- import pixeltable.exceptions as excs
5
- from pixeltable.tests.utils import SAMPLE_IMAGE_URL, skip_test_if_not_installed, validate_update_status
6
- from pixeltable.type_system import StringType, ImageType
7
-
8
-
9
- @pytest.mark.remote_api
10
- class TestOpenai:
11
-
12
- def test_audio(self, test_client: pxt.Client) -> None:
13
- skip_test_if_not_installed('openai')
14
- TestOpenai.skip_test_if_no_openai_client()
15
- cl = test_client
16
- t = cl.create_table('test_tbl', {'input': StringType()})
17
- from pixeltable.functions.openai import speech, transcriptions, translations
18
- t.add_column(speech=speech(t.input, model='tts-1', voice='onyx'))
19
- t.add_column(speech_2=speech(t.input, model='tts-1', voice='onyx', response_format='flac', speed=1.05))
20
- t.add_column(transcription=transcriptions(t.speech, model='whisper-1'))
21
- t.add_column(transcription_2=transcriptions(
22
- t.speech, model='whisper-1', language='en', prompt='Transcribe the contents of this recording.'
23
- ))
24
- t.add_column(translation=translations(t.speech, model='whisper-1'))
25
- t.add_column(translation_2=translations(
26
- t.speech, model='whisper-1', prompt='Translate the recording from Spanish into English.', temperature=0.05
27
- ))
28
- validate_update_status(t.insert([
29
- {'input': 'I am a banana.'},
30
- {'input': 'Es fácil traducir del español al inglés.'}
31
- ]), expected_rows=2)
32
- # The audio generation -> transcription loop on these examples should be simple and clear enough
33
- # that the unit test can reliably expect the output closely enough to pass these checks.
34
- results = t.collect()
35
- assert results[0]['transcription']['text'] in ['I am a banana.', "I'm a banana."]
36
- assert results[0]['transcription_2']['text'] in ['I am a banana.', "I'm a banana."]
37
- assert 'easy to translate' in results[1]['translation']['text']
38
- assert 'easy to translate' in results[1]['translation_2']['text']
39
-
40
- def test_chat_completions(self, test_client: pxt.Client) -> None:
41
- skip_test_if_not_installed('openai')
42
- TestOpenai.skip_test_if_no_openai_client()
43
- cl = test_client
44
- t = cl.create_table('test_tbl', {'input': StringType()})
45
- from pixeltable.functions.openai import chat_completions
46
- msgs = [
47
- {"role": "system", "content": "You are a helpful assistant."},
48
- {"role": "user", "content": t.input}
49
- ]
50
- t.add_column(input_msgs=msgs)
51
- t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
52
- # with inlined messages
53
- t.add_column(chat_output_2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
54
- # test a bunch of the parameters
55
- t.add_column(chat_output_3=chat_completions(
56
- model='gpt-3.5-turbo', messages=msgs, frequency_penalty=0.1, logprobs=True, top_logprobs=3,
57
- max_tokens=500, n=3, presence_penalty=0.1, seed=4171780, stop=['\n'], temperature=0.7, top_p=0.8,
58
- user='pixeltable'
59
- ))
60
- # test with JSON output enforced
61
- t.add_column(chat_output_4=chat_completions(
62
- model='gpt-3.5-turbo', messages=msgs, response_format={'type': 'json_object'}
63
- ))
64
- # TODO Also test the `tools` and `tool_choice` parameters.
65
- validate_update_status(t.insert(input='Give me an example of a typical JSON structure.'), 1)
66
- result = t.collect()
67
- assert len(result['chat_output'][0]['choices'][0]['message']['content']) > 0
68
- assert len(result['chat_output_2'][0]['choices'][0]['message']['content']) > 0
69
- assert len(result['chat_output_3'][0]['choices'][0]['message']['content']) > 0
70
- assert len(result['chat_output_4'][0]['choices'][0]['message']['content']) > 0
71
-
72
- # When OpenAI gets a request with `response_format` equal to `json_object`, but the prompt does not
73
- # contain the string "json", it refuses the request.
74
- # TODO This should probably not be throwing an exception, but rather logging the error in
75
- # `t.chat_output_4.errormsg` etc.
76
- with pytest.raises(excs.ExprEvalError) as exc_info:
77
- t.insert(input='Say something interesting.')
78
- assert "\\'messages\\' must contain the word \\'json\\'" in str(exc_info.value)
79
-
80
- def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
81
- skip_test_if_not_installed('openai')
82
- TestOpenai.skip_test_if_no_openai_client()
83
- cl = test_client
84
- t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
85
- from pixeltable.functions.openai import chat_completions, vision
86
- from pixeltable.functions.string import str_format
87
- t.add_column(response=vision(prompt="What's in this image?", image=t.img))
88
- # Also get the response the low-level way, by calling chat_completions
89
- msgs = [
90
- {'role': 'user',
91
- 'content': [
92
- {'type': 'text', 'text': t.prompt},
93
- {'type': 'image_url', 'image_url': {
94
- 'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
95
- }}
96
- ]}
97
- ]
98
- t.add_column(response_2=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300).choices[0].message.content)
99
- validate_update_status(t.insert(prompt="What's in this image?", img=SAMPLE_IMAGE_URL), 1)
100
- result = t.collect()['response_2'][0]
101
- assert len(result) > 0
102
-
103
- def test_embeddings(self, test_client: pxt.Client) -> None:
104
- skip_test_if_not_installed('openai')
105
- TestOpenai.skip_test_if_no_openai_client()
106
- cl = test_client
107
- from pixeltable.functions.openai import embeddings
108
- t = cl.create_table('test_tbl', {'input': StringType()})
109
- t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
110
- t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input, user='pixeltable'))
111
- validate_update_status(t.insert(input='Say something interesting.'), 1)
112
- _ = t.head()
113
-
114
- def test_moderations(self, test_client: pxt.Client) -> None:
115
- skip_test_if_not_installed('openai')
116
- TestOpenai.skip_test_if_no_openai_client()
117
- cl = test_client
118
- t = cl.create_table('test_tbl', {'input': StringType()})
119
- from pixeltable.functions.openai import moderations
120
- t.add_column(moderation=moderations(input=t.input))
121
- t.add_column(moderation_2=moderations(input=t.input, model='text-moderation-stable'))
122
- validate_update_status(t.insert(input='Say something interesting.'), 1)
123
- _ = t.head()
124
-
125
- def test_image_generations(self, test_client: pxt.Client) -> None:
126
- skip_test_if_not_installed('openai')
127
- TestOpenai.skip_test_if_no_openai_client()
128
- cl = test_client
129
- t = cl.create_table('test_tbl', {'input': StringType()})
130
- from pixeltable.functions.openai import image_generations
131
- t.add_column(img=image_generations(t.input))
132
- # Test dall-e-2 options
133
- t.add_column(img_2=image_generations(
134
- t.input, model='dall-e-2', size='512x512', user='pixeltable'
135
- ))
136
- validate_update_status(t.insert(input='A friendly dinosaur playing tennis in a cornfield'), 1)
137
- assert t.collect()['img'][0].size == (1024, 1024)
138
- assert t.collect()['img_2'][0].size == (512, 512)
139
-
140
- @pytest.mark.skip('Test is expensive and slow')
141
- def test_image_generations_dall_e_3(self, test_client: pxt.Client) -> None:
142
- skip_test_if_not_installed('openai')
143
- TestOpenai.skip_test_if_no_openai_client()
144
- cl = test_client
145
- t = cl.create_table('test_tbl', {'input': StringType()})
146
- from pixeltable.functions.openai import image_generations
147
- # Test dall-e-3 options
148
- t.add_column(img_3=image_generations(
149
- t.input, model='dall-e-3', quality='hd', size='1792x1024', style='natural', user='pixeltable'
150
- ))
151
- validate_update_status(t.insert(input='A friendly dinosaur playing tennis in a cornfield'), 1)
152
- assert t.collect()['img_3'][0].size == (1792, 1024)
153
-
154
- # This ensures that the test will be skipped, rather than returning an error, when no API key is
155
- # available (for example, when a PR runs in CI).
156
- @staticmethod
157
- def skip_test_if_no_openai_client() -> None:
158
- try:
159
- import pixeltable.functions.openai
160
- _ = pixeltable.functions.openai.openai_client()
161
- except excs.Error as exc:
162
- pytest.skip(str(exc))