pixeltable 0.1.0__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pixeltable might be problematic. Click here for more details.

Files changed (147) hide show
  1. pixeltable/__init__.py +34 -6
  2. pixeltable/catalog/__init__.py +13 -0
  3. pixeltable/catalog/catalog.py +159 -0
  4. pixeltable/catalog/column.py +200 -0
  5. pixeltable/catalog/dir.py +32 -0
  6. pixeltable/catalog/globals.py +33 -0
  7. pixeltable/catalog/insertable_table.py +191 -0
  8. pixeltable/catalog/named_function.py +36 -0
  9. pixeltable/catalog/path.py +58 -0
  10. pixeltable/catalog/path_dict.py +139 -0
  11. pixeltable/catalog/schema_object.py +39 -0
  12. pixeltable/catalog/table.py +581 -0
  13. pixeltable/catalog/table_version.py +749 -0
  14. pixeltable/catalog/table_version_path.py +133 -0
  15. pixeltable/catalog/view.py +203 -0
  16. pixeltable/client.py +590 -30
  17. pixeltable/dataframe.py +540 -349
  18. pixeltable/env.py +359 -45
  19. pixeltable/exceptions.py +12 -21
  20. pixeltable/exec/__init__.py +9 -0
  21. pixeltable/exec/aggregation_node.py +78 -0
  22. pixeltable/exec/cache_prefetch_node.py +116 -0
  23. pixeltable/exec/component_iteration_node.py +79 -0
  24. pixeltable/exec/data_row_batch.py +95 -0
  25. pixeltable/exec/exec_context.py +22 -0
  26. pixeltable/exec/exec_node.py +61 -0
  27. pixeltable/exec/expr_eval_node.py +217 -0
  28. pixeltable/exec/in_memory_data_node.py +69 -0
  29. pixeltable/exec/media_validation_node.py +43 -0
  30. pixeltable/exec/sql_scan_node.py +225 -0
  31. pixeltable/exprs/__init__.py +24 -0
  32. pixeltable/exprs/arithmetic_expr.py +102 -0
  33. pixeltable/exprs/array_slice.py +71 -0
  34. pixeltable/exprs/column_property_ref.py +77 -0
  35. pixeltable/exprs/column_ref.py +105 -0
  36. pixeltable/exprs/comparison.py +77 -0
  37. pixeltable/exprs/compound_predicate.py +98 -0
  38. pixeltable/exprs/data_row.py +195 -0
  39. pixeltable/exprs/expr.py +586 -0
  40. pixeltable/exprs/expr_set.py +39 -0
  41. pixeltable/exprs/function_call.py +380 -0
  42. pixeltable/exprs/globals.py +69 -0
  43. pixeltable/exprs/image_member_access.py +115 -0
  44. pixeltable/exprs/image_similarity_predicate.py +58 -0
  45. pixeltable/exprs/inline_array.py +107 -0
  46. pixeltable/exprs/inline_dict.py +101 -0
  47. pixeltable/exprs/is_null.py +38 -0
  48. pixeltable/exprs/json_mapper.py +121 -0
  49. pixeltable/exprs/json_path.py +159 -0
  50. pixeltable/exprs/literal.py +54 -0
  51. pixeltable/exprs/object_ref.py +41 -0
  52. pixeltable/exprs/predicate.py +44 -0
  53. pixeltable/exprs/row_builder.py +355 -0
  54. pixeltable/exprs/rowid_ref.py +94 -0
  55. pixeltable/exprs/type_cast.py +53 -0
  56. pixeltable/exprs/variable.py +45 -0
  57. pixeltable/func/__init__.py +9 -0
  58. pixeltable/func/aggregate_function.py +194 -0
  59. pixeltable/func/batched_function.py +53 -0
  60. pixeltable/func/callable_function.py +69 -0
  61. pixeltable/func/expr_template_function.py +82 -0
  62. pixeltable/func/function.py +110 -0
  63. pixeltable/func/function_registry.py +227 -0
  64. pixeltable/func/globals.py +36 -0
  65. pixeltable/func/nos_function.py +202 -0
  66. pixeltable/func/signature.py +166 -0
  67. pixeltable/func/udf.py +163 -0
  68. pixeltable/functions/__init__.py +52 -103
  69. pixeltable/functions/eval.py +216 -0
  70. pixeltable/functions/fireworks.py +34 -0
  71. pixeltable/functions/huggingface.py +120 -0
  72. pixeltable/functions/image.py +16 -0
  73. pixeltable/functions/openai.py +256 -0
  74. pixeltable/functions/pil/image.py +148 -7
  75. pixeltable/functions/string.py +13 -0
  76. pixeltable/functions/together.py +122 -0
  77. pixeltable/functions/util.py +41 -0
  78. pixeltable/functions/video.py +62 -0
  79. pixeltable/iterators/__init__.py +3 -0
  80. pixeltable/iterators/base.py +48 -0
  81. pixeltable/iterators/document.py +311 -0
  82. pixeltable/iterators/video.py +89 -0
  83. pixeltable/metadata/__init__.py +54 -0
  84. pixeltable/metadata/converters/convert_10.py +18 -0
  85. pixeltable/metadata/schema.py +211 -0
  86. pixeltable/plan.py +656 -0
  87. pixeltable/store.py +418 -182
  88. pixeltable/tests/conftest.py +146 -88
  89. pixeltable/tests/functions/test_fireworks.py +42 -0
  90. pixeltable/tests/functions/test_functions.py +60 -0
  91. pixeltable/tests/functions/test_huggingface.py +158 -0
  92. pixeltable/tests/functions/test_openai.py +152 -0
  93. pixeltable/tests/functions/test_together.py +111 -0
  94. pixeltable/tests/test_audio.py +65 -0
  95. pixeltable/tests/test_catalog.py +27 -0
  96. pixeltable/tests/test_client.py +14 -14
  97. pixeltable/tests/test_component_view.py +370 -0
  98. pixeltable/tests/test_dataframe.py +439 -0
  99. pixeltable/tests/test_dirs.py +78 -62
  100. pixeltable/tests/test_document.py +120 -0
  101. pixeltable/tests/test_exprs.py +592 -135
  102. pixeltable/tests/test_function.py +297 -67
  103. pixeltable/tests/test_migration.py +43 -0
  104. pixeltable/tests/test_nos.py +54 -0
  105. pixeltable/tests/test_snapshot.py +208 -0
  106. pixeltable/tests/test_table.py +1195 -263
  107. pixeltable/tests/test_transactional_directory.py +42 -0
  108. pixeltable/tests/test_types.py +5 -11
  109. pixeltable/tests/test_video.py +151 -34
  110. pixeltable/tests/test_view.py +530 -0
  111. pixeltable/tests/utils.py +320 -45
  112. pixeltable/tool/create_test_db_dump.py +149 -0
  113. pixeltable/tool/create_test_video.py +81 -0
  114. pixeltable/type_system.py +445 -124
  115. pixeltable/utils/__init__.py +17 -46
  116. pixeltable/utils/arrow.py +98 -0
  117. pixeltable/utils/clip.py +12 -15
  118. pixeltable/utils/coco.py +136 -0
  119. pixeltable/utils/documents.py +39 -0
  120. pixeltable/utils/filecache.py +195 -0
  121. pixeltable/utils/help.py +11 -0
  122. pixeltable/utils/hf_datasets.py +157 -0
  123. pixeltable/utils/media_store.py +76 -0
  124. pixeltable/utils/parquet.py +167 -0
  125. pixeltable/utils/pytorch.py +91 -0
  126. pixeltable/utils/s3.py +13 -0
  127. pixeltable/utils/sql.py +17 -0
  128. pixeltable/utils/transactional_directory.py +35 -0
  129. pixeltable-0.2.4.dist-info/LICENSE +18 -0
  130. pixeltable-0.2.4.dist-info/METADATA +127 -0
  131. pixeltable-0.2.4.dist-info/RECORD +132 -0
  132. {pixeltable-0.1.0.dist-info → pixeltable-0.2.4.dist-info}/WHEEL +1 -1
  133. pixeltable/catalog.py +0 -1421
  134. pixeltable/exprs.py +0 -1745
  135. pixeltable/function.py +0 -269
  136. pixeltable/functions/clip.py +0 -10
  137. pixeltable/functions/pil/__init__.py +0 -23
  138. pixeltable/functions/tf.py +0 -21
  139. pixeltable/index.py +0 -57
  140. pixeltable/tests/test_dict.py +0 -24
  141. pixeltable/tests/test_functions.py +0 -11
  142. pixeltable/tests/test_tf.py +0 -69
  143. pixeltable/tf.py +0 -33
  144. pixeltable/utils/tf.py +0 -33
  145. pixeltable/utils/video.py +0 -32
  146. pixeltable-0.1.0.dist-info/METADATA +0 -34
  147. pixeltable-0.1.0.dist-info/RECORD +0 -36
@@ -1,119 +1,177 @@
1
- import numpy as np
2
- import pandas as pd
3
- import datetime
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ from typing import List
4
6
 
7
+ import numpy as np
5
8
  import pytest
6
9
 
7
- import pixeltable as pt
10
+ import pixeltable as pxt
8
11
  import pixeltable.catalog as catalog
9
- from pixeltable.type_system import \
10
- StringType, IntType, FloatType, BoolType, TimestampType, ImageType, JsonType
11
- from pixeltable.tests.utils import read_data_file, make_tbl, create_table_data
12
+ from pixeltable import exprs
13
+ from pixeltable import functions as ptf
14
+ from pixeltable.exprs import RELATIVE_PATH_ROOT as R
15
+ from pixeltable.metadata import SystemInfo, create_system_info
16
+ from pixeltable.metadata.schema import TableSchemaVersion, TableVersion, Table, Function, Dir
17
+ from pixeltable.tests.utils import read_data_file, create_test_tbl, create_all_datatypes_tbl, skip_test_if_not_installed
18
+ from pixeltable.type_system import StringType, ImageType, FloatType
12
19
 
13
20
 
14
21
  @pytest.fixture(scope='session')
15
- def init_db(tmp_path_factory) -> None:
22
+ def init_env(tmp_path_factory) -> None:
16
23
  from pixeltable.env import Env
24
+ # set the relevant env vars for Client() to connect to the test db
25
+
26
+ shared_home = pathlib.Path(os.environ.get('PIXELTABLE_HOME', str(pathlib.Path.home() / '.pixeltable')))
27
+ home_dir = str(tmp_path_factory.mktemp('base') / '.pixeltable')
28
+ os.environ['PIXELTABLE_HOME'] = home_dir
29
+ os.environ['PIXELTABLE_CONFIG'] = str(shared_home / 'config.yaml')
30
+ test_db = 'test'
31
+ os.environ['PIXELTABLE_DB'] = test_db
32
+ os.environ['PIXELTABLE_PGDATA'] = str(shared_home / 'pgdata')
33
+
34
+ # ensure this home dir exits
35
+ shared_home.mkdir(parents=True, exist_ok=True)
17
36
  # this also runs create_all()
18
- db_name = 'test'
19
- Env.get().set_up(tmp_path_factory.mktemp('base'), db_name=db_name, echo=True)
37
+ Env.get().set_up(echo=True)
20
38
  yield
21
39
  # leave db in place for debugging purposes
22
40
 
41
+ @pytest.fixture(scope='function')
42
+ def test_client(init_env) -> pxt.Client:
43
+ # Clean the DB *before* instantiating a client object. This is because some tests
44
+ # (such as test_migration.py) may leave the DB in a broken state, from which the
45
+ # client is uninstantiable.
46
+ clean_db()
47
+ cl = pxt.Client(reload=True)
48
+ cl.logging(level=logging.DEBUG, to_stdout=True)
49
+ yield cl
50
+
51
+
52
+ def clean_db(restore_tables: bool = True) -> None:
53
+ from pixeltable.env import Env
54
+ # The logic from Client.reset_catalog() has been moved here, so that it
55
+ # does not rely on instantiating a Client object. As before, UUID-named data tables will
56
+ # not be cleaned. If in the future it is desirable to clean out data tables as well,
57
+ # the commented lines may be used to drop ALL tables from the test db.
58
+ # sql_md = declarative_base().metadata
59
+ # sql_md.reflect(Env.get().engine)
60
+ # sql_md.drop_all(bind=Env.get().engine)
61
+ engine = Env.get().engine
62
+ SystemInfo.__table__.drop(engine, checkfirst=True)
63
+ TableSchemaVersion.__table__.drop(engine, checkfirst=True)
64
+ TableVersion.__table__.drop(engine, checkfirst=True)
65
+ Table.__table__.drop(engine, checkfirst=True)
66
+ Function.__table__.drop(engine, checkfirst=True)
67
+ Dir.__table__.drop(engine, checkfirst=True)
68
+ if restore_tables:
69
+ Dir.__table__.create(engine)
70
+ Function.__table__.create(engine)
71
+ Table.__table__.create(engine)
72
+ TableVersion.__table__.create(engine)
73
+ TableSchemaVersion.__table__.create(engine)
74
+ SystemInfo.__table__.create(engine)
75
+ create_system_info(engine)
76
+
23
77
 
24
78
  @pytest.fixture(scope='function')
25
- def test_db(init_db: None) -> pt.Db:
26
- cl = pt.Client()
27
- db = cl.create_db(f'test')
28
- yield db
29
- cl.drop_db(db.name, force=True)
79
+ def test_tbl(test_client: pxt.Client) -> catalog.Table:
80
+ return create_test_tbl(test_client)
30
81
 
82
+ # @pytest.fixture(scope='function')
83
+ # def test_stored_fn(test_client: pxt.Client) -> pxt.Function:
84
+ # @pxt.udf(return_type=pxt.IntType(), param_types=[pxt.IntType()])
85
+ # def test_fn(x):
86
+ # return x + 1
87
+ # test_client.create_function('test_fn', test_fn)
88
+ # return test_fn
31
89
 
32
90
  @pytest.fixture(scope='function')
33
- def test_tbl(test_db: pt.Db) -> catalog.Table:
34
- cols = [
35
- catalog.Column('c1', StringType(), nullable=False),
36
- catalog.Column('c2', IntType(), nullable=False),
37
- catalog.Column('c3', FloatType(), nullable=False),
38
- catalog.Column('c4', BoolType(), nullable=False),
39
- catalog.Column('c5', TimestampType(), nullable=False),
40
- catalog.Column('c6', JsonType(), nullable=False),
41
- catalog.Column('c7', JsonType(), nullable=False),
91
+ def test_tbl_exprs(test_tbl: catalog.Table) -> List[exprs.Expr]:
92
+ #def test_tbl_exprs(test_tbl: catalog.Table, test_stored_fn: pxt.Function) -> List[exprs.Expr]:
93
+
94
+ t = test_tbl
95
+ return [
96
+ t.c1,
97
+ t.c7['*'].f1,
98
+ exprs.Literal('test'),
99
+ exprs.InlineDict({
100
+ 'a': t.c1, 'b': t.c6.f1, 'c': 17,
101
+ 'd': exprs.InlineDict({'e': t.c2}),
102
+ 'f': exprs.InlineArray((t.c3, t.c3))
103
+ }),
104
+ exprs.InlineArray([[t.c2, t.c2], [t.c2, t.c2]]),
105
+ t.c2 > 5,
106
+ t.c2 == None,
107
+ ~(t.c2 > 5),
108
+ (t.c2 > 5) & (t.c1 == 'test'),
109
+ (t.c2 > 5) | (t.c1 == 'test'),
110
+ t.c7['*'].f5 >> [R[3], R[2], R[1], R[0]],
111
+ t.c8[0, 1:],
112
+ t.c2.astype(FloatType()),
113
+ (t.c2 + 1).astype(FloatType()),
114
+ t.c2.apply(str),
115
+ (t.c2 + 1).apply(str),
116
+ t.c3.apply(str),
117
+ t.c4.apply(str),
118
+ t.c5.apply(str),
119
+ t.c6.apply(str),
120
+ t.c1.apply(json.loads),
121
+ t.c8.errortype,
122
+ t.c8.errormsg,
123
+ ptf.sum(t.c2, group_by=t.c4, order_by=t.c3),
124
+ #test_stored_fn(t.c2),
42
125
  ]
43
- t = test_db.create_table('test__tbl', cols)
44
-
45
- num_rows = 100
46
- d1 = {
47
- 'f1': 'test string 1',
48
- 'f2': 1,
49
- 'f3': 1.0,
50
- 'f4': True,
51
- 'f5': [1.0, 2.0, 3.0, 4.0],
52
- 'f6': {
53
- 'f7': 'test string 2',
54
- 'f8': [1.0, 2.0, 3.0, 4.0],
55
- },
56
- }
57
- d2 = [d1, d1]
58
-
59
- c1_data = [f'test string {i}' for i in range(num_rows)]
60
- c2_data = [i for i in range(num_rows)]
61
- c3_data = [float(i) for i in range(num_rows)]
62
- c4_data = [bool(i % 2) for i in range(num_rows)]
63
- c5_data = [datetime.datetime.now()] * num_rows
64
- c6_data = []
65
- for i in range(num_rows):
66
- d = {
67
- 'f1': f'test string {i}',
68
- 'f2': i,
69
- 'f3': float(i),
70
- 'f4': bool(i % 2),
71
- 'f5': [1.0, 2.0, 3.0, 4.0],
72
- 'f6': {
73
- 'f7': 'test string 2',
74
- 'f8': [1.0, 2.0, 3.0, 4.0],
75
- },
76
- }
77
- c6_data.append(d)
78
-
79
- c7_data = [d2] * num_rows
80
- data = {'c1': c1_data, 'c2': c2_data, 'c3': c3_data, 'c4': c4_data, 'c5': c5_data, 'c6': c6_data, 'c7': c7_data}
81
- pd_df = pd.DataFrame(data=data)
82
- t.insert_pandas(pd_df)
83
- return t
84
126
 
127
+ @pytest.fixture(scope='function')
128
+ def all_datatypes_tbl(test_client: pxt.Client) -> catalog.Table:
129
+ return create_all_datatypes_tbl(test_client)
85
130
 
86
131
  @pytest.fixture(scope='function')
87
- def img_tbl(test_db: pt.Db) -> catalog.Table:
88
- cols = [
89
- catalog.Column('img', ImageType(), nullable=False, indexed=False),
90
- catalog.Column('category', StringType(), nullable=False),
91
- catalog.Column('split', StringType(), nullable=False),
92
- ]
132
+ def img_tbl(test_client: pxt.Client) -> catalog.Table:
133
+ schema = {
134
+ 'img': ImageType(nullable=False),
135
+ 'category': StringType(nullable=False),
136
+ 'split': StringType(nullable=False),
137
+ }
93
138
  # this table is not indexed in order to avoid the cost of computing embeddings
94
- tbl = test_db.create_table('test_img_tbl', cols)
95
- df = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
96
- tbl.insert_pandas(df)
139
+ tbl = test_client.create_table('test_img_tbl', schema)
140
+ rows = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
141
+ tbl.insert(rows)
97
142
  return tbl
98
143
 
144
+ @pytest.fixture(scope='function')
145
+ def img_tbl_exprs(img_tbl: catalog.Table) -> List[exprs.Expr]:
146
+ img_t = img_tbl
147
+ return [
148
+ img_t.img.width,
149
+ img_t.img.rotate(90),
150
+ # we're using a list here, not a tuple; the latter turns into a list during the back/forth conversion
151
+ img_t.img.rotate(90).resize([224, 224]),
152
+ img_t.img.fileurl,
153
+ img_t.img.localpath,
154
+ ]
99
155
 
100
156
  # TODO: why does this not work with a session scope? (some user tables don't get created with create_all())
101
157
  #@pytest.fixture(scope='session')
102
- #def indexed_img_tbl(init_db: None) -> catalog.Table:
103
- # cl = pt.Client()
158
+ #def indexed_img_tbl(init_env: None) -> catalog.Table:
159
+ # cl = pxt.Client()
104
160
  # db = cl.create_db('test_indexed')
105
161
  @pytest.fixture(scope='function')
106
- def indexed_img_tbl(test_db: pt.Db) -> catalog.Table:
107
- db = test_db
108
- cols = [
109
- catalog.Column('img', ImageType(), nullable=False, indexed=True),
110
- catalog.Column('category', StringType(), nullable=False),
111
- catalog.Column('split', StringType(), nullable=False),
112
- ]
113
- tbl = db.create_table('test_indexed_img_tbl', cols)
114
- df = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
115
- # select rows randomly in the hope of getting a good sample of the available categories
162
+ def indexed_img_tbl(test_client: pxt.Client) -> catalog.Table:
163
+ skip_test_if_not_installed('nos')
164
+ cl = test_client
165
+ schema = {
166
+ 'img': { 'type': ImageType(nullable=False), 'indexed': True },
167
+ 'category': StringType(nullable=False),
168
+ 'split': StringType(nullable=False),
169
+ }
170
+ tbl = cl.create_table('test_indexed_img_tbl', schema)
171
+ rows = read_data_file('imagenette2-160', 'manifest.csv', ['img'])
172
+ # select output_rows randomly in the hope of getting a good sample of the available categories
116
173
  rng = np.random.default_rng(17)
117
- idxs = rng.choice(np.arange(len(df)), size=40, replace=False)
118
- tbl.insert_pandas(df.iloc[idxs])
174
+ idxs = rng.choice(np.arange(len(rows)), size=40, replace=False)
175
+ rows = [rows[i] for i in idxs]
176
+ tbl.insert(rows)
119
177
  return tbl
@@ -0,0 +1,42 @@
1
+ import pytest
2
+
3
+ import pixeltable as pxt
4
+ import pixeltable.exceptions as excs
5
+ from pixeltable.tests.utils import skip_test_if_not_installed, validate_update_status
6
+
7
+
8
+ class TestFireworks:
9
+
10
+ def test_fireworks(self, test_client: pxt.Client) -> None:
11
+ skip_test_if_not_installed('fireworks')
12
+ TestFireworks.skip_test_if_no_fireworks_client()
13
+ cl = test_client
14
+ t = cl.create_table('test_tbl', {'input': pxt.StringType()})
15
+ from pixeltable.functions.fireworks import chat_completions
16
+ messages = [{'role': 'user', 'content': t.input}]
17
+ t['output'] = chat_completions(
18
+ messages=messages,
19
+ model='accounts/fireworks/models/llama-v2-7b-chat'
20
+ )
21
+ t['output_2'] = chat_completions(
22
+ messages=messages,
23
+ model='accounts/fireworks/models/llama-v2-7b-chat',
24
+ max_tokens=300,
25
+ top_k=40,
26
+ top_p=0.9,
27
+ temperature=0.7
28
+ )
29
+ validate_update_status(t.insert(input="How's everything going today?"), 1)
30
+ results = t.collect()
31
+ assert len(results['output'][0]['choices'][0]['message']['content']) > 0
32
+ assert len(results['output_2'][0]['choices'][0]['message']['content']) > 0
33
+
34
+ # This ensures that the test will be skipped, rather than returning an error, when no API key is
35
+ # available (for example, when a PR runs in CI).
36
+ @staticmethod
37
+ def skip_test_if_no_fireworks_client() -> None:
38
+ try:
39
+ import pixeltable.functions.fireworks
40
+ _ = pixeltable.functions.fireworks.fireworks_client()
41
+ except excs.Error as exc:
42
+ pytest.skip(str(exc))
@@ -0,0 +1,60 @@
1
+ import pixeltable as pxt
2
+ from pixeltable import catalog
3
+ from pixeltable.functions.pil.image import blend
4
+ from pixeltable.iterators import FrameIterator
5
+ from pixeltable.tests.utils import get_video_files, skip_test_if_not_installed
6
+ from pixeltable.type_system import VideoType, StringType
7
+
8
+
9
+ class TestFunctions:
10
+ def test_pil(self, img_tbl: catalog.Table) -> None:
11
+ t = img_tbl
12
+ _ = t[t.img, t.img.rotate(90), blend(t.img, t.img.rotate(90), 0.5)].show()
13
+
14
+ def test_eval_detections(self, test_client: pxt.Client) -> None:
15
+ skip_test_if_not_installed('nos')
16
+ cl = test_client
17
+ video_t = cl.create_table('video_tbl', {'video': VideoType()})
18
+ # create frame view
19
+ args = {'video': video_t.video, 'fps': 1}
20
+ v = cl.create_view('test_view', video_t, iterator_class=FrameIterator, iterator_args=args)
21
+
22
+ files = get_video_files()
23
+ video_t.insert(video=files[-1])
24
+ v.add_column(frame_s=v.frame.resize([640, 480]))
25
+ from pixeltable.functions.nos.object_detection_2d import yolox_nano, yolox_small, yolox_large
26
+ v.add_column(detections_a=yolox_nano(v.frame_s))
27
+ v.add_column(detections_b=yolox_small(v.frame_s))
28
+ v.add_column(gt=yolox_large(v.frame_s))
29
+ from pixeltable.functions.eval import eval_detections, mean_ap
30
+ res = v.select(
31
+ eval_detections(
32
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels
33
+ )).show()
34
+ v.add_column(
35
+ eval_a=eval_detections(
36
+ v.detections_a.bboxes, v.detections_a.labels, v.detections_a.scores, v.gt.bboxes, v.gt.labels))
37
+ v.add_column(
38
+ eval_b=eval_detections(
39
+ v.detections_b.bboxes, v.detections_b.labels, v.detections_b.scores, v.gt.bboxes, v.gt.labels))
40
+ ap_a = v.select(mean_ap(v.eval_a)).show()[0, 0]
41
+ ap_b = v.select(mean_ap(v.eval_b)).show()[0, 0]
42
+ common_classes = set(ap_a.keys()) & set(ap_b.keys())
43
+
44
+ ## TODO: following assertion is failing on CI,
45
+ # It is not necessarily a bug, as assert codition is not expected to be always true
46
+ # for k in common_classes:
47
+ # assert ap_a[k] <= ap_b[k]
48
+
49
+ def test_str(self, test_client: pxt.Client) -> None:
50
+ cl = test_client
51
+ t = cl.create_table('test_tbl', {'input': StringType()})
52
+ from pixeltable.functions.string import str_format
53
+ t.add_column(s1=str_format('ABC {0}', t.input))
54
+ t.add_column(s2=str_format('DEF {this}', this=t.input))
55
+ t.add_column(s3=str_format('GHI {0} JKL {this}', t.input, this=t.input))
56
+ status = t.insert(input='MNO')
57
+ assert status.num_rows == 1
58
+ assert status.num_excs == 0
59
+ row = t.head()[0]
60
+ assert row == {'input': 'MNO', 's1': 'ABC MNO', 's2': 'DEF MNO', 's3': 'GHI MNO JKL MNO'}
@@ -0,0 +1,158 @@
1
+ from typing import Dict, Any
2
+
3
+ import pytest
4
+
5
+ import pixeltable as pxt
6
+ from pixeltable.tests.utils import skip_test_if_not_installed, get_sentences, get_image_files, \
7
+ SAMPLE_IMAGE_URL
8
+ from pixeltable.type_system import StringType, JsonType, ImageType, BoolType, FloatType, ArrayType
9
+
10
+
11
+ class TestHuggingface:
12
+
13
+ def test_hf_function(self, test_client: pxt.Client) -> None:
14
+ skip_test_if_not_installed('sentence_transformers')
15
+ cl = test_client
16
+ t = cl.create_table('test_tbl', {'input': StringType(), 'bool_col': BoolType()})
17
+ from pixeltable.functions.huggingface import sentence_transformer
18
+ model_id = 'intfloat/e5-large-v2'
19
+ t.add_column(e5=sentence_transformer(t.input, model_id=model_id))
20
+ sents = get_sentences()
21
+ status = t.insert({'input': s, 'bool_col': True} for s in sents)
22
+ assert status.num_rows == len(sents)
23
+ assert status.num_excs == 0
24
+
25
+ # verify handling of constant params
26
+ with pytest.raises(ValueError) as exc_info:
27
+ t.add_column(e5_2=sentence_transformer(t.input, model_id=t.input))
28
+ assert ': parameter model_id must be a constant value' in str(exc_info.value)
29
+ with pytest.raises(ValueError) as exc_info:
30
+ t.add_column(e5_2=sentence_transformer(t.input, model_id=model_id, normalize_embeddings=t.bool_col))
31
+ assert ': parameter normalize_embeddings must be a constant value' in str(exc_info.value)
32
+
33
+ # make sure this doesn't cause an exception
34
+ # TODO: is there some way to capture the output?
35
+ t.describe()
36
+
37
+ def test_sentence_transformer(self, test_client: pxt.Client) -> None:
38
+ skip_test_if_not_installed('sentence_transformers')
39
+ cl = test_client
40
+ t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
41
+ sents = get_sentences(10)
42
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
43
+ assert status.num_rows == len(sents)
44
+ assert status.num_excs == 0
45
+
46
+ # run multiple models one at a time in order to exercise batching
47
+ from pixeltable.functions.huggingface import sentence_transformer, sentence_transformer_list
48
+ model_ids = ['sentence-transformers/all-mpnet-base-v2', 'BAAI/bge-reranker-base']
49
+ num_dims = [768, 768]
50
+ for idx, model_id in enumerate(model_ids):
51
+ col_name = f'embed{idx}'
52
+ t[col_name] = sentence_transformer(t.input, model_id=model_id, normalize_embeddings=True)
53
+ assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
54
+ list_col_name = f'embed_list{idx}'
55
+ t[list_col_name] = sentence_transformer_list(t.input_list, model_id=model_id, normalize_embeddings=True)
56
+ assert t.column_types()[list_col_name] == JsonType()
57
+
58
+ def verify_row(row: Dict[str, Any]) -> None:
59
+ for idx, (_, d) in enumerate(zip(model_ids, num_dims)):
60
+ assert row[f'embed{idx}'].shape == (d,)
61
+ assert len(row[f'embed_list{idx}']) == len(sents)
62
+ assert all(len(v) == d for v in row[f'embed_list{idx}'])
63
+
64
+ verify_row(t.tail(1)[0])
65
+
66
+ # execution still works after reload
67
+ cl = pxt.Client(reload=True)
68
+ t = cl.get_table('test_tbl')
69
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
70
+ assert status.num_rows == len(sents)
71
+ assert status.num_excs == 0
72
+ verify_row(t.tail(1)[0])
73
+
74
+ def test_cross_encoder(self, test_client: pxt.Client) -> None:
75
+ skip_test_if_not_installed('sentence_transformers')
76
+ cl = test_client
77
+ t = cl.create_table('test_tbl', {'input': StringType(), 'input_list': JsonType()})
78
+ sents = get_sentences(10)
79
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
80
+ assert status.num_rows == len(sents)
81
+ assert status.num_excs == 0
82
+
83
+ # run multiple models one at a time in order to exercise batching
84
+ from pixeltable.functions.huggingface import cross_encoder, cross_encoder_list
85
+ model_ids = ['cross-encoder/ms-marco-MiniLM-L-6-v2', 'cross-encoder/ms-marco-TinyBERT-L-2-v2']
86
+ for idx, model_id in enumerate(model_ids):
87
+ col_name = f'embed{idx}'
88
+ t[col_name] = cross_encoder(t.input, t.input, model_id=model_id)
89
+ assert t.column_types()[col_name] == FloatType()
90
+ list_col_name = f'embed_list{idx}'
91
+ t[list_col_name] = cross_encoder_list(t.input, t.input_list, model_id=model_id)
92
+ assert t.column_types()[list_col_name] == JsonType()
93
+
94
+ def verify_row(row: Dict[str, Any]) -> None:
95
+ for i in range(len(model_ids)):
96
+ assert len(row[f'embed_list{idx}']) == len(sents)
97
+ assert all(isinstance(v, float) for v in row[f'embed_list{idx}'])
98
+
99
+ verify_row(t.tail(1)[0])
100
+
101
+ # execution still works after reload
102
+ cl = pxt.Client(reload=True)
103
+ t = cl.get_table('test_tbl')
104
+ status = t.insert({'input': s, 'input_list': sents} for s in sents)
105
+ assert status.num_rows == len(sents)
106
+ assert status.num_excs == 0
107
+ verify_row(t.tail(1)[0])
108
+
109
+ def test_clip(self, test_client: pxt.Client) -> None:
110
+ skip_test_if_not_installed('transformers')
111
+ cl = test_client
112
+ t = cl.create_table('test_tbl', {'text': StringType(), 'img': ImageType()})
113
+ num_rows = 10
114
+ sents = get_sentences(num_rows)
115
+ imgs = get_image_files()[:num_rows]
116
+ status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
117
+ assert status.num_rows == len(sents)
118
+ assert status.num_excs == 0
119
+
120
+ # run multiple models one at a time in order to exercise batching
121
+ from pixeltable.functions.huggingface import clip_text, clip_image
122
+ model_ids = ['openai/clip-vit-base-patch32', 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K']
123
+ for idx, model_id in enumerate(model_ids):
124
+ col_name = f'embed_text{idx}'
125
+ t[col_name] = clip_text(t.text, model_id=model_id)
126
+ assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
127
+ col_name = f'embed_img{idx}'
128
+ t[col_name] = clip_image(t.img, model_id=model_id)
129
+ assert t.column_types()[col_name] == ArrayType((None,), dtype=FloatType(), nullable=False)
130
+
131
+ def verify_row(row: Dict[str, Any]) -> None:
132
+ for idx, _ in enumerate(model_ids):
133
+ assert row[f'embed_text{idx}'].shape == (512,)
134
+ assert row[f'embed_img{idx}'].shape == (512,)
135
+
136
+ verify_row(t.tail(1)[0])
137
+
138
+ # execution still works after reload
139
+ cl = pxt.Client(reload=True)
140
+ t = cl.get_table('test_tbl')
141
+ status = t.insert({'text': text, 'img': img} for text, img in zip(sents, imgs))
142
+ assert status.num_rows == len(sents)
143
+ assert status.num_excs == 0
144
+ verify_row(t.tail(1)[0])
145
+
146
+ def test_detr_for_object_detection(self, test_client: pxt.Client) -> None:
147
+ skip_test_if_not_installed('transformers')
148
+ cl = test_client
149
+ t = cl.create_table('test_tbl', {'img': ImageType()})
150
+ from pixeltable.functions.huggingface import detr_for_object_detection
151
+ t['detect'] = detr_for_object_detection(t.img, model_id='facebook/detr-resnet-50', threshold=0.8)
152
+ status = t.insert(img=SAMPLE_IMAGE_URL)
153
+ assert status.num_rows == 1
154
+ assert status.num_excs == 0
155
+ result = t.select(t.detect).collect()[0]['detect']
156
+ assert 'orange' in result['label_text']
157
+ assert 'bowl' in result['label_text']
158
+ assert 'broccoli' in result['label_text']