pixeltable 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (99) hide show
  1. pixeltable/__init__.py +18 -9
  2. pixeltable/__version__.py +3 -0
  3. pixeltable/catalog/column.py +31 -50
  4. pixeltable/catalog/insertable_table.py +7 -6
  5. pixeltable/catalog/table.py +171 -57
  6. pixeltable/catalog/table_version.py +417 -140
  7. pixeltable/catalog/table_version_path.py +2 -2
  8. pixeltable/dataframe.py +239 -121
  9. pixeltable/env.py +82 -16
  10. pixeltable/exec/__init__.py +2 -1
  11. pixeltable/exec/cache_prefetch_node.py +1 -1
  12. pixeltable/exec/data_row_batch.py +6 -7
  13. pixeltable/exec/expr_eval_node.py +28 -28
  14. pixeltable/exec/in_memory_data_node.py +11 -7
  15. pixeltable/exec/sql_scan_node.py +7 -6
  16. pixeltable/exprs/__init__.py +4 -3
  17. pixeltable/exprs/column_ref.py +9 -0
  18. pixeltable/exprs/comparison.py +3 -3
  19. pixeltable/exprs/data_row.py +5 -1
  20. pixeltable/exprs/expr.py +15 -7
  21. pixeltable/exprs/function_call.py +17 -15
  22. pixeltable/exprs/image_member_access.py +9 -28
  23. pixeltable/exprs/in_predicate.py +96 -0
  24. pixeltable/exprs/inline_array.py +13 -11
  25. pixeltable/exprs/inline_dict.py +15 -13
  26. pixeltable/exprs/literal.py +16 -4
  27. pixeltable/exprs/row_builder.py +15 -41
  28. pixeltable/exprs/similarity_expr.py +65 -0
  29. pixeltable/ext/__init__.py +5 -0
  30. pixeltable/ext/functions/yolox.py +92 -0
  31. pixeltable/func/__init__.py +0 -2
  32. pixeltable/func/aggregate_function.py +18 -15
  33. pixeltable/func/callable_function.py +57 -13
  34. pixeltable/func/expr_template_function.py +20 -3
  35. pixeltable/func/function.py +35 -4
  36. pixeltable/func/globals.py +24 -14
  37. pixeltable/func/signature.py +23 -27
  38. pixeltable/func/udf.py +13 -12
  39. pixeltable/functions/__init__.py +8 -8
  40. pixeltable/functions/eval.py +7 -8
  41. pixeltable/functions/huggingface.py +64 -17
  42. pixeltable/functions/openai.py +36 -3
  43. pixeltable/functions/pil/image.py +61 -64
  44. pixeltable/functions/together.py +21 -0
  45. pixeltable/functions/util.py +11 -0
  46. pixeltable/globals.py +425 -0
  47. pixeltable/index/__init__.py +2 -0
  48. pixeltable/index/base.py +51 -0
  49. pixeltable/index/embedding_index.py +168 -0
  50. pixeltable/io/__init__.py +3 -0
  51. pixeltable/{utils → io}/hf_datasets.py +48 -17
  52. pixeltable/io/pandas.py +148 -0
  53. pixeltable/{utils → io}/parquet.py +58 -33
  54. pixeltable/iterators/__init__.py +1 -1
  55. pixeltable/iterators/base.py +4 -0
  56. pixeltable/iterators/document.py +218 -97
  57. pixeltable/iterators/video.py +8 -9
  58. pixeltable/metadata/__init__.py +7 -3
  59. pixeltable/metadata/converters/convert_12.py +3 -0
  60. pixeltable/metadata/converters/convert_13.py +41 -0
  61. pixeltable/metadata/schema.py +45 -22
  62. pixeltable/plan.py +15 -51
  63. pixeltable/store.py +38 -41
  64. pixeltable/tool/create_test_db_dump.py +39 -4
  65. pixeltable/type_system.py +47 -96
  66. pixeltable/utils/documents.py +42 -12
  67. pixeltable/utils/http_server.py +70 -0
  68. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/METADATA +14 -10
  69. pixeltable-0.2.6.dist-info/RECORD +119 -0
  70. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/WHEEL +1 -1
  71. pixeltable/client.py +0 -604
  72. pixeltable/exprs/image_similarity_predicate.py +0 -58
  73. pixeltable/func/batched_function.py +0 -53
  74. pixeltable/tests/conftest.py +0 -177
  75. pixeltable/tests/functions/test_fireworks.py +0 -42
  76. pixeltable/tests/functions/test_functions.py +0 -60
  77. pixeltable/tests/functions/test_huggingface.py +0 -158
  78. pixeltable/tests/functions/test_openai.py +0 -152
  79. pixeltable/tests/functions/test_together.py +0 -111
  80. pixeltable/tests/test_audio.py +0 -65
  81. pixeltable/tests/test_catalog.py +0 -27
  82. pixeltable/tests/test_client.py +0 -21
  83. pixeltable/tests/test_component_view.py +0 -370
  84. pixeltable/tests/test_dataframe.py +0 -439
  85. pixeltable/tests/test_dirs.py +0 -107
  86. pixeltable/tests/test_document.py +0 -120
  87. pixeltable/tests/test_exprs.py +0 -805
  88. pixeltable/tests/test_function.py +0 -324
  89. pixeltable/tests/test_migration.py +0 -43
  90. pixeltable/tests/test_nos.py +0 -54
  91. pixeltable/tests/test_snapshot.py +0 -208
  92. pixeltable/tests/test_table.py +0 -1267
  93. pixeltable/tests/test_transactional_directory.py +0 -42
  94. pixeltable/tests/test_types.py +0 -22
  95. pixeltable/tests/test_video.py +0 -159
  96. pixeltable/tests/test_view.py +0 -530
  97. pixeltable/tests/utils.py +0 -408
  98. pixeltable-0.2.4.dist-info/RECORD +0 -132
  99. {pixeltable-0.2.4.dist-info → pixeltable-0.2.6.dist-info}/LICENSE +0 -0
@@ -1,53 +0,0 @@
1
- import inspect
2
- from typing import List, Dict, Any, Optional, Callable
3
- import abc
4
-
5
- from .function import Function
6
- from .signature import Signature
7
-
8
-
9
- class BatchedFunction(Function):
10
- """Base class for functions that can run on batches"""
11
-
12
- @abc.abstractmethod
13
- def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
14
- """Return the batch size for the given arguments, or None if the batch size is unknown.
15
- args/kwargs might be empty
16
- """
17
- raise NotImplementedError
18
-
19
- @abc.abstractmethod
20
- def invoke(self, arg_batches: List[List[Any]], kwarg_batches: Dict[str, List[Any]]) -> List[Any]:
21
- """Invoke the function for the given batch and return a batch of results"""
22
- raise NotImplementedError
23
-
24
-
25
- class ExplicitBatchedFunction(BatchedFunction):
26
- """
27
- A `BatchedFunction` that is defined by a signature and an explicit python
28
- `Callable`.
29
- """
30
- def __init__(self, signature: Signature, batch_size: Optional[int], invoker_fn: Callable, self_path: str):
31
- super().__init__(signature=signature, py_signature=inspect.signature(invoker_fn), self_path=self_path)
32
- self.batch_size = batch_size
33
- self.invoker_fn = invoker_fn
34
-
35
- def get_batch_size(self, *args: Any, **kwargs: Any) -> Optional[int]:
36
- return self.batch_size
37
-
38
- def invoke(self, arg_batches: List[List[Any]], kwarg_batches: Dict[str, List[Any]]) -> List[Any]:
39
- """Invoke the function for the given batch and return a batch of results"""
40
- constant_param_names = [p.name for p in self.signature.constant_parameters]
41
- kwargs = {k: v[0] for k, v in kwarg_batches.items() if k in constant_param_names}
42
- kwarg_batches = {k: v for k, v in kwarg_batches.items() if k not in constant_param_names}
43
- return self.invoker_fn(*arg_batches, **kwargs, **kwarg_batches)
44
-
45
- def validate_call(self, bound_args: Dict[str, Any]) -> None:
46
- """Verify constant parameters"""
47
- import pixeltable.exprs as exprs
48
- for param in self.signature.constant_parameters:
49
- if param.name in bound_args and isinstance(bound_args[param.name], exprs.Expr):
50
- raise ValueError(
51
- f'{self.display_name}(): '
52
- f'parameter {param.name} must be a constant value, not a Pixeltable expression'
53
- )
@@ -1,177 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- from typing import List
6
-
7
- import numpy as np
8
- import pytest
9
-
10
- import pixeltable as pxt
11
- import pixeltable.catalog as catalog
12
- from pixeltable import exprs
13
- from pixeltable import functions as ptf
14
- from pixeltable.exprs import RELATIVE_PATH_ROOT as R
15
- from pixeltable.metadata import SystemInfo, create_system_info
16
- from pixeltable.metadata.schema import TableSchemaVersion, TableVersion, Table, Function, Dir
17
- from pixeltable.tests.utils import read_data_file, create_test_tbl, create_all_datatypes_tbl, skip_test_if_not_installed
18
- from pixeltable.type_system import StringType, ImageType, FloatType
19
-
20
-
21
- @pytest.fixture(scope='session')
22
- def init_env(tmp_path_factory) -> None:
23
- from pixeltable.env import Env
24
- # set the relevant env vars for Client() to connect to the test db
25
-
26
- shared_home = pathlib.Path(os.environ.get('PIXELTABLE_HOME', str(pathlib.Path.home() / '.pixeltable')))
27
- home_dir = str(tmp_path_factory.mktemp('base') / '.pixeltable')
28
- os.environ['PIXELTABLE_HOME'] = home_dir
29
- os.environ['PIXELTABLE_CONFIG'] = str(shared_home / 'config.yaml')
30
- test_db = 'test'
31
- os.environ['PIXELTABLE_DB'] = test_db
32
- os.environ['PIXELTABLE_PGDATA'] = str(shared_home / 'pgdata')
33
-
34
- # ensure this home dir exits
35
- shared_home.mkdir(parents=True, exist_ok=True)
36
- # this also runs create_all()
37
- Env.get().set_up(echo=True)
38
- yield
39
- # leave db in place for debugging purposes
40
-
41
- @pytest.fixture(scope='function')
42
- def test_client(init_env) -> pxt.Client:
43
- # Clean the DB *before* instantiating a client object. This is because some tests
44
- # (such as test_migration.py) may leave the DB in a broken state, from which the
45
- # client is uninstantiable.
46
- clean_db()
47
- cl = pxt.Client(reload=True)
48
- cl.logging(level=logging.DEBUG, to_stdout=True)
49
- yield cl
50
-
51
-
52
- def clean_db(restore_tables: bool = True) -> None:
53
- from pixeltable.env import Env
54
- # The logic from Client.reset_catalog() has been moved here, so that it
55
- # does not rely on instantiating a Client object. As before, UUID-named data tables will
56
- # not be cleaned. If in the future it is desirable to clean out data tables as well,
57
- # the commented lines may be used to drop ALL tables from the test db.
58
- # sql_md = declarative_base().metadata
59
- # sql_md.reflect(Env.get().engine)
60
- # sql_md.drop_all(bind=Env.get().engine)
61
- engine = Env.get().engine
62
- SystemInfo.__table__.drop(engine, checkfirst=True)
63
- TableSchemaVersion.__table__.drop(engine, checkfirst=True)
64
- TableVersion.__table__.drop(engine, checkfirst=True)
65
- Table.__table__.drop(engine, checkfirst=True)
66
- Function.__table__.drop(engine, checkfirst=True)
67
- Dir.__table__.drop(engine, checkfirst=True)
68
- if restore_tables:
69
- Dir.__table__.create(engine)
70
- Function.__table__.create(engine)
71
- Table.__table__.create(engine)
72
- TableVersion.__table__.create(engine)
73
- TableSchemaVersion.__table__.create(engine)
74
- SystemInfo.__table__.create(engine)
75
- create_system_info(engine)
76
-
77
-
78
- @pytest.fixture(scope='function')
79
- def test_tbl(test_client: pxt.Client) -> catalog.Table:
80
- return create_test_tbl(test_client)
81
-
82
- # @pytest.fixture(scope='function')
83
- # def test_stored_fn(test_client: pxt.Client) -> pxt.Function:
84
- # @pxt.udf(return_type=pxt.IntType(), param_types=[pxt.IntType()])
85
- # def test_fn(x):
86
- # return x + 1
87
- # test_client.create_function('test_fn', test_fn)
88
- # return test_fn
89
-
90
- @pytest.fixture(scope='function')
91
- def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
92
- #def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pxt.Function) -> List[exprs.Expr]:
93
-
94
- t = test_tbl
95
- return [
96
- t.c1,
97
- t.c7['*'].f1,
98
- exprs.Literal('test'),
99
- exprs.InlineDict({
100
- 'a': t.c1, 'b': t.c6.f1, 'c': 17,
101
- 'd': exprs.InlineDict({'e': t.c2}),
102
- 'f': exprs.InlineArray((t.c3, t.c3))
103
- }),
104
- exprs.InlineArray([[t.c2, t.c2], [t.c2, t.c2]]),
105
- t.c2 > 5,
106
- t.c2 == None,
107
- ~(t.c2 > 5),
108
- (t.c2 > 5) & (t.c1 == 'test'),
109
- (t.c2 > 5) | (t.c1 == 'test'),
110
- t.c7['*'].f5 >> [R[3], R[2], R[1], R[0]],
111
- t.c8[0, 1:],
112
- t.c2.astype(FloatType()),
113
- (t.c2 + 1).astype(FloatType()),
114
- t.c2.apply(str),
115
- (t.c2 + 1).apply(str),
116
- t.c3.apply(str),
117
- t.c4.apply(str),
118
- t.c5.apply(str),
119
- t.c6.apply(str),
120
- t.c1.apply(json.loads),
121
- t.c8.errortype,
122
- t.c8.errormsg,
123
- ptf.sum(t.c2, group_by=t.c4, order_by=t.c3),
124
- #test_stored_fn(t.c2),
125
- ]
126
-
127
- @pytest.fixture(scope='function')
128
- def all_datatypes_tbl(test_client: pxt.Client) -> catalog.Table:
129
- return create_all_datatypes_tbl(test_client)
130
-
131
- @pytest.fixture(scope='function')
132
- def img_tbl(test_client: pxt.Client) -> catalog.Table:
133
- schema = {
134
- 'img': ImageType(nullable=False),
135
- 'category': StringType(nullable=False),
136
- 'split': StringType(nullable=False),
137
- }
138
- # this table is not indexed in order to avoid the cost of computing embeddings
139
- tbl = test_client.create_table('test_img_tbl', schema)
140
- rows = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
141
- tbl.insert(rows)
142
- return tbl
143
-
144
- @pytest.fixture(scope='function')
145
- def img_tbl_exprs(img_tbl: catalog.Table) -> List[exprs.Expr]:
146
- img_t = img_tbl
147
- return [
148
- img_t.img.width,
149
- img_t.img.rotate(90),
150
- # we're using a list here, not a tuple; the latter turns into a list during the back/forth conversion
151
- img_t.img.rotate(90).resize([224, 224]),
152
- img_t.img.fileurl,
153
- img_t.img.localpath,
154
- ]
155
-
156
- # TODO: why does this not work with a session scope? (some user tables don't get created with create_all())
157
- #@pytest.fixture(scope='session')
158
- #def indexed_img_tbl(init_env: None) -> catalog.Table:
159
- # cl = pxt.Client()
160
- # db = cl.create_db('test_indexed')
161
- @pytest.fixture(scope='function')
162
- def indexed_img_tbl(test_client: pxt.Client) -> catalog.Table:
163
- skip_test_if_not_installed('nos')
164
- cl = test_client
165
- schema = {
166
- 'img': { 'type': ImageType(nullable=False), 'indexed': True },
167
- 'category': StringType(nullable=False),
168
- 'split': StringType(nullable=False),
169
- }
170
- tbl = cl.create_table('test_indexed_img_tbl', schema)
171
- rows = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
172
- # select output_rows randomly in the hope of getting a good sample of the available categories
173
- rng = np.random.default_rng(17)
174
- idxs = rng.choice(np.arange(len(rows)), size=40, replace=False)
175
- rows = [rows[i] for i in idxs]
176
- tbl.insert(rows)
177
- return tbl
@@ -1,42 +0,0 @@
1
- import pytest
2
-
3
- import pixeltable as pxt
4
- import pixeltable.exceptions as excs
5
- from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
6
-
7
-
8
- class TestFireworks:
9
-
10
- def test_fireworks(self, test_client: pxt.Client) -> None:
11
- skip_test_if_not_installed('fireworks')
12
- TestFireworks.skip_test_if_no_fireworks_client()
13
- cl = test_client
14
- t = cl.create_table('test_tbl', {'input': pxt.StringType()})
15
- from pixeltable.functions.fireworks import chat_completions
16
- messages = [{'role': 'user', 'content': t.input}]
17
- t['output'] = chat_completions(
18
- messages=messages,
19
- model='accounts/fireworks/models/llama-v2-7b-chat'
20
- )
21
- t['output_2'] = chat_completions(
22
- messages=messages,
23
- model='accounts/fireworks/models/llama-v2-7b-chat',
24
- max_tokens=300,
25
- top_k=40,
26
- top_p=0.9,
27
- temperature=0.7
28
- )
29
- validate_update_status(t.insert(input="How's everything going today?"), 1)
30
- results = t.collect()
31
- assert len(results['output'][0]['choices'][0]['message']['content']) > 0
32
- assert len(results['output_2'][0]['choices'][0]['message']['content']) > 0
33
-
34
- # This ensures that the test will be skipped, rather than returning an error, when no API key is
35
- # available (for example, when a PR runs in CI).
36
- @staticmethod
37
- def skip_test_if_no_fireworks_client() -> None:
38
- try:
39
- import pixeltable.functions.fireworks
40
- _ = pixeltable.functions.fireworks.fireworks_client()
41
- except excs.Error as exc:
42
- pytest.skip(str(exc))
@@ -1,60 +0,0 @@
1
- import pixeltable as pxt
2
- from pixeltable import catalog
3
- from pixeltable.functions.pil.image import blend
4
- from pixeltable.iterators import FrameIterator
5
- from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
6
- from pixeltable.type_system import VideoType, StringType
7
-
8
-
9
- class TestFunctions:
10
- def test_pil(self, img_tbl: catalog.Table) -> None:
11
- t = img_tbl
12
- _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
13
-
14
- def test_eval_detections(self, test_client: pxt.Client) -> None:
15
- skip_test_if_not_installed('nos')
16
- cl = test_client
17
- video_t = cl.create_table('video_tbl', {'video': VideoType()})
18
- # create frame view
19
- args = {'video': video_t.video, 'fps': 1}
20
- v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
21
-
22
- files = get_video_files()
23
- video_t.insert(video=files[-1])
24
- v.add_column(frame_s=v.frame.resize([640, 480]))
25
- from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
26
- v.add_column(detections_a=yolox_nano(v.frame_s))
27
- v.add_column(detections_b=yolox_small(v.frame_s))
28
- v.add_column(gt=yolox_large(v.frame_s))
29
- from pixeltable.functions.eval import eval_detections, mean_ap
30
- res = v.select(
31
- eval_detections(
32
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
33
- )).show()
34
- v.add_column(
35
- eval_a=eval_detections(
36
- v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
37
- v.add_column(
38
- eval_b=eval_detections(
39
- v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
40
- ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
41
- ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
42
- common_classes = set(ap_a.keys()) & set(ap_b.keys())
43
-
44
- ## TODO: following assertion is failing on CI,
45
- # It is not necessarily a bug, as assert codition is not expected to be always true
46
- # for k in common_classes:
47
- # assert ap_a[k] <= ap_b[k]
48
-
49
- def test_str(self, test_client: pxt.Client) -> None:
50
- cl = test_client
51
- t = cl.create_table('test_tbl', {'input': StringType()})
52
- from pixeltable.functions.string import str_format
53
- t.add_column(s1=str_format('ABC {0}', t.input))
54
- t.add_column(s2=str_format('DEF {this}', this=t.input))
55
- t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
56
- status = t.insert(input='MNO')
57
- assert status.num_rows == 1
58
- assert status.num_excs == 0
59
- row = t.head()[0]
60
- assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
@@ -1,158 +0,0 @@
1
- from typing import Dict, Any
2
-
3
- import pytest
4
-
5
- import pixeltable as pxt
6
- from pixeltable.tests.utils import skip_test_if_not_installed, get_sentences, get_image_files, \
7
- SAMPLE_IMAGE_URL
8
- from pixeltable.type_system import StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
9
-
10
-
11
- class TestHuggingface:
12
-
13
- def test_hf_function(self, test_client: pxt.Client) -> None:
14
- skip_test_if_not_installed('sentence_transformers')
15
- cl = test_client
16
- t = cl.create_table('test_tbl', {'input': StringType(), 'bool_col': BoolType()})
17
- from pixeltable.functions.huggingface import sentence_transformer
18
- model_id = 'intfloat/e5-large-v2'
19
- t.add_column(e5=sentence_transformer(t.input, model_id=model_id))
20
- sents = get_sentences()
21
- status = t.insert({'input': s, 'bool_col': True} for s in sents)
22
- assert status.num_rows == len(sents)
23
- assert status.num_excs == 0
24
-
25
- # verify handling of constant params
26
- with pytest.raises(ValueError) as exc_info:
27
- t.add_column(e5_2=sentence_transformer(t.input, model_id=t.input))
28
- assert ': parameter model_id must be a constant value' in str(exc_info.value)
29
- with pytest.raises(ValueError) as exc_info:
30
- t.add_column(e5_2=sentence_transformer(t.input, model_id=model_id, normalize_embeddings=t.bool_col))
31
- assert ': parameter normalize_embeddings must be a constant value' in str(exc_info.value)
32
-
33
- # make sure this doesn't cause an exception
34
- # TODO: is there some way to capture the output?
35
- t.describe()
36
-
37
- def test_sentence_transformer(self, test_client: pxt.Client) -> None:
38
- skip_test_if_not_installed('sentence_transformers')
39
- cl = test_client
40
- t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
41
- sents = get_sentences(10)
42
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
43
- assert status.num_rows == len(sents)
44
- assert status.num_excs == 0
45
-
46
- # run multiple models one at a time in order to exercise batching
47
- from pixeltable.functions.huggingface import sentence_transformer, sentence_transformer_list
48
- model_ids = ['sentence-transformers/all-mpnet-base-v2', 'BAAI/bge-reranker-base']
49
- num_dims = [768, 768]
50
- for idx, model_id in enumerate(model_ids):
51
- col_name = f'embed{idx}'
52
- t[col_name] = sentence_transformer(t.input, model_id=model_id, normalize_embeddings=True)
53
- assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
54
- list_col_name = f'embed_list{idx}'
55
- t[list_col_name] = sentence_transformer_list(t.input_list, model_id=model_id, normalize_embeddings=True)
56
- assert t.column_types()[list_col_name] == JsonType()
57
-
58
- def verify_row(row: Dict[str, Any]) -> None:
59
- for idx, (_, d) in enumerate(zip(model_ids, num_dims)):
60
- assert row[f'embed{idx}'].shape == (d,)
61
- assert len(row[f'embed_list{idx}']) == len(sents)
62
- assert all(len(v) == d for v in row[f'embed_list{idx}'])
63
-
64
- verify_row(t.tail(1)[0])
65
-
66
- # execution still works after reload
67
- cl = pxt.Client(reload=True)
68
- t = cl.get_table('test_tbl')
69
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
70
- assert status.num_rows == len(sents)
71
- assert status.num_excs == 0
72
- verify_row(t.tail(1)[0])
73
-
74
- def test_cross_encoder(self, test_client: pxt.Client) -> None:
75
- skip_test_if_not_installed('sentence_transformers')
76
- cl = test_client
77
- t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
78
- sents = get_sentences(10)
79
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
80
- assert status.num_rows == len(sents)
81
- assert status.num_excs == 0
82
-
83
- # run multiple models one at a time in order to exercise batching
84
- from pixeltable.functions.huggingface import cross_encoder, cross_encoder_list
85
- model_ids = ['cross-encoder/ms-marco-MiniLM-L-6-v2', 'cross-encoder/ms-marco-TinyBERT-L-2-v2']
86
- for idx, model_id in enumerate(model_ids):
87
- col_name = f'embed{idx}'
88
- t[col_name] = cross_encoder(t.input, t.input, model_id=model_id)
89
- assert t.column_types()[col_name] == FloatType()
90
- list_col_name = f'embed_list{idx}'
91
- t[list_col_name] = cross_encoder_list(t.input, t.input_list, model_id=model_id)
92
- assert t.column_types()[list_col_name] == JsonType()
93
-
94
- def verify_row(row: Dict[str, Any]) -> None:
95
- for i in range(len(model_ids)):
96
- assert len(row[f'embed_list{idx}']) == len(sents)
97
- assert all(isinstance(v, float) for v in row[f'embed_list{idx}'])
98
-
99
- verify_row(t.tail(1)[0])
100
-
101
- # execution still works after reload
102
- cl = pxt.Client(reload=True)
103
- t = cl.get_table('test_tbl')
104
- status = t.insert({'input': s, 'input_list': sents} for s in sents)
105
- assert status.num_rows == len(sents)
106
- assert status.num_excs == 0
107
- verify_row(t.tail(1)[0])
108
-
109
- def test_clip(self, test_client: pxt.Client) -> None:
110
- skip_test_if_not_installed('transformers')
111
- cl = test_client
112
- t = cl.create_table('test_tbl', {'text': StringType(), 'img': ImageType()})
113
- num_rows = 10
114
- sents = get_sentences(num_rows)
115
- imgs = get_image_files()[:num_rows]
116
- status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
117
- assert status.num_rows == len(sents)
118
- assert status.num_excs == 0
119
-
120
- # run multiple models one at a time in order to exercise batching
121
- from pixeltable.functions.huggingface import clip_text, clip_image
122
- model_ids = ['openai/clip-vit-base-patch32', 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K']
123
- for idx, model_id in enumerate(model_ids):
124
- col_name = f'embed_text{idx}'
125
- t[col_name] = clip_text(t.text, model_id=model_id)
126
- assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
127
- col_name = f'embed_img{idx}'
128
- t[col_name] = clip_image(t.img, model_id=model_id)
129
- assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
130
-
131
- def verify_row(row: Dict[str, Any]) -> None:
132
- for idx, _ in enumerate(model_ids):
133
- assert row[f'embed_text{idx}'].shape == (512,)
134
- assert row[f'embed_img{idx}'].shape == (512,)
135
-
136
- verify_row(t.tail(1)[0])
137
-
138
- # execution still works after reload
139
- cl = pxt.Client(reload=True)
140
- t = cl.get_table('test_tbl')
141
- status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
142
- assert status.num_rows == len(sents)
143
- assert status.num_excs == 0
144
- verify_row(t.tail(1)[0])
145
-
146
- def test_detr_for_object_detection(self, test_client: pxt.Client) -> None:
147
- skip_test_if_not_installed('transformers')
148
- cl = test_client
149
- t = cl.create_table('test_tbl', {'img': ImageType()})
150
- from pixeltable.functions.huggingface import detr_for_object_detection
151
- t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
152
- status = t.insert(img=SAMPLE_IMAGE_URL)
153
- assert status.num_rows == 1
154
- assert status.num_excs == 0
155
- result = t.select(t.detect).collect()[0]['detect']
156
- assert 'orange' in result['label_text']
157
- assert 'bowl' in result['label_text']
158
- assert 'broccoli' in result['label_text']
@@ -1,152 +0,0 @@
1
- import pytest
2
-
3
- import pixeltable as pxt
4
- import pixeltable.exceptions as excs
5
- from pixeltable.tests.utils import SAMPLE_IMAGE_URL, skip_test_if_not_installed, validate_update_status
6
- from pixeltable.type_system import StringType, ImageType
7
-
8
-
9
- class TestOpenai:
10
-
11
- def test_audio(self, test_client: pxt.Client) -> None:
12
- skip_test_if_not_installed('openai')
13
- TestOpenai.skip_test_if_no_openai_client()
14
- cl = test_client
15
- t = cl.create_table('test_tbl', {'input': StringType()})
16
- from pixeltable.functions.openai import speech, transcriptions, translations
17
- t.add_column(speech=speech(t.input, model='tts-1', voice='onyx'))
18
- t.add_column(speech_2=speech(t.input, model='tts-1', voice='onyx', response_format='flac', speed=1.05))
19
- t.add_column(transcription=transcriptions(t.speech, model='whisper-1'))
20
- t.add_column(transcription_2=transcriptions(
21
- t.speech, model='whisper-1', language='en', prompt='Transcribe the contents of this recording.'
22
- ))
23
- t.add_column(translation=translations(t.speech, model='whisper-1'))
24
- t.add_column(translation_2=translations(
25
- t.speech, model='whisper-1', prompt='Translate the recording from Spanish into English.', temperature=0.7
26
- ))
27
- validate_update_status(t.insert([
28
- {'input': 'I am a banana.'},
29
- {'input': 'Es fácil traducir del español al inglés.'}
30
- ]), expected_rows=2)
31
- # The audio generation -> transcription loop on these examples should be simple and clear enough
32
- # that the unit test can reliably expect the output closely enough to pass these checks.
33
- results = t.collect()
34
- assert results[0]['transcription']['text'] in ['I am a banana.', "I'm a banana."]
35
- assert results[0]['transcription_2']['text'] in ['I am a banana.', "I'm a banana."]
36
- assert 'easy to translate from Spanish' in results[1]['translation']['text']
37
- assert 'easy to translate from Spanish' in results[1]['translation_2']['text']
38
-
39
- def test_chat_completions(self, test_client: pxt.Client) -> None:
40
- skip_test_if_not_installed('openai')
41
- TestOpenai.skip_test_if_no_openai_client()
42
- cl = test_client
43
- t = cl.create_table('test_tbl', {'input': StringType()})
44
- from pixeltable.functions.openai import chat_completions
45
- msgs = [
46
- {"role": "system", "content": "You are a helpful assistant."},
47
- {"role": "user", "content": t.input}
48
- ]
49
- t.add_column(input_msgs=msgs)
50
- t.add_column(chat_output=chat_completions(model='gpt-3.5-turbo', messages=t.input_msgs))
51
- # with inlined messages
52
- t.add_column(chat_output_2=chat_completions(model='gpt-3.5-turbo', messages=msgs))
53
- # test a bunch of the parameters
54
- t.add_column(chat_output_3=chat_completions(
55
- model='gpt-3.5-turbo', messages=msgs, frequency_penalty=0.1, logprobs=True, top_logprobs=3,
56
- max_tokens=500, n=3, presence_penalty=0.1, seed=4171780, stop=['\n'], temperature=0.7, top_p=0.8,
57
- user='pixeltable'
58
- ))
59
- # test with JSON output enforced
60
- t.add_column(chat_output_4=chat_completions(
61
- model='gpt-3.5-turbo', messages=msgs, response_format={'type': 'json_object'}
62
- ))
63
- # TODO Also test the `tools` and `tool_choice` parameters.
64
- validate_update_status(t.insert(input='Give me an example of a typical JSON structure.'), 1)
65
- result = t.collect()
66
- assert len(result['chat_output'][0]['choices'][0]['message']['content']) > 0
67
- assert len(result['chat_output_2'][0]['choices'][0]['message']['content']) > 0
68
- assert len(result['chat_output_3'][0]['choices'][0]['message']['content']) > 0
69
- assert len(result['chat_output_4'][0]['choices'][0]['message']['content']) > 0
70
-
71
- # When OpenAI gets a request with `response_format` equal to `json_object`, but the prompt does not
72
- # contain the string "json", it refuses the request.
73
- # TODO This should probably not be throwing an exception, but rather logging the error in
74
- # `t.chat_output_4.errormsg` etc.
75
- with pytest.raises(excs.ExprEvalError) as exc_info:
76
- t.insert(input='Say something interesting.')
77
- assert "\\'messages\\' must contain the word \\'json\\'" in str(exc_info.value)
78
-
79
- def test_gpt_4_vision(self, test_client: pxt.Client) -> None:
80
- skip_test_if_not_installed('openai')
81
- TestOpenai.skip_test_if_no_openai_client()
82
- cl = test_client
83
- t = cl.create_table('test_tbl', {'prompt': StringType(), 'img': ImageType()})
84
- from pixeltable.functions.openai import chat_completions, vision
85
- from pixeltable.functions.string import str_format
86
- t.add_column(response=vision(prompt="What's in this image?", image=t.img))
87
- # Also get the response the low-level way, by calling chat_completions
88
- msgs = [
89
- {'role': 'user',
90
- 'content': [
91
- {'type': 'text', 'text': t.prompt},
92
- {'type': 'image_url', 'image_url': {
93
- 'url': str_format('data:image/png;base64,{0}', t.img.b64_encode())
94
- }}
95
- ]}
96
- ]
97
- t.add_column(response_2=chat_completions(model='gpt-4-vision-preview', messages=msgs, max_tokens=300).choices[0].message.content)
98
- validate_update_status(t.insert(prompt="What's in this image?", img=SAMPLE_IMAGE_URL), 1)
99
- result = t.collect()['response_2'][0]
100
- assert len(result) > 0
101
-
102
- def test_embeddings(self, test_client: pxt.Client) -> None:
103
- skip_test_if_not_installed('openai')
104
- TestOpenai.skip_test_if_no_openai_client()
105
- cl = test_client
106
- from pixeltable.functions.openai import embeddings
107
- t = cl.create_table('test_tbl', {'input': StringType()})
108
- t.add_column(ada_embed=embeddings(model='text-embedding-ada-002', input=t.input))
109
- t.add_column(text_3=embeddings(model='text-embedding-3-small', input=t.input, user='pixeltable'))
110
- validate_update_status(t.insert(input='Say something interesting.'), 1)
111
- _ = t.head()
112
-
113
- def test_moderations(self, test_client: pxt.Client) -> None:
114
- skip_test_if_not_installed('openai')
115
- TestOpenai.skip_test_if_no_openai_client()
116
- cl = test_client
117
- t = cl.create_table('test_tbl', {'input': StringType()})
118
- from pixeltable.functions.openai import moderations
119
- t.add_column(moderation=moderations(input=t.input))
120
- t.add_column(moderation_2=moderations(input=t.input, model='text-moderation-stable'))
121
- validate_update_status(t.insert(input='Say something interesting.'), 1)
122
- _ = t.head()
123
-
124
- def test_image_generations(self, test_client: pxt.Client) -> None:
125
- skip_test_if_not_installed('openai')
126
- TestOpenai.skip_test_if_no_openai_client()
127
- cl = test_client
128
- t = cl.create_table('test_tbl', {'input': StringType()})
129
- from pixeltable.functions.openai import image_generations
130
- t.add_column(img=image_generations(t.input))
131
- # Test dall-e-2 options
132
- t.add_column(img_2=image_generations(
133
- t.input, model='dall-e-2', size='512x512', user='pixeltable'
134
- ))
135
- # Test dall-e-3 options
136
- t.add_column(img_3=image_generations(
137
- t.input, model='dall-e-3', quality='hd', size='1792x1024', style='natural', user='pixeltable'
138
- ))
139
- validate_update_status(t.insert(input='A friendly dinosaur playing tennis in a cornfield'), 1)
140
- assert t.collect()['img'][0].size == (1024, 1024)
141
- assert t.collect()['img_2'][0].size == (512, 512)
142
- assert t.collect()['img_3'][0].size == (1792, 1024)
143
-
144
- # This ensures that the test will be skipped, rather than returning an error, when no API key is
145
- # available (for example, when a PR runs in CI).
146
- @staticmethod
147
- def skip_test_if_no_openai_client() -> None:
148
- try:
149
- import pixeltable.functions.openai
150
- _ = pixeltable.functions.openai.openai_client()
151
- except excs.Error as exc:
152
- pytest.skip(str(exc))