pixeltable 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +1 -1
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/catalog.py +8 -7
- pixeltable/catalog/column.py +11 -8
- pixeltable/catalog/insertable_table.py +1 -1
- pixeltable/catalog/path_dict.py +8 -6
- pixeltable/catalog/table.py +20 -14
- pixeltable/catalog/table_version.py +92 -55
- pixeltable/catalog/table_version_path.py +7 -9
- pixeltable/catalog/view.py +3 -2
- pixeltable/dataframe.py +2 -2
- pixeltable/env.py +205 -86
- pixeltable/exceptions.py +5 -1
- pixeltable/exec/aggregation_node.py +2 -1
- pixeltable/exec/component_iteration_node.py +2 -2
- pixeltable/exec/sql_node.py +11 -8
- pixeltable/exprs/__init__.py +2 -2
- pixeltable/exprs/arithmetic_expr.py +4 -4
- pixeltable/exprs/array_slice.py +2 -1
- pixeltable/exprs/column_property_ref.py +9 -7
- pixeltable/exprs/column_ref.py +2 -1
- pixeltable/exprs/comparison.py +10 -7
- pixeltable/exprs/compound_predicate.py +3 -2
- pixeltable/exprs/data_row.py +19 -4
- pixeltable/exprs/expr.py +51 -41
- pixeltable/exprs/expr_set.py +32 -9
- pixeltable/exprs/function_call.py +62 -40
- pixeltable/exprs/in_predicate.py +3 -2
- pixeltable/exprs/inline_expr.py +200 -0
- pixeltable/exprs/is_null.py +3 -2
- pixeltable/exprs/json_mapper.py +5 -4
- pixeltable/exprs/json_path.py +7 -1
- pixeltable/exprs/literal.py +34 -7
- pixeltable/exprs/method_ref.py +3 -3
- pixeltable/exprs/object_ref.py +6 -5
- pixeltable/exprs/row_builder.py +25 -17
- pixeltable/exprs/rowid_ref.py +2 -1
- pixeltable/exprs/similarity_expr.py +2 -1
- pixeltable/exprs/sql_element_cache.py +30 -0
- pixeltable/exprs/type_cast.py +3 -3
- pixeltable/exprs/variable.py +2 -1
- pixeltable/ext/functions/whisperx.py +6 -4
- pixeltable/ext/functions/yolox.py +11 -9
- pixeltable/func/aggregate_function.py +1 -0
- pixeltable/func/function.py +28 -4
- pixeltable/functions/__init__.py +4 -2
- pixeltable/functions/anthropic.py +15 -5
- pixeltable/functions/fireworks.py +1 -1
- pixeltable/functions/globals.py +6 -1
- pixeltable/functions/huggingface.py +91 -14
- pixeltable/functions/image.py +20 -5
- pixeltable/functions/json.py +5 -5
- pixeltable/functions/mistralai.py +188 -0
- pixeltable/functions/openai.py +6 -10
- pixeltable/functions/string.py +3 -2
- pixeltable/functions/timestamp.py +95 -7
- pixeltable/functions/together.py +18 -11
- pixeltable/functions/video.py +2 -2
- pixeltable/functions/vision.py +69 -37
- pixeltable/functions/whisper.py +4 -1
- pixeltable/globals.py +5 -1
- pixeltable/io/hf_datasets.py +17 -15
- pixeltable/io/pandas.py +0 -2
- pixeltable/io/parquet.py +15 -14
- pixeltable/iterators/document.py +16 -15
- pixeltable/metadata/__init__.py +1 -1
- pixeltable/metadata/converters/convert_18.py +1 -1
- pixeltable/metadata/converters/convert_19.py +46 -0
- pixeltable/metadata/converters/convert_20.py +56 -0
- pixeltable/metadata/converters/util.py +29 -4
- pixeltable/metadata/notes.py +2 -0
- pixeltable/metadata/schema.py +5 -4
- pixeltable/plan.py +100 -78
- pixeltable/store.py +5 -1
- pixeltable/tool/create_test_db_dump.py +18 -6
- pixeltable/type_system.py +15 -15
- pixeltable/utils/documents.py +45 -42
- pixeltable/utils/formatter.py +2 -2
- pixeltable-0.2.19.dist-info/LICENSE +201 -0
- {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/METADATA +84 -24
- pixeltable-0.2.19.dist-info/RECORD +147 -0
- pixeltable/exprs/inline_array.py +0 -116
- pixeltable/exprs/inline_dict.py +0 -103
- pixeltable-0.2.17.dist-info/LICENSE +0 -18
- pixeltable-0.2.17.dist-info/RECORD +0 -144
- {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.17.dist-info → pixeltable-0.2.19.dist-info}/entry_points.txt +0 -0
pixeltable/functions/openai.py
CHANGED
|
@@ -9,10 +9,10 @@ import base64
|
|
|
9
9
|
import io
|
|
10
10
|
import pathlib
|
|
11
11
|
import uuid
|
|
12
|
-
from typing import
|
|
12
|
+
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
|
13
13
|
|
|
14
|
-
import PIL.Image
|
|
15
14
|
import numpy as np
|
|
15
|
+
import PIL.Image
|
|
16
16
|
import tenacity
|
|
17
17
|
|
|
18
18
|
import pixeltable as pxt
|
|
@@ -23,13 +23,11 @@ from pixeltable.utils.code import local_public_names
|
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
25
|
import openai
|
|
26
|
-
from openai._types import NotGiven
|
|
27
26
|
|
|
28
27
|
|
|
29
28
|
@env.register_client('openai')
|
|
30
29
|
def _(api_key: str) -> 'openai.OpenAI':
|
|
31
30
|
import openai
|
|
32
|
-
|
|
33
31
|
return openai.OpenAI(api_key=api_key)
|
|
34
32
|
|
|
35
33
|
|
|
@@ -42,10 +40,9 @@ def _openai_client() -> 'openai.OpenAI':
|
|
|
42
40
|
# by OpenAI. Should we investigate making this more customizable in the future?
|
|
43
41
|
def _retry(fn: Callable) -> Callable:
|
|
44
42
|
import openai
|
|
45
|
-
|
|
46
43
|
return tenacity.retry(
|
|
47
44
|
retry=tenacity.retry_if_exception_type(openai.RateLimitError),
|
|
48
|
-
wait=tenacity.wait_random_exponential(multiplier=
|
|
45
|
+
wait=tenacity.wait_random_exponential(multiplier=1, max=60),
|
|
49
46
|
stop=tenacity.stop_after_attempt(20),
|
|
50
47
|
)(fn)
|
|
51
48
|
|
|
@@ -462,10 +459,9 @@ def moderations(input: str, *, model: Optional[str] = None) -> dict:
|
|
|
462
459
|
_T = TypeVar('_T')
|
|
463
460
|
|
|
464
461
|
|
|
465
|
-
def _opt(arg: _T) -> Union[_T, 'NotGiven']:
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
return arg if arg is not None else NOT_GIVEN
|
|
462
|
+
def _opt(arg: _T) -> Union[_T, 'openai.NotGiven']:
|
|
463
|
+
import openai
|
|
464
|
+
return arg if arg is not None else openai.NOT_GIVEN
|
|
469
465
|
|
|
470
466
|
|
|
471
467
|
__all__ = local_public_names(__name__)
|
pixeltable/functions/string.py
CHANGED
|
@@ -14,6 +14,7 @@ t.select(pxt_str.capitalize(t.str_col)).collect()
|
|
|
14
14
|
|
|
15
15
|
from typing import Any, Optional
|
|
16
16
|
|
|
17
|
+
import pixeltable.exceptions as excs
|
|
17
18
|
import pixeltable.func as func
|
|
18
19
|
from pixeltable.utils.code import local_public_names
|
|
19
20
|
|
|
@@ -352,7 +353,7 @@ def normalize(self: str, form: str) -> str:
|
|
|
352
353
|
form: Unicode normal form (`‘NFC’`, `‘NFKC’`, `‘NFD’`, `‘NFKD’`)
|
|
353
354
|
"""
|
|
354
355
|
import unicodedata
|
|
355
|
-
return unicodedata.normalize(form, self)
|
|
356
|
+
return unicodedata.normalize(form, self) # type: ignore[arg-type]
|
|
356
357
|
|
|
357
358
|
@func.udf(is_method=True)
|
|
358
359
|
def pad(self: str, width: int, side: str = 'left', fillchar: str = ' ') -> str:
|
|
@@ -579,7 +580,7 @@ def upper(self: str) -> str:
|
|
|
579
580
|
return self.upper()
|
|
580
581
|
|
|
581
582
|
@func.udf(is_method=True)
|
|
582
|
-
def wrap(self: str, width: int, **kwargs: Any) ->
|
|
583
|
+
def wrap(self: str, width: int, **kwargs: Any) -> list[str]:
|
|
583
584
|
"""
|
|
584
585
|
Wraps the single paragraph in string so every line is at most `width` characters long.
|
|
585
586
|
Returns a list of output lines, without final newlines.
|
|
@@ -13,11 +13,14 @@ t.select(t.timestamp_col.year, t.timestamp_col.weekday()).collect()
|
|
|
13
13
|
from datetime import datetime
|
|
14
14
|
from typing import Optional
|
|
15
15
|
|
|
16
|
+
import sqlalchemy as sql
|
|
17
|
+
|
|
18
|
+
from pixeltable.env import Env
|
|
16
19
|
import pixeltable.func as func
|
|
17
20
|
from pixeltable.utils.code import local_public_names
|
|
18
21
|
|
|
19
22
|
|
|
20
|
-
@func.udf(
|
|
23
|
+
@func.udf(is_property=True)
|
|
21
24
|
def year(self: datetime) -> int:
|
|
22
25
|
"""
|
|
23
26
|
Between [`MINYEAR`](https://docs.python.org/3/library/datetime.html#datetime.MINYEAR) and
|
|
@@ -28,7 +31,12 @@ def year(self: datetime) -> int:
|
|
|
28
31
|
return self.year
|
|
29
32
|
|
|
30
33
|
|
|
31
|
-
@
|
|
34
|
+
@year.to_sql
|
|
35
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
36
|
+
return sql.extract('year', self)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@func.udf(is_property=True)
|
|
32
40
|
def month(self: datetime) -> int:
|
|
33
41
|
"""
|
|
34
42
|
Between 1 and 12 inclusive.
|
|
@@ -38,7 +46,12 @@ def month(self: datetime) -> int:
|
|
|
38
46
|
return self.month
|
|
39
47
|
|
|
40
48
|
|
|
41
|
-
@
|
|
49
|
+
@month.to_sql
|
|
50
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
51
|
+
return sql.extract('month', self)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@func.udf(is_property=True)
|
|
42
55
|
def day(self: datetime) -> int:
|
|
43
56
|
"""
|
|
44
57
|
Between 1 and the number of days in the given month of the given year.
|
|
@@ -48,7 +61,12 @@ def day(self: datetime) -> int:
|
|
|
48
61
|
return self.day
|
|
49
62
|
|
|
50
63
|
|
|
51
|
-
@
|
|
64
|
+
@day.to_sql
|
|
65
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
66
|
+
return sql.extract('day', self)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@func.udf(is_property=True)
|
|
52
70
|
def hour(self: datetime) -> int:
|
|
53
71
|
"""
|
|
54
72
|
Between 0 and 23 inclusive.
|
|
@@ -58,7 +76,12 @@ def hour(self: datetime) -> int:
|
|
|
58
76
|
return self.hour
|
|
59
77
|
|
|
60
78
|
|
|
61
|
-
@
|
|
79
|
+
@hour.to_sql
|
|
80
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
81
|
+
return sql.extract('hour', self)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@func.udf(is_property=True)
|
|
62
85
|
def minute(self: datetime) -> int:
|
|
63
86
|
"""
|
|
64
87
|
Between 0 and 59 inclusive.
|
|
@@ -68,7 +91,12 @@ def minute(self: datetime) -> int:
|
|
|
68
91
|
return self.minute
|
|
69
92
|
|
|
70
93
|
|
|
71
|
-
@
|
|
94
|
+
@minute.to_sql
|
|
95
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
96
|
+
return sql.extract('minute', self)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@func.udf(is_property=True)
|
|
72
100
|
def second(self: datetime) -> int:
|
|
73
101
|
"""
|
|
74
102
|
Between 0 and 59 inclusive.
|
|
@@ -78,7 +106,12 @@ def second(self: datetime) -> int:
|
|
|
78
106
|
return self.second
|
|
79
107
|
|
|
80
108
|
|
|
81
|
-
@
|
|
109
|
+
@second.to_sql
|
|
110
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
111
|
+
return sql.extract('second', self)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@func.udf(is_property=True)
|
|
82
115
|
def microsecond(self: datetime) -> int:
|
|
83
116
|
"""
|
|
84
117
|
Between 0 and 999999 inclusive.
|
|
@@ -88,6 +121,24 @@ def microsecond(self: datetime) -> int:
|
|
|
88
121
|
return self.microsecond
|
|
89
122
|
|
|
90
123
|
|
|
124
|
+
@microsecond.to_sql
|
|
125
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
126
|
+
return sql.extract('microseconds', self) - sql.extract('second', self) * 1000000
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@func.udf(is_method=True)
|
|
130
|
+
def astimezone(self: datetime, tz: str) -> datetime:
|
|
131
|
+
"""
|
|
132
|
+
Convert the datetime to the given time zone.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
tz: The time zone to convert to. Must be a valid time zone name from the IANA Time Zone Database.
|
|
136
|
+
"""
|
|
137
|
+
from zoneinfo import ZoneInfo
|
|
138
|
+
tzinfo = ZoneInfo(tz)
|
|
139
|
+
return self.astimezone(tzinfo)
|
|
140
|
+
|
|
141
|
+
|
|
91
142
|
@func.udf(is_method=True)
|
|
92
143
|
def weekday(self: datetime) -> int:
|
|
93
144
|
"""
|
|
@@ -97,6 +148,12 @@ def weekday(self: datetime) -> int:
|
|
|
97
148
|
"""
|
|
98
149
|
return self.weekday()
|
|
99
150
|
|
|
151
|
+
|
|
152
|
+
@weekday.to_sql
|
|
153
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
154
|
+
return sql.extract('isodow', self) - 1
|
|
155
|
+
|
|
156
|
+
|
|
100
157
|
@func.udf(is_method=True)
|
|
101
158
|
def isoweekday(self: datetime) -> int:
|
|
102
159
|
"""
|
|
@@ -107,6 +164,11 @@ def isoweekday(self: datetime) -> int:
|
|
|
107
164
|
return self.isoweekday()
|
|
108
165
|
|
|
109
166
|
|
|
167
|
+
@isoweekday.to_sql
|
|
168
|
+
def _(self: sql.ColumnElement) -> sql.ColumnElement:
|
|
169
|
+
return sql.extract('isodow', self)
|
|
170
|
+
|
|
171
|
+
|
|
110
172
|
@func.udf(is_method=True)
|
|
111
173
|
def isocalendar(self: datetime) -> dict:
|
|
112
174
|
"""
|
|
@@ -146,6 +208,32 @@ def strftime(self: datetime, format: str) -> str:
|
|
|
146
208
|
return self.strftime(format)
|
|
147
209
|
|
|
148
210
|
|
|
211
|
+
@func.udf(is_method=True)
|
|
212
|
+
def make_timestamp(
|
|
213
|
+
year: int, month: int, day: int, hour: int = 0, minute: int = 0, second: int = 0, microsecond: int = 0
|
|
214
|
+
) -> datetime:
|
|
215
|
+
"""
|
|
216
|
+
Create a timestamp.
|
|
217
|
+
|
|
218
|
+
Equivalent to [`datetime()`](https://docs.python.org/3/library/datetime.html#datetime.datetime).
|
|
219
|
+
"""
|
|
220
|
+
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=Env.get().default_time_zone)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@make_timestamp.to_sql
|
|
224
|
+
def _(
|
|
225
|
+
year: sql.ColumnElement, month: sql.ColumnElement, day: sql.ColumnElement,
|
|
226
|
+
hour: sql.ColumnElement = sql.literal(0), minute: sql.ColumnElement = sql.literal(0),
|
|
227
|
+
second: sql.ColumnElement = sql.literal(0), microsecond: sql.ColumnElement = sql.literal(0)
|
|
228
|
+
) -> sql.ColumnElement:
|
|
229
|
+
return sql.func.make_timestamptz(
|
|
230
|
+
sql.cast(year, sql.Integer),
|
|
231
|
+
sql.cast(month, sql.Integer),
|
|
232
|
+
sql.cast(day, sql.Integer),
|
|
233
|
+
sql.cast(hour, sql.Integer),
|
|
234
|
+
sql.cast(minute, sql.Integer),
|
|
235
|
+
sql.cast(second + microsecond / 1000000.0, sql.Double))
|
|
236
|
+
|
|
149
237
|
# @func.udf
|
|
150
238
|
# def date(self: datetime) -> datetime:
|
|
151
239
|
# """
|
pixeltable/functions/together.py
CHANGED
|
@@ -6,25 +6,25 @@ the [Working with Together AI](https://pixeltable.readme.io/docs/together-ai) tu
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import base64
|
|
9
|
-
|
|
9
|
+
import io
|
|
10
|
+
from typing import TYPE_CHECKING, Callable, Optional
|
|
10
11
|
|
|
11
|
-
import PIL.Image
|
|
12
12
|
import numpy as np
|
|
13
|
+
import PIL.Image
|
|
14
|
+
import tenacity
|
|
13
15
|
|
|
14
|
-
import io
|
|
15
16
|
import pixeltable as pxt
|
|
16
17
|
from pixeltable import env
|
|
17
18
|
from pixeltable.func import Batch
|
|
18
19
|
from pixeltable.utils.code import local_public_names
|
|
19
20
|
|
|
20
21
|
if TYPE_CHECKING:
|
|
21
|
-
import together
|
|
22
|
+
import together # type: ignore[import-untyped]
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
@env.register_client('together')
|
|
25
26
|
def _(api_key: str) -> 'together.Together':
|
|
26
27
|
import together
|
|
27
|
-
|
|
28
28
|
return together.Together(api_key=api_key)
|
|
29
29
|
|
|
30
30
|
|
|
@@ -32,6 +32,15 @@ def _together_client() -> 'together.Together':
|
|
|
32
32
|
return env.Env.get().get_client('together')
|
|
33
33
|
|
|
34
34
|
|
|
35
|
+
def _retry(fn: Callable) -> Callable:
|
|
36
|
+
import together
|
|
37
|
+
return tenacity.retry(
|
|
38
|
+
retry=tenacity.retry_if_exception_type(together.error.RateLimitError),
|
|
39
|
+
wait=tenacity.wait_random_exponential(multiplier=1, max=60),
|
|
40
|
+
stop=tenacity.stop_after_attempt(20),
|
|
41
|
+
)(fn)
|
|
42
|
+
|
|
43
|
+
|
|
35
44
|
@pxt.udf
|
|
36
45
|
def completions(
|
|
37
46
|
prompt: str,
|
|
@@ -74,8 +83,7 @@ def completions(
|
|
|
74
83
|
>>> tbl['response'] = completions(tbl.prompt, model='mistralai/Mixtral-8x7B-v0.1')
|
|
75
84
|
"""
|
|
76
85
|
return (
|
|
77
|
-
_together_client()
|
|
78
|
-
.completions.create(
|
|
86
|
+
_retry(_together_client().completions.create)(
|
|
79
87
|
prompt=prompt,
|
|
80
88
|
model=model,
|
|
81
89
|
max_tokens=max_tokens,
|
|
@@ -139,8 +147,7 @@ def chat_completions(
|
|
|
139
147
|
... tbl['response'] = chat_completions(messages, model='mistralai/Mixtral-8x7B-v0.1')
|
|
140
148
|
"""
|
|
141
149
|
return (
|
|
142
|
-
_together_client()
|
|
143
|
-
.chat.completions.create(
|
|
150
|
+
_retry(_together_client().chat.completions.create)(
|
|
144
151
|
messages=messages,
|
|
145
152
|
model=model,
|
|
146
153
|
max_tokens=max_tokens,
|
|
@@ -198,7 +205,7 @@ def embeddings(input: Batch[str], *, model: str) -> Batch[np.ndarray]:
|
|
|
198
205
|
|
|
199
206
|
>>> tbl['response'] = embeddings(tbl.text, model='togethercomputer/m2-bert-80M-8k-retrieval')
|
|
200
207
|
"""
|
|
201
|
-
result = _together_client().embeddings.create(input=input, model=model)
|
|
208
|
+
result = _retry(_together_client().embeddings.create)(input=input, model=model)
|
|
202
209
|
return [np.array(data.embedding, dtype=np.float64) for data in result.data]
|
|
203
210
|
|
|
204
211
|
|
|
@@ -248,7 +255,7 @@ def image_generations(
|
|
|
248
255
|
>>> tbl['response'] = image_generations(tbl.prompt, model='runwayml/stable-diffusion-v1-5')
|
|
249
256
|
"""
|
|
250
257
|
# TODO(aaron-siegel): Decompose CPU/GPU ops into separate functions
|
|
251
|
-
result = _together_client().images.generate(
|
|
258
|
+
result = _retry(_together_client().images.generate)(
|
|
252
259
|
prompt=prompt, model=model, steps=steps, seed=seed, height=height, width=width, negative_prompt=negative_prompt
|
|
253
260
|
)
|
|
254
261
|
b64_str = result.data[0].b64_json
|
pixeltable/functions/video.py
CHANGED
|
@@ -16,9 +16,9 @@ import uuid
|
|
|
16
16
|
from pathlib import Path
|
|
17
17
|
from typing import Optional
|
|
18
18
|
|
|
19
|
-
import
|
|
20
|
-
import av
|
|
19
|
+
import av # type: ignore[import-untyped]
|
|
21
20
|
import numpy as np
|
|
21
|
+
import PIL.Image
|
|
22
22
|
|
|
23
23
|
import pixeltable.env as env
|
|
24
24
|
import pixeltable.func as func
|
pixeltable/functions/vision.py
CHANGED
|
@@ -13,22 +13,16 @@ t.select(pxtv.draw_bounding_boxes(t.img, boxes=t.boxes, label=t.labels)).collect
|
|
|
13
13
|
|
|
14
14
|
import colorsys
|
|
15
15
|
import hashlib
|
|
16
|
-
import random
|
|
17
16
|
from collections import defaultdict
|
|
18
|
-
from typing import Optional, Union
|
|
17
|
+
from typing import Any, Optional, Union
|
|
19
18
|
|
|
20
|
-
import PIL.Image
|
|
21
|
-
import PIL.Image
|
|
22
19
|
import numpy as np
|
|
20
|
+
import PIL.Image
|
|
23
21
|
|
|
24
|
-
import pixeltable
|
|
25
|
-
import pixeltable.type_system as ts
|
|
22
|
+
import pixeltable as pxt
|
|
26
23
|
from pixeltable.utils.code import local_public_names
|
|
27
24
|
|
|
28
25
|
|
|
29
|
-
# TODO: figure out a better submodule structure
|
|
30
|
-
|
|
31
|
-
|
|
32
26
|
# the following function has been adapted from MMEval
|
|
33
27
|
# (sources at https://github.com/open-mmlab/mmeval)
|
|
34
28
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@@ -164,25 +158,41 @@ def __calculate_image_tpfp(
|
|
|
164
158
|
return tp, fp
|
|
165
159
|
|
|
166
160
|
|
|
167
|
-
@
|
|
168
|
-
return_type=ts.JsonType(nullable=False),
|
|
169
|
-
param_types=[
|
|
170
|
-
ts.JsonType(nullable=False),
|
|
171
|
-
ts.JsonType(nullable=False),
|
|
172
|
-
ts.JsonType(nullable=False),
|
|
173
|
-
ts.JsonType(nullable=False),
|
|
174
|
-
ts.JsonType(nullable=False),
|
|
175
|
-
],
|
|
176
|
-
)
|
|
161
|
+
@pxt.udf
|
|
177
162
|
def eval_detections(
|
|
178
163
|
pred_bboxes: list[list[int]],
|
|
179
164
|
pred_labels: list[int],
|
|
180
165
|
pred_scores: list[float],
|
|
181
166
|
gt_bboxes: list[list[int]],
|
|
182
167
|
gt_labels: list[int],
|
|
183
|
-
|
|
168
|
+
min_iou: float = 0.5,
|
|
169
|
+
) -> list[dict]:
|
|
184
170
|
"""
|
|
185
171
|
Evaluates the performance of a set of predicted bounding boxes against a set of ground truth bounding boxes.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
pred_bboxes: List of predicted bounding boxes, each represented as [xmin, ymin, xmax, ymax].
|
|
175
|
+
pred_labels: List of predicted labels.
|
|
176
|
+
pred_scores: List of predicted scores.
|
|
177
|
+
gt_bboxes: List of ground truth bounding boxes, each represented as [xmin, ymin, xmax, ymax].
|
|
178
|
+
gt_labels: List of ground truth labels.
|
|
179
|
+
min_iou: Minimum intersection-over-union (IoU) threshold for a predicted bounding box to be
|
|
180
|
+
considered a true positive.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
A list of dictionaries, one per label class, with the following structure:
|
|
184
|
+
```python
|
|
185
|
+
{
|
|
186
|
+
'min_iou': float, # The value of `min_iou` used for the detections
|
|
187
|
+
'class': int, # The label class
|
|
188
|
+
'tp': list[int], # List of 1's and 0's indicating true positives for each
|
|
189
|
+
# predicted bounding box of this class
|
|
190
|
+
'fp': list[int], # List of 1's and 0's indicating false positives for each
|
|
191
|
+
# predicted bounding box of this class; `fp[n] == 1 - tp[n]`
|
|
192
|
+
'scores': list[float], # List of predicted scores for each bounding box of this class
|
|
193
|
+
'num_gts': int, # Number of ground truth bounding boxes of this class
|
|
194
|
+
}
|
|
195
|
+
```
|
|
186
196
|
"""
|
|
187
197
|
class_idxs = list(set(pred_labels + gt_labels))
|
|
188
198
|
result: list[dict] = []
|
|
@@ -195,11 +205,11 @@ def eval_detections(
|
|
|
195
205
|
pred_filter = pred_classes_arr == class_idx
|
|
196
206
|
gt_filter = gt_classes_arr == class_idx
|
|
197
207
|
class_pred_scores = pred_scores_arr[pred_filter]
|
|
198
|
-
tp, fp = __calculate_image_tpfp(pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter],
|
|
208
|
+
tp, fp = __calculate_image_tpfp(pred_bboxes_arr[pred_filter], class_pred_scores, gt_bboxes_arr[gt_filter], min_iou)
|
|
199
209
|
ordered_class_pred_scores = -np.sort(-class_pred_scores)
|
|
200
210
|
result.append(
|
|
201
211
|
{
|
|
202
|
-
'min_iou':
|
|
212
|
+
'min_iou': min_iou,
|
|
203
213
|
'class': class_idx,
|
|
204
214
|
'tp': tp.tolist(),
|
|
205
215
|
'fp': fp.tolist(),
|
|
@@ -210,11 +220,20 @@ def eval_detections(
|
|
|
210
220
|
return result
|
|
211
221
|
|
|
212
222
|
|
|
213
|
-
@
|
|
214
|
-
class mean_ap(
|
|
223
|
+
@pxt.uda(update_types=[pxt.JsonType()], value_type=pxt.JsonType(), allows_std_agg=True, allows_window=False)
|
|
224
|
+
class mean_ap(pxt.Aggregator):
|
|
215
225
|
"""
|
|
216
226
|
Calculates the mean average precision (mAP) over
|
|
217
227
|
[`eval_detections()`][pixeltable.functions.vision.eval_detections] results.
|
|
228
|
+
|
|
229
|
+
__Parameters:__
|
|
230
|
+
|
|
231
|
+
- `eval_dicts` (list[dict]): List of dictionaries as returned by
|
|
232
|
+
[`eval_detections()`][pixeltable.functions.vision.eval_detections].
|
|
233
|
+
|
|
234
|
+
__Returns:__
|
|
235
|
+
|
|
236
|
+
- A `dict[int, float]` mapping each label class to an average precision (AP) value for that class.
|
|
218
237
|
"""
|
|
219
238
|
def __init__(self):
|
|
220
239
|
self.class_tpfp: dict[int, list[dict]] = defaultdict(list)
|
|
@@ -249,7 +268,7 @@ class mean_ap(func.Aggregator):
|
|
|
249
268
|
return result
|
|
250
269
|
|
|
251
270
|
|
|
252
|
-
def
|
|
271
|
+
def __create_label_colors(labels: list[Any]) -> dict[Any, str]:
|
|
253
272
|
"""
|
|
254
273
|
Create random colors for labels such that a particular label always gets the same color.
|
|
255
274
|
|
|
@@ -268,7 +287,7 @@ def _create_label_colors(labels: list[Any]) -> dict[Any, str]:
|
|
|
268
287
|
return result
|
|
269
288
|
|
|
270
289
|
|
|
271
|
-
@
|
|
290
|
+
@pxt.udf
|
|
272
291
|
def draw_bounding_boxes(
|
|
273
292
|
img: PIL.Image.Image,
|
|
274
293
|
boxes: list[list[int]],
|
|
@@ -327,34 +346,47 @@ def draw_bounding_boxes(
|
|
|
327
346
|
if color is not None:
|
|
328
347
|
box_colors = [color] * num_boxes
|
|
329
348
|
else:
|
|
330
|
-
label_colors =
|
|
349
|
+
label_colors = __create_label_colors(labels)
|
|
331
350
|
box_colors = [label_colors[label] for label in labels]
|
|
332
351
|
|
|
333
|
-
from PIL import ImageDraw, ImageFont
|
|
352
|
+
from PIL import ImageColor, ImageDraw, ImageFont
|
|
353
|
+
|
|
334
354
|
# set default font if not provided
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
txt_font = ImageFont.truetype(font=font, size=font_size or 10)
|
|
355
|
+
txt_font: Union[ImageFont.ImageFont, ImageFont.FreeTypeFont] = (
|
|
356
|
+
ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size or 10)
|
|
357
|
+
)
|
|
339
358
|
|
|
340
359
|
img_to_draw = img.copy()
|
|
341
360
|
draw = ImageDraw.Draw(img_to_draw, 'RGBA' if fill else 'RGB')
|
|
342
361
|
|
|
343
|
-
|
|
362
|
+
# Draw bounding boxes
|
|
363
|
+
for i, bbox in enumerate(boxes):
|
|
344
364
|
# determine color for the current box and label
|
|
345
365
|
color = box_colors[i % len(box_colors)]
|
|
346
366
|
|
|
347
367
|
if fill:
|
|
348
368
|
rgb_color = ImageColor.getrgb(color)
|
|
349
369
|
fill_color = rgb_color + (100,) # semi-transparent
|
|
350
|
-
draw.rectangle(bbox, outline=color, width=width, fill=fill_color)
|
|
370
|
+
draw.rectangle(bbox, outline=color, width=width, fill=fill_color) # type: ignore[arg-type]
|
|
351
371
|
else:
|
|
352
|
-
draw.rectangle(bbox, outline=color, width=width)
|
|
372
|
+
draw.rectangle(bbox, outline=color, width=width) # type: ignore[arg-type]
|
|
353
373
|
|
|
374
|
+
# Now draw labels separately, so they are not obscured by the boxes
|
|
375
|
+
for i, (bbox, label) in enumerate(zip(boxes, labels)):
|
|
354
376
|
if label is not None:
|
|
355
377
|
label_str = str(label)
|
|
356
|
-
|
|
357
|
-
|
|
378
|
+
_, _, text_width, text_height = draw.textbbox((0, 0), label_str, font=txt_font)
|
|
379
|
+
if bbox[1] - text_height - 2 >= 0:
|
|
380
|
+
# draw text above the box
|
|
381
|
+
y = bbox[1] - text_height - 2
|
|
382
|
+
else:
|
|
383
|
+
y = bbox[3]
|
|
384
|
+
if bbox[0] + text_width + 2 < img.width:
|
|
385
|
+
x = bbox[0]
|
|
386
|
+
else:
|
|
387
|
+
x = img.width - text_width - 2
|
|
388
|
+
draw.rectangle((x, y, x + text_width + 1, y + text_height + 1), fill='black')
|
|
389
|
+
draw.text((x, y), label_str, fill='white', font=txt_font)
|
|
358
390
|
|
|
359
391
|
return img_to_draw
|
|
360
392
|
|
pixeltable/functions/whisper.py
CHANGED
|
@@ -9,9 +9,10 @@ first `pip install openai-whisper`.
|
|
|
9
9
|
from typing import TYPE_CHECKING, Optional
|
|
10
10
|
|
|
11
11
|
import pixeltable as pxt
|
|
12
|
+
from pixeltable.env import Env
|
|
12
13
|
|
|
13
14
|
if TYPE_CHECKING:
|
|
14
|
-
from whisper import Whisper
|
|
15
|
+
from whisper import Whisper # type: ignore[import-untyped]
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
@pxt.udf(
|
|
@@ -71,6 +72,8 @@ def transcribe(
|
|
|
71
72
|
|
|
72
73
|
>>> tbl['result'] = transcribe(tbl.audio, model='base.en')
|
|
73
74
|
"""
|
|
75
|
+
Env.get().require_package('whisper')
|
|
76
|
+
Env.get().require_package('torch')
|
|
74
77
|
import torch
|
|
75
78
|
|
|
76
79
|
if decode_options is None:
|
pixeltable/globals.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import logging
|
|
3
|
-
from typing import Any, Optional, Union
|
|
3
|
+
from typing import Any, Iterable, Optional, Union
|
|
4
4
|
from uuid import UUID
|
|
5
5
|
|
|
6
6
|
import pandas as pd
|
|
@@ -487,3 +487,7 @@ def configure_logging(
|
|
|
487
487
|
remove: comma-separated list of module names
|
|
488
488
|
"""
|
|
489
489
|
return Env.get().configure_logging(to_stdout=to_stdout, level=level, add=add, remove=remove)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def array(elements: Iterable) -> exprs.Expr:
|
|
493
|
+
return exprs.InlineArray(elements)
|
pixeltable/io/hf_datasets.py
CHANGED
|
@@ -6,7 +6,7 @@ import random
|
|
|
6
6
|
import typing
|
|
7
7
|
from typing import Union, Optional, Any
|
|
8
8
|
|
|
9
|
-
import pixeltable
|
|
9
|
+
import pixeltable as pxt
|
|
10
10
|
import pixeltable.type_system as ts
|
|
11
11
|
from pixeltable import exceptions as excs
|
|
12
12
|
|
|
@@ -81,24 +81,26 @@ def import_huggingface_dataset(
|
|
|
81
81
|
dataset: Union[datasets.Dataset, datasets.DatasetDict],
|
|
82
82
|
*,
|
|
83
83
|
column_name_for_split: Optional[str] = None,
|
|
84
|
-
|
|
85
|
-
**kwargs,
|
|
86
|
-
) ->
|
|
87
|
-
"""Create a new
|
|
88
|
-
Requires datasets library to be installed.
|
|
84
|
+
schema_overrides: Optional[dict[str, Any]] = None,
|
|
85
|
+
**kwargs: Any,
|
|
86
|
+
) -> pxt.Table:
|
|
87
|
+
"""Create a new base table from a Huggingface dataset, or dataset dict with multiple splits.
|
|
88
|
+
Requires `datasets` library to be installed.
|
|
89
89
|
|
|
90
90
|
Args:
|
|
91
|
-
|
|
92
|
-
dataset: Huggingface datasets.Dataset
|
|
91
|
+
table_path: Path to the table.
|
|
92
|
+
dataset: Huggingface [`datasets.Dataset`](https://huggingface.co/docs/datasets/en/package_reference/main_classes#datasets.Dataset)
|
|
93
|
+
or [`datasets.DatasetDict`](https://huggingface.co/docs/datasets/en/package_reference/main_classes#datasets.DatasetDict)
|
|
94
|
+
to insert into the table.
|
|
93
95
|
column_name_for_split: column name to use for split information. If None, no split information will be stored.
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
schema_overrides: If specified, then for each (name, type) pair in `schema_overrides`, the column with
|
|
97
|
+
name `name` will be given type `type`, instead of being inferred from the `Dataset` or `DatasetDict`. The keys in
|
|
98
|
+
`schema_overrides` should be the column names of the `Dataset` or `DatasetDict` (whether or not they are valid
|
|
99
|
+
Pixeltable identifiers).
|
|
98
100
|
kwargs: Additional arguments to pass to `create_table`.
|
|
99
101
|
|
|
100
102
|
Returns:
|
|
101
|
-
|
|
103
|
+
A handle to the newly created [`Table`][pixeltable.Table].
|
|
102
104
|
"""
|
|
103
105
|
import datasets
|
|
104
106
|
import pixeltable as pxt
|
|
@@ -118,8 +120,8 @@ def import_huggingface_dataset(
|
|
|
118
120
|
dataset_dict = dataset
|
|
119
121
|
|
|
120
122
|
pixeltable_schema = huggingface_schema_to_pixeltable_schema(dataset)
|
|
121
|
-
if
|
|
122
|
-
pixeltable_schema.update(
|
|
123
|
+
if schema_overrides is not None:
|
|
124
|
+
pixeltable_schema.update(schema_overrides)
|
|
123
125
|
|
|
124
126
|
if column_name_for_split is not None:
|
|
125
127
|
if column_name_for_split in pixeltable_schema:
|