pixeltable 0.2.26__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pixeltable/__init__.py +83 -19
- pixeltable/_query.py +1444 -0
- pixeltable/_version.py +1 -0
- pixeltable/catalog/__init__.py +7 -4
- pixeltable/catalog/catalog.py +2394 -119
- pixeltable/catalog/column.py +225 -104
- pixeltable/catalog/dir.py +38 -9
- pixeltable/catalog/globals.py +53 -34
- pixeltable/catalog/insertable_table.py +265 -115
- pixeltable/catalog/path.py +80 -17
- pixeltable/catalog/schema_object.py +28 -43
- pixeltable/catalog/table.py +1270 -677
- pixeltable/catalog/table_metadata.py +103 -0
- pixeltable/catalog/table_version.py +1270 -751
- pixeltable/catalog/table_version_handle.py +109 -0
- pixeltable/catalog/table_version_path.py +137 -42
- pixeltable/catalog/tbl_ops.py +53 -0
- pixeltable/catalog/update_status.py +191 -0
- pixeltable/catalog/view.py +251 -134
- pixeltable/config.py +215 -0
- pixeltable/env.py +736 -285
- pixeltable/exceptions.py +26 -2
- pixeltable/exec/__init__.py +7 -2
- pixeltable/exec/aggregation_node.py +39 -21
- pixeltable/exec/cache_prefetch_node.py +87 -109
- pixeltable/exec/cell_materialization_node.py +268 -0
- pixeltable/exec/cell_reconstruction_node.py +168 -0
- pixeltable/exec/component_iteration_node.py +25 -28
- pixeltable/exec/data_row_batch.py +11 -46
- pixeltable/exec/exec_context.py +26 -11
- pixeltable/exec/exec_node.py +35 -27
- pixeltable/exec/expr_eval/__init__.py +3 -0
- pixeltable/exec/expr_eval/evaluators.py +365 -0
- pixeltable/exec/expr_eval/expr_eval_node.py +413 -0
- pixeltable/exec/expr_eval/globals.py +200 -0
- pixeltable/exec/expr_eval/row_buffer.py +74 -0
- pixeltable/exec/expr_eval/schedulers.py +413 -0
- pixeltable/exec/globals.py +35 -0
- pixeltable/exec/in_memory_data_node.py +35 -27
- pixeltable/exec/object_store_save_node.py +293 -0
- pixeltable/exec/row_update_node.py +44 -29
- pixeltable/exec/sql_node.py +414 -115
- pixeltable/exprs/__init__.py +8 -5
- pixeltable/exprs/arithmetic_expr.py +79 -45
- pixeltable/exprs/array_slice.py +5 -5
- pixeltable/exprs/column_property_ref.py +40 -26
- pixeltable/exprs/column_ref.py +254 -61
- pixeltable/exprs/comparison.py +14 -9
- pixeltable/exprs/compound_predicate.py +9 -10
- pixeltable/exprs/data_row.py +213 -72
- pixeltable/exprs/expr.py +270 -104
- pixeltable/exprs/expr_dict.py +6 -5
- pixeltable/exprs/expr_set.py +20 -11
- pixeltable/exprs/function_call.py +383 -284
- pixeltable/exprs/globals.py +18 -5
- pixeltable/exprs/in_predicate.py +7 -7
- pixeltable/exprs/inline_expr.py +37 -37
- pixeltable/exprs/is_null.py +8 -4
- pixeltable/exprs/json_mapper.py +120 -54
- pixeltable/exprs/json_path.py +90 -60
- pixeltable/exprs/literal.py +61 -16
- pixeltable/exprs/method_ref.py +7 -6
- pixeltable/exprs/object_ref.py +19 -8
- pixeltable/exprs/row_builder.py +238 -75
- pixeltable/exprs/rowid_ref.py +53 -15
- pixeltable/exprs/similarity_expr.py +65 -50
- pixeltable/exprs/sql_element_cache.py +5 -5
- pixeltable/exprs/string_op.py +107 -0
- pixeltable/exprs/type_cast.py +25 -13
- pixeltable/exprs/variable.py +2 -2
- pixeltable/func/__init__.py +9 -5
- pixeltable/func/aggregate_function.py +197 -92
- pixeltable/func/callable_function.py +119 -35
- pixeltable/func/expr_template_function.py +101 -48
- pixeltable/func/function.py +375 -62
- pixeltable/func/function_registry.py +20 -19
- pixeltable/func/globals.py +6 -5
- pixeltable/func/mcp.py +74 -0
- pixeltable/func/query_template_function.py +151 -35
- pixeltable/func/signature.py +178 -49
- pixeltable/func/tools.py +164 -0
- pixeltable/func/udf.py +176 -53
- pixeltable/functions/__init__.py +44 -4
- pixeltable/functions/anthropic.py +226 -47
- pixeltable/functions/audio.py +148 -11
- pixeltable/functions/bedrock.py +137 -0
- pixeltable/functions/date.py +188 -0
- pixeltable/functions/deepseek.py +113 -0
- pixeltable/functions/document.py +81 -0
- pixeltable/functions/fal.py +76 -0
- pixeltable/functions/fireworks.py +72 -20
- pixeltable/functions/gemini.py +249 -0
- pixeltable/functions/globals.py +208 -53
- pixeltable/functions/groq.py +108 -0
- pixeltable/functions/huggingface.py +1088 -95
- pixeltable/functions/image.py +155 -84
- pixeltable/functions/json.py +8 -11
- pixeltable/functions/llama_cpp.py +31 -19
- pixeltable/functions/math.py +169 -0
- pixeltable/functions/mistralai.py +50 -75
- pixeltable/functions/net.py +70 -0
- pixeltable/functions/ollama.py +29 -36
- pixeltable/functions/openai.py +548 -160
- pixeltable/functions/openrouter.py +143 -0
- pixeltable/functions/replicate.py +15 -14
- pixeltable/functions/reve.py +250 -0
- pixeltable/functions/string.py +310 -85
- pixeltable/functions/timestamp.py +37 -19
- pixeltable/functions/together.py +77 -120
- pixeltable/functions/twelvelabs.py +188 -0
- pixeltable/functions/util.py +7 -2
- pixeltable/functions/uuid.py +30 -0
- pixeltable/functions/video.py +1528 -117
- pixeltable/functions/vision.py +26 -26
- pixeltable/functions/voyageai.py +289 -0
- pixeltable/functions/whisper.py +19 -10
- pixeltable/functions/whisperx.py +179 -0
- pixeltable/functions/yolox.py +112 -0
- pixeltable/globals.py +716 -236
- pixeltable/index/__init__.py +3 -1
- pixeltable/index/base.py +17 -21
- pixeltable/index/btree.py +32 -22
- pixeltable/index/embedding_index.py +155 -92
- pixeltable/io/__init__.py +12 -7
- pixeltable/io/datarows.py +140 -0
- pixeltable/io/external_store.py +83 -125
- pixeltable/io/fiftyone.py +24 -33
- pixeltable/io/globals.py +47 -182
- pixeltable/io/hf_datasets.py +96 -127
- pixeltable/io/label_studio.py +171 -156
- pixeltable/io/lancedb.py +3 -0
- pixeltable/io/pandas.py +136 -115
- pixeltable/io/parquet.py +40 -153
- pixeltable/io/table_data_conduit.py +702 -0
- pixeltable/io/utils.py +100 -0
- pixeltable/iterators/__init__.py +8 -4
- pixeltable/iterators/audio.py +207 -0
- pixeltable/iterators/base.py +9 -3
- pixeltable/iterators/document.py +144 -87
- pixeltable/iterators/image.py +17 -38
- pixeltable/iterators/string.py +15 -12
- pixeltable/iterators/video.py +523 -127
- pixeltable/metadata/__init__.py +33 -8
- pixeltable/metadata/converters/convert_10.py +2 -3
- pixeltable/metadata/converters/convert_13.py +2 -2
- pixeltable/metadata/converters/convert_15.py +15 -11
- pixeltable/metadata/converters/convert_16.py +4 -5
- pixeltable/metadata/converters/convert_17.py +4 -5
- pixeltable/metadata/converters/convert_18.py +4 -6
- pixeltable/metadata/converters/convert_19.py +6 -9
- pixeltable/metadata/converters/convert_20.py +3 -6
- pixeltable/metadata/converters/convert_21.py +6 -8
- pixeltable/metadata/converters/convert_22.py +3 -2
- pixeltable/metadata/converters/convert_23.py +33 -0
- pixeltable/metadata/converters/convert_24.py +55 -0
- pixeltable/metadata/converters/convert_25.py +19 -0
- pixeltable/metadata/converters/convert_26.py +23 -0
- pixeltable/metadata/converters/convert_27.py +29 -0
- pixeltable/metadata/converters/convert_28.py +13 -0
- pixeltable/metadata/converters/convert_29.py +110 -0
- pixeltable/metadata/converters/convert_30.py +63 -0
- pixeltable/metadata/converters/convert_31.py +11 -0
- pixeltable/metadata/converters/convert_32.py +15 -0
- pixeltable/metadata/converters/convert_33.py +17 -0
- pixeltable/metadata/converters/convert_34.py +21 -0
- pixeltable/metadata/converters/convert_35.py +9 -0
- pixeltable/metadata/converters/convert_36.py +38 -0
- pixeltable/metadata/converters/convert_37.py +15 -0
- pixeltable/metadata/converters/convert_38.py +39 -0
- pixeltable/metadata/converters/convert_39.py +124 -0
- pixeltable/metadata/converters/convert_40.py +73 -0
- pixeltable/metadata/converters/convert_41.py +12 -0
- pixeltable/metadata/converters/convert_42.py +9 -0
- pixeltable/metadata/converters/convert_43.py +44 -0
- pixeltable/metadata/converters/util.py +44 -18
- pixeltable/metadata/notes.py +21 -0
- pixeltable/metadata/schema.py +185 -42
- pixeltable/metadata/utils.py +74 -0
- pixeltable/mypy/__init__.py +3 -0
- pixeltable/mypy/mypy_plugin.py +123 -0
- pixeltable/plan.py +616 -225
- pixeltable/share/__init__.py +3 -0
- pixeltable/share/packager.py +797 -0
- pixeltable/share/protocol/__init__.py +33 -0
- pixeltable/share/protocol/common.py +165 -0
- pixeltable/share/protocol/operation_types.py +33 -0
- pixeltable/share/protocol/replica.py +119 -0
- pixeltable/share/publish.py +349 -0
- pixeltable/store.py +398 -232
- pixeltable/type_system.py +730 -267
- pixeltable/utils/__init__.py +40 -0
- pixeltable/utils/arrow.py +201 -29
- pixeltable/utils/av.py +298 -0
- pixeltable/utils/azure_store.py +346 -0
- pixeltable/utils/coco.py +26 -27
- pixeltable/utils/code.py +4 -4
- pixeltable/utils/console_output.py +46 -0
- pixeltable/utils/coroutine.py +24 -0
- pixeltable/utils/dbms.py +92 -0
- pixeltable/utils/description_helper.py +11 -12
- pixeltable/utils/documents.py +60 -61
- pixeltable/utils/exception_handler.py +36 -0
- pixeltable/utils/filecache.py +38 -22
- pixeltable/utils/formatter.py +88 -51
- pixeltable/utils/gcs_store.py +295 -0
- pixeltable/utils/http.py +133 -0
- pixeltable/utils/http_server.py +14 -13
- pixeltable/utils/iceberg.py +13 -0
- pixeltable/utils/image.py +17 -0
- pixeltable/utils/lancedb.py +90 -0
- pixeltable/utils/local_store.py +322 -0
- pixeltable/utils/misc.py +5 -0
- pixeltable/utils/object_stores.py +573 -0
- pixeltable/utils/pydantic.py +60 -0
- pixeltable/utils/pytorch.py +20 -20
- pixeltable/utils/s3_store.py +527 -0
- pixeltable/utils/sql.py +32 -5
- pixeltable/utils/system.py +30 -0
- pixeltable/utils/transactional_directory.py +4 -3
- pixeltable-0.5.7.dist-info/METADATA +579 -0
- pixeltable-0.5.7.dist-info/RECORD +227 -0
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info}/WHEEL +1 -1
- pixeltable-0.5.7.dist-info/entry_points.txt +2 -0
- pixeltable/__version__.py +0 -3
- pixeltable/catalog/named_function.py +0 -36
- pixeltable/catalog/path_dict.py +0 -141
- pixeltable/dataframe.py +0 -894
- pixeltable/exec/expr_eval_node.py +0 -232
- pixeltable/ext/__init__.py +0 -14
- pixeltable/ext/functions/__init__.py +0 -8
- pixeltable/ext/functions/whisperx.py +0 -77
- pixeltable/ext/functions/yolox.py +0 -157
- pixeltable/tool/create_test_db_dump.py +0 -311
- pixeltable/tool/create_test_video.py +0 -81
- pixeltable/tool/doc_plugins/griffe.py +0 -50
- pixeltable/tool/doc_plugins/mkdocstrings.py +0 -6
- pixeltable/tool/doc_plugins/templates/material/udf.html.jinja +0 -135
- pixeltable/tool/embed_udf.py +0 -9
- pixeltable/tool/mypy_plugin.py +0 -55
- pixeltable/utils/media_store.py +0 -76
- pixeltable/utils/s3.py +0 -16
- pixeltable-0.2.26.dist-info/METADATA +0 -400
- pixeltable-0.2.26.dist-info/RECORD +0 -156
- pixeltable-0.2.26.dist-info/entry_points.txt +0 -3
- {pixeltable-0.2.26.dist-info → pixeltable-0.5.7.dist-info/licenses}/LICENSE +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Pixeltable
|
|
2
|
+
Pixeltable UDFs for `TimestampType`.
|
|
3
3
|
|
|
4
4
|
Usage example:
|
|
5
5
|
```python
|
|
@@ -11,7 +11,6 @@ t.select(t.timestamp_col.year, t.timestamp_col.weekday()).collect()
|
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
13
|
from datetime import datetime
|
|
14
|
-
from typing import Optional
|
|
15
14
|
|
|
16
15
|
import sqlalchemy as sql
|
|
17
16
|
|
|
@@ -19,6 +18,8 @@ import pixeltable as pxt
|
|
|
19
18
|
from pixeltable.env import Env
|
|
20
19
|
from pixeltable.utils.code import local_public_names
|
|
21
20
|
|
|
21
|
+
_SQL_ZERO = sql.literal(0)
|
|
22
|
+
|
|
22
23
|
|
|
23
24
|
@pxt.udf(is_property=True)
|
|
24
25
|
def year(self: datetime) -> int:
|
|
@@ -132,9 +133,11 @@ def astimezone(self: datetime, tz: str) -> datetime:
|
|
|
132
133
|
Convert the datetime to the given time zone.
|
|
133
134
|
|
|
134
135
|
Args:
|
|
135
|
-
tz: The time zone to convert to. Must be a valid time zone name from the
|
|
136
|
+
tz: The time zone to convert to. Must be a valid time zone name from the
|
|
137
|
+
[IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).
|
|
136
138
|
"""
|
|
137
139
|
from zoneinfo import ZoneInfo
|
|
140
|
+
|
|
138
141
|
tzinfo = ZoneInfo(tz)
|
|
139
142
|
return self.astimezone(tzinfo)
|
|
140
143
|
|
|
@@ -190,7 +193,9 @@ def isoformat(self: datetime, sep: str = 'T', timespec: str = 'auto') -> str:
|
|
|
190
193
|
|
|
191
194
|
Args:
|
|
192
195
|
sep: Separator between date and time.
|
|
193
|
-
timespec: The number of additional terms in the output. See the
|
|
196
|
+
timespec: The number of additional terms in the output. See the
|
|
197
|
+
[`datetime.isoformat()`](https://docs.python.org/3/library/datetime.html#datetime.datetime.isoformat)
|
|
198
|
+
documentation for more details.
|
|
194
199
|
"""
|
|
195
200
|
return self.isoformat(sep=sep, timespec=timespec)
|
|
196
201
|
|
|
@@ -203,14 +208,15 @@ def strftime(self: datetime, format: str) -> str:
|
|
|
203
208
|
Equivalent to [`datetime.strftime()`](https://docs.python.org/3/library/datetime.html#datetime.datetime.strftime).
|
|
204
209
|
|
|
205
210
|
Args:
|
|
206
|
-
format: The format string to control the output. For a complete list of formatting directives, see
|
|
211
|
+
format: The format string to control the output. For a complete list of formatting directives, see
|
|
212
|
+
[`strftime()` and `strptime()` Behavior](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-behavior).
|
|
207
213
|
"""
|
|
208
214
|
return self.strftime(format)
|
|
209
215
|
|
|
210
216
|
|
|
211
217
|
@pxt.udf(is_method=True)
|
|
212
218
|
def make_timestamp(
|
|
213
|
-
|
|
219
|
+
year: int, month: int, day: int, hour: int = 0, minute: int = 0, second: int = 0, microsecond: int = 0
|
|
214
220
|
) -> datetime:
|
|
215
221
|
"""
|
|
216
222
|
Create a timestamp.
|
|
@@ -222,17 +228,23 @@ def make_timestamp(
|
|
|
222
228
|
|
|
223
229
|
@make_timestamp.to_sql
|
|
224
230
|
def _(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
231
|
+
year: sql.ColumnElement,
|
|
232
|
+
month: sql.ColumnElement,
|
|
233
|
+
day: sql.ColumnElement,
|
|
234
|
+
hour: sql.ColumnElement = _SQL_ZERO,
|
|
235
|
+
minute: sql.ColumnElement = _SQL_ZERO,
|
|
236
|
+
second: sql.ColumnElement = _SQL_ZERO,
|
|
237
|
+
microsecond: sql.ColumnElement = _SQL_ZERO,
|
|
228
238
|
) -> sql.ColumnElement:
|
|
229
239
|
return sql.func.make_timestamptz(
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
240
|
+
year.cast(sql.Integer),
|
|
241
|
+
month.cast(sql.Integer),
|
|
242
|
+
day.cast(sql.Integer),
|
|
243
|
+
hour.cast(sql.Integer),
|
|
244
|
+
minute.cast(sql.Integer),
|
|
245
|
+
(second + microsecond / 1000000.0).cast(sql.Float),
|
|
246
|
+
)
|
|
247
|
+
|
|
236
248
|
|
|
237
249
|
# @pxt.udf
|
|
238
250
|
# def date(self: datetime) -> datetime:
|
|
@@ -258,9 +270,15 @@ def _(
|
|
|
258
270
|
|
|
259
271
|
@pxt.udf(is_method=True)
|
|
260
272
|
def replace(
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
273
|
+
self: datetime,
|
|
274
|
+
year: int | None = None,
|
|
275
|
+
month: int | None = None,
|
|
276
|
+
day: int | None = None,
|
|
277
|
+
hour: int | None = None,
|
|
278
|
+
minute: int | None = None,
|
|
279
|
+
second: int | None = None,
|
|
280
|
+
microsecond: int | None = None,
|
|
281
|
+
) -> datetime:
|
|
264
282
|
"""
|
|
265
283
|
Return a datetime with the same attributes, except for those attributes given new values by whichever keyword
|
|
266
284
|
arguments are specified.
|
|
@@ -294,5 +312,5 @@ def posix_timestamp(self: datetime) -> float:
|
|
|
294
312
|
__all__ = local_public_names(__name__)
|
|
295
313
|
|
|
296
314
|
|
|
297
|
-
def __dir__():
|
|
315
|
+
def __dir__() -> list[str]:
|
|
298
316
|
return __all__
|
pixeltable/functions/together.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Pixeltable
|
|
2
|
+
Pixeltable UDFs
|
|
3
3
|
that wrap various endpoints from the Together AI API. In order to use them, you must
|
|
4
4
|
first `pip install together` and configure your Together AI credentials, as described in
|
|
5
|
-
the [Working with Together AI](https://pixeltable.
|
|
5
|
+
the [Working with Together AI](https://docs.pixeltable.com/notebooks/integrations/working-with-together-ai) tutorial.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import base64
|
|
9
9
|
import io
|
|
10
|
-
from typing import TYPE_CHECKING,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
import PIL.Image
|
|
@@ -16,6 +16,7 @@ import tenacity
|
|
|
16
16
|
|
|
17
17
|
import pixeltable as pxt
|
|
18
18
|
import pixeltable.exceptions as excs
|
|
19
|
+
import pixeltable.type_system as ts
|
|
19
20
|
from pixeltable import env
|
|
20
21
|
from pixeltable.func import Batch
|
|
21
22
|
from pixeltable.utils.code import local_public_names
|
|
@@ -25,12 +26,13 @@ if TYPE_CHECKING:
|
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
@env.register_client('together')
|
|
28
|
-
def _(api_key: str) -> 'together.
|
|
29
|
+
def _(api_key: str) -> 'together.AsyncTogether':
|
|
29
30
|
import together
|
|
30
|
-
return together.Together(api_key=api_key)
|
|
31
31
|
|
|
32
|
+
return together.AsyncTogether(api_key=api_key)
|
|
32
33
|
|
|
33
|
-
|
|
34
|
+
|
|
35
|
+
def _together_client() -> 'together.AsyncTogether':
|
|
34
36
|
return env.Env.get().get_client('together')
|
|
35
37
|
|
|
36
38
|
|
|
@@ -39,6 +41,7 @@ T = TypeVar('T')
|
|
|
39
41
|
|
|
40
42
|
def _retry(fn: Callable[..., T]) -> Callable[..., T]:
|
|
41
43
|
import together
|
|
44
|
+
|
|
42
45
|
return tenacity.retry(
|
|
43
46
|
retry=tenacity.retry_if_exception_type(together.error.RateLimitError),
|
|
44
47
|
wait=tenacity.wait_random_exponential(multiplier=1, max=60),
|
|
@@ -46,27 +49,17 @@ def _retry(fn: Callable[..., T]) -> Callable[..., T]:
|
|
|
46
49
|
)(fn)
|
|
47
50
|
|
|
48
51
|
|
|
49
|
-
@pxt.udf
|
|
50
|
-
def completions(
|
|
51
|
-
prompt: str,
|
|
52
|
-
*,
|
|
53
|
-
model: str,
|
|
54
|
-
max_tokens: Optional[int] = None,
|
|
55
|
-
stop: Optional[list] = None,
|
|
56
|
-
temperature: Optional[float] = None,
|
|
57
|
-
top_p: Optional[float] = None,
|
|
58
|
-
top_k: Optional[int] = None,
|
|
59
|
-
repetition_penalty: Optional[float] = None,
|
|
60
|
-
logprobs: Optional[int] = None,
|
|
61
|
-
echo: Optional[bool] = None,
|
|
62
|
-
n: Optional[int] = None,
|
|
63
|
-
safety_model: Optional[str] = None,
|
|
64
|
-
) -> dict:
|
|
52
|
+
@pxt.udf(resource_pool='request-rate:together:chat')
|
|
53
|
+
async def completions(prompt: str, *, model: str, model_kwargs: dict[str, Any] | None = None) -> dict:
|
|
65
54
|
"""
|
|
66
55
|
Generate completions based on a given prompt using a specified model.
|
|
67
56
|
|
|
68
57
|
Equivalent to the Together AI `completions` API endpoint.
|
|
69
|
-
For additional details, see:
|
|
58
|
+
For additional details, see: <https://docs.together.ai/reference/completions-1>
|
|
59
|
+
|
|
60
|
+
Request throttling:
|
|
61
|
+
Applies the rate limit set in the config (section `together.rate_limits`, key `chat`). If no rate
|
|
62
|
+
limit is configured, uses a default of 600 RPM.
|
|
70
63
|
|
|
71
64
|
__Requirements:__
|
|
72
65
|
|
|
@@ -75,61 +68,38 @@ def completions(
|
|
|
75
68
|
Args:
|
|
76
69
|
prompt: A string providing context for the model to complete.
|
|
77
70
|
model: The name of the model to query.
|
|
78
|
-
|
|
79
|
-
|
|
71
|
+
model_kwargs: Additional keyword arguments for the Together `completions` API.
|
|
72
|
+
For details on the available parameters, see: <https://docs.together.ai/reference/completions-1>
|
|
80
73
|
|
|
81
74
|
Returns:
|
|
82
75
|
A dictionary containing the response and other metadata.
|
|
83
76
|
|
|
84
77
|
Examples:
|
|
85
|
-
Add a computed column that applies the model `mistralai/Mixtral-8x7B-v0.1` to an existing Pixeltable column
|
|
86
|
-
of the table `tbl`:
|
|
78
|
+
Add a computed column that applies the model `mistralai/Mixtral-8x7B-v0.1` to an existing Pixeltable column
|
|
79
|
+
`tbl.prompt` of the table `tbl`:
|
|
87
80
|
|
|
88
|
-
>>> tbl
|
|
81
|
+
>>> tbl.add_computed_column(response=completions(tbl.prompt, model='mistralai/Mixtral-8x7B-v0.1'))
|
|
89
82
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
logprobs=logprobs,
|
|
101
|
-
echo=echo,
|
|
102
|
-
n=n,
|
|
103
|
-
safety_model=safety_model,
|
|
104
|
-
)
|
|
105
|
-
.dict()
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
@pxt.udf
|
|
110
|
-
def chat_completions(
|
|
111
|
-
messages: list[dict[str, str]],
|
|
112
|
-
*,
|
|
113
|
-
model: str,
|
|
114
|
-
max_tokens: Optional[int] = None,
|
|
115
|
-
stop: Optional[list[str]] = None,
|
|
116
|
-
temperature: Optional[float] = None,
|
|
117
|
-
top_p: Optional[float] = None,
|
|
118
|
-
top_k: Optional[int] = None,
|
|
119
|
-
repetition_penalty: Optional[float] = None,
|
|
120
|
-
logprobs: Optional[int] = None,
|
|
121
|
-
echo: Optional[bool] = None,
|
|
122
|
-
n: Optional[int] = None,
|
|
123
|
-
safety_model: Optional[str] = None,
|
|
124
|
-
response_format: Optional[dict] = None,
|
|
125
|
-
tools: Optional[dict] = None,
|
|
126
|
-
tool_choice: Optional[dict] = None,
|
|
83
|
+
if model_kwargs is None:
|
|
84
|
+
model_kwargs = {}
|
|
85
|
+
|
|
86
|
+
result = await _together_client().completions.create(prompt=prompt, model=model, **model_kwargs)
|
|
87
|
+
return result.dict()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pxt.udf(resource_pool='request-rate:together:chat')
|
|
91
|
+
async def chat_completions(
|
|
92
|
+
messages: list[dict[str, str]], *, model: str, model_kwargs: dict[str, Any] | None = None
|
|
127
93
|
) -> dict:
|
|
128
94
|
"""
|
|
129
95
|
Generate chat completions based on a given prompt using a specified model.
|
|
130
96
|
|
|
131
97
|
Equivalent to the Together AI `chat/completions` API endpoint.
|
|
132
|
-
For additional details, see:
|
|
98
|
+
For additional details, see: <https://docs.together.ai/reference/chat-completions-1>
|
|
99
|
+
|
|
100
|
+
Request throttling:
|
|
101
|
+
Applies the rate limit set in the config (section `together.rate_limits`, key `chat`). If no rate
|
|
102
|
+
limit is configured, uses a default of 600 RPM.
|
|
133
103
|
|
|
134
104
|
__Requirements:__
|
|
135
105
|
|
|
@@ -138,39 +108,24 @@ def chat_completions(
|
|
|
138
108
|
Args:
|
|
139
109
|
messages: A list of messages comprising the conversation so far.
|
|
140
110
|
model: The name of the model to query.
|
|
141
|
-
|
|
142
|
-
|
|
111
|
+
model_kwargs: Additional keyword arguments for the Together `chat/completions` API.
|
|
112
|
+
For details on the available parameters, see: <https://docs.together.ai/reference/chat-completions-1>
|
|
143
113
|
|
|
144
114
|
Returns:
|
|
145
115
|
A dictionary containing the response and other metadata.
|
|
146
116
|
|
|
147
117
|
Examples:
|
|
148
|
-
Add a computed column that applies the model `mistralai/Mixtral-8x7B-v0.1` to an existing Pixeltable column
|
|
149
|
-
of the table `tbl`:
|
|
118
|
+
Add a computed column that applies the model `mistralai/Mixtral-8x7B-v0.1` to an existing Pixeltable column
|
|
119
|
+
`tbl.prompt` of the table `tbl`:
|
|
150
120
|
|
|
151
121
|
>>> messages = [{'role': 'user', 'content': tbl.prompt}]
|
|
152
|
-
... tbl
|
|
122
|
+
... tbl.add_computed_column(response=chat_completions(messages, model='mistralai/Mixtral-8x7B-v0.1'))
|
|
153
123
|
"""
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
stop=stop,
|
|
160
|
-
temperature=temperature,
|
|
161
|
-
top_p=top_p,
|
|
162
|
-
top_k=top_k,
|
|
163
|
-
repetition_penalty=repetition_penalty,
|
|
164
|
-
logprobs=logprobs,
|
|
165
|
-
echo=echo,
|
|
166
|
-
n=n,
|
|
167
|
-
safety_model=safety_model,
|
|
168
|
-
response_format=response_format,
|
|
169
|
-
tools=tools,
|
|
170
|
-
tool_choice=tool_choice,
|
|
171
|
-
)
|
|
172
|
-
.dict()
|
|
173
|
-
)
|
|
124
|
+
if model_kwargs is None:
|
|
125
|
+
model_kwargs = {}
|
|
126
|
+
|
|
127
|
+
result = await _together_client().chat.completions.create(messages=messages, model=model, **model_kwargs)
|
|
128
|
+
return result.dict()
|
|
174
129
|
|
|
175
130
|
|
|
176
131
|
_embedding_dimensions_cache = {
|
|
@@ -185,13 +140,17 @@ _embedding_dimensions_cache = {
|
|
|
185
140
|
}
|
|
186
141
|
|
|
187
142
|
|
|
188
|
-
@pxt.udf(batch_size=32)
|
|
189
|
-
def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
143
|
+
@pxt.udf(batch_size=32, resource_pool='request-rate:together:embeddings')
|
|
144
|
+
async def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
190
145
|
"""
|
|
191
146
|
Query an embedding model for a given string of text.
|
|
192
147
|
|
|
193
148
|
Equivalent to the Together AI `embeddings` API endpoint.
|
|
194
|
-
For additional details, see:
|
|
149
|
+
For additional details, see: <https://docs.together.ai/reference/embeddings-2>
|
|
150
|
+
|
|
151
|
+
Request throttling:
|
|
152
|
+
Applies the rate limit set in the config (section `together.rate_limits`, key `embeddings`). If no rate
|
|
153
|
+
limit is configured, uses a default of 600 RPM.
|
|
195
154
|
|
|
196
155
|
__Requirements:__
|
|
197
156
|
|
|
@@ -208,37 +167,32 @@ def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), pxt
|
|
|
208
167
|
Add a computed column that applies the model `togethercomputer/m2-bert-80M-8k-retrieval`
|
|
209
168
|
to an existing Pixeltable column `tbl.text` of the table `tbl`:
|
|
210
169
|
|
|
211
|
-
>>> tbl
|
|
170
|
+
>>> tbl.add_computed_column(response=embeddings(tbl.text, model='togethercomputer/m2-bert-80M-8k-retrieval'))
|
|
212
171
|
"""
|
|
213
|
-
result =
|
|
172
|
+
result = await _together_client().embeddings.create(input=input, model=model)
|
|
214
173
|
return [np.array(data.embedding, dtype=np.float64) for data in result.data]
|
|
215
174
|
|
|
216
175
|
|
|
217
176
|
@embeddings.conditional_return_type
|
|
218
|
-
def _(model: str) ->
|
|
177
|
+
def _(model: str) -> ts.ArrayType:
|
|
219
178
|
if model not in _embedding_dimensions_cache:
|
|
220
179
|
# TODO: find some other way to retrieve a sample
|
|
221
|
-
return
|
|
180
|
+
return ts.ArrayType((None,), dtype=ts.FloatType())
|
|
222
181
|
dimensions = _embedding_dimensions_cache[model]
|
|
223
|
-
return
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
@pxt.udf
|
|
227
|
-
def image_generations(
|
|
228
|
-
prompt: str,
|
|
229
|
-
*,
|
|
230
|
-
model: str,
|
|
231
|
-
steps: Optional[int] = None,
|
|
232
|
-
seed: Optional[int] = None,
|
|
233
|
-
height: Optional[int] = None,
|
|
234
|
-
width: Optional[int] = None,
|
|
235
|
-
negative_prompt: Optional[str] = None,
|
|
236
|
-
) -> PIL.Image.Image:
|
|
182
|
+
return ts.ArrayType((dimensions,), dtype=ts.FloatType())
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@pxt.udf(resource_pool='request-rate:together:images')
|
|
186
|
+
async def image_generations(prompt: str, *, model: str, model_kwargs: dict[str, Any] | None = None) -> PIL.Image.Image:
|
|
237
187
|
"""
|
|
238
188
|
Generate images based on a given prompt using a specified model.
|
|
239
189
|
|
|
240
190
|
Equivalent to the Together AI `images/generations` API endpoint.
|
|
241
|
-
For additional details, see:
|
|
191
|
+
For additional details, see: <https://docs.together.ai/reference/post_images-generations>
|
|
192
|
+
|
|
193
|
+
Request throttling:
|
|
194
|
+
Applies the rate limit set in the config (section `together.rate_limits`, key `images`). If no rate
|
|
195
|
+
limit is configured, uses a default of 600 RPM.
|
|
242
196
|
|
|
243
197
|
__Requirements:__
|
|
244
198
|
|
|
@@ -247,8 +201,8 @@ def image_generations(
|
|
|
247
201
|
Args:
|
|
248
202
|
prompt: A description of the desired images.
|
|
249
203
|
model: The model to use for image generation.
|
|
250
|
-
|
|
251
|
-
|
|
204
|
+
model_kwargs: Additional keyword args for the Together `images/generations` API.
|
|
205
|
+
For details on the available parameters, see: <https://docs.together.ai/reference/post_images-generations>
|
|
252
206
|
|
|
253
207
|
Returns:
|
|
254
208
|
The generated image.
|
|
@@ -257,11 +211,14 @@ def image_generations(
|
|
|
257
211
|
Add a computed column that applies the model `stabilityai/stable-diffusion-xl-base-1.0`
|
|
258
212
|
to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
|
|
259
213
|
|
|
260
|
-
>>> tbl
|
|
214
|
+
>>> tbl.add_computed_column(
|
|
215
|
+
... response=image_generations(tbl.prompt, model='stabilityai/stable-diffusion-xl-base-1.0')
|
|
216
|
+
... )
|
|
261
217
|
"""
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
218
|
+
if model_kwargs is None:
|
|
219
|
+
model_kwargs = {}
|
|
220
|
+
|
|
221
|
+
result = await _together_client().images.generate(prompt=prompt, model=model, **model_kwargs)
|
|
265
222
|
if result.data[0].b64_json is not None:
|
|
266
223
|
b64_bytes = base64.b64decode(result.data[0].b64_json)
|
|
267
224
|
img = PIL.Image.open(io.BytesIO(b64_bytes))
|
|
@@ -282,5 +239,5 @@ def image_generations(
|
|
|
282
239
|
__all__ = local_public_names(__name__)
|
|
283
240
|
|
|
284
241
|
|
|
285
|
-
def __dir__():
|
|
242
|
+
def __dir__() -> list[str]:
|
|
286
243
|
return __all__
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixeltable UDFs
|
|
3
|
+
that wrap various endpoints from the TwelveLabs API. In order to use them, you must
|
|
4
|
+
first `pip install twelvelabs` and configure your TwelveLabs credentials, as described in
|
|
5
|
+
the [Working with TwelveLabs](https://docs.pixeltable.com/notebooks/integrations/working-with-twelvelabs) tutorial.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from base64 import b64encode
|
|
9
|
+
from typing import TYPE_CHECKING, Literal
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
import pixeltable as pxt
|
|
14
|
+
from pixeltable import env, type_system as ts
|
|
15
|
+
from pixeltable.utils.code import local_public_names
|
|
16
|
+
from pixeltable.utils.image import to_base64
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from twelvelabs import AsyncTwelveLabs
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@env.register_client('twelvelabs')
|
|
23
|
+
def _(api_key: str) -> 'AsyncTwelveLabs':
|
|
24
|
+
from twelvelabs import AsyncTwelveLabs
|
|
25
|
+
|
|
26
|
+
return AsyncTwelveLabs(api_key=api_key)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _twelvelabs_client() -> 'AsyncTwelveLabs':
|
|
30
|
+
return env.Env.get().get_client('twelvelabs')
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pxt.udf(resource_pool='request-rate:twelvelabs')
|
|
34
|
+
async def embed(text: str, image: pxt.Image | None = None, *, model_name: str) -> pxt.Array[np.float32] | None:
|
|
35
|
+
"""
|
|
36
|
+
Creates an embedding vector for the given text, audio, image, or video input.
|
|
37
|
+
|
|
38
|
+
Each UDF signature corresponds to one of the four supported input types. If text is specified, it is possible to
|
|
39
|
+
specify an image as well, corresponding to the `text_image` embedding type in the TwelveLabs API. This is
|
|
40
|
+
(currently) the only way to include more than one input type at a time.
|
|
41
|
+
|
|
42
|
+
Equivalent to the TwelveLabs Embed API:
|
|
43
|
+
<https://docs.twelvelabs.io/v1.3/docs/guides/create-embeddings>
|
|
44
|
+
|
|
45
|
+
Request throttling:
|
|
46
|
+
Applies the rate limit set in the config (section `twelvelabs`, key `rate_limit`). If no rate
|
|
47
|
+
limit is configured, uses a default of 600 RPM.
|
|
48
|
+
|
|
49
|
+
__Requirements:__
|
|
50
|
+
|
|
51
|
+
- `pip install twelvelabs`
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
model_name: The name of the model to use. Check
|
|
55
|
+
[the TwelveLabs documentation](https://docs.twelvelabs.io/v1.3/sdk-reference/python/create-text-image-and-audio-embeddings)
|
|
56
|
+
for available models.
|
|
57
|
+
text: The text to embed.
|
|
58
|
+
image: If specified, the embedding will be created from both the text and the image.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The embedding.
|
|
62
|
+
|
|
63
|
+
Examples:
|
|
64
|
+
Add a computed column `embed` for an embedding of a string column `input`:
|
|
65
|
+
|
|
66
|
+
>>> tbl.add_computed_column(
|
|
67
|
+
... embed=embed(model_name='marengo3.0', text=tbl.input)
|
|
68
|
+
... )
|
|
69
|
+
"""
|
|
70
|
+
env.Env.get().require_package('twelvelabs')
|
|
71
|
+
import twelvelabs
|
|
72
|
+
|
|
73
|
+
cl = _twelvelabs_client()
|
|
74
|
+
res: twelvelabs.EmbeddingSuccessResponse
|
|
75
|
+
if image is None:
|
|
76
|
+
# Text-only
|
|
77
|
+
res = await cl.embed.v_2.create(
|
|
78
|
+
input_type='text', model_name=model_name, text=twelvelabs.TextInputRequest(input_text=text)
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
b64str = to_base64(image, format=('png' if image.has_transparency_data else 'jpeg'))
|
|
82
|
+
res = await cl.embed.v_2.create(
|
|
83
|
+
input_type='text_image',
|
|
84
|
+
model_name=model_name,
|
|
85
|
+
text_image=twelvelabs.TextImageInputRequest(
|
|
86
|
+
media_source=twelvelabs.MediaSource(base_64_string=b64str), input_text=text
|
|
87
|
+
),
|
|
88
|
+
)
|
|
89
|
+
if not res.data:
|
|
90
|
+
raise pxt.Error(f"Didn't receive embedding for text: {text}\n{res}")
|
|
91
|
+
vector = res.data[0].embedding
|
|
92
|
+
return np.array(vector, dtype='float32')
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@embed.overload
|
|
96
|
+
async def _(image: pxt.Image, *, model_name: str) -> pxt.Array[np.float32] | None:
|
|
97
|
+
env.Env.get().require_package('twelvelabs')
|
|
98
|
+
import twelvelabs
|
|
99
|
+
|
|
100
|
+
cl = _twelvelabs_client()
|
|
101
|
+
b64_str = to_base64(image, format=('png' if image.has_transparency_data else 'jpeg'))
|
|
102
|
+
res = await cl.embed.v_2.create(
|
|
103
|
+
input_type='image',
|
|
104
|
+
model_name=model_name,
|
|
105
|
+
image=twelvelabs.ImageInputRequest(media_source=twelvelabs.MediaSource(base_64_string=b64_str)),
|
|
106
|
+
)
|
|
107
|
+
if not res.data:
|
|
108
|
+
raise pxt.Error(f"Didn't receive embedding for image: {image}\n{res}")
|
|
109
|
+
vector = res.data[0].embedding
|
|
110
|
+
return np.array(vector, dtype='float32')
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@embed.overload
|
|
114
|
+
async def _(
|
|
115
|
+
audio: pxt.Audio,
|
|
116
|
+
*,
|
|
117
|
+
model_name: str,
|
|
118
|
+
start_sec: float | None = None,
|
|
119
|
+
end_sec: float | None = None,
|
|
120
|
+
embedding_option: list[Literal['audio', 'transcription']] | None = None,
|
|
121
|
+
) -> pxt.Array[np.float32] | None:
|
|
122
|
+
env.Env.get().require_package('twelvelabs')
|
|
123
|
+
import twelvelabs
|
|
124
|
+
|
|
125
|
+
cl = _twelvelabs_client()
|
|
126
|
+
with open(audio, 'rb') as fp:
|
|
127
|
+
b64_str = b64encode(fp.read()).decode('utf-8')
|
|
128
|
+
res = await cl.embed.v_2.create(
|
|
129
|
+
input_type='audio',
|
|
130
|
+
model_name=model_name,
|
|
131
|
+
audio=twelvelabs.AudioInputRequest(
|
|
132
|
+
media_source=twelvelabs.MediaSource(base_64_string=b64_str),
|
|
133
|
+
start_sec=start_sec,
|
|
134
|
+
end_sec=end_sec,
|
|
135
|
+
embedding_option=embedding_option,
|
|
136
|
+
),
|
|
137
|
+
)
|
|
138
|
+
if not res.data:
|
|
139
|
+
raise pxt.Error(f"Didn't receive embedding for audio: {audio}\n{res}")
|
|
140
|
+
vector = res.data[0].embedding
|
|
141
|
+
return np.array(vector, dtype='float32')
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@embed.overload
|
|
145
|
+
async def _(
|
|
146
|
+
video: pxt.Video,
|
|
147
|
+
*,
|
|
148
|
+
model_name: str,
|
|
149
|
+
start_sec: float | None = None,
|
|
150
|
+
end_sec: float | None = None,
|
|
151
|
+
embedding_option: list[Literal['visual', 'audio', 'transcription']] | None = None,
|
|
152
|
+
) -> pxt.Array[np.float32] | None:
|
|
153
|
+
env.Env.get().require_package('twelvelabs')
|
|
154
|
+
import twelvelabs
|
|
155
|
+
|
|
156
|
+
cl = _twelvelabs_client()
|
|
157
|
+
with open(video, 'rb') as fp:
|
|
158
|
+
b64_str = b64encode(fp.read()).decode('utf-8')
|
|
159
|
+
res = await cl.embed.v_2.create(
|
|
160
|
+
input_type='video',
|
|
161
|
+
model_name=model_name,
|
|
162
|
+
video=twelvelabs.VideoInputRequest(
|
|
163
|
+
media_source=twelvelabs.MediaSource(base_64_string=b64_str),
|
|
164
|
+
start_sec=start_sec,
|
|
165
|
+
end_sec=end_sec,
|
|
166
|
+
embedding_option=embedding_option,
|
|
167
|
+
),
|
|
168
|
+
)
|
|
169
|
+
if not res.data:
|
|
170
|
+
raise pxt.Error(f"Didn't receive embedding for video: {video}\n{res}")
|
|
171
|
+
vector = res.data[0].embedding
|
|
172
|
+
return np.array(vector, dtype='float32')
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@embed.conditional_return_type
|
|
176
|
+
def _(model_name: str) -> ts.ArrayType:
|
|
177
|
+
if model_name == 'Marengo-retrieval-2.7':
|
|
178
|
+
return ts.ArrayType(shape=(1024,), dtype=np.dtype('float32'))
|
|
179
|
+
if model_name == 'marengo3.0':
|
|
180
|
+
return ts.ArrayType(shape=(512,), dtype=np.dtype('float32'))
|
|
181
|
+
return ts.ArrayType(dtype=np.dtype('float32'))
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
__all__ = local_public_names(__name__)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def __dir__() -> list[str]:
|
|
188
|
+
return __all__
|
pixeltable/functions/util.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import PIL.Image
|
|
2
2
|
|
|
3
|
+
from pixeltable.config import Config
|
|
3
4
|
from pixeltable.env import Env
|
|
4
5
|
|
|
5
6
|
|
|
@@ -7,10 +8,14 @@ def resolve_torch_device(device: str, allow_mps: bool = True) -> str:
|
|
|
7
8
|
Env.get().require_package('torch')
|
|
8
9
|
import torch
|
|
9
10
|
|
|
11
|
+
mps_enabled = Config.get().get_bool_value('enable_mps')
|
|
12
|
+
if mps_enabled is None:
|
|
13
|
+
mps_enabled = True # Default to True if not set in config
|
|
14
|
+
|
|
10
15
|
if device == 'auto':
|
|
11
16
|
if torch.cuda.is_available():
|
|
12
17
|
return 'cuda'
|
|
13
|
-
if allow_mps and torch.backends.mps.is_available():
|
|
18
|
+
if mps_enabled and allow_mps and torch.backends.mps.is_available():
|
|
14
19
|
return 'mps'
|
|
15
20
|
return 'cpu'
|
|
16
21
|
return device
|
|
@@ -21,7 +26,7 @@ def normalize_image_mode(image: PIL.Image.Image) -> PIL.Image.Image:
|
|
|
21
26
|
Converts grayscale images to 3-channel for compatibility with models that only work with
|
|
22
27
|
multichannel input.
|
|
23
28
|
"""
|
|
24
|
-
if image.mode
|
|
29
|
+
if image.mode in ('1', 'L'):
|
|
25
30
|
return image.convert('RGB')
|
|
26
31
|
if image.mode == 'LA':
|
|
27
32
|
return image.convert('RGBA')
|