pixeltable 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pixeltable might be problematic. Click here for more details.
- pixeltable/__init__.py +7 -19
- pixeltable/__version__.py +2 -2
- pixeltable/catalog/__init__.py +7 -7
- pixeltable/catalog/column.py +37 -11
- pixeltable/catalog/globals.py +21 -0
- pixeltable/catalog/insertable_table.py +6 -4
- pixeltable/catalog/table.py +227 -148
- pixeltable/catalog/table_version.py +66 -28
- pixeltable/catalog/table_version_path.py +0 -8
- pixeltable/catalog/view.py +18 -19
- pixeltable/dataframe.py +16 -32
- pixeltable/env.py +6 -1
- pixeltable/exec/__init__.py +1 -2
- pixeltable/exec/aggregation_node.py +27 -17
- pixeltable/exec/cache_prefetch_node.py +1 -1
- pixeltable/exec/data_row_batch.py +9 -26
- pixeltable/exec/exec_node.py +36 -7
- pixeltable/exec/expr_eval_node.py +19 -11
- pixeltable/exec/in_memory_data_node.py +14 -11
- pixeltable/exec/sql_node.py +266 -138
- pixeltable/exprs/__init__.py +1 -0
- pixeltable/exprs/arithmetic_expr.py +3 -1
- pixeltable/exprs/array_slice.py +7 -7
- pixeltable/exprs/column_property_ref.py +37 -10
- pixeltable/exprs/column_ref.py +93 -14
- pixeltable/exprs/comparison.py +5 -5
- pixeltable/exprs/compound_predicate.py +8 -7
- pixeltable/exprs/data_row.py +56 -36
- pixeltable/exprs/expr.py +65 -63
- pixeltable/exprs/expr_dict.py +55 -0
- pixeltable/exprs/expr_set.py +26 -15
- pixeltable/exprs/function_call.py +53 -24
- pixeltable/exprs/globals.py +4 -1
- pixeltable/exprs/in_predicate.py +8 -7
- pixeltable/exprs/inline_expr.py +4 -4
- pixeltable/exprs/is_null.py +4 -4
- pixeltable/exprs/json_mapper.py +11 -12
- pixeltable/exprs/json_path.py +5 -10
- pixeltable/exprs/literal.py +5 -5
- pixeltable/exprs/method_ref.py +5 -4
- pixeltable/exprs/object_ref.py +2 -1
- pixeltable/exprs/row_builder.py +88 -36
- pixeltable/exprs/rowid_ref.py +14 -13
- pixeltable/exprs/similarity_expr.py +12 -7
- pixeltable/exprs/sql_element_cache.py +12 -6
- pixeltable/exprs/type_cast.py +8 -6
- pixeltable/exprs/variable.py +5 -4
- pixeltable/ext/functions/whisperx.py +7 -2
- pixeltable/func/aggregate_function.py +1 -1
- pixeltable/func/callable_function.py +2 -2
- pixeltable/func/function.py +11 -10
- pixeltable/func/function_registry.py +6 -7
- pixeltable/func/query_template_function.py +11 -12
- pixeltable/func/signature.py +17 -15
- pixeltable/func/udf.py +0 -4
- pixeltable/functions/__init__.py +2 -2
- pixeltable/functions/audio.py +4 -6
- pixeltable/functions/globals.py +84 -42
- pixeltable/functions/huggingface.py +31 -34
- pixeltable/functions/image.py +59 -45
- pixeltable/functions/json.py +0 -1
- pixeltable/functions/llama_cpp.py +106 -0
- pixeltable/functions/mistralai.py +2 -2
- pixeltable/functions/ollama.py +147 -0
- pixeltable/functions/openai.py +22 -25
- pixeltable/functions/replicate.py +72 -0
- pixeltable/functions/string.py +59 -50
- pixeltable/functions/timestamp.py +20 -20
- pixeltable/functions/together.py +2 -2
- pixeltable/functions/video.py +11 -20
- pixeltable/functions/whisper.py +2 -20
- pixeltable/globals.py +65 -74
- pixeltable/index/base.py +2 -2
- pixeltable/index/btree.py +20 -7
- pixeltable/index/embedding_index.py +12 -14
- pixeltable/io/__init__.py +1 -2
- pixeltable/io/external_store.py +11 -5
- pixeltable/io/fiftyone.py +178 -0
- pixeltable/io/globals.py +98 -2
- pixeltable/io/hf_datasets.py +1 -1
- pixeltable/io/label_studio.py +6 -6
- pixeltable/io/parquet.py +14 -13
- pixeltable/iterators/base.py +3 -2
- pixeltable/iterators/document.py +10 -8
- pixeltable/iterators/video.py +126 -60
- pixeltable/metadata/__init__.py +4 -3
- pixeltable/metadata/converters/convert_14.py +4 -2
- pixeltable/metadata/converters/convert_15.py +1 -1
- pixeltable/metadata/converters/convert_19.py +1 -0
- pixeltable/metadata/converters/convert_20.py +1 -1
- pixeltable/metadata/converters/convert_21.py +34 -0
- pixeltable/metadata/converters/util.py +54 -12
- pixeltable/metadata/notes.py +1 -0
- pixeltable/metadata/schema.py +40 -21
- pixeltable/plan.py +149 -165
- pixeltable/py.typed +0 -0
- pixeltable/store.py +57 -37
- pixeltable/tool/create_test_db_dump.py +6 -6
- pixeltable/tool/create_test_video.py +1 -1
- pixeltable/tool/doc_plugins/griffe.py +3 -34
- pixeltable/tool/embed_udf.py +1 -1
- pixeltable/tool/mypy_plugin.py +55 -0
- pixeltable/type_system.py +260 -61
- pixeltable/utils/arrow.py +10 -9
- pixeltable/utils/coco.py +4 -4
- pixeltable/utils/documents.py +16 -2
- pixeltable/utils/filecache.py +9 -9
- pixeltable/utils/formatter.py +10 -11
- pixeltable/utils/http_server.py +2 -5
- pixeltable/utils/media_store.py +6 -6
- pixeltable/utils/pytorch.py +10 -11
- pixeltable/utils/sql.py +2 -1
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/METADATA +50 -13
- pixeltable-0.2.22.dist-info/RECORD +153 -0
- pixeltable/exec/media_validation_node.py +0 -43
- pixeltable/utils/help.py +0 -11
- pixeltable-0.2.20.dist-info/RECORD +0 -147
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/LICENSE +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/WHEEL +0 -0
- {pixeltable-0.2.20.dist-info → pixeltable-0.2.22.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
|
|
4
|
+
import pixeltable as pxt
|
|
5
|
+
import pixeltable.exceptions as excs
|
|
6
|
+
from pixeltable.env import Env
|
|
7
|
+
from pixeltable.utils.code import local_public_names
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import llama_cpp
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pxt.udf
|
|
14
|
+
def create_chat_completion(
|
|
15
|
+
messages: list[dict],
|
|
16
|
+
*,
|
|
17
|
+
model_path: Optional[str] = None,
|
|
18
|
+
repo_id: Optional[str] = None,
|
|
19
|
+
repo_filename: Optional[str] = None,
|
|
20
|
+
args: Optional[dict[str, Any]] = None,
|
|
21
|
+
) -> dict:
|
|
22
|
+
"""
|
|
23
|
+
Generate a chat completion from a list of messages.
|
|
24
|
+
|
|
25
|
+
The model can be specified either as a local path, or as a repo_id and repo_filename that reference a pretrained
|
|
26
|
+
model on the Hugging Face model hub. Exactly one of `model_path` or `repo_id` must be provided; if `model_path`
|
|
27
|
+
is provided, then an optional `repo_filename` can also be specified.
|
|
28
|
+
|
|
29
|
+
For additional details, see the
|
|
30
|
+
[llama_cpp create_chat_completions documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
messages: A list of messages to generate a response for.
|
|
34
|
+
model_path: Path to the model (if using a local model).
|
|
35
|
+
repo_id: The Hugging Face model repo id (if using a pretrained model).
|
|
36
|
+
repo_filename: A filename or glob pattern to match the model file in the repo (optional, if using a
|
|
37
|
+
pretrained model).
|
|
38
|
+
args: Additional arguments to pass to the `create_chat_completions` call, such as `max_tokens`, `temperature`,
|
|
39
|
+
`top_p`, and `top_k`. For details, see the
|
|
40
|
+
[llama_cpp create_chat_completions documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
|
|
41
|
+
"""
|
|
42
|
+
Env.get().require_package('llama_cpp', min_version=[0, 3, 1])
|
|
43
|
+
|
|
44
|
+
if args is None:
|
|
45
|
+
args = {}
|
|
46
|
+
|
|
47
|
+
if (model_path is None) == (repo_id is None):
|
|
48
|
+
raise excs.Error('Exactly one of `model_path` or `repo_id` must be provided.')
|
|
49
|
+
if (repo_id is None) and (repo_filename is not None):
|
|
50
|
+
raise excs.Error('`repo_filename` can only be provided along with `repo_id`.')
|
|
51
|
+
|
|
52
|
+
n_gpu_layers = -1 if _is_gpu_available() else 0 # 0 = CPU only, -1 = offload all layers to GPU
|
|
53
|
+
|
|
54
|
+
if model_path is not None:
|
|
55
|
+
llm = _lookup_local_model(model_path, n_gpu_layers)
|
|
56
|
+
else:
|
|
57
|
+
Env.get().require_package('huggingface_hub')
|
|
58
|
+
llm = _lookup_pretrained_model(repo_id, repo_filename, n_gpu_layers)
|
|
59
|
+
return llm.create_chat_completion(messages, **args) # type: ignore
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _is_gpu_available() -> bool:
|
|
63
|
+
import llama_cpp
|
|
64
|
+
|
|
65
|
+
global _IS_GPU_AVAILABLE
|
|
66
|
+
if _IS_GPU_AVAILABLE is None:
|
|
67
|
+
llama_cpp_path = Path(llama_cpp.__file__).parent
|
|
68
|
+
lib = llama_cpp.llama_cpp.load_shared_library('llama', llama_cpp_path / 'lib')
|
|
69
|
+
_IS_GPU_AVAILABLE = bool(lib.llama_supports_gpu_offload())
|
|
70
|
+
|
|
71
|
+
return _IS_GPU_AVAILABLE
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _lookup_local_model(model_path: str, n_gpu_layers: int) -> 'llama_cpp.Llama':
|
|
75
|
+
import llama_cpp
|
|
76
|
+
|
|
77
|
+
key = (model_path, None, n_gpu_layers)
|
|
78
|
+
if key not in _model_cache:
|
|
79
|
+
llm = llama_cpp.Llama(model_path, n_gpu_layers=n_gpu_layers)
|
|
80
|
+
_model_cache[key] = llm
|
|
81
|
+
return _model_cache[key]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _lookup_pretrained_model(repo_id: str, filename: Optional[str], n_gpu_layers: int) -> 'llama_cpp.Llama':
|
|
85
|
+
import llama_cpp
|
|
86
|
+
|
|
87
|
+
key = (repo_id, filename, n_gpu_layers)
|
|
88
|
+
if key not in _model_cache:
|
|
89
|
+
llm = llama_cpp.Llama.from_pretrained(
|
|
90
|
+
repo_id=repo_id,
|
|
91
|
+
filename=filename,
|
|
92
|
+
n_gpu_layers=n_gpu_layers
|
|
93
|
+
)
|
|
94
|
+
_model_cache[key] = llm
|
|
95
|
+
return _model_cache[key]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
_model_cache: dict[tuple[str, str, int], Any] = {}
|
|
99
|
+
_IS_GPU_AVAILABLE: Optional[bool] = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
__all__ = local_public_names(__name__)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def __dir__():
|
|
106
|
+
return __all__
|
|
@@ -140,8 +140,8 @@ _embedding_dimensions_cache: dict[str, int] = {
|
|
|
140
140
|
}
|
|
141
141
|
|
|
142
142
|
|
|
143
|
-
@pxt.udf(batch_size=16
|
|
144
|
-
def embeddings(input: Batch[str], *, model: str) -> Batch[
|
|
143
|
+
@pxt.udf(batch_size=16)
|
|
144
|
+
def embeddings(input: Batch[str], *, model: str) -> Batch[pxt.Array[(None,), float]]:
|
|
145
145
|
"""
|
|
146
146
|
Embeddings API.
|
|
147
147
|
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import pixeltable as pxt
|
|
6
|
+
from pixeltable import env
|
|
7
|
+
from pixeltable.func import Batch
|
|
8
|
+
from pixeltable.utils.code import local_public_names
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import ollama
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@env.register_client('ollama')
|
|
15
|
+
def _(host: str) -> 'ollama.Client':
|
|
16
|
+
import ollama
|
|
17
|
+
return ollama.Client(host=host)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _ollama_client() -> Optional['ollama.Client']:
|
|
21
|
+
try:
|
|
22
|
+
return env.Env.get().get_client('ollama')
|
|
23
|
+
except Exception:
|
|
24
|
+
return None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pxt.udf
|
|
28
|
+
def generate(
|
|
29
|
+
prompt: str,
|
|
30
|
+
*,
|
|
31
|
+
model: str,
|
|
32
|
+
suffix: str = '',
|
|
33
|
+
system: str = '',
|
|
34
|
+
template: str = '',
|
|
35
|
+
context: Optional[list[int]] = None,
|
|
36
|
+
raw: bool = False,
|
|
37
|
+
format: str = '',
|
|
38
|
+
options: Optional[dict] = None,
|
|
39
|
+
) -> dict:
|
|
40
|
+
"""
|
|
41
|
+
Generate a response for a given prompt with a provided model.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
prompt: The prompt to generate a response for.
|
|
45
|
+
model: The model name.
|
|
46
|
+
suffix: The text after the model response.
|
|
47
|
+
format: The format of the response; must be one of `'json'` or `''` (the empty string).
|
|
48
|
+
system: System message.
|
|
49
|
+
template: Prompt template to use.
|
|
50
|
+
context: The context parameter returned from a previous call to `generate()`.
|
|
51
|
+
raw: If `True`, no formatting will be applied to the prompt.
|
|
52
|
+
options: Additional options to pass to the `chat` call, such as `max_tokens`, `temperature`, `top_p`, and `top_k`.
|
|
53
|
+
For details, see the
|
|
54
|
+
[Valid Parameters and Values](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
|
55
|
+
section of the Ollama documentation.
|
|
56
|
+
"""
|
|
57
|
+
env.Env.get().require_package('ollama')
|
|
58
|
+
import ollama
|
|
59
|
+
|
|
60
|
+
client = _ollama_client() or ollama
|
|
61
|
+
return client.generate(
|
|
62
|
+
model=model,
|
|
63
|
+
prompt=prompt,
|
|
64
|
+
suffix=suffix,
|
|
65
|
+
system=system,
|
|
66
|
+
template=template,
|
|
67
|
+
context=context,
|
|
68
|
+
raw=raw,
|
|
69
|
+
format=format,
|
|
70
|
+
options=options,
|
|
71
|
+
) # type: ignore[call-overload]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pxt.udf
|
|
75
|
+
def chat(
|
|
76
|
+
messages: list[dict],
|
|
77
|
+
*,
|
|
78
|
+
model: str,
|
|
79
|
+
tools: Optional[list[dict]] = None,
|
|
80
|
+
format: str = '',
|
|
81
|
+
options: Optional[dict] = None,
|
|
82
|
+
) -> dict:
|
|
83
|
+
"""
|
|
84
|
+
Generate the next message in a chat with a provided model.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
messages: The messages of the chat.
|
|
88
|
+
model: The model name.
|
|
89
|
+
tools: Tools for the model to use.
|
|
90
|
+
format: The format of the response; must be one of `'json'` or `''` (the empty string).
|
|
91
|
+
options: Additional options to pass to the `chat` call, such as `max_tokens`, `temperature`, `top_p`, and `top_k`.
|
|
92
|
+
For details, see the
|
|
93
|
+
[Valid Parameters and Values](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
|
94
|
+
section of the Ollama documentation.
|
|
95
|
+
"""
|
|
96
|
+
env.Env.get().require_package('ollama')
|
|
97
|
+
import ollama
|
|
98
|
+
|
|
99
|
+
client = _ollama_client() or ollama
|
|
100
|
+
return client.chat(
|
|
101
|
+
model=model,
|
|
102
|
+
messages=messages,
|
|
103
|
+
tools=tools,
|
|
104
|
+
format=format,
|
|
105
|
+
options=options,
|
|
106
|
+
) # type: ignore[call-overload]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@pxt.udf(batch_size=16)
|
|
110
|
+
def embed(
|
|
111
|
+
input: Batch[str],
|
|
112
|
+
*,
|
|
113
|
+
model: str,
|
|
114
|
+
truncate: bool = True,
|
|
115
|
+
options: Optional[dict] = None,
|
|
116
|
+
) -> Batch[pxt.Array[(None,), pxt.Float]]:
|
|
117
|
+
"""
|
|
118
|
+
Generate embeddings from a model.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
input: The input text to generate embeddings for.
|
|
122
|
+
model: The model name.
|
|
123
|
+
truncate: Truncates the end of each input to fit within context length.
|
|
124
|
+
Returns error if false and context length is exceeded.
|
|
125
|
+
options: Additional options to pass to the `embed` call.
|
|
126
|
+
For details, see the
|
|
127
|
+
[Valid Parameters and Values](https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values)
|
|
128
|
+
section of the Ollama documentation.
|
|
129
|
+
"""
|
|
130
|
+
env.Env.get().require_package('ollama')
|
|
131
|
+
import ollama
|
|
132
|
+
|
|
133
|
+
client = _ollama_client() or ollama
|
|
134
|
+
results = client.embed(
|
|
135
|
+
model=model,
|
|
136
|
+
input=input,
|
|
137
|
+
truncate=truncate,
|
|
138
|
+
options=options, # type: ignore[arg-type]
|
|
139
|
+
)
|
|
140
|
+
return [np.array(data, dtype=np.float64) for data in results['embeddings']]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
__all__ = local_public_names(__name__)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def __dir__():
|
|
147
|
+
return __all__
|
pixeltable/functions/openai.py
CHANGED
|
@@ -16,7 +16,6 @@ import PIL.Image
|
|
|
16
16
|
import tenacity
|
|
17
17
|
|
|
18
18
|
import pixeltable as pxt
|
|
19
|
-
import pixeltable.type_system as ts
|
|
20
19
|
from pixeltable import env
|
|
21
20
|
from pixeltable.func import Batch
|
|
22
21
|
from pixeltable.utils.code import local_public_names
|
|
@@ -51,10 +50,10 @@ def _retry(fn: Callable) -> Callable:
|
|
|
51
50
|
# Audio Endpoints
|
|
52
51
|
|
|
53
52
|
|
|
54
|
-
@pxt.udf
|
|
53
|
+
@pxt.udf
|
|
55
54
|
def speech(
|
|
56
55
|
input: str, *, model: str, voice: str, response_format: Optional[str] = None, speed: Optional[float] = None
|
|
57
|
-
) ->
|
|
56
|
+
) -> pxt.Audio:
|
|
58
57
|
"""
|
|
59
58
|
Generates audio from the input text.
|
|
60
59
|
|
|
@@ -91,17 +90,9 @@ def speech(
|
|
|
91
90
|
return output_filename
|
|
92
91
|
|
|
93
92
|
|
|
94
|
-
@pxt.udf
|
|
95
|
-
param_types=[
|
|
96
|
-
ts.AudioType(),
|
|
97
|
-
ts.StringType(),
|
|
98
|
-
ts.StringType(nullable=True),
|
|
99
|
-
ts.StringType(nullable=True),
|
|
100
|
-
ts.FloatType(nullable=True),
|
|
101
|
-
]
|
|
102
|
-
)
|
|
93
|
+
@pxt.udf
|
|
103
94
|
def transcriptions(
|
|
104
|
-
audio:
|
|
95
|
+
audio: pxt.Audio,
|
|
105
96
|
*,
|
|
106
97
|
model: str,
|
|
107
98
|
language: Optional[str] = None,
|
|
@@ -140,8 +131,14 @@ def transcriptions(
|
|
|
140
131
|
return transcription.dict()
|
|
141
132
|
|
|
142
133
|
|
|
143
|
-
@pxt.udf
|
|
144
|
-
def translations(
|
|
134
|
+
@pxt.udf
|
|
135
|
+
def translations(
|
|
136
|
+
audio: pxt.Audio,
|
|
137
|
+
*,
|
|
138
|
+
model: str,
|
|
139
|
+
prompt: Optional[str] = None,
|
|
140
|
+
temperature: Optional[float] = None
|
|
141
|
+
) -> dict:
|
|
145
142
|
"""
|
|
146
143
|
Translates audio into English.
|
|
147
144
|
|
|
@@ -304,10 +301,10 @@ _embedding_dimensions_cache: dict[str, int] = {
|
|
|
304
301
|
}
|
|
305
302
|
|
|
306
303
|
|
|
307
|
-
@pxt.udf(batch_size=32
|
|
304
|
+
@pxt.udf(batch_size=32)
|
|
308
305
|
def embeddings(
|
|
309
306
|
input: Batch[str], *, model: str, dimensions: Optional[int] = None, user: Optional[str] = None
|
|
310
|
-
) -> Batch[
|
|
307
|
+
) -> Batch[pxt.Array[(None,), float]]:
|
|
311
308
|
"""
|
|
312
309
|
Creates an embedding vector representing the input text.
|
|
313
310
|
|
|
@@ -342,13 +339,13 @@ def embeddings(
|
|
|
342
339
|
|
|
343
340
|
|
|
344
341
|
@embeddings.conditional_return_type
|
|
345
|
-
def _(model: str, dimensions: Optional[int] = None) ->
|
|
342
|
+
def _(model: str, dimensions: Optional[int] = None) -> pxt.ArrayType:
|
|
346
343
|
if dimensions is None:
|
|
347
344
|
if model not in _embedding_dimensions_cache:
|
|
348
345
|
# TODO: find some other way to retrieve a sample
|
|
349
|
-
return
|
|
346
|
+
return pxt.ArrayType((None,), dtype=pxt.FloatType(), nullable=False)
|
|
350
347
|
dimensions = _embedding_dimensions_cache.get(model, None)
|
|
351
|
-
return
|
|
348
|
+
return pxt.ArrayType((dimensions,), dtype=pxt.FloatType(), nullable=False)
|
|
352
349
|
|
|
353
350
|
|
|
354
351
|
#####################################
|
|
@@ -408,17 +405,17 @@ def image_generations(
|
|
|
408
405
|
|
|
409
406
|
|
|
410
407
|
@image_generations.conditional_return_type
|
|
411
|
-
def _(size: Optional[str] = None) ->
|
|
408
|
+
def _(size: Optional[str] = None) -> pxt.ImageType:
|
|
412
409
|
if size is None:
|
|
413
|
-
return
|
|
410
|
+
return pxt.ImageType(size=(1024, 1024))
|
|
414
411
|
x_pos = size.find('x')
|
|
415
412
|
if x_pos == -1:
|
|
416
|
-
return
|
|
413
|
+
return pxt.ImageType()
|
|
417
414
|
try:
|
|
418
415
|
width, height = int(size[:x_pos]), int(size[x_pos + 1 :])
|
|
419
416
|
except ValueError:
|
|
420
|
-
return
|
|
421
|
-
return
|
|
417
|
+
return pxt.ImageType()
|
|
418
|
+
return pxt.ImageType(size=(width, height))
|
|
422
419
|
|
|
423
420
|
|
|
424
421
|
#####################################
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pixeltable [UDFs](https://pixeltable.readme.io/docs/user-defined-functions-udfs)
|
|
3
|
+
that wrap various endpoints from the Replicate API. In order to use them, you must
|
|
4
|
+
first `pip install replicate` and configure your Replicate credentials, as described in
|
|
5
|
+
the [Working with Replicate](https://pixeltable.readme.io/docs/working-with-replicate) tutorial.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import pixeltable as pxt
|
|
11
|
+
from pixeltable.env import Env, register_client
|
|
12
|
+
from pixeltable.utils.code import local_public_names
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import replicate # type: ignore[import-untyped]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@register_client('replicate')
|
|
19
|
+
def _(api_token: str) -> 'replicate.Client':
|
|
20
|
+
import replicate
|
|
21
|
+
return replicate.Client(api_token=api_token)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _replicate_client() -> 'replicate.Client':
|
|
25
|
+
return Env.get().get_client('replicate')
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pxt.udf
|
|
29
|
+
def run(
|
|
30
|
+
input: dict[str, Any],
|
|
31
|
+
*,
|
|
32
|
+
ref: str,
|
|
33
|
+
) -> dict[str, Any]:
|
|
34
|
+
"""
|
|
35
|
+
Run a model on Replicate.
|
|
36
|
+
|
|
37
|
+
For additional details, see: <https://replicate.com/docs/topics/models/run-a-model>
|
|
38
|
+
|
|
39
|
+
__Requirements:__
|
|
40
|
+
|
|
41
|
+
- `pip install replicate`
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input: The input parameters for the model.
|
|
45
|
+
ref: The name of the model to run.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
The output of the model.
|
|
49
|
+
|
|
50
|
+
Examples:
|
|
51
|
+
Add a computed column that applies the model `meta/meta-llama-3-8b-instruct`
|
|
52
|
+
to an existing Pixeltable column `tbl.prompt` of the table `tbl`:
|
|
53
|
+
|
|
54
|
+
>>> input = {'system_prompt': 'You are a helpful assistant.', 'prompt': tbl.prompt}
|
|
55
|
+
... tbl['response'] = run(input, ref='meta/meta-llama-3-8b-instruct')
|
|
56
|
+
|
|
57
|
+
Add a computed column that uses the model `black-forest-labs/flux-schnell`
|
|
58
|
+
to generate images from an existing Pixeltable column `tbl.prompt`:
|
|
59
|
+
|
|
60
|
+
>>> input = {'prompt': tbl.prompt, 'go_fast': True, 'megapixels': '1'}
|
|
61
|
+
... tbl['response'] = run(input, ref='black-forest-labs/flux-schnell')
|
|
62
|
+
... tbl['image'] = tbl.response.output[0].astype(pxt.Image)
|
|
63
|
+
"""
|
|
64
|
+
Env.get().require_package('replicate')
|
|
65
|
+
return _replicate_client().run(ref, input, use_file_output=False)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
__all__ = local_public_names(__name__)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def __dir__():
|
|
72
|
+
return __all__
|